Merge pull request #102 from Cristy94/master

Improve workerTrain performance
Esse commit está contido em:
Juan Cazala
2016-07-04 21:49:12 -03:00
commit de GitHub
5 arquivos alterados com 226 adições e 288 exclusões
+112 -143
Ver Arquivo
@@ -1045,32 +1045,56 @@ Network.prototype = {
return new Function(hardcode)();
},
worker: function() {
// Return a HTML5 WebWorker specialized on training the network stored in `memory`.
// Train based on the given dataSet and options.
// The worker returns the updated `memory` when done.
worker: function(memory, set, options) {
// Copy the options and set defaults (options might be different for each worker)
var workerOptions = {};
if(options) workerOptions = JSON.parse(JSON.stringify(options));
workerOptions.rate = options.rate || .2;
workerOptions.iterations = options.iterations || 100000;
workerOptions.error = options.error || .005;
workerOptions.cost = options.cost || null;
workerOptions.crossValidate = options.crossValidate || null;
// Cost function might be different for each worker
costFunction = "var cost = " + (options && options.cost || this.cost || Trainer.cost.MSE) + ";\n";
var workerFunction = Network.getWorkerSharedFunctions();
workerFunction = workerFunction.replace(/var cost = options && options\.cost \|\| this\.cost \|\| Trainer\.cost\.MSE;/g, costFunction);
// Set what we do when training is finished
workerFunction = workerFunction.replace('return results;',
'postMessage({action: "done", message: results, memoryBuffer: F}, [F.buffer]);');
// Replace log with postmessage
workerFunction = workerFunction.replace("console.log('iterations', iterations, 'error', error, 'rate', currentRate)",
"postMessage({action: 'log', message: {\n" +
"iterations: iterations,\n" +
"error: error,\n" +
"rate: currentRate\n" +
"}\n" +
"})");
if (!this.optimized)
this.optimize();
var hardcode = "var inputs = " + this.optimized.data.inputs.length +
";\n";
hardcode += "var outputs = " + this.optimized.data.outputs.length +
";\n";
hardcode += "var F = null;\n";
hardcode += "var activate = " + this.optimized.activate.toString() +
";\n";
hardcode += "var propagate = " + this.optimized.propagate.toString() +
";\n";
hardcode += "onmessage = function(e){\n";
hardcode += "F = e.data.memoryBuffer;\n";
hardcode += "if (e.data.action == 'activate'){\n";
hardcode += "if (e.data.input.length == inputs){\n";
hardcode +=
"postMessage( { action: 'activate', output: activate(e.data.input), memoryBuffer: F }, [F.buffer]);\n";
hardcode += "}\n}\nelse if (e.data.action == 'propagate'){\n";
hardcode += "propagate(e.data.rate, e.data.target);\n";
hardcode +=
"postMessage({ action: 'propagate', memoryBuffer: F }, [F.buffer]);\n";
hardcode += "}\n}\n";
var hardcode = "var inputs = " + this.optimized.data.inputs.length + ";\n";
hardcode += "var outputs = " + this.optimized.data.outputs.length + ";\n";
hardcode += "var F = new Float64Array([" + this.optimized.memory.toString() + "]);\n";
hardcode += "var activate = " + this.optimized.activate.toString() + ";\n";
hardcode += "var propagate = " + this.optimized.propagate.toString() + ";\n";
hardcode +=
"onmessage = function(e) {\n" +
"if (e.data.action == 'startTraining') {\n" +
"train(" + JSON.stringify(set) + "," + JSON.stringify(workerOptions) + ");\n" +
"}\n" +
"}";
var blob = new Blob([hardcode]);
var workerSourceCode = workerFunction + '\n' + hardcode;
var blob = new Blob([workerSourceCode]);
var blobURL = window.URL.createObjectURL(blob);
return new Worker(blobURL);
@@ -1082,6 +1106,39 @@ Network.prototype = {
}
}
/**
* Creates a static String to store the source code of the functions
* that are identical for all the workers (train, _trainSet, test)
*
* @return {String} Source code that can train a network inside a worker.
* @static
*/
Network.getWorkerSharedFunctions = function() {
// If we already computed the source code for the shared functions
if(typeof Network._SHARED_WORKER_FUNCTIONS !== 'undefined')
return Network._SHARED_WORKER_FUNCTIONS;
// Otherwise compute and return the source code
// We compute them by simply copying the source code of the train, _trainSet and test functions
// using the .toString() method
// Load and name the train function
var train_f = Trainer.prototype.train.toString();
train_f = train_f.replace('function (set', 'function train(set') + '\n';
// Load and name the _trainSet function
var _trainSet_f = Trainer.prototype._trainSet.toString().replace(/this.network./g, '');
_trainSet_f = _trainSet_f.replace('function (set', 'function _trainSet(set') + '\n';
_trainSet_f = _trainSet_f.replace('this.crossValidate', 'crossValidate');
_trainSet_f = _trainSet_f.replace('crossValidate = true', 'crossValidate = { }');
// Load and name the test function
var test_f = Trainer.prototype.test.toString().replace(/this.network./g, '');
test_f = test_f.replace('function (set', 'function test(set') + '\n';
return Network._SHARED_WORKER_FUNCTIONS = train_f + _trainSet_f + test_f;
};
// rebuild a network that has been stored in a json using the method toJSON()
Network.fromJSON = function(json) {
@@ -2034,7 +2091,7 @@ function Trainer(network, options) {
this.network = network;
this.rate = options.rate || .2;
this.iterations = options.iterations || 100000;
this.error = options.error || .005
this.error = options.error || .005;
this.cost = options.cost || null;
this.crossValidate = options.crossValidate || null;
}
@@ -2077,7 +2134,8 @@ Trainer.prototype = {
console.log('Deprecated: use schedule instead of customLog')
this.schedule = options.customLog;
}
if (this.crossValidate) {
if (this.crossValidate || options.crossValidate) {
if(!this.crossValidate) this.crossValidate = {};
crossValidate = true;
if (options.crossValidate.testSize)
this.crossValidate.testSize = options.crossValidate.testSize;
@@ -2193,134 +2251,45 @@ Trainer.prototype = {
// trains any given set to a network using a WebWorker
workerTrain: function(set, callback, options) {
console.log('WorkerTrain initiated!');
var that = this;
var error = 1;
var iterations = bucketSize = 0;
var input, output, target, currentRate;
var length = set.length;
var abort = false;
var cost = options && options.cost || that.cost || Trainer.cost.MSE;
if (!this.network.optimized)
this.network.optimize();
var start = Date.now();
if (options) {
if (options.shuffle) {
//+ Jonas Raoni Soares Silva
//@ http://jsfromhell.com/array/shuffle [v1.0]
function shuffle(o) { //v1.0
for (var j, x, i = o.length; i; j = Math.floor(Math.random() *
i), x = o[--i], o[i] = o[j], o[j] = x);
return o;
};
}
if (options.iterations)
that.iterations = options.iterations;
if (options.error)
that.error = options.error;
if (options.rate)
that.rate = options.rate;
if (options.cost)
that.cost = options.cost;
if (options.schedule)
that.schedule = options.schedule;
if (options.customLog)
{
// for backward compatibility with code that used customLog
console.log('Deprecated: use schedule instead of customLog')
that.schedule = options.customLog;
}
}
// dynamic learning rate
currentRate = that.rate;
if(Array.isArray(that.rate)) {
bucketSize = Math.floor(that.iterations / that.rate.length);
}
// create a worker
var worker = that.network.worker();
// activate the network
function activateWorker(input)
{
worker.postMessage({
action: "activate",
input: input,
memoryBuffer: that.network.optimized.memory
}, [that.network.optimized.memory.buffer]);
}
// backpropagate the network
function propagateWorker(target){
if(bucketSize > 0) {
var currentBucket = Math.floor(iterations / bucketSize);
currentRate = that.rate[currentBucket] || currentRate;
}
worker.postMessage({
action: "propagate",
target: target,
rate: currentRate,
memoryBuffer: that.network.optimized.memory
}, [that.network.optimized.memory.buffer]);
}
// Create a new worker
var worker = this.network.worker(this.network.optimized.memory, set, options);
// train the worker
worker.onmessage = function(e){
// give control of the memory back to the network
that.network.optimized.ownership(e.data.memoryBuffer);
worker.onmessage = function(e) {
switch(e.data.action) {
case 'done':
var iterations = e.data.message.iterations;
var error = e.data.message.error;
var time = e.data.message.time;
if (e.data.action == "propagate")
{
if (index >= length)
{
index = 0;
iterations++;
error /= set.length;
that.network.optimized.ownership(e.data.memoryBuffer);
// log
if (options) {
if (that.schedule && that.schedule.every && iterations % that.schedule.every == 0)
abort = that.schedule.do({
error: error,
iterations: iterations,
rate: currentRate
});
else if (options.log && iterations % options.log == 0) {
console.log('iterations', iterations, 'error', error);
};
if (options.shuffle)
shuffle(set);
}
// Done callback
callback({
error: error,
iterations: iterations,
time: time
});
if (!abort && iterations < that.iterations && error > that.error)
{
activateWorker(set[index].input);
} else {
// callback
callback({
error: error,
iterations: iterations,
time: Date.now() - start
})
}
error = 0;
} else {
activateWorker(set[index].input);
}
}
// Delete the worker and all its associated memory
worker.terminate();
break;
if (e.data.action == "activate")
{
error += cost(set[index].output, e.data.output);
propagateWorker(set[index].output);
index++;
}
case 'log':
console.log(e.data.message);
break;
}
}
// kick it
var index = 0;
var iterations = 0;
activateWorker(set[index].input);
// Start the worker
worker.postMessage({action: 'startTraining'});
},
// trains an XOR to the network
@@ -2766,4 +2735,4 @@ Trainer.cost = {
}
},{}]},{},[5]);
var Neuron = synaptic.Neuron, Layer = synaptic.Layer, Network = synaptic.Network, Trainer = synaptic.Trainer, Architect = synaptic.Architect;
var Neuron = synaptic.Neuron, Layer = synaptic.Layer, Network = synaptic.Network, Trainer = synaptic.Trainer, Architect = synaptic.Architect;
+2 -2
Ver Arquivo
Diff do arquivo suprimido porque uma ou mais linhas são muito longas
+1 -1
Ver Arquivo
@@ -7,7 +7,7 @@ var Layer = require('./layer')
ARCHITECT
*******************************************************************************************/
// Colection of useful built-in architectures
// Collection of useful built-in architectures
var Architect = {
// Multilayer Perceptron
+79 -22
Ver Arquivo
@@ -492,32 +492,56 @@ Network.prototype = {
return new Function(hardcode)();
},
worker: function() {
// Return a HTML5 WebWorker specialized on training the network stored in `memory`.
// Train based on the given dataSet and options.
// The worker returns the updated `memory` when done.
worker: function(memory, set, options) {
// Copy the options and set defaults (options might be different for each worker)
var workerOptions = {};
if(options) workerOptions = JSON.parse(JSON.stringify(options));
workerOptions.rate = options.rate || .2;
workerOptions.iterations = options.iterations || 100000;
workerOptions.error = options.error || .005;
workerOptions.cost = options.cost || null;
workerOptions.crossValidate = options.crossValidate || null;
// Cost function might be different for each worker
costFunction = "var cost = " + (options && options.cost || this.cost || Trainer.cost.MSE) + ";\n";
var workerFunction = Network.getWorkerSharedFunctions();
workerFunction = workerFunction.replace(/var cost = options && options\.cost \|\| this\.cost \|\| Trainer\.cost\.MSE;/g, costFunction);
// Set what we do when training is finished
workerFunction = workerFunction.replace('return results;',
'postMessage({action: "done", message: results, memoryBuffer: F}, [F.buffer]);');
// Replace log with postmessage
workerFunction = workerFunction.replace("console.log('iterations', iterations, 'error', error, 'rate', currentRate)",
"postMessage({action: 'log', message: {\n" +
"iterations: iterations,\n" +
"error: error,\n" +
"rate: currentRate\n" +
"}\n" +
"})");
if (!this.optimized)
this.optimize();
var hardcode = "var inputs = " + this.optimized.data.inputs.length +
";\n";
hardcode += "var outputs = " + this.optimized.data.outputs.length +
";\n";
hardcode += "var F = null;\n";
hardcode += "var activate = " + this.optimized.activate.toString() +
";\n";
hardcode += "var propagate = " + this.optimized.propagate.toString() +
";\n";
hardcode += "onmessage = function(e){\n";
hardcode += "F = e.data.memoryBuffer;\n";
hardcode += "if (e.data.action == 'activate'){\n";
hardcode += "if (e.data.input.length == inputs){\n";
hardcode +=
"postMessage( { action: 'activate', output: activate(e.data.input), memoryBuffer: F }, [F.buffer]);\n";
hardcode += "}\n}\nelse if (e.data.action == 'propagate'){\n";
hardcode += "propagate(e.data.rate, e.data.target);\n";
hardcode +=
"postMessage({ action: 'propagate', memoryBuffer: F }, [F.buffer]);\n";
hardcode += "}\n}\n";
var hardcode = "var inputs = " + this.optimized.data.inputs.length + ";\n";
hardcode += "var outputs = " + this.optimized.data.outputs.length + ";\n";
hardcode += "var F = new Float64Array([" + this.optimized.memory.toString() + "]);\n";
hardcode += "var activate = " + this.optimized.activate.toString() + ";\n";
hardcode += "var propagate = " + this.optimized.propagate.toString() + ";\n";
hardcode +=
"onmessage = function(e) {\n" +
"if (e.data.action == 'startTraining') {\n" +
"train(" + JSON.stringify(set) + "," + JSON.stringify(workerOptions) + ");\n" +
"}\n" +
"}";
var blob = new Blob([hardcode]);
var workerSourceCode = workerFunction + '\n' + hardcode;
var blob = new Blob([workerSourceCode]);
var blobURL = window.URL.createObjectURL(blob);
return new Worker(blobURL);
@@ -529,6 +553,39 @@ Network.prototype = {
}
}
/**
* Creates a static String to store the source code of the functions
* that are identical for all the workers (train, _trainSet, test)
*
* @return {String} Source code that can train a network inside a worker.
* @static
*/
Network.getWorkerSharedFunctions = function() {
// If we already computed the source code for the shared functions
if(typeof Network._SHARED_WORKER_FUNCTIONS !== 'undefined')
return Network._SHARED_WORKER_FUNCTIONS;
// Otherwise compute and return the source code
// We compute them by simply copying the source code of the train, _trainSet and test functions
// using the .toString() method
// Load and name the train function
var train_f = Trainer.prototype.train.toString();
train_f = train_f.replace('function (set', 'function train(set') + '\n';
// Load and name the _trainSet function
var _trainSet_f = Trainer.prototype._trainSet.toString().replace(/this.network./g, '');
_trainSet_f = _trainSet_f.replace('function (set', 'function _trainSet(set') + '\n';
_trainSet_f = _trainSet_f.replace('this.crossValidate', 'crossValidate');
_trainSet_f = _trainSet_f.replace('crossValidate = true', 'crossValidate = { }');
// Load and name the test function
var test_f = Trainer.prototype.test.toString().replace(/this.network./g, '');
test_f = test_f.replace('function (set', 'function test(set') + '\n';
return Network._SHARED_WORKER_FUNCTIONS = train_f + _trainSet_f + test_f;
};
// rebuild a network that has been stored in a json using the method toJSON()
Network.fromJSON = function(json) {
+32 -120
Ver Arquivo
@@ -10,7 +10,7 @@ function Trainer(network, options) {
this.network = network;
this.rate = options.rate || .2;
this.iterations = options.iterations || 100000;
this.error = options.error || .005
this.error = options.error || .005;
this.cost = options.cost || null;
this.crossValidate = options.crossValidate || null;
}
@@ -53,7 +53,8 @@ Trainer.prototype = {
console.log('Deprecated: use schedule instead of customLog')
this.schedule = options.customLog;
}
if (this.crossValidate) {
if (this.crossValidate || options.crossValidate) {
if(!this.crossValidate) this.crossValidate = {};
crossValidate = true;
if (options.crossValidate.testSize)
this.crossValidate.testSize = options.crossValidate.testSize;
@@ -169,134 +170,45 @@ Trainer.prototype = {
// trains any given set to a network using a WebWorker
workerTrain: function(set, callback, options) {
console.log('WorkerTrain initiated!');
var that = this;
var error = 1;
var iterations = bucketSize = 0;
var input, output, target, currentRate;
var length = set.length;
var abort = false;
var cost = options && options.cost || that.cost || Trainer.cost.MSE;
if (!this.network.optimized)
this.network.optimize();
var start = Date.now();
if (options) {
if (options.shuffle) {
//+ Jonas Raoni Soares Silva
//@ http://jsfromhell.com/array/shuffle [v1.0]
function shuffle(o) { //v1.0
for (var j, x, i = o.length; i; j = Math.floor(Math.random() *
i), x = o[--i], o[i] = o[j], o[j] = x);
return o;
};
}
if (options.iterations)
that.iterations = options.iterations;
if (options.error)
that.error = options.error;
if (options.rate)
that.rate = options.rate;
if (options.cost)
that.cost = options.cost;
if (options.schedule)
that.schedule = options.schedule;
if (options.customLog)
{
// for backward compatibility with code that used customLog
console.log('Deprecated: use schedule instead of customLog')
that.schedule = options.customLog;
}
}
// dynamic learning rate
currentRate = that.rate;
if(Array.isArray(that.rate)) {
bucketSize = Math.floor(that.iterations / that.rate.length);
}
// create a worker
var worker = that.network.worker();
// activate the network
function activateWorker(input)
{
worker.postMessage({
action: "activate",
input: input,
memoryBuffer: that.network.optimized.memory
}, [that.network.optimized.memory.buffer]);
}
// backpropagate the network
function propagateWorker(target){
if(bucketSize > 0) {
var currentBucket = Math.floor(iterations / bucketSize);
currentRate = that.rate[currentBucket] || currentRate;
}
worker.postMessage({
action: "propagate",
target: target,
rate: currentRate,
memoryBuffer: that.network.optimized.memory
}, [that.network.optimized.memory.buffer]);
}
// Create a new worker
var worker = this.network.worker(this.network.optimized.memory, set, options);
// train the worker
worker.onmessage = function(e){
// give control of the memory back to the network
that.network.optimized.ownership(e.data.memoryBuffer);
worker.onmessage = function(e) {
switch(e.data.action) {
case 'done':
var iterations = e.data.message.iterations;
var error = e.data.message.error;
var time = e.data.message.time;
if (e.data.action == "propagate")
{
if (index >= length)
{
index = 0;
iterations++;
error /= set.length;
that.network.optimized.ownership(e.data.memoryBuffer);
// log
if (options) {
if (that.schedule && that.schedule.every && iterations % that.schedule.every == 0)
abort = that.schedule.do({
error: error,
iterations: iterations,
rate: currentRate
});
else if (options.log && iterations % options.log == 0) {
console.log('iterations', iterations, 'error', error);
};
if (options.shuffle)
shuffle(set);
}
// Done callback
callback({
error: error,
iterations: iterations,
time: time
});
if (!abort && iterations < that.iterations && error > that.error)
{
activateWorker(set[index].input);
} else {
// callback
callback({
error: error,
iterations: iterations,
time: Date.now() - start
})
}
error = 0;
} else {
activateWorker(set[index].input);
}
}
// Delete the worker and all its associated memory
worker.terminate();
break;
if (e.data.action == "activate")
{
error += cost(set[index].output, e.data.output);
propagateWorker(set[index].output);
index++;
}
case 'log':
console.log(e.data.message);
break;
}
}
// kick it
var index = 0;
var iterations = 0;
activateWorker(set[index].input);
// Start the worker
worker.postMessage({action: 'startTraining'});
},
// trains an XOR to the network