|
| 1 | +# Cyclic GANs |
| 2 | + |
| 3 | +In this notebook, we're going to define and train a CycleGAN to read in an image from a set 𝑋 and transform it so that it looks as if it belongs in set 𝑌 . Specifically, we'll look at a set of images of Yosemite national park taken either during the summer of winter. The seasons are our two domains! |
| 4 | + |
| 5 | +## Algorithm |
| 6 | + |
| 7 | +1. Get data, pre-process it and create data-loaders. |
| 8 | +2. Define the CycleGAN model : Discriminator & Generator |
| 9 | + * **Discriminator** |
| 10 | + * This network sees a 128x128x3 image, and passes it through 5 convolutional layers that downsample the image by a factor of 2. |
| 11 | + * The first four convolutional layers have a BatchNorm and ReLu activation function applied to their output |
| 12 | + * The last acts as a classification layer that outputs one value. |
| 13 | + * **Generator** |
| 14 | + * There are 2 generators : G_XtoY and G_YtoX |
| 15 | + * These 2 generators are made of : Encoder and Decoder |
| 16 | + * **Encdoer :** A conv net that is responsible for turning an image into a smaller feature representation. |
| 17 | + * **Decoder :** A `transpose_conv net` that is responsible for turning that representation into an transformed image. |
| 18 | + * **Residual Block :** Connects encoder and decoder parts. |
| 19 | +3. Compute the Generator and Discriminator and Cycle-Consistence Loss (determines how good a reconstructed image is, when compared to an original image.) |
| 20 | +```python |
| 21 | +def real_mse_loss(D_out): |
| 22 | + return torch.mean((D_out-1)**2) |
| 23 | + |
| 24 | +def fake_mse_loss(D_out): |
| 25 | + return torch.mean(D_out**2) |
| 26 | + |
| 27 | +def cycle_consistency_loss(real_im, reconstructed_im, lambda_weight): |
| 28 | + reconstr_loss = torch.mean(torch.abs(real_im - reconstructed_im)) |
| 29 | + return lambda_weight * reconstr_loss |
| 30 | +``` |
| 31 | +4. Define the optimizer. |
| 32 | +5. Training the cycleGAN : <br> |
| 33 | + * **Training the Discriminators** |
| 34 | + * Compute the discriminator 𝐷𝑋 loss on real images |
| 35 | + * Generate fake images that look like domain 𝑋 based on real images in domain 𝑌 |
| 36 | + * Compute the fake loss for 𝐷𝑋 |
| 37 | + * Compute the total loss and perform backpropagation and 𝐷𝑋 optimization |
| 38 | + * Repeat steps 1-4 only with 𝐷𝑌 and your domains switched! |
| 39 | + |
| 40 | + * **Training the Generators** |
| 41 | + * Generate fake images that look like domain 𝑋 based on real images in domain 𝑌 |
| 42 | + * Compute the generator loss based on how 𝐷𝑋 responds to fake 𝑋 |
| 43 | + * Generate reconstructed 𝑌̂ images based on the fake 𝑋 images generated in step 1 |
| 44 | + * Compute the cycle consistency loss by comparing the reconstructions with real 𝑌 images |
| 45 | + * Repeat steps 1-4 only swapping domains |
| 46 | + * Add up all the generator and reconstruction losses and perform backpropagation + optimization |
| 47 | + |
| 48 | +6. Compute training loss and save samples from generator |
| 49 | + |
| 50 | +## Results |
| 51 | + |
| 52 | +1. Training loss of generator and discriminator are as follows - <br> |
| 53 | + * Discriminator loss : 0.3021 (X) , 0.3054 (Y) |
| 54 | + * Generator Loss : 2.4846 |
| 55 | + |
| 56 | +2. Training loss plotted v/s no.of epochs trained - <br> |
| 57 | +<img src="./images/training_loss.png"></img> |
| 58 | + |
| 59 | +3. Final results of generating CycleGAN :<br> |
| 60 | + |
| 61 | +### Image-to-image translation (Converting summer images to winter images) |
| 62 | +<img src="./images/summer_to_winter.png"></img><br> |
| 63 | + |
| 64 | + |
| 65 | +### Image-to-image translation (Converting winter images to summer images) |
| 66 | +<img src="./images/winter_to_summer.png"></img><br> |
| 67 | + |
| 68 | + |
1 | 69 |
|
0 commit comments