Allow arbitrary channel dimensions in ImageDataGenerator
Esse commit está contido em:
@@ -636,8 +636,8 @@ class ImageDataGenerator(object):
|
||||
if x.ndim != 4:
|
||||
raise ValueError('Input to `.fit()` should have rank 4. '
|
||||
'Got array with shape: ' + str(x.shape))
|
||||
if x.shape[self.channel_axis] not in {1, 3, 4}:
|
||||
raise ValueError(
|
||||
if x.shape[self.channel_axis] not in {3, 4}:
|
||||
warnings.warn(
|
||||
'Expected input to be images (as Numpy array) '
|
||||
'following the data format convention "' + self.data_format + '" '
|
||||
'(channels on axis ' + str(self.channel_axis) + '), i.e. expected '
|
||||
|
||||
@@ -73,22 +73,11 @@ class TestImage:
|
||||
with pytest.raises(ValueError):
|
||||
x = np.random.random((3, 10, 10))
|
||||
generator.fit(x)
|
||||
with pytest.raises(ValueError):
|
||||
x = np.random.random((32, 3, 10, 10))
|
||||
generator.fit(x)
|
||||
with pytest.raises(ValueError):
|
||||
x = np.random.random((32, 10, 10, 5))
|
||||
generator.fit(x)
|
||||
|
||||
# Test flow with invalid data
|
||||
with pytest.raises(ValueError):
|
||||
x = np.random.random((32, 10, 10, 5))
|
||||
generator.flow(np.arange(x.shape[0]))
|
||||
with pytest.raises(ValueError):
|
||||
x = np.random.random((32, 10, 10))
|
||||
generator.flow(np.arange(x.shape[0]))
|
||||
with pytest.raises(ValueError):
|
||||
x = np.random.random((32, 3, 10, 10))
|
||||
generator.flow(np.arange(x.shape[0]))
|
||||
|
||||
def test_image_data_generator_fit(self):
|
||||
generator = image.ImageDataGenerator(
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário