Add all code

Esse commit está contido em:
weilheim
2017-12-05 12:03:53 +08:00
commit 7efceb0030
5 arquivos alterados com 185 adições e 0 exclusões
+56
Ver Arquivo
@@ -0,0 +1,56 @@
import numpy as np
import h5py
import os
import torch
import torch.utils.data as data
DATASET = '/home/liusheng/Dataset/EEG/'
class EEGImage(data.Dataset):
"""DEAP EEG dataset"""
VALID_SPLIT = ('train', 'val', 'test')
def __init__(self, data_dir, split='train'):
super(EEGImage, self).__init__()
if split not in self.VALID_SPLIT:
raise ValueError('Unknown split {:s}'.format(split))
if not os.path.exists(data_dir):
raise ValueError('{:s} does not exist'.format(data_dir))
self.split = split
self.data = h5py.File(data_dir, 'r')
self.key = list(self.data.keys())
def __getitem__(self, index):
key = self.key[index]
label = int(self.data[key]['label'][...])
video = self.data[key]['video'][...]
# (D, H, W, C) -> (C, D, H, W): (3, 63, 32, 32)
video = np.transpose(video, (3, 0, 1, 2))
return label, video
def __len__(self):
return len(self.key)
def collate_fn(self, batch):
label = [b[0] for b in batch]
# (C, D, H, W): (3, 63, 32, 32)
video = [b[1] / 255.0 for b in batch]
video = [torch.FloatTensor(v).unsqueeze(0) for v in video]
return torch.LongTensor(label), torch.cat(video, dim=0)
if __name__ == '__main__':
dataset = EEGImage("/home/liusheng/Dataset/EEG/train.h5", 'train')
print len(dataset)
label, video = dataset.__getitem__(0)
print label
print video.shape
loader = data.DataLoader(dataset, batch_size=2, shuffle=False, num_workers=1, collate_fn=dataset.collate_fn)
sample = next(iter(loader))
print sample[0]
print sample[1]
Ver Arquivo
+71
Ver Arquivo
@@ -0,0 +1,71 @@
from __future__ import absolute_import
import torch
import torch.nn as nn
import torch.nn.functional as F
class EEGNet(nn.Module):
"""2D convolutional neural network for single EEG frame."""
VALIDENDPOINT = ('logit', 'predict')
def __init__(self, num_class,
input_channel,
hidden_size,
kernel_size,
stride,
avgpool_size=4,
dropout=0.1):
super(EEGNet, self).__init__()
assert len(kernel_size) == len(hidden_size)
assert len(kernel_size) == len(stride)
self.num_layer = len(kernel_size)
self.num_class = num_class
self.input_channel = input_channel
self.dropout = dropout
in_channel = self.input_channel
layer = 1
self.projections = nn.ModuleList()
self.residualnorms = nn.ModuleList()
self.convolutions = nn.ModuleList()
self.batchnorms = nn.ModuleList()
for (out_channel, kernel_width, s) in zip(hidden_size, kernel_size, stride):
pad = (kernel_size - 1) // 2
self.projections.append(nn.Conv1d(in_channel, out_channel, kernel_size=1, stride=s, bias=False)
if in_channel != out_channel or s != 1 else None)
self.residualnorms.append(nn.BatchNorm2d(out_channel)
if in_channel != out_channel or s != 1 else None)
self.convolutions.append(nn.Conv2d(in_channel, out_channel,
kernel_size=kernel_width, stride=s, padding=pad, bias=False))
self.batchnorms.append(nn.BatchNorm2d(out_channel, eps=1e-5, affine=True))
in_channel = out_channel
# avgpool_size should equal to size of the feature map,
# otherwise self.predict will break.
self.avgpool = nn.AvgPool2d(kernel_size=avgpool_size)
self.predict = nn.Linear(hidden_size[-1], self.num_class, bias=True)
def forward(self, x, endpoint='predict'):
if endpoint not in self.VALIDENDPOINT:
raise ValueError('Unknown endpoint {:s}'.format(endpoint))
for proj, rbn, conv, bn in zip(self.projections, self.residualnorms,
self.convolutions, self.batchnorms):
if proj is not None:
residual = proj(x)
residual = rbn(residual)
else:
residual = x
x = F.dropout(x, p=self.dropout, training=self.training)
x = conv(x)
x = bn(x)
x = (x + residual)
x = F.relu(x)
x = self.avgpool(x)
if endpoint == 'logit':
return x
x = self.predict(x)
return x
+30
Ver Arquivo
@@ -0,0 +1,30 @@
from __future__ import absolute_import
import torch
import torch.nn as nn
import torch.nn.functional as F
import model.eegnet
class EEGLSTM(nn.Module):
def __init__(self, num_class,
num_layer=1,
input_channel=128,
hidden_size=128,
dropout=0.3):
super(EEGLSTM, self).__init__()
self.num_class = num_class
self.num_layer = num_layer
self.input_channel = input_channel
self.hidden_size = hidden_size
self.dropout = dropout
self.lstm = nn.LSTM(input_channel, hidden_size, num_layer,
batch_first=True, bidirectional=False)
self.predict = nn.Linear(hidden_size, num_class)
def forward(self, x):
x, (h, c) = self.lstm(x)
x = self.predict(x)
return x
+28
Ver Arquivo
@@ -0,0 +1,28 @@
from __future__ import absolute_import
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import data
from model import eegnet, lstm
def train(params):
pass
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--learning-rate', metavar='lr', type=float, default=1e-3,
help='learning rate')
parser.add_argument('--weight-decay', metavar='wd', type=float, default=1e-5,
help='weight decay')
parser.add_argument('--batch-size', metavar='bs', type=int, default=32,
help='batch size')
parser.add_argument('--shuffle', action='store_true', default=True,
help='whether shuffle EEG dataset or not')
parser.add_argument('--seed', type=int, default=1)
params = parser.parse_args()
train(params)