mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Add a GRU layer. (#688)
* Add a GRU layer. * Fix the n gate computation.
This commit is contained in:
@ -25,7 +25,7 @@ pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
|
||||
pub use linear::{linear, linear_no_bias, Linear};
|
||||
pub use ops::Dropout;
|
||||
pub use optim::{AdamW, ParamsAdamW, SGD};
|
||||
pub use rnn::{lstm, LSTM, RNN};
|
||||
pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN};
|
||||
pub use var_builder::VarBuilder;
|
||||
pub use var_map::VarMap;
|
||||
|
||||
|
@ -184,3 +184,145 @@ impl RNN for LSTM {
|
||||
Ok((output, state))
|
||||
}
|
||||
}
|
||||
|
||||
/// The state for a GRU network, this contains a single tensor.
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GRUState {
|
||||
h: Tensor,
|
||||
}
|
||||
|
||||
impl GRUState {
|
||||
/// The hidden state vector, which is also the output of the LSTM.
|
||||
pub fn h(&self) -> &Tensor {
|
||||
&self.h
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct GRUConfig {
|
||||
pub w_ih_init: super::Init,
|
||||
pub w_hh_init: super::Init,
|
||||
pub b_ih_init: Option<super::Init>,
|
||||
pub b_hh_init: Option<super::Init>,
|
||||
}
|
||||
|
||||
impl Default for GRUConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
w_ih_init: super::init::DEFAULT_KAIMING_UNIFORM,
|
||||
w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM,
|
||||
b_ih_init: Some(super::Init::Const(0.)),
|
||||
b_hh_init: Some(super::Init::Const(0.)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl GRUConfig {
|
||||
pub fn default_no_bias() -> Self {
|
||||
Self {
|
||||
w_ih_init: super::init::DEFAULT_KAIMING_UNIFORM,
|
||||
w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM,
|
||||
b_ih_init: None,
|
||||
b_hh_init: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A Gated Recurrent Unit (GRU) layer.
|
||||
///
|
||||
/// <https://en.wikipedia.org/wiki/Gated_recurrent_unit>
|
||||
#[allow(clippy::upper_case_acronyms, unused)]
|
||||
#[derive(Debug)]
|
||||
pub struct GRU {
|
||||
w_ih: Tensor,
|
||||
w_hh: Tensor,
|
||||
b_ih: Option<Tensor>,
|
||||
b_hh: Option<Tensor>,
|
||||
hidden_dim: usize,
|
||||
config: GRUConfig,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
/// Creates a GRU layer.
|
||||
pub fn gru(
|
||||
in_dim: usize,
|
||||
hidden_dim: usize,
|
||||
config: GRUConfig,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<GRU> {
|
||||
let w_ih = vb.get_with_hints(
|
||||
(3 * hidden_dim, in_dim),
|
||||
"weight_ih_l0", // Only a single layer is supported.
|
||||
config.w_ih_init,
|
||||
)?;
|
||||
let w_hh = vb.get_with_hints(
|
||||
(3 * hidden_dim, hidden_dim),
|
||||
"weight_hh_l0", // Only a single layer is supported.
|
||||
config.w_hh_init,
|
||||
)?;
|
||||
let b_ih = match config.b_ih_init {
|
||||
Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_ih_l0", init)?),
|
||||
None => None,
|
||||
};
|
||||
let b_hh = match config.b_hh_init {
|
||||
Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_hh_l0", init)?),
|
||||
None => None,
|
||||
};
|
||||
Ok(GRU {
|
||||
w_ih,
|
||||
w_hh,
|
||||
b_ih,
|
||||
b_hh,
|
||||
hidden_dim,
|
||||
config,
|
||||
device: vb.device().clone(),
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
impl RNN for GRU {
|
||||
type State = GRUState;
|
||||
|
||||
fn zero_state(&self, batch_dim: usize) -> Result<Self::State> {
|
||||
let h = Tensor::zeros((batch_dim, self.hidden_dim), self.dtype, &self.device)?;
|
||||
Ok(Self::State { h })
|
||||
}
|
||||
|
||||
fn step(&self, input: &Tensor, in_state: &Self::State) -> Result<Self::State> {
|
||||
let w_ih = input.matmul(&self.w_ih.t()?)?;
|
||||
let w_hh = in_state.h.matmul(&self.w_hh.t()?)?;
|
||||
let w_ih = match &self.b_ih {
|
||||
None => w_ih,
|
||||
Some(b_ih) => w_ih.broadcast_add(b_ih)?,
|
||||
};
|
||||
let w_hh = match &self.b_hh {
|
||||
None => w_hh,
|
||||
Some(b_hh) => w_hh.broadcast_add(b_hh)?,
|
||||
};
|
||||
let chunks_ih = w_ih.chunk(3, 1)?;
|
||||
let chunks_hh = w_hh.chunk(3, 1)?;
|
||||
let r_gate = crate::ops::sigmoid(&(&chunks_ih[0] + &chunks_hh[0])?)?;
|
||||
let z_gate = crate::ops::sigmoid(&(&chunks_ih[1] + &chunks_hh[1])?)?;
|
||||
let n_gate = (&chunks_ih[2] + (r_gate * &chunks_hh[2])?)?.tanh();
|
||||
|
||||
let next_h = ((&z_gate * &in_state.h)? - ((&z_gate - 1.)? * n_gate)?)?;
|
||||
Ok(GRUState { h: next_h })
|
||||
}
|
||||
|
||||
/// The input should have dimensions [batch_size, seq_len, features].
|
||||
fn seq_init(&self, input: &Tensor, in_state: &Self::State) -> Result<(Tensor, Self::State)> {
|
||||
let (_b_size, seq_len, _features) = input.dims3()?;
|
||||
let mut state = in_state.clone();
|
||||
let mut output: Vec<Tensor> = Vec::with_capacity(seq_len);
|
||||
for seq_index in 0..seq_len {
|
||||
let input = input.i((.., seq_index, ..))?;
|
||||
state = self.step(&input, &state)?;
|
||||
output.push(state.h.clone());
|
||||
}
|
||||
let output = Tensor::cat(&output, 1)?;
|
||||
Ok((output, state))
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user