Configurable layer idx for the lstm layer. (#962)

This commit is contained in:
Laurent Mazare
2023-09-25 21:31:14 +01:00
committed by GitHub
parent c78a294323
commit c798184c2b

View File

@ -57,6 +57,7 @@ pub struct LSTMConfig {
pub w_hh_init: super::Init,
pub b_ih_init: Option<super::Init>,
pub b_hh_init: Option<super::Init>,
pub layer_idx: usize,
}
impl Default for LSTMConfig {
@ -66,6 +67,7 @@ impl Default for LSTMConfig {
w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM,
b_ih_init: Some(super::Init::Const(0.)),
b_hh_init: Some(super::Init::Const(0.)),
layer_idx: 0,
}
}
}
@ -77,6 +79,7 @@ impl LSTMConfig {
w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM,
b_ih_init: None,
b_hh_init: None,
layer_idx: 0,
}
}
}
@ -104,22 +107,27 @@ pub fn lstm(
config: LSTMConfig,
vb: crate::VarBuilder,
) -> Result<LSTM> {
let layer_idx = config.layer_idx;
let w_ih = vb.get_with_hints(
(4 * hidden_dim, in_dim),
"weight_ih_l0", // Only a single layer is supported.
&format!("weight_ih_l{layer_idx}"), // Only a single layer is supported.
config.w_ih_init,
)?;
let w_hh = vb.get_with_hints(
(4 * hidden_dim, hidden_dim),
"weight_hh_l0", // Only a single layer is supported.
&format!("weight_hh_l{layer_idx}"), // Only a single layer is supported.
config.w_hh_init,
)?;
let b_ih = match config.b_ih_init {
Some(init) => Some(vb.get_with_hints(4 * hidden_dim, "bias_ih_l0", init)?),
Some(init) => {
Some(vb.get_with_hints(4 * hidden_dim, &format!("bias_ih_l{layer_idx}"), init)?)
}
None => None,
};
let b_hh = match config.b_hh_init {
Some(init) => Some(vb.get_with_hints(4 * hidden_dim, "bias_hh_l0", init)?),
Some(init) => {
Some(vb.get_with_hints(4 * hidden_dim, &format!("bias_hh_l{layer_idx}"), init)?)
}
None => None,
};
Ok(LSTM {