mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
More general seq forward functions for RNNs. (#1050)
This commit is contained in:
@ -4,7 +4,7 @@ use candle::{DType, Device, IndexOp, Result, Tensor};
|
|||||||
/// Trait for Recurrent Neural Networks.
|
/// Trait for Recurrent Neural Networks.
|
||||||
#[allow(clippy::upper_case_acronyms)]
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
pub trait RNN {
|
pub trait RNN {
|
||||||
type State;
|
type State: Clone;
|
||||||
|
|
||||||
/// A zero state from which the recurrent network is usually initialized.
|
/// A zero state from which the recurrent network is usually initialized.
|
||||||
fn zero_state(&self, batch_dim: usize) -> Result<Self::State>;
|
fn zero_state(&self, batch_dim: usize) -> Result<Self::State>;
|
||||||
@ -18,7 +18,7 @@ pub trait RNN {
|
|||||||
///
|
///
|
||||||
/// The input should have dimensions [batch_size, seq_len, features].
|
/// The input should have dimensions [batch_size, seq_len, features].
|
||||||
/// The initial state is the result of applying zero_state.
|
/// The initial state is the result of applying zero_state.
|
||||||
fn seq(&self, input: &Tensor) -> Result<(Tensor, Self::State)> {
|
fn seq(&self, input: &Tensor) -> Result<Vec<Self::State>> {
|
||||||
let batch_dim = input.dim(0)?;
|
let batch_dim = input.dim(0)?;
|
||||||
let state = self.zero_state(batch_dim)?;
|
let state = self.zero_state(batch_dim)?;
|
||||||
self.seq_init(input, &state)
|
self.seq_init(input, &state)
|
||||||
@ -27,7 +27,23 @@ pub trait RNN {
|
|||||||
/// Applies multiple steps of the recurrent network.
|
/// Applies multiple steps of the recurrent network.
|
||||||
///
|
///
|
||||||
/// The input should have dimensions [batch_size, seq_len, features].
|
/// The input should have dimensions [batch_size, seq_len, features].
|
||||||
fn seq_init(&self, input: &Tensor, state: &Self::State) -> Result<(Tensor, Self::State)>;
|
fn seq_init(&self, input: &Tensor, init_state: &Self::State) -> Result<Vec<Self::State>> {
|
||||||
|
let (_b_size, seq_len, _features) = input.dims3()?;
|
||||||
|
let mut output = Vec::with_capacity(seq_len);
|
||||||
|
for seq_index in 0..seq_len {
|
||||||
|
let input = input.i((.., seq_index, ..))?;
|
||||||
|
let state = if seq_index == 0 {
|
||||||
|
self.step(&input, init_state)?
|
||||||
|
} else {
|
||||||
|
self.step(&input, &output[seq_index - 1])?
|
||||||
|
};
|
||||||
|
output.push(state);
|
||||||
|
}
|
||||||
|
Ok(output)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts a sequence of state to a tensor.
|
||||||
|
fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The state for a LSTM network, this contains two tensors.
|
/// The state for a LSTM network, this contains two tensors.
|
||||||
@ -179,18 +195,9 @@ impl RNN for LSTM {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The input should have dimensions [batch_size, seq_len, features].
|
fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor> {
|
||||||
fn seq_init(&self, input: &Tensor, in_state: &Self::State) -> Result<(Tensor, Self::State)> {
|
let states = states.iter().map(|s| s.h.clone()).collect::<Vec<_>>();
|
||||||
let (_b_size, seq_len, _features) = input.dims3()?;
|
Tensor::cat(&states, 1)
|
||||||
let mut state = in_state.clone();
|
|
||||||
let mut output: Vec<Tensor> = Vec::with_capacity(seq_len);
|
|
||||||
for seq_index in 0..seq_len {
|
|
||||||
let input = input.i((.., seq_index, ..))?;
|
|
||||||
state = self.step(&input, &state)?;
|
|
||||||
output.push(state.h.clone());
|
|
||||||
}
|
|
||||||
let output = Tensor::cat(&output, 1)?;
|
|
||||||
Ok((output, state))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -322,17 +329,8 @@ impl RNN for GRU {
|
|||||||
Ok(GRUState { h: next_h })
|
Ok(GRUState { h: next_h })
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The input should have dimensions [batch_size, seq_len, features].
|
fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor> {
|
||||||
fn seq_init(&self, input: &Tensor, in_state: &Self::State) -> Result<(Tensor, Self::State)> {
|
let states = states.iter().map(|s| s.h.clone()).collect::<Vec<_>>();
|
||||||
let (_b_size, seq_len, _features) = input.dims3()?;
|
Tensor::cat(&states, 1)
|
||||||
let mut state = in_state.clone();
|
|
||||||
let mut output: Vec<Tensor> = Vec::with_capacity(seq_len);
|
|
||||||
for seq_index in 0..seq_len {
|
|
||||||
let input = input.i((.., seq_index, ..))?;
|
|
||||||
state = self.step(&input, &state)?;
|
|
||||||
output.push(state.h.clone());
|
|
||||||
}
|
|
||||||
let output = Tensor::cat(&output, 1)?;
|
|
||||||
Ok((output, state))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user