Skip to content

Commit a7b8cad

Browse files
authored
Update Readme.md
1 parent 8979907 commit a7b8cad

File tree

1 file changed

+68
-0
lines changed
  • Chapter-wise code/Code - PyTorch/4. Generative Adversarial Networks (GANs)/3. Cyclic GANs

1 file changed

+68
-0
lines changed
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,69 @@
1+
# Cyclic GANs
2+
3+
In this notebook, we're going to define and train a CycleGAN to read in an image from a set 𝑋 and transform it so that it looks as if it belongs in set 𝑌 . Specifically, we'll look at a set of images of Yosemite national park taken either during the summer of winter. The seasons are our two domains!
4+
5+
## Algorithm
6+
7+
1. Get data, pre-process it and create data-loaders.
8+
2. Define the CycleGAN model : Discriminator & Generator
9+
* **Discriminator**
10+
* This network sees a 128x128x3 image, and passes it through 5 convolutional layers that downsample the image by a factor of 2.
11+
* The first four convolutional layers have a BatchNorm and ReLu activation function applied to their output
12+
* The last acts as a classification layer that outputs one value.
13+
* **Generator**
14+
* There are 2 generators : G_XtoY and G_YtoX
15+
* These 2 generators are made of : Encoder and Decoder
16+
* **Encdoer :** A conv net that is responsible for turning an image into a smaller feature representation.
17+
* **Decoder :** A `transpose_conv net` that is responsible for turning that representation into an transformed image.
18+
* **Residual Block :** Connects encoder and decoder parts.
19+
3. Compute the Generator and Discriminator and Cycle-Consistence Loss (determines how good a reconstructed image is, when compared to an original image.)
20+
```python
21+
def real_mse_loss(D_out):
22+
return torch.mean((D_out-1)**2)
23+
24+
def fake_mse_loss(D_out):
25+
return torch.mean(D_out**2)
26+
27+
def cycle_consistency_loss(real_im, reconstructed_im, lambda_weight):
28+
reconstr_loss = torch.mean(torch.abs(real_im - reconstructed_im))
29+
return lambda_weight * reconstr_loss
30+
```
31+
4. Define the optimizer.
32+
5. Training the cycleGAN : <br>
33+
* **Training the Discriminators**
34+
* Compute the discriminator 𝐷𝑋 loss on real images
35+
* Generate fake images that look like domain 𝑋 based on real images in domain 𝑌
36+
* Compute the fake loss for 𝐷𝑋
37+
* Compute the total loss and perform backpropagation and 𝐷𝑋 optimization
38+
* Repeat steps 1-4 only with 𝐷𝑌 and your domains switched!
39+
40+
* **Training the Generators**
41+
* Generate fake images that look like domain 𝑋 based on real images in domain 𝑌
42+
* Compute the generator loss based on how 𝐷𝑋 responds to fake 𝑋
43+
* Generate reconstructed 𝑌̂ images based on the fake 𝑋 images generated in step 1
44+
* Compute the cycle consistency loss by comparing the reconstructions with real 𝑌 images
45+
* Repeat steps 1-4 only swapping domains
46+
* Add up all the generator and reconstruction losses and perform backpropagation + optimization
47+
48+
6. Compute training loss and save samples from generator
49+
50+
## Results
51+
52+
1. Training loss of generator and discriminator are as follows - <br>
53+
* Discriminator loss : 0.3021 (X) , 0.3054 (Y)
54+
* Generator Loss : 2.4846
55+
56+
2. Training loss plotted v/s no.of epochs trained - <br>
57+
<img src="./images/training_loss.png"></img>
58+
59+
3. Final results of generating CycleGAN :<br>
60+
61+
### Image-to-image translation (Converting summer images to winter images)
62+
<img src="./images/summer_to_winter.png"></img><br>
63+
64+
65+
### Image-to-image translation (Converting winter images to summer images)
66+
<img src="./images/winter_to_summer.png"></img><br>
67+
68+
169

0 commit comments

Comments
 (0)