Remove some todos. (#1042)

This commit is contained in:
Laurent Mazare
2023-10-05 22:42:20 +01:00
committed by GitHub
parent 716883e9b0
commit 4631c48273
2 changed files with 6 additions and 6 deletions

View File

@ -40,7 +40,7 @@ impl Default for Config {
num_attention_heads: 16,
layerdrop: 0.0,
use_cache: true,
activation_function: Activation::Gelu, // TODO: Handle old style gelu.
activation_function: Activation::Gelu,
hidden_size: 1024,
dropout: 0.1,
attention_dropout: 0.0,
@ -66,7 +66,7 @@ impl Config {
num_attention_heads: 16,
layerdrop: 0.0,
use_cache: true,
activation_function: Activation::Gelu, // TODO: Handle old style gelu.
activation_function: Activation::Gelu,
hidden_size: 1024,
dropout: 0.1,
attention_dropout: 0.0,

View File

@ -527,10 +527,10 @@ impl Module for AttentionBlock {
.transpose_for_scores(value_proj)?
.to_dtype(DType::F32)?;
let scale = f64::powf((self.channels as f64) / (self.num_heads as f64), -0.25);
let attention_scores =
// TODO: Check that this needs two multiplication by `scale`.
(query_states * scale)?.matmul(&(key_states.t()? * scale)?)?;
// scale is applied twice, hence the -0.25 here rather than -0.5.
// https://github.com/huggingface/diffusers/blob/d3d22ce5a894becb951eec03e663951b28d45135/src/diffusers/models/attention.py#L87
let scale = f64::powf(self.channels as f64 / self.num_heads as f64, -0.25);
let attention_scores = (query_states * scale)?.matmul(&(key_states.t()? * scale)?)?;
let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?;
let xs = attention_probs.matmul(&value_states.contiguous()?)?;