Added AffectNet dataset w/ categorical and valence-arousal emotion descriptors

Esse commit está contido em:
Aleksander
2020-02-17 00:29:50 +00:00
commit a24d8688dd
4 arquivos alterados com 359 adições e 262 exclusões
+66 -3
Ver Arquivo
@@ -68,8 +68,69 @@ class CelebA(data.Dataset):
return self.num_images return self.num_images
def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=128, class AffectNet(data.Dataset):
batch_size=16, dataset='CelebA', mode='train', num_workers=1): """Dataset class for the AffectNet dataset."""
def __init__(self, image_dir, affectnet_emo_descr, transform, mode):
"""Initialize and preprocess the AffectNet dataset."""
self.image_dir = image_dir
self.affectnet_emo_descr = affectnet_emo_descr
self.transform = transform
self.mode = mode
self.train_dataset = []
self.test_dataset = []
self.preprocess()
if mode == 'train':
self.num_images = len(self.train_dataset)
else:
self.num_images = len(self.test_dataset)
def preprocess(self):
"""Preprocess the AffectNet emotional description file."""
if self.affectnet_emo_descr == 'emotiw':
lines_train = [line.rstrip() for line in open(os.path.join(self.image_dir, '_affectnet_train_emotiw.txt'), 'r')]
for line in lines_train:
filename, label = line.split()
self.train_dataset.append([filename, int(label)])
lines_test = [line.rstrip() for line in open(os.path.join(self.image_dir, '_affectnet_test_emotiw.txt'), 'r')]
for line in lines_test:
filename, label = line.split()
self.test_dataset.append([filename, int(label)])
else:
lines_train = [line.rstrip() for line in open(os.path.join(self.image_dir, '_affectnet_train_va.txt'), 'r')]
for line in lines_train:
split = line.split()
filename, v, a = split
self.train_dataset.append([filename, [float(v), float(a)]])
lines_test = [line.rstrip() for line in open(os.path.join(self.image_dir, '_affectnet_test_va.txt'), 'r')]
for line in lines_test:
split = line.split()
filename, v, a = split
self.test_dataset.append([filename, [float(v), float(a)]])
print('Finished preprocessing the AffectNet dataset...')
def __getitem__(self, index):
"""Return one image and its corresponding attribute label."""
if self.mode == 'train':
dataset = self.train_dataset
filename, label = dataset[index]
image = Image.open(os.path.join(self.image_dir, 'train', filename))
else:
dataset = self.test_dataset
filename, label = dataset[index]
image = Image.open(os.path.join(self.image_dir, 'validation', filename))
return self.transform(image), torch.tensor(label)
def __len__(self):
"""Return the number of images."""
return self.num_images
def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=128,
batch_size=16, dataset='CelebA', mode='train', affectnet_emo_descr='emotiw', num_workers=1):
"""Build and return a data loader.""" """Build and return a data loader."""
transform = [] transform = []
if mode == 'train': if mode == 'train':
@@ -84,9 +145,11 @@ def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=1
dataset = CelebA(image_dir, attr_path, selected_attrs, transform, mode) dataset = CelebA(image_dir, attr_path, selected_attrs, transform, mode)
elif dataset == 'RaFD': elif dataset == 'RaFD':
dataset = ImageFolder(image_dir, transform) dataset = ImageFolder(image_dir, transform)
elif dataset == 'AffectNet':
dataset = AffectNet(image_dir, affectnet_emo_descr, transform, mode)
data_loader = data.DataLoader(dataset=dataset, data_loader = data.DataLoader(dataset=dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=(mode=='train'), shuffle=(mode=='train'),
num_workers=num_workers) num_workers=num_workers)
return data_loader return data_loader
+13 -7
Ver Arquivo
@@ -1,8 +1,14 @@
URL=$1 URL1=$1
URL2=$2
# Download affectnet # Download affectnet and labels
ZIP_FILE=./data/affectnet.zip mkdir -p ./data/affectnet
mkdir -p ./data/ ZIP_FILE1=./data/affectnet.zip
wget -N $URL -O $ZIP_FILE wget -N $URL1 -O $ZIP_FILE1
unzip $ZIP_FILE -d ./data/ unzip $ZIP_FILE1 -d ./data/affectnet
rm $ZIP_FILE rm $ZIP_FILE1
ZIP_FILE2=./data/affectnet_labels.zip
wget -N $URL2 -O $ZIP_FILE2
unzip $ZIP_FILE2 -d ./data/affectnet
rm $ZIP_FILE2
+20 -15
Ver Arquivo
@@ -25,30 +25,33 @@ def main(config):
# Data loader. # Data loader.
celeba_loader = None celeba_loader = None
rafd_loader = None rafd_loader = None
affectnet_loader = None
if config.dataset in ['CelebA', 'Both']: if config.dataset in ['CelebA', 'Both']:
celeba_loader = get_loader(config.celeba_image_dir, config.attr_path, config.selected_attrs, celeba_loader = get_loader(config.celeba_image_dir, config.attr_path, config.selected_attrs,
config.celeba_crop_size, config.image_size, config.batch_size, config.celeba_crop_size, config.image_size, config.batch_size,
'CelebA', config.mode, config.num_workers) 'CelebA', config.mode, None, config.num_workers)
if config.dataset in ['RaFD', 'Both']: if config.dataset in ['RaFD', 'Both']:
rafd_loader = get_loader(config.rafd_image_dir, None, None, rafd_loader = get_loader(config.rafd_image_dir, None, None,
config.rafd_crop_size, config.image_size, config.batch_size, config.rafd_crop_size, config.image_size, config.batch_size,
'RaFD', config.mode, config.num_workers) 'RaFD', config.mode, None, config.num_workers)
if config.dataset in ['AffectNet']:
affectnet_loader = get_loader(config.affectnet_image_dir, None, None,
config.affectnet_crop_size, config.image_size, config.batch_size,
'AffectNet', config.mode, config.affectnet_emo_descr, config.num_workers)
# Solver for training and testing StarGAN. # Solver for training and testing StarGAN.
solver = Solver(celeba_loader, rafd_loader, config) solver = Solver(celeba_loader, rafd_loader, affectnet_loader, config)
if config.mode == 'train': if config.mode == 'train':
if config.dataset in ['CelebA', 'RaFD']: if config.dataset in ['CelebA', 'RaFD', 'AffectNet']:
solver.train() solver.train()
elif config.dataset in ['Both']: # elif config.dataset in ['Both']:
solver.train_multi() # solver.train_multi()
elif config.mode == 'test': elif config.mode == 'test':
if config.dataset in ['CelebA', 'RaFD']: if config.dataset in ['CelebA', 'RaFD', 'AffectNet']:
solver.test() solver.test()
elif config.dataset in ['Both']: # elif config.dataset in ['Both']:
solver.test_multi() # solver.test_multi()
if __name__ == '__main__': if __name__ == '__main__':
@@ -59,6 +62,7 @@ if __name__ == '__main__':
parser.add_argument('--c2_dim', type=int, default=8, help='dimension of domain labels (2nd dataset)') parser.add_argument('--c2_dim', type=int, default=8, help='dimension of domain labels (2nd dataset)')
parser.add_argument('--celeba_crop_size', type=int, default=178, help='crop size for the CelebA dataset') parser.add_argument('--celeba_crop_size', type=int, default=178, help='crop size for the CelebA dataset')
parser.add_argument('--rafd_crop_size', type=int, default=256, help='crop size for the RaFD dataset') parser.add_argument('--rafd_crop_size', type=int, default=256, help='crop size for the RaFD dataset')
parser.add_argument('--affectnet_crop_size', type=int, default=112, help='crop size for the AffectNet dataset')
parser.add_argument('--image_size', type=int, default=128, help='image resolution') parser.add_argument('--image_size', type=int, default=128, help='image resolution')
parser.add_argument('--g_conv_dim', type=int, default=64, help='number of conv filters in the first layer of G') parser.add_argument('--g_conv_dim', type=int, default=64, help='number of conv filters in the first layer of G')
parser.add_argument('--d_conv_dim', type=int, default=64, help='number of conv filters in the first layer of D') parser.add_argument('--d_conv_dim', type=int, default=64, help='number of conv filters in the first layer of D')
@@ -67,9 +71,9 @@ if __name__ == '__main__':
parser.add_argument('--lambda_cls', type=float, default=1, help='weight for domain classification loss') parser.add_argument('--lambda_cls', type=float, default=1, help='weight for domain classification loss')
parser.add_argument('--lambda_rec', type=float, default=10, help='weight for reconstruction loss') parser.add_argument('--lambda_rec', type=float, default=10, help='weight for reconstruction loss')
parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty') parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty')
# Training configuration. # Training configuration.
parser.add_argument('--dataset', type=str, default='CelebA', choices=['CelebA', 'RaFD', 'Both']) parser.add_argument('--dataset', type=str, default='CelebA', choices=['CelebA', 'RaFD', 'AffectNet', 'Both'])
parser.add_argument('--batch_size', type=int, default=16, help='mini-batch size') parser.add_argument('--batch_size', type=int, default=16, help='mini-batch size')
parser.add_argument('--num_iters', type=int, default=200000, help='number of total iterations for training D') parser.add_argument('--num_iters', type=int, default=200000, help='number of total iterations for training D')
parser.add_argument('--num_iters_decay', type=int, default=100000, help='number of iterations for decaying lr') parser.add_argument('--num_iters_decay', type=int, default=100000, help='number of iterations for decaying lr')
@@ -81,7 +85,7 @@ if __name__ == '__main__':
parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this step') parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this step')
parser.add_argument('--selected_attrs', '--list', nargs='+', help='selected attributes for the CelebA dataset', parser.add_argument('--selected_attrs', '--list', nargs='+', help='selected attributes for the CelebA dataset',
default=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']) default=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'])
parser.add_argument('--affectnet_emo_descr', type=str, default='emotiw', help='emotiw for categorical emotions, va for valence-arousal')
# Test configuration. # Test configuration.
parser.add_argument('--test_iters', type=int, default=200000, help='test model from this step') parser.add_argument('--test_iters', type=int, default=200000, help='test model from this step')
@@ -94,6 +98,7 @@ if __name__ == '__main__':
parser.add_argument('--celeba_image_dir', type=str, default='data/celeba/images') parser.add_argument('--celeba_image_dir', type=str, default='data/celeba/images')
parser.add_argument('--attr_path', type=str, default='data/celeba/list_attr_celeba.txt') parser.add_argument('--attr_path', type=str, default='data/celeba/list_attr_celeba.txt')
parser.add_argument('--rafd_image_dir', type=str, default='data/RaFD/train') parser.add_argument('--rafd_image_dir', type=str, default='data/RaFD/train')
parser.add_argument('--affectnet_image_dir', type=str, default='data/affectnet')
parser.add_argument('--log_dir', type=str, default='stargan/logs') parser.add_argument('--log_dir', type=str, default='stargan/logs')
parser.add_argument('--model_save_dir', type=str, default='stargan/models') parser.add_argument('--model_save_dir', type=str, default='stargan/models')
parser.add_argument('--sample_dir', type=str, default='stargan/samples') parser.add_argument('--sample_dir', type=str, default='stargan/samples')
@@ -107,4 +112,4 @@ if __name__ == '__main__':
config = parser.parse_args() config = parser.parse_args()
print(config) print(config)
main(config) main(config)
+260 -237
Ver Arquivo
@@ -13,13 +13,13 @@ import datetime
class Solver(object): class Solver(object):
"""Solver for training and testing StarGAN.""" """Solver for training and testing StarGAN."""
def __init__(self, celeba_loader, rafd_loader, config): def __init__(self, celeba_loader, rafd_loader, affectnet_loader, config):
"""Initialize configurations.""" """Initialize configurations."""
# Data loader. # Data loader.
self.celeba_loader = celeba_loader self.celeba_loader = celeba_loader
self.rafd_loader = rafd_loader self.rafd_loader = rafd_loader
self.affectnet_loader = affectnet_loader
# Model configurations. # Model configurations.
self.c_dim = config.c_dim self.c_dim = config.c_dim
self.c2_dim = config.c2_dim self.c2_dim = config.c2_dim
@@ -44,7 +44,7 @@ class Solver(object):
self.beta2 = config.beta2 self.beta2 = config.beta2
self.resume_iters = config.resume_iters self.resume_iters = config.resume_iters
self.selected_attrs = config.selected_attrs self.selected_attrs = config.selected_attrs
self.affectnet_emo_descr = config.affectnet_emo_descr
# Test configurations. # Test configurations.
self.test_iters = config.test_iters self.test_iters = config.test_iters
@@ -71,9 +71,9 @@ class Solver(object):
def build_model(self): def build_model(self):
"""Create a generator and a discriminator.""" """Create a generator and a discriminator."""
if self.dataset in ['CelebA', 'RaFD']: if self.dataset in ['CelebA', 'RaFD', 'AffectNet']:
self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num) self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num) self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num)
elif self.dataset in ['Both']: elif self.dataset in ['Both']:
self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num) # 2 for mask vector. self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num) # 2 for mask vector.
self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num) self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num)
@@ -82,7 +82,7 @@ class Solver(object):
self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
self.print_network(self.G, 'G') self.print_network(self.G, 'G')
self.print_network(self.D, 'D') self.print_network(self.D, 'D')
self.G.to(self.device) self.G.to(self.device)
self.D.to(self.device) self.D.to(self.device)
@@ -156,29 +156,41 @@ class Solver(object):
hair_color_indices.append(i) hair_color_indices.append(i)
c_trg_list = [] c_trg_list = []
for i in range(c_dim): if dataset == 'AffectNet' and self.affectnet_emo_descr == 'va':
if dataset == 'CelebA': for v in [-1., -0.5, 0., 0.5, 1.]:
c_trg = c_org.clone() for a in [-1., -0.5, 0., 0.5, 1.]:
if i in hair_color_indices: # Set one hair color to 1 and the rest to 0. c_trg = torch.Tensor([v, a]).repeat(c_org.size(0), 1)
c_trg[:, i] = 1 c_trg_list.append(c_trg.to(self.device))
for j in hair_color_indices: else:
if j != i: for i in range(c_dim):
c_trg[:, j] = 0 if dataset == 'CelebA':
else: c_trg = c_org.clone()
c_trg[:, i] = (c_trg[:, i] == 0) # Reverse attribute value. if i in hair_color_indices: # Set one hair color to 1 and the rest to 0.
elif dataset == 'RaFD': c_trg[:, i] = 1
c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim) for j in hair_color_indices:
if j != i:
c_trg[:, j] = 0
else:
c_trg[:, i] = (c_trg[:, i] == 0) # Reverse attribute value.
elif dataset == 'RaFD':
c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim)
elif dataset == 'AffectNet' and self.affectnet_emo_descr == 'emotiw':
c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim)
c_trg_list.append(c_trg.to(self.device))
c_trg_list.append(c_trg.to(self.device))
return c_trg_list return c_trg_list
def classification_loss(self, logit, target, dataset='CelebA'): def classification_loss(self, logit, target, dataset='CelebA'):
"""Compute binary or softmax cross entropy loss.""" """Compute binary or softmax cross entropy loss."""
if dataset == 'CelebA': if dataset == 'CelebA':
return F.binary_cross_entropy_with_logits(logit, target, size_average=False) / logit.size(0) return F.binary_cross_entropy_with_logits(logit, target, reduction='sum') / logit.size(0)
elif dataset == 'RaFD': elif dataset == 'RaFD':
return F.cross_entropy(logit, target) return F.cross_entropy(logit, target)
elif dataset == 'AffectNet':
if self.affectnet_emo_descr == 'emotiw':
return F.cross_entropy(logit, target)
else:
return F.mse_loss(torch.tanh(logit), target)
def train(self): def train(self):
"""Train StarGAN within a single dataset.""" """Train StarGAN within a single dataset."""
# Set data loader. # Set data loader.
@@ -186,6 +198,8 @@ class Solver(object):
data_loader = self.celeba_loader data_loader = self.celeba_loader
elif self.dataset == 'RaFD': elif self.dataset == 'RaFD':
data_loader = self.rafd_loader data_loader = self.rafd_loader
elif self.dataset == 'AffectNet':
data_loader = self.affectnet_loader
# Fetch fixed inputs for debugging. # Fetch fixed inputs for debugging.
data_iter = iter(data_loader) data_iter = iter(data_loader)
@@ -229,6 +243,13 @@ class Solver(object):
elif self.dataset == 'RaFD': elif self.dataset == 'RaFD':
c_org = self.label2onehot(label_org, self.c_dim) c_org = self.label2onehot(label_org, self.c_dim)
c_trg = self.label2onehot(label_trg, self.c_dim) c_trg = self.label2onehot(label_trg, self.c_dim)
elif self.dataset == 'AffectNet':
if self.affectnet_emo_descr == 'emotiw':
c_org = self.label2onehot(label_org, self.c_dim)
c_trg = self.label2onehot(label_trg, self.c_dim)
else:
c_org = label_org.clone()
c_trg = label_trg.clone()
x_real = x_real.to(self.device) # Input images. x_real = x_real.to(self.device) # Input images.
c_org = c_org.to(self.device) # Original domain labels. c_org = c_org.to(self.device) # Original domain labels.
@@ -268,11 +289,11 @@ class Solver(object):
loss['D/loss_fake'] = d_loss_fake.item() loss['D/loss_fake'] = d_loss_fake.item()
loss['D/loss_cls'] = d_loss_cls.item() loss['D/loss_cls'] = d_loss_cls.item()
loss['D/loss_gp'] = d_loss_gp.item() loss['D/loss_gp'] = d_loss_gp.item()
# =================================================================================== # # =================================================================================== #
# 3. Train the generator # # 3. Train the generator #
# =================================================================================== # # =================================================================================== #
if (i+1) % self.n_critic == 0: if (i+1) % self.n_critic == 0:
# Original-to-target domain. # Original-to-target domain.
x_fake = self.G(x_real, c_trg) x_fake = self.G(x_real, c_trg)
@@ -338,199 +359,201 @@ class Solver(object):
self.update_lr(g_lr, d_lr) self.update_lr(g_lr, d_lr)
print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
def train_multi(self): # def train_multi(self):
"""Train StarGAN with multiple datasets.""" # """Train StarGAN with multiple datasets."""
# Data iterators. # # Data iterators.
celeba_iter = iter(self.celeba_loader) # celeba_iter = iter(self.celeba_loader)
rafd_iter = iter(self.rafd_loader) # rafd_iter = iter(self.rafd_loader)
#
# Fetch fixed inputs for debugging. # # Fetch fixed inputs for debugging.
x_fixed, c_org = next(celeba_iter) # x_fixed, c_org = next(celeba_iter)
x_fixed = x_fixed.to(self.device) # x_fixed = x_fixed.to(self.device)
c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs) # c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs)
c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD') # c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD')
zero_celeba = torch.zeros(x_fixed.size(0), self.c_dim).to(self.device) # Zero vector for CelebA. # zero_celeba = torch.zeros(x_fixed.size(0), self.c_dim).to(self.device) # Zero vector for CelebA.
zero_rafd = torch.zeros(x_fixed.size(0), self.c2_dim).to(self.device) # Zero vector for RaFD. # zero_rafd = torch.zeros(x_fixed.size(0), self.c2_dim).to(self.device) # Zero vector for RaFD.
mask_celeba = self.label2onehot(torch.zeros(x_fixed.size(0)), 2).to(self.device) # Mask vector: [1, 0]. # mask_celeba = self.label2onehot(torch.zeros(x_fixed.size(0)), 2).to(self.device) # Mask vector: [1, 0].
mask_rafd = self.label2onehot(torch.ones(x_fixed.size(0)), 2).to(self.device) # Mask vector: [0, 1]. # mask_rafd = self.label2onehot(torch.ones(x_fixed.size(0)), 2).to(self.device) # Mask vector: [0, 1].
#
# Learning rate cache for decaying. # # Learning rate cache for decaying.
g_lr = self.g_lr # g_lr = self.g_lr
d_lr = self.d_lr # d_lr = self.d_lr
#
# Start training from scratch or resume training. # # Start training from scratch or resume training.
start_iters = 0 # start_iters = 0
if self.resume_iters: # if self.resume_iters:
start_iters = self.resume_iters # start_iters = self.resume_iters
self.restore_model(self.resume_iters) # self.restore_model(self.resume_iters)
#
# Start training. # # Start training.
print('Start training...') # print('Start training...')
start_time = time.time() # start_time = time.time()
for i in range(start_iters, self.num_iters): # for i in range(start_iters, self.num_iters):
for dataset in ['CelebA', 'RaFD']: # for dataset in ['CelebA', 'RaFD']:
#
# =================================================================================== # # # =================================================================================== #
# 1. Preprocess input data # # # 1. Preprocess input data #
# =================================================================================== # # # =================================================================================== #
#
# Fetch real images and labels. # # Fetch real images and labels.
data_iter = celeba_iter if dataset == 'CelebA' else rafd_iter # data_iter = celeba_iter if dataset == 'CelebA' else rafd_iter
#
try: # try:
x_real, label_org = next(data_iter) # x_real, label_org = next(data_iter)
except: # except:
if dataset == 'CelebA': # if dataset == 'CelebA':
celeba_iter = iter(self.celeba_loader) # celeba_iter = iter(self.celeba_loader)
x_real, label_org = next(celeba_iter) # x_real, label_org = next(celeba_iter)
elif dataset == 'RaFD': # elif dataset == 'RaFD':
rafd_iter = iter(self.rafd_loader) # rafd_iter = iter(self.rafd_loader)
x_real, label_org = next(rafd_iter) # x_real, label_org = next(rafd_iter)
#
# Generate target domain labels randomly. # # Generate target domain labels randomly.
rand_idx = torch.randperm(label_org.size(0)) # rand_idx = torch.randperm(label_org.size(0))
label_trg = label_org[rand_idx] # label_trg = label_org[rand_idx]
#
if dataset == 'CelebA': # if dataset == 'CelebA':
c_org = label_org.clone() # c_org = label_org.clone()
c_trg = label_trg.clone() # c_trg = label_trg.clone()
zero = torch.zeros(x_real.size(0), self.c2_dim) # zero = torch.zeros(x_real.size(0), self.c2_dim)
mask = self.label2onehot(torch.zeros(x_real.size(0)), 2) # mask = self.label2onehot(torch.zeros(x_real.size(0)), 2)
c_org = torch.cat([c_org, zero, mask], dim=1) # c_org = torch.cat([c_org, zero, mask], dim=1)
c_trg = torch.cat([c_trg, zero, mask], dim=1) # c_trg = torch.cat([c_trg, zero, mask], dim=1)
elif dataset == 'RaFD': # elif dataset == 'RaFD':
c_org = self.label2onehot(label_org, self.c2_dim) # c_org = self.label2onehot(label_org, self.c2_dim)
c_trg = self.label2onehot(label_trg, self.c2_dim) # c_trg = self.label2onehot(label_trg, self.c2_dim)
zero = torch.zeros(x_real.size(0), self.c_dim) # zero = torch.zeros(x_real.size(0), self.c_dim)
mask = self.label2onehot(torch.ones(x_real.size(0)), 2) # mask = self.label2onehot(torch.ones(x_real.size(0)), 2)
c_org = torch.cat([zero, c_org, mask], dim=1) # c_org = torch.cat([zero, c_org, mask], dim=1)
c_trg = torch.cat([zero, c_trg, mask], dim=1) # c_trg = torch.cat([zero, c_trg, mask], dim=1)
#
x_real = x_real.to(self.device) # Input images. # x_real = x_real.to(self.device) # Input images.
c_org = c_org.to(self.device) # Original domain labels. # c_org = c_org.to(self.device) # Original domain labels.
c_trg = c_trg.to(self.device) # Target domain labels. # c_trg = c_trg.to(self.device) # Target domain labels.
label_org = label_org.to(self.device) # Labels for computing classification loss. # label_org = label_org.to(self.device) # Labels for computing classification loss.
label_trg = label_trg.to(self.device) # Labels for computing classification loss. # label_trg = label_trg.to(self.device) # Labels for computing classification loss.
#
# =================================================================================== # # # =================================================================================== #
# 2. Train the discriminator # # # 2. Train the discriminator #
# =================================================================================== # # # =================================================================================== #
#
# Compute loss with real images. # # Compute loss with real images.
out_src, out_cls = self.D(x_real) # out_src, out_cls = self.D(x_real)
out_cls = out_cls[:, :self.c_dim] if dataset == 'CelebA' else out_cls[:, self.c_dim:] # out_cls = out_cls[:, :self.c_dim] if dataset == 'CelebA' else out_cls[:, self.c_dim:]
d_loss_real = - torch.mean(out_src) # d_loss_real = - torch.mean(out_src)
d_loss_cls = self.classification_loss(out_cls, label_org, dataset) # d_loss_cls = self.classification_loss(out_cls, label_org, dataset)
#
# Compute loss with fake images. # # Compute loss with fake images.
x_fake = self.G(x_real, c_trg) # x_fake = self.G(x_real, c_trg)
out_src, _ = self.D(x_fake.detach()) # out_src, _ = self.D(x_fake.detach())
d_loss_fake = torch.mean(out_src) # d_loss_fake = torch.mean(out_src)
#
# Compute loss for gradient penalty. # # Compute loss for gradient penalty.
alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device) # alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True) # x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
out_src, _ = self.D(x_hat) # out_src, _ = self.D(x_hat)
d_loss_gp = self.gradient_penalty(out_src, x_hat) # d_loss_gp = self.gradient_penalty(out_src, x_hat)
#
# Backward and optimize. # # Backward and optimize.
d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp # d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp
self.reset_grad() # self.reset_grad()
d_loss.backward() # d_loss.backward()
self.d_optimizer.step() # self.d_optimizer.step()
#
# Logging. # # Logging.
loss = {} # loss = {}
loss['D/loss_real'] = d_loss_real.item() # loss['D/loss_real'] = d_loss_real.item()
loss['D/loss_fake'] = d_loss_fake.item() # loss['D/loss_fake'] = d_loss_fake.item()
loss['D/loss_cls'] = d_loss_cls.item() # loss['D/loss_cls'] = d_loss_cls.item()
loss['D/loss_gp'] = d_loss_gp.item() # loss['D/loss_gp'] = d_loss_gp.item()
#
# =================================================================================== # # # =================================================================================== #
# 3. Train the generator # # # 3. Train the generator #
# =================================================================================== # # # =================================================================================== #
#
if (i+1) % self.n_critic == 0: # if (i+1) % self.n_critic == 0:
# Original-to-target domain. # # Original-to-target domain.
x_fake = self.G(x_real, c_trg) # x_fake = self.G(x_real, c_trg)
out_src, out_cls = self.D(x_fake) # out_src, out_cls = self.D(x_fake)
out_cls = out_cls[:, :self.c_dim] if dataset == 'CelebA' else out_cls[:, self.c_dim:] # out_cls = out_cls[:, :self.c_dim] if dataset == 'CelebA' else out_cls[:, self.c_dim:]
g_loss_fake = - torch.mean(out_src) # g_loss_fake = - torch.mean(out_src)
g_loss_cls = self.classification_loss(out_cls, label_trg, dataset) # g_loss_cls = self.classification_loss(out_cls, label_trg, dataset)
#
# Target-to-original domain. # # Target-to-original domain.
x_reconst = self.G(x_fake, c_org) # x_reconst = self.G(x_fake, c_org)
g_loss_rec = torch.mean(torch.abs(x_real - x_reconst)) # g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))
#
# Backward and optimize. # # Backward and optimize.
g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls # g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
self.reset_grad() # self.reset_grad()
g_loss.backward() # g_loss.backward()
self.g_optimizer.step() # self.g_optimizer.step()
#
# Logging. # # Logging.
loss['G/loss_fake'] = g_loss_fake.item() # loss['G/loss_fake'] = g_loss_fake.item()
loss['G/loss_rec'] = g_loss_rec.item() # loss['G/loss_rec'] = g_loss_rec.item()
loss['G/loss_cls'] = g_loss_cls.item() # loss['G/loss_cls'] = g_loss_cls.item()
#
# =================================================================================== # # # =================================================================================== #
# 4. Miscellaneous # # # 4. Miscellaneous #
# =================================================================================== # # # =================================================================================== #
#
# Print out training info. # # Print out training info.
if (i+1) % self.log_step == 0: # if (i+1) % self.log_step == 0:
et = time.time() - start_time # et = time.time() - start_time
et = str(datetime.timedelta(seconds=et))[:-7] # et = str(datetime.timedelta(seconds=et))[:-7]
log = "Elapsed [{}], Iteration [{}/{}], Dataset [{}]".format(et, i+1, self.num_iters, dataset) # log = "Elapsed [{}], Iteration [{}/{}], Dataset [{}]".format(et, i+1, self.num_iters, dataset)
for tag, value in loss.items(): # for tag, value in loss.items():
log += ", {}: {:.4f}".format(tag, value) # log += ", {}: {:.4f}".format(tag, value)
print(log) # print(log)
#
if self.use_tensorboard: # if self.use_tensorboard:
for tag, value in loss.items(): # for tag, value in loss.items():
self.logger.scalar_summary(tag, value, i+1) # self.logger.scalar_summary(tag, value, i+1)
#
# Translate fixed images for debugging. # # Translate fixed images for debugging.
if (i+1) % self.sample_step == 0: # if (i+1) % self.sample_step == 0:
with torch.no_grad(): # with torch.no_grad():
x_fake_list = [x_fixed] # x_fake_list = [x_fixed]
for c_fixed in c_celeba_list: # for c_fixed in c_celeba_list:
c_trg = torch.cat([c_fixed, zero_rafd, mask_celeba], dim=1) # c_trg = torch.cat([c_fixed, zero_rafd, mask_celeba], dim=1)
x_fake_list.append(self.G(x_fixed, c_trg)) # x_fake_list.append(self.G(x_fixed, c_trg))
for c_fixed in c_rafd_list: # for c_fixed in c_rafd_list:
c_trg = torch.cat([zero_celeba, c_fixed, mask_rafd], dim=1) # c_trg = torch.cat([zero_celeba, c_fixed, mask_rafd], dim=1)
x_fake_list.append(self.G(x_fixed, c_trg)) # x_fake_list.append(self.G(x_fixed, c_trg))
x_concat = torch.cat(x_fake_list, dim=3) # x_concat = torch.cat(x_fake_list, dim=3)
sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1)) # sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1))
save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0) # save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
print('Saved real and fake images into {}...'.format(sample_path)) # print('Saved real and fake images into {}...'.format(sample_path))
#
# Save model checkpoints. # # Save model checkpoints.
if (i+1) % self.model_save_step == 0: # if (i+1) % self.model_save_step == 0:
G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1)) # G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1))
D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1)) # D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1))
torch.save(self.G.state_dict(), G_path) # torch.save(self.G.state_dict(), G_path)
torch.save(self.D.state_dict(), D_path) # torch.save(self.D.state_dict(), D_path)
print('Saved model checkpoints into {}...'.format(self.model_save_dir)) # print('Saved model checkpoints into {}...'.format(self.model_save_dir))
#
# Decay learning rates. # # Decay learning rates.
if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay): # if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay):
g_lr -= (self.g_lr / float(self.num_iters_decay)) # g_lr -= (self.g_lr / float(self.num_iters_decay))
d_lr -= (self.d_lr / float(self.num_iters_decay)) # d_lr -= (self.d_lr / float(self.num_iters_decay))
self.update_lr(g_lr, d_lr) # self.update_lr(g_lr, d_lr)
print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) # print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
def test(self): def test(self):
"""Translate images using StarGAN trained on a single dataset.""" """Translate images using StarGAN trained on a single dataset."""
# Load the trained generator. # Load the trained generator.
self.restore_model(self.test_iters) self.restore_model(self.test_iters)
# Set data loader. # Set data loader.
if self.dataset == 'CelebA': if self.dataset == 'CelebA':
data_loader = self.celeba_loader data_loader = self.celeba_loader
elif self.dataset == 'RaFD': elif self.dataset == 'RaFD':
data_loader = self.rafd_loader data_loader = self.rafd_loader
elif self.dataset == 'AffectNet':
data_loader = self.affectnet_loader
with torch.no_grad(): with torch.no_grad():
for i, (x_real, c_org) in enumerate(data_loader): for i, (x_real, c_org) in enumerate(data_loader):
@@ -549,34 +572,34 @@ class Solver(object):
save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0) save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
print('Saved real and fake images into {}...'.format(result_path)) print('Saved real and fake images into {}...'.format(result_path))
def test_multi(self): # def test_multi(self):
"""Translate images using StarGAN trained on multiple datasets.""" # """Translate images using StarGAN trained on multiple datasets."""
# Load the trained generator. # # Load the trained generator.
self.restore_model(self.test_iters) # self.restore_model(self.test_iters)
#
with torch.no_grad(): # with torch.no_grad():
for i, (x_real, c_org) in enumerate(self.celeba_loader): # for i, (x_real, c_org) in enumerate(self.celeba_loader):
#
# Prepare input images and target domain labels. # # Prepare input images and target domain labels.
x_real = x_real.to(self.device) # x_real = x_real.to(self.device)
c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs) # c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs)
c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD') # c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD')
zero_celeba = torch.zeros(x_real.size(0), self.c_dim).to(self.device) # Zero vector for CelebA. # zero_celeba = torch.zeros(x_real.size(0), self.c_dim).to(self.device) # Zero vector for CelebA.
zero_rafd = torch.zeros(x_real.size(0), self.c2_dim).to(self.device) # Zero vector for RaFD. # zero_rafd = torch.zeros(x_real.size(0), self.c2_dim).to(self.device) # Zero vector for RaFD.
mask_celeba = self.label2onehot(torch.zeros(x_real.size(0)), 2).to(self.device) # Mask vector: [1, 0]. # mask_celeba = self.label2onehot(torch.zeros(x_real.size(0)), 2).to(self.device) # Mask vector: [1, 0].
mask_rafd = self.label2onehot(torch.ones(x_real.size(0)), 2).to(self.device) # Mask vector: [0, 1]. # mask_rafd = self.label2onehot(torch.ones(x_real.size(0)), 2).to(self.device) # Mask vector: [0, 1].
#
# Translate images. # # Translate images.
x_fake_list = [x_real] # x_fake_list = [x_real]
for c_celeba in c_celeba_list: # for c_celeba in c_celeba_list:
c_trg = torch.cat([c_celeba, zero_rafd, mask_celeba], dim=1) # c_trg = torch.cat([c_celeba, zero_rafd, mask_celeba], dim=1)
x_fake_list.append(self.G(x_real, c_trg)) # x_fake_list.append(self.G(x_real, c_trg))
for c_rafd in c_rafd_list: # for c_rafd in c_rafd_list:
c_trg = torch.cat([zero_celeba, c_rafd, mask_rafd], dim=1) # c_trg = torch.cat([zero_celeba, c_rafd, mask_rafd], dim=1)
x_fake_list.append(self.G(x_real, c_trg)) # x_fake_list.append(self.G(x_real, c_trg))
#
# Save the translated images. # # Save the translated images.
x_concat = torch.cat(x_fake_list, dim=3) # x_concat = torch.cat(x_fake_list, dim=3)
result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1)) # result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0) # save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
print('Saved real and fake images into {}...'.format(result_path)) # print('Saved real and fake images into {}...'.format(result_path))