Clean the code & Add some features and annotations
Esse commit está contido em:
+59
-78
@@ -1,111 +1,92 @@
|
||||
from torch.utils import data
|
||||
from torchvision import transforms as T
|
||||
from torchvision.datasets import ImageFolder
|
||||
from PIL import Image
|
||||
import torch
|
||||
import os
|
||||
import random
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import ImageFolder
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class CelebDataset(Dataset):
|
||||
def __init__(self, image_path, metadata_path, transform, mode):
|
||||
self.image_path = image_path
|
||||
class CelebA(data.Dataset):
|
||||
"""Dataset class for the CelebA dataset."""
|
||||
|
||||
def __init__(self, image_dir, attr_path, selected_attrs, transform, mode):
|
||||
"""Initialize and preprocess the CelebA dataset."""
|
||||
self.image_dir = image_dir
|
||||
self.attr_path = attr_path
|
||||
self.selected_attrs = selected_attrs
|
||||
self.transform = transform
|
||||
self.mode = mode
|
||||
self.lines = open(metadata_path, 'r').readlines()
|
||||
self.num_data = int(self.lines[0])
|
||||
self.train_dataset = []
|
||||
self.test_dataset = []
|
||||
self.attr2idx = {}
|
||||
self.idx2attr = {}
|
||||
|
||||
print ('Start preprocessing dataset..!')
|
||||
random.seed(1234)
|
||||
self.preprocess()
|
||||
print ('Finished preprocessing dataset..!')
|
||||
|
||||
if self.mode == 'train':
|
||||
self.num_data = len(self.train_filenames)
|
||||
elif self.mode == 'test':
|
||||
self.num_data = len(self.test_filenames)
|
||||
if mode == 'train':
|
||||
self.num_images = len(self.train_dataset)
|
||||
else:
|
||||
self.num_images = len(self.test_dataset)
|
||||
|
||||
def preprocess(self):
|
||||
attrs = self.lines[1].split()
|
||||
for i, attr in enumerate(attrs):
|
||||
self.attr2idx[attr] = i
|
||||
self.idx2attr[i] = attr
|
||||
"""Preprocess the CelebA attribute file."""
|
||||
lines = [line.rstrip() for line in open(self.attr_path, 'r')]
|
||||
all_attr_names = lines[1].split()
|
||||
for i, attr_name in enumerate(all_attr_names):
|
||||
self.attr2idx[attr_name] = i
|
||||
self.idx2attr[i] = attr_name
|
||||
|
||||
self.selected_attrs = ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']
|
||||
self.train_filenames = []
|
||||
self.train_labels = []
|
||||
self.test_filenames = []
|
||||
self.test_labels = []
|
||||
|
||||
lines = self.lines[2:]
|
||||
random.shuffle(lines) # random shuffling
|
||||
lines = lines[2:]
|
||||
random.seed(1234)
|
||||
random.shuffle(lines)
|
||||
for i, line in enumerate(lines):
|
||||
|
||||
splits = line.split()
|
||||
filename = splits[0]
|
||||
values = splits[1:]
|
||||
split = line.split()
|
||||
filename = split[0]
|
||||
values = split[1:]
|
||||
|
||||
label = []
|
||||
for idx, value in enumerate(values):
|
||||
attr = self.idx2attr[idx]
|
||||
|
||||
if attr in self.selected_attrs:
|
||||
if value == '1':
|
||||
label.append(1)
|
||||
else:
|
||||
label.append(0)
|
||||
for attr_name in self.selected_attrs:
|
||||
idx = self.attr2idx[attr_name]
|
||||
label.append(values[idx] == '1')
|
||||
|
||||
if (i+1) < 2000:
|
||||
self.test_filenames.append(filename)
|
||||
self.test_labels.append(label)
|
||||
self.test_dataset.append([filename, label])
|
||||
else:
|
||||
self.train_filenames.append(filename)
|
||||
self.train_labels.append(label)
|
||||
self.train_dataset.append([filename, label])
|
||||
|
||||
print('Finished preprocessing the CelebA dataset...')
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.mode == 'train':
|
||||
image = Image.open(os.path.join(self.image_path, self.train_filenames[index]))
|
||||
label = self.train_labels[index]
|
||||
elif self.mode in ['test']:
|
||||
image = Image.open(os.path.join(self.image_path, self.test_filenames[index]))
|
||||
label = self.test_labels[index]
|
||||
|
||||
"""Return one image and its corresponding attribute label."""
|
||||
dataset = self.train_dataset if self.mode == 'train' else self.test_dataset
|
||||
filename, label = dataset[index]
|
||||
image = Image.open(os.path.join(self.image_dir, filename))
|
||||
return self.transform(image), torch.FloatTensor(label)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_data
|
||||
"""Return the number of images."""
|
||||
return self.num_images
|
||||
|
||||
|
||||
def get_loader(image_path, metadata_path, crop_size, image_size, batch_size, dataset='CelebA', mode='train'):
|
||||
"""Build and return data loader."""
|
||||
|
||||
def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=128,
|
||||
batch_size=16, dataset='CelebA', mode='train', num_workers=1):
|
||||
"""Build and return a data loader."""
|
||||
transform = []
|
||||
if mode == 'train':
|
||||
transform = transforms.Compose([
|
||||
transforms.CenterCrop(crop_size),
|
||||
transforms.Resize(image_size, interpolation=Image.ANTIALIAS),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||
else:
|
||||
transform = transforms.Compose([
|
||||
transforms.CenterCrop(crop_size),
|
||||
transforms.Scale(image_size, interpolation=Image.ANTIALIAS),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||
transform.append(T.RandomHorizontalFlip())
|
||||
transform.append(T.CenterCrop(crop_size))
|
||||
transform.append(T.Resize(image_size))
|
||||
transform.append(T.ToTensor())
|
||||
transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
|
||||
transform = T.Compose(transform)
|
||||
|
||||
if dataset == 'CelebA':
|
||||
dataset = CelebDataset(image_path, metadata_path, transform, mode)
|
||||
dataset = CelebA(image_dir, attr_path, selected_attrs, transform, mode)
|
||||
elif dataset == 'RaFD':
|
||||
dataset = ImageFolder(image_path, transform)
|
||||
dataset = ImageFolder(image_dir, transform)
|
||||
|
||||
shuffle = False
|
||||
if mode == 'train':
|
||||
shuffle = True
|
||||
|
||||
data_loader = DataLoader(dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle)
|
||||
data_loader = data.DataLoader(dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=(mode=='train'),
|
||||
num_workers=num_workers)
|
||||
return data_loader
|
||||
+5
-62
@@ -1,71 +1,14 @@
|
||||
# Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import scipy.misc
|
||||
try:
|
||||
from StringIO import StringIO # Python 2.7
|
||||
except ImportError:
|
||||
from io import BytesIO # Python 3.5+
|
||||
|
||||
|
||||
class Logger(object):
|
||||
|
||||
"""Tensorboard logger."""
|
||||
|
||||
def __init__(self, log_dir):
|
||||
"""Create a summary writer logging to log_dir."""
|
||||
"""Initialize summary writer."""
|
||||
self.writer = tf.summary.FileWriter(log_dir)
|
||||
|
||||
def scalar_summary(self, tag, value, step):
|
||||
"""Log a scalar variable."""
|
||||
"""Add scalar summary."""
|
||||
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
|
||||
self.writer.add_summary(summary, step)
|
||||
|
||||
def image_summary(self, tag, images, step):
|
||||
"""Log a list of images."""
|
||||
|
||||
img_summaries = []
|
||||
for i, img in enumerate(images):
|
||||
# Write the image to a string
|
||||
try:
|
||||
s = StringIO()
|
||||
except:
|
||||
s = BytesIO()
|
||||
scipy.misc.toimage(img).save(s, format="png")
|
||||
|
||||
# Create an Image object
|
||||
img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
|
||||
height=img.shape[0],
|
||||
width=img.shape[1])
|
||||
# Create a Summary value
|
||||
img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))
|
||||
|
||||
# Create and write Summary
|
||||
summary = tf.Summary(value=img_summaries)
|
||||
self.writer.add_summary(summary, step)
|
||||
|
||||
def histo_summary(self, tag, values, step, bins=1000):
|
||||
"""Log a histogram of the tensor of values."""
|
||||
|
||||
# Create a histogram using numpy
|
||||
counts, bin_edges = np.histogram(values, bins=bins)
|
||||
|
||||
# Fill the fields of the histogram proto
|
||||
hist = tf.HistogramProto()
|
||||
hist.min = float(np.min(values))
|
||||
hist.max = float(np.max(values))
|
||||
hist.num = int(np.prod(values.shape))
|
||||
hist.sum = float(np.sum(values))
|
||||
hist.sum_squares = float(np.sum(values**2))
|
||||
|
||||
# Drop the start of the first bin
|
||||
bin_edges = bin_edges[1:]
|
||||
|
||||
# Add bin edges and counts
|
||||
for edge in bin_edges:
|
||||
hist.bucket_limit.append(edge)
|
||||
for c in counts:
|
||||
hist.bucket.append(c)
|
||||
|
||||
# Create and write Summary
|
||||
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
|
||||
self.writer.add_summary(summary, step)
|
||||
self.writer.flush()
|
||||
self.writer.add_summary(summary, step)
|
||||
+65
-61
@@ -9,32 +9,35 @@ def str2bool(v):
|
||||
return v.lower() in ('true')
|
||||
|
||||
def main(config):
|
||||
# For fast training
|
||||
# For fast training.
|
||||
cudnn.benchmark = True
|
||||
|
||||
# Create directories if not exist
|
||||
if not os.path.exists(config.log_path):
|
||||
os.makedirs(config.log_path)
|
||||
if not os.path.exists(config.model_save_path):
|
||||
os.makedirs(config.model_save_path)
|
||||
if not os.path.exists(config.sample_path):
|
||||
os.makedirs(config.sample_path)
|
||||
if not os.path.exists(config.result_path):
|
||||
os.makedirs(config.result_path)
|
||||
# Create directories if not exist.
|
||||
if not os.path.exists(config.log_dir):
|
||||
os.makedirs(config.log_dir)
|
||||
if not os.path.exists(config.model_save_dir):
|
||||
os.makedirs(config.model_save_dir)
|
||||
if not os.path.exists(config.sample_dir):
|
||||
os.makedirs(config.sample_dir)
|
||||
if not os.path.exists(config.result_dir):
|
||||
os.makedirs(config.result_dir)
|
||||
|
||||
# Data loader
|
||||
celebA_loader = None
|
||||
# Data loader.
|
||||
celeba_loader = None
|
||||
rafd_loader = None
|
||||
|
||||
if config.dataset in ['CelebA', 'Both']:
|
||||
celebA_loader = get_loader(config.celebA_image_path, config.metadata_path, config.celebA_crop_size,
|
||||
config.image_size, config.batch_size, 'CelebA', config.mode)
|
||||
celeba_loader = get_loader(config.celeba_image_dir, config.attr_path, config.selected_attrs,
|
||||
config.celeba_crop_size, config.image_size, config.batch_size,
|
||||
'CelebA', config.mode, config.num_workers)
|
||||
if config.dataset in ['RaFD', 'Both']:
|
||||
rafd_loader = get_loader(config.rafd_image_path, None, config.rafd_crop_size,
|
||||
config.image_size, config.batch_size, 'RaFD', config.mode)
|
||||
rafd_loader = get_loader(config.rafd_image_dir, None, None,
|
||||
config.rafd_crop_size, config.image_size, config.batch_size,
|
||||
'RaFD', config.mode, config.num_workers)
|
||||
|
||||
|
||||
# Solver
|
||||
solver = Solver(celebA_loader, rafd_loader, config)
|
||||
# Solver for training and testing StarGAN.
|
||||
solver = Solver(celeba_loader, rafd_loader, config)
|
||||
|
||||
if config.mode == 'train':
|
||||
if config.dataset in ['CelebA', 'RaFD']:
|
||||
@@ -51,55 +54,56 @@ def main(config):
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Model hyper-parameters
|
||||
parser.add_argument('--c_dim', type=int, default=5)
|
||||
parser.add_argument('--c2_dim', type=int, default=8)
|
||||
parser.add_argument('--celebA_crop_size', type=int, default=178)
|
||||
parser.add_argument('--rafd_crop_size', type=int, default=256)
|
||||
parser.add_argument('--image_size', type=int, default=128)
|
||||
parser.add_argument('--g_conv_dim', type=int, default=64)
|
||||
parser.add_argument('--d_conv_dim', type=int, default=64)
|
||||
parser.add_argument('--g_repeat_num', type=int, default=6)
|
||||
parser.add_argument('--d_repeat_num', type=int, default=6)
|
||||
parser.add_argument('--g_lr', type=float, default=0.0001)
|
||||
parser.add_argument('--d_lr', type=float, default=0.0001)
|
||||
parser.add_argument('--lambda_cls', type=float, default=1)
|
||||
parser.add_argument('--lambda_rec', type=float, default=10)
|
||||
parser.add_argument('--lambda_gp', type=float, default=10)
|
||||
parser.add_argument('--d_train_repeat', type=int, default=5)
|
||||
|
||||
# Training settings
|
||||
# Model configuration.
|
||||
parser.add_argument('--c_dim', type=int, default=5, help='dimension of domain labels (1st dataset)')
|
||||
parser.add_argument('--c2_dim', type=int, default=8, help='dimension of domain labels (2nd dataset)')
|
||||
parser.add_argument('--celeba_crop_size', type=int, default=178, help='crop size for the CelebA dataset')
|
||||
parser.add_argument('--rafd_crop_size', type=int, default=256, help='crop size for the RaFD dataset')
|
||||
parser.add_argument('--image_size', type=int, default=128, help='image resolution')
|
||||
parser.add_argument('--g_conv_dim', type=int, default=64, help='number of conv filters in the first layer of G')
|
||||
parser.add_argument('--d_conv_dim', type=int, default=64, help='number of conv filters in the first layer of D')
|
||||
parser.add_argument('--g_repeat_num', type=int, default=6, help='number of residual blocks in G')
|
||||
parser.add_argument('--d_repeat_num', type=int, default=6, help='number of strided conv layers in D')
|
||||
parser.add_argument('--lambda_cls', type=float, default=1, help='weight for domain classification loss')
|
||||
parser.add_argument('--lambda_rec', type=float, default=10, help='weight for reconstruction loss')
|
||||
parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty')
|
||||
|
||||
# Training configuration.
|
||||
parser.add_argument('--dataset', type=str, default='CelebA', choices=['CelebA', 'RaFD', 'Both'])
|
||||
parser.add_argument('--num_epochs', type=int, default=20)
|
||||
parser.add_argument('--num_epochs_decay', type=int, default=10)
|
||||
parser.add_argument('--num_iters', type=int, default=200000)
|
||||
parser.add_argument('--num_iters_decay', type=int, default=100000)
|
||||
parser.add_argument('--batch_size', type=int, default=16)
|
||||
parser.add_argument('--batch_size', type=int, default=16, help='mini-batch size')
|
||||
parser.add_argument('--num_iters', type=int, default=200000, help='number of total iterations for training D')
|
||||
parser.add_argument('--num_iters_decay', type=int, default=100000, help='number of iterations for decaying lr')
|
||||
parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for G')
|
||||
parser.add_argument('--d_lr', type=float, default=0.0001, help='learning rate for D')
|
||||
parser.add_argument('--n_critic', type=int, default=5, help='number of D updates per each G update')
|
||||
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer')
|
||||
parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer')
|
||||
parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this step')
|
||||
parser.add_argument('--selected_attrs', '--list', nargs='+', help='selected attributes for the CelebA dataset',
|
||||
default=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'])
|
||||
|
||||
# Test configuration.
|
||||
parser.add_argument('--test_iters', type=int, default=200000, help='test model from this step')
|
||||
|
||||
# Miscellaneous.
|
||||
parser.add_argument('--num_workers', type=int, default=1)
|
||||
parser.add_argument('--beta1', type=float, default=0.5)
|
||||
parser.add_argument('--beta2', type=float, default=0.999)
|
||||
parser.add_argument('--pretrained_model', type=str, default=None)
|
||||
|
||||
# Test settings
|
||||
parser.add_argument('--test_model', type=str, default='20_1000')
|
||||
|
||||
# Misc
|
||||
parser.add_argument('--mode', type=str, default='train', choices=['train', 'test'])
|
||||
parser.add_argument('--use_tensorboard', type=str2bool, default=False)
|
||||
parser.add_argument('--use_tensorboard', type=str2bool, default=True)
|
||||
|
||||
# Path
|
||||
parser.add_argument('--celebA_image_path', type=str, default='./data/CelebA_nocrop/images')
|
||||
parser.add_argument('--rafd_image_path', type=str, default='./data/RaFD/train')
|
||||
parser.add_argument('--metadata_path', type=str, default='./data/list_attr_celeba.txt')
|
||||
parser.add_argument('--log_path', type=str, default='./stargan/logs')
|
||||
parser.add_argument('--model_save_path', type=str, default='./stargan/models')
|
||||
parser.add_argument('--sample_path', type=str, default='./stargan/samples')
|
||||
parser.add_argument('--result_path', type=str, default='./stargan/results')
|
||||
# Directories.
|
||||
parser.add_argument('--celeba_image_dir', type=str, default='data/CelebA_nocrop/images')
|
||||
parser.add_argument('--attr_path', type=str, default='data/list_attr_celeba.txt')
|
||||
parser.add_argument('--rafd_image_dir', type=str, default='data/RaFD/train')
|
||||
parser.add_argument('--log_dir', type=str, default='stargan/logs')
|
||||
parser.add_argument('--model_save_dir', type=str, default='stargan/models')
|
||||
parser.add_argument('--sample_dir', type=str, default='stargan/samples')
|
||||
parser.add_argument('--result_dir', type=str, default='stargan/results')
|
||||
|
||||
# Step size
|
||||
# Step size.
|
||||
parser.add_argument('--log_step', type=int, default=10)
|
||||
parser.add_argument('--sample_step', type=int, default=500)
|
||||
parser.add_argument('--model_save_step', type=int, default=1000)
|
||||
parser.add_argument('--sample_step', type=int, default=1000)
|
||||
parser.add_argument('--model_save_step', type=int, default=10000)
|
||||
parser.add_argument('--lr_update_step', type=int, default=1000)
|
||||
|
||||
config = parser.parse_args()
|
||||
print(config)
|
||||
|
||||
+17
-18
@@ -5,7 +5,7 @@ import numpy as np
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
"""Residual Block."""
|
||||
"""Residual Block with instance normalization."""
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.main = nn.Sequential(
|
||||
@@ -20,7 +20,7 @@ class ResidualBlock(nn.Module):
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
"""Generator. Encoder-Decoder Architecture."""
|
||||
"""Generator network."""
|
||||
def __init__(self, conv_dim=64, c_dim=5, repeat_num=6):
|
||||
super(Generator, self).__init__()
|
||||
|
||||
@@ -29,7 +29,7 @@ class Generator(nn.Module):
|
||||
layers.append(nn.InstanceNorm2d(conv_dim, affine=True))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
|
||||
# Down-Sampling
|
||||
# Down-sampling layers.
|
||||
curr_dim = conv_dim
|
||||
for i in range(2):
|
||||
layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False))
|
||||
@@ -37,11 +37,11 @@ class Generator(nn.Module):
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
curr_dim = curr_dim * 2
|
||||
|
||||
# Bottleneck
|
||||
# Bottleneck layers.
|
||||
for i in range(repeat_num):
|
||||
layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))
|
||||
|
||||
# Up-Sampling
|
||||
# Up-sampling layers.
|
||||
for i in range(2):
|
||||
layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False))
|
||||
layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True))
|
||||
@@ -53,35 +53,34 @@ class Generator(nn.Module):
|
||||
self.main = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x, c):
|
||||
# replicate spatially and concatenate domain information
|
||||
c = c.unsqueeze(2).unsqueeze(3)
|
||||
c = c.expand(c.size(0), c.size(1), x.size(2), x.size(3))
|
||||
# Replicate spatially and concatenate domain information.
|
||||
c = c.view(c.size(0), c.size(1), 1, 1)
|
||||
c = c.repeat(1, 1, x.size(2), x.size(3))
|
||||
x = torch.cat([x, c], dim=1)
|
||||
return self.main(x)
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
"""Discriminator. PatchGAN."""
|
||||
"""Discriminator network with PatchGAN."""
|
||||
def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6):
|
||||
super(Discriminator, self).__init__()
|
||||
|
||||
layers = []
|
||||
layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
|
||||
layers.append(nn.LeakyReLU(0.01, inplace=True))
|
||||
layers.append(nn.LeakyReLU(0.01))
|
||||
|
||||
curr_dim = conv_dim
|
||||
for i in range(1, repeat_num):
|
||||
layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
|
||||
layers.append(nn.LeakyReLU(0.01, inplace=True))
|
||||
layers.append(nn.LeakyReLU(0.01))
|
||||
curr_dim = curr_dim * 2
|
||||
|
||||
k_size = int(image_size / np.power(2, repeat_num))
|
||||
kernel_size = int(image_size / np.power(2, repeat_num))
|
||||
self.main = nn.Sequential(*layers)
|
||||
self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=k_size, bias=False)
|
||||
|
||||
self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=kernel_size, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.main(x)
|
||||
out_real = self.conv1(h)
|
||||
out_aux = self.conv2(h)
|
||||
return out_real.squeeze(), out_aux.squeeze()
|
||||
out_src = self.conv1(h)
|
||||
out_cls = self.conv2(h)
|
||||
return out_src, out_cls.view(out_cls.size(0), out_cls.size(1))
|
||||
+428
-575
Diferenças do arquivo suprimidas por serem muito extensas
Carregar Diff
Referência em uma Nova Issue
Bloquear um usuário