mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add/lstm direction (#2455)
* add: direction for lstm layer * lint: remove unused Error import * refactor: remove unnecessary int assignment to Direction enum: * refactor: use &'static str type instead of String for direction_str: * Run cargofmt. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -70,6 +70,12 @@ impl LSTMState {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub enum Direction {
|
||||||
|
Forward,
|
||||||
|
Backward,
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(clippy::upper_case_acronyms)]
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub struct LSTMConfig {
|
pub struct LSTMConfig {
|
||||||
@ -78,6 +84,7 @@ pub struct LSTMConfig {
|
|||||||
pub b_ih_init: Option<super::Init>,
|
pub b_ih_init: Option<super::Init>,
|
||||||
pub b_hh_init: Option<super::Init>,
|
pub b_hh_init: Option<super::Init>,
|
||||||
pub layer_idx: usize,
|
pub layer_idx: usize,
|
||||||
|
pub direction: Direction,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for LSTMConfig {
|
impl Default for LSTMConfig {
|
||||||
@ -88,6 +95,7 @@ impl Default for LSTMConfig {
|
|||||||
b_ih_init: Some(super::Init::Const(0.)),
|
b_ih_init: Some(super::Init::Const(0.)),
|
||||||
b_hh_init: Some(super::Init::Const(0.)),
|
b_hh_init: Some(super::Init::Const(0.)),
|
||||||
layer_idx: 0,
|
layer_idx: 0,
|
||||||
|
direction: Direction::Forward,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -100,6 +108,7 @@ impl LSTMConfig {
|
|||||||
b_ih_init: None,
|
b_ih_init: None,
|
||||||
b_hh_init: None,
|
b_hh_init: None,
|
||||||
layer_idx: 0,
|
layer_idx: 0,
|
||||||
|
direction: Direction::Forward,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -128,26 +137,34 @@ pub fn lstm(
|
|||||||
vb: crate::VarBuilder,
|
vb: crate::VarBuilder,
|
||||||
) -> Result<LSTM> {
|
) -> Result<LSTM> {
|
||||||
let layer_idx = config.layer_idx;
|
let layer_idx = config.layer_idx;
|
||||||
|
let direction_str = match config.direction {
|
||||||
|
Direction::Forward => "",
|
||||||
|
Direction::Backward => "_reverse",
|
||||||
|
};
|
||||||
let w_ih = vb.get_with_hints(
|
let w_ih = vb.get_with_hints(
|
||||||
(4 * hidden_dim, in_dim),
|
(4 * hidden_dim, in_dim),
|
||||||
&format!("weight_ih_l{layer_idx}"), // Only a single layer is supported.
|
&format!("weight_ih_l{layer_idx}{direction_str}"), // Only a single layer is supported.
|
||||||
config.w_ih_init,
|
config.w_ih_init,
|
||||||
)?;
|
)?;
|
||||||
let w_hh = vb.get_with_hints(
|
let w_hh = vb.get_with_hints(
|
||||||
(4 * hidden_dim, hidden_dim),
|
(4 * hidden_dim, hidden_dim),
|
||||||
&format!("weight_hh_l{layer_idx}"), // Only a single layer is supported.
|
&format!("weight_hh_l{layer_idx}{direction_str}"), // Only a single layer is supported.
|
||||||
config.w_hh_init,
|
config.w_hh_init,
|
||||||
)?;
|
)?;
|
||||||
let b_ih = match config.b_ih_init {
|
let b_ih = match config.b_ih_init {
|
||||||
Some(init) => {
|
Some(init) => Some(vb.get_with_hints(
|
||||||
Some(vb.get_with_hints(4 * hidden_dim, &format!("bias_ih_l{layer_idx}"), init)?)
|
4 * hidden_dim,
|
||||||
}
|
&format!("bias_ih_l{layer_idx}{direction_str}"),
|
||||||
|
init,
|
||||||
|
)?),
|
||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
let b_hh = match config.b_hh_init {
|
let b_hh = match config.b_hh_init {
|
||||||
Some(init) => {
|
Some(init) => Some(vb.get_with_hints(
|
||||||
Some(vb.get_with_hints(4 * hidden_dim, &format!("bias_hh_l{layer_idx}"), init)?)
|
4 * hidden_dim,
|
||||||
}
|
&format!("bias_hh_l{layer_idx}{direction_str}"),
|
||||||
|
init,
|
||||||
|
)?),
|
||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
Ok(LSTM {
|
Ok(LSTM {
|
||||||
|
Reference in New Issue
Block a user