Compare commits

..

1 Commits

Author SHA1 Message Date
f7980abbcd Improve the sampling methods. 2024-05-04 10:53:30 +02:00
277 changed files with 1804 additions and 28718 deletions

View File

@ -18,9 +18,9 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest] # For now, only test on Linux
steps:
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v2
- name: Install Rust
uses: actions-rs/toolchain@v1
@ -65,4 +65,4 @@ jobs:
working-directory: ./candle-pyo3
run: |
source .env/bin/activate
python -m pytest -s -v tests
python -m pytest -s -v tests

View File

@ -1,6 +1,6 @@
on:
on:
push:
branches:
branches:
- main
pull_request:
@ -15,7 +15,7 @@ jobs:
os: [ubuntu-latest, windows-latest, macOS-latest]
rust: [stable]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
@ -34,7 +34,7 @@ jobs:
os: [ubuntu-latest, windows-latest, macOS-latest]
rust: [stable]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
@ -49,7 +49,7 @@ jobs:
name: Rustfmt
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
@ -65,7 +65,7 @@ jobs:
name: Clippy
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal

View File

@ -1,15 +0,0 @@
on:
push:
name: Secret Leaks
jobs:
trufflehog:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Secret Scanning
uses: trufflesecurity/trufflehog@main

10
.gitignore vendored
View File

@ -9,10 +9,6 @@ target/
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
Cargo.lock
# editor config
.helix
.vscode
# These are backup files generated by rustfmt
**/*.rs.bk
@ -40,9 +36,3 @@ candle-wasm-examples/*/package-lock.json
candle-wasm-examples/**/config*.json
.DS_Store
.idea/*
__pycache__
out.safetensors
out.wav
bria.mp3
bria.safetensors
bria.wav

View File

@ -20,7 +20,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.7.1"
version = "0.5.1"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -33,23 +33,22 @@ ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.7.1" }
candle-datasets = { path = "./candle-datasets", version = "0.7.1" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.7.1" }
candle-kernels = { path = "./candle-kernels", version = "0.7.1" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.7.1" }
candle-nn = { path = "./candle-nn", version = "0.7.1" }
candle-onnx = { path = "./candle-onnx", version = "0.7.1" }
candle-transformers = { path = "./candle-transformers", version = "0.7.1" }
candle = { path = "./candle-core", package = "candle-core", version = "0.5.1" }
candle-datasets = { path = "./candle-datasets", version = "0.5.1" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.5.1" }
candle-kernels = { path = "./candle-kernels", version = "0.5.1" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.5.1" }
candle-nn = { path = "./candle-nn", version = "0.5.1" }
candle-onnx = { path = "./candle-onnx", version = "0.5.1" }
candle-transformers = { path = "./candle-transformers", version = "0.5.1" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
cudarc = { version = "0.10.0", features = ["f16"] }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.3.0"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
hound = "3.5.1"
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
image = { version = "0.25.0", default-features = false, features = ["jpeg", "png"] }
imageproc = { version = "0.24.0", default-features = false }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
libc = { version = "0.2.147" }
@ -70,6 +69,7 @@ tokenizers = { version = "0.19.1", default-features = false }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "1.1.1", default-features = false }
metal = { version = "0.27.0", features = ["mps"]}

View File

@ -63,9 +63,7 @@ We also provide a some command line based examples using state of the art models
- [LLaMA v1, v2, and v3](./candle-examples/examples/llama/): general LLM, includes
the SOLAR-10.7B variant.
- [Falcon](./candle-examples/examples/falcon/): general LLM.
- [Codegeex4](./candle-examples/examples/codegeex4-9b/): Code completion,code interpreter,web search,fuction calling,repository-level
- [GLM4](./candle-examples/examples/glm4/): Open Multilingual Multimodal Chat LMs by THUDM
- [Gemma v1 and v2](./candle-examples/examples/gemma/): 2b and 7b+/9b general LLMs from Google Deepmind.
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google Deepmind.
- [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b
Griffin based models from Google that mix attention with a RNN like state.
- [Phi-1, Phi-1.5, Phi-2, and Phi-3](./candle-examples/examples/phi/): 1.3b,
@ -120,8 +118,6 @@ We also provide a some command line based examples using state of the art models
model using residual vector quantization.
- [MetaVoice](./candle-examples/examples/metavoice/): foundational model for
text-to-speech.
- [Parler-TTS](./candle-examples/examples/parler-tts/): large text-to-speech
model.
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
@ -210,7 +206,7 @@ If you have an addition to this list, please submit a pull request.
- StarCoder, StarCoder2.
- Phi 1, 1.5, 2, and 3.
- Mamba, Minimal Mamba
- Gemma v1 2b and 7b+, v2 2b and 9b.
- Gemma 2b and 7b.
- Mistral 7b v0.1.
- Mixtral 8x7b v0.1.
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
@ -238,10 +234,9 @@ If you have an addition to this list, please submit a pull request.
- Whisper, multi-lingual speech-to-text.
- EnCodec, audio compression model.
- MetaVoice-1B, text-to-speech model.
- Parler-TTS, text-to-speech model.
- Computer Vision Models.
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT,
ConvNeXTv2, MobileOne, EfficientVit (MSRA), MobileNetv4, Hiera, FastViT.
ConvNeXTv2, MobileOne, EfficientVit (MSRA).
- yolo-v3, yolo-v8.
- Segment-Anything Model (SAM).
- SegFormer.
@ -413,10 +408,3 @@ This may be caused by the models being loaded from `/mnt/c`, more details on
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
error is generated.
#### CudaRC error
If you encounter an error like this one `called `Result::unwrap()` on an `Err` value: LoadLibraryExW { source: Os { code: 126, kind: Uncategorized, message: "The specified module could not be found." } }` on windows. To fix copy and rename these 3 files (make sure they are in path). The paths depend on your cuda version.
`c:\Windows\System32\nvcuda.dll` -> `cuda.dll`
`c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\cublas64_12.dll` -> `cublas.dll`
`c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\curand64_10.dll` -> `curand.dll`

View File

@ -37,6 +37,7 @@ tokenizers = { workspace = true, features = ["onig"] }
tracing = { workspace = true }
tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true }
wav = { workspace = true }
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
parquet = { workspace = true }
image = { workspace = true }

View File

@ -106,8 +106,8 @@ let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).un
}
}
#[allow(unused)]
#[rustfmt::skip]
#[test]
fn book_training_1() -> Result<()>{
// ANCHOR: book_training_1
use hf_hub::{api::sync::Api, Repo, RepoType};

View File

@ -48,7 +48,3 @@ metal = ["dep:metal", "dep:candle-metal-kernels"]
[[bench]]
name = "bench_main"
harness = false
[[example]]
name = "metal_basics"
required-features = ["metal"]

View File

@ -12,7 +12,7 @@ fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name:
let m = 1024;
let k = 1024;
let tensor = Tensor::zeros((b, m, k), dtype, device).unwrap();
let tensor = Tensor::zeros((b, m, k), dtype, &device).unwrap();
let flops = b * m * k * dtype.size_in_bytes();

View File

@ -7,7 +7,7 @@ use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(matmul: &QMatMul, x: &Tensor) {
matmul.forward(x).unwrap();
matmul.forward(&x).unwrap();
}
fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
@ -50,7 +50,7 @@ fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
for dtype in [
for dtype in vec![
GgmlDType::F32,
GgmlDType::F16,
GgmlDType::Q4_0,

View File

@ -12,7 +12,7 @@ fn run_unary_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &
let m = 1024;
let k = 1024;
let tensor = Tensor::arange(0.0f32, (b * m * k) as f32, device)
let tensor = Tensor::arange(0.0f32, (b * m * k) as f32, &device)
.unwrap()
.to_dtype(dtype)
.unwrap()

View File

@ -25,9 +25,9 @@ const SIZE: usize = B * M * K;
const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), device).unwrap();
let on_true = Tensor::ones((B, M, K), dtype, device).unwrap();
let on_false = Tensor::zeros((B, M, K), dtype, device).unwrap();
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), &device).unwrap();
let on_true = Tensor::ones((B, M, K), dtype, &device).unwrap();
let on_false = Tensor::zeros((B, M, K), dtype, &device).unwrap();
let elements = B * M * K;
// E.g. 2 f32 tensors + 1 u8 tensor

View File

@ -5,29 +5,32 @@ extern crate accelerate_src;
extern crate intel_mkl_src;
use anyhow::Result;
use candle_core::{Device, Tensor};
use candle_core::{Device, Module, Tensor};
use candle_core::quantized::{QMatMul, QTensor};
fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
.to_dtype(candle_core::DType::BF16)?;
candle_core::cuda::set_gemm_reduced_precision_f32(false);
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
let _x1 = x.matmul(&x)?;
drop(_x1);
let start_time = std::time::Instant::now();
let _x1 = x.matmul(&x)?;
device.synchronize()?;
println!("fp32: {:?}", start_time.elapsed());
drop(_x1);
candle_core::cuda::set_gemm_reduced_precision_f32(true);
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
let _x1 = x.matmul(&x)?;
drop(_x1);
let start_time = std::time::Instant::now();
let _x1 = x.matmul(&x)?;
device.synchronize()?;
println!("tf32: {:?}", start_time.elapsed());
drop(_x1);
let q = Tensor::randn(0f32, 1.0, (72, 256), &device)?;
let q_cpu = q.to_device(&Device::Cpu)?;
let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q8K)?;
let q = QMatMul::from_qtensor(q)?;
let x = Tensor::randn(0f32, 1.0, (5, 256), &device)?;
let res_q_cuda = q.forward(&x)?;
println!("{res_q_cuda}");
let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q8K)?;
let q_cpu_tensor = q_cpu.dequantize(&Device::Cpu)?;
let q_cpu = QMatMul::from_qtensor(q_cpu)?;
let x_cpu = x.to_device(&Device::Cpu)?;
let res_q_cpu = q_cpu.forward(&x_cpu)?;
println!("{res_q_cpu}");
let res_mm = x_cpu.matmul(&q_cpu_tensor.t()?)?;
let diff = (res_mm - res_q_cuda.to_device(&Device::Cpu))?
.abs()?
.flatten_all()?
.max(0)?;
println!("{diff}");
Ok(())
}

View File

@ -1,28 +0,0 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use anyhow::Result;
use candle_core::{Device, Tensor};
fn main() -> Result<()> {
// This requires the code to be run with MTL_CAPTURE_ENABLED=1
let device = Device::new_metal(0)?;
let metal_device = match &device {
Device::Metal(m) => m,
_ => anyhow::bail!("unexpected device"),
};
metal_device.capture("/tmp/candle.gputrace")?;
// This first synchronize ensures that a new command buffer gets created after setting up the
// capture scope.
device.synchronize()?;
let x = Tensor::randn(0f32, 1.0, (128, 128), &device)?;
let x1 = x.add(&x)?;
println!("{x1:?}");
// This second synchronize ensures that the command buffer gets commited before the end of the
// capture scope.
device.synchronize()?;
Ok(())
}

View File

@ -320,13 +320,13 @@ impl Tensor {
dilation,
output_padding: _output_padding,
} => {
let grad_arg = grad.conv2d(kernel, *padding, *stride, *dilation, 1)?;
let grad_arg = grad.conv2d(kernel, *padding, *dilation, *stride, 1)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
let grad_kernel = grad
.transpose(0, 1)?
.conv2d(&arg.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.conv2d(&arg.transpose(0, 1)?, *padding, *stride, *dilation, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
let (_, _, k0, k1) = kernel.dims4()?;
@ -623,9 +623,9 @@ impl Tensor {
}
Op::Unary(arg, UnaryOp::Silu) => {
let sum_grad = grads.or_insert(arg)?;
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x))) = sigmoid(x) * (1 - node) + node
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
let silu_grad = &sigmoid_arg * (1. - *node) + *node;
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
}
Op::Elu(arg, alpha) => {
@ -634,8 +634,7 @@ impl Tensor {
let zeros = arg.zeros_like()?;
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
// node == alpha * (e^x - 1) for x <= 0, reuse it
let negative_exp_mask = (negative_mask * (*node + *alpha))?;
let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?;
let combined_mask = (positive_mask + negative_exp_mask)?;
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
}
@ -756,9 +755,4 @@ impl GradStore {
};
Ok(grad)
}
/// Get the tensor ids of the stored gradient tensors
pub fn get_ids(&self) -> impl Iterator<Item = &TensorId> {
self.0.keys()
}
}

View File

@ -10,7 +10,7 @@ pub use utils::{
};
const USE_IM2COL_CONV1D: bool = true;
const USE_COL2IM_CONV1D_TR: bool = true;
const USE_IM2COL_CONV1D_TR: bool = true;
const USE_IM2COL_CONV2D: bool = true;
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
@ -121,8 +121,7 @@ impl ReduceIndex {
let dst_len = src_l.shape().elem_count() / reduce_dim_size;
let mut dst: Vec<U> = Vec::with_capacity(dst_len);
let dst_to_set = dst.spare_capacity_mut();
let dst_to_set =
unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(dst_to_set) };
let dst_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(dst_to_set) };
match src_l.contiguous_offsets() {
Some((o1, o2)) => {
let src = &src[o1..o2];
@ -2250,7 +2249,7 @@ impl BackendStorage for CpuStorage {
&& params.dilation == 1
&& params.padding == 0
&& params.output_padding == 0;
if USE_COL2IM_CONV1D_TR && can_use_col2im {
if USE_IM2COL_CONV1D_TR && can_use_col2im {
let (b_size, c_in, l_in) = l.shape().dims3()?;
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
if !kernel_l.is_contiguous() {

View File

@ -174,9 +174,7 @@ pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe {
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
};
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
@ -187,9 +185,7 @@ pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [
let rhs = &rhs[ob.start..ob.start + ob.len];
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe {
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
};
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
let mut dst_i = 0;
for src_i in (o_l1..o_l2).step_by(ob.len) {
f_vec(
@ -228,9 +224,7 @@ pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [
let lhs = &lhs[ob.start..ob.start + ob.len];
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe {
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
};
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
let mut dst_i = 0;
for src_i in (o_r1..o_r2).step_by(ob.len) {
f_vec(
@ -317,9 +311,7 @@ pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U
crate::StridedBlocks::SingleBlock { start_offset, len } => {
let mut ys: Vec<U> = Vec::with_capacity(len);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe {
std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
};
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(len) };
@ -341,9 +333,7 @@ pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U
} else {
let mut ys: Vec<U> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe {
std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
};
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
let mut dst_index = 0;
for src_index in block_start_index {
let vs = &vs[src_index..src_index + block_len];

View File

@ -1,6 +1,6 @@
use crate::WithDType;
use cudarc;
use cudarc::cudnn::safe::{ConvForward, Cudnn};
use cudarc::cudnn::safe::{Conv2dForward, Cudnn};
use cudarc::driver::{CudaSlice, CudaView, DeviceRepr, ValidAsZeroBits};
use std::cell::RefCell;
use std::collections::HashMap;
@ -87,7 +87,7 @@ pub(crate) fn launch_conv2d<
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[params.b_size as i32, params.c_out as i32, h_out, w_out],
)?;
let conv2d = ConvForward {
let conv2d = Conv2dForward {
conv: &conv,
x: &x,
w: &w,

View File

@ -16,7 +16,7 @@ mod error;
mod utils;
pub use device::{CudaDevice, DeviceId};
pub use error::{CudaError, WrapErr};
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, Map3, S};
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S};
pub enum SlicePtrOrNull<T> {
Ptr(CudaSlice<T>),
@ -174,7 +174,6 @@ impl Map1 for Im2Col1D {
}
}
#[allow(unused)]
struct Im2Col {
h_k: usize,
w_k: usize,
@ -184,7 +183,6 @@ struct Im2Col {
}
impl Im2Col {
#[allow(unused)]
fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
@ -632,31 +630,6 @@ impl<'a> Map2 for Conv2D<'a> {
}
}
struct Col2Im1D {
stride: usize,
}
impl Map1 for Col2Im1D {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
col: &CudaSlice<T>,
dev: &CudaDevice,
l: &Layout,
) -> Result<CudaSlice<T>> {
let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
let stride = self.stride;
let l_out = (l_in - 1) * stride + k_size;
let dst_el = b_size * c_out * l_out;
let mut im = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let params = (dst_el, l_out, l_in, c_out, k_size, stride, col, &mut im);
let func = dev.get_or_load_func(&kernel_name::<T>("col2im1d"), kernels::CONV)?;
unsafe { func.launch(cfg, params) }.w()?;
Ok(im)
}
}
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
impl<'a> Map2 for ConvTranspose1D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
@ -1393,55 +1366,9 @@ impl BackendStorage for CudaStorage {
kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
const USE_COL2IM_CONV1D_TR: bool = true;
let device = self.device().clone();
let can_use_col2im = kernel_l.is_contiguous()
&& params.dilation == 1
&& params.padding == 0
&& params.output_padding == 0;
let slice = if USE_COL2IM_CONV1D_TR && can_use_col2im {
let (b_size, c_in, l_in) = l.shape().dims3()?;
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
if !kernel_l.is_contiguous() {
crate::bail!(
"convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}"
)
}
if c_in != c_in2 {
crate::bail!(
"convtr1d: shape mismatch on c_in {:?} {:?}",
l.shape(),
kernel_l.shape()
)
}
let col = {
// This merges the last two dimensions of the kernel together.
let kernel_l_mm = Layout::new(
(b_size, c_in, k_size * c_out).into(),
vec![0, k_size * c_out, 1],
kernel_l.start_offset(),
);
self.matmul(
kernel,
(
b_size,
/* m */ l_in,
/* n */ c_out * k_size,
/* k */ c_in,
),
&l.transpose(1, 2)?,
&kernel_l_mm,
)?
};
let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
Col2Im1D {
stride: params.stride,
}
.map(&col.slice, &device, &col_l)?
} else {
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?
};
let slice =
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
Ok(Self { slice, device })
}
@ -1688,8 +1615,12 @@ impl BackendStorage for CudaStorage {
let rhs = &rhs.slice(rhs_l.start_offset()..);
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
let mut out = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
unsafe { gemm_strided_batched_f32(&self.device.blas, cfg, rhs, lhs, &mut out) }
.w()?;
unsafe {
self.device
.blas
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
}
.w()?;
CudaStorageSlice::F32(out)
}
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
@ -1886,20 +1817,6 @@ static MM_F16_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(false);
static MM_BF16_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(false);
static MM_F32_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(false);
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
/// allowed with f32 GEMMs.
pub fn gemm_reduced_precision_f32() -> bool {
MM_F32_REDUCED_PRECISION.load(std::sync::atomic::Ordering::Relaxed)
}
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
/// allowed with f32 GEMMs.
pub fn set_gemm_reduced_precision_f32(b: bool) {
MM_F32_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
}
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with f16 GEMMs.
@ -1925,51 +1842,6 @@ pub fn set_gemm_reduced_precision_bf16(b: bool) {
MM_BF16_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
}
unsafe fn gemm_strided_batched_f32(
cublas: &cudarc::cublas::CudaBlas,
cfg: StridedBatchedConfig<f32>,
a: &cudarc::driver::CudaView<f32>,
b: &cudarc::driver::CudaView<f32>,
c: &mut CudaSlice<f32>,
) -> std::result::Result<(), cudarc::cublas::result::CublasError> {
use cudarc::cublas::sys;
use cudarc::driver::DevicePtrMut;
let compute_type = if gemm_reduced_precision_f32() {
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32
} else {
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
};
let alpha = &cfg.gemm.alpha as *const f32 as *const _;
let beta = &cfg.gemm.beta as *const f32 as *const _;
cudarc::cublas::result::gemm_strided_batched_ex(
*cublas.handle(),
cfg.gemm.transa,
cfg.gemm.transb,
cfg.gemm.m,
cfg.gemm.n,
cfg.gemm.k,
alpha,
*a.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_32F,
cfg.gemm.lda,
cfg.stride_a,
*b.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_32F,
cfg.gemm.ldb,
cfg.stride_b,
beta,
*c.device_ptr_mut() as *mut _,
sys::cudaDataType_t::CUDA_R_32F,
cfg.gemm.ldc,
cfg.stride_c,
cfg.batch_size,
compute_type,
sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP,
)
}
unsafe fn gemm_strided_batched_f16(
cublas: &cudarc::cublas::CudaBlas,
cfg: StridedBatchedConfig<f16>,
@ -2037,13 +1909,15 @@ unsafe fn gemm_strided_batched_bf16(
let alpha_f32: f32 = cfg.gemm.alpha.to_f32();
let beta_f32: f32 = cfg.gemm.beta.to_f32();
let alpha = f16::from_f32(alpha_f32);
let beta = f16::from_f32(beta_f32);
// The type for alpha and beta depends on the computeType.
// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex
let (compute_type, alpha, beta) = if gemm_reduced_precision_bf16() {
(
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16BF,
(&alpha_f32) as *const f32 as *const _,
(&beta_f32) as *const f32 as *const _,
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F,
(&alpha) as *const f16 as *const _,
(&beta) as *const f16 as *const _,
)
} else {
(

View File

@ -54,44 +54,6 @@ pub trait Map2 {
}
}
pub trait Map3 {
#[allow(clippy::too_many_arguments)]
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
layout1: &Layout,
src2: &CudaSlice<T>,
layout2: &Layout,
src3: &CudaSlice<T>,
layout3: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>>;
#[allow(clippy::too_many_arguments)]
fn map(
&self,
s1: &S,
l1: &Layout,
s2: &S,
l2: &Layout,
s3: &S,
l3: &Layout,
d: &CudaDevice,
) -> Result<S> {
let out = match (s1, s2, s3) {
(S::U8(s1), S::U8(s2), S::U8(s3)) => S::U8(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::U32(s1), S::U32(s2), S::U32(s3)) => S::U32(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::I64(s1), S::I64(s2), S::I64(s3)) => S::I64(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::BF16(s1), S::BF16(s2), S::BF16(s3)) => S::BF16(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::F16(s1), S::F16(s2), S::F16(s3)) => S::F16(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::F32(s1), S::F32(s2), S::F32(s3)) => S::F32(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::F64(s1), S::F64(s2), S::F64(s3)) => S::F64(self.f(s1, l1, s2, l2, s3, l3, d)?),
_ => Err(CudaError::InternalError("dtype mismatch in ternary op"))?,
};
Ok(out)
}
}
pub trait Map2InPlace {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,

View File

@ -171,22 +171,6 @@ impl Device {
matches!(self, Self::Metal(_))
}
pub fn supports_bf16(&self) -> bool {
match self {
Self::Cuda(_) | Self::Metal(_) => true,
Self::Cpu => false,
}
}
/// Return `BF16` for devices that support it, otherwise default to `F32`.
pub fn bf16_default_to_f32(&self) -> DType {
if self.supports_bf16() {
DType::BF16
} else {
DType::F32
}
}
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
if crate::utils::cuda_is_available() {
Self::new_cuda(ordinal)

View File

@ -258,13 +258,3 @@ pub fn gemm_reduced_precision_bf16() -> bool {
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with bf16 GEMMs.
pub fn set_gemm_reduced_precision_bf16(_: bool) {}
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
/// allowed with f32 GEMMs.
pub fn gemm_reduced_precision_f32() -> bool {
true
}
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
/// allowed with f32 GEMMs.
pub fn set_gemm_reduced_precision_f32(_b: bool) {}

View File

@ -141,117 +141,28 @@ impl<T> IndexOp<T> for Tensor
where
T: Into<TensorIndexer>,
{
///```rust
/// use candle_core::{Tensor, DType, Device, IndexOp};
/// let a = Tensor::new(&[
/// [0., 1.],
/// [2., 3.],
/// [4., 5.]
/// ], &Device::Cpu)?;
///
/// let b = a.i(0)?;
/// assert_eq!(b.shape().dims(), &[2]);
/// assert_eq!(b.to_vec1::<f64>()?, &[0., 1.]);
///
/// let c = a.i(..2)?;
/// assert_eq!(c.shape().dims(), &[2, 2]);
/// assert_eq!(c.to_vec2::<f64>()?, &[
/// [0., 1.],
/// [2., 3.]
/// ]);
///
/// let d = a.i(1..)?;
/// assert_eq!(d.shape().dims(), &[2, 2]);
/// assert_eq!(d.to_vec2::<f64>()?, &[
/// [2., 3.],
/// [4., 5.]
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
fn i(&self, index: T) -> Result<Tensor, Error> {
self.index(&[index.into()])
}
}
impl<A> IndexOp<(A,)> for Tensor
where
A: Into<TensorIndexer>,
{
///```rust
/// use candle_core::{Tensor, DType, Device, IndexOp};
/// let a = Tensor::new(&[
/// [0f32, 1.],
/// [2. , 3.],
/// [4. , 5.]
/// ], &Device::Cpu)?;
///
/// let b = a.i((0,))?;
/// assert_eq!(b.shape().dims(), &[2]);
/// assert_eq!(b.to_vec1::<f32>()?, &[0., 1.]);
///
/// let c = a.i((..2,))?;
/// assert_eq!(c.shape().dims(), &[2, 2]);
/// assert_eq!(c.to_vec2::<f32>()?, &[
/// [0., 1.],
/// [2., 3.]
/// ]);
///
/// let d = a.i((1..,))?;
/// assert_eq!(d.shape().dims(), &[2, 2]);
/// assert_eq!(d.to_vec2::<f32>()?, &[
/// [2., 3.],
/// [4., 5.]
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
fn i(&self, (a,): (A,)) -> Result<Tensor, Error> {
self.index(&[a.into()])
}
}
#[allow(non_snake_case)]
impl<A, B> IndexOp<(A, B)> for Tensor
where
A: Into<TensorIndexer>,
B: Into<TensorIndexer>,
{
///```rust
/// use candle_core::{Tensor, DType, Device, IndexOp};
/// let a = Tensor::new(&[[0f32, 1., 2.], [3., 4., 5.], [6., 7., 8.]], &Device::Cpu)?;
///
/// let b = a.i((1, 0))?;
/// assert_eq!(b.to_vec0::<f32>()?, 3.);
///
/// let c = a.i((..2, 1))?;
/// assert_eq!(c.shape().dims(), &[2]);
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
///
/// let d = a.i((2.., ..))?;
/// assert_eq!(c.shape().dims(), &[2]);
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
/// # Ok::<(), candle_core::Error>(())
/// ```
fn i(&self, (a, b): (A, B)) -> Result<Tensor, Error> {
self.index(&[a.into(), b.into()])
}
}
macro_rules! index_op_tuple {
($doc:tt, $($t:ident),+) => {
($($t:ident),+) => {
#[allow(non_snake_case)]
impl<$($t),*> IndexOp<($($t,)*)> for Tensor
where
$($t: Into<TensorIndexer>,)*
{
#[doc=$doc]
fn i(&self, ($($t,)*): ($($t,)*)) -> Result<Tensor, Error> {
self.index(&[$($t.into(),)*])
}
}
};
}
index_op_tuple!("see [TensorIndex#method.i]", A, B, C);
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D);
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E);
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F);
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F, G);
index_op_tuple!(A);
index_op_tuple!(A, B);
index_op_tuple!(A, B, C);
index_op_tuple!(A, B, C, D);
index_op_tuple!(A, B, C, D, E);
index_op_tuple!(A, B, C, D, E, F);
index_op_tuple!(A, B, C, D, E, F, G);

View File

@ -65,7 +65,6 @@ pub mod scalar;
pub mod shape;
mod sort;
mod storage;
pub mod streaming;
mod strided_index;
mod tensor;
mod tensor_cat;
@ -81,11 +80,10 @@ pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, Inp
pub use device::{Device, DeviceLocation, NdArray};
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
pub use error::{Error, Result};
pub use indexer::{IndexOp, TensorIndexer};
pub use indexer::IndexOp;
pub use layout::Layout;
pub use shape::{Shape, D};
pub use storage::Storage;
pub use streaming::{StreamTensor, StreamingBinOp, StreamingModule};
pub use strided_index::{StridedBlocks, StridedIndex};
pub use tensor::{Tensor, TensorId};
pub use variable::Var;

View File

@ -4,7 +4,7 @@ use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}
use std::collections::HashMap;
use std::ffi::c_void;
use std::path::Path;
use std::sync::{Arc, Mutex, RwLock};
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard};
use super::MetalError;
@ -22,73 +22,7 @@ impl DeviceId {
}
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
pub(crate) struct Commands {
/// Single command queue for the entire device.
command_queue: CommandQueue,
/// One command buffer at a time.
/// The scheduler works by allowing multiple
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
/// to start to work).
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
/// for their START time, but there's no guarantee that command buffer1 will finish before
/// command buffer2 starts (or there are metal bugs there)
command_buffer: CommandBuffer,
/// Keeps track of the current amount of compute command encoders on the current
/// command buffer
/// Arc, RwLock because of the interior mutability.
command_buffer_index: usize,
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
compute_per_buffer: usize,
}
impl Commands {
pub(crate) fn new(command_queue: CommandQueue) -> Result<Self> {
let command_buffer = command_queue.new_command_buffer().to_owned();
command_buffer.enqueue();
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
Ok(val) => val.parse()?,
_ => 50,
};
Ok(Self {
command_queue,
command_buffer,
command_buffer_index: 0,
compute_per_buffer,
})
}
pub fn command_buffer(&mut self) -> Result<(bool, CommandBuffer)> {
let mut command_buffer = self.command_buffer.to_owned();
let mut flushed = false;
if self.command_buffer_index > self.compute_per_buffer {
self.command_buffer.commit();
command_buffer = self.command_queue.new_command_buffer().to_owned();
self.command_buffer = command_buffer.clone();
self.command_buffer_index = 0;
flushed = true;
}
self.command_buffer_index += 1;
Ok((flushed, command_buffer))
}
pub fn wait_until_completed(&mut self) -> Result<()> {
match self.command_buffer.status() {
metal::MTLCommandBufferStatus::Committed
| metal::MTLCommandBufferStatus::Scheduled
| metal::MTLCommandBufferStatus::Completed => {
panic!("Already committed");
}
_ => {}
}
self.command_buffer.commit();
self.command_buffer.wait_until_completed();
self.command_buffer = self.command_queue.new_command_buffer().to_owned();
Ok(())
}
}
type AllocatedBuffers = Arc<RwLock<BufferMap>>;
#[derive(Clone)]
pub struct MetalDevice {
@ -99,8 +33,27 @@ pub struct MetalDevice {
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
pub(crate) device: metal::Device,
pub(crate) commands: Arc<RwLock<Commands>>,
/// Single command queue for the entire device.
pub(crate) command_queue: CommandQueue,
/// One command buffer at a time.
/// The scheduler works by allowing multiple
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
/// to start to work).
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
/// for their START time, but there's no guarantee that command buffer1 will finish before
/// command buffer2 starts (or there are metal bugs there)
pub(crate) command_buffer: Arc<RwLock<CommandBuffer>>,
/// Keeps track of the current amount of compute command encoders on the current
/// command buffer
/// Arc, RwLock because of the interior mutability.
pub(crate) command_buffer_index: Arc<RwLock<usize>>,
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
pub(crate) compute_per_buffer: usize,
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
/// Heavily used by [`candle_metal_kernels`]
pub(crate) kernels: Arc<Kernels>,
/// Simple allocator struct.
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
@ -114,15 +67,9 @@ pub struct MetalDevice {
///
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
/// (strong_count = 1).
pub(crate) buffers: Arc<RwLock<BufferMap>>,
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
/// Heavily used by [`candle_metal_kernels`]
pub(crate) kernels: Arc<Kernels>,
pub(crate) buffers: AllocatedBuffers,
/// Seed for random number generation.
pub(crate) seed: Arc<Mutex<Buffer>>,
/// Whether to use the MLX matmul kernels instead of the MFA ones.
pub(crate) use_mlx_mm: bool,
}
impl std::fmt::Debug for MetalDevice {
@ -140,10 +87,6 @@ impl std::ops::Deref for MetalDevice {
}
impl MetalDevice {
pub fn set_use_mlx_mm(&mut self, use_mlx_mm: bool) {
self.use_mlx_mm = use_mlx_mm
}
pub fn id(&self) -> DeviceId {
self.id
}
@ -152,31 +95,44 @@ impl MetalDevice {
&self.device
}
fn drop_unused_buffers(&self) -> Result<()> {
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
for subbuffers in buffers.values_mut() {
let newbuffers = subbuffers
.iter()
.filter(|s| Arc::strong_count(*s) > 1)
.map(Arc::clone)
.collect();
*subbuffers = newbuffers;
}
Ok(())
pub fn command_queue(&self) -> &CommandQueue {
&self.command_queue
}
pub fn command_buffer(&self) -> Result<CommandBuffer> {
let mut commands = self.commands.write().map_err(MetalError::from)?;
let (flushed, command_buffer) = commands.command_buffer()?;
if flushed {
self.drop_unused_buffers()?
let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?;
let mut command_buffer = command_buffer_lock.to_owned();
let mut index = self
.command_buffer_index
.try_write()
.map_err(MetalError::from)?;
if *index > self.compute_per_buffer {
command_buffer.commit();
command_buffer = self.command_queue.new_command_buffer().to_owned();
*command_buffer_lock = command_buffer.clone();
*index = 0;
self.drop_unused_buffers()?;
}
*index += 1;
Ok(command_buffer)
}
pub fn wait_until_completed(&self) -> Result<()> {
let mut commands = self.commands.write().map_err(MetalError::from)?;
commands.wait_until_completed()
let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?;
match command_buffer.status() {
metal::MTLCommandBufferStatus::Committed
| metal::MTLCommandBufferStatus::Scheduled
| metal::MTLCommandBufferStatus::Completed => {
panic!("Already committed");
}
_ => {}
}
command_buffer.commit();
command_buffer.wait_until_completed();
*command_buffer = self.command_queue.new_command_buffer().to_owned();
Ok(())
}
pub fn kernels(&self) -> &Kernels {
@ -223,8 +179,7 @@ impl MetalDevice {
size,
MTLResourceOptions::StorageModeManaged,
);
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
let subbuffers = buffers
.entry((size, MTLResourceOptions::StorageModeManaged))
.or_insert(vec![]);
@ -255,6 +210,40 @@ impl MetalDevice {
Ok(buffer)
}
fn find_available_buffer(
&self,
size: NSUInteger,
option: MTLResourceOptions,
buffers: &RwLockWriteGuard<BufferMap>,
) -> Option<Arc<Buffer>> {
let mut best_buffer: Option<&Arc<Buffer>> = None;
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
for sub in subbuffers {
if Arc::strong_count(sub) == 1 {
best_buffer = Some(sub);
best_buffer_size = *buffer_size;
}
}
}
}
best_buffer.cloned()
}
fn drop_unused_buffers(&self) -> Result<()> {
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
for subbuffers in buffers.values_mut() {
let newbuffers = subbuffers
.iter()
.filter(|s| Arc::strong_count(*s) > 1)
.map(Arc::clone)
.collect();
*subbuffers = newbuffers;
}
Ok(())
}
/// The critical allocator algorithm
fn allocate_buffer(
&self,
@ -262,8 +251,8 @@ impl MetalDevice {
option: MTLResourceOptions,
_name: &str,
) -> Result<Arc<Buffer>> {
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
if let Some(b) = find_available_buffer(size, option, &buffers) {
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
// Cloning also ensures we increment the strong count
return Ok(b.clone());
}
@ -284,13 +273,7 @@ impl MetalDevice {
let descriptor = metal::CaptureDescriptor::new();
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
descriptor.set_capture_device(self);
// The [set_output_url] call requires an absolute path so we convert it if needed.
if path.as_ref().is_absolute() {
descriptor.set_output_url(path);
} else {
let path = std::env::current_dir()?.join(path);
descriptor.set_output_url(path);
}
descriptor.set_output_url(path);
capture
.start_capture(&descriptor)
@ -302,23 +285,3 @@ impl MetalDevice {
fn buf_size(size: NSUInteger) -> NSUInteger {
size.saturating_sub(1).next_power_of_two() as NSUInteger
}
fn find_available_buffer(
size: NSUInteger,
option: MTLResourceOptions,
buffers: &BufferMap,
) -> Option<Arc<Buffer>> {
let mut best_buffer: Option<&Arc<Buffer>> = None;
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
for sub in subbuffers {
if Arc::strong_count(sub) == 1 {
best_buffer = Some(sub);
best_buffer_size = *buffer_size;
}
}
}
}
best_buffer.cloned()
}

View File

@ -6,7 +6,7 @@ use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels};
use metal::{Buffer, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::{Arc, Mutex, PoisonError, RwLock, TryLockError};
use std::sync::{Arc, Mutex, RwLock, TryLockError};
mod device;
pub use device::{DeviceId, MetalDevice};
@ -36,12 +36,6 @@ impl<T> From<TryLockError<T>> for MetalError {
}
}
impl<T> From<PoisonError<T>> for MetalError {
fn from(p: PoisonError<T>) -> Self {
MetalError::LockError(LockError::Poisoned(p.to_string()))
}
}
/// Metal related errors
#[derive(thiserror::Error, Debug)]
pub enum MetalError {
@ -119,8 +113,6 @@ impl BackendStorage for MetalStorage {
DType::F32 => "affine_f32",
DType::F16 => "affine_f16",
DType::BF16 => "affine_bf16",
DType::U8 => "affine_u8",
DType::U32 => "affine_u32",
dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"),
};
candle_metal_kernels::call_affine(
@ -412,42 +404,17 @@ impl BackendStorage for MetalStorage {
.map_err(MetalError::from)?;
} else {
let kernel_name = match (self.dtype, dtype) {
(DType::BF16, DType::F16) => "cast_bf16_f16_strided",
(DType::BF16, DType::F32) => "cast_bf16_f32_strided",
(DType::BF16, DType::I64) => "cast_bf16_i64_strided",
(DType::BF16, DType::U32) => "cast_bf16_u32_strided",
(DType::BF16, DType::U8) => "cast_bf16_u8_strided",
(DType::F16, DType::BF16) => "cast_f16_bf16_strided",
(DType::F16, DType::F32) => "cast_f16_f32_strided",
(DType::F16, DType::I64) => "cast_f16_i64_strided",
(DType::F16, DType::U32) => "cast_f16_u32_strided",
(DType::F16, DType::U8) => "cast_f16_u8_strided",
(DType::F32, DType::BF16) => "cast_f32_bf16_strided",
(DType::F32, DType::F16) => "cast_f32_f16_strided",
(DType::F32, DType::I64) => "cast_f32_i64_strided",
(DType::F32, DType::U32) => "cast_f32_u32_strided",
(DType::F32, DType::U8) => "cast_f32_u8_strided",
(DType::I64, DType::F32) => "cast_i64_f32_strided",
(DType::I64, DType::BF16) => "cast_i64_bf16_strided",
(DType::I64, DType::F16) => "cast_i64_f16_strided",
(DType::I64, DType::U32) => "cast_i64_u32_strided",
(DType::I64, DType::U8) => "cast_i64_u8_strided",
(DType::U32, DType::BF16) => "cast_u32_bf16_strided",
(DType::U32, DType::F16) => "cast_u32_f16_strided",
(DType::U32, DType::F32) => "cast_u32_f32_strided",
(DType::U32, DType::I64) => "cast_u32_i64_strided",
(DType::U32, DType::U8) => "cast_u32_u8_strided",
(DType::U8, DType::BF16) => "cast_u8_bf16_strided",
(DType::U8, DType::F16) => "cast_u8_f16_strided",
(DType::U32, DType::I64) => "cast_u32_i64_strided",
(DType::U8, DType::U32) => "cast_u8_u32_strided",
(DType::U8, DType::F32) => "cast_u8_f32_strided",
(DType::U8, DType::I64) => "cast_u8_i64_strided",
(DType::U8, DType::U32) => "cast_u8_u32_strided",
(DType::F32, DType::F16) => "cast_f32_f16_strided",
(DType::F16, DType::F32) => "cast_f16_f32_strided",
(DType::I64, DType::F32) => "cast_i64_f32_strided",
(DType::F32, DType::BF16) => "cast_f32_bf16_strided",
(DType::BF16, DType::F32) => "cast_bf16_f32_strided",
(left, right) => {
crate::bail!("Metal strided to_dtype {left:?} {right:?} not implemented")
}
@ -745,7 +712,6 @@ impl BackendStorage for MetalStorage {
}
let name = match (self.dtype, t.dtype()) {
(DType::U8, DType::F32) => "where_u8_f32",
(DType::U32, DType::F32) => "where_u32_f32",
(DType::U8, DType::BF16) => "where_u8_bf16",
(DType::U8, DType::F16) => "where_u8_f16",
(DType::U8, DType::I64) => "where_u8_i64",
@ -852,107 +818,44 @@ impl BackendStorage for MetalStorage {
k_layout: &Layout,
params: &ParamsConvTranspose1D,
) -> Result<Self> {
const USE_COL2IM_CONV1D_TR: bool = true;
let can_use_col2im = k_layout.is_contiguous()
&& params.dilation == 1
&& params.padding == 0
&& params.output_padding == 0;
let l_out = params.l_out();
let dst_el = params.c_out * l_out * params.b_size;
let buffer = self
.device
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
let buffer = if USE_COL2IM_CONV1D_TR && can_use_col2im {
let (b_size, c_in, l_in) = layout.shape().dims3()?;
let (c_in2, c_out, k_size) = k_layout.shape().dims3()?;
if c_in != c_in2 {
crate::bail!(
"convtr1d: shape mismatch on c_in {:?} {:?}",
layout.shape(),
k_layout.shape()
)
}
let buffer = self
.device
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
let name = match self.dtype {
DType::F32 => "col2im1d_f32",
DType::U32 => "col2im1d_u32",
DType::U8 => "col2im1d_u8",
dtype => crate::bail!("metal col2im1d {dtype:?} not implemented"),
};
let col = {
// This merges the last two dimensions of the kernel together.
let kernel_l_mm = Layout::new(
(b_size, c_in, k_size * c_out).into(),
vec![0, k_size * c_out, 1],
k_layout.start_offset(),
);
self.matmul(
k,
(b_size, l_in, c_out * k_size, c_in),
&layout.transpose(1, 2)?,
&kernel_l_mm,
)?
};
// It is important for the command buffer to be obtained *after* the matmul
// kernel has run, otherwise we might use a command-buffer that has been commited
// already resulting in the following error.
// _status < MTLCommandBufferStatusCommitted >
// -[IOGPUMetalCommandBuffer setCurrentCommandEncoder:]
let command_buffer = self.device.command_buffer()?;
candle_metal_kernels::call_col2im1d(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
&[b_size, l_in, c_out, k_size],
params.k_size,
params.stride,
BufferOffset::zero_offset(&col.buffer),
&buffer,
)
.map_err(MetalError::from)?;
buffer
} else {
let buffer = self
.device
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype {
DType::F32 => "conv_transpose1d_f32",
DType::F16 => "conv_transpose1d_f16",
DType::BF16 => "conv_transpose1d_bf16",
DType::U32 => "conv_transpose1d_u32",
DType::U8 => "conv_transpose1d_u8",
dtype => crate::bail!("Metal conv_transpose1d {dtype:?} not implemented"),
};
candle_metal_kernels::call_conv_transpose1d(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
params.dilation,
params.stride,
params.padding,
params.output_padding,
params.c_out,
l_out,
params.b_size,
layout.dims(),
layout.stride(),
k_layout.dims(),
k_layout.stride(),
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
&k.buffer,
k_layout.start_offset() * k.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
buffer
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype {
DType::F32 => "conv_transpose1d_f32",
DType::F16 => "conv_transpose1d_f16",
DType::BF16 => "conv_transpose1d_bf16",
DType::U32 => "conv_transpose1d_u32",
DType::U8 => "conv_transpose1d_u8",
dtype => crate::bail!("Metal conv_transpose1d {dtype:?} not implemented"),
};
candle_metal_kernels::call_conv_transpose1d(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
params.dilation,
params.stride,
params.padding,
params.output_padding,
params.c_out,
l_out,
params.b_size,
layout.dims(),
layout.stride(),
k_layout.dims(),
k_layout.stride(),
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
&k.buffer,
k_layout.start_offset() * k.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
}
@ -1423,7 +1326,6 @@ impl BackendStorage for MetalStorage {
.map_err(MetalError::from)?;
Ok(acc)
}
fn matmul(
&self,
rhs: &Self,
@ -1432,78 +1334,31 @@ impl BackendStorage for MetalStorage {
rhs_l: &Layout,
) -> Result<Self> {
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
let name = match self.dtype {
DType::F32 => "sgemm",
DType::F16 => "hgemm",
dtype => {
return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into())
}
};
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("matmul");
if self.dtype == DType::BF16 {
candle_metal_kernels::call_mlx_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
candle_metal_kernels::GemmDType::BF16,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
} else if self.device.use_mlx_mm {
let dtype = match self.dtype {
DType::F32 => candle_metal_kernels::GemmDType::F32,
DType::F16 => candle_metal_kernels::GemmDType::F16,
DType::BF16 => candle_metal_kernels::GemmDType::BF16,
dtype => {
return Err(MetalError::Message(format!(
"mlx matmul doesn't support {dtype:?}"
))
.into())
}
};
candle_metal_kernels::call_mlx_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
dtype,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
} else {
let name = match self.dtype {
DType::F32 => "sgemm",
DType::F16 => "hgemm",
dtype => {
return Err(
MetalError::Message(format!("matmul doesn't support {dtype:?}")).into(),
)
}
};
candle_metal_kernels::call_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
}
candle_metal_kernels::call_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(
buffer,
self.device.clone(),
@ -1864,25 +1719,31 @@ impl BackendDevice for MetalDevice {
fn new(ordinal: usize) -> Result<Self> {
let device = metal::Device::all().swap_remove(ordinal);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer().to_owned();
command_buffer.enqueue();
let command_buffer = Arc::new(RwLock::new(command_buffer));
let command_buffer_index = Arc::new(RwLock::new(0));
let kernels = Arc::new(Kernels::new());
let use_mlx_mm = match std::env::var("CANDLE_USE_MLX_MM").as_deref() {
Ok("false") | Ok("False") | Ok("FALSE") | Ok("0") | Err(_) => false,
Ok(_) => true,
let buffers = Arc::new(RwLock::new(HashMap::new()));
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
Ok(val) => val.parse()?,
_ => 50,
};
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
[299792458].as_ptr() as *const c_void,
4,
MTLResourceOptions::StorageModeManaged,
)));
let commands = device::Commands::new(command_queue)?;
Ok(Self {
id: DeviceId::new(),
device,
commands: Arc::new(RwLock::new(commands)),
buffers: Arc::new(RwLock::new(HashMap::new())),
command_queue,
command_buffer,
command_buffer_index,
compute_per_buffer,
buffers,
kernels,
seed,
use_mlx_mm,
})
}

View File

@ -34,10 +34,7 @@ fn ceil_div(p: usize, q: usize) -> usize {
}
fn pad(p: usize, q: usize) -> usize {
// Overallocate by q rather than just padding by q as this should pad the last row
// and we don't have enough information here to know how many elements to add :(
// ceil_div(p, q) * q
p + q
ceil_div(p, q) * q
}
fn quantize_q8_1(
@ -442,7 +439,7 @@ impl QCudaStorage {
}
_ => crate::bail!("only f32 can be quantized"),
};
let src_len = pad(src.len(), MATRIX_ROW_PADDING);
let src_len = src.len();
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
qcpu_storage.quantize(&src)?;

View File

@ -217,16 +217,10 @@ impl Value {
}
}
/// This will also automatically upcast any integral types which will not truncate.
pub fn to_u64(&self) -> Result<u64> {
match self {
Self::U64(v) => Ok(*v),
// Autoupcast cases here
Self::U8(v) => Ok(*v as u64),
Self::U16(v) => Ok(*v as u64),
Self::U32(v) => Ok(*v as u64),
Self::Bool(v) => Ok(*v as u64),
v => crate::bail!("not a u64 or upcastable to u64 {v:?}"),
v => crate::bail!("not a u64 {v:?}"),
}
}

View File

@ -18,7 +18,7 @@ pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
let actual_blocks = ys.len();
// Validate that the input is the right size
if actual_blocks < expected_blocks {
if expected_blocks != actual_blocks {
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
}

View File

@ -349,30 +349,6 @@ impl MmapedSafetensors {
}
}
pub struct SliceSafetensors<'a> {
safetensors: SafeTensors<'a>,
}
impl<'a> SliceSafetensors<'a> {
/// Creates a wrapper around a binary buffer and deserialize the safetensors header.
pub fn new(buffer: &'a [u8]) -> Result<Self> {
let safetensors = safetensors::SafeTensors::deserialize(buffer)?;
Ok(Self { safetensors })
}
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
self.safetensors.tensor(name)?.load(dev)
}
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
self.safetensors.tensors()
}
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
Ok(self.safetensors.tensor(name)?)
}
}
pub struct BufferedSafetensors {
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
}

View File

@ -304,7 +304,6 @@ impl Dim for usize {
pub enum D {
Minus1,
Minus2,
Minus(usize),
}
impl D {
@ -312,7 +311,6 @@ impl D {
let dim = match self {
Self::Minus1 => -1,
Self::Minus2 => -2,
Self::Minus(u) => -(*u as i32),
};
Error::DimOutOfRange {
shape: shape.clone(),
@ -329,7 +327,6 @@ impl Dim for D {
match self {
Self::Minus1 if rank >= 1 => Ok(rank - 1),
Self::Minus2 if rank >= 2 => Ok(rank - 2),
Self::Minus(u) if *u > 0 && rank >= *u => Ok(rank - *u),
_ => Err(self.out_of_range(shape, op)),
}
}
@ -339,7 +336,6 @@ impl Dim for D {
match self {
Self::Minus1 => Ok(rank),
Self::Minus2 if rank >= 1 => Ok(rank - 1),
Self::Minus(u) if *u > 0 && rank + 1 >= *u => Ok(rank + 1 - *u),
_ => Err(self.out_of_range(shape, op)),
}
}

View File

@ -1,206 +0,0 @@
use crate::{Result, Shape, Tensor};
pub trait Dim: crate::shape::Dim + Copy {}
impl<T: crate::shape::Dim + Copy> Dim for T {}
/// A stream tensor is used in streaming module. It can either contain an actual tensor or be
/// empty.
#[derive(Clone)]
pub struct StreamTensor(Option<Tensor>);
impl std::fmt::Debug for StreamTensor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.0 {
Some(t) => write!(f, "{:?}", t.shape()),
None => write!(f, "Empty"),
}
}
}
impl std::convert::From<Option<Tensor>> for StreamTensor {
fn from(value: Option<Tensor>) -> Self {
Self(value)
}
}
impl std::convert::From<Tensor> for StreamTensor {
fn from(value: Tensor) -> Self {
Self(Some(value))
}
}
impl std::convert::From<()> for StreamTensor {
fn from(_value: ()) -> Self {
Self(None)
}
}
impl StreamTensor {
pub fn empty() -> Self {
Self(None)
}
pub fn from_tensor(tensor: Tensor) -> Self {
Self(Some(tensor))
}
pub fn shape(&self) -> Option<&Shape> {
self.0.as_ref().map(|t| t.shape())
}
pub fn cat2<D: Dim>(&self, rhs: &Self, dim: D) -> Result<Self> {
let xs = match (&self.0, &rhs.0) {
(Some(lhs), Some(rhs)) => {
let xs = Tensor::cat(&[lhs, rhs], dim)?;
Some(xs)
}
(Some(xs), None) | (None, Some(xs)) => Some(xs.clone()),
(None, None) => None,
};
Ok(Self(xs))
}
pub fn seq_len<D: Dim>(&self, dim: D) -> Result<usize> {
match &self.0 {
None => Ok(0),
Some(v) => v.dim(dim),
}
}
pub fn reset(&mut self) {
self.0 = None
}
pub fn narrow<D: Dim>(&self, dim: D, offset: usize, len: usize) -> Result<StreamTensor> {
let t = match &self.0 {
None => None,
Some(t) => {
let seq_len = t.dim(dim)?;
if seq_len <= offset {
None
} else {
let t = t.narrow(dim, offset, usize::min(len, seq_len - offset))?;
Some(t)
}
}
};
Ok(Self(t))
}
/// Splits the Streaming Tensor on the time axis `dim` with the first `lhs_len` elements
/// returned in the first output and the remaining in the second output.
pub fn split<D: Dim>(&self, dim: D, lhs_len: usize) -> Result<(Self, Self)> {
match &self.0 {
None => Ok((Self::empty(), Self::empty())),
Some(t) => {
let seq_len = t.dim(dim)?;
let lhs_len = usize::min(seq_len, lhs_len);
if lhs_len == 0 {
Ok((Self::empty(), t.clone().into()))
} else {
let lhs = Self::from_tensor(t.narrow(dim, 0, lhs_len)?);
let rhs_len = seq_len - lhs_len;
let rhs = if rhs_len == 0 {
Self::empty()
} else {
Self::from_tensor(t.narrow(dim, lhs_len, rhs_len)?)
};
Ok((lhs, rhs))
}
}
}
}
pub fn as_option(&self) -> Option<&Tensor> {
self.0.as_ref()
}
pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
match &self.0 {
None => Ok(Self::empty()),
Some(t) => Ok(Self::from_tensor(t.apply(m)?)),
}
}
}
/// Streaming modules take as input a stream tensor and return a stream tensor. They may perform
/// some internal buffering so that enough data has been received for the module to be able to
/// perform some operations.
pub trait StreamingModule {
// TODO: Should we also have a flush method?
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor>;
fn reset_state(&mut self);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum BinOp {
Add,
Mul,
Sub,
Div,
}
#[derive(Debug, Clone)]
pub struct StreamingBinOp {
prev_lhs: StreamTensor,
prev_rhs: StreamTensor,
pub op: BinOp,
pub dim: crate::D,
}
impl StreamingBinOp {
pub fn new(op: BinOp, dim: crate::D) -> Self {
Self {
prev_lhs: StreamTensor::empty(),
prev_rhs: StreamTensor::empty(),
op,
dim,
}
}
pub fn reset_state(&mut self) {
self.prev_lhs.reset();
self.prev_rhs.reset();
}
pub fn forward(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor> {
match self.op {
BinOp::Add => Tensor::add(lhs, rhs),
BinOp::Mul => Tensor::mul(lhs, rhs),
BinOp::Sub => Tensor::sub(lhs, rhs),
BinOp::Div => Tensor::div(lhs, rhs),
}
}
pub fn step(&mut self, lhs: &StreamTensor, rhs: &StreamTensor) -> Result<StreamTensor> {
let lhs = StreamTensor::cat2(&self.prev_lhs, lhs, self.dim)?;
let rhs = StreamTensor::cat2(&self.prev_rhs, rhs, self.dim)?;
let lhs_len = lhs.seq_len(self.dim)?;
let rhs_len = rhs.seq_len(self.dim)?;
let common_len = usize::min(lhs_len, rhs_len);
let (lhs, prev_lhs) = lhs.split(self.dim, common_len)?;
let (rhs, prev_rhs) = rhs.split(self.dim, common_len)?;
let ys = match (lhs.0, rhs.0) {
(Some(lhs), Some(rhs)) => {
let ys = self.forward(&lhs, &rhs)?;
StreamTensor::from_tensor(ys)
}
(None, None) => StreamTensor::empty(),
(lhs, rhs) => crate::bail!("INTERNAL ERROR inconsistent lhs and rhs {lhs:?} {rhs:?}"),
};
self.prev_lhs = prev_lhs;
self.prev_rhs = prev_rhs;
Ok(ys)
}
}
/// Simple wrapper that doesn't do any buffering.
pub struct Map<T: crate::Module>(T);
impl<T: crate::Module> StreamingModule for Map<T> {
fn reset_state(&mut self) {}
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
xs.apply(&self.0)
}
}

View File

@ -370,15 +370,6 @@ impl Tensor {
/// Returns a new tensor with all the elements having the same specified value. Note that
/// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
///```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::full(3.5, (2, 4), &Device::Cpu)?;
///
/// assert_eq!(a.to_vec2::<f64>()?, &[
/// [3.5, 3.5, 3.5, 3.5],
/// [3.5, 3.5, 3.5, 3.5],
/// ]);
/// # Ok::<(), candle_core::Error>(())
pub fn full<D: crate::WithDType, S: Into<Shape>>(
value: D,
shape: S,
@ -388,13 +379,6 @@ impl Tensor {
}
/// Creates a new 1D tensor from an iterator.
///```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::from_iter( [1.0, 2.0, 3.0, 4.0].into_iter(), &Device::Cpu)?;
///
/// assert_eq!(a.to_vec1::<f64>()?, &[1.0, 2.0, 3.0, 4.0]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn from_iter<D: crate::WithDType>(
iter: impl IntoIterator<Item = D>,
device: &Device,
@ -406,26 +390,12 @@ impl Tensor {
/// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
/// difference `1` from `start`.
///```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::arange(2., 5., &Device::Cpu)?;
///
/// assert_eq!(a.to_vec1::<f64>()?, &[2., 3., 4.]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn arange<D: crate::WithDType>(start: D, end: D, device: &Device) -> Result<Self> {
Self::arange_step(start, end, D::one(), device)
}
/// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
/// difference `step` from `start`.
///```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::arange_step(2.0, 4.0, 0.5, &Device::Cpu)?;
///
/// assert_eq!(a.to_vec1::<f64>()?, &[2.0, 2.5, 3.0, 3.5]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn arange_step<D: crate::WithDType>(
start: D,
end: D,
@ -471,16 +441,6 @@ impl Tensor {
/// Creates a new tensor initialized with values from the input vector. The number of elements
/// in this vector must be the same as the number of elements defined by the shape.
/// If the device is cpu, no data copy is made.
///```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::from_vec(vec!{1., 2., 3., 4., 5., 6.}, (2, 3), &Device::Cpu)?;
///
/// assert_eq!(a.to_vec2::<f64>()?, &[
/// [1., 2., 3.],
/// [4., 5., 6.]
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
data: Vec<D>,
shape: S,
@ -491,17 +451,6 @@ impl Tensor {
/// Creates a new tensor initialized with values from the input slice. The number of elements
/// in this vector must be the same as the number of elements defined by the shape.
///```rust
/// use candle_core::{Tensor, Device};
/// let values = vec![1., 2., 3., 4., 5., 6., 7., 8.];
/// let a = Tensor::from_slice(&values[1..7], (2, 3), &Device::Cpu)?;
///
/// assert_eq!(a.to_vec2::<f64>()?, &[
/// [2., 3., 4.],
/// [5., 6., 7.]
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
array: &[D],
shape: S,
@ -641,9 +590,9 @@ impl Tensor {
///
/// * `args` - A slice of 1D tensors.
/// * `xy_indexing` - Whether to use xy indexing or ij indexing. If xy is selected, the
/// first dimension corresponds to the cardinality of the second input and the second
/// dimension corresponds to the cardinality of the first input. If ij is selected, the
/// dimensions are in the same order as the cardinality of the inputs.
/// first dimension corresponds to the cardinality of the second input and the second
/// dimension corresponds to the cardinality of the first input. If ij is selected, the
/// dimensions are in the same order as the cardinality of the inputs.
///
/// # Examples
///
@ -783,30 +732,6 @@ impl Tensor {
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
/// ranges from `start` to `start + len`.
/// ```
/// use candle_core::{Tensor, Device};
/// let a = Tensor::new(&[
/// [0f32, 1., 2.],
/// [3. , 4., 5.],
/// [6. , 7., 8.]
/// ], &Device::Cpu)?;
///
/// let b = a.narrow(0, 1, 2)?;
/// assert_eq!(b.shape().dims(), &[2, 3]);
/// assert_eq!(b.to_vec2::<f32>()?, &[
/// [3., 4., 5.],
/// [6., 7., 8.]
/// ]);
///
/// let c = a.narrow(1, 1, 1)?;
/// assert_eq!(c.shape().dims(), &[3, 1]);
/// assert_eq!(c.to_vec2::<f32>()?, &[
/// [1.],
/// [4.],
/// [7.]
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
let dims = self.dims();
let dim = dim.to_index(self.shape(), "narrow")?;
@ -2025,11 +1950,7 @@ impl Tensor {
}
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
_ => {
bail!(
"not implemented yet, self.device: {:?}, device: {:?}",
self.device(),
device
)
bail!("not implemented yet")
}
};
let op = BackpropOp::new1(self, Op::ToDevice);
@ -2519,19 +2440,9 @@ impl Tensor {
/// Returns log(sum(exp(tensor), dim)).
pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
let sum_dims = sum_dims.to_indexes(self.shape(), "log-sum-exp")?;
if sum_dims.is_empty() {
return Ok(self.clone());
}
let max = sum_dims[1..]
.iter()
.try_fold(self.max_keepdim(sum_dims[0])?, |max, &dim| {
max.max_keepdim(dim)
})?;
let exp = self.broadcast_sub(&max)?.exp()?;
let sum = exp.sum(sum_dims.clone())?;
sum.log()? + max.squeeze_dims(&sum_dims)
let exp = self.exp()?;
let sum = exp.sum(sum_dims)?;
sum.log()
}
/// Pointwise pow operation.

View File

@ -235,66 +235,4 @@ impl Tensor {
}
Ok(crate::tensor::from_storage(storage, shape, op, false))
}
/// Set the values on `self` using values from `src`. The copy starts at the specified
/// `offset` for the target dimension `dim` on `self`.
/// `self` and `src` must have the same shape except on dimension `dim` where the `self` size
/// has to be greater than or equal to `offset` plus the `src` size.
///
/// Note that this modifies `self` in place and as such is not compatibel with
/// back-propagation.
pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
let dim = dim.to_index(self.shape(), "slice-set")?;
if !self.is_contiguous() || !src.is_contiguous() {
Err(Error::RequiresContiguous { op: "slice-set" }.bt())?
}
if self.dtype() != src.dtype() {
Err(Error::DTypeMismatchBinaryOp {
lhs: self.dtype(),
rhs: src.dtype(),
op: "slice-set",
}
.bt())?
}
if self.device().location() != src.device().location() {
Err(Error::DeviceMismatchBinaryOp {
lhs: self.device().location(),
rhs: src.device().location(),
op: "slice-set",
}
.bt())?
}
if self.rank() != src.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: self.rank(),
got: src.rank(),
shape: self.shape().clone(),
}
.bt())?
}
for (dim_idx, (v1, v2)) in self.dims().iter().zip(src.dims().iter()).enumerate() {
if dim_idx == dim && *v2 + offset > *v1 {
crate::bail!("shape mismatch on target dim, dst: {v1}, src: {v2} + {offset}")
}
if dim_idx != dim && v1 != v2 {
crate::bail!("shape mismatch on dim {dim_idx}, {v1} <> {v2}")
}
}
let block_size: usize = src.dims().iter().skip(1 + dim).product();
let d1: usize = src.dims().iter().take(dim).product();
let d2 = block_size * src.dims()[dim];
let dst_o = self.layout().start_offset() + offset * block_size;
let src_o = src.layout().start_offset();
src.storage().copy2d(
&mut self.storage_mut(),
d1,
d2,
/* src_s */ d2,
/* dst_s */ block_size * self.dims()[dim],
src_o,
dst_o,
)?;
Ok(())
}
}

View File

@ -730,103 +730,6 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
]
]
);
// Test the same, but then with the following properties, t & w are unmodified.
let padding = 1;
let outpadding = 1;
let dilation = 1;
let stride = 2;
let res = t.conv_transpose2d(&w, padding, outpadding, stride, dilation)?;
let loss = res.sqr()?.sum_all()?;
assert_eq!(test_utils::to_vec0_round(&loss, 0)?, 3627.0); // torch gives 3626.8560
let grads = loss.backward()?;
let grad_t = grads.get(&t).unwrap();
let grad_w = grads.get(&w).unwrap();
assert_eq!(grad_t.dims(), [1, 4, 7, 5]);
assert_eq!(grad_w.dims(), [4, 2, 3, 5]);
#[rustfmt::skip]
assert_eq!(
test_utils::to_vec3_round(&grad_t.i(0)?, 1)?,
[
[
[ 13.2, -40.7, -9.7, -47.3, -82.7],
[ -98.2, 9.7, 57.7, -6.2, 180.7],
[ 100.2, 24.1, 3.7, -100.5, -48.1],
[ -0.3, 13.5, -2.9, 80.0, -49.8],
[ 47.2, -25.6, -74.4, 61.2, -18.4],
[ 4.6, -69.5, 27.9, 66.5, -88.1],
// 4th column on next row; torch is 4.2
[ -12.0, 79.2, -40.0, 4.1, -97.1],
],
[
[ -42.2, -36.5, -51.1, 7.5, 32.3],
[ 74.1, -44.6, -68.8, 19.5, 7.7],
[ 137.1, 54.2, 153.8, -58.0, 45.5],
[ 24.4, -56.8, 9.7, -41.0, -14.5],
[ -3.7, 72.6, 8.3, 134.8, 40.5],
[ 43.2, -56.9, -47.5, -89.4, -95.4],
[ 68.2, 108.1, -80.0, 57.0, -121.1]
],
[
[ 31.1, -11.4, -34.8, 33.1, -44.2],
[ 29.4, -31.6, -40.2, 13.7, 13.1],
[ -0.8, -83.8, -7.8, -17.3, 78.2],
[ 12.0, -118.7, 137.5, -76.7, 50.8],
[ -28.7, -114.2, -3.7, -96.3, -13.8],
[ -31.8, 28.5, -14.3, 4.6, 13.4],
[ 28.0, -0.2, -38.9, -29.7, -59.0]
],
[
[ -16.8, 38.5, 15.5, 26.6, 48.9],
[ 14.5, 49.6, -24.8, 65.6, 61.7],
[ 22.1, -64.7, -4.3, -51.0, 36.3],
[ 31.0, -88.9, 47.1, -123.5, -3.8],
[ -14.8, -39.8, 128.2, -110.3, 42.6],
// 1st column on next row; torch is -7.2
[ -7.1, 95.3, -21.3, -58.7, -13.9],
[ 26.9, 21.3, 16.1, 70.3, 32.1]
]
]
);
#[rustfmt::skip]
assert_eq!(
test_utils::to_vec1_round(&grad_w.flatten_all()?, 1)?,
[
// 2nd value; torch gets -3.2, 3rd value; torch gets 221.8
-2.460e+01, -3.100e+00, 2.219e+02, 7.400e+00, 5.620e+01,
7.420e+01, 7.830e+01, 8.900e+00, 1.050e+01, 2.810e+01,
5.100e+00, -1.046e+02, -1.572e+02, 8.710e+01, -9.840e+01,
-4.230e+01, -1.898e+02, 1.860e+01, -3.570e+01, 9.810e+01,
4.680e+01, 1.182e+02, 4.020e+01, -1.900e+00, 1.508e+02,
1.094e+02, 1.018e+02, -4.620e+01, 1.591e+02, -2.320e+01,
// 5th value; torch gets 7.1
-8.450e+01, -4.600e+00, 6.330e+01, 1.123e+02, -7.000e+00,
1.101e+02, -6.620e+01, 2.090e+01, -5.120e+01, 8.990e+01,
9.050e+01, -6.990e+01, 6.800e+01, -9.250e+01, 1.380e+02,
4.720e+01, 4.710e+01, 6.210e+01, 8.870e+01, 2.098e+02,
3.870e+01, -1.390e+01, 6.270e+01, 1.484e+02, -9.920e+01,
-4.200e+01, -1.505e+02, -1.480e+01, -2.620e+01, 8.220e+01,
-3.350e+01, -2.260e+01, -1.198e+02, -5.080e+01, 1.259e+02,
5.600e+01, 9.270e+01, 1.209e+02, 6.590e+01, -8.330e+01,
7.000e+00, -2.600e+01, -1.133e+02, 3.870e+01, 4.020e+01,
-6.300e+00, -8.710e+01, -5.150e+01, -8.510e+01, 2.000e-01,
3.640e+01, -6.100e+00, 6.590e+01, -2.700e+00, 6.550e+01,
// 4th value; torch gets 3.8
5.300e+00, -6.760e+01, -4.270e+01, -3.900e+00, 2.880e+01,
5.260e+01, 6.170e+01, -1.203e+02, -1.610e+01, 7.740e+01,
-1.008e+02, -1.070e+01, -9.900e+00, 3.300e+00, -2.620e+01,
-4.440e+01, 2.580e+01, -6.920e+01, -4.220e+01, 1.108e+02,
1.240e+01, -3.440e+01, -2.800e+00, 7.880e+01, -6.690e+01,
1.480e+01, 2.310e+01, -4.260e+01, -1.500e+00, -4.760e+01,
5.350e+01, -2.260e+01, 8.000e-01, -3.840e+01, -2.500e+00
]
);
Ok(())
}

View File

@ -49,20 +49,6 @@ fn matmul(device: &Device) -> Result<()> {
Ok(())
}
fn matmul_bf16(device: &Device) -> Result<()> {
if !device.supports_bf16() {
return Ok(());
}
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let a = Tensor::from_slice(&data, (2, 2), device)?.to_dtype(DType::BF16)?;
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let b = Tensor::from_slice(&data, (2, 2), device)?.to_dtype(DType::BF16)?;
let c = a.matmul(&b)?.to_dtype(DType::F32)?;
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
Ok(())
}
fn broadcast_matmul(device: &Device) -> Result<()> {
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
@ -110,12 +96,6 @@ fn mm_layout(device: &Device) -> Result<()> {
}
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
test_device!(
matmul_bf16,
matmul_bf16_cpu,
matmul_bf16_gpu,
matmul_bf16_metal
);
test_device!(
broadcast_matmul,
broadcast_matmul_cpu,

View File

@ -1,31 +1,5 @@
use candle_core::{DType, Result, Tensor};
struct TmpFile(std::path::PathBuf);
impl TmpFile {
fn create(base: &str) -> TmpFile {
let filename = std::env::temp_dir().join(format!(
"candle-{}-{}-{:?}",
base,
std::process::id(),
std::thread::current().id(),
));
TmpFile(filename)
}
}
impl std::convert::AsRef<std::path::Path> for TmpFile {
fn as_ref(&self) -> &std::path::Path {
self.0.as_path()
}
}
impl Drop for TmpFile {
fn drop(&mut self) {
std::fs::remove_file(&self.0).unwrap()
}
}
#[test]
fn npy() -> Result<()> {
let npy = Tensor::read_npy("tests/test.npy")?;
@ -48,24 +22,3 @@ fn npz() -> Result<()> {
);
Ok(())
}
#[test]
fn safetensors() -> Result<()> {
use candle_core::safetensors::Load;
let tmp_file = TmpFile::create("st");
let t = Tensor::arange(0f32, 24f32, &candle_core::Device::Cpu)?;
t.save_safetensors("t", &tmp_file)?;
// Load from file.
let st = candle_core::safetensors::load(&tmp_file, &candle_core::Device::Cpu)?;
let t2 = st.get("t").unwrap();
let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0f32);
// Load from bytes.
let bytes = std::fs::read(tmp_file)?;
let st = candle_core::safetensors::SliceSafetensors::new(&bytes)?;
let t2 = st.get("t").unwrap().load(&candle_core::Device::Cpu);
let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0f32);
Ok(())
}

View File

@ -193,19 +193,6 @@ fn unary_op(device: &Device) -> Result<()> {
tensor.sign()?.to_vec1::<f32>()?,
[-1., -1., -1., 0., 0., 1., 1., 1., 1.]
);
let tensor = Tensor::new(&[-1.0f32, 0., -2., 3.], device)?;
let y = tensor.elu(2.)?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[-1.2642, 0.0000, -1.7293, 3.0000]
);
// This test failed on metal prior to the following PR:
// https://github.com/huggingface/candle/pull/2490
let y = tensor.reshape((2, 2))?.t()?.elu(2.)?.flatten_all()?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[-1.2642, -1.7293, 0.0000, 3.0000]
);
Ok(())
}
@ -678,30 +665,6 @@ fn broadcast(device: &Device) -> Result<()> {
Ok(())
}
fn slice_set(device: &Device) -> Result<()> {
let (b, h, max_t, d) = (2, 4, 7, 3);
let cache = Tensor::zeros((b, h, max_t, d), DType::F32, device)?;
let tensor = Tensor::randn(0f32, 1f32, (b, h, 4, d), device)?;
cache.slice_set(&tensor, 2, 0)?;
let cache_t = cache.narrow(2, 0, 4)?;
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
cache.slice_set(&tensor, 2, 1)?;
let cache_t = cache.narrow(2, 1, 4)?;
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let ones = Tensor::ones((b, h, 1, d), DType::F32, device)?;
cache.slice_set(&ones, 2, 6)?;
let diff = cache.narrow(2, 5, 1)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let diff = (cache.narrow(2, 6, 1)? - 1.)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
Ok(())
}
fn cat(device: &Device) -> Result<()> {
// 1D
let t1 = Tensor::new(&[3f32, 1., 4.], device)?;
@ -1183,7 +1146,6 @@ test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
test_device!(slice_set, ss_cpu, ss_gpu, ss_metal);
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
test_device!(min, min_cpu, min_gpu, min_metal);
@ -1339,29 +1301,11 @@ fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
#[test]
fn log_sum_exp() -> Result<()> {
let input = Tensor::new(
&[
[[1f64, 2., 3.], [4., 5., 6.]],
[[-1000.0, -999.0, -1001.0], [1000.0, 999.0, 1001.0]],
],
&Device::Cpu,
)?;
let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
let output = input.log_sum_exp(D::Minus1)?;
// The expectations obtained from pytorch.
let expected = Tensor::new(&[[3.4076, 6.4076], [-998.5924, 1001.4076]], &Device::Cpu)?;
assert_eq!(output.dims(), expected.dims());
assert_close(&output.flatten_all()?, &expected.flatten_all()?, 0.00001)?;
assert_eq!(
input.log_sum_exp((0, 1))?.to_vec1::<f64>()?,
[1000.0, 999.0, 1001.0]
);
assert_eq!(
input.log_sum_exp(())?.to_vec3::<f64>()?,
input.to_vec3::<f64>()?
);
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
assert_close(&output, &expected, 0.00001)?;
Ok(())
}

View File

@ -89,7 +89,7 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
pub fn load() -> Result<crate::vision::Dataset> {
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let dataset_id = "ylecun/mnist".to_string();
let dataset_id = "mnist".to_string();
let repo = Repo::with_revision(
dataset_id,
RepoType::Dataset,

View File

@ -25,8 +25,6 @@ hf-hub = { workspace = true, features = ["tokio"] }
image = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
palette = { version = "0.7.6", optional = true }
enterpolation = { version = "0.2.1", optional = true}
pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true }
rayon = { workspace = true }
rubato = { version = "0.15.0", optional = true }
@ -35,7 +33,7 @@ serde = { workspace = true }
serde_json = { workspace = true }
symphonia = { version = "0.5.3", features = ["all"], optional = true }
tokenizers = { workspace = true, features = ["onig"] }
cpal = { version = "0.15.2", optional = true }
cpal= { version = "0.15.2", optional = true }
[dev-dependencies]
anyhow = { workspace = true }
@ -67,8 +65,6 @@ onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
microphone = ["cpal"]
encodec = ["cpal", "symphonia", "rubato"]
mimi = ["cpal", "symphonia", "rubato"]
depth_anything_v2 = ["palette", "enterpolation"]
[[example]]
name = "llama_multiprocess"
@ -102,18 +98,6 @@ required-features = ["candle-datasets"]
name = "llama2-c"
required-features = ["candle-datasets"]
[[example]]
name = "mimi"
required-features = ["mimi"]
[[example]]
name = "encodec"
required-features = ["encodec"]
[[example]]
name = "depth_anything_v2"
required-features = ["depth_anything_v2"]
[[example]]
name = "silero-vad"
required-features = ["onnx"]

View File

@ -1,20 +0,0 @@
# candle-based
Experimental, not instruction-tuned small LLM from the Hazy Research group, combining local and linear attention layers.
[Blogpost](https://hazyresearch.stanford.edu/blog/2024-03-03-based)
[Simple linear attention language models balance the recall-throughput tradeoff](https://arxiv.org/abs/2402.18668)
## Running an example
```bash
$ cargo run --example based --release -- --prompt "Flying monkeys are" --which 1b-50b --sample-len 100
Flying monkeys are a common sight in the wild, but they are also a threat to humans.
The new study, published today (July 31) in the journal Science Advances, shows that the monkeys are using their brains to solve the problem of how to get around the problem.
"We found that the monkeys were using a strategy called 'cognitive mapping' - they would use their brains to map out the route ahead," says lead author Dr. David J. Smith from the University of California
```

View File

@ -1,275 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum};
use candle_transformers::models::based::Model;
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <|endoftext|> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
#[value(name = "360m")]
W360m,
#[value(name = "1b")]
W1b,
#[value(name = "1b-50b")]
W1b50b,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "refs/pr/1")]
revision: String,
#[arg(long)]
config_file: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
#[arg(long, default_value = "360m")]
which: Which,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id,
None => match args.which {
Which::W360m => "hazyresearch/based-360m".to_string(),
Which::W1b => "hazyresearch/based-1b".to_string(),
Which::W1b50b => "hazyresearch/based-1b-50b".to_string(),
},
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let config_file = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => vec![repo.get("model.safetensors")?],
};
let repo = api.model("openai-community/gpt2".to_string());
let tokenizer_file = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = serde_json::from_reader(std::fs::File::open(config_file)?)?;
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let mut vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
if args.which == Which::W1b50b {
vb = vb.pp("model");
};
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -1,20 +0,0 @@
# candle-beit
[Beit](https://arxiv.org/abs/2106.08254) is a computer vision model.
In this example, it is used as an ImageNet classifier: the model returns the
probability for the image to belong to each of the 1000 ImageNet categories.
## Running some example
```bash
cargo run --example beit --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
> mountain bike, all-terrain bike, off-roader: 56.16%
> bicycle-built-for-two, tandem bicycle, tandem: 3.08%
> maillot : 2.23%
> alp : 0.88%
> crash helmet : 0.85%
```
![Leading group, Giro d'Italia 2021](../yolo-v8/assets/bike.jpg)

View File

@ -1,79 +0,0 @@
//! BEiT: BERT Pre-Training of Image Transformers
//! https://github.com/microsoft/unilm/tree/master/beit
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::Parser;
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::beit;
/// Loads an image from disk using the image crate, this returns a tensor with shape
/// (3, 384, 384). Beit special normalization is applied.
pub fn load_image384_beit_norm<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
let img = image::ImageReader::open(p)?
.decode()
.map_err(candle::Error::wrap)?
.resize_to_fill(384, 384, image::imageops::FilterType::Triangle);
let img = img.to_rgb8();
let data = img.into_raw();
let data = Tensor::from_vec(data, (384, 384, 3), &Device::Cpu)?.permute((2, 0, 1))?;
let mean = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?;
let std = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?;
(data.to_dtype(candle::DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = load_image384_beit_norm(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("vincent-espitalier/candle-beit".into());
api.get("beit_base_patch16_384.in22k_ft_in22k_in1k.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = beit::vit_base(vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -126,7 +126,7 @@ fn main() -> Result<()> {
println!("Loaded and encoded {:?}", start.elapsed());
for idx in 0..args.n {
let start = std::time::Instant::now();
let ys = model.forward(&token_ids, &token_type_ids, None)?;
let ys = model.forward(&token_ids, &token_type_ids)?;
if idx == 0 {
println!("{ys}");
}
@ -163,19 +163,11 @@ fn main() -> Result<()> {
Ok(Tensor::new(tokens.as_slice(), device)?)
})
.collect::<Result<Vec<_>>>()?;
let attention_mask = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_attention_mask().to_vec();
Ok(Tensor::new(tokens.as_slice(), device)?)
})
.collect::<Result<Vec<_>>>()?;
let token_ids = Tensor::stack(&token_ids, 0)?;
let attention_mask = Tensor::stack(&attention_mask, 0)?;
let token_type_ids = token_ids.zeros_like()?;
println!("running inference on batch {:?}", token_ids.shape());
let embeddings = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
let embeddings = model.forward(&token_ids, &token_type_ids)?;
println!("generated embeddings {:?}", embeddings.shape());
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;

View File

@ -55,7 +55,7 @@ const SEP_TOKEN_ID: u32 = 102;
/// Loads an image from disk using the image crate, this returns a tensor with shape
/// (3, 384, 384). OpenAI normalization is applied.
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
let img = image::ImageReader::open(p)?
let img = image::io::Reader::open(p)?
.decode()
.map_err(candle::Error::wrap)?
.resize_to_fill(384, 384, image::imageops::FilterType::Triangle);

View File

@ -1,4 +1,4 @@
# candle-clip
Contrastive Language-Image Pre-Training
Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
pairs of images with related texts.

View File

@ -33,7 +33,7 @@ struct Args {
}
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
let img = image::ImageReader::open(path)?.decode()?;
let img = image::io::Reader::open(path)?.decode()?;
let (height, width) = (image_size, image_size);
let img = img.resize_to_fill(
width as u32,

View File

@ -1,96 +0,0 @@
* candle-codegeex4_9b
THUDM/CodeGeeX4 is a versatile model for all AI software development scenarios, including code completion, code interpreter, web search, function calling, repository-level Q&A and much more.
- [[https://github.com/THUDM/CodeGeeX4][Github]]
- [[https://codegeex.cn/][HomePage]]
- [[https://huggingface.co/THUDM/codegeex4-all-9b][huggingface]]
** Running with ~cuda~
#+begin_src shell
cargo run --example codegeex4-9b --release --features cuda -- --prompt "please write a insertion sort in rust" --sample-len 300
#+end_src
** Running with ~cpu~
#+begin_src shell
cargo run --example codegeex4-9b --release --cpu -- --prompt "please write a insertion sort in rust" --sample-len 300
#+end_src
** Output_Example
*** Input
#+begin_src shell
cargo run --release --features cuda -- --prompt 'please write a FFT in rust' --sample-len 500 --cache /root/autodl-tmp
#+end_src
*** Output
#+begin_src shell
avx: false, neon: false, simd128: false, f16c: false
temp: 0.95 repeat-penalty: 1.10 repeat-last-n: 64
cache path /root/autodl-tmp
Prompt: [please write a FFT in rust]
Using Seed 11511762269791786684
DType is BF16
transofrmer layers create
模型加载完毕 4
starting the inference loop
开始生成
samplelen 500
500 tokens generated (34.60 token/s)
Result:
Sure, I can help you with that. Here's an example of a Fast Fourier Transform (FFT) implementation in Rust:
```rust
use num_complex::Complex;
fn fft(input: &[Complex<f64> > ] ) -> Vec<Complex<f64> > > {
let n = input.len();
if n == 1 {
return vec![input[0]]];
}
let mut even = vec![];
let mut odd = vec![];
for i in 0..n {
if i % 2 == 0 {
even.push(input[i]);
} else {
odd.push(input[i]);
}
}
let even_fft = fft(&even);
let odd_fft = fft(&odd);
let mut output = vec![];
for k in 0..n/2 {
let t = Complex::new(0.0, -2.0 * std::f64::consts::PI * (k as f64) / (n as f64))) ).exp();
output.push(even_fft[k] + odd_fft[k] * t]);
output.push(even_fft[k] - odd_fft[k] * t]);
}
return output;
}
```
This implementation uses the Cooley-Tukey algorithm to perform the FFT. The function takes an array of complex numbers and returns an array of complex numbers which is the result of the FFT.
#+end_src
* Citation
#+begin_src
@inproceedings{zheng2023codegeex,
title={CodeGeeX: A Pre-Trained Model for Code Generation with Multilingual Benchmarking on HumanEval-X},
author={Qinkai Zheng and Xiao Xia and Xu Zou and Yuxiao Dong and Shan Wang and Yufei Xue and Zihan Wang and Lei Shen and Andi Wang and Yang Li and Teng Su and Zhilin Yang and Jie Tang},
booktitle={Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining},
pages={5673--5684},
year={2023}
}
#+end_src

View File

@ -1,252 +0,0 @@
use candle_transformers::models::codegeex4_9b::*;
use clap::Parser;
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
dtype: DType,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
device: &Device,
dtype: DType,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
logits_processor,
repeat_penalty,
repeat_last_n,
verbose_prompt,
device: device.clone(),
dtype,
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> anyhow::Result<()> {
use std::io::Write;
println!("starting the inference loop");
let tokens = self.tokenizer.encode(prompt, true).expect("tokens error");
if tokens.is_empty() {
panic!("Empty prompts are not supported in the chatglm model.")
}
if self.verbose_prompt {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
}
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
None => panic!("cannot find the endoftext token"),
};
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;
print!("{prompt}");
std::io::stdout().flush().expect("output flush error");
let start_gen = std::time::Instant::now();
println!("\n start_gen");
println!("samplelen {}", sample_len);
let mut count = 0;
let mut result = vec![];
for index in 0..sample_len {
count += 1;
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(self.dtype)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
let token = self
.tokenizer
.decode(&[next_token], true)
.expect("Token error");
if self.verbose_prompt {
println!(
"[Count: {}] [Raw Token: {}] [Decode Token: {}]",
count, next_token, token
);
}
result.push(token);
std::io::stdout().flush()?;
}
let dt = start_gen.elapsed();
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
println!("Result:");
for tokens in result {
print!("{tokens}");
}
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(name = "cache", short, long, default_value = ".")]
cache_path: String,
#[arg(long)]
cpu: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 5000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
weight_file: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.95),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
println!("cache path {}", args.cache_path);
let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
.build()
.map_err(anyhow::Error::msg)?;
let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => "THUDM/codegeex4-all-9b".to_string(),
};
let revision = match args.revision {
Some(rev) => rev.to_string(),
None => "main".to_string(),
};
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("THUDM/codegeex4-all-9b".to_string())
.get("tokenizer.json")
.map_err(anyhow::Error::msg)?,
};
let filenames = match args.weight_file {
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error");
let start = std::time::Instant::now();
let config = Config::codegeex4();
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
args.verbose_prompt,
&device,
dtype,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -1,13 +0,0 @@
# candle-dinov2
[Depth Anything V2] is a model for Monocular Depth Estimation (MDE, i.e. just using a single image) which
builds on the [DINOv2](https://github.com/facebookresearch/dinov2) vision transformer.
This example first instantiates the DINOv2 model and then proceeds to create DepthAnythingV2 and run it.
## Running an example with color map and CUDA
```bash
cargo run --features cuda,depth_anything_v2 --package candle-examples --example depth_anything_v2 -- --color-map --image candle-examples/examples/yolo-v8/assets/bike.jpg
```

View File

@ -1,50 +0,0 @@
use enterpolation::linear::ConstEquidistantLinear;
use enterpolation::Generator;
use palette::LinSrgb;
use candle::Tensor;
pub struct SpectralRColormap {
gradient: ConstEquidistantLinear<f32, LinSrgb, 9>,
}
impl SpectralRColormap {
pub(crate) fn new() -> Self {
// Define a colormap similar to 'Spectral_r' by specifying key colors.
// got the colors from ChatGPT-4o
let gradient = ConstEquidistantLinear::<f32, _, 9>::equidistant_unchecked([
LinSrgb::new(0.3686, 0.3098, 0.6353), // Dark blue
LinSrgb::new(0.1961, 0.5333, 0.7412), // Blue
LinSrgb::new(0.4000, 0.7608, 0.6471), // Cyan
LinSrgb::new(0.6706, 0.8667, 0.6431), // Green
LinSrgb::new(0.9020, 0.9608, 0.5961), // Yellow
LinSrgb::new(0.9961, 0.8784, 0.5451), // Orange
LinSrgb::new(0.9922, 0.6824, 0.3804), // Red
LinSrgb::new(0.9569, 0.4275, 0.2627), // Dark red
LinSrgb::new(0.8353, 0.2431, 0.3098), // Dark purple
]);
Self { gradient }
}
fn get_color(&self, value: f32) -> LinSrgb {
self.gradient.gen(value)
}
pub fn gray2color(&self, gray: &Tensor) -> candle::Result<Tensor> {
println!("Gray: {:?}", gray.dims());
let gray_values: Vec<f32> = gray.flatten_all()?.to_vec1()?;
let rgb_values: Vec<f32> = gray_values
.iter()
.map(|g| self.get_color(*g))
.flat_map(|rgb| [rgb.red, rgb.green, rgb.blue])
.collect();
let [.., height, width] = gray.dims() else {
candle::bail!("Not enough dims!")
};
let color = Tensor::from_vec(rgb_values, (*height, *width, 3), gray.device())?;
color.permute((2, 0, 1))
}
}

View File

@ -1,187 +0,0 @@
//! Depth Anything V2
//! https://huggingface.co/spaces/depth-anything/Depth-Anything-V2
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use std::ffi::OsString;
use std::path::PathBuf;
use clap::Parser;
use candle::DType::{F32, U8};
use candle::{DType, Device, Module, Result, Tensor};
use candle_examples::{load_image, load_image_and_resize, save_image};
use candle_nn::VarBuilder;
use candle_transformers::models::depth_anything_v2::{DepthAnythingV2, DepthAnythingV2Config};
use candle_transformers::models::dinov2;
use crate::color_map::SpectralRColormap;
mod color_map;
// taken these from: https://huggingface.co/spaces/depth-anything/Depth-Anything-V2/blob/main/depth_anything_v2/dpt.py#L207
const MAGIC_MEAN: [f32; 3] = [0.485, 0.456, 0.406];
const MAGIC_STD: [f32; 3] = [0.229, 0.224, 0.225];
const DINO_IMG_SIZE: usize = 518;
#[derive(Parser)]
struct Args {
#[arg(long)]
dinov2_model: Option<PathBuf>,
#[arg(long)]
depth_anything_v2_model: Option<PathBuf>,
#[arg(long)]
image: PathBuf,
#[arg(long)]
output_dir: Option<PathBuf>,
#[arg(long)]
cpu: bool,
#[arg(long)]
color_map: bool,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let dinov2_model_file = match args.dinov2_model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("lmz/candle-dino-v2".into());
api.get("dinov2_vits14.safetensors")?
}
Some(dinov2_model) => dinov2_model,
};
println!("Using file {:?}", dinov2_model_file);
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[dinov2_model_file], F32, &device)? };
let dinov2 = dinov2::vit_small(vb)?;
println!("DinoV2 model built");
let depth_anything_model_file = match args.depth_anything_v2_model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("jeroenvlek/depth-anything-v2-safetensors".into());
api.get("depth_anything_v2_vits.safetensors")?
}
Some(depth_anything_model) => depth_anything_model,
};
println!("Using file {:?}", depth_anything_model_file);
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[depth_anything_model_file], DType::F32, &device)?
};
let config = DepthAnythingV2Config::vit_small();
let depth_anything = DepthAnythingV2::new(&dinov2, &config, vb)?;
let (original_height, original_width, image) = load_and_prep_image(&args.image, &device)?;
println!("Loaded image {image:?}");
let depth = depth_anything.forward(&image)?;
println!("Got predictions {:?}", depth.shape());
let output_image = post_process_image(&depth, original_height, original_width, args.color_map)?;
let output_path = full_output_path(&args.image, &args.output_dir);
println!("Saving image to {}", output_path.to_string_lossy());
save_image(&output_image, output_path)?;
Ok(())
}
fn full_output_path(image_path: &PathBuf, output_dir: &Option<PathBuf>) -> PathBuf {
let input_file_name = image_path.file_name().unwrap();
let mut output_file_name = OsString::from("depth_");
output_file_name.push(input_file_name);
let mut output_path = match output_dir {
None => image_path.parent().unwrap().to_path_buf(),
Some(output_path) => output_path.clone(),
};
output_path.push(output_file_name);
output_path
}
fn load_and_prep_image(
image_path: &PathBuf,
device: &Device,
) -> anyhow::Result<(usize, usize, Tensor)> {
let (_original_image, original_height, original_width) = load_image(&image_path, None)?;
let image = load_image_and_resize(&image_path, DINO_IMG_SIZE, DINO_IMG_SIZE)?
.unsqueeze(0)?
.to_dtype(F32)?
.to_device(&device)?;
let max_pixel_val = Tensor::try_from(255.0f32)?
.to_device(&device)?
.broadcast_as(image.shape())?;
let image = (image / max_pixel_val)?;
let image = normalize_image(&image, &MAGIC_MEAN, &MAGIC_STD)?;
Ok((original_height, original_width, image))
}
fn normalize_image(image: &Tensor, mean: &[f32; 3], std: &[f32; 3]) -> Result<Tensor> {
let mean_tensor =
Tensor::from_vec(mean.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?;
let std_tensor =
Tensor::from_vec(std.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?;
image.sub(&mean_tensor)?.div(&std_tensor)
}
fn post_process_image(
image: &Tensor,
original_height: usize,
original_width: usize,
color_map: bool,
) -> Result<Tensor> {
let out = image.interpolate2d(original_height, original_width)?;
let out = scale_image(&out)?;
let out = if color_map {
let spectral_r = SpectralRColormap::new();
spectral_r.gray2color(&out)?
} else {
let rgb_slice = [&out, &out, &out];
Tensor::cat(&rgb_slice, 0)?.squeeze(1)?
};
let max_pixel_val = Tensor::try_from(255.0f32)?
.to_device(out.device())?
.broadcast_as(out.shape())?;
let out = (out * max_pixel_val)?;
out.to_dtype(U8)
}
fn scale_image(depth: &Tensor) -> Result<Tensor> {
let flat_values: Vec<f32> = depth.flatten_all()?.to_vec1()?;
let min_val = flat_values.iter().min_by(|a, b| a.total_cmp(b)).unwrap();
let max_val = flat_values.iter().max_by(|a, b| a.total_cmp(b)).unwrap();
let min_val_tensor = Tensor::try_from(*min_val)?
.to_device(depth.device())?
.broadcast_as(depth.shape())?;
let depth = (depth - min_val_tensor)?;
let range = max_val - min_val;
let range_tensor = Tensor::try_from(range)?
.to_device(depth.device())?
.broadcast_as(depth.shape())?;
depth / range_tensor
}

View File

@ -1,25 +0,0 @@
# candle-dinov2-reg4
[DINOv2-reg4](https://arxiv.org/abs/2309.16588) is the lastest version of DINOv2 with registers.
In this example, it is used as an plant species classifier: the model returns the
probability for the image to belong to each of the 7806 PlantCLEF2024 categories.
## Running some example
```bash
# Download classes names and a plant picture to identify
curl https://huggingface.co/vincent-espitalier/dino-v2-reg4-with-plantclef2024-weights/raw/main/species_id_mapping.txt --output candle-examples/examples/dinov2reg4/species_id_mapping.txt
curl https://bs.plantnet.org/image/o/bd2d3830ac3270218ba82fd24e2290becd01317c --output candle-examples/examples/dinov2reg4/bd2d3830ac3270218ba82fd24e2290becd01317c.jpg
# Perform inference
cargo run --example dinov2reg4 --release -- --image candle-examples/examples/dinov2reg4/bd2d3830ac3270218ba82fd24e2290becd01317c.jpg
> Orchis simia Lam. : 45.55%
> Orchis × bergonii Nanteuil: 9.80%
> Orchis italica Poir. : 9.66%
> Orchis × angusticruris Franch.: 2.76%
> Orchis × bivonae Tod. : 2.54%
```
![Orchis Simia](https://bs.plantnet.org/image/o/bd2d3830ac3270218ba82fd24e2290becd01317c)

View File

@ -1,70 +0,0 @@
//! DINOv2 reg4 finetuned on PlantCLEF 2024
//! https://arxiv.org/abs/2309.16588
//! https://huggingface.co/spaces/BVRA/PlantCLEF2024
//! https://zenodo.org/records/10848263
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::Parser;
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::dinov2reg4;
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image518(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");
let f_species_id_mapping = "candle-examples/examples/dinov2reg4/species_id_mapping.txt";
let classes: Vec<String> = std::fs::read_to_string(f_species_id_mapping)
.expect("missing classes file")
.split('\n')
.map(|s| s.to_string())
.collect();
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api =
api.model("vincent-espitalier/dino-v2-reg4-with-plantclef2024-weights".into());
api.get(
"vit_base_patch14_reg4_dinov2_lvd142m_pc24_onlyclassifier_then_all.safetensors",
)?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = dinov2reg4::vit_base(vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!("{:24}: {:.2}%", classes[category_idx], 100. * pr);
}
Ok(())
}

View File

@ -7,7 +7,7 @@ quantization.
## Running one example
```bash
cargo run --example encodec --features encodec --release -- code-to-audio \
cargo run --example encodec --features symphonia --release -- code-to-audio \
candle-examples/examples/encodec/jfk-codes.safetensors \
jfk.wav
```

View File

@ -1,21 +0,0 @@
# candle-eva2
[EVA-02](https://arxiv.org/abs/2303.11331) is a computer vision model.
In this example, it is used as an ImageNet classifier: the model returns the
probability for the image to belong to each of the 1000 ImageNet categories.
## Running some example
```bash
cargo run --example eva2 --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
> mountain bike, all-terrain bike, off-roader: 37.09%
> maillot : 8.30%
> alp : 2.13%
> bicycle-built-for-two, tandem bicycle, tandem: 0.84%
> crash helmet : 0.73%
```
![Leading group, Giro d'Italia 2021](../yolo-v8/assets/bike.jpg)

View File

@ -1,82 +0,0 @@
//! EVA-02: Explore the limits of Visual representation at scAle
//! https://github.com/baaivision/EVA
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::Parser;
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::eva2;
/// Loads an image from disk using the image crate, this returns a tensor with shape
/// (3, 448, 448). OpenAI normalization is applied.
pub fn load_image448_openai_norm<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
let img = image::ImageReader::open(p)?
.decode()
.map_err(candle::Error::wrap)?
.resize_to_fill(448, 448, image::imageops::FilterType::Triangle);
let img = img.to_rgb8();
let data = img.into_raw();
let data = Tensor::from_vec(data, (448, 448, 3), &Device::Cpu)?.permute((2, 0, 1))?;
let mean =
Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], &Device::Cpu)?.reshape((3, 1, 1))?;
let std = Tensor::new(&[0.26862954f32, 0.261_302_6, 0.275_777_1], &Device::Cpu)?
.reshape((3, 1, 1))?;
(data.to_dtype(candle::DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = load_image448_openai_norm(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("vincent-espitalier/candle-eva2".into());
api.get("eva02_base_patch14_448.mim_in22k_ft_in22k_in1k_adapted.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = eva2::vit_base(vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -1,20 +0,0 @@
# candle-fastvit
[FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization](https://arxiv.org/abs/2303.14189).
This candle implementation uses a pre-trained FastViT network for inference. The
classification head has been trained on the ImageNet dataset and returns the
probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example fastvit --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which sa12
loaded image Tensor[dims 3, 256, 256; f32]
model built
mountain bike, all-terrain bike, off-roader: 52.67%
bicycle-built-for-two, tandem bicycle, tandem: 7.93%
unicycle, monocycle : 3.46%
maillot : 1.32%
crash helmet : 1.28%
```

View File

@ -1,102 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::fastvit;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
T8,
T12,
S12,
SA12,
SA24,
SA36,
MA36,
}
impl Which {
fn model_filename(&self) -> String {
let name = match self {
Self::T8 => "t8",
Self::T12 => "t12",
Self::S12 => "s12",
Self::SA12 => "sa12",
Self::SA24 => "sa24",
Self::SA36 => "sa36",
Self::MA36 => "ma36",
};
format!("timm/fastvit_{}.apple_in1k", name)
}
fn config(&self) -> fastvit::Config {
match self {
Self::T8 => fastvit::Config::t8(),
Self::T12 => fastvit::Config::t12(),
Self::S12 => fastvit::Config::s12(),
Self::SA12 => fastvit::Config::sa12(),
Self::SA24 => fastvit::Config::sa24(),
Self::SA36 => fastvit::Config::sa36(),
Self::MA36 => fastvit::Config::ma36(),
}
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(value_enum, long, default_value_t=Which::S12)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image(args.image, 256)?.to_device(&device)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let model_name = args.which.model_filename();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = fastvit::fastvit(&args.which.config(), 1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -1,19 +0,0 @@
# candle-flux: image generation with latent rectified flow transformers
![rusty robot holding a candle](./assets/flux-robot.jpg)
Flux is a 12B rectified flow transformer capable of generating images from text
descriptions,
[huggingface](https://huggingface.co/black-forest-labs/FLUX.1-schnell),
[github](https://github.com/black-forest-labs/flux),
[blog post](https://blackforestlabs.ai/announcing-black-forest-labs/).
## Running the model
```bash
cargo run --features cuda --example flux -r -- \
--height 1024 --width 1024 \
--prompt "a rusty robot walking on a beach holding a small torch, the robot has the word "rust" written on it, high quality, 4k"
```

Binary file not shown.

Before

Width:  |  Height:  |  Size: 90 KiB

View File

@ -1,248 +0,0 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use candle_transformers::models::{clip, flux, t5};
use anyhow::{Error as E, Result};
use candle::{IndexOp, Module, Tensor};
use candle_nn::VarBuilder;
use clap::Parser;
use tokenizers::Tokenizer;
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
/// The prompt to be used for image generation.
#[arg(long, default_value = "A rusty robot walking on a beach")]
prompt: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Use the quantized model.
#[arg(long)]
quantized: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// The height in pixels of the generated image.
#[arg(long)]
height: Option<usize>,
/// The width in pixels of the generated image.
#[arg(long)]
width: Option<usize>,
#[arg(long)]
decode_only: Option<String>,
#[arg(long, value_enum, default_value = "schnell")]
model: Model,
}
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
enum Model {
Schnell,
Dev,
}
fn run(args: Args) -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let Args {
prompt,
cpu,
height,
width,
tracing,
decode_only,
model,
quantized,
} = args;
let width = width.unwrap_or(1360);
let height = height.unwrap_or(768);
let _guard = if tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let api = hf_hub::api::sync::Api::new()?;
let bf_repo = {
let name = match model {
Model::Dev => "black-forest-labs/FLUX.1-dev",
Model::Schnell => "black-forest-labs/FLUX.1-schnell",
};
api.repo(hf_hub::Repo::model(name.to_string()))
};
let device = candle_examples::device(cpu)?;
let dtype = device.bf16_default_to_f32();
let img = match decode_only {
None => {
let t5_emb = {
let repo = api.repo(hf_hub::Repo::with_revision(
"google/t5-v1_1-xxl".to_string(),
hf_hub::RepoType::Model,
"refs/pr/2".to_string(),
));
let model_file = repo.get("model.safetensors")?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?;
let config: t5::Config = serde_json::from_str(&config)?;
let mut model = t5::T5EncoderModel::load(vb, &config)?;
let tokenizer_filename = api
.model("lmz/mt5-tokenizers".to_string())
.get("t5-v1_1-xxl.tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let mut tokens = tokenizer
.encode(prompt.as_str(), true)
.map_err(E::msg)?
.get_ids()
.to_vec();
tokens.resize(256, 0);
let input_token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
println!("{input_token_ids}");
model.forward(&input_token_ids)?
};
println!("T5\n{t5_emb}");
let clip_emb = {
let repo = api.repo(hf_hub::Repo::model(
"openai/clip-vit-large-patch14".to_string(),
));
let model_file = repo.get("model.safetensors")?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
// https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
let config = clip::text_model::ClipTextConfig {
vocab_size: 49408,
projection_dim: 768,
activation: clip::text_model::Activation::QuickGelu,
intermediate_size: 3072,
embed_dim: 768,
max_position_embeddings: 77,
pad_with: None,
num_hidden_layers: 12,
num_attention_heads: 12,
};
let model =
clip::text_model::ClipTextTransformer::new(vb.pp("text_model"), &config)?;
let tokenizer_filename = repo.get("tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let tokens = tokenizer
.encode(prompt.as_str(), true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let input_token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
println!("{input_token_ids}");
model.forward(&input_token_ids)?
};
println!("CLIP\n{clip_emb}");
let img = {
let cfg = match model {
Model::Dev => flux::model::Config::dev(),
Model::Schnell => flux::model::Config::schnell(),
};
let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?;
let state = if quantized {
flux::sampling::State::new(
&t5_emb.to_dtype(candle::DType::F32)?,
&clip_emb.to_dtype(candle::DType::F32)?,
&img.to_dtype(candle::DType::F32)?,
)?
} else {
flux::sampling::State::new(&t5_emb, &clip_emb, &img)?
};
let timesteps = match model {
Model::Dev => {
flux::sampling::get_schedule(50, Some((state.img.dim(1)?, 0.5, 1.15)))
}
Model::Schnell => flux::sampling::get_schedule(4, None),
};
println!("{state:?}");
println!("{timesteps:?}");
if quantized {
let model_file = match model {
Model::Schnell => api
.repo(hf_hub::Repo::model("lmz/candle-flux".to_string()))
.get("flux1-schnell.gguf")?,
Model::Dev => todo!(),
};
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
model_file, &device,
)?;
let model = flux::quantized_model::Flux::new(&cfg, vb)?;
flux::sampling::denoise(
&model,
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
&timesteps,
4.,
)?
.to_dtype(dtype)?
} else {
let model_file = match model {
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
};
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)?
};
let model = flux::model::Flux::new(&cfg, vb)?;
flux::sampling::denoise(
&model,
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
&timesteps,
4.,
)?
}
};
flux::sampling::unpack(&img, height, width)?
}
Some(file) => {
let mut st = candle::safetensors::load(file, &device)?;
st.remove("img").unwrap().to_dtype(dtype)?
}
};
println!("latent img\n{img}");
let img = {
let model_file = bf_repo.get("ae.safetensors")?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let cfg = match model {
Model::Dev => flux::autoencoder::Config::dev(),
Model::Schnell => flux::autoencoder::Config::schnell(),
};
let model = flux::autoencoder::AutoEncoder::new(&cfg, vb)?;
model.decode(&img)?
};
println!("img\n{img}");
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?;
candle_examples::save_image(&img.i(0)?, "out.jpg")?;
Ok(())
}
fn main() -> Result<()> {
let args = Args::parse();
run(args)
}

View File

@ -1,6 +0,0 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
BASE_MODEL = "google/t5-v1_1-xxl"
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
# The tokenizer will be saved in /tmp/tokenizer/tokenizer.json
tokenizer.save_pretrained("/tmp/tokenizer/")

View File

@ -1,27 +1,27 @@
# candle-gemma: 2b and 7b LLMs from Google DeepMind
[Gemma](https://ai.google.dev/gemma/docs) is a collection of lightweight open
models published by Google Deepmind with a 2b and a 7b variant for the first
version, and a 2b and a 9b variant for v2.
models published by Google Deepmind with a 2b and a 7b variant.
## Running the example
```bash
$ cargo run --example gemma --features cuda -r -- \
--prompt "Here is a proof that square root of 2 is not rational: "
Here is a proof that square root of 2 is not rational:
Let us assume it to be rational. Then, we can write √2 = p/q where q ≠ 0 and p and q are integers with no common factors other than 1. Squaring both sides gives us (p/q)^2 = 2 or p^2/q^2 = 2. This implies that p^2 is divisible by 2, which means that p must be even. Let us write p = 2m where m is an integer. Substituting this in the above equation we get:
(p^2)/q^2 = 2 or (4m^2)/q^2 = 2 or q^2/2m^2 = 1 which implies that q^2 must be divisible by 2, and hence q is even. This contradicts our assumption that p and q have no common factors other than 1. Hence we conclude that √2 cannot be rational.
```
## Access restrictions
In order to use the v1 examples, you have to accept the license on the
In order to use the example below, you have to accept the license on the
[HuggingFace Hub Gemma repo](https://huggingface.co/google/gemma-7b) and set up
your access token via the [HuggingFace cli login
command](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login).
## Running the example
```bash
$ cargo run --example gemma --release -- --prompt "fn count_primes(max_n: usize)"
fn count_primes(max_n: usize) -> usize {
let mut primes = vec![true; max_n];
for i in 2..=max_n {
if primes[i] {
for j in i * i..max_n {
primes[j] = false;
}
}
}
primes.len()
}
```

View File

@ -7,8 +7,7 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma::{Config, Model};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
@ -39,46 +38,6 @@ enum Which {
CodeInstruct2B,
#[value(name = "code-7b-it")]
CodeInstruct7B,
#[value(name = "2-2b")]
BaseV2_2B,
#[value(name = "2-2b-it")]
InstructV2_2B,
#[value(name = "2-9b")]
BaseV2_9B,
#[value(name = "2-9b-it")]
InstructV2_9B,
}
impl Which {
fn is_v1(&self) -> bool {
match self {
Self::Base2B
| Self::Base7B
| Self::Instruct2B
| Self::Instruct7B
| Self::InstructV1_1_2B
| Self::InstructV1_1_7B
| Self::CodeBase2B
| Self::CodeBase7B
| Self::CodeInstruct2B
| Self::CodeInstruct7B => true,
Self::BaseV2_2B | Self::InstructV2_2B | Self::BaseV2_9B | Self::InstructV2_9B => false,
}
}
}
enum Model {
V1(Model1),
V2(Model2),
}
impl Model {
fn forward(&mut self, input_ids: &Tensor, pos: usize) -> candle::Result<Tensor> {
match self {
Self::V1(m) => m.forward(input_ids, pos),
Self::V2(m) => m.forward(input_ids, pos),
}
}
}
struct TextGeneration {
@ -232,11 +191,8 @@ struct Args {
repeat_last_n: usize,
/// The model to use.
#[arg(long, default_value = "2-2b")]
#[arg(long, default_value = "2b")]
which: Which,
#[arg(long)]
use_flash_attn: bool,
}
fn main() -> Result<()> {
@ -280,10 +236,6 @@ fn main() -> Result<()> {
Which::CodeBase7B => "google/codegemma-7b".to_string(),
Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
Which::BaseV2_2B => "google/gemma-2-2b".to_string(),
Which::InstructV2_2B => "google/gemma-2-2b-it".to_string(),
Which::BaseV2_9B => "google/gemma-2-9b".to_string(),
Which::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
},
};
let repo = api.repo(Repo::with_revision(
@ -308,6 +260,7 @@ fn main() -> Result<()> {
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
@ -317,15 +270,7 @@ fn main() -> Result<()> {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = if args.which.is_v1() {
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model1::new(args.use_flash_attn, &config, vb)?;
Model::V1(model)
} else {
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model2::new(args.use_flash_attn, &config, vb)?;
Model::V2(model)
};
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());

View File

@ -1,77 +0,0 @@
* GLM4
GLM-4-9B is the open-source version of the latest generation of pre-trained models in the GLM-4 series launched by Zhipu AI.
- [[https://github.com/THUDM/GLM4][Github]]
- [[https://huggingface.co/THUDM/glm-4-9b][huggingface]]
** Running with ~cuda~
#+begin_src shell
cargo run --example glm4 --release --features cuda
#+end_src
** Running with ~cpu~
#+begin_src shell
cargo run --example glm4 --release -- --cpu
#+end_src
** Output Example
#+begin_src shell
cargo run --example glm4 --release --features cuda -- --sample-len 500 --cache .
Finished release [optimized] target(s) in 0.24s
Running `/root/candle/target/release/examples/glm4 --sample-len 500 --cache .`
avx: true, neon: false, simd128: false, f16c: true
temp: 0.60 repeat-penalty: 1.20 repeat-last-n: 64
cache path .
retrieved the files in 6.88963ms
loaded the model in 6.113752297s
starting the inference loop
[欢迎使用GLM-4,请输入prompt]
请你告诉我什么是FFT
266 tokens generated (34.50 token/s)
Result:
。Fast Fourier Transform (FFT) 是一种快速计算离散傅里叶变换DFT的方法它广泛应用于信号处理、图像处理和数据分析等领域。
具体来说FFT是一种将时域数据转换为频域数据的算法。在数字信号处理中我们通常需要知道信号的频率成分这就需要进行傅立叶变换。传统的傅立叶变换的计算复杂度较高而 FFT 则大大提高了计算效率,使得大规模的 DFT 换成为可能。
以下是使用 Python 中的 numpy 进行 FFT 的简单示例:
```python
import numpy as np
# 创建一个时域信号
t = np.linspace(0, 1, num=100)
f = np.sin(2*np.pi*5*t) + 3*np.cos(2*np.pi*10*t)
# 对该信号做FFT变换并计算其幅值谱
fft_result = np.fft.fftshift(np.abs(np.fft.fft(f)))
```
在这个例子中,我们首先创建了一个时域信号 f。然后我们对这个信号进行了 FFT 换,得到了一个频域结果 fft_result。
#+end_src
This example will read prompt from stdin
* Citation
#+begin_src
@misc{glm2024chatglm,
title={ChatGLM: A Family of Large Language Models from GLM-130B to GLM-4 All Tools},
author={Team GLM and Aohan Zeng and Bin Xu and Bowen Wang and Chenhui Zhang and Da Yin and Diego Rojas and Guanyu Feng and Hanlin Zhao and Hanyu Lai and Hao Yu and Hongning Wang and Jiadai Sun and Jiajie Zhang and Jiale Cheng and Jiayi Gui and Jie Tang and Jing Zhang and Juanzi Li and Lei Zhao and Lindong Wu and Lucen Zhong and Mingdao Liu and Minlie Huang and Peng Zhang and Qinkai Zheng and Rui Lu and Shuaiqi Duan and Shudan Zhang and Shulin Cao and Shuxun Yang and Weng Lam Tam and Wenyi Zhao and Xiao Liu and Xiao Xia and Xiaohan Zhang and Xiaotao Gu and Xin Lv and Xinghan Liu and Xinyi Liu and Xinyue Yang and Xixuan Song and Xunkai Zhang and Yifan An and Yifan Xu and Yilin Niu and Yuantao Yang and Yueyan Li and Yushi Bai and Yuxiao Dong and Zehan Qi and Zhaoyu Wang and Zhen Yang and Zhengxiao Du and Zhenyu Hou and Zihan Wang},
year={2024},
eprint={2406.12793},
archivePrefix={arXiv},
primaryClass={id='cs.CL' full_name='Computation and Language' is_active=True alt_name='cmp-lg' in_archive='cs' is_general=False description='Covers natural language processing. Roughly includes material in ACM Subject Class I.2.7. Note that work on artificial languages (programming languages, logics, formal systems) that does not explicitly address natural-language issues broadly construed (natural-language processing, computational linguistics, speech, text retrieval, etc.) is not appropriate for this area.'}
}
#+end_src
#+begin_src
@misc{wang2023cogvlm,
title={CogVLM: Visual Expert for Pretrained Language Models},
author={Weihan Wang and Qingsong Lv and Wenmeng Yu and Wenyi Hong and Ji Qi and Yan Wang and Junhui Ji and Zhuoyi Yang and Lei Zhao and Xixuan Song and Jiazheng Xu and Bin Xu and Juanzi Li and Yuxiao Dong and Ming Ding and Jie Tang},
year={2023},
eprint={2311.03079},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
#+end_src

View File

@ -1,255 +0,0 @@
use candle_transformers::models::glm4::*;
use clap::Parser;
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
dtype: DType,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
device: &Device,
dtype: DType,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
logits_processor,
repeat_penalty,
repeat_last_n,
verbose_prompt,
device: device.clone(),
dtype,
}
}
fn run(&mut self, sample_len: usize) -> anyhow::Result<()> {
use std::io::BufRead;
use std::io::BufReader;
use std::io::Write;
println!("starting the inference loop");
println!("[欢迎使用GLM-4,请输入prompt]");
let stdin = std::io::stdin();
let reader = BufReader::new(stdin);
for line in reader.lines() {
let line = line.expect("Failed to read line");
let tokens = self.tokenizer.encode(line, true).expect("tokens error");
if tokens.is_empty() {
panic!("Empty prompts are not supported in the chatglm model.")
}
if self.verbose_prompt {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
}
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
None => panic!("cannot find the endoftext token"),
};
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;
std::io::stdout().flush().expect("output flush error");
let start_gen = std::time::Instant::now();
let mut count = 0;
let mut result = vec![];
for index in 0..sample_len {
count += 1;
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(self.dtype)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
let token = self
.tokenizer
.decode(&[next_token], true)
.expect("Token error");
if self.verbose_prompt {
println!(
"[Count: {}] [Raw Token: {}] [Decode Token: {}]",
count, next_token, token
);
}
result.push(token);
std::io::stdout().flush()?;
}
let dt = start_gen.elapsed();
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
println!("Result:");
for tokens in result {
print!("{tokens}");
}
self.model.reset_kv_cache(); // clean the cache
}
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(name = "cache", short, long, default_value = ".")]
cache_path: String,
#[arg(long)]
cpu: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 8192)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
weight_file: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.2)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.6),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
println!("cache path {}", args.cache_path);
let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
.build()
.map_err(anyhow::Error::msg)?;
let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => "THUDM/glm-4-9b".to_string(),
};
let revision = match args.revision {
Some(rev) => rev.to_string(),
None => "main".to_string(),
};
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("THUDM/codegeex4-all-9b".to_string())
.get("tokenizer.json")
.map_err(anyhow::Error::msg)?,
};
let filenames = match args.weight_file {
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error");
let start = std::time::Instant::now();
let config = Config::glm4();
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
args.verbose_prompt,
&device,
dtype,
);
pipeline.run(args.sample_len)?;
Ok(())
}

View File

@ -1,20 +0,0 @@
# candle-granite LLMs from IBM Research
[Granite](https://www.ibm.com/granite) is a family of Large Language Models built for business, to help drive trust and scalability in AI-driven applications.
## Running the example
```bash
$ cargo run --example granite --features metal -r -- --model-type "granite7b-instruct" \
--prompt "Explain how quantum computing differs from classical computing, focusing on key concepts like qubits, superposition, and entanglement. Describe two potential breakthroughs in the fields of drug discovery and cryptography. Offer a convincing argument for why businesses and governments should invest in quantum computing research now, emphasizing its future benefits and the risks of falling behind"
Explain how quantum computing differs from classical computing, focusing on key concepts like qubits, superposition, and entanglement. Describe two potential breakthroughs in the fields of drug discovery and cryptography. Offer a convincing argument for why businesses and governments should invest in quantum computing research now, emphasizing its future benefits and the risks of falling behind competitors.
In recent years, there has been significant interest in quantum computing due to its potential to revolutionize various fields, including drug discovery, cryptography, and optimization problems. Quantum computers, which leverage the principles of quantum mechanics, differ fundamentally from classical computers. Here are some of the key differences:
```
## Supported Models
There are two different modalities for the Granite family models: Language and Code.
### Granite for language
1. [Granite 7b Instruct](https://huggingface.co/ibm-granite/granite-7b-instruct)

View File

@ -1,251 +0,0 @@
// An implementation of different Granite models https://www.ibm.com/granite
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use anyhow::{bail, Error as E, Result};
use clap::{Parser, ValueEnum};
use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;
use candle_transformers::models::granite as model;
use model::{Granite, GraniteConfig};
use std::time::Instant;
const EOS_TOKEN: &str = "</s>";
const DEFAULT_PROMPT: &str = "How Fault Tolerant Quantum Computers will help humanity?";
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum GraniteModel {
Granite7bInstruct,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(short = 'n', long, default_value_t = 10000)]
sample_len: usize,
/// Disable the key-value cache.
#[arg(long)]
no_kv_cache: bool,
/// The initial prompt.
#[arg(long)]
prompt: Option<String>,
/// Use different dtype than f16
#[arg(long)]
dtype: Option<String>,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long, default_value = "granite7b-instruct")]
model_type: GraniteModel,
#[arg(long)]
use_flash_attn: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 128)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tokenizers::Tokenizer;
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let device = candle_examples::device(args.cpu)?;
let dtype = match args.dtype.as_deref() {
Some("f16") => DType::F16,
Some("bf16") => DType::BF16,
Some("f32") => DType::F32,
Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::F16,
};
let (granite, tokenizer_filename, mut cache, config) = {
let api = Api::new()?;
let model_id = args.model_id.unwrap_or_else(|| match args.model_type {
GraniteModel::Granite7bInstruct => "ibm-granite/granite-7b-instruct".to_string(),
});
println!("loading the model weights from {model_id}");
let revision = args.revision.unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = api.get("tokenizer.json")?;
let config_filename = api.get("config.json")?;
let config: GraniteConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let config = config.into_config(args.use_flash_attn);
let filenames = match args.model_type {
GraniteModel::Granite7bInstruct => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
}
};
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
(
Granite::load(vb, &config)?,
tokenizer_filename,
cache,
config,
)
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let eos_token_id = config.eos_token_id.or_else(|| {
tokenizer
.token_to_id(EOS_TOKEN)
.map(model::GraniteEosToks::Single)
});
let default_prompt = match args.model_type {
GraniteModel::Granite7bInstruct => DEFAULT_PROMPT,
};
let prompt = args.prompt.as_ref().map_or(default_prompt, |p| p.as_str());
let mut tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
println!("Starting the inference loop:");
print!("{prompt}");
let mut logits_processor = {
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k, args.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
let mut start_gen = std::time::Instant::now();
let mut index_pos = 0;
let mut token_generated = 0;
let use_cache_kv = cache.use_kv_cache;
(0..args.sample_len)
.inspect(|index| {
if *index == 1 {
start_gen = Instant::now();
}
})
.try_for_each(|index| -> Result<()> {
let (context_size, context_index) = if use_cache_kv && index > 0 {
(1, index_pos)
} else {
(tokens.len(), 0)
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = granite
.forward(&input, context_index, &mut cache)?
.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&tokens[start_at..],
)?
};
index_pos += ctxt.len();
let next_token = logits_processor.sample(&logits)?;
token_generated += 1;
tokens.push(next_token);
if let Some(model::GraniteEosToks::Single(eos_tok_id)) = eos_token_id {
if next_token == eos_tok_id {
return Err(E::msg("EOS token found"));
}
} else if let Some(model::GraniteEosToks::Multiple(ref eos_ids)) = eos_token_id {
if eos_ids.contains(&next_token) {
return Err(E::msg("EOS token found"));
}
}
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
Ok(())
})
.unwrap_or(());
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
let dt = start_gen.elapsed();
println!(
"\n\n{} tokens generated ({} token/s)\n",
token_generated,
(token_generated - 1) as f64 / dt.as_secs_f64(),
);
Ok(())
}

View File

@ -1,19 +0,0 @@
# gte-Qwen1.5-7B-instruct
gte-Qwen1.5-7B-instruct is a variant of the GTE embedding model family.
- [Model card](https://huggingface.co/Alibaba-NLP/gte-Qwen1.5-7B-instruct) on the HuggingFace Hub.
- [Technical report](https://arxiv.org/abs/2308.03281) *Towards General Text Embeddings with Multi-stage Contrastive Learning*
## Running the example
Automatically download the model from the HuggingFace hub:
```bash
$ cargo run --example gte-qwen --release
```
or, load the model from a local directory:
```bash
cargo run --example gte-qwen --release --features cuda -- --local-repo /path/to/gte_Qwen1.5-7B-instruct/
```

View File

@ -1,178 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::qwen2::{Config, Model};
use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{
utils::padding::{PaddingDirection, PaddingParams, PaddingStrategy},
Tokenizer,
};
// gte-Qwen1.5-7B-instruct use EOS token as padding token
const EOS_TOKEN: &str = "<|endoftext|>";
const EOS_TOKEN_ID: u32 = 151643;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long, default_value = "Alibaba-NLP/gte-Qwen1.5-7B-instruct")]
model_id: String,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
local_repo: Option<String>,
}
#[derive(Debug)]
struct ConfigFiles {
pub config: std::path::PathBuf,
pub tokenizer: std::path::PathBuf,
pub weights: Vec<std::path::PathBuf>,
}
// Loading the model from the HuggingFace Hub. Network access is required.
fn load_from_hub(model_id: &str, revision: &str) -> Result<ConfigFiles> {
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
model_id.to_string(),
RepoType::Model,
revision.to_string(),
));
Ok(ConfigFiles {
config: repo.get("config.json")?,
tokenizer: repo.get("tokenizer.json")?,
weights: candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
})
}
// Loading the model from a local directory.
fn load_from_local(local_path: &str) -> Result<ConfigFiles> {
let local_path = std::path::PathBuf::from(local_path);
let weight_path = local_path.join("model.safetensors.index.json");
let json: serde_json::Value = serde_json::from_str(&std::fs::read_to_string(weight_path)?)?;
let weight_map = match json.get("weight_map") {
Some(serde_json::Value::Object(map)) => map,
Some(_) => panic!("`weight map` is not a map"),
None => panic!("`weight map` not found"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
safetensors_files.insert(
value
.as_str()
.expect("Weight files should be parsed as strings"),
);
}
let safetensors_paths = safetensors_files
.iter()
.map(|v| local_path.join(v))
.collect::<Vec<_>>();
Ok(ConfigFiles {
config: local_path.join("config.json"),
tokenizer: local_path.join("tokenizer.json"),
weights: safetensors_paths,
})
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
// Fetch the model. Do this offline if local path provided.
println!("Fetching model files...");
let start = std::time::Instant::now();
let config_files = match args.local_repo {
Some(local_path) => load_from_local(&local_path)?,
None => load_from_hub(&args.model_id, &args.revision)?,
};
println!("Model file retrieved in {:?}", start.elapsed());
// Inputs will be padded to the longest sequence in the batch.
let padding = PaddingParams {
strategy: PaddingStrategy::BatchLongest,
direction: PaddingDirection::Left,
pad_to_multiple_of: None,
pad_id: EOS_TOKEN_ID,
pad_type_id: 0,
pad_token: String::from(EOS_TOKEN),
};
// Tokenizer setup
let mut tokenizer = Tokenizer::from_file(config_files.tokenizer).map_err(E::msg)?;
tokenizer.with_padding(Some(padding));
// Model initialization
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let config: Config = serde_json::from_slice(&std::fs::read(config_files.config)?)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&config_files.weights, dtype, &device)? };
let mut model = Model::new(&config, vb)?;
println!("Model loaded in {:?}", start.elapsed());
// Encode the queries and the targets
let instruct = "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: ";
let documents = vec![
format!("{instruct}how much protein should a female eat{EOS_TOKEN}"),
format!("{instruct}summit define{EOS_TOKEN}"),
format!("As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.{EOS_TOKEN}"),
format!("Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.{EOS_TOKEN}"),
];
let encoded = tokenizer.encode_batch(documents, true).map_err(E::msg)?;
let tokens: Vec<&[u32]> = encoded.iter().map(|x| x.get_ids()).collect();
let tokens = Tensor::new(tokens, &device)?;
let mask: Vec<&[u32]> = encoded.iter().map(|x| x.get_attention_mask()).collect();
let mask = Tensor::new(mask, &device)?;
// Inference
let start_gen = std::time::Instant::now();
let logits = model.forward(&tokens, 0, Some(&mask))?;
// Extract the last hidden states as embeddings since inputs are padded left.
let (_, seq_len, _) = logits.dims3()?;
let embd = logits
.narrow(1, seq_len - 1, 1)?
.squeeze(1)?
.to_dtype(DType::F32)?;
// Calculate the relativity scores. Note the embeddings should be normalized.
let norm = embd.broadcast_div(&embd.sqr()?.sum_keepdim(1)?.sqrt()?)?;
let scores = norm.narrow(0, 0, 2)?.matmul(&norm.narrow(0, 2, 2)?.t()?)?;
// Print the results
println!("Embedding done in {:?}", start_gen.elapsed());
println!("Scores: {:?}", scores.to_vec2::<f32>()?);
Ok(())
}

View File

@ -1,18 +0,0 @@
# hiera
[Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles](https://arxiv.org/abs/2306.00989)
This candle implementation uses pre-trained Hiera models from timm for inference.
The classification head has been trained on the ImageNet dataset and returns the probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example hiera --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which tiny
loaded image Tensor[dims 3, 224, 224; f32]
model built
mountain bike, all-terrain bike, off-roader: 71.15%
unicycle, monocycle : 7.11%
knee pad : 4.26%
crash helmet : 1.48%
moped : 1.07%
```

View File

@ -1,99 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::hiera;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
Tiny,
Small,
Base,
BasePlus,
Large,
Huge,
}
impl Which {
fn model_filename(&self) -> String {
let name = match self {
Self::Tiny => "tiny",
Self::Small => "small",
Self::Base => "base",
Self::BasePlus => "base_plus",
Self::Large => "large",
Self::Huge => "huge",
};
format!("timm/hiera_{}_224.mae_in1k_ft_in1k", name)
}
fn config(&self) -> hiera::Config {
match self {
Self::Tiny => hiera::Config::tiny(),
Self::Small => hiera::Config::small(),
Self::Base => hiera::Config::base(),
Self::BasePlus => hiera::Config::base_plus(),
Self::Large => hiera::Config::large(),
Self::Huge => hiera::Config::huge(),
}
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(value_enum, long, default_value_t=Which::Tiny)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let model_name = args.which.model_filename();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = hiera::hiera(&args.which.config(), 1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -4,7 +4,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_transformers::models::jina_bert::{BertModel, Config, PositionEmbeddingType};
use candle_transformers::models::jina_bert::{BertModel, Config};
use anyhow::Error as E;
use candle::{DType, Module, Tensor};
@ -39,47 +39,32 @@ struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
model_file: Option<String>,
}
impl Args {
fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
use hf_hub::{api::sync::Api, Repo, RepoType};
let model_name = match self.model.as_ref() {
Some(model) => model.to_string(),
None => "jinaai/jina-embeddings-v2-base-en".to_string(),
};
let model = match &self.model_file {
let model = match &self.model {
Some(model_file) => std::path::PathBuf::from(model_file),
None => Api::new()?
.repo(Repo::new(model_name.to_string(), RepoType::Model))
.repo(Repo::new(
"jinaai/jina-embeddings-v2-base-en".to_string(),
RepoType::Model,
))
.get("model.safetensors")?,
};
let tokenizer = match &self.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => Api::new()?
.repo(Repo::new(model_name.to_string(), RepoType::Model))
.repo(Repo::new(
"sentence-transformers/all-MiniLM-L6-v2".to_string(),
RepoType::Model,
))
.get("tokenizer.json")?,
};
let device = candle_examples::device(self.cpu)?;
let config = Config::v2_base();
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let config = Config::new(
tokenizer.get_vocab_size(true),
768,
12,
12,
3072,
candle_nn::Activation::Gelu,
8192,
2,
0.02,
1e-12,
0,
PositionEmbeddingType::Alibi,
);
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let model = BertModel::new(vb, &config)?;
Ok((model, tokenizer))
@ -116,20 +101,14 @@ fn main() -> anyhow::Result<()> {
.to_vec();
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
println!("Loaded and encoded {:?}", start.elapsed());
let start = std::time::Instant::now();
let embeddings = model.forward(&token_ids)?;
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
println!("pooled_embeddigns: {embeddings}");
let embeddings = if args.normalize_embeddings {
normalize_l2(&embeddings)?
} else {
embeddings
};
if args.normalize_embeddings {
println!("normalized_embeddings: {embeddings}");
for idx in 0..args.n {
let start = std::time::Instant::now();
let ys = model.forward(&token_ids)?;
if idx == 0 {
println!("{ys}");
}
println!("Took {:?}", start.elapsed());
}
println!("Took {:?}", start.elapsed());
} else {
let sentences = [
"The cat sits outside",

View File

@ -32,9 +32,7 @@ enum Which {
V1,
V2,
V3,
V31,
V3Instruct,
V31Instruct,
#[value(name = "solar-10.7b")]
Solar10_7B,
#[value(name = "tiny-llama-1.1b-chat")]
@ -135,8 +133,6 @@ fn main() -> Result<()> {
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
Which::V31 => "meta-llama/Meta-Llama-3.1-8B".to_string(),
Which::V31Instruct => "meta-llama/Meta-Llama-3.1-8B-Instruct".to_string(),
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
});
@ -150,13 +146,7 @@ fn main() -> Result<()> {
let config = config.into_config(args.use_flash_attn);
let filenames = match args.which {
Which::V1
| Which::V2
| Which::V3
| Which::V3Instruct
| Which::V31
| Which::V31Instruct
| Which::Solar10_7B => {
Which::V1 | Which::V2 | Which::V3 | Which::V3Instruct | Which::Solar10_7B => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
}
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
@ -167,11 +157,9 @@ fn main() -> Result<()> {
(Llama::load(vb, &config)?, tokenizer_filename, cache, config)
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let eos_token_id = config.eos_token_id.or_else(|| {
tokenizer
.token_to_id(EOS_TOKEN)
.map(model::LlamaEosToks::Single)
});
let eos_token_id = config
.eos_token_id
.or_else(|| tokenizer.token_to_id(EOS_TOKEN));
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
let mut tokens = tokenizer
.encode(prompt, true)
@ -229,14 +217,8 @@ fn main() -> Result<()> {
token_generated += 1;
tokens.push(next_token);
match eos_token_id {
Some(model::LlamaEosToks::Single(eos_tok_id)) if next_token == eos_tok_id => {
break;
}
Some(model::LlamaEosToks::Multiple(ref eos_ids)) if eos_ids.contains(&next_token) => {
break;
}
_ => (),
if Some(next_token) == eos_token_id {
break;
}
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");

View File

@ -14,7 +14,6 @@ use clap::{Parser, ValueEnum};
use candle::{DType, Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::llama::LlamaEosToks;
use cudarc::driver::safe::CudaDevice;
use cudarc::nccl::safe::{Comm, Id};
use hf_hub::{api::sync::Api, Repo, RepoType};
@ -220,16 +219,9 @@ fn main() -> Result<()> {
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
match config.eos_token_id {
Some(LlamaEosToks::Single(eos_tok_id)) if next_token == eos_tok_id => {
break;
}
Some(LlamaEosToks::Multiple(ref eos_ids)) if eos_ids.contains(&next_token) => {
break;
}
_ => (),
if Some(next_token) == config.eos_token_id {
break;
}
if rank == 0 {
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");

View File

@ -1,4 +0,0 @@
pub const DEFAULT_IMAGE_TOKEN: &str = "<image>";
pub const DEFAULT_IM_START_TOKEN: &str = "<im_start>";
pub const DEFAULT_IM_END_TOKEN: &str = "<im_end>";
pub const IMAGE_PLACEHOLDER: &str = "<image-placeholder>";

View File

@ -1,114 +0,0 @@
pub enum SeparatorStyle {
Two,
Mpt,
}
pub struct Conversation {
pub system: String,
pub roles: Vec<String>,
pub messages: Vec<(String, Option<String>)>,
pub offset: i32,
pub sep_style: SeparatorStyle,
pub sep: String,
pub sep2: Option<String>,
pub version: String,
}
impl Conversation {
pub fn new(
system: &str,
roles: &[String],
offset: i32,
sep_style: SeparatorStyle,
sep: &str,
sep2: Option<&str>,
version: &str,
) -> Self {
Conversation {
system: system.to_string(),
roles: roles.to_vec(),
messages: Vec::new(),
offset,
sep_style,
sep: sep.to_string(),
sep2: sep2.map(|s| s.to_string()),
version: version.to_string(),
}
}
pub fn conv_chatml_direct() -> Self {
Conversation::new(
"<|im_start|>system\nAnswer the questions.",
&[
"<|im_start|>user\n".to_string(),
"<|im_start|>assistant\n".to_string(),
],
0,
SeparatorStyle::Mpt,
"<|im_end|>",
None,
"mpt",
)
}
pub fn conv_llava_v1() -> Self {
Conversation::new(
"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.",
&[
"USER".to_string(),
"ASSISTANT".to_string(),
],
0,
SeparatorStyle::Two,
" ",
Some("</s>"),
"v1"
)
}
pub fn append_message(&mut self, role: String, message: Option<&str>) {
self.messages.push((role, message.map(|s| s.to_string())))
}
pub fn append_user_message(&mut self, message: Option<&str>) {
self.append_message(self.roles[0].clone(), message);
}
pub fn append_assistant_message(&mut self, message: Option<&str>) {
self.append_message(self.roles[1].clone(), message);
}
pub fn get_prompt(&self) -> String {
match self.sep_style {
SeparatorStyle::Mpt => {
let mut ret = String::new();
ret.push_str(&self.system);
ret.push_str(&self.sep);
for (role, message) in &self.messages {
ret.push_str(role);
if let Some(message) = message {
ret.push_str(message);
};
ret.push_str(&self.sep);
}
ret
}
SeparatorStyle::Two => {
let seps = [self.sep.clone(), self.sep2.clone().unwrap()];
let mut ret = String::new();
ret.push_str(&self.system);
ret.push_str(&seps[0]);
for (i, (role, message)) in self.messages.iter().enumerate() {
ret.push_str(role);
if let Some(message) = message {
ret.push_str(": "); // strictly follow the python implementation, otherwise it will cause some minor difference between tokens ^_^
ret.push_str(message);
ret.push_str(&seps[i % 2]);
} else {
ret.push(':')
}
}
ret
}
}
}
}

View File

@ -1,317 +0,0 @@
use std::cmp::min;
use candle::{bail, DType, Device, Result, Tensor};
use candle_transformers::models::llava::{
config::{HFPreProcessorConfig, LLaVAConfig},
utils::select_best_resolution,
};
use hf_hub::api::sync::Api;
use image::{imageops::overlay, DynamicImage, GenericImageView, Rgb, RgbImage};
use serde::{Deserialize, Serialize};
//This struct is mainly for LLaVA aplications, hence it's not completely compatible with python transformer CLIPImageProcessor few several preprocess that LLaVA used, including "openai/clip-vit-large-patch14-336" and "openai/clip-vit-large-patch14".
#[derive(Serialize, Deserialize, Debug)]
pub struct ImageProcessor {
#[serde(default = "default_size")]
pub size: u32, // this is not the same as python transformer
#[serde(default = "default_do_resize")]
pub do_resize: bool,
//resample: u32 // 3 for PIL bicubic, equivalent to rust CatmullRom. Hence below we use CatmullRom
#[serde(default = "default_do_center_crop")]
pub do_center_crop: bool,
#[serde(default = "default_crop_size")]
pub crop_size: u32, // this is not the same as python transformer
#[serde(default = "default_do_rescale")]
pub do_rescale: bool,
#[serde(default = "default_rescale_factor")]
pub rescale_factor: f32,
#[serde(default = "default_do_normalize")]
pub do_normalize: bool,
#[serde(default = "default_image_mean")]
pub image_mean: Vec<f32>,
#[serde(default = "default_image_std")]
pub image_std: Vec<f32>,
}
fn default_size() -> u32 {
224
}
fn default_do_resize() -> bool {
true
}
fn default_do_center_crop() -> bool {
true
}
fn default_crop_size() -> u32 {
224
}
fn default_do_rescale() -> bool {
true
}
fn default_rescale_factor() -> f32 {
1.0 / 255.0
}
fn default_do_normalize() -> bool {
true
}
fn default_image_mean() -> Vec<f32> {
vec![0.48145466, 0.4578275, 0.40821073]
}
fn default_image_std() -> Vec<f32> {
vec![0.26862954, 0.2613026, 0.2757771]
}
impl ImageProcessor {
pub fn from_pretrained(clip_id: &str) -> Result<Self> {
let api = Api::new().map_err(|e| candle::Error::Msg(e.to_string()))?;
let api = api.model(clip_id.to_string());
let config_filename = api
.get("preprocessor_config.json")
.map_err(|e| candle::Error::Msg(e.to_string()))?;
let image_processor =
serde_json::from_slice(&std::fs::read(config_filename).map_err(candle::Error::Io)?)
.map_err(|e| candle::Error::Msg(e.to_string()))?;
Ok(image_processor)
}
pub fn from_hf_preprocessor_config(hf_preprocessor_config: &HFPreProcessorConfig) -> Self {
Self {
size: hf_preprocessor_config.size["shortest_edge"] as u32,
do_resize: hf_preprocessor_config.do_resize,
do_center_crop: hf_preprocessor_config.do_center_crop,
crop_size: hf_preprocessor_config.crop_size["height"] as u32,
do_rescale: hf_preprocessor_config.do_rescale,
rescale_factor: hf_preprocessor_config.rescale_factor,
do_normalize: hf_preprocessor_config.do_normalize,
image_mean: hf_preprocessor_config.image_mean.clone(),
image_std: hf_preprocessor_config.image_std.clone(),
}
}
///shortest edge to self.resize, other edge is resized to maintain aspect ratio
pub fn resize(&self, image: &DynamicImage) -> DynamicImage {
let (width, height) = image.dimensions();
let size = self.size;
if width == size && height == size {
image.clone()
} else {
let (new_width, new_height) = if width < height {
(
size,
(((size * height) as f32) / width as f32).ceil() as u32,
)
} else {
(
(((size * width) as f32) / height as f32).ceil() as u32,
size,
)
};
image.resize(
new_width,
new_height,
image::imageops::FilterType::CatmullRom,
)
}
}
pub fn center_crop(&self, image: &DynamicImage) -> DynamicImage {
let (width, height) = image.dimensions();
let crop_size = self.crop_size;
let (left, top) = calculate_middle((width, height), (crop_size, crop_size));
image.crop_imm(left, top, crop_size, crop_size)
}
pub fn to_tensor(&self, image: &DynamicImage) -> Result<Tensor> {
let img = image.to_rgb8().into_raw();
let (width, height) = image.dimensions();
Tensor::from_vec(img, (height as usize, width as usize, 3), &Device::Cpu)?
.to_dtype(DType::F32) // only for internal compute
}
pub fn rescale(&self, tensor: &Tensor) -> Result<Tensor> {
let rescale_factor = self.rescale_factor as f64;
tensor.affine(rescale_factor, 0.0)
}
pub fn normalize(&self, tensor: &Tensor) -> Result<Tensor> {
let image_mean = self.image_mean.clone();
let image_std = self.image_std.clone();
let mean = Tensor::from_vec(image_mean, (3,), &Device::Cpu)?;
let std = Tensor::from_vec(image_std, (3,), &Device::Cpu)?;
tensor.broadcast_sub(&mean)?.broadcast_div(&std)
}
pub fn to_channel_dimension_format(&self, tensor: &Tensor) -> Result<Tensor> {
tensor.permute((2, 0, 1))
}
pub fn preprocess(&self, image: &DynamicImage) -> Result<Tensor> {
let image = if self.do_resize {
self.resize(image)
} else {
image.clone()
};
let image = if self.do_center_crop {
self.center_crop(&image)
} else {
image
};
let tensor = self.to_tensor(&image)?;
let tensor = if self.do_rescale {
self.rescale(&tensor)?
} else {
tensor
};
let tensor = if self.do_normalize {
self.normalize(&tensor)?
} else {
tensor
};
self.to_channel_dimension_format(&tensor)
}
}
pub fn calculate_middle(image_size: (u32, u32), center_size: (u32, u32)) -> (u32, u32) {
let (width, height) = image_size;
let (center_width, center_height) = center_size;
let left = if width <= center_width {
0
} else {
((width as f32 - center_width as f32) / 2.0).ceil() as u32
};
let top = if height <= center_height {
0
} else {
((height as f32 - center_height as f32) / 2.0).ceil() as u32
};
(left, top)
}
pub fn process_image(
image: &DynamicImage,
processor: &ImageProcessor,
llava_config: &LLaVAConfig,
) -> candle::Result<Tensor> {
if llava_config.image_aspect_ratio == *"square" {
processor.preprocess(image)?.unsqueeze(0)
} else if llava_config.image_aspect_ratio == *"anyres" {
process_anyres_image(image, processor, &llava_config.image_grid_pinpoints)
} else if llava_config.image_aspect_ratio == *"pad" {
process_pad_image(image, processor)
} else {
bail!("Invalid image aspect ratio")
}
}
fn process_pad_image(image: &DynamicImage, processor: &ImageProcessor) -> Result<Tensor> {
let mean_color = processor
.image_mean
.iter()
.map(|x| ((*x) * 255.0) as u8)
.collect::<Vec<u8>>();
let mean_color = Rgb::from([mean_color[0], mean_color[1], mean_color[2]]);
let image_padded = expand2square(image, mean_color);
processor.preprocess(&image_padded)
}
fn process_anyres_image(
image: &DynamicImage,
processor: &ImageProcessor,
grid_pinpoints: &[(u32, u32)],
) -> Result<Tensor> {
let original_size = image.dimensions();
let best_resolution = select_best_resolution(original_size, grid_pinpoints);
let image_padded = resize_and_pad_image(image, best_resolution);
let image_original_resize = image.resize_exact(
processor.size,
processor.size,
image::imageops::FilterType::CatmullRom,
);
let mut patches = vec![image_original_resize];
for patch in divide_to_patches(&image_padded, processor.crop_size) {
patches.push(patch);
}
let tensors = patches
.iter()
.map(|patch| processor.preprocess(patch))
.collect::<Result<Vec<Tensor>>>()?;
Tensor::stack(&tensors, 0)
}
fn expand2square(image: &DynamicImage, background_color: Rgb<u8>) -> DynamicImage {
let (width, height) = image.dimensions();
match width.cmp(&height) {
std::cmp::Ordering::Less => {
let mut new_image =
DynamicImage::from(RgbImage::from_pixel(height, height, background_color));
overlay(&mut new_image, image, ((height - width) / 2) as i64, 0);
new_image
}
std::cmp::Ordering::Equal => image.clone(),
std::cmp::Ordering::Greater => {
let mut new_image =
DynamicImage::from(RgbImage::from_pixel(width, width, background_color));
overlay(&mut new_image, image, 0, ((width - height) / 2) as i64);
new_image
}
}
}
fn resize_and_pad_image(image: &DynamicImage, target_resolution: (u32, u32)) -> DynamicImage {
let (original_width, original_height) = image.dimensions();
let original_width_f = original_width as f32;
let original_height_f = original_height as f32;
let (target_width, target_height) = target_resolution;
let target_width_f = target_width as f32;
let target_height_f = target_height as f32;
let scale_w = target_width_f / original_width_f;
let scale_h = target_height_f / original_height_f;
let (new_width, new_height) = if scale_w < scale_h {
(
target_width,
min((original_height_f * scale_w).ceil() as u32, target_height),
)
} else {
(
min((original_width_f * scale_h).ceil() as u32, target_width),
target_height,
)
};
let resized_image = image.resize_exact(
new_width,
new_height,
image::imageops::FilterType::CatmullRom,
);
let mut new_image = DynamicImage::new_rgb8(target_width, target_height);
let (paste_x, paste_y) =
calculate_middle((target_width, target_height), (new_width, new_height));
overlay(
&mut new_image,
&resized_image,
paste_x.into(),
paste_y.into(),
);
new_image
}
fn divide_to_patches(image: &DynamicImage, patch_size: u32) -> Vec<DynamicImage> {
let (width, height) = image.dimensions();
let mut patches = Vec::new();
for y in (0..height).step_by(patch_size as usize) {
for x in (0..width).step_by(patch_size as usize) {
let patch = image.crop_imm(x, y, patch_size, patch_size);
patches.push(patch);
}
}
patches
}

View File

@ -1,316 +0,0 @@
pub mod constants;
pub mod conversation;
pub mod image_processor;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_transformers::models::llama::Cache;
use anyhow::{bail, Error as E, Result};
use candle::{DType, Device, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::llava::config::{
HFGenerationConfig, HFLLaVAConfig, HFPreProcessorConfig,
};
use candle_transformers::models::llava::{config::LLaVAConfig, LLaVA};
use clap::Parser;
use constants::*;
use conversation::Conversation;
use hf_hub::api::sync::Api;
use image_processor::{process_image, ImageProcessor};
use std::io::Write;
use tokenizers::Tokenizer;
#[derive(Parser, Debug)]
#[command(author, version, about,long_about=None)]
struct Args {
#[arg(long, default_value = "llava-hf/llava-v1.6-vicuna-7b-hf")]
model_path: String,
#[arg(long, default_value = "tokenizer/tokenizer.json")]
tokenizer_path: String,
#[arg(long)]
model_base: Option<String>,
#[arg(long)]
image_file: String, // Required
#[arg(long)]
conv_mode: Option<String>,
#[arg(long, default_value_t = 0.2)]
temperature: f32,
#[arg(long, default_value_t = 512)]
max_new_tokens: usize,
#[arg(long, action)]
hf: bool,
#[arg(long, action)]
cpu: bool,
#[arg(long, action)]
no_kv_cache: bool,
#[arg(long)]
prompt: String,
/// The seed to use when generating random samples. Copy from candle llama. Not exist in python llava.
#[arg(long, default_value_t = 299792458)]
seed: u64,
}
//from https://github.com/huggingface/candle/blob/main/candle-examples/examples/clip/main.rs
fn load_image<T: AsRef<std::path::Path>>(
path: T,
processor: &ImageProcessor,
llava_config: &LLaVAConfig,
dtype: DType,
) -> Result<((u32, u32), Tensor)> {
let img = image::ImageReader::open(path)?.decode()?;
let img_tensor = process_image(&img, processor, llava_config)?;
Ok(((img.width(), img.height()), img_tensor.to_dtype(dtype)?))
}
fn get_model_name_from_path(model_path: &str) -> String {
let model_paths: Vec<String> = model_path
.trim_matches('/')
.split('/')
.map(|s| s.to_string())
.collect();
if model_paths.last().unwrap().starts_with("checkpoint-") {
format!(
"{}_{}",
model_paths[model_paths.len() - 2],
model_paths.last().unwrap()
)
} else {
model_paths.last().unwrap().to_string()
}
}
fn duplicate_vec<T>(vec: &[T], n: usize) -> Vec<T>
where
T: Clone,
{
let mut res = Vec::new();
for _ in 0..n {
res.extend(vec.to_owned());
}
res
}
fn insert_separator<T>(x: Vec<Vec<T>>, sep: Vec<T>) -> Vec<Vec<T>>
where
T: Clone,
{
let sep = vec![sep];
let sep = duplicate_vec(&sep, x.len());
let mut res = x
.iter()
.zip(sep.iter())
.flat_map(|(x, y)| vec![x.clone(), y.clone()])
.collect::<Vec<Vec<T>>>();
res.pop();
res
}
fn tokenizer_image_token(
prompt: &str,
tokenizer: &Tokenizer,
image_token_index: i64,
llava_config: &LLaVAConfig,
) -> Result<Tensor> {
let prompt_chunks = prompt
.split("<image>")
.map(|s| {
tokenizer
.encode(s, true)
.unwrap()
.get_ids()
.to_vec()
.iter()
.map(|x| *x as i64)
.collect()
})
.collect::<Vec<Vec<i64>>>();
let mut input_ids = Vec::new();
let mut offset = 0;
if !prompt_chunks.is_empty()
&& !prompt_chunks[0].is_empty()
&& prompt_chunks[0][0] == llava_config.bos_token_id as i64
{
offset = 1;
input_ids.push(prompt_chunks[0][0]);
}
for x in insert_separator(
prompt_chunks,
duplicate_vec(&[image_token_index], offset + 1),
)
.iter()
{
input_ids.extend(x[1..].to_vec())
}
let input_len = input_ids.len();
Tensor::from_vec(input_ids, (1, input_len), &Device::Cpu).map_err(E::msg)
}
fn main() -> Result<()> {
let mut args = Args::parse();
let device = candle_examples::device(args.cpu)?;
println!("Start loading model");
let api = Api::new()?;
let api = api.model(args.model_path.clone());
let (llava_config, tokenizer, clip_vision_config, image_processor) = if args.hf {
let config_filename = api.get("config.json")?;
let hf_llava_config: HFLLaVAConfig =
serde_json::from_slice(&std::fs::read(config_filename)?)?;
let generation_config_filename = api.get("generation_config.json")?;
let generation_config: HFGenerationConfig =
serde_json::from_slice(&std::fs::read(generation_config_filename)?)?;
let preprocessor_config_filename = api.get("preprocessor_config.json")?;
let preprocessor_config: HFPreProcessorConfig =
serde_json::from_slice(&std::fs::read(preprocessor_config_filename)?)?;
let llava_config =
hf_llava_config.to_llava_config(&generation_config, &preprocessor_config);
let tokenizer_filename = api.get("tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let clip_vision_config = hf_llava_config.to_clip_vision_config();
(
llava_config,
tokenizer,
Some(clip_vision_config),
ImageProcessor::from_hf_preprocessor_config(&preprocessor_config),
)
} else {
let config_filename = api.get("config.json")?;
let llava_config: LLaVAConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let tokenizer = Tokenizer::from_file(&args.tokenizer_path)
.map_err(|e| E::msg(format!("Error loading {}: {}", &args.tokenizer_path, e)))?;
(
llava_config.clone(),
tokenizer,
None,
ImageProcessor::from_pretrained(&llava_config.mm_vision_tower.unwrap())?,
)
};
let llama_config = llava_config.to_llama_config();
let dtype: DType = match llava_config.torch_dtype.as_str() {
"float16" => DType::F16,
"bfloat16" => DType::BF16,
_ => bail!("unsupported dtype"),
};
let eos_token_id = llava_config.eos_token_id;
println!("setting kv cache");
let mut cache = Cache::new(!args.no_kv_cache, dtype, &llama_config, &device)?;
println!("loading model weights");
let weight_filenames =
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&weight_filenames, dtype, &device)? };
let llava: LLaVA = LLaVA::load(vb, &llava_config, clip_vision_config)?;
println!("generating conv template");
let image_token_se = format!(
"{}{}{}",
DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_END_TOKEN
);
let qs = if args.prompt.contains(IMAGE_PLACEHOLDER) {
if llava_config.mm_use_im_start_end {
args.prompt.replace(IMAGE_PLACEHOLDER, &image_token_se)
} else {
args.prompt.replace(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN)
}
} else if llava_config.mm_use_im_start_end {
format!("{}\n{}", image_token_se, args.prompt)
} else {
format!("{}\n{}", DEFAULT_IMAGE_TOKEN, args.prompt)
};
let model_name = get_model_name_from_path(&args.model_path).to_lowercase();
let conv_mode = if model_name.contains("llama-2") {
"llava_llama_2"
} else if model_name.contains("mistral") {
"mistral_instruct"
} else if model_name.contains("v1.6-34b") {
"chatml_direct"
} else if model_name.contains("v1") {
"llava_v1"
} else if model_name.contains("mpt") {
"mpt"
} else {
"llava_v0"
};
if args.conv_mode.is_some() && args.conv_mode.as_deref() != Some(conv_mode) {
println!(
"Warning: the model is trained with {}, but you are using {}",
conv_mode,
args.conv_mode.as_deref().unwrap()
);
} else {
args.conv_mode = Some(conv_mode.to_string());
}
let mut conv = match args.conv_mode {
Some(conv_mode) => match conv_mode.as_str() {
"chatml_direct" => Conversation::conv_chatml_direct(),
"llava_v1" => Conversation::conv_llava_v1(),
_ => todo!("not implement yet"),
},
None => bail!("conv_mode is required"),
};
conv.append_user_message(Some(&qs));
conv.append_assistant_message(None);
let prompt = conv.get_prompt();
println!("loading image");
let (image_size, image_tensor) =
load_image(&args.image_file, &image_processor, &llava_config, dtype)
.map_err(|e| E::msg(format!("Error loading {}: {}", &args.image_file, e)))?;
let image_tensor = image_tensor.to_device(&device)?;
let mut logits_processor = {
let temperature = f64::from(args.temperature);
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
Sampling::All { temperature }
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
// get input tokens
let tokens = tokenizer_image_token(
&prompt,
&tokenizer,
llava_config.image_token_index as i64,
&llava_config,
)?;
let mut input_embeds =
llava.prepare_inputs_labels_for_multimodal(&tokens, &[image_tensor], &[image_size])?;
//inference loop, based on https://github.com/huggingface/candle/blob/main/candle-examples/examples/llama/main.rs
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
let mut index_pos = 0;
for index in 0..args.max_new_tokens {
let (_, input_embeds_len, _) = input_embeds.dims3()?;
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
(1, index_pos)
} else {
(input_embeds_len, 0)
};
let input = input_embeds.i((.., input_embeds_len.saturating_sub(context_size).., ..))?;
let logits = llava.forward(&input, context_index, &mut cache)?; //[1,32000]
let logits = logits.squeeze(0)?;
let (_, input_len, _) = input.dims3()?;
index_pos += input_len;
let next_token = logits_processor.sample(&logits)?;
let next_token_tensor = Tensor::from_vec(vec![next_token], 1, &device)?;
let next_embeds = llava.llama.embed(&next_token_tensor)?.unsqueeze(0)?;
input_embeds = Tensor::cat(&[input_embeds, next_embeds], 1)?;
if next_token == eos_token_id as u32 {
break;
}
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
Ok(())
}

View File

@ -1,40 +0,0 @@
# candle-llava
LLaVA (Large Language-and-Vision Assistant) is an end-to-end trained large
multimodal model. This example is from [candle-llava](https://github.com/chenwanqq/candle-llava)
The code is based on [https://github.com/haotian-liu/LLaVA](https://github.com/haotian-liu/LLaVA), Hence the llava-hf version of config may perform differently.
## model zoo
* [liuhaotian/LLaVA](https://huggingface.co/liuhaotian)
* [llava-hf](https://huggingface.co/llava-hf)
Right now this has been tested on `liuhaotian/llava-v1.6-vicuna-7b` and
`llava-hf/llava-v1.6-vicuna-7b-hf`. Memory usage might have room for optimization.
## Tokenizer Setup
The llava-hf models contain a `tokenizer.json` file so can be used directly with
the `-hf` command line flag.
For the original llava models, you can use the following code to generate the `tokenizer.json` file.
```bash
conda create -n llava python=3.10
pip install transformers protobuf
conda activate llava
python -c "from transformers import AutoTokenizer;tokenizer=AutoTokenizer.from_pretrained('liuhaotian/llava-v1.6-vicuna-7b');tokenizer.save_pretrained('tokenizer')"
```
Then the `tokenizer.json` file should be in `tokenizer/tokenizer.json` (which is the default path).
## eval
```bash
cargo run --example llava --features cuda -- --image-file "llava_logo.png" --prompt "is this a cat?" --hf # default args, use llava-hf/llava-v1.6-vicuna-7b-hf. image-file is required^_^
cargo run --example llava --features cuda -- --model-path liuhaotian/llava-v1.6-vicuna-7b --image-file "llava_logo.png" --prompt "is this a cat?" # use liuhaotian/llava-v1.6-vicuna-7b, tokenizer setup should be done
```
## Major Limitations
1. Currently only support llama-2/vicuna llm. Haven't supoort Mistral yet.
2. There are some ops like split, nonzero and where are not supported by candle.
3. Lack of quantization and LoRA support.

View File

@ -43,14 +43,6 @@ def import_protobuf(error_message=""):
else:
raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))
def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str:
if add_prefix_space:
prepend_scheme = "always"
if hasattr(original_tokenizer, "legacy") and not original_tokenizer.legacy:
prepend_scheme = "first"
else:
prepend_scheme = "never"
return prepend_scheme
class SentencePieceExtractor:
"""
@ -527,15 +519,13 @@ class SpmConverter(Converter):
)
def pre_tokenizer(self, replacement, add_prefix_space):
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
return pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
def post_processor(self):
return None
def decoder(self, replacement, add_prefix_space):
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
return decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
return decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
def converted(self) -> Tokenizer:
tokenizer = self.tokenizer(self.proto)
@ -646,8 +636,7 @@ class DebertaV2Converter(SpmConverter):
list_pretokenizers = []
if self.original_tokenizer.split_by_punct:
list_pretokenizers.append(pre_tokenizers.Punctuation(behavior="isolated"))
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme))
list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space))
return pre_tokenizers.Sequence(list_pretokenizers)
def normalizer(self, proto):
@ -940,11 +929,10 @@ class PegasusConverter(SpmConverter):
return proto.trainer_spec.unk_id + self.original_tokenizer.offset
def pre_tokenizer(self, replacement, add_prefix_space):
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
return pre_tokenizers.Sequence(
[
pre_tokenizers.WhitespaceSplit(),
pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme),
pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space),
]
)

View File

@ -1,20 +0,0 @@
# candle-mimi
[Mimi](https://huggingface.co/kyutai/mimi) is a state of the art audio
compression model using an encoder/decoder architecture with residual vector
quantization. The candle implementation supports streaming meaning that it's
possible to encode or decode a stream of audio tokens on the flight to provide
low latency interaction with an audio model.
## Running one example
Generating some audio tokens from an audio files.
```bash
wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
cargo run --example mimi --features mimi --release -- audio-to-code bria.mp3 bria.safetensors
```
And decoding the audio tokens back into a sound file.
```bash
cargo run --example mimi --features mimi --release -- code-to-audio bria.safetensors bria.wav
```

View File

@ -1,275 +0,0 @@
#![allow(unused)]
use anyhow::{Context, Result};
use std::sync::{Arc, Mutex};
pub const SAMPLE_RATE: usize = 24_000;
pub(crate) struct AudioOutputData_ {
resampled_data: std::collections::VecDeque<f32>,
resampler: rubato::FastFixedIn<f32>,
output_buffer: Vec<f32>,
input_buffer: Vec<f32>,
input_len: usize,
}
impl AudioOutputData_ {
pub(crate) fn new(input_sample_rate: usize, output_sample_rate: usize) -> Result<Self> {
use rubato::Resampler;
let resampled_data = std::collections::VecDeque::with_capacity(output_sample_rate * 10);
let resample_ratio = output_sample_rate as f64 / input_sample_rate as f64;
let resampler = rubato::FastFixedIn::new(
resample_ratio,
f64::max(resample_ratio, 1.0),
rubato::PolynomialDegree::Septic,
1024,
1,
)?;
let input_buffer = resampler.input_buffer_allocate(true).remove(0);
let output_buffer = resampler.output_buffer_allocate(true).remove(0);
Ok(Self {
resampled_data,
resampler,
input_buffer,
output_buffer,
input_len: 0,
})
}
pub fn reset(&mut self) {
use rubato::Resampler;
self.output_buffer.fill(0.);
self.input_buffer.fill(0.);
self.resampler.reset();
self.resampled_data.clear();
}
pub(crate) fn take_all(&mut self) -> Vec<f32> {
let mut data = Vec::with_capacity(self.resampled_data.len());
while let Some(elem) = self.resampled_data.pop_back() {
data.push(elem);
}
data
}
pub(crate) fn is_empty(&self) -> bool {
self.resampled_data.is_empty()
}
// Assumes that the input buffer is large enough.
fn push_input_buffer(&mut self, samples: &[f32]) {
self.input_buffer[self.input_len..self.input_len + samples.len()].copy_from_slice(samples);
self.input_len += samples.len()
}
pub(crate) fn push_samples(&mut self, samples: &[f32]) -> Result<()> {
use rubato::Resampler;
let mut pos_in = 0;
loop {
let rem = self.input_buffer.len() - self.input_len;
let pos_end = usize::min(pos_in + rem, samples.len());
self.push_input_buffer(&samples[pos_in..pos_end]);
pos_in = pos_end;
if self.input_len < self.input_buffer.len() {
break;
}
let (_, out_len) = self.resampler.process_into_buffer(
&[&self.input_buffer],
&mut [&mut self.output_buffer],
None,
)?;
for &elem in self.output_buffer[..out_len].iter() {
self.resampled_data.push_front(elem)
}
self.input_len = 0;
}
Ok(())
}
}
type AudioOutputData = Arc<Mutex<AudioOutputData_>>;
pub(crate) fn setup_output_stream() -> Result<(cpal::Stream, AudioOutputData)> {
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
println!("Setup audio output stream!");
let host = cpal::default_host();
let device = host
.default_output_device()
.context("no output device available")?;
let mut supported_configs_range = device.supported_output_configs()?;
let config_range = match supported_configs_range.find(|c| c.channels() == 1) {
// On macOS, it's commonly the case that there are only stereo outputs.
None => device
.supported_output_configs()?
.next()
.context("no audio output available")?,
Some(config_range) => config_range,
};
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
config_range.min_sample_rate(),
config_range.max_sample_rate(),
);
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
let channels = config.channels as usize;
println!(
"cpal device: {} {} {config:?}",
device.name().unwrap_or_else(|_| "unk".to_string()),
config.sample_rate.0
);
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
SAMPLE_RATE,
config.sample_rate.0 as usize,
)?));
let ad = audio_data.clone();
let stream = device.build_output_stream(
&config,
move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
data.fill(0.);
let mut ad = ad.lock().unwrap();
let mut last_elem = 0f32;
for (idx, elem) in data.iter_mut().enumerate() {
if idx % channels == 0 {
match ad.resampled_data.pop_back() {
None => break,
Some(v) => {
last_elem = v;
*elem = v
}
}
} else {
*elem = last_elem
}
}
},
move |err| eprintln!("cpal error: {err}"),
None, // None=blocking, Some(Duration)=timeout
)?;
stream.play()?;
Ok((stream, audio_data))
}
pub(crate) fn setup_input_stream() -> Result<(cpal::Stream, AudioOutputData)> {
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
println!("Setup audio input stream!");
let host = cpal::default_host();
let device = host
.default_input_device()
.context("no input device available")?;
let mut supported_configs_range = device.supported_input_configs()?;
let config_range = supported_configs_range
.find(|c| c.channels() == 1)
.context("no audio input available")?;
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
config_range.min_sample_rate(),
config_range.max_sample_rate(),
);
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
println!(
"cpal device: {} {} {config:?}",
device.name().unwrap_or_else(|_| "unk".to_string()),
config.sample_rate.0
);
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
config.sample_rate.0 as usize,
SAMPLE_RATE,
)?));
let ad = audio_data.clone();
let stream = device.build_input_stream(
&config,
move |data: &[f32], _: &cpal::InputCallbackInfo| {
let mut ad = ad.lock().unwrap();
if let Err(err) = ad.push_samples(data) {
eprintln!("error processing audio input {err:?}")
}
},
move |err| eprintln!("cpal error: {err}"),
None, // None=blocking, Some(Duration)=timeout
)?;
stream.play()?;
Ok((stream, audio_data))
}
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
where
T: symphonia::core::sample::Sample,
f32: symphonia::core::conv::FromSample<T>,
{
use symphonia::core::audio::Signal;
use symphonia::core::conv::FromSample;
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
}
pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
use symphonia::core::audio::{AudioBufferRef, Signal};
let src = std::fs::File::open(path)?;
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
let hint = symphonia::core::probe::Hint::new();
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
let mut format = probed.format;
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
.expect("no supported audio tracks");
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &Default::default())
.expect("unsupported codec");
let track_id = track.id;
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
let mut pcm_data = Vec::new();
while let Ok(packet) = format.next_packet() {
while !format.metadata().is_latest() {
format.metadata().pop();
}
if packet.track_id() != track_id {
continue;
}
match decoder.decode(&packet)? {
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
}
}
Ok((pcm_data, sample_rate))
}
pub(crate) fn resample(pcm_in: &[f32], sr_in: usize, sr_out: usize) -> Result<Vec<f32>> {
use rubato::Resampler;
let mut pcm_out =
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
let mut resampler = rubato::FftFixedInOut::<f32>::new(sr_in, sr_out, 1024, 1)?;
let mut output_buffer = resampler.output_buffer_allocate(true);
let mut pos_in = 0;
while pos_in + resampler.input_frames_next() < pcm_in.len() {
let (in_len, out_len) =
resampler.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)?;
pos_in += in_len;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}
if pos_in < pcm_in.len() {
let (_in_len, out_len) = resampler.process_partial_into_buffer(
Some(&[&pcm_in[pos_in..]]),
&mut output_buffer,
None,
)?;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}
Ok(pcm_out)
}

View File

@ -1,165 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use candle::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::mimi::{Config, Model};
use clap::{Parser, ValueEnum};
use hf_hub::api::sync::Api;
mod audio_io;
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Action {
AudioToAudio,
AudioToCode,
CodeToAudio,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// The action to be performed, specifies the format for the input and output data.
action: Action,
/// The input file, either an audio file or some mimi tokens stored as safetensors.
in_file: String,
/// The output file, either a wave audio file or some mimi tokens stored as safetensors.
out_file: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// The model weight file, in safetensor format.
#[arg(long)]
model: Option<String>,
/// Whether to use streaming or not, when streaming slices of data of the given size are passed
/// to the encoder/decoder one at a time.
#[arg(long)]
streaming: Option<usize>,
}
fn main() -> Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => Api::new()?
.model("kyutai/mimi".to_string())
.get("model.safetensors")?,
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let config = Config::v0_1(None);
let mut model = Model::new(config, vb)?;
let codes = match args.action {
Action::CodeToAudio => {
let codes = candle::safetensors::load(args.in_file, &device)?;
codes.get("codes").expect("no codes in input file").clone()
}
Action::AudioToCode | Action::AudioToAudio => {
let pcm = if args.in_file == "-" {
println!(">>>> RECORDING AUDIO, PRESS ENTER ONCE DONE <<<<");
let (stream, input_audio) = audio_io::setup_input_stream()?;
let mut pcms = vec![];
let stdin = std::thread::spawn(|| {
let mut s = String::new();
std::io::stdin().read_line(&mut s)
});
while !stdin.is_finished() {
let input = input_audio.lock().unwrap().take_all();
if input.is_empty() {
std::thread::sleep(std::time::Duration::from_millis(100));
continue;
}
pcms.push(input)
}
drop(stream);
pcms.concat()
} else {
let (pcm, sample_rate) = audio_io::pcm_decode(args.in_file)?;
if sample_rate != 24_000 {
println!("WARNING: mimi uses a 24khz sample rate, input uses {sample_rate}, resampling...");
audio_io::resample(&pcm, sample_rate as usize, 24_000)?
} else {
pcm
}
};
match args.streaming {
Some(chunk_size) => {
let mut code_chunks = vec![];
for pcm in pcm.chunks(chunk_size) {
let pcm = Tensor::new(pcm, &device)?.reshape((1, 1, ()))?;
let code_chunk = model.encode(&pcm)?;
code_chunks.push(code_chunk)
}
Tensor::cat(&code_chunks, candle::D::Minus1)?
}
None => {
let pcm_len = pcm.len();
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
println!("input pcm shape: {:?}", pcm.shape());
model.encode(&pcm)?
}
}
}
};
println!("codes shape: {:?}", codes.shape());
model.reset_state();
match args.action {
Action::AudioToCode => {
codes.save_safetensors("codes", &args.out_file)?;
}
Action::AudioToAudio | Action::CodeToAudio => {
let pcm = match args.streaming {
Some(chunk_size) => {
let seq_len = codes.dim(candle::D::Minus1)?;
let mut pcm_chunks = vec![];
for chunk_start in (0..seq_len).step_by(chunk_size) {
let chunk_len = usize::min(chunk_size, seq_len - chunk_start);
let codes = codes.narrow(candle::D::Minus1, chunk_start, chunk_len)?;
let pcm = model.decode_step(&codes.into())?;
if let Some(pcm) = pcm.as_option() {
pcm_chunks.push(pcm.clone())
}
}
Tensor::cat(&pcm_chunks, candle::D::Minus1)?
}
None => model.decode(&codes)?,
};
println!("output pcm shape: {:?}", pcm.shape());
let pcm = pcm.i(0)?.i(0)?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
let pcm = pcm.to_vec1::<f32>()?;
if args.out_file == "-" {
let (stream, ad) = audio_io::setup_output_stream()?;
{
let mut ad = ad.lock().unwrap();
ad.push_samples(&pcm)?;
}
loop {
let ad = ad.lock().unwrap();
if ad.is_empty() {
break;
}
// That's very weird, calling thread::sleep here triggers the stream to stop
// playing (the callback doesn't seem to be called anymore).
// std::thread::sleep(std::time::Duration::from_millis(100));
}
drop(stream)
} else {
let mut output = std::fs::File::create(&args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
}
}
}
Ok(())
}

View File

@ -147,12 +147,6 @@ enum Which {
Mistral7bInstructV01,
#[value(name = "7b-instruct-v0.2")]
Mistral7bInstructV02,
#[value(name = "7b-maths-v0.1")]
Mathstral7bV01,
#[value(name = "nemo-2407")]
MistralNemo2407,
#[value(name = "nemo-instruct-2407")]
MistralNemoInstruct2407,
}
#[derive(Parser, Debug)]
@ -267,16 +261,12 @@ fn main() -> Result<()> {
}
"lmz/candle-mistral".to_string()
} else {
let name = match args.which {
Which::Mistral7bV01 => "mistralai/Mistral-7B-v0.1",
Which::Mistral7bV02 => "mistralai/Mistral-7B-v0.2",
Which::Mistral7bInstructV01 => "mistralai/Mistral-7B-Instruct-v0.1",
Which::Mistral7bInstructV02 => "mistralai/Mistral-7B-Instruct-v0.2",
Which::Mathstral7bV01 => "mistralai/mathstral-7B-v0.1",
Which::MistralNemo2407 => "mistralai/Mistral-Nemo-Base-2407",
Which::MistralNemoInstruct2407 => "mistralai/Mistral-Nemo-Instruct-2407",
};
name.to_string()
match args.which {
Which::Mistral7bV01 => "mistralai/Mistral-7B-v0.1".to_string(),
Which::Mistral7bV02 => "mistralai/Mistral-7B-v0.2".to_string(),
Which::Mistral7bInstructV01 => "mistralai/Mistral-7B-Instruct-v0.1".to_string(),
Which::Mistral7bInstructV02 => "mistralai/Mistral-7B-Instruct-v0.2".to_string(),
}
}
}
};

View File

@ -217,7 +217,11 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let config = Config::v0_1_8x7b(args.use_flash_attn);
let device = candle_examples::device(args.cpu)?;
let dtype = device.bf16_default_to_f32();
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());

View File

@ -1,28 +0,0 @@
# candle-mobileclip
MobileCLIP is family of efficient CLIP-like models using FastViT-based image encoders.
See [MobileCLIP: Fast Image-Text Models through Multi-Modal Reinforced Training](https://arxiv.org/abs/2311.17049)
## Running on an example on cpu
```
$ cargo run --example mobileclip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
softmax_image_vec: [2.4819004e-5, 3.81081e-6, 0.9999714, 0.9999738, 2.382714e-5, 2.3317718e-6]
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
Probability: 0.0025% Text: a cycling race
Probability: 0.0004% Text: a photo of two cats
Probability: 99.9971% Text: a robot holding a candle
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
Probability: 99.9974% Text: a cycling race
Probability: 0.0024% Text: a photo of two cats
Probability: 0.0002% Text: a robot holding a candle
```

View File

@ -1,192 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::{Parser, ValueEnum};
use candle::{DType, Device, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use candle_transformers::models::mobileclip;
use tokenizers::Tokenizer;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
S1,
S2,
}
impl Which {
fn model_name(&self) -> String {
let name = match self {
Self::S1 => "S1",
Self::S2 => "S2",
};
format!("apple/MobileCLIP-{}-OpenCLIP", name)
}
fn config(&self) -> mobileclip::MobileClipConfig {
match self {
Self::S1 => mobileclip::MobileClipConfig::s1(),
Self::S2 => mobileclip::MobileClipConfig::s2(),
}
}
}
#[derive(Parser)]
struct Args {
#[arg(long, use_value_delimiter = true)]
images: Option<Vec<String>>,
#[arg(long)]
cpu: bool,
/// Use the pytorch weights rather than the safetensors ones
#[arg(long)]
use_pth: bool,
#[arg(long, use_value_delimiter = true)]
sequences: Option<Vec<String>>,
#[arg(value_enum, long, default_value_t=Which::S1)]
which: Which,
}
fn load_images<T: AsRef<std::path::Path>>(
paths: &Vec<T>,
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];
for path in paths {
let tensor = candle_examples::imagenet::load_image_with_std_mean(
path,
image_size,
&[0.0, 0.0, 0.0],
&[1.0, 1.0, 1.0],
)?;
images.push(tensor);
}
let images = Tensor::stack(&images, 0)?;
Ok(images)
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let model_name = args.which.model_name();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
let model_file = if args.use_pth {
api.get("open_clip_pytorch_model.bin")?
} else {
api.get("open_clip_model.safetensors")?
};
let tokenizer = api.get("tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let config = &args.which.config();
let device = candle_examples::device(args.cpu)?;
let vec_imgs = match args.images {
Some(imgs) => imgs,
None => vec![
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
],
};
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
let vb = if args.use_pth {
VarBuilder::from_pth(&model_file, DType::F32, &device)?
} else {
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? }
};
let model = mobileclip::MobileClipModel::new(vb, config)?;
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
let softmax_image = softmax(&logits_per_image, 1)?;
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
println!("softmax_image_vec: {:?}", softmax_image_vec);
let probability_vec = softmax_image_vec
.iter()
.map(|v| v * 100.0)
.collect::<Vec<f32>>();
let probability_per_image = probability_vec.len() / vec_imgs.len();
for (i, img) in vec_imgs.iter().enumerate() {
let start = i * probability_per_image;
let end = start + probability_per_image;
let prob = &probability_vec[start..end];
println!("\n\nResults for image: {}\n", img);
for (i, p) in prob.iter().enumerate() {
println!("Probability: {:.4}% Text: {}", p, vec_seq[i]);
}
}
Ok(())
}
pub fn tokenize_sequences(
sequences: Option<Vec<String>>,
tokenizer: &Tokenizer,
device: &Device,
) -> anyhow::Result<(Tensor, Vec<String>)> {
// let pad_id = *tokenizer
// .get_vocab(true)
// .get("<|endoftext|>")
// .ok_or(E::msg("No pad token"))?;
// The model does not work well if the text is padded using the <|endoftext|> token, using 0
// as the original OpenCLIP code.
let pad_id = 0;
let vec_seq = match sequences {
Some(seq) => seq,
None => vec![
"a cycling race".to_string(),
"a photo of two cats".to_string(),
"a robot holding a candle".to_string(),
],
};
let mut tokens = vec![];
for seq in vec_seq.clone() {
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
tokens.push(encoding.get_ids().to_vec());
}
let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0);
// Pad the sequences to have the same length
for token_vec in tokens.iter_mut() {
let len_diff = max_len - token_vec.len();
if len_diff > 0 {
token_vec.extend(vec![pad_id; len_diff]);
}
}
let input_ids = Tensor::new(tokens, device)?;
Ok((input_ids, vec_seq))
}

View File

@ -1,18 +0,0 @@
# candle-mobilenetv4
[MobileNetV4 - Universal Models for the Mobile Ecosystem](https://arxiv.org/abs/2404.10518)
This candle implementation uses pre-trained MobileNetV4 models from timm for inference.
The classification head has been trained on the ImageNet dataset and returns the probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example mobilenetv4 --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which medium
loaded image Tensor[dims 3, 256, 256; f32]
model built
unicycle, monocycle : 20.18%
mountain bike, all-terrain bike, off-roader: 19.77%
bicycle-built-for-two, tandem bicycle, tandem: 15.91%
crash helmet : 1.15%
tricycle, trike, velocipede: 0.67%
```

View File

@ -1,107 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::mobilenetv4;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
Small,
Medium,
Large,
HybridMedium,
HybridLarge,
}
impl Which {
fn model_filename(&self) -> String {
let name = match self {
Self::Small => "conv_small.e2400_r224",
Self::Medium => "conv_medium.e500_r256",
Self::HybridMedium => "hybrid_medium.ix_e550_r256",
Self::Large => "conv_large.e600_r384",
Self::HybridLarge => "hybrid_large.ix_e600_r384",
};
format!("timm/mobilenetv4_{}_in1k", name)
}
fn resolution(&self) -> u32 {
match self {
Self::Small => 224,
Self::Medium => 256,
Self::HybridMedium => 256,
Self::Large => 384,
Self::HybridLarge => 384,
}
}
fn config(&self) -> mobilenetv4::Config {
match self {
Self::Small => mobilenetv4::Config::small(),
Self::Medium => mobilenetv4::Config::medium(),
Self::HybridMedium => mobilenetv4::Config::hybrid_medium(),
Self::Large => mobilenetv4::Config::large(),
Self::HybridLarge => mobilenetv4::Config::hybrid_large(),
}
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(value_enum, long, default_value_t=Which::Small)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image =
candle_examples::imagenet::load_image(args.image, args.which.resolution() as usize)?
.to_device(&device)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let model_name = args.which.model_filename();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = mobilenetv4::mobilenetv4(&args.which.config(), 1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -188,8 +188,8 @@ struct Args {
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
quantized: bool,
@ -208,7 +208,7 @@ struct Args {
/// Loads an image from disk using the image crate, this returns a tensor with shape
/// (3, 378, 378).
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> candle::Result<Tensor> {
let img = image::ImageReader::open(p)?
let img = image::io::Reader::open(p)?
.decode()
.map_err(candle::Error::wrap)?
.resize_to_fill(378, 378, image::imageops::FilterType::Triangle); // Adjusted to 378x378
@ -252,28 +252,20 @@ async fn main() -> anyhow::Result<()> {
let start = std::time::Instant::now();
let api = hf_hub::api::tokio::Api::new()?;
let (model_id, revision) = match args.model_id {
Some(model_id) => (model_id.to_string(), None),
let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => {
if args.quantized {
("santiagomed/candle-moondream".to_string(), None)
"santiagomed/candle-moondream".to_string()
} else {
(
"vikhyatk/moondream2".to_string(),
Some("30c7cdf3fa6914f50bee3956694374143f5cc884"),
)
"vikhyatk/moondream2".to_string()
}
}
};
let revision = match (args.revision, revision) {
(Some(r), _) => r,
(None, Some(r)) => r.to_string(),
(None, None) => "main".to_string(),
};
let repo = api.repo(hf_hub::Repo::with_revision(
model_id,
hf_hub::RepoType::Model,
revision,
args.revision,
));
let model_file = match args.model_file {
Some(m) => m.into(),

View File

@ -284,11 +284,11 @@ impl MusicgenDecoder {
};
let embed_dim = cfg.vocab_size + 1;
let embed_tokens = (0..cfg.num_codebooks)
.map(|i| embedding(embed_dim, h, vb.pp(format!("embed_tokens.{i}"))))
.map(|i| embedding(embed_dim, h, vb.pp(&format!("embed_tokens.{i}"))))
.collect::<Result<Vec<_>>>()?;
let embed_positions = MusicgenSinusoidalPositionalEmbedding::load(vb.clone(), cfg)?;
let layers = (0..cfg.num_hidden_layers)
.map(|i| MusicgenDecoderLayer::load(vb.pp(format!("layers.{i}")), cfg))
.map(|i| MusicgenDecoderLayer::load(vb.pp(&format!("layers.{i}")), cfg))
.collect::<Result<Vec<_>>>()?;
let layer_norm = layer_norm(h, 1e-5, vb.pp("layer_norm"))?;
Ok(Self {
@ -341,7 +341,7 @@ impl MusicgenForCausalLM {
let h = cfg.hidden_size;
let decoder = MusicgenDecoder::load(vb.pp("model.decoder"), cfg)?;
let lm_heads = (0..cfg.num_codebooks)
.map(|i| linear_no_bias(h, cfg.vocab_size, vb.pp(format!("lm_heads.{i}"))))
.map(|i| linear_no_bias(h, cfg.vocab_size, vb.pp(&format!("lm_heads.{i}"))))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
decoder,

View File

@ -1,23 +0,0 @@
# candle-parler-tts
[Parler-TTS](https://huggingface.co/parler-tts/parler-tts-large-v1) is a large
text-to-speech model with 2.2B parameters trained on ~45K hours of audio data.
The voice can be controlled by a text prompt.
## Run an example
```bash
cargo run --example parler-tts -r -- \
--prompt "Hey, how are you doing today?"
```
In order to specify some prompt for the voice, use the `--description` argument.
```bash
cargo run --example parler-tts -r -- \
--prompt "Hey, how are you doing today?" \
--description "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up."
```
https://github.com/user-attachments/assets/1b16aeac-70a3-4803-8589-4563279bba33

View File

@ -1,206 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::Parser;
use candle::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::parler_tts::{Config, Model};
use tokenizers::Tokenizer;
#[derive(Parser)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
#[arg(long, default_value = "Hey, how are you doing today?")]
prompt: String,
#[arg(
long,
default_value = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up."
)]
description: String,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 0.0)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 0)]
seed: u64,
#[arg(long, default_value_t = 5000)]
sample_len: usize,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.0)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
quantized: bool,
/// Use f16 precision for all the computations rather than f32.
#[arg(long)]
f16: bool,
#[arg(long)]
model_file: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
#[arg(long, default_value_t = 512)]
max_steps: usize,
/// The output wav file.
#[arg(long, default_value = "out.wav")]
out_file: String,
#[arg(long, default_value = "large-v1")]
which: Which,
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "large-v1")]
LargeV1,
#[value(name = "mini-v1")]
MiniV1,
}
fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature, args.repeat_penalty, args.repeat_last_n
);
let start = std::time::Instant::now();
let api = hf_hub::api::sync::Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => match args.which {
Which::LargeV1 => "parler-tts/parler-tts-large-v1".to_string(),
Which::MiniV1 => "parler-tts/parler-tts-mini-v1".to_string(),
},
};
let revision = match args.revision {
Some(r) => r,
None => "main".to_string(),
};
let repo = api.repo(hf_hub::Repo::with_revision(
model_id,
hf_hub::RepoType::Model,
revision,
));
let model_files = match args.model_file {
Some(m) => vec![m.into()],
None => match args.which {
Which::MiniV1 => vec![repo.get("model.safetensors")?],
Which::LargeV1 => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
},
};
let config = match args.config_file {
Some(m) => m.into(),
None => repo.get("config.json")?,
};
let tokenizer = match args.tokenizer_file {
Some(m) => m.into(),
None => repo.get("tokenizer.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_files, DType::F32, &device)? };
let config: Config = serde_json::from_reader(std::fs::File::open(config)?)?;
let mut model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let description_tokens = tokenizer
.encode(args.description, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let description_tokens = Tensor::new(description_tokens, &device)?.unsqueeze(0)?;
let prompt_tokens = tokenizer
.encode(args.prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let prompt_tokens = Tensor::new(prompt_tokens, &device)?.unsqueeze(0)?;
let lp = candle_transformers::generation::LogitsProcessor::new(
args.seed,
Some(args.temperature),
args.top_p,
);
println!("starting generation...");
let codes = model.generate(&prompt_tokens, &description_tokens, lp, args.max_steps)?;
println!("generated codes\n{codes}");
let codes = codes.to_dtype(DType::I64)?;
codes.save_safetensors("codes", "out.safetensors")?;
let codes = codes.unsqueeze(0)?;
let pcm = model
.audio_encoder
.decode_codes(&codes.to_device(&device)?)?;
println!("{pcm}");
let pcm = pcm.i((0, 0))?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
let pcm = pcm.to_vec1::<f32>()?;
let mut output = std::fs::File::create(&args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, config.audio_encoder.sampling_rate)?;
Ok(())
}

View File

@ -114,10 +114,6 @@ impl TextGeneration {
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
if let Some(t) = self.tokenizer.decode_rest()? {
print!("{t}");
std::io::stdout().flush()?;
}
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
@ -145,8 +141,6 @@ enum WhichModel {
V2,
#[value(name = "3")]
V3,
#[value(name = "3-medium")]
V3Medium,
#[value(name = "2-old")]
V2Old,
PuffinPhiV2,
@ -260,7 +254,6 @@ fn main() -> Result<()> {
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(),
WhichModel::V3 => "microsoft/Phi-3-mini-4k-instruct".to_string(),
WhichModel::V3Medium => "microsoft/Phi-3-medium-4k-instruct".to_string(),
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
"lmz/candle-quantized-phi".to_string()
}
@ -280,7 +273,6 @@ fn main() -> Result<()> {
WhichModel::V2Old => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(),
WhichModel::V2
| WhichModel::V3
| WhichModel::V3Medium
| WhichModel::PuffinPhiV2
| WhichModel::PhiHermes => "main".to_string(),
}
@ -295,8 +287,7 @@ fn main() -> Result<()> {
| WhichModel::V1_5
| WhichModel::V2
| WhichModel::V2Old
| WhichModel::V3
| WhichModel::V3Medium => repo.get("tokenizer.json")?,
| WhichModel::V3 => repo.get("tokenizer.json")?,
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
repo.get("tokenizer-puffin-phi-v2.json")?
}
@ -312,14 +303,14 @@ fn main() -> Result<()> {
WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?],
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
WhichModel::V3 | WhichModel::V3Medium => anyhow::bail!(
WhichModel::V3 => anyhow::bail!(
"use the quantized or quantized-phi examples for quantized phi-v3"
),
}
} else {
match args.model {
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
WhichModel::V2 | WhichModel::V2Old | WhichModel::V3 | WhichModel::V3Medium => {
WhichModel::V2 | WhichModel::V2Old | WhichModel::V3 => {
candle_examples::hub_load_safetensors(
&repo,
"model.safetensors.index.json",
@ -341,7 +332,7 @@ fn main() -> Result<()> {
WhichModel::V2 | WhichModel::V2Old => Config::v2(),
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
WhichModel::V3 | WhichModel::V3Medium => {
WhichModel::V3 => {
panic!("use the quantized or quantized-phi examples for quantized phi-v3")
}
};
@ -361,8 +352,8 @@ fn main() -> Result<()> {
let dtype = match args.dtype {
Some(dtype) => std::str::FromStr::from_str(&dtype)?,
None => {
if args.model == WhichModel::V3 || args.model == WhichModel::V3Medium {
device.bf16_default_to_f32()
if args.model == WhichModel::V3 && device.is_cuda() {
DType::BF16
} else {
DType::F32
}
@ -377,7 +368,7 @@ fn main() -> Result<()> {
let phi = Phi::new(&config, vb)?;
Model::Phi(phi)
}
WhichModel::V3 | WhichModel::V3Medium => {
WhichModel::V3 => {
let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?;
let config: Phi3Config = serde_json::from_str(&config)?;

Some files were not shown because too many files have changed in this diff Show More