mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Tweak the VarMap set type. (#1758)
This commit is contained in:
@ -70,7 +70,7 @@ impl VarMap {
|
||||
///
|
||||
/// If an error is returned, some of the variables might have already been set to their new
|
||||
/// values.
|
||||
pub fn set<I: Iterator<Item = (K, V)>, K: AsRef<String>, V: AsRef<Tensor>>(
|
||||
pub fn set<I: Iterator<Item = (K, V)>, K: AsRef<str>, V: AsRef<Tensor>>(
|
||||
&mut self,
|
||||
iter: I,
|
||||
) -> Result<()> {
|
||||
|
@ -7,7 +7,7 @@ extern crate accelerate_src;
|
||||
use candle::test_utils::{to_vec0_round, to_vec2_round};
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{Device, Tensor, Var};
|
||||
use candle::{DType, Device, Tensor, Var};
|
||||
use candle_nn::{AdamW, Linear, Module, Optimizer, ParamsAdamW, SGD};
|
||||
|
||||
#[test]
|
||||
@ -121,3 +121,40 @@ fn adamw_linear_regression() -> Result<()> {
|
||||
assert_eq!(to_vec0_round(b.as_tensor(), 4)?, 0.7873);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adamw_linear_regression_varmap() -> Result<()> {
|
||||
use candle_nn::Init::Const;
|
||||
|
||||
// Similar as the previous test but using a VarMap.
|
||||
let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
|
||||
let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
|
||||
let gen = Linear::new(w_gen, Some(b_gen));
|
||||
let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
|
||||
let sample_ys = gen.forward(&sample_xs)?;
|
||||
|
||||
let mut var_map = candle_nn::VarMap::new();
|
||||
|
||||
let w = var_map.get((1, 2), "w", Const(0.), DType::F32, &Device::Cpu)?;
|
||||
let b = var_map.get((), "b", Const(0.), DType::F32, &Device::Cpu)?;
|
||||
let params = ParamsAdamW {
|
||||
lr: 0.1,
|
||||
..Default::default()
|
||||
};
|
||||
let mut opt = AdamW::new(var_map.all_vars(), params)?;
|
||||
let lin = Linear::new(w, Some(b));
|
||||
for _step in 0..100 {
|
||||
let ys = lin.forward(&sample_xs)?;
|
||||
let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
|
||||
opt.backward_step(&loss)?;
|
||||
}
|
||||
assert_eq!(to_vec2_round(lin.weight(), 4)?, &[[2.7257, 0.7097]]);
|
||||
assert_eq!(to_vec0_round(lin.bias().unwrap(), 4)?, 0.7873);
|
||||
|
||||
var_map.set([("w", Tensor::zeros((1, 2), DType::F32, &Device::Cpu)?)].into_iter())?;
|
||||
var_map.set([("b", Tensor::ones((), DType::F32, &Device::Cpu)?)].into_iter())?;
|
||||
|
||||
assert_eq!(to_vec2_round(lin.weight(), 4)?, &[[0., 0.]]);
|
||||
assert_eq!(to_vec0_round(lin.bias().unwrap(), 4)?, 1.);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user