mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Configurable layer idx for the lstm layer. (#962)
This commit is contained in:
@ -57,6 +57,7 @@ pub struct LSTMConfig {
|
||||
pub w_hh_init: super::Init,
|
||||
pub b_ih_init: Option<super::Init>,
|
||||
pub b_hh_init: Option<super::Init>,
|
||||
pub layer_idx: usize,
|
||||
}
|
||||
|
||||
impl Default for LSTMConfig {
|
||||
@ -66,6 +67,7 @@ impl Default for LSTMConfig {
|
||||
w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM,
|
||||
b_ih_init: Some(super::Init::Const(0.)),
|
||||
b_hh_init: Some(super::Init::Const(0.)),
|
||||
layer_idx: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -77,6 +79,7 @@ impl LSTMConfig {
|
||||
w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM,
|
||||
b_ih_init: None,
|
||||
b_hh_init: None,
|
||||
layer_idx: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -104,22 +107,27 @@ pub fn lstm(
|
||||
config: LSTMConfig,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<LSTM> {
|
||||
let layer_idx = config.layer_idx;
|
||||
let w_ih = vb.get_with_hints(
|
||||
(4 * hidden_dim, in_dim),
|
||||
"weight_ih_l0", // Only a single layer is supported.
|
||||
&format!("weight_ih_l{layer_idx}"), // Only a single layer is supported.
|
||||
config.w_ih_init,
|
||||
)?;
|
||||
let w_hh = vb.get_with_hints(
|
||||
(4 * hidden_dim, hidden_dim),
|
||||
"weight_hh_l0", // Only a single layer is supported.
|
||||
&format!("weight_hh_l{layer_idx}"), // Only a single layer is supported.
|
||||
config.w_hh_init,
|
||||
)?;
|
||||
let b_ih = match config.b_ih_init {
|
||||
Some(init) => Some(vb.get_with_hints(4 * hidden_dim, "bias_ih_l0", init)?),
|
||||
Some(init) => {
|
||||
Some(vb.get_with_hints(4 * hidden_dim, &format!("bias_ih_l{layer_idx}"), init)?)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
let b_hh = match config.b_hh_init {
|
||||
Some(init) => Some(vb.get_with_hints(4 * hidden_dim, "bias_hh_l0", init)?),
|
||||
Some(init) => {
|
||||
Some(vb.get_with_hints(4 * hidden_dim, &format!("bias_hh_l{layer_idx}"), init)?)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
Ok(LSTM {
|
||||
|
Reference in New Issue
Block a user