Allow arbitrary channel dimensions in ImageDataGenerator

Esse commit está contido em:
Francois Chollet
2017-06-06 16:29:08 -07:00
commit 552978dc58
2 arquivos alterados com 3 adições e 14 exclusões
+2 -2
Ver Arquivo
@@ -636,8 +636,8 @@ class ImageDataGenerator(object):
if x.ndim != 4:
raise ValueError('Input to `.fit()` should have rank 4. '
'Got array with shape: ' + str(x.shape))
if x.shape[self.channel_axis] not in {1, 3, 4}:
raise ValueError(
if x.shape[self.channel_axis] not in {3, 4}:
warnings.warn(
'Expected input to be images (as Numpy array) '
'following the data format convention "' + self.data_format + '" '
'(channels on axis ' + str(self.channel_axis) + '), i.e. expected '
+1 -12
Ver Arquivo
@@ -73,22 +73,11 @@ class TestImage:
with pytest.raises(ValueError):
x = np.random.random((3, 10, 10))
generator.fit(x)
with pytest.raises(ValueError):
x = np.random.random((32, 3, 10, 10))
generator.fit(x)
with pytest.raises(ValueError):
x = np.random.random((32, 10, 10, 5))
generator.fit(x)
# Test flow with invalid data
with pytest.raises(ValueError):
x = np.random.random((32, 10, 10, 5))
generator.flow(np.arange(x.shape[0]))
with pytest.raises(ValueError):
x = np.random.random((32, 10, 10))
generator.flow(np.arange(x.shape[0]))
with pytest.raises(ValueError):
x = np.random.random((32, 3, 10, 10))
generator.flow(np.arange(x.shape[0]))
def test_image_data_generator_fit(self):
generator = image.ImageDataGenerator(