Esse commit está contido em:
Charlie Hewitt
2017-12-19 18:17:10 +00:00
2 arquivos alterados com 18 adições e 15 exclusões
+1 -1
Ver Arquivo
@@ -1,2 +1,2 @@
__pycache__
.DS_Store
+17 -14
Ver Arquivo
@@ -16,7 +16,7 @@ CLASSIFY = 0
REGRESS = 1
# OTPIONS #
BATCH_SIZE = 256 # TODO whatever GPU can handle
BATCH_SIZE = 400
EPOCHS = 24
def load_images(C_or_R, paths, labels, batch_size=32, eval=False):
@@ -130,22 +130,22 @@ def base_model(C_or_R):
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
# CONV BLOCK 4
# model.add(Conv2D(256, (3, 3)))
# model.add(BatchNormalization(axis=-1))
# model.add(Activation('relu'))
# model.add(Conv2D(256, (3, 3)))
# model.add(BatchNormalization(axis=-1))
# model.add(Activation('relu'))
# model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Conv2D(256, (3, 3)))
model.add(BatchNormalization(axis=-1))
model.add(Activation('relu'))
model.add(Conv2D(256, (3, 3)))
model.add(BatchNormalization(axis=-1))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
# FLATTEN
model.add(Flatten())
# DENSE 1
model.add(Dense(256))
model.add(Dense(1024))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Dropout(0.5))
# DENSE 2
model.add(Dense(256))
model.add(Dense(1024))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Dropout(0.5))
@@ -243,7 +243,7 @@ def train(C_or_R, mode=0):
steps_per_epoch=len(t_labels) // BATCH_SIZE,
class_weight=t_weights,
epochs=EPOCHS,
validation_data=load_images(C_or_R, v_paths, v_labels, BATCH_SIZE),
validation_data=load_images(C_or_R, v_paths, v_labels, BATCH_SIZE, eval=True),
validation_steps=len(v_labels) // BATCH_SIZE,
callbacks=[ModelCheckpoint('AFF_NET_C_WIP.h5', monitor='val_acc', save_best_only=True)])
else:
@@ -252,7 +252,7 @@ def train(C_or_R, mode=0):
load_images(C_or_R, t_paths, t_labels, BATCH_SIZE, eval=True),
steps_per_epoch=len(t_labels) // BATCH_SIZE,
epochs=EPOCHS,
validation_data=load_images(C_or_R, v_paths, v_labels, BATCH_SIZE),
validation_data=load_images(C_or_R, v_paths, v_labels, BATCH_SIZE, eval=True),
validation_steps=len(v_labels) // BATCH_SIZE,
callbacks=[ModelCheckpoint('AFF_NET_R_WIP.h5', save_best_only=True)])
@@ -267,7 +267,10 @@ def train(C_or_R, mode=0):
model.save('AFF_NET_' + ns + '_FULL.h5')
# eval(CLASSIFY)
# train(CLASSIFY, mode=0)
#
# train(REGRESS, mode=1)
# load_and_save()
# eval(CLASSIFY)
# eval(CLASSIFY)
if __name__ == '__main__':
train(CLASSIFY, mode=0)