some change to notebookdocs

Esse commit está contido em:
Robin Tibor Schirrmeister
2018-08-27 14:21:22 +02:00
commit b0f186e3e0
7 arquivos alterados com 417 adições e 81 exclusões
+1 -1
Ver Arquivo
@@ -447,7 +447,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
"version": "3.6.6"
}
},
"nbformat": 4,
+1 -1
Ver Arquivo
@@ -359,7 +359,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
"version": "3.6.6"
}
},
"nbformat": 4,
+1 -1
Ver Arquivo
@@ -346,7 +346,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
"version": "3.6.6"
}
},
"nbformat": 4,
+81 -73
Ver Arquivo
@@ -50,7 +50,9 @@
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import mne\n",
@@ -93,7 +95,9 @@
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
@@ -195,7 +199,9 @@
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from braindecode.datautil.iterators import CropsFromTrialsIterator\n",
@@ -227,7 +233,9 @@
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from braindecode.torch_ext.optimizers import AdamW\n",
@@ -269,105 +277,105 @@
"output_type": "stream",
"text": [
"Epoch 0\n",
"Train Loss: 3.88829\n",
"Train Loss: 3.79010\n",
"Train Accuracy: 50.0%\n",
"Valid Loss: 3.18595\n",
"Valid Accuracy: 50.0%\n",
"Valid Loss: 3.12765\n",
"Valid Accuracy: 46.7%\n",
"Epoch 1\n",
"Train Loss: 1.33431\n",
"Train Accuracy: 55.0%\n",
"Valid Loss: 0.96900\n",
"Valid Accuracy: 60.0%\n",
"Train Loss: 1.91778\n",
"Train Accuracy: 50.0%\n",
"Valid Loss: 1.52058\n",
"Valid Accuracy: 46.7%\n",
"Epoch 2\n",
"Train Loss: 0.60838\n",
"Train Accuracy: 77.5%\n",
"Valid Loss: 0.55685\n",
"Valid Accuracy: 76.7%\n",
"Train Loss: 1.09943\n",
"Train Accuracy: 60.0%\n",
"Valid Loss: 0.99602\n",
"Valid Accuracy: 56.7%\n",
"Epoch 3\n",
"Train Loss: 0.49553\n",
"Train Accuracy: 77.5%\n",
"Valid Loss: 0.57463\n",
"Valid Accuracy: 76.7%\n",
"Train Loss: 0.73015\n",
"Train Accuracy: 67.5%\n",
"Valid Loss: 0.82321\n",
"Valid Accuracy: 56.7%\n",
"Epoch 4\n",
"Train Loss: 0.46749\n",
"Train Accuracy: 80.0%\n",
"Valid Loss: 0.62036\n",
"Valid Accuracy: 76.7%\n",
"Train Loss: 0.55965\n",
"Train Accuracy: 75.0%\n",
"Valid Loss: 0.77549\n",
"Valid Accuracy: 63.3%\n",
"Epoch 5\n",
"Train Loss: 0.36088\n",
"Train Accuracy: 85.0%\n",
"Valid Loss: 0.59054\n",
"Valid Accuracy: 76.7%\n",
"Train Loss: 0.38075\n",
"Train Accuracy: 82.5%\n",
"Valid Loss: 0.65490\n",
"Valid Accuracy: 73.3%\n",
"Epoch 6\n",
"Train Loss: 0.26471\n",
"Train Loss: 0.28886\n",
"Train Accuracy: 90.0%\n",
"Valid Loss: 0.55745\n",
"Valid Accuracy: 76.7%\n",
"Valid Loss: 0.59481\n",
"Valid Accuracy: 73.3%\n",
"Epoch 7\n",
"Train Loss: 0.19023\n",
"Train Accuracy: 95.0%\n",
"Valid Loss: 0.53162\n",
"Train Loss: 0.21871\n",
"Train Accuracy: 97.5%\n",
"Valid Loss: 0.52357\n",
"Valid Accuracy: 83.3%\n",
"Epoch 8\n",
"Train Loss: 0.13684\n",
"Train Loss: 0.16713\n",
"Train Accuracy: 97.5%\n",
"Valid Loss: 0.50510\n",
"Valid Accuracy: 83.3%\n",
"Valid Loss: 0.44890\n",
"Valid Accuracy: 86.7%\n",
"Epoch 9\n",
"Train Loss: 0.10682\n",
"Train Loss: 0.13149\n",
"Train Accuracy: 97.5%\n",
"Valid Loss: 0.47980\n",
"Valid Accuracy: 86.7%\n",
"Valid Loss: 0.40492\n",
"Valid Accuracy: 83.3%\n",
"Epoch 10\n",
"Train Loss: 0.08165\n",
"Train Loss: 0.09581\n",
"Train Accuracy: 100.0%\n",
"Valid Loss: 0.45300\n",
"Valid Accuracy: 86.7%\n",
"Valid Loss: 0.36021\n",
"Valid Accuracy: 90.0%\n",
"Epoch 11\n",
"Train Loss: 0.05927\n",
"Train Loss: 0.07818\n",
"Train Accuracy: 100.0%\n",
"Valid Loss: 0.43593\n",
"Valid Accuracy: 86.7%\n",
"Valid Loss: 0.34625\n",
"Valid Accuracy: 90.0%\n",
"Epoch 12\n",
"Train Loss: 0.04447\n",
"Train Loss: 0.07454\n",
"Train Accuracy: 100.0%\n",
"Valid Loss: 0.43543\n",
"Valid Accuracy: 83.3%\n",
"Valid Loss: 0.34489\n",
"Valid Accuracy: 90.0%\n",
"Epoch 13\n",
"Train Loss: 0.03776\n",
"Train Loss: 0.06694\n",
"Train Accuracy: 100.0%\n",
"Valid Loss: 0.43941\n",
"Valid Accuracy: 83.3%\n",
"Valid Loss: 0.33878\n",
"Valid Accuracy: 90.0%\n",
"Epoch 14\n",
"Train Loss: 0.03341\n",
"Train Loss: 0.05971\n",
"Train Accuracy: 100.0%\n",
"Valid Loss: 0.43653\n",
"Valid Accuracy: 83.3%\n",
"Valid Loss: 0.33128\n",
"Valid Accuracy: 90.0%\n",
"Epoch 15\n",
"Train Loss: 0.02990\n",
"Train Loss: 0.05269\n",
"Train Accuracy: 100.0%\n",
"Valid Loss: 0.43027\n",
"Valid Accuracy: 83.3%\n",
"Valid Loss: 0.32202\n",
"Valid Accuracy: 90.0%\n",
"Epoch 16\n",
"Train Loss: 0.02713\n",
"Train Loss: 0.04354\n",
"Train Accuracy: 100.0%\n",
"Valid Loss: 0.42535\n",
"Valid Accuracy: 83.3%\n",
"Valid Loss: 0.31063\n",
"Valid Accuracy: 90.0%\n",
"Epoch 17\n",
"Train Loss: 0.02483\n",
"Train Loss: 0.03759\n",
"Train Accuracy: 100.0%\n",
"Valid Loss: 0.42535\n",
"Valid Accuracy: 83.3%\n",
"Valid Loss: 0.30314\n",
"Valid Accuracy: 90.0%\n",
"Epoch 18\n",
"Train Loss: 0.02301\n",
"Train Loss: 0.03401\n",
"Train Accuracy: 100.0%\n",
"Valid Loss: 0.42749\n",
"Valid Accuracy: 83.3%\n",
"Valid Loss: 0.29997\n",
"Valid Accuracy: 90.0%\n",
"Epoch 19\n",
"Train Loss: 0.02164\n",
"Train Loss: 0.03145\n",
"Train Accuracy: 100.0%\n",
"Valid Loss: 0.42855\n",
"Valid Accuracy: 86.7%\n"
"Valid Loss: 0.29764\n",
"Valid Accuracy: 90.0%\n"
]
}
],
@@ -381,7 +389,7 @@
"for i_epoch in range(20):\n",
" # Set model to training mode\n",
" model.train()\n",
" for batch_X, batch_y in iterator.get_batches(train_set, shuffle=False):\n",
" for batch_X, batch_y in iterator.get_batches(train_set, shuffle=True):\n",
" net_in = np_to_var(batch_X)\n",
" if cuda:\n",
" net_in = net_in.cuda()\n",
@@ -443,7 +451,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Eventually, we arrive at 86.7% accuracy, so 26 from 30 trials are correctly predicted, 4 more than for the trialwise decoding method."
"Eventually, we arrive at 90.0% accuracy, so 27 from 30 trials are correctly predicted, 5 more than for the trialwise decoding method."
]
},
{
@@ -462,14 +470,14 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test Loss: 0.43439\n",
"Test Loss: 0.42105\n",
"Test Accuracy: 85.0%\n"
]
}
@@ -0,0 +1,318 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"\n",
"import numpy as np\n",
"import matplotlib\n",
"from matplotlib import pyplot as plt\n",
"from matplotlib import cm\n",
"%matplotlib inline\n",
"%config InlineBackend.figure_format = 'png'\n",
"matplotlib.rcParams['figure.figsize'] = (12.0, 4.0)\n",
"matplotlib.rcParams['font.size'] = 7\n",
"\n",
"import matplotlib.lines as mlines\n",
"import seaborn\n",
"seaborn.set_style('darkgrid')\n",
"import logging\n",
"import importlib\n",
"importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195\n",
"log = logging.getLogger()\n",
"log.setLevel('DEBUG')\n",
"import sys\n",
"logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',\n",
" level=logging.DEBUG, stream=sys.stdout)\n",
"seaborn.set_palette('colorblind')\n",
"import os\n",
"# add the repo itself\n",
"os.sys.path.insert(0, '/home/schirrmr/code/explaining/reversible//')\n",
"os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from braindecode.datasets.bbci import BBCIDataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"cnt = BBCIDataset('/data/schirrmr/schirrmr/HGD-public/reduced/train/8.mat', load_sensor_names=['C3', 'C4', 'Cz']).load()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In trialwise decoding, you forward trials through your network and get one prediction per trial. Then you train the network using the loss between the predictions and the labels for the trials.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(5,5))\n",
"data = cnt.get_data()[0,100:201]\n",
"plt.plot(data - np.mean(data))\n",
"ax = plt.gca()\n",
"plt.plot([0,0],[-20,20], color='black')\n",
"plt.plot([100,100],[-20,20], color='black')\n",
"points = [[0, -20], [50, -80], [100, -20]]\n",
"polygon = plt.Polygon(points, facecolor=seaborn.color_palette()[2] + (0.2,), edgecolor=seaborn.color_palette()[2])\n",
"ax.add_artist(polygon)\n",
"plt.ylim(-100,22)\n",
"plt.text(50,30, \"Trial\", fontsize=20, ha='center')\n",
"plt.text(50,-40, \"Network\", fontsize=20, ha='center')\n",
"plt.plot(50,-80, ls='', marker='o', color=seaborn.color_palette()[2])\n",
"plt.text(45,-80, \"Prediction\", fontsize=20, ha='right', va='center')\n",
"plt.plot(50,-95, ls='', marker='o')\n",
"plt.text(45,-95, \"Target\", fontsize=20, ha='right', va='center')\n",
"plt.axis('off')\n",
"\n",
"plt.plot([50,65],[-80, -87.5], color='black')\n",
"plt.plot([50,65],[-95, -87.5], color='black')\n",
"plt.text(67,-87.5, \"Loss\", fontsize=20, ha='left', va='center')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(5,5))\n",
"data = cnt.get_data()[0,100:201]\n",
"plt.plot(data - np.mean(data))\n",
"ax = plt.gca()\n",
"plt.plot([0,0],[-20,20], color='black')\n",
"plt.plot([100,100],[-20,20], color='black')\n",
"points = [[0, -20], [25, -80], [50, -20]]\n",
"polygon = plt.Polygon(points, facecolor=seaborn.color_palette()[2] + (0.2,), edgecolor=seaborn.color_palette()[2])\n",
"ax.add_artist(polygon)\n",
"points = [[3, -20], [28, -80], [53, -20]]\n",
"polygon = plt.Polygon(points, facecolor=seaborn.color_palette()[2] + (0.2,), edgecolor=seaborn.color_palette()[2])\n",
"ax.add_artist(polygon)\n",
"points = [[6, -20], [31, -80], [56, -20]]\n",
"polygon = plt.Polygon(points, facecolor=seaborn.color_palette()[2] + (0.2,), edgecolor=seaborn.color_palette()[2])\n",
"ax.add_artist(polygon)\n",
"plt.plot(25,-80, ls='', marker='o', color=seaborn.color_palette()[2])\n",
"plt.plot(28,-80, ls='', marker='o', color=seaborn.color_palette()[2])\n",
"plt.plot(31,-80, ls='', marker='o', color=seaborn.color_palette()[2])\n",
"plt.ylim(-100,22)\n",
"plt.text(50,35, \"Trial\", fontsize=20, ha='center')\n",
"plt.text(28,-38, \"Network\", fontsize=16, ha='center')\n",
"plt.text(20,-80, \"Predictions\", fontsize=20, ha='right', va='center')\n",
"plt.plot(28,-95, ls='', marker='o')\n",
"plt.text(20,-95, \"Target\", fontsize=20, ha='right', va='center')\n",
"\n",
"# from https://stackoverflow.com/a/20308475/1469195\n",
"def range_brace(x_min, x_max, mid=0.75, \n",
" beta1=50.0, beta2=100.0, height=1, \n",
" initial_divisions=11, resolution_factor=1.5):\n",
" NP = np\n",
" # determine x0 adaptively values using second derivitive\n",
" # could be replaced with less snazzy:\n",
" # x0 = NP.arange(0, 0.5, .001)\n",
" x0 = NP.array(())\n",
" tmpx = NP.linspace(0, 0.5, initial_divisions)\n",
" tmp = beta1**2 * (NP.exp(beta1*tmpx)) * (1-NP.exp(beta1*tmpx)) / NP.power((1+NP.exp(beta1*tmpx)),3)\n",
" tmp += beta2**2 * (NP.exp(beta2*(tmpx-0.5))) * (1-NP.exp(beta2*(tmpx-0.5))) / NP.power((1+NP.exp(beta2*(tmpx-0.5))),3)\n",
" for i in range(0, len(tmpx)-1):\n",
" t = int(NP.ceil(resolution_factor*max(NP.abs(tmp[i:i+2]))/float(initial_divisions)))\n",
" x0 = NP.append(x0, NP.linspace(tmpx[i],tmpx[i+1],t))\n",
" x0 = NP.sort(NP.unique(x0)) # sort and remove dups\n",
" # half brace using sum of two logistic functions\n",
" y0 = mid*2*((1/(1.+NP.exp(-1*beta1*x0)))-0.5)\n",
" y0 += (1-mid)*2*(1/(1.+NP.exp(-1*beta2*(x0-0.5))))\n",
" # concat and scale x\n",
" x = NP.concatenate((x0, 1-x0[::-1])) * float((x_max-x_min)) + x_min\n",
" y = NP.concatenate((y0, y0[::-1])) * float(height)\n",
" return (x,y)\n",
"\n",
"\n",
"x,y = range_brace(0, 50, height=6)\n",
"ax.plot(x, y-20,'-')\n",
"\n",
"ax.annotate('Crop', xy=(25, -10), xycoords='data', \n",
" fontsize=13, ha='center', va='bottom', color=seaborn.color_palette()[2],\n",
" bbox=dict(boxstyle='square', fc='white', ec='None',))\n",
"\n",
"x,y = range_brace(0, 56, height=6, mid=0.5, beta1=50)\n",
"ax.plot(x, y+10,'-', color=seaborn.color_palette()[2])\n",
"plt.plot([56,56], [-20,10], ls='--', color=seaborn.color_palette()[2])\n",
"ax.annotate('Supercrop', xy=(28, 20), xycoords='data', \n",
" fontsize=13, ha='center', va='bottom', color=seaborn.color_palette()[2],\n",
" bbox=dict(boxstyle='square', fc='white', ec='None',))\n",
"x,y = range_brace(25, 31, height=3.5, mid=0.5, beta1=50)\n",
"ax.plot(x, -y-82,'-', color=seaborn.color_palette()[2])\n",
"plt.plot(28,-88, ls='', marker='o', color=seaborn.color_palette()[2])\n",
"\n",
"plt.plot([28,40],[-88, -91.5], color='black')\n",
"plt.plot([28,40],[-95, -91.5], color='black')\n",
"plt.text(42,-91.5, \"Loss\", fontsize=20, ha='left', va='center')\n",
"plt.axis('off')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"\n",
"import matplotlib.transforms as mtrans\n",
"from matplotlib.text import TextPath\n",
"from matplotlib.patches import PathPatch\n",
"def curly(x,y, size, scale, rotation, ax=None):\n",
" tp = TextPath((0, 0), \"}\", size=size,)\n",
" #trans = mtrans.Affine2D().scale(1, scale) + mtrans.Affine2D().rotate_deg(rotation) + \\\n",
" # mtrans.Affine2D().translate(x,y) + ax.transData\n",
" trans = mtrans.Affine2D().rotate_deg(rotation) + mtrans.Affine2D().translate(x,y) + ax.transData\n",
" pp = PathPatch(tp, lw=0, fc=\"k\", transform=trans)\n",
" ax.add_artist(pp)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(5,5))\n",
"data = cnt.get_data()[0,100:201]\n",
"plt.plot(data - np.mean(data))\n",
"ax = plt.gca()\n",
"plt.plot([0,0],[-20,20], color='black')\n",
"plt.plot([100,100],[-20,20], color='black')\n",
"points = [[0, -20], [25, -80], [50, -20]]\n",
"polygon = plt.Polygon(points, facecolor=seaborn.color_palette()[2] + (0.2,), edgecolor=seaborn.color_palette()[2])\n",
"ax.add_artist(polygon)\n",
"points = [[3, -20], [28, -80], [53, -20]]\n",
"polygon = plt.Polygon(points, facecolor=seaborn.color_palette()[2] + (0.2,), edgecolor=seaborn.color_palette()[2])\n",
"ax.add_artist(polygon)\n",
"points = [[6, -20], [31, -80], [56, -20]]\n",
"polygon = plt.Polygon(points, facecolor=seaborn.color_palette()[2] + (0.2,), edgecolor=seaborn.color_palette()[2])\n",
"ax.add_artist(polygon)\n",
"plt.plot(25,-80, ls='', marker='o', color=seaborn.color_palette()[2])\n",
"plt.plot(28,-80, ls='', marker='o', color=seaborn.color_palette()[2])\n",
"plt.plot(31,-80, ls='', marker='o', color=seaborn.color_palette()[2])\n",
"plt.ylim(-100,22)\n",
"plt.text(50,30, \"Trial\", fontsize=20, ha='center')\n",
"plt.text(20,-80, \"Predictions\", fontsize=20, ha='right', va='center')\n",
"plt.plot(25,-95, ls='', marker='o')\n",
"plt.text(20,-95, \"Target\", fontsize=20, ha='right', va='center')\n",
"\n",
"\n",
"# Here is the label and arrow code of interest\n",
"\n",
"ax.annotate('SDL', xy=(0.5, 0.90), xytext=(0.5, 1.00), xycoords='data', \n",
" fontsize=13, ha='center', va='bottom',\n",
" bbox=dict(boxstyle='square', fc='white'),\n",
" arrowprops=dict(arrowstyle='-[, widthB=7.0, lengthB=1.5', lw=2.0))\n",
"\"\"\"\n",
"plt.annotate(r\"$\\}$\",fontsize=46,\n",
" xy=(0.27, 0.77), xycoords='figure fraction',\n",
" rotation=90)\"\"\"\n",
"\n",
"import matplotlib.transforms as mtrans\n",
"from matplotlib.text import TextPath\n",
"from matplotlib.patches import PathPatch\n",
"def curly(x,y, size, scale, rotation, ax=None):\n",
" tp = TextPath((0, 0), \"}\", size=size)\n",
" #trans = mtrans.Affine2D().scale(1, scale) + mtrans.Affine2D().rotate_deg(rotation) + \\\n",
" # mtrans.Affine2D().translate(x,y) + ax.transData\n",
" trans = ax.transData + mtrans.Affine2D().translate(x,y) \n",
" pp = PathPatch(tp, lw=0, fc=\"k\", transform=trans)\n",
" ax.add_artist(pp)\n",
"\n",
"\n",
"curly(10,-20,20, 0.5, 90, ax)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"plt.text(45,-80, \"Prediction\", fontsize=20, ha='right', va='center')\n",
"plt.plot(50,-95, ls='', marker='o')\n",
"plt.text(45,-95, \"Target\", fontsize=20, ha='right', va='center')\n",
"plt.axis('off')\n",
"\n",
"plt.plot([50,65],[-80, -87.5], color='black')\n",
"plt.plot([50,65],[-95, -87.5], color='black')\n",
"plt.text(67,-87.5, \"Loss\", fontsize=20, ha='left', va='center')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"plt.plot(cnt.get_data()[0,100:201])\n",
"plt.axvline(x=0, color='black')\n",
"plt.axvline(x=100, color='black')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@@ -92,7 +92,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import mne\n",
@@ -226,7 +228,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from braindecode.torch_ext.optimizers import AdamW\n",
+9 -3
Ver Arquivo
@@ -174,7 +174,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from braindecode.torch_ext.optimizers import AdamW\n",
@@ -219,7 +221,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from braindecode.datautil.iterators import CropsFromTrialsIterator\n",
@@ -303,7 +307,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"\n",