From c53e59871bfb2c213118f260e4993f3ab959248f Mon Sep 17 00:00:00 2001 From: yunjey Date: Thu, 12 Apr 2018 12:36:57 +0900 Subject: [PATCH] Clean the code & Add some features and annotations --- data_loader.py | 137 +++---- logger.py | 67 +--- main.py | 126 +++--- model.py | 35 +- solver.py | 1003 +++++++++++++++++++++--------------------------- 5 files changed, 574 insertions(+), 794 deletions(-) diff --git a/data_loader.py b/data_loader.py index 108116c..d0c5eac 100644 --- a/data_loader.py +++ b/data_loader.py @@ -1,111 +1,92 @@ +from torch.utils import data +from torchvision import transforms as T +from torchvision.datasets import ImageFolder +from PIL import Image import torch import os import random -from torch.utils.data import Dataset -from torch.utils.data import DataLoader -from torchvision import transforms -from torchvision.datasets import ImageFolder -from PIL import Image -class CelebDataset(Dataset): - def __init__(self, image_path, metadata_path, transform, mode): - self.image_path = image_path +class CelebA(data.Dataset): + """Dataset class for the CelebA dataset.""" + + def __init__(self, image_dir, attr_path, selected_attrs, transform, mode): + """Initialize and preprocess the CelebA dataset.""" + self.image_dir = image_dir + self.attr_path = attr_path + self.selected_attrs = selected_attrs self.transform = transform self.mode = mode - self.lines = open(metadata_path, 'r').readlines() - self.num_data = int(self.lines[0]) + self.train_dataset = [] + self.test_dataset = [] self.attr2idx = {} self.idx2attr = {} - - print ('Start preprocessing dataset..!') - random.seed(1234) self.preprocess() - print ('Finished preprocessing dataset..!') - if self.mode == 'train': - self.num_data = len(self.train_filenames) - elif self.mode == 'test': - self.num_data = len(self.test_filenames) + if mode == 'train': + self.num_images = len(self.train_dataset) + else: + self.num_images = len(self.test_dataset) def preprocess(self): - attrs = self.lines[1].split() - for i, attr in enumerate(attrs): - self.attr2idx[attr] = i - self.idx2attr[i] = attr + """Preprocess the CelebA attribute file.""" + lines = [line.rstrip() for line in open(self.attr_path, 'r')] + all_attr_names = lines[1].split() + for i, attr_name in enumerate(all_attr_names): + self.attr2idx[attr_name] = i + self.idx2attr[i] = attr_name - self.selected_attrs = ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'] - self.train_filenames = [] - self.train_labels = [] - self.test_filenames = [] - self.test_labels = [] - - lines = self.lines[2:] - random.shuffle(lines) # random shuffling + lines = lines[2:] + random.seed(1234) + random.shuffle(lines) for i, line in enumerate(lines): - - splits = line.split() - filename = splits[0] - values = splits[1:] + split = line.split() + filename = split[0] + values = split[1:] label = [] - for idx, value in enumerate(values): - attr = self.idx2attr[idx] - - if attr in self.selected_attrs: - if value == '1': - label.append(1) - else: - label.append(0) + for attr_name in self.selected_attrs: + idx = self.attr2idx[attr_name] + label.append(values[idx] == '1') if (i+1) < 2000: - self.test_filenames.append(filename) - self.test_labels.append(label) + self.test_dataset.append([filename, label]) else: - self.train_filenames.append(filename) - self.train_labels.append(label) + self.train_dataset.append([filename, label]) + + print('Finished preprocessing the CelebA dataset...') def __getitem__(self, index): - if self.mode == 'train': - image = Image.open(os.path.join(self.image_path, self.train_filenames[index])) - label = self.train_labels[index] - elif self.mode in ['test']: - image = Image.open(os.path.join(self.image_path, self.test_filenames[index])) - label = self.test_labels[index] - + """Return one image and its corresponding attribute label.""" + dataset = self.train_dataset if self.mode == 'train' else self.test_dataset + filename, label = dataset[index] + image = Image.open(os.path.join(self.image_dir, filename)) return self.transform(image), torch.FloatTensor(label) def __len__(self): - return self.num_data + """Return the number of images.""" + return self.num_images -def get_loader(image_path, metadata_path, crop_size, image_size, batch_size, dataset='CelebA', mode='train'): - """Build and return data loader.""" - +def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=128, + batch_size=16, dataset='CelebA', mode='train', num_workers=1): + """Build and return a data loader.""" + transform = [] if mode == 'train': - transform = transforms.Compose([ - transforms.CenterCrop(crop_size), - transforms.Resize(image_size, interpolation=Image.ANTIALIAS), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) - else: - transform = transforms.Compose([ - transforms.CenterCrop(crop_size), - transforms.Scale(image_size, interpolation=Image.ANTIALIAS), - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + transform.append(T.RandomHorizontalFlip()) + transform.append(T.CenterCrop(crop_size)) + transform.append(T.Resize(image_size)) + transform.append(T.ToTensor()) + transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) + transform = T.Compose(transform) if dataset == 'CelebA': - dataset = CelebDataset(image_path, metadata_path, transform, mode) + dataset = CelebA(image_dir, attr_path, selected_attrs, transform, mode) elif dataset == 'RaFD': - dataset = ImageFolder(image_path, transform) + dataset = ImageFolder(image_dir, transform) - shuffle = False - if mode == 'train': - shuffle = True - - data_loader = DataLoader(dataset=dataset, - batch_size=batch_size, - shuffle=shuffle) + data_loader = data.DataLoader(dataset=dataset, + batch_size=batch_size, + shuffle=(mode=='train'), + num_workers=num_workers) return data_loader \ No newline at end of file diff --git a/logger.py b/logger.py index ece1005..f30431e 100644 --- a/logger.py +++ b/logger.py @@ -1,71 +1,14 @@ -# Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 import tensorflow as tf -import numpy as np -import scipy.misc -try: - from StringIO import StringIO # Python 2.7 -except ImportError: - from io import BytesIO # Python 3.5+ class Logger(object): - + """Tensorboard logger.""" + def __init__(self, log_dir): - """Create a summary writer logging to log_dir.""" + """Initialize summary writer.""" self.writer = tf.summary.FileWriter(log_dir) def scalar_summary(self, tag, value, step): - """Log a scalar variable.""" + """Add scalar summary.""" summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) - self.writer.add_summary(summary, step) - - def image_summary(self, tag, images, step): - """Log a list of images.""" - - img_summaries = [] - for i, img in enumerate(images): - # Write the image to a string - try: - s = StringIO() - except: - s = BytesIO() - scipy.misc.toimage(img).save(s, format="png") - - # Create an Image object - img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), - height=img.shape[0], - width=img.shape[1]) - # Create a Summary value - img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) - - # Create and write Summary - summary = tf.Summary(value=img_summaries) - self.writer.add_summary(summary, step) - - def histo_summary(self, tag, values, step, bins=1000): - """Log a histogram of the tensor of values.""" - - # Create a histogram using numpy - counts, bin_edges = np.histogram(values, bins=bins) - - # Fill the fields of the histogram proto - hist = tf.HistogramProto() - hist.min = float(np.min(values)) - hist.max = float(np.max(values)) - hist.num = int(np.prod(values.shape)) - hist.sum = float(np.sum(values)) - hist.sum_squares = float(np.sum(values**2)) - - # Drop the start of the first bin - bin_edges = bin_edges[1:] - - # Add bin edges and counts - for edge in bin_edges: - hist.bucket_limit.append(edge) - for c in counts: - hist.bucket.append(c) - - # Create and write Summary - summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) - self.writer.add_summary(summary, step) - self.writer.flush() \ No newline at end of file + self.writer.add_summary(summary, step) \ No newline at end of file diff --git a/main.py b/main.py index 46cb04d..5079b63 100644 --- a/main.py +++ b/main.py @@ -9,32 +9,35 @@ def str2bool(v): return v.lower() in ('true') def main(config): - # For fast training + # For fast training. cudnn.benchmark = True - # Create directories if not exist - if not os.path.exists(config.log_path): - os.makedirs(config.log_path) - if not os.path.exists(config.model_save_path): - os.makedirs(config.model_save_path) - if not os.path.exists(config.sample_path): - os.makedirs(config.sample_path) - if not os.path.exists(config.result_path): - os.makedirs(config.result_path) + # Create directories if not exist. + if not os.path.exists(config.log_dir): + os.makedirs(config.log_dir) + if not os.path.exists(config.model_save_dir): + os.makedirs(config.model_save_dir) + if not os.path.exists(config.sample_dir): + os.makedirs(config.sample_dir) + if not os.path.exists(config.result_dir): + os.makedirs(config.result_dir) - # Data loader - celebA_loader = None + # Data loader. + celeba_loader = None rafd_loader = None if config.dataset in ['CelebA', 'Both']: - celebA_loader = get_loader(config.celebA_image_path, config.metadata_path, config.celebA_crop_size, - config.image_size, config.batch_size, 'CelebA', config.mode) + celeba_loader = get_loader(config.celeba_image_dir, config.attr_path, config.selected_attrs, + config.celeba_crop_size, config.image_size, config.batch_size, + 'CelebA', config.mode, config.num_workers) if config.dataset in ['RaFD', 'Both']: - rafd_loader = get_loader(config.rafd_image_path, None, config.rafd_crop_size, - config.image_size, config.batch_size, 'RaFD', config.mode) + rafd_loader = get_loader(config.rafd_image_dir, None, None, + config.rafd_crop_size, config.image_size, config.batch_size, + 'RaFD', config.mode, config.num_workers) + - # Solver - solver = Solver(celebA_loader, rafd_loader, config) + # Solver for training and testing StarGAN. + solver = Solver(celeba_loader, rafd_loader, config) if config.mode == 'train': if config.dataset in ['CelebA', 'RaFD']: @@ -51,55 +54,56 @@ def main(config): if __name__ == '__main__': parser = argparse.ArgumentParser() - # Model hyper-parameters - parser.add_argument('--c_dim', type=int, default=5) - parser.add_argument('--c2_dim', type=int, default=8) - parser.add_argument('--celebA_crop_size', type=int, default=178) - parser.add_argument('--rafd_crop_size', type=int, default=256) - parser.add_argument('--image_size', type=int, default=128) - parser.add_argument('--g_conv_dim', type=int, default=64) - parser.add_argument('--d_conv_dim', type=int, default=64) - parser.add_argument('--g_repeat_num', type=int, default=6) - parser.add_argument('--d_repeat_num', type=int, default=6) - parser.add_argument('--g_lr', type=float, default=0.0001) - parser.add_argument('--d_lr', type=float, default=0.0001) - parser.add_argument('--lambda_cls', type=float, default=1) - parser.add_argument('--lambda_rec', type=float, default=10) - parser.add_argument('--lambda_gp', type=float, default=10) - parser.add_argument('--d_train_repeat', type=int, default=5) - - # Training settings + # Model configuration. + parser.add_argument('--c_dim', type=int, default=5, help='dimension of domain labels (1st dataset)') + parser.add_argument('--c2_dim', type=int, default=8, help='dimension of domain labels (2nd dataset)') + parser.add_argument('--celeba_crop_size', type=int, default=178, help='crop size for the CelebA dataset') + parser.add_argument('--rafd_crop_size', type=int, default=256, help='crop size for the RaFD dataset') + parser.add_argument('--image_size', type=int, default=128, help='image resolution') + parser.add_argument('--g_conv_dim', type=int, default=64, help='number of conv filters in the first layer of G') + parser.add_argument('--d_conv_dim', type=int, default=64, help='number of conv filters in the first layer of D') + parser.add_argument('--g_repeat_num', type=int, default=6, help='number of residual blocks in G') + parser.add_argument('--d_repeat_num', type=int, default=6, help='number of strided conv layers in D') + parser.add_argument('--lambda_cls', type=float, default=1, help='weight for domain classification loss') + parser.add_argument('--lambda_rec', type=float, default=10, help='weight for reconstruction loss') + parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty') + + # Training configuration. parser.add_argument('--dataset', type=str, default='CelebA', choices=['CelebA', 'RaFD', 'Both']) - parser.add_argument('--num_epochs', type=int, default=20) - parser.add_argument('--num_epochs_decay', type=int, default=10) - parser.add_argument('--num_iters', type=int, default=200000) - parser.add_argument('--num_iters_decay', type=int, default=100000) - parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--batch_size', type=int, default=16, help='mini-batch size') + parser.add_argument('--num_iters', type=int, default=200000, help='number of total iterations for training D') + parser.add_argument('--num_iters_decay', type=int, default=100000, help='number of iterations for decaying lr') + parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for G') + parser.add_argument('--d_lr', type=float, default=0.0001, help='learning rate for D') + parser.add_argument('--n_critic', type=int, default=5, help='number of D updates per each G update') + parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer') + parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer') + parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this step') + parser.add_argument('--selected_attrs', '--list', nargs='+', help='selected attributes for the CelebA dataset', + default=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']) + + # Test configuration. + parser.add_argument('--test_iters', type=int, default=200000, help='test model from this step') + + # Miscellaneous. parser.add_argument('--num_workers', type=int, default=1) - parser.add_argument('--beta1', type=float, default=0.5) - parser.add_argument('--beta2', type=float, default=0.999) - parser.add_argument('--pretrained_model', type=str, default=None) - - # Test settings - parser.add_argument('--test_model', type=str, default='20_1000') - - # Misc parser.add_argument('--mode', type=str, default='train', choices=['train', 'test']) - parser.add_argument('--use_tensorboard', type=str2bool, default=False) + parser.add_argument('--use_tensorboard', type=str2bool, default=True) - # Path - parser.add_argument('--celebA_image_path', type=str, default='./data/CelebA_nocrop/images') - parser.add_argument('--rafd_image_path', type=str, default='./data/RaFD/train') - parser.add_argument('--metadata_path', type=str, default='./data/list_attr_celeba.txt') - parser.add_argument('--log_path', type=str, default='./stargan/logs') - parser.add_argument('--model_save_path', type=str, default='./stargan/models') - parser.add_argument('--sample_path', type=str, default='./stargan/samples') - parser.add_argument('--result_path', type=str, default='./stargan/results') + # Directories. + parser.add_argument('--celeba_image_dir', type=str, default='data/CelebA_nocrop/images') + parser.add_argument('--attr_path', type=str, default='data/list_attr_celeba.txt') + parser.add_argument('--rafd_image_dir', type=str, default='data/RaFD/train') + parser.add_argument('--log_dir', type=str, default='stargan/logs') + parser.add_argument('--model_save_dir', type=str, default='stargan/models') + parser.add_argument('--sample_dir', type=str, default='stargan/samples') + parser.add_argument('--result_dir', type=str, default='stargan/results') - # Step size + # Step size. parser.add_argument('--log_step', type=int, default=10) - parser.add_argument('--sample_step', type=int, default=500) - parser.add_argument('--model_save_step', type=int, default=1000) + parser.add_argument('--sample_step', type=int, default=1000) + parser.add_argument('--model_save_step', type=int, default=10000) + parser.add_argument('--lr_update_step', type=int, default=1000) config = parser.parse_args() print(config) diff --git a/model.py b/model.py index f317dbc..ab3f6cf 100644 --- a/model.py +++ b/model.py @@ -5,7 +5,7 @@ import numpy as np class ResidualBlock(nn.Module): - """Residual Block.""" + """Residual Block with instance normalization.""" def __init__(self, dim_in, dim_out): super(ResidualBlock, self).__init__() self.main = nn.Sequential( @@ -20,7 +20,7 @@ class ResidualBlock(nn.Module): class Generator(nn.Module): - """Generator. Encoder-Decoder Architecture.""" + """Generator network.""" def __init__(self, conv_dim=64, c_dim=5, repeat_num=6): super(Generator, self).__init__() @@ -29,7 +29,7 @@ class Generator(nn.Module): layers.append(nn.InstanceNorm2d(conv_dim, affine=True)) layers.append(nn.ReLU(inplace=True)) - # Down-Sampling + # Down-sampling layers. curr_dim = conv_dim for i in range(2): layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False)) @@ -37,11 +37,11 @@ class Generator(nn.Module): layers.append(nn.ReLU(inplace=True)) curr_dim = curr_dim * 2 - # Bottleneck + # Bottleneck layers. for i in range(repeat_num): layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim)) - # Up-Sampling + # Up-sampling layers. for i in range(2): layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False)) layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True)) @@ -53,35 +53,34 @@ class Generator(nn.Module): self.main = nn.Sequential(*layers) def forward(self, x, c): - # replicate spatially and concatenate domain information - c = c.unsqueeze(2).unsqueeze(3) - c = c.expand(c.size(0), c.size(1), x.size(2), x.size(3)) + # Replicate spatially and concatenate domain information. + c = c.view(c.size(0), c.size(1), 1, 1) + c = c.repeat(1, 1, x.size(2), x.size(3)) x = torch.cat([x, c], dim=1) return self.main(x) class Discriminator(nn.Module): - """Discriminator. PatchGAN.""" + """Discriminator network with PatchGAN.""" def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6): super(Discriminator, self).__init__() - layers = [] layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)) - layers.append(nn.LeakyReLU(0.01, inplace=True)) + layers.append(nn.LeakyReLU(0.01)) curr_dim = conv_dim for i in range(1, repeat_num): layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1)) - layers.append(nn.LeakyReLU(0.01, inplace=True)) + layers.append(nn.LeakyReLU(0.01)) curr_dim = curr_dim * 2 - k_size = int(image_size / np.power(2, repeat_num)) + kernel_size = int(image_size / np.power(2, repeat_num)) self.main = nn.Sequential(*layers) self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False) - self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=k_size, bias=False) - + self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=kernel_size, bias=False) + def forward(self, x): h = self.main(x) - out_real = self.conv1(h) - out_aux = self.conv2(h) - return out_real.squeeze(), out_aux.squeeze() \ No newline at end of file + out_src = self.conv1(h) + out_cls = self.conv2(h) + return out_src, out_cls.view(out_cls.size(0), out_cls.size(1)) \ No newline at end of file diff --git a/solver.py b/solver.py index d615eac..a033514 100644 --- a/solver.py +++ b/solver.py @@ -1,27 +1,26 @@ +from model import Generator +from model import Discriminator +from torch.autograd import Variable +from torchvision.utils import save_image import torch -import torch.nn as nn import torch.nn.functional as F import numpy as np import os import time import datetime -from torch.autograd import grad -from torch.autograd import Variable -from torchvision.utils import save_image -from torchvision import transforms -from model import Generator -from model import Discriminator -from PIL import Image class Solver(object): + """Solver for training and testing StarGAN.""" - def __init__(self, celebA_loader, rafd_loader, config): - # Data loader - self.celebA_loader = celebA_loader + def __init__(self, celeba_loader, rafd_loader, config): + """Initialize configurations.""" + + # Data loader. + self.celeba_loader = celeba_loader self.rafd_loader = rafd_loader - # Model hyper-parameters + # Model configurations. self.c_dim = config.c_dim self.c2_dim = config.c2_dim self.image_size = config.image_size @@ -29,64 +28,58 @@ class Solver(object): self.d_conv_dim = config.d_conv_dim self.g_repeat_num = config.g_repeat_num self.d_repeat_num = config.d_repeat_num - self.d_train_repeat = config.d_train_repeat - - # Hyper-parameteres self.lambda_cls = config.lambda_cls self.lambda_rec = config.lambda_rec self.lambda_gp = config.lambda_gp - self.g_lr = config.g_lr - self.d_lr = config.d_lr - self.beta1 = config.beta1 - self.beta2 = config.beta2 - # Training settings + # Training configurations. self.dataset = config.dataset - self.num_epochs = config.num_epochs - self.num_epochs_decay = config.num_epochs_decay + self.batch_size = config.batch_size self.num_iters = config.num_iters self.num_iters_decay = config.num_iters_decay - self.batch_size = config.batch_size + self.g_lr = config.g_lr + self.d_lr = config.d_lr + self.n_critic = config.n_critic + self.beta1 = config.beta1 + self.beta2 = config.beta2 + self.resume_iters = config.resume_iters + self.selected_attrs = config.selected_attrs + + # Test configurations. + self.test_iters = config.test_iters + + # Miscellaneous. self.use_tensorboard = config.use_tensorboard - self.pretrained_model = config.pretrained_model + self.dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor - # Test settings - self.test_model = config.test_model + # Directories. + self.log_dir = config.log_dir + self.sample_dir = config.sample_dir + self.model_save_dir = config.model_save_dir + self.result_dir = config.result_dir - # Path - self.log_path = config.log_path - self.sample_path = config.sample_path - self.model_save_path = config.model_save_path - self.result_path = config.result_path - - # Step size + # Step size. self.log_step = config.log_step self.sample_step = config.sample_step self.model_save_step = config.model_save_step + self.lr_update_step = config.lr_update_step - # Build tensorboard if use + # Build the model and tensorboard. self.build_model() if self.use_tensorboard: self.build_tensorboard() - # Start with trained model - if self.pretrained_model: - self.load_pretrained_model() - def build_model(self): - # Define a generator and a discriminator - if self.dataset == 'Both': - self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num) # 2 for mask vector - self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num) - else: + """Create a generator and a discriminator.""" + if self.dataset in ['CelebA', 'RaFD']: self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num) self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num) + elif self.dataset in ['Both']: + self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num) # 2 for mask vector. + self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num) - # Optimizers self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) - - # Print networks self.print_network(self.G, 'G') self.print_network(self.D, 'D') @@ -95,529 +88,229 @@ class Solver(object): self.D.cuda() def print_network(self, model, name): + """Print out the network information.""" num_params = 0 for p in model.parameters(): num_params += p.numel() - print(name) print(model) + print(name) print("The number of parameters: {}".format(num_params)) - def load_pretrained_model(self): - self.G.load_state_dict(torch.load(os.path.join( - self.model_save_path, '{}_G.pth'.format(self.pretrained_model)))) - self.D.load_state_dict(torch.load(os.path.join( - self.model_save_path, '{}_D.pth'.format(self.pretrained_model)))) - print('loaded trained models (step: {})..!'.format(self.pretrained_model)) + def restore_model(self, resume_iters): + """Restore the trained generator and discriminator.""" + print('Loading the trained models from step {}...'.format(resume_iters)) + G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters)) + D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters)) + self.G.load_state_dict(torch.load(G_path)) + self.D.load_state_dict(torch.load(D_path)) def build_tensorboard(self): + """Build a tensorboard logger.""" from logger import Logger - self.logger = Logger(self.log_path) + self.logger = Logger(self.log_dir) def update_lr(self, g_lr, d_lr): + """Decay learning rates of the generator and discriminator.""" for param_group in self.g_optimizer.param_groups: param_group['lr'] = g_lr for param_group in self.d_optimizer.param_groups: param_group['lr'] = d_lr def reset_grad(self): + """Reset the gradient buffers.""" self.g_optimizer.zero_grad() self.d_optimizer.zero_grad() - def to_var(self, x, volatile=False): + def tensor2var(self, x, volatile=False): + """Convert torch tensor to variable.""" if torch.cuda.is_available(): x = x.cuda() return Variable(x, volatile=volatile) def denorm(self, x): + """Convert the range from [-1, 1] to [0, 1].""" out = (x + 1) / 2 return out.clamp_(0, 1) - def threshold(self, x): - x = x.clone() - x = (x >= 0.5).float() - return x + def gradient_penalty(self, y, x, dtype): + """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2.""" + weight = torch.ones(y.size()).type(dtype) + dydx = torch.autograd.grad(outputs=y, + inputs=x, + grad_outputs=weight, + retain_graph=True, + create_graph=True, + only_inputs=True)[0] - def compute_accuracy(self, x, y, dataset): - if dataset == 'CelebA': - x = F.sigmoid(x) - predicted = self.threshold(x) - correct = (predicted == y).float() - accuracy = torch.mean(correct, dim=0) * 100.0 - else: - _, predicted = torch.max(x, dim=1) - correct = (predicted == y).float() - accuracy = torch.mean(correct) * 100.0 - return accuracy + dydx = dydx.view(dydx.size(0), -1) + dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1)) + return torch.mean((dydx_l2norm-1)**2) - def one_hot(self, labels, dim): - """Convert label indices to one-hot vector""" + def label2onehot(self, labels, dim): + """Convert label indices to one-hot vectors.""" batch_size = labels.size(0) out = torch.zeros(batch_size, dim) out[np.arange(batch_size), labels.long()] = 1 return out - def make_celeb_labels(self, real_c): - """Generate domain labels for CelebA for debugging/testing. - + def create_labels(self, c_org, c_dim=5, dataset='CelebA', selected_attrs=None): + """Generate target domain labels for debugging and testing.""" + # Get hair color indices. if dataset == 'CelebA': - return single and multiple attribute changes - elif dataset == 'Both': - return single attribute changes - """ - y = [torch.FloatTensor([1, 0, 0]), # black hair - torch.FloatTensor([0, 1, 0]), # blond hair - torch.FloatTensor([0, 0, 1])] # brown hair + hair_color_indices = [] + for i, attr_name in enumerate(selected_attrs): + if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']: + hair_color_indices.append(i) - fixed_c_list = [] - - # single attribute transfer - for i in range(self.c_dim): - fixed_c = real_c.clone() - for c in fixed_c: - if i < 3: - c[:3] = y[i] + c_trg_list = [] + for i in range(c_dim): + if dataset == 'CelebA': + c_trg = c_org.clone() + if i in hair_color_indices: # Set one hair color to 1 and the rest to 0. + c_trg[:, i] = 1 + for j in hair_color_indices: + if j != i: + c_trg[:, j] = 0 else: - c[i] = 0 if c[i] == 1 else 1 # opposite value - fixed_c_list.append(self.to_var(fixed_c, volatile=True)) + c_trg[:, i] = (c_trg[:, i] == 0) # Reverse attribute value. + elif dataset == 'RaFD': + c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim) - # multi-attribute transfer (H+G, H+A, G+A, H+G+A) - if self.dataset == 'CelebA': - for i in range(4): - fixed_c = real_c.clone() - for c in fixed_c: - if i in [0, 1, 3]: # Hair color to brown - c[:3] = y[2] - if i in [0, 2, 3]: # Gender - c[3] = 0 if c[3] == 1 else 1 - if i in [1, 2, 3]: # Aged - c[4] = 0 if c[4] == 1 else 1 - fixed_c_list.append(self.to_var(fixed_c, volatile=True)) - return fixed_c_list + c_trg_list.append(self.tensor2var(c_trg, volatile=True)) + return c_trg_list + + def classification_loss(self, logit, target, dataset='CelebA'): + """Compute binary or softmax cross entropy loss.""" + if dataset == 'CelebA': + return F.binary_cross_entropy_with_logits(logit, target, size_average=False) / logit.size(0) + elif dataset == 'RaFD': + return F.cross_entropy(logit, target) def train(self): """Train StarGAN within a single dataset.""" - - # Set dataloader + # Set data loader. if self.dataset == 'CelebA': - self.data_loader = self.celebA_loader - else: - self.data_loader = self.rafd_loader - - # The number of iterations per epoch - iters_per_epoch = len(self.data_loader) - - fixed_x = [] - real_c = [] - for i, (images, labels) in enumerate(self.data_loader): - fixed_x.append(images) - real_c.append(labels) - if i == 3: - break - - # Fixed inputs and target domain labels for debugging - fixed_x = torch.cat(fixed_x, dim=0) - fixed_x = self.to_var(fixed_x, volatile=True) - real_c = torch.cat(real_c, dim=0) - - if self.dataset == 'CelebA': - fixed_c_list = self.make_celeb_labels(real_c) + data_loader = self.celeba_loader elif self.dataset == 'RaFD': - fixed_c_list = [] - for i in range(self.c_dim): - fixed_c = self.one_hot(torch.ones(fixed_x.size(0)) * i, self.c_dim) - fixed_c_list.append(self.to_var(fixed_c, volatile=True)) + data_loader = self.rafd_loader - # lr cache for decaying + # Fetch fixed inputs for debugging. + data_iter = iter(data_loader) + x_fixed, c_org = next(data_iter) + x_fixed = self.tensor2var(x_fixed, volatile=True) + c_fixed_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs) + + # Learning rate cache for decaying. g_lr = self.g_lr d_lr = self.d_lr - # Start with trained model if exists - if self.pretrained_model: - start = int(self.pretrained_model.split('_')[0]) - else: - start = 0 + # Start training from scratch or resume training. + start_iters = 0 + if self.resume_iters: + start_iters = self.resume_iters + self.restore_model(self.resume_iters) - # Start training + # Start training. + print('Start training...') start_time = time.time() - for e in range(start, self.num_epochs): - for i, (real_x, real_label) in enumerate(self.data_loader): - - # Generat fake labels randomly (target domain labels) - rand_idx = torch.randperm(real_label.size(0)) - fake_label = real_label[rand_idx] + for i in range(start_iters, self.num_iters): - if self.dataset == 'CelebA': - real_c = real_label.clone() - fake_c = fake_label.clone() - else: - real_c = self.one_hot(real_label, self.c_dim) - fake_c = self.one_hot(fake_label, self.c_dim) + # =================================================================================== # + # 1. Preprocess input data # + # =================================================================================== # - # Convert tensor to variable - real_x = self.to_var(real_x) - real_c = self.to_var(real_c) # input for the generator - fake_c = self.to_var(fake_c) - real_label = self.to_var(real_label) # this is same as real_c if dataset == 'CelebA' - fake_label = self.to_var(fake_label) - - # ================== Train D ================== # - - # Compute loss with real images - out_src, out_cls = self.D(real_x) - d_loss_real = - torch.mean(out_src) - - if self.dataset == 'CelebA': - d_loss_cls = F.binary_cross_entropy_with_logits( - out_cls, real_label, size_average=False) / real_x.size(0) - else: - d_loss_cls = F.cross_entropy(out_cls, real_label) - - # Compute classification accuracy of the discriminator - if (i+1) % self.log_step == 0: - accuracies = self.compute_accuracy(out_cls, real_label, self.dataset) - log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()] - if self.dataset == 'CelebA': - print('Classification Acc (Black/Blond/Brown/Gender/Aged): ', end='') - else: - print('Classification Acc (8 emotional expressions): ', end='') - print(log) - - # Compute loss with fake images - fake_x = self.G(real_x, fake_c) - fake_x = Variable(fake_x.data) - out_src, out_cls = self.D(fake_x) - d_loss_fake = torch.mean(out_src) - - # Backward + Optimize - d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls - self.reset_grad() - d_loss.backward() - self.d_optimizer.step() - - # Compute gradient penalty - alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x) - interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True) - out, out_cls = self.D(interpolated) - - grad = torch.autograd.grad(outputs=out, - inputs=interpolated, - grad_outputs=torch.ones(out.size()).cuda(), - retain_graph=True, - create_graph=True, - only_inputs=True)[0] - - grad = grad.view(grad.size(0), -1) - grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1)) - d_loss_gp = torch.mean((grad_l2norm - 1)**2) - - # Backward + Optimize - d_loss = self.lambda_gp * d_loss_gp - self.reset_grad() - d_loss.backward() - self.d_optimizer.step() - - # Logging - loss = {} - loss['D/loss_real'] = d_loss_real.data[0] - loss['D/loss_fake'] = d_loss_fake.data[0] - loss['D/loss_cls'] = d_loss_cls.data[0] - loss['D/loss_gp'] = d_loss_gp.data[0] - - # ================== Train G ================== # - if (i+1) % self.d_train_repeat == 0: - - # Original-to-target and target-to-original domain - fake_x = self.G(real_x, fake_c) - rec_x = self.G(fake_x, real_c) - - # Compute losses - out_src, out_cls = self.D(fake_x) - g_loss_fake = - torch.mean(out_src) - g_loss_rec = torch.mean(torch.abs(real_x - rec_x)) - - if self.dataset == 'CelebA': - g_loss_cls = F.binary_cross_entropy_with_logits( - out_cls, fake_label, size_average=False) / fake_x.size(0) - else: - g_loss_cls = F.cross_entropy(out_cls, fake_label) - - # Backward + Optimize - g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls - self.reset_grad() - g_loss.backward() - self.g_optimizer.step() - - # Logging - loss['G/loss_fake'] = g_loss_fake.data[0] - loss['G/loss_rec'] = g_loss_rec.data[0] - loss['G/loss_cls'] = g_loss_cls.data[0] - - # Print out log info - if (i+1) % self.log_step == 0: - elapsed = time.time() - start_time - elapsed = str(datetime.timedelta(seconds=elapsed)) - - log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format( - elapsed, e+1, self.num_epochs, i+1, iters_per_epoch) - - for tag, value in loss.items(): - log += ", {}: {:.4f}".format(tag, value) - print(log) - - if self.use_tensorboard: - for tag, value in loss.items(): - self.logger.scalar_summary(tag, value, e * iters_per_epoch + i + 1) - - # Translate fixed images for debugging - if (i+1) % self.sample_step == 0: - fake_image_list = [fixed_x] - for fixed_c in fixed_c_list: - fake_image_list.append(self.G(fixed_x, fixed_c)) - fake_images = torch.cat(fake_image_list, dim=3) - save_image(self.denorm(fake_images.data.cpu()), - os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0) - print('Translated images and saved into {}..!'.format(self.sample_path)) - - # Save model checkpoints - if (i+1) % self.model_save_step == 0: - torch.save(self.G.state_dict(), - os.path.join(self.model_save_path, '{}_{}_G.pth'.format(e+1, i+1))) - torch.save(self.D.state_dict(), - os.path.join(self.model_save_path, '{}_{}_D.pth'.format(e+1, i+1))) - - # Decay learning rate - if (e+1) > (self.num_epochs - self.num_epochs_decay): - g_lr -= (self.g_lr / float(self.num_epochs_decay)) - d_lr -= (self.d_lr / float(self.num_epochs_decay)) - self.update_lr(g_lr, d_lr) - print ('Decay learning rate to g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) - - def train_multi(self): - """Train StarGAN with multiple datasets. - In the code below, 1 is related to CelebA and 2 is releated to RaFD. - """ - # Fixed imagse and labels for debugging - fixed_x = [] - real_c = [] - - for i, (images, labels) in enumerate(self.celebA_loader): - fixed_x.append(images) - real_c.append(labels) - if i == 2: - break - - fixed_x = torch.cat(fixed_x, dim=0) - fixed_x = self.to_var(fixed_x, volatile=True) - real_c = torch.cat(real_c, dim=0) - fixed_c1_list = self.make_celeb_labels(real_c) - - fixed_c2_list = [] - for i in range(self.c2_dim): - fixed_c = self.one_hot(torch.ones(fixed_x.size(0)) * i, self.c2_dim) - fixed_c2_list.append(self.to_var(fixed_c, volatile=True)) - - fixed_zero1 = self.to_var(torch.zeros(fixed_x.size(0), self.c2_dim)) # zero vector when training with CelebA - fixed_mask1 = self.to_var(self.one_hot(torch.zeros(fixed_x.size(0)), 2)) # mask vector: [1, 0] - fixed_zero2 = self.to_var(torch.zeros(fixed_x.size(0), self.c_dim)) # zero vector when training with RaFD - fixed_mask2 = self.to_var(self.one_hot(torch.ones(fixed_x.size(0)), 2)) # mask vector: [0, 1] - - # lr cache for decaying - g_lr = self.g_lr - d_lr = self.d_lr - - # data iterator - data_iter1 = iter(self.celebA_loader) - data_iter2 = iter(self.rafd_loader) - - # Start with trained model - if self.pretrained_model: - start = int(self.pretrained_model) + 1 - else: - start = 0 - - # # Start training - start_time = time.time() - for i in range(start, self.num_iters): - - # Fetch mini-batch images and labels + # Fetch real images and labels. try: - real_x1, real_label1 = next(data_iter1) + x_real, label_org = next(data_iter) except: - data_iter1 = iter(self.celebA_loader) - real_x1, real_label1 = next(data_iter1) + data_iter = iter(data_loader) + x_real, label_org = next(data_iter) - try: - real_x2, real_label2 = next(data_iter2) - except: - data_iter2 = iter(self.rafd_loader) - real_x2, real_label2 = next(data_iter2) + # Generate target domain labels randomly. + rand_idx = torch.randperm(label_org.size(0)) + label_trg = label_org[rand_idx] - # Generate fake labels randomly (target domain labels) - rand_idx = torch.randperm(real_label1.size(0)) - fake_label1 = real_label1[rand_idx] - rand_idx = torch.randperm(real_label2.size(0)) - fake_label2 = real_label2[rand_idx] + if self.dataset == 'CelebA': + c_org = label_org.clone() + c_trg = label_trg.clone() + elif self.dataset == 'RaFD': + c_org = self.label2onehot(label_org, self.c_dim) + c_trg = self.label2onehot(label_trg, self.c_dim) - real_c1 = real_label1.clone() - fake_c1 = fake_label1.clone() - zero1 = torch.zeros(real_x1.size(0), self.c2_dim) - mask1 = self.one_hot(torch.zeros(real_x1.size(0)), 2) + x_real = self.tensor2var(x_real) # Input images. + c_org = self.tensor2var(c_org) # Original domain labels. + c_trg = self.tensor2var(c_trg) # Target domain labels. + label_org = self.tensor2var(label_org) # Labels for computing classification loss. + label_trg = self.tensor2var(label_trg) # Labels for computing classification loss. - real_c2 = self.one_hot(real_label2, self.c2_dim) - fake_c2 = self.one_hot(fake_label2, self.c2_dim) - zero2 = torch.zeros(real_x2.size(0), self.c_dim) - mask2 = self.one_hot(torch.ones(real_x2.size(0)), 2) + # =================================================================================== # + # 2. Train the discriminator # + # =================================================================================== # - # Convert tensor to variable - real_x1 = self.to_var(real_x1) - real_c1 = self.to_var(real_c1) - fake_c1 = self.to_var(fake_c1) - mask1 = self.to_var(mask1) - zero1 = self.to_var(zero1) + # Compute loss with real images. + out_src, out_cls = self.D(x_real) + d_loss_real = - torch.mean(out_src) + d_loss_cls = self.classification_loss(out_cls, label_org, self.dataset) - real_x2 = self.to_var(real_x2) - real_c2 = self.to_var(real_c2) - fake_c2 = self.to_var(fake_c2) - mask2 = self.to_var(mask2) - zero2 = self.to_var(zero2) + # Compute loss with fake images. + x_fake = self.G(x_real, c_trg) + out_src, out_cls = self.D(x_fake.detach()) + d_loss_fake = torch.mean(out_src) - real_label1 = self.to_var(real_label1) - fake_label1 = self.to_var(fake_label1) - real_label2 = self.to_var(real_label2) - fake_label2 = self.to_var(fake_label2) + # Compute loss for gradient penalty. + alpha = torch.rand(x_real.size(0), 1, 1, 1).type(self.dtype) + x_hat = Variable(alpha * x_real.data + (1 - alpha) * x_fake.data, requires_grad=True) + out_src, _ = self.D(x_hat) + d_loss_gp = self.gradient_penalty(out_src, x_hat, self.dtype) - # ================== Train D ================== # - - # Real images (CelebA) - out_real, out_cls = self.D(real_x1) - out_cls1 = out_cls[:, :self.c_dim] # celebA part - d_loss_real = - torch.mean(out_real) - d_loss_cls = F.binary_cross_entropy_with_logits(out_cls1, real_label1, size_average=False) / real_x1.size(0) - - # Real images (RaFD) - out_real, out_cls = self.D(real_x2) - out_cls2 = out_cls[:, self.c_dim:] # rafd part - d_loss_real += - torch.mean(out_real) - d_loss_cls += F.cross_entropy(out_cls2, real_label2) - - # Compute classification accuracy of the discriminator - if (i+1) % self.log_step == 0: - accuracies = self.compute_accuracy(out_cls1, real_label1, 'CelebA') - log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()] - print('Classification Acc (Black/Blond/Brown/Gender/Aged): ', end='') - print(log) - accuracies = self.compute_accuracy(out_cls2, real_label2, 'RaFD') - log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()] - print('Classification Acc (8 emotional expressions): ', end='') - print(log) - - # Fake images (CelebA) - fake_c = torch.cat([fake_c1, zero1, mask1], dim=1) - fake_x1 = self.G(real_x1, fake_c) - fake_x1 = Variable(fake_x1.data) - out_fake, _ = self.D(fake_x1) - d_loss_fake = torch.mean(out_fake) - - # Fake images (RaFD) - fake_c = torch.cat([zero2, fake_c2, mask2], dim=1) - fake_x2 = self.G(real_x2, fake_c) - out_fake, _ = self.D(fake_x2) - d_loss_fake += torch.mean(out_fake) - - # Backward + Optimize - d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + # Backward and optimize. + d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.d_optimizer.step() - # Compute gradient penalty - if (i+1) % 2 == 0: - real_x = real_x1 - fake_x = fake_x1 - else: - real_x = real_x2 - fake_x = fake_x2 - - alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x) - interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True) - out, out_cls = self.D(interpolated) - - if (i+1) % 2 == 0: - out_cls = out_cls[:, :self.c_dim] # CelebA - else: - out_cls = out_cls[:, self.c_dim:] # RaFD - - grad = torch.autograd.grad(outputs=out, - inputs=interpolated, - grad_outputs=torch.ones(out.size()).cuda(), - retain_graph=True, - create_graph=True, - only_inputs=True)[0] - - grad = grad.view(grad.size(0), -1) - grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1)) - d_loss_gp = torch.mean((grad_l2norm - 1)**2) - - # Backward + Optimize - d_loss = self.lambda_gp * d_loss_gp - self.reset_grad() - d_loss.backward() - self.d_optimizer.step() - - # Logging + # Logging. loss = {} loss['D/loss_real'] = d_loss_real.data[0] loss['D/loss_fake'] = d_loss_fake.data[0] loss['D/loss_cls'] = d_loss_cls.data[0] loss['D/loss_gp'] = d_loss_gp.data[0] - # ================== Train G ================== # - if (i+1) % self.d_train_repeat == 0: - # Original-to-target and target-to-original domain (CelebA) - fake_c = torch.cat([fake_c1, zero1, mask1], dim=1) - real_c = torch.cat([real_c1, zero1, mask1], dim=1) - fake_x1 = self.G(real_x1, fake_c) - rec_x1 = self.G(fake_x1, real_c) + # =================================================================================== # + # 3. Train the generator # + # =================================================================================== # + + if (i+1) % self.n_critic == 0: + # Original-to-target domain. + x_fake = self.G(x_real, c_trg) + out_src, out_cls = self.D(x_fake) + g_loss_fake = - torch.mean(out_src) + g_loss_cls = self.classification_loss(out_cls, label_trg, self.dataset) - # Compute losses - out, out_cls = self.D(fake_x1) - out_cls1 = out_cls[:, :self.c_dim] - g_loss_fake = - torch.mean(out) - g_loss_rec = torch.mean(torch.abs(real_x1 - rec_x1)) - g_loss_cls = F.binary_cross_entropy_with_logits(out_cls1, fake_label1, size_average=False) / fake_x1.size(0) + # Target-to-original domain. + x_reconst = self.G(x_fake, c_org) + g_loss_rec = torch.mean(torch.abs(x_real - x_reconst)) - # Original-to-target and target-to-original domain (RaFD) - fake_c = torch.cat([zero2, fake_c2, mask2], dim=1) - real_c = torch.cat([zero2, real_c2, mask2], dim=1) - fake_x2 = self.G(real_x2, fake_c) - rec_x2 = self.G(fake_x2, real_c) - - # Compute losses - out, out_cls = self.D(fake_x2) - out_cls2 = out_cls[:, self.c_dim:] - g_loss_fake += - torch.mean(out) - g_loss_rec += torch.mean(torch.abs(real_x2 - rec_x2)) - g_loss_cls += F.cross_entropy(out_cls2, fake_label2) - - # Backward + Optimize - g_loss = g_loss_fake + self.lambda_cls * g_loss_cls + self.lambda_rec * g_loss_rec + # Backward and optimize. + g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls self.reset_grad() g_loss.backward() self.g_optimizer.step() - # Logging + # Logging. loss['G/loss_fake'] = g_loss_fake.data[0] - loss['G/loss_cls'] = g_loss_cls.data[0] loss['G/loss_rec'] = g_loss_rec.data[0] + loss['G/loss_cls'] = g_loss_cls.data[0] - # Print out log info + # =================================================================================== # + # 4. Miscellaneous # + # =================================================================================== # + + # Print out training information. if (i+1) % self.log_step == 0: - elapsed = time.time() - start_time - elapsed = str(datetime.timedelta(seconds=elapsed)) - - log = "Elapsed [{}], Iter [{}/{}]".format( - elapsed, i+1, self.num_iters) - + et = time.time() - start_time + et = str(datetime.timedelta(seconds=et))[:-7] + log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) @@ -626,107 +319,267 @@ class Solver(object): for tag, value in loss.items(): self.logger.scalar_summary(tag, value, i+1) - # Translate the images (debugging) + # Translate fixed images for debugging. if (i+1) % self.sample_step == 0: - fake_image_list = [fixed_x] + x_fake_list = [x_fixed] + for c_fixed in c_fixed_list: + x_fake_list.append(self.G(x_fixed, c_fixed)) + x_concat = torch.cat(x_fake_list, dim=3) + sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1)) + save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0) + print('Saved real and fake images into {}...'.format(sample_path)) - # Changing hair color, gender, and age - for j in range(self.c_dim): - fake_c = torch.cat([fixed_c1_list[j], fixed_zero1, fixed_mask1], dim=1) - fake_image_list.append(self.G(fixed_x, fake_c)) - # Changing emotional expressions - for j in range(self.c2_dim): - fake_c = torch.cat([fixed_zero2, fixed_c2_list[j], fixed_mask2], dim=1) - fake_image_list.append(self.G(fixed_x, fake_c)) - fake = torch.cat(fake_image_list, dim=3) - - # Save the translated images - save_image(self.denorm(fake.data.cpu()), - os.path.join(self.sample_path, '{}_fake.png'.format(i+1)), nrow=1, padding=0) - - # Save model checkpoints + # Save model checkpoints. if (i+1) % self.model_save_step == 0: - torch.save(self.G.state_dict(), - os.path.join(self.model_save_path, '{}_G.pth'.format(i+1))) - torch.save(self.D.state_dict(), - os.path.join(self.model_save_path, '{}_D.pth'.format(i+1))) + G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1)) + D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1)) + torch.save(self.G.state_dict(), G_path) + torch.save(self.D.state_dict(), D_path) + print('Saved model checkpoints into {}...'.format(self.model_save_dir)) - # Decay learning rate - decay_step = 1000 - if (i+1) > (self.num_iters - self.num_iters_decay) and (i+1) % decay_step==0: - g_lr -= (self.g_lr / float(self.num_iters_decay) * decay_step) - d_lr -= (self.d_lr / float(self.num_iters_decay) * decay_step) + # Decay learning rates. + if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay): + g_lr -= (self.g_lr / float(self.num_iters_decay)) + d_lr -= (self.d_lr / float(self.num_iters_decay)) self.update_lr(g_lr, d_lr) - print ('Decay learning rate to g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) + print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) + + def train_multi(self): + """Train StarGAN with multiple datasets.""" + # Data iterators. + celeba_iter = iter(self.celeba_loader) + rafd_iter = iter(self.rafd_loader) + + # Fetch fixed inputs for debugging. + x_fixed, c_org = next(celeba_iter) + x_fixed = self.tensor2var(x_fixed, volatile=True) + c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs) + c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD') + zero_celeba = self.tensor2var(torch.zeros(x_fixed.size(0), self.c_dim)) # Zero vector for CelebA. + zero_rafd = self.tensor2var(torch.zeros(x_fixed.size(0), self.c2_dim)) # Zero vector for RaFD. + mask_celeba = self.tensor2var(self.label2onehot(torch.zeros(x_fixed.size(0)), 2)) # Mask vector: [1, 0]. + mask_rafd = self.tensor2var(self.label2onehot(torch.ones(x_fixed.size(0)), 2)) # Mask vector: [0, 1]. + + # Learning rate cache for decaying. + g_lr = self.g_lr + d_lr = self.d_lr + + # Start training from scratch or resume training. + start_iters = 0 + if self.resume_iters: + start_iters = self.resume_iters + self.restore_model(self.resume_iters) + + # Start training. + print('Start training...') + start_time = time.time() + for i in range(start_iters, self.num_iters): + for dataset in ['CelebA', 'RaFD']: + + # =================================================================================== # + # 1. Preprocess input data # + # =================================================================================== # + + # Fetch real images and labels. + data_iter = celeba_iter if dataset == 'CelebA' else rafd_iter + + try: + x_real, label_org = next(data_iter) + except: + if dataset == 'CelebA': + celeba_iter = iter(self.celeba_loader) + x_real, label_org = next(celeba_iter) + elif dataset == 'RaFD': + rafd_iter = iter(self.rafd_loader) + x_real, label_org = next(rafd_iter) + + # Generate target domain labels randomly. + rand_idx = torch.randperm(label_org.size(0)) + label_trg = label_org[rand_idx] + + if dataset == 'CelebA': + c_org = label_org.clone() + c_trg = label_trg.clone() + zero = torch.zeros(x_real.size(0), self.c2_dim) + mask = self.label2onehot(torch.zeros(x_real.size(0)), 2) + c_org = torch.cat([c_org, zero, mask], dim=1) + c_trg = torch.cat([c_trg, zero, mask], dim=1) + elif dataset == 'RaFD': + c_org = self.label2onehot(label_org, self.c2_dim) + c_trg = self.label2onehot(label_trg, self.c2_dim) + zero = torch.zeros(x_real.size(0), self.c_dim) + mask = self.label2onehot(torch.ones(x_real.size(0)), 2) + c_org = torch.cat([zero, c_org, mask], dim=1) + c_trg = torch.cat([zero, c_trg, mask], dim=1) + + x_real = self.tensor2var(x_real) # Input images. + c_org = self.tensor2var(c_org) # Original domain labels. + c_trg = self.tensor2var(c_trg) # Target domain labels. + label_org = self.tensor2var(label_org) # Labels for computing classification loss. + label_trg = self.tensor2var(label_trg) # Labels for computing classification loss. + + # =================================================================================== # + # 2. Train the discriminator # + # =================================================================================== # + + # Compute loss with real images. + out_src, out_cls = self.D(x_real) + out_cls = out_cls[:, :self.c_dim] if dataset == 'CelebA' else out_cls[:, self.c_dim:] + d_loss_real = - torch.mean(out_src) + d_loss_cls = self.classification_loss(out_cls, label_org, dataset) + + # Compute loss with fake images. + x_fake = self.G(x_real, c_trg) + out_src, _ = self.D(x_fake.detach()) + d_loss_fake = torch.mean(out_src) + + # Compute loss for gradient penalty. + alpha = torch.rand(x_real.size(0), 1, 1, 1).type(self.dtype) + x_hat = Variable(alpha * x_real.data + (1 - alpha) * x_fake.data, requires_grad=True) + out_src, _ = self.D(x_hat) + d_loss_gp = self.gradient_penalty(out_src, x_hat, self.dtype) + + # Backward and optimize. + d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp + self.reset_grad() + d_loss.backward() + self.d_optimizer.step() + + # Logging. + loss = {} + loss['D/loss_real'] = d_loss_real.data[0] + loss['D/loss_fake'] = d_loss_fake.data[0] + loss['D/loss_cls'] = d_loss_cls.data[0] + loss['D/loss_gp'] = d_loss_gp.data[0] + + # =================================================================================== # + # 3. Train the generator # + # =================================================================================== # + + if (i+1) % self.n_critic == 0: + # Original-to-target domain. + x_fake = self.G(x_real, c_trg) + out_src, out_cls = self.D(x_fake) + out_cls = out_cls[:, :self.c_dim] if dataset == 'CelebA' else out_cls[:, self.c_dim:] + g_loss_fake = - torch.mean(out_src) + g_loss_cls = self.classification_loss(out_cls, label_trg, dataset) + + # Target-to-original domain. + x_reconst = self.G(x_fake, c_org) + g_loss_rec = torch.mean(torch.abs(x_real - x_reconst)) + + # Backward and optimize. + g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls + self.reset_grad() + g_loss.backward() + self.g_optimizer.step() + + # Logging. + loss['G/loss_fake'] = g_loss_fake.data[0] + loss['G/loss_rec'] = g_loss_rec.data[0] + loss['G/loss_cls'] = g_loss_cls.data[0] + + # =================================================================================== # + # 4. Miscellaneous # + # =================================================================================== # + + # Print out training info. + if (i+1) % self.log_step == 0: + et = time.time() - start_time + et = str(datetime.timedelta(seconds=et))[:-7] + log = "Elapsed [{}], Iteration [{}/{}], Dataset [{}]".format(et, i+1, self.num_iters, dataset) + for tag, value in loss.items(): + log += ", {}: {:.4f}".format(tag, value) + print(log) + + if self.use_tensorboard: + for tag, value in loss.items(): + self.logger.scalar_summary(tag, value, i+1) + + # Translate fixed images for debugging. + if (i+1) % self.sample_step == 0: + x_fake_list = [x_fixed] + for c_fixed in c_celeba_list: + c_trg = torch.cat([c_fixed, zero_rafd, mask_celeba], dim=1) + x_fake_list.append(self.G(x_fixed, c_trg)) + for c_fixed in c_rafd_list: + c_trg = torch.cat([zero_celeba, c_fixed, mask_rafd], dim=1) + x_fake_list.append(self.G(x_fixed, c_trg)) + x_concat = torch.cat(x_fake_list, dim=3) + sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1)) + save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0) + print('Saved real and fake images into {}...'.format(sample_path)) + + # Save model checkpoints. + if (i+1) % self.model_save_step == 0: + G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1)) + D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1)) + torch.save(self.G.state_dict(), G_path) + torch.save(self.D.state_dict(), D_path) + print('Saved model checkpoints into {}...'.format(self.model_save_dir)) + + # Decay learning rates. + if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay): + g_lr -= (self.g_lr / float(self.num_iters_decay)) + d_lr -= (self.d_lr / float(self.num_iters_decay)) + self.update_lr(g_lr, d_lr) + print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) def test(self): - """Facial attribute transfer on CelebA or facial expression synthesis on RaFD.""" - # Load trained parameters - G_path = os.path.join(self.model_save_path, '{}_G.pth'.format(self.test_model)) - self.G.load_state_dict(torch.load(G_path)) - self.G.eval() - + """Translate images using StarGAN trained on a single dataset.""" + # Load the trained generator. + self.restore_model(self.test_iters) + + # Set data loader. if self.dataset == 'CelebA': - data_loader = self.celebA_loader - else: + data_loader = self.celeba_loader + elif self.dataset == 'RaFD': data_loader = self.rafd_loader - for i, (real_x, org_c) in enumerate(data_loader): - real_x = self.to_var(real_x, volatile=True) - - if self.dataset == 'CelebA': - target_c_list = self.make_celeb_labels(org_c) - else: - target_c_list = [] - for j in range(self.c_dim): - target_c = self.one_hot(torch.ones(real_x.size(0)) * j, self.c_dim) - target_c_list.append(self.to_var(target_c, volatile=True)) - - # Start translations - fake_image_list = [real_x] - for target_c in target_c_list: - fake_image_list.append(self.G(real_x, target_c)) - fake_images = torch.cat(fake_image_list, dim=3) - save_path = os.path.join(self.result_path, '{}_fake.png'.format(i+1)) - save_image(self.denorm(fake_images.data), save_path, nrow=1, padding=0) - print('Translated test images and saved into "{}"..!'.format(save_path)) + for i, (x_real, c_org) in enumerate(data_loader): + + # Prepare input images and target domain labels. + x_real = self.tensor2var(x_real, volatile=True) + c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs) + + # Translate images. + x_fake_list = [x_real] + for c_trg in c_trg_list: + x_fake_list.append(self.G(x_real, c_trg)) + + # Save the translated images. + x_concat = torch.cat(x_fake_list, dim=3) + result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1)) + save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0) + print('Saved real and fake images into {}...'.format(result_path)) def test_multi(self): - """Facial attribute transfer and expression synthesis on CelebA.""" - # Load trained parameters - G_path = os.path.join(self.model_save_path, '{}_G.pth'.format(self.test_model)) - self.G.load_state_dict(torch.load(G_path)) - self.G.eval() + """Translate images using StarGAN trained on multiple datasets.""" + # Load the trained generator. + self.restore_model(self.test_iters) - for i, (real_x, org_c) in enumerate(self.celebA_loader): + for i, (x_real, c_org) in enumerate(self.celeba_loader): - # Prepare input images and target domain labels - real_x = self.to_var(real_x, volatile=True) - target_c1_list = self.make_celeb_labels(org_c) - target_c2_list = [] - for j in range(self.c2_dim): - target_c = self.one_hot(torch.ones(real_x.size(0)) * j, self.c2_dim) - target_c2_list.append(self.to_var(target_c, volatile=True)) + # Prepare input images and target domain labels. + x_real = self.tensor2var(x_real, volatile=True) + c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs) + c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD') + zero_celeba = self.tensor2var(torch.zeros(x_real.size(0), self.c_dim)) # Zero vector for CelebA. + zero_rafd = self.tensor2var(torch.zeros(x_real.size(0), self.c2_dim)) # Zero vector for RaFD. + mask_celeba = self.tensor2var(self.label2onehot(torch.zeros(x_real.size(0)), 2)) # Mask vector: [1, 0]. + mask_rafd = self.tensor2var(self.label2onehot(torch.ones(x_real.size(0)), 2)) # Mask vector: [0, 1]. - # Zero vectors and mask vectors - zero1 = self.to_var(torch.zeros(real_x.size(0), self.c2_dim)) # zero vector for rafd expressions - mask1 = self.to_var(self.one_hot(torch.zeros(real_x.size(0)), 2)) # mask vector: [1, 0] - zero2 = self.to_var(torch.zeros(real_x.size(0), self.c_dim)) # zero vector for celebA attributes - mask2 = self.to_var(self.one_hot(torch.ones(real_x.size(0)), 2)) # mask vector: [0, 1] + # Translate images. + x_fake_list = [x_real] + for c_celeba in c_celeba_list: + c_trg = torch.cat([c_celeba, zero_rafd, mask_celeba], dim=1) + x_fake_list.append(self.G(x_real, c_trg)) + for c_rafd in c_rafd_list: + c_trg = torch.cat([zero_celeba, c_rafd, mask_rafd], dim=1) + x_fake_list.append(self.G(x_real, c_trg)) - # Changing hair color, gender, and age - fake_image_list = [real_x] - for j in range(self.c_dim): - target_c = torch.cat([target_c1_list[j], zero1, mask1], dim=1) - fake_image_list.append(self.G(real_x, target_c)) - - # Changing emotional expressions - for j in range(self.c2_dim): - target_c = torch.cat([zero2, target_c2_list[j], mask2], dim=1) - fake_image_list.append(self.G(real_x, target_c)) - fake_images = torch.cat(fake_image_list, dim=3) - - # Save the translated images - save_path = os.path.join(self.result_path, '{}_fake.png'.format(i+1)) - save_image(self.denorm(fake_images.data), save_path, nrow=1, padding=0) - print('Translated test images and saved into "{}"..!'.format(save_path)) + # Save the translated images. + x_concat = torch.cat(x_fake_list, dim=3) + result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1)) + save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0) + print('Saved real and fake images into {}...'.format(result_path)) \ No newline at end of file