Stable diffusion fix. (#1993)

* Stable diffusion fix.

* And add a comment.
This commit is contained in:
Laurent Mazare
2024-04-02 14:36:28 +02:00
committed by GitHub
parent be9c200cbb
commit b23436bf90

View File

@ -533,7 +533,9 @@ impl Module for AttentionBlock {
let attention_scores = (query_states * scale)?.matmul(&(key_states.t()? * scale)?)?;
let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?;
let xs = attention_probs.matmul(&value_states.contiguous()?)?;
// TODO: revert the call to force_contiguous once the three matmul kernels have been
// adapted to handle layout with some dims set to 1.
let xs = attention_probs.matmul(&value_states.force_contiguous()?)?;
let xs = xs.to_dtype(in_dtype)?;
let xs = xs.transpose(1, 2)?.contiguous()?;
let xs = xs.flatten_from(D::Minus2)?;