Test

This class is designed for testing and visualizing synthetic images generated by the Generator model.

Attributes:
  • image_size (int) –

    The size of the generated images (height and width).

  • num_samples (int) –

    The number of synthetic images to generate.

  • latent_space (int) –

    The dimensionality of the latent space for the generator.

  • device (device) –

    The device (CPU, GPU, etc.) on which the generator model will be loaded and run.

  • generator (Generator) –

    The generator model for synthetic image generation.

Methods:

Name Description
get_the_best_model

Loads the best performing generator model from a predefined directory.

test

Generates and visualizes synthetic images using the generator model.

Source code in test.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
class Test:
    """
    This class is designed for testing and visualizing synthetic images generated by the Generator model.

    Attributes:
        image_size (int): The size of the generated images (height and width).
        num_samples (int): The number of synthetic images to generate.
        latent_space (int): The dimensionality of the latent space for the generator.
        device (torch.device): The device (CPU, GPU, etc.) on which the generator model will be loaded and run.
        generator (Generator): The generator model for synthetic image generation.

    Methods:
        get_the_best_model: Loads the best performing generator model from a predefined directory.
        test: Generates and visualizes synthetic images using the generator model.
    """

    def __init__(self, num_samples=20, image_size=64, latent_space=100, device="cpu"):
        """
        Initializes the Test class with the specified configuration.

        Args:
            num_samples (int): Number of synthetic images to generate for testing.
            image_size (int): The height and width of the images to generate (assumed square).
            latent_space (int): Dimensionality of the latent space for the generator.
            device (str): The device to run the generator model on ('cpu', 'cuda', 'mps', etc.).
        """
        self.image_size = image_size
        self.num_samples = num_samples
        self.latent_space = latent_space
        self.device = device_init(device=device)
        self.generator = Generator(
            image_size=self.image_size, latent_space=self.latent_space
        ).to(self.device)

    def get_the_best_model(self):
        """
        Searches for and loads the best performing generator model from a specified directory.

        Raises:
            Exception: If no model is found in the specified directory.
        """
        model_path = "./models/best_model/"

        if os.path.exists(model_path):
            logging.info("Best model found at {}".format(model_path).title())

            best_model = os.listdir(model_path)[0]
            load_state_dict = torch.load(os.path.join(model_path, best_model))
            self.generator.load_state_dict(load_state_dict)

        else:
            raise Exception("No Best Model Found".title())

    def test(self):
        """
        Generates synthetic images using the best generator model and visualizes them.

        This method also attempts to save the generated images to a file and logs any exceptions encountered.
        """
        try:
            self.get_the_best_model()
            random_noise = torch.randn(self.num_samples, self.latent_space, 1, 1).to(
                self.device
            )
        except Exception as e:
            logging.exception("Exception caught: {}".format(e))
            print("Exception caught: {}".format(e))
        else:
            plt.figure(figsize=(10, 10))
            synthetic_images = self.generator(random_noise)

            for index, image in enumerate(synthetic_images):
                plt.subplot(4, 5, index + 1)
                image_to_plot = image.cpu().detach().permute(1, 2, 0).numpy()
                image_to_plot = (image_to_plot - image_to_plot.min()) / (
                    image_to_plot.max() - image_to_plot.min()
                )
                plt.imshow(image_to_plot)
                plt.axis("off")

            try:
                plt.savefig("./outputs/fake_image.png")
            except Exception as e:
                logging.exception("Exception caught: {}".format(e))

            plt.show()

__init__(num_samples=20, image_size=64, latent_space=100, device='cpu')

Initializes the Test class with the specified configuration.

Parameters:
  • num_samples (int, default: 20 ) –

    Number of synthetic images to generate for testing.

  • image_size (int, default: 64 ) –

    The height and width of the images to generate (assumed square).

  • latent_space (int, default: 100 ) –

    Dimensionality of the latent space for the generator.

  • device (str, default: 'cpu' ) –

    The device to run the generator model on ('cpu', 'cuda', 'mps', etc.).

Source code in test.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def __init__(self, num_samples=20, image_size=64, latent_space=100, device="cpu"):
    """
    Initializes the Test class with the specified configuration.

    Args:
        num_samples (int): Number of synthetic images to generate for testing.
        image_size (int): The height and width of the images to generate (assumed square).
        latent_space (int): Dimensionality of the latent space for the generator.
        device (str): The device to run the generator model on ('cpu', 'cuda', 'mps', etc.).
    """
    self.image_size = image_size
    self.num_samples = num_samples
    self.latent_space = latent_space
    self.device = device_init(device=device)
    self.generator = Generator(
        image_size=self.image_size, latent_space=self.latent_space
    ).to(self.device)

get_the_best_model()

Searches for and loads the best performing generator model from a specified directory.

Raises:
  • Exception

    If no model is found in the specified directory.

Source code in test.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def get_the_best_model(self):
    """
    Searches for and loads the best performing generator model from a specified directory.

    Raises:
        Exception: If no model is found in the specified directory.
    """
    model_path = "./models/best_model/"

    if os.path.exists(model_path):
        logging.info("Best model found at {}".format(model_path).title())

        best_model = os.listdir(model_path)[0]
        load_state_dict = torch.load(os.path.join(model_path, best_model))
        self.generator.load_state_dict(load_state_dict)

    else:
        raise Exception("No Best Model Found".title())

test()

Generates synthetic images using the best generator model and visualizes them.

This method also attempts to save the generated images to a file and logs any exceptions encountered.

Source code in test.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def test(self):
    """
    Generates synthetic images using the best generator model and visualizes them.

    This method also attempts to save the generated images to a file and logs any exceptions encountered.
    """
    try:
        self.get_the_best_model()
        random_noise = torch.randn(self.num_samples, self.latent_space, 1, 1).to(
            self.device
        )
    except Exception as e:
        logging.exception("Exception caught: {}".format(e))
        print("Exception caught: {}".format(e))
    else:
        plt.figure(figsize=(10, 10))
        synthetic_images = self.generator(random_noise)

        for index, image in enumerate(synthetic_images):
            plt.subplot(4, 5, index + 1)
            image_to_plot = image.cpu().detach().permute(1, 2, 0).numpy()
            image_to_plot = (image_to_plot - image_to_plot.min()) / (
                image_to_plot.max() - image_to_plot.min()
            )
            plt.imshow(image_to_plot)
            plt.axis("off")

        try:
            plt.savefig("./outputs/fake_image.png")
        except Exception as e:
            logging.exception("Exception caught: {}".format(e))

        plt.show()