-
-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Expand file tree
/
Copy pathadamw.rs
More file actions
162 lines (140 loc) · 6.16 KB
/
adamw.rs
File metadata and controls
162 lines (140 loc) · 6.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
//! # AdamW (Adam with decoupled weight decay) optimizer
//!
//! AdamW modifies the standard Adam optimizer by decoupling weight decay from the
//! gradient update step. In standard Adam, weight decay is typically implemented
//! by adding an L2 penalty to the loss, which interacts with the adaptive learning
//! rates in a way that often results in suboptimal model convergence.
//!
//! AdamW explicitly decays the weights prior to the gradient update, restoring
//! the original mathematical definition of weight decay and generally enabling
//! better performance on complex models such as transformers.
//!
//! ## Resources:
//! - Decoupled Weight Decay Regularization (by Ilya Loshchilov and Frank Hutter):
//! - [https://arxiv.org/abs/1711.05101]
//! - PyTorch AdamW optimizer:
//! - [https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html]
#[allow(dead_code)]
pub struct AdamW {
learning_rate: f64, // alpha: initial step size
betas: (f64, f64), // betas: exponential decay rates for moment estimates
epsilon: f64, // epsilon: prevent division by zero
weight_decay: f64, // weight_decay: decouples weight decay penalty
m: Vec<f64>, // m: biased first moment estimate of gradient
v: Vec<f64>, // v: biased second raw moment estimate of gradient
t: usize, // t: time step
}
#[allow(dead_code)]
impl AdamW {
pub fn new(
learning_rate: Option<f64>,
betas: Option<(f64, f64)>,
epsilon: Option<f64>,
weight_decay: Option<f64>,
params_len: usize,
) -> Self {
AdamW {
learning_rate: learning_rate.unwrap_or(1e-3),
betas: betas.unwrap_or((0.9, 0.999)),
epsilon: epsilon.unwrap_or(1e-8),
weight_decay: weight_decay.unwrap_or(1e-2), // default weight decay scaling
m: vec![0.0; params_len],
v: vec![0.0; params_len],
t: 0,
}
}
/// Computes the AdamW step, updating the model parameters directly inline to
/// properly enable decoupled weight decay modifications.
pub fn step(&mut self, params: &mut [f64], gradients: &[f64]) {
assert_eq!(
params.len(),
gradients.len(),
"Parameters and gradients must be identical sizes."
);
self.t += 1;
for i in 0..gradients.len() {
// Apply decoupled weight decay (the 'W' in AdamW) inline
params[i] -= self.learning_rate * self.weight_decay * params[i];
// update biased first and second moment estimate
self.m[i] = self.betas.0 * self.m[i] + (1.0 - self.betas.0) * gradients[i];
self.v[i] = self.betas.1 * self.v[i] + (1.0 - self.betas.1) * gradients[i].powi(2);
// bias correction
let m_hat = self.m[i] / (1.0 - self.betas.0.powi(self.t as i32));
let v_hat = self.v[i] / (1.0 - self.betas.1.powi(self.t as i32));
// Apply standard Adam adaptive learning rate step
params[i] -= self.learning_rate * m_hat / (v_hat.sqrt() + self.epsilon);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_adamw_init_default_values() {
let optimizer = AdamW::new(None, None, None, None, 1);
assert_eq!(optimizer.learning_rate, 0.001);
assert_eq!(optimizer.betas, (0.9, 0.999));
assert_eq!(optimizer.epsilon, 1e-8);
assert_eq!(optimizer.weight_decay, 1e-2);
assert_eq!(optimizer.m, vec![0.0; 1]);
assert_eq!(optimizer.v, vec![0.0; 1]);
assert_eq!(optimizer.t, 0);
}
#[test]
fn test_adamw_init_custom_values() {
let optimizer = AdamW::new(Some(0.1), Some((0.8, 0.888)), Some(1e-4), Some(0.005), 3);
assert_eq!(optimizer.learning_rate, 0.1);
assert_eq!(optimizer.betas, (0.8, 0.888));
assert_eq!(optimizer.epsilon, 1e-4);
assert_eq!(optimizer.weight_decay, 0.005);
assert_eq!(optimizer.m, vec![0.0; 3]);
assert_eq!(optimizer.v, vec![0.0; 3]);
assert_eq!(optimizer.t, 0);
}
#[test]
fn test_adamw_step_default_params() {
let gradients = vec![-1.0, 2.0, -3.0];
let mut params = vec![0.5, -0.5, 0.0]; // non-zero starting params to test wd
let mut optimizer = AdamW::new(None, None, None, None, 3);
optimizer.step(&mut params, &gradients);
// Calculate expected values conceptually manually
// For i=0 (val = 0.5, grad = -1.0)
// param = 0.5 - (0.001 * 0.01 * 0.5) = 0.5 - 0.000005 = 0.499995
// m = 0.9(0) + 0.1(-1.0) = -0.1
// v = 0.999(0) + 0.001(1.0) = 0.001
// m_hat = -0.1 / 0.1 = -1.0
// v_hat = 0.001 / 0.001 = 1.0
// param -= 0.001 * -1.0 / (1.0 + 1e-8)
// final param roughly 0.499995 + 0.001 = 0.50099499999
assert!(params[0] > 0.5);
assert!(params[1] < -0.5);
}
#[test]
fn test_adamw_step_zero_gradients_with_weight_decay() {
// If gradients are zero, params should strictly decay toward zero.
let gradients = vec![0.0, 0.0];
let mut params = vec![100.0, -100.0];
let mut optimizer = AdamW::new(Some(1.0), None, None, Some(0.1), 2); // 10% daily decay
optimizer.step(&mut params, &gradients);
assert_eq!(params, vec![90.0, -90.0]); // 10% toward 0
optimizer.step(&mut params, &gradients);
assert_eq!(params, vec![81.0, -81.0]);
}
#[test]
fn test_adamw_step_iteratively_until_convergence() {
let gradients = vec![1.0, 2.0, 3.0, 4.0];
// High learning rate and weight decay to force massive movement quickly
let mut optimizer = AdamW::new(Some(0.1), None, None, Some(0.01), 4);
let mut model_params = vec![5.0; 4];
for _ in 0..100 {
optimizer.step(&mut model_params, &gradients);
}
// Because the gradient is constantly pushing positive, and the weight decay
// is pushing towards zero, the parameters should be pushed negatively from 5.0
// and eventually find a stable equilibrium.
assert!(model_params[0] < 5.0);
assert!(model_params[1] < 5.0);
assert!(model_params[2] < 5.0);
assert!(model_params[3] < 5.0);
}
}