Merge pull request #102 from kylemath/ml5
[ready to deploy] - Working KNN classification, two training classes, live prediction of state and animation of prediction
Esse commit está contido em:
@@ -13,6 +13,7 @@
|
||||
"file-saver": "^2.0.2",
|
||||
"firebase": "^7.5.0",
|
||||
"firebase-tools": "^7.9.0",
|
||||
"ml5": "^0.4.3",
|
||||
"handlebars": "^4.3.0",
|
||||
"muse-js": "^3.0.1",
|
||||
"p5": "^0.10.2",
|
||||
|
||||
@@ -15,6 +15,7 @@ import * as funAnimate from "./components/EEGEduAnimate/EEGEduAnimate";
|
||||
import * as funSpectro from "./components/EEGEduSpectro/EEGEduSpectro";
|
||||
import * as funAlpha from "./components/EEGEduAlpha/EEGEduAlpha"
|
||||
import * as funSsvep from "./components/EEGEduSsvep/EEGEduSsvep"
|
||||
import * as funPredict from "./components/EEGEduPredict/EEGEduPredict"
|
||||
|
||||
const intro = translations.types.intro;
|
||||
const raw = translations.types.raw;
|
||||
@@ -24,6 +25,7 @@ const animate = translations.types.animate;
|
||||
const spectro = translations.types.spectro;
|
||||
const alpha = translations.types.alpha;
|
||||
const ssvep = translations.types.ssvep;
|
||||
const predict = translations.types.predict;
|
||||
|
||||
export function PageSwitcher() {
|
||||
|
||||
@@ -36,6 +38,7 @@ export function PageSwitcher() {
|
||||
const [spectroData, setSpectroData] = useState(emptyChannelData);
|
||||
const [alphaData, setAlphaData] = useState(emptyChannelData);
|
||||
const [ssvepData, setSsvepData] = useState(emptyChannelData);
|
||||
const [predictData, setPredictData] = useState(emptyChannelData);
|
||||
|
||||
// pipe settings
|
||||
const [introSettings] = useState(funIntro.getSettings);
|
||||
@@ -46,6 +49,7 @@ export function PageSwitcher() {
|
||||
const [spectroSettings, setSpectroSettings] = useState(funSpectro.getSettings);
|
||||
const [alphaSettings, setAlphaSettings] = useState(funAlpha.getSettings);
|
||||
const [ssvepSettings, setSsvepSettings] = useState(funSsvep.getSettings);
|
||||
const [predictSettings, setPredictSettings] = useState(funPredict.getSettings);
|
||||
|
||||
// connection status
|
||||
const [status, setStatus] = useState(generalTranslations.connect);
|
||||
@@ -65,6 +69,7 @@ export function PageSwitcher() {
|
||||
if (window.subscriptionSpectro) window.subscriptionSpectro.unsubscribe();
|
||||
if (window.subscriptionAlpha) window.subscriptionAlpha.unsubscribe();
|
||||
if (window.subscriptionSsvep) window.subscriptionSsvep.unsubscribe();
|
||||
if (window.subscriptionPredict) window.subscriptionPredict.unsubscribe();
|
||||
|
||||
subscriptionSetup(value);
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
@@ -86,7 +91,9 @@ export function PageSwitcher() {
|
||||
{ label: animate, value: animate },
|
||||
{ label: spectro, value: spectro },
|
||||
{ label: alpha, value: alpha },
|
||||
{ label: ssvep, value: ssvep }
|
||||
{ label: ssvep, value: ssvep },
|
||||
{ label: predict, value: predict }
|
||||
|
||||
];
|
||||
|
||||
function buildPipes(value) {
|
||||
@@ -98,6 +105,7 @@ export function PageSwitcher() {
|
||||
funSpectro.buildPipe(spectroSettings);
|
||||
funAlpha.buildPipe(alphaSettings);
|
||||
funSsvep.buildPipe(ssvepSettings);
|
||||
funPredict.buildPipe(predictSettings);
|
||||
}
|
||||
|
||||
function subscriptionSetup(value) {
|
||||
@@ -126,6 +134,9 @@ export function PageSwitcher() {
|
||||
case ssvep:
|
||||
funSsvep.setup(setSsvepData, ssvepSettings);
|
||||
break;
|
||||
case predict:
|
||||
funPredict.setup(setPredictData, predictSettings);
|
||||
break;
|
||||
default:
|
||||
console.log(
|
||||
"Error on handle Subscriptions. Couldn't switch to: " + value
|
||||
@@ -201,6 +212,10 @@ export function PageSwitcher() {
|
||||
return (
|
||||
funSsvep.renderSliders(setSsvepData, setSsvepSettings, status, ssvepSettings)
|
||||
);
|
||||
case predict:
|
||||
return (
|
||||
funPredict.renderSliders(setPredictData, setPredictSettings, status, predictSettings)
|
||||
);
|
||||
default: console.log('Error rendering settings display');
|
||||
}
|
||||
}
|
||||
@@ -223,6 +238,8 @@ export function PageSwitcher() {
|
||||
return <funAlpha.renderModule data={alphaData} />;
|
||||
case ssvep:
|
||||
return <funSsvep.renderModule data={ssvepData} />;
|
||||
case predict:
|
||||
return <funPredict.renderModule data={predictData} />;
|
||||
default:
|
||||
console.log("Error on renderCharts switch.");
|
||||
}
|
||||
@@ -256,6 +273,10 @@ export function PageSwitcher() {
|
||||
return (
|
||||
funSsvep.renderRecord(recordPopChange, recordPop, status, ssvepSettings, recordTwoPopChange, recordTwoPop)
|
||||
)
|
||||
case predict:
|
||||
return (
|
||||
funPredict.renderRecord(status)
|
||||
)
|
||||
default:
|
||||
console.log("Error on renderRecord.");
|
||||
}
|
||||
|
||||
@@ -0,0 +1,233 @@
|
||||
import React from "react";
|
||||
import { catchError, multicast } from "rxjs/operators";
|
||||
|
||||
import { TextContainer, Card, Stack, Button, ButtonGroup } from "@shopify/polaris";
|
||||
import { Subject } from "rxjs";
|
||||
|
||||
import { zipSamples } from "muse-js";
|
||||
|
||||
import {
|
||||
bandpassFilter,
|
||||
epoch,
|
||||
fft,
|
||||
sliceFFT
|
||||
} from "@neurosity/pipes";
|
||||
|
||||
import { chartStyles } from "../chartOptions";
|
||||
|
||||
import * as generalTranslations from "../translations/en";
|
||||
import * as specificTranslations from "./translations/en";
|
||||
|
||||
import P5Wrapper from 'react-p5-wrapper';
|
||||
import sketchPredict from './sketchPredict';
|
||||
|
||||
import ml5 from 'ml5'
|
||||
|
||||
let knnClassifier = ml5.KNNClassifier();
|
||||
|
||||
|
||||
export function getSettings() {
|
||||
return {
|
||||
cutOffLow: 2,
|
||||
cutOffHigh: 20,
|
||||
nbChannels: 4,
|
||||
interval: 256,
|
||||
bins: 256,
|
||||
sliceFFTLow: 1,
|
||||
sliceFFTHigh: 30,
|
||||
duration: 512,
|
||||
srate: 256,
|
||||
name: 'Predict'
|
||||
}
|
||||
};
|
||||
|
||||
export function buildPipe(Settings) {
|
||||
if (window.subscriptionPredict) window.subscriptionPredict.unsubscribe();
|
||||
|
||||
window.pipePredict$ = null;
|
||||
window.multicastPredict$ = null;
|
||||
window.subscriptionPredict = null;
|
||||
|
||||
// Build Pipe
|
||||
window.pipePredict$ = zipSamples(window.source.eegReadings$).pipe(
|
||||
bandpassFilter({
|
||||
cutoffFrequencies: [Settings.cutOffLow, Settings.cutOffHigh],
|
||||
nbChannels: Settings.nbChannels }),
|
||||
epoch({
|
||||
duration: Settings.duration,
|
||||
interval: Settings.interval,
|
||||
samplingRate: Settings.srate
|
||||
}),
|
||||
fft({ bins: Settings.bins }),
|
||||
sliceFFT([Settings.sliceFFTLow, Settings.sliceFFTHigh]),
|
||||
catchError(err => {
|
||||
console.log(err);
|
||||
})
|
||||
);
|
||||
|
||||
window.multicastPredict$ = window.pipePredict$.pipe(
|
||||
multicast(() => new Subject())
|
||||
);
|
||||
}
|
||||
|
||||
export function setup(setData, Settings) {
|
||||
console.log("Subscribing to " + Settings.name);
|
||||
|
||||
if (window.multicastPredict$) {
|
||||
window.subscriptionPredict = window.multicastPredict$.subscribe(data => {
|
||||
setData(predictData => {
|
||||
Object.values(predictData).forEach((channel, index) => {
|
||||
if (index < 4) {
|
||||
channel.datasets[0].data = data.psd[index];
|
||||
channel.xLabels = data.freqs;
|
||||
}
|
||||
});
|
||||
|
||||
return {
|
||||
ch0: predictData.ch0,
|
||||
ch1: predictData.ch1,
|
||||
ch2: predictData.ch2,
|
||||
ch3: predictData.ch3
|
||||
};
|
||||
});
|
||||
});
|
||||
|
||||
window.multicastPredict$.connect();
|
||||
console.log("Subscribed to " + Settings.name);
|
||||
}
|
||||
}
|
||||
|
||||
export function renderModule(channels) {
|
||||
function renderCharts() {
|
||||
return Object.values(channels.data).map((channel, index) => {
|
||||
if (index === 0) {
|
||||
|
||||
if (channel.datasets[0].data) {
|
||||
window.psd = channel.datasets[0].data;
|
||||
window.freqs = channel.xLabels;
|
||||
if (channel.xLabels) {
|
||||
window.bins = channel.xLabels.length;
|
||||
}
|
||||
}
|
||||
return null
|
||||
} else {
|
||||
return null
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<Card title={specificTranslations.title}>
|
||||
<Card.Section>
|
||||
<Stack>
|
||||
<TextContainer>
|
||||
<p>{specificTranslations.description}</p>
|
||||
</TextContainer>
|
||||
</Stack>
|
||||
</Card.Section>
|
||||
<Card.Section>
|
||||
<div style={chartStyles.wrapperStyle.style}>{renderCharts()}</div>
|
||||
</Card.Section>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
export function renderSliders(setData, setSettings, status, Settings) {
|
||||
return null
|
||||
}
|
||||
|
||||
// Classification algorithm (using renderRecord function)
|
||||
window.exampleCounts = {A: 0, B: 0};
|
||||
window.thisLabel = 'A';
|
||||
window.confidences = {A: 1, B: 0};
|
||||
|
||||
window.isPredicting = false;
|
||||
window.enoughLabels = false;
|
||||
|
||||
export function renderRecord(status) {
|
||||
const condA = "A";
|
||||
const condB = "B";
|
||||
|
||||
// Adds example from current incoming psd
|
||||
function addExample (label) {
|
||||
if (window.psd) {
|
||||
knnClassifier.addExample(window.psd, label);
|
||||
window.exampleCounts[label]++;
|
||||
|
||||
const numLabels = knnClassifier.getNumLabels();
|
||||
if (numLabels === 2) {
|
||||
window.enoughLabels = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Classifies current incoming psd and outputs results
|
||||
function classify () {
|
||||
window.isPredicting = true;
|
||||
knnClassifier.classify(window.psd, gotResults)
|
||||
}
|
||||
|
||||
// callback from classify to assign results to window and recurse
|
||||
function gotResults(err, result) {
|
||||
if (result.confidencesByLabel) {
|
||||
window.confidences = result.confidencesByLabel;
|
||||
if (result.label) {
|
||||
window.thisLabel = result.label;
|
||||
}
|
||||
}
|
||||
classify(); //recursive so it continues to run
|
||||
}
|
||||
|
||||
//buttons for training at prediction
|
||||
return(
|
||||
<React.Fragment>
|
||||
<Card title={'Record Training Data'} sectioned>
|
||||
<Stack>
|
||||
<ButtonGroup>
|
||||
<Button
|
||||
onClick={() => {
|
||||
addExample('A');
|
||||
}}
|
||||
disabled={window.isPredicting || status === generalTranslations.connect}
|
||||
>
|
||||
{'Record ' + condA +' Data - Count: ' + window.exampleCounts['A']}
|
||||
</Button>
|
||||
<Button
|
||||
onClick={() => {
|
||||
addExample('B');
|
||||
}}
|
||||
disabled={window.isPredicting || status === generalTranslations.connect}
|
||||
>
|
||||
{'Record ' + condB + ' Data - Count: ' + window.exampleCounts['B']}
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
</Stack>
|
||||
</Card>
|
||||
|
||||
<Card title={'Predict current brain state after Training'} sectioned>
|
||||
<Stack>
|
||||
<ButtonGroup>
|
||||
<Button
|
||||
onClick={() => {
|
||||
console.log('Attempting to classify state')
|
||||
classify();
|
||||
}}
|
||||
disabled={!window.enoughLabels || status === generalTranslations.connect}
|
||||
primary={true}
|
||||
>
|
||||
{'Predict State: ' + window.thisLabel + ', Confidence: ' + window.confidences[window.thisLabel].toFixed(2)}
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
</Stack>
|
||||
<Card.Section>
|
||||
<P5Wrapper sketch={sketchPredict}
|
||||
label={window.thisLabel}
|
||||
confidences={window.confidences}
|
||||
|
||||
/>
|
||||
</Card.Section>
|
||||
</Card>
|
||||
|
||||
</React.Fragment>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
export default function sketchPredict (p) {
|
||||
|
||||
let label;
|
||||
let confidence;
|
||||
|
||||
p.setup = function () {
|
||||
p.createCanvas(p.windowWidth*.6, 300);
|
||||
|
||||
};
|
||||
|
||||
p.windowResized = function() {
|
||||
p.resizeCanvas(p.windowWidth*.6, 300);
|
||||
}
|
||||
|
||||
p.myCustomRedrawAccordingToNewPropsHandler = function (props) {
|
||||
label = props.label;
|
||||
confidence = props.confidences[label];
|
||||
};
|
||||
|
||||
p.draw = function () {
|
||||
p.background(250, 250, 150);
|
||||
p.fill(0);
|
||||
p.strokeWeight(5);
|
||||
p.line(p.width/2, 0, p.width/2, p.height);
|
||||
p.textSize(30);
|
||||
p.text('A', p.width/4, 30);
|
||||
p.text('B', p.width-p.width/4, 30)
|
||||
if (label === 'A') {
|
||||
p.fill(120, 120, 250);
|
||||
if (confidence > .8) {
|
||||
p.ellipse(p.width/6, p.height/2, 60);
|
||||
} else {
|
||||
p.ellipse(p.width/3, p.height/2, 20);
|
||||
}
|
||||
} else {
|
||||
p.fill(120, 250, 120);
|
||||
if (confidence > .8) {
|
||||
p.ellipse(p.width-p.width/6, p.height/2, 60);
|
||||
} else {
|
||||
p.ellipse(p.width-p.width/3, p.height/2, 20);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"title": "Predict brain states with a trained classifier",
|
||||
"description": [
|
||||
"In the next module we will train and test classifiers of brain data like we have been looking at so far. ",
|
||||
"We will collect data in two different conditions, with the goal of inducing two different brain states. ",
|
||||
"We will then train a classifier based on the pattern of activity over time, frequency, and space on the head. ",
|
||||
"We will then use the classifier to predict on real time which of the two brain states are currently happening. ",
|
||||
"That is, we will attempt to predict if peoples brain activity more closely resembles condition A or conditoin B ",
|
||||
"Of the training data. "
|
||||
],
|
||||
"xlabel": "Frequency (Hz)",
|
||||
"ylabel": "Power (\u03BCV\u00B2)"
|
||||
}
|
||||
@@ -1,13 +1,14 @@
|
||||
{
|
||||
"title": "Choose your Module",
|
||||
"types": {
|
||||
"intro": "Introduction",
|
||||
"raw": "Raw and Filtered Data",
|
||||
"spectra": "Frequency Spectra",
|
||||
"bands": "Frequency Bands",
|
||||
"animate": "Brain Controlled Animation",
|
||||
"spectro": "Spectrogram (spectra over time)",
|
||||
"alpha": "Eyes open vs. Eyes closed Experiment",
|
||||
"ssvep": "Steady-State Visual Evoked Potential (SSVEP) Experiment"
|
||||
"intro": "1. Introduction",
|
||||
"raw": "2. Raw and Filtered Data",
|
||||
"spectra": "3. Frequency Spectra",
|
||||
"bands": "4. Frequency Bands",
|
||||
"animate": "5. Brain Controlled Animation",
|
||||
"spectro": "6. Spectrogram (spectra over time)",
|
||||
"alpha": "7. Eyes open vs. Eyes closed Experiment",
|
||||
"ssvep": "8. Steady-State Visual Evoked Potential (SSVEP) Experiment",
|
||||
"predict": "9. Predict brain states with a trained classifier"
|
||||
}
|
||||
}
|
||||
|
||||
+951
-17
Diferenças do arquivo suprimidas por serem muito extensas
Carregar Diff
Referência em uma Nova Issue
Bloquear um usuário