09b6ca0ad0
WIP #3. Added import tests WIP #3 Added more test-cases WIP #3 Added more tests WIP #3. Fixed concat.lua test WIP #3 minor changes WIP #3 Fixed concat-parallel.lua WIP #3 Added check-model helper WIP #3 Added more tests for model checker WIP #3 Added extra tests WIP #3 Changed check-model to GraphChecker WIP #3. multiple cases fail for ImportTorch... WIP #3 Fixed ImportTorch batch test case running WIP #3 Changed graph checker to use gme path for id WIP #3 Updated tests WIP #3. Tweaked to get all examples working locally w/ 'th' WIP #3 Fixed tests
24 linhas
406 B
Lua
24 linhas
406 B
Lua
require 'nn'
|
|
|
|
nhiddens = 150
|
|
|
|
function createSeq(input, output)
|
|
seq = nn.Sequential();
|
|
seq:add(nn.Linear(input,nhiddens))
|
|
seq:add(nn.Tanh())
|
|
seq:add(nn.Linear(nhiddens,output))
|
|
return seq
|
|
end
|
|
|
|
mlp = nn.Sequential()
|
|
|
|
-- concat layer
|
|
concat = nn.Concat(1)
|
|
concat:add(createSeq(100, 50))
|
|
concat:add(createSeq(100, 30))
|
|
|
|
-- merge
|
|
mlp:add(concat)
|
|
mlp:add(nn.Tanh())
|
|
mlp:add(nn.Linear(80,7))
|