Add map location 'gpu'->'cpu' for retoring models
Esse commit está contido em:
+3
-3
@@ -82,7 +82,7 @@ class Solver(object):
|
||||
self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
|
||||
self.print_network(self.G, 'G')
|
||||
self.print_network(self.D, 'D')
|
||||
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.G.cuda()
|
||||
self.D.cuda()
|
||||
@@ -101,8 +101,8 @@ class Solver(object):
|
||||
print('Loading the trained models from step {}...'.format(resume_iters))
|
||||
G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters))
|
||||
D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters))
|
||||
self.G.load_state_dict(torch.load(G_path))
|
||||
self.D.load_state_dict(torch.load(D_path))
|
||||
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
|
||||
self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
|
||||
|
||||
def build_tensorboard(self):
|
||||
"""Build a tensorboard logger."""
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário