F16 support for stable diffusion (#488)

* F16 support for stable diffusion.

* Keep the attention bits in F32.

* Keep more of the attention bits in F32.

* More mixed precision support.
This commit is contained in:
Laurent Mazare
2023-08-17 13:48:56 +01:00
committed by GitHub
parent c3176f0dfb
commit 5d99026fd2
6 changed files with 99 additions and 43 deletions

View File

@ -5,7 +5,7 @@
//! pairs of images with related texts.
//!
//! https://github.com/openai/CLIP
use candle::{Device, Result, Tensor, D};
use candle::{DType, Device, Result, Tensor, D};
use candle_nn as nn;
#[derive(Debug, Clone, Copy)]
@ -146,18 +146,22 @@ impl ClipAttention {
}
fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
let in_dtype = xs.dtype();
let (bsz, seq_len, embed_dim) = xs.dims3()?;
let query_states = (self.q_proj.forward(xs)? * self.scale)?;
let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
let query_states = self
.shape(&query_states, seq_len, bsz)?
.reshape(proj_shape)?;
.reshape(proj_shape)?
.to_dtype(DType::F32)?;
let key_states = self
.shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
.reshape(proj_shape)?;
.reshape(proj_shape)?
.to_dtype(DType::F32)?;
let value_states = self
.shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
.reshape(proj_shape)?;
.reshape(proj_shape)?
.to_dtype(DType::F32)?;
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
let src_len = key_states.dim(1)?;
@ -168,7 +172,7 @@ impl ClipAttention {
attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
let attn_output = attn_weights.matmul(&value_states)?;
let attn_output = attn_weights.matmul(&value_states)?.to_dtype(in_dtype)?;
let attn_output = attn_output
.reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
.transpose(1, 2)?