mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
Remove some todos. (#1042)
This commit is contained in:
@ -527,10 +527,10 @@ impl Module for AttentionBlock {
|
||||
.transpose_for_scores(value_proj)?
|
||||
.to_dtype(DType::F32)?;
|
||||
|
||||
let scale = f64::powf((self.channels as f64) / (self.num_heads as f64), -0.25);
|
||||
let attention_scores =
|
||||
// TODO: Check that this needs two multiplication by `scale`.
|
||||
(query_states * scale)?.matmul(&(key_states.t()? * scale)?)?;
|
||||
// scale is applied twice, hence the -0.25 here rather than -0.5.
|
||||
// https://github.com/huggingface/diffusers/blob/d3d22ce5a894becb951eec03e663951b28d45135/src/diffusers/models/attention.py#L87
|
||||
let scale = f64::powf(self.channels as f64 / self.num_heads as f64, -0.25);
|
||||
let attention_scores = (query_states * scale)?.matmul(&(key_states.t()? * scale)?)?;
|
||||
let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?;
|
||||
|
||||
let xs = attention_probs.matmul(&value_states.contiguous()?)?;
|
||||
|
Reference in New Issue
Block a user