From 21e1c738928eb6ad0266d63ae10f9d8d849bb124 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 30 Aug 2023 20:05:42 +0200 Subject: [PATCH] Add a LSTM test. (#681) * Add a LSTM test. * Clippy. --- candle-nn/src/rnn.rs | 2 +- candle-nn/tests/rnn.rs | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) create mode 100644 candle-nn/tests/rnn.rs diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs index 3c82e794..681f2b2b 100644 --- a/candle-nn/src/rnn.rs +++ b/candle-nn/src/rnn.rs @@ -110,7 +110,7 @@ pub fn lstm( config.w_ih_init, )?; let w_hh = vb.get_with_hints( - (4 * hidden_dim, in_dim), + (4 * hidden_dim, hidden_dim), "weight_hh_l0", // Only a single layer is supported. config.w_hh_init, )?; diff --git a/candle-nn/tests/rnn.rs b/candle-nn/tests/rnn.rs new file mode 100644 index 00000000..0f0cca38 --- /dev/null +++ b/candle-nn/tests/rnn.rs @@ -0,0 +1,42 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use candle::{test_utils::to_vec2_round, DType, Device, Result, Tensor}; +use candle_nn::RNN; + +#[test] +fn lstm() -> Result<()> { + let cpu = &Device::Cpu; + let w_ih = Tensor::arange(0f32, 24f32, cpu)?.reshape((12, 2))?; + let w_ih = w_ih.cos()?; + let w_hh = Tensor::arange(0f32, 36f32, cpu)?.reshape((12, 3))?; + let w_hh = w_hh.sin()?; + let b_ih = Tensor::new( + &[-1f32, 1., -0.5, 2., -1., 1., -0.5, 2., -1., 1., -0.5, 2.], + cpu, + )?; + let b_hh = b_ih.cos()?; + let tensors: std::collections::HashMap<_, _> = [ + ("weight_ih_l0".to_string(), w_ih), + ("weight_hh_l0".to_string(), w_hh), + ("bias_ih_l0".to_string(), b_ih), + ("bias_hh_l0".to_string(), b_hh), + ] + .into_iter() + .collect(); + let vb = candle_nn::VarBuilder::from_tensors(tensors, DType::F32, cpu); + let lstm = candle_nn::lstm(2, 3, Default::default(), vb)?; + let mut state = lstm.zero_state(1)?; + for inp in [3f32, 1., 4., 1., 5., 9., 2.] { + let inp = Tensor::new(&[[inp, inp * 0.5]], cpu)?; + state = lstm.step(&inp, &state)? + } + let h = state.h(); + let c = state.c(); + assert_eq!(to_vec2_round(h, 4)?, &[[0.9919, 0.1738, -0.1451]]); + assert_eq!(to_vec2_round(c, 4)?, &[[5.725, 0.4458, -0.2908]]); + Ok(()) +}