Add all code
Esse commit está contido em:
+56
@@ -0,0 +1,56 @@
|
|||||||
|
import numpy as np
|
||||||
|
import h5py
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.utils.data as data
|
||||||
|
|
||||||
|
DATASET = '/home/liusheng/Dataset/EEG/'
|
||||||
|
|
||||||
|
|
||||||
|
class EEGImage(data.Dataset):
|
||||||
|
"""DEAP EEG dataset"""
|
||||||
|
VALID_SPLIT = ('train', 'val', 'test')
|
||||||
|
|
||||||
|
def __init__(self, data_dir, split='train'):
|
||||||
|
super(EEGImage, self).__init__()
|
||||||
|
|
||||||
|
if split not in self.VALID_SPLIT:
|
||||||
|
raise ValueError('Unknown split {:s}'.format(split))
|
||||||
|
if not os.path.exists(data_dir):
|
||||||
|
raise ValueError('{:s} does not exist'.format(data_dir))
|
||||||
|
|
||||||
|
self.split = split
|
||||||
|
self.data = h5py.File(data_dir, 'r')
|
||||||
|
self.key = list(self.data.keys())
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
key = self.key[index]
|
||||||
|
label = int(self.data[key]['label'][...])
|
||||||
|
video = self.data[key]['video'][...]
|
||||||
|
# (D, H, W, C) -> (C, D, H, W): (3, 63, 32, 32)
|
||||||
|
video = np.transpose(video, (3, 0, 1, 2))
|
||||||
|
return label, video
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.key)
|
||||||
|
|
||||||
|
def collate_fn(self, batch):
|
||||||
|
label = [b[0] for b in batch]
|
||||||
|
# (C, D, H, W): (3, 63, 32, 32)
|
||||||
|
video = [b[1] / 255.0 for b in batch]
|
||||||
|
video = [torch.FloatTensor(v).unsqueeze(0) for v in video]
|
||||||
|
return torch.LongTensor(label), torch.cat(video, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
dataset = EEGImage("/home/liusheng/Dataset/EEG/train.h5", 'train')
|
||||||
|
print len(dataset)
|
||||||
|
label, video = dataset.__getitem__(0)
|
||||||
|
print label
|
||||||
|
print video.shape
|
||||||
|
|
||||||
|
loader = data.DataLoader(dataset, batch_size=2, shuffle=False, num_workers=1, collate_fn=dataset.collate_fn)
|
||||||
|
sample = next(iter(loader))
|
||||||
|
print sample[0]
|
||||||
|
print sample[1]
|
||||||
|
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class EEGNet(nn.Module):
|
||||||
|
"""2D convolutional neural network for single EEG frame."""
|
||||||
|
VALIDENDPOINT = ('logit', 'predict')
|
||||||
|
|
||||||
|
def __init__(self, num_class,
|
||||||
|
input_channel,
|
||||||
|
hidden_size,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
avgpool_size=4,
|
||||||
|
dropout=0.1):
|
||||||
|
super(EEGNet, self).__init__()
|
||||||
|
|
||||||
|
assert len(kernel_size) == len(hidden_size)
|
||||||
|
assert len(kernel_size) == len(stride)
|
||||||
|
self.num_layer = len(kernel_size)
|
||||||
|
self.num_class = num_class
|
||||||
|
self.input_channel = input_channel
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
|
in_channel = self.input_channel
|
||||||
|
layer = 1
|
||||||
|
self.projections = nn.ModuleList()
|
||||||
|
self.residualnorms = nn.ModuleList()
|
||||||
|
self.convolutions = nn.ModuleList()
|
||||||
|
self.batchnorms = nn.ModuleList()
|
||||||
|
for (out_channel, kernel_width, s) in zip(hidden_size, kernel_size, stride):
|
||||||
|
pad = (kernel_size - 1) // 2
|
||||||
|
self.projections.append(nn.Conv1d(in_channel, out_channel, kernel_size=1, stride=s, bias=False)
|
||||||
|
if in_channel != out_channel or s != 1 else None)
|
||||||
|
self.residualnorms.append(nn.BatchNorm2d(out_channel)
|
||||||
|
if in_channel != out_channel or s != 1 else None)
|
||||||
|
self.convolutions.append(nn.Conv2d(in_channel, out_channel,
|
||||||
|
kernel_size=kernel_width, stride=s, padding=pad, bias=False))
|
||||||
|
self.batchnorms.append(nn.BatchNorm2d(out_channel, eps=1e-5, affine=True))
|
||||||
|
in_channel = out_channel
|
||||||
|
# avgpool_size should equal to size of the feature map,
|
||||||
|
# otherwise self.predict will break.
|
||||||
|
self.avgpool = nn.AvgPool2d(kernel_size=avgpool_size)
|
||||||
|
self.predict = nn.Linear(hidden_size[-1], self.num_class, bias=True)
|
||||||
|
|
||||||
|
def forward(self, x, endpoint='predict'):
|
||||||
|
if endpoint not in self.VALIDENDPOINT:
|
||||||
|
raise ValueError('Unknown endpoint {:s}'.format(endpoint))
|
||||||
|
|
||||||
|
for proj, rbn, conv, bn in zip(self.projections, self.residualnorms,
|
||||||
|
self.convolutions, self.batchnorms):
|
||||||
|
if proj is not None:
|
||||||
|
residual = proj(x)
|
||||||
|
residual = rbn(residual)
|
||||||
|
else:
|
||||||
|
residual = x
|
||||||
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
|
x = conv(x)
|
||||||
|
x = bn(x)
|
||||||
|
x = (x + residual)
|
||||||
|
x = F.relu(x)
|
||||||
|
x = self.avgpool(x)
|
||||||
|
if endpoint == 'logit':
|
||||||
|
return x
|
||||||
|
x = self.predict(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import model.eegnet
|
||||||
|
|
||||||
|
|
||||||
|
class EEGLSTM(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, num_class,
|
||||||
|
num_layer=1,
|
||||||
|
input_channel=128,
|
||||||
|
hidden_size=128,
|
||||||
|
dropout=0.3):
|
||||||
|
super(EEGLSTM, self).__init__()
|
||||||
|
self.num_class = num_class
|
||||||
|
self.num_layer = num_layer
|
||||||
|
self.input_channel = input_channel
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.dropout = dropout
|
||||||
|
self.lstm = nn.LSTM(input_channel, hidden_size, num_layer,
|
||||||
|
batch_first=True, bidirectional=False)
|
||||||
|
self.predict = nn.Linear(hidden_size, num_class)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x, (h, c) = self.lstm(x)
|
||||||
|
x = self.predict(x)
|
||||||
|
return x
|
||||||
+28
@@ -0,0 +1,28 @@
|
|||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import data
|
||||||
|
from model import eegnet, lstm
|
||||||
|
|
||||||
|
|
||||||
|
def train(params):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--learning-rate', metavar='lr', type=float, default=1e-3,
|
||||||
|
help='learning rate')
|
||||||
|
parser.add_argument('--weight-decay', metavar='wd', type=float, default=1e-5,
|
||||||
|
help='weight decay')
|
||||||
|
parser.add_argument('--batch-size', metavar='bs', type=int, default=32,
|
||||||
|
help='batch size')
|
||||||
|
parser.add_argument('--shuffle', action='store_true', default=True,
|
||||||
|
help='whether shuffle EEG dataset or not')
|
||||||
|
parser.add_argument('--seed', type=int, default=1)
|
||||||
|
params = parser.parse_args()
|
||||||
|
train(params)
|
||||||
Referência em uma Nova Issue
Bloquear um usuário