Add/lstm direction (#2455)

* add: direction for lstm layer

* lint: remove unused Error import

* refactor: remove unnecessary int assignment to Direction enum:

* refactor: use &'static str type instead of String for direction_str:

* Run cargofmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
Justin Sing
2024-09-30 16:44:07 -04:00
committed by GitHub
parent 724650446c
commit aa35bf2ff5

View File

@ -70,6 +70,12 @@ impl LSTMState {
} }
} }
#[derive(Debug, Clone, Copy)]
pub enum Direction {
Forward,
Backward,
}
#[allow(clippy::upper_case_acronyms)] #[allow(clippy::upper_case_acronyms)]
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct LSTMConfig { pub struct LSTMConfig {
@ -78,6 +84,7 @@ pub struct LSTMConfig {
pub b_ih_init: Option<super::Init>, pub b_ih_init: Option<super::Init>,
pub b_hh_init: Option<super::Init>, pub b_hh_init: Option<super::Init>,
pub layer_idx: usize, pub layer_idx: usize,
pub direction: Direction,
} }
impl Default for LSTMConfig { impl Default for LSTMConfig {
@ -88,6 +95,7 @@ impl Default for LSTMConfig {
b_ih_init: Some(super::Init::Const(0.)), b_ih_init: Some(super::Init::Const(0.)),
b_hh_init: Some(super::Init::Const(0.)), b_hh_init: Some(super::Init::Const(0.)),
layer_idx: 0, layer_idx: 0,
direction: Direction::Forward,
} }
} }
} }
@ -100,6 +108,7 @@ impl LSTMConfig {
b_ih_init: None, b_ih_init: None,
b_hh_init: None, b_hh_init: None,
layer_idx: 0, layer_idx: 0,
direction: Direction::Forward,
} }
} }
} }
@ -128,26 +137,34 @@ pub fn lstm(
vb: crate::VarBuilder, vb: crate::VarBuilder,
) -> Result<LSTM> { ) -> Result<LSTM> {
let layer_idx = config.layer_idx; let layer_idx = config.layer_idx;
let direction_str = match config.direction {
Direction::Forward => "",
Direction::Backward => "_reverse",
};
let w_ih = vb.get_with_hints( let w_ih = vb.get_with_hints(
(4 * hidden_dim, in_dim), (4 * hidden_dim, in_dim),
&format!("weight_ih_l{layer_idx}"), // Only a single layer is supported. &format!("weight_ih_l{layer_idx}{direction_str}"), // Only a single layer is supported.
config.w_ih_init, config.w_ih_init,
)?; )?;
let w_hh = vb.get_with_hints( let w_hh = vb.get_with_hints(
(4 * hidden_dim, hidden_dim), (4 * hidden_dim, hidden_dim),
&format!("weight_hh_l{layer_idx}"), // Only a single layer is supported. &format!("weight_hh_l{layer_idx}{direction_str}"), // Only a single layer is supported.
config.w_hh_init, config.w_hh_init,
)?; )?;
let b_ih = match config.b_ih_init { let b_ih = match config.b_ih_init {
Some(init) => { Some(init) => Some(vb.get_with_hints(
Some(vb.get_with_hints(4 * hidden_dim, &format!("bias_ih_l{layer_idx}"), init)?) 4 * hidden_dim,
} &format!("bias_ih_l{layer_idx}{direction_str}"),
init,
)?),
None => None, None => None,
}; };
let b_hh = match config.b_hh_init { let b_hh = match config.b_hh_init {
Some(init) => { Some(init) => Some(vb.get_with_hints(
Some(vb.get_with_hints(4 * hidden_dim, &format!("bias_hh_l{layer_idx}"), init)?) 4 * hidden_dim,
} &format!("bias_hh_l{layer_idx}{direction_str}"),
init,
)?),
None => None, None => None,
}; };
Ok(LSTM { Ok(LSTM {