mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Fix the rnn tests for accelerate. (#704)
This commit is contained in:
@ -138,7 +138,8 @@ impl RNN for LSTM {
|
||||
type State = LSTMState;
|
||||
|
||||
fn zero_state(&self, batch_dim: usize) -> Result<Self::State> {
|
||||
let zeros = Tensor::zeros((batch_dim, self.hidden_dim), self.dtype, &self.device)?;
|
||||
let zeros =
|
||||
Tensor::zeros((batch_dim, self.hidden_dim), self.dtype, &self.device)?.contiguous()?;
|
||||
Ok(Self::State {
|
||||
h: zeros.clone(),
|
||||
c: zeros.clone(),
|
||||
@ -287,7 +288,8 @@ impl RNN for GRU {
|
||||
type State = GRUState;
|
||||
|
||||
fn zero_state(&self, batch_dim: usize) -> Result<Self::State> {
|
||||
let h = Tensor::zeros((batch_dim, self.hidden_dim), self.dtype, &self.device)?;
|
||||
let h =
|
||||
Tensor::zeros((batch_dim, self.hidden_dim), self.dtype, &self.device)?.contiguous()?;
|
||||
Ok(Self::State { h })
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user