Charts

A class used to represent and perform operations for model visualization, including selecting the best model based on loss, loading data, normalizing images, creating GIFs from training images, and plotting comparison images (noisy, clean, predicted) using Matplotlib.

Attributes:
  • device (str) –

    The device type (e.g., 'cpu', 'cuda', 'mps') for PyTorch operations.

  • infinity (float) –

    A high initial value for comparing and finding minimum loss.

  • model (Module) –

    The neural network model used for image denoising.

  • best_model (Module) –

    The best performing model based on loss.

  • details (dict) –

    A dictionary containing details of the best model and its loss.

Methods:

Name Description
select_best_model

Selects the best model based on minimum loss from saved models.

dataloader

Loads the dataloader from the processed data path.

image_normalized

Normalizes an image tensor to be in the range [0, 1].

create_gif

Creates a GIF from training images stored in a specified path.

plot

Plots comparison images (noisy, clean, predicted) and saves a comparison figure and a training images GIF.

Source code in test.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
class Charts:
    """
    A class used to represent and perform operations for model visualization, including
    selecting the best model based on loss, loading data, normalizing images, creating GIFs from
    training images, and plotting comparison images (noisy, clean, predicted) using Matplotlib.

    Attributes:
        device (str): The device type (e.g., 'cpu', 'cuda', 'mps') for PyTorch operations.
        infinity (float): A high initial value for comparing and finding minimum loss.
        model (torch.nn.Module): The neural network model used for image denoising.
        best_model (torch.nn.Module): The best performing model based on loss.
        details (dict): A dictionary containing details of the best model and its loss.

    Methods:
        select_best_model(self):
            Selects the best model based on minimum loss from saved models.

        dataloader(self):
            Loads the dataloader from the processed data path.

        image_normalized(self, image):
            Normalizes an image tensor to be in the range [0, 1].

        create_gif(self):
            Creates a GIF from training images stored in a specified path.

        plot(self):
            Plots comparison images (noisy, clean, predicted) and saves a comparison figure and a training images GIF.
    """

    def __init__(self, device="mps"):
        """
        Initializes the Charts object with specified device for PyTorch operations and
        sets initial values for attributes.

        Parameters:
            device (str): The device type for PyTorch operations. Defaults to 'mps'.
        """
        self.device = device_init(device=device)
        self.infinity = float("inf")
        self.model = None
        self.best_model = None
        self.details = {"best_model": list(), "best_loss": list()}

    def select_best_model(self):
        """
        Selects the best model based on minimum loss from models saved in the BEST_MODELS_PATH.

        Returns:
            torch.nn.Module: The best performing model based on loss.

        Raises:
            Exception: If the BEST_MODELS_PATH does not exist.
        """
        if os.path.exists(BEST_MODELS_PATH):

            models = os.listdir(BEST_MODELS_PATH)
            for model in models:
                if (
                    self.infinity
                    > torch.load(os.path.join(BEST_MODELS_PATH, model))["loss"]
                ):
                    self.infinity = torch.load(os.path.join(BEST_MODELS_PATH, model))[
                        "loss"
                    ]
                    self.best_model = torch.load(os.path.join(BEST_MODELS_PATH, model))[
                        "model"
                    ]
            self.details["best_model"].append(self.best_model)
            self.details["best_loss"].append(self.infinity)

            return self.best_model
        else:
            raise Exception("Best model path is not found".title())

    def dataloader(self):
        """
        Loads and returns the dataloader object from a pickled file in the PROCESSED_DATA_PATH.

        Returns:
            DataLoader: The dataloader containing the dataset.

        Raises:
            Exception: If the PROCESSED_DATA_PATH does not exist or the dataloader file is not found.
        """
        if os.path.exists(PROCESSED_DATA_PATH):
            return load(os.path.join(PROCESSED_DATA_PATH, "dataloader.pkl"))
        else:
            raise Exception("Processed data path is not found".title())

    def image_normalized(self, image):
        """
        Normalizes an image tensor to have values in the range [0, 1].

        Parameters:
            image (torch.Tensor): The image tensor to be normalized.

        Returns:
            numpy.ndarray: The normalized image as a numpy array.
        """
        image = image.cpu().detach().numpy()
        return (image - image.min()) / (image.max() - image.min())

    def create_gif(self):
        """
        Creates and saves a GIF from images stored in TRAIN_IMAGES_PATH to GIF_IMAGE_PATH.

        Raises:
            Exception: If the TRAIN_IMAGES_PATH or GIF_IMAGE_PATH does not exist.
        """
        if os.path.exists(TRAIN_IMAGES_PATH) and os.path.exists(GIF_IMAGE_PATH):
            images = [
                imageio.imread(os.path.join(TRAIN_IMAGES_PATH, image))
                for image in os.listdir(TRAIN_IMAGES_PATH)
            ]
            imageio.mimsave(
                os.path.join(GIF_IMAGE_PATH, "train_masks.gif"), images, "GIF"
            )
        else:
            raise Exception("Train images path not found.".capitalize())

    def plot(self):
        """
        Plots comparison images for each set of noisy, clean, and predicted images in the test dataset.
        Saves a comparison figure to TEST_IMAGES_PATH and a GIF of training images to GIF_IMAGE_PATH.

        Raises:
            Exception: If the TEST_IMAGES_PATH does not exist.
        """
        try:
            self.model = DnCNN().to(self.device)
        except Exception as e:
            print("The exception in the model is:", e)
        else:
            self.model.load_state_dict(self.select_best_model())

        finally:
            self.test_dataloader = self.dataloader()

            noise_images, clean_images = next(iter(self.test_dataloader))

            noise_images = noise_images.to(self.device)
            clean_images = clean_images.to(self.device)

            predicted_images = self.model(noise_images)

        plt.figure(figsize=(36, 24))

        for index, image in enumerate(predicted_images):
            noisy = noise_images[index].permute(1, 2, 0)
            noisy = self.image_normalized(noisy)

            plt.subplot(3 * 5, 3 * 8, 3 * index + 1)
            plt.imshow(noisy, cmap="gray")
            plt.title("Noisy")
            plt.axis("off")

            clean = clean_images[index].permute(1, 2, 0)
            clean = self.image_normalized(clean)

            plt.subplot(3 * 5, 3 * 8, 3 * index + 2)
            plt.imshow(clean, cmap="gray")
            plt.title("Clean")
            plt.axis("off")

            predicted = image.permute(1, 2, 0)
            predicted = self.image_normalized(predicted)

            plt.subplot(3 * 5, 3 * 8, 3 * index + 3)
            plt.imshow(predicted, cmap="gray")
            plt.title("Predict")
            plt.axis("off")

        plt.tight_layout()

        if os.path.exists(TEST_IMAGES_PATH):
            plt.savefig(os.path.join(TEST_IMAGES_PATH, "test.png"))

            try:
                self.create_gif()
            except Exception as e:
                print("The exception in the gif is:", e)
        else:
            raise Exception("Test images path is not found".title())

        plt.show()

__init__(device='mps')

Initializes the Charts object with specified device for PyTorch operations and sets initial values for attributes.

Parameters:
  • device (str, default: 'mps' ) –

    The device type for PyTorch operations. Defaults to 'mps'.

Source code in test.py
51
52
53
54
55
56
57
58
59
60
61
62
63
def __init__(self, device="mps"):
    """
    Initializes the Charts object with specified device for PyTorch operations and
    sets initial values for attributes.

    Parameters:
        device (str): The device type for PyTorch operations. Defaults to 'mps'.
    """
    self.device = device_init(device=device)
    self.infinity = float("inf")
    self.model = None
    self.best_model = None
    self.details = {"best_model": list(), "best_loss": list()}

create_gif()

Creates and saves a GIF from images stored in TRAIN_IMAGES_PATH to GIF_IMAGE_PATH.

Raises:
  • Exception

    If the TRAIN_IMAGES_PATH or GIF_IMAGE_PATH does not exist.

Source code in test.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def create_gif(self):
    """
    Creates and saves a GIF from images stored in TRAIN_IMAGES_PATH to GIF_IMAGE_PATH.

    Raises:
        Exception: If the TRAIN_IMAGES_PATH or GIF_IMAGE_PATH does not exist.
    """
    if os.path.exists(TRAIN_IMAGES_PATH) and os.path.exists(GIF_IMAGE_PATH):
        images = [
            imageio.imread(os.path.join(TRAIN_IMAGES_PATH, image))
            for image in os.listdir(TRAIN_IMAGES_PATH)
        ]
        imageio.mimsave(
            os.path.join(GIF_IMAGE_PATH, "train_masks.gif"), images, "GIF"
        )
    else:
        raise Exception("Train images path not found.".capitalize())

dataloader()

Loads and returns the dataloader object from a pickled file in the PROCESSED_DATA_PATH.

Returns:
  • DataLoader

    The dataloader containing the dataset.

Raises:
  • Exception

    If the PROCESSED_DATA_PATH does not exist or the dataloader file is not found.

Source code in test.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def dataloader(self):
    """
    Loads and returns the dataloader object from a pickled file in the PROCESSED_DATA_PATH.

    Returns:
        DataLoader: The dataloader containing the dataset.

    Raises:
        Exception: If the PROCESSED_DATA_PATH does not exist or the dataloader file is not found.
    """
    if os.path.exists(PROCESSED_DATA_PATH):
        return load(os.path.join(PROCESSED_DATA_PATH, "dataloader.pkl"))
    else:
        raise Exception("Processed data path is not found".title())

image_normalized(image)

Normalizes an image tensor to have values in the range [0, 1].

Parameters:
  • image (Tensor) –

    The image tensor to be normalized.

Returns:
  • numpy.ndarray: The normalized image as a numpy array.

Source code in test.py
111
112
113
114
115
116
117
118
119
120
121
122
def image_normalized(self, image):
    """
    Normalizes an image tensor to have values in the range [0, 1].

    Parameters:
        image (torch.Tensor): The image tensor to be normalized.

    Returns:
        numpy.ndarray: The normalized image as a numpy array.
    """
    image = image.cpu().detach().numpy()
    return (image - image.min()) / (image.max() - image.min())

plot()

Plots comparison images for each set of noisy, clean, and predicted images in the test dataset. Saves a comparison figure to TEST_IMAGES_PATH and a GIF of training images to GIF_IMAGE_PATH.

Raises:
  • Exception

    If the TEST_IMAGES_PATH does not exist.

Source code in test.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
def plot(self):
    """
    Plots comparison images for each set of noisy, clean, and predicted images in the test dataset.
    Saves a comparison figure to TEST_IMAGES_PATH and a GIF of training images to GIF_IMAGE_PATH.

    Raises:
        Exception: If the TEST_IMAGES_PATH does not exist.
    """
    try:
        self.model = DnCNN().to(self.device)
    except Exception as e:
        print("The exception in the model is:", e)
    else:
        self.model.load_state_dict(self.select_best_model())

    finally:
        self.test_dataloader = self.dataloader()

        noise_images, clean_images = next(iter(self.test_dataloader))

        noise_images = noise_images.to(self.device)
        clean_images = clean_images.to(self.device)

        predicted_images = self.model(noise_images)

    plt.figure(figsize=(36, 24))

    for index, image in enumerate(predicted_images):
        noisy = noise_images[index].permute(1, 2, 0)
        noisy = self.image_normalized(noisy)

        plt.subplot(3 * 5, 3 * 8, 3 * index + 1)
        plt.imshow(noisy, cmap="gray")
        plt.title("Noisy")
        plt.axis("off")

        clean = clean_images[index].permute(1, 2, 0)
        clean = self.image_normalized(clean)

        plt.subplot(3 * 5, 3 * 8, 3 * index + 2)
        plt.imshow(clean, cmap="gray")
        plt.title("Clean")
        plt.axis("off")

        predicted = image.permute(1, 2, 0)
        predicted = self.image_normalized(predicted)

        plt.subplot(3 * 5, 3 * 8, 3 * index + 3)
        plt.imshow(predicted, cmap="gray")
        plt.title("Predict")
        plt.axis("off")

    plt.tight_layout()

    if os.path.exists(TEST_IMAGES_PATH):
        plt.savefig(os.path.join(TEST_IMAGES_PATH, "test.png"))

        try:
            self.create_gif()
        except Exception as e:
            print("The exception in the gif is:", e)
    else:
        raise Exception("Test images path is not found".title())

    plt.show()

select_best_model()

Selects the best model based on minimum loss from models saved in the BEST_MODELS_PATH.

Returns:
  • torch.nn.Module: The best performing model based on loss.

Raises:
  • Exception

    If the BEST_MODELS_PATH does not exist.

Source code in test.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def select_best_model(self):
    """
    Selects the best model based on minimum loss from models saved in the BEST_MODELS_PATH.

    Returns:
        torch.nn.Module: The best performing model based on loss.

    Raises:
        Exception: If the BEST_MODELS_PATH does not exist.
    """
    if os.path.exists(BEST_MODELS_PATH):

        models = os.listdir(BEST_MODELS_PATH)
        for model in models:
            if (
                self.infinity
                > torch.load(os.path.join(BEST_MODELS_PATH, model))["loss"]
            ):
                self.infinity = torch.load(os.path.join(BEST_MODELS_PATH, model))[
                    "loss"
                ]
                self.best_model = torch.load(os.path.join(BEST_MODELS_PATH, model))[
                    "model"
                ]
        self.details["best_model"].append(self.best_model)
        self.details["best_loss"].append(self.infinity)

        return self.best_model
    else:
        raise Exception("Best model path is not found".title())