F16 support for stable diffusion (#488)

* F16 support for stable diffusion.

* Keep the attention bits in F32.

* Keep more of the attention bits in F32.

* More mixed precision support.
This commit is contained in:
Laurent Mazare
2023-08-17 13:48:56 +01:00
committed by GitHub
parent c3176f0dfb
commit 5d99026fd2
6 changed files with 99 additions and 43 deletions

View File

@ -1,5 +1,5 @@
//! Attention Based Building Blocks
use candle::{IndexOp, Result, Tensor, D};
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn as nn;
#[derive(Debug)]
@ -147,6 +147,10 @@ impl CrossAttention {
) -> Result<Tensor> {
let batch_size_attention = query.dim(0)?;
let mut hidden_states = Vec::with_capacity(batch_size_attention / slice_size);
let in_dtype = query.dtype();
let query = query.to_dtype(DType::F32)?;
let key = key.to_dtype(DType::F32)?;
let value = value.to_dtype(DType::F32)?;
for i in 0..batch_size_attention / slice_size {
let start_idx = i * slice_size;
@ -158,7 +162,7 @@ impl CrossAttention {
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(&value.i(start_idx..end_idx)?)?;
hidden_states.push(xs)
}
let hidden_states = Tensor::stack(&hidden_states, 0)?;
let hidden_states = Tensor::stack(&hidden_states, 0)?.to_dtype(in_dtype)?;
self.reshape_batch_dim_to_heads(&hidden_states)
}
@ -183,8 +187,14 @@ impl CrossAttention {
.squeeze(0)?
.to_dtype(init_dtype)?
} else {
let in_dtype = query.dtype();
let query = query.to_dtype(DType::F32)?;
let key = key.to_dtype(DType::F32)?;
let value = value.to_dtype(DType::F32)?;
let xs = query.matmul(&(key.t()? * self.scale)?)?;
nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?
nn::ops::softmax(&xs, D::Minus1)?
.matmul(&value)?
.to_dtype(in_dtype)?
};
self.reshape_batch_dim_to_heads(&xs)
}
@ -457,10 +467,15 @@ impl AttentionBlock {
let num_heads = channels / num_head_channels;
let group_norm =
nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?;
let query = nn::linear(channels, channels, vs.pp("query"))?;
let key = nn::linear(channels, channels, vs.pp("key"))?;
let value = nn::linear(channels, channels, vs.pp("value"))?;
let proj_attn = nn::linear(channels, channels, vs.pp("proj_attn"))?;
let (q_path, k_path, v_path, out_path) = if vs.dtype() == DType::F16 {
("to_q", "to_k", "to_v", "to_out.0")
} else {
("query", "key", "value", "proj_attn")
};
let query = nn::linear(channels, channels, vs.pp(q_path))?;
let key = nn::linear(channels, channels, vs.pp(k_path))?;
let value = nn::linear(channels, channels, vs.pp(v_path))?;
let proj_attn = nn::linear(channels, channels, vs.pp(out_path))?;
let span = tracing::span!(tracing::Level::TRACE, "attn-block");
Ok(Self {
group_norm,
@ -483,6 +498,7 @@ impl AttentionBlock {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let in_dtype = xs.dtype();
let residual = xs;
let (batch, channel, height, width) = xs.dims4()?;
let xs = self
@ -495,9 +511,13 @@ impl AttentionBlock {
let key_proj = self.key.forward(&xs)?;
let value_proj = self.value.forward(&xs)?;
let query_states = self.transpose_for_scores(query_proj)?;
let key_states = self.transpose_for_scores(key_proj)?;
let value_states = self.transpose_for_scores(value_proj)?;
let query_states = self
.transpose_for_scores(query_proj)?
.to_dtype(DType::F32)?;
let key_states = self.transpose_for_scores(key_proj)?.to_dtype(DType::F32)?;
let value_states = self
.transpose_for_scores(value_proj)?
.to_dtype(DType::F32)?;
let scale = f64::powf((self.channels as f64) / (self.num_heads as f64), -0.25);
let attention_scores =
@ -506,6 +526,7 @@ impl AttentionBlock {
let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?;
let xs = attention_probs.matmul(&value_states.contiguous()?)?;
let xs = xs.to_dtype(in_dtype)?;
let xs = xs.transpose(1, 2)?.contiguous()?;
let xs = xs.flatten_from(D::Minus2)?;
let xs = self

View File

@ -5,7 +5,7 @@
//! pairs of images with related texts.
//!
//! https://github.com/openai/CLIP
use candle::{Device, Result, Tensor, D};
use candle::{DType, Device, Result, Tensor, D};
use candle_nn as nn;
#[derive(Debug, Clone, Copy)]
@ -146,18 +146,22 @@ impl ClipAttention {
}
fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
let in_dtype = xs.dtype();
let (bsz, seq_len, embed_dim) = xs.dims3()?;
let query_states = (self.q_proj.forward(xs)? * self.scale)?;
let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
let query_states = self
.shape(&query_states, seq_len, bsz)?
.reshape(proj_shape)?;
.reshape(proj_shape)?
.to_dtype(DType::F32)?;
let key_states = self
.shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
.reshape(proj_shape)?;
.reshape(proj_shape)?
.to_dtype(DType::F32)?;
let value_states = self
.shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
.reshape(proj_shape)?;
.reshape(proj_shape)?
.to_dtype(DType::F32)?;
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
let src_len = key_states.dim(1)?;
@ -168,7 +172,7 @@ impl ClipAttention {
attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
let attn_output = attn_weights.matmul(&value_states)?;
let attn_output = attn_weights.matmul(&value_states)?.to_dtype(in_dtype)?;
let attn_output = attn_output
.reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
.transpose(1, 2)?

View File

@ -44,10 +44,10 @@ impl Timesteps {
impl Timesteps {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let half_dim = (self.num_channels / 2) as u32;
let exponent =
(Tensor::arange(0, half_dim, xs.device())?.to_dtype(xs.dtype())? * -f64::ln(10000.))?;
let exponent = (Tensor::arange(0, half_dim, xs.device())?.to_dtype(candle::DType::F32)?
* -f64::ln(10000.))?;
let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?;
let emb = exponent.exp()?;
let emb = exponent.exp()?.to_dtype(xs.dtype())?;
// emb = timesteps[:, None].float() * emb[None, :]
let emb = xs.unsqueeze(D::Minus1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
let (cos, sin) = (emb.cos()?, emb.sin()?);

View File

@ -93,6 +93,9 @@ struct Args {
#[arg(long)]
use_flash_attn: bool,
#[arg(long)]
use_f16: bool,
}
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
@ -117,21 +120,39 @@ impl StableDiffusionVersion {
}
}
fn unet_file(&self) -> &'static str {
fn unet_file(&self, use_f16: bool) -> &'static str {
match self {
Self::V1_5 | Self::V2_1 => "unet/diffusion_pytorch_model.safetensors",
Self::V1_5 | Self::V2_1 => {
if use_f16 {
"unet/diffusion_pytorch_model.fp16.safetensors"
} else {
"unet/diffusion_pytorch_model.safetensors"
}
}
}
}
fn vae_file(&self) -> &'static str {
fn vae_file(&self, use_f16: bool) -> &'static str {
match self {
Self::V1_5 | Self::V2_1 => "vae/diffusion_pytorch_model.safetensors",
Self::V1_5 | Self::V2_1 => {
if use_f16 {
"vae/diffusion_pytorch_model.fp16.safetensors"
} else {
"vae/diffusion_pytorch_model.safetensors"
}
}
}
}
fn clip_file(&self) -> &'static str {
fn clip_file(&self, use_f16: bool) -> &'static str {
match self {
Self::V1_5 | Self::V2_1 => "text_encoder/model.safetensors",
Self::V1_5 | Self::V2_1 => {
if use_f16 {
"text_encoder/model.fp16.safetensors"
} else {
"text_encoder/model.safetensors"
}
}
}
}
}
@ -144,6 +165,7 @@ impl ModelFile {
&self,
filename: Option<String>,
version: StableDiffusionVersion,
use_f16: bool,
) -> Result<std::path::PathBuf> {
use hf_hub::api::sync::Api;
match filename {
@ -151,9 +173,9 @@ impl ModelFile {
None => {
let (repo, path) = match self {
Self::Tokenizer => (Self::TOKENIZER_REPO, Self::TOKENIZER_PATH),
Self::Clip => (version.repo(), version.clip_file()),
Self::Unet => (version.repo(), version.unet_file()),
Self::Vae => (version.repo(), version.vae_file()),
Self::Clip => (version.repo(), version.clip_file(use_f16)),
Self::Unet => (version.repo(), version.unet_file(use_f16)),
Self::Vae => (version.repo(), version.vae_file(use_f16)),
};
let filename = Api::new()?.model(repo.to_string()).get(path)?;
Ok(filename)
@ -209,6 +231,8 @@ fn run(args: Args) -> Result<()> {
vae_weights,
unet_weights,
tracing,
use_f16,
use_flash_attn,
..
} = args;
@ -220,6 +244,7 @@ fn run(args: Args) -> Result<()> {
None
};
let dtype = if use_f16 { DType::F16 } else { DType::F32 };
let sd_config = match sd_version {
StableDiffusionVersion::V1_5 => {
stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size, height, width)
@ -232,7 +257,7 @@ fn run(args: Args) -> Result<()> {
let scheduler = sd_config.build_scheduler(n_steps)?;
let device = candle_examples::device(cpu)?;
let tokenizer = ModelFile::Tokenizer.get(tokenizer, sd_version)?;
let tokenizer = ModelFile::Tokenizer.get(tokenizer, sd_version, use_f16)?;
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let pad_id = match &sd_config.clip.pad_with {
Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
@ -260,18 +285,20 @@ fn run(args: Args) -> Result<()> {
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?;
println!("Building the Clip transformer.");
let clip_weights = ModelFile::Clip.get(clip_weights, sd_version)?;
let text_model = sd_config.build_clip_transformer(&clip_weights, &device)?;
let text_embeddings = text_model.forward(&tokens)?;
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?;
let text_embeddings = {
let clip_weights = ModelFile::Clip.get(clip_weights, sd_version, false)?;
let text_model = sd_config.build_clip_transformer(&clip_weights, &device, DType::F32)?;
let text_embeddings = text_model.forward(&tokens)?;
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?
};
println!("Building the autoencoder.");
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version)?;
let vae = sd_config.build_vae(&vae_weights, &device)?;
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
let vae = sd_config.build_vae(&vae_weights, &device, dtype)?;
println!("Building the unet.");
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version)?;
let unet = sd_config.build_unet(&unet_weights, &device, 4, args.use_flash_attn)?;
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?;
let unet = sd_config.build_unet(&unet_weights, &device, 4, use_flash_attn, dtype)?;
let bsize = 1;
for idx in 0..num_samples {
@ -280,7 +307,8 @@ fn run(args: Args) -> Result<()> {
1f32,
(bsize, 4, sd_config.height / 8, sd_config.width / 8),
&device,
)?;
)?
.to_dtype(dtype)?;
// scale the initial noise by the standard deviation required by the scheduler
latents = (latents * scheduler.init_noise_sigma())?;

View File

@ -159,10 +159,11 @@ impl StableDiffusionConfig {
&self,
vae_weights: P,
device: &Device,
dtype: DType,
) -> Result<vae::AutoEncoderKL> {
let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? };
let weights = weights.deserialize()?;
let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
let autoencoder = vae::AutoEncoderKL::new(vs_ae, 3, 3, self.autoencoder.clone())?;
Ok(autoencoder)
@ -174,10 +175,11 @@ impl StableDiffusionConfig {
device: &Device,
in_channels: usize,
use_flash_attn: bool,
dtype: DType,
) -> Result<unet_2d::UNet2DConditionModel> {
let weights = unsafe { candle::safetensors::MmapedFile::new(unet_weights)? };
let weights = weights.deserialize()?;
let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
let unet = unet_2d::UNet2DConditionModel::new(
vs_unet,
in_channels,
@ -196,10 +198,11 @@ impl StableDiffusionConfig {
&self,
clip_weights: P,
device: &Device,
dtype: DType,
) -> Result<clip::ClipTextTransformer> {
let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
let weights = weights.deserialize()?;
let vs = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
let text_model = clip::ClipTextTransformer::new(vs, &self.clip)?;
Ok(text_model)
}

View File

@ -5,7 +5,7 @@
use crate::embeddings::{TimestepEmbedding, Timesteps};
use crate::unet_2d_blocks::*;
use crate::utils::{conv2d, Conv2d};
use candle::{DType, Result, Tensor};
use candle::{Result, Tensor};
use candle_nn as nn;
#[derive(Debug, Clone, Copy)]
@ -316,7 +316,7 @@ impl UNet2DConditionModel {
xs.clone()
};
// 1. time
let emb = (Tensor::ones(bsize, DType::F32, device)? * timestep)?;
let emb = (Tensor::ones(bsize, xs.dtype(), device)? * timestep)?;
let emb = self.time_proj.forward(&emb)?;
let emb = self.time_embedding.forward(&emb)?;
// 2. pre-process