DualGAN model

Defines the DualGAN model architecture.

We use the models that were introduced in the DualGAN paper. The original implementation is here.

Generator


source

weights_init_normal

 weights_init_normal (m)

source

UNetUp

 UNetUp (in_size, out_size, dropout=0.0)

Expanding layers of the Unet used in DualGAN


source

UNetDown

 UNetDown (in_size, out_size, normalize=True, dropout=0.0)

Contracting layers of the Unet used in DualGAN


source

DualGANGenerator

 DualGANGenerator (channels=3)

Generator model for the DualGAN

Test generator

Let’s test for a few things: 1. The generator can indeed be initialized correctly 2. A random image can be passed into the model successfully with the correct size output

First let’s create a random batch:

img1 = torch.randn(4,3,256,256)
m = DualGANGenerator(3)
with torch.no_grad():
    out1 = m(img1)
out1.shape
torch.Size([4, 3, 256, 256])
test_eq(out1.shape, torch.Size([4, 3, 256, 256]))

Discriminator

As described in the DualGAN paper, we will use a 70x70 PatchGAN, the same discriminator for the CycleGAN.

D = discriminator(3)
print(D)
Sequential(
  (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (1): LeakyReLU(negative_slope=0.2, inplace=True)
  (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (4): LeakyReLU(negative_slope=0.2, inplace=True)
  (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (7): LeakyReLU(negative_slope=0.2, inplace=True)
  (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (10): LeakyReLU(negative_slope=0.2, inplace=True)
  (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
)

source

DualGAN

 DualGAN (ch_in:int=3, n_features:int=64, disc_layers:int=3,
          lsgan:bool=False, drop:float=0.0,
          norm_layer:torch.nn.modules.module.Module=None)

DualGAN model.

When called, takes in input batch of real images from both domains and outputs fake images for the opposite domains (with the generators).

Attributes:

G_A (nn.Module): takes real input B and generates fake input A

G_B (nn.Module): takes real input A and generates fake input B

D_A (nn.Module): trained to make the difference between real input A and fake input A

D_B (nn.Module): trained to make the difference between real input B and fake input B


source

DualGAN.__init__

 DualGAN.__init__ (ch_in:int=3, n_features:int=64, disc_layers:int=3,
                   lsgan:bool=False, drop:float=0.0,
                   norm_layer:torch.nn.modules.module.Module=None)

Constructor for DualGAN model.

Arguments:

ch_in (int): Number of input channels (default=3)

n_features (int): Number of input features (default=64)

disc_layers (int): Number of discriminator layers (default=3)

lsgan (bool): LSGAN training objective (output unnormalized float) or not? (default=True)

norm_layer (nn.Module): Type of normalization layer to use in the models (default=None)


source

DualGAN.forward

 DualGAN.forward (input)

Forward function for DualGAN model. The input is a tuple of a batch of real images from both domains A and B.

Quick model tests

Again, let’s check that the model can be called sucsessfully and outputs the correct shapes.

dualgan_model = DualGAN()
img1 = torch.randn(4,3,256,256)
img2 = torch.randn(4,3,256,256)
with torch.no_grad(): dualgan_output = dualgan_model((img1,img2))
CPU times: user 18.9 s, sys: 1.57 s, total: 20.5 s
Wall time: 447 ms
test_eq(len(dualgan_output),4)
for output_batch in dualgan_output:
    test_eq(output_batch.shape,img1.shape)
dualgan_model.push_to_hub('upit-dualgan-test')
Cloning https://huggingface.co/tmabraham/upit-dualgan-test into local empty directory.
To https://huggingface.co/tmabraham/upit-dualgan-test
   dccaa0f..f8d92db  main -> main
'https://huggingface.co/tmabraham/upit-dualgan-test/commit/f8d92db7854429ca64335e9ab698d7e7f2f44feb'
dualgan_model.from_pretrained('tmabraham/upit-dualgan-test')