Fixes #5 Added basic arch generation

WIP #5 Added basic arch generation

WIP #5 Added tests for GenerateArchitecture

WIP #5. Added more transfer function support and test

WIP #5 Added "Add" layer support
Esse commit está contido em:
Brian Broll
2016-03-21 17:06:41 -05:00
commit d1a2c1ce80
20 arquivos alterados com 433 adições e 3 exclusões
+5
Ver Arquivo
@@ -1,3 +1,4 @@
// jshint node: true
'use strict';
var config = require('./config.base'),
@@ -7,5 +8,9 @@ var config = require('./config.base'),
// config.server.port = 8080;
// config.mongo.uri = 'mongodb://127.0.0.1:27017/webgme_my_app';
// Seeds for development are prefixed with 'dev'
config.seedProjects.basePaths = config.seedProjects.basePaths
.filter(path => path.indexOf('dev') === -1);
validateConfig(config);
module.exports = config;
+3
Ver Arquivo
@@ -9,9 +9,11 @@ var config = require('webgme/config/config.default'),
// The paths can be loaded from the webgme-setup.json
config.plugin.basePaths.push('src/plugins');
config.plugin.basePaths.push('node_modules/webgme-simple-nodes/src/plugins');
config.visualization.layout.basePaths.push('node_modules/webgme-chflayout/src/layouts');
config.seedProjects.basePaths.push('src/seeds/DevMinimal');
config.seedProjects.basePaths.push('src/seeds/nn');
config.seedProjects.basePaths.push('src/seeds/devTests');
@@ -28,6 +30,7 @@ config.requirejsPaths = {
'BreadcrumbHeader': 'panels/BreadcrumbHeader/BreadcrumbHeaderPanel',
'FloatingActionButton': 'panels/FloatingActionButton/FloatingActionButtonPanel',
'CHFLayout': 'node_modules/webgme-chflayout/src/layouts/CHFLayout',
'TemplateCreator': 'node_modules/webgme-simple-nodes/src/plugins/TemplateCreator',
'panels': './src/visualizers/panels',
'widgets': './src/visualizers/widgets',
'panels/AutoViz': './node_modules/webgme-autoviz/src/visualizers/panels/AutoViz',
+3 -1
Ver Arquivo
@@ -11,9 +11,11 @@
"webgme-autoviz": "dfst/webgme-autoviz",
"webgme-breadcrumbheader": "dfst/webgme-breadcrumbheader",
"webgme-chflayout": "^1.0.0",
"webgme-fab": "dfst/webgme-fab"
"webgme-fab": "dfst/webgme-fab",
"webgme-simple-nodes": "brollb/webgme-simple-nodes"
},
"devDependencies": {
"jszip": "^2.5.0",
"mocha": "^2.2.5",
"rimraf": "^2.4.0",
"chai": "^3.0.0"
+60
Ver Arquivo
@@ -0,0 +1,60 @@
define([
], function(
) {
'use strict';
var dimensionality = function(type, attr, prev) {
if (!dimensionality[type]) {
// This will be tricky with custom layers...
// TODO
throw 'Cannot determine dimensionality of ' + type;
}
return dimensionality[type](type, attr, prev);
};
// Currently, this is done by the meta type of the given layer.
// It would probably be more extensible to have "types of types" or,
// rather, an enumeration of dimensionality calculation techniques
// that the layer's meta type registers to FIXME
dimensionality.Reshape = function(type, attr, prev) {
return attr.dimensions || 1;
};
dimensionality.Linear = function(type, attr, prev) {
return attr.output || 1;
};
var PassThru = function(type, attr, prev) {
if (!prev) {
throw 'Cannot determine prev args of ' + type;
}
return dimensionality.apply(null, prev());
};
[ // pass through layers -> same dim as predecessor
'HardTanh',
'HardShrink',
'SoftShrink',
'SoftMax',
'SoftMin',
'SoftPlus',
'SoftSign',
'LogSigmoid',
'LogSoftMax',
'Sigmoid',
'Tanh',
'ReLU',
'PReLU',
'RReLU',
'LeakyReLU',
'AddConstant',
'MulConstant',
// Math
'Mul', // Does this really leave the size the same?
'CMul',
'Add'
].forEach(layer => dimensionality[layer] = PassThru);
return dimensionality;
});
+42
Ver Arquivo
@@ -0,0 +1,42 @@
define([
'text!deepforge/layers.yml',
'deepforge/js-yaml.min'
], function(
LAYER_TEXT,
yaml
) {
'use strict';
// default arg
var cleanArg = function(arg) {
var name = Object.keys(arg)[0],
result = arg[name] || {type: 'integer'};
result.name = name;
return result;
};
// Create the layer dictionary
var LayerDict = {},
layerObj = yaml.load(LAYER_TEXT),
absLayers = Object.keys(layerObj),
layer,
layers;
// Basically, create a dictionary of the second level of keys
for (var i = absLayers.length; i--;) {
layers = layerObj[absLayers[i]];
for (var j = layers.length; j--;) {
layer = layers[j];
if (typeof layer === 'string') {
LayerDict[layers[j]] = [];
} else {
layer = Object.keys(layer)[0];
LayerDict[layer] = (layers[j][layer] || [])
.map(cleanArg);
}
}
}
return LayerDict;
});
+2
Ver Arquivo
@@ -1,3 +1,5 @@
# This file should actually be an alternative way of viewing the metamodel.
#
# This contains metadata about the Torch nn library used for
# creating the metamodel
#
+1 -1
Ver Arquivo
@@ -8,7 +8,7 @@
define([
'plugin/PluginConfig',
'plugin/PluginBase',
'./js-yaml.min.js',
'deepforge/js-yaml.min',
'text!deepforge/layers.yml'
], function (
PluginConfig,
@@ -0,0 +1,131 @@
/*globals define*/
/*jshint node:true, browser:true*/
/**
* Generated by PluginGenerator 0.14.0 from webgme on Sun Mar 20 2016 16:49:12 GMT-0500 (CDT).
*/
define([
'TemplateCreator/TemplateCreator',
//'TemplateCreator/Constants'
'TemplateCreator/templates/Constants',
'deepforge/layer-args',
'deepforge/dimensionality',
'underscore'
], function (
PluginBase,
Constants,
LAYER_ARGS,
dimensionality,
_
) {
'use strict';
/**
* Initializes a new instance of GenerateArchitecture.
* @class
* @augments {PluginBase}
* @classdesc This class represents the plugin GenerateArchitecture.
* @constructor
*/
var GenerateArchitecture = function () {
// Call base class' constructor.
PluginBase.call(this);
this.generator = this;
};
// Prototypal inheritance from PluginBase.
GenerateArchitecture.prototype = Object.create(PluginBase.prototype);
GenerateArchitecture.prototype.constructor = GenerateArchitecture;
/**
* Gets the name of the GenerateArchitecture.
* @returns {string} The name of the plugin.
* @public
*/
GenerateArchitecture.prototype.getName = function () {
return 'GenerateArchitecture';
};
/**
* Gets the semantic version (semver.org) of the GenerateArchitecture.
* @returns {string} The version of the plugin.
* @public
*/
GenerateArchitecture.prototype.getVersion = function () {
return '0.1.0';
};
GenerateArchitecture.prototype.createOutputFiles = function (tree) {
var layers = tree[Constants.CHILDREN],
result = {},
template,
snippet,
code,
args;
code = [
'require \'nn\'',
'',
'model = nn.Sequential()'
].join('\n');
// Start with sequential (just one input)
for (var i = 0; i < layers.length; i++) {
if (layers[i][Constants.NEXT].length > 1) {
// no support for
console.error('No support for parallel layers... yet');
break;
} else {
// args
args = GenerateArchitecture.createArgString(layers[i]);
template = _.template('model:add(nn.{{= name }}' + args + ')');
snippet = template(layers[i]);
code += '\n' + snippet;
}
}
result[tree.name + '.lua'] = code;
return result;
};
GenerateArchitecture.createArgString = function (layer) {
if (CreateLayerArgs[layer.name]) {
return '(' + CreateLayerArgs[layer.name](layer).join(', ') + ')';
}
// fall back on default...
return '(' + LAYER_ARGS[layer.name].map(arg => layer[arg.name]) + ')';
};
GenerateArchitecture.getDimArgs = function (layer) {
var prev = layer[Constants.PREV][0], // Assuming all inputs have the same dims
fn = null;
// Only return getDimArgs if
if (prev[Constants.PREV][0]) {
fn = GenerateArchitecture.getDimArgs.bind(null, prev);
}
return [prev.name, prev, fn];
};
// Custom Layer Argument Generators
// These return an array of argument values
var CreateLayerArgs = {};
CreateLayerArgs.Linear = function(layer) {
var args = GenerateArchitecture.getDimArgs(layer),
dims = dimensionality.apply(null, args);
return [dims, layer.output];
};
CreateLayerArgs.Add = function(layer) {
var args = GenerateArchitecture.getDimArgs(layer),
dims = dimensionality.apply(null, args);
return [dims, layer.isScalar];
};
return GenerateArchitecture;
});
Arquivo binário não exibido.
@@ -0,0 +1,137 @@
/*jshint node:true, mocha:true*/
/**
* Generated by PluginGenerator 0.14.0 from webgme on Sun Mar 20 2016 16:49:12 GMT-0500 (CDT).
*/
'use strict';
var testFixture = require('../../globals'),
path = testFixture.path,
jszip = require('jszip'),
fs = require('fs'),
TEST_CASE_DIR = path.join(__dirname, '..', 'test-cases'),
SEED_DIR = path.join(testFixture.SEED_DIR, '..', 'devTests');
describe('GenerateArchitecture', function () {
var gmeConfig = testFixture.getGmeConfig(),
expect = testFixture.expect,
logger = testFixture.logger.fork('GenerateArchitecture'),
PluginCliManager = testFixture.WebGME.PluginCliManager,
BlobClient = require('webgme/src/server/middleware/blob/BlobClientWithFSBackend'),
projectName = 'testProject',
pluginName = 'GenerateArchitecture',
project,
gmeAuth,
storage,
commitHash;
before(function (done) {
testFixture.clearDBAndGetGMEAuth(gmeConfig, projectName)
.then(function (gmeAuth_) {
gmeAuth = gmeAuth_;
// This uses in memory storage. Use testFixture.getMongoStorage to persist test to database.
storage = testFixture.getMemoryStorage(logger, gmeConfig, gmeAuth);
return storage.openDatabase();
})
.then(function () {
var importParam = {
projectSeed: testFixture.path.join(SEED_DIR, 'devTests.zip'),
projectName: projectName,
branchName: 'master',
logger: logger,
gmeConfig: gmeConfig
};
return testFixture.importProject(storage, importParam);
})
.then(function (importResult) {
project = importResult.project;
commitHash = importResult.commitHash;
return project.createBranch('test', commitHash);
})
.nodeify(done);
});
after(function (done) {
storage.closeDatabase()
.then(function () {
return gmeAuth.unload();
})
.nodeify(done);
});
it('should run plugin and not update the branch', function (done) {
var manager = new PluginCliManager(null, logger, gmeConfig),
pluginConfig = {
},
context = {
project: project,
commitHash: commitHash,
branchName: 'test',
activeNode: '/960660211',
};
manager.executePlugin(pluginName, pluginConfig, context, function (err, pluginResult) {
expect(err).to.equal(null);
expect(typeof pluginResult).to.equal('object');
expect(pluginResult.success).to.equal(true);
project.getBranchHash('test')
.then(function (branchHash) {
expect(branchHash).to.equal(commitHash);
})
.nodeify(done);
});
});
describe('test cases', function() {
var cases = [
['/R', 'basic.lua'],
['/e', 'basic-transfers.lua']
];
var runTest = function(pair, done) {
var id = pair[0],
name = pair[1],
manager = new PluginCliManager(null, logger, gmeConfig),
pluginConfig = {
},
context = {
project: project,
commitHash: commitHash,
branchName: 'test',
activeNode: id,
},
expected = fs.readFileSync(path.join(TEST_CASE_DIR, name), 'utf8');
manager.executePlugin(pluginName, pluginConfig, context, function (err, pluginResult) {
var codeHash = pluginResult.artifacts[0];
expect(err).to.equal(null);
expect(typeof pluginResult).to.equal('object');
expect(pluginResult.success).to.equal(true);
// Retrieve the code from the blob and check it!
var blobClient = new BlobClient(gmeConfig, logger);
blobClient.getObject(codeHash, (err, obj) => {
// Unzip first...
var zip = new jszip(),
filename,
actual;
zip.load(obj);
filename = Object.keys(zip.files)
.filter(name => name.indexOf('.lua') > -1)[0];
actual = zip.files[filename].asText();
expect(actual).to.equal(expected);
done();
});
});
};
cases.forEach(pair => {
it(`should correctly evaluate ${pair[0]} (${pair[1]})`,
runTest.bind(this, pair));
});
});
});
+14
Ver Arquivo
@@ -0,0 +1,14 @@
require 'nn'
model = nn.Sequential()
model:add(nn.Reshape(100))
model:add(nn.Linear(100, 300))
model:add(nn.RReLU())
model:add(nn.Linear(300, 100))
model:add(nn.ReLU())
model:add(nn.Linear(100, 100))
model:add(nn.Sigmoid())
model:add(nn.Linear(100, 120))
model:add(nn.LeakyReLU())
model:add(nn.Linear(120, 5))
model:add(nn.SoftMax())
+7
Ver Arquivo
@@ -0,0 +1,7 @@
require 'nn'
model = nn.Sequential()
model:add(nn.Reshape(100))
model:add(nn.Linear(100, 300))
model:add(nn.HardTanh())
model:add(nn.Linear(300, 10))
+6
Ver Arquivo
@@ -0,0 +1,6 @@
# Misc Thoughts
+ Should I create this yaml stuff by hand?
+ For now, I think it is a good idea - I can't get things like default values programmatically
+ In the future, it should be an alternative representation for the meta sheet
+ May need to rethink the syntax a little... Currently, it mixes inheritance and attributes...
+ A node with attributes cannot be subclassed
+9
Ver Arquivo
@@ -0,0 +1,9 @@
-- Script to validate a layers.yml file
require 'nn'
-- Parse the yaml file
-- TODO
-- Check that the number of arguments are correct for each constructor
-- TODO
+13 -1
Ver Arquivo
@@ -8,6 +8,10 @@
"CreateTorchMeta": {
"src": "src/plugins/CreateTorchMeta",
"test": "test/plugins/CreateTorchMeta"
},
"GenerateArchitecture": {
"src": "src/plugins/GenerateArchitecture",
"test": "test/plugins/GenerateArchitecture"
}
},
"layouts": {},
@@ -20,11 +24,19 @@
},
"nn": {
"src": "src/seeds/nn"
},
"devTests": {
"src": "src/seeds/devTests"
}
}
},
"dependencies": {
"plugins": {},
"plugins": {
"TemplateCreator": {
"project": "webgme-simple-nodes",
"path": "node_modules/webgme-simple-nodes/src/plugins/TemplateCreator"
}
},
"layouts": {
"CHFLayout": {
"project": "webgme-chflayout",