From a24d8688ddd6e7c63c0bb07d69417d8caffc1977 Mon Sep 17 00:00:00 2001 From: Aleksander Date: Mon, 17 Feb 2020 00:29:50 +0000 Subject: [PATCH] Added AffectNet dataset w/ categorical and valence-arousal emotion descriptors --- data_loader.py | 69 +++++- download_affectnet.sh | 20 +- main.py | 35 +-- solver.py | 497 ++++++++++++++++++++++-------------------- 4 files changed, 359 insertions(+), 262 deletions(-) diff --git a/data_loader.py b/data_loader.py index d0c5eac..0189111 100644 --- a/data_loader.py +++ b/data_loader.py @@ -68,8 +68,69 @@ class CelebA(data.Dataset): return self.num_images -def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=128, - batch_size=16, dataset='CelebA', mode='train', num_workers=1): +class AffectNet(data.Dataset): + """Dataset class for the AffectNet dataset.""" + + def __init__(self, image_dir, affectnet_emo_descr, transform, mode): + """Initialize and preprocess the AffectNet dataset.""" + self.image_dir = image_dir + self.affectnet_emo_descr = affectnet_emo_descr + self.transform = transform + self.mode = mode + self.train_dataset = [] + self.test_dataset = [] + self.preprocess() + + if mode == 'train': + self.num_images = len(self.train_dataset) + else: + self.num_images = len(self.test_dataset) + + + def preprocess(self): + """Preprocess the AffectNet emotional description file.""" + if self.affectnet_emo_descr == 'emotiw': + lines_train = [line.rstrip() for line in open(os.path.join(self.image_dir, '_affectnet_train_emotiw.txt'), 'r')] + for line in lines_train: + filename, label = line.split() + self.train_dataset.append([filename, int(label)]) + lines_test = [line.rstrip() for line in open(os.path.join(self.image_dir, '_affectnet_test_emotiw.txt'), 'r')] + for line in lines_test: + filename, label = line.split() + self.test_dataset.append([filename, int(label)]) + else: + lines_train = [line.rstrip() for line in open(os.path.join(self.image_dir, '_affectnet_train_va.txt'), 'r')] + for line in lines_train: + split = line.split() + filename, v, a = split + self.train_dataset.append([filename, [float(v), float(a)]]) + lines_test = [line.rstrip() for line in open(os.path.join(self.image_dir, '_affectnet_test_va.txt'), 'r')] + for line in lines_test: + split = line.split() + filename, v, a = split + self.test_dataset.append([filename, [float(v), float(a)]]) + + print('Finished preprocessing the AffectNet dataset...') + + def __getitem__(self, index): + """Return one image and its corresponding attribute label.""" + if self.mode == 'train': + dataset = self.train_dataset + filename, label = dataset[index] + image = Image.open(os.path.join(self.image_dir, 'train', filename)) + else: + dataset = self.test_dataset + filename, label = dataset[index] + image = Image.open(os.path.join(self.image_dir, 'validation', filename)) + return self.transform(image), torch.tensor(label) + + def __len__(self): + """Return the number of images.""" + return self.num_images + + +def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=128, + batch_size=16, dataset='CelebA', mode='train', affectnet_emo_descr='emotiw', num_workers=1): """Build and return a data loader.""" transform = [] if mode == 'train': @@ -84,9 +145,11 @@ def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=1 dataset = CelebA(image_dir, attr_path, selected_attrs, transform, mode) elif dataset == 'RaFD': dataset = ImageFolder(image_dir, transform) + elif dataset == 'AffectNet': + dataset = AffectNet(image_dir, affectnet_emo_descr, transform, mode) data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=(mode=='train'), num_workers=num_workers) - return data_loader \ No newline at end of file + return data_loader diff --git a/download_affectnet.sh b/download_affectnet.sh index cd9553a..949c258 100644 --- a/download_affectnet.sh +++ b/download_affectnet.sh @@ -1,8 +1,14 @@ -URL=$1 +URL1=$1 +URL2=$2 -# Download affectnet -ZIP_FILE=./data/affectnet.zip -mkdir -p ./data/ -wget -N $URL -O $ZIP_FILE -unzip $ZIP_FILE -d ./data/ -rm $ZIP_FILE +# Download affectnet and labels +mkdir -p ./data/affectnet +ZIP_FILE1=./data/affectnet.zip +wget -N $URL1 -O $ZIP_FILE1 +unzip $ZIP_FILE1 -d ./data/affectnet +rm $ZIP_FILE1 + +ZIP_FILE2=./data/affectnet_labels.zip +wget -N $URL2 -O $ZIP_FILE2 +unzip $ZIP_FILE2 -d ./data/affectnet +rm $ZIP_FILE2 diff --git a/main.py b/main.py index 8400abb..aac52e7 100644 --- a/main.py +++ b/main.py @@ -25,30 +25,33 @@ def main(config): # Data loader. celeba_loader = None rafd_loader = None + affectnet_loader = None if config.dataset in ['CelebA', 'Both']: celeba_loader = get_loader(config.celeba_image_dir, config.attr_path, config.selected_attrs, config.celeba_crop_size, config.image_size, config.batch_size, - 'CelebA', config.mode, config.num_workers) + 'CelebA', config.mode, None, config.num_workers) if config.dataset in ['RaFD', 'Both']: rafd_loader = get_loader(config.rafd_image_dir, None, None, config.rafd_crop_size, config.image_size, config.batch_size, - 'RaFD', config.mode, config.num_workers) - - + 'RaFD', config.mode, None, config.num_workers) + if config.dataset in ['AffectNet']: + affectnet_loader = get_loader(config.affectnet_image_dir, None, None, + config.affectnet_crop_size, config.image_size, config.batch_size, + 'AffectNet', config.mode, config.affectnet_emo_descr, config.num_workers) # Solver for training and testing StarGAN. - solver = Solver(celeba_loader, rafd_loader, config) + solver = Solver(celeba_loader, rafd_loader, affectnet_loader, config) if config.mode == 'train': - if config.dataset in ['CelebA', 'RaFD']: + if config.dataset in ['CelebA', 'RaFD', 'AffectNet']: solver.train() - elif config.dataset in ['Both']: - solver.train_multi() + # elif config.dataset in ['Both']: + # solver.train_multi() elif config.mode == 'test': - if config.dataset in ['CelebA', 'RaFD']: + if config.dataset in ['CelebA', 'RaFD', 'AffectNet']: solver.test() - elif config.dataset in ['Both']: - solver.test_multi() + # elif config.dataset in ['Both']: + # solver.test_multi() if __name__ == '__main__': @@ -59,6 +62,7 @@ if __name__ == '__main__': parser.add_argument('--c2_dim', type=int, default=8, help='dimension of domain labels (2nd dataset)') parser.add_argument('--celeba_crop_size', type=int, default=178, help='crop size for the CelebA dataset') parser.add_argument('--rafd_crop_size', type=int, default=256, help='crop size for the RaFD dataset') + parser.add_argument('--affectnet_crop_size', type=int, default=112, help='crop size for the AffectNet dataset') parser.add_argument('--image_size', type=int, default=128, help='image resolution') parser.add_argument('--g_conv_dim', type=int, default=64, help='number of conv filters in the first layer of G') parser.add_argument('--d_conv_dim', type=int, default=64, help='number of conv filters in the first layer of D') @@ -67,9 +71,9 @@ if __name__ == '__main__': parser.add_argument('--lambda_cls', type=float, default=1, help='weight for domain classification loss') parser.add_argument('--lambda_rec', type=float, default=10, help='weight for reconstruction loss') parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty') - + # Training configuration. - parser.add_argument('--dataset', type=str, default='CelebA', choices=['CelebA', 'RaFD', 'Both']) + parser.add_argument('--dataset', type=str, default='CelebA', choices=['CelebA', 'RaFD', 'AffectNet', 'Both']) parser.add_argument('--batch_size', type=int, default=16, help='mini-batch size') parser.add_argument('--num_iters', type=int, default=200000, help='number of total iterations for training D') parser.add_argument('--num_iters_decay', type=int, default=100000, help='number of iterations for decaying lr') @@ -81,7 +85,7 @@ if __name__ == '__main__': parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this step') parser.add_argument('--selected_attrs', '--list', nargs='+', help='selected attributes for the CelebA dataset', default=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']) - + parser.add_argument('--affectnet_emo_descr', type=str, default='emotiw', help='emotiw for categorical emotions, va for valence-arousal') # Test configuration. parser.add_argument('--test_iters', type=int, default=200000, help='test model from this step') @@ -94,6 +98,7 @@ if __name__ == '__main__': parser.add_argument('--celeba_image_dir', type=str, default='data/celeba/images') parser.add_argument('--attr_path', type=str, default='data/celeba/list_attr_celeba.txt') parser.add_argument('--rafd_image_dir', type=str, default='data/RaFD/train') + parser.add_argument('--affectnet_image_dir', type=str, default='data/affectnet') parser.add_argument('--log_dir', type=str, default='stargan/logs') parser.add_argument('--model_save_dir', type=str, default='stargan/models') parser.add_argument('--sample_dir', type=str, default='stargan/samples') @@ -107,4 +112,4 @@ if __name__ == '__main__': config = parser.parse_args() print(config) - main(config) \ No newline at end of file + main(config) diff --git a/solver.py b/solver.py index f2079b2..ec780e7 100644 --- a/solver.py +++ b/solver.py @@ -13,13 +13,13 @@ import datetime class Solver(object): """Solver for training and testing StarGAN.""" - def __init__(self, celeba_loader, rafd_loader, config): + def __init__(self, celeba_loader, rafd_loader, affectnet_loader, config): """Initialize configurations.""" # Data loader. self.celeba_loader = celeba_loader self.rafd_loader = rafd_loader - + self.affectnet_loader = affectnet_loader # Model configurations. self.c_dim = config.c_dim self.c2_dim = config.c2_dim @@ -44,7 +44,7 @@ class Solver(object): self.beta2 = config.beta2 self.resume_iters = config.resume_iters self.selected_attrs = config.selected_attrs - + self.affectnet_emo_descr = config.affectnet_emo_descr # Test configurations. self.test_iters = config.test_iters @@ -71,9 +71,9 @@ class Solver(object): def build_model(self): """Create a generator and a discriminator.""" - if self.dataset in ['CelebA', 'RaFD']: + if self.dataset in ['CelebA', 'RaFD', 'AffectNet']: self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num) - self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num) + self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num) elif self.dataset in ['Both']: self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num) # 2 for mask vector. self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num) @@ -82,7 +82,7 @@ class Solver(object): self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) self.print_network(self.G, 'G') self.print_network(self.D, 'D') - + self.G.to(self.device) self.D.to(self.device) @@ -156,29 +156,41 @@ class Solver(object): hair_color_indices.append(i) c_trg_list = [] - for i in range(c_dim): - if dataset == 'CelebA': - c_trg = c_org.clone() - if i in hair_color_indices: # Set one hair color to 1 and the rest to 0. - c_trg[:, i] = 1 - for j in hair_color_indices: - if j != i: - c_trg[:, j] = 0 - else: - c_trg[:, i] = (c_trg[:, i] == 0) # Reverse attribute value. - elif dataset == 'RaFD': - c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim) + if dataset == 'AffectNet' and self.affectnet_emo_descr == 'va': + for v in [-1., -0.5, 0., 0.5, 1.]: + for a in [-1., -0.5, 0., 0.5, 1.]: + c_trg = torch.Tensor([v, a]).repeat(c_org.size(0), 1) + c_trg_list.append(c_trg.to(self.device)) + else: + for i in range(c_dim): + if dataset == 'CelebA': + c_trg = c_org.clone() + if i in hair_color_indices: # Set one hair color to 1 and the rest to 0. + c_trg[:, i] = 1 + for j in hair_color_indices: + if j != i: + c_trg[:, j] = 0 + else: + c_trg[:, i] = (c_trg[:, i] == 0) # Reverse attribute value. + elif dataset == 'RaFD': + c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim) + elif dataset == 'AffectNet' and self.affectnet_emo_descr == 'emotiw': + c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim) + c_trg_list.append(c_trg.to(self.device)) - c_trg_list.append(c_trg.to(self.device)) return c_trg_list def classification_loss(self, logit, target, dataset='CelebA'): """Compute binary or softmax cross entropy loss.""" if dataset == 'CelebA': - return F.binary_cross_entropy_with_logits(logit, target, size_average=False) / logit.size(0) + return F.binary_cross_entropy_with_logits(logit, target, reduction='sum') / logit.size(0) elif dataset == 'RaFD': return F.cross_entropy(logit, target) - + elif dataset == 'AffectNet': + if self.affectnet_emo_descr == 'emotiw': + return F.cross_entropy(logit, target) + else: + return F.mse_loss(torch.tanh(logit), target) def train(self): """Train StarGAN within a single dataset.""" # Set data loader. @@ -186,6 +198,8 @@ class Solver(object): data_loader = self.celeba_loader elif self.dataset == 'RaFD': data_loader = self.rafd_loader + elif self.dataset == 'AffectNet': + data_loader = self.affectnet_loader # Fetch fixed inputs for debugging. data_iter = iter(data_loader) @@ -229,6 +243,13 @@ class Solver(object): elif self.dataset == 'RaFD': c_org = self.label2onehot(label_org, self.c_dim) c_trg = self.label2onehot(label_trg, self.c_dim) + elif self.dataset == 'AffectNet': + if self.affectnet_emo_descr == 'emotiw': + c_org = self.label2onehot(label_org, self.c_dim) + c_trg = self.label2onehot(label_trg, self.c_dim) + else: + c_org = label_org.clone() + c_trg = label_trg.clone() x_real = x_real.to(self.device) # Input images. c_org = c_org.to(self.device) # Original domain labels. @@ -268,11 +289,11 @@ class Solver(object): loss['D/loss_fake'] = d_loss_fake.item() loss['D/loss_cls'] = d_loss_cls.item() loss['D/loss_gp'] = d_loss_gp.item() - + # =================================================================================== # # 3. Train the generator # # =================================================================================== # - + if (i+1) % self.n_critic == 0: # Original-to-target domain. x_fake = self.G(x_real, c_trg) @@ -338,199 +359,201 @@ class Solver(object): self.update_lr(g_lr, d_lr) print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) - def train_multi(self): - """Train StarGAN with multiple datasets.""" - # Data iterators. - celeba_iter = iter(self.celeba_loader) - rafd_iter = iter(self.rafd_loader) - - # Fetch fixed inputs for debugging. - x_fixed, c_org = next(celeba_iter) - x_fixed = x_fixed.to(self.device) - c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs) - c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD') - zero_celeba = torch.zeros(x_fixed.size(0), self.c_dim).to(self.device) # Zero vector for CelebA. - zero_rafd = torch.zeros(x_fixed.size(0), self.c2_dim).to(self.device) # Zero vector for RaFD. - mask_celeba = self.label2onehot(torch.zeros(x_fixed.size(0)), 2).to(self.device) # Mask vector: [1, 0]. - mask_rafd = self.label2onehot(torch.ones(x_fixed.size(0)), 2).to(self.device) # Mask vector: [0, 1]. - - # Learning rate cache for decaying. - g_lr = self.g_lr - d_lr = self.d_lr - - # Start training from scratch or resume training. - start_iters = 0 - if self.resume_iters: - start_iters = self.resume_iters - self.restore_model(self.resume_iters) - - # Start training. - print('Start training...') - start_time = time.time() - for i in range(start_iters, self.num_iters): - for dataset in ['CelebA', 'RaFD']: - - # =================================================================================== # - # 1. Preprocess input data # - # =================================================================================== # - - # Fetch real images and labels. - data_iter = celeba_iter if dataset == 'CelebA' else rafd_iter - - try: - x_real, label_org = next(data_iter) - except: - if dataset == 'CelebA': - celeba_iter = iter(self.celeba_loader) - x_real, label_org = next(celeba_iter) - elif dataset == 'RaFD': - rafd_iter = iter(self.rafd_loader) - x_real, label_org = next(rafd_iter) - - # Generate target domain labels randomly. - rand_idx = torch.randperm(label_org.size(0)) - label_trg = label_org[rand_idx] - - if dataset == 'CelebA': - c_org = label_org.clone() - c_trg = label_trg.clone() - zero = torch.zeros(x_real.size(0), self.c2_dim) - mask = self.label2onehot(torch.zeros(x_real.size(0)), 2) - c_org = torch.cat([c_org, zero, mask], dim=1) - c_trg = torch.cat([c_trg, zero, mask], dim=1) - elif dataset == 'RaFD': - c_org = self.label2onehot(label_org, self.c2_dim) - c_trg = self.label2onehot(label_trg, self.c2_dim) - zero = torch.zeros(x_real.size(0), self.c_dim) - mask = self.label2onehot(torch.ones(x_real.size(0)), 2) - c_org = torch.cat([zero, c_org, mask], dim=1) - c_trg = torch.cat([zero, c_trg, mask], dim=1) - - x_real = x_real.to(self.device) # Input images. - c_org = c_org.to(self.device) # Original domain labels. - c_trg = c_trg.to(self.device) # Target domain labels. - label_org = label_org.to(self.device) # Labels for computing classification loss. - label_trg = label_trg.to(self.device) # Labels for computing classification loss. - - # =================================================================================== # - # 2. Train the discriminator # - # =================================================================================== # - - # Compute loss with real images. - out_src, out_cls = self.D(x_real) - out_cls = out_cls[:, :self.c_dim] if dataset == 'CelebA' else out_cls[:, self.c_dim:] - d_loss_real = - torch.mean(out_src) - d_loss_cls = self.classification_loss(out_cls, label_org, dataset) - - # Compute loss with fake images. - x_fake = self.G(x_real, c_trg) - out_src, _ = self.D(x_fake.detach()) - d_loss_fake = torch.mean(out_src) - - # Compute loss for gradient penalty. - alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device) - x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True) - out_src, _ = self.D(x_hat) - d_loss_gp = self.gradient_penalty(out_src, x_hat) - - # Backward and optimize. - d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp - self.reset_grad() - d_loss.backward() - self.d_optimizer.step() - - # Logging. - loss = {} - loss['D/loss_real'] = d_loss_real.item() - loss['D/loss_fake'] = d_loss_fake.item() - loss['D/loss_cls'] = d_loss_cls.item() - loss['D/loss_gp'] = d_loss_gp.item() - - # =================================================================================== # - # 3. Train the generator # - # =================================================================================== # - - if (i+1) % self.n_critic == 0: - # Original-to-target domain. - x_fake = self.G(x_real, c_trg) - out_src, out_cls = self.D(x_fake) - out_cls = out_cls[:, :self.c_dim] if dataset == 'CelebA' else out_cls[:, self.c_dim:] - g_loss_fake = - torch.mean(out_src) - g_loss_cls = self.classification_loss(out_cls, label_trg, dataset) - - # Target-to-original domain. - x_reconst = self.G(x_fake, c_org) - g_loss_rec = torch.mean(torch.abs(x_real - x_reconst)) - - # Backward and optimize. - g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls - self.reset_grad() - g_loss.backward() - self.g_optimizer.step() - - # Logging. - loss['G/loss_fake'] = g_loss_fake.item() - loss['G/loss_rec'] = g_loss_rec.item() - loss['G/loss_cls'] = g_loss_cls.item() - - # =================================================================================== # - # 4. Miscellaneous # - # =================================================================================== # - - # Print out training info. - if (i+1) % self.log_step == 0: - et = time.time() - start_time - et = str(datetime.timedelta(seconds=et))[:-7] - log = "Elapsed [{}], Iteration [{}/{}], Dataset [{}]".format(et, i+1, self.num_iters, dataset) - for tag, value in loss.items(): - log += ", {}: {:.4f}".format(tag, value) - print(log) - - if self.use_tensorboard: - for tag, value in loss.items(): - self.logger.scalar_summary(tag, value, i+1) - - # Translate fixed images for debugging. - if (i+1) % self.sample_step == 0: - with torch.no_grad(): - x_fake_list = [x_fixed] - for c_fixed in c_celeba_list: - c_trg = torch.cat([c_fixed, zero_rafd, mask_celeba], dim=1) - x_fake_list.append(self.G(x_fixed, c_trg)) - for c_fixed in c_rafd_list: - c_trg = torch.cat([zero_celeba, c_fixed, mask_rafd], dim=1) - x_fake_list.append(self.G(x_fixed, c_trg)) - x_concat = torch.cat(x_fake_list, dim=3) - sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1)) - save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0) - print('Saved real and fake images into {}...'.format(sample_path)) - - # Save model checkpoints. - if (i+1) % self.model_save_step == 0: - G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1)) - D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1)) - torch.save(self.G.state_dict(), G_path) - torch.save(self.D.state_dict(), D_path) - print('Saved model checkpoints into {}...'.format(self.model_save_dir)) - - # Decay learning rates. - if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay): - g_lr -= (self.g_lr / float(self.num_iters_decay)) - d_lr -= (self.d_lr / float(self.num_iters_decay)) - self.update_lr(g_lr, d_lr) - print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) + # def train_multi(self): + # """Train StarGAN with multiple datasets.""" + # # Data iterators. + # celeba_iter = iter(self.celeba_loader) + # rafd_iter = iter(self.rafd_loader) + # + # # Fetch fixed inputs for debugging. + # x_fixed, c_org = next(celeba_iter) + # x_fixed = x_fixed.to(self.device) + # c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs) + # c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD') + # zero_celeba = torch.zeros(x_fixed.size(0), self.c_dim).to(self.device) # Zero vector for CelebA. + # zero_rafd = torch.zeros(x_fixed.size(0), self.c2_dim).to(self.device) # Zero vector for RaFD. + # mask_celeba = self.label2onehot(torch.zeros(x_fixed.size(0)), 2).to(self.device) # Mask vector: [1, 0]. + # mask_rafd = self.label2onehot(torch.ones(x_fixed.size(0)), 2).to(self.device) # Mask vector: [0, 1]. + # + # # Learning rate cache for decaying. + # g_lr = self.g_lr + # d_lr = self.d_lr + # + # # Start training from scratch or resume training. + # start_iters = 0 + # if self.resume_iters: + # start_iters = self.resume_iters + # self.restore_model(self.resume_iters) + # + # # Start training. + # print('Start training...') + # start_time = time.time() + # for i in range(start_iters, self.num_iters): + # for dataset in ['CelebA', 'RaFD']: + # + # # =================================================================================== # + # # 1. Preprocess input data # + # # =================================================================================== # + # + # # Fetch real images and labels. + # data_iter = celeba_iter if dataset == 'CelebA' else rafd_iter + # + # try: + # x_real, label_org = next(data_iter) + # except: + # if dataset == 'CelebA': + # celeba_iter = iter(self.celeba_loader) + # x_real, label_org = next(celeba_iter) + # elif dataset == 'RaFD': + # rafd_iter = iter(self.rafd_loader) + # x_real, label_org = next(rafd_iter) + # + # # Generate target domain labels randomly. + # rand_idx = torch.randperm(label_org.size(0)) + # label_trg = label_org[rand_idx] + # + # if dataset == 'CelebA': + # c_org = label_org.clone() + # c_trg = label_trg.clone() + # zero = torch.zeros(x_real.size(0), self.c2_dim) + # mask = self.label2onehot(torch.zeros(x_real.size(0)), 2) + # c_org = torch.cat([c_org, zero, mask], dim=1) + # c_trg = torch.cat([c_trg, zero, mask], dim=1) + # elif dataset == 'RaFD': + # c_org = self.label2onehot(label_org, self.c2_dim) + # c_trg = self.label2onehot(label_trg, self.c2_dim) + # zero = torch.zeros(x_real.size(0), self.c_dim) + # mask = self.label2onehot(torch.ones(x_real.size(0)), 2) + # c_org = torch.cat([zero, c_org, mask], dim=1) + # c_trg = torch.cat([zero, c_trg, mask], dim=1) + # + # x_real = x_real.to(self.device) # Input images. + # c_org = c_org.to(self.device) # Original domain labels. + # c_trg = c_trg.to(self.device) # Target domain labels. + # label_org = label_org.to(self.device) # Labels for computing classification loss. + # label_trg = label_trg.to(self.device) # Labels for computing classification loss. + # + # # =================================================================================== # + # # 2. Train the discriminator # + # # =================================================================================== # + # + # # Compute loss with real images. + # out_src, out_cls = self.D(x_real) + # out_cls = out_cls[:, :self.c_dim] if dataset == 'CelebA' else out_cls[:, self.c_dim:] + # d_loss_real = - torch.mean(out_src) + # d_loss_cls = self.classification_loss(out_cls, label_org, dataset) + # + # # Compute loss with fake images. + # x_fake = self.G(x_real, c_trg) + # out_src, _ = self.D(x_fake.detach()) + # d_loss_fake = torch.mean(out_src) + # + # # Compute loss for gradient penalty. + # alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device) + # x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True) + # out_src, _ = self.D(x_hat) + # d_loss_gp = self.gradient_penalty(out_src, x_hat) + # + # # Backward and optimize. + # d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp + # self.reset_grad() + # d_loss.backward() + # self.d_optimizer.step() + # + # # Logging. + # loss = {} + # loss['D/loss_real'] = d_loss_real.item() + # loss['D/loss_fake'] = d_loss_fake.item() + # loss['D/loss_cls'] = d_loss_cls.item() + # loss['D/loss_gp'] = d_loss_gp.item() + # + # # =================================================================================== # + # # 3. Train the generator # + # # =================================================================================== # + # + # if (i+1) % self.n_critic == 0: + # # Original-to-target domain. + # x_fake = self.G(x_real, c_trg) + # out_src, out_cls = self.D(x_fake) + # out_cls = out_cls[:, :self.c_dim] if dataset == 'CelebA' else out_cls[:, self.c_dim:] + # g_loss_fake = - torch.mean(out_src) + # g_loss_cls = self.classification_loss(out_cls, label_trg, dataset) + # + # # Target-to-original domain. + # x_reconst = self.G(x_fake, c_org) + # g_loss_rec = torch.mean(torch.abs(x_real - x_reconst)) + # + # # Backward and optimize. + # g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls + # self.reset_grad() + # g_loss.backward() + # self.g_optimizer.step() + # + # # Logging. + # loss['G/loss_fake'] = g_loss_fake.item() + # loss['G/loss_rec'] = g_loss_rec.item() + # loss['G/loss_cls'] = g_loss_cls.item() + # + # # =================================================================================== # + # # 4. Miscellaneous # + # # =================================================================================== # + # + # # Print out training info. + # if (i+1) % self.log_step == 0: + # et = time.time() - start_time + # et = str(datetime.timedelta(seconds=et))[:-7] + # log = "Elapsed [{}], Iteration [{}/{}], Dataset [{}]".format(et, i+1, self.num_iters, dataset) + # for tag, value in loss.items(): + # log += ", {}: {:.4f}".format(tag, value) + # print(log) + # + # if self.use_tensorboard: + # for tag, value in loss.items(): + # self.logger.scalar_summary(tag, value, i+1) + # + # # Translate fixed images for debugging. + # if (i+1) % self.sample_step == 0: + # with torch.no_grad(): + # x_fake_list = [x_fixed] + # for c_fixed in c_celeba_list: + # c_trg = torch.cat([c_fixed, zero_rafd, mask_celeba], dim=1) + # x_fake_list.append(self.G(x_fixed, c_trg)) + # for c_fixed in c_rafd_list: + # c_trg = torch.cat([zero_celeba, c_fixed, mask_rafd], dim=1) + # x_fake_list.append(self.G(x_fixed, c_trg)) + # x_concat = torch.cat(x_fake_list, dim=3) + # sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1)) + # save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0) + # print('Saved real and fake images into {}...'.format(sample_path)) + # + # # Save model checkpoints. + # if (i+1) % self.model_save_step == 0: + # G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1)) + # D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1)) + # torch.save(self.G.state_dict(), G_path) + # torch.save(self.D.state_dict(), D_path) + # print('Saved model checkpoints into {}...'.format(self.model_save_dir)) + # + # # Decay learning rates. + # if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay): + # g_lr -= (self.g_lr / float(self.num_iters_decay)) + # d_lr -= (self.d_lr / float(self.num_iters_decay)) + # self.update_lr(g_lr, d_lr) + # print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) def test(self): """Translate images using StarGAN trained on a single dataset.""" # Load the trained generator. self.restore_model(self.test_iters) - + # Set data loader. if self.dataset == 'CelebA': data_loader = self.celeba_loader elif self.dataset == 'RaFD': data_loader = self.rafd_loader - + elif self.dataset == 'AffectNet': + data_loader = self.affectnet_loader + with torch.no_grad(): for i, (x_real, c_org) in enumerate(data_loader): @@ -549,34 +572,34 @@ class Solver(object): save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0) print('Saved real and fake images into {}...'.format(result_path)) - def test_multi(self): - """Translate images using StarGAN trained on multiple datasets.""" - # Load the trained generator. - self.restore_model(self.test_iters) - - with torch.no_grad(): - for i, (x_real, c_org) in enumerate(self.celeba_loader): - - # Prepare input images and target domain labels. - x_real = x_real.to(self.device) - c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs) - c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD') - zero_celeba = torch.zeros(x_real.size(0), self.c_dim).to(self.device) # Zero vector for CelebA. - zero_rafd = torch.zeros(x_real.size(0), self.c2_dim).to(self.device) # Zero vector for RaFD. - mask_celeba = self.label2onehot(torch.zeros(x_real.size(0)), 2).to(self.device) # Mask vector: [1, 0]. - mask_rafd = self.label2onehot(torch.ones(x_real.size(0)), 2).to(self.device) # Mask vector: [0, 1]. - - # Translate images. - x_fake_list = [x_real] - for c_celeba in c_celeba_list: - c_trg = torch.cat([c_celeba, zero_rafd, mask_celeba], dim=1) - x_fake_list.append(self.G(x_real, c_trg)) - for c_rafd in c_rafd_list: - c_trg = torch.cat([zero_celeba, c_rafd, mask_rafd], dim=1) - x_fake_list.append(self.G(x_real, c_trg)) - - # Save the translated images. - x_concat = torch.cat(x_fake_list, dim=3) - result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1)) - save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0) - print('Saved real and fake images into {}...'.format(result_path)) \ No newline at end of file + # def test_multi(self): + # """Translate images using StarGAN trained on multiple datasets.""" + # # Load the trained generator. + # self.restore_model(self.test_iters) + # + # with torch.no_grad(): + # for i, (x_real, c_org) in enumerate(self.celeba_loader): + # + # # Prepare input images and target domain labels. + # x_real = x_real.to(self.device) + # c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs) + # c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD') + # zero_celeba = torch.zeros(x_real.size(0), self.c_dim).to(self.device) # Zero vector for CelebA. + # zero_rafd = torch.zeros(x_real.size(0), self.c2_dim).to(self.device) # Zero vector for RaFD. + # mask_celeba = self.label2onehot(torch.zeros(x_real.size(0)), 2).to(self.device) # Mask vector: [1, 0]. + # mask_rafd = self.label2onehot(torch.ones(x_real.size(0)), 2).to(self.device) # Mask vector: [0, 1]. + # + # # Translate images. + # x_fake_list = [x_real] + # for c_celeba in c_celeba_list: + # c_trg = torch.cat([c_celeba, zero_rafd, mask_celeba], dim=1) + # x_fake_list.append(self.G(x_real, c_trg)) + # for c_rafd in c_rafd_list: + # c_trg = torch.cat([zero_celeba, c_rafd, mask_rafd], dim=1) + # x_fake_list.append(self.G(x_real, c_trg)) + # + # # Save the translated images. + # x_concat = torch.cat(x_fake_list, dim=3) + # result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1)) + # save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0) + # print('Saved real and fake images into {}...'.format(result_path))