57 linhas
1.6 KiB
Python
57 linhas
1.6 KiB
Python
import numpy as np
|
|
import h5py
|
|
import os
|
|
import torch
|
|
import torch.utils.data as data
|
|
|
|
DATASET = '/home/liusheng/Dataset/EEG/'
|
|
|
|
|
|
class EEGImage(data.Dataset):
|
|
"""DEAP EEG dataset"""
|
|
VALID_SPLIT = ('train', 'val', 'test')
|
|
|
|
def __init__(self, data_dir, split='train'):
|
|
super(EEGImage, self).__init__()
|
|
|
|
if split not in self.VALID_SPLIT:
|
|
raise ValueError('Unknown split {:s}'.format(split))
|
|
if not os.path.exists(data_dir):
|
|
raise ValueError('{:s} does not exist'.format(data_dir))
|
|
|
|
self.split = split
|
|
self.data = h5py.File(data_dir, 'r')
|
|
self.key = list(self.data.keys())
|
|
|
|
def __getitem__(self, index):
|
|
key = self.key[index]
|
|
label = int(self.data[key]['label'][...])
|
|
video = self.data[key]['video'][...]
|
|
# (D, H, W, C) -> (C, D, H, W): (3, 63, 32, 32)
|
|
video = np.transpose(video, (3, 0, 1, 2))
|
|
return label, video
|
|
|
|
def __len__(self):
|
|
return len(self.key)
|
|
|
|
def collate_fn(self, batch):
|
|
label = [b[0] for b in batch]
|
|
# (C, D, H, W): (3, 63, 32, 32)
|
|
video = [b[1] / 255.0 for b in batch]
|
|
video = [torch.FloatTensor(v).unsqueeze(0) for v in video]
|
|
return torch.LongTensor(label), torch.cat(video, dim=0)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
dataset = EEGImage("/home/liusheng/Dataset/EEG/train.h5", 'train')
|
|
print len(dataset)
|
|
label, video = dataset.__getitem__(0)
|
|
print label
|
|
print video.shape
|
|
|
|
loader = data.DataLoader(dataset, batch_size=2, shuffle=False, num_workers=1, collate_fn=dataset.collate_fn)
|
|
sample = next(iter(loader))
|
|
print sample[0]
|
|
print sample[1]
|
|
|