Added AffectNet dataset w/ categorical and valence-arousal emotion descriptors
Esse commit está contido em:
+64
-1
@@ -68,8 +68,69 @@ class CelebA(data.Dataset):
|
||||
return self.num_images
|
||||
|
||||
|
||||
class AffectNet(data.Dataset):
|
||||
"""Dataset class for the AffectNet dataset."""
|
||||
|
||||
def __init__(self, image_dir, affectnet_emo_descr, transform, mode):
|
||||
"""Initialize and preprocess the AffectNet dataset."""
|
||||
self.image_dir = image_dir
|
||||
self.affectnet_emo_descr = affectnet_emo_descr
|
||||
self.transform = transform
|
||||
self.mode = mode
|
||||
self.train_dataset = []
|
||||
self.test_dataset = []
|
||||
self.preprocess()
|
||||
|
||||
if mode == 'train':
|
||||
self.num_images = len(self.train_dataset)
|
||||
else:
|
||||
self.num_images = len(self.test_dataset)
|
||||
|
||||
|
||||
def preprocess(self):
|
||||
"""Preprocess the AffectNet emotional description file."""
|
||||
if self.affectnet_emo_descr == 'emotiw':
|
||||
lines_train = [line.rstrip() for line in open(os.path.join(self.image_dir, '_affectnet_train_emotiw.txt'), 'r')]
|
||||
for line in lines_train:
|
||||
filename, label = line.split()
|
||||
self.train_dataset.append([filename, int(label)])
|
||||
lines_test = [line.rstrip() for line in open(os.path.join(self.image_dir, '_affectnet_test_emotiw.txt'), 'r')]
|
||||
for line in lines_test:
|
||||
filename, label = line.split()
|
||||
self.test_dataset.append([filename, int(label)])
|
||||
else:
|
||||
lines_train = [line.rstrip() for line in open(os.path.join(self.image_dir, '_affectnet_train_va.txt'), 'r')]
|
||||
for line in lines_train:
|
||||
split = line.split()
|
||||
filename, v, a = split
|
||||
self.train_dataset.append([filename, [float(v), float(a)]])
|
||||
lines_test = [line.rstrip() for line in open(os.path.join(self.image_dir, '_affectnet_test_va.txt'), 'r')]
|
||||
for line in lines_test:
|
||||
split = line.split()
|
||||
filename, v, a = split
|
||||
self.test_dataset.append([filename, [float(v), float(a)]])
|
||||
|
||||
print('Finished preprocessing the AffectNet dataset...')
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Return one image and its corresponding attribute label."""
|
||||
if self.mode == 'train':
|
||||
dataset = self.train_dataset
|
||||
filename, label = dataset[index]
|
||||
image = Image.open(os.path.join(self.image_dir, 'train', filename))
|
||||
else:
|
||||
dataset = self.test_dataset
|
||||
filename, label = dataset[index]
|
||||
image = Image.open(os.path.join(self.image_dir, 'validation', filename))
|
||||
return self.transform(image), torch.tensor(label)
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of images."""
|
||||
return self.num_images
|
||||
|
||||
|
||||
def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=128,
|
||||
batch_size=16, dataset='CelebA', mode='train', num_workers=1):
|
||||
batch_size=16, dataset='CelebA', mode='train', affectnet_emo_descr='emotiw', num_workers=1):
|
||||
"""Build and return a data loader."""
|
||||
transform = []
|
||||
if mode == 'train':
|
||||
@@ -84,6 +145,8 @@ def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=1
|
||||
dataset = CelebA(image_dir, attr_path, selected_attrs, transform, mode)
|
||||
elif dataset == 'RaFD':
|
||||
dataset = ImageFolder(image_dir, transform)
|
||||
elif dataset == 'AffectNet':
|
||||
dataset = AffectNet(image_dir, affectnet_emo_descr, transform, mode)
|
||||
|
||||
data_loader = data.DataLoader(dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
|
||||
+13
-7
@@ -1,8 +1,14 @@
|
||||
URL=$1
|
||||
URL1=$1
|
||||
URL2=$2
|
||||
|
||||
# Download affectnet
|
||||
ZIP_FILE=./data/affectnet.zip
|
||||
mkdir -p ./data/
|
||||
wget -N $URL -O $ZIP_FILE
|
||||
unzip $ZIP_FILE -d ./data/
|
||||
rm $ZIP_FILE
|
||||
# Download affectnet and labels
|
||||
mkdir -p ./data/affectnet
|
||||
ZIP_FILE1=./data/affectnet.zip
|
||||
wget -N $URL1 -O $ZIP_FILE1
|
||||
unzip $ZIP_FILE1 -d ./data/affectnet
|
||||
rm $ZIP_FILE1
|
||||
|
||||
ZIP_FILE2=./data/affectnet_labels.zip
|
||||
wget -N $URL2 -O $ZIP_FILE2
|
||||
unzip $ZIP_FILE2 -d ./data/affectnet
|
||||
rm $ZIP_FILE2
|
||||
|
||||
+18
-13
@@ -25,30 +25,33 @@ def main(config):
|
||||
# Data loader.
|
||||
celeba_loader = None
|
||||
rafd_loader = None
|
||||
affectnet_loader = None
|
||||
|
||||
if config.dataset in ['CelebA', 'Both']:
|
||||
celeba_loader = get_loader(config.celeba_image_dir, config.attr_path, config.selected_attrs,
|
||||
config.celeba_crop_size, config.image_size, config.batch_size,
|
||||
'CelebA', config.mode, config.num_workers)
|
||||
'CelebA', config.mode, None, config.num_workers)
|
||||
if config.dataset in ['RaFD', 'Both']:
|
||||
rafd_loader = get_loader(config.rafd_image_dir, None, None,
|
||||
config.rafd_crop_size, config.image_size, config.batch_size,
|
||||
'RaFD', config.mode, config.num_workers)
|
||||
|
||||
|
||||
'RaFD', config.mode, None, config.num_workers)
|
||||
if config.dataset in ['AffectNet']:
|
||||
affectnet_loader = get_loader(config.affectnet_image_dir, None, None,
|
||||
config.affectnet_crop_size, config.image_size, config.batch_size,
|
||||
'AffectNet', config.mode, config.affectnet_emo_descr, config.num_workers)
|
||||
# Solver for training and testing StarGAN.
|
||||
solver = Solver(celeba_loader, rafd_loader, config)
|
||||
solver = Solver(celeba_loader, rafd_loader, affectnet_loader, config)
|
||||
|
||||
if config.mode == 'train':
|
||||
if config.dataset in ['CelebA', 'RaFD']:
|
||||
if config.dataset in ['CelebA', 'RaFD', 'AffectNet']:
|
||||
solver.train()
|
||||
elif config.dataset in ['Both']:
|
||||
solver.train_multi()
|
||||
# elif config.dataset in ['Both']:
|
||||
# solver.train_multi()
|
||||
elif config.mode == 'test':
|
||||
if config.dataset in ['CelebA', 'RaFD']:
|
||||
if config.dataset in ['CelebA', 'RaFD', 'AffectNet']:
|
||||
solver.test()
|
||||
elif config.dataset in ['Both']:
|
||||
solver.test_multi()
|
||||
# elif config.dataset in ['Both']:
|
||||
# solver.test_multi()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@@ -59,6 +62,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--c2_dim', type=int, default=8, help='dimension of domain labels (2nd dataset)')
|
||||
parser.add_argument('--celeba_crop_size', type=int, default=178, help='crop size for the CelebA dataset')
|
||||
parser.add_argument('--rafd_crop_size', type=int, default=256, help='crop size for the RaFD dataset')
|
||||
parser.add_argument('--affectnet_crop_size', type=int, default=112, help='crop size for the AffectNet dataset')
|
||||
parser.add_argument('--image_size', type=int, default=128, help='image resolution')
|
||||
parser.add_argument('--g_conv_dim', type=int, default=64, help='number of conv filters in the first layer of G')
|
||||
parser.add_argument('--d_conv_dim', type=int, default=64, help='number of conv filters in the first layer of D')
|
||||
@@ -69,7 +73,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty')
|
||||
|
||||
# Training configuration.
|
||||
parser.add_argument('--dataset', type=str, default='CelebA', choices=['CelebA', 'RaFD', 'Both'])
|
||||
parser.add_argument('--dataset', type=str, default='CelebA', choices=['CelebA', 'RaFD', 'AffectNet', 'Both'])
|
||||
parser.add_argument('--batch_size', type=int, default=16, help='mini-batch size')
|
||||
parser.add_argument('--num_iters', type=int, default=200000, help='number of total iterations for training D')
|
||||
parser.add_argument('--num_iters_decay', type=int, default=100000, help='number of iterations for decaying lr')
|
||||
@@ -81,7 +85,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this step')
|
||||
parser.add_argument('--selected_attrs', '--list', nargs='+', help='selected attributes for the CelebA dataset',
|
||||
default=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'])
|
||||
|
||||
parser.add_argument('--affectnet_emo_descr', type=str, default='emotiw', help='emotiw for categorical emotions, va for valence-arousal')
|
||||
# Test configuration.
|
||||
parser.add_argument('--test_iters', type=int, default=200000, help='test model from this step')
|
||||
|
||||
@@ -94,6 +98,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--celeba_image_dir', type=str, default='data/celeba/images')
|
||||
parser.add_argument('--attr_path', type=str, default='data/celeba/list_attr_celeba.txt')
|
||||
parser.add_argument('--rafd_image_dir', type=str, default='data/RaFD/train')
|
||||
parser.add_argument('--affectnet_image_dir', type=str, default='data/affectnet')
|
||||
parser.add_argument('--log_dir', type=str, default='stargan/logs')
|
||||
parser.add_argument('--model_save_dir', type=str, default='stargan/models')
|
||||
parser.add_argument('--sample_dir', type=str, default='stargan/samples')
|
||||
|
||||
+254
-231
@@ -13,13 +13,13 @@ import datetime
|
||||
class Solver(object):
|
||||
"""Solver for training and testing StarGAN."""
|
||||
|
||||
def __init__(self, celeba_loader, rafd_loader, config):
|
||||
def __init__(self, celeba_loader, rafd_loader, affectnet_loader, config):
|
||||
"""Initialize configurations."""
|
||||
|
||||
# Data loader.
|
||||
self.celeba_loader = celeba_loader
|
||||
self.rafd_loader = rafd_loader
|
||||
|
||||
self.affectnet_loader = affectnet_loader
|
||||
# Model configurations.
|
||||
self.c_dim = config.c_dim
|
||||
self.c2_dim = config.c2_dim
|
||||
@@ -44,7 +44,7 @@ class Solver(object):
|
||||
self.beta2 = config.beta2
|
||||
self.resume_iters = config.resume_iters
|
||||
self.selected_attrs = config.selected_attrs
|
||||
|
||||
self.affectnet_emo_descr = config.affectnet_emo_descr
|
||||
# Test configurations.
|
||||
self.test_iters = config.test_iters
|
||||
|
||||
@@ -71,7 +71,7 @@ class Solver(object):
|
||||
|
||||
def build_model(self):
|
||||
"""Create a generator and a discriminator."""
|
||||
if self.dataset in ['CelebA', 'RaFD']:
|
||||
if self.dataset in ['CelebA', 'RaFD', 'AffectNet']:
|
||||
self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
|
||||
self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num)
|
||||
elif self.dataset in ['Both']:
|
||||
@@ -156,29 +156,41 @@ class Solver(object):
|
||||
hair_color_indices.append(i)
|
||||
|
||||
c_trg_list = []
|
||||
for i in range(c_dim):
|
||||
if dataset == 'CelebA':
|
||||
c_trg = c_org.clone()
|
||||
if i in hair_color_indices: # Set one hair color to 1 and the rest to 0.
|
||||
c_trg[:, i] = 1
|
||||
for j in hair_color_indices:
|
||||
if j != i:
|
||||
c_trg[:, j] = 0
|
||||
else:
|
||||
c_trg[:, i] = (c_trg[:, i] == 0) # Reverse attribute value.
|
||||
elif dataset == 'RaFD':
|
||||
c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim)
|
||||
if dataset == 'AffectNet' and self.affectnet_emo_descr == 'va':
|
||||
for v in [-1., -0.5, 0., 0.5, 1.]:
|
||||
for a in [-1., -0.5, 0., 0.5, 1.]:
|
||||
c_trg = torch.Tensor([v, a]).repeat(c_org.size(0), 1)
|
||||
c_trg_list.append(c_trg.to(self.device))
|
||||
else:
|
||||
for i in range(c_dim):
|
||||
if dataset == 'CelebA':
|
||||
c_trg = c_org.clone()
|
||||
if i in hair_color_indices: # Set one hair color to 1 and the rest to 0.
|
||||
c_trg[:, i] = 1
|
||||
for j in hair_color_indices:
|
||||
if j != i:
|
||||
c_trg[:, j] = 0
|
||||
else:
|
||||
c_trg[:, i] = (c_trg[:, i] == 0) # Reverse attribute value.
|
||||
elif dataset == 'RaFD':
|
||||
c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim)
|
||||
elif dataset == 'AffectNet' and self.affectnet_emo_descr == 'emotiw':
|
||||
c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim)
|
||||
c_trg_list.append(c_trg.to(self.device))
|
||||
|
||||
c_trg_list.append(c_trg.to(self.device))
|
||||
return c_trg_list
|
||||
|
||||
def classification_loss(self, logit, target, dataset='CelebA'):
|
||||
"""Compute binary or softmax cross entropy loss."""
|
||||
if dataset == 'CelebA':
|
||||
return F.binary_cross_entropy_with_logits(logit, target, size_average=False) / logit.size(0)
|
||||
return F.binary_cross_entropy_with_logits(logit, target, reduction='sum') / logit.size(0)
|
||||
elif dataset == 'RaFD':
|
||||
return F.cross_entropy(logit, target)
|
||||
|
||||
elif dataset == 'AffectNet':
|
||||
if self.affectnet_emo_descr == 'emotiw':
|
||||
return F.cross_entropy(logit, target)
|
||||
else:
|
||||
return F.mse_loss(torch.tanh(logit), target)
|
||||
def train(self):
|
||||
"""Train StarGAN within a single dataset."""
|
||||
# Set data loader.
|
||||
@@ -186,6 +198,8 @@ class Solver(object):
|
||||
data_loader = self.celeba_loader
|
||||
elif self.dataset == 'RaFD':
|
||||
data_loader = self.rafd_loader
|
||||
elif self.dataset == 'AffectNet':
|
||||
data_loader = self.affectnet_loader
|
||||
|
||||
# Fetch fixed inputs for debugging.
|
||||
data_iter = iter(data_loader)
|
||||
@@ -229,6 +243,13 @@ class Solver(object):
|
||||
elif self.dataset == 'RaFD':
|
||||
c_org = self.label2onehot(label_org, self.c_dim)
|
||||
c_trg = self.label2onehot(label_trg, self.c_dim)
|
||||
elif self.dataset == 'AffectNet':
|
||||
if self.affectnet_emo_descr == 'emotiw':
|
||||
c_org = self.label2onehot(label_org, self.c_dim)
|
||||
c_trg = self.label2onehot(label_trg, self.c_dim)
|
||||
else:
|
||||
c_org = label_org.clone()
|
||||
c_trg = label_trg.clone()
|
||||
|
||||
x_real = x_real.to(self.device) # Input images.
|
||||
c_org = c_org.to(self.device) # Original domain labels.
|
||||
@@ -338,187 +359,187 @@ class Solver(object):
|
||||
self.update_lr(g_lr, d_lr)
|
||||
print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
|
||||
|
||||
def train_multi(self):
|
||||
"""Train StarGAN with multiple datasets."""
|
||||
# Data iterators.
|
||||
celeba_iter = iter(self.celeba_loader)
|
||||
rafd_iter = iter(self.rafd_loader)
|
||||
|
||||
# Fetch fixed inputs for debugging.
|
||||
x_fixed, c_org = next(celeba_iter)
|
||||
x_fixed = x_fixed.to(self.device)
|
||||
c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs)
|
||||
c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD')
|
||||
zero_celeba = torch.zeros(x_fixed.size(0), self.c_dim).to(self.device) # Zero vector for CelebA.
|
||||
zero_rafd = torch.zeros(x_fixed.size(0), self.c2_dim).to(self.device) # Zero vector for RaFD.
|
||||
mask_celeba = self.label2onehot(torch.zeros(x_fixed.size(0)), 2).to(self.device) # Mask vector: [1, 0].
|
||||
mask_rafd = self.label2onehot(torch.ones(x_fixed.size(0)), 2).to(self.device) # Mask vector: [0, 1].
|
||||
|
||||
# Learning rate cache for decaying.
|
||||
g_lr = self.g_lr
|
||||
d_lr = self.d_lr
|
||||
|
||||
# Start training from scratch or resume training.
|
||||
start_iters = 0
|
||||
if self.resume_iters:
|
||||
start_iters = self.resume_iters
|
||||
self.restore_model(self.resume_iters)
|
||||
|
||||
# Start training.
|
||||
print('Start training...')
|
||||
start_time = time.time()
|
||||
for i in range(start_iters, self.num_iters):
|
||||
for dataset in ['CelebA', 'RaFD']:
|
||||
|
||||
# =================================================================================== #
|
||||
# 1. Preprocess input data #
|
||||
# =================================================================================== #
|
||||
|
||||
# Fetch real images and labels.
|
||||
data_iter = celeba_iter if dataset == 'CelebA' else rafd_iter
|
||||
|
||||
try:
|
||||
x_real, label_org = next(data_iter)
|
||||
except:
|
||||
if dataset == 'CelebA':
|
||||
celeba_iter = iter(self.celeba_loader)
|
||||
x_real, label_org = next(celeba_iter)
|
||||
elif dataset == 'RaFD':
|
||||
rafd_iter = iter(self.rafd_loader)
|
||||
x_real, label_org = next(rafd_iter)
|
||||
|
||||
# Generate target domain labels randomly.
|
||||
rand_idx = torch.randperm(label_org.size(0))
|
||||
label_trg = label_org[rand_idx]
|
||||
|
||||
if dataset == 'CelebA':
|
||||
c_org = label_org.clone()
|
||||
c_trg = label_trg.clone()
|
||||
zero = torch.zeros(x_real.size(0), self.c2_dim)
|
||||
mask = self.label2onehot(torch.zeros(x_real.size(0)), 2)
|
||||
c_org = torch.cat([c_org, zero, mask], dim=1)
|
||||
c_trg = torch.cat([c_trg, zero, mask], dim=1)
|
||||
elif dataset == 'RaFD':
|
||||
c_org = self.label2onehot(label_org, self.c2_dim)
|
||||
c_trg = self.label2onehot(label_trg, self.c2_dim)
|
||||
zero = torch.zeros(x_real.size(0), self.c_dim)
|
||||
mask = self.label2onehot(torch.ones(x_real.size(0)), 2)
|
||||
c_org = torch.cat([zero, c_org, mask], dim=1)
|
||||
c_trg = torch.cat([zero, c_trg, mask], dim=1)
|
||||
|
||||
x_real = x_real.to(self.device) # Input images.
|
||||
c_org = c_org.to(self.device) # Original domain labels.
|
||||
c_trg = c_trg.to(self.device) # Target domain labels.
|
||||
label_org = label_org.to(self.device) # Labels for computing classification loss.
|
||||
label_trg = label_trg.to(self.device) # Labels for computing classification loss.
|
||||
|
||||
# =================================================================================== #
|
||||
# 2. Train the discriminator #
|
||||
# =================================================================================== #
|
||||
|
||||
# Compute loss with real images.
|
||||
out_src, out_cls = self.D(x_real)
|
||||
out_cls = out_cls[:, :self.c_dim] if dataset == 'CelebA' else out_cls[:, self.c_dim:]
|
||||
d_loss_real = - torch.mean(out_src)
|
||||
d_loss_cls = self.classification_loss(out_cls, label_org, dataset)
|
||||
|
||||
# Compute loss with fake images.
|
||||
x_fake = self.G(x_real, c_trg)
|
||||
out_src, _ = self.D(x_fake.detach())
|
||||
d_loss_fake = torch.mean(out_src)
|
||||
|
||||
# Compute loss for gradient penalty.
|
||||
alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
|
||||
x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
|
||||
out_src, _ = self.D(x_hat)
|
||||
d_loss_gp = self.gradient_penalty(out_src, x_hat)
|
||||
|
||||
# Backward and optimize.
|
||||
d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp
|
||||
self.reset_grad()
|
||||
d_loss.backward()
|
||||
self.d_optimizer.step()
|
||||
|
||||
# Logging.
|
||||
loss = {}
|
||||
loss['D/loss_real'] = d_loss_real.item()
|
||||
loss['D/loss_fake'] = d_loss_fake.item()
|
||||
loss['D/loss_cls'] = d_loss_cls.item()
|
||||
loss['D/loss_gp'] = d_loss_gp.item()
|
||||
|
||||
# =================================================================================== #
|
||||
# 3. Train the generator #
|
||||
# =================================================================================== #
|
||||
|
||||
if (i+1) % self.n_critic == 0:
|
||||
# Original-to-target domain.
|
||||
x_fake = self.G(x_real, c_trg)
|
||||
out_src, out_cls = self.D(x_fake)
|
||||
out_cls = out_cls[:, :self.c_dim] if dataset == 'CelebA' else out_cls[:, self.c_dim:]
|
||||
g_loss_fake = - torch.mean(out_src)
|
||||
g_loss_cls = self.classification_loss(out_cls, label_trg, dataset)
|
||||
|
||||
# Target-to-original domain.
|
||||
x_reconst = self.G(x_fake, c_org)
|
||||
g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))
|
||||
|
||||
# Backward and optimize.
|
||||
g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
|
||||
self.reset_grad()
|
||||
g_loss.backward()
|
||||
self.g_optimizer.step()
|
||||
|
||||
# Logging.
|
||||
loss['G/loss_fake'] = g_loss_fake.item()
|
||||
loss['G/loss_rec'] = g_loss_rec.item()
|
||||
loss['G/loss_cls'] = g_loss_cls.item()
|
||||
|
||||
# =================================================================================== #
|
||||
# 4. Miscellaneous #
|
||||
# =================================================================================== #
|
||||
|
||||
# Print out training info.
|
||||
if (i+1) % self.log_step == 0:
|
||||
et = time.time() - start_time
|
||||
et = str(datetime.timedelta(seconds=et))[:-7]
|
||||
log = "Elapsed [{}], Iteration [{}/{}], Dataset [{}]".format(et, i+1, self.num_iters, dataset)
|
||||
for tag, value in loss.items():
|
||||
log += ", {}: {:.4f}".format(tag, value)
|
||||
print(log)
|
||||
|
||||
if self.use_tensorboard:
|
||||
for tag, value in loss.items():
|
||||
self.logger.scalar_summary(tag, value, i+1)
|
||||
|
||||
# Translate fixed images for debugging.
|
||||
if (i+1) % self.sample_step == 0:
|
||||
with torch.no_grad():
|
||||
x_fake_list = [x_fixed]
|
||||
for c_fixed in c_celeba_list:
|
||||
c_trg = torch.cat([c_fixed, zero_rafd, mask_celeba], dim=1)
|
||||
x_fake_list.append(self.G(x_fixed, c_trg))
|
||||
for c_fixed in c_rafd_list:
|
||||
c_trg = torch.cat([zero_celeba, c_fixed, mask_rafd], dim=1)
|
||||
x_fake_list.append(self.G(x_fixed, c_trg))
|
||||
x_concat = torch.cat(x_fake_list, dim=3)
|
||||
sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1))
|
||||
save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
|
||||
print('Saved real and fake images into {}...'.format(sample_path))
|
||||
|
||||
# Save model checkpoints.
|
||||
if (i+1) % self.model_save_step == 0:
|
||||
G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1))
|
||||
D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1))
|
||||
torch.save(self.G.state_dict(), G_path)
|
||||
torch.save(self.D.state_dict(), D_path)
|
||||
print('Saved model checkpoints into {}...'.format(self.model_save_dir))
|
||||
|
||||
# Decay learning rates.
|
||||
if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay):
|
||||
g_lr -= (self.g_lr / float(self.num_iters_decay))
|
||||
d_lr -= (self.d_lr / float(self.num_iters_decay))
|
||||
self.update_lr(g_lr, d_lr)
|
||||
print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
|
||||
# def train_multi(self):
|
||||
# """Train StarGAN with multiple datasets."""
|
||||
# # Data iterators.
|
||||
# celeba_iter = iter(self.celeba_loader)
|
||||
# rafd_iter = iter(self.rafd_loader)
|
||||
#
|
||||
# # Fetch fixed inputs for debugging.
|
||||
# x_fixed, c_org = next(celeba_iter)
|
||||
# x_fixed = x_fixed.to(self.device)
|
||||
# c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs)
|
||||
# c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD')
|
||||
# zero_celeba = torch.zeros(x_fixed.size(0), self.c_dim).to(self.device) # Zero vector for CelebA.
|
||||
# zero_rafd = torch.zeros(x_fixed.size(0), self.c2_dim).to(self.device) # Zero vector for RaFD.
|
||||
# mask_celeba = self.label2onehot(torch.zeros(x_fixed.size(0)), 2).to(self.device) # Mask vector: [1, 0].
|
||||
# mask_rafd = self.label2onehot(torch.ones(x_fixed.size(0)), 2).to(self.device) # Mask vector: [0, 1].
|
||||
#
|
||||
# # Learning rate cache for decaying.
|
||||
# g_lr = self.g_lr
|
||||
# d_lr = self.d_lr
|
||||
#
|
||||
# # Start training from scratch or resume training.
|
||||
# start_iters = 0
|
||||
# if self.resume_iters:
|
||||
# start_iters = self.resume_iters
|
||||
# self.restore_model(self.resume_iters)
|
||||
#
|
||||
# # Start training.
|
||||
# print('Start training...')
|
||||
# start_time = time.time()
|
||||
# for i in range(start_iters, self.num_iters):
|
||||
# for dataset in ['CelebA', 'RaFD']:
|
||||
#
|
||||
# # =================================================================================== #
|
||||
# # 1. Preprocess input data #
|
||||
# # =================================================================================== #
|
||||
#
|
||||
# # Fetch real images and labels.
|
||||
# data_iter = celeba_iter if dataset == 'CelebA' else rafd_iter
|
||||
#
|
||||
# try:
|
||||
# x_real, label_org = next(data_iter)
|
||||
# except:
|
||||
# if dataset == 'CelebA':
|
||||
# celeba_iter = iter(self.celeba_loader)
|
||||
# x_real, label_org = next(celeba_iter)
|
||||
# elif dataset == 'RaFD':
|
||||
# rafd_iter = iter(self.rafd_loader)
|
||||
# x_real, label_org = next(rafd_iter)
|
||||
#
|
||||
# # Generate target domain labels randomly.
|
||||
# rand_idx = torch.randperm(label_org.size(0))
|
||||
# label_trg = label_org[rand_idx]
|
||||
#
|
||||
# if dataset == 'CelebA':
|
||||
# c_org = label_org.clone()
|
||||
# c_trg = label_trg.clone()
|
||||
# zero = torch.zeros(x_real.size(0), self.c2_dim)
|
||||
# mask = self.label2onehot(torch.zeros(x_real.size(0)), 2)
|
||||
# c_org = torch.cat([c_org, zero, mask], dim=1)
|
||||
# c_trg = torch.cat([c_trg, zero, mask], dim=1)
|
||||
# elif dataset == 'RaFD':
|
||||
# c_org = self.label2onehot(label_org, self.c2_dim)
|
||||
# c_trg = self.label2onehot(label_trg, self.c2_dim)
|
||||
# zero = torch.zeros(x_real.size(0), self.c_dim)
|
||||
# mask = self.label2onehot(torch.ones(x_real.size(0)), 2)
|
||||
# c_org = torch.cat([zero, c_org, mask], dim=1)
|
||||
# c_trg = torch.cat([zero, c_trg, mask], dim=1)
|
||||
#
|
||||
# x_real = x_real.to(self.device) # Input images.
|
||||
# c_org = c_org.to(self.device) # Original domain labels.
|
||||
# c_trg = c_trg.to(self.device) # Target domain labels.
|
||||
# label_org = label_org.to(self.device) # Labels for computing classification loss.
|
||||
# label_trg = label_trg.to(self.device) # Labels for computing classification loss.
|
||||
#
|
||||
# # =================================================================================== #
|
||||
# # 2. Train the discriminator #
|
||||
# # =================================================================================== #
|
||||
#
|
||||
# # Compute loss with real images.
|
||||
# out_src, out_cls = self.D(x_real)
|
||||
# out_cls = out_cls[:, :self.c_dim] if dataset == 'CelebA' else out_cls[:, self.c_dim:]
|
||||
# d_loss_real = - torch.mean(out_src)
|
||||
# d_loss_cls = self.classification_loss(out_cls, label_org, dataset)
|
||||
#
|
||||
# # Compute loss with fake images.
|
||||
# x_fake = self.G(x_real, c_trg)
|
||||
# out_src, _ = self.D(x_fake.detach())
|
||||
# d_loss_fake = torch.mean(out_src)
|
||||
#
|
||||
# # Compute loss for gradient penalty.
|
||||
# alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
|
||||
# x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
|
||||
# out_src, _ = self.D(x_hat)
|
||||
# d_loss_gp = self.gradient_penalty(out_src, x_hat)
|
||||
#
|
||||
# # Backward and optimize.
|
||||
# d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp
|
||||
# self.reset_grad()
|
||||
# d_loss.backward()
|
||||
# self.d_optimizer.step()
|
||||
#
|
||||
# # Logging.
|
||||
# loss = {}
|
||||
# loss['D/loss_real'] = d_loss_real.item()
|
||||
# loss['D/loss_fake'] = d_loss_fake.item()
|
||||
# loss['D/loss_cls'] = d_loss_cls.item()
|
||||
# loss['D/loss_gp'] = d_loss_gp.item()
|
||||
#
|
||||
# # =================================================================================== #
|
||||
# # 3. Train the generator #
|
||||
# # =================================================================================== #
|
||||
#
|
||||
# if (i+1) % self.n_critic == 0:
|
||||
# # Original-to-target domain.
|
||||
# x_fake = self.G(x_real, c_trg)
|
||||
# out_src, out_cls = self.D(x_fake)
|
||||
# out_cls = out_cls[:, :self.c_dim] if dataset == 'CelebA' else out_cls[:, self.c_dim:]
|
||||
# g_loss_fake = - torch.mean(out_src)
|
||||
# g_loss_cls = self.classification_loss(out_cls, label_trg, dataset)
|
||||
#
|
||||
# # Target-to-original domain.
|
||||
# x_reconst = self.G(x_fake, c_org)
|
||||
# g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))
|
||||
#
|
||||
# # Backward and optimize.
|
||||
# g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
|
||||
# self.reset_grad()
|
||||
# g_loss.backward()
|
||||
# self.g_optimizer.step()
|
||||
#
|
||||
# # Logging.
|
||||
# loss['G/loss_fake'] = g_loss_fake.item()
|
||||
# loss['G/loss_rec'] = g_loss_rec.item()
|
||||
# loss['G/loss_cls'] = g_loss_cls.item()
|
||||
#
|
||||
# # =================================================================================== #
|
||||
# # 4. Miscellaneous #
|
||||
# # =================================================================================== #
|
||||
#
|
||||
# # Print out training info.
|
||||
# if (i+1) % self.log_step == 0:
|
||||
# et = time.time() - start_time
|
||||
# et = str(datetime.timedelta(seconds=et))[:-7]
|
||||
# log = "Elapsed [{}], Iteration [{}/{}], Dataset [{}]".format(et, i+1, self.num_iters, dataset)
|
||||
# for tag, value in loss.items():
|
||||
# log += ", {}: {:.4f}".format(tag, value)
|
||||
# print(log)
|
||||
#
|
||||
# if self.use_tensorboard:
|
||||
# for tag, value in loss.items():
|
||||
# self.logger.scalar_summary(tag, value, i+1)
|
||||
#
|
||||
# # Translate fixed images for debugging.
|
||||
# if (i+1) % self.sample_step == 0:
|
||||
# with torch.no_grad():
|
||||
# x_fake_list = [x_fixed]
|
||||
# for c_fixed in c_celeba_list:
|
||||
# c_trg = torch.cat([c_fixed, zero_rafd, mask_celeba], dim=1)
|
||||
# x_fake_list.append(self.G(x_fixed, c_trg))
|
||||
# for c_fixed in c_rafd_list:
|
||||
# c_trg = torch.cat([zero_celeba, c_fixed, mask_rafd], dim=1)
|
||||
# x_fake_list.append(self.G(x_fixed, c_trg))
|
||||
# x_concat = torch.cat(x_fake_list, dim=3)
|
||||
# sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1))
|
||||
# save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
|
||||
# print('Saved real and fake images into {}...'.format(sample_path))
|
||||
#
|
||||
# # Save model checkpoints.
|
||||
# if (i+1) % self.model_save_step == 0:
|
||||
# G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1))
|
||||
# D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1))
|
||||
# torch.save(self.G.state_dict(), G_path)
|
||||
# torch.save(self.D.state_dict(), D_path)
|
||||
# print('Saved model checkpoints into {}...'.format(self.model_save_dir))
|
||||
#
|
||||
# # Decay learning rates.
|
||||
# if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay):
|
||||
# g_lr -= (self.g_lr / float(self.num_iters_decay))
|
||||
# d_lr -= (self.d_lr / float(self.num_iters_decay))
|
||||
# self.update_lr(g_lr, d_lr)
|
||||
# print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
|
||||
|
||||
def test(self):
|
||||
"""Translate images using StarGAN trained on a single dataset."""
|
||||
@@ -530,6 +551,8 @@ class Solver(object):
|
||||
data_loader = self.celeba_loader
|
||||
elif self.dataset == 'RaFD':
|
||||
data_loader = self.rafd_loader
|
||||
elif self.dataset == 'AffectNet':
|
||||
data_loader = self.affectnet_loader
|
||||
|
||||
with torch.no_grad():
|
||||
for i, (x_real, c_org) in enumerate(data_loader):
|
||||
@@ -549,34 +572,34 @@ class Solver(object):
|
||||
save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
|
||||
print('Saved real and fake images into {}...'.format(result_path))
|
||||
|
||||
def test_multi(self):
|
||||
"""Translate images using StarGAN trained on multiple datasets."""
|
||||
# Load the trained generator.
|
||||
self.restore_model(self.test_iters)
|
||||
|
||||
with torch.no_grad():
|
||||
for i, (x_real, c_org) in enumerate(self.celeba_loader):
|
||||
|
||||
# Prepare input images and target domain labels.
|
||||
x_real = x_real.to(self.device)
|
||||
c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs)
|
||||
c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD')
|
||||
zero_celeba = torch.zeros(x_real.size(0), self.c_dim).to(self.device) # Zero vector for CelebA.
|
||||
zero_rafd = torch.zeros(x_real.size(0), self.c2_dim).to(self.device) # Zero vector for RaFD.
|
||||
mask_celeba = self.label2onehot(torch.zeros(x_real.size(0)), 2).to(self.device) # Mask vector: [1, 0].
|
||||
mask_rafd = self.label2onehot(torch.ones(x_real.size(0)), 2).to(self.device) # Mask vector: [0, 1].
|
||||
|
||||
# Translate images.
|
||||
x_fake_list = [x_real]
|
||||
for c_celeba in c_celeba_list:
|
||||
c_trg = torch.cat([c_celeba, zero_rafd, mask_celeba], dim=1)
|
||||
x_fake_list.append(self.G(x_real, c_trg))
|
||||
for c_rafd in c_rafd_list:
|
||||
c_trg = torch.cat([zero_celeba, c_rafd, mask_rafd], dim=1)
|
||||
x_fake_list.append(self.G(x_real, c_trg))
|
||||
|
||||
# Save the translated images.
|
||||
x_concat = torch.cat(x_fake_list, dim=3)
|
||||
result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
|
||||
save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
|
||||
print('Saved real and fake images into {}...'.format(result_path))
|
||||
# def test_multi(self):
|
||||
# """Translate images using StarGAN trained on multiple datasets."""
|
||||
# # Load the trained generator.
|
||||
# self.restore_model(self.test_iters)
|
||||
#
|
||||
# with torch.no_grad():
|
||||
# for i, (x_real, c_org) in enumerate(self.celeba_loader):
|
||||
#
|
||||
# # Prepare input images and target domain labels.
|
||||
# x_real = x_real.to(self.device)
|
||||
# c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs)
|
||||
# c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD')
|
||||
# zero_celeba = torch.zeros(x_real.size(0), self.c_dim).to(self.device) # Zero vector for CelebA.
|
||||
# zero_rafd = torch.zeros(x_real.size(0), self.c2_dim).to(self.device) # Zero vector for RaFD.
|
||||
# mask_celeba = self.label2onehot(torch.zeros(x_real.size(0)), 2).to(self.device) # Mask vector: [1, 0].
|
||||
# mask_rafd = self.label2onehot(torch.ones(x_real.size(0)), 2).to(self.device) # Mask vector: [0, 1].
|
||||
#
|
||||
# # Translate images.
|
||||
# x_fake_list = [x_real]
|
||||
# for c_celeba in c_celeba_list:
|
||||
# c_trg = torch.cat([c_celeba, zero_rafd, mask_celeba], dim=1)
|
||||
# x_fake_list.append(self.G(x_real, c_trg))
|
||||
# for c_rafd in c_rafd_list:
|
||||
# c_trg = torch.cat([zero_celeba, c_rafd, mask_rafd], dim=1)
|
||||
# x_fake_list.append(self.G(x_real, c_trg))
|
||||
#
|
||||
# # Save the translated images.
|
||||
# x_concat = torch.cat(x_fake_list, dim=3)
|
||||
# result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
|
||||
# save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
|
||||
# print('Saved real and fake images into {}...'.format(result_path))
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário