cli()

Main CLI for GAN Operations

This script serves as the main command-line interface (CLI) for various operations related to Generative Adversarial Networks (GANs). It integrates functionalities such as data loading, model training, and synthetic data generation.

Features:

  • Argument parsing for flexible configuration of operations like data loading, training, and synthetic image generation.
  • Facilitates downloading and loading of MNIST dataset.
  • Initiates the training of GAN models.
  • Generates synthetic images using a trained generator.

Usage:

Run the script from the command line with the desired arguments. For example: - python main_cli.py --download_mnist --batch_size 32 --epochs 100 --latent_space 100 --lr 0.0002 --samples 20

Run the script from the command line for synthetic with the desired arguments. For example: - python main_cli.py --samples 20 --latent_space 100 --test

Arguments:

  • --batch_size: Batch size for the DataLoader.
  • --download_mnist: Flag to download the MNIST dataset.
  • --epochs: Number of epochs for training.
  • --latent_space: Dimension of the latent space for the generator.
  • --lr: Learning rate for the optimizer.
  • --samples: Number of synthetic samples to generate.
  • --device: Train the model with CPU, GPU, MPS.
  • --critic_steps: Critic steps used to give the priority to the Critic rather Generator.
  • --display: Display the critic loss and generator loss in each iterations
Source code in cli.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def cli():
    """
    # Main CLI for GAN Operations

    This script serves as the main command-line interface (CLI) for various operations related to Generative Adversarial Networks (GANs). It integrates functionalities such as data loading, model training, and synthetic data generation.

    ## Features:
    - Argument parsing for flexible configuration of operations like data loading, training, and synthetic image generation.
    - Facilitates downloading and loading of MNIST dataset.
    - Initiates the training of GAN models.
    - Generates synthetic images using a trained generator.

    ## Usage:
    Run the script from the command line with the desired arguments. For example:
        - `python main_cli.py --download_mnist --batch_size 32 --epochs 100 --latent_space 100 --lr 0.0002 --samples 20`

    Run the script from the command line for synthetic with the desired arguments. For example:
       - `python main_cli.py --samples 20 --latent_space 100 --test`


    ## Arguments:
    - `--batch_size`: Batch size for the DataLoader.
    - `--download_mnist`: Flag to download the MNIST dataset.
    - `--epochs`: Number of epochs for training.
    - `--latent_space`: Dimension of the latent space for the generator.
    - `--lr`: Learning rate for the optimizer.
    - `--samples`: Number of synthetic samples to generate.
    - `--device`: Train the model with CPU, GPU, MPS.
    - `--critic_steps`: Critic steps used to give the priority to the Critic rather Generator.
    - `--display`: Display the critic loss and generator loss in each iterations

    """
    parser = argparse.ArgumentParser(description="Command line coding".title())
    parser.add_argument(
        "--batch_size",
        type=int,
        default=32,
        help="Batch size for the DataLoader".capitalize(),
    )
    parser.add_argument(
        "--download_mnist",
        action="store_true",
        help="Download Mnist dataset".capitalize(),
    )
    parser.add_argument(
        "--epochs", type=int, default=100, help="Number of epochs".capitalize()
    )
    parser.add_argument(
        "--latent_space", type=int, default=100, help="Latent size".capitalize()
    )
    parser.add_argument(
        "--lr", type=float, default=0.0002, help="Learning rate".capitalize()
    )
    parser.add_argument(
        "--samples",
        type=int,
        default=20,
        help="Number of samples to generate".capitalize(),
    )
    parser.add_argument(
        "--test", action="store_true", help="Run synthetic data tests".capitalize()
    )
    parser.add_argument(
        "--device", default=torch.device("cpu"), help="Device defined".capitalize()
    )
    parser.add_argument(
        "--critic_steps", type=int, default=5, help="Critic steps".capitalize()
    )
    parser.add_argument(
        "--display", default=True, help="Display steps of each training".capitalize()
    )

    args = parser.parse_args()

    if args.download_mnist:
        if (
            args.batch_size > 10
            and args.epochs
            and args.latent_space > 50
            and args.lr
            and args.device
            and args.critic_steps > 1
            and args.display
        ):
            loader = Loader(batch_size=args.batch_size)
            loader.create_loader(mnist_data=loader.download_mnist())

            trainer = Trainer(
                latent_space=args.latent_space,
                epochs=args.epochs,
                lr=args.lr,
                device=args.device,
                n_critic_step=args.critic_steps,
            )
            trainer.train_WGAN()
        else:
            raise Exception("Provide the arguments appropriate way".capitalize())

    if args.test:
        if args.samples % 2 == 0 and args.latent_space > 50:
            test = Test(num_samples=args.samples, latent_space=args.latent_space)
            test.plot_synthetic_image()
        else:
            raise Exception(
                "Please enter a valid number of samples and latent space".capitalize()
            )