From c798184c2bab5be6a02781aaa736ab1d666991ed Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 25 Sep 2023 21:31:14 +0100 Subject: [PATCH] Configurable layer idx for the lstm layer. (#962) --- candle-nn/src/rnn.rs | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs index 18a4a71c..10ba48f3 100644 --- a/candle-nn/src/rnn.rs +++ b/candle-nn/src/rnn.rs @@ -57,6 +57,7 @@ pub struct LSTMConfig { pub w_hh_init: super::Init, pub b_ih_init: Option, pub b_hh_init: Option, + pub layer_idx: usize, } impl Default for LSTMConfig { @@ -66,6 +67,7 @@ impl Default for LSTMConfig { w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM, b_ih_init: Some(super::Init::Const(0.)), b_hh_init: Some(super::Init::Const(0.)), + layer_idx: 0, } } } @@ -77,6 +79,7 @@ impl LSTMConfig { w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM, b_ih_init: None, b_hh_init: None, + layer_idx: 0, } } } @@ -104,22 +107,27 @@ pub fn lstm( config: LSTMConfig, vb: crate::VarBuilder, ) -> Result { + let layer_idx = config.layer_idx; let w_ih = vb.get_with_hints( (4 * hidden_dim, in_dim), - "weight_ih_l0", // Only a single layer is supported. + &format!("weight_ih_l{layer_idx}"), // Only a single layer is supported. config.w_ih_init, )?; let w_hh = vb.get_with_hints( (4 * hidden_dim, hidden_dim), - "weight_hh_l0", // Only a single layer is supported. + &format!("weight_hh_l{layer_idx}"), // Only a single layer is supported. config.w_hh_init, )?; let b_ih = match config.b_ih_init { - Some(init) => Some(vb.get_with_hints(4 * hidden_dim, "bias_ih_l0", init)?), + Some(init) => { + Some(vb.get_with_hints(4 * hidden_dim, &format!("bias_ih_l{layer_idx}"), init)?) + } None => None, }; let b_hh = match config.b_hh_init { - Some(init) => Some(vb.get_with_hints(4 * hidden_dim, "bias_hh_l0", init)?), + Some(init) => { + Some(vb.get_with_hints(4 * hidden_dim, &format!("bias_hh_l{layer_idx}"), init)?) + } None => None, }; Ok(LSTM {