Make the RNN configs accessible from the models. (#2541)

This commit is contained in:
Laurent Mazare
2024-10-04 14:22:23 +02:00
committed by GitHub
parent 6faecaa616
commit 56aacb05da
3 changed files with 103 additions and 74 deletions

View File

@ -1,4 +1,3 @@
#![allow(unused)]
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};

View File

@ -1,4 +1,3 @@
#![allow(unused)]
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};

View File

@ -116,7 +116,7 @@ impl LSTMConfig {
/// A Long Short-Term Memory (LSTM) layer. /// A Long Short-Term Memory (LSTM) layer.
/// ///
/// <https://en.wikipedia.org/wiki/Long_short-term_memory> /// <https://en.wikipedia.org/wiki/Long_short-term_memory>
#[allow(clippy::upper_case_acronyms, unused)] #[allow(clippy::upper_case_acronyms)]
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct LSTM { pub struct LSTM {
w_ih: Tensor, w_ih: Tensor,
@ -129,13 +129,14 @@ pub struct LSTM {
dtype: DType, dtype: DType,
} }
/// Creates a LSTM layer. impl LSTM {
pub fn lstm( /// Creates a LSTM layer.
pub fn new(
in_dim: usize, in_dim: usize,
hidden_dim: usize, hidden_dim: usize,
config: LSTMConfig, config: LSTMConfig,
vb: crate::VarBuilder, vb: crate::VarBuilder,
) -> Result<LSTM> { ) -> Result<Self> {
let layer_idx = config.layer_idx; let layer_idx = config.layer_idx;
let direction_str = match config.direction { let direction_str = match config.direction {
Direction::Forward => "", Direction::Forward => "",
@ -167,7 +168,7 @@ pub fn lstm(
)?), )?),
None => None, None => None,
}; };
Ok(LSTM { Ok(Self {
w_ih, w_ih,
w_hh, w_hh,
b_ih, b_ih,
@ -177,6 +178,21 @@ pub fn lstm(
device: vb.device().clone(), device: vb.device().clone(),
dtype: vb.dtype(), dtype: vb.dtype(),
}) })
}
pub fn config(&self) -> &LSTMConfig {
&self.config
}
}
/// Creates a LSTM layer.
pub fn lstm(
in_dim: usize,
hidden_dim: usize,
config: LSTMConfig,
vb: crate::VarBuilder,
) -> Result<LSTM> {
LSTM::new(in_dim, hidden_dim, config, vb)
} }
impl RNN for LSTM { impl RNN for LSTM {
@ -270,7 +286,7 @@ impl GRUConfig {
/// A Gated Recurrent Unit (GRU) layer. /// A Gated Recurrent Unit (GRU) layer.
/// ///
/// <https://en.wikipedia.org/wiki/Gated_recurrent_unit> /// <https://en.wikipedia.org/wiki/Gated_recurrent_unit>
#[allow(clippy::upper_case_acronyms, unused)] #[allow(clippy::upper_case_acronyms)]
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct GRU { pub struct GRU {
w_ih: Tensor, w_ih: Tensor,
@ -283,13 +299,14 @@ pub struct GRU {
dtype: DType, dtype: DType,
} }
/// Creates a GRU layer. impl GRU {
pub fn gru( /// Creates a GRU layer.
pub fn new(
in_dim: usize, in_dim: usize,
hidden_dim: usize, hidden_dim: usize,
config: GRUConfig, config: GRUConfig,
vb: crate::VarBuilder, vb: crate::VarBuilder,
) -> Result<GRU> { ) -> Result<Self> {
let w_ih = vb.get_with_hints( let w_ih = vb.get_with_hints(
(3 * hidden_dim, in_dim), (3 * hidden_dim, in_dim),
"weight_ih_l0", // Only a single layer is supported. "weight_ih_l0", // Only a single layer is supported.
@ -308,7 +325,7 @@ pub fn gru(
Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_hh_l0", init)?), Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_hh_l0", init)?),
None => None, None => None,
}; };
Ok(GRU { Ok(Self {
w_ih, w_ih,
w_hh, w_hh,
b_ih, b_ih,
@ -318,6 +335,20 @@ pub fn gru(
device: vb.device().clone(), device: vb.device().clone(),
dtype: vb.dtype(), dtype: vb.dtype(),
}) })
}
pub fn config(&self) -> &GRUConfig {
&self.config
}
}
pub fn gru(
in_dim: usize,
hidden_dim: usize,
config: GRUConfig,
vb: crate::VarBuilder,
) -> Result<GRU> {
GRU::new(in_dim, hidden_dim, config, vb)
} }
impl RNN for GRU { impl RNN for GRU {