Hook the quantized model.

This commit is contained in:
Laurent
2024-09-25 11:24:50 +02:00
parent 0bd61bae29
commit 8cc560bb8c
3 changed files with 49 additions and 185 deletions

View File

@ -23,6 +23,10 @@ struct Args {
#[arg(long)]
cpu: bool,
/// Use the quantized model.
#[arg(long)]
quantized: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
@ -60,6 +64,7 @@ fn run(args: Args) -> Result<()> {
tracing,
decode_only,
model,
quantized,
} = args;
let width = width.unwrap_or(1360);
let height = height.unwrap_or(768);
@ -146,12 +151,6 @@ fn run(args: Args) -> Result<()> {
};
println!("CLIP\n{clip_emb}");
let img = {
let model_file = match model {
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
};
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let cfg = match model {
Model::Dev => flux::model::Config::dev(),
Model::Schnell => flux::model::Config::schnell(),
@ -164,10 +163,18 @@ fn run(args: Args) -> Result<()> {
}
Model::Schnell => flux::sampling::get_schedule(4, None),
};
let model = flux::model::Flux::new(&cfg, vb)?;
println!("{state:?}");
println!("{timesteps:?}");
if quantized {
let model_file = match model {
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
};
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
model_file, &device,
)?;
let model = flux::quantized_model::Flux::new(&cfg, vb)?;
flux::sampling::denoise(
&model,
&state.img,
@ -178,6 +185,26 @@ fn run(args: Args) -> Result<()> {
&timesteps,
4.,
)?
} else {
let model_file = match model {
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
};
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)?
};
let model = flux::model::Flux::new(&cfg, vb)?;
flux::sampling::denoise(
&model,
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
&timesteps,
4.,
)?
}
};
flux::sampling::unpack(&img, height, width)?
}

View File

@ -109,14 +109,14 @@ fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result<Tensor> {
(fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec())
}
fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
pub(crate) fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
let q = apply_rope(q, pe)?.contiguous()?;
let k = apply_rope(k, pe)?.contiguous()?;
let x = scaled_dot_product_attention(&q, &k, v)?;
x.transpose(1, 2)?.flatten_from(2)
}
fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
pub(crate) fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
const TIME_FACTOR: f64 = 1000.;
const MAX_PERIOD: f64 = 10000.;
if dim % 2 == 1 {
@ -144,7 +144,7 @@ pub struct EmbedNd {
}
impl EmbedNd {
fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
pub fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
Self {
dim,
theta,

View File

@ -1,177 +1,14 @@
use super::model::{attention, timestep_embedding, Config, EmbedNd};
use crate::quantized_nn::{linear, linear_b, Linear};
use crate::quantized_var_builder::VarBuilder;
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{LayerNorm, RmsNorm};
// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/model.py#L12
#[derive(Debug, Clone)]
pub struct Config {
pub in_channels: usize,
pub vec_in_dim: usize,
pub context_in_dim: usize,
pub hidden_size: usize,
pub mlp_ratio: f64,
pub num_heads: usize,
pub depth: usize,
pub depth_single_blocks: usize,
pub axes_dim: Vec<usize>,
pub theta: usize,
pub qkv_bias: bool,
pub guidance_embed: bool,
}
impl Config {
// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L32
pub fn dev() -> Self {
Self {
in_channels: 64,
vec_in_dim: 768,
context_in_dim: 4096,
hidden_size: 3072,
mlp_ratio: 4.0,
num_heads: 24,
depth: 19,
depth_single_blocks: 38,
axes_dim: vec![16, 56, 56],
theta: 10_000,
qkv_bias: true,
guidance_embed: true,
}
}
// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L64
pub fn schnell() -> Self {
Self {
in_channels: 64,
vec_in_dim: 768,
context_in_dim: 4096,
hidden_size: 3072,
mlp_ratio: 4.0,
num_heads: 24,
depth: 19,
depth_single_blocks: 38,
axes_dim: vec![16, 56, 56],
theta: 10_000,
qkv_bias: true,
guidance_embed: false,
}
}
}
fn layer_norm(dim: usize, vb: VarBuilder) -> Result<LayerNorm> {
let ws = Tensor::ones(dim, DType::F32, vb.device())?;
Ok(LayerNorm::new_no_bias(ws, 1e-6))
}
fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
let dim = q.dim(D::Minus1)?;
let scale_factor = 1.0 / (dim as f64).sqrt();
let mut batch_dims = q.dims().to_vec();
batch_dims.pop();
batch_dims.pop();
let q = q.flatten_to(batch_dims.len() - 1)?;
let k = k.flatten_to(batch_dims.len() - 1)?;
let v = v.flatten_to(batch_dims.len() - 1)?;
let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?;
let attn_scores = candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(&v)?;
batch_dims.push(attn_scores.dim(D::Minus2)?);
batch_dims.push(attn_scores.dim(D::Minus1)?);
attn_scores.reshape(batch_dims)
}
fn rope(pos: &Tensor, dim: usize, theta: usize) -> Result<Tensor> {
if dim % 2 == 1 {
candle::bail!("dim {dim} is odd")
}
let dev = pos.device();
let theta = theta as f64;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, 1, inv_freq_len), dev)?;
let inv_freq = inv_freq.to_dtype(pos.dtype())?;
let freqs = pos.unsqueeze(2)?.broadcast_mul(&inv_freq)?;
let cos = freqs.cos()?;
let sin = freqs.sin()?;
let out = Tensor::stack(&[&cos, &sin.neg()?, &sin, &cos], 3)?;
let (b, n, d, _ij) = out.dims4()?;
out.reshape((b, n, d, 2, 2))
}
fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result<Tensor> {
let dims = x.dims();
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
let x0 = x.narrow(D::Minus1, 0, 1)?;
let x1 = x.narrow(D::Minus1, 1, 1)?;
let fr0 = freq_cis.get_on_dim(D::Minus1, 0)?;
let fr1 = freq_cis.get_on_dim(D::Minus1, 1)?;
(fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec())
}
fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
let q = apply_rope(q, pe)?.contiguous()?;
let k = apply_rope(k, pe)?.contiguous()?;
let x = scaled_dot_product_attention(&q, &k, v)?;
x.transpose(1, 2)?.flatten_from(2)
}
fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
const TIME_FACTOR: f64 = 1000.;
const MAX_PERIOD: f64 = 10000.;
if dim % 2 == 1 {
candle::bail!("{dim} is odd")
}
let dev = t.device();
let half = dim / 2;
let t = (t * TIME_FACTOR)?;
let arange = Tensor::arange(0, half as u32, dev)?.to_dtype(candle::DType::F32)?;
let freqs = (arange * (-MAX_PERIOD.ln() / half as f64))?.exp()?;
let args = t
.unsqueeze(1)?
.to_dtype(candle::DType::F32)?
.broadcast_mul(&freqs.unsqueeze(0)?)?;
let emb = Tensor::cat(&[args.cos()?, args.sin()?], D::Minus1)?.to_dtype(dtype)?;
Ok(emb)
}
#[derive(Debug, Clone)]
pub struct EmbedNd {
#[allow(unused)]
dim: usize,
theta: usize,
axes_dim: Vec<usize>,
}
impl EmbedNd {
fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
Self {
dim,
theta,
axes_dim,
}
}
}
impl candle::Module for EmbedNd {
fn forward(&self, ids: &Tensor) -> Result<Tensor> {
let n_axes = ids.dim(D::Minus1)?;
let mut emb = Vec::with_capacity(n_axes);
for idx in 0..n_axes {
let r = rope(
&ids.get_on_dim(D::Minus1, idx)?,
self.axes_dim[idx],
self.theta,
)?;
emb.push(r)
}
let emb = Tensor::cat(&emb, 2)?;
emb.unsqueeze(1)
}
}
#[derive(Debug, Clone)]
pub struct MlpEmbedder {
in_layer: Linear,