mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
F16 support for stable diffusion (#488)
* F16 support for stable diffusion. * Keep the attention bits in F32. * Keep more of the attention bits in F32. * More mixed precision support.
This commit is contained in:
@ -5,7 +5,7 @@
|
||||
//! pairs of images with related texts.
|
||||
//!
|
||||
//! https://github.com/openai/CLIP
|
||||
use candle::{Device, Result, Tensor, D};
|
||||
use candle::{DType, Device, Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
@ -146,18 +146,22 @@ impl ClipAttention {
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let in_dtype = xs.dtype();
|
||||
let (bsz, seq_len, embed_dim) = xs.dims3()?;
|
||||
let query_states = (self.q_proj.forward(xs)? * self.scale)?;
|
||||
let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
|
||||
let query_states = self
|
||||
.shape(&query_states, seq_len, bsz)?
|
||||
.reshape(proj_shape)?;
|
||||
.reshape(proj_shape)?
|
||||
.to_dtype(DType::F32)?;
|
||||
let key_states = self
|
||||
.shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
|
||||
.reshape(proj_shape)?;
|
||||
.reshape(proj_shape)?
|
||||
.to_dtype(DType::F32)?;
|
||||
let value_states = self
|
||||
.shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
|
||||
.reshape(proj_shape)?;
|
||||
.reshape(proj_shape)?
|
||||
.to_dtype(DType::F32)?;
|
||||
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
|
||||
|
||||
let src_len = key_states.dim(1)?;
|
||||
@ -168,7 +172,7 @@ impl ClipAttention {
|
||||
attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
|
||||
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
|
||||
|
||||
let attn_output = attn_weights.matmul(&value_states)?;
|
||||
let attn_output = attn_weights.matmul(&value_states)?.to_dtype(in_dtype)?;
|
||||
let attn_output = attn_output
|
||||
.reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
|
Reference in New Issue
Block a user