Discriminator
Bases: Module
A Discriminator class for a Generative Adversarial Network (GAN), designed to differentiate between real and fake images.
This module implements a series of convolutional layers with LeakyReLU activations and BatchNorm, progressively doubling the number of feature maps while reducing the spatial dimensions of the input image. The final layer uses a Sigmoid activation to output a probability indicating the likelihood of the input image being real.
Parameters: - image_size (int): The height / width of the square input images. Default is 64. This parameter also indirectly controls the complexity of the discriminator's architecture by setting the size of feature maps.
Attributes:
- ndf (int): The size of the feature maps in the discriminator, initially set based on the image_size.
- layer_config (OrderedDict): An ordered dictionary that defines the architecture of the discriminator, including convolutional layers, batch normalization layers, and activation functions.
- main (nn.Sequential): The sequential container of layers as defined in layer_config.
The architecture starts with a convolutional layer with ndf (number of discriminator features) channels, followed by layers with 2*ndf, 4*ndf, and 8*ndf channels, before concluding with a final convolution to a single output channel. Batch normalization is applied starting from the second convolutional layer.
Methods: - forward(input): Defines the forward pass of the discriminator.
Example
discriminator = Discriminator(image_size=64)
Assuming images is a batch of real or generated images
predictions = discriminator(images)
Note:
The input images are expected to be 3-channel RGB images of size [image_size, image_size]. The discriminator dynamically adjusts its complexity based on the input image size.
Layer Details:
- Conv1: Input 3 channels, output ndf channels, 4x4 kernel, stride 2, padding 1, no bias.
- LeakyReLU1: Negative slope 0.2, inplace.
- Conv2: Input ndf channels, output 2*ndf channels, 4x4 kernel, stride 2, padding 1, no bias.
- BN2: BatchNorm on 2*ndf channels.
- LeakyReLU2: Negative slope 0.2, inplace.
- Conv3: Input 2*ndf channels, output 4*ndf channels, 4x4 kernel, stride 2, padding 1, no bias.
- BN3: BatchNorm on 4*ndf channels.
- LeakyReLU3: Negative slope 0.2, inplace.
- Conv4: Input 4*ndf channels, output 8*ndf channels, 4x4 kernel, stride 2, padding 1, no bias.
- BN4: BatchNorm on 8*ndf channels.
- LeakyReLU4: Negative slope 0.2, inplace.
- Conv5: Input 8*ndf channels, output 1 channel, 4x4 kernel, stride 1, no padding, no bias.
- Sigmoid: Applied to the final layer output to obtain a probability.
Source code in discriminator.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | |
forward(input)
Forward pass of the discriminator. Takes an image tensor and returns the discriminator's prediction.
Parameters:
- input (torch.Tensor): A batch of images of shape (N, 3, image_size, image_size).
Returns:
- torch.Tensor: A tensor of shape (N,) containing the probability that each image in the batch is real.
Source code in discriminator.py
80 81 82 83 84 85 86 87 88 89 90 | |