fix references for AdamW

Esse commit está contido em:
robintibor
2019-01-07 15:48:44 +01:00
commit de GitHub
commit c32ac5b0d5
+23 -16
Ver Arquivo
@@ -4,22 +4,29 @@ from torch.optim.optimizer import Optimizer
class AdamW(Optimizer):
"""Implements Adam algorithm with fixed as in `AdamW`
"""Implements Adam algorithm with weight decay fixed as in [AdamW]_` .
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
Parameters
----------
params: iterable
Iterable of parameters to optimize or dicts defining parameter groups
lr: float, optional
Learning rate.
betas: Tuple[float, float], optional
Coefficients used for computing running averages of gradient and its square
eps: float, optional
Term added to the denominator to improve numerical stability
weight_decay: float, optional
The "fixed" weight decay.
References
----------
.. [AdamW] Loshchilov, I. & Hutter, F. (2017).
Fixing Weight Decay Regularization in Adam.
arXiv preprint arXiv:1711.05101.
Online: https://arxiv.org/abs/1711.05101
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
@@ -74,4 +81,4 @@ class AdamW(Optimizer):
if group['weight_decay'] != 0:
p.data.add_(-group['weight_decay'], p.data)
return loss
return loss