mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
F16 support for stable diffusion (#488)
* F16 support for stable diffusion. * Keep the attention bits in F32. * Keep more of the attention bits in F32. * More mixed precision support.
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
//! Attention Based Building Blocks
|
||||
use candle::{IndexOp, Result, Tensor, D};
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -147,6 +147,10 @@ impl CrossAttention {
|
||||
) -> Result<Tensor> {
|
||||
let batch_size_attention = query.dim(0)?;
|
||||
let mut hidden_states = Vec::with_capacity(batch_size_attention / slice_size);
|
||||
let in_dtype = query.dtype();
|
||||
let query = query.to_dtype(DType::F32)?;
|
||||
let key = key.to_dtype(DType::F32)?;
|
||||
let value = value.to_dtype(DType::F32)?;
|
||||
|
||||
for i in 0..batch_size_attention / slice_size {
|
||||
let start_idx = i * slice_size;
|
||||
@ -158,7 +162,7 @@ impl CrossAttention {
|
||||
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(&value.i(start_idx..end_idx)?)?;
|
||||
hidden_states.push(xs)
|
||||
}
|
||||
let hidden_states = Tensor::stack(&hidden_states, 0)?;
|
||||
let hidden_states = Tensor::stack(&hidden_states, 0)?.to_dtype(in_dtype)?;
|
||||
self.reshape_batch_dim_to_heads(&hidden_states)
|
||||
}
|
||||
|
||||
@ -183,8 +187,14 @@ impl CrossAttention {
|
||||
.squeeze(0)?
|
||||
.to_dtype(init_dtype)?
|
||||
} else {
|
||||
let in_dtype = query.dtype();
|
||||
let query = query.to_dtype(DType::F32)?;
|
||||
let key = key.to_dtype(DType::F32)?;
|
||||
let value = value.to_dtype(DType::F32)?;
|
||||
let xs = query.matmul(&(key.t()? * self.scale)?)?;
|
||||
nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?
|
||||
nn::ops::softmax(&xs, D::Minus1)?
|
||||
.matmul(&value)?
|
||||
.to_dtype(in_dtype)?
|
||||
};
|
||||
self.reshape_batch_dim_to_heads(&xs)
|
||||
}
|
||||
@ -457,10 +467,15 @@ impl AttentionBlock {
|
||||
let num_heads = channels / num_head_channels;
|
||||
let group_norm =
|
||||
nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?;
|
||||
let query = nn::linear(channels, channels, vs.pp("query"))?;
|
||||
let key = nn::linear(channels, channels, vs.pp("key"))?;
|
||||
let value = nn::linear(channels, channels, vs.pp("value"))?;
|
||||
let proj_attn = nn::linear(channels, channels, vs.pp("proj_attn"))?;
|
||||
let (q_path, k_path, v_path, out_path) = if vs.dtype() == DType::F16 {
|
||||
("to_q", "to_k", "to_v", "to_out.0")
|
||||
} else {
|
||||
("query", "key", "value", "proj_attn")
|
||||
};
|
||||
let query = nn::linear(channels, channels, vs.pp(q_path))?;
|
||||
let key = nn::linear(channels, channels, vs.pp(k_path))?;
|
||||
let value = nn::linear(channels, channels, vs.pp(v_path))?;
|
||||
let proj_attn = nn::linear(channels, channels, vs.pp(out_path))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attn-block");
|
||||
Ok(Self {
|
||||
group_norm,
|
||||
@ -483,6 +498,7 @@ impl AttentionBlock {
|
||||
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let in_dtype = xs.dtype();
|
||||
let residual = xs;
|
||||
let (batch, channel, height, width) = xs.dims4()?;
|
||||
let xs = self
|
||||
@ -495,9 +511,13 @@ impl AttentionBlock {
|
||||
let key_proj = self.key.forward(&xs)?;
|
||||
let value_proj = self.value.forward(&xs)?;
|
||||
|
||||
let query_states = self.transpose_for_scores(query_proj)?;
|
||||
let key_states = self.transpose_for_scores(key_proj)?;
|
||||
let value_states = self.transpose_for_scores(value_proj)?;
|
||||
let query_states = self
|
||||
.transpose_for_scores(query_proj)?
|
||||
.to_dtype(DType::F32)?;
|
||||
let key_states = self.transpose_for_scores(key_proj)?.to_dtype(DType::F32)?;
|
||||
let value_states = self
|
||||
.transpose_for_scores(value_proj)?
|
||||
.to_dtype(DType::F32)?;
|
||||
|
||||
let scale = f64::powf((self.channels as f64) / (self.num_heads as f64), -0.25);
|
||||
let attention_scores =
|
||||
@ -506,6 +526,7 @@ impl AttentionBlock {
|
||||
let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?;
|
||||
|
||||
let xs = attention_probs.matmul(&value_states.contiguous()?)?;
|
||||
let xs = xs.to_dtype(in_dtype)?;
|
||||
let xs = xs.transpose(1, 2)?.contiguous()?;
|
||||
let xs = xs.flatten_from(D::Minus2)?;
|
||||
let xs = self
|
||||
|
Reference in New Issue
Block a user