diff --git a/data_loader.py b/data_loader.py index bbe1bd2..268540c 100644 --- a/data_loader.py +++ b/data_loader.py @@ -112,11 +112,22 @@ class AffectNet(data.Dataset): elif self.affectnet_emo_descr in ['va-reg', 'va-cls']: for mode, dataset in zip(['train_', 'test_'], [self.train_dataset, self.test_dataset]): filenames = [line.rstrip() for line in open(os.path.join(self.image_dir, mode + 'images.txt'), 'r')] - labels = [int(line.rstrip()) for line in open(os.path.join(self.image_dir, mode + 'labels.txt'), 'r')] + labels = [[int(line.rstrip())] for line in open(os.path.join(self.image_dir, mode + 'labels.txt'), 'r')] predictions = [[float(x) for x in line.rstrip().split()] for line in open(os.path.join(self.image_dir, mode + 'predictions.txt'), 'r')] - dataset += list(zip(filenames, labels, predictions)) - + elif self.affectnet_emo_descr == '64d_cls': + for mode, dataset in zip(['train', 'test'], [self.train_dataset, self.test_dataset]): + filenames = [line.rstrip() for line in open(os.path.join(self.image_dir, self.affectnet_emo_descr, mode, 'images.txt'), 'r')] + labels = [[int(line.rstrip())] for line in open(os.path.join(self.image_dir, self.affectnet_emo_descr, mode, 'labels.txt'), 'r')] + predictions = [[float(x) for x in line.rstrip().split()] for line in open(os.path.join(self.image_dir, self.affectnet_emo_descr, mode, 'predictions.txt'), 'r')] + dataset += list(zip(filenames, labels, predictions)) + elif self.affectnet_emo_descr == '64d_reg': + for mode, dataset in zip(['train', 'test'], [self.train_dataset, self.test_dataset]): + filenames = [line.rstrip() for line in open(os.path.join(self.image_dir, self.affectnet_emo_descr, mode, 'images.txt'), 'r')] + labels = [[float(x) for x in line.rstrip().split()] for line in open(os.path.join(self.image_dir, self.affectnet_emo_descr, mode, 'labels.txt'), 'r')] + predictions = [[float(x) for x in line.rstrip().split()] for line in open(os.path.join(self.image_dir, self.affectnet_emo_descr, mode, 'predictions.txt'), 'r')] + dataset += list(zip(filenames, labels, predictions)) + print('Finished preprocessing the AffectNet dataset...') def __getitem__(self, index): @@ -131,7 +142,7 @@ class AffectNet(data.Dataset): filename, label = dataset[index] image = Image.open(os.path.join(self.image_dir, 'validation', filename)) return self.transform(image), torch.tensor(label) - elif self.affectnet_emo_descr in ['va-reg', 'va-cls']: + elif self.affectnet_emo_descr in ['va-reg', 'va-cls', '64d_reg', '64d_cls']: if self.mode == 'train': dataset = self.train_dataset filename, label, prediction = dataset[index] @@ -140,7 +151,7 @@ class AffectNet(data.Dataset): dataset = self.test_dataset filename, label, prediction = dataset[index] image = Image.open(os.path.join(self.image_dir, 'validation', filename)) - return self.transform(image), torch.tensor([label] + prediction) + return self.transform(image), torch.tensor(label + prediction) def __len__(self): """Return the number of images.""" diff --git a/main.py b/main.py index fb79b04..9a1b85f 100644 --- a/main.py +++ b/main.py @@ -71,7 +71,7 @@ if __name__ == '__main__': parser.add_argument('--lambda_cls', type=float, default=1, help='weight for domain classification loss') parser.add_argument('--lambda_rec', type=float, default=10, help='weight for reconstruction loss') parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty') - + parser.add_argument('--lambda_cls2', type=float, default=1, help='second (extra) weight for domain classification loss') # Training configuration. parser.add_argument('--dataset', type=str, default='CelebA', choices=['CelebA', 'RaFD', 'AffectNet', 'Both']) parser.add_argument('--batch_size', type=int, default=16, help='mini-batch size') @@ -85,11 +85,17 @@ if __name__ == '__main__': parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this step') parser.add_argument('--selected_attrs', '--list', nargs='+', help='selected attributes for the CelebA dataset', default=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']) + # parser.add_argument('--affectnet_emo_descr', type=str, default='emotiw', help='emotiw for categorical emotions, va for valence-arousal') parser.add_argument('--use_ccc', type=bool, default=False, help='use ccc loss instead of mse') + parser.add_argument('--depth_concat', type=bool, default=True, help='perform depthwise concatenation in G') + parser.add_argument('--d_loss_cls_type', type=str, default='actv', help='chosen terms for classification loss') + # Test configuration. parser.add_argument('--test_iters', type=int, default=200000, help='test model from this step') - + # + parser.add_argument('--pca_n_components', type=int, default=0.99, help='PCA visualization number of components kept') + parser.add_argument('--pca_variant', type=str, default='quantiles', help='PCA visualization variant') # Miscellaneous. parser.add_argument('--num_workers', type=int, default=1) parser.add_argument('--mode', type=str, default='train', choices=['train', 'test']) diff --git a/model.py b/model.py index 3d0e627..51e9974 100644 --- a/model.py +++ b/model.py @@ -21,11 +21,15 @@ class ResidualBlock(nn.Module): class Generator(nn.Module): """Generator network.""" - def __init__(self, conv_dim=64, c_dim=5, repeat_num=6): + def __init__(self, conv_dim=64, c_dim=5, repeat_num=6, depth_concat=True): super(Generator, self).__init__() + self.depth_concat = depth_concat layers = [] - layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False)) + if self.depth_concat: + layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False)) + else: + layers.append(nn.Conv2d(4, conv_dim, kernel_size=7, stride=1, padding=3, bias=False)) layers.append(nn.InstanceNorm2d(conv_dim, affine=True, track_running_stats=True)) layers.append(nn.ReLU(inplace=True)) @@ -56,15 +60,26 @@ class Generator(nn.Module): # Replicate spatially and concatenate domain information. # Note that this type of label conditioning does not work at all if we use reflection padding in Conv2d. # This is because instance normalization ignores the shifting (or bias) effect. - c = c.view(c.size(0), c.size(1), 1, 1) - c = c.repeat(1, 1, x.size(2), x.size(3)) + if self.depth_concat: + c = c.view(c.size(0), c.size(1), 1, 1) + c = c.repeat(1, 1, x.size(2), x.size(3)) + else: + if c.size(1) == 2: #labels are assumed to have length 2 or 64 and images are 112x112 + # c = c.view(c.size(0), 1, 1, 2) + # c = c.repeat(1, 1, 112, 56) + # c = c.view(-1).repeat_interleave(8).view(c.size(0),1,1,-1).repeat(1,1,112,7) + c = c.view(-1).repeat_interleave(56).view(c.size(0),1,1,-1).repeat(1,1,112,1) + elif c.size(1) == 64: + c = c.view(c.size(0), 1, 8, 8) + c = c.repeat(1, 1, 14, 14) + x = torch.cat([x, c], dim=1) return self.main(x) class Discriminator(nn.Module): """Discriminator network with PatchGAN.""" - def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6): + def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6, affectnet_emo_descr=None, d_loss_cls_type=None): super(Discriminator, self).__init__() layers = [] layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)) @@ -80,7 +95,12 @@ class Discriminator(nn.Module): self.main = nn.Sequential(*layers) self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False) self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=kernel_size, bias=False) - + + if affectnet_emo_descr == '64d_cls' and (d_loss_cls_type in ['both', 'pred']): + self.fc = torch.nn.Linear(c_dim, 7) + elif affectnet_emo_descr == '64d_reg' and (d_loss_cls_type in ['both', 'pred']): + self.fc = torch.nn.Linear(c_dim, 2) + def forward(self, x): h = self.main(x) out_src = self.conv1(h) diff --git a/solver.py b/solver.py index 74a5e6e..d69a550 100644 --- a/solver.py +++ b/solver.py @@ -10,24 +10,71 @@ import time import datetime from sklearn.preprocessing import StandardScaler from sklearn.decomposition import PCA +import pandas as pd class LabelTransform(object): - def __init__(self, max_components_kept=14): - self.image_dir = '/vol/bitbucket/apg416/affectnet' - self.max_components_kept = max_components_kept + def __init__(self, config): + self.image_dir = config.affectnet_image_dir + self.affectnet_emo_descr = config.affectnet_emo_descr + self.pca_variant = config.pca_variant self.ss = StandardScaler() - self.pca = PCA() + self.pca = PCA(config.pca_n_components) self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - predictions_train = np.array([[float(x) for x in line.rstrip().split()] for line in open(os.path.join(self.image_dir,'train_predictions.txt'), 'r')]) - self.pca.fit(self.ss.fit_transform(predictions_train)) + if self.affectnet_emo_descr in ['64d_reg', '64d_cls']: + actvs_train = np.array([[float(x) for x in line.rstrip().split()] for line in open(os.path.join(self.image_dir, self.affectnet_emo_descr, 'train', 'predictions.txt'), 'r')]) + elif self.affectnet_emo_descr in ['va-reg', 'va-cls']: + actvs_train = np.array([[float(x) for x in line.rstrip().split()] for line in open(os.path.join(self.image_dir,'train_predictions.txt'), 'r')]) + actvs_latent_train = self.pca.fit_transform(self.ss.fit_transform(actvs_train)) + + self.stats = pd.DataFrame(actvs_latent_train).describe(percentiles=[0.1, 0.5, 0.9]).T.drop(columns=['count']) #components in rows def create_labels(self, c_org): c_trg_list = [] - for i in range(self.max_components_kept): - transformed_labels = self.pca.transform(self.ss.transform(c_org)) - transformed_labels[:,1+i:] = 0 - c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device)) + c_org_latent = self.pca.transform(self.ss.transform(c_org)) + + if self.pca_variant == 'quantiles': + for i in range(self.pca.n_components_): + # transformed_labels_lo, transformed_labels_hi = c_org_latent, c_org_latent + # transformed_labels_lo[:,i] = self.stats.loc[i,'10%'] + # c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels_lo))).to(self.device)) + # transformed_labels_lo[:,i] = self.stats.loc[i,'50%'] + # c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels_lo))).to(self.device)) + # transformed_labels_hi[:,i] = self.stats.loc[i,'90%'] + # c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels_hi))).to(self.device)) + transformed_labels = c_org_latent + transformed_labels[:,i] = self.stats.loc[i,'10%'] + c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device)) + transformed_labels = c_org_latent + transformed_labels[:,i] = self.stats.loc[i,'50%'] + c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device)) + transformed_labels = c_org_latent + transformed_labels[:,i] = self.stats.loc[i,'90%'] + c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device)) + + # transformed_labels = c_org_latent + # transformed_labels[:,0] = self.stats.loc[i,'10%'] + # transformed_labels[:,1] = self.stats.loc[i,'10%'] + # c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device)) + # transformed_labels = c_org_latent + # transformed_labels[:,0] = self.stats.loc[i,'90%'] + # transformed_labels[:,1] = self.stats.loc[i,'90%'] + # c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device)) + # transformed_labels = c_org_latent + # transformed_labels[:,0] = self.stats.loc[i,'10%'] + # transformed_labels[:,1] = self.stats.loc[i,'90%'] + # c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device)) + # transformed_labels = c_org_latent + # transformed_labels[:,0] = self.stats.loc[i,'90%'] + # transformed_labels[:,1] = self.stats.loc[i,'10%'] + # c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device)) + + else: + for i in range(1, self.pca.n_components_): + transformed_labels = c_org_latent + transformed_labels[:,i:] = 0 + c_trg_list.append(torch.FloatTensor(self.ss.inverse_transform(self.pca.inverse_transform(transformed_labels))).to(self.device)) + c_trg_list.append(torch.FloatTensor(c_org).to(self.device)) return c_trg_list class Solver(object): @@ -51,6 +98,7 @@ class Solver(object): self.lambda_cls = config.lambda_cls self.lambda_rec = config.lambda_rec self.lambda_gp = config.lambda_gp + self.lambda_cls2 = config.lambda_cls2 # Training configurations. self.dataset = config.dataset @@ -66,13 +114,16 @@ class Solver(object): self.selected_attrs = config.selected_attrs self.affectnet_emo_descr = config.affectnet_emo_descr self.use_ccc = config.use_ccc + self.depth_concat = config.depth_concat + self.d_loss_cls_type = config.d_loss_cls_type # Test configurations. self.test_iters = config.test_iters # Miscellaneous. self.use_tensorboard = config.use_tensorboard self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - self.label_transform = LabelTransform() + if self.affectnet_emo_descr in ['64d_reg', '64d_cls', 'va-reg', 'va-cls']: + self.label_transform = LabelTransform(config) # Directories. self.log_dir = config.log_dir @@ -94,11 +145,17 @@ class Solver(object): def build_model(self): """Create a generator and a discriminator.""" if self.dataset in ['CelebA', 'RaFD', 'AffectNet']: - self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num) - self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num) - elif self.dataset in ['Both']: - self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num) # 2 for mask vector. - self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num) + self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num, self.depth_concat) + self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num, self.affectnet_emo_descr, self.d_loss_cls_type) + + # if self.affectnet_emo_descr == '64d_cls' and (self.d_loss_cls_type in ['both', 'pred']): + # self.D.fc = torch.nn.Linear(self.c_dim, 7) + # elif self.affectnet_emo_descr == '64d_reg' and (self.d_loss_cls_type in ['both', 'pred']): + # self.D.fc = torch.nn.Linear(self.c_dim, 2) + + # elif self.dataset in ['Both']: + # self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num) # 2 for mask vector. + # self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num) self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) @@ -205,11 +262,13 @@ class Solver(object): for a in [-1., -0.5, 0., 0.5, 1.]: c_trg = torch.Tensor([v, a]).repeat(c_org.size(0), 1) c_trg_list.append(c_trg.to(self.device)) - elif self.affectnet_emo_descr in ['va-reg', 'va-cls']: + elif self.affectnet_emo_descr in ['va-reg', 'va-cls', '64d_cls']: # c_trg_list = np.loadtxt('/vol/bitbucket/apg416/affectnet/vgg2pcs.txt') # c_trg_list = torch.FloatTensor(c_trg_list).to(self.device) # c_trg_list = [label_vec.repeat(c_org.size(0), 1) for label_vec in torch.unbind(c_trg_list)] c_trg_list = self.label_transform.create_labels(c_org[:,1:]) + elif self.affectnet_emo_descr in ['64d_reg']: + c_trg_list = self.label_transform.create_labels(c_org[:,2:]) return c_trg_list def ccc_loss(self, x, y): @@ -236,6 +295,34 @@ class Solver(object): return self.ccc_loss(logit, target) else: return F.mse_loss(logit, target) + elif self.affectnet_emo_descr in ['64d_reg']: + logit_pred, logit_actv = logit + target_pred, target_actv = target + out = 0 + if logit_pred is not None: + if self.use_ccc: + out += self.lambda_cls2/self.lambda_cls*self.ccc_loss(logit_pred, target_pred) + else: + out += self.lambda_cls2/self.lambda_cls*F.mse_loss(logit_pred, target_pred) + if logit_actv is not None: + if self.use_ccc: + out += self.ccc_loss(logit_actv, target_actv) + else: + out += F.mse_loss(logit_actv, target_actv) + return out + elif self.affectnet_emo_descr in ['64d_cls']: + logit_pred, logit_actv = logit + target_pred, target_actv = target + out = 0 + if logit_pred is not None: + out += self.lambda_cls2/self.lambda_cls*F.cross_entropy(logit_pred, target_pred) + if logit_actv is not None: + if self.use_ccc: + out += self.ccc_loss(logit_actv, target_actv) + else: + out += F.mse_loss(logit_actv, target_actv) + return out + def train(self): """Train StarGAN within a single dataset.""" # Set data loader. @@ -262,6 +349,18 @@ class Solver(object): start_iters = self.resume_iters self.restore_model(self.resume_iters) + # # --- converting + # if self.affectnet_emo_descr == '64d_cls' and (self.d_loss_cls_type in ['both', 'pred']): + # self.D.fc = torch.nn.Linear(self.c_dim, 7) + # elif self.affectnet_emo_descr == '64d_reg' and (self.d_loss_cls_type in ['both', 'pred']): + # self.D.fc = torch.nn.Linear(self.c_dim, 2) + # G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(self.resume_iters)) + # D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(self.resume_iters)) + # torch.save(self.G.state_dict(), G_path) + # torch.save(self.D.state_dict(), D_path) + # return 0 + # # --- + # Start training. print('Start training...') start_time = time.time() @@ -283,36 +382,56 @@ class Solver(object): label_trg = label_org[rand_idx] if self.dataset == 'CelebA': - c_org = label_org.clone() - c_trg = label_trg.clone() + c_org = label_org.clone().to(self.device) + c_trg = label_trg.clone().to(self.device) + label_org = label_org.to(self.device) # Labels for computing classification loss. + label_trg = label_trg.to(self.device) # Labels for computing classification loss. elif self.dataset == 'RaFD': - c_org = self.label2onehot(label_org, self.c_dim) - c_trg = self.label2onehot(label_trg, self.c_dim) + c_org = self.label2onehot(label_org, self.c_dim).to(self.device) + c_trg = self.label2onehot(label_trg, self.c_dim).to(self.device) + label_org = label_org.to(self.device) # Labels for computing classification loss. + label_trg = label_trg.to(self.device) # Labels for computing classification loss. elif self.dataset == 'AffectNet': if self.affectnet_emo_descr == 'emotiw': - c_org = self.label2onehot(label_org, self.c_dim) - c_trg = self.label2onehot(label_trg, self.c_dim) + c_org = self.label2onehot(label_org, self.c_dim).to(self.device) + c_trg = self.label2onehot(label_trg, self.c_dim).to(self.device) + label_org = label_org.to(self.device) # Labels for computing classification loss. + label_trg = label_trg.to(self.device) # Labels for computing classification loss. elif self.affectnet_emo_descr == 'va': - c_org = label_org.clone() - c_trg = label_trg.clone() + c_org = label_org.clone().to(self.device) + c_trg = label_trg.clone().to(self.device) + label_org = label_org.to(self.device) # Labels for computing classification loss. + label_trg = label_trg.to(self.device) # Labels for computing classification loss. elif self.affectnet_emo_descr == 'va-reg': - label_org = label_org[:,1:] - label_trg = label_trg[:,1:] - c_org = label_org.clone() - c_trg = label_trg.clone() + label_org = label_org[:,1:].to(self.device) + label_trg = label_trg[:,1:].to(self.device) + c_org = label_org.clone().to(self.device) + c_trg = label_trg.clone().to(self.device) elif self.affectnet_emo_descr == 'va-cls': - cls_org = label_org[:,0].clone().long().to(self.device) - cls_trg = label_trg[:,0].clone().long().to(self.device) - label_org = label_org[:,1:] - label_trg = label_trg[:,1:] - c_org = label_org.clone() - c_trg = label_trg.clone() + pred_org = label_org[:,0].clone().long().to(self.device) + pred_trg = label_trg[:,0].clone().long().to(self.device) + label_org = label_org[:,1:].to(self.device) + label_trg = label_trg[:,1:].to(self.device) + c_org = label_org.clone().to(self.device) + c_trg = label_trg.clone().to(self.device) + elif self.affectnet_emo_descr == '64d_cls': + pred_org = label_org[:,0].clone().long().to(self.device) + pred_trg = label_trg[:,0].clone().long().to(self.device) + actv_org = label_org[:,1:].to(self.device) + actv_trg = label_trg[:,1:].to(self.device) + c_org = actv_org.clone().to(self.device) + c_trg = actv_trg.clone().to(self.device) + elif self.affectnet_emo_descr == '64d_reg': + pred_org = label_org[:,:2].clone().float().to(self.device) + pred_trg = label_trg[:,:2].clone().float().to(self.device) + actv_org = label_org[:,2:].to(self.device) + actv_trg = label_trg[:,2:].to(self.device) + c_org = actv_org.clone().to(self.device) + c_trg = actv_trg.clone().to(self.device) + x_real = x_real.to(self.device) # Input images. - c_org = c_org.to(self.device) # Original domain labels. - c_trg = c_trg.to(self.device) # Target domain labels. - label_org = label_org.to(self.device) # Labels for computing classification loss. - label_trg = label_trg.to(self.device) # Labels for computing classification loss. + # =================================================================================== # # 2. Train the discriminator # @@ -322,7 +441,14 @@ class Solver(object): out_src, out_cls = self.D(x_real) d_loss_real = - torch.mean(out_src) if self.affectnet_emo_descr == 'va-cls': - d_loss_cls = self.classification_loss(out_cls, cls_org, self.dataset) + d_loss_cls = self.classification_loss(out_cls, pred_org, self.dataset) + elif self.affectnet_emo_descr in ['64d_reg', '64d_cls']: + if self.d_loss_cls_type == 'actv': + d_loss_cls = self.classification_loss((None, out_cls), (None, actv_org), self.dataset) + elif self.d_loss_cls_type == 'pred': + d_loss_cls = self.classification_loss((self.D.fc(out_cls), None), (pred_org, None), self.dataset) + else: + d_loss_cls = self.classification_loss((self.D.fc(out_cls), out_cls), (pred_org, actv_org), self.dataset) else: d_loss_cls = self.classification_loss(out_cls, label_org, self.dataset) @@ -359,8 +485,19 @@ class Solver(object): x_fake = self.G(x_real, c_trg) out_src, out_cls = self.D(x_fake) g_loss_fake = - torch.mean(out_src) + # if self.affectnet_emo_descr == 'va-cls': + # g_loss_cls = self.classification_loss(out_cls, pred_trg, self.dataset) + # else: + # g_loss_cls = self.classification_loss(out_cls, label_trg, self.dataset) if self.affectnet_emo_descr == 'va-cls': - g_loss_cls = self.classification_loss(out_cls, cls_trg, self.dataset) + g_loss_cls = self.classification_loss(out_cls, pred_trg, self.dataset) + elif self.affectnet_emo_descr in ['64d_reg', '64d_cls']: + if self.d_loss_cls_type == 'actv': + g_loss_cls = self.classification_loss((None, out_cls), (None, actv_trg), self.dataset) + elif self.d_loss_cls_type == 'pred': + g_loss_cls = self.classification_loss((self.D.fc(out_cls), None), (pred_trg, None), self.dataset) + else: + g_loss_cls = self.classification_loss((self.D.fc(out_cls), out_cls), (pred_trg, actv_trg), self.dataset) else: g_loss_cls = self.classification_loss(out_cls, label_trg, self.dataset) @@ -632,9 +769,11 @@ class Solver(object): # Save the translated images. x_concat = torch.cat(x_fake_list, dim=3) result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1)) + # result_path = os.path.join(self.result_dir, '{}.jpg'.format(self.model_save_dir.replace("/vol/bitbucket/apg416/","").replace("/models",""))) save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0) print('Saved real and fake images into {}...'.format(result_path)) + # def test_multi(self): # """Translate images using StarGAN trained on multiple datasets.""" # # Load the trained generator.