From 5ebcfeaf0f5af69bb2f74385e8d6b020d4a3b8df Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 16 Feb 2024 09:17:35 +0100 Subject: [PATCH] Make the r, k, v tensors contiguous. (#1719) --- candle-transformers/src/models/rwkv_v5.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/candle-transformers/src/models/rwkv_v5.rs b/candle-transformers/src/models/rwkv_v5.rs index 04dbfc45..d11cdedd 100644 --- a/candle-transformers/src/models/rwkv_v5.rs +++ b/candle-transformers/src/models/rwkv_v5.rs @@ -165,9 +165,9 @@ impl SelfAttention { let mut out: Vec = Vec::with_capacity(t); for t_ in 0..t { // - let rt = receptance.i((.., .., t_..t_ + 1))?; - let kt = key.i((.., .., .., t_..t_ + 1))?; - let vt = value.i((.., .., t_..t_ + 1))?; + let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?; + let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?; + let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?; let at = kt.matmul(&vt)?; let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?; let out_ = rt.matmul(&rhs)?.squeeze(2)?;