mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Some CLIP fixes for stable diffusion. (#338)
* Some CLIP fixes for stable diffusion. * Add the avg-pool2d operation on cpu.
This commit is contained in:
@ -103,7 +103,7 @@ impl ClipTextEmbeddings {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let token_embedding = self.token_embedding.forward(xs)?;
|
||||
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
|
||||
token_embedding + position_embedding
|
||||
token_embedding.broadcast_add(&position_embedding)
|
||||
}
|
||||
}
|
||||
|
||||
@ -161,9 +161,9 @@ impl ClipAttention {
|
||||
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
|
||||
|
||||
let src_len = key_states.dim(1)?;
|
||||
let attn_weights =
|
||||
(attn_weights.reshape((bsz, self.num_attention_heads, seq_len, src_len))?
|
||||
+ causal_attention_mask)?;
|
||||
let attn_weights = attn_weights
|
||||
.reshape((bsz, self.num_attention_heads, seq_len, src_len))?
|
||||
.broadcast_add(causal_attention_mask)?;
|
||||
let attn_weights =
|
||||
attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
|
||||
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
|
||||
@ -287,7 +287,7 @@ impl ClipTextTransformer {
|
||||
// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678
|
||||
fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..seq_len)
|
||||
.flat_map(|i| (0..seq_len).map(move |j| u8::from(j > i)))
|
||||
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::MIN } else { 0. }))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
|
||||
mask.broadcast_as((bsz, seq_len, seq_len))
|
||||
|
Reference in New Issue
Block a user