rename fixed weights to pool_weights in AvgPool2dWithConv
Esse commit está contido em:
@@ -52,7 +52,10 @@ class AvgPool2dWithConv(torch.nn.Module):
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.weights = None
|
||||
# don't name them "weights" to
|
||||
# make sure these are not accidentally used by some procedure
|
||||
# that initializes parameters or something
|
||||
self._pool_weights = None
|
||||
|
||||
def forward(self, x):
|
||||
# Create weights for the convolution on demand:
|
||||
@@ -60,11 +63,11 @@ class AvgPool2dWithConv(torch.nn.Module):
|
||||
in_channels = x.size()[1]
|
||||
weight_shape = (in_channels, 1,
|
||||
self.kernel_size[0], self.kernel_size[1])
|
||||
if self.weights is None or (
|
||||
(tuple(self.weights.size()) != tuple(weight_shape)) or (
|
||||
self.weights.is_cuda != x.is_cuda
|
||||
if self._pool_weights is None or (
|
||||
(tuple(self._pool_weights.size()) != tuple(weight_shape)) or (
|
||||
self._pool_weights.is_cuda != x.is_cuda
|
||||
) or (
|
||||
self.weights.data.type() != x.data.type()
|
||||
self._pool_weights.data.type() != x.data.type()
|
||||
)):
|
||||
n_pool = np.prod(self.kernel_size)
|
||||
weights = np_to_var(
|
||||
@@ -72,9 +75,9 @@ class AvgPool2dWithConv(torch.nn.Module):
|
||||
weights = weights.type_as(x)
|
||||
if x.is_cuda:
|
||||
weights = weights.cuda()
|
||||
self.weights = weights
|
||||
self._pool_weights = weights
|
||||
|
||||
pooled = F.conv2d(x, self.weights, bias=None, stride=self.stride,
|
||||
pooled = F.conv2d(x, self._pool_weights, bias=None, stride=self.stride,
|
||||
dilation=self.dilation,
|
||||
groups=in_channels,)
|
||||
return pooled
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário