mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
Compare commits
11 Commits
qmm-pad-fi
...
cudarc-12-
Author | SHA1 | Date | |
---|---|---|---|
42c702a023 | |||
d6f01f625d | |||
3277844fd9 | |||
c79bf421c7 | |||
58c1e909d3 | |||
9964c6d86c | |||
fc877920ce | |||
6547c4bfc3 | |||
f9579f80be | |||
1bddd44cb8 | |||
9cfe3c7141 |
18
Cargo.toml
18
Cargo.toml
@ -20,7 +20,7 @@ exclude = [
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.7.1"
|
version = "0.7.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "Minimalist ML framework."
|
description = "Minimalist ML framework."
|
||||||
repository = "https://github.com/huggingface/candle"
|
repository = "https://github.com/huggingface/candle"
|
||||||
@ -33,14 +33,14 @@ ab_glyph = "0.2.23"
|
|||||||
accelerate-src = { version = "0.3.2" }
|
accelerate-src = { version = "0.3.2" }
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
byteorder = "1.4.3"
|
byteorder = "1.4.3"
|
||||||
candle = { path = "./candle-core", package = "candle-core", version = "0.7.1" }
|
candle = { path = "./candle-core", package = "candle-core", version = "0.7.0" }
|
||||||
candle-datasets = { path = "./candle-datasets", version = "0.7.1" }
|
candle-datasets = { path = "./candle-datasets", version = "0.7.0" }
|
||||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.7.1" }
|
candle-flash-attn = { path = "./candle-flash-attn", version = "0.7.0" }
|
||||||
candle-kernels = { path = "./candle-kernels", version = "0.7.1" }
|
candle-kernels = { path = "./candle-kernels", version = "0.7.0" }
|
||||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.7.1" }
|
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.7.0" }
|
||||||
candle-nn = { path = "./candle-nn", version = "0.7.1" }
|
candle-nn = { path = "./candle-nn", version = "0.7.0" }
|
||||||
candle-onnx = { path = "./candle-onnx", version = "0.7.1" }
|
candle-onnx = { path = "./candle-onnx", version = "0.7.0" }
|
||||||
candle-transformers = { path = "./candle-transformers", version = "0.7.1" }
|
candle-transformers = { path = "./candle-transformers", version = "0.7.0" }
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
criterion = { version = "0.5.1", default-features=false }
|
criterion = { version = "0.5.1", default-features=false }
|
||||||
cudarc = { version = "0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
cudarc = { version = "0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||||
|
@ -34,10 +34,7 @@ fn ceil_div(p: usize, q: usize) -> usize {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn pad(p: usize, q: usize) -> usize {
|
fn pad(p: usize, q: usize) -> usize {
|
||||||
// Overallocate by q rather than just padding by q as this should pad the last row
|
ceil_div(p, q) * q
|
||||||
// and we don't have enough information here to know how many elements to add :(
|
|
||||||
// ceil_div(p, q) * q
|
|
||||||
p + q
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q8_1(
|
fn quantize_q8_1(
|
||||||
@ -442,7 +439,7 @@ impl QCudaStorage {
|
|||||||
}
|
}
|
||||||
_ => crate::bail!("only f32 can be quantized"),
|
_ => crate::bail!("only f32 can be quantized"),
|
||||||
};
|
};
|
||||||
let src_len = pad(src.len(), MATRIX_ROW_PADDING);
|
let src_len = src.len();
|
||||||
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
|
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
|
||||||
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
|
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
|
||||||
qcpu_storage.quantize(&src)?;
|
qcpu_storage.quantize(&src)?;
|
||||||
|
@ -18,7 +18,7 @@ pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
|
|||||||
let actual_blocks = ys.len();
|
let actual_blocks = ys.len();
|
||||||
|
|
||||||
// Validate that the input is the right size
|
// Validate that the input is the right size
|
||||||
if actual_blocks < expected_blocks {
|
if expected_blocks != actual_blocks {
|
||||||
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
|
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ descriptions,
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
cargo run --features cuda --example flux -r -- \
|
cargo run --features cuda --example flux -r -- \
|
||||||
--height 1024 --width 1024 \
|
--height 1024 --width 1024
|
||||||
--prompt "a rusty robot walking on a beach holding a small torch, the robot has the word "rust" written on it, high quality, 4k"
|
--prompt "a rusty robot walking on a beach holding a small torch, the robot has the word "rust" written on it, high quality, 4k"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -23,10 +23,6 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cpu: bool,
|
cpu: bool,
|
||||||
|
|
||||||
/// Use the quantized model.
|
|
||||||
#[arg(long)]
|
|
||||||
quantized: bool,
|
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tracing: bool,
|
tracing: bool,
|
||||||
@ -64,7 +60,6 @@ fn run(args: Args) -> Result<()> {
|
|||||||
tracing,
|
tracing,
|
||||||
decode_only,
|
decode_only,
|
||||||
model,
|
model,
|
||||||
quantized,
|
|
||||||
} = args;
|
} = args;
|
||||||
let width = width.unwrap_or(1360);
|
let width = width.unwrap_or(1360);
|
||||||
let height = height.unwrap_or(768);
|
let height = height.unwrap_or(768);
|
||||||
@ -151,71 +146,38 @@ fn run(args: Args) -> Result<()> {
|
|||||||
};
|
};
|
||||||
println!("CLIP\n{clip_emb}");
|
println!("CLIP\n{clip_emb}");
|
||||||
let img = {
|
let img = {
|
||||||
|
let model_file = match model {
|
||||||
|
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
|
||||||
|
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
|
||||||
|
};
|
||||||
|
let vb =
|
||||||
|
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
||||||
let cfg = match model {
|
let cfg = match model {
|
||||||
Model::Dev => flux::model::Config::dev(),
|
Model::Dev => flux::model::Config::dev(),
|
||||||
Model::Schnell => flux::model::Config::schnell(),
|
Model::Schnell => flux::model::Config::schnell(),
|
||||||
};
|
};
|
||||||
let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?;
|
let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?;
|
||||||
let state = if quantized {
|
let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?;
|
||||||
flux::sampling::State::new(
|
|
||||||
&t5_emb.to_dtype(candle::DType::F32)?,
|
|
||||||
&clip_emb.to_dtype(candle::DType::F32)?,
|
|
||||||
&img.to_dtype(candle::DType::F32)?,
|
|
||||||
)?
|
|
||||||
} else {
|
|
||||||
flux::sampling::State::new(&t5_emb, &clip_emb, &img)?
|
|
||||||
};
|
|
||||||
let timesteps = match model {
|
let timesteps = match model {
|
||||||
Model::Dev => {
|
Model::Dev => {
|
||||||
flux::sampling::get_schedule(50, Some((state.img.dim(1)?, 0.5, 1.15)))
|
flux::sampling::get_schedule(50, Some((state.img.dim(1)?, 0.5, 1.15)))
|
||||||
}
|
}
|
||||||
Model::Schnell => flux::sampling::get_schedule(4, None),
|
Model::Schnell => flux::sampling::get_schedule(4, None),
|
||||||
};
|
};
|
||||||
|
let model = flux::model::Flux::new(&cfg, vb)?;
|
||||||
|
|
||||||
println!("{state:?}");
|
println!("{state:?}");
|
||||||
println!("{timesteps:?}");
|
println!("{timesteps:?}");
|
||||||
if quantized {
|
flux::sampling::denoise(
|
||||||
let model_file = match model {
|
&model,
|
||||||
Model::Schnell => api
|
&state.img,
|
||||||
.repo(hf_hub::Repo::model("lmz/candle-flux".to_string()))
|
&state.img_ids,
|
||||||
.get("flux1-schnell.gguf")?,
|
&state.txt,
|
||||||
Model::Dev => todo!(),
|
&state.txt_ids,
|
||||||
};
|
&state.vec,
|
||||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
×teps,
|
||||||
model_file, &device,
|
4.,
|
||||||
)?;
|
)?
|
||||||
|
|
||||||
let model = flux::quantized_model::Flux::new(&cfg, vb)?;
|
|
||||||
flux::sampling::denoise(
|
|
||||||
&model,
|
|
||||||
&state.img,
|
|
||||||
&state.img_ids,
|
|
||||||
&state.txt,
|
|
||||||
&state.txt_ids,
|
|
||||||
&state.vec,
|
|
||||||
×teps,
|
|
||||||
4.,
|
|
||||||
)?
|
|
||||||
.to_dtype(dtype)?
|
|
||||||
} else {
|
|
||||||
let model_file = match model {
|
|
||||||
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
|
|
||||||
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
|
|
||||||
};
|
|
||||||
let vb = unsafe {
|
|
||||||
VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)?
|
|
||||||
};
|
|
||||||
let model = flux::model::Flux::new(&cfg, vb)?;
|
|
||||||
flux::sampling::denoise(
|
|
||||||
&model,
|
|
||||||
&state.img,
|
|
||||||
&state.img_ids,
|
|
||||||
&state.txt,
|
|
||||||
&state.txt_ids,
|
|
||||||
&state.vec,
|
|
||||||
×teps,
|
|
||||||
4.,
|
|
||||||
)?
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
flux::sampling::unpack(&img, height, width)?
|
flux::sampling::unpack(&img, height, width)?
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-flash-attn"
|
name = "candle-flash-attn"
|
||||||
version = "0.7.1"
|
version = "0.7.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Flash attention layer for the candle ML framework."
|
description = "Flash attention layer for the candle ML framework."
|
||||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.7.1" }
|
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.7.0" }
|
||||||
half = { version = "2.3.1", features = ["num-traits"] }
|
half = { version = "2.3.1", features = ["num-traits"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-kernels"
|
name = "candle-kernels"
|
||||||
version = "0.7.1"
|
version = "0.7.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "CUDA kernels for Candle"
|
description = "CUDA kernels for Candle"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-metal-kernels"
|
name = "candle-metal-kernels"
|
||||||
version = "0.7.1"
|
version = "0.7.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Metal kernels for Candle"
|
description = "Metal kernels for Candle"
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use candle::{Device, Result, Tensor};
|
use candle::{Result, Tensor};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Cache {
|
pub struct Cache {
|
||||||
@ -255,56 +255,6 @@ impl RotatingCache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_mask_abs(&self, size1: usize, size2: usize, device: &Device) -> Result<Tensor> {
|
|
||||||
let context = self.max_seq_len;
|
|
||||||
let mask: Vec<_> = (0..size1)
|
|
||||||
.flat_map(|i| {
|
|
||||||
(0..size2).map(move |j| {
|
|
||||||
u8::from(size1 + j > size2 + i || size1 + j + context < size2 + i)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
Tensor::from_slice(&mask, (size1, size2), device)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_mask_rel(&self, size1: usize, size2: usize, device: &Device) -> Result<Tensor> {
|
|
||||||
let context = self.max_seq_len;
|
|
||||||
let upd_offset = (self.offset + size1) % self.max_seq_len;
|
|
||||||
let mask: Vec<_> = (0..size1)
|
|
||||||
.flat_map(|pos_src| {
|
|
||||||
// The absolute position of the elements that will get added to the cache.
|
|
||||||
let pos_src = self.current_seq_len + pos_src;
|
|
||||||
(0..size2).map(move |pos_cache_rel| {
|
|
||||||
// The absolute position of the cache elements after the addition.
|
|
||||||
let pos_cache = self.current_seq_len + size1 + pos_cache_rel - upd_offset;
|
|
||||||
let pos_cache = if pos_cache_rel < upd_offset {
|
|
||||||
pos_cache
|
|
||||||
} else {
|
|
||||||
pos_cache - self.max_seq_len
|
|
||||||
};
|
|
||||||
u8::from(pos_cache > pos_src || pos_cache + context < pos_src)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
Tensor::from_slice(&mask, (size1, size2), device)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the attn_mask to be applied *after* adding `seq_len` to the cache.
|
|
||||||
pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
|
|
||||||
let mask = if seq_len == 1 {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
let mask = if seq_len < self.max_seq_len {
|
|
||||||
let cache_out_len = (self.current_seq_len + seq_len).min(self.max_seq_len);
|
|
||||||
self.get_mask_rel(seq_len, cache_out_len, device)?
|
|
||||||
} else {
|
|
||||||
self.get_mask_abs(seq_len, seq_len, device)?
|
|
||||||
};
|
|
||||||
Some(mask)
|
|
||||||
};
|
|
||||||
Ok(mask)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -358,10 +308,6 @@ impl RotatingKvCache {
|
|||||||
self.k.current_seq_len()
|
self.k.current_seq_len()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
|
|
||||||
self.k.attn_mask(seq_len, device)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn reset(&mut self) {
|
pub fn reset(&mut self) {
|
||||||
self.k.reset();
|
self.k.reset();
|
||||||
self.v.reset();
|
self.v.reset();
|
||||||
|
@ -69,36 +69,13 @@ fn rotating_kv_cache() -> Result<()> {
|
|||||||
assert_eq!(cache.current_seq_len(), 13);
|
assert_eq!(cache.current_seq_len(), 13);
|
||||||
assert_eq!(cache.offset(), 1);
|
assert_eq!(cache.offset(), 1);
|
||||||
|
|
||||||
let mask = cache.attn_mask(2, &Device::Cpu)?.unwrap();
|
|
||||||
assert_eq!(
|
|
||||||
mask.to_vec2::<u8>()?,
|
|
||||||
&[[0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 0]]
|
|
||||||
);
|
|
||||||
let mask = cache.attn_mask(3, &Device::Cpu)?.unwrap();
|
|
||||||
assert_eq!(
|
|
||||||
mask.to_vec2::<u8>()?,
|
|
||||||
&[[0, 0, 1, 1, 0, 0], [0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 0]],
|
|
||||||
);
|
|
||||||
let t = Tensor::new(&[0., 1., 2., 3., 4., 5., 6., 7., 8.], &Device::Cpu)?;
|
let t = Tensor::new(&[0., 1., 2., 3., 4., 5., 6., 7., 8.], &Device::Cpu)?;
|
||||||
let data = cache.append(&t)?;
|
let data = cache.append(&t)?;
|
||||||
assert_eq!(data.to_vec1::<f64>()?, [0., 1., 2., 3., 4., 5., 6., 7., 8.]);
|
assert_eq!(data.to_vec1::<f64>()?, [0., 1., 2., 3., 4., 5., 6., 7., 8.]);
|
||||||
assert_eq!(cache.current_seq_len(), 22);
|
assert_eq!(cache.current_seq_len(), 22);
|
||||||
assert_eq!(cache.offset(), 0);
|
assert_eq!(cache.offset(), 0);
|
||||||
|
|
||||||
let mask = cache.attn_mask(1, &Device::Cpu)?;
|
|
||||||
assert!(mask.is_none());
|
|
||||||
let mask = cache.attn_mask(2, &Device::Cpu)?.unwrap();
|
|
||||||
assert_eq!(
|
|
||||||
mask.to_vec2::<u8>()?,
|
|
||||||
&[[0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]
|
|
||||||
);
|
|
||||||
let mask = cache.attn_mask(3, &Device::Cpu)?.unwrap();
|
|
||||||
assert_eq!(
|
|
||||||
mask.to_vec2::<u8>()?,
|
|
||||||
&[[0, 1, 1, 0, 0, 0], [0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 0]]
|
|
||||||
);
|
|
||||||
let t = Tensor::new(&[42.], &Device::Cpu)?;
|
let t = Tensor::new(&[42.], &Device::Cpu)?;
|
||||||
|
|
||||||
let data = cache.append(&t)?;
|
let data = cache.append(&t)?;
|
||||||
assert_eq!(data.to_vec1::<f64>()?, [42., 4., 5., 6., 7., 8.]);
|
assert_eq!(data.to_vec1::<f64>()?, [42., 4., 5., 6., 7., 8.]);
|
||||||
assert_eq!(cache.current_seq_len(), 23);
|
assert_eq!(cache.current_seq_len(), 23);
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-onnx"
|
name = "candle-onnx"
|
||||||
version = "0.7.1"
|
version = "0.7.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "ONNX support for Candle"
|
description = "ONNX support for Candle"
|
||||||
@ -10,8 +10,8 @@ categories = ["science"]
|
|||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", package = "candle-core", version = "0.7.1" }
|
candle = { path = "../candle-core", package = "candle-core", version = "0.7.0" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.7.1" }
|
candle-nn = { path = "../candle-nn", version = "0.7.0" }
|
||||||
prost = "0.12.1"
|
prost = "0.12.1"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
@ -1,20 +1,3 @@
|
|||||||
use candle::{Result, Tensor};
|
|
||||||
|
|
||||||
pub trait WithForward {
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
fn forward(
|
|
||||||
&self,
|
|
||||||
img: &Tensor,
|
|
||||||
img_ids: &Tensor,
|
|
||||||
txt: &Tensor,
|
|
||||||
txt_ids: &Tensor,
|
|
||||||
timesteps: &Tensor,
|
|
||||||
y: &Tensor,
|
|
||||||
guidance: Option<&Tensor>,
|
|
||||||
) -> Result<Tensor>;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub mod autoencoder;
|
pub mod autoencoder;
|
||||||
pub mod model;
|
pub mod model;
|
||||||
pub mod quantized_model;
|
|
||||||
pub mod sampling;
|
pub mod sampling;
|
||||||
|
@ -109,14 +109,14 @@ fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result<Tensor> {
|
|||||||
(fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec())
|
(fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
|
fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
|
||||||
let q = apply_rope(q, pe)?.contiguous()?;
|
let q = apply_rope(q, pe)?.contiguous()?;
|
||||||
let k = apply_rope(k, pe)?.contiguous()?;
|
let k = apply_rope(k, pe)?.contiguous()?;
|
||||||
let x = scaled_dot_product_attention(&q, &k, v)?;
|
let x = scaled_dot_product_attention(&q, &k, v)?;
|
||||||
x.transpose(1, 2)?.flatten_from(2)
|
x.transpose(1, 2)?.flatten_from(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
|
fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
|
||||||
const TIME_FACTOR: f64 = 1000.;
|
const TIME_FACTOR: f64 = 1000.;
|
||||||
const MAX_PERIOD: f64 = 10000.;
|
const MAX_PERIOD: f64 = 10000.;
|
||||||
if dim % 2 == 1 {
|
if dim % 2 == 1 {
|
||||||
@ -144,7 +144,7 @@ pub struct EmbedNd {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl EmbedNd {
|
impl EmbedNd {
|
||||||
pub fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
|
fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
dim,
|
dim,
|
||||||
theta,
|
theta,
|
||||||
@ -575,11 +575,9 @@ impl Flux {
|
|||||||
final_layer,
|
final_layer,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl super::WithForward for Flux {
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn forward(
|
pub fn forward(
|
||||||
&self,
|
&self,
|
||||||
img: &Tensor,
|
img: &Tensor,
|
||||||
img_ids: &Tensor,
|
img_ids: &Tensor,
|
||||||
|
@ -1,465 +0,0 @@
|
|||||||
use super::model::{attention, timestep_embedding, Config, EmbedNd};
|
|
||||||
use crate::quantized_nn::{linear, linear_b, Linear};
|
|
||||||
use crate::quantized_var_builder::VarBuilder;
|
|
||||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
|
||||||
use candle_nn::{LayerNorm, RmsNorm};
|
|
||||||
|
|
||||||
fn layer_norm(dim: usize, vb: VarBuilder) -> Result<LayerNorm> {
|
|
||||||
let ws = Tensor::ones(dim, DType::F32, vb.device())?;
|
|
||||||
Ok(LayerNorm::new_no_bias(ws, 1e-6))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct MlpEmbedder {
|
|
||||||
in_layer: Linear,
|
|
||||||
out_layer: Linear,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MlpEmbedder {
|
|
||||||
fn new(in_sz: usize, h_sz: usize, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let in_layer = linear(in_sz, h_sz, vb.pp("in_layer"))?;
|
|
||||||
let out_layer = linear(h_sz, h_sz, vb.pp("out_layer"))?;
|
|
||||||
Ok(Self {
|
|
||||||
in_layer,
|
|
||||||
out_layer,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl candle::Module for MlpEmbedder {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
xs.apply(&self.in_layer)?.silu()?.apply(&self.out_layer)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct QkNorm {
|
|
||||||
query_norm: RmsNorm,
|
|
||||||
key_norm: RmsNorm,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl QkNorm {
|
|
||||||
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let query_norm = vb.get(dim, "query_norm.scale")?.dequantize(vb.device())?;
|
|
||||||
let query_norm = RmsNorm::new(query_norm, 1e-6);
|
|
||||||
let key_norm = vb.get(dim, "key_norm.scale")?.dequantize(vb.device())?;
|
|
||||||
let key_norm = RmsNorm::new(key_norm, 1e-6);
|
|
||||||
Ok(Self {
|
|
||||||
query_norm,
|
|
||||||
key_norm,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ModulationOut {
|
|
||||||
shift: Tensor,
|
|
||||||
scale: Tensor,
|
|
||||||
gate: Tensor,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ModulationOut {
|
|
||||||
fn scale_shift(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
xs.broadcast_mul(&(&self.scale + 1.)?)?
|
|
||||||
.broadcast_add(&self.shift)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn gate(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
self.gate.broadcast_mul(xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct Modulation1 {
|
|
||||||
lin: Linear,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Modulation1 {
|
|
||||||
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let lin = linear(dim, 3 * dim, vb.pp("lin"))?;
|
|
||||||
Ok(Self { lin })
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, vec_: &Tensor) -> Result<ModulationOut> {
|
|
||||||
let ys = vec_
|
|
||||||
.silu()?
|
|
||||||
.apply(&self.lin)?
|
|
||||||
.unsqueeze(1)?
|
|
||||||
.chunk(3, D::Minus1)?;
|
|
||||||
if ys.len() != 3 {
|
|
||||||
candle::bail!("unexpected len from chunk {ys:?}")
|
|
||||||
}
|
|
||||||
Ok(ModulationOut {
|
|
||||||
shift: ys[0].clone(),
|
|
||||||
scale: ys[1].clone(),
|
|
||||||
gate: ys[2].clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct Modulation2 {
|
|
||||||
lin: Linear,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Modulation2 {
|
|
||||||
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let lin = linear(dim, 6 * dim, vb.pp("lin"))?;
|
|
||||||
Ok(Self { lin })
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, vec_: &Tensor) -> Result<(ModulationOut, ModulationOut)> {
|
|
||||||
let ys = vec_
|
|
||||||
.silu()?
|
|
||||||
.apply(&self.lin)?
|
|
||||||
.unsqueeze(1)?
|
|
||||||
.chunk(6, D::Minus1)?;
|
|
||||||
if ys.len() != 6 {
|
|
||||||
candle::bail!("unexpected len from chunk {ys:?}")
|
|
||||||
}
|
|
||||||
let mod1 = ModulationOut {
|
|
||||||
shift: ys[0].clone(),
|
|
||||||
scale: ys[1].clone(),
|
|
||||||
gate: ys[2].clone(),
|
|
||||||
};
|
|
||||||
let mod2 = ModulationOut {
|
|
||||||
shift: ys[3].clone(),
|
|
||||||
scale: ys[4].clone(),
|
|
||||||
gate: ys[5].clone(),
|
|
||||||
};
|
|
||||||
Ok((mod1, mod2))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct SelfAttention {
|
|
||||||
qkv: Linear,
|
|
||||||
norm: QkNorm,
|
|
||||||
proj: Linear,
|
|
||||||
num_heads: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SelfAttention {
|
|
||||||
fn new(dim: usize, num_heads: usize, qkv_bias: bool, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let head_dim = dim / num_heads;
|
|
||||||
let qkv = linear_b(dim, dim * 3, qkv_bias, vb.pp("qkv"))?;
|
|
||||||
let norm = QkNorm::new(head_dim, vb.pp("norm"))?;
|
|
||||||
let proj = linear(dim, dim, vb.pp("proj"))?;
|
|
||||||
Ok(Self {
|
|
||||||
qkv,
|
|
||||||
norm,
|
|
||||||
proj,
|
|
||||||
num_heads,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn qkv(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
|
|
||||||
let qkv = xs.apply(&self.qkv)?;
|
|
||||||
let (b, l, _khd) = qkv.dims3()?;
|
|
||||||
let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
|
|
||||||
let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
|
|
||||||
let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
|
|
||||||
let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
|
|
||||||
let q = q.apply(&self.norm.query_norm)?;
|
|
||||||
let k = k.apply(&self.norm.key_norm)?;
|
|
||||||
Ok((q, k, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(unused)]
|
|
||||||
fn forward(&self, xs: &Tensor, pe: &Tensor) -> Result<Tensor> {
|
|
||||||
let (q, k, v) = self.qkv(xs)?;
|
|
||||||
attention(&q, &k, &v, pe)?.apply(&self.proj)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct Mlp {
|
|
||||||
lin1: Linear,
|
|
||||||
lin2: Linear,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Mlp {
|
|
||||||
fn new(in_sz: usize, mlp_sz: usize, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let lin1 = linear(in_sz, mlp_sz, vb.pp("0"))?;
|
|
||||||
let lin2 = linear(mlp_sz, in_sz, vb.pp("2"))?;
|
|
||||||
Ok(Self { lin1, lin2 })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl candle::Module for Mlp {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct DoubleStreamBlock {
|
|
||||||
img_mod: Modulation2,
|
|
||||||
img_norm1: LayerNorm,
|
|
||||||
img_attn: SelfAttention,
|
|
||||||
img_norm2: LayerNorm,
|
|
||||||
img_mlp: Mlp,
|
|
||||||
txt_mod: Modulation2,
|
|
||||||
txt_norm1: LayerNorm,
|
|
||||||
txt_attn: SelfAttention,
|
|
||||||
txt_norm2: LayerNorm,
|
|
||||||
txt_mlp: Mlp,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DoubleStreamBlock {
|
|
||||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let h_sz = cfg.hidden_size;
|
|
||||||
let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize;
|
|
||||||
let img_mod = Modulation2::new(h_sz, vb.pp("img_mod"))?;
|
|
||||||
let img_norm1 = layer_norm(h_sz, vb.pp("img_norm1"))?;
|
|
||||||
let img_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("img_attn"))?;
|
|
||||||
let img_norm2 = layer_norm(h_sz, vb.pp("img_norm2"))?;
|
|
||||||
let img_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("img_mlp"))?;
|
|
||||||
let txt_mod = Modulation2::new(h_sz, vb.pp("txt_mod"))?;
|
|
||||||
let txt_norm1 = layer_norm(h_sz, vb.pp("txt_norm1"))?;
|
|
||||||
let txt_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("txt_attn"))?;
|
|
||||||
let txt_norm2 = layer_norm(h_sz, vb.pp("txt_norm2"))?;
|
|
||||||
let txt_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("txt_mlp"))?;
|
|
||||||
Ok(Self {
|
|
||||||
img_mod,
|
|
||||||
img_norm1,
|
|
||||||
img_attn,
|
|
||||||
img_norm2,
|
|
||||||
img_mlp,
|
|
||||||
txt_mod,
|
|
||||||
txt_norm1,
|
|
||||||
txt_attn,
|
|
||||||
txt_norm2,
|
|
||||||
txt_mlp,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(
|
|
||||||
&self,
|
|
||||||
img: &Tensor,
|
|
||||||
txt: &Tensor,
|
|
||||||
vec_: &Tensor,
|
|
||||||
pe: &Tensor,
|
|
||||||
) -> Result<(Tensor, Tensor)> {
|
|
||||||
let (img_mod1, img_mod2) = self.img_mod.forward(vec_)?; // shift, scale, gate
|
|
||||||
let (txt_mod1, txt_mod2) = self.txt_mod.forward(vec_)?; // shift, scale, gate
|
|
||||||
let img_modulated = img.apply(&self.img_norm1)?;
|
|
||||||
let img_modulated = img_mod1.scale_shift(&img_modulated)?;
|
|
||||||
let (img_q, img_k, img_v) = self.img_attn.qkv(&img_modulated)?;
|
|
||||||
|
|
||||||
let txt_modulated = txt.apply(&self.txt_norm1)?;
|
|
||||||
let txt_modulated = txt_mod1.scale_shift(&txt_modulated)?;
|
|
||||||
let (txt_q, txt_k, txt_v) = self.txt_attn.qkv(&txt_modulated)?;
|
|
||||||
|
|
||||||
let q = Tensor::cat(&[txt_q, img_q], 2)?;
|
|
||||||
let k = Tensor::cat(&[txt_k, img_k], 2)?;
|
|
||||||
let v = Tensor::cat(&[txt_v, img_v], 2)?;
|
|
||||||
|
|
||||||
let attn = attention(&q, &k, &v, pe)?;
|
|
||||||
let txt_attn = attn.narrow(1, 0, txt.dim(1)?)?;
|
|
||||||
let img_attn = attn.narrow(1, txt.dim(1)?, attn.dim(1)? - txt.dim(1)?)?;
|
|
||||||
|
|
||||||
let img = (img + img_mod1.gate(&img_attn.apply(&self.img_attn.proj)?))?;
|
|
||||||
let img = (&img
|
|
||||||
+ img_mod2.gate(
|
|
||||||
&img_mod2
|
|
||||||
.scale_shift(&img.apply(&self.img_norm2)?)?
|
|
||||||
.apply(&self.img_mlp)?,
|
|
||||||
)?)?;
|
|
||||||
|
|
||||||
let txt = (txt + txt_mod1.gate(&txt_attn.apply(&self.txt_attn.proj)?))?;
|
|
||||||
let txt = (&txt
|
|
||||||
+ txt_mod2.gate(
|
|
||||||
&txt_mod2
|
|
||||||
.scale_shift(&txt.apply(&self.txt_norm2)?)?
|
|
||||||
.apply(&self.txt_mlp)?,
|
|
||||||
)?)?;
|
|
||||||
|
|
||||||
Ok((img, txt))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct SingleStreamBlock {
|
|
||||||
linear1: Linear,
|
|
||||||
linear2: Linear,
|
|
||||||
norm: QkNorm,
|
|
||||||
pre_norm: LayerNorm,
|
|
||||||
modulation: Modulation1,
|
|
||||||
h_sz: usize,
|
|
||||||
mlp_sz: usize,
|
|
||||||
num_heads: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SingleStreamBlock {
|
|
||||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let h_sz = cfg.hidden_size;
|
|
||||||
let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize;
|
|
||||||
let head_dim = h_sz / cfg.num_heads;
|
|
||||||
let linear1 = linear(h_sz, h_sz * 3 + mlp_sz, vb.pp("linear1"))?;
|
|
||||||
let linear2 = linear(h_sz + mlp_sz, h_sz, vb.pp("linear2"))?;
|
|
||||||
let norm = QkNorm::new(head_dim, vb.pp("norm"))?;
|
|
||||||
let pre_norm = layer_norm(h_sz, vb.pp("pre_norm"))?;
|
|
||||||
let modulation = Modulation1::new(h_sz, vb.pp("modulation"))?;
|
|
||||||
Ok(Self {
|
|
||||||
linear1,
|
|
||||||
linear2,
|
|
||||||
norm,
|
|
||||||
pre_norm,
|
|
||||||
modulation,
|
|
||||||
h_sz,
|
|
||||||
mlp_sz,
|
|
||||||
num_heads: cfg.num_heads,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, xs: &Tensor, vec_: &Tensor, pe: &Tensor) -> Result<Tensor> {
|
|
||||||
let mod_ = self.modulation.forward(vec_)?;
|
|
||||||
let x_mod = mod_.scale_shift(&xs.apply(&self.pre_norm)?)?;
|
|
||||||
let x_mod = x_mod.apply(&self.linear1)?;
|
|
||||||
let qkv = x_mod.narrow(D::Minus1, 0, 3 * self.h_sz)?;
|
|
||||||
let (b, l, _khd) = qkv.dims3()?;
|
|
||||||
let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
|
|
||||||
let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
|
|
||||||
let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
|
|
||||||
let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
|
|
||||||
let mlp = x_mod.narrow(D::Minus1, 3 * self.h_sz, self.mlp_sz)?;
|
|
||||||
let q = q.apply(&self.norm.query_norm)?;
|
|
||||||
let k = k.apply(&self.norm.key_norm)?;
|
|
||||||
let attn = attention(&q, &k, &v, pe)?;
|
|
||||||
let output = Tensor::cat(&[attn, mlp.gelu()?], 2)?.apply(&self.linear2)?;
|
|
||||||
xs + mod_.gate(&output)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct LastLayer {
|
|
||||||
norm_final: LayerNorm,
|
|
||||||
linear: Linear,
|
|
||||||
ada_ln_modulation: Linear,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LastLayer {
|
|
||||||
fn new(h_sz: usize, p_sz: usize, out_c: usize, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let norm_final = layer_norm(h_sz, vb.pp("norm_final"))?;
|
|
||||||
let linear_ = linear(h_sz, p_sz * p_sz * out_c, vb.pp("linear"))?;
|
|
||||||
let ada_ln_modulation = linear(h_sz, 2 * h_sz, vb.pp("adaLN_modulation.1"))?;
|
|
||||||
Ok(Self {
|
|
||||||
norm_final,
|
|
||||||
linear: linear_,
|
|
||||||
ada_ln_modulation,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, xs: &Tensor, vec: &Tensor) -> Result<Tensor> {
|
|
||||||
let chunks = vec.silu()?.apply(&self.ada_ln_modulation)?.chunk(2, 1)?;
|
|
||||||
let (shift, scale) = (&chunks[0], &chunks[1]);
|
|
||||||
let xs = xs
|
|
||||||
.apply(&self.norm_final)?
|
|
||||||
.broadcast_mul(&(scale.unsqueeze(1)? + 1.0)?)?
|
|
||||||
.broadcast_add(&shift.unsqueeze(1)?)?;
|
|
||||||
xs.apply(&self.linear)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct Flux {
|
|
||||||
img_in: Linear,
|
|
||||||
txt_in: Linear,
|
|
||||||
time_in: MlpEmbedder,
|
|
||||||
vector_in: MlpEmbedder,
|
|
||||||
guidance_in: Option<MlpEmbedder>,
|
|
||||||
pe_embedder: EmbedNd,
|
|
||||||
double_blocks: Vec<DoubleStreamBlock>,
|
|
||||||
single_blocks: Vec<SingleStreamBlock>,
|
|
||||||
final_layer: LastLayer,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Flux {
|
|
||||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let img_in = linear(cfg.in_channels, cfg.hidden_size, vb.pp("img_in"))?;
|
|
||||||
let txt_in = linear(cfg.context_in_dim, cfg.hidden_size, vb.pp("txt_in"))?;
|
|
||||||
let mut double_blocks = Vec::with_capacity(cfg.depth);
|
|
||||||
let vb_d = vb.pp("double_blocks");
|
|
||||||
for idx in 0..cfg.depth {
|
|
||||||
let db = DoubleStreamBlock::new(cfg, vb_d.pp(idx))?;
|
|
||||||
double_blocks.push(db)
|
|
||||||
}
|
|
||||||
let mut single_blocks = Vec::with_capacity(cfg.depth_single_blocks);
|
|
||||||
let vb_s = vb.pp("single_blocks");
|
|
||||||
for idx in 0..cfg.depth_single_blocks {
|
|
||||||
let sb = SingleStreamBlock::new(cfg, vb_s.pp(idx))?;
|
|
||||||
single_blocks.push(sb)
|
|
||||||
}
|
|
||||||
let time_in = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("time_in"))?;
|
|
||||||
let vector_in = MlpEmbedder::new(cfg.vec_in_dim, cfg.hidden_size, vb.pp("vector_in"))?;
|
|
||||||
let guidance_in = if cfg.guidance_embed {
|
|
||||||
let mlp = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("guidance_in"))?;
|
|
||||||
Some(mlp)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let final_layer =
|
|
||||||
LastLayer::new(cfg.hidden_size, 1, cfg.in_channels, vb.pp("final_layer"))?;
|
|
||||||
let pe_dim = cfg.hidden_size / cfg.num_heads;
|
|
||||||
let pe_embedder = EmbedNd::new(pe_dim, cfg.theta, cfg.axes_dim.to_vec());
|
|
||||||
Ok(Self {
|
|
||||||
img_in,
|
|
||||||
txt_in,
|
|
||||||
time_in,
|
|
||||||
vector_in,
|
|
||||||
guidance_in,
|
|
||||||
pe_embedder,
|
|
||||||
double_blocks,
|
|
||||||
single_blocks,
|
|
||||||
final_layer,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl super::WithForward for Flux {
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
fn forward(
|
|
||||||
&self,
|
|
||||||
img: &Tensor,
|
|
||||||
img_ids: &Tensor,
|
|
||||||
txt: &Tensor,
|
|
||||||
txt_ids: &Tensor,
|
|
||||||
timesteps: &Tensor,
|
|
||||||
y: &Tensor,
|
|
||||||
guidance: Option<&Tensor>,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
if txt.rank() != 3 {
|
|
||||||
candle::bail!("unexpected shape for txt {:?}", txt.shape())
|
|
||||||
}
|
|
||||||
if img.rank() != 3 {
|
|
||||||
candle::bail!("unexpected shape for img {:?}", img.shape())
|
|
||||||
}
|
|
||||||
let dtype = img.dtype();
|
|
||||||
let pe = {
|
|
||||||
let ids = Tensor::cat(&[txt_ids, img_ids], 1)?;
|
|
||||||
ids.apply(&self.pe_embedder)?
|
|
||||||
};
|
|
||||||
let mut txt = txt.apply(&self.txt_in)?;
|
|
||||||
let mut img = img.apply(&self.img_in)?;
|
|
||||||
let vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?;
|
|
||||||
let vec_ = match (self.guidance_in.as_ref(), guidance) {
|
|
||||||
(Some(g_in), Some(guidance)) => {
|
|
||||||
(vec_ + timestep_embedding(guidance, 256, dtype)?.apply(g_in))?
|
|
||||||
}
|
|
||||||
_ => vec_,
|
|
||||||
};
|
|
||||||
let vec_ = (vec_ + y.apply(&self.vector_in))?;
|
|
||||||
|
|
||||||
// Double blocks
|
|
||||||
for block in self.double_blocks.iter() {
|
|
||||||
(img, txt) = block.forward(&img, &txt, &vec_, &pe)?
|
|
||||||
}
|
|
||||||
// Single blocks
|
|
||||||
let mut img = Tensor::cat(&[&txt, &img], 1)?;
|
|
||||||
for block in self.single_blocks.iter() {
|
|
||||||
img = block.forward(&img, &vec_, &pe)?;
|
|
||||||
}
|
|
||||||
let img = img.i((.., txt.dim(1)?..))?;
|
|
||||||
self.final_layer.forward(&img, &vec_)
|
|
||||||
}
|
|
||||||
}
|
|
@ -92,8 +92,8 @@ pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn denoise<M: super::WithForward>(
|
pub fn denoise(
|
||||||
model: &M,
|
model: &super::model::Flux,
|
||||||
img: &Tensor,
|
img: &Tensor,
|
||||||
img_ids: &Tensor,
|
img_ids: &Tensor,
|
||||||
txt: &Tensor,
|
txt: &Tensor,
|
||||||
|
@ -101,6 +101,21 @@ impl Module for LayerScale {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn get_mask(
|
||||||
|
size1: usize,
|
||||||
|
size2: usize,
|
||||||
|
context: usize,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let mask: Vec<_> = (0..size1)
|
||||||
|
.flat_map(|i| {
|
||||||
|
(0..size2)
|
||||||
|
.map(move |j| u8::from(size1 + j > size2 + i || size1 + j + context < size2 + i))
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Tensor::from_slice(&mask, (size1, size2), device)
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct StreamingMultiheadAttention {
|
pub struct StreamingMultiheadAttention {
|
||||||
q_proj: Linear,
|
q_proj: Linear,
|
||||||
@ -575,6 +590,7 @@ impl StreamingTransformerLayer {
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct StreamingTransformer {
|
pub struct StreamingTransformer {
|
||||||
layers: Vec<StreamingTransformerLayer>,
|
layers: Vec<StreamingTransformerLayer>,
|
||||||
|
context: usize,
|
||||||
positional_embedding: PositionalEmbedding,
|
positional_embedding: PositionalEmbedding,
|
||||||
max_period: usize,
|
max_period: usize,
|
||||||
}
|
}
|
||||||
@ -601,6 +617,7 @@ impl StreamingTransformer {
|
|||||||
}
|
}
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
layers,
|
layers,
|
||||||
|
context: cfg.context,
|
||||||
positional_embedding: cfg.positional_embedding,
|
positional_embedding: cfg.positional_embedding,
|
||||||
max_period: cfg.max_period,
|
max_period: cfg.max_period,
|
||||||
})
|
})
|
||||||
@ -612,11 +629,23 @@ impl StreamingTransformer {
|
|||||||
|
|
||||||
pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&Tensor>) -> Result<Tensor> {
|
pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&Tensor>) -> Result<Tensor> {
|
||||||
let (_b, t, c) = xs.dims3()?;
|
let (_b, t, c) = xs.dims3()?;
|
||||||
let pos = self.layers[0].self_attn.kv_cache.current_seq_len();
|
let pos = self.layers[0]
|
||||||
let mask = self.layers[0]
|
|
||||||
.self_attn
|
.self_attn
|
||||||
.kv_cache
|
.kv_cache
|
||||||
.attn_mask(t, xs.device())?;
|
.k_cache()
|
||||||
|
.current_seq_len();
|
||||||
|
let mask = if t == 1 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let cache_out_len = if t < self.context {
|
||||||
|
(pos + t).min(self.context)
|
||||||
|
} else {
|
||||||
|
t
|
||||||
|
};
|
||||||
|
// TODO: this is wrong, the mask depends on the kv-cache offset because of its rotating
|
||||||
|
// nature.
|
||||||
|
Some(get_mask(t, cache_out_len, self.context, xs.device())?)
|
||||||
|
};
|
||||||
let mut xs = match self.positional_embedding {
|
let mut xs = match self.positional_embedding {
|
||||||
PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(),
|
PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(),
|
||||||
PositionalEmbedding::Sin => {
|
PositionalEmbedding::Sin => {
|
||||||
|
Reference in New Issue
Block a user