expose the dtype
Esse commit está contido em:
+7
-3
@@ -73,10 +73,14 @@ fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, face_detec
|
||||
In order to specify the device (GPU or CPU) on which the code will run one can explicitly pass the device flag:
|
||||
|
||||
```python
|
||||
import torch
|
||||
import face_alignment
|
||||
|
||||
# cuda for CUDA
|
||||
# cuda for CUDA, mps for Apple M1/2 GPUs.
|
||||
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, device='cpu')
|
||||
|
||||
# running using lower precision
|
||||
fa = fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, dtype=torch.bfloat16, device='cuda')
|
||||
```
|
||||
|
||||
Please also see the ``examples`` folder
|
||||
@@ -85,10 +89,10 @@ Please also see the ``examples`` folder
|
||||
|
||||
```python
|
||||
|
||||
# dlib
|
||||
# dlib (fast, may miss faces)
|
||||
model = FaceAlignment(landmarks_type= LandmarksType.TWO_D, face_detector='dlib')
|
||||
|
||||
# SFD
|
||||
# SFD (likely best results, but slowest)
|
||||
model = FaceAlignment(landmarks_type= LandmarksType.TWO_D, face_detector='sfd')
|
||||
|
||||
# Blazeface (front camera model)
|
||||
|
||||
@@ -50,11 +50,12 @@ models_urls = {
|
||||
|
||||
class FaceAlignment:
|
||||
def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
|
||||
device='cuda', flip_input=False, face_detector='sfd', face_detector_kwargs=None, verbose=False):
|
||||
device='cuda', dtype=torch.float32, flip_input=False, face_detector='sfd', face_detector_kwargs=None, verbose=False):
|
||||
self.device = device
|
||||
self.flip_input = flip_input
|
||||
self.landmarks_type = landmarks_type
|
||||
self.verbose = verbose
|
||||
self.dtype = dtype
|
||||
|
||||
if version.parse(torch.__version__) < version.parse('1.5.0'):
|
||||
raise ImportError(f'Unsupported pytorch version detected. Minimum supported version of pytorch: 1.5.0\
|
||||
@@ -84,7 +85,7 @@ class FaceAlignment:
|
||||
self.face_alignment_net = torch.jit.load(
|
||||
load_file_from_url(models_urls.get(pytorch_version, default_model_urls)[network_name]))
|
||||
|
||||
self.face_alignment_net.to(device)
|
||||
self.face_alignment_net.to(device, dtype=dtype)
|
||||
self.face_alignment_net.eval()
|
||||
|
||||
# Initialiase the depth prediciton network
|
||||
@@ -92,7 +93,7 @@ class FaceAlignment:
|
||||
self.depth_prediciton_net = torch.jit.load(
|
||||
load_file_from_url(models_urls.get(pytorch_version, default_model_urls)['depth']))
|
||||
|
||||
self.depth_prediciton_net.to(device)
|
||||
self.depth_prediciton_net.to(device, dtype=dtype)
|
||||
self.depth_prediciton_net.eval()
|
||||
|
||||
def get_landmarks(self, image_or_path, detected_faces=None, return_bboxes=False, return_landmark_score=False):
|
||||
@@ -159,13 +160,13 @@ class FaceAlignment:
|
||||
inp = torch.from_numpy(inp.transpose(
|
||||
(2, 0, 1))).float()
|
||||
|
||||
inp = inp.to(self.device)
|
||||
inp = inp.to(self.device, dtype=self.dtype)
|
||||
inp.div_(255.0).unsqueeze_(0)
|
||||
|
||||
out = self.face_alignment_net(inp).detach()
|
||||
if self.flip_input:
|
||||
out += flip(self.face_alignment_net(flip(inp)).detach(), is_label=True)
|
||||
out = out.cpu().numpy()
|
||||
out = out.to(device='cpu', dtype=torch.float32).numpy()
|
||||
|
||||
pts, pts_img, scores = get_preds_fromhm(out, center.numpy(), scale)
|
||||
pts, pts_img = torch.from_numpy(pts), torch.from_numpy(pts_img)
|
||||
@@ -181,9 +182,9 @@ class FaceAlignment:
|
||||
heatmaps = torch.from_numpy(
|
||||
heatmaps).unsqueeze_(0)
|
||||
|
||||
heatmaps = heatmaps.to(self.device)
|
||||
heatmaps = heatmaps.to(self.device, dtype=self.dtype)
|
||||
depth_pred = self.depth_prediciton_net(
|
||||
torch.cat((inp, heatmaps), 1)).data.cpu().view(68, 1)
|
||||
torch.cat((inp, heatmaps), 1)).data.cpu().view(68, 1).to(dtype=torch.float32)
|
||||
pts_img = torch.cat(
|
||||
(pts_img, depth_pred * (1.0 / (256.0 / (200.0 * scale)))), 1)
|
||||
|
||||
|
||||
@@ -114,38 +114,13 @@ class BlazeFace(nn.Module):
|
||||
self.backbone = nn.Sequential(
|
||||
nn.Conv2d(in_channels=3, out_channels=24, kernel_size=5, stride=2, padding=0, bias=True),
|
||||
nn.ReLU(inplace=True),
|
||||
|
||||
BlazeBlock(24, 24),
|
||||
BlazeBlock(24, 24),
|
||||
BlazeBlock(24, 24),
|
||||
BlazeBlock(24, 24),
|
||||
BlazeBlock(24, 24),
|
||||
BlazeBlock(24, 24),
|
||||
BlazeBlock(24, 24),
|
||||
*[BlazeBlock(24, 24) for _ in range(7)],
|
||||
BlazeBlock(24, 24, stride=2),
|
||||
BlazeBlock(24, 24),
|
||||
BlazeBlock(24, 24),
|
||||
BlazeBlock(24, 24),
|
||||
BlazeBlock(24, 24),
|
||||
BlazeBlock(24, 24),
|
||||
BlazeBlock(24, 24),
|
||||
BlazeBlock(24, 24),
|
||||
*[BlazeBlock(24, 24) for _ in range(7)],
|
||||
BlazeBlock(24, 48, stride=2),
|
||||
BlazeBlock(48, 48),
|
||||
BlazeBlock(48, 48),
|
||||
BlazeBlock(48, 48),
|
||||
BlazeBlock(48, 48),
|
||||
BlazeBlock(48, 48),
|
||||
BlazeBlock(48, 48),
|
||||
BlazeBlock(48, 48),
|
||||
*[BlazeBlock(48, 48) for _ in range(7)],
|
||||
BlazeBlock(48, 96, stride=2),
|
||||
BlazeBlock(96, 96),
|
||||
BlazeBlock(96, 96),
|
||||
BlazeBlock(96, 96),
|
||||
BlazeBlock(96, 96),
|
||||
BlazeBlock(96, 96),
|
||||
BlazeBlock(96, 96),
|
||||
BlazeBlock(96, 96),
|
||||
*[BlazeBlock(96, 96) for _ in range(7)],
|
||||
)
|
||||
self.final = FinalBlazeBlock(96)
|
||||
self.classifier_8 = nn.Conv2d(96, 2, 1, bias=True)
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário