Pre-cleaning

Esse commit está contido em:
alexbrx
2020-04-30 20:27:30 +01:00
commit 1784bbea9b
4 arquivos alterados com 230 adições e 54 exclusões
+15 -4
Ver Arquivo
@@ -112,9 +112,20 @@ class AffectNet(data.Dataset):
elif self.affectnet_emo_descr in ['va-reg', 'va-cls']:
for mode, dataset in zip(['train_', 'test_'], [self.train_dataset, self.test_dataset]):
filenames = [line.rstrip() for line in open(os.path.join(self.image_dir, mode + 'images.txt'), 'r')]
labels = [int(line.rstrip()) for line in open(os.path.join(self.image_dir, mode + 'labels.txt'), 'r')]
labels = [[int(line.rstrip())] for line in open(os.path.join(self.image_dir, mode + 'labels.txt'), 'r')]
predictions = [[float(x) for x in line.rstrip().split()] for line in open(os.path.join(self.image_dir, mode + 'predictions.txt'), 'r')]
dataset += list(zip(filenames, labels, predictions))
elif self.affectnet_emo_descr == '64d_cls':
for mode, dataset in zip(['train', 'test'], [self.train_dataset, self.test_dataset]):
filenames = [line.rstrip() for line in open(os.path.join(self.image_dir, self.affectnet_emo_descr, mode, 'images.txt'), 'r')]
labels = [[int(line.rstrip())] for line in open(os.path.join(self.image_dir, self.affectnet_emo_descr, mode, 'labels.txt'), 'r')]
predictions = [[float(x) for x in line.rstrip().split()] for line in open(os.path.join(self.image_dir, self.affectnet_emo_descr, mode, 'predictions.txt'), 'r')]
dataset += list(zip(filenames, labels, predictions))
elif self.affectnet_emo_descr == '64d_reg':
for mode, dataset in zip(['train', 'test'], [self.train_dataset, self.test_dataset]):
filenames = [line.rstrip() for line in open(os.path.join(self.image_dir, self.affectnet_emo_descr, mode, 'images.txt'), 'r')]
labels = [[float(x) for x in line.rstrip().split()] for line in open(os.path.join(self.image_dir, self.affectnet_emo_descr, mode, 'labels.txt'), 'r')]
predictions = [[float(x) for x in line.rstrip().split()] for line in open(os.path.join(self.image_dir, self.affectnet_emo_descr, mode, 'predictions.txt'), 'r')]
dataset += list(zip(filenames, labels, predictions))
print('Finished preprocessing the AffectNet dataset...')
@@ -131,7 +142,7 @@ class AffectNet(data.Dataset):
filename, label = dataset[index]
image = Image.open(os.path.join(self.image_dir, 'validation', filename))
return self.transform(image), torch.tensor(label)
elif self.affectnet_emo_descr in ['va-reg', 'va-cls']:
elif self.affectnet_emo_descr in ['va-reg', 'va-cls', '64d_reg', '64d_cls']:
if self.mode == 'train':
dataset = self.train_dataset
filename, label, prediction = dataset[index]
@@ -140,7 +151,7 @@ class AffectNet(data.Dataset):
dataset = self.test_dataset
filename, label, prediction = dataset[index]
image = Image.open(os.path.join(self.image_dir, 'validation', filename))
return self.transform(image), torch.tensor([label] + prediction)
return self.transform(image), torch.tensor(label + prediction)
def __len__(self):
"""Return the number of images."""
+8 -2
Ver Arquivo
@@ -71,7 +71,7 @@ if __name__ == '__main__':
parser.add_argument('--lambda_cls', type=float, default=1, help='weight for domain classification loss')
parser.add_argument('--lambda_rec', type=float, default=10, help='weight for reconstruction loss')
parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty')
parser.add_argument('--lambda_cls2', type=float, default=1, help='second (extra) weight for domain classification loss')
# Training configuration.
parser.add_argument('--dataset', type=str, default='CelebA', choices=['CelebA', 'RaFD', 'AffectNet', 'Both'])
parser.add_argument('--batch_size', type=int, default=16, help='mini-batch size')
@@ -85,11 +85,17 @@ if __name__ == '__main__':
parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this step')
parser.add_argument('--selected_attrs', '--list', nargs='+', help='selected attributes for the CelebA dataset',
default=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'])
#
parser.add_argument('--affectnet_emo_descr', type=str, default='emotiw', help='emotiw for categorical emotions, va for valence-arousal')
parser.add_argument('--use_ccc', type=bool, default=False, help='use ccc loss instead of mse')
parser.add_argument('--depth_concat', type=bool, default=True, help='perform depthwise concatenation in G')
parser.add_argument('--d_loss_cls_type', type=str, default='actv', help='chosen terms for classification loss')
# Test configuration.
parser.add_argument('--test_iters', type=int, default=200000, help='test model from this step')
#
parser.add_argument('--pca_n_components', type=int, default=0.99, help='PCA visualization number of components kept')
parser.add_argument('--pca_variant', type=str, default='quantiles', help='PCA visualization variant')
# Miscellaneous.
parser.add_argument('--num_workers', type=int, default=1)
parser.add_argument('--mode', type=str, default='train', choices=['train', 'test'])
+25 -5
Ver Arquivo
@@ -21,11 +21,15 @@ class ResidualBlock(nn.Module):
class Generator(nn.Module):
"""Generator network."""
def __init__(self, conv_dim=64, c_dim=5, repeat_num=6):
def __init__(self, conv_dim=64, c_dim=5, repeat_num=6, depth_concat=True):
super(Generator, self).__init__()
self.depth_concat = depth_concat
layers = []
layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
if self.depth_concat:
layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
else:
layers.append(nn.Conv2d(4, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
layers.append(nn.InstanceNorm2d(conv_dim, affine=True, track_running_stats=True))
layers.append(nn.ReLU(inplace=True))
@@ -56,15 +60,26 @@ class Generator(nn.Module):
# Replicate spatially and concatenate domain information.
# Note that this type of label conditioning does not work at all if we use reflection padding in Conv2d.
# This is because instance normalization ignores the shifting (or bias) effect.
c = c.view(c.size(0), c.size(1), 1, 1)
c = c.repeat(1, 1, x.size(2), x.size(3))
if self.depth_concat:
c = c.view(c.size(0), c.size(1), 1, 1)
c = c.repeat(1, 1, x.size(2), x.size(3))
else:
if c.size(1) == 2: #labels are assumed to have length 2 or 64 and images are 112x112
# c = c.view(c.size(0), 1, 1, 2)
# c = c.repeat(1, 1, 112, 56)
# c = c.view(-1).repeat_interleave(8).view(c.size(0),1,1,-1).repeat(1,1,112,7)
c = c.view(-1).repeat_interleave(56).view(c.size(0),1,1,-1).repeat(1,1,112,1)
elif c.size(1) == 64:
c = c.view(c.size(0), 1, 8, 8)
c = c.repeat(1, 1, 14, 14)
x = torch.cat([x, c], dim=1)
return self.main(x)
class Discriminator(nn.Module):
"""Discriminator network with PatchGAN."""
def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6):
def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6, affectnet_emo_descr=None, d_loss_cls_type=None):
super(Discriminator, self).__init__()
layers = []
layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
@@ -81,6 +96,11 @@ class Discriminator(nn.Module):
self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False)
self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=kernel_size, bias=False)
if affectnet_emo_descr == '64d_cls' and (d_loss_cls_type in ['both', 'pred']):
self.fc = torch.nn.Linear(c_dim, 7)
elif affectnet_emo_descr == '64d_reg' and (d_loss_cls_type in ['both', 'pred']):
self.fc = torch.nn.Linear(c_dim, 2)
def forward(self, x):
h = self.main(x)
out_src = self.conv1(h)
+180 -41
Ver Arquivo
@@ -10,24 +10,71 @@ import time
import datetime
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import pandas as pd
class LabelTransform(object):
def __init__(self, max_components_kept=14):
self.image_dir = '/vol/bitbucket/apg416/affectnet'
self.max_components_kept = max_components_kept
def __init__(self, config):
self.image_dir = config.affectnet_image_dir
self.affectnet_emo_descr = config.affectnet_emo_descr
self.pca_variant = config.pca_variant
self.ss = StandardScaler()
self.pca = PCA()
self.pca = PCA(config.pca_n_components)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
predictions_train = np.array([[float(x) for x in line.rstrip().split()] for line in open(os.path.join(self.image_dir,'train_predictions.txt'), 'r')])
self.pca.fit(self.ss.fit_transform(predictions_train))
if self.affectnet_emo_descr in ['64d_reg', '64d_cls']:
actvs_train = np.array([[float(x) for x in line.rstrip().split()] for line in open(os.path.join(self.image_dir, self.affectnet_emo_descr, 'train', 'predictions.txt'), 'r')])
elif self.affectnet_emo_descr in ['va-reg', 'va-cls']:
actvs_train = np.array([[float(x) for x in line.rstrip().split()] for line in open(os.path.join(self.image_dir,'train_predictions.txt'), 'r')])
actvs_latent_train = self.pca.fit_transform(self.ss.fit_transform(actvs_train))
self.stats = pd.DataFrame(actvs_latent_train).describe(percentiles=[0.1, 0.5, 0.9]).T.drop(columns=['count']) #components in rows
def create_labels(self, c_org):
c_trg_list = []
for i in range(self.max_components_kept):
transformed_labels = self.pca.transform(self.ss.transform(c_org))
transformed_labels[:,1+i:] = 0
c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device))
c_org_latent = self.pca.transform(self.ss.transform(c_org))
if self.pca_variant == 'quantiles':
for i in range(self.pca.n_components_):
# transformed_labels_lo, transformed_labels_hi = c_org_latent, c_org_latent
# transformed_labels_lo[:,i] = self.stats.loc[i,'10%']
# c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels_lo))).to(self.device))
# transformed_labels_lo[:,i] = self.stats.loc[i,'50%']
# c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels_lo))).to(self.device))
# transformed_labels_hi[:,i] = self.stats.loc[i,'90%']
# c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels_hi))).to(self.device))
transformed_labels = c_org_latent
transformed_labels[:,i] = self.stats.loc[i,'10%']
c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device))
transformed_labels = c_org_latent
transformed_labels[:,i] = self.stats.loc[i,'50%']
c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device))
transformed_labels = c_org_latent
transformed_labels[:,i] = self.stats.loc[i,'90%']
c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device))
# transformed_labels = c_org_latent
# transformed_labels[:,0] = self.stats.loc[i,'10%']
# transformed_labels[:,1] = self.stats.loc[i,'10%']
# c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device))
# transformed_labels = c_org_latent
# transformed_labels[:,0] = self.stats.loc[i,'90%']
# transformed_labels[:,1] = self.stats.loc[i,'90%']
# c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device))
# transformed_labels = c_org_latent
# transformed_labels[:,0] = self.stats.loc[i,'10%']
# transformed_labels[:,1] = self.stats.loc[i,'90%']
# c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device))
# transformed_labels = c_org_latent
# transformed_labels[:,0] = self.stats.loc[i,'90%']
# transformed_labels[:,1] = self.stats.loc[i,'10%']
# c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device))
else:
for i in range(1, self.pca.n_components_):
transformed_labels = c_org_latent
transformed_labels[:,i:] = 0
c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device))
c_trg_list.append(torch.FloatTensor(c_org).to(self.device))
return c_trg_list
class Solver(object):
@@ -51,6 +98,7 @@ class Solver(object):
self.lambda_cls = config.lambda_cls
self.lambda_rec = config.lambda_rec
self.lambda_gp = config.lambda_gp
self.lambda_cls2 = config.lambda_cls2
# Training configurations.
self.dataset = config.dataset
@@ -66,13 +114,16 @@ class Solver(object):
self.selected_attrs = config.selected_attrs
self.affectnet_emo_descr = config.affectnet_emo_descr
self.use_ccc = config.use_ccc
self.depth_concat = config.depth_concat
self.d_loss_cls_type = config.d_loss_cls_type
# Test configurations.
self.test_iters = config.test_iters
# Miscellaneous.
self.use_tensorboard = config.use_tensorboard
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.label_transform = LabelTransform()
if self.affectnet_emo_descr in ['64d_reg', '64d_cls', 'va-reg', 'va-cls']:
self.label_transform = LabelTransform(config)
# Directories.
self.log_dir = config.log_dir
@@ -94,11 +145,17 @@ class Solver(object):
def build_model(self):
"""Create a generator and a discriminator."""
if self.dataset in ['CelebA', 'RaFD', 'AffectNet']:
self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num)
elif self.dataset in ['Both']:
self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num) # 2 for mask vector.
self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num)
self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num, self.depth_concat)
self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num, self.affectnet_emo_descr, self.d_loss_cls_type)
# if self.affectnet_emo_descr == '64d_cls' and (self.d_loss_cls_type in ['both', 'pred']):
# self.D.fc = torch.nn.Linear(self.c_dim, 7)
# elif self.affectnet_emo_descr == '64d_reg' and (self.d_loss_cls_type in ['both', 'pred']):
# self.D.fc = torch.nn.Linear(self.c_dim, 2)
# elif self.dataset in ['Both']:
# self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num) # 2 for mask vector.
# self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num)
self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
@@ -205,11 +262,13 @@ class Solver(object):
for a in [-1., -0.5, 0., 0.5, 1.]:
c_trg = torch.Tensor([v, a]).repeat(c_org.size(0), 1)
c_trg_list.append(c_trg.to(self.device))
elif self.affectnet_emo_descr in ['va-reg', 'va-cls']:
elif self.affectnet_emo_descr in ['va-reg', 'va-cls', '64d_cls']:
# c_trg_list = np.loadtxt('/vol/bitbucket/apg416/affectnet/vgg2pcs.txt')
# c_trg_list = torch.FloatTensor(c_trg_list).to(self.device)
# c_trg_list = [label_vec.repeat(c_org.size(0), 1) for label_vec in torch.unbind(c_trg_list)]
c_trg_list = self.label_transform.create_labels(c_org[:,1:])
elif self.affectnet_emo_descr in ['64d_reg']:
c_trg_list = self.label_transform.create_labels(c_org[:,2:])
return c_trg_list
def ccc_loss(self, x, y):
@@ -236,6 +295,34 @@ class Solver(object):
return self.ccc_loss(logit, target)
else:
return F.mse_loss(logit, target)
elif self.affectnet_emo_descr in ['64d_reg']:
logit_pred, logit_actv = logit
target_pred, target_actv = target
out = 0
if logit_pred is not None:
if self.use_ccc:
out += self.lambda_cls2/self.lambda_cls*self.ccc_loss(logit_pred, target_pred)
else:
out += self.lambda_cls2/self.lambda_cls*F.mse_loss(logit_pred, target_pred)
if logit_actv is not None:
if self.use_ccc:
out += self.ccc_loss(logit_actv, target_actv)
else:
out += F.mse_loss(logit_actv, target_actv)
return out
elif self.affectnet_emo_descr in ['64d_cls']:
logit_pred, logit_actv = logit
target_pred, target_actv = target
out = 0
if logit_pred is not None:
out += self.lambda_cls2/self.lambda_cls*F.cross_entropy(logit_pred, target_pred)
if logit_actv is not None:
if self.use_ccc:
out += self.ccc_loss(logit_actv, target_actv)
else:
out += F.mse_loss(logit_actv, target_actv)
return out
def train(self):
"""Train StarGAN within a single dataset."""
# Set data loader.
@@ -262,6 +349,18 @@ class Solver(object):
start_iters = self.resume_iters
self.restore_model(self.resume_iters)
# # --- converting
# if self.affectnet_emo_descr == '64d_cls' and (self.d_loss_cls_type in ['both', 'pred']):
# self.D.fc = torch.nn.Linear(self.c_dim, 7)
# elif self.affectnet_emo_descr == '64d_reg' and (self.d_loss_cls_type in ['both', 'pred']):
# self.D.fc = torch.nn.Linear(self.c_dim, 2)
# G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(self.resume_iters))
# D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(self.resume_iters))
# torch.save(self.G.state_dict(), G_path)
# torch.save(self.D.state_dict(), D_path)
# return 0
# # ---
# Start training.
print('Start training...')
start_time = time.time()
@@ -283,36 +382,56 @@ class Solver(object):
label_trg = label_org[rand_idx]
if self.dataset == 'CelebA':
c_org = label_org.clone()
c_trg = label_trg.clone()
c_org = label_org.clone().to(self.device)
c_trg = label_trg.clone().to(self.device)
label_org = label_org.to(self.device) # Labels for computing classification loss.
label_trg = label_trg.to(self.device) # Labels for computing classification loss.
elif self.dataset == 'RaFD':
c_org = self.label2onehot(label_org, self.c_dim)
c_trg = self.label2onehot(label_trg, self.c_dim)
c_org = self.label2onehot(label_org, self.c_dim).to(self.device)
c_trg = self.label2onehot(label_trg, self.c_dim).to(self.device)
label_org = label_org.to(self.device) # Labels for computing classification loss.
label_trg = label_trg.to(self.device) # Labels for computing classification loss.
elif self.dataset == 'AffectNet':
if self.affectnet_emo_descr == 'emotiw':
c_org = self.label2onehot(label_org, self.c_dim)
c_trg = self.label2onehot(label_trg, self.c_dim)
c_org = self.label2onehot(label_org, self.c_dim).to(self.device)
c_trg = self.label2onehot(label_trg, self.c_dim).to(self.device)
label_org = label_org.to(self.device) # Labels for computing classification loss.
label_trg = label_trg.to(self.device) # Labels for computing classification loss.
elif self.affectnet_emo_descr == 'va':
c_org = label_org.clone()
c_trg = label_trg.clone()
c_org = label_org.clone().to(self.device)
c_trg = label_trg.clone().to(self.device)
label_org = label_org.to(self.device) # Labels for computing classification loss.
label_trg = label_trg.to(self.device) # Labels for computing classification loss.
elif self.affectnet_emo_descr == 'va-reg':
label_org = label_org[:,1:]
label_trg = label_trg[:,1:]
c_org = label_org.clone()
c_trg = label_trg.clone()
label_org = label_org[:,1:].to(self.device)
label_trg = label_trg[:,1:].to(self.device)
c_org = label_org.clone().to(self.device)
c_trg = label_trg.clone().to(self.device)
elif self.affectnet_emo_descr == 'va-cls':
cls_org = label_org[:,0].clone().long().to(self.device)
cls_trg = label_trg[:,0].clone().long().to(self.device)
label_org = label_org[:,1:]
label_trg = label_trg[:,1:]
c_org = label_org.clone()
c_trg = label_trg.clone()
pred_org = label_org[:,0].clone().long().to(self.device)
pred_trg = label_trg[:,0].clone().long().to(self.device)
label_org = label_org[:,1:].to(self.device)
label_trg = label_trg[:,1:].to(self.device)
c_org = label_org.clone().to(self.device)
c_trg = label_trg.clone().to(self.device)
elif self.affectnet_emo_descr == '64d_cls':
pred_org = label_org[:,0].clone().long().to(self.device)
pred_trg = label_trg[:,0].clone().long().to(self.device)
actv_org = label_org[:,1:].to(self.device)
actv_trg = label_trg[:,1:].to(self.device)
c_org = actv_org.clone().to(self.device)
c_trg = actv_trg.clone().to(self.device)
elif self.affectnet_emo_descr == '64d_reg':
pred_org = label_org[:,:2].clone().float().to(self.device)
pred_trg = label_trg[:,:2].clone().float().to(self.device)
actv_org = label_org[:,2:].to(self.device)
actv_trg = label_trg[:,2:].to(self.device)
c_org = actv_org.clone().to(self.device)
c_trg = actv_trg.clone().to(self.device)
x_real = x_real.to(self.device) # Input images.
c_org = c_org.to(self.device) # Original domain labels.
c_trg = c_trg.to(self.device) # Target domain labels.
label_org = label_org.to(self.device) # Labels for computing classification loss.
label_trg = label_trg.to(self.device) # Labels for computing classification loss.
# =================================================================================== #
# 2. Train the discriminator #
@@ -322,7 +441,14 @@ class Solver(object):
out_src, out_cls = self.D(x_real)
d_loss_real = - torch.mean(out_src)
if self.affectnet_emo_descr == 'va-cls':
d_loss_cls = self.classification_loss(out_cls, cls_org, self.dataset)
d_loss_cls = self.classification_loss(out_cls, pred_org, self.dataset)
elif self.affectnet_emo_descr in ['64d_reg', '64d_cls']:
if self.d_loss_cls_type == 'actv':
d_loss_cls = self.classification_loss((None, out_cls), (None, actv_org), self.dataset)
elif self.d_loss_cls_type == 'pred':
d_loss_cls = self.classification_loss((self.D.fc(out_cls), None), (pred_org, None), self.dataset)
else:
d_loss_cls = self.classification_loss((self.D.fc(out_cls), out_cls), (pred_org, actv_org), self.dataset)
else:
d_loss_cls = self.classification_loss(out_cls, label_org, self.dataset)
@@ -359,8 +485,19 @@ class Solver(object):
x_fake = self.G(x_real, c_trg)
out_src, out_cls = self.D(x_fake)
g_loss_fake = - torch.mean(out_src)
# if self.affectnet_emo_descr == 'va-cls':
# g_loss_cls = self.classification_loss(out_cls, pred_trg, self.dataset)
# else:
# g_loss_cls = self.classification_loss(out_cls, label_trg, self.dataset)
if self.affectnet_emo_descr == 'va-cls':
g_loss_cls = self.classification_loss(out_cls, cls_trg, self.dataset)
g_loss_cls = self.classification_loss(out_cls, pred_trg, self.dataset)
elif self.affectnet_emo_descr in ['64d_reg', '64d_cls']:
if self.d_loss_cls_type == 'actv':
g_loss_cls = self.classification_loss((None, out_cls), (None, actv_trg), self.dataset)
elif self.d_loss_cls_type == 'pred':
g_loss_cls = self.classification_loss((self.D.fc(out_cls), None), (pred_trg, None), self.dataset)
else:
g_loss_cls = self.classification_loss((self.D.fc(out_cls), out_cls), (pred_trg, actv_trg), self.dataset)
else:
g_loss_cls = self.classification_loss(out_cls, label_trg, self.dataset)
@@ -632,9 +769,11 @@ class Solver(object):
# Save the translated images.
x_concat = torch.cat(x_fake_list, dim=3)
result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
# result_path = os.path.join(self.result_dir, '{}.jpg'.format(self.model_save_dir.replace("/vol/bitbucket/apg416/","").replace("/models","")))
save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
print('Saved real and fake images into {}...'.format(result_path))
# def test_multi(self):
# """Translate images using StarGAN trained on multiple datasets."""
# # Load the trained generator.