mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Hook the quantized model.
This commit is contained in:
@ -23,6 +23,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cpu: bool,
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Use the quantized model.
|
||||||
|
#[arg(long)]
|
||||||
|
quantized: bool,
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tracing: bool,
|
tracing: bool,
|
||||||
@ -60,6 +64,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
tracing,
|
tracing,
|
||||||
decode_only,
|
decode_only,
|
||||||
model,
|
model,
|
||||||
|
quantized,
|
||||||
} = args;
|
} = args;
|
||||||
let width = width.unwrap_or(1360);
|
let width = width.unwrap_or(1360);
|
||||||
let height = height.unwrap_or(768);
|
let height = height.unwrap_or(768);
|
||||||
@ -146,12 +151,6 @@ fn run(args: Args) -> Result<()> {
|
|||||||
};
|
};
|
||||||
println!("CLIP\n{clip_emb}");
|
println!("CLIP\n{clip_emb}");
|
||||||
let img = {
|
let img = {
|
||||||
let model_file = match model {
|
|
||||||
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
|
|
||||||
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
|
|
||||||
};
|
|
||||||
let vb =
|
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
|
||||||
let cfg = match model {
|
let cfg = match model {
|
||||||
Model::Dev => flux::model::Config::dev(),
|
Model::Dev => flux::model::Config::dev(),
|
||||||
Model::Schnell => flux::model::Config::schnell(),
|
Model::Schnell => flux::model::Config::schnell(),
|
||||||
@ -164,10 +163,18 @@ fn run(args: Args) -> Result<()> {
|
|||||||
}
|
}
|
||||||
Model::Schnell => flux::sampling::get_schedule(4, None),
|
Model::Schnell => flux::sampling::get_schedule(4, None),
|
||||||
};
|
};
|
||||||
let model = flux::model::Flux::new(&cfg, vb)?;
|
|
||||||
|
|
||||||
println!("{state:?}");
|
println!("{state:?}");
|
||||||
println!("{timesteps:?}");
|
println!("{timesteps:?}");
|
||||||
|
if quantized {
|
||||||
|
let model_file = match model {
|
||||||
|
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
|
||||||
|
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
|
||||||
|
};
|
||||||
|
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
||||||
|
model_file, &device,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let model = flux::quantized_model::Flux::new(&cfg, vb)?;
|
||||||
flux::sampling::denoise(
|
flux::sampling::denoise(
|
||||||
&model,
|
&model,
|
||||||
&state.img,
|
&state.img,
|
||||||
@ -178,6 +185,26 @@ fn run(args: Args) -> Result<()> {
|
|||||||
×teps,
|
×teps,
|
||||||
4.,
|
4.,
|
||||||
)?
|
)?
|
||||||
|
} else {
|
||||||
|
let model_file = match model {
|
||||||
|
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
|
||||||
|
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
|
||||||
|
};
|
||||||
|
let vb = unsafe {
|
||||||
|
VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)?
|
||||||
|
};
|
||||||
|
let model = flux::model::Flux::new(&cfg, vb)?;
|
||||||
|
flux::sampling::denoise(
|
||||||
|
&model,
|
||||||
|
&state.img,
|
||||||
|
&state.img_ids,
|
||||||
|
&state.txt,
|
||||||
|
&state.txt_ids,
|
||||||
|
&state.vec,
|
||||||
|
×teps,
|
||||||
|
4.,
|
||||||
|
)?
|
||||||
|
}
|
||||||
};
|
};
|
||||||
flux::sampling::unpack(&img, height, width)?
|
flux::sampling::unpack(&img, height, width)?
|
||||||
}
|
}
|
||||||
|
@ -109,14 +109,14 @@ fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result<Tensor> {
|
|||||||
(fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec())
|
(fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
|
pub(crate) fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
|
||||||
let q = apply_rope(q, pe)?.contiguous()?;
|
let q = apply_rope(q, pe)?.contiguous()?;
|
||||||
let k = apply_rope(k, pe)?.contiguous()?;
|
let k = apply_rope(k, pe)?.contiguous()?;
|
||||||
let x = scaled_dot_product_attention(&q, &k, v)?;
|
let x = scaled_dot_product_attention(&q, &k, v)?;
|
||||||
x.transpose(1, 2)?.flatten_from(2)
|
x.transpose(1, 2)?.flatten_from(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
|
pub(crate) fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
|
||||||
const TIME_FACTOR: f64 = 1000.;
|
const TIME_FACTOR: f64 = 1000.;
|
||||||
const MAX_PERIOD: f64 = 10000.;
|
const MAX_PERIOD: f64 = 10000.;
|
||||||
if dim % 2 == 1 {
|
if dim % 2 == 1 {
|
||||||
@ -144,7 +144,7 @@ pub struct EmbedNd {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl EmbedNd {
|
impl EmbedNd {
|
||||||
fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
|
pub fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
dim,
|
dim,
|
||||||
theta,
|
theta,
|
||||||
|
@ -1,177 +1,14 @@
|
|||||||
|
use super::model::{attention, timestep_embedding, Config, EmbedNd};
|
||||||
use crate::quantized_nn::{linear, linear_b, Linear};
|
use crate::quantized_nn::{linear, linear_b, Linear};
|
||||||
use crate::quantized_var_builder::VarBuilder;
|
use crate::quantized_var_builder::VarBuilder;
|
||||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{LayerNorm, RmsNorm};
|
use candle_nn::{LayerNorm, RmsNorm};
|
||||||
|
|
||||||
// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/model.py#L12
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct Config {
|
|
||||||
pub in_channels: usize,
|
|
||||||
pub vec_in_dim: usize,
|
|
||||||
pub context_in_dim: usize,
|
|
||||||
pub hidden_size: usize,
|
|
||||||
pub mlp_ratio: f64,
|
|
||||||
pub num_heads: usize,
|
|
||||||
pub depth: usize,
|
|
||||||
pub depth_single_blocks: usize,
|
|
||||||
pub axes_dim: Vec<usize>,
|
|
||||||
pub theta: usize,
|
|
||||||
pub qkv_bias: bool,
|
|
||||||
pub guidance_embed: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Config {
|
|
||||||
// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L32
|
|
||||||
pub fn dev() -> Self {
|
|
||||||
Self {
|
|
||||||
in_channels: 64,
|
|
||||||
vec_in_dim: 768,
|
|
||||||
context_in_dim: 4096,
|
|
||||||
hidden_size: 3072,
|
|
||||||
mlp_ratio: 4.0,
|
|
||||||
num_heads: 24,
|
|
||||||
depth: 19,
|
|
||||||
depth_single_blocks: 38,
|
|
||||||
axes_dim: vec![16, 56, 56],
|
|
||||||
theta: 10_000,
|
|
||||||
qkv_bias: true,
|
|
||||||
guidance_embed: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L64
|
|
||||||
pub fn schnell() -> Self {
|
|
||||||
Self {
|
|
||||||
in_channels: 64,
|
|
||||||
vec_in_dim: 768,
|
|
||||||
context_in_dim: 4096,
|
|
||||||
hidden_size: 3072,
|
|
||||||
mlp_ratio: 4.0,
|
|
||||||
num_heads: 24,
|
|
||||||
depth: 19,
|
|
||||||
depth_single_blocks: 38,
|
|
||||||
axes_dim: vec![16, 56, 56],
|
|
||||||
theta: 10_000,
|
|
||||||
qkv_bias: true,
|
|
||||||
guidance_embed: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn layer_norm(dim: usize, vb: VarBuilder) -> Result<LayerNorm> {
|
fn layer_norm(dim: usize, vb: VarBuilder) -> Result<LayerNorm> {
|
||||||
let ws = Tensor::ones(dim, DType::F32, vb.device())?;
|
let ws = Tensor::ones(dim, DType::F32, vb.device())?;
|
||||||
Ok(LayerNorm::new_no_bias(ws, 1e-6))
|
Ok(LayerNorm::new_no_bias(ws, 1e-6))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
|
||||||
let dim = q.dim(D::Minus1)?;
|
|
||||||
let scale_factor = 1.0 / (dim as f64).sqrt();
|
|
||||||
let mut batch_dims = q.dims().to_vec();
|
|
||||||
batch_dims.pop();
|
|
||||||
batch_dims.pop();
|
|
||||||
let q = q.flatten_to(batch_dims.len() - 1)?;
|
|
||||||
let k = k.flatten_to(batch_dims.len() - 1)?;
|
|
||||||
let v = v.flatten_to(batch_dims.len() - 1)?;
|
|
||||||
let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?;
|
|
||||||
let attn_scores = candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(&v)?;
|
|
||||||
batch_dims.push(attn_scores.dim(D::Minus2)?);
|
|
||||||
batch_dims.push(attn_scores.dim(D::Minus1)?);
|
|
||||||
attn_scores.reshape(batch_dims)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rope(pos: &Tensor, dim: usize, theta: usize) -> Result<Tensor> {
|
|
||||||
if dim % 2 == 1 {
|
|
||||||
candle::bail!("dim {dim} is odd")
|
|
||||||
}
|
|
||||||
let dev = pos.device();
|
|
||||||
let theta = theta as f64;
|
|
||||||
let inv_freq: Vec<_> = (0..dim)
|
|
||||||
.step_by(2)
|
|
||||||
.map(|i| 1f32 / theta.powf(i as f64 / dim as f64) as f32)
|
|
||||||
.collect();
|
|
||||||
let inv_freq_len = inv_freq.len();
|
|
||||||
let inv_freq = Tensor::from_vec(inv_freq, (1, 1, inv_freq_len), dev)?;
|
|
||||||
let inv_freq = inv_freq.to_dtype(pos.dtype())?;
|
|
||||||
let freqs = pos.unsqueeze(2)?.broadcast_mul(&inv_freq)?;
|
|
||||||
let cos = freqs.cos()?;
|
|
||||||
let sin = freqs.sin()?;
|
|
||||||
let out = Tensor::stack(&[&cos, &sin.neg()?, &sin, &cos], 3)?;
|
|
||||||
let (b, n, d, _ij) = out.dims4()?;
|
|
||||||
out.reshape((b, n, d, 2, 2))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result<Tensor> {
|
|
||||||
let dims = x.dims();
|
|
||||||
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
|
|
||||||
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
|
|
||||||
let x0 = x.narrow(D::Minus1, 0, 1)?;
|
|
||||||
let x1 = x.narrow(D::Minus1, 1, 1)?;
|
|
||||||
let fr0 = freq_cis.get_on_dim(D::Minus1, 0)?;
|
|
||||||
let fr1 = freq_cis.get_on_dim(D::Minus1, 1)?;
|
|
||||||
(fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
|
|
||||||
let q = apply_rope(q, pe)?.contiguous()?;
|
|
||||||
let k = apply_rope(k, pe)?.contiguous()?;
|
|
||||||
let x = scaled_dot_product_attention(&q, &k, v)?;
|
|
||||||
x.transpose(1, 2)?.flatten_from(2)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
|
|
||||||
const TIME_FACTOR: f64 = 1000.;
|
|
||||||
const MAX_PERIOD: f64 = 10000.;
|
|
||||||
if dim % 2 == 1 {
|
|
||||||
candle::bail!("{dim} is odd")
|
|
||||||
}
|
|
||||||
let dev = t.device();
|
|
||||||
let half = dim / 2;
|
|
||||||
let t = (t * TIME_FACTOR)?;
|
|
||||||
let arange = Tensor::arange(0, half as u32, dev)?.to_dtype(candle::DType::F32)?;
|
|
||||||
let freqs = (arange * (-MAX_PERIOD.ln() / half as f64))?.exp()?;
|
|
||||||
let args = t
|
|
||||||
.unsqueeze(1)?
|
|
||||||
.to_dtype(candle::DType::F32)?
|
|
||||||
.broadcast_mul(&freqs.unsqueeze(0)?)?;
|
|
||||||
let emb = Tensor::cat(&[args.cos()?, args.sin()?], D::Minus1)?.to_dtype(dtype)?;
|
|
||||||
Ok(emb)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct EmbedNd {
|
|
||||||
#[allow(unused)]
|
|
||||||
dim: usize,
|
|
||||||
theta: usize,
|
|
||||||
axes_dim: Vec<usize>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EmbedNd {
|
|
||||||
fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
|
|
||||||
Self {
|
|
||||||
dim,
|
|
||||||
theta,
|
|
||||||
axes_dim,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl candle::Module for EmbedNd {
|
|
||||||
fn forward(&self, ids: &Tensor) -> Result<Tensor> {
|
|
||||||
let n_axes = ids.dim(D::Minus1)?;
|
|
||||||
let mut emb = Vec::with_capacity(n_axes);
|
|
||||||
for idx in 0..n_axes {
|
|
||||||
let r = rope(
|
|
||||||
&ids.get_on_dim(D::Minus1, idx)?,
|
|
||||||
self.axes_dim[idx],
|
|
||||||
self.theta,
|
|
||||||
)?;
|
|
||||||
emb.push(r)
|
|
||||||
}
|
|
||||||
let emb = Tensor::cat(&emb, 2)?;
|
|
||||||
emb.unsqueeze(1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct MlpEmbedder {
|
pub struct MlpEmbedder {
|
||||||
in_layer: Linear,
|
in_layer: Linear,
|
||||||
|
Reference in New Issue
Block a user