Improvements to style transfer as discussed in https://github.com/fchollet/keras/pull/6872 (#6877)
Esse commit está contido em:
commit de
François Chollet
pai
720ed1adc4
commit
53303fdb10
@@ -57,7 +57,7 @@ from scipy.optimize import fmin_l_bfgs_b
|
||||
import time
|
||||
import argparse
|
||||
|
||||
from keras.applications import vgg16
|
||||
from keras.applications import vgg19
|
||||
from keras import backend as K
|
||||
|
||||
parser = argparse.ArgumentParser(description='Neural style transfer with Keras.')
|
||||
@@ -99,7 +99,7 @@ def preprocess_image(image_path):
|
||||
img = load_img(image_path, target_size=(img_nrows, img_ncols))
|
||||
img = img_to_array(img)
|
||||
img = np.expand_dims(img, axis=0)
|
||||
img = vgg16.preprocess_input(img)
|
||||
img = vgg19.preprocess_input(img)
|
||||
return img
|
||||
|
||||
# util function to convert a tensor into a valid image
|
||||
@@ -137,7 +137,7 @@ input_tensor = K.concatenate([base_image,
|
||||
|
||||
# build the VGG16 network with our 3 images as input
|
||||
# the model will be loaded with pre-trained ImageNet weights
|
||||
model = vgg16.VGG16(input_tensor=input_tensor,
|
||||
model = vgg19.VGG19(input_tensor=input_tensor,
|
||||
weights='imagenet', include_top=False)
|
||||
print('Model loaded.')
|
||||
|
||||
@@ -199,7 +199,7 @@ def total_variation_loss(x):
|
||||
|
||||
# combine these loss functions into a single scalar
|
||||
loss = K.variable(0.)
|
||||
layer_features = outputs_dict['block4_conv2']
|
||||
layer_features = outputs_dict['block5_conv2']
|
||||
base_image_features = layer_features[0, :, :, :]
|
||||
combination_features = layer_features[2, :, :, :]
|
||||
loss += content_weight * content_loss(base_image_features,
|
||||
@@ -273,10 +273,7 @@ evaluator = Evaluator()
|
||||
|
||||
# run scipy-based optimization (L-BFGS) over the pixels of the generated image
|
||||
# so as to minimize the neural style loss
|
||||
if K.image_data_format() == 'channels_first':
|
||||
x = np.random.uniform(0, 255, (1, 3, img_nrows, img_ncols)) - 128.
|
||||
else:
|
||||
x = np.random.uniform(0, 255, (1, img_nrows, img_ncols, 3)) - 128.
|
||||
x = preprocess_image(base_image_path)
|
||||
|
||||
for i in range(iterations):
|
||||
print('Start of iteration', i)
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário