Compare commits

...

9 Commits

Author SHA1 Message Date
5221146cfa Cuda quantization padding fix. 2024-09-25 23:35:16 +02:00
fd3b53f48b Fix for the quantized model. 2024-09-25 12:34:46 +02:00
c6019e9635 Use the newly minted gguf file. 2024-09-25 12:08:20 +02:00
8cc560bb8c Hook the quantized model. 2024-09-25 11:24:50 +02:00
0bd61bae29 More generic sampling. 2024-09-25 11:15:37 +02:00
fa1e0e438e Quantized version of flux. 2024-09-25 11:07:49 +02:00
d01207dbf3 Add a RotatingKVCache. (#2493)
* Add a RotatingKVCache.

* Add some KvCache tests.

* Test the reset too.

* More kv-cache testing.

* More tests for the rotating kv-cache.

* Improve the api for the rotating cache so that the whole src tensor gets returned when it's overlarge.

* Handle contiguity + bugfix + use in mimi.

* Add a way to test the mimi streaming mode.

* Mimi streaming fixes.

* More rotating kv-cache.

* Fix the attn mask generation.

* Handle the abs case.

* Add some tests for the generated mask.
2024-09-23 13:14:32 +02:00
8097559c1a Move the candle version to 0.7.1. (#2495) 2024-09-22 20:44:39 +02:00
829dcfa8dc Update cudarc to 0.12.1. (#2494) 2024-09-22 20:32:29 +02:00
17 changed files with 950 additions and 84 deletions

View File

@ -20,7 +20,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.7.0"
version = "0.7.1"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -33,17 +33,17 @@ ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.7.0" }
candle-datasets = { path = "./candle-datasets", version = "0.7.0" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.7.0" }
candle-kernels = { path = "./candle-kernels", version = "0.7.0" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.7.0" }
candle-nn = { path = "./candle-nn", version = "0.7.0" }
candle-onnx = { path = "./candle-onnx", version = "0.7.0" }
candle-transformers = { path = "./candle-transformers", version = "0.7.0" }
candle = { path = "./candle-core", package = "candle-core", version = "0.7.1" }
candle-datasets = { path = "./candle-datasets", version = "0.7.1" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.7.1" }
candle-kernels = { path = "./candle-kernels", version = "0.7.1" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.7.1" }
candle-nn = { path = "./candle-nn", version = "0.7.1" }
candle-onnx = { path = "./candle-onnx", version = "0.7.1" }
candle-transformers = { path = "./candle-transformers", version = "0.7.1" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.12.0", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
cudarc = { version = "0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.3.0"

View File

@ -34,7 +34,10 @@ fn ceil_div(p: usize, q: usize) -> usize {
}
fn pad(p: usize, q: usize) -> usize {
ceil_div(p, q) * q
// Overallocate by q rather than just padding by q as this should pad the last row
// and we don't have enough information here to know how many elements to add :(
// ceil_div(p, q) * q
p + q
}
fn quantize_q8_1(
@ -439,7 +442,7 @@ impl QCudaStorage {
}
_ => crate::bail!("only f32 can be quantized"),
};
let src_len = src.len();
let src_len = pad(src.len(), MATRIX_ROW_PADDING);
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
qcpu_storage.quantize(&src)?;

View File

@ -18,7 +18,7 @@ pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
let actual_blocks = ys.len();
// Validate that the input is the right size
if expected_blocks != actual_blocks {
if actual_blocks < expected_blocks {
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
}

View File

@ -13,7 +13,7 @@ descriptions,
```bash
cargo run --features cuda --example flux -r -- \
--height 1024 --width 1024
--height 1024 --width 1024 \
--prompt "a rusty robot walking on a beach holding a small torch, the robot has the word "rust" written on it, high quality, 4k"
```

View File

@ -23,6 +23,10 @@ struct Args {
#[arg(long)]
cpu: bool,
/// Use the quantized model.
#[arg(long)]
quantized: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
@ -60,6 +64,7 @@ fn run(args: Args) -> Result<()> {
tracing,
decode_only,
model,
quantized,
} = args;
let width = width.unwrap_or(1360);
let height = height.unwrap_or(768);
@ -146,38 +151,71 @@ fn run(args: Args) -> Result<()> {
};
println!("CLIP\n{clip_emb}");
let img = {
let model_file = match model {
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
};
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let cfg = match model {
Model::Dev => flux::model::Config::dev(),
Model::Schnell => flux::model::Config::schnell(),
};
let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?;
let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?;
let state = if quantized {
flux::sampling::State::new(
&t5_emb.to_dtype(candle::DType::F32)?,
&clip_emb.to_dtype(candle::DType::F32)?,
&img.to_dtype(candle::DType::F32)?,
)?
} else {
flux::sampling::State::new(&t5_emb, &clip_emb, &img)?
};
let timesteps = match model {
Model::Dev => {
flux::sampling::get_schedule(50, Some((state.img.dim(1)?, 0.5, 1.15)))
}
Model::Schnell => flux::sampling::get_schedule(4, None),
};
let model = flux::model::Flux::new(&cfg, vb)?;
println!("{state:?}");
println!("{timesteps:?}");
flux::sampling::denoise(
&model,
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
&timesteps,
4.,
)?
if quantized {
let model_file = match model {
Model::Schnell => api
.repo(hf_hub::Repo::model("lmz/candle-flux".to_string()))
.get("flux1-schnell.gguf")?,
Model::Dev => todo!(),
};
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
model_file, &device,
)?;
let model = flux::quantized_model::Flux::new(&cfg, vb)?;
flux::sampling::denoise(
&model,
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
&timesteps,
4.,
)?
.to_dtype(dtype)?
} else {
let model_file = match model {
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
};
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)?
};
let model = flux::model::Flux::new(&cfg, vb)?;
flux::sampling::denoise(
&model,
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
&timesteps,
4.,
)?
}
};
flux::sampling::unpack(&img, height, width)?
}

View File

@ -39,6 +39,11 @@ struct Args {
/// The model weight file, in safetensor format.
#[arg(long)]
model: Option<String>,
/// Whether to use streaming or not, when streaming slices of data of the given size are passed
/// to the encoder/decoder one at a time.
#[arg(long)]
streaming: Option<usize>,
}
fn main() -> Result<()> {
@ -87,20 +92,49 @@ fn main() -> Result<()> {
pcm
}
};
let pcm_len = pcm.len();
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
println!("input pcm shape: {:?}", pcm.shape());
model.encode(&pcm)?
match args.streaming {
Some(chunk_size) => {
let mut code_chunks = vec![];
for pcm in pcm.chunks(chunk_size) {
let pcm = Tensor::new(pcm, &device)?.reshape((1, 1, ()))?;
let code_chunk = model.encode(&pcm)?;
code_chunks.push(code_chunk)
}
Tensor::cat(&code_chunks, candle::D::Minus1)?
}
None => {
let pcm_len = pcm.len();
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
println!("input pcm shape: {:?}", pcm.shape());
model.encode(&pcm)?
}
}
}
};
println!("codes shape: {:?}", codes.shape());
model.reset_state();
match args.action {
Action::AudioToCode => {
codes.save_safetensors("codes", &args.out_file)?;
}
Action::AudioToAudio | Action::CodeToAudio => {
let pcm = model.decode(&codes)?;
let pcm = match args.streaming {
Some(chunk_size) => {
let seq_len = codes.dim(candle::D::Minus1)?;
let mut pcm_chunks = vec![];
for chunk_start in (0..seq_len).step_by(chunk_size) {
let chunk_len = usize::min(chunk_size, seq_len - chunk_start);
let codes = codes.narrow(candle::D::Minus1, chunk_start, chunk_len)?;
let pcm = model.decode_step(&codes.into())?;
if let Some(pcm) = pcm.as_option() {
pcm_chunks.push(pcm.clone())
}
}
Tensor::cat(&pcm_chunks, candle::D::Minus1)?
}
None => model.decode(&codes)?,
};
println!("output pcm shape: {:?}", pcm.shape());
let pcm = pcm.i(0)?.i(0)?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;

View File

@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
version = "0.7.0"
version = "0.7.1"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.7.0" }
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.7.1" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]

View File

@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
version = "0.7.0"
version = "0.7.1"
edition = "2021"
description = "CUDA kernels for Candle"

View File

@ -1,6 +1,6 @@
[package]
name = "candle-metal-kernels"
version = "0.7.0"
version = "0.7.1"
edition = "2021"
description = "Metal kernels for Candle"

View File

@ -1,4 +1,4 @@
use candle::{Result, Tensor};
use candle::{Device, Result, Tensor};
#[derive(Debug, Clone)]
pub struct Cache {
@ -145,3 +145,225 @@ impl KvCache {
self.v.reset();
}
}
#[derive(Debug, Clone)]
pub struct RotatingCache {
all_data: Option<Tensor>,
dim: usize,
// `offset` is the current write index in the buffer
offset: usize,
// The total size of the sequence seen so far.
current_seq_len: usize,
// max_seq_len is the size of the rotating buffer, it is actually allowed for the full
// sequence to grow past this limit.
max_seq_len: usize,
}
impl RotatingCache {
pub fn new(dim: usize, max_seq_len: usize) -> Self {
Self {
all_data: None,
dim,
offset: 0,
current_seq_len: 0,
max_seq_len,
}
}
pub fn offset(&self) -> usize {
self.offset
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn current_seq_len(&self) -> usize {
self.current_seq_len
}
pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}
pub fn all_data(&self) -> &Option<Tensor> {
&self.all_data
}
pub fn current_data(&self) -> Result<Option<Tensor>> {
let data = match self.all_data.as_ref() {
None => None,
Some(d) => {
if self.current_seq_len >= self.max_seq_len {
Some(d.clone())
} else {
Some(d.narrow(self.dim, 0, self.current_seq_len)?)
}
}
};
Ok(data)
}
pub fn reset(&mut self) {
self.offset = 0;
self.current_seq_len = 0;
self.all_data = None;
}
pub fn append(&mut self, src: &Tensor) -> Result<Tensor> {
let seq_len = src.dim(self.dim)?;
// This doesn't seem very idiomatic but because the creation can fail, it's tricky to use
// self.all_data.get_or_insert_with.
if self.all_data.is_none() {
let mut shape = src.dims().to_vec();
shape[self.dim] = self.max_seq_len;
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
self.all_data = Some(ad)
};
let ad = self.all_data.as_mut().unwrap();
self.current_seq_len += seq_len;
if seq_len >= self.max_seq_len {
let to_copy = src
.narrow(self.dim, seq_len - self.max_seq_len, self.max_seq_len)?
.contiguous()?;
ad.slice_set(&to_copy, self.dim, 0)?;
self.offset = 0;
// Here we return `src` rather than `ad` so that all the past can be used.
Ok(src.clone())
} else {
let rem_len = self.max_seq_len - self.offset;
if seq_len <= rem_len {
ad.slice_set(&src.contiguous()?, self.dim, self.offset)?;
self.offset = (self.offset + seq_len) % self.max_seq_len;
} else {
// We have to make two copies here as we go over the boundary of the cache.
if rem_len > 0 {
let src1 = src.narrow(self.dim, 0, rem_len)?.contiguous()?;
ad.slice_set(&src1, self.dim, self.offset)?;
}
let src2 = src
.narrow(self.dim, rem_len, seq_len - rem_len)?
.contiguous()?;
ad.slice_set(&src2, self.dim, 0)?;
self.offset = seq_len - rem_len;
}
if self.current_seq_len >= self.max_seq_len {
Ok(ad.clone())
} else {
Ok(ad.narrow(self.dim, 0, self.current_seq_len)?)
}
}
}
fn get_mask_abs(&self, size1: usize, size2: usize, device: &Device) -> Result<Tensor> {
let context = self.max_seq_len;
let mask: Vec<_> = (0..size1)
.flat_map(|i| {
(0..size2).map(move |j| {
u8::from(size1 + j > size2 + i || size1 + j + context < size2 + i)
})
})
.collect();
Tensor::from_slice(&mask, (size1, size2), device)
}
fn get_mask_rel(&self, size1: usize, size2: usize, device: &Device) -> Result<Tensor> {
let context = self.max_seq_len;
let upd_offset = (self.offset + size1) % self.max_seq_len;
let mask: Vec<_> = (0..size1)
.flat_map(|pos_src| {
// The absolute position of the elements that will get added to the cache.
let pos_src = self.current_seq_len + pos_src;
(0..size2).map(move |pos_cache_rel| {
// The absolute position of the cache elements after the addition.
let pos_cache = self.current_seq_len + size1 + pos_cache_rel - upd_offset;
let pos_cache = if pos_cache_rel < upd_offset {
pos_cache
} else {
pos_cache - self.max_seq_len
};
u8::from(pos_cache > pos_src || pos_cache + context < pos_src)
})
})
.collect();
Tensor::from_slice(&mask, (size1, size2), device)
}
/// Returns the attn_mask to be applied *after* adding `seq_len` to the cache.
pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
let mask = if seq_len == 1 {
None
} else {
let mask = if seq_len < self.max_seq_len {
let cache_out_len = (self.current_seq_len + seq_len).min(self.max_seq_len);
self.get_mask_rel(seq_len, cache_out_len, device)?
} else {
self.get_mask_abs(seq_len, seq_len, device)?
};
Some(mask)
};
Ok(mask)
}
}
#[derive(Debug, Clone)]
pub struct RotatingKvCache {
k: RotatingCache,
v: RotatingCache,
}
impl RotatingKvCache {
pub fn new(dim: usize, max_seq_len: usize) -> Self {
let k = RotatingCache::new(dim, max_seq_len);
let v = RotatingCache::new(dim, max_seq_len);
Self { k, v }
}
pub fn k_cache(&self) -> &RotatingCache {
&self.k
}
pub fn v_cache(&self) -> &RotatingCache {
&self.v
}
pub fn k_cache_mut(&mut self) -> &mut RotatingCache {
&mut self.k
}
pub fn v_cache_mut(&mut self) -> &mut RotatingCache {
&mut self.v
}
pub fn k(&self) -> Result<Option<Tensor>> {
self.k.current_data()
}
pub fn v(&self) -> Result<Option<Tensor>> {
self.v.current_data()
}
pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
let out_k = self.k.append(k)?;
let out_v = self.v.append(v)?;
Ok((out_k, out_v))
}
pub fn offset(&self) -> usize {
self.k.offset()
}
pub fn current_seq_len(&self) -> usize {
self.k.current_seq_len()
}
pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
self.k.attn_mask(seq_len, device)
}
pub fn reset(&mut self) {
self.k.reset();
self.v.reset();
}
}

110
candle-nn/tests/kv_cache.rs Normal file
View File

@ -0,0 +1,110 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{Device, Result, Tensor};
#[test]
fn kv_cache() -> Result<()> {
let mut cache = candle_nn::kv_cache::Cache::new(0, 16);
for _ in [0, 1] {
assert_eq!(cache.current_seq_len(), 0);
let data = cache.current_data()?;
assert!(data.is_none());
let t = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?;
cache.append(&t)?;
let data = cache.current_data()?.unwrap();
assert_eq!(data.to_vec1::<f32>()?, [1., 2., 3.]);
let t = Tensor::new(&[4f32], &Device::Cpu)?;
cache.append(&t)?;
let data = cache.current_data()?.unwrap();
assert_eq!(data.to_vec1::<f32>()?, [1., 2., 3., 4.]);
let t = Tensor::new(&[0f32, 5., 6., 7.], &Device::Cpu)?;
cache.append(&t)?;
let data = cache.current_data()?.unwrap();
assert_eq!(data.to_vec1::<f32>()?, [1., 2., 3., 4., 0., 5., 6., 7.]);
assert_eq!(cache.current_seq_len(), 8);
cache.reset();
}
Ok(())
}
#[test]
fn rotating_kv_cache() -> Result<()> {
let mut cache = candle_nn::kv_cache::RotatingCache::new(0, 6);
for _ in [0, 1] {
assert_eq!(cache.offset(), 0);
assert_eq!(cache.current_seq_len(), 0);
let data = cache.current_data()?;
assert!(data.is_none());
let t = Tensor::new(&[1., 2., 3.], &Device::Cpu)?;
let data = cache.append(&t)?;
assert_eq!(data.to_vec1::<f64>()?, [1., 2., 3.]);
let t = Tensor::new(&[4.], &Device::Cpu)?;
let data = cache.append(&t)?;
assert_eq!(data.to_vec1::<f64>()?, [1., 2., 3., 4.]);
let t = Tensor::new(&[0., 5., 6., 7.], &Device::Cpu)?;
let data = cache.append(&t)?;
assert_eq!(data.to_vec1::<f64>()?, [6., 7., 3., 4., 0., 5.]);
assert_eq!(cache.current_seq_len(), 8);
assert_eq!(cache.offset(), 2);
let t = Tensor::new(&[8.], &Device::Cpu)?;
let data = cache.append(&t)?;
assert_eq!(data.to_vec1::<f64>()?, [6., 7., 8., 4., 0., 5.]);
assert_eq!(cache.current_seq_len(), 9);
assert_eq!(cache.offset(), 3);
let t = Tensor::new(&[9., 10., 11.], &Device::Cpu)?;
let data = cache.append(&t)?;
assert_eq!(data.to_vec1::<f64>()?, [6., 7., 8., 9., 10., 11.]);
assert_eq!(cache.current_seq_len(), 12);
assert_eq!(cache.offset(), 0);
let t = Tensor::new(&[12.], &Device::Cpu)?;
let data = cache.append(&t)?;
assert_eq!(data.to_vec1::<f64>()?, [12., 7., 8., 9., 10., 11.]);
assert_eq!(cache.current_seq_len(), 13);
assert_eq!(cache.offset(), 1);
let mask = cache.attn_mask(2, &Device::Cpu)?.unwrap();
assert_eq!(
mask.to_vec2::<u8>()?,
&[[0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 0]]
);
let mask = cache.attn_mask(3, &Device::Cpu)?.unwrap();
assert_eq!(
mask.to_vec2::<u8>()?,
&[[0, 0, 1, 1, 0, 0], [0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 0]],
);
let t = Tensor::new(&[0., 1., 2., 3., 4., 5., 6., 7., 8.], &Device::Cpu)?;
let data = cache.append(&t)?;
assert_eq!(data.to_vec1::<f64>()?, [0., 1., 2., 3., 4., 5., 6., 7., 8.]);
assert_eq!(cache.current_seq_len(), 22);
assert_eq!(cache.offset(), 0);
let mask = cache.attn_mask(1, &Device::Cpu)?;
assert!(mask.is_none());
let mask = cache.attn_mask(2, &Device::Cpu)?.unwrap();
assert_eq!(
mask.to_vec2::<u8>()?,
&[[0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]
);
let mask = cache.attn_mask(3, &Device::Cpu)?.unwrap();
assert_eq!(
mask.to_vec2::<u8>()?,
&[[0, 1, 1, 0, 0, 0], [0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 0]]
);
let t = Tensor::new(&[42.], &Device::Cpu)?;
let data = cache.append(&t)?;
assert_eq!(data.to_vec1::<f64>()?, [42., 4., 5., 6., 7., 8.]);
assert_eq!(cache.current_seq_len(), 23);
assert_eq!(cache.offset(), 1);
cache.reset();
}
Ok(())
}

View File

@ -1,6 +1,6 @@
[package]
name = "candle-onnx"
version = "0.7.0"
version = "0.7.1"
edition = "2021"
description = "ONNX support for Candle"
@ -10,8 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
candle = { path = "../candle-core", package = "candle-core", version = "0.7.0" }
candle-nn = { path = "../candle-nn", version = "0.7.0" }
candle = { path = "../candle-core", package = "candle-core", version = "0.7.1" }
candle-nn = { path = "../candle-nn", version = "0.7.1" }
prost = "0.12.1"
[build-dependencies]

View File

@ -1,3 +1,20 @@
use candle::{Result, Tensor};
pub trait WithForward {
#[allow(clippy::too_many_arguments)]
fn forward(
&self,
img: &Tensor,
img_ids: &Tensor,
txt: &Tensor,
txt_ids: &Tensor,
timesteps: &Tensor,
y: &Tensor,
guidance: Option<&Tensor>,
) -> Result<Tensor>;
}
pub mod autoencoder;
pub mod model;
pub mod quantized_model;
pub mod sampling;

View File

@ -109,14 +109,14 @@ fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result<Tensor> {
(fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec())
}
fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
pub(crate) fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
let q = apply_rope(q, pe)?.contiguous()?;
let k = apply_rope(k, pe)?.contiguous()?;
let x = scaled_dot_product_attention(&q, &k, v)?;
x.transpose(1, 2)?.flatten_from(2)
}
fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
pub(crate) fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
const TIME_FACTOR: f64 = 1000.;
const MAX_PERIOD: f64 = 10000.;
if dim % 2 == 1 {
@ -144,7 +144,7 @@ pub struct EmbedNd {
}
impl EmbedNd {
fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
pub fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
Self {
dim,
theta,
@ -575,9 +575,11 @@ impl Flux {
final_layer,
})
}
}
impl super::WithForward for Flux {
#[allow(clippy::too_many_arguments)]
pub fn forward(
fn forward(
&self,
img: &Tensor,
img_ids: &Tensor,

View File

@ -0,0 +1,465 @@
use super::model::{attention, timestep_embedding, Config, EmbedNd};
use crate::quantized_nn::{linear, linear_b, Linear};
use crate::quantized_var_builder::VarBuilder;
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{LayerNorm, RmsNorm};
fn layer_norm(dim: usize, vb: VarBuilder) -> Result<LayerNorm> {
let ws = Tensor::ones(dim, DType::F32, vb.device())?;
Ok(LayerNorm::new_no_bias(ws, 1e-6))
}
#[derive(Debug, Clone)]
pub struct MlpEmbedder {
in_layer: Linear,
out_layer: Linear,
}
impl MlpEmbedder {
fn new(in_sz: usize, h_sz: usize, vb: VarBuilder) -> Result<Self> {
let in_layer = linear(in_sz, h_sz, vb.pp("in_layer"))?;
let out_layer = linear(h_sz, h_sz, vb.pp("out_layer"))?;
Ok(Self {
in_layer,
out_layer,
})
}
}
impl candle::Module for MlpEmbedder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.in_layer)?.silu()?.apply(&self.out_layer)
}
}
#[derive(Debug, Clone)]
pub struct QkNorm {
query_norm: RmsNorm,
key_norm: RmsNorm,
}
impl QkNorm {
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
let query_norm = vb.get(dim, "query_norm.scale")?.dequantize(vb.device())?;
let query_norm = RmsNorm::new(query_norm, 1e-6);
let key_norm = vb.get(dim, "key_norm.scale")?.dequantize(vb.device())?;
let key_norm = RmsNorm::new(key_norm, 1e-6);
Ok(Self {
query_norm,
key_norm,
})
}
}
struct ModulationOut {
shift: Tensor,
scale: Tensor,
gate: Tensor,
}
impl ModulationOut {
fn scale_shift(&self, xs: &Tensor) -> Result<Tensor> {
xs.broadcast_mul(&(&self.scale + 1.)?)?
.broadcast_add(&self.shift)
}
fn gate(&self, xs: &Tensor) -> Result<Tensor> {
self.gate.broadcast_mul(xs)
}
}
#[derive(Debug, Clone)]
struct Modulation1 {
lin: Linear,
}
impl Modulation1 {
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
let lin = linear(dim, 3 * dim, vb.pp("lin"))?;
Ok(Self { lin })
}
fn forward(&self, vec_: &Tensor) -> Result<ModulationOut> {
let ys = vec_
.silu()?
.apply(&self.lin)?
.unsqueeze(1)?
.chunk(3, D::Minus1)?;
if ys.len() != 3 {
candle::bail!("unexpected len from chunk {ys:?}")
}
Ok(ModulationOut {
shift: ys[0].clone(),
scale: ys[1].clone(),
gate: ys[2].clone(),
})
}
}
#[derive(Debug, Clone)]
struct Modulation2 {
lin: Linear,
}
impl Modulation2 {
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
let lin = linear(dim, 6 * dim, vb.pp("lin"))?;
Ok(Self { lin })
}
fn forward(&self, vec_: &Tensor) -> Result<(ModulationOut, ModulationOut)> {
let ys = vec_
.silu()?
.apply(&self.lin)?
.unsqueeze(1)?
.chunk(6, D::Minus1)?;
if ys.len() != 6 {
candle::bail!("unexpected len from chunk {ys:?}")
}
let mod1 = ModulationOut {
shift: ys[0].clone(),
scale: ys[1].clone(),
gate: ys[2].clone(),
};
let mod2 = ModulationOut {
shift: ys[3].clone(),
scale: ys[4].clone(),
gate: ys[5].clone(),
};
Ok((mod1, mod2))
}
}
#[derive(Debug, Clone)]
pub struct SelfAttention {
qkv: Linear,
norm: QkNorm,
proj: Linear,
num_heads: usize,
}
impl SelfAttention {
fn new(dim: usize, num_heads: usize, qkv_bias: bool, vb: VarBuilder) -> Result<Self> {
let head_dim = dim / num_heads;
let qkv = linear_b(dim, dim * 3, qkv_bias, vb.pp("qkv"))?;
let norm = QkNorm::new(head_dim, vb.pp("norm"))?;
let proj = linear(dim, dim, vb.pp("proj"))?;
Ok(Self {
qkv,
norm,
proj,
num_heads,
})
}
fn qkv(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
let qkv = xs.apply(&self.qkv)?;
let (b, l, _khd) = qkv.dims3()?;
let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
let q = q.apply(&self.norm.query_norm)?;
let k = k.apply(&self.norm.key_norm)?;
Ok((q, k, v))
}
#[allow(unused)]
fn forward(&self, xs: &Tensor, pe: &Tensor) -> Result<Tensor> {
let (q, k, v) = self.qkv(xs)?;
attention(&q, &k, &v, pe)?.apply(&self.proj)
}
}
#[derive(Debug, Clone)]
struct Mlp {
lin1: Linear,
lin2: Linear,
}
impl Mlp {
fn new(in_sz: usize, mlp_sz: usize, vb: VarBuilder) -> Result<Self> {
let lin1 = linear(in_sz, mlp_sz, vb.pp("0"))?;
let lin2 = linear(mlp_sz, in_sz, vb.pp("2"))?;
Ok(Self { lin1, lin2 })
}
}
impl candle::Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2)
}
}
#[derive(Debug, Clone)]
pub struct DoubleStreamBlock {
img_mod: Modulation2,
img_norm1: LayerNorm,
img_attn: SelfAttention,
img_norm2: LayerNorm,
img_mlp: Mlp,
txt_mod: Modulation2,
txt_norm1: LayerNorm,
txt_attn: SelfAttention,
txt_norm2: LayerNorm,
txt_mlp: Mlp,
}
impl DoubleStreamBlock {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let h_sz = cfg.hidden_size;
let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize;
let img_mod = Modulation2::new(h_sz, vb.pp("img_mod"))?;
let img_norm1 = layer_norm(h_sz, vb.pp("img_norm1"))?;
let img_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("img_attn"))?;
let img_norm2 = layer_norm(h_sz, vb.pp("img_norm2"))?;
let img_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("img_mlp"))?;
let txt_mod = Modulation2::new(h_sz, vb.pp("txt_mod"))?;
let txt_norm1 = layer_norm(h_sz, vb.pp("txt_norm1"))?;
let txt_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("txt_attn"))?;
let txt_norm2 = layer_norm(h_sz, vb.pp("txt_norm2"))?;
let txt_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("txt_mlp"))?;
Ok(Self {
img_mod,
img_norm1,
img_attn,
img_norm2,
img_mlp,
txt_mod,
txt_norm1,
txt_attn,
txt_norm2,
txt_mlp,
})
}
fn forward(
&self,
img: &Tensor,
txt: &Tensor,
vec_: &Tensor,
pe: &Tensor,
) -> Result<(Tensor, Tensor)> {
let (img_mod1, img_mod2) = self.img_mod.forward(vec_)?; // shift, scale, gate
let (txt_mod1, txt_mod2) = self.txt_mod.forward(vec_)?; // shift, scale, gate
let img_modulated = img.apply(&self.img_norm1)?;
let img_modulated = img_mod1.scale_shift(&img_modulated)?;
let (img_q, img_k, img_v) = self.img_attn.qkv(&img_modulated)?;
let txt_modulated = txt.apply(&self.txt_norm1)?;
let txt_modulated = txt_mod1.scale_shift(&txt_modulated)?;
let (txt_q, txt_k, txt_v) = self.txt_attn.qkv(&txt_modulated)?;
let q = Tensor::cat(&[txt_q, img_q], 2)?;
let k = Tensor::cat(&[txt_k, img_k], 2)?;
let v = Tensor::cat(&[txt_v, img_v], 2)?;
let attn = attention(&q, &k, &v, pe)?;
let txt_attn = attn.narrow(1, 0, txt.dim(1)?)?;
let img_attn = attn.narrow(1, txt.dim(1)?, attn.dim(1)? - txt.dim(1)?)?;
let img = (img + img_mod1.gate(&img_attn.apply(&self.img_attn.proj)?))?;
let img = (&img
+ img_mod2.gate(
&img_mod2
.scale_shift(&img.apply(&self.img_norm2)?)?
.apply(&self.img_mlp)?,
)?)?;
let txt = (txt + txt_mod1.gate(&txt_attn.apply(&self.txt_attn.proj)?))?;
let txt = (&txt
+ txt_mod2.gate(
&txt_mod2
.scale_shift(&txt.apply(&self.txt_norm2)?)?
.apply(&self.txt_mlp)?,
)?)?;
Ok((img, txt))
}
}
#[derive(Debug, Clone)]
pub struct SingleStreamBlock {
linear1: Linear,
linear2: Linear,
norm: QkNorm,
pre_norm: LayerNorm,
modulation: Modulation1,
h_sz: usize,
mlp_sz: usize,
num_heads: usize,
}
impl SingleStreamBlock {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let h_sz = cfg.hidden_size;
let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize;
let head_dim = h_sz / cfg.num_heads;
let linear1 = linear(h_sz, h_sz * 3 + mlp_sz, vb.pp("linear1"))?;
let linear2 = linear(h_sz + mlp_sz, h_sz, vb.pp("linear2"))?;
let norm = QkNorm::new(head_dim, vb.pp("norm"))?;
let pre_norm = layer_norm(h_sz, vb.pp("pre_norm"))?;
let modulation = Modulation1::new(h_sz, vb.pp("modulation"))?;
Ok(Self {
linear1,
linear2,
norm,
pre_norm,
modulation,
h_sz,
mlp_sz,
num_heads: cfg.num_heads,
})
}
fn forward(&self, xs: &Tensor, vec_: &Tensor, pe: &Tensor) -> Result<Tensor> {
let mod_ = self.modulation.forward(vec_)?;
let x_mod = mod_.scale_shift(&xs.apply(&self.pre_norm)?)?;
let x_mod = x_mod.apply(&self.linear1)?;
let qkv = x_mod.narrow(D::Minus1, 0, 3 * self.h_sz)?;
let (b, l, _khd) = qkv.dims3()?;
let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
let mlp = x_mod.narrow(D::Minus1, 3 * self.h_sz, self.mlp_sz)?;
let q = q.apply(&self.norm.query_norm)?;
let k = k.apply(&self.norm.key_norm)?;
let attn = attention(&q, &k, &v, pe)?;
let output = Tensor::cat(&[attn, mlp.gelu()?], 2)?.apply(&self.linear2)?;
xs + mod_.gate(&output)
}
}
#[derive(Debug, Clone)]
pub struct LastLayer {
norm_final: LayerNorm,
linear: Linear,
ada_ln_modulation: Linear,
}
impl LastLayer {
fn new(h_sz: usize, p_sz: usize, out_c: usize, vb: VarBuilder) -> Result<Self> {
let norm_final = layer_norm(h_sz, vb.pp("norm_final"))?;
let linear_ = linear(h_sz, p_sz * p_sz * out_c, vb.pp("linear"))?;
let ada_ln_modulation = linear(h_sz, 2 * h_sz, vb.pp("adaLN_modulation.1"))?;
Ok(Self {
norm_final,
linear: linear_,
ada_ln_modulation,
})
}
fn forward(&self, xs: &Tensor, vec: &Tensor) -> Result<Tensor> {
let chunks = vec.silu()?.apply(&self.ada_ln_modulation)?.chunk(2, 1)?;
let (shift, scale) = (&chunks[0], &chunks[1]);
let xs = xs
.apply(&self.norm_final)?
.broadcast_mul(&(scale.unsqueeze(1)? + 1.0)?)?
.broadcast_add(&shift.unsqueeze(1)?)?;
xs.apply(&self.linear)
}
}
#[derive(Debug, Clone)]
pub struct Flux {
img_in: Linear,
txt_in: Linear,
time_in: MlpEmbedder,
vector_in: MlpEmbedder,
guidance_in: Option<MlpEmbedder>,
pe_embedder: EmbedNd,
double_blocks: Vec<DoubleStreamBlock>,
single_blocks: Vec<SingleStreamBlock>,
final_layer: LastLayer,
}
impl Flux {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let img_in = linear(cfg.in_channels, cfg.hidden_size, vb.pp("img_in"))?;
let txt_in = linear(cfg.context_in_dim, cfg.hidden_size, vb.pp("txt_in"))?;
let mut double_blocks = Vec::with_capacity(cfg.depth);
let vb_d = vb.pp("double_blocks");
for idx in 0..cfg.depth {
let db = DoubleStreamBlock::new(cfg, vb_d.pp(idx))?;
double_blocks.push(db)
}
let mut single_blocks = Vec::with_capacity(cfg.depth_single_blocks);
let vb_s = vb.pp("single_blocks");
for idx in 0..cfg.depth_single_blocks {
let sb = SingleStreamBlock::new(cfg, vb_s.pp(idx))?;
single_blocks.push(sb)
}
let time_in = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("time_in"))?;
let vector_in = MlpEmbedder::new(cfg.vec_in_dim, cfg.hidden_size, vb.pp("vector_in"))?;
let guidance_in = if cfg.guidance_embed {
let mlp = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("guidance_in"))?;
Some(mlp)
} else {
None
};
let final_layer =
LastLayer::new(cfg.hidden_size, 1, cfg.in_channels, vb.pp("final_layer"))?;
let pe_dim = cfg.hidden_size / cfg.num_heads;
let pe_embedder = EmbedNd::new(pe_dim, cfg.theta, cfg.axes_dim.to_vec());
Ok(Self {
img_in,
txt_in,
time_in,
vector_in,
guidance_in,
pe_embedder,
double_blocks,
single_blocks,
final_layer,
})
}
}
impl super::WithForward for Flux {
#[allow(clippy::too_many_arguments)]
fn forward(
&self,
img: &Tensor,
img_ids: &Tensor,
txt: &Tensor,
txt_ids: &Tensor,
timesteps: &Tensor,
y: &Tensor,
guidance: Option<&Tensor>,
) -> Result<Tensor> {
if txt.rank() != 3 {
candle::bail!("unexpected shape for txt {:?}", txt.shape())
}
if img.rank() != 3 {
candle::bail!("unexpected shape for img {:?}", img.shape())
}
let dtype = img.dtype();
let pe = {
let ids = Tensor::cat(&[txt_ids, img_ids], 1)?;
ids.apply(&self.pe_embedder)?
};
let mut txt = txt.apply(&self.txt_in)?;
let mut img = img.apply(&self.img_in)?;
let vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?;
let vec_ = match (self.guidance_in.as_ref(), guidance) {
(Some(g_in), Some(guidance)) => {
(vec_ + timestep_embedding(guidance, 256, dtype)?.apply(g_in))?
}
_ => vec_,
};
let vec_ = (vec_ + y.apply(&self.vector_in))?;
// Double blocks
for block in self.double_blocks.iter() {
(img, txt) = block.forward(&img, &txt, &vec_, &pe)?
}
// Single blocks
let mut img = Tensor::cat(&[&txt, &img], 1)?;
for block in self.single_blocks.iter() {
img = block.forward(&img, &vec_, &pe)?;
}
let img = img.i((.., txt.dim(1)?..))?;
self.final_layer.forward(&img, &vec_)
}
}

View File

@ -92,8 +92,8 @@ pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
}
#[allow(clippy::too_many_arguments)]
pub fn denoise(
model: &super::model::Flux,
pub fn denoise<M: super::WithForward>(
model: &M,
img: &Tensor,
img_ids: &Tensor,
txt: &Tensor,

View File

@ -101,21 +101,6 @@ impl Module for LayerScale {
}
}
pub(crate) fn get_mask(
size1: usize,
size2: usize,
context: usize,
device: &Device,
) -> Result<Tensor> {
let mask: Vec<_> = (0..size1)
.flat_map(|i| {
(0..size2)
.map(move |j| u8::from(size1 + j > size2 + i || size1 + j + context < size2 + i))
})
.collect();
Tensor::from_slice(&mask, (size1, size2), device)
}
#[derive(Debug, Clone)]
pub struct StreamingMultiheadAttention {
q_proj: Linear,
@ -127,7 +112,7 @@ pub struct StreamingMultiheadAttention {
context: usize,
neg_inf: Tensor,
rope: Option<Arc<RotaryEmbedding>>,
kv_cache: candle_nn::kv_cache::KvCache,
kv_cache: candle_nn::kv_cache::RotatingKvCache,
pos: usize,
use_flash_attn: bool,
span: tracing::Span,
@ -153,7 +138,7 @@ impl StreamingMultiheadAttention {
num_heads: cfg.num_heads,
context: cfg.context,
neg_inf,
kv_cache: candle_nn::kv_cache::KvCache::new(2, cfg.max_seq_len),
kv_cache: candle_nn::kv_cache::RotatingKvCache::new(2, cfg.context),
pos: 0,
use_flash_attn: false,
span: tracing::span!(tracing::Level::TRACE, "mha"),
@ -236,7 +221,7 @@ impl StreamingMultiheadAttention {
self.kv_cache.reset()
}
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) {
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::RotatingKvCache) {
self.kv_cache = kv_cache
}
}
@ -582,7 +567,7 @@ impl StreamingTransformerLayer {
self.self_attn.reset_kv_cache()
}
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) {
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::RotatingKvCache) {
self.self_attn.set_kv_cache(kv_cache)
}
}
@ -590,7 +575,6 @@ impl StreamingTransformerLayer {
#[derive(Debug, Clone)]
pub struct StreamingTransformer {
layers: Vec<StreamingTransformerLayer>,
context: usize,
positional_embedding: PositionalEmbedding,
max_period: usize,
}
@ -617,7 +601,6 @@ impl StreamingTransformer {
}
Ok(Self {
layers,
context: cfg.context,
positional_embedding: cfg.positional_embedding,
max_period: cfg.max_period,
})
@ -629,19 +612,11 @@ impl StreamingTransformer {
pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&Tensor>) -> Result<Tensor> {
let (_b, t, c) = xs.dims3()?;
// We will extract at most "context" from the kv_cache.
// Note that the mask will discard the values that are before context.
let pos = self.layers[0]
let pos = self.layers[0].self_attn.kv_cache.current_seq_len();
let mask = self.layers[0]
.self_attn
.kv_cache
.k_cache()
.current_seq_len()
.min(self.context);
let mask = if t == 1 {
None
} else {
Some(get_mask(t, pos + t, self.context, xs.device())?)
};
.attn_mask(t, xs.device())?;
let mut xs = match self.positional_embedding {
PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(),
PositionalEmbedding::Sin => {