Added recurrent layers. Fixes #477 (#645)

WIP #477 Added skipArgs and rnn parsing support

WIP #477 Added update script and updated CreateTorchMeta

WIP #477 Run update script after nn-parser

WIP #477 Simplified nn-parser usage (just 'rnn' or 'nn')

WIP #477 Added 'all' option

WIP #477 Updated the nn project

WIP #477 `require 'rnn'` => nop in importer

WIP #477 Updated gen arch

WIP #477 Updated update --torch

WIP #477 Added rnn installation to torch install

WIP #477 Fixed code climate issues
Esse commit está contido em:
Brian Broll
2016-08-09 13:50:19 -05:00
commit de GitHub
commit 6d70728b54
19 arquivos alterados com 516 adições e 289 exclusões
+41 -25
Ver Arquivo
@@ -3,7 +3,7 @@
var Command = require('commander').Command,
program = new Command(),
childProcess = require('child_process'),
spawn = childProcess.spawn,
rawSpawn = childProcess.spawn,
execSync = childProcess.execSync,
path = require('path'),
fs = require('fs'),
@@ -115,7 +115,7 @@ var checkMongo = function(args) {
};
var startMongo = function(args, silent) {
var job = spawn('mongod', ['--dbpath', p(config.mongo.dir)], {
var job = rawSpawn('mongod', ['--dbpath', p(config.mongo.dir)], {
cwd: process.env.HOME
});
if (!silent) {
@@ -166,7 +166,8 @@ var _checkTorch = function(resolve, reject) {
process.chdir(tgtDir);
spawnMany([
'bash install-deps',
'./install.sh'
'./install.sh',
'luarocks install rnn'
], () => {
storeConfig('torch.dir', tgtDir);
resolve(true);
@@ -192,7 +193,7 @@ var spawnMany = function(cmds, succ, err) {
rawCmd = cmds.shift();
args = rawCmd.split(' ');
cmd = args.shift();
job = spawn(cmd, args);
job = rawSpawn(cmd, args);
job.stdout.on('data', data => process.stdout.write(data));
job.stderr.on('data', data => process.stderr.write(data));
job.on('close', code => {
@@ -206,6 +207,28 @@ var spawnMany = function(cmds, succ, err) {
};
var spawn = function(cmd, args, opts) {
var promise,
err;
args = args || [];
promise = new Promise((resolve, reject) => {
var job = opts ? rawSpawn(cmd, args, opts) : rawSpawn(cmd, args);
job.stdout.on('data', data => process.stdout.write(data));
job.stderr.on('data', data => process.stderr.write(data));
job.on('close', code => {
if (err) {
reject(err);
} else {
resolve(code);
}
});
job.on('error', e => err = e);
});
return promise;
};
program.command('start')
.description('start deepforge locally (default) or specific components')
.option('-p, --port <port>', 'specify the port to use')
@@ -264,7 +287,6 @@ program
.option('-s, --server', 'update deepforge')
.action(args => {
var pkg = 'deepforge',
job,
latestVersion;
// Install the project
@@ -287,16 +309,11 @@ program
}
}
job = spawn('npm', ['install', '-g', pkg]);
job.stdout.on('data', data => process.stdout.write(data.toString()));
job.stderr.on('data', data => process.stderr.write(data.toString()));
job.on('close', code => {
if (!code) {
spawn('npm', ['install', '-g', pkg])
.then(() => {
console.log('Upgrade successful!');
} else {
console.log('Upgrade failed w/ error code: ' + code);
}
});
})
.catch(code => console.log('Upgrade failed w/ error code: ' + code));
}
if (args.torch || !args.server) {
@@ -316,18 +333,17 @@ program
return;
}
job = spawn('bash', ['./update.sh'], {
cwd: p(config.torch.dir)
});
job.stdout.on('data', data => process.stdout.write(data.toString()));
job.stderr.on('data', data => process.stderr.write(data.toString()));
job.on('close', code => {
if (!code) {
spawn('bash', ['./update.sh'], {cwd: p(config.torch.dir)})
.catch(err => console.log('Upgrade failed w/ error code: ' + err.code))
.then(() => {
console.log('About to update rnn package...');
// Update rnn
return spawn('luarocks', ['install', 'rnn']);
})
.then(() => {
console.log('Upgrade successful!');
} else {
console.log('Upgrade failed w/ error code: ' + code);
}
});
})
.catch(code => console.log('Upgrade failed w/ error code: ' + code));
}
});
}
+95 -111
Ver Arquivo
@@ -7,7 +7,7 @@ define([
'js/RegistryKeys',
'js/Panels/MetaEditor/MetaEditorConstants',
'underscore',
'text!deepforge/layers.json',
'./schemas/index',
'text!./metadata.json'
], function (
PluginBase,
@@ -15,7 +15,7 @@ define([
REGISTRY_KEYS,
META_CONSTANTS,
_,
DEFAULT_LAYERS,
Schemas,
metadata
) {
'use strict';
@@ -51,110 +51,98 @@ define([
* @param {function(string, plugin.PluginResult)} callback - the result callback
*/
CreateTorchMeta.prototype.main = function (callback) {
// Use self to access core, project, result, logger etc from PluginBase.
// These are all instantiated at this point.
var self = this;
if (!this.META.Language) {
return callback('"Language" container required to run plugin', this.result);
}
// Extra layer names
this.getJsonLayers((err, text) => {
if (err) {
return callback(err, this.result);
// The format is...
// - (Abstract) CategoryLayerTypes
// - LayerName
// - Attributes (if exists)
var layers,
content = {},
categories,
config = this.getCurrentConfig(),
nodes = {};
try {
layers = this.getJsonLayers();
} catch (e) {
return callback('JSON parse error: ' + e, this.result);
}
layers.forEach(layer => {
if (!content[layer.type]) {
content[layer.type] = [];
}
// The format is...
// - (Abstract) CategoryLayerTypes
// - LayerName
// - Attributes (if exists)
var content = {},
categories,
config = this.getCurrentConfig(),
nodes = {},
layers;
try {
layers = JSON.parse(text);
} catch (e) {
return callback('JSON parse error: ' + e, this.result);
}
layers.forEach(layer => {
if (!content[layer.type]) {
content[layer.type] = [];
}
content[layer.type].push(layer);
});
categories = Object.keys(content);
// Create the base class, if needed
if (!this.META.Layer) {
this.META.Layer = this.createMetaNode('Layer', this.META.FCO);
}
// Create the category nodes
categories
.forEach(name => {
// Create a tab for each
this.metaSheets[name] = this.createMetaSheetTab(name);
this.sheetCounts[name] = 0;
nodes[name] = this.createMetaNode(name, this.META.Layer, name);
});
// Make them abstract
categories
.forEach(name => this.core.setRegistry(nodes[name], 'isAbstract', true));
if (config.removeOldLayers) {
var isNewLayer = {},
newLayers = layers.map(layer => layer.name),
oldLayers,
oldNames;
newLayers = newLayers.concat(categories); // add the category nodes
newLayers.forEach(name => isNewLayer[name] = true);
// Set the newLayer nodes 'base' to 'Layer' so we don't accidentally
// delete them
newLayers
.map(name => this.META[name])
.filter(layer => !!layer)
.forEach(layer => this.core.setBase(layer, this.META.Layer));
oldLayers = Object.keys(this.META)
.filter(name => name !== 'Layer')
.map(name => this.META[name])
.filter(node => this.isMetaTypeOf(node, this.META.Layer))
.filter(node => !isNewLayer[this.core.getAttribute(node, 'name')]);
oldNames = oldLayers.map(l => this.core.getAttribute(l, 'name'));
// Get the old layer names
this.logger.debug(`Removing layers: ${oldNames.join(', ')}`);
oldLayers.forEach(layer => this.core.deleteNode(layer));
}
// Create the actual nodes
categories.forEach(cat => {
content[cat]
.forEach(layer => {
var name = layer.name;
nodes[name] = this.createMetaNode(name, nodes[cat], cat, layer);
// Make the node non-abstract
this.core.setRegistry(nodes[name], 'isAbstract', false);
});
});
self.save('CreateTorchMeta updated model.', function (err) {
if (err) {
callback(err, self.result);
return;
}
self.result.setSuccess(true);
callback(null, self.result);
});
content[layer.type].push(layer);
});
categories = Object.keys(content);
// Create the base class, if needed
if (!this.META.Layer) {
this.META.Layer = this.createMetaNode('Layer', this.META.FCO);
}
// Create the category nodes
categories
.forEach(name => {
// Create a tab for each
this.metaSheets[name] = this.createMetaSheetTab(name);
this.sheetCounts[name] = 0;
nodes[name] = this.createMetaNode(name, this.META.Layer, name);
});
// Make them abstract
categories
.forEach(name => this.core.setRegistry(nodes[name], 'isAbstract', true));
if (config.removeOldLayers) {
var isNewLayer = {},
newLayers = layers.map(layer => layer.name),
oldLayers,
oldNames;
newLayers = newLayers.concat(categories); // add the category nodes
newLayers.forEach(name => isNewLayer[name] = true);
// Set the newLayer nodes 'base' to 'Layer' so we don't accidentally
// delete them
newLayers
.map(name => this.META[name])
.filter(layer => !!layer)
.forEach(layer => this.core.setBase(layer, this.META.Layer));
oldLayers = Object.keys(this.META)
.filter(name => name !== 'Layer')
.map(name => this.META[name])
.filter(node => this.isMetaTypeOf(node, this.META.Layer))
.filter(node => !isNewLayer[this.core.getAttribute(node, 'name')]);
oldNames = oldLayers.map(l => this.core.getAttribute(l, 'name'));
// Get the old layer names
this.logger.debug(`Removing layers: ${oldNames.join(', ')}`);
oldLayers.forEach(layer => this.core.deleteNode(layer));
}
// Create the actual nodes
categories.forEach(cat => {
content[cat]
.forEach(layer => {
var name = layer.name;
nodes[name] = this.createMetaNode(name, nodes[cat], cat, layer);
// Make the node non-abstract
this.core.setRegistry(nodes[name], 'isAbstract', false);
});
});
this.save('CreateTorchMeta updated model.')
.then(() => {
this.result.setSuccess(true);
callback(null, this.result);
})
.fail(err => callback(err, this.result));
};
CreateTorchMeta.prototype.removeFromMeta = function (nodeId) {
@@ -196,20 +184,16 @@ define([
return sheet.SetID;
};
CreateTorchMeta.prototype.getJsonLayers = function (callback) {
var config = this.getCurrentConfig();
CreateTorchMeta.prototype.getJsonLayers = function () {
var config = this.getCurrentConfig(),
schema = config.layerSchema;
if (config.layerNameHash) {
this.blobClient.getObject(config.layerNameHash, (err, buffer) => {
if (err) {
return callback(err, this.result);
}
var text = String.fromCharCode.apply(null, new Uint8Array(buffer));
return callback(null, text);
});
} else {
return callback(null, DEFAULT_LAYERS);
if (schema === 'all') {
return Object.keys(Schemas).map(key => JSON.parse(Schemas[key]))
.reduce((l1, l2) => l1.concat(l2), []);
}
return JSON.parse(Schemas[schema]);
};
var isBoolean = txt => {
+23 -18
Ver Arquivo
@@ -7,24 +7,29 @@
"src": "",
"class": "glyphicon glyphicon-ok-circle"
},
"disableServerSideExecution": false,
"disableServerSideExecution": true,
"disableBrowserSideExecution": false,
"configStructure": [
{
"name": "layerNameHash",
"displayName": "Torch Layers",
"description": "Yaml file of torch layer descriptors (optional)",
"value": "",
"valueType": "asset",
"readOnly": false
},
{
"name": "removeOldLayers",
"displayName": "Delete old layers",
"description": "Delete all layers not in the current description",
"value": true,
"valueType": "boolean",
"readOnly": false
}
{
"name": "layerSchema",
"displayName": "Torch Libraries",
"description": "Torch nn libraries to create layers from",
"value": "all",
"valueItems": [
"nn",
"rnn",
"all"
],
"valueType": "string",
"readOnly": false
},
{
"name": "removeOldLayers",
"displayName": "Delete old layers",
"description": "Delete all layers not in the current description",
"value": true,
"valueType": "boolean",
"readOnly": false
}
]
}
}
+13
Ver Arquivo
@@ -0,0 +1,13 @@
/*globals define*/
define([
'text!./nn.json',
'text!./rnn.json'
], function(
nn,
rnn
) {
return {
nn: nn,
rnn: rnn
};
});
@@ -324,9 +324,7 @@
{
"name": "GradientReversal",
"baseType": "Module",
"params": [
"lambda"
],
"params": [],
"setters": {},
"defaults": {},
"type": "Misc"
@@ -348,8 +346,7 @@
"baseType": "Module",
"params": [
"min_value",
"max_value",
"inplace"
"max_value"
],
"setters": {},
"defaults": {
@@ -458,7 +455,9 @@
{
"name": "Log",
"baseType": "Module",
"params": [],
"params": [
"inputSize"
],
"setters": {},
"defaults": {},
"type": "Misc"
@@ -832,16 +831,6 @@
"defaults": {},
"type": "Transfer"
},
{
"name": "ReLU6",
"baseType": "Module",
"params": [
"inplace"
],
"setters": {},
"defaults": {},
"type": "Transfer"
},
{
"name": "Replicate",
"baseType": "Module",
@@ -1036,17 +1025,6 @@
"defaults": {},
"type": "Convolution"
},
{
"name": "SpatialClassNLLCriterion",
"baseType": "Criterion",
"params": [
"weights",
"sizeAverage"
],
"setters": {},
"defaults": {},
"type": "Criterion"
},
{
"name": "SpatialContrastiveNormalization",
"baseType": "Module",
@@ -1152,28 +1130,6 @@
},
"type": "Convolution"
},
{
"name": "SpatialDilatedConvolution",
"baseType": "SpatialConvolution",
"params": [
"nInputPlane",
"nOutputPlane",
"kW",
"kH",
"dW",
"dH",
"padW",
"padH",
"dilationH",
"dilationW"
],
"setters": {},
"defaults": {
"dilationW": 1,
"dilationH": 1
},
"type": "Convolution"
},
{
"name": "SpatialDivisiveNormalization",
"baseType": "Module",
@@ -1692,21 +1648,6 @@
"defaults": {},
"type": "Convolution"
},
{
"name": "VolumetricReplicationPadding",
"baseType": "Module",
"params": [
"pleft",
"pright",
"ptop",
"pbottom",
"pfront",
"pback"
],
"setters": {},
"defaults": {},
"type": "Misc"
},
{
"name": "WeightedEuclidean",
"baseType": "Module",
+142
Ver Arquivo
@@ -0,0 +1,142 @@
[
{
"name": "CopyGrad",
"baseType": "Identity",
"setters": {},
"defaults": {},
"type": "Recurrent"
},
{
"name": "FastLSTM",
"baseType": "LSTM",
"params": [
"inputSize",
"outputSize",
"rho",
"eps",
"momentum",
"affine"
],
"setters": {},
"defaults": {
"momentum": 0.1,
"eps": 0.1
},
"type": "Recurrent"
},
{
"name": "LSTM",
"baseType": "AbstractRecurrent",
"params": [
"inputSize",
"outputSize",
"rho",
"cell2gate"
],
"setters": {},
"defaults": {
"rho": 9999
},
"type": "Recurrent"
},
{
"name": "LinearNoBias",
"baseType": "Linear",
"params": [
"inputSize",
"outputSize"
],
"setters": {},
"defaults": {},
"type": "Simple"
},
{
"name": "LookupTableMaskZero",
"baseType": "LookupTable",
"params": [
"nIndex",
"nOutput"
],
"setters": {},
"defaults": {},
"type": "Recurrent"
},
{
"name": "NormStabilizer",
"baseType": "AbstractRecurrent",
"params": [
"beta"
],
"setters": {},
"defaults": {},
"type": "Recurrent"
},
{
"name": "SAdd",
"baseType": "Module",
"params": [
"addend",
"negate"
],
"setters": {},
"defaults": {},
"type": "Recurrent"
},
{
"name": "SeqBRNN",
"baseType": "Container",
"params": [
"inputDim",
"hiddenDim",
"batchFirst"
],
"setters": {},
"defaults": {},
"type": "Recurrent"
},
{
"name": "SeqGRU",
"baseType": "Module",
"params": [
"inputSize",
"outputSize"
],
"setters": {},
"defaults": {},
"type": "Recurrent"
},
{
"name": "SeqLSTM",
"baseType": "Module",
"params": [
"inputsize",
"hiddensize",
"outputsize"
],
"setters": {},
"defaults": {},
"type": "Recurrent"
},
{
"name": "SeqLSTMP",
"baseType": "SeqLSTM",
"params": [
"inputsize",
"hiddensize",
"outputsize"
],
"setters": {},
"defaults": {},
"type": "Recurrent"
},
{
"name": "SeqReverseSequence",
"baseType": "Module",
"params": [
"dim"
],
"setters": {},
"defaults": {},
"type": "Recurrent"
}
]
@@ -0,0 +1,37 @@
/* eslint-disable no-console */
// Update the metadata and schemas/index based on the new schemas in schemas/
// Update metadata
var fs = require('fs'),
path = require('path'),
schemas,
metadata = require('./metadata.json'),
schemaList;
schemas = fs.readdirSync(__dirname + '/schemas/')
.filter(name => path.extname(name) === '.json')
.map(name => name.replace(/\.json$/, ''));
console.log('Discovered schemas: ' + schemas.join(', '));
schemaList = metadata.configStructure.find(struct => struct.name === 'layerSchema');
schemaList.valueItems = schemas.concat('all');
console.log('Updating metadata...');
fs.writeFileSync(__dirname + '/metadata.json', JSON.stringify(metadata, null, 2));
// Update index.js
var index =
`/*globals define*/
define([
${schemas.map(s => `'text!./${s}.json'`).join(',\n ')}
], function(
${schemas.map(s => s).join(',\n ')}
) {
return {
${schemas.map(s => s + ': ' + s).join(',\n ')}
};
});`;
console.log('Updating index.js...');
fs.writeFileSync(__dirname + '/schemas/index.js', index);
@@ -58,9 +58,8 @@ define([
var layers = tree[Constants.CHILDREN],
//initialLayers,
result = {},
code = 'require \'nn\'\n';
code = 'require \'nn\'\nrequire \'rnn\'\n';
//initialLayers = layers.filter(layer => layer[Constants.PREV].length === 0);
// Add an index to each layer
layers.forEach((l, index) => l[INDEX] = index);
+2 -6
Ver Arquivo
@@ -75,8 +75,6 @@ define([
this.bin = this.context.loadString(src);
this.bin();
this.afterExecution();
return this.save('ImportTorch updated model.');
})
.then(() => { // changes saved successfully
@@ -101,6 +99,8 @@ define([
this.context._G.get('package').set('searchers', [function(name) {
if (name === 'nn') {
return lib;
} else {
return () => {};
}
}]);
@@ -108,9 +108,5 @@ define([
// "nn" package to the global scope...
};
ImportTorch.prototype.afterExecution = function () {
// TODO
};
return ImportTorch;
});
Arquivo binário não exibido.
Arquivo binário não exibido.
Arquivo binário não exibido.
+62 -54
Ver Arquivo
@@ -1,7 +1,4 @@
/*jshint node:true, mocha:true*/
/**
* Generated by PluginGenerator 0.14.0 from webgme on Thu Mar 10 2016 04:16:02 GMT-0600 (CST).
*/
'use strict';
var testFixture = require('../../globals'),
@@ -103,60 +100,16 @@ describe('ImportTorch', function () {
});
var runTest = function(name, done) {
var manager = new PluginCliManager(null, logger, gmeConfig),
pluginConfig = {},
context = {
namespace: 'nn',
project: project,
branchName: 'test',
activeNode: ''
},
data = fs.readFileSync(path.join(TEST_CASE_DIR, name), 'utf8'),
var data = fs.readFileSync(path.join(TEST_CASE_DIR, name), 'utf8'),
ymlFile = path.join(YAML_DIR, name.replace(/lua$/, 'yml')),
yml = fs.readFileSync(ymlFile, 'utf8'),
initModels;
yml = fs.readFileSync(ymlFile, 'utf8');
// Load the children from the head of the 'test' branch
project.getBranchHash('test')
.then(function (branchHash) {
return Q.ninvoke(project, 'loadObject', branchHash);
})
.then(function (commitObject) {
return Q.ninvoke(core, 'loadRoot', commitObject.root);
})
.then(function (root) {
return core.loadChildren(root);
})
.then(children => {
initModels = children.map(core.getPath);
return blobClient.putFile(name, data); // upload the file
})
.then(hash => {
pluginConfig.srcHash = hash;
return Q.nfcall(
manager.executePlugin.bind(manager),
pluginName,
pluginConfig,
context
);
})
.then(pluginResult => {
expect(typeof pluginResult).to.equal('object');
expect(pluginResult.success).to.equal(true);
return project.getBranchHash('test');
})
importTorch(name, data)
// Use the check-model object to check the result models!
.then(function (branchHash) {
return Q.ninvoke(project, 'loadObject', branchHash);
})
.then(function (commitObject) {
return Q.ninvoke(core, 'loadRoot', commitObject.root);
})
.then(function (root) {
return core.loadChildren(root);
})
.then(children => {
var newModel = children.find(model =>
.then(groups => {
var children = groups[1],
initModels = groups[0],
newModel = children.find(model =>
initModels.indexOf(core.getPath(model)) === -1);
expect(initModels.length+1).to.equal(children.length);
@@ -176,6 +129,56 @@ describe('ImportTorch', function () {
.nodeify(done);
};
var importTorch = function(name, code) {
var manager = new PluginCliManager(null, logger, gmeConfig),
pluginConfig = {},
context = {
namespace: 'nn',
project: project,
branchName: 'test',
activeNode: ''
},
initModels;
// Load the children from the head of the 'test' branch
return project.getBranchHash('test')
.then(function (branchHash) {
return Q.ninvoke(project, 'loadObject', branchHash);
})
.then(function (commitObject) {
return Q.ninvoke(core, 'loadRoot', commitObject.root);
})
.then(function (root) {
return core.loadChildren(root);
})
.then(children => {
initModels = children.map(core.getPath);
return blobClient.putFile(name, code); // upload the file
})
.then(hash => {
pluginConfig.srcHash = hash;
return Q.nfcall(
manager.executePlugin.bind(manager),
pluginName,
pluginConfig,
context
);
})
.then(pluginResult => {
expect(typeof pluginResult).to.equal('object');
expect(pluginResult.success).to.equal(true);
return project.getBranchHash('test');
})
.then(function (branchHash) {
return Q.ninvoke(project, 'loadObject', branchHash);
})
.then(function (commitObject) {
return Q.ninvoke(core, 'loadRoot', commitObject.root);
})
.then(root => core.loadChildren(root))
.then(children => [initModels, children]);
};
describe('run test cases', function() {
var cases = fs.readdirSync(TEST_CASE_DIR)
.filter(name => path.extname(name) === '.lua')
@@ -186,4 +189,9 @@ describe('ImportTorch', function () {
// one test for each test name
cases.forEach(name => it(`should run test "${name}"`, runTest.bind(this, name)));
});
it('should support "require \'rnn\'"', function(done) {
importTorch('test', 'require \'nn\'\nrequire \'rnn\'')
.nodeify(done);
});
});
+2 -1
Ver Arquivo
@@ -1,4 +1,5 @@
require 'nn'
require 'rnn'
local net = nn.Sequential()
net:add(nn.SpatialConvolution(3, 64, 7, 7, 2, 2, 3, 3))
@@ -338,4 +339,4 @@ concat_183:add(net_31)
net:add(concat_183)
return net
return net
+2 -1
Ver Arquivo
@@ -1,4 +1,5 @@
require 'nn'
require 'rnn'
local net = nn.Sequential()
net:add(nn.SpatialConvolution(3, 96, 11, 11, 4, 4, 0))
@@ -24,4 +25,4 @@ net:add(nn.Threshold(0, 0.000001))
net:add(nn.Linear(4096, 7))
net:add(nn.LogSoftMax())
return net
return net
+29
Ver Arquivo
@@ -1,4 +1,32 @@
{
"Recurrent": [
"BiSequencer",
"BiSequencerLM",
"GRU",
"MaskZero",
"MaskZeroCriterion",
"Recurrence",
"Recurrent",
"RecurrentAttention",
"Recursor",
"Repeater",
"RepeaterCriterion",
"Sequencer",
"SequencerCriterion",
"TrimZero",
"CopyGrad",
"FastLSTM",
"LSTM",
"LookupTableMaskZero",
"NormStabilizer",
"SAdd",
"SeqBRNN",
"SeqGRU",
"SeqLSTM",
"SeqLSTMP",
"SeqReverseSequence"
],
"Convolution": [
"TemporalConvolution",
"TemporalMaxPooling",
@@ -65,6 +93,7 @@
],
"Simple": [
"Linear",
"LinearNoBias",
"SparseLinear",
"Dropout",
"Abs",
+37 -7
Ver Arquivo
@@ -7,20 +7,29 @@ var fs = require('fs'),
skipLayerList = require('./skipLayers.json'),
categories = require('./categories.json'),
SKIP_ARGS = require('./skipArgs.json'),
catNames = Object.keys(categories),
exists = require('exists-file'),
configDir = process.env.HOME + '/.deepforge/',
configPath = configDir + 'config.json',
layerToCategory = {},
outputName = 'nn',
outputDst = 'src/plugins/CreateTorchMeta/schemas/',
config;
// Check the deepforge config
if (process.argv[2]) {
outputName = process.argv[2];
}
// Find the given package in the torch installation
torchPath = process.env.HOME + '/torch';
if (exists.sync(configPath)) {
if (exists.sync(configPath)) { // Check the deepforge config
config = JSON.parse(fs.readFileSync(configPath, 'utf8'));
torchPath = (config.torch && config.torch.dir) || (configDir + 'torch');
}
torchPath += '/extra/nn/';
torchPath += `/install/share/lua/5.1/${outputName}/`;
console.log(`parsing ${outputName} from ${torchPath}`);
skipLayerList.forEach(name => SKIP_LAYERS[name] = true);
catNames.forEach(cat => // create layer -> category dictionary
@@ -40,7 +49,6 @@ fs.readdir(torchPath, function(err,files){
layerByName = {};
layers = files.filter(filename => path.extname(filename) === '.lua')
//.filter(filename => filename === 'SpatialAveragePooling.lua')
.map(filename => fs.readFileSync(torchPath + filename, 'utf8'))
.map(code => LayerParser.parse(code))
.filter(layer => !!layer && layer.name);
@@ -53,17 +61,39 @@ fs.readdir(torchPath, function(err,files){
// handle inheritance
layers.forEach(layer => {
var iter = layer,
params = layer.params;
params = layer.params,
unsupArgs = SKIP_ARGS[layer.name],
i;
while (iter && params === undefined) {
params = iter.params;
iter = layerByName[iter.baseType];
}
// Remove any unsupported (optional) args
if (unsupArgs) {
for (var k = params.length; k--;) {
i = unsupArgs.indexOf(params[k]);
if (i !== -1) {
// eslint-disable-next-line no-console
console.log(`Removing "${params[k]}" param from ${layer.name}`);
params = params.splice(0, k);
}
}
}
layer.params = params;
});
layers = layers.filter(layer => !SKIP_LAYERS[layer.name]);
outputDst += outputName + '.json';
// eslint-disable-next-line no-console
console.log('Saved nn interface to src/common/layers.json');
fs.writeFileSync('src/common/layers.json', JSON.stringify(layers, null, 2));
console.log('Saved nn interface to ' + outputDst);
fs.writeFileSync(outputDst, JSON.stringify(layers, null, 2));
// Update the CreateTorchMeta index
var updateSchemas = `${__dirname}/../src/plugins/CreateTorchMeta/update-schemas.js`,
job = require('child_process').fork(updateSchemas);
job.on('close', code => {
process.exit(code);
});
});
+5
Ver Arquivo
@@ -0,0 +1,5 @@
{
"SeqBRNN": [
"merge"
]
}
+20
Ver Arquivo
@@ -1,4 +1,24 @@
[
"AbstractRecurrent",
"AbstractSequencer",
"BiSequencer",
"BiSequencerLM",
"GRU",
"MaskZero",
"MaskZeroCriterion",
"Recurrence",
"Recurrent",
"RecurrentAttention",
"Recursor",
"Repeater",
"RepeaterCriterion",
"Sequencer",
"SequencerCriterion",
"TrimZero",
"Bottle",
"GPU",
"Sequential",
"Container",
"Criterion",