Merge pull request #102 from kylemath/ml5

[ready to deploy] - Working KNN classification, two training classes, live prediction of state and animation of prediction
Esse commit está contido em:
Kory
2020-01-04 23:22:16 -05:00
commit de GitHub
7 arquivos alterados com 1273 adições e 26 exclusões
+1
Ver Arquivo
@@ -13,6 +13,7 @@
"file-saver": "^2.0.2",
"firebase": "^7.5.0",
"firebase-tools": "^7.9.0",
"ml5": "^0.4.3",
"handlebars": "^4.3.0",
"muse-js": "^3.0.1",
"p5": "^0.10.2",
+22 -1
Ver Arquivo
@@ -15,6 +15,7 @@ import * as funAnimate from "./components/EEGEduAnimate/EEGEduAnimate";
import * as funSpectro from "./components/EEGEduSpectro/EEGEduSpectro";
import * as funAlpha from "./components/EEGEduAlpha/EEGEduAlpha"
import * as funSsvep from "./components/EEGEduSsvep/EEGEduSsvep"
import * as funPredict from "./components/EEGEduPredict/EEGEduPredict"
const intro = translations.types.intro;
const raw = translations.types.raw;
@@ -24,6 +25,7 @@ const animate = translations.types.animate;
const spectro = translations.types.spectro;
const alpha = translations.types.alpha;
const ssvep = translations.types.ssvep;
const predict = translations.types.predict;
export function PageSwitcher() {
@@ -36,6 +38,7 @@ export function PageSwitcher() {
const [spectroData, setSpectroData] = useState(emptyChannelData);
const [alphaData, setAlphaData] = useState(emptyChannelData);
const [ssvepData, setSsvepData] = useState(emptyChannelData);
const [predictData, setPredictData] = useState(emptyChannelData);
// pipe settings
const [introSettings] = useState(funIntro.getSettings);
@@ -46,6 +49,7 @@ export function PageSwitcher() {
const [spectroSettings, setSpectroSettings] = useState(funSpectro.getSettings);
const [alphaSettings, setAlphaSettings] = useState(funAlpha.getSettings);
const [ssvepSettings, setSsvepSettings] = useState(funSsvep.getSettings);
const [predictSettings, setPredictSettings] = useState(funPredict.getSettings);
// connection status
const [status, setStatus] = useState(generalTranslations.connect);
@@ -65,6 +69,7 @@ export function PageSwitcher() {
if (window.subscriptionSpectro) window.subscriptionSpectro.unsubscribe();
if (window.subscriptionAlpha) window.subscriptionAlpha.unsubscribe();
if (window.subscriptionSsvep) window.subscriptionSsvep.unsubscribe();
if (window.subscriptionPredict) window.subscriptionPredict.unsubscribe();
subscriptionSetup(value);
// eslint-disable-next-line react-hooks/exhaustive-deps
@@ -86,7 +91,9 @@ export function PageSwitcher() {
{ label: animate, value: animate },
{ label: spectro, value: spectro },
{ label: alpha, value: alpha },
{ label: ssvep, value: ssvep }
{ label: ssvep, value: ssvep },
{ label: predict, value: predict }
];
function buildPipes(value) {
@@ -98,6 +105,7 @@ export function PageSwitcher() {
funSpectro.buildPipe(spectroSettings);
funAlpha.buildPipe(alphaSettings);
funSsvep.buildPipe(ssvepSettings);
funPredict.buildPipe(predictSettings);
}
function subscriptionSetup(value) {
@@ -126,6 +134,9 @@ export function PageSwitcher() {
case ssvep:
funSsvep.setup(setSsvepData, ssvepSettings);
break;
case predict:
funPredict.setup(setPredictData, predictSettings);
break;
default:
console.log(
"Error on handle Subscriptions. Couldn't switch to: " + value
@@ -201,6 +212,10 @@ export function PageSwitcher() {
return (
funSsvep.renderSliders(setSsvepData, setSsvepSettings, status, ssvepSettings)
);
case predict:
return (
funPredict.renderSliders(setPredictData, setPredictSettings, status, predictSettings)
);
default: console.log('Error rendering settings display');
}
}
@@ -223,6 +238,8 @@ export function PageSwitcher() {
return <funAlpha.renderModule data={alphaData} />;
case ssvep:
return <funSsvep.renderModule data={ssvepData} />;
case predict:
return <funPredict.renderModule data={predictData} />;
default:
console.log("Error on renderCharts switch.");
}
@@ -256,6 +273,10 @@ export function PageSwitcher() {
return (
funSsvep.renderRecord(recordPopChange, recordPop, status, ssvepSettings, recordTwoPopChange, recordTwoPop)
)
case predict:
return (
funPredict.renderRecord(status)
)
default:
console.log("Error on renderRecord.");
}
@@ -0,0 +1,233 @@
import React from "react";
import { catchError, multicast } from "rxjs/operators";
import { TextContainer, Card, Stack, Button, ButtonGroup } from "@shopify/polaris";
import { Subject } from "rxjs";
import { zipSamples } from "muse-js";
import {
bandpassFilter,
epoch,
fft,
sliceFFT
} from "@neurosity/pipes";
import { chartStyles } from "../chartOptions";
import * as generalTranslations from "../translations/en";
import * as specificTranslations from "./translations/en";
import P5Wrapper from 'react-p5-wrapper';
import sketchPredict from './sketchPredict';
import ml5 from 'ml5'
let knnClassifier = ml5.KNNClassifier();
export function getSettings() {
return {
cutOffLow: 2,
cutOffHigh: 20,
nbChannels: 4,
interval: 256,
bins: 256,
sliceFFTLow: 1,
sliceFFTHigh: 30,
duration: 512,
srate: 256,
name: 'Predict'
}
};
export function buildPipe(Settings) {
if (window.subscriptionPredict) window.subscriptionPredict.unsubscribe();
window.pipePredict$ = null;
window.multicastPredict$ = null;
window.subscriptionPredict = null;
// Build Pipe
window.pipePredict$ = zipSamples(window.source.eegReadings$).pipe(
bandpassFilter({
cutoffFrequencies: [Settings.cutOffLow, Settings.cutOffHigh],
nbChannels: Settings.nbChannels }),
epoch({
duration: Settings.duration,
interval: Settings.interval,
samplingRate: Settings.srate
}),
fft({ bins: Settings.bins }),
sliceFFT([Settings.sliceFFTLow, Settings.sliceFFTHigh]),
catchError(err => {
console.log(err);
})
);
window.multicastPredict$ = window.pipePredict$.pipe(
multicast(() => new Subject())
);
}
export function setup(setData, Settings) {
console.log("Subscribing to " + Settings.name);
if (window.multicastPredict$) {
window.subscriptionPredict = window.multicastPredict$.subscribe(data => {
setData(predictData => {
Object.values(predictData).forEach((channel, index) => {
if (index < 4) {
channel.datasets[0].data = data.psd[index];
channel.xLabels = data.freqs;
}
});
return {
ch0: predictData.ch0,
ch1: predictData.ch1,
ch2: predictData.ch2,
ch3: predictData.ch3
};
});
});
window.multicastPredict$.connect();
console.log("Subscribed to " + Settings.name);
}
}
export function renderModule(channels) {
function renderCharts() {
return Object.values(channels.data).map((channel, index) => {
if (index === 0) {
if (channel.datasets[0].data) {
window.psd = channel.datasets[0].data;
window.freqs = channel.xLabels;
if (channel.xLabels) {
window.bins = channel.xLabels.length;
}
}
return null
} else {
return null
}
});
}
return (
<Card title={specificTranslations.title}>
<Card.Section>
<Stack>
<TextContainer>
<p>{specificTranslations.description}</p>
</TextContainer>
</Stack>
</Card.Section>
<Card.Section>
<div style={chartStyles.wrapperStyle.style}>{renderCharts()}</div>
</Card.Section>
</Card>
);
}
export function renderSliders(setData, setSettings, status, Settings) {
return null
}
// Classification algorithm (using renderRecord function)
window.exampleCounts = {A: 0, B: 0};
window.thisLabel = 'A';
window.confidences = {A: 1, B: 0};
window.isPredicting = false;
window.enoughLabels = false;
export function renderRecord(status) {
const condA = "A";
const condB = "B";
// Adds example from current incoming psd
function addExample (label) {
if (window.psd) {
knnClassifier.addExample(window.psd, label);
window.exampleCounts[label]++;
const numLabels = knnClassifier.getNumLabels();
if (numLabels === 2) {
window.enoughLabels = true;
}
}
}
// Classifies current incoming psd and outputs results
function classify () {
window.isPredicting = true;
knnClassifier.classify(window.psd, gotResults)
}
// callback from classify to assign results to window and recurse
function gotResults(err, result) {
if (result.confidencesByLabel) {
window.confidences = result.confidencesByLabel;
if (result.label) {
window.thisLabel = result.label;
}
}
classify(); //recursive so it continues to run
}
//buttons for training at prediction
return(
<React.Fragment>
<Card title={'Record Training Data'} sectioned>
<Stack>
<ButtonGroup>
<Button
onClick={() => {
addExample('A');
}}
disabled={window.isPredicting || status === generalTranslations.connect}
>
{'Record ' + condA +' Data - Count: ' + window.exampleCounts['A']}
</Button>
<Button
onClick={() => {
addExample('B');
}}
disabled={window.isPredicting || status === generalTranslations.connect}
>
{'Record ' + condB + ' Data - Count: ' + window.exampleCounts['B']}
</Button>
</ButtonGroup>
</Stack>
</Card>
<Card title={'Predict current brain state after Training'} sectioned>
<Stack>
<ButtonGroup>
<Button
onClick={() => {
console.log('Attempting to classify state')
classify();
}}
disabled={!window.enoughLabels || status === generalTranslations.connect}
primary={true}
>
{'Predict State: ' + window.thisLabel + ', Confidence: ' + window.confidences[window.thisLabel].toFixed(2)}
</Button>
</ButtonGroup>
</Stack>
<Card.Section>
<P5Wrapper sketch={sketchPredict}
label={window.thisLabel}
confidences={window.confidences}
/>
</Card.Section>
</Card>
</React.Fragment>
)
}
@@ -0,0 +1,44 @@
export default function sketchPredict (p) {
let label;
let confidence;
p.setup = function () {
p.createCanvas(p.windowWidth*.6, 300);
};
p.windowResized = function() {
p.resizeCanvas(p.windowWidth*.6, 300);
}
p.myCustomRedrawAccordingToNewPropsHandler = function (props) {
label = props.label;
confidence = props.confidences[label];
};
p.draw = function () {
p.background(250, 250, 150);
p.fill(0);
p.strokeWeight(5);
p.line(p.width/2, 0, p.width/2, p.height);
p.textSize(30);
p.text('A', p.width/4, 30);
p.text('B', p.width-p.width/4, 30)
if (label === 'A') {
p.fill(120, 120, 250);
if (confidence > .8) {
p.ellipse(p.width/6, p.height/2, 60);
} else {
p.ellipse(p.width/3, p.height/2, 20);
}
} else {
p.fill(120, 250, 120);
if (confidence > .8) {
p.ellipse(p.width-p.width/6, p.height/2, 60);
} else {
p.ellipse(p.width-p.width/3, p.height/2, 20);
}
}
}
};
@@ -0,0 +1,13 @@
{
"title": "Predict brain states with a trained classifier",
"description": [
"In the next module we will train and test classifiers of brain data like we have been looking at so far. ",
"We will collect data in two different conditions, with the goal of inducing two different brain states. ",
"We will then train a classifier based on the pattern of activity over time, frequency, and space on the head. ",
"We will then use the classifier to predict on real time which of the two brain states are currently happening. ",
"That is, we will attempt to predict if peoples brain activity more closely resembles condition A or conditoin B ",
"Of the training data. "
],
"xlabel": "Frequency (Hz)",
"ylabel": "Power (\u03BCV\u00B2)"
}
@@ -1,13 +1,14 @@
{
"title": "Choose your Module",
"types": {
"intro": "Introduction",
"raw": "Raw and Filtered Data",
"spectra": "Frequency Spectra",
"bands": "Frequency Bands",
"animate": "Brain Controlled Animation",
"spectro": "Spectrogram (spectra over time)",
"alpha": "Eyes open vs. Eyes closed Experiment",
"ssvep": "Steady-State Visual Evoked Potential (SSVEP) Experiment"
"intro": "1. Introduction",
"raw": "2. Raw and Filtered Data",
"spectra": "3. Frequency Spectra",
"bands": "4. Frequency Bands",
"animate": "5. Brain Controlled Animation",
"spectro": "6. Spectrogram (spectra over time)",
"alpha": "7. Eyes open vs. Eyes closed Experiment",
"ssvep": "8. Steady-State Visual Evoked Potential (SSVEP) Experiment",
"predict": "9. Predict brain states with a trained classifier"
}
}
+951 -17
Ver Arquivo
Diferenças do arquivo suprimidas por serem muito extensas Carregar Diff