feat: CNTK Export Provider (#771)

Adds CNTK export provider into v2

Resolves #754
Esse commit está contido em:
Wallace Breza
2019-04-17 16:43:31 -07:00
commit c10c971caf
15 arquivos alterados com 353 adições e 28 exclusões
+7 -4
Ver Arquivo
@@ -287,6 +287,10 @@ export const english: IAppStrings = {
tagged: "Only tagged Assets",
},
},
testTrainSplit: {
title: "Test / Train Split",
description: "The test train split to use for exported data",
},
},
},
vottJson: {
@@ -344,15 +348,14 @@ export const english: IAppStrings = {
},
pascalVoc: {
displayName: "Pascal VOC",
testTrainSplit: {
title: "Test / Train Split",
description: "The test train split to use for exported data",
},
exportUnassigned: {
title: "Export Unassigned",
description: "Whether or not to include unassigned tags in exported data",
},
},
cntk: {
displayName: "Microsoft Cognitive Toolkit (CNTK)",
},
},
messages: {
saveSuccess: "Successfully saved export settings",
+7 -4
Ver Arquivo
@@ -289,6 +289,10 @@ export const spanish: IAppStrings = {
tagged: "Solo activos etiquetados",
},
},
testTrainSplit: {
title: "La división para entrenar y comprobar",
description: "La división de datos para utilizar entre el entrenamiento y la comprobación",
},
},
},
vottJson: {
@@ -346,15 +350,14 @@ export const spanish: IAppStrings = {
},
pascalVoc: {
displayName: "Pascal VOC",
testTrainSplit: {
title: "Prueba/tren Split",
description: "La división del tren de prueba que se utilizará para los datos exportados",
},
exportUnassigned: {
title: "Exportar sin asignar",
description: "Si se incluyen o no etiquetas no asignadas en los datos exportados",
},
},
cntk: {
displayName: "Microsoft Cognitive Toolkit (CNTK)",
},
},
messages: {
saveSuccess: "Configuración de exportación guardada correctamente",
+7 -4
Ver Arquivo
@@ -285,6 +285,10 @@ export interface IAppStrings {
tagged: string,
},
},
testTrainSplit: {
title: string,
description: string,
},
},
},
vottJson: {
@@ -342,15 +346,14 @@ export interface IAppStrings {
},
pascalVoc: {
displayName: string,
testTrainSplit: {
title: string,
description: string,
},
exportUnassigned: {
title: string,
description: string,
},
},
cntk: {
displayName: string,
},
},
messages: {
saveSuccess: string;
+1 -1
Ver Arquivo
@@ -6,7 +6,7 @@ import { ExportProviderFactory } from "./exportProviderFactory";
import MockFactory from "../../common/mockFactory";
import {
IProject, AssetState, IAsset, IAssetMetadata,
RegionType, IRegion, IExportProviderOptions, AssetType,
RegionType, IRegion, IExportProviderOptions,
} from "../../models/applicationState";
import { ExportAssetState } from "./exportProvider";
jest.mock("./azureCustomVision/azureCustomVisionService");
+28
Ver Arquivo
@@ -0,0 +1,28 @@
{
"type": "object",
"title": "${strings.export.providers.cntk.displayName}",
"properties": {
"assetState": {
"type": "string",
"title": "${strings.export.providers.common.properties.assetState.title}",
"description": "${strings.export.providers.common.properties.assetState.description}",
"enum": [
"all",
"visited",
"tagged"
],
"default": "visited",
"enumNames": [
"${strings.export.providers.common.properties.assetState.options.all}",
"${strings.export.providers.common.properties.assetState.options.visited}",
"${strings.export.providers.common.properties.assetState.options.tagged}"
]
},
"testTrainSplit": {
"title": "${strings.export.providers.common.properties.testTrainSplit.title}",
"description": "${strings.export.providers.common.properties.testTrainSplit.description}",
"type": "number",
"default": 80
}
}
}
+167
Ver Arquivo
@@ -0,0 +1,167 @@
import _ from "lodash";
import os from "os";
import { CntkExportProvider, ICntkExportProviderOptions } from "./cntk";
import { IProject, AssetState, IAssetMetadata } from "../../models/applicationState";
import { AssetProviderFactory } from "../storage/assetProviderFactory";
import { ExportAssetState } from "./exportProvider";
import MockFactory from "../../common/mockFactory";
import registerMixins from "../../registerMixins";
import registerProviders from "../../registerProviders";
import { ExportProviderFactory } from "./exportProviderFactory";
jest.mock("../../services/assetService");
import { AssetService } from "../../services/assetService";
jest.mock("../storage/localFileSystemProxy");
import { LocalFileSystemProxy } from "../storage/localFileSystemProxy";
import HtmlFileReader from "../../common/htmlFileReader";
import { appInfo } from "../../common/appInfo";
describe("CNTK Export Provider", () => {
const testAssets = MockFactory.createTestAssets(10, 1);
let testProject: IProject = null;
const defaultOptions: ICntkExportProviderOptions = {
assetState: ExportAssetState.Tagged,
testTrainSplit: 80,
};
function createProvider(project: IProject): CntkExportProvider {
return new CntkExportProvider(
project,
project.exportFormat.providerOptions as ICntkExportProviderOptions,
);
}
beforeAll(() => {
registerMixins();
registerProviders();
HtmlFileReader.getAssetBlob = jest.fn(() => {
return Promise.resolve(new Blob(["Some binary data"]));
});
});
beforeEach(() => {
jest.resetAllMocks();
testAssets.forEach((asset) => {
asset.state = AssetState.Tagged;
});
testProject = {
...MockFactory.createTestProject("TestProject"),
assets: _.keyBy(testAssets, (a) => a.id),
exportFormat: {
providerType: "cntk",
providerOptions: defaultOptions,
},
};
AssetProviderFactory.create = jest.fn(() => {
return {
getAssets: jest.fn(() => Promise.resolve(testAssets)),
};
});
const assetServiceMock = AssetService as jest.Mocked<typeof AssetService>;
assetServiceMock.prototype.getAssetMetadata = jest.fn((asset) => {
const assetMetadata = {
asset: { ...asset },
regions: [
MockFactory.createTestRegion("region-1", ["tag1"]),
MockFactory.createTestRegion("region-2", ["tag1"]),
],
version: appInfo.version,
};
return Promise.resolve(assetMetadata);
});
});
it("Is defined", () => {
expect(CntkExportProvider).toBeDefined();
});
it("Can be instantiated through the factory", () => {
const options: ICntkExportProviderOptions = {
assetState: ExportAssetState.All,
testTrainSplit: 80,
};
const exportProvider = ExportProviderFactory.create("cntk", testProject, options);
expect(exportProvider).not.toBeNull();
expect(exportProvider).toBeInstanceOf(CntkExportProvider);
});
it("Creates correct folder structure", async () => {
const provider = createProvider(testProject);
await provider.export();
const storageProviderMock = LocalFileSystemProxy as any;
const createContainerCalls = storageProviderMock.mock.instances[0].createContainer.mock.calls;
const createContainerArgs = createContainerCalls.map((args) => args[0]);
const expectedFolderPath = "Project-TestProject-CNTK-export";
expect(createContainerArgs).toContain(expectedFolderPath);
expect(createContainerArgs).toContain(`${expectedFolderPath}/positive`);
expect(createContainerArgs).toContain(`${expectedFolderPath}/negative`);
expect(createContainerArgs).toContain(`${expectedFolderPath}/testImages`);
});
it("Writes export files to storage provider", async () => {
const provider = createProvider(testProject);
const getAssetsSpy = jest.spyOn(provider, "getAssetsForExport");
await provider.export();
const assetsToExport = await getAssetsSpy.mock.results[0].value;
const testSplit = (100 - (defaultOptions.testTrainSplit || 80)) / 100;
const testCount = Math.ceil(assetsToExport.length * testSplit);
const testArray = assetsToExport.slice(0, testCount);
const trainArray = assetsToExport.slice(testCount, assetsToExport.length);
const storageProviderMock = LocalFileSystemProxy as any;
const writeBinaryCalls = storageProviderMock.mock.instances[0].writeBinary.mock.calls;
const writeTextFileCalls = storageProviderMock.mock.instances[0].writeText.mock.calls;
expect(writeBinaryCalls).toHaveLength(testAssets.length);
expect(writeTextFileCalls).toHaveLength(testAssets.length * 2);
testArray.forEach((assetMetadata) => {
const testFolderPath = "Project-TestProject-CNTK-export/testImages";
assertExportedAsset(testFolderPath, assetMetadata);
});
trainArray.forEach((assetMetadata) => {
const trainFolderPath = "Project-TestProject-CNTK-export/positive";
assertExportedAsset(trainFolderPath, assetMetadata);
});
});
function assertExportedAsset(folderPath: string, assetMetadata: IAssetMetadata) {
const storageProviderMock = LocalFileSystemProxy as any;
const writeBinaryCalls = storageProviderMock.mock.instances[0].writeBinary.mock.calls;
const writeBinaryFilenameArgs = writeBinaryCalls.map((args) => args[0]);
const writeTextFileCalls = storageProviderMock.mock.instances[0].writeText.mock.calls;
const writeTextFilenameArgs = writeTextFileCalls.map((args) => args[0]);
expect(writeBinaryFilenameArgs).toContain(`${folderPath}/${assetMetadata.asset.name}`);
expect(writeTextFilenameArgs).toContain(`${folderPath}/${assetMetadata.asset.name}.bboxes.labels.tsv`);
expect(writeTextFilenameArgs).toContain(`${folderPath}/${assetMetadata.asset.name}.bboxes.tsv`);
const writeLabelsCall = writeTextFileCalls
.find((args: string[]) => args[0].indexOf(`${assetMetadata.asset.name}.bboxes.labels.tsv`) >= 0);
const writeBoxesCall = writeTextFileCalls
.find((args: string[]) => args[0].indexOf(`${assetMetadata.asset.name}.bboxes.tsv`) >= 0);
const expectedLabelData = `${assetMetadata.regions[0].tags[0]}${os.EOL}${assetMetadata.regions[1].tags[0]}`;
expect(writeLabelsCall[1]).toEqual(expectedLabelData);
const expectedBoxData = [];
// tslint:disable-next-line:max-line-length
expectedBoxData.push(`${assetMetadata.regions[0].boundingBox.left}\t${assetMetadata.regions[0].boundingBox.left + assetMetadata.regions[0].boundingBox.width}\t${assetMetadata.regions[0].boundingBox.top}\t${assetMetadata.regions[0].boundingBox.top + assetMetadata.regions[0].boundingBox.height}`);
// tslint:disable-next-line:max-line-length
expectedBoxData.push(`${assetMetadata.regions[1].boundingBox.left}\t${assetMetadata.regions[1].boundingBox.left + assetMetadata.regions[1].boundingBox.width}\t${assetMetadata.regions[1].boundingBox.top}\t${assetMetadata.regions[1].boundingBox.top + assetMetadata.regions[1].boundingBox.height}`);
expect(writeBoxesCall[1]).toEqual(expectedBoxData.join(os.EOL));
}
});
+105
Ver Arquivo
@@ -0,0 +1,105 @@
import os from "os";
import { ExportProvider, IExportResults } from "./exportProvider";
import { IAssetMetadata, IExportProviderOptions, IProject } from "../../models/applicationState";
import HtmlFileReader from "../../common/htmlFileReader";
import Guard from "../../common/guard";
enum ExportSplit {
Test,
Train,
}
/**
* Export options for CNTK export provider
*/
export interface ICntkExportProviderOptions extends IExportProviderOptions {
/** The test / train split ratio for exporting data */
testTrainSplit?: number;
}
/**
* CNTK Export provider
*/
export class CntkExportProvider extends ExportProvider<ICntkExportProviderOptions> {
private exportFolderName: string;
constructor(project: IProject, options: ICntkExportProviderOptions) {
super(project, options);
Guard.null(options);
this.exportFolderName = `${this.project.name.replace(/\s/g, "-")}-CNTK-export`;
}
public async export(): Promise<IExportResults> {
await this.createFolderStructure();
const assetsToExport = await this.getAssetsForExport();
const testSplit = (100 - (this.options.testTrainSplit || 80)) / 100;
const testCount = Math.ceil(assetsToExport.length * testSplit);
const testArray = assetsToExport.slice(0, testCount);
const results = await assetsToExport.mapAsync(async (assetMetadata) => {
try {
const exportSplit = testArray.find((am) => am.asset.id === assetMetadata.asset.id)
? ExportSplit.Test
: ExportSplit.Train;
await this.exportAssetFrame(assetMetadata, exportSplit);
return {
asset: assetMetadata,
success: true,
};
} catch (e) {
return {
asset: assetMetadata,
success: false,
error: e,
};
}
});
return {
completed: results.filter((r) => r.success),
errors: results.filter((r) => !r.success),
count: results.length,
};
}
private async exportAssetFrame(assetMetadata: IAssetMetadata, exportSplit: ExportSplit) {
const labelData = [];
const boundingBoxData = [];
assetMetadata.regions.forEach((region) => {
region.tags.forEach((tagName) => {
labelData.push(tagName);
// tslint:disable-next-line:max-line-length
boundingBoxData.push(`${region.boundingBox.left}\t${region.boundingBox.left + region.boundingBox.width}\t${region.boundingBox.top}\t${region.boundingBox.top + region.boundingBox.height}`);
});
});
const folderName = exportSplit === ExportSplit.Train ? "positive" : "testImages";
const labelsPath = `${this.exportFolderName}/${folderName}/${assetMetadata.asset.name}.bboxes.labels.tsv`;
const boundingBoxPath = `${this.exportFolderName}/${folderName}/${assetMetadata.asset.name}.bboxes.tsv`;
const binaryPath = `${this.exportFolderName}/${folderName}/${assetMetadata.asset.name}`;
const buffer = await HtmlFileReader.getAssetArray(assetMetadata.asset);
await Promise.all([
this.storageProvider.writeText(labelsPath, labelData.join(os.EOL)),
this.storageProvider.writeText(boundingBoxPath, boundingBoxData.join(os.EOL)),
this.storageProvider.writeBinary(binaryPath, Buffer.from(buffer)),
]);
}
private async createFolderStructure(): Promise<void> {
const positiveFolder = `${this.exportFolderName}/positive`;
const negativeFolder = `${this.exportFolderName}/negative`;
const testImagesFolder = `${this.exportFolderName}/testImages`;
await this.storageProvider.createContainer(this.exportFolderName);
await [positiveFolder, negativeFolder, testImagesFolder]
.forEachAsync(async (folderPath) => {
await this.storageProvider.createContainer(folderPath);
});
}
}
+5
Ver Arquivo
@@ -0,0 +1,5 @@
{
"testTrainSplit": {
"ui:widget": "slider"
}
}
+2 -2
Ver Arquivo
@@ -19,8 +19,8 @@
]
},
"testTrainSplit": {
"title": "${strings.export.providers.pascalVoc.testTrainSplit.title}",
"description": "${strings.export.providers.pascalVoc.testTrainSplit.description}",
"title": "${strings.export.providers.common.properties.testTrainSplit.title}",
"description": "${strings.export.providers.common.properties.testTrainSplit.description}",
"type": "number",
"default": 80
},
+2 -3
Ver Arquivo
@@ -1,11 +1,10 @@
import _ from "lodash";
import { ExportProvider } from "./exportProvider";
import { IProject, IAssetMetadata, RegionType, ITag, IExportProviderOptions } from "../../models/applicationState";
import { IProject, IAssetMetadata, ITag, IExportProviderOptions } from "../../models/applicationState";
import Guard from "../../common/guard";
import HtmlFileReader from "../../common/htmlFileReader";
import { itemTemplate, annotationTemplate, objectTemplate } from "./pascalVOC/pascalVOCTemplates";
import { interpolate } from "../../common/strings";
import { PlatformType } from "../../common/hostProcess";
import os from "os";
interface IObjectInfo {
@@ -53,7 +52,7 @@ export class PascalVOCExportProvider extends ExportProvider<IPascalVOCExportProv
exportObject.assets = _.keyBy(allAssets, (assetMetadata) => assetMetadata.asset.id);
// Create Export Folder
const exportFolderName = `${this.project.name.replace(" ", "-")}-PascalVOC-export`;
const exportFolderName = `${this.project.name.replace(/\s/g, "-")}-PascalVOC-export`;
await this.storageProvider.createContainer(exportFolderName);
await this.exportImages(exportFolderName, allAssets);
+1 -1
Ver Arquivo
@@ -41,7 +41,7 @@ export class TFRecordsExportProvider extends ExportProvider {
exportObject.assets = _.keyBy(allAssets, (assetMetadata) => assetMetadata.asset.id);
// Create Export Folder
const exportFolderName = `${this.project.name.replace(" ", "-")}-TFRecords-export`;
const exportFolderName = `${this.project.name.replace(/\s/g, "-")}-TFRecords-export`;
await this.storageProvider.createContainer(exportFolderName);
await this.exportPBTXT(exportFolderName, this.project);
+1 -1
Ver Arquivo
@@ -3,7 +3,7 @@ import { VottJsonExportProvider, IVottJsonExportProviderOptions } from "./vottJs
import registerProviders from "../../registerProviders";
import { ExportAssetState } from "./exportProvider";
import { ExportProviderFactory } from "./exportProviderFactory";
import { IProject, IAssetMetadata, AssetState, IExportProviderOptions } from "../../models/applicationState";
import { IProject, IAssetMetadata, AssetState } from "../../models/applicationState";
import MockFactory from "../../common/mockFactory";
jest.mock("../../services/assetService");
+1
Ver Arquivo
@@ -96,6 +96,7 @@ describe("Project Redux Actions", () => {
providerType: "vottJson",
providerOptions: {
assetState: ExportAssetState.Visited,
includeImages: true,
},
});
});
+13 -8
Ver Arquivo
@@ -15,6 +15,8 @@ import { createAction, createPayloadAction, IPayloadAction } from "./actionCreat
import { ExportAssetState, IExportResults } from "../../providers/export/exportProvider";
import { appInfo } from "../../common/appInfo";
import { strings } from "../../common/strings";
import { IExportFormat } from "vott-react";
import { IVottJsonExportProviderOptions } from "../../providers/export/vottJson";
/**
* Actions to be performed in relation to projects
@@ -76,11 +78,14 @@ export function saveProject(project: IProject)
throw new AppError(ErrorCode.SecurityTokenNotFound, "Security Token Not Found");
}
const defaultExportFormat = {
const defaultExportProviderOptions: IVottJsonExportProviderOptions = {
assetState: ExportAssetState.Visited,
includeImages: true,
};
const defaultExportFormat: IExportFormat = {
providerType: "vottJson",
providerOptions: {
assetState: ExportAssetState.Visited,
},
providerOptions: defaultExportProviderOptions,
};
const newProject = {
@@ -130,7 +135,7 @@ export function deleteProject(project: IProject)
*/
export function closeProject(): (dispatch: Dispatch) => void {
return (dispatch: Dispatch): void => {
dispatch({type: ActionTypes.CLOSE_PROJECT_SUCCESS});
dispatch({ type: ActionTypes.CLOSE_PROJECT_SUCCESS });
};
}
@@ -159,7 +164,7 @@ export function loadAssetMetadata(project: IProject, asset: IAsset): (dispatch:
const assetMetadata = await assetService.getAssetMetadata(asset);
dispatch(loadAssetMetadataAction(assetMetadata));
return {...assetMetadata};
return { ...assetMetadata };
};
}
@@ -171,14 +176,14 @@ export function loadAssetMetadata(project: IProject, asset: IAsset): (dispatch:
export function saveAssetMetadata(
project: IProject,
assetMetadata: IAssetMetadata): (dispatch: Dispatch) => Promise<IAssetMetadata> {
const newAssetMetadata = {...assetMetadata, version: appInfo.version};
const newAssetMetadata = { ...assetMetadata, version: appInfo.version };
return async (dispatch: Dispatch) => {
const assetService = new AssetService(project);
const savedMetadata = await assetService.save(newAssetMetadata);
dispatch(saveAssetMetadataAction(savedMetadata));
return {...savedMetadata};
return { ...savedMetadata };
};
}
+6
Ver Arquivo
@@ -11,6 +11,7 @@ import registerToolbar from "./registerToolbar";
import { strings } from "./common/strings";
import { HostProcessType } from "./common/hostProcess";
import { AzureCustomVisionProvider } from "./providers/export/azureCustomVision";
import { CntkExportProvider } from "./providers/export/cntk";
/**
* Registers storage, asset and export providers, as well as all toolbar items
@@ -68,6 +69,11 @@ export default function registerProviders() {
displayName: strings.export.providers.azureCV.displayName,
factory: (project, options) => new AzureCustomVisionProvider(project, options),
});
ExportProviderFactory.register({
name: "cntk",
displayName: strings.export.providers.cntk.displayName,
factory: (project, options) => new CntkExportProvider(project, options),
});
registerToolbar();
}