Comparar commits
46 Commits
development
...
master
| Autor | SHA1 | Data | |
|---|---|---|---|
| 0b31eed0d1 | |||
| a5fbefea5a | |||
| 3d8265d804 | |||
| 34c93cb718 | |||
| d3e2ea8488 | |||
| 4cfd5a828f | |||
| 24807634e8 | |||
| ebac6d8c33 | |||
| b0a9559beb | |||
| 5a6c487102 | |||
| 9a06d558aa | |||
| 0a544dbc9e | |||
| 3c433021e5 | |||
| e2240eb7b3 | |||
| efb642c2a7 | |||
| 697de49de0 | |||
| 82d64f27a9 | |||
| 116df545fe | |||
| e6953b5028 | |||
| 47745614e3 | |||
| 5041da2c25 | |||
| 05113ec720 | |||
| d3044fbe2a | |||
| f6e8d79f2a | |||
| e1ff6b86df | |||
| c233227b10 | |||
| 8ff4f28075 | |||
| 533d77aaea | |||
| 29bd2ecf28 | |||
| 8a7d82b25c | |||
| 45e4cc2ab3 | |||
| bd5062bdaa | |||
| 0345fe648d | |||
| b820493823 | |||
| 7ee75f9984 | |||
| 25e1e151a2 | |||
| c5be6ffa3d | |||
| 50d28c328f | |||
| 5524e29646 | |||
| abdce5117a | |||
| a8d237f577 | |||
| 238be47c90 | |||
| 4a8301d3f9 | |||
| 6fc3bbf898 | |||
| 8895db5052 | |||
| dab9030d19 |
@@ -16,3 +16,5 @@ demo.js
|
||||
|
||||
# Degub
|
||||
debug.html
|
||||
|
||||
.settings
|
||||
|
||||
+7
-1
@@ -1,3 +1,9 @@
|
||||
language: node_js
|
||||
script: "npm run test:travis"
|
||||
node_js:
|
||||
- "0.10"
|
||||
# always latest release
|
||||
- "node"
|
||||
# previous releases
|
||||
- "6"
|
||||
- "5"
|
||||
- "4"
|
||||
Arquivo executável → Arquivo normal
+29
-5
@@ -1,6 +1,6 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2014 Juan Cazala (juancazala.com)
|
||||
Copyright (c) 2016 Juan Cazala - juancazala.com
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
@@ -9,14 +9,38 @@ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE
|
||||
|
||||
|
||||
|
||||
********************************************************************************************
|
||||
SYNAPTIC (v1.0.8)
|
||||
********************************************************************************************
|
||||
|
||||
Synaptic is a javascript neural network library for node.js and the browser, its generalized
|
||||
algorithm is architecture-free, so you can build and train basically any type of first order
|
||||
or even second order neural network architectures.
|
||||
|
||||
http://en.wikipedia.org/wiki/Recurrent_neural_network#Second_Order_Recurrent_Neural_Network
|
||||
|
||||
The library includes a few built-in architectures like multilayer perceptrons, multilayer
|
||||
long-short term memory networks (LSTM) or liquid state machines, and a trainer capable of
|
||||
training any given network, and includes built-in training tasks/tests like solving an XOR,
|
||||
passing a Distracted Sequence Recall test or an Embeded Reber Grammar test.
|
||||
|
||||
The algorithm implemented by this library has been taken from Derek D. Monner's paper:
|
||||
|
||||
|
||||
A generalized LSTM-like training algorithm for second-order recurrent neural networks
|
||||
http://www.overcomplete.net/papers/nn2012.pdf
|
||||
|
||||
There are references to the equations in that paper commented through the source code.
|
||||
|
||||
+29
-19
@@ -1,6 +1,8 @@
|
||||
Synaptic [](https://travis-ci.org/cazala/synaptic)
|
||||
Synaptic [](https://travis-ci.org/cazala/synaptic) [](https://synaptic-slack-ugiqacqvmd.now.sh/)
|
||||
========
|
||||
|
||||
## Important: [Synaptic 2.x](https://github.com/cazala/synaptic/issues/140) is in stage of discussion now! Feel free to participate
|
||||
|
||||
Synaptic is a javascript neural network library for **node.js** and the **browser**, its generalized algorithm is architecture-free, so you can build and train basically any type of first order or even [second order neural network](http://en.wikipedia.org/wiki/Recurrent_neural_network#Second_Order_Recurrent_Neural_Network) architectures.
|
||||
|
||||
This library includes a few built-in architectures like [multilayer perceptrons](http://en.wikipedia.org/wiki/Multilayer_perceptron), [multilayer long-short term memory](http://en.wikipedia.org/wiki/Long_short_term_memory) networks (LSTM), [liquid state machines](http://en.wikipedia.org/wiki/Liquid_state_machine) or [Hopfield](http://en.wikipedia.org/wiki/Hopfield_network) networks, and a trainer capable of training any given network, which includes built-in training tasks/tests like solving an XOR, completing a Distracted Sequence Recall task or an [Embedded Reber Grammar](http://www.willamette.edu/~gorr/classes/cs449/reber.html) test, so you can easily test and compare the performance of different architectures.
|
||||
@@ -17,14 +19,19 @@ There are references to the equations in that paper commented through the source
|
||||
|
||||
If you have no prior knowledge about Neural Networks, you should start by [reading this guide](https://github.com/cazala/synaptic/wiki/Neural-Networks-101).
|
||||
|
||||
|
||||
If you want a practical example on how to feed data to a neural network, then take a look at [this article](https://github.com/cazala/synaptic/wiki/Normalization-101).
|
||||
|
||||
You may also want to take a look at [this article](http://blog.webkid.io/neural-networks-in-javascript/).
|
||||
|
||||
####Demos
|
||||
|
||||
- [Solve an XOR](http://synaptic.juancazala.com/#/xor)
|
||||
- [Discrete Sequence Recall Task](http://synaptic.juancazala.com/#/dsr)
|
||||
- [Learn Image Filters](http://synaptic.juancazala.com/#/image-filters)
|
||||
- [Paint an Image](http://synaptic.juancazala.com/#/paint-an-image)
|
||||
- [Self Organizing Map](http://synaptic.juancazala.com/#/self-organizing-map)
|
||||
- [Read from Wikipedia](http://synaptic.juancazala.com/#/wikipedia)
|
||||
- [Solve an XOR](http://caza.la/synaptic/#/xor)
|
||||
- [Discrete Sequence Recall Task](http://caza.la/synaptic/#/dsr)
|
||||
- [Learn Image Filters](http://caza.la/synaptic/#/image-filters)
|
||||
- [Paint an Image](http://caza.la/synaptic/#/paint-an-image)
|
||||
- [Self Organizing Map](http://caza.la/synaptic/#/self-organizing-map)
|
||||
- [Read from Wikipedia](http://caza.la/synaptic/#/wikipedia)
|
||||
|
||||
The source code of these demos can be found in [this branch](https://github.com/cazala/synaptic/tree/gh-pages/scripts).
|
||||
|
||||
@@ -36,12 +43,17 @@ The source code of these demos can be found in [this branch](https://github.com/
|
||||
- [Trainer](https://github.com/cazala/synaptic/wiki/Trainer/)
|
||||
- [Architect](https://github.com/cazala/synaptic/wiki/Architect/)
|
||||
|
||||
To try out the examples, checkout the [gh-pages](https://github.com/cazala/synaptic/tree/gh-pages) branch.
|
||||
|
||||
`git checkout gh-pages`
|
||||
|
||||
|
||||
##Overview
|
||||
|
||||
###Installation
|
||||
|
||||
#####In node
|
||||
|
||||
You can install synaptic with [npm](http://npmjs.org):
|
||||
|
||||
```cmd
|
||||
@@ -49,10 +61,17 @@ npm install synaptic --save
|
||||
```
|
||||
|
||||
#####In the browser
|
||||
Just include the file synaptic.js from `/dist` directory with a script tag in your HTML:
|
||||
|
||||
You can install synaptic with [bower](http://bower.io):
|
||||
|
||||
```cmd
|
||||
bower install synaptic
|
||||
```
|
||||
|
||||
Or you can simply use the CDN link, kindly provided by [CDNjs](https://cdnjs.com/)
|
||||
|
||||
```html
|
||||
<script src="synaptic.js"></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/synaptic/1.0.8/synaptic.js"></script>
|
||||
```
|
||||
|
||||
###Usage
|
||||
@@ -68,15 +87,6 @@ var Neuron = synaptic.Neuron,
|
||||
|
||||
Now you can start to create networks, train them, or use built-in networks from the [Architect](http://github.com/cazala/synaptic#architect).
|
||||
|
||||
###Gulp Tasks
|
||||
|
||||
- **gulp**: runs all the tests and builds the minified and unminified bundles into `/dist`.
|
||||
- **gulp build**: builds the bundle: `/dist/synaptic.js`.
|
||||
- **gulp min**: builds the minified bundle: `/dist/synaptic.min.js`.
|
||||
- **gulp debug**: builds the bundle `/dist/synaptic.js` with sourcemaps.
|
||||
- **gulp dev**: same as `gulp debug`, but watches the source files and rebuilds when any change is detected.
|
||||
- **gulp test**: runs all the tests.
|
||||
|
||||
###Examples
|
||||
|
||||
#####Perceptron
|
||||
@@ -186,6 +196,6 @@ Multilayer LSTM network architectures.
|
||||
|
||||
**Synaptic** is an Open Source project that started in Buenos Aires, Argentina. Anybody in the world is welcome to contribute to the development of the project.
|
||||
|
||||
If you want to contribute feel free to send PR's, just make sure to run the default **gulp** task before submiting it. This way you'll run all the test specs and build the web distribution files.
|
||||
If you want to contribute feel free to send PR's, just make sure to run **npm run test** and **npm run build** before submiting it. This way you'll run all the test specs and build the web distribution files.
|
||||
|
||||
<3
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "synaptic",
|
||||
"version": "1.0.4",
|
||||
"version": "1.0.8",
|
||||
"homepage": "https://github.com/cazala/synaptic",
|
||||
"authors": [
|
||||
"Juan Cazala <juancazala@gmail.com>"
|
||||
|
||||
externo
+2848
-2769
Diferenças do arquivo suprimidas por serem muito extensas
Carregar Diff
externo
+2831
-36
Diff do arquivo suprimido porque uma ou mais linhas são muito longas
@@ -1,60 +0,0 @@
|
||||
'use strict';
|
||||
|
||||
var license = '/*\n\nThe MIT License (MIT)\n\nCopyright (c) 2014 Juan Cazala - juancazala.com\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the "Software"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in\nall copies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\nTHE SOFTWARE\n\n\n\n********************************************************************************************\n SYNAPTIC\n********************************************************************************************\n\nSynaptic is a javascript neural network library for node.js and the browser, its generalized\nalgorithm is architecture-free, so you can build and train basically any type of first order\nor even second order neural network architectures.\n\nhttp://en.wikipedia.org/wiki/Recurrent_neural_network#Second_Order_Recurrent_Neural_Network\n\nThe library includes a few built-in architectures like multilayer perceptrons, multilayer\nlong-short term memory networks (LSTM) or liquid state machines, and a trainer capable of\ntraining any given network, and includes built-in training tasks/tests like solving an XOR,\npassing a Distracted Sequence Recall test or an Embeded Reber Grammar test.\n\nThe algorithm implemented by this library has been taken from Derek D. Monner\'s paper:\n\n\nA generalized LSTM-like training algorithm for second-order recurrent neural networks\nhttp://www.overcomplete.net/papers/nn2012.pdf\n\nThere are references to the equations in that paper commented through the source code.\n\n\n********************************************************************************************/\n'
|
||||
var globals = 'var Neuron = synaptic.Neuron, Layer = synaptic.Layer, Network = synaptic.Network, Trainer = synaptic.Trainer, Architect = synaptic.Architect;';
|
||||
|
||||
// import
|
||||
var gulp = require('gulp');
|
||||
var browserify = require('browserify');
|
||||
var uglify = require('gulp-uglify');
|
||||
var mocha = require('gulp-mocha');
|
||||
var prepend = require('gulp-insert').prepend;
|
||||
var append = require('gulp-insert').append;
|
||||
var source = require('vinyl-source-stream');
|
||||
var buffer = require('vinyl-buffer');
|
||||
|
||||
// default task: runs all the tests, and builds all the files into dist (minified and unminifed)
|
||||
gulp.task('default', ['test', 'build', 'min']);
|
||||
|
||||
// build source into /dist for the web
|
||||
gulp.task('build', function () {
|
||||
return browserify({ entries: ['./src/synaptic.js'] })
|
||||
.bundle()
|
||||
.pipe(source('synaptic.js'))
|
||||
.pipe(buffer())
|
||||
.pipe(append(globals))
|
||||
.pipe(gulp.dest('./dist'));
|
||||
});
|
||||
|
||||
// build source into /dist for web (minified)
|
||||
gulp.task('min', function () {
|
||||
return browserify({ entries: ['./src/synaptic.js'] })
|
||||
.bundle()
|
||||
.pipe(source('synaptic.min.js'))
|
||||
.pipe(buffer())
|
||||
.pipe(uglify())
|
||||
.pipe(prepend(license))
|
||||
.pipe(append(globals))
|
||||
.pipe(gulp.dest('./dist'));
|
||||
});
|
||||
|
||||
// build source into /dist with sourcemaps for debugging
|
||||
gulp.task('debug', function () {
|
||||
return browserify({ entries: ['./src/synaptic.js'], debug: true })
|
||||
.bundle()
|
||||
.pipe(source('synaptic.js'))
|
||||
.pipe(buffer())
|
||||
.pipe(append(globals))
|
||||
.pipe(gulp.dest('./dist'));
|
||||
});
|
||||
|
||||
// run all the tests with mocha
|
||||
gulp.task('test', function () {
|
||||
return gulp.src('test/synaptic.js', {read: false})
|
||||
.pipe(mocha());
|
||||
});
|
||||
|
||||
// watch for changes and re-build (debug)
|
||||
gulp.task('dev', function () {
|
||||
gulp.watch('./src/*.js', ['debug']);
|
||||
});
|
||||
@@ -0,0 +1,25 @@
|
||||
// Karma configuration
|
||||
|
||||
module.exports = function(config) {
|
||||
config.set({
|
||||
basePath: '',
|
||||
frameworks: ['mocha'],
|
||||
files: [
|
||||
'dist/synaptic.js',
|
||||
'test/[^_]*.js'
|
||||
],
|
||||
exclude: [
|
||||
],
|
||||
preprocessors: {
|
||||
'test/*.js': ['webpack'],
|
||||
},
|
||||
reporters: ['progress'],
|
||||
port: 9876,
|
||||
colors: true,
|
||||
logLevel: config.LOG_INFO,
|
||||
autoWatch: true,
|
||||
singleRun: false,
|
||||
concurrency: Infinity,
|
||||
browserNoActivityTimeout: 60000,
|
||||
})
|
||||
}
|
||||
+27
-13
@@ -1,22 +1,36 @@
|
||||
{
|
||||
"name": "synaptic",
|
||||
"version": "1.0.4",
|
||||
"version": "1.0.9",
|
||||
"description": "architecture-free neural network library",
|
||||
"main": "./src/synaptic",
|
||||
"scripts": {
|
||||
"test": "mocha test"
|
||||
"test": "npm run test:src",
|
||||
"test:src": "mocha test --require src/synaptic.js ./test",
|
||||
"test:dist": "npm run build && npm run test:mocha:dist && npm run test:karma:browsers",
|
||||
"test:mocha:src": "mocha test --require src/synaptic.js ./test",
|
||||
"test:mocha:dist": "mocha test --require dist/synaptic.js ./test",
|
||||
"test:karma:browsers": "karma start --single-run --browsers Chrome,Firefox,SafariPrivate",
|
||||
"test:karma:phantomjs": "karma start --single-run --browsers PhantomJS",
|
||||
"test:travis": "npm run test:mocha:src && npm run build && npm run test:mocha:dist",
|
||||
"build": "webpack --config webpack.config.js"
|
||||
},
|
||||
"prepush": [
|
||||
"test",
|
||||
"build"
|
||||
],
|
||||
"devDependencies": {
|
||||
"browserify": "^10.1.3",
|
||||
"gulp": "^3.8.11",
|
||||
"gulp-insert": "^0.4.0",
|
||||
"gulp-mocha": "^2.0.1",
|
||||
"gulp-sourcemaps": "^1.5.2",
|
||||
"gulp-uglify": "^1.2.0",
|
||||
"gulp-util": "^3.0.4",
|
||||
"vinyl-buffer": "^1.0.0",
|
||||
"vinyl-source-stream": "^1.1.0",
|
||||
"mocha": "^2.2.4"
|
||||
"chai": "^3.5.0",
|
||||
"chai-stats": "^0.3.0",
|
||||
"karma": "^1.1.2",
|
||||
"karma-chrome-launcher": "^1.0.1",
|
||||
"karma-firefox-launcher": "^1.0.0",
|
||||
"karma-mocha": "^1.1.1",
|
||||
"karma-phantomjs-launcher": "^1.0.1",
|
||||
"karma-safari-launcher": "^1.0.0",
|
||||
"karma-webpack": "^1.7.0",
|
||||
"mocha": "^2.2.4",
|
||||
"pre-push": "^0.1.1",
|
||||
"webpack": "^1.13.1"
|
||||
},
|
||||
"repository": {
|
||||
"type": "git",
|
||||
@@ -39,6 +53,6 @@
|
||||
},
|
||||
"homepage": "http://synaptic.juancazala.com",
|
||||
"engines": {
|
||||
"node": ">=0.10"
|
||||
"node": ">=4"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
// update license year and version
|
||||
var fs = require('fs')
|
||||
module.exports = function() {
|
||||
var year = (new Date).getFullYear()
|
||||
var version = require('./package.json').version
|
||||
// LICENSE
|
||||
var license = fs.readFileSync('LICENSE', 'utf-8')
|
||||
.replace(/\(c\) ([0-9]+)/, `(c) ${year}`)
|
||||
.replace(/SYNAPTIC \(v(.*)\)/, `SYNAPTIC (v${version})`)
|
||||
fs.writeFileSync('LICENSE', license)
|
||||
// bower.json
|
||||
var bower = fs.readFileSync('bower.json', 'utf-8')
|
||||
.replace(/\"version\": \"(.*)\",/, `"version": "${version}",`)
|
||||
fs.writeFileSync('bower.json', bower)
|
||||
// README.md
|
||||
var readme = fs.readFileSync('README.md', 'utf-8')
|
||||
.replace(/ajax\/libs\/synaptic\/(.*)\/synaptic.js/, `ajax/libs/synaptic/${version}/synaptic.js`)
|
||||
fs.writeFileSync('README.md', readme)
|
||||
// return license for dist banner
|
||||
return license
|
||||
}
|
||||
+5
-5
@@ -7,7 +7,7 @@ var Layer = require('./layer')
|
||||
ARCHITECT
|
||||
*******************************************************************************************/
|
||||
|
||||
// Colection of useful built-in architectures
|
||||
// Collection of useful built-in architectures
|
||||
var Architect = {
|
||||
|
||||
// Multilayer Perceptron
|
||||
@@ -28,7 +28,7 @@ var Architect = {
|
||||
var previous = input;
|
||||
|
||||
// generate hidden layers
|
||||
for (level in layers) {
|
||||
for (var level in layers) {
|
||||
var size = layers[level];
|
||||
var layer = new Layer(size);
|
||||
hidden.push(layer);
|
||||
@@ -217,8 +217,8 @@ var Architect = {
|
||||
this.trainer = new Trainer(this);
|
||||
},
|
||||
|
||||
Hopfield: function Hopfield(size)
|
||||
{
|
||||
Hopfield: function Hopfield(size) {
|
||||
|
||||
var inputLayer = new Layer(size);
|
||||
var outputLayer = new Layer(size);
|
||||
|
||||
@@ -248,7 +248,7 @@ var Architect = {
|
||||
error: .00005,
|
||||
rate: 1
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
proto.feed = proto.feed || function(pattern)
|
||||
{
|
||||
|
||||
+92
-27
@@ -4,6 +4,7 @@ if (module) module.exports = Network;
|
||||
// import
|
||||
var Neuron = require('./neuron')
|
||||
, Layer = require('./layer')
|
||||
, Trainer = require('./trainer')
|
||||
|
||||
/*******************************************************************************************
|
||||
NETWORK
|
||||
@@ -420,7 +421,6 @@ Network.prototype = {
|
||||
code += " " + layerID + " -> " + layerToID + " [label = " + size + "]\n";
|
||||
for (var from in connection.gatedfrom) { // gatings
|
||||
var layerfrom = connection.gatedfrom[from].layer;
|
||||
var type = connection.gatedfrom[from].type;
|
||||
var layerfromID = layers.indexOf(layerfrom);
|
||||
code += " " + layerfromID + " -> " + fakeNode + " [color = blue]\n";
|
||||
}
|
||||
@@ -428,7 +428,6 @@ Network.prototype = {
|
||||
code += " " + layerID + " -> " + layerToID + " [label = " + size + "]\n";
|
||||
for (var from in connection.gatedfrom) { // gatings
|
||||
var layerfrom = connection.gatedfrom[from].layer;
|
||||
var type = connection.gatedfrom[from].type;
|
||||
var layerfromID = layers.indexOf(layerfrom);
|
||||
code += " " + layerfromID + " -> " + layerToID + " [color = blue]\n";
|
||||
}
|
||||
@@ -492,32 +491,65 @@ Network.prototype = {
|
||||
return new Function(hardcode)();
|
||||
},
|
||||
|
||||
worker: function() {
|
||||
|
||||
// Return a HTML5 WebWorker specialized on training the network stored in `memory`.
|
||||
// Train based on the given dataSet and options.
|
||||
// The worker returns the updated `memory` when done.
|
||||
worker: function(memory, set, options) {
|
||||
|
||||
// Copy the options and set defaults (options might be different for each worker)
|
||||
var workerOptions = {};
|
||||
if(options) workerOptions = options;
|
||||
workerOptions.rate = options.rate || .2;
|
||||
workerOptions.iterations = options.iterations || 100000;
|
||||
workerOptions.error = options.error || .005;
|
||||
workerOptions.cost = options.cost || null;
|
||||
workerOptions.crossValidate = options.crossValidate || null;
|
||||
|
||||
// Cost function might be different for each worker
|
||||
costFunction = "var cost = " + (options && options.cost || this.cost || Trainer.cost.MSE) + ";\n";
|
||||
var workerFunction = Network.getWorkerSharedFunctions();
|
||||
workerFunction = workerFunction.replace(/var cost = options && options\.cost \|\| this\.cost \|\| Trainer\.cost\.MSE;/g, costFunction);
|
||||
|
||||
// Set what we do when training is finished
|
||||
workerFunction = workerFunction.replace('return results;',
|
||||
'postMessage({action: "done", message: results, memoryBuffer: F}, [F.buffer]);');
|
||||
|
||||
// Replace log with postmessage
|
||||
workerFunction = workerFunction.replace("console.log('iterations', iterations, 'error', error, 'rate', currentRate)",
|
||||
"postMessage({action: 'log', message: {\n" +
|
||||
"iterations: iterations,\n" +
|
||||
"error: error,\n" +
|
||||
"rate: currentRate\n" +
|
||||
"}\n" +
|
||||
"})");
|
||||
|
||||
// Replace schedule with postmessage
|
||||
workerFunction = workerFunction.replace("abort = this.schedule.do({ error: error, iterations: iterations, rate: currentRate })",
|
||||
"postMessage({action: 'schedule', message: {\n" +
|
||||
"iterations: iterations,\n" +
|
||||
"error: error,\n" +
|
||||
"rate: currentRate\n" +
|
||||
"}\n" +
|
||||
"})");
|
||||
|
||||
if (!this.optimized)
|
||||
this.optimize();
|
||||
|
||||
var hardcode = "var inputs = " + this.optimized.data.inputs.length +
|
||||
";\n";
|
||||
hardcode += "var outputs = " + this.optimized.data.outputs.length +
|
||||
";\n";
|
||||
hardcode += "var F = null;\n";
|
||||
hardcode += "var activate = " + this.optimized.activate.toString() +
|
||||
";\n";
|
||||
hardcode += "var propagate = " + this.optimized.propagate.toString() +
|
||||
";\n";
|
||||
hardcode += "onmessage = function(e){\n";
|
||||
hardcode += "F = e.data.memoryBuffer;\n";
|
||||
hardcode += "if (e.data.action == 'activate'){\n";
|
||||
hardcode += "if (e.data.input.length == inputs){\n";
|
||||
var hardcode = "var inputs = " + this.optimized.data.inputs.length + ";\n";
|
||||
hardcode += "var outputs = " + this.optimized.data.outputs.length + ";\n";
|
||||
hardcode += "var F = new Float64Array([" + this.optimized.memory.toString() + "]);\n";
|
||||
hardcode += "var activate = " + this.optimized.activate.toString() + ";\n";
|
||||
hardcode += "var propagate = " + this.optimized.propagate.toString() + ";\n";
|
||||
hardcode +=
|
||||
"postMessage( { action: 'activate', output: activate(e.data.input), memoryBuffer: F }, [F.buffer]);\n";
|
||||
hardcode += "}\n}\nelse if (e.data.action == 'propagate'){\n";
|
||||
hardcode += "propagate(e.data.rate, e.data.target);\n";
|
||||
hardcode +=
|
||||
"postMessage({ action: 'propagate', memoryBuffer: F }, [F.buffer]);\n";
|
||||
hardcode += "}\n}\n";
|
||||
"onmessage = function(e) {\n" +
|
||||
"if (e.data.action == 'startTraining') {\n" +
|
||||
"train(" + JSON.stringify(set) + "," + JSON.stringify(workerOptions) + ");\n" +
|
||||
"}\n" +
|
||||
"}";
|
||||
|
||||
var blob = new Blob([hardcode]);
|
||||
var workerSourceCode = workerFunction + '\n' + hardcode;
|
||||
var blob = new Blob([workerSourceCode]);
|
||||
var blobURL = window.URL.createObjectURL(blob);
|
||||
|
||||
return new Worker(blobURL);
|
||||
@@ -527,7 +559,40 @@ Network.prototype = {
|
||||
clone: function() {
|
||||
return Network.fromJSON(this.toJSON());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a static String to store the source code of the functions
|
||||
* that are identical for all the workers (train, _trainSet, test)
|
||||
*
|
||||
* @return {String} Source code that can train a network inside a worker.
|
||||
* @static
|
||||
*/
|
||||
Network.getWorkerSharedFunctions = function() {
|
||||
// If we already computed the source code for the shared functions
|
||||
if(typeof Network._SHARED_WORKER_FUNCTIONS !== 'undefined')
|
||||
return Network._SHARED_WORKER_FUNCTIONS;
|
||||
|
||||
// Otherwise compute and return the source code
|
||||
// We compute them by simply copying the source code of the train, _trainSet and test functions
|
||||
// using the .toString() method
|
||||
|
||||
// Load and name the train function
|
||||
var train_f = Trainer.prototype.train.toString();
|
||||
train_f = train_f.replace('function (set', 'function train(set') + '\n';
|
||||
|
||||
// Load and name the _trainSet function
|
||||
var _trainSet_f = Trainer.prototype._trainSet.toString().replace(/this.network./g, '');
|
||||
_trainSet_f = _trainSet_f.replace('function (set', 'function _trainSet(set') + '\n';
|
||||
_trainSet_f = _trainSet_f.replace('this.crossValidate', 'crossValidate');
|
||||
_trainSet_f = _trainSet_f.replace('crossValidate = true', 'crossValidate = { }');
|
||||
|
||||
// Load and name the test function
|
||||
var test_f = Trainer.prototype.test.toString().replace(/this.network./g, '');
|
||||
test_f = test_f.replace('function (set', 'function test(set') + '\n';
|
||||
|
||||
return Network._SHARED_WORKER_FUNCTIONS = train_f + _trainSet_f + test_f;
|
||||
};
|
||||
|
||||
// rebuild a network that has been stored in a json using the method toJSON()
|
||||
Network.fromJSON = function(json) {
|
||||
@@ -538,7 +603,7 @@ Network.fromJSON = function(json) {
|
||||
input: new Layer(),
|
||||
hidden: [],
|
||||
output: new Layer()
|
||||
}
|
||||
};
|
||||
|
||||
for (var i in json.neurons) {
|
||||
var config = json.neurons[i];
|
||||
@@ -568,7 +633,7 @@ Network.fromJSON = function(json) {
|
||||
var config = json.connections[i];
|
||||
var from = neurons[config.from];
|
||||
var to = neurons[config.to];
|
||||
var weight = config.weight
|
||||
var weight = config.weight;
|
||||
var gater = neurons[config.gater];
|
||||
|
||||
var connection = from.project(to, weight);
|
||||
@@ -577,4 +642,4 @@ Network.fromJSON = function(json) {
|
||||
}
|
||||
|
||||
return new Network(layers);
|
||||
}
|
||||
};
|
||||
|
||||
+2
-7
@@ -66,7 +66,6 @@ Neuron.prototype = {
|
||||
var influences = [];
|
||||
for (var id in this.trace.extended) {
|
||||
// extended elegibility trace
|
||||
var xtrace = this.trace.extended[id];
|
||||
var neuron = this.neighboors[id];
|
||||
|
||||
// if gated neuron's selfconnection is gated by this unit, the influence keeps track of the neuron's old state
|
||||
@@ -304,7 +303,6 @@ Neuron.prototype = {
|
||||
optimize: function(optimized, layer) {
|
||||
|
||||
optimized = optimized || {};
|
||||
var that = this;
|
||||
var store_activation = [];
|
||||
var store_trace = [];
|
||||
var store_propagation = [];
|
||||
@@ -327,7 +325,7 @@ Neuron.prototype = {
|
||||
layers.__count = store.push([]) - 1;
|
||||
layers[layer] = layers.__count;
|
||||
}
|
||||
}
|
||||
};
|
||||
allocate(activation_sentences);
|
||||
allocate(trace_sentences);
|
||||
allocate(propagation_sentences);
|
||||
@@ -386,7 +384,7 @@ Neuron.prototype = {
|
||||
sentence += 'F[' + args[i].id + ']';
|
||||
|
||||
store.push(sentence + ';');
|
||||
}
|
||||
};
|
||||
|
||||
// helper to check if an object is empty
|
||||
var isEmpty = function(obj) {
|
||||
@@ -474,7 +472,6 @@ Neuron.prototype = {
|
||||
for (var id in this.trace.extended) {
|
||||
// calculate extended elegibility traces in advance
|
||||
|
||||
var xtrace = this.trace.extended[id];
|
||||
var neuron = this.neighboors[id];
|
||||
var influence = getVar('influences[' + neuron.ID + ']');
|
||||
var neuron_old = getVar(neuron, 'old');
|
||||
@@ -532,10 +529,8 @@ Neuron.prototype = {
|
||||
}
|
||||
for (var id in this.trace.extended) {
|
||||
// extended elegibility trace
|
||||
var xtrace = this.trace.extended[id];
|
||||
var neuron = this.neighboors[id];
|
||||
var influence = getVar('influences[' + neuron.ID + ']');
|
||||
var neuron_old = getVar(neuron, 'old');
|
||||
|
||||
var trace = getVar(this, 'trace', 'elegibility', input.ID, this.trace
|
||||
.elegibility[input.ID]);
|
||||
|
||||
+4
-55
@@ -1,54 +1,3 @@
|
||||
/*
|
||||
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2014 Juan Cazala - juancazala.com
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE
|
||||
|
||||
|
||||
|
||||
********************************************************************************************
|
||||
SYNAPTIC
|
||||
********************************************************************************************
|
||||
|
||||
Synaptic is a javascript neural network library for node.js and the browser, its generalized
|
||||
algorithm is architecture-free, so you can build and train basically any type of first order
|
||||
or even second order neural network architectures.
|
||||
|
||||
http://en.wikipedia.org/wiki/Recurrent_neural_network#Second_Order_Recurrent_Neural_Network
|
||||
|
||||
The library includes a few built-in architectures like multilayer perceptrons, multilayer
|
||||
long-short term memory networks (LSTM) or liquid state machines, and a trainer capable of
|
||||
training any given network, and includes built-in training tasks/tests like solving an XOR,
|
||||
passing a Distracted Sequence Recall test or an Embeded Reber Grammar test.
|
||||
|
||||
The algorithm implemented by this library has been taken from Derek D. Monner's paper:
|
||||
|
||||
A generalized LSTM-like training algorithm for second-order recurrent neural networks
|
||||
http://www.overcomplete.net/papers/nn2012.pdf
|
||||
|
||||
There are references to the equations in that paper commented through the source code.
|
||||
|
||||
|
||||
********************************************************************************************/
|
||||
|
||||
var Synaptic = {
|
||||
Neuron: require('./neuron'),
|
||||
Layer: require('./layer'),
|
||||
@@ -72,12 +21,12 @@ if (typeof module !== 'undefined' && module.exports)
|
||||
// Browser
|
||||
if (typeof window == 'object')
|
||||
{
|
||||
(function(){
|
||||
(function(){
|
||||
var oldSynaptic = window['synaptic'];
|
||||
Synaptic.ninja = function(){
|
||||
window['synaptic'] = oldSynaptic;
|
||||
Synaptic.ninja = function(){
|
||||
window['synaptic'] = oldSynaptic;
|
||||
return Synaptic;
|
||||
};
|
||||
};
|
||||
})();
|
||||
|
||||
window['synaptic'] = Synaptic;
|
||||
|
||||
+89
-155
@@ -10,7 +10,7 @@ function Trainer(network, options) {
|
||||
this.network = network;
|
||||
this.rate = options.rate || .2;
|
||||
this.iterations = options.iterations || 100000;
|
||||
this.error = options.error || .005
|
||||
this.error = options.error || .005;
|
||||
this.cost = options.cost || null;
|
||||
this.crossValidate = options.crossValidate || null;
|
||||
}
|
||||
@@ -23,7 +23,7 @@ Trainer.prototype = {
|
||||
var error = 1;
|
||||
var iterations = bucketSize = 0;
|
||||
var abort = false;
|
||||
var input, output, target, currentRate;
|
||||
var currentRate;
|
||||
var cost = options && options.cost || this.cost || Trainer.cost.MSE;
|
||||
var crossValidate = false, testSet, trainSet;
|
||||
|
||||
@@ -53,7 +53,8 @@ Trainer.prototype = {
|
||||
console.log('Deprecated: use schedule instead of customLog')
|
||||
this.schedule = options.customLog;
|
||||
}
|
||||
if (this.crossValidate) {
|
||||
if (this.crossValidate || options.crossValidate) {
|
||||
if(!this.crossValidate) this.crossValidate = {};
|
||||
crossValidate = true;
|
||||
if (options.crossValidate.testSize)
|
||||
this.crossValidate.testSize = options.crossValidate.testSize;
|
||||
@@ -64,7 +65,7 @@ Trainer.prototype = {
|
||||
|
||||
currentRate = this.rate;
|
||||
if(Array.isArray(this.rate)) {
|
||||
bucketSize = Math.floor(this.iterations / this.rate.length);
|
||||
var bucketSize = Math.floor(this.iterations / this.rate.length);
|
||||
}
|
||||
|
||||
if(crossValidate) {
|
||||
@@ -73,6 +74,7 @@ Trainer.prototype = {
|
||||
testSet = set.slice(numTrain);
|
||||
}
|
||||
|
||||
var lastError = 0;
|
||||
while ((!abort && iterations < this.iterations && error > this.error)) {
|
||||
if (crossValidate && error <= this.crossValidate.testError) {
|
||||
break;
|
||||
@@ -80,11 +82,16 @@ Trainer.prototype = {
|
||||
|
||||
var currentSetSize = set.length;
|
||||
error = 0;
|
||||
iterations++;
|
||||
|
||||
if(bucketSize > 0) {
|
||||
var currentBucket = Math.floor(iterations / bucketSize);
|
||||
currentRate = this.rate[currentBucket] || currentRate;
|
||||
}
|
||||
|
||||
if(typeof this.rate === 'function') {
|
||||
currentRate = this.rate(iterations, lastError);
|
||||
}
|
||||
|
||||
if (crossValidate) {
|
||||
this._trainSet(trainSet, currentRate, cost);
|
||||
@@ -96,17 +103,13 @@ Trainer.prototype = {
|
||||
}
|
||||
|
||||
// check error
|
||||
iterations++;
|
||||
error /= currentSetSize;
|
||||
lastError = error;
|
||||
|
||||
if (options) {
|
||||
if (this.schedule && this.schedule.every && iterations %
|
||||
this.schedule.every == 0)
|
||||
abort = this.schedule.do({
|
||||
error: error,
|
||||
iterations: iterations,
|
||||
rate: currentRate
|
||||
});
|
||||
abort = this.schedule.do({ error: error, iterations: iterations, rate: currentRate });
|
||||
else if (options.log && iterations % options.log == 0) {
|
||||
console.log('iterations', iterations, 'error', error, 'rate', currentRate);
|
||||
};
|
||||
@@ -119,19 +122,31 @@ Trainer.prototype = {
|
||||
error: error,
|
||||
iterations: iterations,
|
||||
time: Date.now() - start
|
||||
}
|
||||
};
|
||||
|
||||
return results;
|
||||
},
|
||||
|
||||
// trains any given set to a network, using a WebWorker (only for the browser). Returns a Promise of the results.
|
||||
trainAsync: function(set, options) {
|
||||
var train = this.workerTrain.bind(this);
|
||||
return new Promise(function(resolve, reject) {
|
||||
try {
|
||||
train(set, resolve, options, true)
|
||||
} catch(e) {
|
||||
reject(e)
|
||||
}
|
||||
})
|
||||
},
|
||||
|
||||
// preforms one training epoch and returns the error (private function used in this.train)
|
||||
_trainSet: function(set, currentRate, costFunction) {
|
||||
var errorSum = 0;
|
||||
for (var train in set) {
|
||||
input = set[train].input;
|
||||
target = set[train].output;
|
||||
var input = set[train].input;
|
||||
var target = set[train].output;
|
||||
|
||||
output = this.network.activate(input);
|
||||
var output = this.network.activate(input);
|
||||
this.network.propagate(currentRate, target);
|
||||
|
||||
errorSum += costFunction(target, output);
|
||||
@@ -143,7 +158,6 @@ Trainer.prototype = {
|
||||
test: function(set, options) {
|
||||
|
||||
var error = 0;
|
||||
var abort = false;
|
||||
var input, output, target;
|
||||
var cost = options && options.cost || this.cost || Trainer.cost.MSE;
|
||||
|
||||
@@ -161,142 +175,60 @@ Trainer.prototype = {
|
||||
var results = {
|
||||
error: error,
|
||||
time: Date.now() - start
|
||||
}
|
||||
};
|
||||
|
||||
return results;
|
||||
},
|
||||
|
||||
// trains any given set to a network using a WebWorker
|
||||
workerTrain: function(set, callback, options) {
|
||||
// trains any given set to a network using a WebWorker [deprecated: use trainAsync instead]
|
||||
workerTrain: function(set, callback, options, suppressWarning) {
|
||||
|
||||
if (!suppressWarning) {
|
||||
console.warn('Deprecated: do not use `workerTrain`, use `trainAsync` instead.')
|
||||
}
|
||||
var that = this;
|
||||
var error = 1;
|
||||
var iterations = bucketSize = 0;
|
||||
var input, output, target, currentRate;
|
||||
var length = set.length;
|
||||
var abort = false;
|
||||
var cost = options && options.cost || that.cost || Trainer.cost.MSE;
|
||||
|
||||
var start = Date.now();
|
||||
if (!this.network.optimized)
|
||||
this.network.optimize();
|
||||
|
||||
if (options) {
|
||||
if (options.shuffle) {
|
||||
//+ Jonas Raoni Soares Silva
|
||||
//@ http://jsfromhell.com/array/shuffle [v1.0]
|
||||
function shuffle(o) { //v1.0
|
||||
for (var j, x, i = o.length; i; j = Math.floor(Math.random() *
|
||||
i), x = o[--i], o[i] = o[j], o[j] = x);
|
||||
return o;
|
||||
};
|
||||
}
|
||||
if (options.iterations)
|
||||
that.iterations = options.iterations;
|
||||
if (options.error)
|
||||
that.error = options.error;
|
||||
if (options.rate)
|
||||
that.rate = options.rate;
|
||||
if (options.cost)
|
||||
that.cost = options.cost;
|
||||
if (options.schedule)
|
||||
that.schedule = options.schedule;
|
||||
if (options.customLog)
|
||||
{
|
||||
// for backward compatibility with code that used customLog
|
||||
console.log('Deprecated: use schedule instead of customLog')
|
||||
that.schedule = options.customLog;
|
||||
}
|
||||
}
|
||||
|
||||
// dynamic learning rate
|
||||
currentRate = that.rate;
|
||||
if(Array.isArray(that.rate)) {
|
||||
bucketSize = Math.floor(that.iterations / that.rate.length);
|
||||
}
|
||||
|
||||
// create a worker
|
||||
var worker = that.network.worker();
|
||||
|
||||
// activate the network
|
||||
function activateWorker(input)
|
||||
{
|
||||
worker.postMessage({
|
||||
action: "activate",
|
||||
input: input,
|
||||
memoryBuffer: that.network.optimized.memory
|
||||
}, [that.network.optimized.memory.buffer]);
|
||||
}
|
||||
|
||||
// backpropagate the network
|
||||
function propagateWorker(target){
|
||||
if(bucketSize > 0) {
|
||||
var currentBucket = Math.floor(iterations / bucketSize);
|
||||
currentRate = that.rate[currentBucket] || currentRate;
|
||||
}
|
||||
worker.postMessage({
|
||||
action: "propagate",
|
||||
target: target,
|
||||
rate: currentRate,
|
||||
memoryBuffer: that.network.optimized.memory
|
||||
}, [that.network.optimized.memory.buffer]);
|
||||
}
|
||||
// Create a new worker
|
||||
var worker = this.network.worker(this.network.optimized.memory, set, options);
|
||||
|
||||
// train the worker
|
||||
worker.onmessage = function(e){
|
||||
// give control of the memory back to the network
|
||||
that.network.optimized.ownership(e.data.memoryBuffer);
|
||||
worker.onmessage = function(e) {
|
||||
switch(e.data.action) {
|
||||
case 'done':
|
||||
var iterations = e.data.message.iterations;
|
||||
var error = e.data.message.error;
|
||||
var time = e.data.message.time;
|
||||
|
||||
if (e.data.action == "propagate")
|
||||
{
|
||||
if (index >= length)
|
||||
{
|
||||
index = 0;
|
||||
iterations++;
|
||||
error /= set.length;
|
||||
that.network.optimized.ownership(e.data.memoryBuffer);
|
||||
|
||||
// log
|
||||
if (options) {
|
||||
if (that.schedule && that.schedule.every && iterations % that.schedule.every == 0)
|
||||
abort = that.schedule.do({
|
||||
error: error,
|
||||
iterations: iterations,
|
||||
rate: currentRate
|
||||
});
|
||||
else if (options.log && iterations % options.log == 0) {
|
||||
console.log('iterations', iterations, 'error', error);
|
||||
};
|
||||
if (options.shuffle)
|
||||
shuffle(set);
|
||||
}
|
||||
// Done callback
|
||||
callback({
|
||||
error: error,
|
||||
iterations: iterations,
|
||||
time: time
|
||||
});
|
||||
|
||||
if (!abort && iterations < that.iterations && error > that.error)
|
||||
{
|
||||
activateWorker(set[index].input);
|
||||
} else {
|
||||
// callback
|
||||
callback({
|
||||
error: error,
|
||||
iterations: iterations,
|
||||
time: Date.now() - start
|
||||
})
|
||||
}
|
||||
error = 0;
|
||||
} else {
|
||||
activateWorker(set[index].input);
|
||||
// Delete the worker and all its associated memory
|
||||
worker.terminate();
|
||||
break;
|
||||
|
||||
case 'log':
|
||||
console.log(e.data.message);
|
||||
|
||||
case 'schedule':
|
||||
if (options && options.schedule && typeof options.schedule.do === 'function') {
|
||||
var scheduled = options.schedule.do
|
||||
scheduled(e.data.message)
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
if (e.data.action == "activate")
|
||||
{
|
||||
error += cost(set[index].output, e.data.output);
|
||||
propagateWorker(set[index].output);
|
||||
index++;
|
||||
}
|
||||
}
|
||||
|
||||
// kick it
|
||||
var index = 0;
|
||||
var iterations = 0;
|
||||
activateWorker(set[index].input);
|
||||
// Start the worker
|
||||
worker.postMessage({action: 'startTraining'});
|
||||
},
|
||||
|
||||
// trains an XOR to the network
|
||||
@@ -310,7 +242,7 @@ Trainer.prototype = {
|
||||
log: false,
|
||||
shuffle: true,
|
||||
cost: Trainer.cost.MSE
|
||||
}
|
||||
};
|
||||
|
||||
if (options)
|
||||
for (var i in options)
|
||||
@@ -346,8 +278,9 @@ Trainer.prototype = {
|
||||
var schedule = options.schedule || {};
|
||||
var cost = options.cost || this.cost || Trainer.cost.CROSS_ENTROPY;
|
||||
|
||||
var trial = correct = i = j = success = 0,
|
||||
error = 1,
|
||||
var trial, correct, i, j, success;
|
||||
trial = correct = i = j = success = 0;
|
||||
var error = 1,
|
||||
symbols = targets.length + distractors.length + prompts.length;
|
||||
|
||||
var noRepeat = function(range, avoid) {
|
||||
@@ -357,14 +290,14 @@ Trainer.prototype = {
|
||||
if (number == avoid[i])
|
||||
used = true;
|
||||
return used ? noRepeat(range, avoid) : number;
|
||||
}
|
||||
};
|
||||
|
||||
var equal = function(prediction, output) {
|
||||
for (var i in prediction)
|
||||
if (Math.round(prediction[i]) != output[i])
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
var start = Date.now();
|
||||
|
||||
@@ -389,6 +322,7 @@ Trainer.prototype = {
|
||||
}
|
||||
|
||||
//train sequence
|
||||
var distractorsCorrect;
|
||||
var targetsCorrect = distractorsCorrect = 0;
|
||||
error = 0;
|
||||
for (i = 0; i < length; i++) {
|
||||
@@ -470,7 +404,7 @@ Trainer.prototype = {
|
||||
// gramar node
|
||||
var Node = function() {
|
||||
this.paths = [];
|
||||
}
|
||||
};
|
||||
Node.prototype = {
|
||||
connect: function(node, value) {
|
||||
this.paths.push({
|
||||
@@ -491,7 +425,7 @@ Trainer.prototype = {
|
||||
return this.paths[i];
|
||||
return false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
var reberGrammar = function() {
|
||||
|
||||
@@ -500,19 +434,19 @@ Trainer.prototype = {
|
||||
var n1 = (new Node()).connect(output, "E");
|
||||
var n2 = (new Node()).connect(n1, "S");
|
||||
var n3 = (new Node()).connect(n1, "V").connect(n2, "P");
|
||||
var n4 = (new Node()).connect(n2, "X")
|
||||
var n4 = (new Node()).connect(n2, "X");
|
||||
n4.connect(n4, "S");
|
||||
var n5 = (new Node()).connect(n3, "V")
|
||||
var n5 = (new Node()).connect(n3, "V");
|
||||
n5.connect(n5, "T");
|
||||
n2.connect(n5, "X")
|
||||
n2.connect(n5, "X");
|
||||
var n6 = (new Node()).connect(n4, "T").connect(n5, "P");
|
||||
var input = (new Node()).connect(n6, "B")
|
||||
var input = (new Node()).connect(n6, "B");
|
||||
|
||||
return {
|
||||
input: input,
|
||||
output: output
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// build an embeded reber grammar
|
||||
var embededReberGrammar = function() {
|
||||
@@ -532,7 +466,7 @@ Trainer.prototype = {
|
||||
output: output
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
// generate an ERG sequence
|
||||
var generate = function() {
|
||||
@@ -544,7 +478,7 @@ Trainer.prototype = {
|
||||
next = next.node.any();
|
||||
}
|
||||
return str;
|
||||
}
|
||||
};
|
||||
|
||||
// test if a string matches an embeded reber grammar
|
||||
var test = function(str) {
|
||||
@@ -559,7 +493,7 @@ Trainer.prototype = {
|
||||
ch = str.charAt(++i);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// helper to check if the output and the target vectors match
|
||||
var different = function(array1, array2) {
|
||||
@@ -579,7 +513,7 @@ Trainer.prototype = {
|
||||
}
|
||||
|
||||
return i1 != i2;
|
||||
}
|
||||
};
|
||||
|
||||
var iteration = 0;
|
||||
var error = 1;
|
||||
@@ -590,7 +524,7 @@ Trainer.prototype = {
|
||||
"X": 3,
|
||||
"S": 4,
|
||||
"E": 5
|
||||
}
|
||||
};
|
||||
|
||||
var start = Date.now();
|
||||
while (iteration < iterations && error > criterion) {
|
||||
@@ -649,7 +583,7 @@ Trainer.prototype = {
|
||||
throw new Error("Invalid Network: must have 2 inputs and one output");
|
||||
|
||||
if (typeof options == 'undefined')
|
||||
var options = {};
|
||||
options = {};
|
||||
|
||||
// helper
|
||||
function getSamples (trainingSize, testSize){
|
||||
@@ -659,7 +593,7 @@ Trainer.prototype = {
|
||||
|
||||
// generate samples
|
||||
var t = 0;
|
||||
var set = [];
|
||||
var set = [];
|
||||
for (var i = 0; i < size; i++) {
|
||||
set.push({ input: [0,0], output: [0] });
|
||||
}
|
||||
|
||||
+2
-2
@@ -1,6 +1,6 @@
|
||||
Test using gulp, from root directory:
|
||||
Test using mocha, from root directory:
|
||||
|
||||
`gulp test`
|
||||
`mocha test`
|
||||
|
||||
To test the web version, start a web server at the root dir of this repo, then use your OS browser.
|
||||
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
global.synaptic = require('../dist/synaptic');
|
||||
@@ -0,0 +1 @@
|
||||
global.synaptic = require('../src/synaptic');
|
||||
@@ -0,0 +1 @@
|
||||
[^_]*.js
|
||||
+208
-184
@@ -1,62 +1,66 @@
|
||||
// import
|
||||
var chai = require('chai');
|
||||
chai.use(require('chai-stats'));
|
||||
var assert = chai.assert;
|
||||
|
||||
var assert = require('assert'),
|
||||
synaptic = require('../src/synaptic');
|
||||
var Perceptron = synaptic.Architect.Perceptron;
|
||||
var LSTM = synaptic.Architect.LSTM;
|
||||
var Layer = synaptic.Layer;
|
||||
var Network = synaptic.Network;
|
||||
var Trainer = synaptic.Trainer;
|
||||
|
||||
var Perceptron = synaptic.Architect.Perceptron,
|
||||
LSTM = synaptic.Architect.LSTM,
|
||||
Layer = synaptic.Layer,
|
||||
Network = synaptic.Network,
|
||||
Trainer = synaptic.Trainer;
|
||||
|
||||
|
||||
var learningRate = .5;
|
||||
|
||||
|
||||
// utils
|
||||
|
||||
function noRepeat (range, avoid) {
|
||||
function noRepeat(range, avoid) {
|
||||
var number = Math.random() * range | 0;
|
||||
for (var i in avoid){
|
||||
if (number == avoid[i]){
|
||||
return noRepeat(range,avoid);
|
||||
for (var i in avoid) {
|
||||
if (number == avoid[i]) {
|
||||
return noRepeat(range, avoid);
|
||||
}
|
||||
}
|
||||
return number;
|
||||
}
|
||||
|
||||
function equal (prediction, output) {
|
||||
function equal(prediction, output) {
|
||||
for (var i in prediction)
|
||||
if (Math.round(prediction[i]) != output[i])
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
function generateRandomArray (size){
|
||||
var array = [];
|
||||
for (var j = 0; j < size; j++)
|
||||
array.push(Math.random() + .5 | 0);
|
||||
return array;
|
||||
function generateRandomArray(size) {
|
||||
var array = [];
|
||||
for (var j = 0; j < size; j++)
|
||||
array.push(Math.random() + .5 | 0);
|
||||
return array;
|
||||
}
|
||||
|
||||
function compare (a, b) {
|
||||
function calculateMse(a, b) {
|
||||
var mse = 0;
|
||||
for (var k in a)
|
||||
mse += Math.pow(a[k] - b[k], 2);
|
||||
mse /= a.length;
|
||||
|
||||
return mse < 1e-10;
|
||||
return mse;
|
||||
}
|
||||
|
||||
function equalWithError (output, expected, error) {
|
||||
function equalWithError(output, expected, error) {
|
||||
return Math.abs(output - expected) <= error;
|
||||
}
|
||||
|
||||
// specs
|
||||
|
||||
describe('Basic Neural Network', function() {
|
||||
describe('Basic Neural Network', function () {
|
||||
|
||||
it("trains an AND gate", function() {
|
||||
it("trains an AND gate", function () {
|
||||
|
||||
var inputLayer = new Layer(2),
|
||||
outputLayer = new Layer(1);
|
||||
outputLayer = new Layer(1);
|
||||
|
||||
inputLayer.project(outputLayer);
|
||||
|
||||
@@ -99,7 +103,7 @@ describe('Basic Neural Network', function() {
|
||||
assert.equal(test11, 1, "[1,1] did not output 1");
|
||||
});
|
||||
|
||||
it("trains an OR gate", function() {
|
||||
it("trains an OR gate", function () {
|
||||
|
||||
var inputLayer = new Layer(2),
|
||||
outputLayer = new Layer(1);
|
||||
@@ -145,11 +149,10 @@ describe('Basic Neural Network', function() {
|
||||
assert.equal(test11, 1, "[1,1] did not output 1");
|
||||
});
|
||||
|
||||
it("trains a NOT gate", function() {
|
||||
it("trains a NOT gate", function () {
|
||||
|
||||
var inputLayer = new Layer(1),
|
||||
outputLayer = new Layer(1),
|
||||
network;
|
||||
outputLayer = new Layer(1);
|
||||
|
||||
inputLayer.project(outputLayer);
|
||||
|
||||
@@ -180,50 +183,43 @@ describe('Basic Neural Network', function() {
|
||||
});
|
||||
});
|
||||
|
||||
describe("Perceptron - XOR", function() {
|
||||
describe("Perceptron - XOR", function () {
|
||||
|
||||
var perceptron = new Perceptron(2, 3, 1);
|
||||
perceptron.trainer.XOR();
|
||||
|
||||
var test00 = Math.round(perceptron.activate([0, 0]));
|
||||
it("input: [0,0] output: " + test00, function() {
|
||||
|
||||
assert.equal(test00, 0, "[0,0] did not output 0");
|
||||
it("should return near-0 value on [0,0]", function () {
|
||||
assert.isAtMost(perceptron.activate([0, 0]), .49, "[0,0] did not output 0");
|
||||
});
|
||||
|
||||
var test01 = Math.round(perceptron.activate([0, 1]));
|
||||
it("input: [0,1] output: " + test01, function() {
|
||||
|
||||
assert.equal(test01, 1, "[0,1] did not output 1");
|
||||
it("should return near-1 value on [0,1]", function () {
|
||||
assert.isAtLeast(perceptron.activate([0, 1]), .51, "[0,1] did not output 1");
|
||||
});
|
||||
|
||||
var test10 = Math.round(perceptron.activate([1, 0]));
|
||||
it("input: [1,0] output: " + test10, function() {
|
||||
|
||||
assert.equal(test10, 1, "[1,0] did not output 1");
|
||||
it("should return near-1 value on [1,0]", function () {
|
||||
assert.isAtLeast(perceptron.activate([1, 0]), .51, "[1,0] did not output 1");
|
||||
});
|
||||
|
||||
var test11 = Math.round(perceptron.activate([1, 1]));
|
||||
it("input: [1,1] output: " + test11, function() {
|
||||
|
||||
assert.equal(test11, 0, "[1,1] did not output 0");
|
||||
it("should return near-0 value on [1,1]", function () {
|
||||
assert.isAtMost(perceptron.activate([1, 1]), .49, "[1,1] did not output 0");
|
||||
});
|
||||
});
|
||||
|
||||
describe("Perceptron - SIN", function() {
|
||||
var mySin = function(x) {
|
||||
return (Math.sin(x)+1)/2;
|
||||
describe("Perceptron - SIN", function () {
|
||||
var mySin = function (x) {
|
||||
return (Math.sin(x) + 1) / 2;
|
||||
};
|
||||
|
||||
var sinNetwork = new Perceptron(1, 12, 1);
|
||||
var trainingSet = [];
|
||||
|
||||
var trainingSet = Array.apply(null, Array(800)).map(function () {
|
||||
while (trainingSet.length < 800) {
|
||||
var inputValue = Math.random() * Math.PI * 2;
|
||||
return {
|
||||
trainingSet.push({
|
||||
input: [inputValue],
|
||||
output: [mySin(inputValue)]
|
||||
};
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
var results = sinNetwork.trainer.train(trainingSet, {
|
||||
iterations: 2000,
|
||||
@@ -232,38 +228,26 @@ describe("Perceptron - SIN", function() {
|
||||
cost: Trainer.cost.MSE,
|
||||
});
|
||||
|
||||
var test0 = sinNetwork.activate([0])[0];
|
||||
var expected0 = mySin(0);
|
||||
it("input: [0] output: " + test0 + ", expected: " + expected0, function() {
|
||||
var eq = equalWithError(test0, expected0, .035);
|
||||
assert.equal(eq, true, "[0] did not output " + expected0);
|
||||
});
|
||||
|
||||
var test05PI = sinNetwork.activate([.5*Math.PI])[0];
|
||||
var expected05PI = mySin(.5*Math.PI);
|
||||
it("input: [0.5*Math.PI] output: " + test05PI + ", expected: " + expected05PI, function() {
|
||||
var eq = equalWithError(test05PI, expected05PI, .035);
|
||||
assert.equal(eq, true, "[0.5*Math.PI] did not output " + expected05PI);
|
||||
});
|
||||
|
||||
var test2 = sinNetwork.activate([2])[0];
|
||||
var expected2 = mySin(2);
|
||||
it("input: [2] output: " + test2 + ", expected: " + expected2, function() {
|
||||
var eq = equalWithError(test2, expected2, .035);
|
||||
assert.equal(eq, true, "[2] did not output " + expected2);
|
||||
});
|
||||
[0, .5 * Math.PI, 2]
|
||||
.forEach(function (x) {
|
||||
var y = mySin(x);
|
||||
it("should return value around " + y + " when [" + x + "] is on input", function () {
|
||||
// near scalability: abs(expected-actual) < 0.5 * 10**(-decimal)
|
||||
// 0.5 * Math.pow(10, -.15) => 0.35397289219206896
|
||||
assert.almostEqual(sinNetwork.activate([x])[0], y, .15);
|
||||
});
|
||||
});
|
||||
|
||||
var errorResult = results.error;
|
||||
it("Sin error: " + errorResult, function() {
|
||||
var lessThanOrEqualError = errorResult <= .001;
|
||||
assert.equal(lessThanOrEqualError, true, "Sin error not less than or equal to desired error.");
|
||||
it("Sin error: " + errorResult, function () {
|
||||
assert.isAtMost(errorResult, .001, "Sin error not less than or equal to desired error.");
|
||||
});
|
||||
});
|
||||
|
||||
describe("Perceptron - SIN - CrossValidate", function() {
|
||||
describe("Perceptron - SIN - CrossValidate", function () {
|
||||
|
||||
var mySin = function(x) {
|
||||
return (Math.sin(x)+1)/2;
|
||||
var mySin = function (x) {
|
||||
return (Math.sin(x) + 1) / 2;
|
||||
};
|
||||
|
||||
var sinNetwork = new Perceptron(1, 12, 1);
|
||||
@@ -289,33 +273,31 @@ describe("Perceptron - SIN - CrossValidate", function() {
|
||||
|
||||
var test0 = sinNetwork.activate([0])[0];
|
||||
var expected0 = mySin(0);
|
||||
it("input: [0] output: " + test0 + ", expected: " + expected0, function() {
|
||||
var eq = equalWithError(test0, expected0, .035);
|
||||
assert.equal(eq, true, "[0] did not output " + expected0);
|
||||
it("input: [0] output: " + test0 + ", expected: " + expected0, function () {
|
||||
assert.isAtMost(Math.abs(test0 - expected0), .035, "[0] did not output " + expected0);
|
||||
});
|
||||
|
||||
var test05PI = sinNetwork.activate([.5*Math.PI])[0];
|
||||
var expected05PI = mySin(.5*Math.PI);
|
||||
it("input: [0.5*Math.PI] output: " + test05PI + ", expected: " + expected05PI, function() {
|
||||
var eq = equalWithError(test05PI, expected05PI, .035);
|
||||
assert.equal(eq, true, "[0.5*Math.PI] did not output " + expected05PI);
|
||||
var test05PI = sinNetwork.activate([.5 * Math.PI])[0];
|
||||
var expected05PI = mySin(.5 * Math.PI);
|
||||
it("input: [0.5*Math.PI] output: " + test05PI + ", expected: " + expected05PI, function () {
|
||||
assert.isAtMost(Math.abs(test05PI - expected05PI), .035, "[0.5*Math.PI] did not output " + expected05PI);
|
||||
});
|
||||
|
||||
var test2 = sinNetwork.activate([2])[0];
|
||||
var expected2 = mySin(2);
|
||||
it("input: [2] output: " + test2 + ", expected: " + expected2, function() {
|
||||
it("input: [2] output: " + test2 + ", expected: " + expected2, function () {
|
||||
var eq = equalWithError(test2, expected2, .035);
|
||||
assert.equal(eq, true, "[2] did not output " + expected2);
|
||||
});
|
||||
|
||||
var errorResult = results.error;
|
||||
it("CrossValidation error: " + errorResult, function() {
|
||||
it("CrossValidation error: " + errorResult, function () {
|
||||
var lessThanOrEqualError = errorResult <= .001;
|
||||
assert.equal(lessThanOrEqualError, true, "CrossValidation error not less than or equal to desired error.");
|
||||
});
|
||||
});
|
||||
|
||||
describe("LSTM - Discrete Sequence Recall", function() {
|
||||
describe("LSTM - Discrete Sequence Recall", function () {
|
||||
|
||||
var targets = [2, 4];
|
||||
var distractors = [3, 5];
|
||||
@@ -353,7 +335,7 @@ describe("LSTM - Discrete Sequence Recall", function() {
|
||||
sequence.push(prompts[i]);
|
||||
}
|
||||
|
||||
var check = function(which) {
|
||||
var check = function (which) {
|
||||
// generate input from sequence
|
||||
var input = [];
|
||||
for (j = 0; j < symbols; j++)
|
||||
@@ -378,7 +360,7 @@ describe("LSTM - Discrete Sequence Recall", function() {
|
||||
};
|
||||
};
|
||||
|
||||
var value = function(array) {
|
||||
var value = function (array) {
|
||||
var max = .5;
|
||||
var res = -1;
|
||||
for (var i in array)
|
||||
@@ -389,190 +371,232 @@ describe("LSTM - Discrete Sequence Recall", function() {
|
||||
return res == -1 ? '-' : targets[res];
|
||||
};
|
||||
|
||||
it("targets: " + targets, function() {
|
||||
it("targets: " + targets, function () {
|
||||
assert(true);
|
||||
});
|
||||
it("distractors: " + distractors, function() {
|
||||
it("distractors: " + distractors, function () {
|
||||
assert(true);
|
||||
});
|
||||
it("prompts: " + prompts, function() {
|
||||
it("prompts: " + prompts, function () {
|
||||
assert(true);
|
||||
});
|
||||
it("length: " + length + "\n", function() {
|
||||
it("length: " + length + "\n", function () {
|
||||
assert(true);
|
||||
});
|
||||
|
||||
for (var i = 0; i < length; i++) {
|
||||
var test = check(i);
|
||||
it((i + 1) + ") input: " + sequence[i] + " output: " + value(test.prediction),
|
||||
function() {
|
||||
function () {
|
||||
var ok = equal(test.prediction, test.output);
|
||||
assert(ok);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
describe("LSTM - Timing Task", function() {
|
||||
var network = new synaptic.Architect.LSTM(2,7,1);
|
||||
describe("LSTM - Timing Task", function () {
|
||||
var network = new LSTM(2, 7, 1);
|
||||
var result = network.trainer.timingTask({
|
||||
log: false,
|
||||
trainSamples: 4000,
|
||||
testSamples: 500
|
||||
});
|
||||
|
||||
it("should complete the training in less than 200 iterations", function() {
|
||||
it("should complete the training in less than 200 iterations", function () {
|
||||
assert(result.train.iterations <= 200);
|
||||
});
|
||||
|
||||
it("should pass the test with an error smaller than 0.05", function() {
|
||||
it("should pass the test with an error smaller than 0.05", function () {
|
||||
assert(result.test.error < .05);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Optimized and Unoptimized Networks Equivalency", function() {
|
||||
var optimized = new LSTM(2,1,1)
|
||||
describe("Optimized and Unoptimized Networks Equivalency", function () {
|
||||
|
||||
var unoptimized = optimized.clone();
|
||||
unoptimized.setOptimize(false);
|
||||
var optimized;
|
||||
var unoptimized;
|
||||
beforeEach(function () {
|
||||
optimized = new LSTM(2, 1, 1);
|
||||
unoptimized = optimized.clone();
|
||||
unoptimized.setOptimize(false);
|
||||
});
|
||||
|
||||
var learningRate = .5;
|
||||
var iterations = 1000;
|
||||
|
||||
for (var i = 1; i <= iterations; i++)
|
||||
{
|
||||
//random input
|
||||
it('should produce the same output for both networks', function () {
|
||||
this.timeout(30000);
|
||||
for (var i = 0; i < 1000; i++) {
|
||||
var input = generateRandomArray(2);
|
||||
|
||||
// activate networks
|
||||
var output1 = optimized.activate(input);
|
||||
var output2 = unoptimized.activate(input);
|
||||
|
||||
if (i % 100 == 0)
|
||||
it('should produce the same output for both networks after ' + i + ' iterations', function(){
|
||||
assert(compare(output1, output2));
|
||||
});
|
||||
|
||||
// random target
|
||||
var target = generateRandomArray(1);
|
||||
|
||||
// propagate networks
|
||||
optimized.activate(input);
|
||||
unoptimized.activate(input);
|
||||
optimized.propagate(learningRate, target);
|
||||
unoptimized.propagate(learningRate, target);
|
||||
}
|
||||
}
|
||||
var mse = calculateMse(optimized.activate(input), unoptimized.activate(input));
|
||||
assert.isAtMost(mse, 1e-9, 'output should be same for both networks after ' + i + ' iterations');
|
||||
});
|
||||
});
|
||||
|
||||
describe("toJSON/fromJSON Networks Equivalency", function() {
|
||||
var original = new LSTM(10,5,5);
|
||||
describe("toJSON/fromJSON Networks Equivalency", function () {
|
||||
var original;
|
||||
var imported;
|
||||
beforeEach(function () {
|
||||
original = new LSTM(10, 5, 5);
|
||||
imported = Network.fromJSON(original.toJSON());
|
||||
});
|
||||
|
||||
var exported = original.toJSON();
|
||||
var imported = Network.fromJSON(exported);
|
||||
|
||||
var learningRate = .5;
|
||||
var iterations = 1000;
|
||||
|
||||
for (var i = 1; i <= iterations; i++)
|
||||
{
|
||||
//random input
|
||||
it('should produce the same output for both networks', function () {
|
||||
this.timeout(30000);
|
||||
for (var i = 0; i < 1000; i++) {
|
||||
var input = generateRandomArray(10);
|
||||
|
||||
// activate networks
|
||||
var output1 = original.activate(input);
|
||||
var output2 = imported.activate(input);
|
||||
|
||||
if (i % 100 == 0)
|
||||
it('should produce the same output for both networks after ' + i + ' iterations', function(){
|
||||
assert(compare(output1, output2));
|
||||
});
|
||||
|
||||
// random target
|
||||
var target = generateRandomArray(5);
|
||||
|
||||
// propagate networks
|
||||
original.propagate(learningRate, target);
|
||||
imported.propagate(learningRate, target);
|
||||
}
|
||||
|
||||
assert.isAtMost(calculateMse(output1, output2), 1e-10,
|
||||
'output should be same for both networks after ' + i + ' iterations');
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe("Cloned Networks Equivalency", function() {
|
||||
describe("Cloned Networks Equivalency", function () {
|
||||
var original;
|
||||
var cloned;
|
||||
beforeEach(function () {
|
||||
original = new LSTM(10, 5, 5);
|
||||
cloned = Network.fromJSON(original.toJSON());
|
||||
});
|
||||
|
||||
var original = new LSTM(10,5,5);
|
||||
|
||||
var cloned = original.clone();
|
||||
|
||||
var learningRate = .5;
|
||||
var iterations = 1000;
|
||||
|
||||
for (var i = 1; i <= iterations; i++)
|
||||
{
|
||||
//random input
|
||||
it('should produce the same output for both networks', function () {
|
||||
this.timeout(30000);
|
||||
for (var i = 0; i < 1000; i++) {
|
||||
var input = generateRandomArray(10);
|
||||
|
||||
// activate networks
|
||||
var output1 = original.activate(input);
|
||||
var output2 = cloned.activate(input);
|
||||
|
||||
if (i % 100 == 0)
|
||||
it('should produce the same output for both networks after ' + i + ' iterations', function(){
|
||||
assert(compare(output1, output2));
|
||||
});
|
||||
|
||||
// random target
|
||||
var target = generateRandomArray(5);
|
||||
|
||||
// propagate networks
|
||||
original.propagate(learningRate, target);
|
||||
cloned.propagate(learningRate, target);
|
||||
}
|
||||
|
||||
assert.isAtMost(calculateMse(output1, output2), 1e-10,
|
||||
'output should be same for both networks after ' + i + ' iterations');
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe("Scheduled Tasks", function() {
|
||||
describe("Scheduled Tasks", function () {
|
||||
var perceptron = new Perceptron(2, 3, 1);
|
||||
|
||||
it('should stop training at 3000 iterations', function(){
|
||||
it('should stop training at 3000 iterations', function () {
|
||||
var final_stats = perceptron.trainer.XOR({
|
||||
iterations: 3000,
|
||||
rate: 0.000001,
|
||||
error: 0.000001,
|
||||
schedule: {
|
||||
every: 1000,
|
||||
do: function(data) {
|
||||
if( data.iterations == 20000){
|
||||
return true
|
||||
}
|
||||
}
|
||||
every: 1000,
|
||||
do: function (data) {
|
||||
return data.iterations == 20000;
|
||||
}
|
||||
}
|
||||
});
|
||||
assert.equal( final_stats.iterations, 3000 )
|
||||
assert.equal(final_stats.iterations, 3000)
|
||||
});
|
||||
|
||||
it('should abort the training at 2000 iterations', function(){
|
||||
it('should abort the training at 2000 iterations', function () {
|
||||
var final_stats = perceptron.trainer.XOR({
|
||||
iterations: 3000,
|
||||
rate: 0.000001,
|
||||
error: 0.000001,
|
||||
schedule: {
|
||||
every: 1000,
|
||||
do: function(data) {
|
||||
if( data.iterations == 2000){
|
||||
return true
|
||||
}
|
||||
}
|
||||
every: 1000,
|
||||
do: function (data) {
|
||||
return data.iterations == 2000;
|
||||
}
|
||||
}
|
||||
});
|
||||
assert.equal( final_stats.iterations, 2000 )
|
||||
assert.equal(final_stats.iterations, 2000)
|
||||
});
|
||||
|
||||
it('should work even if shedule.do() returns no value', function(){
|
||||
it('should work even if schedule.do() returns no value', function () {
|
||||
var final_stats = perceptron.trainer.XOR({
|
||||
iterations: 3000,
|
||||
rate: 0.000001,
|
||||
error: 0.000001,
|
||||
schedule: {
|
||||
every: 1000,
|
||||
do: function(data) {}
|
||||
}
|
||||
every: 1000,
|
||||
do: function (data) {}
|
||||
}
|
||||
});
|
||||
assert.equal( final_stats.iterations, 3000 )
|
||||
assert.equal(final_stats.iterations, 3000)
|
||||
});
|
||||
|
||||
});
|
||||
|
||||
describe("Rate Callback Check", function () {
|
||||
var perceptron = new Perceptron(2, 3, 1);
|
||||
|
||||
it('should switch rate from 0.01 to 0.005 after 1000 iterations', function () {
|
||||
var final_stats = perceptron.trainer.XOR({
|
||||
iterations: 2000,
|
||||
rate: function (iterations, error) {
|
||||
return iterations < 1000 ? 0.01 : 0.005
|
||||
},
|
||||
error: 0.000001,
|
||||
schedule: {
|
||||
every: 1,
|
||||
do: function (data) {
|
||||
switch (data.iterations) {
|
||||
case 1:
|
||||
case 500:
|
||||
case 999:
|
||||
assert.equal(data.rate, 0.01);
|
||||
break;
|
||||
|
||||
case 1000:
|
||||
case 1500:
|
||||
case 2000:
|
||||
assert.equal(data.rate, 0.005);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Rate Array Check", function () {
|
||||
var perceptron = new Perceptron(2, 3, 1);
|
||||
|
||||
it('should switch rate from 0.01 to 0.005 after 1000 iterations', function () {
|
||||
var final_stats = perceptron.trainer.XOR({
|
||||
iterations: 2000,
|
||||
rate: [0.01, 0.005],
|
||||
error: 0.000001,
|
||||
schedule: {
|
||||
every: 1,
|
||||
do: function (data) {
|
||||
switch (data.iterations) {
|
||||
case 1:
|
||||
case 500:
|
||||
case 999:
|
||||
assert.equal(data.rate, 0.01);
|
||||
break;
|
||||
|
||||
case 1000:
|
||||
case 1500:
|
||||
case 2000:
|
||||
assert.equal(data.rate, 0.005);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
var webpack = require('webpack')
|
||||
var license = require('./prebuild.js')
|
||||
module.exports = {
|
||||
context: __dirname,
|
||||
entry: {
|
||||
synaptic: './src/synaptic.js',
|
||||
'synaptic.min': './src/synaptic.js'
|
||||
},
|
||||
output: {
|
||||
path: 'dist',
|
||||
filename: '[name].js',
|
||||
},
|
||||
plugins: [
|
||||
new webpack.NoErrorsPlugin(),
|
||||
new webpack.BannerPlugin(license())
|
||||
]
|
||||
}
|
||||
Referência em uma Nova Issue
Bloquear um usuário