Pre-cleaning
Esse commit está contido em:
+15
-4
@@ -112,9 +112,20 @@ class AffectNet(data.Dataset):
|
||||
elif self.affectnet_emo_descr in ['va-reg', 'va-cls']:
|
||||
for mode, dataset in zip(['train_', 'test_'], [self.train_dataset, self.test_dataset]):
|
||||
filenames = [line.rstrip() for line in open(os.path.join(self.image_dir, mode + 'images.txt'), 'r')]
|
||||
labels = [int(line.rstrip()) for line in open(os.path.join(self.image_dir, mode + 'labels.txt'), 'r')]
|
||||
labels = [[int(line.rstrip())] for line in open(os.path.join(self.image_dir, mode + 'labels.txt'), 'r')]
|
||||
predictions = [[float(x) for x in line.rstrip().split()] for line in open(os.path.join(self.image_dir, mode + 'predictions.txt'), 'r')]
|
||||
|
||||
dataset += list(zip(filenames, labels, predictions))
|
||||
elif self.affectnet_emo_descr == '64d_cls':
|
||||
for mode, dataset in zip(['train', 'test'], [self.train_dataset, self.test_dataset]):
|
||||
filenames = [line.rstrip() for line in open(os.path.join(self.image_dir, self.affectnet_emo_descr, mode, 'images.txt'), 'r')]
|
||||
labels = [[int(line.rstrip())] for line in open(os.path.join(self.image_dir, self.affectnet_emo_descr, mode, 'labels.txt'), 'r')]
|
||||
predictions = [[float(x) for x in line.rstrip().split()] for line in open(os.path.join(self.image_dir, self.affectnet_emo_descr, mode, 'predictions.txt'), 'r')]
|
||||
dataset += list(zip(filenames, labels, predictions))
|
||||
elif self.affectnet_emo_descr == '64d_reg':
|
||||
for mode, dataset in zip(['train', 'test'], [self.train_dataset, self.test_dataset]):
|
||||
filenames = [line.rstrip() for line in open(os.path.join(self.image_dir, self.affectnet_emo_descr, mode, 'images.txt'), 'r')]
|
||||
labels = [[float(x) for x in line.rstrip().split()] for line in open(os.path.join(self.image_dir, self.affectnet_emo_descr, mode, 'labels.txt'), 'r')]
|
||||
predictions = [[float(x) for x in line.rstrip().split()] for line in open(os.path.join(self.image_dir, self.affectnet_emo_descr, mode, 'predictions.txt'), 'r')]
|
||||
dataset += list(zip(filenames, labels, predictions))
|
||||
|
||||
print('Finished preprocessing the AffectNet dataset...')
|
||||
@@ -131,7 +142,7 @@ class AffectNet(data.Dataset):
|
||||
filename, label = dataset[index]
|
||||
image = Image.open(os.path.join(self.image_dir, 'validation', filename))
|
||||
return self.transform(image), torch.tensor(label)
|
||||
elif self.affectnet_emo_descr in ['va-reg', 'va-cls']:
|
||||
elif self.affectnet_emo_descr in ['va-reg', 'va-cls', '64d_reg', '64d_cls']:
|
||||
if self.mode == 'train':
|
||||
dataset = self.train_dataset
|
||||
filename, label, prediction = dataset[index]
|
||||
@@ -140,7 +151,7 @@ class AffectNet(data.Dataset):
|
||||
dataset = self.test_dataset
|
||||
filename, label, prediction = dataset[index]
|
||||
image = Image.open(os.path.join(self.image_dir, 'validation', filename))
|
||||
return self.transform(image), torch.tensor([label] + prediction)
|
||||
return self.transform(image), torch.tensor(label + prediction)
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of images."""
|
||||
|
||||
+8
-2
@@ -71,7 +71,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--lambda_cls', type=float, default=1, help='weight for domain classification loss')
|
||||
parser.add_argument('--lambda_rec', type=float, default=10, help='weight for reconstruction loss')
|
||||
parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty')
|
||||
|
||||
parser.add_argument('--lambda_cls2', type=float, default=1, help='second (extra) weight for domain classification loss')
|
||||
# Training configuration.
|
||||
parser.add_argument('--dataset', type=str, default='CelebA', choices=['CelebA', 'RaFD', 'AffectNet', 'Both'])
|
||||
parser.add_argument('--batch_size', type=int, default=16, help='mini-batch size')
|
||||
@@ -85,11 +85,17 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this step')
|
||||
parser.add_argument('--selected_attrs', '--list', nargs='+', help='selected attributes for the CelebA dataset',
|
||||
default=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'])
|
||||
#
|
||||
parser.add_argument('--affectnet_emo_descr', type=str, default='emotiw', help='emotiw for categorical emotions, va for valence-arousal')
|
||||
parser.add_argument('--use_ccc', type=bool, default=False, help='use ccc loss instead of mse')
|
||||
parser.add_argument('--depth_concat', type=bool, default=True, help='perform depthwise concatenation in G')
|
||||
parser.add_argument('--d_loss_cls_type', type=str, default='actv', help='chosen terms for classification loss')
|
||||
|
||||
# Test configuration.
|
||||
parser.add_argument('--test_iters', type=int, default=200000, help='test model from this step')
|
||||
|
||||
#
|
||||
parser.add_argument('--pca_n_components', type=int, default=0.99, help='PCA visualization number of components kept')
|
||||
parser.add_argument('--pca_variant', type=str, default='quantiles', help='PCA visualization variant')
|
||||
# Miscellaneous.
|
||||
parser.add_argument('--num_workers', type=int, default=1)
|
||||
parser.add_argument('--mode', type=str, default='train', choices=['train', 'test'])
|
||||
|
||||
+25
-5
@@ -21,11 +21,15 @@ class ResidualBlock(nn.Module):
|
||||
|
||||
class Generator(nn.Module):
|
||||
"""Generator network."""
|
||||
def __init__(self, conv_dim=64, c_dim=5, repeat_num=6):
|
||||
def __init__(self, conv_dim=64, c_dim=5, repeat_num=6, depth_concat=True):
|
||||
super(Generator, self).__init__()
|
||||
|
||||
self.depth_concat = depth_concat
|
||||
layers = []
|
||||
layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
|
||||
if self.depth_concat:
|
||||
layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
|
||||
else:
|
||||
layers.append(nn.Conv2d(4, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
|
||||
layers.append(nn.InstanceNorm2d(conv_dim, affine=True, track_running_stats=True))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
|
||||
@@ -56,15 +60,26 @@ class Generator(nn.Module):
|
||||
# Replicate spatially and concatenate domain information.
|
||||
# Note that this type of label conditioning does not work at all if we use reflection padding in Conv2d.
|
||||
# This is because instance normalization ignores the shifting (or bias) effect.
|
||||
c = c.view(c.size(0), c.size(1), 1, 1)
|
||||
c = c.repeat(1, 1, x.size(2), x.size(3))
|
||||
if self.depth_concat:
|
||||
c = c.view(c.size(0), c.size(1), 1, 1)
|
||||
c = c.repeat(1, 1, x.size(2), x.size(3))
|
||||
else:
|
||||
if c.size(1) == 2: #labels are assumed to have length 2 or 64 and images are 112x112
|
||||
# c = c.view(c.size(0), 1, 1, 2)
|
||||
# c = c.repeat(1, 1, 112, 56)
|
||||
# c = c.view(-1).repeat_interleave(8).view(c.size(0),1,1,-1).repeat(1,1,112,7)
|
||||
c = c.view(-1).repeat_interleave(56).view(c.size(0),1,1,-1).repeat(1,1,112,1)
|
||||
elif c.size(1) == 64:
|
||||
c = c.view(c.size(0), 1, 8, 8)
|
||||
c = c.repeat(1, 1, 14, 14)
|
||||
|
||||
x = torch.cat([x, c], dim=1)
|
||||
return self.main(x)
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
"""Discriminator network with PatchGAN."""
|
||||
def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6):
|
||||
def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6, affectnet_emo_descr=None, d_loss_cls_type=None):
|
||||
super(Discriminator, self).__init__()
|
||||
layers = []
|
||||
layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
|
||||
@@ -81,6 +96,11 @@ class Discriminator(nn.Module):
|
||||
self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=kernel_size, bias=False)
|
||||
|
||||
if affectnet_emo_descr == '64d_cls' and (d_loss_cls_type in ['both', 'pred']):
|
||||
self.fc = torch.nn.Linear(c_dim, 7)
|
||||
elif affectnet_emo_descr == '64d_reg' and (d_loss_cls_type in ['both', 'pred']):
|
||||
self.fc = torch.nn.Linear(c_dim, 2)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.main(x)
|
||||
out_src = self.conv1(h)
|
||||
|
||||
+180
-41
@@ -10,24 +10,71 @@ import time
|
||||
import datetime
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.decomposition import PCA
|
||||
import pandas as pd
|
||||
|
||||
class LabelTransform(object):
|
||||
def __init__(self, max_components_kept=14):
|
||||
self.image_dir = '/vol/bitbucket/apg416/affectnet'
|
||||
self.max_components_kept = max_components_kept
|
||||
def __init__(self, config):
|
||||
self.image_dir = config.affectnet_image_dir
|
||||
self.affectnet_emo_descr = config.affectnet_emo_descr
|
||||
self.pca_variant = config.pca_variant
|
||||
self.ss = StandardScaler()
|
||||
self.pca = PCA()
|
||||
self.pca = PCA(config.pca_n_components)
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
predictions_train = np.array([[float(x) for x in line.rstrip().split()] for line in open(os.path.join(self.image_dir,'train_predictions.txt'), 'r')])
|
||||
self.pca.fit(self.ss.fit_transform(predictions_train))
|
||||
if self.affectnet_emo_descr in ['64d_reg', '64d_cls']:
|
||||
actvs_train = np.array([[float(x) for x in line.rstrip().split()] for line in open(os.path.join(self.image_dir, self.affectnet_emo_descr, 'train', 'predictions.txt'), 'r')])
|
||||
elif self.affectnet_emo_descr in ['va-reg', 'va-cls']:
|
||||
actvs_train = np.array([[float(x) for x in line.rstrip().split()] for line in open(os.path.join(self.image_dir,'train_predictions.txt'), 'r')])
|
||||
actvs_latent_train = self.pca.fit_transform(self.ss.fit_transform(actvs_train))
|
||||
|
||||
self.stats = pd.DataFrame(actvs_latent_train).describe(percentiles=[0.1, 0.5, 0.9]).T.drop(columns=['count']) #components in rows
|
||||
|
||||
def create_labels(self, c_org):
|
||||
c_trg_list = []
|
||||
for i in range(self.max_components_kept):
|
||||
transformed_labels = self.pca.transform(self.ss.transform(c_org))
|
||||
transformed_labels[:,1+i:] = 0
|
||||
c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device))
|
||||
c_org_latent = self.pca.transform(self.ss.transform(c_org))
|
||||
|
||||
if self.pca_variant == 'quantiles':
|
||||
for i in range(self.pca.n_components_):
|
||||
# transformed_labels_lo, transformed_labels_hi = c_org_latent, c_org_latent
|
||||
# transformed_labels_lo[:,i] = self.stats.loc[i,'10%']
|
||||
# c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels_lo))).to(self.device))
|
||||
# transformed_labels_lo[:,i] = self.stats.loc[i,'50%']
|
||||
# c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels_lo))).to(self.device))
|
||||
# transformed_labels_hi[:,i] = self.stats.loc[i,'90%']
|
||||
# c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels_hi))).to(self.device))
|
||||
transformed_labels = c_org_latent
|
||||
transformed_labels[:,i] = self.stats.loc[i,'10%']
|
||||
c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device))
|
||||
transformed_labels = c_org_latent
|
||||
transformed_labels[:,i] = self.stats.loc[i,'50%']
|
||||
c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device))
|
||||
transformed_labels = c_org_latent
|
||||
transformed_labels[:,i] = self.stats.loc[i,'90%']
|
||||
c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device))
|
||||
|
||||
# transformed_labels = c_org_latent
|
||||
# transformed_labels[:,0] = self.stats.loc[i,'10%']
|
||||
# transformed_labels[:,1] = self.stats.loc[i,'10%']
|
||||
# c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device))
|
||||
# transformed_labels = c_org_latent
|
||||
# transformed_labels[:,0] = self.stats.loc[i,'90%']
|
||||
# transformed_labels[:,1] = self.stats.loc[i,'90%']
|
||||
# c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device))
|
||||
# transformed_labels = c_org_latent
|
||||
# transformed_labels[:,0] = self.stats.loc[i,'10%']
|
||||
# transformed_labels[:,1] = self.stats.loc[i,'90%']
|
||||
# c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device))
|
||||
# transformed_labels = c_org_latent
|
||||
# transformed_labels[:,0] = self.stats.loc[i,'90%']
|
||||
# transformed_labels[:,1] = self.stats.loc[i,'10%']
|
||||
# c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device))
|
||||
|
||||
else:
|
||||
for i in range(1, self.pca.n_components_):
|
||||
transformed_labels = c_org_latent
|
||||
transformed_labels[:,i:] = 0
|
||||
c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device))
|
||||
c_trg_list.append(torch.FloatTensor(c_org).to(self.device))
|
||||
return c_trg_list
|
||||
|
||||
class Solver(object):
|
||||
@@ -51,6 +98,7 @@ class Solver(object):
|
||||
self.lambda_cls = config.lambda_cls
|
||||
self.lambda_rec = config.lambda_rec
|
||||
self.lambda_gp = config.lambda_gp
|
||||
self.lambda_cls2 = config.lambda_cls2
|
||||
|
||||
# Training configurations.
|
||||
self.dataset = config.dataset
|
||||
@@ -66,13 +114,16 @@ class Solver(object):
|
||||
self.selected_attrs = config.selected_attrs
|
||||
self.affectnet_emo_descr = config.affectnet_emo_descr
|
||||
self.use_ccc = config.use_ccc
|
||||
self.depth_concat = config.depth_concat
|
||||
self.d_loss_cls_type = config.d_loss_cls_type
|
||||
# Test configurations.
|
||||
self.test_iters = config.test_iters
|
||||
|
||||
# Miscellaneous.
|
||||
self.use_tensorboard = config.use_tensorboard
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.label_transform = LabelTransform()
|
||||
if self.affectnet_emo_descr in ['64d_reg', '64d_cls', 'va-reg', 'va-cls']:
|
||||
self.label_transform = LabelTransform(config)
|
||||
|
||||
# Directories.
|
||||
self.log_dir = config.log_dir
|
||||
@@ -94,11 +145,17 @@ class Solver(object):
|
||||
def build_model(self):
|
||||
"""Create a generator and a discriminator."""
|
||||
if self.dataset in ['CelebA', 'RaFD', 'AffectNet']:
|
||||
self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
|
||||
self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num)
|
||||
elif self.dataset in ['Both']:
|
||||
self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num) # 2 for mask vector.
|
||||
self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num)
|
||||
self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num, self.depth_concat)
|
||||
self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num, self.affectnet_emo_descr, self.d_loss_cls_type)
|
||||
|
||||
# if self.affectnet_emo_descr == '64d_cls' and (self.d_loss_cls_type in ['both', 'pred']):
|
||||
# self.D.fc = torch.nn.Linear(self.c_dim, 7)
|
||||
# elif self.affectnet_emo_descr == '64d_reg' and (self.d_loss_cls_type in ['both', 'pred']):
|
||||
# self.D.fc = torch.nn.Linear(self.c_dim, 2)
|
||||
|
||||
# elif self.dataset in ['Both']:
|
||||
# self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num) # 2 for mask vector.
|
||||
# self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num)
|
||||
|
||||
self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
|
||||
self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
|
||||
@@ -205,11 +262,13 @@ class Solver(object):
|
||||
for a in [-1., -0.5, 0., 0.5, 1.]:
|
||||
c_trg = torch.Tensor([v, a]).repeat(c_org.size(0), 1)
|
||||
c_trg_list.append(c_trg.to(self.device))
|
||||
elif self.affectnet_emo_descr in ['va-reg', 'va-cls']:
|
||||
elif self.affectnet_emo_descr in ['va-reg', 'va-cls', '64d_cls']:
|
||||
# c_trg_list = np.loadtxt('/vol/bitbucket/apg416/affectnet/vgg2pcs.txt')
|
||||
# c_trg_list = torch.FloatTensor(c_trg_list).to(self.device)
|
||||
# c_trg_list = [label_vec.repeat(c_org.size(0), 1) for label_vec in torch.unbind(c_trg_list)]
|
||||
c_trg_list = self.label_transform.create_labels(c_org[:,1:])
|
||||
elif self.affectnet_emo_descr in ['64d_reg']:
|
||||
c_trg_list = self.label_transform.create_labels(c_org[:,2:])
|
||||
return c_trg_list
|
||||
|
||||
def ccc_loss(self, x, y):
|
||||
@@ -236,6 +295,34 @@ class Solver(object):
|
||||
return self.ccc_loss(logit, target)
|
||||
else:
|
||||
return F.mse_loss(logit, target)
|
||||
elif self.affectnet_emo_descr in ['64d_reg']:
|
||||
logit_pred, logit_actv = logit
|
||||
target_pred, target_actv = target
|
||||
out = 0
|
||||
if logit_pred is not None:
|
||||
if self.use_ccc:
|
||||
out += self.lambda_cls2/self.lambda_cls*self.ccc_loss(logit_pred, target_pred)
|
||||
else:
|
||||
out += self.lambda_cls2/self.lambda_cls*F.mse_loss(logit_pred, target_pred)
|
||||
if logit_actv is not None:
|
||||
if self.use_ccc:
|
||||
out += self.ccc_loss(logit_actv, target_actv)
|
||||
else:
|
||||
out += F.mse_loss(logit_actv, target_actv)
|
||||
return out
|
||||
elif self.affectnet_emo_descr in ['64d_cls']:
|
||||
logit_pred, logit_actv = logit
|
||||
target_pred, target_actv = target
|
||||
out = 0
|
||||
if logit_pred is not None:
|
||||
out += self.lambda_cls2/self.lambda_cls*F.cross_entropy(logit_pred, target_pred)
|
||||
if logit_actv is not None:
|
||||
if self.use_ccc:
|
||||
out += self.ccc_loss(logit_actv, target_actv)
|
||||
else:
|
||||
out += F.mse_loss(logit_actv, target_actv)
|
||||
return out
|
||||
|
||||
def train(self):
|
||||
"""Train StarGAN within a single dataset."""
|
||||
# Set data loader.
|
||||
@@ -262,6 +349,18 @@ class Solver(object):
|
||||
start_iters = self.resume_iters
|
||||
self.restore_model(self.resume_iters)
|
||||
|
||||
# # --- converting
|
||||
# if self.affectnet_emo_descr == '64d_cls' and (self.d_loss_cls_type in ['both', 'pred']):
|
||||
# self.D.fc = torch.nn.Linear(self.c_dim, 7)
|
||||
# elif self.affectnet_emo_descr == '64d_reg' and (self.d_loss_cls_type in ['both', 'pred']):
|
||||
# self.D.fc = torch.nn.Linear(self.c_dim, 2)
|
||||
# G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(self.resume_iters))
|
||||
# D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(self.resume_iters))
|
||||
# torch.save(self.G.state_dict(), G_path)
|
||||
# torch.save(self.D.state_dict(), D_path)
|
||||
# return 0
|
||||
# # ---
|
||||
|
||||
# Start training.
|
||||
print('Start training...')
|
||||
start_time = time.time()
|
||||
@@ -283,36 +382,56 @@ class Solver(object):
|
||||
label_trg = label_org[rand_idx]
|
||||
|
||||
if self.dataset == 'CelebA':
|
||||
c_org = label_org.clone()
|
||||
c_trg = label_trg.clone()
|
||||
c_org = label_org.clone().to(self.device)
|
||||
c_trg = label_trg.clone().to(self.device)
|
||||
label_org = label_org.to(self.device) # Labels for computing classification loss.
|
||||
label_trg = label_trg.to(self.device) # Labels for computing classification loss.
|
||||
elif self.dataset == 'RaFD':
|
||||
c_org = self.label2onehot(label_org, self.c_dim)
|
||||
c_trg = self.label2onehot(label_trg, self.c_dim)
|
||||
c_org = self.label2onehot(label_org, self.c_dim).to(self.device)
|
||||
c_trg = self.label2onehot(label_trg, self.c_dim).to(self.device)
|
||||
label_org = label_org.to(self.device) # Labels for computing classification loss.
|
||||
label_trg = label_trg.to(self.device) # Labels for computing classification loss.
|
||||
elif self.dataset == 'AffectNet':
|
||||
if self.affectnet_emo_descr == 'emotiw':
|
||||
c_org = self.label2onehot(label_org, self.c_dim)
|
||||
c_trg = self.label2onehot(label_trg, self.c_dim)
|
||||
c_org = self.label2onehot(label_org, self.c_dim).to(self.device)
|
||||
c_trg = self.label2onehot(label_trg, self.c_dim).to(self.device)
|
||||
label_org = label_org.to(self.device) # Labels for computing classification loss.
|
||||
label_trg = label_trg.to(self.device) # Labels for computing classification loss.
|
||||
elif self.affectnet_emo_descr == 'va':
|
||||
c_org = label_org.clone()
|
||||
c_trg = label_trg.clone()
|
||||
c_org = label_org.clone().to(self.device)
|
||||
c_trg = label_trg.clone().to(self.device)
|
||||
label_org = label_org.to(self.device) # Labels for computing classification loss.
|
||||
label_trg = label_trg.to(self.device) # Labels for computing classification loss.
|
||||
elif self.affectnet_emo_descr == 'va-reg':
|
||||
label_org = label_org[:,1:]
|
||||
label_trg = label_trg[:,1:]
|
||||
c_org = label_org.clone()
|
||||
c_trg = label_trg.clone()
|
||||
label_org = label_org[:,1:].to(self.device)
|
||||
label_trg = label_trg[:,1:].to(self.device)
|
||||
c_org = label_org.clone().to(self.device)
|
||||
c_trg = label_trg.clone().to(self.device)
|
||||
elif self.affectnet_emo_descr == 'va-cls':
|
||||
cls_org = label_org[:,0].clone().long().to(self.device)
|
||||
cls_trg = label_trg[:,0].clone().long().to(self.device)
|
||||
label_org = label_org[:,1:]
|
||||
label_trg = label_trg[:,1:]
|
||||
c_org = label_org.clone()
|
||||
c_trg = label_trg.clone()
|
||||
pred_org = label_org[:,0].clone().long().to(self.device)
|
||||
pred_trg = label_trg[:,0].clone().long().to(self.device)
|
||||
label_org = label_org[:,1:].to(self.device)
|
||||
label_trg = label_trg[:,1:].to(self.device)
|
||||
c_org = label_org.clone().to(self.device)
|
||||
c_trg = label_trg.clone().to(self.device)
|
||||
elif self.affectnet_emo_descr == '64d_cls':
|
||||
pred_org = label_org[:,0].clone().long().to(self.device)
|
||||
pred_trg = label_trg[:,0].clone().long().to(self.device)
|
||||
actv_org = label_org[:,1:].to(self.device)
|
||||
actv_trg = label_trg[:,1:].to(self.device)
|
||||
c_org = actv_org.clone().to(self.device)
|
||||
c_trg = actv_trg.clone().to(self.device)
|
||||
elif self.affectnet_emo_descr == '64d_reg':
|
||||
pred_org = label_org[:,:2].clone().float().to(self.device)
|
||||
pred_trg = label_trg[:,:2].clone().float().to(self.device)
|
||||
actv_org = label_org[:,2:].to(self.device)
|
||||
actv_trg = label_trg[:,2:].to(self.device)
|
||||
c_org = actv_org.clone().to(self.device)
|
||||
c_trg = actv_trg.clone().to(self.device)
|
||||
|
||||
|
||||
x_real = x_real.to(self.device) # Input images.
|
||||
c_org = c_org.to(self.device) # Original domain labels.
|
||||
c_trg = c_trg.to(self.device) # Target domain labels.
|
||||
label_org = label_org.to(self.device) # Labels for computing classification loss.
|
||||
label_trg = label_trg.to(self.device) # Labels for computing classification loss.
|
||||
|
||||
|
||||
# =================================================================================== #
|
||||
# 2. Train the discriminator #
|
||||
@@ -322,7 +441,14 @@ class Solver(object):
|
||||
out_src, out_cls = self.D(x_real)
|
||||
d_loss_real = - torch.mean(out_src)
|
||||
if self.affectnet_emo_descr == 'va-cls':
|
||||
d_loss_cls = self.classification_loss(out_cls, cls_org, self.dataset)
|
||||
d_loss_cls = self.classification_loss(out_cls, pred_org, self.dataset)
|
||||
elif self.affectnet_emo_descr in ['64d_reg', '64d_cls']:
|
||||
if self.d_loss_cls_type == 'actv':
|
||||
d_loss_cls = self.classification_loss((None, out_cls), (None, actv_org), self.dataset)
|
||||
elif self.d_loss_cls_type == 'pred':
|
||||
d_loss_cls = self.classification_loss((self.D.fc(out_cls), None), (pred_org, None), self.dataset)
|
||||
else:
|
||||
d_loss_cls = self.classification_loss((self.D.fc(out_cls), out_cls), (pred_org, actv_org), self.dataset)
|
||||
else:
|
||||
d_loss_cls = self.classification_loss(out_cls, label_org, self.dataset)
|
||||
|
||||
@@ -359,8 +485,19 @@ class Solver(object):
|
||||
x_fake = self.G(x_real, c_trg)
|
||||
out_src, out_cls = self.D(x_fake)
|
||||
g_loss_fake = - torch.mean(out_src)
|
||||
# if self.affectnet_emo_descr == 'va-cls':
|
||||
# g_loss_cls = self.classification_loss(out_cls, pred_trg, self.dataset)
|
||||
# else:
|
||||
# g_loss_cls = self.classification_loss(out_cls, label_trg, self.dataset)
|
||||
if self.affectnet_emo_descr == 'va-cls':
|
||||
g_loss_cls = self.classification_loss(out_cls, cls_trg, self.dataset)
|
||||
g_loss_cls = self.classification_loss(out_cls, pred_trg, self.dataset)
|
||||
elif self.affectnet_emo_descr in ['64d_reg', '64d_cls']:
|
||||
if self.d_loss_cls_type == 'actv':
|
||||
g_loss_cls = self.classification_loss((None, out_cls), (None, actv_trg), self.dataset)
|
||||
elif self.d_loss_cls_type == 'pred':
|
||||
g_loss_cls = self.classification_loss((self.D.fc(out_cls), None), (pred_trg, None), self.dataset)
|
||||
else:
|
||||
g_loss_cls = self.classification_loss((self.D.fc(out_cls), out_cls), (pred_trg, actv_trg), self.dataset)
|
||||
else:
|
||||
g_loss_cls = self.classification_loss(out_cls, label_trg, self.dataset)
|
||||
|
||||
@@ -632,9 +769,11 @@ class Solver(object):
|
||||
# Save the translated images.
|
||||
x_concat = torch.cat(x_fake_list, dim=3)
|
||||
result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
|
||||
# result_path = os.path.join(self.result_dir, '{}.jpg'.format(self.model_save_dir.replace("/vol/bitbucket/apg416/","").replace("/models","")))
|
||||
save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
|
||||
print('Saved real and fake images into {}...'.format(result_path))
|
||||
|
||||
|
||||
# def test_multi(self):
|
||||
# """Translate images using StarGAN trained on multiple datasets."""
|
||||
# # Load the trained generator.
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário