From 7efceb00301f5a5aca103adb94afc819ac42662b Mon Sep 17 00:00:00 2001 From: weilheim Date: Tue, 5 Dec 2017 12:03:53 +0800 Subject: [PATCH] Add all code --- data.py | 56 +++++++++++++++++++++++++++++++++++++ model/__init__.py | 0 model/eegnet.py | 71 +++++++++++++++++++++++++++++++++++++++++++++++ model/lstm.py | 30 ++++++++++++++++++++ train.py | 28 +++++++++++++++++++ 5 files changed, 185 insertions(+) create mode 100644 data.py create mode 100644 model/__init__.py create mode 100644 model/eegnet.py create mode 100644 model/lstm.py create mode 100644 train.py diff --git a/data.py b/data.py new file mode 100644 index 0000000..36c0a2e --- /dev/null +++ b/data.py @@ -0,0 +1,56 @@ +import numpy as np +import h5py +import os +import torch +import torch.utils.data as data + +DATASET = '/home/liusheng/Dataset/EEG/' + + +class EEGImage(data.Dataset): + """DEAP EEG dataset""" + VALID_SPLIT = ('train', 'val', 'test') + + def __init__(self, data_dir, split='train'): + super(EEGImage, self).__init__() + + if split not in self.VALID_SPLIT: + raise ValueError('Unknown split {:s}'.format(split)) + if not os.path.exists(data_dir): + raise ValueError('{:s} does not exist'.format(data_dir)) + + self.split = split + self.data = h5py.File(data_dir, 'r') + self.key = list(self.data.keys()) + + def __getitem__(self, index): + key = self.key[index] + label = int(self.data[key]['label'][...]) + video = self.data[key]['video'][...] + # (D, H, W, C) -> (C, D, H, W): (3, 63, 32, 32) + video = np.transpose(video, (3, 0, 1, 2)) + return label, video + + def __len__(self): + return len(self.key) + + def collate_fn(self, batch): + label = [b[0] for b in batch] + # (C, D, H, W): (3, 63, 32, 32) + video = [b[1] / 255.0 for b in batch] + video = [torch.FloatTensor(v).unsqueeze(0) for v in video] + return torch.LongTensor(label), torch.cat(video, dim=0) + + +if __name__ == '__main__': + dataset = EEGImage("/home/liusheng/Dataset/EEG/train.h5", 'train') + print len(dataset) + label, video = dataset.__getitem__(0) + print label + print video.shape + + loader = data.DataLoader(dataset, batch_size=2, shuffle=False, num_workers=1, collate_fn=dataset.collate_fn) + sample = next(iter(loader)) + print sample[0] + print sample[1] + diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/model/eegnet.py b/model/eegnet.py new file mode 100644 index 0000000..bcd5b46 --- /dev/null +++ b/model/eegnet.py @@ -0,0 +1,71 @@ +from __future__ import absolute_import + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class EEGNet(nn.Module): + """2D convolutional neural network for single EEG frame.""" + VALIDENDPOINT = ('logit', 'predict') + + def __init__(self, num_class, + input_channel, + hidden_size, + kernel_size, + stride, + avgpool_size=4, + dropout=0.1): + super(EEGNet, self).__init__() + + assert len(kernel_size) == len(hidden_size) + assert len(kernel_size) == len(stride) + self.num_layer = len(kernel_size) + self.num_class = num_class + self.input_channel = input_channel + self.dropout = dropout + + in_channel = self.input_channel + layer = 1 + self.projections = nn.ModuleList() + self.residualnorms = nn.ModuleList() + self.convolutions = nn.ModuleList() + self.batchnorms = nn.ModuleList() + for (out_channel, kernel_width, s) in zip(hidden_size, kernel_size, stride): + pad = (kernel_size - 1) // 2 + self.projections.append(nn.Conv1d(in_channel, out_channel, kernel_size=1, stride=s, bias=False) + if in_channel != out_channel or s != 1 else None) + self.residualnorms.append(nn.BatchNorm2d(out_channel) + if in_channel != out_channel or s != 1 else None) + self.convolutions.append(nn.Conv2d(in_channel, out_channel, + kernel_size=kernel_width, stride=s, padding=pad, bias=False)) + self.batchnorms.append(nn.BatchNorm2d(out_channel, eps=1e-5, affine=True)) + in_channel = out_channel + # avgpool_size should equal to size of the feature map, + # otherwise self.predict will break. + self.avgpool = nn.AvgPool2d(kernel_size=avgpool_size) + self.predict = nn.Linear(hidden_size[-1], self.num_class, bias=True) + + def forward(self, x, endpoint='predict'): + if endpoint not in self.VALIDENDPOINT: + raise ValueError('Unknown endpoint {:s}'.format(endpoint)) + + for proj, rbn, conv, bn in zip(self.projections, self.residualnorms, + self.convolutions, self.batchnorms): + if proj is not None: + residual = proj(x) + residual = rbn(residual) + else: + residual = x + x = F.dropout(x, p=self.dropout, training=self.training) + x = conv(x) + x = bn(x) + x = (x + residual) + x = F.relu(x) + x = self.avgpool(x) + if endpoint == 'logit': + return x + x = self.predict(x) + return x + + diff --git a/model/lstm.py b/model/lstm.py new file mode 100644 index 0000000..b9050ea --- /dev/null +++ b/model/lstm.py @@ -0,0 +1,30 @@ +from __future__ import absolute_import + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import model.eegnet + + +class EEGLSTM(nn.Module): + + def __init__(self, num_class, + num_layer=1, + input_channel=128, + hidden_size=128, + dropout=0.3): + super(EEGLSTM, self).__init__() + self.num_class = num_class + self.num_layer = num_layer + self.input_channel = input_channel + self.hidden_size = hidden_size + self.dropout = dropout + self.lstm = nn.LSTM(input_channel, hidden_size, num_layer, + batch_first=True, bidirectional=False) + self.predict = nn.Linear(hidden_size, num_class) + + def forward(self, x): + x, (h, c) = self.lstm(x) + x = self.predict(x) + return x diff --git a/train.py b/train.py new file mode 100644 index 0000000..b05f6d7 --- /dev/null +++ b/train.py @@ -0,0 +1,28 @@ +from __future__ import absolute_import + +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F + +import data +from model import eegnet, lstm + + +def train(params): + pass + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--learning-rate', metavar='lr', type=float, default=1e-3, + help='learning rate') + parser.add_argument('--weight-decay', metavar='wd', type=float, default=1e-5, + help='weight decay') + parser.add_argument('--batch-size', metavar='bs', type=int, default=32, + help='batch size') + parser.add_argument('--shuffle', action='store_true', default=True, + help='whether shuffle EEG dataset or not') + parser.add_argument('--seed', type=int, default=1) + params = parser.parse_args() + train(params) \ No newline at end of file