Merge pull request #102 from Cristy94/master
Improve workerTrain performance
Esse commit está contido em:
externo
+112
-143
@@ -1045,32 +1045,56 @@ Network.prototype = {
|
||||
return new Function(hardcode)();
|
||||
},
|
||||
|
||||
worker: function() {
|
||||
|
||||
// Return a HTML5 WebWorker specialized on training the network stored in `memory`.
|
||||
// Train based on the given dataSet and options.
|
||||
// The worker returns the updated `memory` when done.
|
||||
worker: function(memory, set, options) {
|
||||
|
||||
// Copy the options and set defaults (options might be different for each worker)
|
||||
var workerOptions = {};
|
||||
if(options) workerOptions = JSON.parse(JSON.stringify(options));
|
||||
workerOptions.rate = options.rate || .2;
|
||||
workerOptions.iterations = options.iterations || 100000;
|
||||
workerOptions.error = options.error || .005;
|
||||
workerOptions.cost = options.cost || null;
|
||||
workerOptions.crossValidate = options.crossValidate || null;
|
||||
|
||||
// Cost function might be different for each worker
|
||||
costFunction = "var cost = " + (options && options.cost || this.cost || Trainer.cost.MSE) + ";\n";
|
||||
var workerFunction = Network.getWorkerSharedFunctions();
|
||||
workerFunction = workerFunction.replace(/var cost = options && options\.cost \|\| this\.cost \|\| Trainer\.cost\.MSE;/g, costFunction);
|
||||
|
||||
// Set what we do when training is finished
|
||||
workerFunction = workerFunction.replace('return results;',
|
||||
'postMessage({action: "done", message: results, memoryBuffer: F}, [F.buffer]);');
|
||||
|
||||
// Replace log with postmessage
|
||||
workerFunction = workerFunction.replace("console.log('iterations', iterations, 'error', error, 'rate', currentRate)",
|
||||
"postMessage({action: 'log', message: {\n" +
|
||||
"iterations: iterations,\n" +
|
||||
"error: error,\n" +
|
||||
"rate: currentRate\n" +
|
||||
"}\n" +
|
||||
"})");
|
||||
|
||||
if (!this.optimized)
|
||||
this.optimize();
|
||||
|
||||
var hardcode = "var inputs = " + this.optimized.data.inputs.length +
|
||||
";\n";
|
||||
hardcode += "var outputs = " + this.optimized.data.outputs.length +
|
||||
";\n";
|
||||
hardcode += "var F = null;\n";
|
||||
hardcode += "var activate = " + this.optimized.activate.toString() +
|
||||
";\n";
|
||||
hardcode += "var propagate = " + this.optimized.propagate.toString() +
|
||||
";\n";
|
||||
hardcode += "onmessage = function(e){\n";
|
||||
hardcode += "F = e.data.memoryBuffer;\n";
|
||||
hardcode += "if (e.data.action == 'activate'){\n";
|
||||
hardcode += "if (e.data.input.length == inputs){\n";
|
||||
hardcode +=
|
||||
"postMessage( { action: 'activate', output: activate(e.data.input), memoryBuffer: F }, [F.buffer]);\n";
|
||||
hardcode += "}\n}\nelse if (e.data.action == 'propagate'){\n";
|
||||
hardcode += "propagate(e.data.rate, e.data.target);\n";
|
||||
hardcode +=
|
||||
"postMessage({ action: 'propagate', memoryBuffer: F }, [F.buffer]);\n";
|
||||
hardcode += "}\n}\n";
|
||||
var hardcode = "var inputs = " + this.optimized.data.inputs.length + ";\n";
|
||||
hardcode += "var outputs = " + this.optimized.data.outputs.length + ";\n";
|
||||
hardcode += "var F = new Float64Array([" + this.optimized.memory.toString() + "]);\n";
|
||||
hardcode += "var activate = " + this.optimized.activate.toString() + ";\n";
|
||||
hardcode += "var propagate = " + this.optimized.propagate.toString() + ";\n";
|
||||
hardcode +=
|
||||
"onmessage = function(e) {\n" +
|
||||
"if (e.data.action == 'startTraining') {\n" +
|
||||
"train(" + JSON.stringify(set) + "," + JSON.stringify(workerOptions) + ");\n" +
|
||||
"}\n" +
|
||||
"}";
|
||||
|
||||
var blob = new Blob([hardcode]);
|
||||
var workerSourceCode = workerFunction + '\n' + hardcode;
|
||||
var blob = new Blob([workerSourceCode]);
|
||||
var blobURL = window.URL.createObjectURL(blob);
|
||||
|
||||
return new Worker(blobURL);
|
||||
@@ -1082,6 +1106,39 @@ Network.prototype = {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a static String to store the source code of the functions
|
||||
* that are identical for all the workers (train, _trainSet, test)
|
||||
*
|
||||
* @return {String} Source code that can train a network inside a worker.
|
||||
* @static
|
||||
*/
|
||||
Network.getWorkerSharedFunctions = function() {
|
||||
// If we already computed the source code for the shared functions
|
||||
if(typeof Network._SHARED_WORKER_FUNCTIONS !== 'undefined')
|
||||
return Network._SHARED_WORKER_FUNCTIONS;
|
||||
|
||||
// Otherwise compute and return the source code
|
||||
// We compute them by simply copying the source code of the train, _trainSet and test functions
|
||||
// using the .toString() method
|
||||
|
||||
// Load and name the train function
|
||||
var train_f = Trainer.prototype.train.toString();
|
||||
train_f = train_f.replace('function (set', 'function train(set') + '\n';
|
||||
|
||||
// Load and name the _trainSet function
|
||||
var _trainSet_f = Trainer.prototype._trainSet.toString().replace(/this.network./g, '');
|
||||
_trainSet_f = _trainSet_f.replace('function (set', 'function _trainSet(set') + '\n';
|
||||
_trainSet_f = _trainSet_f.replace('this.crossValidate', 'crossValidate');
|
||||
_trainSet_f = _trainSet_f.replace('crossValidate = true', 'crossValidate = { }');
|
||||
|
||||
// Load and name the test function
|
||||
var test_f = Trainer.prototype.test.toString().replace(/this.network./g, '');
|
||||
test_f = test_f.replace('function (set', 'function test(set') + '\n';
|
||||
|
||||
return Network._SHARED_WORKER_FUNCTIONS = train_f + _trainSet_f + test_f;
|
||||
};
|
||||
|
||||
// rebuild a network that has been stored in a json using the method toJSON()
|
||||
Network.fromJSON = function(json) {
|
||||
|
||||
@@ -2034,7 +2091,7 @@ function Trainer(network, options) {
|
||||
this.network = network;
|
||||
this.rate = options.rate || .2;
|
||||
this.iterations = options.iterations || 100000;
|
||||
this.error = options.error || .005
|
||||
this.error = options.error || .005;
|
||||
this.cost = options.cost || null;
|
||||
this.crossValidate = options.crossValidate || null;
|
||||
}
|
||||
@@ -2077,7 +2134,8 @@ Trainer.prototype = {
|
||||
console.log('Deprecated: use schedule instead of customLog')
|
||||
this.schedule = options.customLog;
|
||||
}
|
||||
if (this.crossValidate) {
|
||||
if (this.crossValidate || options.crossValidate) {
|
||||
if(!this.crossValidate) this.crossValidate = {};
|
||||
crossValidate = true;
|
||||
if (options.crossValidate.testSize)
|
||||
this.crossValidate.testSize = options.crossValidate.testSize;
|
||||
@@ -2193,134 +2251,45 @@ Trainer.prototype = {
|
||||
// trains any given set to a network using a WebWorker
|
||||
workerTrain: function(set, callback, options) {
|
||||
|
||||
console.log('WorkerTrain initiated!');
|
||||
|
||||
var that = this;
|
||||
var error = 1;
|
||||
var iterations = bucketSize = 0;
|
||||
var input, output, target, currentRate;
|
||||
var length = set.length;
|
||||
var abort = false;
|
||||
var cost = options && options.cost || that.cost || Trainer.cost.MSE;
|
||||
|
||||
if (!this.network.optimized)
|
||||
this.network.optimize();
|
||||
|
||||
var start = Date.now();
|
||||
|
||||
if (options) {
|
||||
if (options.shuffle) {
|
||||
//+ Jonas Raoni Soares Silva
|
||||
//@ http://jsfromhell.com/array/shuffle [v1.0]
|
||||
function shuffle(o) { //v1.0
|
||||
for (var j, x, i = o.length; i; j = Math.floor(Math.random() *
|
||||
i), x = o[--i], o[i] = o[j], o[j] = x);
|
||||
return o;
|
||||
};
|
||||
}
|
||||
if (options.iterations)
|
||||
that.iterations = options.iterations;
|
||||
if (options.error)
|
||||
that.error = options.error;
|
||||
if (options.rate)
|
||||
that.rate = options.rate;
|
||||
if (options.cost)
|
||||
that.cost = options.cost;
|
||||
if (options.schedule)
|
||||
that.schedule = options.schedule;
|
||||
if (options.customLog)
|
||||
{
|
||||
// for backward compatibility with code that used customLog
|
||||
console.log('Deprecated: use schedule instead of customLog')
|
||||
that.schedule = options.customLog;
|
||||
}
|
||||
}
|
||||
|
||||
// dynamic learning rate
|
||||
currentRate = that.rate;
|
||||
if(Array.isArray(that.rate)) {
|
||||
bucketSize = Math.floor(that.iterations / that.rate.length);
|
||||
}
|
||||
|
||||
// create a worker
|
||||
var worker = that.network.worker();
|
||||
|
||||
// activate the network
|
||||
function activateWorker(input)
|
||||
{
|
||||
worker.postMessage({
|
||||
action: "activate",
|
||||
input: input,
|
||||
memoryBuffer: that.network.optimized.memory
|
||||
}, [that.network.optimized.memory.buffer]);
|
||||
}
|
||||
|
||||
// backpropagate the network
|
||||
function propagateWorker(target){
|
||||
if(bucketSize > 0) {
|
||||
var currentBucket = Math.floor(iterations / bucketSize);
|
||||
currentRate = that.rate[currentBucket] || currentRate;
|
||||
}
|
||||
worker.postMessage({
|
||||
action: "propagate",
|
||||
target: target,
|
||||
rate: currentRate,
|
||||
memoryBuffer: that.network.optimized.memory
|
||||
}, [that.network.optimized.memory.buffer]);
|
||||
}
|
||||
// Create a new worker
|
||||
var worker = this.network.worker(this.network.optimized.memory, set, options);
|
||||
|
||||
// train the worker
|
||||
worker.onmessage = function(e){
|
||||
// give control of the memory back to the network
|
||||
that.network.optimized.ownership(e.data.memoryBuffer);
|
||||
worker.onmessage = function(e) {
|
||||
switch(e.data.action) {
|
||||
case 'done':
|
||||
var iterations = e.data.message.iterations;
|
||||
var error = e.data.message.error;
|
||||
var time = e.data.message.time;
|
||||
|
||||
if (e.data.action == "propagate")
|
||||
{
|
||||
if (index >= length)
|
||||
{
|
||||
index = 0;
|
||||
iterations++;
|
||||
error /= set.length;
|
||||
that.network.optimized.ownership(e.data.memoryBuffer);
|
||||
|
||||
// log
|
||||
if (options) {
|
||||
if (that.schedule && that.schedule.every && iterations % that.schedule.every == 0)
|
||||
abort = that.schedule.do({
|
||||
error: error,
|
||||
iterations: iterations,
|
||||
rate: currentRate
|
||||
});
|
||||
else if (options.log && iterations % options.log == 0) {
|
||||
console.log('iterations', iterations, 'error', error);
|
||||
};
|
||||
if (options.shuffle)
|
||||
shuffle(set);
|
||||
}
|
||||
// Done callback
|
||||
callback({
|
||||
error: error,
|
||||
iterations: iterations,
|
||||
time: time
|
||||
});
|
||||
|
||||
if (!abort && iterations < that.iterations && error > that.error)
|
||||
{
|
||||
activateWorker(set[index].input);
|
||||
} else {
|
||||
// callback
|
||||
callback({
|
||||
error: error,
|
||||
iterations: iterations,
|
||||
time: Date.now() - start
|
||||
})
|
||||
}
|
||||
error = 0;
|
||||
} else {
|
||||
activateWorker(set[index].input);
|
||||
}
|
||||
}
|
||||
// Delete the worker and all its associated memory
|
||||
worker.terminate();
|
||||
break;
|
||||
|
||||
if (e.data.action == "activate")
|
||||
{
|
||||
error += cost(set[index].output, e.data.output);
|
||||
propagateWorker(set[index].output);
|
||||
index++;
|
||||
}
|
||||
case 'log':
|
||||
console.log(e.data.message);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// kick it
|
||||
var index = 0;
|
||||
var iterations = 0;
|
||||
activateWorker(set[index].input);
|
||||
// Start the worker
|
||||
worker.postMessage({action: 'startTraining'});
|
||||
},
|
||||
|
||||
// trains an XOR to the network
|
||||
@@ -2766,4 +2735,4 @@ Trainer.cost = {
|
||||
}
|
||||
|
||||
},{}]},{},[5]);
|
||||
var Neuron = synaptic.Neuron, Layer = synaptic.Layer, Network = synaptic.Network, Trainer = synaptic.Trainer, Architect = synaptic.Architect;
|
||||
var Neuron = synaptic.Neuron, Layer = synaptic.Layer, Network = synaptic.Network, Trainer = synaptic.Trainer, Architect = synaptic.Architect;
|
||||
externo
+2
-2
Diff do arquivo suprimido porque uma ou mais linhas são muito longas
+1
-1
@@ -7,7 +7,7 @@ var Layer = require('./layer')
|
||||
ARCHITECT
|
||||
*******************************************************************************************/
|
||||
|
||||
// Colection of useful built-in architectures
|
||||
// Collection of useful built-in architectures
|
||||
var Architect = {
|
||||
|
||||
// Multilayer Perceptron
|
||||
|
||||
+79
-22
@@ -492,32 +492,56 @@ Network.prototype = {
|
||||
return new Function(hardcode)();
|
||||
},
|
||||
|
||||
worker: function() {
|
||||
|
||||
// Return a HTML5 WebWorker specialized on training the network stored in `memory`.
|
||||
// Train based on the given dataSet and options.
|
||||
// The worker returns the updated `memory` when done.
|
||||
worker: function(memory, set, options) {
|
||||
|
||||
// Copy the options and set defaults (options might be different for each worker)
|
||||
var workerOptions = {};
|
||||
if(options) workerOptions = JSON.parse(JSON.stringify(options));
|
||||
workerOptions.rate = options.rate || .2;
|
||||
workerOptions.iterations = options.iterations || 100000;
|
||||
workerOptions.error = options.error || .005;
|
||||
workerOptions.cost = options.cost || null;
|
||||
workerOptions.crossValidate = options.crossValidate || null;
|
||||
|
||||
// Cost function might be different for each worker
|
||||
costFunction = "var cost = " + (options && options.cost || this.cost || Trainer.cost.MSE) + ";\n";
|
||||
var workerFunction = Network.getWorkerSharedFunctions();
|
||||
workerFunction = workerFunction.replace(/var cost = options && options\.cost \|\| this\.cost \|\| Trainer\.cost\.MSE;/g, costFunction);
|
||||
|
||||
// Set what we do when training is finished
|
||||
workerFunction = workerFunction.replace('return results;',
|
||||
'postMessage({action: "done", message: results, memoryBuffer: F}, [F.buffer]);');
|
||||
|
||||
// Replace log with postmessage
|
||||
workerFunction = workerFunction.replace("console.log('iterations', iterations, 'error', error, 'rate', currentRate)",
|
||||
"postMessage({action: 'log', message: {\n" +
|
||||
"iterations: iterations,\n" +
|
||||
"error: error,\n" +
|
||||
"rate: currentRate\n" +
|
||||
"}\n" +
|
||||
"})");
|
||||
|
||||
if (!this.optimized)
|
||||
this.optimize();
|
||||
|
||||
var hardcode = "var inputs = " + this.optimized.data.inputs.length +
|
||||
";\n";
|
||||
hardcode += "var outputs = " + this.optimized.data.outputs.length +
|
||||
";\n";
|
||||
hardcode += "var F = null;\n";
|
||||
hardcode += "var activate = " + this.optimized.activate.toString() +
|
||||
";\n";
|
||||
hardcode += "var propagate = " + this.optimized.propagate.toString() +
|
||||
";\n";
|
||||
hardcode += "onmessage = function(e){\n";
|
||||
hardcode += "F = e.data.memoryBuffer;\n";
|
||||
hardcode += "if (e.data.action == 'activate'){\n";
|
||||
hardcode += "if (e.data.input.length == inputs){\n";
|
||||
hardcode +=
|
||||
"postMessage( { action: 'activate', output: activate(e.data.input), memoryBuffer: F }, [F.buffer]);\n";
|
||||
hardcode += "}\n}\nelse if (e.data.action == 'propagate'){\n";
|
||||
hardcode += "propagate(e.data.rate, e.data.target);\n";
|
||||
hardcode +=
|
||||
"postMessage({ action: 'propagate', memoryBuffer: F }, [F.buffer]);\n";
|
||||
hardcode += "}\n}\n";
|
||||
var hardcode = "var inputs = " + this.optimized.data.inputs.length + ";\n";
|
||||
hardcode += "var outputs = " + this.optimized.data.outputs.length + ";\n";
|
||||
hardcode += "var F = new Float64Array([" + this.optimized.memory.toString() + "]);\n";
|
||||
hardcode += "var activate = " + this.optimized.activate.toString() + ";\n";
|
||||
hardcode += "var propagate = " + this.optimized.propagate.toString() + ";\n";
|
||||
hardcode +=
|
||||
"onmessage = function(e) {\n" +
|
||||
"if (e.data.action == 'startTraining') {\n" +
|
||||
"train(" + JSON.stringify(set) + "," + JSON.stringify(workerOptions) + ");\n" +
|
||||
"}\n" +
|
||||
"}";
|
||||
|
||||
var blob = new Blob([hardcode]);
|
||||
var workerSourceCode = workerFunction + '\n' + hardcode;
|
||||
var blob = new Blob([workerSourceCode]);
|
||||
var blobURL = window.URL.createObjectURL(blob);
|
||||
|
||||
return new Worker(blobURL);
|
||||
@@ -529,6 +553,39 @@ Network.prototype = {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a static String to store the source code of the functions
|
||||
* that are identical for all the workers (train, _trainSet, test)
|
||||
*
|
||||
* @return {String} Source code that can train a network inside a worker.
|
||||
* @static
|
||||
*/
|
||||
Network.getWorkerSharedFunctions = function() {
|
||||
// If we already computed the source code for the shared functions
|
||||
if(typeof Network._SHARED_WORKER_FUNCTIONS !== 'undefined')
|
||||
return Network._SHARED_WORKER_FUNCTIONS;
|
||||
|
||||
// Otherwise compute and return the source code
|
||||
// We compute them by simply copying the source code of the train, _trainSet and test functions
|
||||
// using the .toString() method
|
||||
|
||||
// Load and name the train function
|
||||
var train_f = Trainer.prototype.train.toString();
|
||||
train_f = train_f.replace('function (set', 'function train(set') + '\n';
|
||||
|
||||
// Load and name the _trainSet function
|
||||
var _trainSet_f = Trainer.prototype._trainSet.toString().replace(/this.network./g, '');
|
||||
_trainSet_f = _trainSet_f.replace('function (set', 'function _trainSet(set') + '\n';
|
||||
_trainSet_f = _trainSet_f.replace('this.crossValidate', 'crossValidate');
|
||||
_trainSet_f = _trainSet_f.replace('crossValidate = true', 'crossValidate = { }');
|
||||
|
||||
// Load and name the test function
|
||||
var test_f = Trainer.prototype.test.toString().replace(/this.network./g, '');
|
||||
test_f = test_f.replace('function (set', 'function test(set') + '\n';
|
||||
|
||||
return Network._SHARED_WORKER_FUNCTIONS = train_f + _trainSet_f + test_f;
|
||||
};
|
||||
|
||||
// rebuild a network that has been stored in a json using the method toJSON()
|
||||
Network.fromJSON = function(json) {
|
||||
|
||||
|
||||
+32
-120
@@ -10,7 +10,7 @@ function Trainer(network, options) {
|
||||
this.network = network;
|
||||
this.rate = options.rate || .2;
|
||||
this.iterations = options.iterations || 100000;
|
||||
this.error = options.error || .005
|
||||
this.error = options.error || .005;
|
||||
this.cost = options.cost || null;
|
||||
this.crossValidate = options.crossValidate || null;
|
||||
}
|
||||
@@ -53,7 +53,8 @@ Trainer.prototype = {
|
||||
console.log('Deprecated: use schedule instead of customLog')
|
||||
this.schedule = options.customLog;
|
||||
}
|
||||
if (this.crossValidate) {
|
||||
if (this.crossValidate || options.crossValidate) {
|
||||
if(!this.crossValidate) this.crossValidate = {};
|
||||
crossValidate = true;
|
||||
if (options.crossValidate.testSize)
|
||||
this.crossValidate.testSize = options.crossValidate.testSize;
|
||||
@@ -169,134 +170,45 @@ Trainer.prototype = {
|
||||
// trains any given set to a network using a WebWorker
|
||||
workerTrain: function(set, callback, options) {
|
||||
|
||||
console.log('WorkerTrain initiated!');
|
||||
|
||||
var that = this;
|
||||
var error = 1;
|
||||
var iterations = bucketSize = 0;
|
||||
var input, output, target, currentRate;
|
||||
var length = set.length;
|
||||
var abort = false;
|
||||
var cost = options && options.cost || that.cost || Trainer.cost.MSE;
|
||||
|
||||
if (!this.network.optimized)
|
||||
this.network.optimize();
|
||||
|
||||
var start = Date.now();
|
||||
|
||||
if (options) {
|
||||
if (options.shuffle) {
|
||||
//+ Jonas Raoni Soares Silva
|
||||
//@ http://jsfromhell.com/array/shuffle [v1.0]
|
||||
function shuffle(o) { //v1.0
|
||||
for (var j, x, i = o.length; i; j = Math.floor(Math.random() *
|
||||
i), x = o[--i], o[i] = o[j], o[j] = x);
|
||||
return o;
|
||||
};
|
||||
}
|
||||
if (options.iterations)
|
||||
that.iterations = options.iterations;
|
||||
if (options.error)
|
||||
that.error = options.error;
|
||||
if (options.rate)
|
||||
that.rate = options.rate;
|
||||
if (options.cost)
|
||||
that.cost = options.cost;
|
||||
if (options.schedule)
|
||||
that.schedule = options.schedule;
|
||||
if (options.customLog)
|
||||
{
|
||||
// for backward compatibility with code that used customLog
|
||||
console.log('Deprecated: use schedule instead of customLog')
|
||||
that.schedule = options.customLog;
|
||||
}
|
||||
}
|
||||
|
||||
// dynamic learning rate
|
||||
currentRate = that.rate;
|
||||
if(Array.isArray(that.rate)) {
|
||||
bucketSize = Math.floor(that.iterations / that.rate.length);
|
||||
}
|
||||
|
||||
// create a worker
|
||||
var worker = that.network.worker();
|
||||
|
||||
// activate the network
|
||||
function activateWorker(input)
|
||||
{
|
||||
worker.postMessage({
|
||||
action: "activate",
|
||||
input: input,
|
||||
memoryBuffer: that.network.optimized.memory
|
||||
}, [that.network.optimized.memory.buffer]);
|
||||
}
|
||||
|
||||
// backpropagate the network
|
||||
function propagateWorker(target){
|
||||
if(bucketSize > 0) {
|
||||
var currentBucket = Math.floor(iterations / bucketSize);
|
||||
currentRate = that.rate[currentBucket] || currentRate;
|
||||
}
|
||||
worker.postMessage({
|
||||
action: "propagate",
|
||||
target: target,
|
||||
rate: currentRate,
|
||||
memoryBuffer: that.network.optimized.memory
|
||||
}, [that.network.optimized.memory.buffer]);
|
||||
}
|
||||
// Create a new worker
|
||||
var worker = this.network.worker(this.network.optimized.memory, set, options);
|
||||
|
||||
// train the worker
|
||||
worker.onmessage = function(e){
|
||||
// give control of the memory back to the network
|
||||
that.network.optimized.ownership(e.data.memoryBuffer);
|
||||
worker.onmessage = function(e) {
|
||||
switch(e.data.action) {
|
||||
case 'done':
|
||||
var iterations = e.data.message.iterations;
|
||||
var error = e.data.message.error;
|
||||
var time = e.data.message.time;
|
||||
|
||||
if (e.data.action == "propagate")
|
||||
{
|
||||
if (index >= length)
|
||||
{
|
||||
index = 0;
|
||||
iterations++;
|
||||
error /= set.length;
|
||||
that.network.optimized.ownership(e.data.memoryBuffer);
|
||||
|
||||
// log
|
||||
if (options) {
|
||||
if (that.schedule && that.schedule.every && iterations % that.schedule.every == 0)
|
||||
abort = that.schedule.do({
|
||||
error: error,
|
||||
iterations: iterations,
|
||||
rate: currentRate
|
||||
});
|
||||
else if (options.log && iterations % options.log == 0) {
|
||||
console.log('iterations', iterations, 'error', error);
|
||||
};
|
||||
if (options.shuffle)
|
||||
shuffle(set);
|
||||
}
|
||||
// Done callback
|
||||
callback({
|
||||
error: error,
|
||||
iterations: iterations,
|
||||
time: time
|
||||
});
|
||||
|
||||
if (!abort && iterations < that.iterations && error > that.error)
|
||||
{
|
||||
activateWorker(set[index].input);
|
||||
} else {
|
||||
// callback
|
||||
callback({
|
||||
error: error,
|
||||
iterations: iterations,
|
||||
time: Date.now() - start
|
||||
})
|
||||
}
|
||||
error = 0;
|
||||
} else {
|
||||
activateWorker(set[index].input);
|
||||
}
|
||||
}
|
||||
// Delete the worker and all its associated memory
|
||||
worker.terminate();
|
||||
break;
|
||||
|
||||
if (e.data.action == "activate")
|
||||
{
|
||||
error += cost(set[index].output, e.data.output);
|
||||
propagateWorker(set[index].output);
|
||||
index++;
|
||||
}
|
||||
case 'log':
|
||||
console.log(e.data.message);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// kick it
|
||||
var index = 0;
|
||||
var iterations = 0;
|
||||
activateWorker(set[index].input);
|
||||
// Start the worker
|
||||
worker.postMessage({action: 'startTraining'});
|
||||
},
|
||||
|
||||
// trains an XOR to the network
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário