From 8cc560bb8ccbf1979cf34bf79fed778f52408285 Mon Sep 17 00:00:00 2001 From: Laurent Date: Wed, 25 Sep 2024 11:24:50 +0200 Subject: [PATCH] Hook the quantized model. --- candle-examples/examples/flux/main.rs | 63 +++++-- candle-transformers/src/models/flux/model.rs | 6 +- .../src/models/flux/quantized_model.rs | 165 +----------------- 3 files changed, 49 insertions(+), 185 deletions(-) diff --git a/candle-examples/examples/flux/main.rs b/candle-examples/examples/flux/main.rs index 539ae6f2..641b72f5 100644 --- a/candle-examples/examples/flux/main.rs +++ b/candle-examples/examples/flux/main.rs @@ -23,6 +23,10 @@ struct Args { #[arg(long)] cpu: bool, + /// Use the quantized model. + #[arg(long)] + quantized: bool, + /// Enable tracing (generates a trace-timestamp.json file). #[arg(long)] tracing: bool, @@ -60,6 +64,7 @@ fn run(args: Args) -> Result<()> { tracing, decode_only, model, + quantized, } = args; let width = width.unwrap_or(1360); let height = height.unwrap_or(768); @@ -146,12 +151,6 @@ fn run(args: Args) -> Result<()> { }; println!("CLIP\n{clip_emb}"); let img = { - let model_file = match model { - Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?, - Model::Dev => bf_repo.get("flux1-dev.safetensors")?, - }; - let vb = - unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? }; let cfg = match model { Model::Dev => flux::model::Config::dev(), Model::Schnell => flux::model::Config::schnell(), @@ -164,20 +163,48 @@ fn run(args: Args) -> Result<()> { } Model::Schnell => flux::sampling::get_schedule(4, None), }; - let model = flux::model::Flux::new(&cfg, vb)?; - println!("{state:?}"); println!("{timesteps:?}"); - flux::sampling::denoise( - &model, - &state.img, - &state.img_ids, - &state.txt, - &state.txt_ids, - &state.vec, - ×teps, - 4., - )? + if quantized { + let model_file = match model { + Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?, + Model::Dev => bf_repo.get("flux1-dev.safetensors")?, + }; + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf( + model_file, &device, + )?; + + let model = flux::quantized_model::Flux::new(&cfg, vb)?; + flux::sampling::denoise( + &model, + &state.img, + &state.img_ids, + &state.txt, + &state.txt_ids, + &state.vec, + ×teps, + 4., + )? + } else { + let model_file = match model { + Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?, + Model::Dev => bf_repo.get("flux1-dev.safetensors")?, + }; + let vb = unsafe { + VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? + }; + let model = flux::model::Flux::new(&cfg, vb)?; + flux::sampling::denoise( + &model, + &state.img, + &state.img_ids, + &state.txt, + &state.txt_ids, + &state.vec, + ×teps, + 4., + )? + } }; flux::sampling::unpack(&img, height, width)? } diff --git a/candle-transformers/src/models/flux/model.rs b/candle-transformers/src/models/flux/model.rs index 02835be5..17b4eb25 100644 --- a/candle-transformers/src/models/flux/model.rs +++ b/candle-transformers/src/models/flux/model.rs @@ -109,14 +109,14 @@ fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result { (fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec()) } -fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result { +pub(crate) fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result { let q = apply_rope(q, pe)?.contiguous()?; let k = apply_rope(k, pe)?.contiguous()?; let x = scaled_dot_product_attention(&q, &k, v)?; x.transpose(1, 2)?.flatten_from(2) } -fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result { +pub(crate) fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result { const TIME_FACTOR: f64 = 1000.; const MAX_PERIOD: f64 = 10000.; if dim % 2 == 1 { @@ -144,7 +144,7 @@ pub struct EmbedNd { } impl EmbedNd { - fn new(dim: usize, theta: usize, axes_dim: Vec) -> Self { + pub fn new(dim: usize, theta: usize, axes_dim: Vec) -> Self { Self { dim, theta, diff --git a/candle-transformers/src/models/flux/quantized_model.rs b/candle-transformers/src/models/flux/quantized_model.rs index 366182eb..0efeeab5 100644 --- a/candle-transformers/src/models/flux/quantized_model.rs +++ b/candle-transformers/src/models/flux/quantized_model.rs @@ -1,177 +1,14 @@ +use super::model::{attention, timestep_embedding, Config, EmbedNd}; use crate::quantized_nn::{linear, linear_b, Linear}; use crate::quantized_var_builder::VarBuilder; use candle::{DType, IndexOp, Result, Tensor, D}; use candle_nn::{LayerNorm, RmsNorm}; -// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/model.py#L12 -#[derive(Debug, Clone)] -pub struct Config { - pub in_channels: usize, - pub vec_in_dim: usize, - pub context_in_dim: usize, - pub hidden_size: usize, - pub mlp_ratio: f64, - pub num_heads: usize, - pub depth: usize, - pub depth_single_blocks: usize, - pub axes_dim: Vec, - pub theta: usize, - pub qkv_bias: bool, - pub guidance_embed: bool, -} - -impl Config { - // https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L32 - pub fn dev() -> Self { - Self { - in_channels: 64, - vec_in_dim: 768, - context_in_dim: 4096, - hidden_size: 3072, - mlp_ratio: 4.0, - num_heads: 24, - depth: 19, - depth_single_blocks: 38, - axes_dim: vec![16, 56, 56], - theta: 10_000, - qkv_bias: true, - guidance_embed: true, - } - } - - // https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L64 - pub fn schnell() -> Self { - Self { - in_channels: 64, - vec_in_dim: 768, - context_in_dim: 4096, - hidden_size: 3072, - mlp_ratio: 4.0, - num_heads: 24, - depth: 19, - depth_single_blocks: 38, - axes_dim: vec![16, 56, 56], - theta: 10_000, - qkv_bias: true, - guidance_embed: false, - } - } -} - fn layer_norm(dim: usize, vb: VarBuilder) -> Result { let ws = Tensor::ones(dim, DType::F32, vb.device())?; Ok(LayerNorm::new_no_bias(ws, 1e-6)) } -fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result { - let dim = q.dim(D::Minus1)?; - let scale_factor = 1.0 / (dim as f64).sqrt(); - let mut batch_dims = q.dims().to_vec(); - batch_dims.pop(); - batch_dims.pop(); - let q = q.flatten_to(batch_dims.len() - 1)?; - let k = k.flatten_to(batch_dims.len() - 1)?; - let v = v.flatten_to(batch_dims.len() - 1)?; - let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?; - let attn_scores = candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(&v)?; - batch_dims.push(attn_scores.dim(D::Minus2)?); - batch_dims.push(attn_scores.dim(D::Minus1)?); - attn_scores.reshape(batch_dims) -} - -fn rope(pos: &Tensor, dim: usize, theta: usize) -> Result { - if dim % 2 == 1 { - candle::bail!("dim {dim} is odd") - } - let dev = pos.device(); - let theta = theta as f64; - let inv_freq: Vec<_> = (0..dim) - .step_by(2) - .map(|i| 1f32 / theta.powf(i as f64 / dim as f64) as f32) - .collect(); - let inv_freq_len = inv_freq.len(); - let inv_freq = Tensor::from_vec(inv_freq, (1, 1, inv_freq_len), dev)?; - let inv_freq = inv_freq.to_dtype(pos.dtype())?; - let freqs = pos.unsqueeze(2)?.broadcast_mul(&inv_freq)?; - let cos = freqs.cos()?; - let sin = freqs.sin()?; - let out = Tensor::stack(&[&cos, &sin.neg()?, &sin, &cos], 3)?; - let (b, n, d, _ij) = out.dims4()?; - out.reshape((b, n, d, 2, 2)) -} - -fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result { - let dims = x.dims(); - let (b_sz, n_head, seq_len, n_embd) = x.dims4()?; - let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?; - let x0 = x.narrow(D::Minus1, 0, 1)?; - let x1 = x.narrow(D::Minus1, 1, 1)?; - let fr0 = freq_cis.get_on_dim(D::Minus1, 0)?; - let fr1 = freq_cis.get_on_dim(D::Minus1, 1)?; - (fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec()) -} - -fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result { - let q = apply_rope(q, pe)?.contiguous()?; - let k = apply_rope(k, pe)?.contiguous()?; - let x = scaled_dot_product_attention(&q, &k, v)?; - x.transpose(1, 2)?.flatten_from(2) -} - -fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result { - const TIME_FACTOR: f64 = 1000.; - const MAX_PERIOD: f64 = 10000.; - if dim % 2 == 1 { - candle::bail!("{dim} is odd") - } - let dev = t.device(); - let half = dim / 2; - let t = (t * TIME_FACTOR)?; - let arange = Tensor::arange(0, half as u32, dev)?.to_dtype(candle::DType::F32)?; - let freqs = (arange * (-MAX_PERIOD.ln() / half as f64))?.exp()?; - let args = t - .unsqueeze(1)? - .to_dtype(candle::DType::F32)? - .broadcast_mul(&freqs.unsqueeze(0)?)?; - let emb = Tensor::cat(&[args.cos()?, args.sin()?], D::Minus1)?.to_dtype(dtype)?; - Ok(emb) -} - -#[derive(Debug, Clone)] -pub struct EmbedNd { - #[allow(unused)] - dim: usize, - theta: usize, - axes_dim: Vec, -} - -impl EmbedNd { - fn new(dim: usize, theta: usize, axes_dim: Vec) -> Self { - Self { - dim, - theta, - axes_dim, - } - } -} - -impl candle::Module for EmbedNd { - fn forward(&self, ids: &Tensor) -> Result { - let n_axes = ids.dim(D::Minus1)?; - let mut emb = Vec::with_capacity(n_axes); - for idx in 0..n_axes { - let r = rope( - &ids.get_on_dim(D::Minus1, idx)?, - self.axes_dim[idx], - self.theta, - )?; - emb.push(r) - } - let emb = Tensor::cat(&emb, 2)?; - emb.unsqueeze(1) - } -} - #[derive(Debug, Clone)] pub struct MlpEmbedder { in_layer: Linear,