fix: test asset distribution to include all tags on test/train split (#823)
* fix: test asset distribution to include all tags on test/train split The test asset may not included all tags when export with test/train split option in current venison (2.1.0). * Extract the same split logic into helper function * Formatting * Inverting if statement
Esse commit está contido em:
@@ -115,9 +115,37 @@ describe("CNTK Export Provider", () => {
|
||||
|
||||
const assetsToExport = await getAssetsSpy.mock.results[0].value;
|
||||
const testSplit = (100 - (defaultOptions.testTrainSplit || 80)) / 100;
|
||||
const testCount = Math.ceil(assetsToExport.length * testSplit);
|
||||
const testArray = assetsToExport.slice(0, testCount);
|
||||
const trainArray = assetsToExport.slice(testCount, assetsToExport.length);
|
||||
|
||||
const trainArray = [];
|
||||
const testArray = [];
|
||||
const tagsAssestList: {
|
||||
[index: string]: {
|
||||
assetSet: Set<string>,
|
||||
testArray: string[],
|
||||
trainArray: string[],
|
||||
},
|
||||
} = {};
|
||||
testProject.tags.forEach((tag) =>
|
||||
tagsAssestList[tag.name] = {
|
||||
assetSet: new Set(), testArray: [],
|
||||
trainArray: [],
|
||||
});
|
||||
assetsToExport.forEach((assetMetadata) => {
|
||||
assetMetadata.regions.forEach((region) => {
|
||||
region.tags.forEach((tagName) => {
|
||||
if (tagsAssestList[tagName]) {
|
||||
tagsAssestList[tagName].assetSet.add(assetMetadata.asset.name);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
for (const tagKey of Object.keys(tagsAssestList)) {
|
||||
const assetSet = tagsAssestList[tagKey].assetSet;
|
||||
const testCount = Math.ceil(assetSet.size * testSplit);
|
||||
testArray.push(...Array.from(assetSet).slice(0, testCount));
|
||||
trainArray.push(...Array.from(assetSet).slice(testCount, assetSet.size));
|
||||
}
|
||||
|
||||
const storageProviderMock = LocalFileSystemProxy as any;
|
||||
const writeBinaryCalls = storageProviderMock.mock.instances[0].writeBinary.mock.calls;
|
||||
|
||||
@@ -3,6 +3,7 @@ import { ExportProvider, IExportResults } from "./exportProvider";
|
||||
import { IAssetMetadata, IExportProviderOptions, IProject } from "../../models/applicationState";
|
||||
import HtmlFileReader from "../../common/htmlFileReader";
|
||||
import Guard from "../../common/guard";
|
||||
import { splitTestAsset } from "./testAssetsSplitHelper";
|
||||
|
||||
enum ExportSplit {
|
||||
Test,
|
||||
@@ -33,13 +34,17 @@ export class CntkExportProvider extends ExportProvider<ICntkExportProviderOption
|
||||
public async export(): Promise<IExportResults> {
|
||||
await this.createFolderStructure();
|
||||
const assetsToExport = await this.getAssetsForExport();
|
||||
const testAssets: string[] = [];
|
||||
|
||||
const testSplit = (100 - (this.options.testTrainSplit || 80)) / 100;
|
||||
const testCount = Math.ceil(assetsToExport.length * testSplit);
|
||||
const testArray = assetsToExport.slice(0, testCount);
|
||||
if (testSplit > 0 && testSplit <= 1) {
|
||||
const splittedAssets = splitTestAsset(assetsToExport, this.project.tags, testSplit);
|
||||
testAssets.push(...splittedAssets);
|
||||
}
|
||||
|
||||
const results = await assetsToExport.mapAsync(async (assetMetadata) => {
|
||||
try {
|
||||
const exportSplit = testArray.find((am) => am.asset.id === assetMetadata.asset.id)
|
||||
const exportSplit = testAssets.find((am) => am === assetMetadata.asset.id)
|
||||
? ExportSplit.Test
|
||||
: ExportSplit.Train;
|
||||
|
||||
|
||||
@@ -69,7 +69,9 @@ describe("PascalVOC Json Export Provider", () => {
|
||||
beforeEach(() => {
|
||||
const assetServiceMock = AssetService as jest.Mocked<typeof AssetService>;
|
||||
assetServiceMock.prototype.getAssetMetadata = jest.fn((asset) => {
|
||||
const mockTag = MockFactory.createTestTag();
|
||||
const mockTag1 = MockFactory.createTestTag("1");
|
||||
const mockTag2 = MockFactory.createTestTag("2");
|
||||
const mockTag = Number(asset.id.split("-")[1]) > 7 ? mockTag1 : mockTag2;
|
||||
const mockRegion1 = MockFactory.createTestRegion("region-1", [mockTag.name]);
|
||||
const mockRegion2 = MockFactory.createTestRegion("region-2", [mockTag.name]);
|
||||
|
||||
@@ -352,27 +354,70 @@ describe("PascalVOC Json Export Provider", () => {
|
||||
};
|
||||
|
||||
const testProject = { ...baseTestProject };
|
||||
const testAssets = MockFactory.createTestAssets(10, 0);
|
||||
const testAssets = MockFactory.createTestAssets(13, 0);
|
||||
testAssets.forEach((asset) => asset.state = AssetState.Tagged);
|
||||
testProject.assets = _.keyBy(testAssets, (asset) => asset.id);
|
||||
testProject.tags = [MockFactory.createTestTag("1")];
|
||||
testProject.tags = MockFactory.createTestTags(3);
|
||||
|
||||
const exportProvider = new PascalVOCExportProvider(testProject, options);
|
||||
const getAssetsSpy = jest.spyOn(exportProvider, "getAssetsForExport");
|
||||
|
||||
await exportProvider.export();
|
||||
|
||||
const storageProviderMock = LocalFileSystemProxy as any;
|
||||
const writeTextFileCalls = storageProviderMock.mock.instances[0].writeText.mock.calls as any[];
|
||||
|
||||
const valDataIndex = writeTextFileCalls
|
||||
const valDataIndex1 = writeTextFileCalls
|
||||
.findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 1_val.txt"));
|
||||
const trainDataIndex = writeTextFileCalls
|
||||
const trainDataIndex1 = writeTextFileCalls
|
||||
.findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 1_train.txt"));
|
||||
const valDataIndex2 = writeTextFileCalls
|
||||
.findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 2_val.txt"));
|
||||
const trainDataIndex2 = writeTextFileCalls
|
||||
.findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 2_train.txt"));
|
||||
|
||||
const expectedTrainCount = (testTrainSplit / 100) * testAssets.length;
|
||||
const expectedTestCount = ((100 - testTrainSplit) / 100) * testAssets.length;
|
||||
const assetsToExport = await getAssetsSpy.mock.results[0].value;
|
||||
const trainArray = [];
|
||||
const testArray = [];
|
||||
const tagsAssestList: {
|
||||
[index: string]: {
|
||||
assetSet: Set<string>,
|
||||
testArray: string[],
|
||||
trainArray: string[],
|
||||
},
|
||||
} = {};
|
||||
testProject.tags.forEach((tag) =>
|
||||
tagsAssestList[tag.name] = {
|
||||
assetSet: new Set(), testArray: [],
|
||||
trainArray: [],
|
||||
});
|
||||
assetsToExport.forEach((assetMetadata) => {
|
||||
assetMetadata.regions.forEach((region) => {
|
||||
region.tags.forEach((tagName) => {
|
||||
if (tagsAssestList[tagName]) {
|
||||
tagsAssestList[tagName].assetSet.add(assetMetadata.asset.name);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
expect(writeTextFileCalls[valDataIndex][1].split("\n")).toHaveLength(expectedTestCount);
|
||||
expect(writeTextFileCalls[trainDataIndex][1].split("\n")).toHaveLength(expectedTrainCount);
|
||||
for (const tagKey of Object.keys(tagsAssestList)) {
|
||||
const assetSet = tagsAssestList[tagKey].assetSet;
|
||||
const testCount = Math.ceil(((100 - testTrainSplit) / 100) * assetSet.size);
|
||||
tagsAssestList[tagKey].testArray = Array.from(assetSet).slice(0, testCount);
|
||||
tagsAssestList[tagKey].trainArray = Array.from(assetSet).slice(testCount, assetSet.size);
|
||||
testArray.push(...tagsAssestList[tagKey].testArray);
|
||||
trainArray.push(...tagsAssestList[tagKey].trainArray);
|
||||
}
|
||||
|
||||
expect(writeTextFileCalls[valDataIndex1][1].split(/\r?\n/).filter((line) =>
|
||||
line.endsWith(" 1"))).toHaveLength(tagsAssestList["Tag 1"].testArray.length);
|
||||
expect(writeTextFileCalls[trainDataIndex1][1].split(/\r?\n/).filter((line) =>
|
||||
line.endsWith(" 1"))).toHaveLength(tagsAssestList["Tag 1"].trainArray.length);
|
||||
expect(writeTextFileCalls[valDataIndex2][1].split(/\r?\n/).filter((line) =>
|
||||
line.endsWith(" 1"))).toHaveLength(tagsAssestList["Tag 2"].testArray.length);
|
||||
expect(writeTextFileCalls[trainDataIndex2][1].split(/\r?\n/).filter((line) =>
|
||||
line.endsWith(" 1"))).toHaveLength(tagsAssestList["Tag 2"].trainArray.length);
|
||||
}
|
||||
|
||||
it("Correctly generated files based on 50/50 test / train split", async () => {
|
||||
|
||||
@@ -6,6 +6,7 @@ import HtmlFileReader from "../../common/htmlFileReader";
|
||||
import { itemTemplate, annotationTemplate, objectTemplate } from "./pascalVOC/pascalVOCTemplates";
|
||||
import { interpolate } from "../../common/strings";
|
||||
import os from "os";
|
||||
import { splitTestAsset } from "./testAssetsSplitHelper";
|
||||
|
||||
interface IObjectInfo {
|
||||
name: string;
|
||||
@@ -253,40 +254,58 @@ export class PascalVOCExportProvider extends ExportProvider<IPascalVOCExportProv
|
||||
}
|
||||
});
|
||||
|
||||
// Save ImageSets
|
||||
await tags.forEachAsync(async (tag) => {
|
||||
const tagInstances = tagUsage.get(tag.name) || 0;
|
||||
if (!exportUnassignedTags && tagInstances === 0) {
|
||||
return;
|
||||
}
|
||||
if (testSplit > 0 && testSplit <= 1) {
|
||||
const tags = this.project.tags;
|
||||
const testAssets: string[] = splitTestAsset(allAssets, tags, testSplit);
|
||||
|
||||
const assetList = [];
|
||||
assetUsage.forEach((tags, assetName) => {
|
||||
if (tags.has(tag.name)) {
|
||||
assetList.push(`${assetName} 1`);
|
||||
} else {
|
||||
assetList.push(`${assetName} -1`);
|
||||
await tags.forEachAsync(async (tag) => {
|
||||
const tagInstances = tagUsage.get(tag.name) || 0;
|
||||
if (!exportUnassignedTags && tagInstances === 0) {
|
||||
return;
|
||||
}
|
||||
});
|
||||
|
||||
if (testSplit > 0 && testSplit <= 1) {
|
||||
// Split in Test and Train sets
|
||||
const totalAssets = assetUsage.size;
|
||||
const testCount = Math.ceil(totalAssets * testSplit);
|
||||
|
||||
const testArray = assetList.slice(0, testCount);
|
||||
const trainArray = assetList.slice(testCount, totalAssets);
|
||||
const testArray = [];
|
||||
const trainArray = [];
|
||||
assetUsage.forEach((tags, assetName) => {
|
||||
let assetString = "";
|
||||
if (tags.has(tag.name)) {
|
||||
assetString = `${assetName} 1`;
|
||||
} else {
|
||||
assetString = `${assetName} -1`;
|
||||
}
|
||||
if (testAssets.find((am) => am === assetName)) {
|
||||
testArray.push(assetString);
|
||||
} else {
|
||||
trainArray.push(assetString);
|
||||
}
|
||||
});
|
||||
|
||||
const testImageSetFileName = `${imageSetsMainFolderName}/${tag.name}_val.txt`;
|
||||
await this.storageProvider.writeText(testImageSetFileName, testArray.join(os.EOL));
|
||||
|
||||
const trainImageSetFileName = `${imageSetsMainFolderName}/${tag.name}_train.txt`;
|
||||
await this.storageProvider.writeText(trainImageSetFileName, trainArray.join(os.EOL));
|
||||
});
|
||||
} else {
|
||||
|
||||
// Save ImageSets
|
||||
await tags.forEachAsync(async (tag) => {
|
||||
const tagInstances = tagUsage.get(tag.name) || 0;
|
||||
if (!exportUnassignedTags && tagInstances === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const assetList = [];
|
||||
assetUsage.forEach((tags, assetName) => {
|
||||
if (tags.has(tag.name)) {
|
||||
assetList.push(`${assetName} 1`);
|
||||
} else {
|
||||
assetList.push(`${assetName} -1`);
|
||||
}
|
||||
});
|
||||
|
||||
} else {
|
||||
const imageSetFileName = `${imageSetsMainFolderName}/${tag.name}.txt`;
|
||||
await this.storageProvider.writeText(imageSetFileName, assetList.join(os.EOL));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
import _ from "lodash";
|
||||
import {
|
||||
IAssetMetadata, AssetState, IRegion,
|
||||
RegionType, IPoint, IExportProviderOptions,
|
||||
} from "../../models/applicationState";
|
||||
import MockFactory from "../../common/mockFactory";
|
||||
import { splitTestAsset } from "./testAssetsSplitHelper";
|
||||
import { appInfo } from "../../common/appInfo";
|
||||
|
||||
describe("splitTestAsset Helper tests", () => {
|
||||
|
||||
describe("Test Train Splits", () => {
|
||||
async function testTestTrainSplit(testTrainSplit: number): Promise<void> {
|
||||
const assetArray = MockFactory.createTestAssets(13, 0);
|
||||
const tags = MockFactory.createTestTags(2);
|
||||
assetArray.forEach((asset) => asset.state = AssetState.Tagged);
|
||||
|
||||
const testSplit = (100 - testTrainSplit) / 100;
|
||||
const testCount = Math.ceil(testSplit * assetArray.length);
|
||||
|
||||
const assetMetadatas = assetArray.map((asset, i) =>
|
||||
MockFactory.createTestAssetMetadata(asset,
|
||||
i < (assetArray.length - testCount) ?
|
||||
[MockFactory.createTestRegion("Region" + i, [tags[0].name])] :
|
||||
[MockFactory.createTestRegion("Region" + i, [tags[1].name])]));
|
||||
const testAssetsNames = splitTestAsset(assetMetadatas, tags, testSplit);
|
||||
|
||||
const trainAssetsArray = assetMetadatas.filter((assetMetadata) =>
|
||||
testAssetsNames.indexOf(assetMetadata.asset.name) < 0);
|
||||
const testAssetsArray = assetMetadatas.filter((assetMetadata) =>
|
||||
testAssetsNames.indexOf(assetMetadata.asset.name) >= 0);
|
||||
|
||||
const expectedTestCount = Math.ceil(testSplit * testCount) +
|
||||
Math.ceil(testSplit * (assetArray.length - testCount));
|
||||
expect(testAssetsNames).toHaveLength(expectedTestCount);
|
||||
expect(trainAssetsArray.length + testAssetsArray.length).toEqual(assetMetadatas.length);
|
||||
expect(testAssetsArray).toHaveLength(expectedTestCount);
|
||||
|
||||
expect(testAssetsArray.filter((assetMetadata) => assetMetadata.regions[0].tags[0] === tags[0].name).length)
|
||||
.toBeGreaterThan(0);
|
||||
expect(testAssetsArray.filter((assetMetadata) => assetMetadata.regions[0].tags[0] === tags[1].name).length)
|
||||
.toBeGreaterThan(0);
|
||||
}
|
||||
|
||||
it("Correctly generated files based on 50/50 test / train split", async () => {
|
||||
await testTestTrainSplit(50);
|
||||
});
|
||||
|
||||
it("Correctly generated files based on 60/40 test / train split", async () => {
|
||||
await testTestTrainSplit(60);
|
||||
});
|
||||
|
||||
it("Correctly generated files based on 80/20 test / train split", async () => {
|
||||
await testTestTrainSplit(80);
|
||||
});
|
||||
|
||||
it("Correctly generated files based on 90/10 test / train split", async () => {
|
||||
await testTestTrainSplit(90);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,30 @@
|
||||
import { IAssetMetadata, ITag } from "../../models/applicationState";
|
||||
|
||||
/**
|
||||
* A helper function to split train and test assets
|
||||
* @param template String containing variables
|
||||
* @param params Params containing substitution values
|
||||
*/
|
||||
export function splitTestAsset(allAssets: IAssetMetadata[], tags: ITag[], testSplitRatio: number): string[] {
|
||||
if (testSplitRatio <= 0 || testSplitRatio > 1) { return []; }
|
||||
|
||||
const testAssets: string[] = [];
|
||||
const tagsAssetDict: { [index: string]: { assetList: Set<string> } } = {};
|
||||
tags.forEach((tag) => tagsAssetDict[tag.name] = { assetList: new Set() });
|
||||
allAssets.forEach((assetMetadata) => {
|
||||
assetMetadata.regions.forEach((region) => {
|
||||
region.tags.forEach((tagName) => {
|
||||
if (tagsAssetDict[tagName]) {
|
||||
tagsAssetDict[tagName].assetList.add(assetMetadata.asset.name);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
for (const tagKey of Object.keys(tagsAssetDict)) {
|
||||
const assetList = tagsAssetDict[tagKey].assetList;
|
||||
const testCount = Math.ceil(assetList.size * testSplitRatio);
|
||||
testAssets.push(...Array.from(assetList).slice(0, testCount));
|
||||
}
|
||||
return testAssets;
|
||||
}
|
||||
Referência em uma Nova Issue
Bloquear um usuário