mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
@ -110,7 +110,7 @@ pub fn lstm(
|
|||||||
config.w_ih_init,
|
config.w_ih_init,
|
||||||
)?;
|
)?;
|
||||||
let w_hh = vb.get_with_hints(
|
let w_hh = vb.get_with_hints(
|
||||||
(4 * hidden_dim, in_dim),
|
(4 * hidden_dim, hidden_dim),
|
||||||
"weight_hh_l0", // Only a single layer is supported.
|
"weight_hh_l0", // Only a single layer is supported.
|
||||||
config.w_hh_init,
|
config.w_hh_init,
|
||||||
)?;
|
)?;
|
||||||
|
42
candle-nn/tests/rnn.rs
Normal file
42
candle-nn/tests/rnn.rs
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use candle::{test_utils::to_vec2_round, DType, Device, Result, Tensor};
|
||||||
|
use candle_nn::RNN;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn lstm() -> Result<()> {
|
||||||
|
let cpu = &Device::Cpu;
|
||||||
|
let w_ih = Tensor::arange(0f32, 24f32, cpu)?.reshape((12, 2))?;
|
||||||
|
let w_ih = w_ih.cos()?;
|
||||||
|
let w_hh = Tensor::arange(0f32, 36f32, cpu)?.reshape((12, 3))?;
|
||||||
|
let w_hh = w_hh.sin()?;
|
||||||
|
let b_ih = Tensor::new(
|
||||||
|
&[-1f32, 1., -0.5, 2., -1., 1., -0.5, 2., -1., 1., -0.5, 2.],
|
||||||
|
cpu,
|
||||||
|
)?;
|
||||||
|
let b_hh = b_ih.cos()?;
|
||||||
|
let tensors: std::collections::HashMap<_, _> = [
|
||||||
|
("weight_ih_l0".to_string(), w_ih),
|
||||||
|
("weight_hh_l0".to_string(), w_hh),
|
||||||
|
("bias_ih_l0".to_string(), b_ih),
|
||||||
|
("bias_hh_l0".to_string(), b_hh),
|
||||||
|
]
|
||||||
|
.into_iter()
|
||||||
|
.collect();
|
||||||
|
let vb = candle_nn::VarBuilder::from_tensors(tensors, DType::F32, cpu);
|
||||||
|
let lstm = candle_nn::lstm(2, 3, Default::default(), vb)?;
|
||||||
|
let mut state = lstm.zero_state(1)?;
|
||||||
|
for inp in [3f32, 1., 4., 1., 5., 9., 2.] {
|
||||||
|
let inp = Tensor::new(&[[inp, inp * 0.5]], cpu)?;
|
||||||
|
state = lstm.step(&inp, &state)?
|
||||||
|
}
|
||||||
|
let h = state.h();
|
||||||
|
let c = state.c();
|
||||||
|
assert_eq!(to_vec2_round(h, 4)?, &[[0.9919, 0.1738, -0.1451]]);
|
||||||
|
assert_eq!(to_vec2_round(c, 4)?, &[[5.725, 0.4458, -0.2908]]);
|
||||||
|
Ok(())
|
||||||
|
}
|
Reference in New Issue
Block a user