21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206 | class Charts:
"""
A class used to represent and perform operations for model visualization, including
selecting the best model based on loss, loading data, normalizing images, creating GIFs from
training images, and plotting comparison images (noisy, clean, predicted) using Matplotlib.
Attributes:
device (str): The device type (e.g., 'cpu', 'cuda', 'mps') for PyTorch operations.
infinity (float): A high initial value for comparing and finding minimum loss.
model (torch.nn.Module): The neural network model used for image denoising.
best_model (torch.nn.Module): The best performing model based on loss.
details (dict): A dictionary containing details of the best model and its loss.
Methods:
select_best_model(self):
Selects the best model based on minimum loss from saved models.
dataloader(self):
Loads the dataloader from the processed data path.
image_normalized(self, image):
Normalizes an image tensor to be in the range [0, 1].
create_gif(self):
Creates a GIF from training images stored in a specified path.
plot(self):
Plots comparison images (noisy, clean, predicted) and saves a comparison figure and a training images GIF.
"""
def __init__(self, device="mps"):
"""
Initializes the Charts object with specified device for PyTorch operations and
sets initial values for attributes.
Parameters:
device (str): The device type for PyTorch operations. Defaults to 'mps'.
"""
self.device = device_init(device=device)
self.infinity = float("inf")
self.model = None
self.best_model = None
self.details = {"best_model": list(), "best_loss": list()}
def select_best_model(self):
"""
Selects the best model based on minimum loss from models saved in the BEST_MODELS_PATH.
Returns:
torch.nn.Module: The best performing model based on loss.
Raises:
Exception: If the BEST_MODELS_PATH does not exist.
"""
if os.path.exists(BEST_MODELS_PATH):
models = os.listdir(BEST_MODELS_PATH)
for model in models:
if (
self.infinity
> torch.load(os.path.join(BEST_MODELS_PATH, model))["loss"]
):
self.infinity = torch.load(os.path.join(BEST_MODELS_PATH, model))[
"loss"
]
self.best_model = torch.load(os.path.join(BEST_MODELS_PATH, model))[
"model"
]
self.details["best_model"].append(self.best_model)
self.details["best_loss"].append(self.infinity)
return self.best_model
else:
raise Exception("Best model path is not found".title())
def dataloader(self):
"""
Loads and returns the dataloader object from a pickled file in the PROCESSED_DATA_PATH.
Returns:
DataLoader: The dataloader containing the dataset.
Raises:
Exception: If the PROCESSED_DATA_PATH does not exist or the dataloader file is not found.
"""
if os.path.exists(PROCESSED_DATA_PATH):
return load(os.path.join(PROCESSED_DATA_PATH, "dataloader.pkl"))
else:
raise Exception("Processed data path is not found".title())
def image_normalized(self, image):
"""
Normalizes an image tensor to have values in the range [0, 1].
Parameters:
image (torch.Tensor): The image tensor to be normalized.
Returns:
numpy.ndarray: The normalized image as a numpy array.
"""
image = image.cpu().detach().numpy()
return (image - image.min()) / (image.max() - image.min())
def create_gif(self):
"""
Creates and saves a GIF from images stored in TRAIN_IMAGES_PATH to GIF_IMAGE_PATH.
Raises:
Exception: If the TRAIN_IMAGES_PATH or GIF_IMAGE_PATH does not exist.
"""
if os.path.exists(TRAIN_IMAGES_PATH) and os.path.exists(GIF_IMAGE_PATH):
images = [
imageio.imread(os.path.join(TRAIN_IMAGES_PATH, image))
for image in os.listdir(TRAIN_IMAGES_PATH)
]
imageio.mimsave(
os.path.join(GIF_IMAGE_PATH, "train_masks.gif"), images, "GIF"
)
else:
raise Exception("Train images path not found.".capitalize())
def plot(self):
"""
Plots comparison images for each set of noisy, clean, and predicted images in the test dataset.
Saves a comparison figure to TEST_IMAGES_PATH and a GIF of training images to GIF_IMAGE_PATH.
Raises:
Exception: If the TEST_IMAGES_PATH does not exist.
"""
try:
self.model = DnCNN().to(self.device)
except Exception as e:
print("The exception in the model is:", e)
else:
self.model.load_state_dict(self.select_best_model())
finally:
self.test_dataloader = self.dataloader()
noise_images, clean_images = next(iter(self.test_dataloader))
noise_images = noise_images.to(self.device)
clean_images = clean_images.to(self.device)
predicted_images = self.model(noise_images)
plt.figure(figsize=(36, 24))
for index, image in enumerate(predicted_images):
noisy = noise_images[index].permute(1, 2, 0)
noisy = self.image_normalized(noisy)
plt.subplot(3 * 5, 3 * 8, 3 * index + 1)
plt.imshow(noisy, cmap="gray")
plt.title("Noisy")
plt.axis("off")
clean = clean_images[index].permute(1, 2, 0)
clean = self.image_normalized(clean)
plt.subplot(3 * 5, 3 * 8, 3 * index + 2)
plt.imshow(clean, cmap="gray")
plt.title("Clean")
plt.axis("off")
predicted = image.permute(1, 2, 0)
predicted = self.image_normalized(predicted)
plt.subplot(3 * 5, 3 * 8, 3 * index + 3)
plt.imshow(predicted, cmap="gray")
plt.title("Predict")
plt.axis("off")
plt.tight_layout()
if os.path.exists(TEST_IMAGES_PATH):
plt.savefig(os.path.join(TEST_IMAGES_PATH, "test.png"))
try:
self.create_gif()
except Exception as e:
print("The exception in the gif is:", e)
else:
raise Exception("Test images path is not found".title())
plt.show()
|