Add labels from vggface
Esse commit está contido em:
+28
-10
@@ -98,7 +98,7 @@ class AffectNet(data.Dataset):
|
||||
for line in lines_test:
|
||||
filename, label = line.split()
|
||||
self.test_dataset.append([filename, int(label)])
|
||||
else:
|
||||
elif self.affectnet_emo_descr == 'va':
|
||||
lines_train = [line.rstrip() for line in open(os.path.join(self.image_dir, '_affectnet_train_va.txt'), 'r')]
|
||||
for line in lines_train:
|
||||
split = line.split()
|
||||
@@ -109,20 +109,38 @@ class AffectNet(data.Dataset):
|
||||
split = line.split()
|
||||
filename, v, a = split
|
||||
self.test_dataset.append([filename, [float(v), float(a)]])
|
||||
elif self.affectnet_emo_descr in ['va-reg', 'va-cls']:
|
||||
for mode, dataset in zip(['train_', 'test_'], [self.train_dataset, self.test_dataset]):
|
||||
filenames = [line.rstrip() for line in open(os.path.join(self.image_dir, mode + 'images.txt'), 'r')]
|
||||
labels = [int(line.rstrip()) for line in open(os.path.join(self.image_dir, mode + 'labels.txt'), 'r')]
|
||||
predictions = [[float(x) for x in line.rstrip().split()] for line in open(os.path.join(self.image_dir, mode + 'predictions.txt'), 'r')]
|
||||
|
||||
dataset += list(zip(filenames, labels, predictions))
|
||||
|
||||
print('Finished preprocessing the AffectNet dataset...')
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Return one image and its corresponding attribute label."""
|
||||
if self.mode == 'train':
|
||||
dataset = self.train_dataset
|
||||
filename, label = dataset[index]
|
||||
image = Image.open(os.path.join(self.image_dir, 'train', filename))
|
||||
else:
|
||||
dataset = self.test_dataset
|
||||
filename, label = dataset[index]
|
||||
image = Image.open(os.path.join(self.image_dir, 'validation', filename))
|
||||
return self.transform(image), torch.tensor(label)
|
||||
if self.affectnet_emo_descr in ['emotiw', 'va']:
|
||||
if self.mode == 'train':
|
||||
dataset = self.train_dataset
|
||||
filename, label = dataset[index]
|
||||
image = Image.open(os.path.join(self.image_dir, 'train', filename))
|
||||
else:
|
||||
dataset = self.test_dataset
|
||||
filename, label = dataset[index]
|
||||
image = Image.open(os.path.join(self.image_dir, 'validation', filename))
|
||||
return self.transform(image), torch.tensor(label)
|
||||
elif self.affectnet_emo_descr in ['va-reg', 'va-cls']:
|
||||
if self.mode == 'train':
|
||||
dataset = self.train_dataset
|
||||
filename, label, prediction = dataset[index]
|
||||
image = Image.open(os.path.join(self.image_dir, 'train', filename))
|
||||
else:
|
||||
dataset = self.test_dataset
|
||||
filename, label, prediction = dataset[index]
|
||||
image = Image.open(os.path.join(self.image_dir, 'validation', filename))
|
||||
return self.transform(image), torch.tensor([label] + prediction)
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of images."""
|
||||
|
||||
@@ -86,6 +86,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--selected_attrs', '--list', nargs='+', help='selected attributes for the CelebA dataset',
|
||||
default=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'])
|
||||
parser.add_argument('--affectnet_emo_descr', type=str, default='emotiw', help='emotiw for categorical emotions, va for valence-arousal')
|
||||
parser.add_argument('--use_ccc', type=bool, default=False, help='use ccc loss instead of mse')
|
||||
# Test configuration.
|
||||
parser.add_argument('--test_iters', type=int, default=200000, help='test model from this step')
|
||||
|
||||
|
||||
+58
-21
@@ -45,6 +45,7 @@ class Solver(object):
|
||||
self.resume_iters = config.resume_iters
|
||||
self.selected_attrs = config.selected_attrs
|
||||
self.affectnet_emo_descr = config.affectnet_emo_descr
|
||||
self.use_ccc = config.use_ccc
|
||||
# Test configurations.
|
||||
self.test_iters = config.test_iters
|
||||
|
||||
@@ -148,20 +149,15 @@ class Solver(object):
|
||||
|
||||
def create_labels(self, c_org, c_dim=5, dataset='CelebA', selected_attrs=None):
|
||||
"""Generate target domain labels for debugging and testing."""
|
||||
# Get hair color indices.
|
||||
if dataset == 'CelebA':
|
||||
hair_color_indices = []
|
||||
for i, attr_name in enumerate(selected_attrs):
|
||||
if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
|
||||
hair_color_indices.append(i)
|
||||
if dataset in ['CelebA', 'RaFD']:
|
||||
# Get hair color indices.
|
||||
if dataset == 'CelebA':
|
||||
hair_color_indices = []
|
||||
for i, attr_name in enumerate(selected_attrs):
|
||||
if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
|
||||
hair_color_indices.append(i)
|
||||
|
||||
c_trg_list = []
|
||||
if dataset == 'AffectNet' and self.affectnet_emo_descr == 'va':
|
||||
for v in [-1., -0.5, 0., 0.5, 1.]:
|
||||
for a in [-1., -0.5, 0., 0.5, 1.]:
|
||||
c_trg = torch.Tensor([v, a]).repeat(c_org.size(0), 1)
|
||||
c_trg_list.append(c_trg.to(self.device))
|
||||
else:
|
||||
c_trg_list = []
|
||||
for i in range(c_dim):
|
||||
if dataset == 'CelebA':
|
||||
c_trg = c_org.clone()
|
||||
@@ -174,12 +170,32 @@ class Solver(object):
|
||||
c_trg[:, i] = (c_trg[:, i] == 0) # Reverse attribute value.
|
||||
elif dataset == 'RaFD':
|
||||
c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim)
|
||||
elif dataset == 'AffectNet' and self.affectnet_emo_descr == 'emotiw':
|
||||
c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim)
|
||||
|
||||
c_trg_list.append(c_trg.to(self.device))
|
||||
|
||||
elif dataset == 'AffectNet':
|
||||
c_trg_list = []
|
||||
if self.affectnet_emo_descr in ['emotiw', 'va-reg', 'va-cls']:
|
||||
for i in range(c_dim): #One hot for simplicty, change to PCA later
|
||||
c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim)
|
||||
c_trg_list.append(c_trg.to(self.device))
|
||||
elif self.affectnet_emo_descr == 'va':
|
||||
for v in [-1., -0.5, 0., 0.5, 1.]:
|
||||
for a in [-1., -0.5, 0., 0.5, 1.]:
|
||||
c_trg = torch.Tensor([v, a]).repeat(c_org.size(0), 1)
|
||||
c_trg_list.append(c_trg.to(self.device))
|
||||
return c_trg_list
|
||||
|
||||
def ccc_loss(self, x, y):
|
||||
mu_x = x.mean(0)
|
||||
var_x = x.var(0)
|
||||
mu_y = y.mean(0)
|
||||
var_y = y.var(0)
|
||||
cov_xy = torch.mean((x - mu_x.unsqueeze(0))*(y - mu_y.unsqueeze(0)), 0)
|
||||
ccc = 2*cov_xy/(var_x + var_y + (mu_x - mu_y)**2)
|
||||
loss = 1 - ccc
|
||||
return loss.sum()
|
||||
|
||||
def classification_loss(self, logit, target, dataset='CelebA'):
|
||||
"""Compute binary or softmax cross entropy loss."""
|
||||
if dataset == 'CelebA':
|
||||
@@ -187,10 +203,13 @@ class Solver(object):
|
||||
elif dataset == 'RaFD':
|
||||
return F.cross_entropy(logit, target)
|
||||
elif dataset == 'AffectNet':
|
||||
if self.affectnet_emo_descr == 'emotiw':
|
||||
if self.affectnet_emo_descr in ['emotiw', 'va-cls']:
|
||||
return F.cross_entropy(logit, target)
|
||||
else:
|
||||
return F.mse_loss(torch.tanh(logit), target)
|
||||
elif self.affectnet_emo_descr in ['va', 'va-reg']:
|
||||
if self.use_ccc:
|
||||
return self.ccc_loss(logit, target)
|
||||
else:
|
||||
return F.mse_loss(logit, target)
|
||||
def train(self):
|
||||
"""Train StarGAN within a single dataset."""
|
||||
# Set data loader.
|
||||
@@ -247,7 +266,19 @@ class Solver(object):
|
||||
if self.affectnet_emo_descr == 'emotiw':
|
||||
c_org = self.label2onehot(label_org, self.c_dim)
|
||||
c_trg = self.label2onehot(label_trg, self.c_dim)
|
||||
else:
|
||||
elif self.affectnet_emo_descr == 'va':
|
||||
c_org = label_org.clone()
|
||||
c_trg = label_trg.clone()
|
||||
elif self.affectnet_emo_descr == 'va-reg':
|
||||
label_org = label_org[:,1:]
|
||||
label_trg = label_trg[:,1:]
|
||||
c_org = label_org.clone()
|
||||
c_trg = label_trg.clone()
|
||||
elif self.affectnet_emo_descr == 'va-cls':
|
||||
cls_org = label_org[:,0].clone().long().to(self.device)
|
||||
cls_trg = label_trg[:,0].clone().long().to(self.device)
|
||||
label_org = label_org[:,1:]
|
||||
label_trg = label_trg[:,1:]
|
||||
c_org = label_org.clone()
|
||||
c_trg = label_trg.clone()
|
||||
|
||||
@@ -264,7 +295,10 @@ class Solver(object):
|
||||
# Compute loss with real images.
|
||||
out_src, out_cls = self.D(x_real)
|
||||
d_loss_real = - torch.mean(out_src)
|
||||
d_loss_cls = self.classification_loss(out_cls, label_org, self.dataset)
|
||||
if self.affectnet_emo_descr == 'va-cls':
|
||||
d_loss_cls = self.classification_loss(out_cls, cls_org, self.dataset)
|
||||
else:
|
||||
d_loss_cls = self.classification_loss(out_cls, label_org, self.dataset)
|
||||
|
||||
# Compute loss with fake images.
|
||||
x_fake = self.G(x_real, c_trg)
|
||||
@@ -299,7 +333,10 @@ class Solver(object):
|
||||
x_fake = self.G(x_real, c_trg)
|
||||
out_src, out_cls = self.D(x_fake)
|
||||
g_loss_fake = - torch.mean(out_src)
|
||||
g_loss_cls = self.classification_loss(out_cls, label_trg, self.dataset)
|
||||
if self.affectnet_emo_descr == 'va-cls':
|
||||
g_loss_cls = self.classification_loss(out_cls, cls_trg, self.dataset)
|
||||
else:
|
||||
g_loss_cls = self.classification_loss(out_cls, label_trg, self.dataset)
|
||||
|
||||
# Target-to-original domain.
|
||||
x_reconst = self.G(x_fake, c_org)
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário