From b23436bf90b99eb17aed36aaa219875d3c962a7e Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 2 Apr 2024 14:36:28 +0200 Subject: [PATCH] Stable diffusion fix. (#1993) * Stable diffusion fix. * And add a comment. --- candle-transformers/src/models/stable_diffusion/attention.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/candle-transformers/src/models/stable_diffusion/attention.rs b/candle-transformers/src/models/stable_diffusion/attention.rs index 07ce0fe4..05e51e44 100644 --- a/candle-transformers/src/models/stable_diffusion/attention.rs +++ b/candle-transformers/src/models/stable_diffusion/attention.rs @@ -533,7 +533,9 @@ impl Module for AttentionBlock { let attention_scores = (query_states * scale)?.matmul(&(key_states.t()? * scale)?)?; let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?; - let xs = attention_probs.matmul(&value_states.contiguous()?)?; + // TODO: revert the call to force_contiguous once the three matmul kernels have been + // adapted to handle layout with some dims set to 1. + let xs = attention_probs.matmul(&value_states.force_contiguous()?)?; let xs = xs.to_dtype(in_dtype)?; let xs = xs.transpose(1, 2)?.contiguous()?; let xs = xs.flatten_from(D::Minus2)?;