Comparar commits

...

6 Commits

Autor SHA1 Mensagem Data
Brian Broll 708ef3f48a WIP Minor changes for GenArch 2016-11-26 14:26:07 -06:00
Brian Broll 4b14c74733 WIP Changed code content to use python 2016-11-26 12:11:14 -06:00
Brian Broll cf6e6dd4e5 WIP Updated code editors to use python for comments and content 2016-11-26 12:10:22 -06:00
Brian Broll 88a57a5af9 WIP Fixed the boolean values to use True/False 2016-11-26 11:42:30 -06:00
Brian Broll d30d990330 WIP Updated nn-parser and nn for pytorch layers 2016-11-26 11:12:02 -06:00
Brian Broll 96720c3140 WIP Added some basic layer parsing support
name, baseType, defaults, and types (for the defaults). Still need
to verify named args work though (and more of the types work)...
2016-11-25 16:22:53 -06:00
22 arquivos alterados com 1164 adições e 2792 exclusões
+1
Ver Arquivo
@@ -34,6 +34,7 @@
"webgme-simple-nodes": "^2.1.0"
},
"devDependencies": {
"brython": "^3.2.7",
"chai": "^3.0.0",
"jszip": "^2.5.0",
"mocha": "^2.2.5",
+97 -348
Ver Arquivo
@@ -1,378 +1,127 @@
/* globals define*/
(function(root, factory){
if(typeof define === 'function' && define.amd) {
define(['./lua'], function(luajs){
return (root.LayerParser = factory(luajs));
// TODO: Load the brython script
define(['./lua'], function(brython){
return (root.LayerParser = factory(brython, console.assert));
});
} else if(typeof module === 'object' && module.exports) {
var luajs = require('./lua');
module.exports = (root.LayerParser = factory(luajs));
var brython = require('./node-brython'),
assert = require('assert');
module.exports = (root.LayerParser = factory(brython, assert));
}
}(this, function(luajs) {
}(this, function(brython, assert) {
var LayerParser = {};
//////////////////////// Setters ////////////////////////
var returnsSelf = function(fnNode){
var stats = fnNode.block.stats,
last = stats[stats.length-1];
function build_ast(src) {
brython.$py_module_path['__main__']='./'
return brython.py2js(src,'__main__', '__main__', '__builtins__')
}
if (last.type === 'stat.return') {
return last.nret[0].type === 'variable' && last.nret[0].val === 'self';
}
return false;
};
var isAttrSetter = function(node){
if (node.type === 'stat.assignment' && node.lefts.length === 1) {
var left = node.lefts[0];
return left.type === 'expr.index' && left.self.val === 'self';
}
return false;
};
var getSettingAttrName = function(node){
if (isAttrSetter(node)) {
var left = node.lefts[0];
return left.key.val;
}
return null;
};
var getSettingAttrValue = function(node){
if (isAttrSetter(node)) {
return node.right;
}
return null;
};
var isSetterMethod = function(curr, parent, className){
if (parent && parent.type === 'stat.method') {
// is it a fn w/ two statements (stats)
if (parent.self.val === className && curr.type === 'function' &&
curr.block.stats.length === 2) {
// Is the first statement setting a value?
return returnsSelf(curr) && getSettingAttrName(curr.block.stats[0]); // does it return itself?
// The provided tree gives us contexts which can have associated 'C'
function traverse(node, fn) {
var i;
if (node.children) {
for (i = node.children.length; i--;) {
traverse(node.children[i], fn);
fn(node.children[i]);
}
}
return false;
};
var isFnArg = function(method, name) {
return method.args.indexOf(name) !== -1;
};
var getSetterSchema = function(node, method) {
var setterType,
setterFn,
value = getSettingAttrValue(node);
if (value[0].type === 'variable' && isFnArg(method.func, value[0].val)) {
setterType = 'arg';
setterFn = method.key.val;
} else {
setterType = 'const';
setterFn = {};
setterFn[value[0].val] = method.key.val;
}
return {
setterType,
setterFn
};
};
//////////////////////// Setters END ////////////////////////
var isInitFn = function(node, className) {
if (node.type === 'stat.method' && node.self.val === className) {
return node.key.val === '__init';
}
return false;
};
var getClassAttrDefs = function(method) {
var fn = method.func,
dict = {},
attr,
right,
value;
luajs.codegen.traverse(curr => {
if (isAttrSetter(curr)) {
// Store the value if it is set to a constant
attr = curr.lefts[0].key.val;
right = curr.right[0];
if (right.type.indexOf('const.') !== -1) {
value = right.val;
if (right.type === 'const.nil') {
value = null;
}
dict[attr] = value;
}
}
})(fn);
return dict;
};
var getAttrsAndVals = function(method) {
// Given a method, get the 'self' attributes and the default values
var fn = method.func,
dict = {},
varName,
value,
varUsageCnt = {};
// Get the variables that are used only once (or updating themselves)
luajs.codegen.traverse(curr => {
if (curr.type === 'variable') {
varUsageCnt[curr.val] = varUsageCnt[curr.val] ?
varUsageCnt[curr.val] + 1 : 1;
}
})(method);
luajs.codegen.traverse(curr => {
// If the variable is only used once and is 'or'-ed w/ a constant
// during this use, we can infer that this is the default value
if (curr.type === 'expr.op' && curr.op === 'op.or' &&
curr.left.type === 'variable' && curr.right.type.indexOf('const') !== -1) {
varName = curr.left.val;
if (varUsageCnt[varName] === 1) {
value = curr.right.type === 'const.nil' ? null : curr.right;
dict[varName] = value;
}
}
})(fn);
return dict;
};
var copyNodeValues = function(attrs, from, to) {
var value;
for (var i = attrs.length; i--;) {
value = from[attrs[i]] || null;
if (value) {
value = (value && value.hasOwnProperty('val')) ? value.val : value;
to[attrs[i]] = value;
if (node.C && node.C.tree) {
for (i = node.C.tree.length; i--;) {
traverse(node.C.tree[i], fn);
fn(node.C.tree[i]);
}
}
return to;
};
}
var getTypeCheckInfo = function(cond) {
var caller,
method,
target,
expType;
var types = {},
layers = [],
pCtx,
classNode,
params;
// Check for torch.isTypeOf:
if (cond.type === 'expr.call' && cond.func.type === 'expr.index') {
caller = cond.func.self.val;
method = cond.func.key.val;
function isClass(node) {
return node.type === 'class';
}
if (cond.type === 'expr.call' && caller === 'torch') {
target = cond.args[0].val;
if (method === 'isTypeOf' && target) {
expType = cond.args[1].val;
return {
target,
type: expType
};
}
}
} else if (cond.type === 'expr.op') { // torch.type() === ''
// Check right side, too!
var sides = [cond.left, cond.right],
side,
otherSide;
function isInitFn(node) {
return node.type === 'def' && node.name === '__init__';
}
for (var i = sides.length; i--;) {
side = sides[i];
otherSide = sides[(i+1)%2];
if (side.type === 'expr.call' && side.func.type === 'expr.index') {
// Is it torch?
caller = side.func.self.val;
method = side.func.key.val;
if (caller === 'torch' && method === 'type') {
if (side.args[0].type === 'variable') {
target = side.args[0].val;
if (otherSide.type === 'const.string') {
expType = otherSide.val;
function getBaseClass(node) {
assert(node.type === 'class');
return node.args.tree[0].tree[0].tree[0].value;
}
return {
target: target,
type: expType
};
function findTorchLayers(root) {
var defaults = {},
layers = [],
defTypes,
args,
def;
traverse(root, node => {
// Get the class for the given function
if (isInitFn(node)) {
// TODO: What if there is no constructor? Is this a potential problem?
pCtx = node.parent.node.parent;
classNode = pCtx.C.tree[0];
if (isClass(classNode)) {
// remove the 'self' variable
// TODO: May need to update this for kwargs
// (use positional_list)
args = node.tree[1].tree;
defaults = {};
params = node.args.slice(1);
defTypes = {};
for (var i = args.length; i--;) {
if (args[i].tree[0]) {
def = args[i].tree[0].tree[0];
if (def.type === 'int') {
defaults[params[i-1]] = parseInt.apply(null, def.value.reverse());
} else {
defaults[params[i-1]] = def.value;
}
if (/^(True|False)$/.test(defaults[params[i-1]])) {
defTypes[params[i-1]] = 'boolean';
} else {
defTypes[params[i-1]] = def.type;
}
}
}
layers.push({
name: classNode.name,
baseType: getBaseClass(classNode),
//doc: classNode.doc_string || '',
defaults: defaults,
types: defTypes,
setters: {},
params: params
});
}
}
return null;
}
};
var isError = function(stat) {
var fn;
if (stat.type === 'stat.expr' && stat.expr.type === 'expr.call') {
fn = stat.expr.func.val;
return fn === 'error';
}
return false;
};
var inferParamTypes = function(node, paramDefs) {
var types = {},
check,
cond;
// Infer from assertions
luajs.codegen.traverse(curr => {
// check for 'assert's that check type
if (curr.type === 'expr.call' && curr.func.val === 'assert') {
cond = curr.args[0];
check = getTypeCheckInfo(cond);
if (check) {
types[check.target] = check.type;
}
} else if (curr.type === 'stat.if' && curr.cond.op === 'uop.not') {
// if statements throwing errors on type mismatch
cond = curr.cond.operand; // non-negated version
// Check that it throws an error on true
if (curr.tblock.stats.some(isError)) {
check = getTypeCheckInfo(cond);
if (check) {
types[check.target] = check.type;
}
}
}
})(node);
// Infer from defaults
Object.keys(paramDefs).forEach(param => {
var val = paramDefs[param];
if (val) { // initialized to 'null' doesn't help us...
types[param] = val.type.replace('const.', '');
}
});
return types;
};
return layers;
}
var findTorchClass = function(ast){
var torchClassArgs, // args for `torch.class(...)`
name = '',
alias,
baseType,
params,
setters = {},
defaults = {},
paramDefs,
attrDefs;
if(ast.type == 'function'){
ast.block.stats.forEach(function(func){
if(func.type == 'stat.local' && func.right && func.right[0] &&
func.right[0].func && func.right[0].func.self &&
func.right[0].func.self.val == 'torch' &&
func.right[0].func.key.val == 'class'){
torchClassArgs = func.right[0].args.map(arg => arg.val);
name = torchClassArgs[0];
if(name !== ''){
name = name.replace('nn.', '');
alias = func.names[0] || name;
if (torchClassArgs.length > 1) {
baseType = torchClassArgs[1].replace('nn.', '');
}
}
}
});
}
// Get the setters, defaults and type info (inferred)
var setterNames,
schema,
types,
values;
luajs.codegen.traverse((curr, parent) => {
var firstLine,
attrName;
// Record the setter functions
if (isSetterMethod(curr, parent, alias)) {
firstLine = curr.block.stats[0];
// just use the attribute attrName for now...
attrName = getSettingAttrName(firstLine);
// merge schemas
schema = getSetterSchema(firstLine, parent);
if (setters[attrName] && setters[attrName].setterType === 'const') { // merge
for (var val in schema.setterFn) {
setters[attrName].setterFn[val] = schema.setterFn[val];
}
} else {
setters[attrName] = schema;
}
} else if (isInitFn(curr, alias)) { // Record the defaults
paramDefs = getAttrsAndVals(curr);
attrDefs = getClassAttrDefs(curr);
types = inferParamTypes(curr, paramDefs);
// get ctor args
params = curr.func.args;
if(params.length === 0 && curr.func.varargs){
params.push('params');
}
}
})(ast);
// Get the defaults for the params from defs
if (paramDefs && params) {
copyNodeValues(params, paramDefs, defaults);
}
// Get the defaults for the setters from attrDefs
if (attrDefs) {
setterNames = Object.keys(setters);
copyNodeValues(setterNames, attrDefs, defaults);
}
// Remove any const setters w/ only one value and no default
setterNames = Object.keys(setters);
for (var i = setterNames.length; i--;) {
schema = setters[setterNames[i]];
if (schema.setterType === 'const') {
values = Object.keys(schema.setterFn);
if (values.length === 1 &&
// boolean setters can have the default value inferred
values[0] !== 'true' && values[0] !== 'false' &&
!defaults[setterNames[i]]) {
delete setters[setterNames[i]];
}
}
}
return {
name,
baseType,
params,
setters,
types,
defaults
};
};
LayerParser.parse = function(text) {
// Try to find the class definitions...
//
// Need to create:
//
// setters: (I don't think these are used in pytorch!
// types:
// type:
//////////////////////// Setters ////////////////////////
LayerParser.parse = function(src) {
try {
var ast = luajs.parser.parse(text);
return findTorchClass(ast);
brython.$py_module_path['__main__']='./';
var ast = brython.py2js(src,'__main__', '__main__', '__builtins__');
var layers = findTorchLayers(ast);
return layers;
} catch (e) {
return null;
}
+176
Ver Arquivo
@@ -0,0 +1,176 @@
/*
Author: Billy Earney
Date: 04/19/2013
License: MIT
Description: This file can work as a "bridge" between nodejs and brython
so that client side brython code can be executed on the server side.
Will brython replace Cython one day? Only time will tell.
:)
*/
var fs = require('fs'),
path = require('path'),
brythonSrcPath = path.join(__dirname, '..', '..', 'node_modules', 'brython', 'www', 'src', 'brython.js');
document={};
document.getElementsByTagName = () => [{src: ''}];
window={};
window.location = {href: ''};
window.navigator={}
window.confirm = () => true;
window.console = console;
document.$py_src = {}
document.$debug = 0
self={};
__BRYTHON__={}
__BRYTHON__.$py_module_path = {}
__BRYTHON__.$py_module_alias = {}
__BRYTHON__.$py_next_hash = -Math.pow(2,53)
__BRYTHON__.exception_stack = []
__BRYTHON__.scope = {}
__BRYTHON__.modules = {}
// Read and eval library
jscode = fs.readFileSync(brythonSrcPath, 'utf8');
eval(jscode);
//function node_import(module,alias,names) {
function $import_single(module) {
var search_path=['../src/libs', '../src/Lib'];
var ext=['.js', '.py'];
var mods=[module, module+'/__init__'];
for(var i=0, _len_i = search_path.length; i < _len_i; i++) {
for (var j=0, _len_j = ext.length; j < _len_j; j++) {
for (var k=0, _len_k = mods.length; k < _len_k; k++) {
var path=search_path[i]+'/'+mods[k]+ext[j]
//console.log("searching for " + path);
var module_contents;
try {
module_contents=fs.readFileSync(path, 'utf8')
} catch(err) {}
if (module_contents !== undefined) {
console.log("imported " + module)
//console.log(module_contents);
if (ext[j] == '.js') {
return $import_js_module(module,alias,names,path,module_contents)
}
return $import_py_module(module,alias,names,path,module_contents)
}
}
}
}
console.log("error time!");
res = Error()
res.name = 'NotFoundError'
res.message = "No module named '"+module+"'"
throw res
}
$compile_python=function(module_contents,module) {
var root = __BRYTHON__.py2js(module_contents,module)
var body = root.children
root.children = []
// use the module pattern : module name returns the results of an anonymous function
var mod_node = new $Node('expression')
//if(names!==undefined){alias='$module'}
new $NodeJSCtx(mod_node,'$module=(function()')
root.insert(0,mod_node)
mod_node.children = body
// search for module-level names : functions, classes and variables
var mod_names = []
for(var i=0, _len_i = mod_node.children.length; i < _len_i;i++){
var node = mod_node.children[i]
// use function get_ctx()
// because attribute 'context' is renamed by make_dist...
var ctx = node.get_ctx().tree[0]
if(ctx.type==='def'||ctx.type==='class'){
if(mod_names.indexOf(ctx.name)===-1){mod_names.push(ctx.name)}
} else if(ctx.type==='from') {
for (var j=0, _len_j = ctx.names.length; j < _len_j; j++) {
var name=ctx.names[j];
if (name === '*') {
// just pass, we don't want to include '*'
} else if (ctx.aliases[name] !== undefined) {
if (mod_names.indexOf(ctx.aliases[name])===-1){
mod_names.push(ctx.aliases[name])
}
} else {
if (mod_names.indexOf(ctx.names[j])===-1){
mod_names.push(ctx.names[j])
}
}
}
}else if(ctx.type==='assign'){
var left = ctx.tree[0]
if(left.type==='expr'&&left.tree[0].type==='id'&&left.tree[0].tree.length===0){
var id_name = left.tree[0].value
if(mod_names.indexOf(id_name)===-1){mod_names.push(id_name)}
}
}
}
// create the object that will be returned when the anonymous function is run
var ret_code = 'return {'
for(var i=0, _len_i = mod_names.length; i < _len_i;i++){
ret_code += mod_names[i]+':'+mod_names[i]+','
}
ret_code += '__getattr__:function(attr){return this[attr]},'
ret_code += '__setattr__:function(attr,value){this[attr]=value}'
ret_code += '}'
var ret_node = new $Node('expression')
new $NodeJSCtx(ret_node,ret_code)
mod_node.add(ret_node)
// add parenthesis for anonymous function execution
var ex_node = new $Node('expression')
new $NodeJSCtx(ex_node,')()')
root.add(ex_node)
try{
var js = root.to_js()
return js;
}catch(err){
eval('throw '+err.name+'(err.message)')
}
return undefined;
}
function build_ast(src) {
__BRYTHON__.$py_module_path['__main__']='./'
return __BRYTHON__.py2js(src,'__main__', '__main__', '__builtins__')
}
function execute_python_script(filename) {
_py_src=fs.readFileSync(filename, 'utf8')
var root = build_ast(_py_src)
var js = root.to_js()
//eval(js);
}
//console.log("try to execute compile script");
__BRYTHON__.$py_module_path = __BRYTHON__.$py_module_path || {}
__BRYTHON__.$py_module_alias = __BRYTHON__.$py_module_alias || {}
__BRYTHON__.exception_stack = __BRYTHON__.exception_stack || []
__BRYTHON__.scope = __BRYTHON__.scope || {}
__BRYTHON__.imported = __BRYTHON__.imported || {}
__BRYTHON__.modules = __BRYTHON__.modules || {}
__BRYTHON__.compile_python=$compile_python
__BRYTHON__.debug = 0
__BRYTHON__.$options = {}
__BRYTHON__.$options.debug = 0
// other import algs don't work in node
//import_funcs=[node_import]
if (!module.parent) {
var filename=process.argv[2];
execute_python_script(filename)
}
module.exports = __BRYTHON__;
+9 -3
Ver Arquivo
@@ -205,9 +205,10 @@ define([
};
// Some helper methods w/ attribute handling
var LUA_TO_GME = {
var PYTHON_TO_GME = {
boolean: 'boolean',
number: 'float',
float: 'float',
int: 'integer',
string: 'string'
};
@@ -301,7 +302,7 @@ define([
attrs.forEach(name => {
desc = {};
defVal = defaults.hasOwnProperty(name) ? defaults[name] : '';
type = LUA_TO_GME[types[name]];
type = PYTHON_TO_GME[types[name]];
if (type) {
desc.type = type;
}
@@ -376,6 +377,11 @@ define([
// Set the min, max
schema.max = +schema.max;
}
// Add the enum for booleans so we use python style True/False
if (schema.type === 'boolean') {
schema.enum = ['True', 'False'];
schema.type = 'string';
}
// Create the attribute and set the schema
this.core.setAttributeMeta(node, name, schema);
-1
Ver Arquivo
@@ -17,7 +17,6 @@
"value": "all",
"valueItems": [
"nn",
"rnn",
"all"
],
"valueType": "string",
+3 -6
Ver Arquivo
@@ -1,13 +1,10 @@
/*globals define*/
define([
'text!./nn.json',
'text!./rnn.json'
'text!./nn.json'
], function(
nn,
rnn
nn
) {
return {
nn: nn,
rnn: rnn
nn: nn
};
});
Diferenças do arquivo suprimidas por serem muito extensas Carregar Diff
-178
Ver Arquivo
@@ -1,178 +0,0 @@
[
{
"name": "CopyGrad",
"baseType": "Identity",
"setters": {},
"defaults": {},
"type": "RNN"
},
{
"name": "FastLSTM",
"baseType": "LSTM",
"params": [
"inputSize",
"outputSize",
"rho",
"eps",
"momentum",
"affine"
],
"setters": {},
"types": {
"eps": "number",
"momentum": "number"
},
"defaults": {
"momentum": 0.1,
"eps": 0.1
},
"type": "RNN"
},
{
"name": "LSTM",
"baseType": "AbstractRecurrent",
"params": [
"inputSize",
"outputSize",
"rho",
"cell2gate"
],
"setters": {},
"types": {
"rho": "number"
},
"defaults": {
"rho": 9999
},
"type": "RNN"
},
{
"name": "LinearNoBias",
"baseType": "Linear",
"params": [
"inputSize",
"outputSize"
],
"setters": {},
"types": {},
"defaults": {},
"type": "Simple"
},
{
"name": "LookupTableMaskZero",
"baseType": "LookupTable",
"params": [
"nIndex",
"nOutput"
],
"setters": {},
"types": {},
"defaults": {},
"type": "RNN"
},
{
"name": "NormStabilizer",
"baseType": "AbstractRecurrent",
"params": [
"beta"
],
"setters": {},
"defaults": {},
"type": "RNN"
},
{
"name": "Recurrent",
"baseType": "AbstractRecurrent",
"params": [
"start",
"input",
"feedback",
"transfer",
"rho",
"merge"
],
"setters": {},
"types": {
"start": "nn.Module",
"transfer": "nn.Module",
"feedback": "nn.Module",
"input": "nn.Module"
},
"defaults": {},
"type": "RNN"
},
{
"name": "SAdd",
"baseType": "Module",
"params": [
"addend",
"negate"
],
"setters": {},
"types": {},
"defaults": {},
"type": "RNN"
},
{
"name": "SeqBRNN",
"baseType": "Container",
"params": [
"inputDim",
"hiddenDim",
"batchFirst"
],
"setters": {},
"types": {},
"defaults": {},
"type": "RNN"
},
{
"name": "SeqGRU",
"baseType": "Module",
"params": [
"inputSize",
"outputSize"
],
"setters": {},
"types": {},
"defaults": {},
"type": "RNN"
},
{
"name": "SeqLSTM",
"baseType": "Module",
"params": [
"inputsize",
"hiddensize",
"outputsize"
],
"setters": {},
"types": {},
"defaults": {},
"type": "RNN"
},
{
"name": "SeqLSTMP",
"baseType": "SeqLSTM",
"params": [
"inputsize",
"hiddensize",
"outputsize"
],
"setters": {},
"types": {},
"defaults": {},
"type": "RNN"
},
{
"name": "SeqReverseSequence",
"baseType": "Module",
"params": [
"dim"
],
"setters": {},
"types": {},
"defaults": {},
"type": "RNN"
}
]
@@ -72,8 +72,8 @@ define([
code = '';
this.definitions = [
'require \'nn\'',
'require \'rnn\''
'import torch',
'import torch.nn as nn'
];
// Add an index to each layer
@@ -85,6 +85,7 @@ define([
code += this.genLayerDefinitions(layers);
}
// TODO: Define the network w/ 'class ARCHITECTURE_NAME'
this.logger.debug('Generating architecture code...');
code += this.genArchCode(layers);
this.logger.debug('Prepending hoisted code...');
Arquivo binário não exibido.
+1 -1
Ver Arquivo
@@ -1 +1 @@
0.5.0
1.0.3
Arquivo binário não exibido.
+1 -1
Ver Arquivo
@@ -1 +1 @@
0.4.0
0.4.1
@@ -15,7 +15,7 @@ define([
'use strict';
var NO_CODE_MESSAGE = '-- <%= name %> is not an editable layer!',
var NO_CODE_MESSAGE = '<%= name %> is not an editable layer!',
LayerEditorControl;
LayerEditorControl = function (options) {
@@ -45,10 +45,10 @@ define([
// Retrieve the template from the mixin
template = node.getMixinPaths()
.map(id => this._client.getNode(id).getAttribute('code'))
.find(code => !!code) || NO_CODE_MESSAGE;
.find(code => !!code) || this.comment(NO_CODE_MESSAGE);
}
} else {
template = NO_CODE_MESSAGE;
template = this.comment(NO_CODE_MESSAGE);
}
if (template) {
@@ -19,9 +19,7 @@ define([
_.extend(ClassCodeEditorWidget.prototype, TextEditorWidget.prototype);
ClassCodeEditorWidget.prototype.getHeader = function(desc) {
return [
`-- The class definition for ${desc.name}`
].join('\n');
return this.comment(`The class definition for ${desc.name}`);
};
ClassCodeEditorWidget.prototype.updateNode = function() {
@@ -26,13 +26,13 @@ define([
DeserializeEditorWidget.prototype.getHeader = function(desc) {
this._name = desc.name;
return [
`-- The deserialization function for ${desc.name}`,
'-- Globals:',
'-- `path` - target filename to load',
'--',
`-- return the loaded ${desc.name}`
].join('\n');
return this.comment([
`The deserialization function for ${desc.name}`,
'Globals:',
' `path` - target filename to load',
'',
`return the loaded ${desc.name}`
].join('\n'));
};
DeserializeEditorWidget.prototype.getNameRegex = function() {
@@ -32,24 +32,24 @@ define([
OperationCodeEditorWidget.prototype.getHeader = function (desc) {
// Add comment about the inputs, attributes and references
var inputs = desc.inputs.map(pair => `-- ${pair[0]} (${pair[1]})`).join('\n'),
refs = desc.references.map(name => `-- ${name}`).join('\n'),
var inputs = desc.inputs.map(pair => `${pair[0]} (${pair[1]})`).join('\n'),
refs = desc.references.map(name => `${name}`).join('\n'),
header = [
`-- Editing "${desc.name}" Implementation`
`Editing "${desc.name}" Implementation`
];
if (inputs.length) {
header.push('--');
header.push('-- Defined variables:');
header.push('');
header.push('Defined variables:');
header.push(inputs);
}
if (refs) {
header.push(refs);
}
header.push('--');
header.push('-- The following will be executed when the operation is run:');
header.push('');
header.push('The following will be executed when the operation is run:');
return header.join('\n');
return this.comment(header.join('\n'));
};
OperationCodeEditorWidget.prototype.canAddReturnTmpl = function (desc) {
@@ -23,12 +23,12 @@ define([
SerializeEditorWidget.prototype.getHeader = function(desc) {
this._name = desc.name;
return [
`-- The serialization function for ${desc.name}`,
'-- Globals:',
'-- `path` - target filename',
`-- \`data\` - the ${desc.name} to store`
].join('\n');
return this.comment([
`The serialization function for ${desc.name}`,
'Globals:',
' `path` - target filename',
` \`data\` - the ${desc.name} to store`
].join('\n'));
};
SerializeEditorWidget.prototype.getNameRegex = function () {
@@ -14,11 +14,16 @@ define([
'use strict';
var TextEditorWidget,
WIDGET_CLASS = 'text-editor';
WIDGET_CLASS = 'text-editor',
LINE_COMMENT = {
python: '#',
lua: '--'
};
TextEditorWidget = function (logger, container) {
this._logger = logger.fork('Widget');
this.language = this.language || 'python';
this._el = container;
this._el.css({height: '100%'});
this.$editor = $('<div/>');
@@ -74,8 +79,8 @@ define([
TextEditorWidget.prototype.getSessionOptions = function () {
return {
mode: 'ace/mode/lua',
tabSize: 3,
mode: 'ace/mode/' + this.language,
tabSize: 4,
useSoftTabs: true
};
};
@@ -90,8 +95,16 @@ define([
};
// Adding/Removing/Updating items
TextEditorWidget.prototype.comment = function (text) {
var prefix = LINE_COMMENT[this.language] + ' ';
return text.replace(
new RegExp('^(' + LINE_COMMENT[this.language] + ')?','mg'),
prefix
);
};
TextEditorWidget.prototype.getHeader = function (desc) {
return `-- Editing "${desc.name}"`;
return this.comment(`Editing "${desc.name}"`);
};
TextEditorWidget.prototype.addNode = function (desc) {
+49 -138
Ver Arquivo
@@ -1,142 +1,53 @@
{
"RNN": [
"BiSequencer",
"BiSequencerLM",
"GRU",
"MaskZero",
"MaskZeroCriterion",
"Recurrence",
"Recurrent",
"RecurrentAttention",
"Recursor",
"Repeater",
"RepeaterCriterion",
"Sequencer",
"SequencerCriterion",
"TrimZero",
"Transfer": [
"PReLU",
"Softshrink",
"Softplus",
"LeakyReLU",
"Hardshrink",
"ELU",
"ReLU6",
"Hardtanh",
"RReLU",
"ReLU",
"Threshold"
],
"CopyGrad",
"FastLSTM",
"LSTM",
"LookupTableMaskZero",
"NormStabilizer",
"SAdd",
"SeqBRNN",
"SeqGRU",
"SeqLSTM",
"SeqLSTMP",
"SeqReverseSequence"
],
"Convolution": [
"TemporalConvolution",
"TemporalMaxPooling",
"TemporalSubSampling",
"LookupTable",
"SpatialConvolutionMM",
"SpatialConvolution",
"SpatialConvolutionMap",
"SpatialFullConvolutionMap",
"SpatialLPPooling",
"SpatialMaxPooling",
"SpatialAveragePooling",
"SpatialAdaptiveMaxPooling",
"SpatialSubSampling",
"SpatialUpSamplingNearest",
"SpatialZeroPadding",
"SpatialReflectionPadding",
"SpatialReplicationPadding",
"SpatialSubtractiveNormalization",
"SpatialCrossMapLRN",
"SpatialConvolutionLocal",
"SpatialDropout",
"SpatialDilatedConvolution",
"SpatialFractionalMaxPooling",
"SpatialDivisiveNormalization",
"SpatialContrastiveNormalization",
"SpatialBatchNormalization",
"SpatialFullConvolution",
"SpatialMaxUnpooling",
"VolumetricConvolution",
"VolumetricMaxPooling",
"VolumetricAveragePooling",
"VolumetricBatchNormalization",
"VolumetricDropout",
"Convolution": [
"ConvTranspose3d",
"Conv3d",
"ConvTranspose2d",
"Conv2d",
"Conv1d",
"VolumetricFullConvolution",
"VolumetricMaxUnpooling"
],
"Criterion": [
"BCECriterion",
"WeightedMSECriterion",
"SmoothL1Criterion",
"MSECriterion",
"AbsCriterion",
"MultiCriterion",
"DistKLDivCriterion",
"HingeEmbeddingCriterion",
"CriterionTable",
"MultiMarginCriterion",
"MultiLabelMarginCriterion",
"L1HingeEmbeddingCriterion",
"CosineEmbeddingCriterion",
"MarginRankingCriterion",
"CrossEntropyCriterion",
"MarginCriterion",
"ClassNLLCriterion",
"ParallelCriterion",
"SpatialClassNLLCriterion",
"SoftMarginCriterion",
"MultiLabelSoftMarginCriterion"
],
"Simple": [
"Linear",
"LinearNoBias",
"SparseLinear",
"Dropout",
"Concat",
"Abs",
"Add",
"Mul",
"CMul",
"Max",
"Min",
"Mean",
"Sum",
"Euclidean",
"WeightedEuclidean",
"Identity",
"Copy",
"Narrow",
"Replicate",
"Reshape",
"View",
"Select",
"Exp",
"Square",
"Sqrt",
"Power",
"MM",
"AddConstant",
"MulConstant"
],
"Transfer": [
"Threshold",
"HardTanh",
"HardShrink",
"SoftShrink",
"SoftMax",
"SpatialSoftMax",
"SoftMin",
"SoftPlus",
"SoftSign",
"LogSigmoid",
"LogSoftMax",
"Sigmoid",
"Tanh",
"ReLU",
"ReLU6",
"PReLU",
"RReLU",
"LeakyReLU"
]
"FractionalMaxPool2d",
"LPPool2d",
"MaxUnpool3d",
"AvgPool3d",
"MaxPool3d",
"AvgPool2d",
"MaxUnpool2d",
"MaxPool2d",
"MaxPool1d",
"ReplicationPad3d",
"ReplicationPad2d",
"ReflectionPad2d"
],
"Simple": [
"Dropout3d",
"Dropout2d",
"Dropout",
"Linear",
"Embedding"
],
"Criterion": [
"MultiMarginLoss",
"MarginRankingLoss",
"CosineEmbeddingLoss",
"CrossMapLRN2d"
]
}
+33 -16
Ver Arquivo
@@ -28,7 +28,10 @@ if (exists.sync(configPath)) { // Check the deepforge config
config = JSON.parse(fs.readFileSync(configPath, 'utf8'));
torchPath = (config.torch && config.torch.dir) || (configDir + 'torch');
}
torchPath += `/install/share/lua/5.1/${outputName}/`;
// FIXME: Get the pytorch root path...
torchPath = process.env.HOME + '/projects/pytorch';
// check 'modules', 'parallel'
torchPath += '/torch/nn/';
console.log(`parsing ${outputName} from ${torchPath}`);
@@ -51,14 +54,18 @@ var lookupType = function(layer){
return layerType || 'Misc';
};
fs.readdir(torchPath, function(err,files){
if(err) throw err;
var layers,
var parseLayerFiles = function(layerDir) {
var files = fs.readdirSync(layerDir),
layers,
layerByName = {};
layers = files.filter(filename => path.extname(filename) === '.lua')
.map(filename => fs.readFileSync(torchPath + filename, 'utf8'))
console.log('parsing', layerDir);
layers = files.filter(filename => path.extname(filename) === '.py' &&
filename[0] !== '_')
.map(filename => fs.readFileSync(layerDir + filename, 'utf8'))
.map(code => LayerParser.parse(code))
.filter(list => list !== null)
.reduce((l1, l2) => l1.concat(l2), [])
.filter(layer => !!layer && layer.name);
layers.forEach(layer => {
@@ -104,17 +111,27 @@ fs.readdir(torchPath, function(err,files){
}
});
layers = layers.filter(layer => !SKIP_LAYERS[layer.name]);
return layers;
};
outputDst += outputName + '.json';
// eslint-disable-next-line no-console
console.log('Saved nn interface to ' + outputDst);
fs.writeFileSync(outputDst, JSON.stringify(layers, null, 2));
var layers = ['modules']
.map(dir => torchPath + dir + '/')
.map(path => parseLayerFiles(path))
.reduce((l1, l2) => l1.concat(l2))
.filter(layer => layer.name[0] !== '_'); // skip hidden/abstract layers
// Update the CreateTorchMeta index
var updateSchemas = `${__dirname}/../src/plugins/CreateTorchMeta/update-schemas.js`,
job = require('child_process').fork(updateSchemas);
job.on('close', code => {
process.exit(code);
});
// eslint-disable-next-line no-console
console.log('discovered', layers.length, 'layers');
outputDst += outputName + '.json';
// eslint-disable-next-line no-console
console.log('Saved nn interface to ' + outputDst);
fs.writeFileSync(outputDst, JSON.stringify(layers, null, 2));
// Update the CreateTorchMeta index
var updateSchemas = `${__dirname}/../src/plugins/CreateTorchMeta/update-schemas.js`,
job = require('child_process').fork(updateSchemas);
job.on('close', code => {
process.exit(code);
});
+104
Ver Arquivo
@@ -0,0 +1,104 @@
var brython = require('./node-brython'),
fs = require('fs'),
assert = require('assert'),
src = fs.readFileSync(process.env.HOME + '/projects/pytorch/torch/nn/modules/conv.py', 'utf8'),
root = build_ast(src);
function build_ast(src) {
brython.$py_module_path['__main__']='./'
return brython.py2js(src,'__main__', '__main__', '__builtins__')
}
// The provided tree gives us contexts which can have associated 'C'
function traverse (node, fn) {
var i;
if (node.children) {
for (i = node.children.length; i--;) {
traverse(node.children[i], fn);
fn(node.children[i]);
}
}
if (node.C && node.C.tree) {
for (i = node.C.tree.length; i--;) {
traverse(node.C.tree[i], fn);
fn(node.C.tree[i]);
}
}
}
var types = {},
layers = [],
pCtx,
classNode,
params;
function isClass(node) {
return node.type === 'class';
}
function isInitFn(node) {
return node.type === 'def' && node.name === '__init__';
}
function getBaseClass(node) {
assert(node.type === 'class');
return node.args.tree[0].tree[0].tree[0].value;
}
var defaults = {},
defTypes,
args,
def;
traverse(root, node => {
if (node.type) types[node.type] = true;
// Get the class for the given function
if (isInitFn(node)) {
pCtx = node.parent.node.parent;
classNode = pCtx.C.tree[0];
if (isClass(classNode)) {
// remove the 'self' variable
// TODO: May need to update this for kwargs
// (use positional_list)
args = node.tree[1].tree;
defaults = {};
params = node.args.slice(1);
defTypes = {};
for (var i = args.length; i--;) {
if (args[i].tree[0]) {
def = args[i].tree[0].tree[0];
console.log('setting type of ', params[i-1], 'to', def.type);
defTypes[params[i-1]] = def.type;
if (def.type === 'int') {
defaults[params[i-1]] = parseInt.apply(null, def.value.reverse());
} else {
defaults[params[i-1]] = def.value;
}
}
}
layers.push({
name: classNode.name,
baseType: getBaseClass(classNode),
//doc: classNode.doc_string || '',
defaults: defaults,
types: defTypes,
params: params
});
}
}
// TODO: What if there is no constructor? Is this a potential problem?
});
console.log('layers:', layers);
fs.writeFileSync('./testPyTorchLayers.json', JSON.stringify(layers, null, 2));
//console.log('layers:', layers.map(l => l.name));
// Try to find the class definitions...
//
// Need to create:
//
// setters: (I don't think these are used in pytorch!
// types:
// defaults:
// type: