diff --git a/candle-examples/examples/musicgen/musicgen_model.rs b/candle-examples/examples/musicgen/musicgen_model.rs index d6d8ae15..dc8c3667 100644 --- a/candle-examples/examples/musicgen/musicgen_model.rs +++ b/candle-examples/examples/musicgen/musicgen_model.rs @@ -40,7 +40,7 @@ impl Default for Config { num_attention_heads: 16, layerdrop: 0.0, use_cache: true, - activation_function: Activation::Gelu, // TODO: Handle old style gelu. + activation_function: Activation::Gelu, hidden_size: 1024, dropout: 0.1, attention_dropout: 0.0, @@ -66,7 +66,7 @@ impl Config { num_attention_heads: 16, layerdrop: 0.0, use_cache: true, - activation_function: Activation::Gelu, // TODO: Handle old style gelu. + activation_function: Activation::Gelu, hidden_size: 1024, dropout: 0.1, attention_dropout: 0.0, diff --git a/candle-transformers/src/models/stable_diffusion/attention.rs b/candle-transformers/src/models/stable_diffusion/attention.rs index b3ea91f9..07ce0fe4 100644 --- a/candle-transformers/src/models/stable_diffusion/attention.rs +++ b/candle-transformers/src/models/stable_diffusion/attention.rs @@ -527,10 +527,10 @@ impl Module for AttentionBlock { .transpose_for_scores(value_proj)? .to_dtype(DType::F32)?; - let scale = f64::powf((self.channels as f64) / (self.num_heads as f64), -0.25); - let attention_scores = - // TODO: Check that this needs two multiplication by `scale`. - (query_states * scale)?.matmul(&(key_states.t()? * scale)?)?; + // scale is applied twice, hence the -0.25 here rather than -0.5. + // https://github.com/huggingface/diffusers/blob/d3d22ce5a894becb951eec03e663951b28d45135/src/diffusers/models/attention.py#L87 + let scale = f64::powf(self.channels as f64 / self.num_heads as f64, -0.25); + let attention_scores = (query_states * scale)?.matmul(&(key_states.t()? * scale)?)?; let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?; let xs = attention_probs.matmul(&value_states.contiguous()?)?;