version changed (pytorch 0.3)

Esse commit está contido em:
yunjey
2018-01-15 18:19:46 +09:00
commit 1a40a165f7
2 arquivos alterados com 4 adições e 4 exclusões
+3 -2
Ver Arquivo
@@ -19,6 +19,7 @@ class CelebDataset(Dataset):
self.idx2attr = {}
print ('Start preprocessing dataset..!')
random.seed(1234)
self.preprocess()
print ('Finished preprocessing dataset..!')
@@ -84,14 +85,14 @@ def get_loader(image_path, metadata_path, crop_size, image_size, batch_size, dat
if mode == 'train':
transform = transforms.Compose([
transforms.CenterCrop(crop_size),
transforms.Scale(image_size),
transforms.Resize(image_size, interpolation=Image.ANTIALIAS),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
else:
transform = transforms.Compose([
transforms.CenterCrop(crop_size),
transforms.Scale(image_size),
transforms.Scale(image_size, interpolation=Image.ANTIALIAS),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
+1 -2
Ver Arquivo
@@ -134,8 +134,7 @@ class Solver(object):
def threshold(self, x):
x = x.clone()
x[x >= 0.5] = 1
x[x < 0.5] = 0
x = (x >= 0.5).float()
return x
def compute_accuracy(self, x, y, dataset):