16 linhas
368 B
Python
16 linhas
368 B
Python
import torch
|
|
|
|
|
|
def one_hot_embedding(labels, num_classes):
|
|
'''Embedding labels to one-hot.
|
|
|
|
Args:
|
|
labels: (LongTensor) class labels, sized [N,].
|
|
num_classes: (int) number of classes.
|
|
|
|
Returns:
|
|
(tensor) encoded labels, sized [N,#classes].
|
|
'''
|
|
y = torch.eye(num_classes, device=labels.device) # [D,D]
|
|
return y[labels] # [N,D]
|