Some CLIP fixes for stable diffusion. (#338)

* Some CLIP fixes for stable diffusion.

* Add the avg-pool2d operation on cpu.
This commit is contained in:
Laurent Mazare
2023-08-07 19:31:45 +02:00
committed by GitHub
parent 2345b8ce3f
commit fc265d9dcf
7 changed files with 81 additions and 18 deletions

View File

@ -103,7 +103,7 @@ impl ClipTextEmbeddings {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let token_embedding = self.token_embedding.forward(xs)?;
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
token_embedding + position_embedding
token_embedding.broadcast_add(&position_embedding)
}
}
@ -161,9 +161,9 @@ impl ClipAttention {
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
let src_len = key_states.dim(1)?;
let attn_weights =
(attn_weights.reshape((bsz, self.num_attention_heads, seq_len, src_len))?
+ causal_attention_mask)?;
let attn_weights = attn_weights
.reshape((bsz, self.num_attention_heads, seq_len, src_len))?
.broadcast_add(causal_attention_mask)?;
let attn_weights =
attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
@ -287,7 +287,7 @@ impl ClipTextTransformer {
// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678
fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..seq_len)
.flat_map(|i| (0..seq_len).map(move |j| u8::from(j > i)))
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::MIN } else { 0. }))
.collect();
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
mask.broadcast_as((bsz, seq_len, seq_len))