Add all code
Esse commit está contido em:
+56
@@ -0,0 +1,56 @@
|
||||
import numpy as np
|
||||
import h5py
|
||||
import os
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
|
||||
DATASET = '/home/liusheng/Dataset/EEG/'
|
||||
|
||||
|
||||
class EEGImage(data.Dataset):
|
||||
"""DEAP EEG dataset"""
|
||||
VALID_SPLIT = ('train', 'val', 'test')
|
||||
|
||||
def __init__(self, data_dir, split='train'):
|
||||
super(EEGImage, self).__init__()
|
||||
|
||||
if split not in self.VALID_SPLIT:
|
||||
raise ValueError('Unknown split {:s}'.format(split))
|
||||
if not os.path.exists(data_dir):
|
||||
raise ValueError('{:s} does not exist'.format(data_dir))
|
||||
|
||||
self.split = split
|
||||
self.data = h5py.File(data_dir, 'r')
|
||||
self.key = list(self.data.keys())
|
||||
|
||||
def __getitem__(self, index):
|
||||
key = self.key[index]
|
||||
label = int(self.data[key]['label'][...])
|
||||
video = self.data[key]['video'][...]
|
||||
# (D, H, W, C) -> (C, D, H, W): (3, 63, 32, 32)
|
||||
video = np.transpose(video, (3, 0, 1, 2))
|
||||
return label, video
|
||||
|
||||
def __len__(self):
|
||||
return len(self.key)
|
||||
|
||||
def collate_fn(self, batch):
|
||||
label = [b[0] for b in batch]
|
||||
# (C, D, H, W): (3, 63, 32, 32)
|
||||
video = [b[1] / 255.0 for b in batch]
|
||||
video = [torch.FloatTensor(v).unsqueeze(0) for v in video]
|
||||
return torch.LongTensor(label), torch.cat(video, dim=0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dataset = EEGImage("/home/liusheng/Dataset/EEG/train.h5", 'train')
|
||||
print len(dataset)
|
||||
label, video = dataset.__getitem__(0)
|
||||
print label
|
||||
print video.shape
|
||||
|
||||
loader = data.DataLoader(dataset, batch_size=2, shuffle=False, num_workers=1, collate_fn=dataset.collate_fn)
|
||||
sample = next(iter(loader))
|
||||
print sample[0]
|
||||
print sample[1]
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class EEGNet(nn.Module):
|
||||
"""2D convolutional neural network for single EEG frame."""
|
||||
VALIDENDPOINT = ('logit', 'predict')
|
||||
|
||||
def __init__(self, num_class,
|
||||
input_channel,
|
||||
hidden_size,
|
||||
kernel_size,
|
||||
stride,
|
||||
avgpool_size=4,
|
||||
dropout=0.1):
|
||||
super(EEGNet, self).__init__()
|
||||
|
||||
assert len(kernel_size) == len(hidden_size)
|
||||
assert len(kernel_size) == len(stride)
|
||||
self.num_layer = len(kernel_size)
|
||||
self.num_class = num_class
|
||||
self.input_channel = input_channel
|
||||
self.dropout = dropout
|
||||
|
||||
in_channel = self.input_channel
|
||||
layer = 1
|
||||
self.projections = nn.ModuleList()
|
||||
self.residualnorms = nn.ModuleList()
|
||||
self.convolutions = nn.ModuleList()
|
||||
self.batchnorms = nn.ModuleList()
|
||||
for (out_channel, kernel_width, s) in zip(hidden_size, kernel_size, stride):
|
||||
pad = (kernel_size - 1) // 2
|
||||
self.projections.append(nn.Conv1d(in_channel, out_channel, kernel_size=1, stride=s, bias=False)
|
||||
if in_channel != out_channel or s != 1 else None)
|
||||
self.residualnorms.append(nn.BatchNorm2d(out_channel)
|
||||
if in_channel != out_channel or s != 1 else None)
|
||||
self.convolutions.append(nn.Conv2d(in_channel, out_channel,
|
||||
kernel_size=kernel_width, stride=s, padding=pad, bias=False))
|
||||
self.batchnorms.append(nn.BatchNorm2d(out_channel, eps=1e-5, affine=True))
|
||||
in_channel = out_channel
|
||||
# avgpool_size should equal to size of the feature map,
|
||||
# otherwise self.predict will break.
|
||||
self.avgpool = nn.AvgPool2d(kernel_size=avgpool_size)
|
||||
self.predict = nn.Linear(hidden_size[-1], self.num_class, bias=True)
|
||||
|
||||
def forward(self, x, endpoint='predict'):
|
||||
if endpoint not in self.VALIDENDPOINT:
|
||||
raise ValueError('Unknown endpoint {:s}'.format(endpoint))
|
||||
|
||||
for proj, rbn, conv, bn in zip(self.projections, self.residualnorms,
|
||||
self.convolutions, self.batchnorms):
|
||||
if proj is not None:
|
||||
residual = proj(x)
|
||||
residual = rbn(residual)
|
||||
else:
|
||||
residual = x
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = conv(x)
|
||||
x = bn(x)
|
||||
x = (x + residual)
|
||||
x = F.relu(x)
|
||||
x = self.avgpool(x)
|
||||
if endpoint == 'logit':
|
||||
return x
|
||||
x = self.predict(x)
|
||||
return x
|
||||
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import model.eegnet
|
||||
|
||||
|
||||
class EEGLSTM(nn.Module):
|
||||
|
||||
def __init__(self, num_class,
|
||||
num_layer=1,
|
||||
input_channel=128,
|
||||
hidden_size=128,
|
||||
dropout=0.3):
|
||||
super(EEGLSTM, self).__init__()
|
||||
self.num_class = num_class
|
||||
self.num_layer = num_layer
|
||||
self.input_channel = input_channel
|
||||
self.hidden_size = hidden_size
|
||||
self.dropout = dropout
|
||||
self.lstm = nn.LSTM(input_channel, hidden_size, num_layer,
|
||||
batch_first=True, bidirectional=False)
|
||||
self.predict = nn.Linear(hidden_size, num_class)
|
||||
|
||||
def forward(self, x):
|
||||
x, (h, c) = self.lstm(x)
|
||||
x = self.predict(x)
|
||||
return x
|
||||
+28
@@ -0,0 +1,28 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import data
|
||||
from model import eegnet, lstm
|
||||
|
||||
|
||||
def train(params):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--learning-rate', metavar='lr', type=float, default=1e-3,
|
||||
help='learning rate')
|
||||
parser.add_argument('--weight-decay', metavar='wd', type=float, default=1e-5,
|
||||
help='weight decay')
|
||||
parser.add_argument('--batch-size', metavar='bs', type=int, default=32,
|
||||
help='batch size')
|
||||
parser.add_argument('--shuffle', action='store_true', default=True,
|
||||
help='whether shuffle EEG dataset or not')
|
||||
parser.add_argument('--seed', type=int, default=1)
|
||||
params = parser.parse_args()
|
||||
train(params)
|
||||
Referência em uma Nova Issue
Bloquear um usuário