mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Compare commits
9 Commits
Author | SHA1 | Date | |
---|---|---|---|
3a3c48b14b | |||
261ed65f36 | |||
62525e8352 | |||
2c25754281 | |||
ed48f54b54 | |||
ad8a4c5e5a | |||
c3c392f45c | |||
a0184a4fe4 | |||
10d47183c0 |
3
.github/workflows/ci_cuda.yaml
vendored
3
.github/workflows/ci_cuda.yaml
vendored
@ -9,7 +9,8 @@ jobs:
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
runs-on: [single-gpu, nvidia-gpu, t4, ci]
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: nvidia/cuda:12.3.1-devel-ubuntu22.04
|
||||
options: --gpus 0
|
||||
|
18
Cargo.toml
18
Cargo.toml
@ -20,7 +20,7 @@ exclude = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.7.1"
|
||||
version = "0.7.2"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -33,14 +33,14 @@ ab_glyph = "0.2.23"
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.7.1" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.7.1" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.7.1" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.7.1" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.7.1" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.7.1" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.7.1" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.7.1" }
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.7.2" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.7.2" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.7.2" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.7.2" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.7.2" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.7.2" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.7.2" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.7.2" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||
|
@ -12,7 +12,6 @@ use candle_nn::{ops::softmax, VarBuilder};
|
||||
use candle_transformers::models::clip;
|
||||
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing::info;
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
@ -40,15 +39,12 @@ fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::
|
||||
height as u32,
|
||||
image::imageops::FilterType::Triangle,
|
||||
);
|
||||
|
||||
let img = img.to_rgb8();
|
||||
|
||||
let img = img.into_raw();
|
||||
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
|
||||
.permute((2, 0, 1))?
|
||||
.to_dtype(DType::F32)?
|
||||
.affine(2. / 255., -1.)?;
|
||||
// .unsqueeze(0)?;
|
||||
Ok(img)
|
||||
}
|
||||
|
||||
@ -57,24 +53,16 @@ fn load_images<T: AsRef<std::path::Path>>(
|
||||
image_size: usize,
|
||||
) -> anyhow::Result<Tensor> {
|
||||
let mut images = vec![];
|
||||
|
||||
for path in paths {
|
||||
let tensor = load_image(path, image_size)?;
|
||||
images.push(tensor);
|
||||
}
|
||||
|
||||
let images = Tensor::stack(&images, 0)?;
|
||||
|
||||
Ok(images)
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
// std::env::set_var("RUST_BACKTRACE", "full");
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
@ -89,13 +77,9 @@ pub fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
|
||||
let tokenizer = get_tokenizer(args.tokenizer)?;
|
||||
|
||||
let config = clip::ClipConfig::vit_base_patch32();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let vec_imgs = match args.images {
|
||||
Some(imgs) => imgs,
|
||||
None => vec![
|
||||
@ -103,43 +87,29 @@ pub fn main() -> anyhow::Result<()> {
|
||||
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
|
||||
],
|
||||
};
|
||||
|
||||
// let image = load_image(args.image, config.image_size)?.to_device(&device)?;
|
||||
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
|
||||
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
|
||||
|
||||
let model = clip::ClipModel::new(vb, &config)?;
|
||||
|
||||
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
|
||||
|
||||
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
|
||||
|
||||
let softmax_image = softmax(&logits_per_image, 1)?;
|
||||
|
||||
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
|
||||
|
||||
info!("softmax_image_vec: {:?}", softmax_image_vec);
|
||||
|
||||
println!("softmax_image_vec: {:?}", softmax_image_vec);
|
||||
let probability_vec = softmax_image_vec
|
||||
.iter()
|
||||
.map(|v| v * 100.0)
|
||||
.collect::<Vec<f32>>();
|
||||
|
||||
let probability_per_image = probability_vec.len() / vec_imgs.len();
|
||||
|
||||
for (i, img) in vec_imgs.iter().enumerate() {
|
||||
let start = i * probability_per_image;
|
||||
let end = start + probability_per_image;
|
||||
let prob = &probability_vec[start..end];
|
||||
info!("\n\nResults for image: {}\n", img);
|
||||
|
||||
println!("\n\nResults for image: {}\n", img);
|
||||
for (i, p) in prob.iter().enumerate() {
|
||||
info!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
|
||||
println!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -156,7 +126,6 @@ pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
|
||||
}
|
||||
Some(file) => file.into(),
|
||||
};
|
||||
|
||||
Tokenizer::from_file(tokenizer).map_err(E::msg)
|
||||
}
|
||||
|
||||
@ -169,7 +138,6 @@ pub fn tokenize_sequences(
|
||||
.get_vocab(true)
|
||||
.get("<|endoftext|>")
|
||||
.ok_or(E::msg("No pad token"))?;
|
||||
|
||||
let vec_seq = match sequences {
|
||||
Some(seq) => seq,
|
||||
None => vec![
|
||||
@ -178,16 +146,12 @@ pub fn tokenize_sequences(
|
||||
"a robot holding a candle".to_string(),
|
||||
],
|
||||
};
|
||||
|
||||
let mut tokens = vec![];
|
||||
|
||||
for seq in vec_seq.clone() {
|
||||
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
|
||||
tokens.push(encoding.get_ids().to_vec());
|
||||
}
|
||||
|
||||
let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0);
|
||||
|
||||
// Pad the sequences to have the same length
|
||||
for token_vec in tokens.iter_mut() {
|
||||
let len_diff = max_len - token_vec.len();
|
||||
@ -195,8 +159,6 @@ pub fn tokenize_sequences(
|
||||
token_vec.extend(vec![pad_id; len_diff]);
|
||||
}
|
||||
}
|
||||
|
||||
let input_ids = Tensor::new(tokens, device)?;
|
||||
|
||||
Ok((input_ids, vec_seq))
|
||||
}
|
||||
|
@ -13,7 +13,7 @@ descriptions,
|
||||
|
||||
```bash
|
||||
cargo run --features cuda --example flux -r -- \
|
||||
--height 1024 --width 1024
|
||||
--height 1024 --width 1024 \
|
||||
--prompt "a rusty robot walking on a beach holding a small torch, the robot has the word "rust" written on it, high quality, 4k"
|
||||
```
|
||||
|
||||
|
@ -23,6 +23,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Use the quantized model.
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
@ -40,6 +44,10 @@ struct Args {
|
||||
|
||||
#[arg(long, value_enum, default_value = "schnell")]
|
||||
model: Model,
|
||||
|
||||
/// Use the faster kernels which are buggy at the moment.
|
||||
#[arg(long)]
|
||||
no_dmmv: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
|
||||
@ -60,6 +68,8 @@ fn run(args: Args) -> Result<()> {
|
||||
tracing,
|
||||
decode_only,
|
||||
model,
|
||||
quantized,
|
||||
..
|
||||
} = args;
|
||||
let width = width.unwrap_or(1360);
|
||||
let height = height.unwrap_or(768);
|
||||
@ -146,38 +156,71 @@ fn run(args: Args) -> Result<()> {
|
||||
};
|
||||
println!("CLIP\n{clip_emb}");
|
||||
let img = {
|
||||
let model_file = match model {
|
||||
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
|
||||
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
|
||||
};
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
||||
let cfg = match model {
|
||||
Model::Dev => flux::model::Config::dev(),
|
||||
Model::Schnell => flux::model::Config::schnell(),
|
||||
};
|
||||
let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?;
|
||||
let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?;
|
||||
let state = if quantized {
|
||||
flux::sampling::State::new(
|
||||
&t5_emb.to_dtype(candle::DType::F32)?,
|
||||
&clip_emb.to_dtype(candle::DType::F32)?,
|
||||
&img.to_dtype(candle::DType::F32)?,
|
||||
)?
|
||||
} else {
|
||||
flux::sampling::State::new(&t5_emb, &clip_emb, &img)?
|
||||
};
|
||||
let timesteps = match model {
|
||||
Model::Dev => {
|
||||
flux::sampling::get_schedule(50, Some((state.img.dim(1)?, 0.5, 1.15)))
|
||||
}
|
||||
Model::Schnell => flux::sampling::get_schedule(4, None),
|
||||
};
|
||||
let model = flux::model::Flux::new(&cfg, vb)?;
|
||||
|
||||
println!("{state:?}");
|
||||
println!("{timesteps:?}");
|
||||
flux::sampling::denoise(
|
||||
&model,
|
||||
&state.img,
|
||||
&state.img_ids,
|
||||
&state.txt,
|
||||
&state.txt_ids,
|
||||
&state.vec,
|
||||
×teps,
|
||||
4.,
|
||||
)?
|
||||
if quantized {
|
||||
let model_file = match model {
|
||||
Model::Schnell => api
|
||||
.repo(hf_hub::Repo::model("lmz/candle-flux".to_string()))
|
||||
.get("flux1-schnell.gguf")?,
|
||||
Model::Dev => todo!(),
|
||||
};
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
||||
model_file, &device,
|
||||
)?;
|
||||
|
||||
let model = flux::quantized_model::Flux::new(&cfg, vb)?;
|
||||
flux::sampling::denoise(
|
||||
&model,
|
||||
&state.img,
|
||||
&state.img_ids,
|
||||
&state.txt,
|
||||
&state.txt_ids,
|
||||
&state.vec,
|
||||
×teps,
|
||||
4.,
|
||||
)?
|
||||
.to_dtype(dtype)?
|
||||
} else {
|
||||
let model_file = match model {
|
||||
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
|
||||
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
|
||||
};
|
||||
let vb = unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)?
|
||||
};
|
||||
let model = flux::model::Flux::new(&cfg, vb)?;
|
||||
flux::sampling::denoise(
|
||||
&model,
|
||||
&state.img,
|
||||
&state.img_ids,
|
||||
&state.txt,
|
||||
&state.txt_ids,
|
||||
&state.vec,
|
||||
×teps,
|
||||
4.,
|
||||
)?
|
||||
}
|
||||
};
|
||||
flux::sampling::unpack(&img, height, width)?
|
||||
}
|
||||
@ -206,5 +249,7 @@ fn run(args: Args) -> Result<()> {
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
#[cfg(feature = "cuda")]
|
||||
candle::quantized::cuda::set_force_dmmv(!args.no_dmmv);
|
||||
run(args)
|
||||
}
|
||||
|
@ -35,6 +35,10 @@ enum Which {
|
||||
V31,
|
||||
V3Instruct,
|
||||
V31Instruct,
|
||||
V32_1b,
|
||||
V32_1bInstruct,
|
||||
V32_3b,
|
||||
V32_3bInstruct,
|
||||
#[value(name = "solar-10.7b")]
|
||||
Solar10_7B,
|
||||
#[value(name = "tiny-llama-1.1b-chat")]
|
||||
@ -137,6 +141,10 @@ fn main() -> Result<()> {
|
||||
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
|
||||
Which::V31 => "meta-llama/Meta-Llama-3.1-8B".to_string(),
|
||||
Which::V31Instruct => "meta-llama/Meta-Llama-3.1-8B-Instruct".to_string(),
|
||||
Which::V32_1b => "meta-llama/Llama-3.2-1B".to_string(),
|
||||
Which::V32_1bInstruct => "meta-llama/Llama-3.2-1B-Instruct".to_string(),
|
||||
Which::V32_3b => "meta-llama/Llama-3.2-3B".to_string(),
|
||||
Which::V32_3bInstruct => "meta-llama/Llama-3.2-3B-Instruct".to_string(),
|
||||
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
||||
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
|
||||
});
|
||||
@ -156,10 +164,14 @@ fn main() -> Result<()> {
|
||||
| Which::V3Instruct
|
||||
| Which::V31
|
||||
| Which::V31Instruct
|
||||
| Which::V32_3b
|
||||
| Which::V32_3bInstruct
|
||||
| Which::Solar10_7B => {
|
||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||
}
|
||||
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
||||
Which::V32_1b | Which::V32_1bInstruct | Which::TinyLlama1_1BChat => {
|
||||
vec![api.get("model.safetensors")?]
|
||||
}
|
||||
};
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
|
||||
|
@ -60,7 +60,6 @@ fn load_images<T: AsRef<std::path::Path>>(
|
||||
image_size: usize,
|
||||
) -> anyhow::Result<Tensor> {
|
||||
let mut images = vec![];
|
||||
|
||||
for path in paths {
|
||||
let tensor = candle_examples::imagenet::load_image_with_std_mean(
|
||||
path,
|
||||
@ -70,9 +69,7 @@ fn load_images<T: AsRef<std::path::Path>>(
|
||||
)?;
|
||||
images.push(tensor);
|
||||
}
|
||||
|
||||
let images = Tensor::stack(&images, 0)?;
|
||||
|
||||
Ok(images)
|
||||
}
|
||||
|
||||
@ -80,24 +77,17 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let model_name = args.which.model_name();
|
||||
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(model_name);
|
||||
|
||||
let model_file = if args.use_pth {
|
||||
api.get("open_clip_pytorch_model.bin")?
|
||||
} else {
|
||||
api.get("open_clip_model.safetensors")?
|
||||
};
|
||||
|
||||
let tokenizer = api.get("tokenizer.json")?;
|
||||
|
||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
|
||||
let config = &args.which.config();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let vec_imgs = match args.images {
|
||||
Some(imgs) => imgs,
|
||||
None => vec![
|
||||
@ -105,9 +95,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
|
||||
],
|
||||
};
|
||||
|
||||
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
|
||||
|
||||
let vb = if args.use_pth {
|
||||
VarBuilder::from_pth(&model_file, DType::F32, &device)?
|
||||
} else {
|
||||
@ -115,22 +103,15 @@ pub fn main() -> anyhow::Result<()> {
|
||||
};
|
||||
|
||||
let model = mobileclip::MobileClipModel::new(vb, config)?;
|
||||
|
||||
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
|
||||
|
||||
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
|
||||
|
||||
let softmax_image = softmax(&logits_per_image, 1)?;
|
||||
|
||||
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
|
||||
|
||||
println!("softmax_image_vec: {:?}", softmax_image_vec);
|
||||
|
||||
let probability_vec = softmax_image_vec
|
||||
.iter()
|
||||
.map(|v| v * 100.0)
|
||||
.collect::<Vec<f32>>();
|
||||
|
||||
let probability_per_image = probability_vec.len() / vec_imgs.len();
|
||||
|
||||
for (i, img) in vec_imgs.iter().enumerate() {
|
||||
@ -171,7 +152,6 @@ pub fn tokenize_sequences(
|
||||
};
|
||||
|
||||
let mut tokens = vec![];
|
||||
|
||||
for seq in vec_seq.clone() {
|
||||
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
|
||||
tokens.push(encoding.get_ids().to_vec());
|
||||
@ -185,8 +165,6 @@ pub fn tokenize_sequences(
|
||||
token_vec.extend(vec![pad_id; len_diff]);
|
||||
}
|
||||
}
|
||||
|
||||
let input_ids = Tensor::new(tokens, device)?;
|
||||
|
||||
Ok((input_ids, vec_seq))
|
||||
}
|
||||
|
24
candle-examples/examples/siglip/README.md
Normal file
24
candle-examples/examples/siglip/README.md
Normal file
@ -0,0 +1,24 @@
|
||||
## SigLIP
|
||||
|
||||
SigLIP is multi-modal text-vision model that improves over CLIP by using a sigmoid based loss,
|
||||
[HuggingFace](https://huggingface.co/google/siglip-base-patch16-224).
|
||||
|
||||
### Running an example
|
||||
```
|
||||
$ cargo run --features cuda -r --example siglip -
|
||||
softmax_image_vec: [2.1912122e-14, 2.3624872e-14, 1.0, 1.0, 2.4787932e-8, 3.2784535e-12]
|
||||
|
||||
|
||||
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
|
||||
|
||||
Probability: 0.0000% Text: a cycling race
|
||||
Probability: 0.0000% Text: a photo of two cats
|
||||
Probability: 100.0000% Text: a robot holding a candle
|
||||
|
||||
|
||||
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
|
||||
Probability: 100.0000% Text: a cycling race
|
||||
Probability: 0.0000% Text: a photo of two cats
|
||||
Probability: 0.0000% Text: a robot holding a candle
|
||||
```
|
153
candle-examples/examples/siglip/main.rs
Normal file
153
candle-examples/examples/siglip/main.rs
Normal file
@ -0,0 +1,153 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Error as E;
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
use candle_transformers::models::siglip;
|
||||
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(long, use_value_delimiter = true)]
|
||||
images: Option<Vec<String>>,
|
||||
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(long, use_value_delimiter = true)]
|
||||
sequences: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
|
||||
let img = image::ImageReader::open(path)?.decode()?;
|
||||
let (height, width) = (image_size, image_size);
|
||||
let img = img.resize_to_fill(
|
||||
width as u32,
|
||||
height as u32,
|
||||
image::imageops::FilterType::Triangle,
|
||||
);
|
||||
let img = img.to_rgb8();
|
||||
let img = img.into_raw();
|
||||
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
|
||||
.permute((2, 0, 1))?
|
||||
.to_dtype(DType::F32)?
|
||||
.affine(2. / 255., -1.)?;
|
||||
Ok(img)
|
||||
}
|
||||
|
||||
fn load_images<T: AsRef<std::path::Path>>(
|
||||
paths: &Vec<T>,
|
||||
image_size: usize,
|
||||
) -> anyhow::Result<Tensor> {
|
||||
let mut images = vec![];
|
||||
for path in paths {
|
||||
let tensor = load_image(path, image_size)?;
|
||||
images.push(tensor);
|
||||
}
|
||||
let images = Tensor::stack(&images, 0)?;
|
||||
Ok(images)
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("google/siglip-base-patch16-224".to_string());
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let tokenizer = get_tokenizer(args.tokenizer)?;
|
||||
let config = siglip::Config::base_patch16_224();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vec_imgs = match args.images {
|
||||
Some(imgs) => imgs,
|
||||
None => vec![
|
||||
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
|
||||
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
|
||||
],
|
||||
};
|
||||
let images = load_images(&vec_imgs, config.vision_config.image_size)?.to_device(&device)?;
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
|
||||
let model = siglip::Model::new(&config, vb)?;
|
||||
let (input_ids, vec_seq) = tokenize_sequences(&config, args.sequences, &tokenizer, &device)?;
|
||||
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
|
||||
let softmax_image = softmax(&logits_per_image, 1)?;
|
||||
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
|
||||
println!("softmax_image_vec: {:?}", softmax_image_vec);
|
||||
let probability_vec = softmax_image_vec
|
||||
.iter()
|
||||
.map(|v| v * 100.0)
|
||||
.collect::<Vec<f32>>();
|
||||
let probability_per_image = probability_vec.len() / vec_imgs.len();
|
||||
for (i, img) in vec_imgs.iter().enumerate() {
|
||||
let start = i * probability_per_image;
|
||||
let end = start + probability_per_image;
|
||||
let prob = &probability_vec[start..end];
|
||||
println!("\n\nResults for image: {}\n", img);
|
||||
for (i, p) in prob.iter().enumerate() {
|
||||
println!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
|
||||
let tokenizer = match tokenizer {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("google/siglip-base-patch16-224".to_string());
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
Some(file) => file.into(),
|
||||
};
|
||||
|
||||
Tokenizer::from_file(tokenizer).map_err(E::msg)
|
||||
}
|
||||
|
||||
pub fn tokenize_sequences(
|
||||
config: &siglip::Config,
|
||||
sequences: Option<Vec<String>>,
|
||||
tokenizer: &Tokenizer,
|
||||
device: &Device,
|
||||
) -> anyhow::Result<(Tensor, Vec<String>)> {
|
||||
let pad_id = config.text_config.pad_token_id;
|
||||
let vec_seq = match sequences {
|
||||
Some(seq) => seq,
|
||||
None => vec![
|
||||
"a cycling race".to_string(),
|
||||
"a photo of two cats".to_string(),
|
||||
"a robot holding a candle".to_string(),
|
||||
],
|
||||
};
|
||||
let mut tokens = vec![];
|
||||
for seq in vec_seq.clone() {
|
||||
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
|
||||
tokens.push(encoding.get_ids().to_vec());
|
||||
}
|
||||
let max_len = config.text_config.max_position_embeddings;
|
||||
// Pad the sequences to have the same length
|
||||
for token_vec in tokens.iter_mut() {
|
||||
let len_diff = max_len - token_vec.len();
|
||||
if len_diff > 0 {
|
||||
token_vec.extend(vec![pad_id; len_diff]);
|
||||
}
|
||||
}
|
||||
let input_ids = Tensor::new(tokens, device)?;
|
||||
Ok((input_ids, vec_seq))
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.7.1"
|
||||
version = "0.7.2"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.7.1" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.7.2" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.7.1"
|
||||
version = "0.7.2"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.7.1"
|
||||
version = "0.7.2"
|
||||
edition = "2021"
|
||||
|
||||
description = "Metal kernels for Candle"
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-onnx"
|
||||
version = "0.7.1"
|
||||
version = "0.7.2"
|
||||
edition = "2021"
|
||||
|
||||
description = "ONNX support for Candle"
|
||||
@ -10,8 +10,8 @@ categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.7.1" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.7.1" }
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.7.2" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.7.2" }
|
||||
prost = "0.12.1"
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -2,7 +2,7 @@ use crate::onnx::attribute_proto::AttributeType;
|
||||
use crate::onnx::tensor_proto::DataType;
|
||||
use crate::onnx::{self, GraphProto};
|
||||
use candle::{bail, DType, Device, Result, Tensor};
|
||||
use std::{collections::HashMap, usize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub type Value = Tensor;
|
||||
|
||||
@ -321,8 +321,15 @@ fn simple_eval_(
|
||||
for node in graph.node.iter() {
|
||||
let get = |input_name: &str| match values.get(input_name) {
|
||||
Some(value) => Ok(value),
|
||||
None => bail!("cannot find {input_name} for op {}", node.name),
|
||||
None => bail!("cannot find {input_name} for op '{}'", node.name),
|
||||
};
|
||||
let get_opt = |i: usize| {
|
||||
node.input
|
||||
.get(i)
|
||||
.filter(|s: &&String| !s.is_empty())
|
||||
.map(|s| get(s))
|
||||
};
|
||||
|
||||
// TODO: Validate node.input for each operator.
|
||||
match node.op_type.as_str() {
|
||||
"Add" => {
|
||||
@ -355,7 +362,7 @@ fn simple_eval_(
|
||||
// HACK: current implementation of broadcast_pow cannot handle negative base,
|
||||
// so we use powf where we can, which *does* correctly handle negative base.
|
||||
if let Ok(exp) = (|| input1.to_dtype(DType::F64)?.to_scalar::<f64>())() {
|
||||
let output = input0.powf(exp as f64)?;
|
||||
let output = input0.powf(exp)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
} else {
|
||||
let output = input0.broadcast_pow(input1)?;
|
||||
@ -608,15 +615,13 @@ fn simple_eval_(
|
||||
}
|
||||
"Clip" => {
|
||||
let xs = get(&node.input[0])?;
|
||||
let xs = if node.input.len() >= 2 {
|
||||
let mins = get(&node.input[1])?;
|
||||
xs.broadcast_maximum(mins)?
|
||||
let xs = if let Some(mins) = get_opt(1) {
|
||||
xs.broadcast_maximum(mins?)?
|
||||
} else {
|
||||
xs.clone()
|
||||
};
|
||||
let xs = if node.input.len() >= 3 {
|
||||
let maxs = get(&node.input[2])?;
|
||||
xs.broadcast_minimum(maxs)?
|
||||
let xs = if let Some(maxs) = get_opt(2) {
|
||||
xs.broadcast_minimum(maxs?)?
|
||||
} else {
|
||||
xs.clone()
|
||||
};
|
||||
@ -638,7 +643,7 @@ fn simple_eval_(
|
||||
let mask = indices.lt(&zeros)?;
|
||||
mask.to_dtype(indices.dtype())?
|
||||
.broadcast_mul(&max)?
|
||||
.add(&indices)?
|
||||
.add(indices)?
|
||||
};
|
||||
|
||||
// In Pytorch or Numpy this can be done by indexing the xs tensor using the indices
|
||||
@ -759,7 +764,14 @@ fn simple_eval_(
|
||||
let cond = get(&node.input[0])?;
|
||||
let a = get(&node.input[1])?;
|
||||
let b = get(&node.input[2])?;
|
||||
let output = cond.where_cond(a, b)?;
|
||||
|
||||
// where_cond requires that all inputs are the same shape.
|
||||
// In contrast, the Where op in ONNX only requires that they are broadcastable.
|
||||
let shape = broadcast_shape_from_many(&[cond.dims(), a.dims(), b.dims()])?;
|
||||
let cond = cond.broadcast_as(shape.clone())?;
|
||||
let a = a.broadcast_as(shape.clone())?;
|
||||
let b = b.broadcast_as(shape)?;
|
||||
let output = cond.where_cond(&a, &b)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Conv" => {
|
||||
@ -962,6 +974,7 @@ fn simple_eval_(
|
||||
}
|
||||
rtype => bail!("unsupported 'value' type {rtype:?} for {}", node.name),
|
||||
};
|
||||
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
|
||||
@ -1199,6 +1212,151 @@ fn simple_eval_(
|
||||
};
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
//https://github.com/onnx/onnx/blob/main/docs/Operators.md#Split
|
||||
// Version 18 impl
|
||||
"Split" => {
|
||||
let input_tensor = get(&node.input[0])?;
|
||||
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
|
||||
let axis = input_tensor.normalize_axis(axis)?;
|
||||
|
||||
// Determine split sizes
|
||||
let splits = if node.input.len() > 1 {
|
||||
// If the split tensor is provided, use it to determine sizes
|
||||
let split_tensor = get(&node.input[1])?.to_vec1::<i64>()?;
|
||||
split_tensor.iter().map(|&x| x as usize).collect::<Vec<_>>()
|
||||
} else {
|
||||
let num_outputs = if let Some(&num_outputs_attrib) =
|
||||
get_attr_opt::<i64>(node, "num_outputs")?
|
||||
{
|
||||
num_outputs_attrib as usize
|
||||
} else {
|
||||
node.output.len()
|
||||
};
|
||||
|
||||
let input_dim = input_tensor.dim(axis)?;
|
||||
|
||||
let mut split_sizes =
|
||||
vec![input_dim / num_outputs as usize; num_outputs as usize];
|
||||
let remainder = input_dim % num_outputs as usize;
|
||||
if remainder > 0 {
|
||||
// If there's a remainder, add it to the last split size
|
||||
split_sizes[num_outputs as usize - 1] += remainder;
|
||||
}
|
||||
|
||||
split_sizes
|
||||
};
|
||||
|
||||
// Perform the split operation
|
||||
let mut outputs = vec![];
|
||||
let mut start = 0;
|
||||
for &size in &splits {
|
||||
let end = start + size;
|
||||
let slice = input_tensor.narrow(axis, start, size)?;
|
||||
outputs.push(slice);
|
||||
start = end;
|
||||
}
|
||||
|
||||
// Insert the split outputs into the values map
|
||||
for (output, slice) in node.output.iter().zip(outputs.into_iter()) {
|
||||
values.insert(output.clone(), slice);
|
||||
}
|
||||
}
|
||||
//https://github.com/onnx/onnx/blob/main/docs/Operators.md#Expand
|
||||
// Version 13 impl
|
||||
"Expand" => {
|
||||
// unlike broadcast_to, expand allows for the output shape to
|
||||
// be different from the specified shape.
|
||||
let input_tensor = get(&node.input[0])?;
|
||||
let input_shape = get(&node.input[1])?;
|
||||
|
||||
// Check that the shape tensor is 1D
|
||||
if input_shape.rank() != 1 {
|
||||
bail!(
|
||||
"Expand expects 'shape' input to be 1D tensor: {:?}",
|
||||
input_shape
|
||||
);
|
||||
}
|
||||
let input_tensor_dims = input_tensor.dims();
|
||||
let input_shape_dims = input_shape
|
||||
.to_vec1::<i64>()?
|
||||
.into_iter()
|
||||
.map(|x| x as usize)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let target_shape = broadcast_shape(input_tensor_dims, input_shape_dims.as_slice())?;
|
||||
|
||||
let expanded_tensor = input_tensor.broadcast_as(target_shape)?;
|
||||
|
||||
values.insert(node.output[0].clone(), expanded_tensor);
|
||||
}
|
||||
//https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceSum
|
||||
// Version 13 impl
|
||||
"ReduceSum" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let axes = get_opt(1);
|
||||
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1);
|
||||
let noop_with_empty_axes = get_attr_opt::<i64>(node, "noop_with_empty_axes")?
|
||||
.copied()
|
||||
.unwrap_or(0);
|
||||
|
||||
let axes = match axes {
|
||||
Some(Ok(axes)) => axes
|
||||
.to_vec1::<i64>()?
|
||||
.into_iter()
|
||||
.map(|x| x as usize)
|
||||
.collect::<Vec<_>>(),
|
||||
Some(Err(_)) | None => {
|
||||
if noop_with_empty_axes == 1 {
|
||||
vec![]
|
||||
} else {
|
||||
(0..input.rank()).collect()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let output = if keepdims == 1 {
|
||||
input.sum_keepdim(axes)?
|
||||
} else {
|
||||
input.sum(axes)?
|
||||
};
|
||||
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceL2
|
||||
// Version 18 impl
|
||||
"ReduceL2" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let axes = get_opt(1);
|
||||
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1);
|
||||
let noop_with_empty_axes = get_attr_opt::<i64>(node, "noop_with_empty_axes")?
|
||||
.copied()
|
||||
.unwrap_or(0);
|
||||
|
||||
let input_sq = input.sqr()?;
|
||||
|
||||
let axes = match axes {
|
||||
Some(axes) => axes?
|
||||
.to_vec1::<i64>()?
|
||||
.into_iter()
|
||||
.map(|x| x as usize)
|
||||
.collect::<Vec<_>>(),
|
||||
None => {
|
||||
if noop_with_empty_axes == 1 {
|
||||
vec![]
|
||||
} else {
|
||||
(0..input_sq.rank()).collect()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let output = if keepdims == 1 {
|
||||
input_sq.sum_keepdim(axes)?.sqrt()?
|
||||
} else {
|
||||
input_sq.sum(axes)?.sqrt()?
|
||||
};
|
||||
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
random_type @ ("RandomUniform" | "RandomNormal") => {
|
||||
let dt: i64 = get_attr_opt(node, "dtype")?.copied().unwrap_or(1); // 1 is float
|
||||
// type by
|
||||
@ -1395,13 +1553,6 @@ fn simple_eval_(
|
||||
// This tensor has shape `[num_directions, 4*hidden_size, hidden_size]`.
|
||||
let r = get(&node.input[2])?;
|
||||
|
||||
let get_opt = |i: usize| {
|
||||
node.input
|
||||
.get(i)
|
||||
.filter(|s: &&String| !s.is_empty())
|
||||
.map(|s| get(s))
|
||||
};
|
||||
|
||||
// The bias tensor for input gate.
|
||||
// Concatenation of `[Wb[iofc], Rb[iofc]]`, and `[WBb[iofc], RBb[iofc]]` (if bidirectional) along dimension 0.
|
||||
// This tensor has shape `[num_directions, 8*hidden_size]`.
|
||||
@ -1488,7 +1639,7 @@ fn simple_eval_(
|
||||
let w = w.get(0)?; // w[iofc] has shape [4*hidden_size, input_size]
|
||||
let r = r.get(0)?; // r[iofc] has shape [4*hidden_size, hidden_size]
|
||||
let b = b.get(0)?; // concat of [wb[iofc],rb[iofc]] has shape [8*hidden_size]
|
||||
let idx_wb = Tensor::arange(0 * hidden_size, 4 * hidden_size, x.device())?;
|
||||
let idx_wb = Tensor::arange(0, 4 * hidden_size, x.device())?;
|
||||
let idx_rb = Tensor::arange(4 * hidden_size, 8 * hidden_size, x.device())?;
|
||||
let wb = b.index_select(&idx_wb, 0)?;
|
||||
let rb = b.index_select(&idx_rb, 0)?;
|
||||
@ -1497,8 +1648,8 @@ fn simple_eval_(
|
||||
|
||||
// w, r, wb, rb are all iofc but lstm expects ifco
|
||||
// so we need to move some stuff around
|
||||
let idx_i = Tensor::arange(0 * hidden_size, 1 * hidden_size, x.device())?;
|
||||
let idx_o = Tensor::arange(1 * hidden_size, 2 * hidden_size, x.device())?;
|
||||
let idx_i = Tensor::arange(0, hidden_size, x.device())?;
|
||||
let idx_o = Tensor::arange(hidden_size, 2 * hidden_size, x.device())?;
|
||||
let idx_f = Tensor::arange(2 * hidden_size, 3 * hidden_size, x.device())?;
|
||||
let idx_c = Tensor::arange(3 * hidden_size, 4 * hidden_size, x.device())?;
|
||||
let idx_ifco = Tensor::cat(&[&idx_i, &idx_f, &idx_c, &idx_o], 0)?;
|
||||
@ -1522,7 +1673,7 @@ fn simple_eval_(
|
||||
)?;
|
||||
|
||||
let mut lstm_state = candle_nn::rnn::LSTMState::new(h, c);
|
||||
let mut h_acc = if node.output.get(0).map(String::as_str).unwrap_or("") != "" {
|
||||
let mut h_acc = if node.output.first().map(String::as_str).unwrap_or("") != "" {
|
||||
Some(vec![])
|
||||
} else {
|
||||
None
|
||||
@ -1536,7 +1687,7 @@ fn simple_eval_(
|
||||
}
|
||||
|
||||
assert_eq!(num_directions, 1, "if support for bidirectional is ever added, outputs will have to be concatenated, not simply reshaped");
|
||||
if let Some(name) = node.output.get(0) {
|
||||
if let Some(name) = node.output.first() {
|
||||
let h_acc = h_acc.as_ref().unwrap();
|
||||
let h_acc = lstm.states_to_tensor(h_acc)?;
|
||||
let h_acc = h_acc.reshape((
|
||||
@ -1580,3 +1731,36 @@ fn simple_eval_(
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn broadcast_shape(shape_a: &[usize], shape_b: &[usize]) -> Result<Vec<usize>> {
|
||||
let (longest, shortest) = if shape_a.len() > shape_b.len() {
|
||||
(shape_a, shape_b)
|
||||
} else {
|
||||
(shape_b, shape_a)
|
||||
};
|
||||
let diff = longest.len() - shortest.len();
|
||||
let mut target_shape = longest[0..diff].to_vec();
|
||||
for (dim1, dim2) in longest[diff..].iter().zip(shortest.iter()) {
|
||||
if *dim1 == *dim2 || *dim2 == 1 || *dim1 == 1 {
|
||||
target_shape.push(usize::max(*dim1, *dim2));
|
||||
} else {
|
||||
bail!(
|
||||
"Expand: incompatible shapes for broadcast, {:?} and {:?}",
|
||||
shape_a,
|
||||
shape_b
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(target_shape)
|
||||
}
|
||||
|
||||
fn broadcast_shape_from_many(shapes: &[&[usize]]) -> Result<Vec<usize>> {
|
||||
if shapes.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut shape_out = shapes[0].to_vec();
|
||||
for shape in shapes[1..].iter() {
|
||||
shape_out = broadcast_shape(&shape_out, shape)?;
|
||||
}
|
||||
Ok(shape_out)
|
||||
}
|
||||
|
@ -1,12 +1,5 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::test_utils::to_vec2_round;
|
||||
use candle::{DType, Device, NdArray, Result, Tensor};
|
||||
use candle_onnx::eval::Value;
|
||||
use candle_onnx::onnx::attribute_proto::AttributeType;
|
||||
use candle_onnx::onnx::tensor_proto::DataType;
|
||||
use candle_onnx::onnx::tensor_shape_proto::{dimension, Dimension};
|
||||
@ -3574,312 +3567,312 @@ fn test_lstm() -> Result<()> {
|
||||
let number_directions = 1;
|
||||
let weight_ih_l0 = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
-1.5255959033966064,
|
||||
-0.7502318024635315,
|
||||
-0.6539809107780457,
|
||||
-1.6094847917556763,
|
||||
-0.1001671776175499,
|
||||
-0.6091889142990112,
|
||||
-0.9797722697257996,
|
||||
-1.6090962886810303,
|
||||
-0.7121446132659912,
|
||||
0.30372199416160583,
|
||||
-0.777314305305481,
|
||||
-0.25145524740219116,
|
||||
-0.22227048873901367,
|
||||
1.6871134042739868,
|
||||
0.22842517495155334,
|
||||
0.46763551235198975,
|
||||
-0.6969724297523499,
|
||||
-1.1607614755630493,
|
||||
0.6995424032211304,
|
||||
0.1990816295146942,
|
||||
0.8656923770904541,
|
||||
0.2444038987159729,
|
||||
-0.6629113554954529,
|
||||
0.8073082566261292,
|
||||
1.1016806364059448,
|
||||
-0.1759360432624817,
|
||||
-2.2455577850341797,
|
||||
-1.4464579820632935,
|
||||
0.0611552819609642,
|
||||
-0.6177444458007812,
|
||||
-0.7980698347091675,
|
||||
-0.13162320852279663,
|
||||
1.8793457746505737,
|
||||
-0.07213178277015686,
|
||||
0.15777060389518738,
|
||||
-0.7734549045562744,
|
||||
0.1990565061569214,
|
||||
0.04570277780294418,
|
||||
0.15295691788196564,
|
||||
-0.47567880153656006,
|
||||
-0.11101982742547989,
|
||||
0.2927352488040924,
|
||||
-0.1578451544046402,
|
||||
-0.028787139803171158,
|
||||
0.4532545804977417,
|
||||
1.1421611309051514,
|
||||
0.2486107051372528,
|
||||
-1.7754007577896118,
|
||||
-0.025502461940050125,
|
||||
-1.023330569267273,
|
||||
-0.5961851477622986,
|
||||
-1.0055307149887085,
|
||||
0.42854228615760803,
|
||||
1.4760777950286865,
|
||||
-1.7868678569793701,
|
||||
1.610317587852478,
|
||||
-0.703956663608551,
|
||||
-0.18526579439640045,
|
||||
-0.9962350726127625,
|
||||
-0.8312552571296692,
|
||||
-1.525_595_9,
|
||||
-0.750_231_8,
|
||||
-0.653_980_9,
|
||||
-1.609_484_8,
|
||||
-0.100_167_18,
|
||||
-0.609_188_9,
|
||||
-0.979_772_27,
|
||||
-1.609_096_3,
|
||||
-0.712_144_6,
|
||||
0.303_722,
|
||||
-0.777_314_3,
|
||||
-0.251_455_25,
|
||||
-0.222_270_49,
|
||||
1.687_113_4,
|
||||
0.228_425_17,
|
||||
0.467_635_5,
|
||||
-0.696_972_4,
|
||||
-1.160_761_5,
|
||||
0.699_542_4,
|
||||
0.199_081_63,
|
||||
0.865_692_4,
|
||||
0.244_403_9,
|
||||
-0.662_911_36,
|
||||
0.807_308_26,
|
||||
1.101_680_6,
|
||||
-0.175_936_04,
|
||||
-2.245_557_8,
|
||||
-1.446_458,
|
||||
0.061_155_282,
|
||||
-0.617_744_45,
|
||||
-0.798_069_83,
|
||||
-0.131_623_21,
|
||||
1.879_345_8,
|
||||
-0.072_131_78,
|
||||
0.157_770_6,
|
||||
-0.773_454_9,
|
||||
0.199_056_5,
|
||||
0.045_702_778,
|
||||
0.152_956_92,
|
||||
-0.475_678_8,
|
||||
-0.111_019_83,
|
||||
0.292_735_25,
|
||||
-0.157_845_15,
|
||||
-0.028_787_14,
|
||||
0.453_254_58,
|
||||
1.142_161_1,
|
||||
0.248_610_7,
|
||||
-1.775_400_8,
|
||||
-0.025_502_462,
|
||||
-1.023_330_6,
|
||||
-0.596_185_15,
|
||||
-1.005_530_7,
|
||||
0.428_542_3,
|
||||
1.476_077_8,
|
||||
-1.786_867_9,
|
||||
1.610_317_6,
|
||||
-0.703_956_66,
|
||||
-0.185_265_8,
|
||||
-0.996_235_1,
|
||||
-0.831_255_26,
|
||||
],
|
||||
(20, 3),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let weight_hh_l0 = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
0.4099724292755127,
|
||||
0.4084506630897522,
|
||||
0.25786539912223816,
|
||||
1.095021367073059,
|
||||
-0.5064865946769714,
|
||||
0.09977540373802185,
|
||||
-0.653973400592804,
|
||||
0.731693685054779,
|
||||
-1.456732988357544,
|
||||
1.6089353561401367,
|
||||
0.09376997500658035,
|
||||
-1.2597490549087524,
|
||||
0.25463348627090454,
|
||||
-0.5019572973251343,
|
||||
-1.041200041770935,
|
||||
0.7322672009468079,
|
||||
1.3075355291366577,
|
||||
-1.1627987623214722,
|
||||
0.11963611096143723,
|
||||
-0.1631353348493576,
|
||||
0.6614453196525574,
|
||||
1.1899205446243286,
|
||||
0.8165339231491089,
|
||||
-0.9135236144065857,
|
||||
-0.3538065254688263,
|
||||
0.7639270424842834,
|
||||
-0.5889506936073303,
|
||||
-0.7635973691940308,
|
||||
1.3352056741714478,
|
||||
0.6042736172676086,
|
||||
-0.10344208031892776,
|
||||
-0.15121692419052124,
|
||||
1.2465683221817017,
|
||||
0.505721390247345,
|
||||
0.9505112171173096,
|
||||
1.2966482639312744,
|
||||
0.873796284198761,
|
||||
-0.5602594017982483,
|
||||
1.2857844829559326,
|
||||
0.8168238401412964,
|
||||
-1.464799404144287,
|
||||
-1.2629283666610718,
|
||||
1.122018814086914,
|
||||
1.5663341283798218,
|
||||
2.558138370513916,
|
||||
-0.23336388170719147,
|
||||
-0.013472129590809345,
|
||||
1.8606348037719727,
|
||||
1.549620509147644,
|
||||
0.34762924909591675,
|
||||
0.09300802648067474,
|
||||
0.6147403120994568,
|
||||
0.7123645544052124,
|
||||
-1.7765072584152222,
|
||||
0.3538645803928375,
|
||||
1.1996132135391235,
|
||||
-0.7122589349746704,
|
||||
-0.620034396648407,
|
||||
-0.22813494503498077,
|
||||
-0.7892746329307556,
|
||||
-1.6111117601394653,
|
||||
-1.8716129064559937,
|
||||
0.5430836081504822,
|
||||
0.6606786251068115,
|
||||
0.270527720451355,
|
||||
0.5596919655799866,
|
||||
-0.31839630007743835,
|
||||
1.5117206573486328,
|
||||
-1.363267183303833,
|
||||
-0.9832196235656738,
|
||||
1.5112667083740234,
|
||||
0.6418707370758057,
|
||||
-0.7474458813667297,
|
||||
-0.923438549041748,
|
||||
0.5733984112739563,
|
||||
-0.10929951071739197,
|
||||
0.5181121230125427,
|
||||
0.10653535276651382,
|
||||
0.26924076676368713,
|
||||
1.3247679471969604,
|
||||
0.037456899881362915,
|
||||
-0.6378393173217773,
|
||||
-0.8147554397583008,
|
||||
-0.6895065307617188,
|
||||
0.8436542749404907,
|
||||
1.1657012701034546,
|
||||
0.5269321799278259,
|
||||
1.6192532777786255,
|
||||
-0.963976263999939,
|
||||
0.14152038097381592,
|
||||
-0.1636609584093094,
|
||||
-0.3582225739955902,
|
||||
1.7222793102264404,
|
||||
-0.3035756051540375,
|
||||
0.23887419700622559,
|
||||
1.3440011739730835,
|
||||
0.1032256931066513,
|
||||
1.1003541946411133,
|
||||
-0.3416801989078522,
|
||||
0.947338879108429,
|
||||
0.409_972_43,
|
||||
0.408_450_66,
|
||||
0.257_865_4,
|
||||
1.095_021_4,
|
||||
-0.506_486_6,
|
||||
0.099_775_404,
|
||||
-0.653_973_4,
|
||||
0.731_693_7,
|
||||
-1.456_733,
|
||||
1.608_935_4,
|
||||
0.093_769_975,
|
||||
-1.259_749,
|
||||
0.254_633_5,
|
||||
-0.501_957_3,
|
||||
-1.041_2,
|
||||
0.732_267_2,
|
||||
1.307_535_5,
|
||||
-1.162_798_8,
|
||||
0.119_636_11,
|
||||
-0.163_135_33,
|
||||
0.661_445_3,
|
||||
1.189_920_5,
|
||||
0.816_533_9,
|
||||
-0.913_523_6,
|
||||
-0.353_806_53,
|
||||
0.763_927_04,
|
||||
-0.588_950_7,
|
||||
-0.763_597_37,
|
||||
1.335_205_7,
|
||||
0.604_273_6,
|
||||
-0.103_442_08,
|
||||
-0.151_216_92,
|
||||
1.246_568_3,
|
||||
0.505_721_4,
|
||||
0.950_511_2,
|
||||
1.296_648_3,
|
||||
0.873_796_3,
|
||||
-0.560_259_4,
|
||||
1.285_784_5,
|
||||
0.816_823_84,
|
||||
-1.464_799_4,
|
||||
-1.262_928_4,
|
||||
1.122_018_8,
|
||||
1.566_334_1,
|
||||
2.558_138_4,
|
||||
-0.233_363_88,
|
||||
-0.013_472_13,
|
||||
1.860_634_8,
|
||||
1.549_620_5,
|
||||
0.347_629_25,
|
||||
0.093_008_03,
|
||||
0.614_740_3,
|
||||
0.712_364_55,
|
||||
-1.776_507_3,
|
||||
0.353_864_58,
|
||||
1.199_613_2,
|
||||
-0.712_258_93,
|
||||
-0.620_034_4,
|
||||
-0.228_134_95,
|
||||
-0.789_274_63,
|
||||
-1.611_111_8,
|
||||
-1.871_612_9,
|
||||
0.543_083_6,
|
||||
0.660_678_6,
|
||||
0.270_527_72,
|
||||
0.559_691_97,
|
||||
-0.318_396_3,
|
||||
1.511_720_7,
|
||||
-1.363_267_2,
|
||||
-0.983_219_6,
|
||||
1.511_266_7,
|
||||
0.641_870_74,
|
||||
-0.747_445_9,
|
||||
-0.923_438_55,
|
||||
0.573_398_4,
|
||||
-0.109_299_51,
|
||||
0.518_112_1,
|
||||
0.106_535_35,
|
||||
0.269_240_77,
|
||||
1.324_768,
|
||||
0.037_456_9,
|
||||
-0.637_839_3,
|
||||
-0.814_755_44,
|
||||
-0.689_506_53,
|
||||
0.843_654_3,
|
||||
1.165_701_3,
|
||||
0.526_932_2,
|
||||
1.619_253_3,
|
||||
-0.963_976_26,
|
||||
0.141_520_38,
|
||||
-0.163_660_96,
|
||||
-0.358_222_57,
|
||||
1.722_279_3,
|
||||
-0.303_575_6,
|
||||
0.238_874_2,
|
||||
1.344_001_2,
|
||||
0.103_225_69,
|
||||
1.100_354_2,
|
||||
-0.341_680_2,
|
||||
0.947_338_9,
|
||||
],
|
||||
(20, 5),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let bias_ih_l0 = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
-0.568515956401825,
|
||||
0.8375961780548096,
|
||||
1.783660650253296,
|
||||
-0.1954246610403061,
|
||||
0.235193133354187,
|
||||
1.9142433404922485,
|
||||
1.8364111185073853,
|
||||
1.324532389640808,
|
||||
-0.07051458209753036,
|
||||
0.34697940945625305,
|
||||
-0.653679609298706,
|
||||
1.5586202144622803,
|
||||
0.2185661494731903,
|
||||
-0.5743072628974915,
|
||||
1.4571250677108765,
|
||||
1.7709556818008423,
|
||||
-2.0172998905181885,
|
||||
0.42350319027900696,
|
||||
0.5730220079421997,
|
||||
-1.7962429523468018,
|
||||
-0.568_515_96,
|
||||
0.837_596_2,
|
||||
1.783_660_7,
|
||||
-0.195_424_66,
|
||||
0.235_193_13,
|
||||
1.914_243_3,
|
||||
1.836_411_1,
|
||||
1.324_532_4,
|
||||
-0.070_514_58,
|
||||
0.346_979_4,
|
||||
-0.653_679_6,
|
||||
1.558_620_2,
|
||||
0.218_566_15,
|
||||
-0.574_307_26,
|
||||
1.457_125_1,
|
||||
1.770_955_7,
|
||||
-2.017_3,
|
||||
0.423_503_2,
|
||||
0.573_022,
|
||||
-1.796_243,
|
||||
],
|
||||
(20,),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let bias_hh_l0 = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
1.2470403909683228,
|
||||
1.2738511562347412,
|
||||
0.3909492492675781,
|
||||
0.387210488319397,
|
||||
0.14440394937992096,
|
||||
0.7771684527397156,
|
||||
-2.3381125926971436,
|
||||
-0.829120397567749,
|
||||
1.1661391258239746,
|
||||
1.4786574840545654,
|
||||
0.26760873198509216,
|
||||
0.7561198472976685,
|
||||
-0.5873361229896545,
|
||||
-2.061920642852783,
|
||||
0.4304734766483307,
|
||||
0.3376566171646118,
|
||||
-0.3437853455543518,
|
||||
-0.6172260642051697,
|
||||
1.2529692649841309,
|
||||
-0.05141742154955864,
|
||||
1.247_040_4,
|
||||
1.273_851_2,
|
||||
0.390_949_25,
|
||||
0.387_210_5,
|
||||
0.144_403_95,
|
||||
0.777_168_45,
|
||||
-2.338_112_6,
|
||||
-0.829_120_4,
|
||||
1.166_139_1,
|
||||
1.478_657_5,
|
||||
0.267_608_73,
|
||||
0.756_119_85,
|
||||
-0.587_336_1,
|
||||
-2.061_920_6,
|
||||
0.430_473_48,
|
||||
0.337_656_62,
|
||||
-0.343_785_35,
|
||||
-0.617_226_06,
|
||||
1.252_969_3,
|
||||
-0.051_417_42,
|
||||
],
|
||||
(20,),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let input = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
0.6472128033638,
|
||||
-0.04116716980934143,
|
||||
-0.17749308049678802,
|
||||
-0.500039279460907,
|
||||
0.8672749400138855,
|
||||
-0.27319222688674927,
|
||||
-0.4607681334018707,
|
||||
-0.0990937128663063,
|
||||
0.47284480929374695,
|
||||
1.0049484968185425,
|
||||
-0.2871420383453369,
|
||||
-1.1618621349334717,
|
||||
0.647_212_8,
|
||||
-0.041_167_17,
|
||||
-0.177_493_08,
|
||||
-0.500_039_3,
|
||||
0.867_274_94,
|
||||
-0.273_192_23,
|
||||
-0.460_768_13,
|
||||
-0.099_093_71,
|
||||
0.472_844_8,
|
||||
1.004_948_5,
|
||||
-0.287_142_04,
|
||||
-1.161_862_1,
|
||||
],
|
||||
(4, 1, 3),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let h0 = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
0.02758178487420082,
|
||||
0.5652382373809814,
|
||||
-0.011487378738820553,
|
||||
0.6706400513648987,
|
||||
-0.4929250478744507,
|
||||
0.027_581_785,
|
||||
0.565_238_24,
|
||||
-0.011_487_379,
|
||||
0.670_640_05,
|
||||
-0.492_925_05,
|
||||
],
|
||||
(1, 1, 5),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let c0 = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
1.505028486251831,
|
||||
-2.32635498046875,
|
||||
1.6168899536132812,
|
||||
-0.9026237726211548,
|
||||
0.17366823554039001,
|
||||
1.505_028_5,
|
||||
-2.326_355,
|
||||
1.616_89,
|
||||
-0.902_623_8,
|
||||
0.173_668_24,
|
||||
],
|
||||
(1, 1, 5),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let output = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
0.5956016778945923,
|
||||
-0.01723279245197773,
|
||||
0.11035571992397308,
|
||||
-0.49323174357414246,
|
||||
0.047632161527872086,
|
||||
0.6358451843261719,
|
||||
0.040328118950128555,
|
||||
-0.3788611590862274,
|
||||
-0.7464339733123779,
|
||||
0.20080909132957458,
|
||||
0.5840265154838562,
|
||||
0.1453288197517395,
|
||||
-0.7345298528671265,
|
||||
-0.5214304327964783,
|
||||
0.21903817355632782,
|
||||
0.7420451641082764,
|
||||
0.31943878531455994,
|
||||
-0.04726646468043327,
|
||||
-0.2823849618434906,
|
||||
0.2713133990764618,
|
||||
0.595_601_7,
|
||||
-0.017_232_792,
|
||||
0.110_355_72,
|
||||
-0.493_231_74,
|
||||
0.047_632_16,
|
||||
0.635_845_2,
|
||||
0.040_328_12,
|
||||
-0.378_861_16,
|
||||
-0.746_434,
|
||||
0.200_809_09,
|
||||
0.584_026_5,
|
||||
0.145_328_82,
|
||||
-0.734_529_85,
|
||||
-0.521_430_43,
|
||||
0.219_038_17,
|
||||
0.742_045_16,
|
||||
0.319_438_8,
|
||||
-0.047_266_465,
|
||||
-0.282_384_96,
|
||||
0.271_313_4,
|
||||
],
|
||||
(4, 1, 5),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let hn = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
0.7420451641082764,
|
||||
0.31943878531455994,
|
||||
-0.04726646468043327,
|
||||
-0.2823849618434906,
|
||||
0.2713133990764618,
|
||||
0.742_045_16,
|
||||
0.319_438_8,
|
||||
-0.047_266_465,
|
||||
-0.282_384_96,
|
||||
0.271_313_4,
|
||||
],
|
||||
(1, 1, 5),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let cn = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
0.9630558490753174,
|
||||
1.0033069849014282,
|
||||
-1.754899024963379,
|
||||
-1.5967122316360474,
|
||||
0.8252924680709839,
|
||||
0.963_055_85,
|
||||
1.003_307,
|
||||
-1.754_899,
|
||||
-1.596_712_2,
|
||||
0.825_292_47,
|
||||
],
|
||||
(1, 1, 5),
|
||||
&Device::Cpu,
|
||||
@ -3929,8 +3922,8 @@ fn test_lstm() -> Result<()> {
|
||||
let idx_iofc = {
|
||||
let stride = hidden_size as i64;
|
||||
let dev = weight_ih_l0.device();
|
||||
let idx_i = Tensor::arange(0 * stride, 1 * stride, dev)?;
|
||||
let idx_f = Tensor::arange(1 * stride, 2 * stride, dev)?;
|
||||
let idx_i = Tensor::arange(0, stride, dev)?;
|
||||
let idx_f = Tensor::arange(stride, 2 * stride, dev)?;
|
||||
let idx_g = Tensor::arange(2 * stride, 3 * stride, dev)?;
|
||||
let idx_o = Tensor::arange(3 * stride, 4 * stride, dev)?;
|
||||
|
||||
@ -3966,17 +3959,346 @@ fn test_lstm() -> Result<()> {
|
||||
Ok(diffs.iter().all(|f| f.abs() < 0.0001))
|
||||
};
|
||||
assert!(
|
||||
diff_close_enough(&output, &actual_output)?,
|
||||
diff_close_enough(&output, actual_output)?,
|
||||
"output did not match expected\n{actual_output}\n{output}",
|
||||
);
|
||||
assert!(
|
||||
diff_close_enough(&hn, &actual_hn)?,
|
||||
diff_close_enough(&hn, actual_hn)?,
|
||||
"hn did not match expected\n{actual_hn}\n{hn}",
|
||||
);
|
||||
assert!(
|
||||
diff_close_enough(&cn, &actual_cn)?,
|
||||
diff_close_enough(&cn, actual_cn)?,
|
||||
"cn did not match expected\n{actual_cn}\n{cn}",
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_dim_changed() -> Result<()> {
|
||||
// Create a manual graph for the Expand operation
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Expand".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec!["data".to_string(), "new_shape".to_string()],
|
||||
output: vec!["expanded".to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: "data".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: "new_shape".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: "expanded".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
..GraphProto::default()
|
||||
}));
|
||||
|
||||
// Input tensor with shape [3, 1]
|
||||
let data = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32], (3, 1), &Device::Cpu)?;
|
||||
|
||||
// New shape tensor: [2, 1, 6]
|
||||
let new_shape = Tensor::from_vec(vec![2i64, 1, 6], (3,), &Device::Cpu)?;
|
||||
|
||||
// Expected output after expansion
|
||||
let expected = Tensor::from_vec(
|
||||
vec![
|
||||
1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32,
|
||||
2.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32,
|
||||
1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 3.0f32, 3.0f32, 3.0f32,
|
||||
3.0f32, 3.0f32, 3.0f32,
|
||||
],
|
||||
(2, 3, 6),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
// Execute the model evaluation
|
||||
let inputs = HashMap::from_iter([
|
||||
("data".to_string(), data),
|
||||
("new_shape".to_string(), new_shape),
|
||||
]);
|
||||
let result = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
|
||||
// Retrieve and compare the result
|
||||
let expanded = result.get("expanded").expect("Output 'expanded' not found");
|
||||
|
||||
assert_eq!(expanded.to_vec3::<f32>()?, expected.to_vec3::<f32>()?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn make_graph_helper(
|
||||
op_name: &str,
|
||||
inputs: &[&str],
|
||||
outputs: &[&str],
|
||||
attribs: Vec<AttributeProto>,
|
||||
) -> ModelProto {
|
||||
create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: op_name.to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: attribs,
|
||||
input: inputs.iter().map(|s| s.to_string()).collect(),
|
||||
output: outputs.iter().map(|s| s.to_string()).collect(),
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
input: inputs
|
||||
.iter()
|
||||
.map(|name| ValueInfoProto {
|
||||
name: name.to_string(),
|
||||
..ValueInfoProto::default()
|
||||
})
|
||||
.collect(),
|
||||
output: outputs
|
||||
.iter()
|
||||
.map(|name| ValueInfoProto {
|
||||
name: name.to_string(),
|
||||
..ValueInfoProto::default()
|
||||
})
|
||||
.collect(),
|
||||
..GraphProto::default()
|
||||
}))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_dim_unchanged() -> Result<()> {
|
||||
// Create a manual graph for the Expand operation
|
||||
let manual_graph = make_graph_helper("Expand", &["data", "new_shape"], &["expanded"], vec![]);
|
||||
|
||||
// Input tensor with shape [3, 1] and dtype f32
|
||||
let data = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32], (3, 1), &Device::Cpu)?;
|
||||
|
||||
// New shape tensor: [3, 4]
|
||||
let new_shape = Tensor::from_vec(vec![3i64, 4], (2,), &Device::Cpu)?;
|
||||
|
||||
// Expected output after expansion, dtype f32
|
||||
let expected = Tensor::from_vec(
|
||||
vec![
|
||||
1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 3.0f32, 3.0f32, 3.0f32,
|
||||
3.0f32,
|
||||
],
|
||||
(3, 4),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
// Execute the model evaluation
|
||||
let inputs = HashMap::from_iter([
|
||||
("data".to_string(), data),
|
||||
("new_shape".to_string(), new_shape),
|
||||
]);
|
||||
let result = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
|
||||
// Retrieve and compare the result
|
||||
let expanded = result.get("expanded").expect("Output 'expanded' not found");
|
||||
assert_eq!(expanded.to_vec2::<f32>()?, expected.to_vec2::<f32>()?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn make_split_graph_helper(inputs: &[&str], outputs: &[&str], axis: i64) -> ModelProto {
|
||||
let attribs = vec![AttributeProto {
|
||||
name: "axis".to_string(),
|
||||
r#type: AttributeType::Int.into(),
|
||||
i: axis,
|
||||
..AttributeProto::default()
|
||||
}];
|
||||
|
||||
make_graph_helper("Split", inputs, outputs, attribs)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split_equal_parts_1d_opset13() -> Result<()> {
|
||||
let input = Tensor::from_vec(
|
||||
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32],
|
||||
(6,),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let mut inputs = HashMap::new();
|
||||
inputs.insert("input".to_string(), input);
|
||||
|
||||
{
|
||||
let manual_graph =
|
||||
make_split_graph_helper(&["input"], &["output_1", "output_2", "output_3"], 0);
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs.clone())?;
|
||||
assert_eq!(eval.len(), 3);
|
||||
|
||||
let out1 = eval.get("output_1").expect("Output 'output_1' not found");
|
||||
let out2 = eval.get("output_2").expect("Output 'output_2' not found");
|
||||
let out3 = eval.get("output_3").expect("Output 'output_3' not found");
|
||||
|
||||
assert_eq!(out1.to_vec1::<f32>()?, vec![1.0f32, 2.0f32]);
|
||||
assert_eq!(out2.to_vec1::<f32>()?, vec![3.0f32, 4.0f32]);
|
||||
assert_eq!(out3.to_vec1::<f32>()?, vec![5.0f32, 6.0f32]);
|
||||
}
|
||||
|
||||
{
|
||||
let splits = Tensor::from_vec(vec![2i64, 4], (2,), &Device::Cpu)?;
|
||||
inputs.insert("split".to_string(), splits);
|
||||
|
||||
let manual_graph =
|
||||
make_split_graph_helper(&["input", "split"], &["output_1", "output_2"], 0);
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 2);
|
||||
|
||||
let out1 = eval.get("output_1").expect("Output 'output_1' not found");
|
||||
let out2 = eval.get("output_2").expect("Output 'output_2' not found");
|
||||
|
||||
assert_eq!(out1.to_vec1::<f32>()?, vec![1.0f32, 2.0f32]);
|
||||
assert_eq!(out2.to_vec1::<f32>()?, vec![3.0f32, 4.0f32, 5.0f32, 6.0f32]);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn make_reduce_sum_graph_helper(
|
||||
inputs: &[&str],
|
||||
outputs: &[&str],
|
||||
keepdims: Option<i64>,
|
||||
noop_with_empty_axes: Option<i64>,
|
||||
) -> ModelProto {
|
||||
let mut attribs = vec![];
|
||||
if let Some(keepdims) = keepdims {
|
||||
attribs.push(AttributeProto {
|
||||
name: "keepdims".to_string(),
|
||||
r#type: AttributeType::Int.into(),
|
||||
i: keepdims,
|
||||
..AttributeProto::default()
|
||||
});
|
||||
}
|
||||
if let Some(noop_with_empty_axes) = noop_with_empty_axes {
|
||||
attribs.push(AttributeProto {
|
||||
name: "noop_with_empty_axes".to_string(),
|
||||
r#type: AttributeType::Ints.into(),
|
||||
i: noop_with_empty_axes,
|
||||
..AttributeProto::default()
|
||||
});
|
||||
}
|
||||
make_graph_helper("ReduceSum", inputs, outputs, attribs)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reduce_sum_default_axes_keepdims() -> Result<()> {
|
||||
let manual_graph = make_reduce_sum_graph_helper(&["data", "axes"], &["reduced"], Some(1), None);
|
||||
|
||||
// Test with example data
|
||||
{
|
||||
let data = Tensor::from_vec(
|
||||
vec![
|
||||
1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
|
||||
],
|
||||
(3, 2, 2),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
// let axes = Tensor::from_vec(Vec::<i64>::new(), (0,), &Device::Cpu)?;
|
||||
|
||||
let mut inputs = HashMap::new();
|
||||
inputs.insert("data".to_string(), data);
|
||||
// inputs.insert("axes".to_string(), axes);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let reduced = eval.get("reduced").expect("Output 'reduced' not found");
|
||||
let expected = Tensor::from_vec(vec![78.0f32], (1, 1, 1), &Device::Cpu)?;
|
||||
|
||||
assert_eq!(reduced.to_vec3::<f32>()?, expected.to_vec3::<f32>()?);
|
||||
}
|
||||
|
||||
{
|
||||
let data = Tensor::from_vec(
|
||||
vec![
|
||||
-5.2f32, 7.8, -3.1, 9.4, 2.6, -8.7, 4.3, -1.9, 6.5, -0.8, -7.2, 3.6,
|
||||
],
|
||||
(3, 2, 2),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let mut inputs = HashMap::new();
|
||||
inputs.insert("data".to_string(), data.clone());
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let reduced = eval.get("reduced").expect("Output 'reduced' not found");
|
||||
let expected = data.sum_all()?.reshape((1, 1, 1))?;
|
||||
|
||||
assert_eq!(reduced.to_vec3::<f32>()?, expected.to_vec3::<f32>()?);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reduce_sum_do_not_keep_dims() -> Result<()> {
|
||||
let manual_graph = make_reduce_sum_graph_helper(&["data", "axes"], &["reduced"], Some(0), None);
|
||||
|
||||
// Test with example data
|
||||
{
|
||||
let data = Tensor::from_vec(
|
||||
vec![
|
||||
1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
|
||||
],
|
||||
(3, 2, 2),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let axes = Tensor::from_vec(vec![1i64], (1,), &Device::Cpu)?;
|
||||
|
||||
let mut inputs = HashMap::new();
|
||||
inputs.insert("data".to_string(), data);
|
||||
inputs.insert("axes".to_string(), axes);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let reduced = eval.get("reduced").expect("Output 'reduced' not found");
|
||||
let expected = Tensor::from_vec(
|
||||
vec![4.0f32, 6.0, 12.0, 14.0, 20.0, 22.0],
|
||||
(3, 2),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
assert_eq!(reduced.to_vec2::<f32>()?, expected.to_vec2::<f32>()?);
|
||||
}
|
||||
|
||||
// Test with random data
|
||||
{
|
||||
let _shape = (3, 2, 2);
|
||||
let data = Tensor::from_vec(
|
||||
vec![
|
||||
-5.2f32, 7.8, -3.1, 9.4, 2.6, -8.7, 4.3, -1.9, 6.5, -0.8, -7.2, 3.6,
|
||||
],
|
||||
(3, 2, 2),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let axes = Tensor::from_vec(vec![1i64], (1,), &Device::Cpu)?;
|
||||
|
||||
let mut inputs = HashMap::new();
|
||||
inputs.insert("data".to_string(), data.clone());
|
||||
inputs.insert("axes".to_string(), axes);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let reduced = eval.get("reduced").expect("Output 'reduced' not found");
|
||||
|
||||
// Calculate expected result
|
||||
let expected = data.sum(1)?;
|
||||
|
||||
assert_eq!(reduced.to_vec2::<f32>()?, expected.to_vec2::<f32>()?);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -92,28 +92,23 @@ impl ClipConfig {
|
||||
impl ClipModel {
|
||||
pub fn new(vs: candle_nn::VarBuilder, c: &ClipConfig) -> Result<Self> {
|
||||
let text_model = ClipTextTransformer::new(vs.pp("text_model"), &c.text_config)?;
|
||||
|
||||
let vision_model = ClipVisionTransformer::new(vs.pp("vision_model"), &c.vision_config)?;
|
||||
|
||||
let visual_projection = candle_nn::linear_no_bias(
|
||||
c.vision_config.embed_dim,
|
||||
c.vision_config.projection_dim,
|
||||
vs.pp("visual_projection"),
|
||||
)?;
|
||||
|
||||
let text_projection = candle_nn::linear_no_bias(
|
||||
c.text_config.embed_dim,
|
||||
c.text_config.projection_dim,
|
||||
vs.pp("text_projection"),
|
||||
)?;
|
||||
|
||||
// originally nn.Parameter
|
||||
let logit_scale = if vs.contains_tensor("logit_scale") {
|
||||
vs.get(&[], "logit_scale")?
|
||||
} else {
|
||||
Tensor::new(&[c.logit_scale_init_value], vs.device())?
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
text_model,
|
||||
vision_model,
|
||||
|
@ -77,7 +77,7 @@ impl ClipTextEmbeddings {
|
||||
)?;
|
||||
let position_ids =
|
||||
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
|
||||
Ok(ClipTextEmbeddings {
|
||||
Ok(Self {
|
||||
token_embedding,
|
||||
position_embedding,
|
||||
position_ids,
|
||||
@ -298,7 +298,7 @@ impl ClipTextTransformer {
|
||||
})
|
||||
}
|
||||
|
||||
// TODO: rewrrite to newer version
|
||||
// TODO: rewrite to newer version
|
||||
fn build_causal_attention_mask(
|
||||
bsz: usize,
|
||||
seq_len: usize,
|
||||
|
@ -11,13 +11,13 @@ use candle_nn::{
|
||||
BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
|
||||
pub struct Config {
|
||||
exp_ratio: usize,
|
||||
in_channels: usize,
|
||||
blocks: [usize; 4],
|
||||
attn: bool,
|
||||
lkc_use_act: bool,
|
||||
pub exp_ratio: usize,
|
||||
pub in_channels: usize,
|
||||
pub blocks: [usize; 4],
|
||||
pub attn: bool,
|
||||
pub lkc_use_act: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@ -495,7 +495,6 @@ fn fastvit_model(cfg: &Config, nclasses: Option<usize>, vb: VarBuilder) -> Resul
|
||||
.apply(&stage3)?
|
||||
.apply(&stage4)?
|
||||
.apply(&final_conv)?;
|
||||
|
||||
match &cls {
|
||||
None => Ok(xs),
|
||||
Some(cls) => xs.mean(D::Minus2)?.mean(D::Minus1)?.apply(cls),
|
||||
|
@ -1,3 +1,20 @@
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
pub trait WithForward {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn forward(
|
||||
&self,
|
||||
img: &Tensor,
|
||||
img_ids: &Tensor,
|
||||
txt: &Tensor,
|
||||
txt_ids: &Tensor,
|
||||
timesteps: &Tensor,
|
||||
y: &Tensor,
|
||||
guidance: Option<&Tensor>,
|
||||
) -> Result<Tensor>;
|
||||
}
|
||||
|
||||
pub mod autoencoder;
|
||||
pub mod model;
|
||||
pub mod quantized_model;
|
||||
pub mod sampling;
|
||||
|
@ -109,14 +109,14 @@ fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result<Tensor> {
|
||||
(fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec())
|
||||
}
|
||||
|
||||
fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
|
||||
pub(crate) fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
|
||||
let q = apply_rope(q, pe)?.contiguous()?;
|
||||
let k = apply_rope(k, pe)?.contiguous()?;
|
||||
let x = scaled_dot_product_attention(&q, &k, v)?;
|
||||
x.transpose(1, 2)?.flatten_from(2)
|
||||
}
|
||||
|
||||
fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
|
||||
pub(crate) fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
|
||||
const TIME_FACTOR: f64 = 1000.;
|
||||
const MAX_PERIOD: f64 = 10000.;
|
||||
if dim % 2 == 1 {
|
||||
@ -144,7 +144,7 @@ pub struct EmbedNd {
|
||||
}
|
||||
|
||||
impl EmbedNd {
|
||||
fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
|
||||
pub fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
|
||||
Self {
|
||||
dim,
|
||||
theta,
|
||||
@ -575,9 +575,11 @@ impl Flux {
|
||||
final_layer,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl super::WithForward for Flux {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn forward(
|
||||
fn forward(
|
||||
&self,
|
||||
img: &Tensor,
|
||||
img_ids: &Tensor,
|
||||
|
465
candle-transformers/src/models/flux/quantized_model.rs
Normal file
465
candle-transformers/src/models/flux/quantized_model.rs
Normal file
@ -0,0 +1,465 @@
|
||||
use super::model::{attention, timestep_embedding, Config, EmbedNd};
|
||||
use crate::quantized_nn::{linear, linear_b, Linear};
|
||||
use crate::quantized_var_builder::VarBuilder;
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{LayerNorm, RmsNorm};
|
||||
|
||||
fn layer_norm(dim: usize, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
let ws = Tensor::ones(dim, DType::F32, vb.device())?;
|
||||
Ok(LayerNorm::new_no_bias(ws, 1e-6))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MlpEmbedder {
|
||||
in_layer: Linear,
|
||||
out_layer: Linear,
|
||||
}
|
||||
|
||||
impl MlpEmbedder {
|
||||
fn new(in_sz: usize, h_sz: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let in_layer = linear(in_sz, h_sz, vb.pp("in_layer"))?;
|
||||
let out_layer = linear(h_sz, h_sz, vb.pp("out_layer"))?;
|
||||
Ok(Self {
|
||||
in_layer,
|
||||
out_layer,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::Module for MlpEmbedder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.in_layer)?.silu()?.apply(&self.out_layer)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QkNorm {
|
||||
query_norm: RmsNorm,
|
||||
key_norm: RmsNorm,
|
||||
}
|
||||
|
||||
impl QkNorm {
|
||||
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let query_norm = vb.get(dim, "query_norm.scale")?.dequantize(vb.device())?;
|
||||
let query_norm = RmsNorm::new(query_norm, 1e-6);
|
||||
let key_norm = vb.get(dim, "key_norm.scale")?.dequantize(vb.device())?;
|
||||
let key_norm = RmsNorm::new(key_norm, 1e-6);
|
||||
Ok(Self {
|
||||
query_norm,
|
||||
key_norm,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct ModulationOut {
|
||||
shift: Tensor,
|
||||
scale: Tensor,
|
||||
gate: Tensor,
|
||||
}
|
||||
|
||||
impl ModulationOut {
|
||||
fn scale_shift(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.broadcast_mul(&(&self.scale + 1.)?)?
|
||||
.broadcast_add(&self.shift)
|
||||
}
|
||||
|
||||
fn gate(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.gate.broadcast_mul(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Modulation1 {
|
||||
lin: Linear,
|
||||
}
|
||||
|
||||
impl Modulation1 {
|
||||
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let lin = linear(dim, 3 * dim, vb.pp("lin"))?;
|
||||
Ok(Self { lin })
|
||||
}
|
||||
|
||||
fn forward(&self, vec_: &Tensor) -> Result<ModulationOut> {
|
||||
let ys = vec_
|
||||
.silu()?
|
||||
.apply(&self.lin)?
|
||||
.unsqueeze(1)?
|
||||
.chunk(3, D::Minus1)?;
|
||||
if ys.len() != 3 {
|
||||
candle::bail!("unexpected len from chunk {ys:?}")
|
||||
}
|
||||
Ok(ModulationOut {
|
||||
shift: ys[0].clone(),
|
||||
scale: ys[1].clone(),
|
||||
gate: ys[2].clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Modulation2 {
|
||||
lin: Linear,
|
||||
}
|
||||
|
||||
impl Modulation2 {
|
||||
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let lin = linear(dim, 6 * dim, vb.pp("lin"))?;
|
||||
Ok(Self { lin })
|
||||
}
|
||||
|
||||
fn forward(&self, vec_: &Tensor) -> Result<(ModulationOut, ModulationOut)> {
|
||||
let ys = vec_
|
||||
.silu()?
|
||||
.apply(&self.lin)?
|
||||
.unsqueeze(1)?
|
||||
.chunk(6, D::Minus1)?;
|
||||
if ys.len() != 6 {
|
||||
candle::bail!("unexpected len from chunk {ys:?}")
|
||||
}
|
||||
let mod1 = ModulationOut {
|
||||
shift: ys[0].clone(),
|
||||
scale: ys[1].clone(),
|
||||
gate: ys[2].clone(),
|
||||
};
|
||||
let mod2 = ModulationOut {
|
||||
shift: ys[3].clone(),
|
||||
scale: ys[4].clone(),
|
||||
gate: ys[5].clone(),
|
||||
};
|
||||
Ok((mod1, mod2))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SelfAttention {
|
||||
qkv: Linear,
|
||||
norm: QkNorm,
|
||||
proj: Linear,
|
||||
num_heads: usize,
|
||||
}
|
||||
|
||||
impl SelfAttention {
|
||||
fn new(dim: usize, num_heads: usize, qkv_bias: bool, vb: VarBuilder) -> Result<Self> {
|
||||
let head_dim = dim / num_heads;
|
||||
let qkv = linear_b(dim, dim * 3, qkv_bias, vb.pp("qkv"))?;
|
||||
let norm = QkNorm::new(head_dim, vb.pp("norm"))?;
|
||||
let proj = linear(dim, dim, vb.pp("proj"))?;
|
||||
Ok(Self {
|
||||
qkv,
|
||||
norm,
|
||||
proj,
|
||||
num_heads,
|
||||
})
|
||||
}
|
||||
|
||||
fn qkv(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
|
||||
let qkv = xs.apply(&self.qkv)?;
|
||||
let (b, l, _khd) = qkv.dims3()?;
|
||||
let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
|
||||
let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
|
||||
let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
|
||||
let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
|
||||
let q = q.apply(&self.norm.query_norm)?;
|
||||
let k = k.apply(&self.norm.key_norm)?;
|
||||
Ok((q, k, v))
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn forward(&self, xs: &Tensor, pe: &Tensor) -> Result<Tensor> {
|
||||
let (q, k, v) = self.qkv(xs)?;
|
||||
attention(&q, &k, &v, pe)?.apply(&self.proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Mlp {
|
||||
lin1: Linear,
|
||||
lin2: Linear,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(in_sz: usize, mlp_sz: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let lin1 = linear(in_sz, mlp_sz, vb.pp("0"))?;
|
||||
let lin2 = linear(mlp_sz, in_sz, vb.pp("2"))?;
|
||||
Ok(Self { lin1, lin2 })
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DoubleStreamBlock {
|
||||
img_mod: Modulation2,
|
||||
img_norm1: LayerNorm,
|
||||
img_attn: SelfAttention,
|
||||
img_norm2: LayerNorm,
|
||||
img_mlp: Mlp,
|
||||
txt_mod: Modulation2,
|
||||
txt_norm1: LayerNorm,
|
||||
txt_attn: SelfAttention,
|
||||
txt_norm2: LayerNorm,
|
||||
txt_mlp: Mlp,
|
||||
}
|
||||
|
||||
impl DoubleStreamBlock {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let h_sz = cfg.hidden_size;
|
||||
let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize;
|
||||
let img_mod = Modulation2::new(h_sz, vb.pp("img_mod"))?;
|
||||
let img_norm1 = layer_norm(h_sz, vb.pp("img_norm1"))?;
|
||||
let img_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("img_attn"))?;
|
||||
let img_norm2 = layer_norm(h_sz, vb.pp("img_norm2"))?;
|
||||
let img_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("img_mlp"))?;
|
||||
let txt_mod = Modulation2::new(h_sz, vb.pp("txt_mod"))?;
|
||||
let txt_norm1 = layer_norm(h_sz, vb.pp("txt_norm1"))?;
|
||||
let txt_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("txt_attn"))?;
|
||||
let txt_norm2 = layer_norm(h_sz, vb.pp("txt_norm2"))?;
|
||||
let txt_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("txt_mlp"))?;
|
||||
Ok(Self {
|
||||
img_mod,
|
||||
img_norm1,
|
||||
img_attn,
|
||||
img_norm2,
|
||||
img_mlp,
|
||||
txt_mod,
|
||||
txt_norm1,
|
||||
txt_attn,
|
||||
txt_norm2,
|
||||
txt_mlp,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
img: &Tensor,
|
||||
txt: &Tensor,
|
||||
vec_: &Tensor,
|
||||
pe: &Tensor,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (img_mod1, img_mod2) = self.img_mod.forward(vec_)?; // shift, scale, gate
|
||||
let (txt_mod1, txt_mod2) = self.txt_mod.forward(vec_)?; // shift, scale, gate
|
||||
let img_modulated = img.apply(&self.img_norm1)?;
|
||||
let img_modulated = img_mod1.scale_shift(&img_modulated)?;
|
||||
let (img_q, img_k, img_v) = self.img_attn.qkv(&img_modulated)?;
|
||||
|
||||
let txt_modulated = txt.apply(&self.txt_norm1)?;
|
||||
let txt_modulated = txt_mod1.scale_shift(&txt_modulated)?;
|
||||
let (txt_q, txt_k, txt_v) = self.txt_attn.qkv(&txt_modulated)?;
|
||||
|
||||
let q = Tensor::cat(&[txt_q, img_q], 2)?;
|
||||
let k = Tensor::cat(&[txt_k, img_k], 2)?;
|
||||
let v = Tensor::cat(&[txt_v, img_v], 2)?;
|
||||
|
||||
let attn = attention(&q, &k, &v, pe)?;
|
||||
let txt_attn = attn.narrow(1, 0, txt.dim(1)?)?;
|
||||
let img_attn = attn.narrow(1, txt.dim(1)?, attn.dim(1)? - txt.dim(1)?)?;
|
||||
|
||||
let img = (img + img_mod1.gate(&img_attn.apply(&self.img_attn.proj)?))?;
|
||||
let img = (&img
|
||||
+ img_mod2.gate(
|
||||
&img_mod2
|
||||
.scale_shift(&img.apply(&self.img_norm2)?)?
|
||||
.apply(&self.img_mlp)?,
|
||||
)?)?;
|
||||
|
||||
let txt = (txt + txt_mod1.gate(&txt_attn.apply(&self.txt_attn.proj)?))?;
|
||||
let txt = (&txt
|
||||
+ txt_mod2.gate(
|
||||
&txt_mod2
|
||||
.scale_shift(&txt.apply(&self.txt_norm2)?)?
|
||||
.apply(&self.txt_mlp)?,
|
||||
)?)?;
|
||||
|
||||
Ok((img, txt))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SingleStreamBlock {
|
||||
linear1: Linear,
|
||||
linear2: Linear,
|
||||
norm: QkNorm,
|
||||
pre_norm: LayerNorm,
|
||||
modulation: Modulation1,
|
||||
h_sz: usize,
|
||||
mlp_sz: usize,
|
||||
num_heads: usize,
|
||||
}
|
||||
|
||||
impl SingleStreamBlock {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let h_sz = cfg.hidden_size;
|
||||
let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize;
|
||||
let head_dim = h_sz / cfg.num_heads;
|
||||
let linear1 = linear(h_sz, h_sz * 3 + mlp_sz, vb.pp("linear1"))?;
|
||||
let linear2 = linear(h_sz + mlp_sz, h_sz, vb.pp("linear2"))?;
|
||||
let norm = QkNorm::new(head_dim, vb.pp("norm"))?;
|
||||
let pre_norm = layer_norm(h_sz, vb.pp("pre_norm"))?;
|
||||
let modulation = Modulation1::new(h_sz, vb.pp("modulation"))?;
|
||||
Ok(Self {
|
||||
linear1,
|
||||
linear2,
|
||||
norm,
|
||||
pre_norm,
|
||||
modulation,
|
||||
h_sz,
|
||||
mlp_sz,
|
||||
num_heads: cfg.num_heads,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, vec_: &Tensor, pe: &Tensor) -> Result<Tensor> {
|
||||
let mod_ = self.modulation.forward(vec_)?;
|
||||
let x_mod = mod_.scale_shift(&xs.apply(&self.pre_norm)?)?;
|
||||
let x_mod = x_mod.apply(&self.linear1)?;
|
||||
let qkv = x_mod.narrow(D::Minus1, 0, 3 * self.h_sz)?;
|
||||
let (b, l, _khd) = qkv.dims3()?;
|
||||
let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
|
||||
let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
|
||||
let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
|
||||
let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
|
||||
let mlp = x_mod.narrow(D::Minus1, 3 * self.h_sz, self.mlp_sz)?;
|
||||
let q = q.apply(&self.norm.query_norm)?;
|
||||
let k = k.apply(&self.norm.key_norm)?;
|
||||
let attn = attention(&q, &k, &v, pe)?;
|
||||
let output = Tensor::cat(&[attn, mlp.gelu()?], 2)?.apply(&self.linear2)?;
|
||||
xs + mod_.gate(&output)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LastLayer {
|
||||
norm_final: LayerNorm,
|
||||
linear: Linear,
|
||||
ada_ln_modulation: Linear,
|
||||
}
|
||||
|
||||
impl LastLayer {
|
||||
fn new(h_sz: usize, p_sz: usize, out_c: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let norm_final = layer_norm(h_sz, vb.pp("norm_final"))?;
|
||||
let linear_ = linear(h_sz, p_sz * p_sz * out_c, vb.pp("linear"))?;
|
||||
let ada_ln_modulation = linear(h_sz, 2 * h_sz, vb.pp("adaLN_modulation.1"))?;
|
||||
Ok(Self {
|
||||
norm_final,
|
||||
linear: linear_,
|
||||
ada_ln_modulation,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, vec: &Tensor) -> Result<Tensor> {
|
||||
let chunks = vec.silu()?.apply(&self.ada_ln_modulation)?.chunk(2, 1)?;
|
||||
let (shift, scale) = (&chunks[0], &chunks[1]);
|
||||
let xs = xs
|
||||
.apply(&self.norm_final)?
|
||||
.broadcast_mul(&(scale.unsqueeze(1)? + 1.0)?)?
|
||||
.broadcast_add(&shift.unsqueeze(1)?)?;
|
||||
xs.apply(&self.linear)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Flux {
|
||||
img_in: Linear,
|
||||
txt_in: Linear,
|
||||
time_in: MlpEmbedder,
|
||||
vector_in: MlpEmbedder,
|
||||
guidance_in: Option<MlpEmbedder>,
|
||||
pe_embedder: EmbedNd,
|
||||
double_blocks: Vec<DoubleStreamBlock>,
|
||||
single_blocks: Vec<SingleStreamBlock>,
|
||||
final_layer: LastLayer,
|
||||
}
|
||||
|
||||
impl Flux {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let img_in = linear(cfg.in_channels, cfg.hidden_size, vb.pp("img_in"))?;
|
||||
let txt_in = linear(cfg.context_in_dim, cfg.hidden_size, vb.pp("txt_in"))?;
|
||||
let mut double_blocks = Vec::with_capacity(cfg.depth);
|
||||
let vb_d = vb.pp("double_blocks");
|
||||
for idx in 0..cfg.depth {
|
||||
let db = DoubleStreamBlock::new(cfg, vb_d.pp(idx))?;
|
||||
double_blocks.push(db)
|
||||
}
|
||||
let mut single_blocks = Vec::with_capacity(cfg.depth_single_blocks);
|
||||
let vb_s = vb.pp("single_blocks");
|
||||
for idx in 0..cfg.depth_single_blocks {
|
||||
let sb = SingleStreamBlock::new(cfg, vb_s.pp(idx))?;
|
||||
single_blocks.push(sb)
|
||||
}
|
||||
let time_in = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("time_in"))?;
|
||||
let vector_in = MlpEmbedder::new(cfg.vec_in_dim, cfg.hidden_size, vb.pp("vector_in"))?;
|
||||
let guidance_in = if cfg.guidance_embed {
|
||||
let mlp = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("guidance_in"))?;
|
||||
Some(mlp)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let final_layer =
|
||||
LastLayer::new(cfg.hidden_size, 1, cfg.in_channels, vb.pp("final_layer"))?;
|
||||
let pe_dim = cfg.hidden_size / cfg.num_heads;
|
||||
let pe_embedder = EmbedNd::new(pe_dim, cfg.theta, cfg.axes_dim.to_vec());
|
||||
Ok(Self {
|
||||
img_in,
|
||||
txt_in,
|
||||
time_in,
|
||||
vector_in,
|
||||
guidance_in,
|
||||
pe_embedder,
|
||||
double_blocks,
|
||||
single_blocks,
|
||||
final_layer,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl super::WithForward for Flux {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn forward(
|
||||
&self,
|
||||
img: &Tensor,
|
||||
img_ids: &Tensor,
|
||||
txt: &Tensor,
|
||||
txt_ids: &Tensor,
|
||||
timesteps: &Tensor,
|
||||
y: &Tensor,
|
||||
guidance: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
if txt.rank() != 3 {
|
||||
candle::bail!("unexpected shape for txt {:?}", txt.shape())
|
||||
}
|
||||
if img.rank() != 3 {
|
||||
candle::bail!("unexpected shape for img {:?}", img.shape())
|
||||
}
|
||||
let dtype = img.dtype();
|
||||
let pe = {
|
||||
let ids = Tensor::cat(&[txt_ids, img_ids], 1)?;
|
||||
ids.apply(&self.pe_embedder)?
|
||||
};
|
||||
let mut txt = txt.apply(&self.txt_in)?;
|
||||
let mut img = img.apply(&self.img_in)?;
|
||||
let vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?;
|
||||
let vec_ = match (self.guidance_in.as_ref(), guidance) {
|
||||
(Some(g_in), Some(guidance)) => {
|
||||
(vec_ + timestep_embedding(guidance, 256, dtype)?.apply(g_in))?
|
||||
}
|
||||
_ => vec_,
|
||||
};
|
||||
let vec_ = (vec_ + y.apply(&self.vector_in))?;
|
||||
|
||||
// Double blocks
|
||||
for block in self.double_blocks.iter() {
|
||||
(img, txt) = block.forward(&img, &txt, &vec_, &pe)?
|
||||
}
|
||||
// Single blocks
|
||||
let mut img = Tensor::cat(&[&txt, &img], 1)?;
|
||||
for block in self.single_blocks.iter() {
|
||||
img = block.forward(&img, &vec_, &pe)?;
|
||||
}
|
||||
let img = img.i((.., txt.dim(1)?..))?;
|
||||
self.final_layer.forward(&img, &vec_)
|
||||
}
|
||||
}
|
@ -92,8 +92,8 @@ pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn denoise(
|
||||
model: &super::model::Flux,
|
||||
pub fn denoise<M: super::WithForward>(
|
||||
model: &M,
|
||||
img: &Tensor,
|
||||
img_ids: &Tensor,
|
||||
txt: &Tensor,
|
||||
|
@ -44,6 +44,7 @@ pub struct LlamaConfig {
|
||||
pub eos_token_id: Option<LlamaEosToks>,
|
||||
pub rope_scaling: Option<Llama3RopeConfig>,
|
||||
pub max_position_embeddings: usize,
|
||||
pub tie_word_embeddings: Option<bool>,
|
||||
}
|
||||
|
||||
impl LlamaConfig {
|
||||
@ -72,6 +73,7 @@ impl LlamaConfig {
|
||||
eos_token_id: self.eos_token_id,
|
||||
rope_scaling: self.rope_scaling,
|
||||
max_position_embeddings: self.max_position_embeddings,
|
||||
tie_word_embeddings: self.tie_word_embeddings.unwrap_or(false),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -91,6 +93,7 @@ pub struct Config {
|
||||
pub eos_token_id: Option<LlamaEosToks>,
|
||||
pub rope_scaling: Option<Llama3RopeConfig>,
|
||||
pub max_position_embeddings: usize,
|
||||
pub tie_word_embeddings: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@ -109,6 +112,7 @@ impl Config {
|
||||
eos_token_id: None,
|
||||
rope_scaling: None,
|
||||
max_position_embeddings: DEFAULT_MAX_SEQ_LEN,
|
||||
tie_word_embeddings: false,
|
||||
}
|
||||
}
|
||||
|
||||
@ -127,6 +131,7 @@ impl Config {
|
||||
eos_token_id: None,
|
||||
rope_scaling: None,
|
||||
max_position_embeddings: DEFAULT_MAX_SEQ_LEN,
|
||||
tie_word_embeddings: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -504,7 +509,11 @@ impl Llama {
|
||||
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let lm_head = if cfg.tie_word_embeddings {
|
||||
Linear::from_weights(wte.embeddings().clone(), None)
|
||||
} else {
|
||||
linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
|
||||
};
|
||||
let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
|
||||
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cfg).unwrap())
|
||||
|
@ -43,6 +43,7 @@ pub struct LLaVAConfig {
|
||||
pub image_token_index: isize,
|
||||
#[serde(default = "default_hf")]
|
||||
pub hf: bool,
|
||||
pub tie_word_embeddings: Option<bool>,
|
||||
}
|
||||
|
||||
fn default_hf() -> bool {
|
||||
@ -77,6 +78,7 @@ impl LLaVAConfig {
|
||||
use_flash_attn: false,
|
||||
rope_scaling: None, // Assume we don't have LLaVA for Llama 3.1
|
||||
max_position_embeddings: self.max_position_embeddings,
|
||||
tie_word_embeddings: self.tie_word_embeddings.unwrap_or(false),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -264,6 +266,7 @@ impl HFLLaVAConfig {
|
||||
use_cache: self.text_config.use_cache,
|
||||
vocab_size: self.vocab_size,
|
||||
image_token_index: self.image_token_index,
|
||||
tie_word_embeddings: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -22,7 +22,6 @@ impl MobileClipConfig {
|
||||
pub fn s1() -> Self {
|
||||
let text_config = text_model::Config::vit_base_patch32();
|
||||
let vision_config = fastvit::Config::mci1();
|
||||
|
||||
Self {
|
||||
text_config,
|
||||
vision_config,
|
||||
@ -32,7 +31,6 @@ impl MobileClipConfig {
|
||||
pub fn s2() -> Self {
|
||||
let text_config = text_model::Config::vit_base_patch32();
|
||||
let vision_config = fastvit::Config::mci2();
|
||||
|
||||
Self {
|
||||
text_config,
|
||||
vision_config,
|
||||
@ -45,12 +43,10 @@ impl MobileClipModel {
|
||||
pub fn new(vs: VarBuilder, c: &MobileClipConfig) -> Result<Self> {
|
||||
let vision_model = fastvit::fastvit(&c.vision_config, 512, vs.pp("visual.trunk"))?;
|
||||
let text_model = text_model::OpenClipTextTransformer::new(vs.pp("text"), &c.text_config)?;
|
||||
|
||||
let text_projection = vs.get(
|
||||
(c.text_config.embed_dim, c.text_config.projection_dim),
|
||||
"text.text_projection",
|
||||
)?;
|
||||
|
||||
let logit_scale = vs.get(&[], "logit_scale")?;
|
||||
Ok(Self {
|
||||
text_model,
|
||||
|
@ -76,6 +76,7 @@ pub mod rwkv_v5;
|
||||
pub mod rwkv_v6;
|
||||
pub mod segformer;
|
||||
pub mod segment_anything;
|
||||
pub mod siglip;
|
||||
pub mod stable_diffusion;
|
||||
pub mod stable_lm;
|
||||
pub mod starcoder2;
|
||||
|
608
candle-transformers/src/models/siglip.rs
Normal file
608
candle-transformers/src/models/siglip.rs
Normal file
@ -0,0 +1,608 @@
|
||||
use crate::models::clip::div_l2_norm;
|
||||
use candle::{IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, linear, LayerNorm, Linear, VarBuilder};
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L27
|
||||
#[derive(serde::Deserialize, Clone, Debug)]
|
||||
pub struct TextConfig {
|
||||
pub vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub max_position_embeddings: usize,
|
||||
pub hidden_act: candle_nn::Activation,
|
||||
pub layer_norm_eps: f64,
|
||||
pub pad_token_id: u32,
|
||||
pub bos_token_id: u32,
|
||||
pub eos_token_id: u32,
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L132
|
||||
#[derive(serde::Deserialize, Clone, Debug)]
|
||||
pub struct VisionConfig {
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_channels: usize,
|
||||
pub image_size: usize,
|
||||
pub patch_size: usize,
|
||||
pub hidden_act: candle_nn::Activation,
|
||||
pub layer_norm_eps: f64,
|
||||
}
|
||||
|
||||
trait TransformerConfig {
|
||||
fn hidden_size(&self) -> usize;
|
||||
fn intermediate_size(&self) -> usize;
|
||||
fn num_attention_heads(&self) -> usize;
|
||||
fn num_hidden_layers(&self) -> usize;
|
||||
fn layer_norm_eps(&self) -> f64;
|
||||
fn hidden_act(&self) -> candle_nn::Activation;
|
||||
}
|
||||
|
||||
impl TransformerConfig for TextConfig {
|
||||
fn hidden_size(&self) -> usize {
|
||||
self.hidden_size
|
||||
}
|
||||
fn intermediate_size(&self) -> usize {
|
||||
self.intermediate_size
|
||||
}
|
||||
fn num_attention_heads(&self) -> usize {
|
||||
self.num_attention_heads
|
||||
}
|
||||
fn num_hidden_layers(&self) -> usize {
|
||||
self.num_hidden_layers
|
||||
}
|
||||
fn layer_norm_eps(&self) -> f64 {
|
||||
self.layer_norm_eps
|
||||
}
|
||||
fn hidden_act(&self) -> candle_nn::Activation {
|
||||
self.hidden_act
|
||||
}
|
||||
}
|
||||
|
||||
impl TransformerConfig for VisionConfig {
|
||||
fn hidden_size(&self) -> usize {
|
||||
self.hidden_size
|
||||
}
|
||||
fn intermediate_size(&self) -> usize {
|
||||
self.intermediate_size
|
||||
}
|
||||
fn num_attention_heads(&self) -> usize {
|
||||
self.num_attention_heads
|
||||
}
|
||||
fn num_hidden_layers(&self) -> usize {
|
||||
self.num_hidden_layers
|
||||
}
|
||||
fn layer_norm_eps(&self) -> f64 {
|
||||
self.layer_norm_eps
|
||||
}
|
||||
fn hidden_act(&self) -> candle_nn::Activation {
|
||||
self.hidden_act
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L228
|
||||
#[derive(serde::Deserialize, Clone, Debug)]
|
||||
pub struct Config {
|
||||
pub text_config: TextConfig,
|
||||
pub vision_config: VisionConfig,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn base_patch16_224() -> Self {
|
||||
let text_config = TextConfig {
|
||||
// https://huggingface.co/google/siglip-base-patch16-224/blob/main/config.json
|
||||
hidden_size: 768,
|
||||
intermediate_size: 3072,
|
||||
num_attention_heads: 12,
|
||||
vocab_size: 32000,
|
||||
// Default values.
|
||||
pad_token_id: 1,
|
||||
bos_token_id: 49406,
|
||||
eos_token_id: 49407,
|
||||
layer_norm_eps: 1e-6,
|
||||
hidden_act: candle_nn::Activation::GeluPytorchTanh,
|
||||
max_position_embeddings: 64,
|
||||
num_hidden_layers: 12,
|
||||
};
|
||||
let vision_config = VisionConfig {
|
||||
patch_size: 16,
|
||||
// Default values.
|
||||
hidden_size: 768,
|
||||
intermediate_size: 3072,
|
||||
num_hidden_layers: 12,
|
||||
num_attention_heads: 12,
|
||||
num_channels: 3,
|
||||
image_size: 224,
|
||||
hidden_act: candle_nn::Activation::GeluPytorchTanh,
|
||||
layer_norm_eps: 1e-6,
|
||||
};
|
||||
Self {
|
||||
text_config,
|
||||
vision_config,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct MultiheadAttention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
out_proj: Linear,
|
||||
num_heads: usize,
|
||||
}
|
||||
|
||||
impl MultiheadAttention {
|
||||
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let w_in_proj = vb.get((3 * h, h), "in_proj_weight")?.chunk(3, 0)?;
|
||||
let b_in_proj = vb.get(3 * h, "in_proj_bias")?.chunk(3, 0)?;
|
||||
let q_proj = Linear::new(w_in_proj[0].clone(), Some(b_in_proj[0].clone()));
|
||||
let k_proj = Linear::new(w_in_proj[1].clone(), Some(b_in_proj[1].clone()));
|
||||
let v_proj = Linear::new(w_in_proj[2].clone(), Some(b_in_proj[2].clone()));
|
||||
let out_proj = linear(h, h, vb.pp("out_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
out_proj,
|
||||
num_heads,
|
||||
})
|
||||
}
|
||||
|
||||
fn separate_heads(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (b, n, c) = x.dims3()?;
|
||||
x.reshape((b, n, self.num_heads, c / self.num_heads))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()
|
||||
}
|
||||
|
||||
fn recombine_heads(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (b, n_heads, n_tokens, c_per_head) = x.dims4()?;
|
||||
x.transpose(1, 2)?
|
||||
.reshape((b, n_tokens, n_heads * c_per_head))
|
||||
}
|
||||
|
||||
fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
||||
let q = self.q_proj.forward(&q.contiguous()?)?;
|
||||
let k = self.k_proj.forward(&k.contiguous()?)?;
|
||||
let v = self.v_proj.forward(&v.contiguous()?)?;
|
||||
|
||||
let q = self.separate_heads(&q)?;
|
||||
let k = self.separate_heads(&k)?;
|
||||
let v = self.separate_heads(&v)?;
|
||||
|
||||
let (_, _, _, c_per_head) = q.dims4()?;
|
||||
let attn = (q.matmul(&k.t()?)? / (c_per_head as f64).sqrt())?;
|
||||
let attn = candle_nn::ops::softmax_last_dim(&attn)?;
|
||||
|
||||
let out = attn.matmul(&v)?;
|
||||
self.recombine_heads(&out)?.apply(&self.out_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct MultiheadAttentionPoolingHead {
|
||||
probe: Tensor,
|
||||
attention: MultiheadAttention,
|
||||
layernorm: LayerNorm,
|
||||
mlp: Mlp,
|
||||
}
|
||||
|
||||
impl MultiheadAttentionPoolingHead {
|
||||
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let mlp = Mlp::new(cfg, vb.pp("mlp"))?;
|
||||
let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("layernorm"))?;
|
||||
let probe = vb.get((1, 1, cfg.hidden_size), "probe")?;
|
||||
let attention = MultiheadAttention::new(cfg, vb.pp("attention"))?;
|
||||
Ok(Self {
|
||||
probe,
|
||||
attention,
|
||||
layernorm,
|
||||
mlp,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MultiheadAttentionPoolingHead {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let batch_size = xs.dim(0)?;
|
||||
let probe = self.probe.repeat((batch_size, 1, 1))?;
|
||||
let xs = self.attention.forward(&probe, xs, xs)?;
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(&self.layernorm)?.apply(&self.mlp)?;
|
||||
(xs + residual)?.i((.., 0))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
out_proj: Linear,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
scale: f64,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
|
||||
let embed_dim = cfg.hidden_size();
|
||||
let q_proj = linear(embed_dim, embed_dim, vb.pp("q_proj"))?;
|
||||
let k_proj = linear(embed_dim, embed_dim, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(embed_dim, embed_dim, vb.pp("v_proj"))?;
|
||||
let out_proj = linear(embed_dim, embed_dim, vb.pp("out_proj"))?;
|
||||
let num_heads = cfg.num_attention_heads();
|
||||
let head_dim = embed_dim / num_heads;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
out_proj,
|
||||
num_heads,
|
||||
head_dim,
|
||||
scale: (head_dim as f64).powf(-0.5),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let (batch_size, q_len, _) = xs.dims3()?;
|
||||
let query_states = xs.apply(&self.q_proj)?;
|
||||
let key_states = xs.apply(&self.k_proj)?;
|
||||
let value_states = xs.apply(&self.v_proj)?;
|
||||
|
||||
let shape = (batch_size, q_len, self.num_heads, self.head_dim);
|
||||
let query_states = query_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
|
||||
let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
|
||||
let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
|
||||
|
||||
let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?;
|
||||
let attn_weights = match attention_mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||
};
|
||||
// The original implementation upcasts to f32 but candle_nn::ops::softmax should handle this properly.
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
let attn_outputs = attn_weights
|
||||
.matmul(&value_states)?
|
||||
.transpose(1, 2)?
|
||||
.reshape((batch_size, q_len, ()))?
|
||||
.apply(&self.out_proj)?;
|
||||
Ok(attn_outputs)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/modeling_siglip.py#L599
|
||||
#[derive(Debug, Clone)]
|
||||
struct Mlp {
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
activation_fn: candle_nn::Activation,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_size = cfg.hidden_size();
|
||||
let intermediate_size = cfg.intermediate_size();
|
||||
let fc1 = candle_nn::linear(hidden_size, intermediate_size, vb.pp("fc1"))?;
|
||||
let fc2 = candle_nn::linear(intermediate_size, hidden_size, vb.pp("fc2"))?;
|
||||
Ok(Self {
|
||||
fc1,
|
||||
fc2,
|
||||
activation_fn: cfg.hidden_act(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &candle::Tensor) -> Result<candle::Tensor> {
|
||||
xs.apply(&self.fc1)?
|
||||
.apply(&self.activation_fn)?
|
||||
.apply(&self.fc2)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/modeling_siglip.py#L614
|
||||
#[derive(Debug, Clone)]
|
||||
struct EncoderLayer {
|
||||
self_attn: Attention,
|
||||
layer_norm1: LayerNorm,
|
||||
mlp: Mlp,
|
||||
layer_norm2: LayerNorm,
|
||||
}
|
||||
|
||||
impl EncoderLayer {
|
||||
fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_size = cfg.hidden_size();
|
||||
let layer_norm_eps = cfg.layer_norm_eps();
|
||||
let self_attn = Attention::new(cfg, vb.pp("self_attn"))?;
|
||||
let layer_norm1 = layer_norm(hidden_size, layer_norm_eps, vb.pp("layer_norm1"))?;
|
||||
let mlp = Mlp::new(cfg, vb.pp("mlp"))?;
|
||||
let layer_norm2 = layer_norm(hidden_size, layer_norm_eps, vb.pp("layer_norm2"))?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
layer_norm1,
|
||||
mlp,
|
||||
layer_norm2,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = xs.apply(&self.layer_norm1)?;
|
||||
let xs = self.self_attn.forward(&xs, attention_mask)?;
|
||||
let xs = (residual + xs)?;
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(&self.layer_norm2)?.apply(&self.mlp)?;
|
||||
let xs = (xs + residual)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Encoder {
|
||||
layers: Vec<EncoderLayer>,
|
||||
}
|
||||
|
||||
impl Encoder {
|
||||
fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
|
||||
let mut layers = vec![];
|
||||
let vb = vb.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers() {
|
||||
let layer = EncoderLayer::new(cfg, vb.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
Ok(Self { layers })
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs, attention_mask)?
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct VisionEmbeddings {
|
||||
patch_embedding: candle_nn::Conv2d,
|
||||
position_embedding: candle_nn::Embedding,
|
||||
position_ids: Tensor,
|
||||
}
|
||||
|
||||
impl VisionEmbeddings {
|
||||
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let conv2d_cfg = candle_nn::Conv2dConfig {
|
||||
stride: cfg.patch_size,
|
||||
..Default::default()
|
||||
};
|
||||
let patch_embedding = candle_nn::conv2d(
|
||||
cfg.num_channels,
|
||||
cfg.hidden_size,
|
||||
cfg.patch_size,
|
||||
conv2d_cfg,
|
||||
vb.pp("patch_embedding"),
|
||||
)?;
|
||||
let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
|
||||
let position_ids = Tensor::arange(0, num_patches as i64, vb.device())?;
|
||||
let position_embedding =
|
||||
candle_nn::embedding(num_patches, cfg.hidden_size(), vb.pp("position_embedding"))?;
|
||||
Ok(Self {
|
||||
patch_embedding,
|
||||
position_embedding,
|
||||
position_ids,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for VisionEmbeddings {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (_batch, _channels, _height, _width) = xs.dims4()?;
|
||||
let embeddings = xs.apply(&self.patch_embedding)?;
|
||||
let embeddings = embeddings.flatten_from(2)?.transpose(1, 2)?;
|
||||
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
|
||||
embeddings.broadcast_add(&position_embedding)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct VisionTransformer {
|
||||
embeddings: VisionEmbeddings,
|
||||
encoder: Encoder,
|
||||
post_layernorm: LayerNorm,
|
||||
head: Option<MultiheadAttentionPoolingHead>,
|
||||
}
|
||||
|
||||
impl VisionTransformer {
|
||||
fn new(cfg: &VisionConfig, use_head: bool, vb: VarBuilder) -> Result<Self> {
|
||||
let embeddings = VisionEmbeddings::new(cfg, vb.pp("embeddings"))?;
|
||||
let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
|
||||
let post_layernorm =
|
||||
layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("post_layernorm"))?;
|
||||
let head = if use_head {
|
||||
Some(MultiheadAttentionPoolingHead::new(cfg, vb.pp("head"))?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
encoder,
|
||||
post_layernorm,
|
||||
head,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for VisionTransformer {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.apply(&self.embeddings)?;
|
||||
let xs = self.encoder.forward(&xs, None)?;
|
||||
let xs = xs.apply(&self.post_layernorm)?;
|
||||
match self.head.as_ref() {
|
||||
None => Ok(xs),
|
||||
Some(h) => xs.apply(h),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VisionModel {
|
||||
vision_model: VisionTransformer,
|
||||
}
|
||||
|
||||
impl VisionModel {
|
||||
pub fn new(cfg: &VisionConfig, use_head: bool, vb: VarBuilder) -> Result<Self> {
|
||||
let vision_model = VisionTransformer::new(cfg, use_head, vb)?;
|
||||
Ok(Self { vision_model })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for VisionModel {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.vision_model)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TextEmbeddings {
|
||||
token_embedding: candle_nn::Embedding,
|
||||
position_embedding: candle_nn::Embedding,
|
||||
position_ids: Tensor,
|
||||
}
|
||||
|
||||
impl TextEmbeddings {
|
||||
fn new(cfg: &TextConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let token_embedding =
|
||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("token_embedding"))?;
|
||||
let position_embedding = candle_nn::embedding(
|
||||
cfg.max_position_embeddings,
|
||||
cfg.hidden_size,
|
||||
vb.pp("position_embedding"),
|
||||
)?;
|
||||
let position_ids =
|
||||
Tensor::arange(0u32, cfg.max_position_embeddings as u32, vb.device())?.unsqueeze(0)?;
|
||||
Ok(Self {
|
||||
token_embedding,
|
||||
position_embedding,
|
||||
position_ids,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for TextEmbeddings {
|
||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let seq_length = input_ids.dim(D::Minus1)?;
|
||||
let inputs_embeds = self.token_embedding.forward(input_ids)?;
|
||||
let position_ids = self.position_ids.narrow(1, 0, seq_length)?;
|
||||
let position_embedding = self.position_embedding.forward(&position_ids)?;
|
||||
inputs_embeds.broadcast_add(&position_embedding)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TextTransformer {
|
||||
embeddings: TextEmbeddings,
|
||||
encoder: Encoder,
|
||||
final_layer_norm: LayerNorm,
|
||||
pub head: Linear,
|
||||
}
|
||||
|
||||
impl TextTransformer {
|
||||
fn new(cfg: &TextConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let embeddings = TextEmbeddings::new(cfg, vb.pp("embeddings"))?;
|
||||
let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
|
||||
let final_layer_norm = layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layer_norm_eps,
|
||||
vb.pp("final_layer_norm"),
|
||||
)?;
|
||||
let head = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("head"))?;
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
encoder,
|
||||
final_layer_norm,
|
||||
head,
|
||||
})
|
||||
}
|
||||
}
|
||||
impl Module for TextTransformer {
|
||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let (_bsz, seq_len) = input_ids.dims2()?;
|
||||
let input_ids = self.embeddings.forward(input_ids)?;
|
||||
let input_ids = self.encoder.forward(&input_ids, None)?;
|
||||
let last_hidden_state = self.final_layer_norm.forward(&input_ids)?;
|
||||
last_hidden_state
|
||||
.i((.., seq_len - 1, ..))?
|
||||
.contiguous()?
|
||||
.apply(&self.head)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TextModel {
|
||||
pub text_model: TextTransformer,
|
||||
}
|
||||
|
||||
impl TextModel {
|
||||
pub fn new(cfg: &TextConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let text_model = TextTransformer::new(cfg, vb)?;
|
||||
Ok(Self { text_model })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for TextModel {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.text_model)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Model {
|
||||
text_model: TextModel,
|
||||
vision_model: VisionModel,
|
||||
logit_bias: Tensor,
|
||||
logit_scale: Tensor,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let text_model = TextModel::new(&cfg.text_config, vb.pp("text_model"))?;
|
||||
let vision_model = VisionModel::new(&cfg.vision_config, true, vb.pp("vision_model"))?;
|
||||
let logit_scale = vb.get(&[1], "logit_scale")?;
|
||||
let logit_bias = vb.get(&[1], "logit_bias")?;
|
||||
Ok(Self {
|
||||
text_model,
|
||||
vision_model,
|
||||
logit_bias,
|
||||
logit_scale,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
input_ids.apply(&self.text_model)
|
||||
}
|
||||
|
||||
pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
||||
pixel_values.apply(&self.vision_model)
|
||||
}
|
||||
|
||||
pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
let image_features = self.get_image_features(pixel_values)?;
|
||||
let text_features = self.get_text_features(input_ids)?;
|
||||
let image_features_normalized = div_l2_norm(&image_features)?;
|
||||
let text_features_normalized = div_l2_norm(&text_features)?;
|
||||
let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
|
||||
let logit_scale = self.logit_scale.exp()?;
|
||||
let logits_per_text = logits_per_text
|
||||
.broadcast_mul(&logit_scale)?
|
||||
.broadcast_add(&self.logit_bias)?;
|
||||
let logits_per_image = logits_per_text.t()?;
|
||||
Ok((logits_per_text, logits_per_image))
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user