Allow custom print functions for summary

Esse commit está contido em:
Francois Chollet
2017-06-23 13:24:20 -07:00
commit 7dcd2982b2
2 arquivos alterados com 37 adições e 16 exclusões
+19 -4
Ver Arquivo
@@ -2637,10 +2637,25 @@ class Container(Layer):
"""
return yaml.dump(self._updated_config(), **kwargs)
def summary(self, line_length=None, positions=None):
print_layer_summary(self,
line_length=line_length,
positions=positions)
def summary(self, line_length=None, positions=None, print_fn=print):
"""Prints a string summary of the network.
# Arguments
line_length: Total length of printed lines
(e.g. set this to adapt the display to different
terminal window sizes).
positions: Relative or absolute positions of log elements
in each line. If not provided,
defaults to `[.33, .55, .67, 1.]`.
print_fn: Print function to use.
It will be called on each line of the summary.
You can set it to a custom function
in order to capture the string summary.
"""
return print_layer_summary(self,
line_length=line_length,
positions=positions,
print_fn=print_fn)
def get_source_inputs(tensor, layer=None, node_index=None):
+18 -12
Ver Arquivo
@@ -5,14 +5,20 @@ from .. import backend as K
import numpy as np
def print_summary(model, line_length=None, positions=None):
def print_summary(model, line_length=None, positions=None, print_fn=print):
"""Prints a summary of a model.
# Arguments
model: Keras model instance.
line_length: total length of printed lines
positions: relative or absolute positions of log elements in each line.
line_length: Total length of printed lines
(e.g. set this to adapt the display to different
terminal window sizes).
positions: Relative or absolute positions of log elements in each line.
If not provided, defaults to `[.33, .55, .67, 1.]`.
print_fn: Print function to use.
It will be called on each line of the summary.
You can set it to a custom function
in order to capture the string summary.
"""
if model.__class__.__name__ == 'Sequential':
sequential_like = True
@@ -51,11 +57,11 @@ def print_summary(model, line_length=None, positions=None):
line += str(fields[i])
line = line[:positions[i]]
line += ' ' * (positions[i] - len(line))
print(line)
print_fn(line)
print('_' * line_length)
print_fn('_' * line_length)
print_row(to_display, positions)
print('=' * line_length)
print_fn('=' * line_length)
def print_layer_summary(layer):
try:
@@ -108,19 +114,19 @@ def print_summary(model, line_length=None, positions=None):
else:
print_layer_summary_with_connections(layers[i])
if i == len(layers) - 1:
print('=' * line_length)
print_fn('=' * line_length)
else:
print('_' * line_length)
print_fn('_' * line_length)
trainable_count = int(
np.sum([K.count_params(p) for p in set(model.trainable_weights)]))
non_trainable_count = int(
np.sum([K.count_params(p) for p in set(model.non_trainable_weights)]))
print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
print('Non-trainable params: {:,}'.format(non_trainable_count))
print('_' * line_length)
print_fn('Total params: {:,}'.format(trainable_count + non_trainable_count))
print_fn('Trainable params: {:,}'.format(trainable_count))
print_fn('Non-trainable params: {:,}'.format(non_trainable_count))
print_fn('_' * line_length)
def convert_all_kernels_in_model(model):