mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add the optimizer trait. (#702)
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
use crate::model::{Cache, Config, Llama};
|
||||
use candle::{DType, Device, Result};
|
||||
use candle_datasets::nlp::tinystories::{Dataset, DatasetRandomIter};
|
||||
use candle_nn::Optimizer;
|
||||
|
||||
fn valid_loss(
|
||||
dataset: &Dataset,
|
||||
|
@ -9,7 +9,7 @@ use clap::{Parser, ValueEnum};
|
||||
use rand::prelude::*;
|
||||
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
use candle_nn::{loss, ops, Conv2d, Linear, Module, VarBuilder, VarMap};
|
||||
use candle_nn::{loss, ops, Conv2d, Linear, Module, Optimizer, VarBuilder, VarMap};
|
||||
|
||||
const IMAGE_DIM: usize = 784;
|
||||
const LABELS: usize = 10;
|
||||
@ -190,7 +190,7 @@ fn training_loop<M: Model>(
|
||||
varmap.load(load)?
|
||||
}
|
||||
|
||||
let sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate);
|
||||
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate)?;
|
||||
let test_images = m.test_images.to_device(&dev)?;
|
||||
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
for epoch in 1..args.epochs {
|
||||
|
@ -5,7 +5,7 @@ extern crate intel_mkl_src;
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle_nn::{linear, AdamW, Linear, Module, ParamsAdamW, VarBuilder, VarMap};
|
||||
use candle_nn::{linear, AdamW, Linear, Module, Optimizer, ParamsAdamW, VarBuilder, VarMap};
|
||||
|
||||
fn gen_data() -> Result<(Tensor, Tensor)> {
|
||||
// Generate some sample linear data.
|
||||
|
@ -24,7 +24,7 @@ pub use init::Init;
|
||||
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
|
||||
pub use linear::{linear, linear_no_bias, Linear};
|
||||
pub use ops::Dropout;
|
||||
pub use optim::{AdamW, ParamsAdamW, SGD};
|
||||
pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD};
|
||||
pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN};
|
||||
pub use var_builder::VarBuilder;
|
||||
pub use var_map::VarMap;
|
||||
|
@ -1,6 +1,33 @@
|
||||
//! Various optimization algorithms.
|
||||
use candle::{Result, Tensor, Var};
|
||||
|
||||
/// The interface optimizers should implement.
|
||||
pub trait Optimizer: Sized {
|
||||
type Config: Sized;
|
||||
|
||||
fn new(vars: Vec<Var>, config: Self::Config) -> Result<Self>;
|
||||
|
||||
fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()>;
|
||||
|
||||
fn learning_rate(&self) -> f64;
|
||||
|
||||
fn set_learning_rate(&mut self, lr: f64);
|
||||
|
||||
fn empty(config: Self::Config) -> Result<Self> {
|
||||
Self::new(vec![], config)
|
||||
}
|
||||
|
||||
fn backward_step(&mut self, loss: &Tensor) -> Result<()> {
|
||||
let grads = loss.backward()?;
|
||||
self.step(&grads)
|
||||
}
|
||||
|
||||
fn from_slice(vars: &[&Var], config: Self::Config) -> Result<Self> {
|
||||
let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect();
|
||||
Self::new(vars, config)
|
||||
}
|
||||
}
|
||||
|
||||
/// Optimizer for Stochastic Gradient Descent.
|
||||
///
|
||||
/// Contrary to the PyTorch implementation of SGD, this version does not support momentum.
|
||||
@ -10,42 +37,21 @@ pub struct SGD {
|
||||
learning_rate: f64,
|
||||
}
|
||||
|
||||
impl SGD {
|
||||
pub fn from_slice(vars: &[&Var], learning_rate: f64) -> Self {
|
||||
let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect();
|
||||
Self {
|
||||
impl Optimizer for SGD {
|
||||
type Config = f64;
|
||||
|
||||
fn new(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
|
||||
Ok(Self {
|
||||
vars,
|
||||
learning_rate,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new(vars: Vec<Var>, learning_rate: f64) -> Self {
|
||||
Self {
|
||||
vars,
|
||||
learning_rate,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn empty(learning_rate: f64) -> Self {
|
||||
Self {
|
||||
vars: vec![],
|
||||
learning_rate,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> Vec<Var> {
|
||||
self.vars
|
||||
}
|
||||
|
||||
pub fn learning_rate(&self) -> f64 {
|
||||
fn learning_rate(&self) -> f64 {
|
||||
self.learning_rate
|
||||
}
|
||||
|
||||
pub fn push(&mut self, var: &Var) {
|
||||
self.vars.push(var.clone())
|
||||
}
|
||||
|
||||
pub fn step(&self, grads: &candle::backprop::GradStore) -> Result<()> {
|
||||
fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
|
||||
for var in self.vars.iter() {
|
||||
if let Some(grad) = grads.get(var) {
|
||||
var.set(&var.sub(&(grad * self.learning_rate)?)?)?;
|
||||
@ -54,13 +60,18 @@ impl SGD {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn backward_step(&self, loss: &Tensor) -> Result<()> {
|
||||
let grads = loss.backward()?;
|
||||
self.step(&grads)
|
||||
fn set_learning_rate(&mut self, lr: f64) {
|
||||
self.learning_rate = lr
|
||||
}
|
||||
}
|
||||
|
||||
impl SGD {
|
||||
pub fn into_inner(self) -> Vec<Var> {
|
||||
self.vars
|
||||
}
|
||||
|
||||
pub fn set_learning_rate(&mut self, lr: f64) {
|
||||
self.learning_rate = lr
|
||||
pub fn push(&mut self, var: &Var) {
|
||||
self.vars.push(var.clone())
|
||||
}
|
||||
}
|
||||
|
||||
@ -99,8 +110,10 @@ pub struct AdamW {
|
||||
params: ParamsAdamW,
|
||||
}
|
||||
|
||||
impl AdamW {
|
||||
pub fn new(vars: Vec<Var>, params: ParamsAdamW) -> Result<Self> {
|
||||
impl Optimizer for AdamW {
|
||||
type Config = ParamsAdamW;
|
||||
|
||||
fn new(vars: Vec<Var>, params: ParamsAdamW) -> Result<Self> {
|
||||
let vars = vars
|
||||
.into_iter()
|
||||
.map(|var| {
|
||||
@ -123,19 +136,15 @@ impl AdamW {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new_lr(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
|
||||
let params = ParamsAdamW {
|
||||
lr: learning_rate,
|
||||
..ParamsAdamW::default()
|
||||
};
|
||||
Self::new(vars, params)
|
||||
fn learning_rate(&self) -> f64 {
|
||||
self.params.lr
|
||||
}
|
||||
|
||||
pub fn set_learning_rate(&mut self, lr: f64) {
|
||||
fn set_learning_rate(&mut self, lr: f64) {
|
||||
self.params.lr = lr
|
||||
}
|
||||
|
||||
pub fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
|
||||
fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
|
||||
self.step_t += 1;
|
||||
let lr = self.params.lr;
|
||||
let lambda = self.params.weight_decay;
|
||||
@ -166,9 +175,14 @@ impl AdamW {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn backward_step(&mut self, loss: &Tensor) -> Result<()> {
|
||||
let grads = loss.backward()?;
|
||||
self.step(&grads)
|
||||
impl AdamW {
|
||||
pub fn new_lr(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
|
||||
let params = ParamsAdamW {
|
||||
lr: learning_rate,
|
||||
..ParamsAdamW::default()
|
||||
};
|
||||
Self::new(vars, params)
|
||||
}
|
||||
}
|
||||
|
@ -8,12 +8,12 @@ use candle::test_utils::{to_vec0_round, to_vec2_round};
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{Device, Tensor, Var};
|
||||
use candle_nn::{AdamW, Linear, Module, ParamsAdamW, SGD};
|
||||
use candle_nn::{AdamW, Linear, Module, Optimizer, ParamsAdamW, SGD};
|
||||
|
||||
#[test]
|
||||
fn sgd_optim() -> Result<()> {
|
||||
let x = Var::new(0f32, &Device::Cpu)?;
|
||||
let sgd = SGD::new(vec![x.clone()], 0.1);
|
||||
let mut sgd = SGD::new(vec![x.clone()], 0.1)?;
|
||||
let xt = x.as_tensor();
|
||||
for _step in 0..100 {
|
||||
let loss = ((xt - 4.2)? * (xt - 4.2)?)?;
|
||||
@ -59,7 +59,7 @@ fn sgd_linear_regression() -> Result<()> {
|
||||
// Now use backprop to run a linear regression between samples and get the coefficients back.
|
||||
let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
|
||||
let b = Var::new(0f32, &Device::Cpu)?;
|
||||
let sgd = SGD::new(vec![w.clone(), b.clone()], 0.004);
|
||||
let mut sgd = SGD::new(vec![w.clone(), b.clone()], 0.004)?;
|
||||
let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
|
||||
for _step in 0..1000 {
|
||||
let ys = lin.forward(&sample_xs)?;
|
||||
|
Reference in New Issue
Block a user