mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Hook the quantized model.
This commit is contained in:
@ -23,6 +23,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Use the quantized model.
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
@ -60,6 +64,7 @@ fn run(args: Args) -> Result<()> {
|
||||
tracing,
|
||||
decode_only,
|
||||
model,
|
||||
quantized,
|
||||
} = args;
|
||||
let width = width.unwrap_or(1360);
|
||||
let height = height.unwrap_or(768);
|
||||
@ -146,12 +151,6 @@ fn run(args: Args) -> Result<()> {
|
||||
};
|
||||
println!("CLIP\n{clip_emb}");
|
||||
let img = {
|
||||
let model_file = match model {
|
||||
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
|
||||
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
|
||||
};
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
||||
let cfg = match model {
|
||||
Model::Dev => flux::model::Config::dev(),
|
||||
Model::Schnell => flux::model::Config::schnell(),
|
||||
@ -164,10 +163,18 @@ fn run(args: Args) -> Result<()> {
|
||||
}
|
||||
Model::Schnell => flux::sampling::get_schedule(4, None),
|
||||
};
|
||||
let model = flux::model::Flux::new(&cfg, vb)?;
|
||||
|
||||
println!("{state:?}");
|
||||
println!("{timesteps:?}");
|
||||
if quantized {
|
||||
let model_file = match model {
|
||||
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
|
||||
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
|
||||
};
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
||||
model_file, &device,
|
||||
)?;
|
||||
|
||||
let model = flux::quantized_model::Flux::new(&cfg, vb)?;
|
||||
flux::sampling::denoise(
|
||||
&model,
|
||||
&state.img,
|
||||
@ -178,6 +185,26 @@ fn run(args: Args) -> Result<()> {
|
||||
×teps,
|
||||
4.,
|
||||
)?
|
||||
} else {
|
||||
let model_file = match model {
|
||||
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
|
||||
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
|
||||
};
|
||||
let vb = unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)?
|
||||
};
|
||||
let model = flux::model::Flux::new(&cfg, vb)?;
|
||||
flux::sampling::denoise(
|
||||
&model,
|
||||
&state.img,
|
||||
&state.img_ids,
|
||||
&state.txt,
|
||||
&state.txt_ids,
|
||||
&state.vec,
|
||||
×teps,
|
||||
4.,
|
||||
)?
|
||||
}
|
||||
};
|
||||
flux::sampling::unpack(&img, height, width)?
|
||||
}
|
||||
|
@ -109,14 +109,14 @@ fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result<Tensor> {
|
||||
(fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec())
|
||||
}
|
||||
|
||||
fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
|
||||
pub(crate) fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
|
||||
let q = apply_rope(q, pe)?.contiguous()?;
|
||||
let k = apply_rope(k, pe)?.contiguous()?;
|
||||
let x = scaled_dot_product_attention(&q, &k, v)?;
|
||||
x.transpose(1, 2)?.flatten_from(2)
|
||||
}
|
||||
|
||||
fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
|
||||
pub(crate) fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
|
||||
const TIME_FACTOR: f64 = 1000.;
|
||||
const MAX_PERIOD: f64 = 10000.;
|
||||
if dim % 2 == 1 {
|
||||
@ -144,7 +144,7 @@ pub struct EmbedNd {
|
||||
}
|
||||
|
||||
impl EmbedNd {
|
||||
fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
|
||||
pub fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
|
||||
Self {
|
||||
dim,
|
||||
theta,
|
||||
|
@ -1,177 +1,14 @@
|
||||
use super::model::{attention, timestep_embedding, Config, EmbedNd};
|
||||
use crate::quantized_nn::{linear, linear_b, Linear};
|
||||
use crate::quantized_var_builder::VarBuilder;
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{LayerNorm, RmsNorm};
|
||||
|
||||
// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/model.py#L12
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub in_channels: usize,
|
||||
pub vec_in_dim: usize,
|
||||
pub context_in_dim: usize,
|
||||
pub hidden_size: usize,
|
||||
pub mlp_ratio: f64,
|
||||
pub num_heads: usize,
|
||||
pub depth: usize,
|
||||
pub depth_single_blocks: usize,
|
||||
pub axes_dim: Vec<usize>,
|
||||
pub theta: usize,
|
||||
pub qkv_bias: bool,
|
||||
pub guidance_embed: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L32
|
||||
pub fn dev() -> Self {
|
||||
Self {
|
||||
in_channels: 64,
|
||||
vec_in_dim: 768,
|
||||
context_in_dim: 4096,
|
||||
hidden_size: 3072,
|
||||
mlp_ratio: 4.0,
|
||||
num_heads: 24,
|
||||
depth: 19,
|
||||
depth_single_blocks: 38,
|
||||
axes_dim: vec![16, 56, 56],
|
||||
theta: 10_000,
|
||||
qkv_bias: true,
|
||||
guidance_embed: true,
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L64
|
||||
pub fn schnell() -> Self {
|
||||
Self {
|
||||
in_channels: 64,
|
||||
vec_in_dim: 768,
|
||||
context_in_dim: 4096,
|
||||
hidden_size: 3072,
|
||||
mlp_ratio: 4.0,
|
||||
num_heads: 24,
|
||||
depth: 19,
|
||||
depth_single_blocks: 38,
|
||||
axes_dim: vec![16, 56, 56],
|
||||
theta: 10_000,
|
||||
qkv_bias: true,
|
||||
guidance_embed: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn layer_norm(dim: usize, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
let ws = Tensor::ones(dim, DType::F32, vb.device())?;
|
||||
Ok(LayerNorm::new_no_bias(ws, 1e-6))
|
||||
}
|
||||
|
||||
fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
||||
let dim = q.dim(D::Minus1)?;
|
||||
let scale_factor = 1.0 / (dim as f64).sqrt();
|
||||
let mut batch_dims = q.dims().to_vec();
|
||||
batch_dims.pop();
|
||||
batch_dims.pop();
|
||||
let q = q.flatten_to(batch_dims.len() - 1)?;
|
||||
let k = k.flatten_to(batch_dims.len() - 1)?;
|
||||
let v = v.flatten_to(batch_dims.len() - 1)?;
|
||||
let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?;
|
||||
let attn_scores = candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(&v)?;
|
||||
batch_dims.push(attn_scores.dim(D::Minus2)?);
|
||||
batch_dims.push(attn_scores.dim(D::Minus1)?);
|
||||
attn_scores.reshape(batch_dims)
|
||||
}
|
||||
|
||||
fn rope(pos: &Tensor, dim: usize, theta: usize) -> Result<Tensor> {
|
||||
if dim % 2 == 1 {
|
||||
candle::bail!("dim {dim} is odd")
|
||||
}
|
||||
let dev = pos.device();
|
||||
let theta = theta as f64;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, 1, inv_freq_len), dev)?;
|
||||
let inv_freq = inv_freq.to_dtype(pos.dtype())?;
|
||||
let freqs = pos.unsqueeze(2)?.broadcast_mul(&inv_freq)?;
|
||||
let cos = freqs.cos()?;
|
||||
let sin = freqs.sin()?;
|
||||
let out = Tensor::stack(&[&cos, &sin.neg()?, &sin, &cos], 3)?;
|
||||
let (b, n, d, _ij) = out.dims4()?;
|
||||
out.reshape((b, n, d, 2, 2))
|
||||
}
|
||||
|
||||
fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result<Tensor> {
|
||||
let dims = x.dims();
|
||||
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
|
||||
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
|
||||
let x0 = x.narrow(D::Minus1, 0, 1)?;
|
||||
let x1 = x.narrow(D::Minus1, 1, 1)?;
|
||||
let fr0 = freq_cis.get_on_dim(D::Minus1, 0)?;
|
||||
let fr1 = freq_cis.get_on_dim(D::Minus1, 1)?;
|
||||
(fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec())
|
||||
}
|
||||
|
||||
fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
|
||||
let q = apply_rope(q, pe)?.contiguous()?;
|
||||
let k = apply_rope(k, pe)?.contiguous()?;
|
||||
let x = scaled_dot_product_attention(&q, &k, v)?;
|
||||
x.transpose(1, 2)?.flatten_from(2)
|
||||
}
|
||||
|
||||
fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
|
||||
const TIME_FACTOR: f64 = 1000.;
|
||||
const MAX_PERIOD: f64 = 10000.;
|
||||
if dim % 2 == 1 {
|
||||
candle::bail!("{dim} is odd")
|
||||
}
|
||||
let dev = t.device();
|
||||
let half = dim / 2;
|
||||
let t = (t * TIME_FACTOR)?;
|
||||
let arange = Tensor::arange(0, half as u32, dev)?.to_dtype(candle::DType::F32)?;
|
||||
let freqs = (arange * (-MAX_PERIOD.ln() / half as f64))?.exp()?;
|
||||
let args = t
|
||||
.unsqueeze(1)?
|
||||
.to_dtype(candle::DType::F32)?
|
||||
.broadcast_mul(&freqs.unsqueeze(0)?)?;
|
||||
let emb = Tensor::cat(&[args.cos()?, args.sin()?], D::Minus1)?.to_dtype(dtype)?;
|
||||
Ok(emb)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EmbedNd {
|
||||
#[allow(unused)]
|
||||
dim: usize,
|
||||
theta: usize,
|
||||
axes_dim: Vec<usize>,
|
||||
}
|
||||
|
||||
impl EmbedNd {
|
||||
fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
|
||||
Self {
|
||||
dim,
|
||||
theta,
|
||||
axes_dim,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::Module for EmbedNd {
|
||||
fn forward(&self, ids: &Tensor) -> Result<Tensor> {
|
||||
let n_axes = ids.dim(D::Minus1)?;
|
||||
let mut emb = Vec::with_capacity(n_axes);
|
||||
for idx in 0..n_axes {
|
||||
let r = rope(
|
||||
&ids.get_on_dim(D::Minus1, idx)?,
|
||||
self.axes_dim[idx],
|
||||
self.theta,
|
||||
)?;
|
||||
emb.push(r)
|
||||
}
|
||||
let emb = Tensor::cat(&emb, 2)?;
|
||||
emb.unsqueeze(1)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MlpEmbedder {
|
||||
in_layer: Linear,
|
||||
|
Reference in New Issue
Block a user