Allow custom print functions for summary
Esse commit está contido em:
@@ -2637,10 +2637,25 @@ class Container(Layer):
|
||||
"""
|
||||
return yaml.dump(self._updated_config(), **kwargs)
|
||||
|
||||
def summary(self, line_length=None, positions=None):
|
||||
print_layer_summary(self,
|
||||
line_length=line_length,
|
||||
positions=positions)
|
||||
def summary(self, line_length=None, positions=None, print_fn=print):
|
||||
"""Prints a string summary of the network.
|
||||
|
||||
# Arguments
|
||||
line_length: Total length of printed lines
|
||||
(e.g. set this to adapt the display to different
|
||||
terminal window sizes).
|
||||
positions: Relative or absolute positions of log elements
|
||||
in each line. If not provided,
|
||||
defaults to `[.33, .55, .67, 1.]`.
|
||||
print_fn: Print function to use.
|
||||
It will be called on each line of the summary.
|
||||
You can set it to a custom function
|
||||
in order to capture the string summary.
|
||||
"""
|
||||
return print_layer_summary(self,
|
||||
line_length=line_length,
|
||||
positions=positions,
|
||||
print_fn=print_fn)
|
||||
|
||||
|
||||
def get_source_inputs(tensor, layer=None, node_index=None):
|
||||
|
||||
@@ -5,14 +5,20 @@ from .. import backend as K
|
||||
import numpy as np
|
||||
|
||||
|
||||
def print_summary(model, line_length=None, positions=None):
|
||||
def print_summary(model, line_length=None, positions=None, print_fn=print):
|
||||
"""Prints a summary of a model.
|
||||
|
||||
# Arguments
|
||||
model: Keras model instance.
|
||||
line_length: total length of printed lines
|
||||
positions: relative or absolute positions of log elements in each line.
|
||||
line_length: Total length of printed lines
|
||||
(e.g. set this to adapt the display to different
|
||||
terminal window sizes).
|
||||
positions: Relative or absolute positions of log elements in each line.
|
||||
If not provided, defaults to `[.33, .55, .67, 1.]`.
|
||||
print_fn: Print function to use.
|
||||
It will be called on each line of the summary.
|
||||
You can set it to a custom function
|
||||
in order to capture the string summary.
|
||||
"""
|
||||
if model.__class__.__name__ == 'Sequential':
|
||||
sequential_like = True
|
||||
@@ -51,11 +57,11 @@ def print_summary(model, line_length=None, positions=None):
|
||||
line += str(fields[i])
|
||||
line = line[:positions[i]]
|
||||
line += ' ' * (positions[i] - len(line))
|
||||
print(line)
|
||||
print_fn(line)
|
||||
|
||||
print('_' * line_length)
|
||||
print_fn('_' * line_length)
|
||||
print_row(to_display, positions)
|
||||
print('=' * line_length)
|
||||
print_fn('=' * line_length)
|
||||
|
||||
def print_layer_summary(layer):
|
||||
try:
|
||||
@@ -108,19 +114,19 @@ def print_summary(model, line_length=None, positions=None):
|
||||
else:
|
||||
print_layer_summary_with_connections(layers[i])
|
||||
if i == len(layers) - 1:
|
||||
print('=' * line_length)
|
||||
print_fn('=' * line_length)
|
||||
else:
|
||||
print('_' * line_length)
|
||||
print_fn('_' * line_length)
|
||||
|
||||
trainable_count = int(
|
||||
np.sum([K.count_params(p) for p in set(model.trainable_weights)]))
|
||||
non_trainable_count = int(
|
||||
np.sum([K.count_params(p) for p in set(model.non_trainable_weights)]))
|
||||
|
||||
print('Total params: {:,}'.format(trainable_count + non_trainable_count))
|
||||
print('Trainable params: {:,}'.format(trainable_count))
|
||||
print('Non-trainable params: {:,}'.format(non_trainable_count))
|
||||
print('_' * line_length)
|
||||
print_fn('Total params: {:,}'.format(trainable_count + non_trainable_count))
|
||||
print_fn('Trainable params: {:,}'.format(trainable_count))
|
||||
print_fn('Non-trainable params: {:,}'.format(non_trainable_count))
|
||||
print_fn('_' * line_length)
|
||||
|
||||
|
||||
def convert_all_kernels_in_model(model):
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário