Synthetic Data Generator Script
This script is responsible for generating synthetic data using a trained Generative Adversarial Network (GAN). It includes argument parsing for command-line customization and utilizes the trained generator model to produce synthetic images.
Features:
- Command-line argument parsing for specifying the number of samples and latent space dimensions.
- Loading the best-performing generator model from checkpoints.
- Generating and saving synthetic images.
Usage:
To use this script, run it from the command line with the desired arguments, for example:
python synthetic_data_generator.py --samples 20 --latent_space 100
Arguments:
--samples: Number of synthetic samples to generate.
--latent_space: Dimension of the latent space for the generator.
Source code in test.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123 | class Test:
"""
# Synthetic Data Generator Script
This script is responsible for generating synthetic data using a trained Generative Adversarial Network (GAN). It includes argument parsing for command-line customization and utilizes the trained generator model to produce synthetic images.
## Features:
- Command-line argument parsing for specifying the number of samples and latent space dimensions.
- Loading the best-performing generator model from checkpoints.
- Generating and saving synthetic images.
## Usage:
To use this script, run it from the command line with the desired arguments, for example:
python synthetic_data_generator.py --samples 20 --latent_space 100
## Arguments:
- `--samples`: Number of synthetic samples to generate.
- `--latent_space`: Dimension of the latent space for the generator.
"""
def __init__(
self,
num_samples=20,
latent_space=100,
):
"""
Initializes the Test class with the specified configuration for generating synthetic images.
### Parameters:
- `num_samples` (int): Number of synthetic samples to generate.
- `latent_space` (int): Size of the latent space (input vector for the generator).
"""
self.num_samples = num_samples
self.latent_space = latent_space
self.generator = Generator()
def get_best_model(self):
"""
Retrieves the best-performing generator model from the saved checkpoints.
### Returns:
- `model` (str): Path to the best-performing generator model.
### Raises:
- Exception: If no model is found in the checkpoints directory.
"""
model_checkpoints = "./models/checkpoints"
try:
model = (
model_checkpoints
+ "/"
+ "generator_"
+ str(len(os.listdir(model_checkpoints)) - 1)
+ ".pth"
)
except Exception as e:
raise Exception("No model found".capitalize())
return model
def saved_images(self, **kwargs):
"""
Saves and displays the generated synthetic images.
### Parameters (passed as keyword arguments):
- `synthetic_samples` (Tensor): Synthetic samples generated by the generator.
- `real_labels` (list): List of class labels.
- `labels` (Tensor): Labels for the synthetic samples.
- `batch_size` (int): Number of samples in the batch.
### Side Effects:
- Saves the generated images to the filesystem.
- Displays the generated images in a plot.
"""
plt.figure(figsize=(10, 5))
num_rows = 2
num_columns = kwargs["batch_size"] // num_rows
for index in range(kwargs["batch_size"]):
plt.subplot(num_rows, num_columns, index + 1)
plt.imshow(
kwargs["synthetic_samples"][index].detach().numpy().reshape(28, 28)
)
plt.axis("off")
plt.tight_layout()
try:
plt.savefig("./outputs/synthetic_image.png")
plt.show()
except Exception as e:
raise Exception("No model found".capitalize())
def plot_synthetic_image(self):
"""
Generates and plots synthetic images using the best-performing generator model.
### Process:
- Loads the best-performing generator model.
- Generates synthetic samples and their labels.
- Calls `saved_images` to save and display the generated images.
### Side Effects:
- Updates the state of the generator model.
- Calls `saved_images` which saves and displays images.
"""
model = self.get_best_model()
self.generator.load_state_dict(torch.load(model))
noise_samples = torch.randn(self.num_samples, self.latent_space)
synthetic_samples = self.generator(noise_samples)
self.saved_images(
synthetic_samples=synthetic_samples,
batch_size=self.num_samples,
)
|
__init__(num_samples=20, latent_space=100)
Initializes the Test class with the specified configuration for generating synthetic images.
Parameters:
num_samples (int): Number of synthetic samples to generate.
latent_space (int): Size of the latent space (input vector for the generator).
Source code in test.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46 | def __init__(
self,
num_samples=20,
latent_space=100,
):
"""
Initializes the Test class with the specified configuration for generating synthetic images.
### Parameters:
- `num_samples` (int): Number of synthetic samples to generate.
- `latent_space` (int): Size of the latent space (input vector for the generator).
"""
self.num_samples = num_samples
self.latent_space = latent_space
self.generator = Generator()
|
get_best_model()
Retrieves the best-performing generator model from the saved checkpoints.
Returns:
model (str): Path to the best-performing generator model.
Raises:
- Exception: If no model is found in the checkpoints directory.
Source code in test.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69 | def get_best_model(self):
"""
Retrieves the best-performing generator model from the saved checkpoints.
### Returns:
- `model` (str): Path to the best-performing generator model.
### Raises:
- Exception: If no model is found in the checkpoints directory.
"""
model_checkpoints = "./models/checkpoints"
try:
model = (
model_checkpoints
+ "/"
+ "generator_"
+ str(len(os.listdir(model_checkpoints)) - 1)
+ ".pth"
)
except Exception as e:
raise Exception("No model found".capitalize())
return model
|
plot_synthetic_image()
Generates and plots synthetic images using the best-performing generator model.
Process:
- Loads the best-performing generator model.
- Generates synthetic samples and their labels.
- Calls
saved_images to save and display the generated images.
Side Effects:
- Updates the state of the generator model.
- Calls
saved_images which saves and displays images.
Source code in test.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123 | def plot_synthetic_image(self):
"""
Generates and plots synthetic images using the best-performing generator model.
### Process:
- Loads the best-performing generator model.
- Generates synthetic samples and their labels.
- Calls `saved_images` to save and display the generated images.
### Side Effects:
- Updates the state of the generator model.
- Calls `saved_images` which saves and displays images.
"""
model = self.get_best_model()
self.generator.load_state_dict(torch.load(model))
noise_samples = torch.randn(self.num_samples, self.latent_space)
synthetic_samples = self.generator(noise_samples)
self.saved_images(
synthetic_samples=synthetic_samples,
batch_size=self.num_samples,
)
|
saved_images(**kwargs)
Saves and displays the generated synthetic images.
Parameters (passed as keyword arguments):
synthetic_samples (Tensor): Synthetic samples generated by the generator.
real_labels (list): List of class labels.
labels (Tensor): Labels for the synthetic samples.
batch_size (int): Number of samples in the batch.
Side Effects:
- Saves the generated images to the filesystem.
- Displays the generated images in a plot.
Source code in test.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100 | def saved_images(self, **kwargs):
"""
Saves and displays the generated synthetic images.
### Parameters (passed as keyword arguments):
- `synthetic_samples` (Tensor): Synthetic samples generated by the generator.
- `real_labels` (list): List of class labels.
- `labels` (Tensor): Labels for the synthetic samples.
- `batch_size` (int): Number of samples in the batch.
### Side Effects:
- Saves the generated images to the filesystem.
- Displays the generated images in a plot.
"""
plt.figure(figsize=(10, 5))
num_rows = 2
num_columns = kwargs["batch_size"] // num_rows
for index in range(kwargs["batch_size"]):
plt.subplot(num_rows, num_columns, index + 1)
plt.imshow(
kwargs["synthetic_samples"][index].detach().numpy().reshape(28, 28)
)
plt.axis("off")
plt.tight_layout()
try:
plt.savefig("./outputs/synthetic_image.png")
plt.show()
except Exception as e:
raise Exception("No model found".capitalize())
|