Arquivos
cuneiform-sign-detection-code/lib/utils/torchcv/one_hot_embedding.py
T
2020-11-19 12:18:53 +01:00

16 linhas
368 B
Python

import torch
def one_hot_embedding(labels, num_classes):
'''Embedding labels to one-hot.
Args:
labels: (LongTensor) class labels, sized [N,].
num_classes: (int) number of classes.
Returns:
(tensor) encoded labels, sized [N,#classes].
'''
y = torch.eye(num_classes, device=labels.device) # [D,D]
return y[labels] # [N,D]