mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Make the RNN configs accessible from the models. (#2541)
This commit is contained in:
@ -1,4 +1,3 @@
|
|||||||
#![allow(unused)]
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
#![allow(unused)]
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
@ -116,7 +116,7 @@ impl LSTMConfig {
|
|||||||
/// A Long Short-Term Memory (LSTM) layer.
|
/// A Long Short-Term Memory (LSTM) layer.
|
||||||
///
|
///
|
||||||
/// <https://en.wikipedia.org/wiki/Long_short-term_memory>
|
/// <https://en.wikipedia.org/wiki/Long_short-term_memory>
|
||||||
#[allow(clippy::upper_case_acronyms, unused)]
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct LSTM {
|
pub struct LSTM {
|
||||||
w_ih: Tensor,
|
w_ih: Tensor,
|
||||||
@ -129,6 +129,62 @@ pub struct LSTM {
|
|||||||
dtype: DType,
|
dtype: DType,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl LSTM {
|
||||||
|
/// Creates a LSTM layer.
|
||||||
|
pub fn new(
|
||||||
|
in_dim: usize,
|
||||||
|
hidden_dim: usize,
|
||||||
|
config: LSTMConfig,
|
||||||
|
vb: crate::VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let layer_idx = config.layer_idx;
|
||||||
|
let direction_str = match config.direction {
|
||||||
|
Direction::Forward => "",
|
||||||
|
Direction::Backward => "_reverse",
|
||||||
|
};
|
||||||
|
let w_ih = vb.get_with_hints(
|
||||||
|
(4 * hidden_dim, in_dim),
|
||||||
|
&format!("weight_ih_l{layer_idx}{direction_str}"), // Only a single layer is supported.
|
||||||
|
config.w_ih_init,
|
||||||
|
)?;
|
||||||
|
let w_hh = vb.get_with_hints(
|
||||||
|
(4 * hidden_dim, hidden_dim),
|
||||||
|
&format!("weight_hh_l{layer_idx}{direction_str}"), // Only a single layer is supported.
|
||||||
|
config.w_hh_init,
|
||||||
|
)?;
|
||||||
|
let b_ih = match config.b_ih_init {
|
||||||
|
Some(init) => Some(vb.get_with_hints(
|
||||||
|
4 * hidden_dim,
|
||||||
|
&format!("bias_ih_l{layer_idx}{direction_str}"),
|
||||||
|
init,
|
||||||
|
)?),
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
let b_hh = match config.b_hh_init {
|
||||||
|
Some(init) => Some(vb.get_with_hints(
|
||||||
|
4 * hidden_dim,
|
||||||
|
&format!("bias_hh_l{layer_idx}{direction_str}"),
|
||||||
|
init,
|
||||||
|
)?),
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
Ok(Self {
|
||||||
|
w_ih,
|
||||||
|
w_hh,
|
||||||
|
b_ih,
|
||||||
|
b_hh,
|
||||||
|
hidden_dim,
|
||||||
|
config,
|
||||||
|
device: vb.device().clone(),
|
||||||
|
dtype: vb.dtype(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn config(&self) -> &LSTMConfig {
|
||||||
|
&self.config
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Creates a LSTM layer.
|
/// Creates a LSTM layer.
|
||||||
pub fn lstm(
|
pub fn lstm(
|
||||||
in_dim: usize,
|
in_dim: usize,
|
||||||
@ -136,47 +192,7 @@ pub fn lstm(
|
|||||||
config: LSTMConfig,
|
config: LSTMConfig,
|
||||||
vb: crate::VarBuilder,
|
vb: crate::VarBuilder,
|
||||||
) -> Result<LSTM> {
|
) -> Result<LSTM> {
|
||||||
let layer_idx = config.layer_idx;
|
LSTM::new(in_dim, hidden_dim, config, vb)
|
||||||
let direction_str = match config.direction {
|
|
||||||
Direction::Forward => "",
|
|
||||||
Direction::Backward => "_reverse",
|
|
||||||
};
|
|
||||||
let w_ih = vb.get_with_hints(
|
|
||||||
(4 * hidden_dim, in_dim),
|
|
||||||
&format!("weight_ih_l{layer_idx}{direction_str}"), // Only a single layer is supported.
|
|
||||||
config.w_ih_init,
|
|
||||||
)?;
|
|
||||||
let w_hh = vb.get_with_hints(
|
|
||||||
(4 * hidden_dim, hidden_dim),
|
|
||||||
&format!("weight_hh_l{layer_idx}{direction_str}"), // Only a single layer is supported.
|
|
||||||
config.w_hh_init,
|
|
||||||
)?;
|
|
||||||
let b_ih = match config.b_ih_init {
|
|
||||||
Some(init) => Some(vb.get_with_hints(
|
|
||||||
4 * hidden_dim,
|
|
||||||
&format!("bias_ih_l{layer_idx}{direction_str}"),
|
|
||||||
init,
|
|
||||||
)?),
|
|
||||||
None => None,
|
|
||||||
};
|
|
||||||
let b_hh = match config.b_hh_init {
|
|
||||||
Some(init) => Some(vb.get_with_hints(
|
|
||||||
4 * hidden_dim,
|
|
||||||
&format!("bias_hh_l{layer_idx}{direction_str}"),
|
|
||||||
init,
|
|
||||||
)?),
|
|
||||||
None => None,
|
|
||||||
};
|
|
||||||
Ok(LSTM {
|
|
||||||
w_ih,
|
|
||||||
w_hh,
|
|
||||||
b_ih,
|
|
||||||
b_hh,
|
|
||||||
hidden_dim,
|
|
||||||
config,
|
|
||||||
device: vb.device().clone(),
|
|
||||||
dtype: vb.dtype(),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RNN for LSTM {
|
impl RNN for LSTM {
|
||||||
@ -270,7 +286,7 @@ impl GRUConfig {
|
|||||||
/// A Gated Recurrent Unit (GRU) layer.
|
/// A Gated Recurrent Unit (GRU) layer.
|
||||||
///
|
///
|
||||||
/// <https://en.wikipedia.org/wiki/Gated_recurrent_unit>
|
/// <https://en.wikipedia.org/wiki/Gated_recurrent_unit>
|
||||||
#[allow(clippy::upper_case_acronyms, unused)]
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct GRU {
|
pub struct GRU {
|
||||||
w_ih: Tensor,
|
w_ih: Tensor,
|
||||||
@ -283,41 +299,56 @@ pub struct GRU {
|
|||||||
dtype: DType,
|
dtype: DType,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a GRU layer.
|
impl GRU {
|
||||||
|
/// Creates a GRU layer.
|
||||||
|
pub fn new(
|
||||||
|
in_dim: usize,
|
||||||
|
hidden_dim: usize,
|
||||||
|
config: GRUConfig,
|
||||||
|
vb: crate::VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let w_ih = vb.get_with_hints(
|
||||||
|
(3 * hidden_dim, in_dim),
|
||||||
|
"weight_ih_l0", // Only a single layer is supported.
|
||||||
|
config.w_ih_init,
|
||||||
|
)?;
|
||||||
|
let w_hh = vb.get_with_hints(
|
||||||
|
(3 * hidden_dim, hidden_dim),
|
||||||
|
"weight_hh_l0", // Only a single layer is supported.
|
||||||
|
config.w_hh_init,
|
||||||
|
)?;
|
||||||
|
let b_ih = match config.b_ih_init {
|
||||||
|
Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_ih_l0", init)?),
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
let b_hh = match config.b_hh_init {
|
||||||
|
Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_hh_l0", init)?),
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
Ok(Self {
|
||||||
|
w_ih,
|
||||||
|
w_hh,
|
||||||
|
b_ih,
|
||||||
|
b_hh,
|
||||||
|
hidden_dim,
|
||||||
|
config,
|
||||||
|
device: vb.device().clone(),
|
||||||
|
dtype: vb.dtype(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn config(&self) -> &GRUConfig {
|
||||||
|
&self.config
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn gru(
|
pub fn gru(
|
||||||
in_dim: usize,
|
in_dim: usize,
|
||||||
hidden_dim: usize,
|
hidden_dim: usize,
|
||||||
config: GRUConfig,
|
config: GRUConfig,
|
||||||
vb: crate::VarBuilder,
|
vb: crate::VarBuilder,
|
||||||
) -> Result<GRU> {
|
) -> Result<GRU> {
|
||||||
let w_ih = vb.get_with_hints(
|
GRU::new(in_dim, hidden_dim, config, vb)
|
||||||
(3 * hidden_dim, in_dim),
|
|
||||||
"weight_ih_l0", // Only a single layer is supported.
|
|
||||||
config.w_ih_init,
|
|
||||||
)?;
|
|
||||||
let w_hh = vb.get_with_hints(
|
|
||||||
(3 * hidden_dim, hidden_dim),
|
|
||||||
"weight_hh_l0", // Only a single layer is supported.
|
|
||||||
config.w_hh_init,
|
|
||||||
)?;
|
|
||||||
let b_ih = match config.b_ih_init {
|
|
||||||
Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_ih_l0", init)?),
|
|
||||||
None => None,
|
|
||||||
};
|
|
||||||
let b_hh = match config.b_hh_init {
|
|
||||||
Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_hh_l0", init)?),
|
|
||||||
None => None,
|
|
||||||
};
|
|
||||||
Ok(GRU {
|
|
||||||
w_ih,
|
|
||||||
w_hh,
|
|
||||||
b_ih,
|
|
||||||
b_hh,
|
|
||||||
hidden_dim,
|
|
||||||
config,
|
|
||||||
device: vb.device().clone(),
|
|
||||||
dtype: vb.dtype(),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RNN for GRU {
|
impl RNN for GRU {
|
||||||
|
Reference in New Issue
Block a user