From 1a6043af5123bf9e189063d3baf110b39cf47617 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 25 Feb 2024 20:50:08 +0100 Subject: [PATCH] Tweak the VarMap set type. (#1758) --- candle-nn/src/var_map.rs | 2 +- candle-nn/tests/optim.rs | 39 ++++++++++++++++++++++- candle-transformers/src/models/mamba.rs | 6 ++-- candle-transformers/src/models/rwkv_v5.rs | 12 +++---- 4 files changed, 48 insertions(+), 11 deletions(-) diff --git a/candle-nn/src/var_map.rs b/candle-nn/src/var_map.rs index d34cee78..3cb27c63 100644 --- a/candle-nn/src/var_map.rs +++ b/candle-nn/src/var_map.rs @@ -70,7 +70,7 @@ impl VarMap { /// /// If an error is returned, some of the variables might have already been set to their new /// values. - pub fn set, K: AsRef, V: AsRef>( + pub fn set, K: AsRef, V: AsRef>( &mut self, iter: I, ) -> Result<()> { diff --git a/candle-nn/tests/optim.rs b/candle-nn/tests/optim.rs index 841f65c8..4eb14ed8 100644 --- a/candle-nn/tests/optim.rs +++ b/candle-nn/tests/optim.rs @@ -7,7 +7,7 @@ extern crate accelerate_src; use candle::test_utils::{to_vec0_round, to_vec2_round}; use anyhow::Result; -use candle::{Device, Tensor, Var}; +use candle::{DType, Device, Tensor, Var}; use candle_nn::{AdamW, Linear, Module, Optimizer, ParamsAdamW, SGD}; #[test] @@ -121,3 +121,40 @@ fn adamw_linear_regression() -> Result<()> { assert_eq!(to_vec0_round(b.as_tensor(), 4)?, 0.7873); Ok(()) } + +#[test] +fn adamw_linear_regression_varmap() -> Result<()> { + use candle_nn::Init::Const; + + // Similar as the previous test but using a VarMap. + let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; + let b_gen = Tensor::new(-2f32, &Device::Cpu)?; + let gen = Linear::new(w_gen, Some(b_gen)); + let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; + let sample_ys = gen.forward(&sample_xs)?; + + let mut var_map = candle_nn::VarMap::new(); + + let w = var_map.get((1, 2), "w", Const(0.), DType::F32, &Device::Cpu)?; + let b = var_map.get((), "b", Const(0.), DType::F32, &Device::Cpu)?; + let params = ParamsAdamW { + lr: 0.1, + ..Default::default() + }; + let mut opt = AdamW::new(var_map.all_vars(), params)?; + let lin = Linear::new(w, Some(b)); + for _step in 0..100 { + let ys = lin.forward(&sample_xs)?; + let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; + opt.backward_step(&loss)?; + } + assert_eq!(to_vec2_round(lin.weight(), 4)?, &[[2.7257, 0.7097]]); + assert_eq!(to_vec0_round(lin.bias().unwrap(), 4)?, 0.7873); + + var_map.set([("w", Tensor::zeros((1, 2), DType::F32, &Device::Cpu)?)].into_iter())?; + var_map.set([("b", Tensor::ones((), DType::F32, &Device::Cpu)?)].into_iter())?; + + assert_eq!(to_vec2_round(lin.weight(), 4)?, &[[0., 0.]]); + assert_eq!(to_vec0_round(lin.bias().unwrap(), 4)?, 1.); + Ok(()) +} diff --git a/candle-transformers/src/models/mamba.rs b/candle-transformers/src/models/mamba.rs index da254bd1..81828ad5 100644 --- a/candle-transformers/src/models/mamba.rs +++ b/candle-transformers/src/models/mamba.rs @@ -32,9 +32,9 @@ impl Config { } pub struct State { - hs: Vec, - prev_xs: Vec<[Tensor; D_CONV]>, - pos: usize, + pub hs: Vec, + pub prev_xs: Vec<[Tensor; D_CONV]>, + pub pos: usize, } impl State { diff --git a/candle-transformers/src/models/rwkv_v5.rs b/candle-transformers/src/models/rwkv_v5.rs index d11cdedd..38b1e450 100644 --- a/candle-transformers/src/models/rwkv_v5.rs +++ b/candle-transformers/src/models/rwkv_v5.rs @@ -22,15 +22,15 @@ pub struct Config { pub rescale_every: usize, } -struct StatePerLayer { - extract_key_value: Tensor, - linear_attention: Tensor, - feed_forward: Tensor, +pub struct StatePerLayer { + pub extract_key_value: Tensor, + pub linear_attention: Tensor, + pub feed_forward: Tensor, } pub struct State { - per_layer: Vec, - pos: usize, + pub per_layer: Vec, + pub pos: usize, } impl State {