Compare commits

...

9 Commits
0.7.1 ... 0.7.2

Author SHA1 Message Date
3a3c48b14b Bump the crate version to 0.7.2. (#2517) 2024-09-29 10:56:50 +02:00
261ed65f36 Add the SigLIP model. (#2515)
* Add the SigLIP model.

* Add more to the forward pass of the vision model.

* Complete the forward pass.

* Add the siglip example.

* Fix.

* Another fix.

* Get everything in place.

* Add a readme.
2024-09-28 23:48:00 +02:00
62525e8352 Remove some extra whitelines. (#2513) 2024-09-28 14:41:28 +02:00
2c25754281 Clippy fixes for onnx + fix a broken test. (#2510) 2024-09-26 23:37:59 +02:00
ed48f54b54 Expand split ops (#2505)
* candle-onnx: Add Split and Expand operators, Fix Where Op

Implemented based on https://github.com/onnx/onnx/blob/main/docs/Operators.md
Test cases based on those examples.

TODO: Should add the remaining Split examples as tests
TODO: Add.test case that motivates Where fix

* candle-onnx: Add ReduceSum operator

Implemented based on https://github.com/onnx/onnx/blob/main/docs/Operators.md
Test cases based on those examples.

TODO: Should add the remaining ReduceSum examples as tests

* candle-onnx: Add ReduceL2 operator

Implemented based on https://github.com/onnx/onnx/blob/main/docs/Operators.md
Test cases based on those examples.

TODO: Should add the remaining ReduceSum examples as tests

* candle-onnx: Fix Clip operator empty string as default arg issue

Optional input args may be signified by an empty string. The length of the input array is not enough because non optional args may follow optional ones.

I encountered this when trying to use the ONNX model found at https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 for example.

The LSTM op has a utility which I factored to be more generally accessible, and I have used it in the ops I have recently created or debugged.

I believe it is likely that this issue may also manifest in other ops, but I didn't want to change anything that I'm not testing.

* fix formatting

* fix small mistake made during refactor
2024-09-26 22:57:55 +02:00
ad8a4c5e5a Add some llama-3.2 examples. (#2508)
* Add some llama-3.2 examples.

* Support tie-word-embeddings for llama.
2024-09-26 21:00:18 +02:00
c3c392f45c Merge pull request #2507 from huggingface/ci-move
move CI/Cuda runner
2024-09-26 18:48:52 +02:00
a0184a4fe4 move CI/Cuda runner 2024-09-26 17:09:26 +02:00
10d47183c0 Quantized version of flux. (#2500)
* Quantized version of flux.

* More generic sampling.

* Hook the quantized model.

* Use the newly minted gguf file.

* Fix for the quantized model.

* Default to avoid the faster cuda kernels.
2024-09-26 10:23:43 +02:00
27 changed files with 2189 additions and 413 deletions

View File

@ -9,7 +9,8 @@ jobs:
concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
runs-on: [single-gpu, nvidia-gpu, t4, ci]
runs-on:
group: aws-g4dn-2xlarge
container:
image: nvidia/cuda:12.3.1-devel-ubuntu22.04
options: --gpus 0

View File

@ -20,7 +20,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.7.1"
version = "0.7.2"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -33,14 +33,14 @@ ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.7.1" }
candle-datasets = { path = "./candle-datasets", version = "0.7.1" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.7.1" }
candle-kernels = { path = "./candle-kernels", version = "0.7.1" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.7.1" }
candle-nn = { path = "./candle-nn", version = "0.7.1" }
candle-onnx = { path = "./candle-onnx", version = "0.7.1" }
candle-transformers = { path = "./candle-transformers", version = "0.7.1" }
candle = { path = "./candle-core", package = "candle-core", version = "0.7.2" }
candle-datasets = { path = "./candle-datasets", version = "0.7.2" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.7.2" }
candle-kernels = { path = "./candle-kernels", version = "0.7.2" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.7.2" }
candle-nn = { path = "./candle-nn", version = "0.7.2" }
candle-onnx = { path = "./candle-onnx", version = "0.7.2" }
candle-transformers = { path = "./candle-transformers", version = "0.7.2" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }

View File

@ -12,7 +12,6 @@ use candle_nn::{ops::softmax, VarBuilder};
use candle_transformers::models::clip;
use tokenizers::Tokenizer;
use tracing::info;
#[derive(Parser)]
struct Args {
@ -40,15 +39,12 @@ fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::
height as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let img = img.into_raw();
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?;
// .unsqueeze(0)?;
Ok(img)
}
@ -57,24 +53,16 @@ fn load_images<T: AsRef<std::path::Path>>(
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];
for path in paths {
let tensor = load_image(path, image_size)?;
images.push(tensor);
}
let images = Tensor::stack(&images, 0)?;
Ok(images)
}
pub fn main() -> anyhow::Result<()> {
// std::env::set_var("RUST_BACKTRACE", "full");
let args = Args::parse();
tracing_subscriber::fmt::init();
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
@ -89,13 +77,9 @@ pub fn main() -> anyhow::Result<()> {
}
Some(model) => model.into(),
};
let tokenizer = get_tokenizer(args.tokenizer)?;
let config = clip::ClipConfig::vit_base_patch32();
let device = candle_examples::device(args.cpu)?;
let vec_imgs = match args.images {
Some(imgs) => imgs,
None => vec![
@ -103,43 +87,29 @@ pub fn main() -> anyhow::Result<()> {
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
],
};
// let image = load_image(args.image, config.image_size)?.to_device(&device)?;
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
let model = clip::ClipModel::new(vb, &config)?;
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
let softmax_image = softmax(&logits_per_image, 1)?;
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
info!("softmax_image_vec: {:?}", softmax_image_vec);
println!("softmax_image_vec: {:?}", softmax_image_vec);
let probability_vec = softmax_image_vec
.iter()
.map(|v| v * 100.0)
.collect::<Vec<f32>>();
let probability_per_image = probability_vec.len() / vec_imgs.len();
for (i, img) in vec_imgs.iter().enumerate() {
let start = i * probability_per_image;
let end = start + probability_per_image;
let prob = &probability_vec[start..end];
info!("\n\nResults for image: {}\n", img);
println!("\n\nResults for image: {}\n", img);
for (i, p) in prob.iter().enumerate() {
info!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
println!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
}
}
Ok(())
}
@ -156,7 +126,6 @@ pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
}
Some(file) => file.into(),
};
Tokenizer::from_file(tokenizer).map_err(E::msg)
}
@ -169,7 +138,6 @@ pub fn tokenize_sequences(
.get_vocab(true)
.get("<|endoftext|>")
.ok_or(E::msg("No pad token"))?;
let vec_seq = match sequences {
Some(seq) => seq,
None => vec![
@ -178,16 +146,12 @@ pub fn tokenize_sequences(
"a robot holding a candle".to_string(),
],
};
let mut tokens = vec![];
for seq in vec_seq.clone() {
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
tokens.push(encoding.get_ids().to_vec());
}
let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0);
// Pad the sequences to have the same length
for token_vec in tokens.iter_mut() {
let len_diff = max_len - token_vec.len();
@ -195,8 +159,6 @@ pub fn tokenize_sequences(
token_vec.extend(vec![pad_id; len_diff]);
}
}
let input_ids = Tensor::new(tokens, device)?;
Ok((input_ids, vec_seq))
}

View File

@ -13,7 +13,7 @@ descriptions,
```bash
cargo run --features cuda --example flux -r -- \
--height 1024 --width 1024
--height 1024 --width 1024 \
--prompt "a rusty robot walking on a beach holding a small torch, the robot has the word "rust" written on it, high quality, 4k"
```

View File

@ -23,6 +23,10 @@ struct Args {
#[arg(long)]
cpu: bool,
/// Use the quantized model.
#[arg(long)]
quantized: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
@ -40,6 +44,10 @@ struct Args {
#[arg(long, value_enum, default_value = "schnell")]
model: Model,
/// Use the faster kernels which are buggy at the moment.
#[arg(long)]
no_dmmv: bool,
}
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
@ -60,6 +68,8 @@ fn run(args: Args) -> Result<()> {
tracing,
decode_only,
model,
quantized,
..
} = args;
let width = width.unwrap_or(1360);
let height = height.unwrap_or(768);
@ -146,38 +156,71 @@ fn run(args: Args) -> Result<()> {
};
println!("CLIP\n{clip_emb}");
let img = {
let model_file = match model {
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
};
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let cfg = match model {
Model::Dev => flux::model::Config::dev(),
Model::Schnell => flux::model::Config::schnell(),
};
let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?;
let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?;
let state = if quantized {
flux::sampling::State::new(
&t5_emb.to_dtype(candle::DType::F32)?,
&clip_emb.to_dtype(candle::DType::F32)?,
&img.to_dtype(candle::DType::F32)?,
)?
} else {
flux::sampling::State::new(&t5_emb, &clip_emb, &img)?
};
let timesteps = match model {
Model::Dev => {
flux::sampling::get_schedule(50, Some((state.img.dim(1)?, 0.5, 1.15)))
}
Model::Schnell => flux::sampling::get_schedule(4, None),
};
let model = flux::model::Flux::new(&cfg, vb)?;
println!("{state:?}");
println!("{timesteps:?}");
flux::sampling::denoise(
&model,
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
&timesteps,
4.,
)?
if quantized {
let model_file = match model {
Model::Schnell => api
.repo(hf_hub::Repo::model("lmz/candle-flux".to_string()))
.get("flux1-schnell.gguf")?,
Model::Dev => todo!(),
};
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
model_file, &device,
)?;
let model = flux::quantized_model::Flux::new(&cfg, vb)?;
flux::sampling::denoise(
&model,
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
&timesteps,
4.,
)?
.to_dtype(dtype)?
} else {
let model_file = match model {
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
};
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)?
};
let model = flux::model::Flux::new(&cfg, vb)?;
flux::sampling::denoise(
&model,
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
&timesteps,
4.,
)?
}
};
flux::sampling::unpack(&img, height, width)?
}
@ -206,5 +249,7 @@ fn run(args: Args) -> Result<()> {
fn main() -> Result<()> {
let args = Args::parse();
#[cfg(feature = "cuda")]
candle::quantized::cuda::set_force_dmmv(!args.no_dmmv);
run(args)
}

View File

@ -35,6 +35,10 @@ enum Which {
V31,
V3Instruct,
V31Instruct,
V32_1b,
V32_1bInstruct,
V32_3b,
V32_3bInstruct,
#[value(name = "solar-10.7b")]
Solar10_7B,
#[value(name = "tiny-llama-1.1b-chat")]
@ -137,6 +141,10 @@ fn main() -> Result<()> {
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
Which::V31 => "meta-llama/Meta-Llama-3.1-8B".to_string(),
Which::V31Instruct => "meta-llama/Meta-Llama-3.1-8B-Instruct".to_string(),
Which::V32_1b => "meta-llama/Llama-3.2-1B".to_string(),
Which::V32_1bInstruct => "meta-llama/Llama-3.2-1B-Instruct".to_string(),
Which::V32_3b => "meta-llama/Llama-3.2-3B".to_string(),
Which::V32_3bInstruct => "meta-llama/Llama-3.2-3B-Instruct".to_string(),
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
});
@ -156,10 +164,14 @@ fn main() -> Result<()> {
| Which::V3Instruct
| Which::V31
| Which::V31Instruct
| Which::V32_3b
| Which::V32_3bInstruct
| Which::Solar10_7B => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
}
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
Which::V32_1b | Which::V32_1bInstruct | Which::TinyLlama1_1BChat => {
vec![api.get("model.safetensors")?]
}
};
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;

View File

@ -60,7 +60,6 @@ fn load_images<T: AsRef<std::path::Path>>(
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];
for path in paths {
let tensor = candle_examples::imagenet::load_image_with_std_mean(
path,
@ -70,9 +69,7 @@ fn load_images<T: AsRef<std::path::Path>>(
)?;
images.push(tensor);
}
let images = Tensor::stack(&images, 0)?;
Ok(images)
}
@ -80,24 +77,17 @@ pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let model_name = args.which.model_name();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
let model_file = if args.use_pth {
api.get("open_clip_pytorch_model.bin")?
} else {
api.get("open_clip_model.safetensors")?
};
let tokenizer = api.get("tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let config = &args.which.config();
let device = candle_examples::device(args.cpu)?;
let vec_imgs = match args.images {
Some(imgs) => imgs,
None => vec![
@ -105,9 +95,7 @@ pub fn main() -> anyhow::Result<()> {
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
],
};
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
let vb = if args.use_pth {
VarBuilder::from_pth(&model_file, DType::F32, &device)?
} else {
@ -115,22 +103,15 @@ pub fn main() -> anyhow::Result<()> {
};
let model = mobileclip::MobileClipModel::new(vb, config)?;
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
let softmax_image = softmax(&logits_per_image, 1)?;
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
println!("softmax_image_vec: {:?}", softmax_image_vec);
let probability_vec = softmax_image_vec
.iter()
.map(|v| v * 100.0)
.collect::<Vec<f32>>();
let probability_per_image = probability_vec.len() / vec_imgs.len();
for (i, img) in vec_imgs.iter().enumerate() {
@ -171,7 +152,6 @@ pub fn tokenize_sequences(
};
let mut tokens = vec![];
for seq in vec_seq.clone() {
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
tokens.push(encoding.get_ids().to_vec());
@ -185,8 +165,6 @@ pub fn tokenize_sequences(
token_vec.extend(vec![pad_id; len_diff]);
}
}
let input_ids = Tensor::new(tokens, device)?;
Ok((input_ids, vec_seq))
}

View File

@ -0,0 +1,24 @@
## SigLIP
SigLIP is multi-modal text-vision model that improves over CLIP by using a sigmoid based loss,
[HuggingFace](https://huggingface.co/google/siglip-base-patch16-224).
### Running an example
```
$ cargo run --features cuda -r --example siglip -
softmax_image_vec: [2.1912122e-14, 2.3624872e-14, 1.0, 1.0, 2.4787932e-8, 3.2784535e-12]
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
Probability: 0.0000% Text: a cycling race
Probability: 0.0000% Text: a photo of two cats
Probability: 100.0000% Text: a robot holding a candle
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
Probability: 100.0000% Text: a cycling race
Probability: 0.0000% Text: a photo of two cats
Probability: 0.0000% Text: a robot holding a candle
```

View File

@ -0,0 +1,153 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::Parser;
use candle::{DType, Device, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use candle_transformers::models::siglip;
use tokenizers::Tokenizer;
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long, use_value_delimiter = true)]
images: Option<Vec<String>>,
#[arg(long)]
cpu: bool,
#[arg(long, use_value_delimiter = true)]
sequences: Option<Vec<String>>,
}
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
let img = image::ImageReader::open(path)?.decode()?;
let (height, width) = (image_size, image_size);
let img = img.resize_to_fill(
width as u32,
height as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let img = img.into_raw();
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?;
Ok(img)
}
fn load_images<T: AsRef<std::path::Path>>(
paths: &Vec<T>,
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];
for path in paths {
let tensor = load_image(path, image_size)?;
images.push(tensor);
}
let images = Tensor::stack(&images, 0)?;
Ok(images)
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("google/siglip-base-patch16-224".to_string());
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let tokenizer = get_tokenizer(args.tokenizer)?;
let config = siglip::Config::base_patch16_224();
let device = candle_examples::device(args.cpu)?;
let vec_imgs = match args.images {
Some(imgs) => imgs,
None => vec![
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
],
};
let images = load_images(&vec_imgs, config.vision_config.image_size)?.to_device(&device)?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
let model = siglip::Model::new(&config, vb)?;
let (input_ids, vec_seq) = tokenize_sequences(&config, args.sequences, &tokenizer, &device)?;
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
let softmax_image = softmax(&logits_per_image, 1)?;
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
println!("softmax_image_vec: {:?}", softmax_image_vec);
let probability_vec = softmax_image_vec
.iter()
.map(|v| v * 100.0)
.collect::<Vec<f32>>();
let probability_per_image = probability_vec.len() / vec_imgs.len();
for (i, img) in vec_imgs.iter().enumerate() {
let start = i * probability_per_image;
let end = start + probability_per_image;
let prob = &probability_vec[start..end];
println!("\n\nResults for image: {}\n", img);
for (i, p) in prob.iter().enumerate() {
println!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
}
}
Ok(())
}
pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
let tokenizer = match tokenizer {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("google/siglip-base-patch16-224".to_string());
api.get("tokenizer.json")?
}
Some(file) => file.into(),
};
Tokenizer::from_file(tokenizer).map_err(E::msg)
}
pub fn tokenize_sequences(
config: &siglip::Config,
sequences: Option<Vec<String>>,
tokenizer: &Tokenizer,
device: &Device,
) -> anyhow::Result<(Tensor, Vec<String>)> {
let pad_id = config.text_config.pad_token_id;
let vec_seq = match sequences {
Some(seq) => seq,
None => vec![
"a cycling race".to_string(),
"a photo of two cats".to_string(),
"a robot holding a candle".to_string(),
],
};
let mut tokens = vec![];
for seq in vec_seq.clone() {
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
tokens.push(encoding.get_ids().to_vec());
}
let max_len = config.text_config.max_position_embeddings;
// Pad the sequences to have the same length
for token_vec in tokens.iter_mut() {
let len_diff = max_len - token_vec.len();
if len_diff > 0 {
token_vec.extend(vec![pad_id; len_diff]);
}
}
let input_ids = Tensor::new(tokens, device)?;
Ok((input_ids, vec_seq))
}

View File

@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
version = "0.7.1"
version = "0.7.2"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.7.1" }
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.7.2" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]

View File

@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
version = "0.7.1"
version = "0.7.2"
edition = "2021"
description = "CUDA kernels for Candle"

View File

@ -1,6 +1,6 @@
[package]
name = "candle-metal-kernels"
version = "0.7.1"
version = "0.7.2"
edition = "2021"
description = "Metal kernels for Candle"

View File

@ -1,6 +1,6 @@
[package]
name = "candle-onnx"
version = "0.7.1"
version = "0.7.2"
edition = "2021"
description = "ONNX support for Candle"
@ -10,8 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
candle = { path = "../candle-core", package = "candle-core", version = "0.7.1" }
candle-nn = { path = "../candle-nn", version = "0.7.1" }
candle = { path = "../candle-core", package = "candle-core", version = "0.7.2" }
candle-nn = { path = "../candle-nn", version = "0.7.2" }
prost = "0.12.1"
[build-dependencies]

View File

@ -2,7 +2,7 @@ use crate::onnx::attribute_proto::AttributeType;
use crate::onnx::tensor_proto::DataType;
use crate::onnx::{self, GraphProto};
use candle::{bail, DType, Device, Result, Tensor};
use std::{collections::HashMap, usize};
use std::collections::HashMap;
pub type Value = Tensor;
@ -321,8 +321,15 @@ fn simple_eval_(
for node in graph.node.iter() {
let get = |input_name: &str| match values.get(input_name) {
Some(value) => Ok(value),
None => bail!("cannot find {input_name} for op {}", node.name),
None => bail!("cannot find {input_name} for op '{}'", node.name),
};
let get_opt = |i: usize| {
node.input
.get(i)
.filter(|s: &&String| !s.is_empty())
.map(|s| get(s))
};
// TODO: Validate node.input for each operator.
match node.op_type.as_str() {
"Add" => {
@ -355,7 +362,7 @@ fn simple_eval_(
// HACK: current implementation of broadcast_pow cannot handle negative base,
// so we use powf where we can, which *does* correctly handle negative base.
if let Ok(exp) = (|| input1.to_dtype(DType::F64)?.to_scalar::<f64>())() {
let output = input0.powf(exp as f64)?;
let output = input0.powf(exp)?;
values.insert(node.output[0].clone(), output);
} else {
let output = input0.broadcast_pow(input1)?;
@ -608,15 +615,13 @@ fn simple_eval_(
}
"Clip" => {
let xs = get(&node.input[0])?;
let xs = if node.input.len() >= 2 {
let mins = get(&node.input[1])?;
xs.broadcast_maximum(mins)?
let xs = if let Some(mins) = get_opt(1) {
xs.broadcast_maximum(mins?)?
} else {
xs.clone()
};
let xs = if node.input.len() >= 3 {
let maxs = get(&node.input[2])?;
xs.broadcast_minimum(maxs)?
let xs = if let Some(maxs) = get_opt(2) {
xs.broadcast_minimum(maxs?)?
} else {
xs.clone()
};
@ -638,7 +643,7 @@ fn simple_eval_(
let mask = indices.lt(&zeros)?;
mask.to_dtype(indices.dtype())?
.broadcast_mul(&max)?
.add(&indices)?
.add(indices)?
};
// In Pytorch or Numpy this can be done by indexing the xs tensor using the indices
@ -759,7 +764,14 @@ fn simple_eval_(
let cond = get(&node.input[0])?;
let a = get(&node.input[1])?;
let b = get(&node.input[2])?;
let output = cond.where_cond(a, b)?;
// where_cond requires that all inputs are the same shape.
// In contrast, the Where op in ONNX only requires that they are broadcastable.
let shape = broadcast_shape_from_many(&[cond.dims(), a.dims(), b.dims()])?;
let cond = cond.broadcast_as(shape.clone())?;
let a = a.broadcast_as(shape.clone())?;
let b = b.broadcast_as(shape)?;
let output = cond.where_cond(&a, &b)?;
values.insert(node.output[0].clone(), output);
}
"Conv" => {
@ -962,6 +974,7 @@ fn simple_eval_(
}
rtype => bail!("unsupported 'value' type {rtype:?} for {}", node.name),
};
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
@ -1199,6 +1212,151 @@ fn simple_eval_(
};
values.insert(node.output[0].clone(), output);
}
//https://github.com/onnx/onnx/blob/main/docs/Operators.md#Split
// Version 18 impl
"Split" => {
let input_tensor = get(&node.input[0])?;
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
let axis = input_tensor.normalize_axis(axis)?;
// Determine split sizes
let splits = if node.input.len() > 1 {
// If the split tensor is provided, use it to determine sizes
let split_tensor = get(&node.input[1])?.to_vec1::<i64>()?;
split_tensor.iter().map(|&x| x as usize).collect::<Vec<_>>()
} else {
let num_outputs = if let Some(&num_outputs_attrib) =
get_attr_opt::<i64>(node, "num_outputs")?
{
num_outputs_attrib as usize
} else {
node.output.len()
};
let input_dim = input_tensor.dim(axis)?;
let mut split_sizes =
vec![input_dim / num_outputs as usize; num_outputs as usize];
let remainder = input_dim % num_outputs as usize;
if remainder > 0 {
// If there's a remainder, add it to the last split size
split_sizes[num_outputs as usize - 1] += remainder;
}
split_sizes
};
// Perform the split operation
let mut outputs = vec![];
let mut start = 0;
for &size in &splits {
let end = start + size;
let slice = input_tensor.narrow(axis, start, size)?;
outputs.push(slice);
start = end;
}
// Insert the split outputs into the values map
for (output, slice) in node.output.iter().zip(outputs.into_iter()) {
values.insert(output.clone(), slice);
}
}
//https://github.com/onnx/onnx/blob/main/docs/Operators.md#Expand
// Version 13 impl
"Expand" => {
// unlike broadcast_to, expand allows for the output shape to
// be different from the specified shape.
let input_tensor = get(&node.input[0])?;
let input_shape = get(&node.input[1])?;
// Check that the shape tensor is 1D
if input_shape.rank() != 1 {
bail!(
"Expand expects 'shape' input to be 1D tensor: {:?}",
input_shape
);
}
let input_tensor_dims = input_tensor.dims();
let input_shape_dims = input_shape
.to_vec1::<i64>()?
.into_iter()
.map(|x| x as usize)
.collect::<Vec<_>>();
let target_shape = broadcast_shape(input_tensor_dims, input_shape_dims.as_slice())?;
let expanded_tensor = input_tensor.broadcast_as(target_shape)?;
values.insert(node.output[0].clone(), expanded_tensor);
}
//https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceSum
// Version 13 impl
"ReduceSum" => {
let input = get(&node.input[0])?;
let axes = get_opt(1);
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1);
let noop_with_empty_axes = get_attr_opt::<i64>(node, "noop_with_empty_axes")?
.copied()
.unwrap_or(0);
let axes = match axes {
Some(Ok(axes)) => axes
.to_vec1::<i64>()?
.into_iter()
.map(|x| x as usize)
.collect::<Vec<_>>(),
Some(Err(_)) | None => {
if noop_with_empty_axes == 1 {
vec![]
} else {
(0..input.rank()).collect()
}
}
};
let output = if keepdims == 1 {
input.sum_keepdim(axes)?
} else {
input.sum(axes)?
};
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceL2
// Version 18 impl
"ReduceL2" => {
let input = get(&node.input[0])?;
let axes = get_opt(1);
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1);
let noop_with_empty_axes = get_attr_opt::<i64>(node, "noop_with_empty_axes")?
.copied()
.unwrap_or(0);
let input_sq = input.sqr()?;
let axes = match axes {
Some(axes) => axes?
.to_vec1::<i64>()?
.into_iter()
.map(|x| x as usize)
.collect::<Vec<_>>(),
None => {
if noop_with_empty_axes == 1 {
vec![]
} else {
(0..input_sq.rank()).collect()
}
}
};
let output = if keepdims == 1 {
input_sq.sum_keepdim(axes)?.sqrt()?
} else {
input_sq.sum(axes)?.sqrt()?
};
values.insert(node.output[0].clone(), output);
}
random_type @ ("RandomUniform" | "RandomNormal") => {
let dt: i64 = get_attr_opt(node, "dtype")?.copied().unwrap_or(1); // 1 is float
// type by
@ -1395,13 +1553,6 @@ fn simple_eval_(
// This tensor has shape `[num_directions, 4*hidden_size, hidden_size]`.
let r = get(&node.input[2])?;
let get_opt = |i: usize| {
node.input
.get(i)
.filter(|s: &&String| !s.is_empty())
.map(|s| get(s))
};
// The bias tensor for input gate.
// Concatenation of `[Wb[iofc], Rb[iofc]]`, and `[WBb[iofc], RBb[iofc]]` (if bidirectional) along dimension 0.
// This tensor has shape `[num_directions, 8*hidden_size]`.
@ -1488,7 +1639,7 @@ fn simple_eval_(
let w = w.get(0)?; // w[iofc] has shape [4*hidden_size, input_size]
let r = r.get(0)?; // r[iofc] has shape [4*hidden_size, hidden_size]
let b = b.get(0)?; // concat of [wb[iofc],rb[iofc]] has shape [8*hidden_size]
let idx_wb = Tensor::arange(0 * hidden_size, 4 * hidden_size, x.device())?;
let idx_wb = Tensor::arange(0, 4 * hidden_size, x.device())?;
let idx_rb = Tensor::arange(4 * hidden_size, 8 * hidden_size, x.device())?;
let wb = b.index_select(&idx_wb, 0)?;
let rb = b.index_select(&idx_rb, 0)?;
@ -1497,8 +1648,8 @@ fn simple_eval_(
// w, r, wb, rb are all iofc but lstm expects ifco
// so we need to move some stuff around
let idx_i = Tensor::arange(0 * hidden_size, 1 * hidden_size, x.device())?;
let idx_o = Tensor::arange(1 * hidden_size, 2 * hidden_size, x.device())?;
let idx_i = Tensor::arange(0, hidden_size, x.device())?;
let idx_o = Tensor::arange(hidden_size, 2 * hidden_size, x.device())?;
let idx_f = Tensor::arange(2 * hidden_size, 3 * hidden_size, x.device())?;
let idx_c = Tensor::arange(3 * hidden_size, 4 * hidden_size, x.device())?;
let idx_ifco = Tensor::cat(&[&idx_i, &idx_f, &idx_c, &idx_o], 0)?;
@ -1522,7 +1673,7 @@ fn simple_eval_(
)?;
let mut lstm_state = candle_nn::rnn::LSTMState::new(h, c);
let mut h_acc = if node.output.get(0).map(String::as_str).unwrap_or("") != "" {
let mut h_acc = if node.output.first().map(String::as_str).unwrap_or("") != "" {
Some(vec![])
} else {
None
@ -1536,7 +1687,7 @@ fn simple_eval_(
}
assert_eq!(num_directions, 1, "if support for bidirectional is ever added, outputs will have to be concatenated, not simply reshaped");
if let Some(name) = node.output.get(0) {
if let Some(name) = node.output.first() {
let h_acc = h_acc.as_ref().unwrap();
let h_acc = lstm.states_to_tensor(h_acc)?;
let h_acc = h_acc.reshape((
@ -1580,3 +1731,36 @@ fn simple_eval_(
})
.collect()
}
fn broadcast_shape(shape_a: &[usize], shape_b: &[usize]) -> Result<Vec<usize>> {
let (longest, shortest) = if shape_a.len() > shape_b.len() {
(shape_a, shape_b)
} else {
(shape_b, shape_a)
};
let diff = longest.len() - shortest.len();
let mut target_shape = longest[0..diff].to_vec();
for (dim1, dim2) in longest[diff..].iter().zip(shortest.iter()) {
if *dim1 == *dim2 || *dim2 == 1 || *dim1 == 1 {
target_shape.push(usize::max(*dim1, *dim2));
} else {
bail!(
"Expand: incompatible shapes for broadcast, {:?} and {:?}",
shape_a,
shape_b
);
}
}
Ok(target_shape)
}
fn broadcast_shape_from_many(shapes: &[&[usize]]) -> Result<Vec<usize>> {
if shapes.is_empty() {
return Ok(Vec::new());
}
let mut shape_out = shapes[0].to_vec();
for shape in shapes[1..].iter() {
shape_out = broadcast_shape(&shape_out, shape)?;
}
Ok(shape_out)
}

View File

@ -1,12 +1,5 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::test_utils::to_vec2_round;
use candle::{DType, Device, NdArray, Result, Tensor};
use candle_onnx::eval::Value;
use candle_onnx::onnx::attribute_proto::AttributeType;
use candle_onnx::onnx::tensor_proto::DataType;
use candle_onnx::onnx::tensor_shape_proto::{dimension, Dimension};
@ -3574,312 +3567,312 @@ fn test_lstm() -> Result<()> {
let number_directions = 1;
let weight_ih_l0 = Tensor::from_vec::<_, f32>(
vec![
-1.5255959033966064,
-0.7502318024635315,
-0.6539809107780457,
-1.6094847917556763,
-0.1001671776175499,
-0.6091889142990112,
-0.9797722697257996,
-1.6090962886810303,
-0.7121446132659912,
0.30372199416160583,
-0.777314305305481,
-0.25145524740219116,
-0.22227048873901367,
1.6871134042739868,
0.22842517495155334,
0.46763551235198975,
-0.6969724297523499,
-1.1607614755630493,
0.6995424032211304,
0.1990816295146942,
0.8656923770904541,
0.2444038987159729,
-0.6629113554954529,
0.8073082566261292,
1.1016806364059448,
-0.1759360432624817,
-2.2455577850341797,
-1.4464579820632935,
0.0611552819609642,
-0.6177444458007812,
-0.7980698347091675,
-0.13162320852279663,
1.8793457746505737,
-0.07213178277015686,
0.15777060389518738,
-0.7734549045562744,
0.1990565061569214,
0.04570277780294418,
0.15295691788196564,
-0.47567880153656006,
-0.11101982742547989,
0.2927352488040924,
-0.1578451544046402,
-0.028787139803171158,
0.4532545804977417,
1.1421611309051514,
0.2486107051372528,
-1.7754007577896118,
-0.025502461940050125,
-1.023330569267273,
-0.5961851477622986,
-1.0055307149887085,
0.42854228615760803,
1.4760777950286865,
-1.7868678569793701,
1.610317587852478,
-0.703956663608551,
-0.18526579439640045,
-0.9962350726127625,
-0.8312552571296692,
-1.525_595_9,
-0.750_231_8,
-0.653_980_9,
-1.609_484_8,
-0.100_167_18,
-0.609_188_9,
-0.979_772_27,
-1.609_096_3,
-0.712_144_6,
0.303_722,
-0.777_314_3,
-0.251_455_25,
-0.222_270_49,
1.687_113_4,
0.228_425_17,
0.467_635_5,
-0.696_972_4,
-1.160_761_5,
0.699_542_4,
0.199_081_63,
0.865_692_4,
0.244_403_9,
-0.662_911_36,
0.807_308_26,
1.101_680_6,
-0.175_936_04,
-2.245_557_8,
-1.446_458,
0.061_155_282,
-0.617_744_45,
-0.798_069_83,
-0.131_623_21,
1.879_345_8,
-0.072_131_78,
0.157_770_6,
-0.773_454_9,
0.199_056_5,
0.045_702_778,
0.152_956_92,
-0.475_678_8,
-0.111_019_83,
0.292_735_25,
-0.157_845_15,
-0.028_787_14,
0.453_254_58,
1.142_161_1,
0.248_610_7,
-1.775_400_8,
-0.025_502_462,
-1.023_330_6,
-0.596_185_15,
-1.005_530_7,
0.428_542_3,
1.476_077_8,
-1.786_867_9,
1.610_317_6,
-0.703_956_66,
-0.185_265_8,
-0.996_235_1,
-0.831_255_26,
],
(20, 3),
&Device::Cpu,
)?;
let weight_hh_l0 = Tensor::from_vec::<_, f32>(
vec![
0.4099724292755127,
0.4084506630897522,
0.25786539912223816,
1.095021367073059,
-0.5064865946769714,
0.09977540373802185,
-0.653973400592804,
0.731693685054779,
-1.456732988357544,
1.6089353561401367,
0.09376997500658035,
-1.2597490549087524,
0.25463348627090454,
-0.5019572973251343,
-1.041200041770935,
0.7322672009468079,
1.3075355291366577,
-1.1627987623214722,
0.11963611096143723,
-0.1631353348493576,
0.6614453196525574,
1.1899205446243286,
0.8165339231491089,
-0.9135236144065857,
-0.3538065254688263,
0.7639270424842834,
-0.5889506936073303,
-0.7635973691940308,
1.3352056741714478,
0.6042736172676086,
-0.10344208031892776,
-0.15121692419052124,
1.2465683221817017,
0.505721390247345,
0.9505112171173096,
1.2966482639312744,
0.873796284198761,
-0.5602594017982483,
1.2857844829559326,
0.8168238401412964,
-1.464799404144287,
-1.2629283666610718,
1.122018814086914,
1.5663341283798218,
2.558138370513916,
-0.23336388170719147,
-0.013472129590809345,
1.8606348037719727,
1.549620509147644,
0.34762924909591675,
0.09300802648067474,
0.6147403120994568,
0.7123645544052124,
-1.7765072584152222,
0.3538645803928375,
1.1996132135391235,
-0.7122589349746704,
-0.620034396648407,
-0.22813494503498077,
-0.7892746329307556,
-1.6111117601394653,
-1.8716129064559937,
0.5430836081504822,
0.6606786251068115,
0.270527720451355,
0.5596919655799866,
-0.31839630007743835,
1.5117206573486328,
-1.363267183303833,
-0.9832196235656738,
1.5112667083740234,
0.6418707370758057,
-0.7474458813667297,
-0.923438549041748,
0.5733984112739563,
-0.10929951071739197,
0.5181121230125427,
0.10653535276651382,
0.26924076676368713,
1.3247679471969604,
0.037456899881362915,
-0.6378393173217773,
-0.8147554397583008,
-0.6895065307617188,
0.8436542749404907,
1.1657012701034546,
0.5269321799278259,
1.6192532777786255,
-0.963976263999939,
0.14152038097381592,
-0.1636609584093094,
-0.3582225739955902,
1.7222793102264404,
-0.3035756051540375,
0.23887419700622559,
1.3440011739730835,
0.1032256931066513,
1.1003541946411133,
-0.3416801989078522,
0.947338879108429,
0.409_972_43,
0.408_450_66,
0.257_865_4,
1.095_021_4,
-0.506_486_6,
0.099_775_404,
-0.653_973_4,
0.731_693_7,
-1.456_733,
1.608_935_4,
0.093_769_975,
-1.259_749,
0.254_633_5,
-0.501_957_3,
-1.041_2,
0.732_267_2,
1.307_535_5,
-1.162_798_8,
0.119_636_11,
-0.163_135_33,
0.661_445_3,
1.189_920_5,
0.816_533_9,
-0.913_523_6,
-0.353_806_53,
0.763_927_04,
-0.588_950_7,
-0.763_597_37,
1.335_205_7,
0.604_273_6,
-0.103_442_08,
-0.151_216_92,
1.246_568_3,
0.505_721_4,
0.950_511_2,
1.296_648_3,
0.873_796_3,
-0.560_259_4,
1.285_784_5,
0.816_823_84,
-1.464_799_4,
-1.262_928_4,
1.122_018_8,
1.566_334_1,
2.558_138_4,
-0.233_363_88,
-0.013_472_13,
1.860_634_8,
1.549_620_5,
0.347_629_25,
0.093_008_03,
0.614_740_3,
0.712_364_55,
-1.776_507_3,
0.353_864_58,
1.199_613_2,
-0.712_258_93,
-0.620_034_4,
-0.228_134_95,
-0.789_274_63,
-1.611_111_8,
-1.871_612_9,
0.543_083_6,
0.660_678_6,
0.270_527_72,
0.559_691_97,
-0.318_396_3,
1.511_720_7,
-1.363_267_2,
-0.983_219_6,
1.511_266_7,
0.641_870_74,
-0.747_445_9,
-0.923_438_55,
0.573_398_4,
-0.109_299_51,
0.518_112_1,
0.106_535_35,
0.269_240_77,
1.324_768,
0.037_456_9,
-0.637_839_3,
-0.814_755_44,
-0.689_506_53,
0.843_654_3,
1.165_701_3,
0.526_932_2,
1.619_253_3,
-0.963_976_26,
0.141_520_38,
-0.163_660_96,
-0.358_222_57,
1.722_279_3,
-0.303_575_6,
0.238_874_2,
1.344_001_2,
0.103_225_69,
1.100_354_2,
-0.341_680_2,
0.947_338_9,
],
(20, 5),
&Device::Cpu,
)?;
let bias_ih_l0 = Tensor::from_vec::<_, f32>(
vec![
-0.568515956401825,
0.8375961780548096,
1.783660650253296,
-0.1954246610403061,
0.235193133354187,
1.9142433404922485,
1.8364111185073853,
1.324532389640808,
-0.07051458209753036,
0.34697940945625305,
-0.653679609298706,
1.5586202144622803,
0.2185661494731903,
-0.5743072628974915,
1.4571250677108765,
1.7709556818008423,
-2.0172998905181885,
0.42350319027900696,
0.5730220079421997,
-1.7962429523468018,
-0.568_515_96,
0.837_596_2,
1.783_660_7,
-0.195_424_66,
0.235_193_13,
1.914_243_3,
1.836_411_1,
1.324_532_4,
-0.070_514_58,
0.346_979_4,
-0.653_679_6,
1.558_620_2,
0.218_566_15,
-0.574_307_26,
1.457_125_1,
1.770_955_7,
-2.017_3,
0.423_503_2,
0.573_022,
-1.796_243,
],
(20,),
&Device::Cpu,
)?;
let bias_hh_l0 = Tensor::from_vec::<_, f32>(
vec![
1.2470403909683228,
1.2738511562347412,
0.3909492492675781,
0.387210488319397,
0.14440394937992096,
0.7771684527397156,
-2.3381125926971436,
-0.829120397567749,
1.1661391258239746,
1.4786574840545654,
0.26760873198509216,
0.7561198472976685,
-0.5873361229896545,
-2.061920642852783,
0.4304734766483307,
0.3376566171646118,
-0.3437853455543518,
-0.6172260642051697,
1.2529692649841309,
-0.05141742154955864,
1.247_040_4,
1.273_851_2,
0.390_949_25,
0.387_210_5,
0.144_403_95,
0.777_168_45,
-2.338_112_6,
-0.829_120_4,
1.166_139_1,
1.478_657_5,
0.267_608_73,
0.756_119_85,
-0.587_336_1,
-2.061_920_6,
0.430_473_48,
0.337_656_62,
-0.343_785_35,
-0.617_226_06,
1.252_969_3,
-0.051_417_42,
],
(20,),
&Device::Cpu,
)?;
let input = Tensor::from_vec::<_, f32>(
vec![
0.6472128033638,
-0.04116716980934143,
-0.17749308049678802,
-0.500039279460907,
0.8672749400138855,
-0.27319222688674927,
-0.4607681334018707,
-0.0990937128663063,
0.47284480929374695,
1.0049484968185425,
-0.2871420383453369,
-1.1618621349334717,
0.647_212_8,
-0.041_167_17,
-0.177_493_08,
-0.500_039_3,
0.867_274_94,
-0.273_192_23,
-0.460_768_13,
-0.099_093_71,
0.472_844_8,
1.004_948_5,
-0.287_142_04,
-1.161_862_1,
],
(4, 1, 3),
&Device::Cpu,
)?;
let h0 = Tensor::from_vec::<_, f32>(
vec![
0.02758178487420082,
0.5652382373809814,
-0.011487378738820553,
0.6706400513648987,
-0.4929250478744507,
0.027_581_785,
0.565_238_24,
-0.011_487_379,
0.670_640_05,
-0.492_925_05,
],
(1, 1, 5),
&Device::Cpu,
)?;
let c0 = Tensor::from_vec::<_, f32>(
vec![
1.505028486251831,
-2.32635498046875,
1.6168899536132812,
-0.9026237726211548,
0.17366823554039001,
1.505_028_5,
-2.326_355,
1.616_89,
-0.902_623_8,
0.173_668_24,
],
(1, 1, 5),
&Device::Cpu,
)?;
let output = Tensor::from_vec::<_, f32>(
vec![
0.5956016778945923,
-0.01723279245197773,
0.11035571992397308,
-0.49323174357414246,
0.047632161527872086,
0.6358451843261719,
0.040328118950128555,
-0.3788611590862274,
-0.7464339733123779,
0.20080909132957458,
0.5840265154838562,
0.1453288197517395,
-0.7345298528671265,
-0.5214304327964783,
0.21903817355632782,
0.7420451641082764,
0.31943878531455994,
-0.04726646468043327,
-0.2823849618434906,
0.2713133990764618,
0.595_601_7,
-0.017_232_792,
0.110_355_72,
-0.493_231_74,
0.047_632_16,
0.635_845_2,
0.040_328_12,
-0.378_861_16,
-0.746_434,
0.200_809_09,
0.584_026_5,
0.145_328_82,
-0.734_529_85,
-0.521_430_43,
0.219_038_17,
0.742_045_16,
0.319_438_8,
-0.047_266_465,
-0.282_384_96,
0.271_313_4,
],
(4, 1, 5),
&Device::Cpu,
)?;
let hn = Tensor::from_vec::<_, f32>(
vec![
0.7420451641082764,
0.31943878531455994,
-0.04726646468043327,
-0.2823849618434906,
0.2713133990764618,
0.742_045_16,
0.319_438_8,
-0.047_266_465,
-0.282_384_96,
0.271_313_4,
],
(1, 1, 5),
&Device::Cpu,
)?;
let cn = Tensor::from_vec::<_, f32>(
vec![
0.9630558490753174,
1.0033069849014282,
-1.754899024963379,
-1.5967122316360474,
0.8252924680709839,
0.963_055_85,
1.003_307,
-1.754_899,
-1.596_712_2,
0.825_292_47,
],
(1, 1, 5),
&Device::Cpu,
@ -3929,8 +3922,8 @@ fn test_lstm() -> Result<()> {
let idx_iofc = {
let stride = hidden_size as i64;
let dev = weight_ih_l0.device();
let idx_i = Tensor::arange(0 * stride, 1 * stride, dev)?;
let idx_f = Tensor::arange(1 * stride, 2 * stride, dev)?;
let idx_i = Tensor::arange(0, stride, dev)?;
let idx_f = Tensor::arange(stride, 2 * stride, dev)?;
let idx_g = Tensor::arange(2 * stride, 3 * stride, dev)?;
let idx_o = Tensor::arange(3 * stride, 4 * stride, dev)?;
@ -3966,17 +3959,346 @@ fn test_lstm() -> Result<()> {
Ok(diffs.iter().all(|f| f.abs() < 0.0001))
};
assert!(
diff_close_enough(&output, &actual_output)?,
diff_close_enough(&output, actual_output)?,
"output did not match expected\n{actual_output}\n{output}",
);
assert!(
diff_close_enough(&hn, &actual_hn)?,
diff_close_enough(&hn, actual_hn)?,
"hn did not match expected\n{actual_hn}\n{hn}",
);
assert!(
diff_close_enough(&cn, &actual_cn)?,
diff_close_enough(&cn, actual_cn)?,
"cn did not match expected\n{actual_cn}\n{cn}",
);
Ok(())
}
#[test]
fn test_expand_dim_changed() -> Result<()> {
// Create a manual graph for the Expand operation
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Expand".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec!["data".to_string(), "new_shape".to_string()],
output: vec!["expanded".to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
input: vec![
ValueInfoProto {
name: "data".to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: "new_shape".to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: "expanded".to_string(),
doc_string: "".to_string(),
r#type: None,
}],
..GraphProto::default()
}));
// Input tensor with shape [3, 1]
let data = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32], (3, 1), &Device::Cpu)?;
// New shape tensor: [2, 1, 6]
let new_shape = Tensor::from_vec(vec![2i64, 1, 6], (3,), &Device::Cpu)?;
// Expected output after expansion
let expected = Tensor::from_vec(
vec![
1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32,
2.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32,
1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 3.0f32, 3.0f32, 3.0f32,
3.0f32, 3.0f32, 3.0f32,
],
(2, 3, 6),
&Device::Cpu,
)?;
// Execute the model evaluation
let inputs = HashMap::from_iter([
("data".to_string(), data),
("new_shape".to_string(), new_shape),
]);
let result = candle_onnx::simple_eval(&manual_graph, inputs)?;
// Retrieve and compare the result
let expanded = result.get("expanded").expect("Output 'expanded' not found");
assert_eq!(expanded.to_vec3::<f32>()?, expected.to_vec3::<f32>()?);
Ok(())
}
fn make_graph_helper(
op_name: &str,
inputs: &[&str],
outputs: &[&str],
attribs: Vec<AttributeProto>,
) -> ModelProto {
create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: op_name.to_string(),
domain: "".to_string(),
attribute: attribs,
input: inputs.iter().map(|s| s.to_string()).collect(),
output: outputs.iter().map(|s| s.to_string()).collect(),
name: "".to_string(),
doc_string: "".to_string(),
}],
input: inputs
.iter()
.map(|name| ValueInfoProto {
name: name.to_string(),
..ValueInfoProto::default()
})
.collect(),
output: outputs
.iter()
.map(|name| ValueInfoProto {
name: name.to_string(),
..ValueInfoProto::default()
})
.collect(),
..GraphProto::default()
}))
}
#[test]
fn test_expand_dim_unchanged() -> Result<()> {
// Create a manual graph for the Expand operation
let manual_graph = make_graph_helper("Expand", &["data", "new_shape"], &["expanded"], vec![]);
// Input tensor with shape [3, 1] and dtype f32
let data = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32], (3, 1), &Device::Cpu)?;
// New shape tensor: [3, 4]
let new_shape = Tensor::from_vec(vec![3i64, 4], (2,), &Device::Cpu)?;
// Expected output after expansion, dtype f32
let expected = Tensor::from_vec(
vec![
1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 3.0f32, 3.0f32, 3.0f32,
3.0f32,
],
(3, 4),
&Device::Cpu,
)?;
// Execute the model evaluation
let inputs = HashMap::from_iter([
("data".to_string(), data),
("new_shape".to_string(), new_shape),
]);
let result = candle_onnx::simple_eval(&manual_graph, inputs)?;
// Retrieve and compare the result
let expanded = result.get("expanded").expect("Output 'expanded' not found");
assert_eq!(expanded.to_vec2::<f32>()?, expected.to_vec2::<f32>()?);
Ok(())
}
fn make_split_graph_helper(inputs: &[&str], outputs: &[&str], axis: i64) -> ModelProto {
let attribs = vec![AttributeProto {
name: "axis".to_string(),
r#type: AttributeType::Int.into(),
i: axis,
..AttributeProto::default()
}];
make_graph_helper("Split", inputs, outputs, attribs)
}
#[test]
fn test_split_equal_parts_1d_opset13() -> Result<()> {
let input = Tensor::from_vec(
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32],
(6,),
&Device::Cpu,
)?;
let mut inputs = HashMap::new();
inputs.insert("input".to_string(), input);
{
let manual_graph =
make_split_graph_helper(&["input"], &["output_1", "output_2", "output_3"], 0);
let eval = candle_onnx::simple_eval(&manual_graph, inputs.clone())?;
assert_eq!(eval.len(), 3);
let out1 = eval.get("output_1").expect("Output 'output_1' not found");
let out2 = eval.get("output_2").expect("Output 'output_2' not found");
let out3 = eval.get("output_3").expect("Output 'output_3' not found");
assert_eq!(out1.to_vec1::<f32>()?, vec![1.0f32, 2.0f32]);
assert_eq!(out2.to_vec1::<f32>()?, vec![3.0f32, 4.0f32]);
assert_eq!(out3.to_vec1::<f32>()?, vec![5.0f32, 6.0f32]);
}
{
let splits = Tensor::from_vec(vec![2i64, 4], (2,), &Device::Cpu)?;
inputs.insert("split".to_string(), splits);
let manual_graph =
make_split_graph_helper(&["input", "split"], &["output_1", "output_2"], 0);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 2);
let out1 = eval.get("output_1").expect("Output 'output_1' not found");
let out2 = eval.get("output_2").expect("Output 'output_2' not found");
assert_eq!(out1.to_vec1::<f32>()?, vec![1.0f32, 2.0f32]);
assert_eq!(out2.to_vec1::<f32>()?, vec![3.0f32, 4.0f32, 5.0f32, 6.0f32]);
}
Ok(())
}
fn make_reduce_sum_graph_helper(
inputs: &[&str],
outputs: &[&str],
keepdims: Option<i64>,
noop_with_empty_axes: Option<i64>,
) -> ModelProto {
let mut attribs = vec![];
if let Some(keepdims) = keepdims {
attribs.push(AttributeProto {
name: "keepdims".to_string(),
r#type: AttributeType::Int.into(),
i: keepdims,
..AttributeProto::default()
});
}
if let Some(noop_with_empty_axes) = noop_with_empty_axes {
attribs.push(AttributeProto {
name: "noop_with_empty_axes".to_string(),
r#type: AttributeType::Ints.into(),
i: noop_with_empty_axes,
..AttributeProto::default()
});
}
make_graph_helper("ReduceSum", inputs, outputs, attribs)
}
#[test]
fn test_reduce_sum_default_axes_keepdims() -> Result<()> {
let manual_graph = make_reduce_sum_graph_helper(&["data", "axes"], &["reduced"], Some(1), None);
// Test with example data
{
let data = Tensor::from_vec(
vec![
1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
(3, 2, 2),
&Device::Cpu,
)?;
// let axes = Tensor::from_vec(Vec::<i64>::new(), (0,), &Device::Cpu)?;
let mut inputs = HashMap::new();
inputs.insert("data".to_string(), data);
// inputs.insert("axes".to_string(), axes);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let reduced = eval.get("reduced").expect("Output 'reduced' not found");
let expected = Tensor::from_vec(vec![78.0f32], (1, 1, 1), &Device::Cpu)?;
assert_eq!(reduced.to_vec3::<f32>()?, expected.to_vec3::<f32>()?);
}
{
let data = Tensor::from_vec(
vec![
-5.2f32, 7.8, -3.1, 9.4, 2.6, -8.7, 4.3, -1.9, 6.5, -0.8, -7.2, 3.6,
],
(3, 2, 2),
&Device::Cpu,
)?;
let mut inputs = HashMap::new();
inputs.insert("data".to_string(), data.clone());
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let reduced = eval.get("reduced").expect("Output 'reduced' not found");
let expected = data.sum_all()?.reshape((1, 1, 1))?;
assert_eq!(reduced.to_vec3::<f32>()?, expected.to_vec3::<f32>()?);
}
Ok(())
}
#[test]
fn test_reduce_sum_do_not_keep_dims() -> Result<()> {
let manual_graph = make_reduce_sum_graph_helper(&["data", "axes"], &["reduced"], Some(0), None);
// Test with example data
{
let data = Tensor::from_vec(
vec![
1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
(3, 2, 2),
&Device::Cpu,
)?;
let axes = Tensor::from_vec(vec![1i64], (1,), &Device::Cpu)?;
let mut inputs = HashMap::new();
inputs.insert("data".to_string(), data);
inputs.insert("axes".to_string(), axes);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let reduced = eval.get("reduced").expect("Output 'reduced' not found");
let expected = Tensor::from_vec(
vec![4.0f32, 6.0, 12.0, 14.0, 20.0, 22.0],
(3, 2),
&Device::Cpu,
)?;
assert_eq!(reduced.to_vec2::<f32>()?, expected.to_vec2::<f32>()?);
}
// Test with random data
{
let _shape = (3, 2, 2);
let data = Tensor::from_vec(
vec![
-5.2f32, 7.8, -3.1, 9.4, 2.6, -8.7, 4.3, -1.9, 6.5, -0.8, -7.2, 3.6,
],
(3, 2, 2),
&Device::Cpu,
)?;
let axes = Tensor::from_vec(vec![1i64], (1,), &Device::Cpu)?;
let mut inputs = HashMap::new();
inputs.insert("data".to_string(), data.clone());
inputs.insert("axes".to_string(), axes);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let reduced = eval.get("reduced").expect("Output 'reduced' not found");
// Calculate expected result
let expected = data.sum(1)?;
assert_eq!(reduced.to_vec2::<f32>()?, expected.to_vec2::<f32>()?);
}
Ok(())
}

View File

@ -92,28 +92,23 @@ impl ClipConfig {
impl ClipModel {
pub fn new(vs: candle_nn::VarBuilder, c: &ClipConfig) -> Result<Self> {
let text_model = ClipTextTransformer::new(vs.pp("text_model"), &c.text_config)?;
let vision_model = ClipVisionTransformer::new(vs.pp("vision_model"), &c.vision_config)?;
let visual_projection = candle_nn::linear_no_bias(
c.vision_config.embed_dim,
c.vision_config.projection_dim,
vs.pp("visual_projection"),
)?;
let text_projection = candle_nn::linear_no_bias(
c.text_config.embed_dim,
c.text_config.projection_dim,
vs.pp("text_projection"),
)?;
// originally nn.Parameter
let logit_scale = if vs.contains_tensor("logit_scale") {
vs.get(&[], "logit_scale")?
} else {
Tensor::new(&[c.logit_scale_init_value], vs.device())?
};
Ok(Self {
text_model,
vision_model,

View File

@ -77,7 +77,7 @@ impl ClipTextEmbeddings {
)?;
let position_ids =
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
Ok(ClipTextEmbeddings {
Ok(Self {
token_embedding,
position_embedding,
position_ids,
@ -298,7 +298,7 @@ impl ClipTextTransformer {
})
}
// TODO: rewrrite to newer version
// TODO: rewrite to newer version
fn build_causal_attention_mask(
bsz: usize,
seq_len: usize,

View File

@ -11,13 +11,13 @@ use candle_nn::{
BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
};
#[derive(Clone, Debug)]
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
pub struct Config {
exp_ratio: usize,
in_channels: usize,
blocks: [usize; 4],
attn: bool,
lkc_use_act: bool,
pub exp_ratio: usize,
pub in_channels: usize,
pub blocks: [usize; 4],
pub attn: bool,
pub lkc_use_act: bool,
}
impl Config {
@ -495,7 +495,6 @@ fn fastvit_model(cfg: &Config, nclasses: Option<usize>, vb: VarBuilder) -> Resul
.apply(&stage3)?
.apply(&stage4)?
.apply(&final_conv)?;
match &cls {
None => Ok(xs),
Some(cls) => xs.mean(D::Minus2)?.mean(D::Minus1)?.apply(cls),

View File

@ -1,3 +1,20 @@
use candle::{Result, Tensor};
pub trait WithForward {
#[allow(clippy::too_many_arguments)]
fn forward(
&self,
img: &Tensor,
img_ids: &Tensor,
txt: &Tensor,
txt_ids: &Tensor,
timesteps: &Tensor,
y: &Tensor,
guidance: Option<&Tensor>,
) -> Result<Tensor>;
}
pub mod autoencoder;
pub mod model;
pub mod quantized_model;
pub mod sampling;

View File

@ -109,14 +109,14 @@ fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result<Tensor> {
(fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec())
}
fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
pub(crate) fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
let q = apply_rope(q, pe)?.contiguous()?;
let k = apply_rope(k, pe)?.contiguous()?;
let x = scaled_dot_product_attention(&q, &k, v)?;
x.transpose(1, 2)?.flatten_from(2)
}
fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
pub(crate) fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
const TIME_FACTOR: f64 = 1000.;
const MAX_PERIOD: f64 = 10000.;
if dim % 2 == 1 {
@ -144,7 +144,7 @@ pub struct EmbedNd {
}
impl EmbedNd {
fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
pub fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
Self {
dim,
theta,
@ -575,9 +575,11 @@ impl Flux {
final_layer,
})
}
}
impl super::WithForward for Flux {
#[allow(clippy::too_many_arguments)]
pub fn forward(
fn forward(
&self,
img: &Tensor,
img_ids: &Tensor,

View File

@ -0,0 +1,465 @@
use super::model::{attention, timestep_embedding, Config, EmbedNd};
use crate::quantized_nn::{linear, linear_b, Linear};
use crate::quantized_var_builder::VarBuilder;
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{LayerNorm, RmsNorm};
fn layer_norm(dim: usize, vb: VarBuilder) -> Result<LayerNorm> {
let ws = Tensor::ones(dim, DType::F32, vb.device())?;
Ok(LayerNorm::new_no_bias(ws, 1e-6))
}
#[derive(Debug, Clone)]
pub struct MlpEmbedder {
in_layer: Linear,
out_layer: Linear,
}
impl MlpEmbedder {
fn new(in_sz: usize, h_sz: usize, vb: VarBuilder) -> Result<Self> {
let in_layer = linear(in_sz, h_sz, vb.pp("in_layer"))?;
let out_layer = linear(h_sz, h_sz, vb.pp("out_layer"))?;
Ok(Self {
in_layer,
out_layer,
})
}
}
impl candle::Module for MlpEmbedder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.in_layer)?.silu()?.apply(&self.out_layer)
}
}
#[derive(Debug, Clone)]
pub struct QkNorm {
query_norm: RmsNorm,
key_norm: RmsNorm,
}
impl QkNorm {
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
let query_norm = vb.get(dim, "query_norm.scale")?.dequantize(vb.device())?;
let query_norm = RmsNorm::new(query_norm, 1e-6);
let key_norm = vb.get(dim, "key_norm.scale")?.dequantize(vb.device())?;
let key_norm = RmsNorm::new(key_norm, 1e-6);
Ok(Self {
query_norm,
key_norm,
})
}
}
struct ModulationOut {
shift: Tensor,
scale: Tensor,
gate: Tensor,
}
impl ModulationOut {
fn scale_shift(&self, xs: &Tensor) -> Result<Tensor> {
xs.broadcast_mul(&(&self.scale + 1.)?)?
.broadcast_add(&self.shift)
}
fn gate(&self, xs: &Tensor) -> Result<Tensor> {
self.gate.broadcast_mul(xs)
}
}
#[derive(Debug, Clone)]
struct Modulation1 {
lin: Linear,
}
impl Modulation1 {
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
let lin = linear(dim, 3 * dim, vb.pp("lin"))?;
Ok(Self { lin })
}
fn forward(&self, vec_: &Tensor) -> Result<ModulationOut> {
let ys = vec_
.silu()?
.apply(&self.lin)?
.unsqueeze(1)?
.chunk(3, D::Minus1)?;
if ys.len() != 3 {
candle::bail!("unexpected len from chunk {ys:?}")
}
Ok(ModulationOut {
shift: ys[0].clone(),
scale: ys[1].clone(),
gate: ys[2].clone(),
})
}
}
#[derive(Debug, Clone)]
struct Modulation2 {
lin: Linear,
}
impl Modulation2 {
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
let lin = linear(dim, 6 * dim, vb.pp("lin"))?;
Ok(Self { lin })
}
fn forward(&self, vec_: &Tensor) -> Result<(ModulationOut, ModulationOut)> {
let ys = vec_
.silu()?
.apply(&self.lin)?
.unsqueeze(1)?
.chunk(6, D::Minus1)?;
if ys.len() != 6 {
candle::bail!("unexpected len from chunk {ys:?}")
}
let mod1 = ModulationOut {
shift: ys[0].clone(),
scale: ys[1].clone(),
gate: ys[2].clone(),
};
let mod2 = ModulationOut {
shift: ys[3].clone(),
scale: ys[4].clone(),
gate: ys[5].clone(),
};
Ok((mod1, mod2))
}
}
#[derive(Debug, Clone)]
pub struct SelfAttention {
qkv: Linear,
norm: QkNorm,
proj: Linear,
num_heads: usize,
}
impl SelfAttention {
fn new(dim: usize, num_heads: usize, qkv_bias: bool, vb: VarBuilder) -> Result<Self> {
let head_dim = dim / num_heads;
let qkv = linear_b(dim, dim * 3, qkv_bias, vb.pp("qkv"))?;
let norm = QkNorm::new(head_dim, vb.pp("norm"))?;
let proj = linear(dim, dim, vb.pp("proj"))?;
Ok(Self {
qkv,
norm,
proj,
num_heads,
})
}
fn qkv(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
let qkv = xs.apply(&self.qkv)?;
let (b, l, _khd) = qkv.dims3()?;
let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
let q = q.apply(&self.norm.query_norm)?;
let k = k.apply(&self.norm.key_norm)?;
Ok((q, k, v))
}
#[allow(unused)]
fn forward(&self, xs: &Tensor, pe: &Tensor) -> Result<Tensor> {
let (q, k, v) = self.qkv(xs)?;
attention(&q, &k, &v, pe)?.apply(&self.proj)
}
}
#[derive(Debug, Clone)]
struct Mlp {
lin1: Linear,
lin2: Linear,
}
impl Mlp {
fn new(in_sz: usize, mlp_sz: usize, vb: VarBuilder) -> Result<Self> {
let lin1 = linear(in_sz, mlp_sz, vb.pp("0"))?;
let lin2 = linear(mlp_sz, in_sz, vb.pp("2"))?;
Ok(Self { lin1, lin2 })
}
}
impl candle::Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2)
}
}
#[derive(Debug, Clone)]
pub struct DoubleStreamBlock {
img_mod: Modulation2,
img_norm1: LayerNorm,
img_attn: SelfAttention,
img_norm2: LayerNorm,
img_mlp: Mlp,
txt_mod: Modulation2,
txt_norm1: LayerNorm,
txt_attn: SelfAttention,
txt_norm2: LayerNorm,
txt_mlp: Mlp,
}
impl DoubleStreamBlock {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let h_sz = cfg.hidden_size;
let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize;
let img_mod = Modulation2::new(h_sz, vb.pp("img_mod"))?;
let img_norm1 = layer_norm(h_sz, vb.pp("img_norm1"))?;
let img_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("img_attn"))?;
let img_norm2 = layer_norm(h_sz, vb.pp("img_norm2"))?;
let img_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("img_mlp"))?;
let txt_mod = Modulation2::new(h_sz, vb.pp("txt_mod"))?;
let txt_norm1 = layer_norm(h_sz, vb.pp("txt_norm1"))?;
let txt_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("txt_attn"))?;
let txt_norm2 = layer_norm(h_sz, vb.pp("txt_norm2"))?;
let txt_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("txt_mlp"))?;
Ok(Self {
img_mod,
img_norm1,
img_attn,
img_norm2,
img_mlp,
txt_mod,
txt_norm1,
txt_attn,
txt_norm2,
txt_mlp,
})
}
fn forward(
&self,
img: &Tensor,
txt: &Tensor,
vec_: &Tensor,
pe: &Tensor,
) -> Result<(Tensor, Tensor)> {
let (img_mod1, img_mod2) = self.img_mod.forward(vec_)?; // shift, scale, gate
let (txt_mod1, txt_mod2) = self.txt_mod.forward(vec_)?; // shift, scale, gate
let img_modulated = img.apply(&self.img_norm1)?;
let img_modulated = img_mod1.scale_shift(&img_modulated)?;
let (img_q, img_k, img_v) = self.img_attn.qkv(&img_modulated)?;
let txt_modulated = txt.apply(&self.txt_norm1)?;
let txt_modulated = txt_mod1.scale_shift(&txt_modulated)?;
let (txt_q, txt_k, txt_v) = self.txt_attn.qkv(&txt_modulated)?;
let q = Tensor::cat(&[txt_q, img_q], 2)?;
let k = Tensor::cat(&[txt_k, img_k], 2)?;
let v = Tensor::cat(&[txt_v, img_v], 2)?;
let attn = attention(&q, &k, &v, pe)?;
let txt_attn = attn.narrow(1, 0, txt.dim(1)?)?;
let img_attn = attn.narrow(1, txt.dim(1)?, attn.dim(1)? - txt.dim(1)?)?;
let img = (img + img_mod1.gate(&img_attn.apply(&self.img_attn.proj)?))?;
let img = (&img
+ img_mod2.gate(
&img_mod2
.scale_shift(&img.apply(&self.img_norm2)?)?
.apply(&self.img_mlp)?,
)?)?;
let txt = (txt + txt_mod1.gate(&txt_attn.apply(&self.txt_attn.proj)?))?;
let txt = (&txt
+ txt_mod2.gate(
&txt_mod2
.scale_shift(&txt.apply(&self.txt_norm2)?)?
.apply(&self.txt_mlp)?,
)?)?;
Ok((img, txt))
}
}
#[derive(Debug, Clone)]
pub struct SingleStreamBlock {
linear1: Linear,
linear2: Linear,
norm: QkNorm,
pre_norm: LayerNorm,
modulation: Modulation1,
h_sz: usize,
mlp_sz: usize,
num_heads: usize,
}
impl SingleStreamBlock {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let h_sz = cfg.hidden_size;
let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize;
let head_dim = h_sz / cfg.num_heads;
let linear1 = linear(h_sz, h_sz * 3 + mlp_sz, vb.pp("linear1"))?;
let linear2 = linear(h_sz + mlp_sz, h_sz, vb.pp("linear2"))?;
let norm = QkNorm::new(head_dim, vb.pp("norm"))?;
let pre_norm = layer_norm(h_sz, vb.pp("pre_norm"))?;
let modulation = Modulation1::new(h_sz, vb.pp("modulation"))?;
Ok(Self {
linear1,
linear2,
norm,
pre_norm,
modulation,
h_sz,
mlp_sz,
num_heads: cfg.num_heads,
})
}
fn forward(&self, xs: &Tensor, vec_: &Tensor, pe: &Tensor) -> Result<Tensor> {
let mod_ = self.modulation.forward(vec_)?;
let x_mod = mod_.scale_shift(&xs.apply(&self.pre_norm)?)?;
let x_mod = x_mod.apply(&self.linear1)?;
let qkv = x_mod.narrow(D::Minus1, 0, 3 * self.h_sz)?;
let (b, l, _khd) = qkv.dims3()?;
let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
let mlp = x_mod.narrow(D::Minus1, 3 * self.h_sz, self.mlp_sz)?;
let q = q.apply(&self.norm.query_norm)?;
let k = k.apply(&self.norm.key_norm)?;
let attn = attention(&q, &k, &v, pe)?;
let output = Tensor::cat(&[attn, mlp.gelu()?], 2)?.apply(&self.linear2)?;
xs + mod_.gate(&output)
}
}
#[derive(Debug, Clone)]
pub struct LastLayer {
norm_final: LayerNorm,
linear: Linear,
ada_ln_modulation: Linear,
}
impl LastLayer {
fn new(h_sz: usize, p_sz: usize, out_c: usize, vb: VarBuilder) -> Result<Self> {
let norm_final = layer_norm(h_sz, vb.pp("norm_final"))?;
let linear_ = linear(h_sz, p_sz * p_sz * out_c, vb.pp("linear"))?;
let ada_ln_modulation = linear(h_sz, 2 * h_sz, vb.pp("adaLN_modulation.1"))?;
Ok(Self {
norm_final,
linear: linear_,
ada_ln_modulation,
})
}
fn forward(&self, xs: &Tensor, vec: &Tensor) -> Result<Tensor> {
let chunks = vec.silu()?.apply(&self.ada_ln_modulation)?.chunk(2, 1)?;
let (shift, scale) = (&chunks[0], &chunks[1]);
let xs = xs
.apply(&self.norm_final)?
.broadcast_mul(&(scale.unsqueeze(1)? + 1.0)?)?
.broadcast_add(&shift.unsqueeze(1)?)?;
xs.apply(&self.linear)
}
}
#[derive(Debug, Clone)]
pub struct Flux {
img_in: Linear,
txt_in: Linear,
time_in: MlpEmbedder,
vector_in: MlpEmbedder,
guidance_in: Option<MlpEmbedder>,
pe_embedder: EmbedNd,
double_blocks: Vec<DoubleStreamBlock>,
single_blocks: Vec<SingleStreamBlock>,
final_layer: LastLayer,
}
impl Flux {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let img_in = linear(cfg.in_channels, cfg.hidden_size, vb.pp("img_in"))?;
let txt_in = linear(cfg.context_in_dim, cfg.hidden_size, vb.pp("txt_in"))?;
let mut double_blocks = Vec::with_capacity(cfg.depth);
let vb_d = vb.pp("double_blocks");
for idx in 0..cfg.depth {
let db = DoubleStreamBlock::new(cfg, vb_d.pp(idx))?;
double_blocks.push(db)
}
let mut single_blocks = Vec::with_capacity(cfg.depth_single_blocks);
let vb_s = vb.pp("single_blocks");
for idx in 0..cfg.depth_single_blocks {
let sb = SingleStreamBlock::new(cfg, vb_s.pp(idx))?;
single_blocks.push(sb)
}
let time_in = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("time_in"))?;
let vector_in = MlpEmbedder::new(cfg.vec_in_dim, cfg.hidden_size, vb.pp("vector_in"))?;
let guidance_in = if cfg.guidance_embed {
let mlp = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("guidance_in"))?;
Some(mlp)
} else {
None
};
let final_layer =
LastLayer::new(cfg.hidden_size, 1, cfg.in_channels, vb.pp("final_layer"))?;
let pe_dim = cfg.hidden_size / cfg.num_heads;
let pe_embedder = EmbedNd::new(pe_dim, cfg.theta, cfg.axes_dim.to_vec());
Ok(Self {
img_in,
txt_in,
time_in,
vector_in,
guidance_in,
pe_embedder,
double_blocks,
single_blocks,
final_layer,
})
}
}
impl super::WithForward for Flux {
#[allow(clippy::too_many_arguments)]
fn forward(
&self,
img: &Tensor,
img_ids: &Tensor,
txt: &Tensor,
txt_ids: &Tensor,
timesteps: &Tensor,
y: &Tensor,
guidance: Option<&Tensor>,
) -> Result<Tensor> {
if txt.rank() != 3 {
candle::bail!("unexpected shape for txt {:?}", txt.shape())
}
if img.rank() != 3 {
candle::bail!("unexpected shape for img {:?}", img.shape())
}
let dtype = img.dtype();
let pe = {
let ids = Tensor::cat(&[txt_ids, img_ids], 1)?;
ids.apply(&self.pe_embedder)?
};
let mut txt = txt.apply(&self.txt_in)?;
let mut img = img.apply(&self.img_in)?;
let vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?;
let vec_ = match (self.guidance_in.as_ref(), guidance) {
(Some(g_in), Some(guidance)) => {
(vec_ + timestep_embedding(guidance, 256, dtype)?.apply(g_in))?
}
_ => vec_,
};
let vec_ = (vec_ + y.apply(&self.vector_in))?;
// Double blocks
for block in self.double_blocks.iter() {
(img, txt) = block.forward(&img, &txt, &vec_, &pe)?
}
// Single blocks
let mut img = Tensor::cat(&[&txt, &img], 1)?;
for block in self.single_blocks.iter() {
img = block.forward(&img, &vec_, &pe)?;
}
let img = img.i((.., txt.dim(1)?..))?;
self.final_layer.forward(&img, &vec_)
}
}

View File

@ -92,8 +92,8 @@ pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
}
#[allow(clippy::too_many_arguments)]
pub fn denoise(
model: &super::model::Flux,
pub fn denoise<M: super::WithForward>(
model: &M,
img: &Tensor,
img_ids: &Tensor,
txt: &Tensor,

View File

@ -44,6 +44,7 @@ pub struct LlamaConfig {
pub eos_token_id: Option<LlamaEosToks>,
pub rope_scaling: Option<Llama3RopeConfig>,
pub max_position_embeddings: usize,
pub tie_word_embeddings: Option<bool>,
}
impl LlamaConfig {
@ -72,6 +73,7 @@ impl LlamaConfig {
eos_token_id: self.eos_token_id,
rope_scaling: self.rope_scaling,
max_position_embeddings: self.max_position_embeddings,
tie_word_embeddings: self.tie_word_embeddings.unwrap_or(false),
}
}
}
@ -91,6 +93,7 @@ pub struct Config {
pub eos_token_id: Option<LlamaEosToks>,
pub rope_scaling: Option<Llama3RopeConfig>,
pub max_position_embeddings: usize,
pub tie_word_embeddings: bool,
}
impl Config {
@ -109,6 +112,7 @@ impl Config {
eos_token_id: None,
rope_scaling: None,
max_position_embeddings: DEFAULT_MAX_SEQ_LEN,
tie_word_embeddings: false,
}
}
@ -127,6 +131,7 @@ impl Config {
eos_token_id: None,
rope_scaling: None,
max_position_embeddings: DEFAULT_MAX_SEQ_LEN,
tie_word_embeddings: false,
}
}
}
@ -504,7 +509,11 @@ impl Llama {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let lm_head = if cfg.tie_word_embeddings {
Linear::from_weights(wte.embeddings().clone(), None)
} else {
linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
};
let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cfg).unwrap())

View File

@ -43,6 +43,7 @@ pub struct LLaVAConfig {
pub image_token_index: isize,
#[serde(default = "default_hf")]
pub hf: bool,
pub tie_word_embeddings: Option<bool>,
}
fn default_hf() -> bool {
@ -77,6 +78,7 @@ impl LLaVAConfig {
use_flash_attn: false,
rope_scaling: None, // Assume we don't have LLaVA for Llama 3.1
max_position_embeddings: self.max_position_embeddings,
tie_word_embeddings: self.tie_word_embeddings.unwrap_or(false),
}
}
}
@ -264,6 +266,7 @@ impl HFLLaVAConfig {
use_cache: self.text_config.use_cache,
vocab_size: self.vocab_size,
image_token_index: self.image_token_index,
tie_word_embeddings: None,
}
}
}

View File

@ -22,7 +22,6 @@ impl MobileClipConfig {
pub fn s1() -> Self {
let text_config = text_model::Config::vit_base_patch32();
let vision_config = fastvit::Config::mci1();
Self {
text_config,
vision_config,
@ -32,7 +31,6 @@ impl MobileClipConfig {
pub fn s2() -> Self {
let text_config = text_model::Config::vit_base_patch32();
let vision_config = fastvit::Config::mci2();
Self {
text_config,
vision_config,
@ -45,12 +43,10 @@ impl MobileClipModel {
pub fn new(vs: VarBuilder, c: &MobileClipConfig) -> Result<Self> {
let vision_model = fastvit::fastvit(&c.vision_config, 512, vs.pp("visual.trunk"))?;
let text_model = text_model::OpenClipTextTransformer::new(vs.pp("text"), &c.text_config)?;
let text_projection = vs.get(
(c.text_config.embed_dim, c.text_config.projection_dim),
"text.text_projection",
)?;
let logit_scale = vs.get(&[], "logit_scale")?;
Ok(Self {
text_model,

View File

@ -76,6 +76,7 @@ pub mod rwkv_v5;
pub mod rwkv_v6;
pub mod segformer;
pub mod segment_anything;
pub mod siglip;
pub mod stable_diffusion;
pub mod stable_lm;
pub mod starcoder2;

View File

@ -0,0 +1,608 @@
use crate::models::clip::div_l2_norm;
use candle::{IndexOp, Module, Result, Tensor, D};
use candle_nn::{layer_norm, linear, LayerNorm, Linear, VarBuilder};
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L27
#[derive(serde::Deserialize, Clone, Debug)]
pub struct TextConfig {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub max_position_embeddings: usize,
pub hidden_act: candle_nn::Activation,
pub layer_norm_eps: f64,
pub pad_token_id: u32,
pub bos_token_id: u32,
pub eos_token_id: u32,
}
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L132
#[derive(serde::Deserialize, Clone, Debug)]
pub struct VisionConfig {
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_channels: usize,
pub image_size: usize,
pub patch_size: usize,
pub hidden_act: candle_nn::Activation,
pub layer_norm_eps: f64,
}
trait TransformerConfig {
fn hidden_size(&self) -> usize;
fn intermediate_size(&self) -> usize;
fn num_attention_heads(&self) -> usize;
fn num_hidden_layers(&self) -> usize;
fn layer_norm_eps(&self) -> f64;
fn hidden_act(&self) -> candle_nn::Activation;
}
impl TransformerConfig for TextConfig {
fn hidden_size(&self) -> usize {
self.hidden_size
}
fn intermediate_size(&self) -> usize {
self.intermediate_size
}
fn num_attention_heads(&self) -> usize {
self.num_attention_heads
}
fn num_hidden_layers(&self) -> usize {
self.num_hidden_layers
}
fn layer_norm_eps(&self) -> f64 {
self.layer_norm_eps
}
fn hidden_act(&self) -> candle_nn::Activation {
self.hidden_act
}
}
impl TransformerConfig for VisionConfig {
fn hidden_size(&self) -> usize {
self.hidden_size
}
fn intermediate_size(&self) -> usize {
self.intermediate_size
}
fn num_attention_heads(&self) -> usize {
self.num_attention_heads
}
fn num_hidden_layers(&self) -> usize {
self.num_hidden_layers
}
fn layer_norm_eps(&self) -> f64 {
self.layer_norm_eps
}
fn hidden_act(&self) -> candle_nn::Activation {
self.hidden_act
}
}
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L228
#[derive(serde::Deserialize, Clone, Debug)]
pub struct Config {
pub text_config: TextConfig,
pub vision_config: VisionConfig,
}
impl Config {
pub fn base_patch16_224() -> Self {
let text_config = TextConfig {
// https://huggingface.co/google/siglip-base-patch16-224/blob/main/config.json
hidden_size: 768,
intermediate_size: 3072,
num_attention_heads: 12,
vocab_size: 32000,
// Default values.
pad_token_id: 1,
bos_token_id: 49406,
eos_token_id: 49407,
layer_norm_eps: 1e-6,
hidden_act: candle_nn::Activation::GeluPytorchTanh,
max_position_embeddings: 64,
num_hidden_layers: 12,
};
let vision_config = VisionConfig {
patch_size: 16,
// Default values.
hidden_size: 768,
intermediate_size: 3072,
num_hidden_layers: 12,
num_attention_heads: 12,
num_channels: 3,
image_size: 224,
hidden_act: candle_nn::Activation::GeluPytorchTanh,
layer_norm_eps: 1e-6,
};
Self {
text_config,
vision_config,
}
}
}
#[derive(Clone, Debug)]
struct MultiheadAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
out_proj: Linear,
num_heads: usize,
}
impl MultiheadAttention {
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
let h = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let w_in_proj = vb.get((3 * h, h), "in_proj_weight")?.chunk(3, 0)?;
let b_in_proj = vb.get(3 * h, "in_proj_bias")?.chunk(3, 0)?;
let q_proj = Linear::new(w_in_proj[0].clone(), Some(b_in_proj[0].clone()));
let k_proj = Linear::new(w_in_proj[1].clone(), Some(b_in_proj[1].clone()));
let v_proj = Linear::new(w_in_proj[2].clone(), Some(b_in_proj[2].clone()));
let out_proj = linear(h, h, vb.pp("out_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
out_proj,
num_heads,
})
}
fn separate_heads(&self, x: &Tensor) -> Result<Tensor> {
let (b, n, c) = x.dims3()?;
x.reshape((b, n, self.num_heads, c / self.num_heads))?
.transpose(1, 2)?
.contiguous()
}
fn recombine_heads(&self, x: &Tensor) -> Result<Tensor> {
let (b, n_heads, n_tokens, c_per_head) = x.dims4()?;
x.transpose(1, 2)?
.reshape((b, n_tokens, n_heads * c_per_head))
}
fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
let q = self.q_proj.forward(&q.contiguous()?)?;
let k = self.k_proj.forward(&k.contiguous()?)?;
let v = self.v_proj.forward(&v.contiguous()?)?;
let q = self.separate_heads(&q)?;
let k = self.separate_heads(&k)?;
let v = self.separate_heads(&v)?;
let (_, _, _, c_per_head) = q.dims4()?;
let attn = (q.matmul(&k.t()?)? / (c_per_head as f64).sqrt())?;
let attn = candle_nn::ops::softmax_last_dim(&attn)?;
let out = attn.matmul(&v)?;
self.recombine_heads(&out)?.apply(&self.out_proj)
}
}
#[derive(Debug, Clone)]
struct MultiheadAttentionPoolingHead {
probe: Tensor,
attention: MultiheadAttention,
layernorm: LayerNorm,
mlp: Mlp,
}
impl MultiheadAttentionPoolingHead {
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
let mlp = Mlp::new(cfg, vb.pp("mlp"))?;
let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("layernorm"))?;
let probe = vb.get((1, 1, cfg.hidden_size), "probe")?;
let attention = MultiheadAttention::new(cfg, vb.pp("attention"))?;
Ok(Self {
probe,
attention,
layernorm,
mlp,
})
}
}
impl Module for MultiheadAttentionPoolingHead {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let batch_size = xs.dim(0)?;
let probe = self.probe.repeat((batch_size, 1, 1))?;
let xs = self.attention.forward(&probe, xs, xs)?;
let residual = &xs;
let xs = xs.apply(&self.layernorm)?.apply(&self.mlp)?;
(xs + residual)?.i((.., 0))
}
}
#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
out_proj: Linear,
num_heads: usize,
head_dim: usize,
scale: f64,
}
impl Attention {
fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
let embed_dim = cfg.hidden_size();
let q_proj = linear(embed_dim, embed_dim, vb.pp("q_proj"))?;
let k_proj = linear(embed_dim, embed_dim, vb.pp("k_proj"))?;
let v_proj = linear(embed_dim, embed_dim, vb.pp("v_proj"))?;
let out_proj = linear(embed_dim, embed_dim, vb.pp("out_proj"))?;
let num_heads = cfg.num_attention_heads();
let head_dim = embed_dim / num_heads;
Ok(Self {
q_proj,
k_proj,
v_proj,
out_proj,
num_heads,
head_dim,
scale: (head_dim as f64).powf(-0.5),
})
}
fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
let (batch_size, q_len, _) = xs.dims3()?;
let query_states = xs.apply(&self.q_proj)?;
let key_states = xs.apply(&self.k_proj)?;
let value_states = xs.apply(&self.v_proj)?;
let shape = (batch_size, q_len, self.num_heads, self.head_dim);
let query_states = query_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?;
let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
// The original implementation upcasts to f32 but candle_nn::ops::softmax should handle this properly.
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let attn_outputs = attn_weights
.matmul(&value_states)?
.transpose(1, 2)?
.reshape((batch_size, q_len, ()))?
.apply(&self.out_proj)?;
Ok(attn_outputs)
}
}
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/modeling_siglip.py#L599
#[derive(Debug, Clone)]
struct Mlp {
fc1: Linear,
fc2: Linear,
activation_fn: candle_nn::Activation,
}
impl Mlp {
fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
let hidden_size = cfg.hidden_size();
let intermediate_size = cfg.intermediate_size();
let fc1 = candle_nn::linear(hidden_size, intermediate_size, vb.pp("fc1"))?;
let fc2 = candle_nn::linear(intermediate_size, hidden_size, vb.pp("fc2"))?;
Ok(Self {
fc1,
fc2,
activation_fn: cfg.hidden_act(),
})
}
}
impl Module for Mlp {
fn forward(&self, xs: &candle::Tensor) -> Result<candle::Tensor> {
xs.apply(&self.fc1)?
.apply(&self.activation_fn)?
.apply(&self.fc2)
}
}
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/modeling_siglip.py#L614
#[derive(Debug, Clone)]
struct EncoderLayer {
self_attn: Attention,
layer_norm1: LayerNorm,
mlp: Mlp,
layer_norm2: LayerNorm,
}
impl EncoderLayer {
fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
let hidden_size = cfg.hidden_size();
let layer_norm_eps = cfg.layer_norm_eps();
let self_attn = Attention::new(cfg, vb.pp("self_attn"))?;
let layer_norm1 = layer_norm(hidden_size, layer_norm_eps, vb.pp("layer_norm1"))?;
let mlp = Mlp::new(cfg, vb.pp("mlp"))?;
let layer_norm2 = layer_norm(hidden_size, layer_norm_eps, vb.pp("layer_norm2"))?;
Ok(Self {
self_attn,
layer_norm1,
mlp,
layer_norm2,
})
}
fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
let residual = xs;
let xs = xs.apply(&self.layer_norm1)?;
let xs = self.self_attn.forward(&xs, attention_mask)?;
let xs = (residual + xs)?;
let residual = &xs;
let xs = xs.apply(&self.layer_norm2)?.apply(&self.mlp)?;
let xs = (xs + residual)?;
Ok(xs)
}
}
#[derive(Debug, Clone)]
struct Encoder {
layers: Vec<EncoderLayer>,
}
impl Encoder {
fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
let mut layers = vec![];
let vb = vb.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers() {
let layer = EncoderLayer::new(cfg, vb.pp(layer_idx))?;
layers.push(layer)
}
Ok(Self { layers })
}
fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = layer.forward(&xs, attention_mask)?
}
Ok(xs)
}
}
#[derive(Debug, Clone)]
struct VisionEmbeddings {
patch_embedding: candle_nn::Conv2d,
position_embedding: candle_nn::Embedding,
position_ids: Tensor,
}
impl VisionEmbeddings {
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
let conv2d_cfg = candle_nn::Conv2dConfig {
stride: cfg.patch_size,
..Default::default()
};
let patch_embedding = candle_nn::conv2d(
cfg.num_channels,
cfg.hidden_size,
cfg.patch_size,
conv2d_cfg,
vb.pp("patch_embedding"),
)?;
let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
let position_ids = Tensor::arange(0, num_patches as i64, vb.device())?;
let position_embedding =
candle_nn::embedding(num_patches, cfg.hidden_size(), vb.pp("position_embedding"))?;
Ok(Self {
patch_embedding,
position_embedding,
position_ids,
})
}
}
impl Module for VisionEmbeddings {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (_batch, _channels, _height, _width) = xs.dims4()?;
let embeddings = xs.apply(&self.patch_embedding)?;
let embeddings = embeddings.flatten_from(2)?.transpose(1, 2)?;
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
embeddings.broadcast_add(&position_embedding)
}
}
#[derive(Debug, Clone)]
struct VisionTransformer {
embeddings: VisionEmbeddings,
encoder: Encoder,
post_layernorm: LayerNorm,
head: Option<MultiheadAttentionPoolingHead>,
}
impl VisionTransformer {
fn new(cfg: &VisionConfig, use_head: bool, vb: VarBuilder) -> Result<Self> {
let embeddings = VisionEmbeddings::new(cfg, vb.pp("embeddings"))?;
let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
let post_layernorm =
layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("post_layernorm"))?;
let head = if use_head {
Some(MultiheadAttentionPoolingHead::new(cfg, vb.pp("head"))?)
} else {
None
};
Ok(Self {
embeddings,
encoder,
post_layernorm,
head,
})
}
}
impl Module for VisionTransformer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = xs.apply(&self.embeddings)?;
let xs = self.encoder.forward(&xs, None)?;
let xs = xs.apply(&self.post_layernorm)?;
match self.head.as_ref() {
None => Ok(xs),
Some(h) => xs.apply(h),
}
}
}
#[derive(Debug, Clone)]
pub struct VisionModel {
vision_model: VisionTransformer,
}
impl VisionModel {
pub fn new(cfg: &VisionConfig, use_head: bool, vb: VarBuilder) -> Result<Self> {
let vision_model = VisionTransformer::new(cfg, use_head, vb)?;
Ok(Self { vision_model })
}
}
impl Module for VisionModel {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.vision_model)
}
}
#[derive(Debug, Clone)]
struct TextEmbeddings {
token_embedding: candle_nn::Embedding,
position_embedding: candle_nn::Embedding,
position_ids: Tensor,
}
impl TextEmbeddings {
fn new(cfg: &TextConfig, vb: VarBuilder) -> Result<Self> {
let token_embedding =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("token_embedding"))?;
let position_embedding = candle_nn::embedding(
cfg.max_position_embeddings,
cfg.hidden_size,
vb.pp("position_embedding"),
)?;
let position_ids =
Tensor::arange(0u32, cfg.max_position_embeddings as u32, vb.device())?.unsqueeze(0)?;
Ok(Self {
token_embedding,
position_embedding,
position_ids,
})
}
}
impl Module for TextEmbeddings {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let seq_length = input_ids.dim(D::Minus1)?;
let inputs_embeds = self.token_embedding.forward(input_ids)?;
let position_ids = self.position_ids.narrow(1, 0, seq_length)?;
let position_embedding = self.position_embedding.forward(&position_ids)?;
inputs_embeds.broadcast_add(&position_embedding)
}
}
#[derive(Debug, Clone)]
pub struct TextTransformer {
embeddings: TextEmbeddings,
encoder: Encoder,
final_layer_norm: LayerNorm,
pub head: Linear,
}
impl TextTransformer {
fn new(cfg: &TextConfig, vb: VarBuilder) -> Result<Self> {
let embeddings = TextEmbeddings::new(cfg, vb.pp("embeddings"))?;
let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
let final_layer_norm = layer_norm(
cfg.hidden_size,
cfg.layer_norm_eps,
vb.pp("final_layer_norm"),
)?;
let head = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("head"))?;
Ok(Self {
embeddings,
encoder,
final_layer_norm,
head,
})
}
}
impl Module for TextTransformer {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let (_bsz, seq_len) = input_ids.dims2()?;
let input_ids = self.embeddings.forward(input_ids)?;
let input_ids = self.encoder.forward(&input_ids, None)?;
let last_hidden_state = self.final_layer_norm.forward(&input_ids)?;
last_hidden_state
.i((.., seq_len - 1, ..))?
.contiguous()?
.apply(&self.head)
}
}
#[derive(Debug, Clone)]
pub struct TextModel {
pub text_model: TextTransformer,
}
impl TextModel {
pub fn new(cfg: &TextConfig, vb: VarBuilder) -> Result<Self> {
let text_model = TextTransformer::new(cfg, vb)?;
Ok(Self { text_model })
}
}
impl Module for TextModel {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.text_model)
}
}
#[derive(Clone, Debug)]
pub struct Model {
text_model: TextModel,
vision_model: VisionModel,
logit_bias: Tensor,
logit_scale: Tensor,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let text_model = TextModel::new(&cfg.text_config, vb.pp("text_model"))?;
let vision_model = VisionModel::new(&cfg.vision_config, true, vb.pp("vision_model"))?;
let logit_scale = vb.get(&[1], "logit_scale")?;
let logit_bias = vb.get(&[1], "logit_bias")?;
Ok(Self {
text_model,
vision_model,
logit_bias,
logit_scale,
})
}
pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
input_ids.apply(&self.text_model)
}
pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
pixel_values.apply(&self.vision_model)
}
pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> {
let image_features = self.get_image_features(pixel_values)?;
let text_features = self.get_text_features(input_ids)?;
let image_features_normalized = div_l2_norm(&image_features)?;
let text_features_normalized = div_l2_norm(&text_features)?;
let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
let logit_scale = self.logit_scale.exp()?;
let logits_per_text = logits_per_text
.broadcast_mul(&logit_scale)?
.broadcast_add(&self.logit_bias)?;
let logits_per_image = logits_per_text.t()?;
Ok((logits_per_text, logits_per_image))
}
}