From 5d99026fd24506c6723d72aff4fee820c51b587a Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 17 Aug 2023 13:48:56 +0100 Subject: [PATCH] F16 support for stable diffusion (#488) * F16 support for stable diffusion. * Keep the attention bits in F32. * Keep more of the attention bits in F32. * More mixed precision support. --- .../examples/stable-diffusion/attention.rs | 41 ++++++++--- .../examples/stable-diffusion/clip.rs | 14 ++-- .../examples/stable-diffusion/embeddings.rs | 6 +- .../examples/stable-diffusion/main.rs | 68 +++++++++++++------ .../stable-diffusion/stable_diffusion.rs | 9 ++- .../examples/stable-diffusion/unet_2d.rs | 4 +- 6 files changed, 99 insertions(+), 43 deletions(-) diff --git a/candle-examples/examples/stable-diffusion/attention.rs b/candle-examples/examples/stable-diffusion/attention.rs index dc414889..d981b6f4 100644 --- a/candle-examples/examples/stable-diffusion/attention.rs +++ b/candle-examples/examples/stable-diffusion/attention.rs @@ -1,5 +1,5 @@ //! Attention Based Building Blocks -use candle::{IndexOp, Result, Tensor, D}; +use candle::{DType, IndexOp, Result, Tensor, D}; use candle_nn as nn; #[derive(Debug)] @@ -147,6 +147,10 @@ impl CrossAttention { ) -> Result { let batch_size_attention = query.dim(0)?; let mut hidden_states = Vec::with_capacity(batch_size_attention / slice_size); + let in_dtype = query.dtype(); + let query = query.to_dtype(DType::F32)?; + let key = key.to_dtype(DType::F32)?; + let value = value.to_dtype(DType::F32)?; for i in 0..batch_size_attention / slice_size { let start_idx = i * slice_size; @@ -158,7 +162,7 @@ impl CrossAttention { let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(&value.i(start_idx..end_idx)?)?; hidden_states.push(xs) } - let hidden_states = Tensor::stack(&hidden_states, 0)?; + let hidden_states = Tensor::stack(&hidden_states, 0)?.to_dtype(in_dtype)?; self.reshape_batch_dim_to_heads(&hidden_states) } @@ -183,8 +187,14 @@ impl CrossAttention { .squeeze(0)? .to_dtype(init_dtype)? } else { + let in_dtype = query.dtype(); + let query = query.to_dtype(DType::F32)?; + let key = key.to_dtype(DType::F32)?; + let value = value.to_dtype(DType::F32)?; let xs = query.matmul(&(key.t()? * self.scale)?)?; - nn::ops::softmax(&xs, D::Minus1)?.matmul(value)? + nn::ops::softmax(&xs, D::Minus1)? + .matmul(&value)? + .to_dtype(in_dtype)? }; self.reshape_batch_dim_to_heads(&xs) } @@ -457,10 +467,15 @@ impl AttentionBlock { let num_heads = channels / num_head_channels; let group_norm = nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?; - let query = nn::linear(channels, channels, vs.pp("query"))?; - let key = nn::linear(channels, channels, vs.pp("key"))?; - let value = nn::linear(channels, channels, vs.pp("value"))?; - let proj_attn = nn::linear(channels, channels, vs.pp("proj_attn"))?; + let (q_path, k_path, v_path, out_path) = if vs.dtype() == DType::F16 { + ("to_q", "to_k", "to_v", "to_out.0") + } else { + ("query", "key", "value", "proj_attn") + }; + let query = nn::linear(channels, channels, vs.pp(q_path))?; + let key = nn::linear(channels, channels, vs.pp(k_path))?; + let value = nn::linear(channels, channels, vs.pp(v_path))?; + let proj_attn = nn::linear(channels, channels, vs.pp(out_path))?; let span = tracing::span!(tracing::Level::TRACE, "attn-block"); Ok(Self { group_norm, @@ -483,6 +498,7 @@ impl AttentionBlock { pub fn forward(&self, xs: &Tensor) -> Result { let _enter = self.span.enter(); + let in_dtype = xs.dtype(); let residual = xs; let (batch, channel, height, width) = xs.dims4()?; let xs = self @@ -495,9 +511,13 @@ impl AttentionBlock { let key_proj = self.key.forward(&xs)?; let value_proj = self.value.forward(&xs)?; - let query_states = self.transpose_for_scores(query_proj)?; - let key_states = self.transpose_for_scores(key_proj)?; - let value_states = self.transpose_for_scores(value_proj)?; + let query_states = self + .transpose_for_scores(query_proj)? + .to_dtype(DType::F32)?; + let key_states = self.transpose_for_scores(key_proj)?.to_dtype(DType::F32)?; + let value_states = self + .transpose_for_scores(value_proj)? + .to_dtype(DType::F32)?; let scale = f64::powf((self.channels as f64) / (self.num_heads as f64), -0.25); let attention_scores = @@ -506,6 +526,7 @@ impl AttentionBlock { let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?; let xs = attention_probs.matmul(&value_states.contiguous()?)?; + let xs = xs.to_dtype(in_dtype)?; let xs = xs.transpose(1, 2)?.contiguous()?; let xs = xs.flatten_from(D::Minus2)?; let xs = self diff --git a/candle-examples/examples/stable-diffusion/clip.rs b/candle-examples/examples/stable-diffusion/clip.rs index ca00b417..29591f55 100644 --- a/candle-examples/examples/stable-diffusion/clip.rs +++ b/candle-examples/examples/stable-diffusion/clip.rs @@ -5,7 +5,7 @@ //! pairs of images with related texts. //! //! https://github.com/openai/CLIP -use candle::{Device, Result, Tensor, D}; +use candle::{DType, Device, Result, Tensor, D}; use candle_nn as nn; #[derive(Debug, Clone, Copy)] @@ -146,18 +146,22 @@ impl ClipAttention { } fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result { + let in_dtype = xs.dtype(); let (bsz, seq_len, embed_dim) = xs.dims3()?; let query_states = (self.q_proj.forward(xs)? * self.scale)?; let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim); let query_states = self .shape(&query_states, seq_len, bsz)? - .reshape(proj_shape)?; + .reshape(proj_shape)? + .to_dtype(DType::F32)?; let key_states = self .shape(&self.k_proj.forward(xs)?, seq_len, bsz)? - .reshape(proj_shape)?; + .reshape(proj_shape)? + .to_dtype(DType::F32)?; let value_states = self .shape(&self.v_proj.forward(xs)?, seq_len, bsz)? - .reshape(proj_shape)?; + .reshape(proj_shape)? + .to_dtype(DType::F32)?; let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?; let src_len = key_states.dim(1)?; @@ -168,7 +172,7 @@ impl ClipAttention { attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?; let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; - let attn_output = attn_weights.matmul(&value_states)?; + let attn_output = attn_weights.matmul(&value_states)?.to_dtype(in_dtype)?; let attn_output = attn_output .reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))? .transpose(1, 2)? diff --git a/candle-examples/examples/stable-diffusion/embeddings.rs b/candle-examples/examples/stable-diffusion/embeddings.rs index e3a339f5..c94f24f8 100644 --- a/candle-examples/examples/stable-diffusion/embeddings.rs +++ b/candle-examples/examples/stable-diffusion/embeddings.rs @@ -44,10 +44,10 @@ impl Timesteps { impl Timesteps { pub fn forward(&self, xs: &Tensor) -> Result { let half_dim = (self.num_channels / 2) as u32; - let exponent = - (Tensor::arange(0, half_dim, xs.device())?.to_dtype(xs.dtype())? * -f64::ln(10000.))?; + let exponent = (Tensor::arange(0, half_dim, xs.device())?.to_dtype(candle::DType::F32)? + * -f64::ln(10000.))?; let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?; - let emb = exponent.exp()?; + let emb = exponent.exp()?.to_dtype(xs.dtype())?; // emb = timesteps[:, None].float() * emb[None, :] let emb = xs.unsqueeze(D::Minus1)?.broadcast_mul(&emb.unsqueeze(0)?)?; let (cos, sin) = (emb.cos()?, emb.sin()?); diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs index de20d4a7..6edd8ae6 100644 --- a/candle-examples/examples/stable-diffusion/main.rs +++ b/candle-examples/examples/stable-diffusion/main.rs @@ -93,6 +93,9 @@ struct Args { #[arg(long)] use_flash_attn: bool, + + #[arg(long)] + use_f16: bool, } #[derive(Debug, Clone, Copy, clap::ValueEnum)] @@ -117,21 +120,39 @@ impl StableDiffusionVersion { } } - fn unet_file(&self) -> &'static str { + fn unet_file(&self, use_f16: bool) -> &'static str { match self { - Self::V1_5 | Self::V2_1 => "unet/diffusion_pytorch_model.safetensors", + Self::V1_5 | Self::V2_1 => { + if use_f16 { + "unet/diffusion_pytorch_model.fp16.safetensors" + } else { + "unet/diffusion_pytorch_model.safetensors" + } + } } } - fn vae_file(&self) -> &'static str { + fn vae_file(&self, use_f16: bool) -> &'static str { match self { - Self::V1_5 | Self::V2_1 => "vae/diffusion_pytorch_model.safetensors", + Self::V1_5 | Self::V2_1 => { + if use_f16 { + "vae/diffusion_pytorch_model.fp16.safetensors" + } else { + "vae/diffusion_pytorch_model.safetensors" + } + } } } - fn clip_file(&self) -> &'static str { + fn clip_file(&self, use_f16: bool) -> &'static str { match self { - Self::V1_5 | Self::V2_1 => "text_encoder/model.safetensors", + Self::V1_5 | Self::V2_1 => { + if use_f16 { + "text_encoder/model.fp16.safetensors" + } else { + "text_encoder/model.safetensors" + } + } } } } @@ -144,6 +165,7 @@ impl ModelFile { &self, filename: Option, version: StableDiffusionVersion, + use_f16: bool, ) -> Result { use hf_hub::api::sync::Api; match filename { @@ -151,9 +173,9 @@ impl ModelFile { None => { let (repo, path) = match self { Self::Tokenizer => (Self::TOKENIZER_REPO, Self::TOKENIZER_PATH), - Self::Clip => (version.repo(), version.clip_file()), - Self::Unet => (version.repo(), version.unet_file()), - Self::Vae => (version.repo(), version.vae_file()), + Self::Clip => (version.repo(), version.clip_file(use_f16)), + Self::Unet => (version.repo(), version.unet_file(use_f16)), + Self::Vae => (version.repo(), version.vae_file(use_f16)), }; let filename = Api::new()?.model(repo.to_string()).get(path)?; Ok(filename) @@ -209,6 +231,8 @@ fn run(args: Args) -> Result<()> { vae_weights, unet_weights, tracing, + use_f16, + use_flash_attn, .. } = args; @@ -220,6 +244,7 @@ fn run(args: Args) -> Result<()> { None }; + let dtype = if use_f16 { DType::F16 } else { DType::F32 }; let sd_config = match sd_version { StableDiffusionVersion::V1_5 => { stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size, height, width) @@ -232,7 +257,7 @@ fn run(args: Args) -> Result<()> { let scheduler = sd_config.build_scheduler(n_steps)?; let device = candle_examples::device(cpu)?; - let tokenizer = ModelFile::Tokenizer.get(tokenizer, sd_version)?; + let tokenizer = ModelFile::Tokenizer.get(tokenizer, sd_version, use_f16)?; let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; let pad_id = match &sd_config.clip.pad_with { Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(), @@ -260,18 +285,20 @@ fn run(args: Args) -> Result<()> { let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?; println!("Building the Clip transformer."); - let clip_weights = ModelFile::Clip.get(clip_weights, sd_version)?; - let text_model = sd_config.build_clip_transformer(&clip_weights, &device)?; - let text_embeddings = text_model.forward(&tokens)?; - let uncond_embeddings = text_model.forward(&uncond_tokens)?; - let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?; + let text_embeddings = { + let clip_weights = ModelFile::Clip.get(clip_weights, sd_version, false)?; + let text_model = sd_config.build_clip_transformer(&clip_weights, &device, DType::F32)?; + let text_embeddings = text_model.forward(&tokens)?; + let uncond_embeddings = text_model.forward(&uncond_tokens)?; + Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)? + }; println!("Building the autoencoder."); - let vae_weights = ModelFile::Vae.get(vae_weights, sd_version)?; - let vae = sd_config.build_vae(&vae_weights, &device)?; + let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?; + let vae = sd_config.build_vae(&vae_weights, &device, dtype)?; println!("Building the unet."); - let unet_weights = ModelFile::Unet.get(unet_weights, sd_version)?; - let unet = sd_config.build_unet(&unet_weights, &device, 4, args.use_flash_attn)?; + let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?; + let unet = sd_config.build_unet(&unet_weights, &device, 4, use_flash_attn, dtype)?; let bsize = 1; for idx in 0..num_samples { @@ -280,7 +307,8 @@ fn run(args: Args) -> Result<()> { 1f32, (bsize, 4, sd_config.height / 8, sd_config.width / 8), &device, - )?; + )? + .to_dtype(dtype)?; // scale the initial noise by the standard deviation required by the scheduler latents = (latents * scheduler.init_noise_sigma())?; diff --git a/candle-examples/examples/stable-diffusion/stable_diffusion.rs b/candle-examples/examples/stable-diffusion/stable_diffusion.rs index 05ba41cb..e159fa0a 100644 --- a/candle-examples/examples/stable-diffusion/stable_diffusion.rs +++ b/candle-examples/examples/stable-diffusion/stable_diffusion.rs @@ -159,10 +159,11 @@ impl StableDiffusionConfig { &self, vae_weights: P, device: &Device, + dtype: DType, ) -> Result { let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? }; let weights = weights.deserialize()?; - let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device); + let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], dtype, device); // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json let autoencoder = vae::AutoEncoderKL::new(vs_ae, 3, 3, self.autoencoder.clone())?; Ok(autoencoder) @@ -174,10 +175,11 @@ impl StableDiffusionConfig { device: &Device, in_channels: usize, use_flash_attn: bool, + dtype: DType, ) -> Result { let weights = unsafe { candle::safetensors::MmapedFile::new(unet_weights)? }; let weights = weights.deserialize()?; - let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device); + let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], dtype, device); let unet = unet_2d::UNet2DConditionModel::new( vs_unet, in_channels, @@ -196,10 +198,11 @@ impl StableDiffusionConfig { &self, clip_weights: P, device: &Device, + dtype: DType, ) -> Result { let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? }; let weights = weights.deserialize()?; - let vs = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device); + let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device); let text_model = clip::ClipTextTransformer::new(vs, &self.clip)?; Ok(text_model) } diff --git a/candle-examples/examples/stable-diffusion/unet_2d.rs b/candle-examples/examples/stable-diffusion/unet_2d.rs index e52ec281..0fa2f31a 100644 --- a/candle-examples/examples/stable-diffusion/unet_2d.rs +++ b/candle-examples/examples/stable-diffusion/unet_2d.rs @@ -5,7 +5,7 @@ use crate::embeddings::{TimestepEmbedding, Timesteps}; use crate::unet_2d_blocks::*; use crate::utils::{conv2d, Conv2d}; -use candle::{DType, Result, Tensor}; +use candle::{Result, Tensor}; use candle_nn as nn; #[derive(Debug, Clone, Copy)] @@ -316,7 +316,7 @@ impl UNet2DConditionModel { xs.clone() }; // 1. time - let emb = (Tensor::ones(bsize, DType::F32, device)? * timestep)?; + let emb = (Tensor::ones(bsize, xs.dtype(), device)? * timestep)?; let emb = self.time_proj.forward(&emb)?; let emb = self.time_embedding.forward(&emb)?; // 2. pre-process