Arquivos
2020-04-30 21:23:26 +01:00

744 linhas
36 KiB
Python

from model import Generator
from model import Discriminator
from torch.autograd import Variable
from torchvision.utils import save_image
import torch
import torch.nn.functional as F
import numpy as np
import os
import time
import datetime
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import pandas as pd
class LabelTransform(object):
def __init__(self, config):
self.image_dir = config.affectnet_image_dir
self.affectnet_emo_descr = config.affectnet_emo_descr
self.pca_variant = config.pca_variant
self.ss = StandardScaler()
self.pca = PCA(config.pca_n_components)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if self.affectnet_emo_descr in ['64d_reg', '64d_cls']:
actvs_train = np.array([[float(x) for x in line.rstrip().split()] for line in open(os.path.join(self.image_dir, self.affectnet_emo_descr, 'train', 'predictions.txt'), 'r')])
actvs_latent_train = self.pca.fit_transform(self.ss.fit_transform(actvs_train))
self.stats = pd.DataFrame(actvs_latent_train).describe(percentiles=[0.1, 0.5, 0.9]).T.drop(columns=['count']) #components in rows
def create_labels(self, c_org):
c_trg_list = []
c_org_latent = self.pca.transform(self.ss.transform(c_org))
if self.pca_variant == 'quantiles':
for i in range(self.pca.n_components_):
transformed_labels = c_org_latent
transformed_labels[:,i] = self.stats.loc[i,'10%']
c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device))
transformed_labels = c_org_latent
transformed_labels[:,i] = self.stats.loc[i,'50%']
c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device))
transformed_labels = c_org_latent
transformed_labels[:,i] = self.stats.loc[i,'90%']
c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device))
else:
for i in range(1, self.pca.n_components_):
transformed_labels = c_org_latent
transformed_labels[:,i:] = 0
c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device))
c_trg_list.append(torch.FloatTensor(c_org).to(self.device))
return c_trg_list
class Solver(object):
"""Solver for training and testing StarGAN."""
def __init__(self, celeba_loader, rafd_loader, affectnet_loader, config):
"""Initialize configurations."""
# Data loader.
self.celeba_loader = celeba_loader
self.rafd_loader = rafd_loader
self.affectnet_loader = affectnet_loader
# Model configurations.
self.c_dim = config.c_dim
self.c2_dim = config.c2_dim
self.image_size = config.image_size
self.g_conv_dim = config.g_conv_dim
self.d_conv_dim = config.d_conv_dim
self.g_repeat_num = config.g_repeat_num
self.d_repeat_num = config.d_repeat_num
self.lambda_cls = config.lambda_cls
self.lambda_rec = config.lambda_rec
self.lambda_gp = config.lambda_gp
self.lambda_cls2 = config.lambda_cls2
# Training configurations.
self.dataset = config.dataset
self.batch_size = config.batch_size
self.num_iters = config.num_iters
self.num_iters_decay = config.num_iters_decay
self.g_lr = config.g_lr
self.d_lr = config.d_lr
self.n_critic = config.n_critic
self.beta1 = config.beta1
self.beta2 = config.beta2
self.resume_iters = config.resume_iters
self.selected_attrs = config.selected_attrs
self.affectnet_emo_descr = config.affectnet_emo_descr
self.use_ccc = config.use_ccc
self.depth_concat = config.depth_concat
self.d_loss_cls_type = config.d_loss_cls_type
# Test configurations.
self.test_iters = config.test_iters
# Miscellaneous.
self.use_tensorboard = config.use_tensorboard
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if self.affectnet_emo_descr in ['64d_reg', '64d_cls']:
self.label_transform = LabelTransform(config)
# Directories.
self.log_dir = config.log_dir
self.sample_dir = config.sample_dir
self.model_save_dir = config.model_save_dir
self.result_dir = config.result_dir
# Step size.
self.log_step = config.log_step
self.sample_step = config.sample_step
self.model_save_step = config.model_save_step
self.lr_update_step = config.lr_update_step
# Build the model and tensorboard.
self.build_model()
if self.use_tensorboard:
self.build_tensorboard()
def build_model(self):
"""Create a generator and a discriminator."""
if self.dataset in ['CelebA', 'RaFD', 'AffectNet']:
self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num, self.depth_concat)
self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num, self.affectnet_emo_descr, self.d_loss_cls_type)
self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
self.print_network(self.G, 'G')
self.print_network(self.D, 'D')
self.G.to(self.device)
self.D.to(self.device)
def print_network(self, model, name):
"""Print out the network information."""
num_params = 0
for p in model.parameters():
num_params += p.numel()
print(model)
print(name)
print("The number of parameters: {}".format(num_params))
def restore_model(self, resume_iters):
"""Restore the trained generator and discriminator."""
print('Loading the trained models from step {}...'.format(resume_iters))
G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters))
D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters))
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
def build_tensorboard(self):
"""Build a tensorboard logger."""
from logger import Logger
self.logger = Logger(self.log_dir)
def update_lr(self, g_lr, d_lr):
"""Decay learning rates of the generator and discriminator."""
for param_group in self.g_optimizer.param_groups:
param_group['lr'] = g_lr
for param_group in self.d_optimizer.param_groups:
param_group['lr'] = d_lr
def reset_grad(self):
"""Reset the gradient buffers."""
self.g_optimizer.zero_grad()
self.d_optimizer.zero_grad()
def denorm(self, x):
"""Convert the range from [-1, 1] to [0, 1]."""
out = (x + 1) / 2
return out.clamp_(0, 1)
def gradient_penalty(self, y, x):
"""Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
weight = torch.ones(y.size()).to(self.device)
dydx = torch.autograd.grad(outputs=y,
inputs=x,
grad_outputs=weight,
retain_graph=True,
create_graph=True,
only_inputs=True)[0]
dydx = dydx.view(dydx.size(0), -1)
dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
return torch.mean((dydx_l2norm-1)**2)
def label2onehot(self, labels, dim):
"""Convert label indices to one-hot vectors."""
batch_size = labels.size(0)
out = torch.zeros(batch_size, dim)
out[np.arange(batch_size), labels.long()] = 1
return out
def create_labels(self, c_org, c_dim=5, dataset='CelebA', selected_attrs=None):
"""Generate target domain labels for debugging and testing."""
if dataset in ['CelebA', 'RaFD']:
# Get hair color indices.
if dataset == 'CelebA':
hair_color_indices = []
for i, attr_name in enumerate(selected_attrs):
if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
hair_color_indices.append(i)
c_trg_list = []
for i in range(c_dim):
if dataset == 'CelebA':
c_trg = c_org.clone()
if i in hair_color_indices: # Set one hair color to 1 and the rest to 0.
c_trg[:, i] = 1
for j in hair_color_indices:
if j != i:
c_trg[:, j] = 0
else:
c_trg[:, i] = (c_trg[:, i] == 0) # Reverse attribute value.
elif dataset == 'RaFD':
c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim)
c_trg_list.append(c_trg.to(self.device))
elif dataset == 'AffectNet':
c_trg_list = []
if self.affectnet_emo_descr == 'emotiw':
for i in range(c_dim):
c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim)
c_trg_list.append(c_trg.to(self.device))
elif self.affectnet_emo_descr == 'va':
for v in [-1., -0.5, 0., 0.5, 1.]:
for a in [-1., -0.5, 0., 0.5, 1.]:
c_trg = torch.Tensor([v, a]).repeat(c_org.size(0), 1)
c_trg_list.append(c_trg.to(self.device))
elif self.affectnet_emo_descr in ['64d_cls']:
# c_trg_list = np.loadtxt('/vol/bitbucket/apg416/affectnet/vgg2pcs.txt')
# c_trg_list = torch.FloatTensor(c_trg_list).to(self.device)
# c_trg_list = [label_vec.repeat(c_org.size(0), 1) for label_vec in torch.unbind(c_trg_list)]
c_trg_list = self.label_transform.create_labels(c_org[:,1:])
elif self.affectnet_emo_descr in ['64d_reg']:
c_trg_list = self.label_transform.create_labels(c_org[:,2:])
return c_trg_list
def ccc_loss(self, x, y):
mu_x = x.mean(0)
var_x = x.var(0)
mu_y = y.mean(0)
var_y = y.var(0)
cov_xy = torch.mean((x - mu_x.unsqueeze(0))*(y - mu_y.unsqueeze(0)), 0)
ccc = 2*cov_xy/(var_x + var_y + (mu_x - mu_y)**2)
loss = 1 - ccc
return loss.sum()
def classification_loss(self, logit, target, dataset='CelebA'):
"""Compute binary or softmax cross entropy loss."""
if dataset == 'CelebA':
return F.binary_cross_entropy_with_logits(logit, target, reduction='sum') / logit.size(0)
elif dataset == 'RaFD':
return F.cross_entropy(logit, target)
elif dataset == 'AffectNet':
if self.affectnet_emo_descr in ['emotiw']:
return F.cross_entropy(logit, target)
elif self.affectnet_emo_descr in ['va']:
if self.use_ccc:
return self.ccc_loss(logit, target)
else:
return F.mse_loss(logit, target)
elif self.affectnet_emo_descr in ['64d_reg']:
logit_pred, logit_actv = logit
target_pred, target_actv = target
out = 0
if logit_pred is not None:
if self.use_ccc:
out += self.lambda_cls2/self.lambda_cls*self.ccc_loss(logit_pred, target_pred)
else:
out += self.lambda_cls2/self.lambda_cls*F.mse_loss(logit_pred, target_pred)
if logit_actv is not None:
if self.use_ccc:
out += self.ccc_loss(logit_actv, target_actv)
else:
out += F.mse_loss(logit_actv, target_actv)
return out
elif self.affectnet_emo_descr in ['64d_cls']:
logit_pred, logit_actv = logit
target_pred, target_actv = target
out = 0
if logit_pred is not None:
out += self.lambda_cls2/self.lambda_cls*F.cross_entropy(logit_pred, target_pred)
if logit_actv is not None:
if self.use_ccc:
out += self.ccc_loss(logit_actv, target_actv)
else:
out += F.mse_loss(logit_actv, target_actv)
return out
def train(self):
"""Train StarGAN within a single dataset."""
# Set data loader.
if self.dataset == 'CelebA':
data_loader = self.celeba_loader
elif self.dataset == 'RaFD':
data_loader = self.rafd_loader
elif self.dataset == 'AffectNet':
data_loader = self.affectnet_loader
# Fetch fixed inputs for debugging.
data_iter = iter(data_loader)
x_fixed, c_org = next(data_iter)
x_fixed = x_fixed.to(self.device)
c_fixed_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs)
# Learning rate cache for decaying.
g_lr = self.g_lr
d_lr = self.d_lr
# Start training from scratch or resume training.
start_iters = 0
if self.resume_iters:
start_iters = self.resume_iters
self.restore_model(self.resume_iters)
# Start training.
print('Start training...')
start_time = time.time()
for i in range(start_iters, self.num_iters):
# =================================================================================== #
# 1. Preprocess input data #
# =================================================================================== #
# Fetch real images and labels.
try:
x_real, label_org = next(data_iter)
except:
data_iter = iter(data_loader)
x_real, label_org = next(data_iter)
# Generate target domain labels randomly.
rand_idx = torch.randperm(label_org.size(0))
label_trg = label_org[rand_idx]
if self.dataset == 'CelebA':
c_org = label_org.clone().to(self.device)
c_trg = label_trg.clone().to(self.device)
label_org = label_org.to(self.device) # Labels for computing classification loss.
label_trg = label_trg.to(self.device) # Labels for computing classification loss.
elif self.dataset == 'RaFD':
c_org = self.label2onehot(label_org, self.c_dim).to(self.device)
c_trg = self.label2onehot(label_trg, self.c_dim).to(self.device)
label_org = label_org.to(self.device) # Labels for computing classification loss.
label_trg = label_trg.to(self.device) # Labels for computing classification loss.
elif self.dataset == 'AffectNet':
if self.affectnet_emo_descr == 'emotiw':
c_org = self.label2onehot(label_org, self.c_dim).to(self.device)
c_trg = self.label2onehot(label_trg, self.c_dim).to(self.device)
label_org = label_org.to(self.device) # Labels for computing classification loss.
label_trg = label_trg.to(self.device) # Labels for computing classification loss.
elif self.affectnet_emo_descr == 'va':
c_org = label_org.clone().to(self.device)
c_trg = label_trg.clone().to(self.device)
label_org = label_org.to(self.device) # Labels for computing classification loss.
label_trg = label_trg.to(self.device) # Labels for computing classification loss.
elif self.affectnet_emo_descr == '64d_cls':
pred_org = label_org[:,0].clone().long().to(self.device)
pred_trg = label_trg[:,0].clone().long().to(self.device)
actv_org = label_org[:,1:].to(self.device)
actv_trg = label_trg[:,1:].to(self.device)
c_org = actv_org.clone().to(self.device)
c_trg = actv_trg.clone().to(self.device)
elif self.affectnet_emo_descr == '64d_reg':
pred_org = label_org[:,:2].clone().float().to(self.device)
pred_trg = label_trg[:,:2].clone().float().to(self.device)
actv_org = label_org[:,2:].to(self.device)
actv_trg = label_trg[:,2:].to(self.device)
c_org = actv_org.clone().to(self.device)
c_trg = actv_trg.clone().to(self.device)
x_real = x_real.to(self.device) # Input images.
# =================================================================================== #
# 2. Train the discriminator #
# =================================================================================== #
# Compute loss with real images.
out_src, out_cls = self.D(x_real)
d_loss_real = - torch.mean(out_src)
if self.affectnet_emo_descr in ['64d_reg', '64d_cls']:
if self.d_loss_cls_type == 'actv':
d_loss_cls = self.classification_loss((None, out_cls), (None, actv_org), self.dataset)
elif self.d_loss_cls_type == 'pred':
d_loss_cls = self.classification_loss((self.D.fc(out_cls), None), (pred_org, None), self.dataset)
else:
d_loss_cls = self.classification_loss((self.D.fc(out_cls), out_cls), (pred_org, actv_org), self.dataset)
else:
d_loss_cls = self.classification_loss(out_cls, label_org, self.dataset)
# Compute loss with fake images.
x_fake = self.G(x_real, c_trg)
out_src, out_cls = self.D(x_fake.detach())
d_loss_fake = torch.mean(out_src)
# Compute loss for gradient penalty.
alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
out_src, _ = self.D(x_hat)
d_loss_gp = self.gradient_penalty(out_src, x_hat)
# Backward and optimize.
d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp
self.reset_grad()
d_loss.backward()
self.d_optimizer.step()
# Logging.
loss = {}
loss['D/loss_real'] = d_loss_real.item()
loss['D/loss_fake'] = d_loss_fake.item()
loss['D/loss_cls'] = d_loss_cls.item()
loss['D/loss_gp'] = d_loss_gp.item()
# =================================================================================== #
# 3. Train the generator #
# =================================================================================== #
if (i+1) % self.n_critic == 0:
# Original-to-target domain.
x_fake = self.G(x_real, c_trg)
out_src, out_cls = self.D(x_fake)
g_loss_fake = - torch.mean(out_src)
if self.affectnet_emo_descr in ['64d_reg', '64d_cls']:
if self.d_loss_cls_type == 'actv':
g_loss_cls = self.classification_loss((None, out_cls), (None, actv_trg), self.dataset)
elif self.d_loss_cls_type == 'pred':
g_loss_cls = self.classification_loss((self.D.fc(out_cls), None), (pred_trg, None), self.dataset)
else:
g_loss_cls = self.classification_loss((self.D.fc(out_cls), out_cls), (pred_trg, actv_trg), self.dataset)
else:
g_loss_cls = self.classification_loss(out_cls, label_trg, self.dataset)
# Target-to-original domain.
x_reconst = self.G(x_fake, c_org)
g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))
# Backward and optimize.
g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
self.reset_grad()
g_loss.backward()
self.g_optimizer.step()
# Logging.
loss['G/loss_fake'] = g_loss_fake.item()
loss['G/loss_rec'] = g_loss_rec.item()
loss['G/loss_cls'] = g_loss_cls.item()
# =================================================================================== #
# 4. Miscellaneous #
# =================================================================================== #
# Print out training information.
if (i+1) % self.log_step == 0:
et = time.time() - start_time
et = str(datetime.timedelta(seconds=et))[:-7]
log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters)
for tag, value in loss.items():
log += ", {}: {:.4f}".format(tag, value)
print(log)
if self.use_tensorboard:
for tag, value in loss.items():
self.logger.scalar_summary(tag, value, i+1)
# Translate fixed images for debugging.
if (i+1) % self.sample_step == 0:
with torch.no_grad():
x_fake_list = [x_fixed]
for c_fixed in c_fixed_list:
x_fake_list.append(self.G(x_fixed, c_fixed))
x_concat = torch.cat(x_fake_list, dim=3)
sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1))
save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
print('Saved real and fake images into {}...'.format(sample_path))
# Save model checkpoints.
if (i+1) % self.model_save_step == 0:
G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1))
D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1))
torch.save(self.G.state_dict(), G_path)
torch.save(self.D.state_dict(), D_path)
print('Saved model checkpoints into {}...'.format(self.model_save_dir))
# Decay learning rates.
if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay):
g_lr -= (self.g_lr / float(self.num_iters_decay))
d_lr -= (self.d_lr / float(self.num_iters_decay))
self.update_lr(g_lr, d_lr)
print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
# def train_multi(self):
# """Train StarGAN with multiple datasets."""
# # Data iterators.
# celeba_iter = iter(self.celeba_loader)
# rafd_iter = iter(self.rafd_loader)
#
# # Fetch fixed inputs for debugging.
# x_fixed, c_org = next(celeba_iter)
# x_fixed = x_fixed.to(self.device)
# c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs)
# c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD')
# zero_celeba = torch.zeros(x_fixed.size(0), self.c_dim).to(self.device) # Zero vector for CelebA.
# zero_rafd = torch.zeros(x_fixed.size(0), self.c2_dim).to(self.device) # Zero vector for RaFD.
# mask_celeba = self.label2onehot(torch.zeros(x_fixed.size(0)), 2).to(self.device) # Mask vector: [1, 0].
# mask_rafd = self.label2onehot(torch.ones(x_fixed.size(0)), 2).to(self.device) # Mask vector: [0, 1].
#
# # Learning rate cache for decaying.
# g_lr = self.g_lr
# d_lr = self.d_lr
#
# # Start training from scratch or resume training.
# start_iters = 0
# if self.resume_iters:
# start_iters = self.resume_iters
# self.restore_model(self.resume_iters)
#
# # Start training.
# print('Start training...')
# start_time = time.time()
# for i in range(start_iters, self.num_iters):
# for dataset in ['CelebA', 'RaFD']:
#
# # =================================================================================== #
# # 1. Preprocess input data #
# # =================================================================================== #
#
# # Fetch real images and labels.
# data_iter = celeba_iter if dataset == 'CelebA' else rafd_iter
#
# try:
# x_real, label_org = next(data_iter)
# except:
# if dataset == 'CelebA':
# celeba_iter = iter(self.celeba_loader)
# x_real, label_org = next(celeba_iter)
# elif dataset == 'RaFD':
# rafd_iter = iter(self.rafd_loader)
# x_real, label_org = next(rafd_iter)
#
# # Generate target domain labels randomly.
# rand_idx = torch.randperm(label_org.size(0))
# label_trg = label_org[rand_idx]
#
# if dataset == 'CelebA':
# c_org = label_org.clone()
# c_trg = label_trg.clone()
# zero = torch.zeros(x_real.size(0), self.c2_dim)
# mask = self.label2onehot(torch.zeros(x_real.size(0)), 2)
# c_org = torch.cat([c_org, zero, mask], dim=1)
# c_trg = torch.cat([c_trg, zero, mask], dim=1)
# elif dataset == 'RaFD':
# c_org = self.label2onehot(label_org, self.c2_dim)
# c_trg = self.label2onehot(label_trg, self.c2_dim)
# zero = torch.zeros(x_real.size(0), self.c_dim)
# mask = self.label2onehot(torch.ones(x_real.size(0)), 2)
# c_org = torch.cat([zero, c_org, mask], dim=1)
# c_trg = torch.cat([zero, c_trg, mask], dim=1)
#
# x_real = x_real.to(self.device) # Input images.
# c_org = c_org.to(self.device) # Original domain labels.
# c_trg = c_trg.to(self.device) # Target domain labels.
# label_org = label_org.to(self.device) # Labels for computing classification loss.
# label_trg = label_trg.to(self.device) # Labels for computing classification loss.
#
# # =================================================================================== #
# # 2. Train the discriminator #
# # =================================================================================== #
#
# # Compute loss with real images.
# out_src, out_cls = self.D(x_real)
# out_cls = out_cls[:, :self.c_dim] if dataset == 'CelebA' else out_cls[:, self.c_dim:]
# d_loss_real = - torch.mean(out_src)
# d_loss_cls = self.classification_loss(out_cls, label_org, dataset)
#
# # Compute loss with fake images.
# x_fake = self.G(x_real, c_trg)
# out_src, _ = self.D(x_fake.detach())
# d_loss_fake = torch.mean(out_src)
#
# # Compute loss for gradient penalty.
# alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
# x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
# out_src, _ = self.D(x_hat)
# d_loss_gp = self.gradient_penalty(out_src, x_hat)
#
# # Backward and optimize.
# d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp
# self.reset_grad()
# d_loss.backward()
# self.d_optimizer.step()
#
# # Logging.
# loss = {}
# loss['D/loss_real'] = d_loss_real.item()
# loss['D/loss_fake'] = d_loss_fake.item()
# loss['D/loss_cls'] = d_loss_cls.item()
# loss['D/loss_gp'] = d_loss_gp.item()
#
# # =================================================================================== #
# # 3. Train the generator #
# # =================================================================================== #
#
# if (i+1) % self.n_critic == 0:
# # Original-to-target domain.
# x_fake = self.G(x_real, c_trg)
# out_src, out_cls = self.D(x_fake)
# out_cls = out_cls[:, :self.c_dim] if dataset == 'CelebA' else out_cls[:, self.c_dim:]
# g_loss_fake = - torch.mean(out_src)
# g_loss_cls = self.classification_loss(out_cls, label_trg, dataset)
#
# # Target-to-original domain.
# x_reconst = self.G(x_fake, c_org)
# g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))
#
# # Backward and optimize.
# g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
# self.reset_grad()
# g_loss.backward()
# self.g_optimizer.step()
#
# # Logging.
# loss['G/loss_fake'] = g_loss_fake.item()
# loss['G/loss_rec'] = g_loss_rec.item()
# loss['G/loss_cls'] = g_loss_cls.item()
#
# # =================================================================================== #
# # 4. Miscellaneous #
# # =================================================================================== #
#
# # Print out training info.
# if (i+1) % self.log_step == 0:
# et = time.time() - start_time
# et = str(datetime.timedelta(seconds=et))[:-7]
# log = "Elapsed [{}], Iteration [{}/{}], Dataset [{}]".format(et, i+1, self.num_iters, dataset)
# for tag, value in loss.items():
# log += ", {}: {:.4f}".format(tag, value)
# print(log)
#
# if self.use_tensorboard:
# for tag, value in loss.items():
# self.logger.scalar_summary(tag, value, i+1)
#
# # Translate fixed images for debugging.
# if (i+1) % self.sample_step == 0:
# with torch.no_grad():
# x_fake_list = [x_fixed]
# for c_fixed in c_celeba_list:
# c_trg = torch.cat([c_fixed, zero_rafd, mask_celeba], dim=1)
# x_fake_list.append(self.G(x_fixed, c_trg))
# for c_fixed in c_rafd_list:
# c_trg = torch.cat([zero_celeba, c_fixed, mask_rafd], dim=1)
# x_fake_list.append(self.G(x_fixed, c_trg))
# x_concat = torch.cat(x_fake_list, dim=3)
# sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1))
# save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
# print('Saved real and fake images into {}...'.format(sample_path))
#
# # Save model checkpoints.
# if (i+1) % self.model_save_step == 0:
# G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1))
# D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1))
# torch.save(self.G.state_dict(), G_path)
# torch.save(self.D.state_dict(), D_path)
# print('Saved model checkpoints into {}...'.format(self.model_save_dir))
#
# # Decay learning rates.
# if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay):
# g_lr -= (self.g_lr / float(self.num_iters_decay))
# d_lr -= (self.d_lr / float(self.num_iters_decay))
# self.update_lr(g_lr, d_lr)
# print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
def test(self):
"""Translate images using StarGAN trained on a single dataset."""
# Load the trained generator.
self.restore_model(self.test_iters)
# Set data loader.
if self.dataset == 'CelebA':
data_loader = self.celeba_loader
elif self.dataset == 'RaFD':
data_loader = self.rafd_loader
elif self.dataset == 'AffectNet':
data_loader = self.affectnet_loader
with torch.no_grad():
for i, (x_real, c_org) in enumerate(data_loader):
# Prepare input images and target domain labels.
x_real = x_real.to(self.device)
c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs)
# Translate images.
x_fake_list = [x_real]
for c_trg in c_trg_list:
x_fake_list.append(self.G(x_real, c_trg))
# Save the translated images.
x_concat = torch.cat(x_fake_list, dim=3)
result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
# result_path = os.path.join(self.result_dir, '{}.jpg'.format(self.model_save_dir.replace("/vol/bitbucket/apg416/","").replace("/models","")))
save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
print('Saved real and fake images into {}...'.format(result_path))
# def test_multi(self):
# """Translate images using StarGAN trained on multiple datasets."""
# # Load the trained generator.
# self.restore_model(self.test_iters)
#
# with torch.no_grad():
# for i, (x_real, c_org) in enumerate(self.celeba_loader):
#
# # Prepare input images and target domain labels.
# x_real = x_real.to(self.device)
# c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs)
# c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD')
# zero_celeba = torch.zeros(x_real.size(0), self.c_dim).to(self.device) # Zero vector for CelebA.
# zero_rafd = torch.zeros(x_real.size(0), self.c2_dim).to(self.device) # Zero vector for RaFD.
# mask_celeba = self.label2onehot(torch.zeros(x_real.size(0)), 2).to(self.device) # Mask vector: [1, 0].
# mask_rafd = self.label2onehot(torch.ones(x_real.size(0)), 2).to(self.device) # Mask vector: [0, 1].
#
# # Translate images.
# x_fake_list = [x_real]
# for c_celeba in c_celeba_list:
# c_trg = torch.cat([c_celeba, zero_rafd, mask_celeba], dim=1)
# x_fake_list.append(self.G(x_real, c_trg))
# for c_rafd in c_rafd_list:
# c_trg = torch.cat([zero_celeba, c_rafd, mask_rafd], dim=1)
# x_fake_list.append(self.G(x_real, c_trg))
#
# # Save the translated images.
# x_concat = torch.cat(x_fake_list, dim=3)
# result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
# save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
# print('Saved real and fake images into {}...'.format(result_path))