mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the optimizer trait. (#702)
This commit is contained in:
@ -1,6 +1,7 @@
|
|||||||
use crate::model::{Cache, Config, Llama};
|
use crate::model::{Cache, Config, Llama};
|
||||||
use candle::{DType, Device, Result};
|
use candle::{DType, Device, Result};
|
||||||
use candle_datasets::nlp::tinystories::{Dataset, DatasetRandomIter};
|
use candle_datasets::nlp::tinystories::{Dataset, DatasetRandomIter};
|
||||||
|
use candle_nn::Optimizer;
|
||||||
|
|
||||||
fn valid_loss(
|
fn valid_loss(
|
||||||
dataset: &Dataset,
|
dataset: &Dataset,
|
||||||
|
@ -9,7 +9,7 @@ use clap::{Parser, ValueEnum};
|
|||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
|
|
||||||
use candle::{DType, Result, Tensor, D};
|
use candle::{DType, Result, Tensor, D};
|
||||||
use candle_nn::{loss, ops, Conv2d, Linear, Module, VarBuilder, VarMap};
|
use candle_nn::{loss, ops, Conv2d, Linear, Module, Optimizer, VarBuilder, VarMap};
|
||||||
|
|
||||||
const IMAGE_DIM: usize = 784;
|
const IMAGE_DIM: usize = 784;
|
||||||
const LABELS: usize = 10;
|
const LABELS: usize = 10;
|
||||||
@ -190,7 +190,7 @@ fn training_loop<M: Model>(
|
|||||||
varmap.load(load)?
|
varmap.load(load)?
|
||||||
}
|
}
|
||||||
|
|
||||||
let sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate);
|
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate)?;
|
||||||
let test_images = m.test_images.to_device(&dev)?;
|
let test_images = m.test_images.to_device(&dev)?;
|
||||||
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||||
for epoch in 1..args.epochs {
|
for epoch in 1..args.epochs {
|
||||||
|
@ -5,7 +5,7 @@ extern crate intel_mkl_src;
|
|||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use candle::{DType, Device, Result, Tensor};
|
use candle::{DType, Device, Result, Tensor};
|
||||||
use candle_nn::{linear, AdamW, Linear, Module, ParamsAdamW, VarBuilder, VarMap};
|
use candle_nn::{linear, AdamW, Linear, Module, Optimizer, ParamsAdamW, VarBuilder, VarMap};
|
||||||
|
|
||||||
fn gen_data() -> Result<(Tensor, Tensor)> {
|
fn gen_data() -> Result<(Tensor, Tensor)> {
|
||||||
// Generate some sample linear data.
|
// Generate some sample linear data.
|
||||||
|
@ -24,7 +24,7 @@ pub use init::Init;
|
|||||||
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
|
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
|
||||||
pub use linear::{linear, linear_no_bias, Linear};
|
pub use linear::{linear, linear_no_bias, Linear};
|
||||||
pub use ops::Dropout;
|
pub use ops::Dropout;
|
||||||
pub use optim::{AdamW, ParamsAdamW, SGD};
|
pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD};
|
||||||
pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN};
|
pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN};
|
||||||
pub use var_builder::VarBuilder;
|
pub use var_builder::VarBuilder;
|
||||||
pub use var_map::VarMap;
|
pub use var_map::VarMap;
|
||||||
|
@ -1,6 +1,33 @@
|
|||||||
//! Various optimization algorithms.
|
//! Various optimization algorithms.
|
||||||
use candle::{Result, Tensor, Var};
|
use candle::{Result, Tensor, Var};
|
||||||
|
|
||||||
|
/// The interface optimizers should implement.
|
||||||
|
pub trait Optimizer: Sized {
|
||||||
|
type Config: Sized;
|
||||||
|
|
||||||
|
fn new(vars: Vec<Var>, config: Self::Config) -> Result<Self>;
|
||||||
|
|
||||||
|
fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()>;
|
||||||
|
|
||||||
|
fn learning_rate(&self) -> f64;
|
||||||
|
|
||||||
|
fn set_learning_rate(&mut self, lr: f64);
|
||||||
|
|
||||||
|
fn empty(config: Self::Config) -> Result<Self> {
|
||||||
|
Self::new(vec![], config)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn backward_step(&mut self, loss: &Tensor) -> Result<()> {
|
||||||
|
let grads = loss.backward()?;
|
||||||
|
self.step(&grads)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn from_slice(vars: &[&Var], config: Self::Config) -> Result<Self> {
|
||||||
|
let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect();
|
||||||
|
Self::new(vars, config)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Optimizer for Stochastic Gradient Descent.
|
/// Optimizer for Stochastic Gradient Descent.
|
||||||
///
|
///
|
||||||
/// Contrary to the PyTorch implementation of SGD, this version does not support momentum.
|
/// Contrary to the PyTorch implementation of SGD, this version does not support momentum.
|
||||||
@ -10,42 +37,21 @@ pub struct SGD {
|
|||||||
learning_rate: f64,
|
learning_rate: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SGD {
|
impl Optimizer for SGD {
|
||||||
pub fn from_slice(vars: &[&Var], learning_rate: f64) -> Self {
|
type Config = f64;
|
||||||
let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect();
|
|
||||||
Self {
|
fn new(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
|
||||||
|
Ok(Self {
|
||||||
vars,
|
vars,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new(vars: Vec<Var>, learning_rate: f64) -> Self {
|
fn learning_rate(&self) -> f64 {
|
||||||
Self {
|
|
||||||
vars,
|
|
||||||
learning_rate,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn empty(learning_rate: f64) -> Self {
|
|
||||||
Self {
|
|
||||||
vars: vec![],
|
|
||||||
learning_rate,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn into_inner(self) -> Vec<Var> {
|
|
||||||
self.vars
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn learning_rate(&self) -> f64 {
|
|
||||||
self.learning_rate
|
self.learning_rate
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn push(&mut self, var: &Var) {
|
fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
|
||||||
self.vars.push(var.clone())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn step(&self, grads: &candle::backprop::GradStore) -> Result<()> {
|
|
||||||
for var in self.vars.iter() {
|
for var in self.vars.iter() {
|
||||||
if let Some(grad) = grads.get(var) {
|
if let Some(grad) = grads.get(var) {
|
||||||
var.set(&var.sub(&(grad * self.learning_rate)?)?)?;
|
var.set(&var.sub(&(grad * self.learning_rate)?)?)?;
|
||||||
@ -54,13 +60,18 @@ impl SGD {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn backward_step(&self, loss: &Tensor) -> Result<()> {
|
fn set_learning_rate(&mut self, lr: f64) {
|
||||||
let grads = loss.backward()?;
|
self.learning_rate = lr
|
||||||
self.step(&grads)
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SGD {
|
||||||
|
pub fn into_inner(self) -> Vec<Var> {
|
||||||
|
self.vars
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_learning_rate(&mut self, lr: f64) {
|
pub fn push(&mut self, var: &Var) {
|
||||||
self.learning_rate = lr
|
self.vars.push(var.clone())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -99,8 +110,10 @@ pub struct AdamW {
|
|||||||
params: ParamsAdamW,
|
params: ParamsAdamW,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AdamW {
|
impl Optimizer for AdamW {
|
||||||
pub fn new(vars: Vec<Var>, params: ParamsAdamW) -> Result<Self> {
|
type Config = ParamsAdamW;
|
||||||
|
|
||||||
|
fn new(vars: Vec<Var>, params: ParamsAdamW) -> Result<Self> {
|
||||||
let vars = vars
|
let vars = vars
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|var| {
|
.map(|var| {
|
||||||
@ -123,19 +136,15 @@ impl AdamW {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_lr(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
|
fn learning_rate(&self) -> f64 {
|
||||||
let params = ParamsAdamW {
|
self.params.lr
|
||||||
lr: learning_rate,
|
|
||||||
..ParamsAdamW::default()
|
|
||||||
};
|
|
||||||
Self::new(vars, params)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_learning_rate(&mut self, lr: f64) {
|
fn set_learning_rate(&mut self, lr: f64) {
|
||||||
self.params.lr = lr
|
self.params.lr = lr
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
|
fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
|
||||||
self.step_t += 1;
|
self.step_t += 1;
|
||||||
let lr = self.params.lr;
|
let lr = self.params.lr;
|
||||||
let lambda = self.params.weight_decay;
|
let lambda = self.params.weight_decay;
|
||||||
@ -166,9 +175,14 @@ impl AdamW {
|
|||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn backward_step(&mut self, loss: &Tensor) -> Result<()> {
|
impl AdamW {
|
||||||
let grads = loss.backward()?;
|
pub fn new_lr(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
|
||||||
self.step(&grads)
|
let params = ParamsAdamW {
|
||||||
|
lr: learning_rate,
|
||||||
|
..ParamsAdamW::default()
|
||||||
|
};
|
||||||
|
Self::new(vars, params)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -8,12 +8,12 @@ use candle::test_utils::{to_vec0_round, to_vec2_round};
|
|||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use candle::{Device, Tensor, Var};
|
use candle::{Device, Tensor, Var};
|
||||||
use candle_nn::{AdamW, Linear, Module, ParamsAdamW, SGD};
|
use candle_nn::{AdamW, Linear, Module, Optimizer, ParamsAdamW, SGD};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn sgd_optim() -> Result<()> {
|
fn sgd_optim() -> Result<()> {
|
||||||
let x = Var::new(0f32, &Device::Cpu)?;
|
let x = Var::new(0f32, &Device::Cpu)?;
|
||||||
let sgd = SGD::new(vec![x.clone()], 0.1);
|
let mut sgd = SGD::new(vec![x.clone()], 0.1)?;
|
||||||
let xt = x.as_tensor();
|
let xt = x.as_tensor();
|
||||||
for _step in 0..100 {
|
for _step in 0..100 {
|
||||||
let loss = ((xt - 4.2)? * (xt - 4.2)?)?;
|
let loss = ((xt - 4.2)? * (xt - 4.2)?)?;
|
||||||
@ -59,7 +59,7 @@ fn sgd_linear_regression() -> Result<()> {
|
|||||||
// Now use backprop to run a linear regression between samples and get the coefficients back.
|
// Now use backprop to run a linear regression between samples and get the coefficients back.
|
||||||
let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
|
let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
|
||||||
let b = Var::new(0f32, &Device::Cpu)?;
|
let b = Var::new(0f32, &Device::Cpu)?;
|
||||||
let sgd = SGD::new(vec![w.clone(), b.clone()], 0.004);
|
let mut sgd = SGD::new(vec![w.clone(), b.clone()], 0.004)?;
|
||||||
let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
|
let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
|
||||||
for _step in 0..1000 {
|
for _step in 0..1000 {
|
||||||
let ys = lin.forward(&sample_xs)?;
|
let ys = lin.forward(&sample_xs)?;
|
||||||
|
Reference in New Issue
Block a user