feat: Simplify tiny YOLO training
Added some convenience features that make training the tiny model easier.
Esse commit está contido em:
@@ -8,8 +8,8 @@ import argparse
|
||||
|
||||
|
||||
def get_parent_dir(n=1):
|
||||
""" returns the n-th parent dicrectory of the current
|
||||
working directory """
|
||||
"""returns the n-th parent dicrectory of the current
|
||||
working directory"""
|
||||
current_path = os.path.dirname(os.path.abspath(__file__))
|
||||
for k in range(n):
|
||||
current_path = os.path.dirname(current_path)
|
||||
|
||||
@@ -27,29 +27,46 @@ if __name__ == "__main__":
|
||||
help="Folder to download weights to. Default is " + download_folder,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--is_tiny",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Use the tiny Yolo version for better performance and less accuracy. Default is False.",
|
||||
)
|
||||
|
||||
FLAGS = parser.parse_args()
|
||||
|
||||
if not os.path.isfile(os.path.join(download_folder, "yolov3.weights")):
|
||||
print("\n", "Downloading Raw YOLOv3 Weights", "\n")
|
||||
if not FLAGS.is_tiny:
|
||||
weights_file = "yolov3.weights"
|
||||
h5_file = "yolo.h5"
|
||||
cfg_file = "yolov3.cfg"
|
||||
# Original URL: https://pjreddie.com/media/files/yolov3.weights
|
||||
gdrive_id = "1ENKguLZbkgvM8unU3Hq1BoFzoLeGWvE_"
|
||||
|
||||
else:
|
||||
weights_file = "yolov3-tiny.weights"
|
||||
h5_file = "yolo-tiny.h5"
|
||||
cfg_file = "yolov3-tiny.cfg"
|
||||
# Original URL: https://pjreddie.com/media/files/yolov3-tiny.weights
|
||||
gdrive_id = "1mIEZthXBcEguMvuVAHKLXQX3mA1oZUuC"
|
||||
|
||||
if not os.path.isfile(os.path.join(download_folder, weights_file)):
|
||||
print(f"\nDownloading Raw {weights_file}\n")
|
||||
start = time.time()
|
||||
call_string = " ".join(
|
||||
[
|
||||
"python",
|
||||
download_script,
|
||||
"1ENKguLZbkgvM8unU3Hq1BoFzoLeGWvE_",
|
||||
os.path.join(download_folder, "yolov3.weights"),
|
||||
gdrive_id,
|
||||
os.path.join(download_folder, weights_file),
|
||||
]
|
||||
)
|
||||
|
||||
subprocess.call(call_string, shell=True)
|
||||
|
||||
end = time.time()
|
||||
print(
|
||||
"Downloaded Raw YOLOv3 Weights in {0:.1f} seconds".format(end - start), "\n"
|
||||
)
|
||||
print(f"Downloaded Raw {weights_file} in {end - start:.1f} seconds\n")
|
||||
|
||||
# Original URL: https://pjreddie.com/media/files/yolov3.weights
|
||||
|
||||
call_string = "python convert.py yolov3.cfg yolov3.weights yolo.h5"
|
||||
call_string = f"python convert.py {cfg_file} {weights_file} {h5_file}"
|
||||
|
||||
subprocess.call(call_string, shell=True, cwd=download_folder)
|
||||
|
||||
@@ -158,8 +158,15 @@ if __name__ == "__main__":
|
||||
|
||||
class_names = get_classes(FLAGS.classes_file)
|
||||
num_classes = len(class_names)
|
||||
anchors = get_anchors(FLAGS.anchors_path)
|
||||
weights_path = FLAGS.weights_path
|
||||
|
||||
if FLAGS.is_tiny and FLAGS.weights_path == weights_path:
|
||||
weights_path = os.path.join(os.path.dirname(FLAGS.weights_path), "yolo-tiny.h5")
|
||||
if FLAGS.is_tiny and FLAGS.anchors_path == anchors_path:
|
||||
anchors_path = os.path.join(
|
||||
os.path.dirname(FLAGS.anchors_path), "yolo-tiny_anchors.txt"
|
||||
)
|
||||
|
||||
anchors = get_anchors(anchors_path)
|
||||
|
||||
input_shape = (416, 416) # multiple of 32, height, width
|
||||
epoch1, epoch2 = FLAGS.epochs, FLAGS.epochs
|
||||
|
||||
@@ -272,7 +272,7 @@ def _main(args):
|
||||
remaining_weights = len(weights_file.read()) / 4
|
||||
weights_file.close()
|
||||
print(
|
||||
"Read {} of {} from Darknet weights.".format(count, count + remaining_weights)
|
||||
f"Read {count:0.0f} of {count + remaining_weights:0.0f} from Darknet weights."
|
||||
)
|
||||
if remaining_weights > 0:
|
||||
print("Warning: {} unused weights".format(remaining_weights))
|
||||
|
||||
@@ -21,6 +21,7 @@ import tensorflow.python.keras.backend as K
|
||||
|
||||
tf.disable_eager_execution()
|
||||
|
||||
|
||||
class YOLO(object):
|
||||
_defaults = {
|
||||
"model_path": "model_data/yolo.h5",
|
||||
|
||||
@@ -3,8 +3,8 @@ import sys
|
||||
|
||||
|
||||
def get_parent_dir(n=1):
|
||||
""" returns the n-th parent dicrectory of the current
|
||||
working directory """
|
||||
"""returns the n-th parent dicrectory of the current
|
||||
working directory"""
|
||||
current_path = os.path.dirname(os.path.abspath(__file__))
|
||||
for k in range(n):
|
||||
current_path = os.path.dirname(current_path)
|
||||
@@ -28,6 +28,7 @@ import pandas as pd
|
||||
import numpy as np
|
||||
from Get_File_Paths import GetFileList
|
||||
import random
|
||||
from Train_Utils import get_anchors
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
|
||||
@@ -141,6 +142,13 @@ if __name__ == "__main__":
|
||||
help='Specify the postfix for images with bounding boxes. Default is "_catface"',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--is_tiny",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Use the tiny Yolo version for better performance and less accuracy. Default is False.",
|
||||
)
|
||||
|
||||
FLAGS = parser.parse_args()
|
||||
|
||||
save_img = not FLAGS.no_save_img
|
||||
@@ -168,11 +176,17 @@ if __name__ == "__main__":
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
|
||||
if FLAGS.is_tiny and FLAGS.anchors_path == anchors_path:
|
||||
anchors_path = os.path.join(
|
||||
os.path.dirname(FLAGS.anchors_path), "yolo-tiny_anchors.txt"
|
||||
)
|
||||
|
||||
anchors = get_anchors(anchors_path)
|
||||
# define YOLO detector
|
||||
yolo = YOLO(
|
||||
**{
|
||||
"model_path": FLAGS.model_path,
|
||||
"anchors_path": FLAGS.anchors_path,
|
||||
"anchors_path": anchors_path,
|
||||
"classes_path": FLAGS.classes_path,
|
||||
"score": FLAGS.score,
|
||||
"gpu_num": FLAGS.gpu_num,
|
||||
|
||||
@@ -9,7 +9,14 @@ from Get_File_Paths import GetFileList, ChangeToOtherMachine
|
||||
|
||||
def convert_vott_csv_to_yolo(
|
||||
vott_df,
|
||||
labeldict=dict(zip(["Cat_Face"], [0,])),
|
||||
labeldict=dict(
|
||||
zip(
|
||||
["Cat_Face"],
|
||||
[
|
||||
0,
|
||||
],
|
||||
)
|
||||
),
|
||||
path="",
|
||||
target_name="data_train.txt",
|
||||
abs_path=False,
|
||||
@@ -106,16 +113,16 @@ def crop_and_save(
|
||||
):
|
||||
"""Takes a vott_csv file with image names, labels and crop_boxes
|
||||
and crops the images accordingly
|
||||
|
||||
|
||||
Input csv file format:
|
||||
|
||||
|
||||
image xmin ymin xmax ymax label
|
||||
im.jpg 0 10 100 500 house
|
||||
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df : pd.Dataframe
|
||||
df : pd.Dataframe
|
||||
The input dataframe with file_names, bounding box info
|
||||
and label
|
||||
source_path : str
|
||||
@@ -175,7 +182,14 @@ def crop_and_save(
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Prepare the houses dataset for YOLO
|
||||
labeldict = dict(zip(["house"], [0,]))
|
||||
labeldict = dict(
|
||||
zip(
|
||||
["house"],
|
||||
[
|
||||
0,
|
||||
],
|
||||
)
|
||||
)
|
||||
multi_df = pd.read_csv(
|
||||
"C:/Users/Anton/Documents/Insight/eq/EQ_new/Train_Housing_detector/2/vott-csv-export/Housing_cropping-export.csv"
|
||||
)
|
||||
|
||||
@@ -40,7 +40,7 @@ def ChangeToOtherMachine(filelist, repo="TrainYourOwnYOLO", remote_machine=""):
|
||||
'/home/ubuntu/TrainYourOwnYOLO/Data/Street_View_Images/vulnerable/test.jpg'
|
||||
|
||||
Get's converted to
|
||||
|
||||
|
||||
'C:/Users/Anton/TrainYourOwnYOLO/Data/Street_View_Images/vulnerable/test.jpg'
|
||||
|
||||
"""
|
||||
|
||||
@@ -9,8 +9,8 @@ import sys
|
||||
|
||||
|
||||
def get_parent_dir(n=1):
|
||||
""" returns the n-th parent dicrectory of the current
|
||||
working directory """
|
||||
"""returns the n-th parent dicrectory of the current
|
||||
working directory"""
|
||||
current_path = os.getcwd()
|
||||
for k in range(n):
|
||||
current_path = os.path.dirname(current_path)
|
||||
@@ -221,7 +221,7 @@ def ChangeToOtherMachine(filelist, repo="TrainYourOwnYOLO", remote_machine=""):
|
||||
'/home/ubuntu/TrainYourOwnYOLO/Data/Street_View_Images/vulnerable/test.jpg'
|
||||
|
||||
Get's converted to
|
||||
|
||||
|
||||
'C:/Users/Anton/TrainYourOwnYOLO/Data/Street_View_Images/vulnerable/test.jpg'
|
||||
|
||||
"""
|
||||
|
||||
+1
-2
@@ -78,8 +78,7 @@ def load_extractor_model(model_name="InceptionV3", flavor=1):
|
||||
flavor: int specifying the model variant and input_shape.
|
||||
For InceptionV3, the map is {0: default, 1: 200*200, truncate last Inception block,
|
||||
2: 200*200, truncate last 2 blocks, 3: 200*200, truncate last 3 blocks, 4: 200*200}
|
||||
For VGG16, it only changes the input size, {0: 224 (default), 1: 128, 2: 64}.
|
||||
"""
|
||||
For VGG16, it only changes the input size, {0: 224 (default), 1: 128, 2: 64}."""
|
||||
start = timer()
|
||||
if model_name == "InceptionV3":
|
||||
from keras.applications.inception_v3 import InceptionV3
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário