Trainer
Trainer class for a Generative Adversarial Network (GAN) encapsulates the training process, including initialization, training loops for the discriminator and generator, and saving the model. It handles training over a specified number of epochs, optimizes both the generator and discriminator models, and logs training progress.
Parameters: - device (torch.device): The device to train on, e.g., 'cpu' or 'cuda'. - latent_space (int, optional): Dimension of the latent space vector. Defaults to 100. - image_size (int, optional): Height and width of the images to generate. Defaults to 64. - lr (float, optional): Learning rate for the Adam optimizers. Defaults to 0.0002. - epochs (int, optional): Number of training epochs. Defaults to 100.
Attributes: - netG (Generator): The generator model. - netD (Discriminator): The discriminator model. - optimizerD (torch.optim.Optimizer): Optimizer for the discriminator. - optimizerG (torch.optim.Optimizer): Optimizer for the generator. - criterion (nn.Module): Loss function (Binary Cross Entropy Loss). - real_label (float): Label for real images (1.0). - fake_label (float): Label for fake images (0.0). - nz (int): Size of the latent vector (z). - num_epochs (int): Number of epochs for training.
Methods: - model_init(): Initializes the models and applies weights initialization. - optimizer_init(generator, discriminator): Initializes the optimizers for both models. - train_discriminator(data): Performs a single training step for the discriminator. - train_generator(fake): Performs a single training step for the generator. - display_results(epoch, i, dataloader, errD, errG, D_x, D_G_z1, D_G_z2): Logs training progress to the console. - save_generator_model(epoch): Saves the current state of the generator model. - dataloader(): Loads and returns a dataloader instance. - train(): Executes the training loop over the specified number of epochs.
Example
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") trainer = Trainer(device=device, epochs=30, lr=0.0002) trainer.train()
Note:
This class assumes the presence of Generator and Discriminator classes, along with a weights_init function for model weight initialization. The dataloader is expected to be loaded using joblib from a specified path.
Source code in trainer.py
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 | |
dataloader()
Loads and returns the training data loader from a serialized file.
This method is responsible for loading the training data loader, which has been previously saved to disk using serialization (e.g., with joblib). It allows for quick loading of preprocessed and prepared batches of data for training.
Returns: - DataLoader: The loaded DataLoader object ready for iteration. This dataloader is expected to yield batches of training data during the training loop.
Note: - The dataloader is loaded from a predefined path '../data/processed/dataloader.pkl'. This path must exist and contain a serialized DataLoader object. The method assumes the preprocessing and preparation of data are already completed and saved to this location. - This method performs a file I/O operation to read the DataLoader object from disk. Ensure the specified path is accessible and the file format is compatible with the joblib library.
Source code in trainer.py
279 280 281 282 283 284 285 286 287 288 289 290 291 292 | |
display_results(epoch, i, dataloader, errD, errG, D_x, D_G_z1, D_G_z2)
Displays the training results and progress metrics for the current batch and epoch.
This method logs the losses of the discriminator and generator, as well as the discriminator's performance on real and fake images. It provides insights into how well the discriminator and generator are learning and adapting during the training process.
Parameters: - epoch (int): The current epoch number during training. - i (int): The current batch number within the epoch. - dataloader (DataLoader): The DataLoader used for training, utilized here to determine the total number of batches. - errD (float): The current loss of the discriminator. - errG (float): The current loss of the generator. - D_x (float): The average output of the discriminator for real images. Closer to 1 indicates better performance on real images. - D_G_z1 (float): The average output of the discriminator for fake images before the generator update. Closer to 0 indicates better discrimination of fake images. - D_G_z2 (float): The average output of the discriminator for fake images after the generator update. Closer to 1 indicates the generator is improving in fooling the discriminator.
Output: - The method prints a formatted string to the console, summarizing the training metrics for the current batch within the ongoing epoch.
Note: - This method is intended for logging purposes and does not return any values. It provides a snapshot of the training progress at the moment it is called, allowing for monitoring of the GAN's learning dynamics.
Source code in trainer.py
224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 | |
model_init()
Initializes the Generator and Discriminator models for the GAN. This method constructs the models with the specified latent space size and image size, moves them to the appropriate device (CPU or GPU), and applies a predefined weight initialization function to both models.
The models are defined by the Generator and Discriminator classes, which should be available in the same scope as this Trainer class. The latent space size and image size are used to configure the models according to the specifics of the GAN architecture being trained.
| Returns: |
|
|---|
Side effects
- Instantiates the Generator and Discriminator models with the specified configurations.
- Applies a predefined weight initialization function to both models to ensure optimal training behavior.
- Moves the models to the specified device, which is typically determined by whether a GPU is available for training.
Note
The device used for training is determined by the 'device' attribute of the Trainer class instance. The weights initialization function applied to both models is defined externally and must be available in the same scope as this Trainer class.
Source code in trainer.py
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 | |
optimizer_init(generator, discriminator)
Initializes the optimizers for both the generator and discriminator models. This method sets up Adam optimizers with specified learning rates and betas parameters, which are critical for the training dynamics of the Generative Adversarial Network (GAN).
Parameters: - generator (torch.nn.Module): The generator model for which the optimizer will be initialized. This model should already be instantiated and configured with the appropriate architecture for generating images. - discriminator (torch.nn.Module): The discriminator model for which the optimizer will be initialized. This model should already be instantiated and configured with the appropriate architecture for discriminating between real and generated images.
- tuple: A tuple containing two optimizer objects:
- optimizerD (torch.optim.Adam): The Adam optimizer configured for the discriminator model, including learning rate and betas parameters.
- optimizerG (torch.optim.Adam): The Adam optimizer configured for the generator model, including learning rate and betas parameters.
Note:
- The learning rate (lr) and betas parameters for the Adam optimizers are critical hyperparameters that can affect the training stability and convergence of the GAN. These parameters are set based on best practices and empirical results but may require adjustment based on the specific characteristics of the dataset or model architecture.
- This method assumes that the lr attribute (learning rate) is already set in the Trainer class instance and uses this value for both optimizers. The betas parameters are fixed in this implementation but could be exposed as parameters or attributes for more flexibility.
Source code in trainer.py
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | |
save_generator_model(epoch)
Saves the state dictionary of the generator model to a file, capturing its current weights.
This method is typically called at the end of each training epoch to persist the state of the generator model, allowing for later use or further training from the saved state. The filename includes the epoch number for easy identification and versioning.
Parameters: - epoch (int): The current epoch number. This is used to name the saved model file, indicating at which point in the training process the model was saved.
Output:
- The method saves the generator's state dictionary to a file in the current working directory. The file is named 'generator_epoch_{epoch}.pth', where {epoch} is replaced with the current epoch number.
Note: - This method does not return any value. It performs a file I/O operation to write the generator model's state dictionary to disk.
Source code in trainer.py
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 | |
train()
Executes the training loop for the Generative Adversarial Network (GAN).
This method orchestrates the training process by iterating over a specified number of epochs, during which it trains the discriminator and generator models in sequence. At each step of the training, it logs the progress, including the losses of both models and the discriminator's performance metrics. At the end of each epoch, it saves the current state of the generator model.
The training loop follows these steps:
1. Loads the data using the dataloader method, which should return an iterable DataLoader object containing the training data.
2. Iterates over the specified number of epochs (as defined by self.num_epochs).
a. For each batch in the DataLoader:
i. Trains the discriminator on both real and fake data, computing its loss.
ii. Generates a new batch of fake data and trains the generator, attempting to fool the discriminator, computing its loss.
iii. Logs the current losses and discriminator performance metrics using the display_results method.
3. Saves the state of the generator model after each epoch using the save_generator_model method.
Note:
- The actual training of the discriminator and generator is performed by the train_discriminator and train_generator methods, respectively. This method coordinates these calls and handles logging and model state saving.
- Progress logging and model saving are designed to provide insights into the training process and to allow for interruption and resumption of training without loss of progress.
Source code in trainer.py
294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 | |
train_discriminator(data)
Trains the discriminator model on both real and generated (fake) images. This method performs a forward pass with real images from the dataset and fake images generated by the generator, computes the loss for both, backpropagates to update the discriminator's weights, and returns the losses and discriminator outputs.
Parameters: - data (torch.Tensor): A batch of real images from the dataset. This tensor should have the shape (N, C, H, W), where N is the batch size, C is the number of channels, and H and W are the height and width of the images.
- tuple: A tuple containing the following elements:
- errD (torch.Tensor): The total discriminator loss calculated as the sum of the loss for real and fake images.
- D_x (float): The mean output of the discriminator for real images. This value is used to evaluate the discriminator's performance on real data.
- D_G_z1 (float): The mean output of the discriminator for fake images before the generator update. This value is used to evaluate the discriminator's performance on fake data.
- fake (torch.Tensor): A batch of fake images generated by the generator.
The method performs the following steps: 1. Zeroes the gradients of the discriminator. 2. Processes a batch of real images, computes the loss against the true labels, backpropagates the error, and calculates the mean discriminator output (D_x). 3. Generates a batch of fake images using the generator, computes the loss against the false labels, backpropagates the error, and calculates the mean discriminator output for the fake images (D_G_z1). 4. Updates the discriminator's weights based on the total loss.
Note: - This method updates the discriminator's weights once per call, using the combined loss from both real and fake images. - The real_label and fake_label attributes of the Trainer class are used to denote the true and false labels, respectively, for computing the loss.
Source code in trainer.py
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 | |
train_generator(fake)
Trains the generator model by attempting to fool the discriminator.
This method updates the generator's weights based on its ability to generate fake images that the discriminator classifies as real. It computes the loss using the output of the discriminator on the generated (fake) images, performs backpropagation to calculate the gradients, and updates the generator's weights to reduce this loss.
Parameters: - fake (torch.Tensor): A batch of fake images generated by the generator model. The tensor should have dimensions (N, C, H, W), where N is the batch size, C is the number of channels, and H and W are the height and width of the images, respectively.
- tuple: A tuple containing the following elements:
- errG (torch.Tensor): The loss of the generator computed as the binary cross-entropy loss between the discriminator's output on the fake images and the real labels.
- D_G_z2 (float): The average output of the discriminator for the fake images. This metric indicates how well the generator is fooling the discriminator, with higher values suggesting better performance.
The training step involves: 1. Zeroing the gradients of the generator to ensure that previous training steps do not affect the current update. 2. Creating a tensor of real labels (since the generator's goal is to have its fake images classified as real by the discriminator) and computing the loss against the discriminator's predictions on the fake images. 3. Performing backpropagation to calculate the gradients with respect to the generator's parameters. 4. Updating the generator's weights using the optimizer to reduce the loss, thereby improving the generator's ability to produce realistic images.
Note: - The method assumes the use of the binary cross-entropy loss (BCELoss) to quantify how well the generator is fooling the discriminator. The labels for the fake images are set to 'real' (self.real_label) for the purpose of this loss calculation. - This method directly influences the generator's performance by adjusting its ability to create images that are indistinguishable from real images to the discriminator.
Source code in trainer.py
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 | |