mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Make the RNN configs accessible from the models. (#2541)
This commit is contained in:
@ -1,4 +1,3 @@
|
||||
#![allow(unused)]
|
||||
use anyhow::{Context, Result};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
#![allow(unused)]
|
||||
use anyhow::{Context, Result};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
|
@ -116,7 +116,7 @@ impl LSTMConfig {
|
||||
/// A Long Short-Term Memory (LSTM) layer.
|
||||
///
|
||||
/// <https://en.wikipedia.org/wiki/Long_short-term_memory>
|
||||
#[allow(clippy::upper_case_acronyms, unused)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct LSTM {
|
||||
w_ih: Tensor,
|
||||
@ -129,13 +129,14 @@ pub struct LSTM {
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl LSTM {
|
||||
/// Creates a LSTM layer.
|
||||
pub fn lstm(
|
||||
pub fn new(
|
||||
in_dim: usize,
|
||||
hidden_dim: usize,
|
||||
config: LSTMConfig,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<LSTM> {
|
||||
) -> Result<Self> {
|
||||
let layer_idx = config.layer_idx;
|
||||
let direction_str = match config.direction {
|
||||
Direction::Forward => "",
|
||||
@ -167,7 +168,7 @@ pub fn lstm(
|
||||
)?),
|
||||
None => None,
|
||||
};
|
||||
Ok(LSTM {
|
||||
Ok(Self {
|
||||
w_ih,
|
||||
w_hh,
|
||||
b_ih,
|
||||
@ -179,6 +180,21 @@ pub fn lstm(
|
||||
})
|
||||
}
|
||||
|
||||
pub fn config(&self) -> &LSTMConfig {
|
||||
&self.config
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a LSTM layer.
|
||||
pub fn lstm(
|
||||
in_dim: usize,
|
||||
hidden_dim: usize,
|
||||
config: LSTMConfig,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<LSTM> {
|
||||
LSTM::new(in_dim, hidden_dim, config, vb)
|
||||
}
|
||||
|
||||
impl RNN for LSTM {
|
||||
type State = LSTMState;
|
||||
|
||||
@ -270,7 +286,7 @@ impl GRUConfig {
|
||||
/// A Gated Recurrent Unit (GRU) layer.
|
||||
///
|
||||
/// <https://en.wikipedia.org/wiki/Gated_recurrent_unit>
|
||||
#[allow(clippy::upper_case_acronyms, unused)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct GRU {
|
||||
w_ih: Tensor,
|
||||
@ -283,13 +299,14 @@ pub struct GRU {
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl GRU {
|
||||
/// Creates a GRU layer.
|
||||
pub fn gru(
|
||||
pub fn new(
|
||||
in_dim: usize,
|
||||
hidden_dim: usize,
|
||||
config: GRUConfig,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<GRU> {
|
||||
) -> Result<Self> {
|
||||
let w_ih = vb.get_with_hints(
|
||||
(3 * hidden_dim, in_dim),
|
||||
"weight_ih_l0", // Only a single layer is supported.
|
||||
@ -308,7 +325,7 @@ pub fn gru(
|
||||
Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_hh_l0", init)?),
|
||||
None => None,
|
||||
};
|
||||
Ok(GRU {
|
||||
Ok(Self {
|
||||
w_ih,
|
||||
w_hh,
|
||||
b_ih,
|
||||
@ -320,6 +337,20 @@ pub fn gru(
|
||||
})
|
||||
}
|
||||
|
||||
pub fn config(&self) -> &GRUConfig {
|
||||
&self.config
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gru(
|
||||
in_dim: usize,
|
||||
hidden_dim: usize,
|
||||
config: GRUConfig,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<GRU> {
|
||||
GRU::new(in_dim, hidden_dim, config, vb)
|
||||
}
|
||||
|
||||
impl RNN for GRU {
|
||||
type State = GRUState;
|
||||
|
||||
|
Reference in New Issue
Block a user