version changed (pytorch 0.3)
Esse commit está contido em:
+3
-2
@@ -19,6 +19,7 @@ class CelebDataset(Dataset):
|
||||
self.idx2attr = {}
|
||||
|
||||
print ('Start preprocessing dataset..!')
|
||||
random.seed(1234)
|
||||
self.preprocess()
|
||||
print ('Finished preprocessing dataset..!')
|
||||
|
||||
@@ -84,14 +85,14 @@ def get_loader(image_path, metadata_path, crop_size, image_size, batch_size, dat
|
||||
if mode == 'train':
|
||||
transform = transforms.Compose([
|
||||
transforms.CenterCrop(crop_size),
|
||||
transforms.Scale(image_size),
|
||||
transforms.Resize(image_size, interpolation=Image.ANTIALIAS),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||
else:
|
||||
transform = transforms.Compose([
|
||||
transforms.CenterCrop(crop_size),
|
||||
transforms.Scale(image_size),
|
||||
transforms.Scale(image_size, interpolation=Image.ANTIALIAS),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||
|
||||
|
||||
+1
-2
@@ -134,8 +134,7 @@ class Solver(object):
|
||||
|
||||
def threshold(self, x):
|
||||
x = x.clone()
|
||||
x[x >= 0.5] = 1
|
||||
x[x < 0.5] = 0
|
||||
x = (x >= 0.5).float()
|
||||
return x
|
||||
|
||||
def compute_accuracy(self, x, y, dataset):
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário