split hybrid into pytorch module and our model

Esse commit está contido em:
Robin Tibor Schirrmeister
2018-09-17 17:23:04 +02:00
commit 81678c4ab5
+22 -5
Ver Arquivo
@@ -8,7 +8,24 @@ from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from braindecode.models.util import to_dense_prediction_model
class HybridNet(nn.Module, BaseModel):
class HybridNet(BaseModel):
"""
Wrapper for HybridNetModule
"""
def __init__(self, in_chans, n_classes, input_time_length):
self.in_chans = in_chans
self.n_classes = n_classes
self.input_time_length = input_time_length
def create_network(self):
return HybridNetModule(
in_chans=self.in_chans,
n_classes=self.n_classes,
input_time_length=self.input_time_length
)
class HybridNetModule(nn.Module):
"""
Hybrid ConvNet model from [3]_.
@@ -23,16 +40,16 @@ class HybridNet(nn.Module, BaseModel):
visualization.
Human Brain Mapping , Aug. 2017. Online: http://dx.doi.org/10.1002/hbm.23730
"""
def __init__(self, n_chans, n_classes, input_time_length):
super(HybridNet, self).__init__()
deep_model = Deep4Net(n_chans, n_classes, n_filters_time=20,
def __init__(self, in_chans, n_classes, input_time_length):
super(HybridNetModule, self).__init__()
deep_model = Deep4Net(in_chans, n_classes, n_filters_time=20,
n_filters_spat=30,
n_filters_2=40,
n_filters_3=50,
n_filters_4=60,
input_time_length=input_time_length,
final_conv_length=2).create_network()
shallow_model = ShallowFBCSPNet(n_chans, n_classes,
shallow_model = ShallowFBCSPNet(in_chans, n_classes,
input_time_length=input_time_length,
n_filters_time=30,
n_filters_spat=40,