72 linhas
3.0 KiB
Python
72 linhas
3.0 KiB
Python
from __future__ import print_function
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class SSDLoss(nn.Module):
|
|
def __init__(self, num_classes):
|
|
super(SSDLoss, self).__init__()
|
|
self.num_classes = num_classes
|
|
|
|
def _hard_negative_mining(self, cls_loss, pos):
|
|
'''Return negative indices that is 3x the number as positive indices.
|
|
|
|
Args:
|
|
cls_loss: (tensor) cross entroy loss between cls_preds and cls_targets, sized [N,#anchors].
|
|
pos: (tensor) positive class mask, sized [N,#anchors].
|
|
|
|
Return:
|
|
(tensor) negative indices, sized [N,#anchors].
|
|
'''
|
|
cls_loss = cls_loss * (pos.float() - 1)
|
|
|
|
_, idx = cls_loss.sort(1) # sort by negative losses
|
|
_, rank = idx.sort(1) # [N,#anchors]
|
|
|
|
num_neg = 3*pos.sum(1) # [N,]
|
|
neg = rank < num_neg[:,None] # [N,#anchors]
|
|
return neg
|
|
|
|
def forward(self, loc_preds, loc_targets, cls_preds, cls_targets):
|
|
'''Compute loss between (loc_preds, loc_targets) and (cls_preds, cls_targets).
|
|
|
|
Args:
|
|
loc_preds: (tensor) predicted locations, sized [N, #anchors, 4].
|
|
loc_targets: (tensor) encoded target locations, sized [N, #anchors, 4].
|
|
cls_preds: (tensor) predicted class confidences, sized [N, #anchors, #classes].
|
|
cls_targets: (tensor) encoded target labels, sized [N, #anchors].
|
|
|
|
loss:
|
|
(tensor) loss = SmoothL1Loss(loc_preds, loc_targets) + CrossEntropyLoss(cls_preds, cls_targets).
|
|
'''
|
|
pos = cls_targets > 0 # [N,#anchors]
|
|
batch_size = pos.size(0)
|
|
num_pos = pos.sum().item()
|
|
|
|
#===============================================================
|
|
# loc_loss = SmoothL1Loss(pos_loc_preds, pos_loc_targets)
|
|
#===============================================================
|
|
mask = pos.unsqueeze(2).expand_as(loc_preds) # [N,#anchors,4]
|
|
loc_loss = F.smooth_l1_loss(loc_preds[mask], loc_targets[mask], size_average=False)
|
|
|
|
#===============================================================
|
|
# cls_loss = CrossEntropyLoss(cls_preds, cls_targets)
|
|
#===============================================================
|
|
# TD: added clamp, because cross entropy does not handle negative indices well
|
|
cls_loss = F.cross_entropy(cls_preds.view(-1,self.num_classes), \
|
|
cls_targets.clamp(min=0).view(-1), reduce=False) # [N*#anchors,]
|
|
cls_loss = cls_loss.view(batch_size, -1)
|
|
cls_loss[cls_targets < 0] = 0 # set ignored loss to 0
|
|
|
|
neg = self._hard_negative_mining(cls_loss, pos) # [N,#anchors]
|
|
cls_loss = cls_loss[pos|neg].sum()
|
|
if num_pos > 0: # TD mod to prevent div by zero
|
|
print('loc_loss: %.3f | cls_loss: %.3f' % (loc_loss.item()/num_pos, cls_loss.item()/num_pos), end=' | ')
|
|
loss = (loc_loss+cls_loss)/num_pos
|
|
else:
|
|
print('num_pos zero exception')
|
|
loss = (loc_loss+cls_loss)/1.
|
|
return loss
|