Make tensor contiguous before the repeat_kv calls to avoid strided copies (#2953)

This commit is contained in:
Borek Požár
2025-05-14 10:47:28 +02:00
committed by GitHub
parent 485ddf2996
commit 6bd61727bc

View File

@ -217,6 +217,10 @@ impl AttentionWeights {
}
let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
// Make tensor contiguous to avoid some strided copies
let k = k.contiguous()?;
let v = v.contiguous()?;
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;