Make the r, k, v tensors contiguous. (#1719)

This commit is contained in:
Laurent Mazare
2024-02-16 09:17:35 +01:00
committed by GitHub
parent 7c7400fb63
commit 5ebcfeaf0f

View File

@ -165,9 +165,9 @@ impl SelfAttention {
let mut out: Vec<Tensor> = Vec::with_capacity(t);
for t_ in 0..t {
//
let rt = receptance.i((.., .., t_..t_ + 1))?;
let kt = key.i((.., .., .., t_..t_ + 1))?;
let vt = value.i((.., .., t_..t_ + 1))?;
let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?;
let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?;
let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?;
let at = kt.matmul(&vt)?;
let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
let out_ = rt.matmul(&rhs)?.squeeze(2)?;