Improve the handling of matmul with squeezed layouts. (#1998)

* Improve the handling of matmul with squeezed layouts.

* Fix for the cuda backend.

* Revert the temporary fix.
This commit is contained in:
Laurent Mazare
2024-04-02 23:17:05 +02:00
committed by GitHub
parent d17b2cdad9
commit 08c049def3
5 changed files with 151 additions and 139 deletions

View File

@ -535,7 +535,7 @@ impl Module for AttentionBlock {
// TODO: revert the call to force_contiguous once the three matmul kernels have been
// adapted to handle layout with some dims set to 1.
let xs = attention_probs.matmul(&value_states.force_contiguous()?)?;
let xs = attention_probs.matmul(&value_states)?;
let xs = xs.to_dtype(in_dtype)?;
let xs = xs.transpose(1, 2)?.contiguous()?;
let xs = xs.flatten_from(D::Minus2)?;