697 linhas
31 KiB
Python
697 linhas
31 KiB
Python
import numpy as np
|
|
import pandas as pd
|
|
from PIL import Image
|
|
from ast import literal_eval
|
|
import os.path
|
|
from tqdm import tqdm
|
|
|
|
import torch.utils.data as data
|
|
|
|
from ..detection.sign_detection import *
|
|
|
|
# from utils.cython_bbox import bbox_overlaps
|
|
from ..utils.bbox_utils import clip_boxes
|
|
from ..utils.transform_utils import convert2binaryPIL
|
|
from ..utils.torchcv.transforms.crop_box import crop_box
|
|
from ..utils.torchcv.transforms.resize import resize
|
|
from ..utils.torchcv.transforms_lm.crop_box import crop_box_lm
|
|
from ..utils.torchcv.transforms_lm.resize import resize_lm
|
|
|
|
from .lines_dataset import collect_line_coords, create_line_trafo
|
|
|
|
from ..detection.line_detection import compute_image_label_map
|
|
|
|
|
|
# helper functions
|
|
|
|
|
|
def convert_bbox_global2local(gbbox, seg_bbox):
|
|
x, y = seg_bbox[:2]
|
|
relative_bbox = np.array(gbbox) - np.array([x, y, x, y])
|
|
return relative_bbox.tolist()
|
|
|
|
|
|
def get_segment_meta(segment_rec):
|
|
image_name = segment_rec.tablet_CDLI
|
|
|
|
# this should control which scale is used in consecutive processing
|
|
scale = segment_rec.scale #* self.rescale
|
|
|
|
seg_bbox = segment_rec.bbox
|
|
path_to_image = segment_rec.im_path
|
|
view_desc = "{}".format(segment_rec.view_desc).replace("nan", "")
|
|
|
|
return image_name, scale, seg_bbox, path_to_image, view_desc
|
|
|
|
|
|
def compute_tiles(imw, imh, scale, tile_shape=[600, 600], border_sz=100, w_step_sz=300, h_step_sz=400):
|
|
# TODO: improve using linespace and allow overlap to vary
|
|
|
|
# signs height should be around 130px, however, length can be up to 300px
|
|
# -> overlap along lines (300px) should be larger than between lines (200px)
|
|
# -> this means for step sizes: w_step_sz < h_step_sz
|
|
inv_scale = 1. / scale
|
|
tile_shape = np.array(tile_shape) * inv_scale
|
|
border_sz *= inv_scale
|
|
w_step_sz *= inv_scale
|
|
h_step_sz *= inv_scale
|
|
|
|
tile_ol_w = tile_shape[0] - w_step_sz
|
|
tile_ol_h = tile_shape[0] - h_step_sz
|
|
w_list = np.arange(border_sz, imw - border_sz - tile_ol_w, step=w_step_sz)
|
|
h_list = np.arange(border_sz, imh - border_sz - tile_ol_h, step=h_step_sz)
|
|
|
|
# grid pts represent upper left corner of tile box
|
|
# tiles can be larger than image and need to be padded
|
|
XX, YY = np.meshgrid(w_list, h_list)
|
|
|
|
# compute bboxes
|
|
ul_corner = np.rint(np.stack([XX.ravel(), YY.ravel()], axis=1)).astype(int)
|
|
lr_corner = ul_corner + np.rint(tile_shape)
|
|
bboxes = np.hstack([ul_corner, lr_corner])
|
|
# make sure tiles inside image boundaries
|
|
bboxes = clip_boxes(bboxes, [imh, imw]) # [imh, imw] is correct order for this function
|
|
|
|
return bboxes, XX, YY
|
|
|
|
|
|
def bbox_ctr_overlaps(boxes1, boxes2):
|
|
# check for all combinations of boxes1 and boxes2 if ctrs of boxes2 are in boxes1
|
|
overlaps_mat = np.zeros([boxes1.shape[0], boxes2.shape[0]])
|
|
for ii, box in enumerate(boxes1):
|
|
x, y, x2, y2 = box
|
|
# check if center is still inside tile_box, otherwise ignore box
|
|
# if center is not inside tile box,
|
|
# not possible to get IoU >= 0.5 --> treated as background anyways
|
|
center = (boxes2[:, :2] + boxes2[:, 2:]) / 2
|
|
mask = (center[:, 0] >= x) & (center[:, 0] <= x2) \
|
|
& (center[:, 1] >= y) & (center[:, 1] <= y2)
|
|
overlaps_mat[ii, :] = mask
|
|
return overlaps_mat
|
|
|
|
|
|
# Cuneiform SSD dataset
|
|
|
|
|
|
class CuneiformSSD(data.Dataset):
|
|
|
|
def __init__(self, collections=['train'], gen_file_path=None, gen_collections=[], gen_folder=None, transform=None,
|
|
relative_path='../', use_balanced_idx=True, tile_shape=[600, 600], use_linemaps=False,
|
|
remove_empty_tiles=False, min_align_ratio=0.6, filter_nms=False, compl_thresh=-1, ncompl_thresh=-1,
|
|
num_top_ncompl=0, min_ncompl_thresh=10):
|
|
|
|
# merge multiple data sources in order to form a single dataset that can be used for SSD style detector training
|
|
# provides following function:
|
|
# f(idx) -> image, bboxes, labels
|
|
# or more general:
|
|
# f(idx) -> image, bboxes, labels, line_map
|
|
|
|
# join multiple levels of supervision: three cases for sign annotations
|
|
# 1) tablets completely annotated (no need to load line annotations nor line detections)
|
|
# 2) tablets partly annotated and line annotations available (no need to load line detections)
|
|
# 3) tablets partly annotated and line detections required
|
|
|
|
# transforms for data preparation
|
|
self.transform = transform
|
|
self.line_model_version = None
|
|
self.use_linemaps = use_linemaps
|
|
self.min_align_ratio = min_align_ratio
|
|
self.filter_nms = filter_nms
|
|
self.compl_thresh = compl_thresh
|
|
self.ncompl_thresh = ncompl_thresh
|
|
self.num_top_ncompl = num_top_ncompl
|
|
self.min_ncompl_thresh = min_ncompl_thresh
|
|
|
|
line_model_version = 'v007'
|
|
num_classes = 240
|
|
|
|
###################
|
|
# I) load generated and manual annotations
|
|
|
|
### load and prepare gen_df
|
|
# generated annotations may be based on multiple collections
|
|
|
|
gen_cols = ['imageName', 'folder', 'image_path', 'label', 'train_label',
|
|
'x1', 'y1', 'x2', 'y2', 'width', 'height', 'segm_idx',
|
|
'line_idx', 'pos_idx', 'det_score', 'm_score', 'align_ratio', 'nms_keep', 'compl', 'ncompl']
|
|
|
|
# OPT I : use csv file that contains list of generated boxes
|
|
if gen_file_path:
|
|
gen_file_path = "{}results{}".format(relative_path, gen_file_path)
|
|
gen_df = pd.read_csv(gen_file_path, engine='python', header=None, names=gen_cols)
|
|
# OPT II : load csv files for collection specific collections and concatenate
|
|
elif len(gen_collections) > 0:
|
|
assert gen_folder is not None, 'When using gen_collections, user needs to provide gen_model!'
|
|
df_list = []
|
|
for gen_coll in gen_collections:
|
|
gen_file_path = "{}results/{}line_generated_bboxes_refined80_{}.csv".format(relative_path, gen_folder, gen_coll)
|
|
# special delimiter because of legacy support, thanks to regex possible to support new and old formats
|
|
gen_df = pd.read_csv(gen_file_path, engine='python', delimiter=',\s*', header=None, names=gen_cols) #delimiter=', ',
|
|
df_list.append(gen_df)
|
|
gen_df = pd.concat(df_list, ignore_index=True)
|
|
|
|
# prepare gen_df
|
|
list_gen_collection = []
|
|
if gen_file_path or (len(gen_collections) > 0):
|
|
|
|
num_before_filter = len(gen_df)
|
|
# IMPORTANT: filter gen data according to align ratio
|
|
gen_df = gen_df[gen_df.align_ratio > self.min_align_ratio]
|
|
print('Align Ratio {} :: Removed {} samples. [{}]'.format(self.min_align_ratio, num_before_filter - len(gen_df), len(gen_df)))
|
|
num_before_filter = len(gen_df)
|
|
# only keep inlier classes [0-240] (only required when using null hypos)
|
|
gen_df = gen_df[gen_df.train_label < num_classes]
|
|
print('Class Range {} :: Removed {} samples. [{}]'.format(num_classes, num_before_filter - len(gen_df), len(gen_df)))
|
|
|
|
# IMPORTANT: fill nan values in a way that avoids filtering
|
|
gen_df.nms_keep = gen_df.nms_keep.fillna(1).astype(bool)
|
|
gen_df.compl = gen_df.compl.fillna(50)
|
|
gen_df.ncompl = gen_df.ncompl.fillna(100)
|
|
|
|
num_before_filter = len(gen_df)
|
|
if self.filter_nms:
|
|
# filter using nms
|
|
gen_df = gen_df[gen_df.nms_keep]
|
|
print('NMS :: Removed {} samples. [{}]'.format(num_before_filter - len(gen_df), len(gen_df)))
|
|
num_before_filter = len(gen_df)
|
|
|
|
select_topn = False
|
|
if self.num_top_ncompl > 0:
|
|
# find top 5 for each class with more relaxed ncompl condition
|
|
select_min_ncompl = (gen_df.ncompl > self.min_ncompl_thresh) # necessary condition
|
|
index_list = gen_df[select_min_ncompl].groupby('train_label').ncompl.nlargest(self.num_top_ncompl).index.values
|
|
select_topn = gen_df.index.isin(np.stack(index_list)[:, 1])
|
|
|
|
if self.compl_thresh > -1:
|
|
# filter using compl
|
|
gen_df = gen_df[gen_df.compl > self.compl_thresh] # 0, 2, 4, 5
|
|
print('Completeness {} :: Removed {} samples. [{}]'.format(self.compl_thresh, num_before_filter - len(gen_df), len(gen_df)))
|
|
elif self.ncompl_thresh > -1:
|
|
# filter using compl
|
|
gen_df = gen_df[(gen_df.ncompl > self.ncompl_thresh) | select_topn] # 0, 2, 4, 5
|
|
print('Completeness (norm.) {} :: Removed {} samples. [{}]'.format(self.ncompl_thresh, num_before_filter - len(gen_df), len(gen_df)))
|
|
print('class sample count stats: ')
|
|
print(gen_df.train_label.value_counts().describe())
|
|
|
|
# add additional columns
|
|
gen_df['collection'] = gen_df.folder.str.split('/').str[0]
|
|
gen_df['generated'] = True
|
|
gen_df['global_segm_idx'] = -1
|
|
gen_df['relative_bbox'] = gen_df[['x1', 'y1', 'x2', 'y2']].values.tolist()
|
|
gen_df['relative_bbox'] = gen_df['relative_bbox'].apply(np.array)
|
|
gen_df['mzl_label'] = gen_df['label']
|
|
gen_df['tablet_CDLI'] = gen_df['imageName']
|
|
|
|
# identify all collections with generated annotations
|
|
list_gen_collection = gen_df.collection.unique().tolist()
|
|
|
|
|
|
### load and prepare list_sign_anno_df
|
|
# manual annotation files may be based on multiple collections
|
|
# for each collection
|
|
# store in list_sign_anno_df
|
|
|
|
# load bbox annotations
|
|
list_anno_collections = []
|
|
sign_anno_df_list = []
|
|
for collection in collections:
|
|
# load sign annotations
|
|
annotation_file = '{}data/annotations/bbox_annotations_{}.csv'.format(relative_path, collection)
|
|
# ATTENTION: only use gt annotations if collection is provided in collections parameter
|
|
if os.path.exists(annotation_file):
|
|
sign_anno_df = pd.read_csv(annotation_file, engine='python') # read annotation file
|
|
# add additional columns
|
|
sign_anno_df['generated'] = False
|
|
sign_anno_df['global_segm_idx'] = -1
|
|
sign_anno_df['relative_bbox'] = sign_anno_df['relative_bbox'].apply(literal_eval)
|
|
sign_anno_df['relative_bbox'] = sign_anno_df['relative_bbox'].apply(np.array) # convert to ndarray
|
|
|
|
# only keep inlier classes [0-240]
|
|
class_outlier_select = sign_anno_df.train_label < num_classes
|
|
if np.any(class_outlier_select):
|
|
print('Drop {} outlier class samples from {}!'.format(np.sum(~class_outlier_select), collection))
|
|
sign_anno_df = sign_anno_df[class_outlier_select]
|
|
# slice sign_anno_df if there are multiple different collections contained
|
|
for sub_collection in sign_anno_df.collection.unique():
|
|
# store collection name
|
|
list_anno_collections.append(sub_collection)
|
|
# store collection specific slice of data frame
|
|
sub_sign_anno_df = sign_anno_df[sign_anno_df.collection == sub_collection]
|
|
sign_anno_df_list.append(sub_sign_anno_df)
|
|
|
|
|
|
### extend collections
|
|
# create list of elementary collections
|
|
collections_ext = np.unique(list_gen_collection + list_anno_collections).tolist()
|
|
#collections_ext
|
|
|
|
###################
|
|
# II) on collection level: load segments meta data and line annotation (optional)
|
|
|
|
### load segment, line
|
|
# for each collection
|
|
# store in segments_df_list, line_anno_df_list
|
|
|
|
# reduced set of columns - only keep what is needed and maintained
|
|
# segments_df_columns = ['tablet_CDLI', 'view_desc', 'padded_bbox', 'collection', 'line_scale', 'scale',
|
|
# 'im_path',
|
|
# 'num_dets_hd', 'num_signs_visible']
|
|
|
|
segments_df_columns = ['tablet_CDLI', 'view_desc', 'bbox', 'collection', 'scale', 'im_path']
|
|
|
|
segments_df_list = []
|
|
line_anno_df_list = []
|
|
for collection in collections_ext:
|
|
|
|
# load segment metadata
|
|
annotation_file = '{}data/segments/tablet_segments_{}.csv'.format(relative_path, collection)
|
|
tablet_segments_df = pd.read_csv(annotation_file, engine='python', index_col=0)
|
|
# convert string of list to list
|
|
tablet_segments_df['bbox'] = tablet_segments_df['bbox'].apply(literal_eval)
|
|
tablet_segments_df['bbox'] = tablet_segments_df['bbox'].apply(np.array) # convert to ndarray
|
|
# add additional columns
|
|
tablet_segments_df['imageName'] = tablet_segments_df['tablet_CDLI'] + '.jpg'
|
|
tablet_segments_df['im_path'] = '{}data/images/'.format(relative_path) + \
|
|
tablet_segments_df['collection'] + '/' + tablet_segments_df['imageName']
|
|
# get assigned segment (can be edited from outside without harm)
|
|
assigned_segments_df = tablet_segments_df[tablet_segments_df.assigned == True]
|
|
|
|
# load line annotations
|
|
annotation_file = '{}data/annotations/line_annotations_{}.csv'.format(relative_path, collection)
|
|
if os.path.exists(annotation_file):
|
|
line_anno_df = pd.read_csv(annotation_file, engine='python')
|
|
else:
|
|
line_anno_df = []
|
|
|
|
# collect data frames in lists
|
|
segments_df_list.append(assigned_segments_df[segments_df_columns])
|
|
line_anno_df_list.append(line_anno_df)
|
|
|
|
|
|
### assemble ssd_segments_df with new index
|
|
# search all segments with annotations
|
|
|
|
list_segments_df_anno = []
|
|
for collection in collections_ext:
|
|
coll_idx = collections_ext.index(collection)
|
|
#print(collection)
|
|
|
|
# get all segment indices for this collection that contain annotations
|
|
list_segm_indices = []
|
|
|
|
# if there are gt annotations
|
|
if collection in list_anno_collections:
|
|
anno_coll_idx = list_anno_collections.index(collection)
|
|
if len(sign_anno_df_list[anno_coll_idx]) > 0:
|
|
# load their indices
|
|
segm_indices_anno = sign_anno_df_list[anno_coll_idx].segm_idx.unique()
|
|
# filter annotations without assigned segment
|
|
segm_indices_anno = segm_indices_anno[segm_indices_anno >= 0]
|
|
list_segm_indices.append(segm_indices_anno)
|
|
|
|
# if there are generated annotations
|
|
if collection in list_gen_collection:
|
|
# select gen annotations by collection
|
|
col_gen_df = gen_df[gen_df.collection == collection]
|
|
# load their indices
|
|
segm_indices_anno = col_gen_df.segm_idx.unique()
|
|
list_segm_indices.append(segm_indices_anno)
|
|
|
|
# stack to obtain list of segment indices with annotations
|
|
segm_indices = np.unique(np.hstack(list_segm_indices))
|
|
|
|
# append only segments with anno
|
|
if len(segm_indices) > 0:
|
|
list_segments_df_anno.append(segments_df_list[coll_idx].loc[segm_indices])
|
|
|
|
# create new datasets ssd_segment_df
|
|
# concat dataframes and use reset_index to create column with old indices
|
|
ssd_segments_df = pd.concat(list_segments_df_anno).reset_index()
|
|
# rename column to segm_idx
|
|
ssd_segments_df.columns.values[0] = 'segm_idx'
|
|
|
|
|
|
###################
|
|
# III) on segment level: load data and prepare dataset index
|
|
|
|
### assemble ssd_sign_anno_df and update ssd_segments_df
|
|
# make sure all annos have relative_bbox
|
|
# additional column for ssd_sign_anno_df: global_segm_idx
|
|
# add two columns to ssd_segments_df: with num_anno, with_line_anno
|
|
# type of annotation: full, partly_w_line_anno, partly_w_line_dect
|
|
|
|
# sign_anno_df_cols = ['imageName', 'image_path', 'label', 'train_label', 'segm_idx', 'collection',
|
|
# 'generated', 'relative_bbox', 'global_segm_idx']
|
|
sign_anno_df_cols = ['tablet_CDLI', 'mzl_label', 'train_label', 'segm_idx', 'collection',
|
|
'generated', 'relative_bbox', 'global_segm_idx']
|
|
list_ssd_sign_anno_df = []
|
|
|
|
list_lines_annotated_per_segm = np.zeros(len(ssd_segments_df), dtype=bool)
|
|
list_num_anno_per_segm = np.zeros(len(ssd_segments_df), dtype=int)
|
|
|
|
# iterate over segments
|
|
for global_seg_idx, seg_rec in ssd_segments_df.iterrows():
|
|
image_name, scale, seg_bbox, image_path, view_desc = get_segment_meta(seg_rec)
|
|
res_name = "{}{}".format(image_name, view_desc)
|
|
collection = seg_rec.collection
|
|
segm_idx = seg_rec.segm_idx
|
|
coll_idx = collections_ext.index(collection)
|
|
|
|
### if annotations available for segment, append to list
|
|
if collection in list_anno_collections:
|
|
anno_coll_idx = list_anno_collections.index(collection)
|
|
if len(sign_anno_df_list[anno_coll_idx]) > 0:
|
|
sign_anno_df = sign_anno_df_list[anno_coll_idx]
|
|
# select sign annos for segment
|
|
segm_select = sign_anno_df.segm_idx == segm_idx
|
|
|
|
if len(sign_anno_df[segm_select]) > 0:
|
|
# update data frame column
|
|
sign_anno_df.loc[segm_select, 'global_segm_idx'] = global_seg_idx
|
|
# collect information
|
|
sign_anno_seg = sign_anno_df[segm_select]
|
|
list_num_anno_per_segm[global_seg_idx] = len(sign_anno_seg)
|
|
list_ssd_sign_anno_df.append(sign_anno_seg[sign_anno_df_cols])
|
|
|
|
### if generated annotations available, append to list
|
|
if collection in list_gen_collection:
|
|
# select sign annos for segment AND collection
|
|
segm_select = (gen_df.segm_idx == segm_idx) & (gen_df.collection == seg_rec.collection)
|
|
if len(gen_df[segm_select]) > 0:
|
|
# update data frame columns
|
|
gen_df.loc[segm_select, 'global_segm_idx'] = global_seg_idx
|
|
# compute relative_bbox
|
|
relative_boxes = gen_df[segm_select].relative_bbox.apply(
|
|
lambda x: np.rint(convert_bbox_global2local(x, list(seg_bbox))).astype(int))
|
|
gen_df.loc[segm_select, 'relative_bbox'] = relative_boxes
|
|
|
|
# collect information
|
|
sign_anno_seg = gen_df[segm_select]
|
|
list_num_anno_per_segm[global_seg_idx] = len(sign_anno_seg)
|
|
list_ssd_sign_anno_df.append(sign_anno_seg[sign_anno_df_cols])
|
|
|
|
### check for line annotations
|
|
if len(line_anno_df_list[coll_idx]) > 0:
|
|
line_anno_df = line_anno_df_list[coll_idx]
|
|
# select line annos for segment
|
|
segm_select = line_anno_df.segm_idx == segm_idx
|
|
# if there are line annotations for segment
|
|
if len(line_anno_df[segm_select]) > 0:
|
|
# assume all lines are annotated and remember type of line data
|
|
list_lines_annotated_per_segm[global_seg_idx] = True
|
|
|
|
# add columns to ssd_segments_df
|
|
ssd_segments_df['num_anno'] = np.array(list_num_anno_per_segm)
|
|
ssd_segments_df['with_line_anno'] = list_lines_annotated_per_segm
|
|
|
|
# assemble ssd_sign_anno_df (drop old index)
|
|
ssd_sign_anno_df = pd.concat(list_ssd_sign_anno_df, ignore_index=True)
|
|
|
|
# this is deprecated, since bug fix
|
|
#assert np.sum(ssd_sign_anno_df.groupby('global_segm_idx').collection.nunique() > 1) == 0
|
|
|
|
###################
|
|
# IV) Preload: segment images and line detections
|
|
|
|
|
|
### preload segment images
|
|
# crop segment and convert to gray scale
|
|
# IMPORTANT: preload segment crops (without scaling, because memory)
|
|
|
|
image_data_list = []
|
|
|
|
# iterate over segments
|
|
for global_seg_idx, seg_rec in tqdm(ssd_segments_df.iterrows(), total=len(ssd_segments_df)):
|
|
image_name, scale, seg_bbox, image_path, view_desc = get_segment_meta(seg_rec)
|
|
res_name = "{}{}".format(image_name, view_desc)
|
|
|
|
# load composite image
|
|
pil_im = Image.open(image_path)
|
|
# crop segment
|
|
tablet_seg, new_bbox = crop_segment_from_tablet_im(pil_im, seg_bbox)
|
|
# convert to gray scale and store in list
|
|
image_data_list.append(tablet_seg.convert('L'))
|
|
|
|
|
|
|
|
### preload line detections
|
|
# could pre-compute line annotations->line map
|
|
# this is a speed memory trade-off
|
|
|
|
line_detection_dict = {}
|
|
line_map_dict = {}
|
|
|
|
# only required if there are any generated detections
|
|
if self.use_linemaps:
|
|
|
|
# iterate over segments
|
|
for global_seg_idx, seg_rec in tqdm(ssd_segments_df.iterrows(), total=len(ssd_segments_df)):
|
|
image_name, scale, seg_bbox, image_path, view_desc = get_segment_meta(seg_rec)
|
|
res_name = "{}{}".format(image_name, view_desc)
|
|
# get collection idx
|
|
coll_idx = collections_ext.index(seg_rec.collection)
|
|
# get seg image shape
|
|
input_shape = np.array(image_data_list[global_seg_idx].size[::-1])
|
|
|
|
# if annotations are generated, need to create line map
|
|
#if seg_rec.collection in list_gen_collection:
|
|
|
|
# if no line annotations available
|
|
if True: # ALWAYS use generated annotations not seg_rec.with_line_anno: # if seg_rec.collection != 'train'
|
|
# either skeleton or lbl_ind
|
|
line_res_path = "{}results/results_line/{}/{}".format(relative_path, line_model_version, seg_rec.collection)
|
|
lines_file = "{}/{}_lbl_ind.npy".format(line_res_path, res_name)
|
|
# lines_file = "{}/{}_skeleton.npy".format(line_res_path, res_name)
|
|
lbl_ind_x = np.load(lines_file).astype(int)
|
|
# store in dictionary
|
|
line_detection_dict[global_seg_idx] = lbl_ind_x
|
|
|
|
# create line map from detections -> PIL binary
|
|
lbl_im = create_line_map_from_line_det(line_detection_dict, global_seg_idx, scale, input_shape)
|
|
|
|
else:
|
|
# create line map from line annotations -> PIL binary
|
|
lbl_im = create_line_map_from_line_anno(line_anno_df_list, coll_idx, seg_rec.segm_idx, input_shape)
|
|
|
|
# resize to image size (do here or in next iter
|
|
# lbl_im = lbl_im.resize(input_shape[::-1])
|
|
|
|
# store in dictionary
|
|
line_map_dict[global_seg_idx] = lbl_im
|
|
|
|
|
|
###################
|
|
# V) Tiling
|
|
|
|
### compute ssd_tile_df
|
|
list_tile_boxes = []
|
|
list_tile_support = []
|
|
list_tile_seg_idx = []
|
|
|
|
# iterate over segments
|
|
for global_seg_idx, seg_rec in tqdm(ssd_segments_df.iterrows(), total=len(ssd_segments_df)):
|
|
image_name, scale, seg_bbox, image_path, view_desc = get_segment_meta(seg_rec)
|
|
res_name = "{}{}".format(image_name, view_desc)
|
|
|
|
## compute tiles
|
|
# get segment shape
|
|
imw, imh = image_data_list[global_seg_idx].size
|
|
# compute tile boxes
|
|
tile_boxes, _, _ = compute_tiles(imw, imh, scale, tile_shape=tile_shape)
|
|
# append
|
|
list_tile_boxes.append(tile_boxes)
|
|
list_tile_seg_idx.append([global_seg_idx] * len(tile_boxes))
|
|
|
|
## check overlap of tile boxes and sign boxes
|
|
# get annotations
|
|
seg_sign_annos = ssd_sign_anno_df[ssd_sign_anno_df.global_segm_idx == global_seg_idx]
|
|
sign_bboxes = np.stack(seg_sign_annos.relative_bbox.values)
|
|
|
|
# OPT I: compute IOU
|
|
# tiles_sign_iou = bbox_overlaps(tile_boxes.astype(float), sign_bboxes.astype(float))
|
|
# tile_support = np.sum(tiles_sign_iou > 0.005, axis=1) # 0.01 or 0.005
|
|
|
|
# OPT II: compute ctr overlap (strict)
|
|
tiles_sign_ctrs = bbox_ctr_overlaps(tile_boxes.astype(float), sign_bboxes.astype(float))
|
|
tile_support = np.sum(tiles_sign_ctrs, axis=1).astype(int)
|
|
list_tile_support.append(tile_support)
|
|
|
|
# stack tile boxes
|
|
tile_boxes_arr = np.vstack(list_tile_boxes)
|
|
tile_global_seg_idx = np.hstack(list_tile_seg_idx).astype(int)
|
|
tile_support_arr = np.hstack(list_tile_support)
|
|
|
|
# create tile_df
|
|
tile_df = pd.DataFrame({'global_segm_idx': tile_global_seg_idx,
|
|
'tile_bbox': tile_boxes_arr.tolist(),
|
|
'num_anno': tile_support_arr})
|
|
|
|
# OPTIONAL: filter tiles with little support
|
|
if remove_empty_tiles and not use_balanced_idx:
|
|
tile_df = tile_df[tile_df.num_anno > 0] # 0
|
|
tile_df.reset_index(drop=True)
|
|
|
|
###################
|
|
# VI) Dataset index
|
|
|
|
## Balance sampling of tiles with anno per tile
|
|
# create an dataset index which is proportional to annotations per tile
|
|
# attention: tiles without support will be ignored!
|
|
use_balanced_idx = use_balanced_idx # good for debug
|
|
|
|
# 1) get tile factors
|
|
tile_factors = tile_df.num_anno.values
|
|
# 2) compute list to sample from
|
|
if use_balanced_idx:
|
|
sample2tile_list = []
|
|
for ii, tile_factor in enumerate(tile_factors):
|
|
sample2tile_list.extend([ii] * tile_factor)
|
|
else:
|
|
sample2tile_list = tile_df.index.values
|
|
|
|
###################
|
|
# attach resulting data structures to class
|
|
self.collections = collections
|
|
self.collections_ext = collections_ext
|
|
|
|
self.ssd_segments_df = ssd_segments_df
|
|
self.ssd_sign_anno_df = ssd_sign_anno_df
|
|
self.tile_df = tile_df
|
|
|
|
self.image_data_list = image_data_list
|
|
# self.line_detection_dict = line_detection_dict
|
|
self.line_map_dict = line_map_dict
|
|
|
|
self.line_anno_df_list = line_anno_df_list
|
|
# self.sign_anno_df_list = sign_anno_df_list
|
|
# self.segments_df_list = segments_df_list
|
|
|
|
self.sample2tile_list = sample2tile_list
|
|
|
|
# setup finished
|
|
print("Setup dataset spanning {} collections with {} annotations [{} segments, {} tiles, {} indices]".format(
|
|
len(collections_ext), len(ssd_sign_anno_df), len(ssd_segments_df), len(tile_df), len(sample2tile_list)))
|
|
|
|
def __getitem__(self, index):
|
|
# get tile
|
|
tile_index = self.sample2tile_list[index]
|
|
tile_rec = self.tile_df.loc[tile_index]
|
|
tile_bbox = tile_rec.tile_bbox
|
|
|
|
# get segment
|
|
global_seg_idx = tile_rec.global_segm_idx
|
|
seg_rec = self.ssd_segments_df.loc[global_seg_idx]
|
|
coll_idx = self.collections_ext.index(seg_rec.collection)
|
|
|
|
# load segment meta data
|
|
image_name, scale, seg_bbox, path_to_image, view_desc = get_segment_meta(seg_rec)
|
|
with_line_anno = seg_rec.with_line_anno
|
|
|
|
# get segment image
|
|
pil_im = self.image_data_list[global_seg_idx]
|
|
|
|
# get sign annos
|
|
select_segm = self.ssd_sign_anno_df.global_segm_idx == global_seg_idx
|
|
segm_annos = self.ssd_sign_anno_df[select_segm]
|
|
seg_boxes = np.stack(segm_annos.relative_bbox)
|
|
labels = segm_annos.train_label.values
|
|
are_generated = segm_annos.generated.any()
|
|
|
|
# OPT II: tensor functions adapted from kuangliu's code
|
|
# https://github.com/kuangliu/torchcv/tree/master/torchcv/transforms
|
|
|
|
# convert to torch tensors
|
|
seg_boxes = torch.from_numpy(seg_boxes).float()
|
|
labels = torch.from_numpy(labels)
|
|
|
|
if self.use_linemaps:
|
|
|
|
if are_generated:
|
|
# incomplete annotations -> use line detections to avoid false negatives
|
|
lbl_im = self.line_map_dict[global_seg_idx]
|
|
# resize to crop
|
|
lbl_im = lbl_im.resize(pil_im.size)
|
|
else:
|
|
# assume all ground truth signs are annotated
|
|
# provide dummy label map
|
|
lbl_im = Image.new('1', pil_im.size, 0)
|
|
|
|
if False:
|
|
from skimage.color import label2rgb
|
|
import matplotlib.pyplot as plt
|
|
|
|
plt.figure(figsize=(10, 10))
|
|
# plt.imshow(lbl_ind)
|
|
plt.imshow(label2rgb(np.asarray(lbl_im), np.asarray(pil_im)))
|
|
plt.show()
|
|
|
|
# crop tile
|
|
# print pil_im.size, seg_boxes.shape, labels.shape, tile_bbox
|
|
im, boxes, labels, linemap = crop_box_lm(pil_im, seg_boxes, labels, lbl_im, tile_bbox)
|
|
# scale tile
|
|
im, boxes, linemap = resize_lm(im, boxes, linemap, None, scale=scale)
|
|
|
|
# apply augmentation pipeline and convert from PIL to numpy
|
|
if self.transform is not None:
|
|
im, boxes, labels, linemap = self.transform(im, boxes, labels, linemap)
|
|
|
|
return im, boxes, labels, linemap
|
|
|
|
else:
|
|
|
|
# crop tile
|
|
#print pil_im.size, seg_boxes.shape, labels.shape, tile_bbox
|
|
im, boxes, labels = crop_box(pil_im, seg_boxes, labels, tile_bbox)
|
|
# scale tile
|
|
im, boxes = resize(im, boxes, None, scale=scale)
|
|
|
|
# apply augmentation pipeline and convert from PIL to numpy
|
|
if self.transform is not None:
|
|
im, boxes, labels = self.transform(im, boxes, labels)
|
|
|
|
return im, boxes, labels
|
|
|
|
def __len__(self):
|
|
return len(self.sample2tile_list)
|
|
|
|
|
|
# helper functions
|
|
|
|
def create_line_map_from_line_anno(line_anno_df_list, coll_idx, segm_idx, input_shape):
|
|
line_height = 3
|
|
|
|
# select line annotations
|
|
line_anno_df = line_anno_df_list[coll_idx]
|
|
seg_line_df = line_anno_df[line_anno_df.segm_idx == segm_idx]
|
|
# # collect all line coordinates
|
|
rr, cc, lbboxes = collect_line_coords(seg_line_df, scale=1 / 16.)
|
|
# compute line trafo
|
|
line_trafo = create_line_trafo(rr, cc, input_shape / 16)
|
|
# # compute masks
|
|
line_mask = line_trafo < line_height
|
|
# convert to binary PIL image
|
|
lbl_im = convert2binaryPIL(line_mask)
|
|
|
|
return lbl_im
|
|
|
|
|
|
def create_line_map_from_line_det(line_detection_dict, global_seg_idx, scale, input_shape):
|
|
# get line detection
|
|
lbl_ind = line_detection_dict[global_seg_idx]
|
|
# compute line map
|
|
lbl_ind = compute_image_label_map(lbl_ind, np.array(input_shape * scale, dtype=int), padding=5) # default:16, other padding=16 20 24
|
|
# convert to binary PIL image
|
|
lbl_im = convert2binaryPIL(lbl_ind)
|
|
|
|
return lbl_im
|
|
|
|
|
|
# run test
|
|
def test(collections=['train'], gen_collections=[], gen_folder=None, use_balanced_idx=True, use_linemaps=False,
|
|
remove_empty_tiles=False, min_align_ratio=0.2, relative_path='../../'):
|
|
ssd_dataset = CuneiformSSD(collections=collections, gen_file_path=None, gen_collections=gen_collections,
|
|
gen_folder=gen_folder, relative_path=relative_path,
|
|
use_balanced_idx=use_balanced_idx, tile_shape=[600, 600], use_linemaps=use_linemaps,
|
|
remove_empty_tiles=remove_empty_tiles, min_align_ratio=min_align_ratio)
|
|
return ssd_dataset
|