Fix the rnn tests for accelerate. (#704)

This commit is contained in:
Laurent Mazare
2023-09-01 14:21:38 +02:00
committed by GitHub
parent 7529531056
commit af552a5274

View File

@ -138,7 +138,8 @@ impl RNN for LSTM {
type State = LSTMState;
fn zero_state(&self, batch_dim: usize) -> Result<Self::State> {
let zeros = Tensor::zeros((batch_dim, self.hidden_dim), self.dtype, &self.device)?;
let zeros =
Tensor::zeros((batch_dim, self.hidden_dim), self.dtype, &self.device)?.contiguous()?;
Ok(Self::State {
h: zeros.clone(),
c: zeros.clone(),
@ -287,7 +288,8 @@ impl RNN for GRU {
type State = GRUState;
fn zero_state(&self, batch_dim: usize) -> Result<Self::State> {
let h = Tensor::zeros((batch_dim, self.hidden_dim), self.dtype, &self.device)?;
let h =
Tensor::zeros((batch_dim, self.hidden_dim), self.dtype, &self.device)?.contiguous()?;
Ok(Self::State { h })
}