Caution added

Esse commit está contido em:
Yunjey Choi
2019-04-22 18:07:17 +09:00
commit de GitHub
commit 6faf5e234f
+3 -1
Ver Arquivo
@@ -54,6 +54,8 @@ class Generator(nn.Module):
def forward(self, x, c):
# Replicate spatially and concatenate domain information.
# Note that this type of label conditioning does not work at all if we use reflection padding in Conv2d.
# This is because instance normalization ignores the shifting (or bias) effect.
c = c.view(c.size(0), c.size(1), 1, 1)
c = c.repeat(1, 1, x.size(2), x.size(3))
x = torch.cat([x, c], dim=1)
@@ -83,4 +85,4 @@ class Discriminator(nn.Module):
h = self.main(x)
out_src = self.conv1(h)
out_cls = self.conv2(h)
return out_src, out_cls.view(out_cls.size(0), out_cls.size(1))
return out_src, out_cls.view(out_cls.size(0), out_cls.size(1))