mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
99 Commits
0.6.0
...
qmm-pad-fi
Author | SHA1 | Date | |
---|---|---|---|
5221146cfa | |||
fd3b53f48b | |||
c6019e9635 | |||
8cc560bb8c | |||
0bd61bae29 | |||
fa1e0e438e | |||
d01207dbf3 | |||
8097559c1a | |||
829dcfa8dc | |||
c2fca0ca11 | |||
844d45cde4 | |||
af2104078f | |||
5fc4f17727 | |||
c58c5d5b01 | |||
382c6b51af | |||
6eea45a761 | |||
ebf722b446 | |||
c09afc211c | |||
b60faebea4 | |||
72d649058b | |||
0cb0bd1dfa | |||
afb6575835 | |||
5635650d38 | |||
13b2a8a4a0 | |||
e3261216b1 | |||
c02b7c3272 | |||
86613c00e2 | |||
29e25c458d | |||
aafa24ed93 | |||
fdc2622686 | |||
ccdbe87639 | |||
2ec8729d51 | |||
e3c146ada6 | |||
1e96b8b695 | |||
a8288b7a72 | |||
6070278a31 | |||
b47c0bc475 | |||
14fd2d97e0 | |||
31a1075f4b | |||
236b29ff15 | |||
58197e1896 | |||
736d8eb752 | |||
7cff5898ec | |||
b75ef051cf | |||
c1b9e07e35 | |||
69fdcfe96a | |||
2b75dd9551 | |||
53ce65f706 | |||
68aa9c7320 | |||
35e5f31397 | |||
d3fe989d08 | |||
14db029494 | |||
6e6c1c99b0 | |||
b7d9af00cc | |||
59bbc0d287 | |||
dfdce2b602 | |||
500c9f2882 | |||
2be9bd211e | |||
89eae41efd | |||
c0a559d427 | |||
aa7ac1832d | |||
19db6b9723 | |||
0fcb40b229 | |||
6991a37b94 | |||
9ca277a9d7 | |||
2e9c010609 | |||
ac51f477eb | |||
d4b6f6eef6 | |||
957d604a78 | |||
ce90287f45 | |||
1ba87a9450 | |||
bd80078acf | |||
fea46cb719 | |||
8696cf6494 | |||
4a52aeb437 | |||
24d54d0ff9 | |||
636eff652a | |||
0f5cbb08b3 | |||
ddafc61055 | |||
a925ae6bc6 | |||
6056fd5c90 | |||
ebc9aa60bc | |||
2489a606fe | |||
3c815b1dca | |||
42891cc613 | |||
f25173d68b | |||
6a4741bbf9 | |||
30cdd769f9 | |||
d74fbed334 | |||
c63048d374 | |||
a226a9736b | |||
25960676ca | |||
9cd54aa5d4 | |||
eec11ce2ce | |||
9182f9f5c2 | |||
ecff05d72b | |||
7f1ba8038c | |||
74e9e41911 | |||
e27aac0a06 |
2
.github/workflows/python.yml
vendored
2
.github/workflows/python.yml
vendored
@ -20,7 +20,7 @@ jobs:
|
|||||||
os: [ubuntu-latest] # For now, only test on Linux
|
os: [ubuntu-latest] # For now, only test on Linux
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Rust
|
- name: Install Rust
|
||||||
uses: actions-rs/toolchain@v1
|
uses: actions-rs/toolchain@v1
|
||||||
|
8
.github/workflows/rust-ci.yml
vendored
8
.github/workflows/rust-ci.yml
vendored
@ -15,7 +15,7 @@ jobs:
|
|||||||
os: [ubuntu-latest, windows-latest, macOS-latest]
|
os: [ubuntu-latest, windows-latest, macOS-latest]
|
||||||
rust: [stable]
|
rust: [stable]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v4
|
||||||
- uses: actions-rs/toolchain@v1
|
- uses: actions-rs/toolchain@v1
|
||||||
with:
|
with:
|
||||||
profile: minimal
|
profile: minimal
|
||||||
@ -34,7 +34,7 @@ jobs:
|
|||||||
os: [ubuntu-latest, windows-latest, macOS-latest]
|
os: [ubuntu-latest, windows-latest, macOS-latest]
|
||||||
rust: [stable]
|
rust: [stable]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v4
|
||||||
- uses: actions-rs/toolchain@v1
|
- uses: actions-rs/toolchain@v1
|
||||||
with:
|
with:
|
||||||
profile: minimal
|
profile: minimal
|
||||||
@ -49,7 +49,7 @@ jobs:
|
|||||||
name: Rustfmt
|
name: Rustfmt
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v4
|
||||||
- uses: actions-rs/toolchain@v1
|
- uses: actions-rs/toolchain@v1
|
||||||
with:
|
with:
|
||||||
profile: minimal
|
profile: minimal
|
||||||
@ -65,7 +65,7 @@ jobs:
|
|||||||
name: Clippy
|
name: Clippy
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v4
|
||||||
- uses: actions-rs/toolchain@v1
|
- uses: actions-rs/toolchain@v1
|
||||||
with:
|
with:
|
||||||
profile: minimal
|
profile: minimal
|
||||||
|
10
.gitignore
vendored
10
.gitignore
vendored
@ -9,6 +9,10 @@ target/
|
|||||||
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
|
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
|
||||||
Cargo.lock
|
Cargo.lock
|
||||||
|
|
||||||
|
# editor config
|
||||||
|
.helix
|
||||||
|
.vscode
|
||||||
|
|
||||||
# These are backup files generated by rustfmt
|
# These are backup files generated by rustfmt
|
||||||
**/*.rs.bk
|
**/*.rs.bk
|
||||||
|
|
||||||
@ -36,3 +40,9 @@ candle-wasm-examples/*/package-lock.json
|
|||||||
candle-wasm-examples/**/config*.json
|
candle-wasm-examples/**/config*.json
|
||||||
.DS_Store
|
.DS_Store
|
||||||
.idea/*
|
.idea/*
|
||||||
|
__pycache__
|
||||||
|
out.safetensors
|
||||||
|
out.wav
|
||||||
|
bria.mp3
|
||||||
|
bria.safetensors
|
||||||
|
bria.wav
|
||||||
|
22
Cargo.toml
22
Cargo.toml
@ -20,7 +20,7 @@ exclude = [
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.6.0"
|
version = "0.7.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "Minimalist ML framework."
|
description = "Minimalist ML framework."
|
||||||
repository = "https://github.com/huggingface/candle"
|
repository = "https://github.com/huggingface/candle"
|
||||||
@ -33,23 +33,23 @@ ab_glyph = "0.2.23"
|
|||||||
accelerate-src = { version = "0.3.2" }
|
accelerate-src = { version = "0.3.2" }
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
byteorder = "1.4.3"
|
byteorder = "1.4.3"
|
||||||
candle = { path = "./candle-core", package = "candle-core", version = "0.6.0" }
|
candle = { path = "./candle-core", package = "candle-core", version = "0.7.1" }
|
||||||
candle-datasets = { path = "./candle-datasets", version = "0.6.0" }
|
candle-datasets = { path = "./candle-datasets", version = "0.7.1" }
|
||||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.6.0" }
|
candle-flash-attn = { path = "./candle-flash-attn", version = "0.7.1" }
|
||||||
candle-kernels = { path = "./candle-kernels", version = "0.6.0" }
|
candle-kernels = { path = "./candle-kernels", version = "0.7.1" }
|
||||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.6.0" }
|
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.7.1" }
|
||||||
candle-nn = { path = "./candle-nn", version = "0.6.0" }
|
candle-nn = { path = "./candle-nn", version = "0.7.1" }
|
||||||
candle-onnx = { path = "./candle-onnx", version = "0.6.0" }
|
candle-onnx = { path = "./candle-onnx", version = "0.7.1" }
|
||||||
candle-transformers = { path = "./candle-transformers", version = "0.6.0" }
|
candle-transformers = { path = "./candle-transformers", version = "0.7.1" }
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
criterion = { version = "0.5.1", default-features=false }
|
criterion = { version = "0.5.1", default-features=false }
|
||||||
cudarc = { version = "0.11.4", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
cudarc = { version = "0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||||
fancy-regex = "0.13.0"
|
fancy-regex = "0.13.0"
|
||||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||||
hf-hub = "0.3.0"
|
hf-hub = "0.3.0"
|
||||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||||
hound = "3.5.1"
|
hound = "3.5.1"
|
||||||
image = { version = "0.25.0", default-features = false, features = ["jpeg", "png"] }
|
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
|
||||||
imageproc = { version = "0.24.0", default-features = false }
|
imageproc = { version = "0.24.0", default-features = false }
|
||||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||||
libc = { version = "0.2.147" }
|
libc = { version = "0.2.147" }
|
||||||
|
11
README.md
11
README.md
@ -63,7 +63,9 @@ We also provide a some command line based examples using state of the art models
|
|||||||
- [LLaMA v1, v2, and v3](./candle-examples/examples/llama/): general LLM, includes
|
- [LLaMA v1, v2, and v3](./candle-examples/examples/llama/): general LLM, includes
|
||||||
the SOLAR-10.7B variant.
|
the SOLAR-10.7B variant.
|
||||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||||
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google Deepmind.
|
- [Codegeex4](./candle-examples/examples/codegeex4-9b/): Code completion,code interpreter,web search,fuction calling,repository-level
|
||||||
|
- [GLM4](./candle-examples/examples/glm4/): Open Multilingual Multimodal Chat LMs by THUDM
|
||||||
|
- [Gemma v1 and v2](./candle-examples/examples/gemma/): 2b and 7b+/9b general LLMs from Google Deepmind.
|
||||||
- [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b
|
- [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b
|
||||||
Griffin based models from Google that mix attention with a RNN like state.
|
Griffin based models from Google that mix attention with a RNN like state.
|
||||||
- [Phi-1, Phi-1.5, Phi-2, and Phi-3](./candle-examples/examples/phi/): 1.3b,
|
- [Phi-1, Phi-1.5, Phi-2, and Phi-3](./candle-examples/examples/phi/): 1.3b,
|
||||||
@ -118,6 +120,8 @@ We also provide a some command line based examples using state of the art models
|
|||||||
model using residual vector quantization.
|
model using residual vector quantization.
|
||||||
- [MetaVoice](./candle-examples/examples/metavoice/): foundational model for
|
- [MetaVoice](./candle-examples/examples/metavoice/): foundational model for
|
||||||
text-to-speech.
|
text-to-speech.
|
||||||
|
- [Parler-TTS](./candle-examples/examples/parler-tts/): large text-to-speech
|
||||||
|
model.
|
||||||
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
|
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
|
||||||
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
|
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
|
||||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||||
@ -206,7 +210,7 @@ If you have an addition to this list, please submit a pull request.
|
|||||||
- StarCoder, StarCoder2.
|
- StarCoder, StarCoder2.
|
||||||
- Phi 1, 1.5, 2, and 3.
|
- Phi 1, 1.5, 2, and 3.
|
||||||
- Mamba, Minimal Mamba
|
- Mamba, Minimal Mamba
|
||||||
- Gemma 2b and 7b.
|
- Gemma v1 2b and 7b+, v2 2b and 9b.
|
||||||
- Mistral 7b v0.1.
|
- Mistral 7b v0.1.
|
||||||
- Mixtral 8x7b v0.1.
|
- Mixtral 8x7b v0.1.
|
||||||
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
|
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
|
||||||
@ -234,9 +238,10 @@ If you have an addition to this list, please submit a pull request.
|
|||||||
- Whisper, multi-lingual speech-to-text.
|
- Whisper, multi-lingual speech-to-text.
|
||||||
- EnCodec, audio compression model.
|
- EnCodec, audio compression model.
|
||||||
- MetaVoice-1B, text-to-speech model.
|
- MetaVoice-1B, text-to-speech model.
|
||||||
|
- Parler-TTS, text-to-speech model.
|
||||||
- Computer Vision Models.
|
- Computer Vision Models.
|
||||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT,
|
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT,
|
||||||
ConvNeXTv2, MobileOne, EfficientVit (MSRA).
|
ConvNeXTv2, MobileOne, EfficientVit (MSRA), MobileNetv4, Hiera, FastViT.
|
||||||
- yolo-v3, yolo-v8.
|
- yolo-v3, yolo-v8.
|
||||||
- Segment-Anything Model (SAM).
|
- Segment-Anything Model (SAM).
|
||||||
- SegFormer.
|
- SegFormer.
|
||||||
|
@ -48,3 +48,7 @@ metal = ["dep:metal", "dep:candle-metal-kernels"]
|
|||||||
[[bench]]
|
[[bench]]
|
||||||
name = "bench_main"
|
name = "bench_main"
|
||||||
harness = false
|
harness = false
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "metal_basics"
|
||||||
|
required-features = ["metal"]
|
||||||
|
@ -12,7 +12,7 @@ fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name:
|
|||||||
let m = 1024;
|
let m = 1024;
|
||||||
let k = 1024;
|
let k = 1024;
|
||||||
|
|
||||||
let tensor = Tensor::zeros((b, m, k), dtype, &device).unwrap();
|
let tensor = Tensor::zeros((b, m, k), dtype, device).unwrap();
|
||||||
|
|
||||||
let flops = b * m * k * dtype.size_in_bytes();
|
let flops = b * m * k * dtype.size_in_bytes();
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ use criterion::{black_box, criterion_group, Criterion, Throughput};
|
|||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
fn run(matmul: &QMatMul, x: &Tensor) {
|
fn run(matmul: &QMatMul, x: &Tensor) {
|
||||||
matmul.forward(&x).unwrap();
|
matmul.forward(x).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
|
fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
|
||||||
@ -50,7 +50,7 @@ fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
|
|||||||
fn criterion_benchmark(c: &mut Criterion) {
|
fn criterion_benchmark(c: &mut Criterion) {
|
||||||
let handler = BenchDeviceHandler::new().unwrap();
|
let handler = BenchDeviceHandler::new().unwrap();
|
||||||
for device in handler.devices {
|
for device in handler.devices {
|
||||||
for dtype in vec![
|
for dtype in [
|
||||||
GgmlDType::F32,
|
GgmlDType::F32,
|
||||||
GgmlDType::F16,
|
GgmlDType::F16,
|
||||||
GgmlDType::Q4_0,
|
GgmlDType::Q4_0,
|
||||||
|
@ -12,7 +12,7 @@ fn run_unary_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &
|
|||||||
let m = 1024;
|
let m = 1024;
|
||||||
let k = 1024;
|
let k = 1024;
|
||||||
|
|
||||||
let tensor = Tensor::arange(0.0f32, (b * m * k) as f32, &device)
|
let tensor = Tensor::arange(0.0f32, (b * m * k) as f32, device)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_dtype(dtype)
|
.to_dtype(dtype)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -25,9 +25,9 @@ const SIZE: usize = B * M * K;
|
|||||||
const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
|
const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
|
||||||
|
|
||||||
fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||||
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), &device).unwrap();
|
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), device).unwrap();
|
||||||
let on_true = Tensor::ones((B, M, K), dtype, &device).unwrap();
|
let on_true = Tensor::ones((B, M, K), dtype, device).unwrap();
|
||||||
let on_false = Tensor::zeros((B, M, K), dtype, &device).unwrap();
|
let on_false = Tensor::zeros((B, M, K), dtype, device).unwrap();
|
||||||
|
|
||||||
let elements = B * M * K;
|
let elements = B * M * K;
|
||||||
// E.g. 2 f32 tensors + 1 u8 tensor
|
// E.g. 2 f32 tensors + 1 u8 tensor
|
||||||
|
28
candle-core/examples/metal_basics.rs
Normal file
28
candle-core/examples/metal_basics.rs
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use candle_core::{Device, Tensor};
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
// This requires the code to be run with MTL_CAPTURE_ENABLED=1
|
||||||
|
let device = Device::new_metal(0)?;
|
||||||
|
let metal_device = match &device {
|
||||||
|
Device::Metal(m) => m,
|
||||||
|
_ => anyhow::bail!("unexpected device"),
|
||||||
|
};
|
||||||
|
metal_device.capture("/tmp/candle.gputrace")?;
|
||||||
|
// This first synchronize ensures that a new command buffer gets created after setting up the
|
||||||
|
// capture scope.
|
||||||
|
device.synchronize()?;
|
||||||
|
let x = Tensor::randn(0f32, 1.0, (128, 128), &device)?;
|
||||||
|
let x1 = x.add(&x)?;
|
||||||
|
println!("{x1:?}");
|
||||||
|
// This second synchronize ensures that the command buffer gets commited before the end of the
|
||||||
|
// capture scope.
|
||||||
|
device.synchronize()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -320,13 +320,13 @@ impl Tensor {
|
|||||||
dilation,
|
dilation,
|
||||||
output_padding: _output_padding,
|
output_padding: _output_padding,
|
||||||
} => {
|
} => {
|
||||||
let grad_arg = grad.conv2d(kernel, *padding, *dilation, *stride, 1)?;
|
let grad_arg = grad.conv2d(kernel, *padding, *stride, *dilation, 1)?;
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||||
|
|
||||||
let grad_kernel = grad
|
let grad_kernel = grad
|
||||||
.transpose(0, 1)?
|
.transpose(0, 1)?
|
||||||
.conv2d(&arg.transpose(0, 1)?, *padding, *stride, *dilation, 1)?
|
.conv2d(&arg.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
||||||
.transpose(0, 1)?;
|
.transpose(0, 1)?;
|
||||||
let sum_grad = grads.or_insert(kernel)?;
|
let sum_grad = grads.or_insert(kernel)?;
|
||||||
let (_, _, k0, k1) = kernel.dims4()?;
|
let (_, _, k0, k1) = kernel.dims4()?;
|
||||||
@ -623,9 +623,9 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
Op::Unary(arg, UnaryOp::Silu) => {
|
Op::Unary(arg, UnaryOp::Silu) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
|
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x))) = sigmoid(x) * (1 - node) + node
|
||||||
let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
|
let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
|
||||||
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
|
let silu_grad = &sigmoid_arg * (1. - *node) + *node;
|
||||||
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
|
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
|
||||||
}
|
}
|
||||||
Op::Elu(arg, alpha) => {
|
Op::Elu(arg, alpha) => {
|
||||||
@ -634,7 +634,8 @@ impl Tensor {
|
|||||||
let zeros = arg.zeros_like()?;
|
let zeros = arg.zeros_like()?;
|
||||||
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
|
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
|
||||||
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
|
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
|
||||||
let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?;
|
// node == alpha * (e^x - 1) for x <= 0, reuse it
|
||||||
|
let negative_exp_mask = (negative_mask * (*node + *alpha))?;
|
||||||
let combined_mask = (positive_mask + negative_exp_mask)?;
|
let combined_mask = (positive_mask + negative_exp_mask)?;
|
||||||
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
|
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
|
||||||
}
|
}
|
||||||
@ -755,4 +756,9 @@ impl GradStore {
|
|||||||
};
|
};
|
||||||
Ok(grad)
|
Ok(grad)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the tensor ids of the stored gradient tensors
|
||||||
|
pub fn get_ids(&self) -> impl Iterator<Item = &TensorId> {
|
||||||
|
self.0.keys()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use crate::WithDType;
|
use crate::WithDType;
|
||||||
use cudarc;
|
use cudarc;
|
||||||
use cudarc::cudnn::safe::{Conv2dForward, Cudnn};
|
use cudarc::cudnn::safe::{ConvForward, Cudnn};
|
||||||
use cudarc::driver::{CudaSlice, CudaView, DeviceRepr, ValidAsZeroBits};
|
use cudarc::driver::{CudaSlice, CudaView, DeviceRepr, ValidAsZeroBits};
|
||||||
use std::cell::RefCell;
|
use std::cell::RefCell;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@ -87,7 +87,7 @@ pub(crate) fn launch_conv2d<
|
|||||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||||
[params.b_size as i32, params.c_out as i32, h_out, w_out],
|
[params.b_size as i32, params.c_out as i32, h_out, w_out],
|
||||||
)?;
|
)?;
|
||||||
let conv2d = Conv2dForward {
|
let conv2d = ConvForward {
|
||||||
conv: &conv,
|
conv: &conv,
|
||||||
x: &x,
|
x: &x,
|
||||||
w: &w,
|
w: &w,
|
||||||
|
@ -174,6 +174,7 @@ impl Map1 for Im2Col1D {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
struct Im2Col {
|
struct Im2Col {
|
||||||
h_k: usize,
|
h_k: usize,
|
||||||
w_k: usize,
|
w_k: usize,
|
||||||
@ -183,6 +184,7 @@ struct Im2Col {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Im2Col {
|
impl Im2Col {
|
||||||
|
#[allow(unused)]
|
||||||
fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
|
fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
|
||||||
let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
|
let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
|
||||||
let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
|
let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
|
||||||
|
@ -171,6 +171,22 @@ impl Device {
|
|||||||
matches!(self, Self::Metal(_))
|
matches!(self, Self::Metal(_))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn supports_bf16(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::Cuda(_) | Self::Metal(_) => true,
|
||||||
|
Self::Cpu => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return `BF16` for devices that support it, otherwise default to `F32`.
|
||||||
|
pub fn bf16_default_to_f32(&self) -> DType {
|
||||||
|
if self.supports_bf16() {
|
||||||
|
DType::BF16
|
||||||
|
} else {
|
||||||
|
DType::F32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
||||||
if crate::utils::cuda_is_available() {
|
if crate::utils::cuda_is_available() {
|
||||||
Self::new_cuda(ordinal)
|
Self::new_cuda(ordinal)
|
||||||
|
@ -141,28 +141,117 @@ impl<T> IndexOp<T> for Tensor
|
|||||||
where
|
where
|
||||||
T: Into<TensorIndexer>,
|
T: Into<TensorIndexer>,
|
||||||
{
|
{
|
||||||
|
///```rust
|
||||||
|
/// use candle_core::{Tensor, DType, Device, IndexOp};
|
||||||
|
/// let a = Tensor::new(&[
|
||||||
|
/// [0., 1.],
|
||||||
|
/// [2., 3.],
|
||||||
|
/// [4., 5.]
|
||||||
|
/// ], &Device::Cpu)?;
|
||||||
|
///
|
||||||
|
/// let b = a.i(0)?;
|
||||||
|
/// assert_eq!(b.shape().dims(), &[2]);
|
||||||
|
/// assert_eq!(b.to_vec1::<f64>()?, &[0., 1.]);
|
||||||
|
///
|
||||||
|
/// let c = a.i(..2)?;
|
||||||
|
/// assert_eq!(c.shape().dims(), &[2, 2]);
|
||||||
|
/// assert_eq!(c.to_vec2::<f64>()?, &[
|
||||||
|
/// [0., 1.],
|
||||||
|
/// [2., 3.]
|
||||||
|
/// ]);
|
||||||
|
///
|
||||||
|
/// let d = a.i(1..)?;
|
||||||
|
/// assert_eq!(d.shape().dims(), &[2, 2]);
|
||||||
|
/// assert_eq!(d.to_vec2::<f64>()?, &[
|
||||||
|
/// [2., 3.],
|
||||||
|
/// [4., 5.]
|
||||||
|
/// ]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
fn i(&self, index: T) -> Result<Tensor, Error> {
|
fn i(&self, index: T) -> Result<Tensor, Error> {
|
||||||
self.index(&[index.into()])
|
self.index(&[index.into()])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<A> IndexOp<(A,)> for Tensor
|
||||||
|
where
|
||||||
|
A: Into<TensorIndexer>,
|
||||||
|
{
|
||||||
|
///```rust
|
||||||
|
/// use candle_core::{Tensor, DType, Device, IndexOp};
|
||||||
|
/// let a = Tensor::new(&[
|
||||||
|
/// [0f32, 1.],
|
||||||
|
/// [2. , 3.],
|
||||||
|
/// [4. , 5.]
|
||||||
|
/// ], &Device::Cpu)?;
|
||||||
|
///
|
||||||
|
/// let b = a.i((0,))?;
|
||||||
|
/// assert_eq!(b.shape().dims(), &[2]);
|
||||||
|
/// assert_eq!(b.to_vec1::<f32>()?, &[0., 1.]);
|
||||||
|
///
|
||||||
|
/// let c = a.i((..2,))?;
|
||||||
|
/// assert_eq!(c.shape().dims(), &[2, 2]);
|
||||||
|
/// assert_eq!(c.to_vec2::<f32>()?, &[
|
||||||
|
/// [0., 1.],
|
||||||
|
/// [2., 3.]
|
||||||
|
/// ]);
|
||||||
|
///
|
||||||
|
/// let d = a.i((1..,))?;
|
||||||
|
/// assert_eq!(d.shape().dims(), &[2, 2]);
|
||||||
|
/// assert_eq!(d.to_vec2::<f32>()?, &[
|
||||||
|
/// [2., 3.],
|
||||||
|
/// [4., 5.]
|
||||||
|
/// ]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
|
fn i(&self, (a,): (A,)) -> Result<Tensor, Error> {
|
||||||
|
self.index(&[a.into()])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[allow(non_snake_case)]
|
||||||
|
impl<A, B> IndexOp<(A, B)> for Tensor
|
||||||
|
where
|
||||||
|
A: Into<TensorIndexer>,
|
||||||
|
B: Into<TensorIndexer>,
|
||||||
|
{
|
||||||
|
///```rust
|
||||||
|
/// use candle_core::{Tensor, DType, Device, IndexOp};
|
||||||
|
/// let a = Tensor::new(&[[0f32, 1., 2.], [3., 4., 5.], [6., 7., 8.]], &Device::Cpu)?;
|
||||||
|
///
|
||||||
|
/// let b = a.i((1, 0))?;
|
||||||
|
/// assert_eq!(b.to_vec0::<f32>()?, 3.);
|
||||||
|
///
|
||||||
|
/// let c = a.i((..2, 1))?;
|
||||||
|
/// assert_eq!(c.shape().dims(), &[2]);
|
||||||
|
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
|
||||||
|
///
|
||||||
|
/// let d = a.i((2.., ..))?;
|
||||||
|
/// assert_eq!(c.shape().dims(), &[2]);
|
||||||
|
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
|
fn i(&self, (a, b): (A, B)) -> Result<Tensor, Error> {
|
||||||
|
self.index(&[a.into(), b.into()])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
macro_rules! index_op_tuple {
|
macro_rules! index_op_tuple {
|
||||||
($($t:ident),+) => {
|
($doc:tt, $($t:ident),+) => {
|
||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
impl<$($t),*> IndexOp<($($t,)*)> for Tensor
|
impl<$($t),*> IndexOp<($($t,)*)> for Tensor
|
||||||
where
|
where
|
||||||
$($t: Into<TensorIndexer>,)*
|
$($t: Into<TensorIndexer>,)*
|
||||||
{
|
{
|
||||||
|
#[doc=$doc]
|
||||||
fn i(&self, ($($t,)*): ($($t,)*)) -> Result<Tensor, Error> {
|
fn i(&self, ($($t,)*): ($($t,)*)) -> Result<Tensor, Error> {
|
||||||
self.index(&[$($t.into(),)*])
|
self.index(&[$($t.into(),)*])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
index_op_tuple!(A);
|
|
||||||
index_op_tuple!(A, B);
|
index_op_tuple!("see [TensorIndex#method.i]", A, B, C);
|
||||||
index_op_tuple!(A, B, C);
|
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D);
|
||||||
index_op_tuple!(A, B, C, D);
|
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E);
|
||||||
index_op_tuple!(A, B, C, D, E);
|
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F);
|
||||||
index_op_tuple!(A, B, C, D, E, F);
|
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F, G);
|
||||||
index_op_tuple!(A, B, C, D, E, F, G);
|
|
||||||
|
@ -65,6 +65,7 @@ pub mod scalar;
|
|||||||
pub mod shape;
|
pub mod shape;
|
||||||
mod sort;
|
mod sort;
|
||||||
mod storage;
|
mod storage;
|
||||||
|
pub mod streaming;
|
||||||
mod strided_index;
|
mod strided_index;
|
||||||
mod tensor;
|
mod tensor;
|
||||||
mod tensor_cat;
|
mod tensor_cat;
|
||||||
@ -80,10 +81,11 @@ pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, Inp
|
|||||||
pub use device::{Device, DeviceLocation, NdArray};
|
pub use device::{Device, DeviceLocation, NdArray};
|
||||||
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
|
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
|
||||||
pub use error::{Error, Result};
|
pub use error::{Error, Result};
|
||||||
pub use indexer::IndexOp;
|
pub use indexer::{IndexOp, TensorIndexer};
|
||||||
pub use layout::Layout;
|
pub use layout::Layout;
|
||||||
pub use shape::{Shape, D};
|
pub use shape::{Shape, D};
|
||||||
pub use storage::Storage;
|
pub use storage::Storage;
|
||||||
|
pub use streaming::{StreamTensor, StreamingBinOp, StreamingModule};
|
||||||
pub use strided_index::{StridedBlocks, StridedIndex};
|
pub use strided_index::{StridedBlocks, StridedIndex};
|
||||||
pub use tensor::{Tensor, TensorId};
|
pub use tensor::{Tensor, TensorId};
|
||||||
pub use variable::Var;
|
pub use variable::Var;
|
||||||
|
@ -4,7 +4,7 @@ use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard};
|
use std::sync::{Arc, Mutex, RwLock};
|
||||||
|
|
||||||
use super::MetalError;
|
use super::MetalError;
|
||||||
|
|
||||||
@ -22,7 +22,73 @@ impl DeviceId {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
|
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
|
||||||
type AllocatedBuffers = Arc<RwLock<BufferMap>>;
|
pub(crate) struct Commands {
|
||||||
|
/// Single command queue for the entire device.
|
||||||
|
command_queue: CommandQueue,
|
||||||
|
/// One command buffer at a time.
|
||||||
|
/// The scheduler works by allowing multiple
|
||||||
|
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
|
||||||
|
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
|
||||||
|
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
|
||||||
|
/// to start to work).
|
||||||
|
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
|
||||||
|
/// for their START time, but there's no guarantee that command buffer1 will finish before
|
||||||
|
/// command buffer2 starts (or there are metal bugs there)
|
||||||
|
command_buffer: CommandBuffer,
|
||||||
|
/// Keeps track of the current amount of compute command encoders on the current
|
||||||
|
/// command buffer
|
||||||
|
/// Arc, RwLock because of the interior mutability.
|
||||||
|
command_buffer_index: usize,
|
||||||
|
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
|
||||||
|
compute_per_buffer: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Commands {
|
||||||
|
pub(crate) fn new(command_queue: CommandQueue) -> Result<Self> {
|
||||||
|
let command_buffer = command_queue.new_command_buffer().to_owned();
|
||||||
|
command_buffer.enqueue();
|
||||||
|
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
|
||||||
|
Ok(val) => val.parse()?,
|
||||||
|
_ => 50,
|
||||||
|
};
|
||||||
|
Ok(Self {
|
||||||
|
command_queue,
|
||||||
|
command_buffer,
|
||||||
|
command_buffer_index: 0,
|
||||||
|
compute_per_buffer,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn command_buffer(&mut self) -> Result<(bool, CommandBuffer)> {
|
||||||
|
let mut command_buffer = self.command_buffer.to_owned();
|
||||||
|
let mut flushed = false;
|
||||||
|
if self.command_buffer_index > self.compute_per_buffer {
|
||||||
|
self.command_buffer.commit();
|
||||||
|
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||||
|
self.command_buffer = command_buffer.clone();
|
||||||
|
self.command_buffer_index = 0;
|
||||||
|
flushed = true;
|
||||||
|
}
|
||||||
|
self.command_buffer_index += 1;
|
||||||
|
Ok((flushed, command_buffer))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn wait_until_completed(&mut self) -> Result<()> {
|
||||||
|
match self.command_buffer.status() {
|
||||||
|
metal::MTLCommandBufferStatus::Committed
|
||||||
|
| metal::MTLCommandBufferStatus::Scheduled
|
||||||
|
| metal::MTLCommandBufferStatus::Completed => {
|
||||||
|
panic!("Already committed");
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
self.command_buffer.commit();
|
||||||
|
self.command_buffer.wait_until_completed();
|
||||||
|
self.command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct MetalDevice {
|
pub struct MetalDevice {
|
||||||
@ -33,27 +99,8 @@ pub struct MetalDevice {
|
|||||||
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
|
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
|
||||||
pub(crate) device: metal::Device,
|
pub(crate) device: metal::Device,
|
||||||
|
|
||||||
/// Single command queue for the entire device.
|
pub(crate) commands: Arc<RwLock<Commands>>,
|
||||||
pub(crate) command_queue: CommandQueue,
|
|
||||||
/// One command buffer at a time.
|
|
||||||
/// The scheduler works by allowing multiple
|
|
||||||
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
|
|
||||||
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
|
|
||||||
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
|
|
||||||
/// to start to work).
|
|
||||||
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
|
|
||||||
/// for their START time, but there's no guarantee that command buffer1 will finish before
|
|
||||||
/// command buffer2 starts (or there are metal bugs there)
|
|
||||||
pub(crate) command_buffer: Arc<RwLock<CommandBuffer>>,
|
|
||||||
/// Keeps track of the current amount of compute command encoders on the current
|
|
||||||
/// command buffer
|
|
||||||
/// Arc, RwLock because of the interior mutability.
|
|
||||||
pub(crate) command_buffer_index: Arc<RwLock<usize>>,
|
|
||||||
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
|
|
||||||
pub(crate) compute_per_buffer: usize,
|
|
||||||
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
|
||||||
/// Heavily used by [`candle_metal_kernels`]
|
|
||||||
pub(crate) kernels: Arc<Kernels>,
|
|
||||||
/// Simple allocator struct.
|
/// Simple allocator struct.
|
||||||
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
|
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
|
||||||
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
|
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
|
||||||
@ -67,9 +114,15 @@ pub struct MetalDevice {
|
|||||||
///
|
///
|
||||||
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
|
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
|
||||||
/// (strong_count = 1).
|
/// (strong_count = 1).
|
||||||
pub(crate) buffers: AllocatedBuffers,
|
pub(crate) buffers: Arc<RwLock<BufferMap>>,
|
||||||
|
|
||||||
|
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
||||||
|
/// Heavily used by [`candle_metal_kernels`]
|
||||||
|
pub(crate) kernels: Arc<Kernels>,
|
||||||
/// Seed for random number generation.
|
/// Seed for random number generation.
|
||||||
pub(crate) seed: Arc<Mutex<Buffer>>,
|
pub(crate) seed: Arc<Mutex<Buffer>>,
|
||||||
|
/// Whether to use the MLX matmul kernels instead of the MFA ones.
|
||||||
|
pub(crate) use_mlx_mm: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for MetalDevice {
|
impl std::fmt::Debug for MetalDevice {
|
||||||
@ -87,6 +140,10 @@ impl std::ops::Deref for MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl MetalDevice {
|
impl MetalDevice {
|
||||||
|
pub fn set_use_mlx_mm(&mut self, use_mlx_mm: bool) {
|
||||||
|
self.use_mlx_mm = use_mlx_mm
|
||||||
|
}
|
||||||
|
|
||||||
pub fn id(&self) -> DeviceId {
|
pub fn id(&self) -> DeviceId {
|
||||||
self.id
|
self.id
|
||||||
}
|
}
|
||||||
@ -95,44 +152,31 @@ impl MetalDevice {
|
|||||||
&self.device
|
&self.device
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn command_queue(&self) -> &CommandQueue {
|
fn drop_unused_buffers(&self) -> Result<()> {
|
||||||
&self.command_queue
|
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||||
|
for subbuffers in buffers.values_mut() {
|
||||||
|
let newbuffers = subbuffers
|
||||||
|
.iter()
|
||||||
|
.filter(|s| Arc::strong_count(*s) > 1)
|
||||||
|
.map(Arc::clone)
|
||||||
|
.collect();
|
||||||
|
*subbuffers = newbuffers;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn command_buffer(&self) -> Result<CommandBuffer> {
|
pub fn command_buffer(&self) -> Result<CommandBuffer> {
|
||||||
let mut command_buffer_lock = self.command_buffer.write().map_err(MetalError::from)?;
|
let mut commands = self.commands.write().map_err(MetalError::from)?;
|
||||||
let mut command_buffer = command_buffer_lock.to_owned();
|
let (flushed, command_buffer) = commands.command_buffer()?;
|
||||||
let mut index = self
|
if flushed {
|
||||||
.command_buffer_index
|
self.drop_unused_buffers()?
|
||||||
.write()
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
if *index > self.compute_per_buffer {
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
|
||||||
*command_buffer_lock = command_buffer.clone();
|
|
||||||
*index = 0;
|
|
||||||
|
|
||||||
self.drop_unused_buffers()?;
|
|
||||||
}
|
}
|
||||||
*index += 1;
|
|
||||||
Ok(command_buffer)
|
Ok(command_buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn wait_until_completed(&self) -> Result<()> {
|
pub fn wait_until_completed(&self) -> Result<()> {
|
||||||
let mut command_buffer = self.command_buffer.write().map_err(MetalError::from)?;
|
let mut commands = self.commands.write().map_err(MetalError::from)?;
|
||||||
match command_buffer.status() {
|
commands.wait_until_completed()
|
||||||
metal::MTLCommandBufferStatus::Committed
|
|
||||||
| metal::MTLCommandBufferStatus::Scheduled
|
|
||||||
| metal::MTLCommandBufferStatus::Completed => {
|
|
||||||
panic!("Already committed");
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
*command_buffer = self.command_queue.new_command_buffer().to_owned();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn kernels(&self) -> &Kernels {
|
pub fn kernels(&self) -> &Kernels {
|
||||||
@ -180,6 +224,7 @@ impl MetalDevice {
|
|||||||
MTLResourceOptions::StorageModeManaged,
|
MTLResourceOptions::StorageModeManaged,
|
||||||
);
|
);
|
||||||
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||||
|
|
||||||
let subbuffers = buffers
|
let subbuffers = buffers
|
||||||
.entry((size, MTLResourceOptions::StorageModeManaged))
|
.entry((size, MTLResourceOptions::StorageModeManaged))
|
||||||
.or_insert(vec![]);
|
.or_insert(vec![]);
|
||||||
@ -210,40 +255,6 @@ impl MetalDevice {
|
|||||||
Ok(buffer)
|
Ok(buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn find_available_buffer(
|
|
||||||
&self,
|
|
||||||
size: NSUInteger,
|
|
||||||
option: MTLResourceOptions,
|
|
||||||
buffers: &RwLockWriteGuard<BufferMap>,
|
|
||||||
) -> Option<Arc<Buffer>> {
|
|
||||||
let mut best_buffer: Option<&Arc<Buffer>> = None;
|
|
||||||
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
|
|
||||||
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
|
|
||||||
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
|
|
||||||
for sub in subbuffers {
|
|
||||||
if Arc::strong_count(sub) == 1 {
|
|
||||||
best_buffer = Some(sub);
|
|
||||||
best_buffer_size = *buffer_size;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
best_buffer.cloned()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn drop_unused_buffers(&self) -> Result<()> {
|
|
||||||
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
|
||||||
for subbuffers in buffers.values_mut() {
|
|
||||||
let newbuffers = subbuffers
|
|
||||||
.iter()
|
|
||||||
.filter(|s| Arc::strong_count(*s) > 1)
|
|
||||||
.map(Arc::clone)
|
|
||||||
.collect();
|
|
||||||
*subbuffers = newbuffers;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The critical allocator algorithm
|
/// The critical allocator algorithm
|
||||||
fn allocate_buffer(
|
fn allocate_buffer(
|
||||||
&self,
|
&self,
|
||||||
@ -252,7 +263,7 @@ impl MetalDevice {
|
|||||||
_name: &str,
|
_name: &str,
|
||||||
) -> Result<Arc<Buffer>> {
|
) -> Result<Arc<Buffer>> {
|
||||||
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||||
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
|
if let Some(b) = find_available_buffer(size, option, &buffers) {
|
||||||
// Cloning also ensures we increment the strong count
|
// Cloning also ensures we increment the strong count
|
||||||
return Ok(b.clone());
|
return Ok(b.clone());
|
||||||
}
|
}
|
||||||
@ -273,7 +284,13 @@ impl MetalDevice {
|
|||||||
let descriptor = metal::CaptureDescriptor::new();
|
let descriptor = metal::CaptureDescriptor::new();
|
||||||
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
||||||
descriptor.set_capture_device(self);
|
descriptor.set_capture_device(self);
|
||||||
descriptor.set_output_url(path);
|
// The [set_output_url] call requires an absolute path so we convert it if needed.
|
||||||
|
if path.as_ref().is_absolute() {
|
||||||
|
descriptor.set_output_url(path);
|
||||||
|
} else {
|
||||||
|
let path = std::env::current_dir()?.join(path);
|
||||||
|
descriptor.set_output_url(path);
|
||||||
|
}
|
||||||
|
|
||||||
capture
|
capture
|
||||||
.start_capture(&descriptor)
|
.start_capture(&descriptor)
|
||||||
@ -285,3 +302,23 @@ impl MetalDevice {
|
|||||||
fn buf_size(size: NSUInteger) -> NSUInteger {
|
fn buf_size(size: NSUInteger) -> NSUInteger {
|
||||||
size.saturating_sub(1).next_power_of_two() as NSUInteger
|
size.saturating_sub(1).next_power_of_two() as NSUInteger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn find_available_buffer(
|
||||||
|
size: NSUInteger,
|
||||||
|
option: MTLResourceOptions,
|
||||||
|
buffers: &BufferMap,
|
||||||
|
) -> Option<Arc<Buffer>> {
|
||||||
|
let mut best_buffer: Option<&Arc<Buffer>> = None;
|
||||||
|
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
|
||||||
|
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
|
||||||
|
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
|
||||||
|
for sub in subbuffers {
|
||||||
|
if Arc::strong_count(sub) == 1 {
|
||||||
|
best_buffer = Some(sub);
|
||||||
|
best_buffer_size = *buffer_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
best_buffer.cloned()
|
||||||
|
}
|
||||||
|
@ -119,6 +119,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
DType::F32 => "affine_f32",
|
DType::F32 => "affine_f32",
|
||||||
DType::F16 => "affine_f16",
|
DType::F16 => "affine_f16",
|
||||||
DType::BF16 => "affine_bf16",
|
DType::BF16 => "affine_bf16",
|
||||||
|
DType::U8 => "affine_u8",
|
||||||
|
DType::U32 => "affine_u32",
|
||||||
dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"),
|
dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_affine(
|
candle_metal_kernels::call_affine(
|
||||||
@ -410,17 +412,42 @@ impl BackendStorage for MetalStorage {
|
|||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
} else {
|
} else {
|
||||||
let kernel_name = match (self.dtype, dtype) {
|
let kernel_name = match (self.dtype, dtype) {
|
||||||
|
(DType::BF16, DType::F16) => "cast_bf16_f16_strided",
|
||||||
|
(DType::BF16, DType::F32) => "cast_bf16_f32_strided",
|
||||||
|
(DType::BF16, DType::I64) => "cast_bf16_i64_strided",
|
||||||
|
(DType::BF16, DType::U32) => "cast_bf16_u32_strided",
|
||||||
|
(DType::BF16, DType::U8) => "cast_bf16_u8_strided",
|
||||||
|
|
||||||
|
(DType::F16, DType::BF16) => "cast_f16_bf16_strided",
|
||||||
|
(DType::F16, DType::F32) => "cast_f16_f32_strided",
|
||||||
|
(DType::F16, DType::I64) => "cast_f16_i64_strided",
|
||||||
|
(DType::F16, DType::U32) => "cast_f16_u32_strided",
|
||||||
|
(DType::F16, DType::U8) => "cast_f16_u8_strided",
|
||||||
|
|
||||||
|
(DType::F32, DType::BF16) => "cast_f32_bf16_strided",
|
||||||
|
(DType::F32, DType::F16) => "cast_f32_f16_strided",
|
||||||
|
(DType::F32, DType::I64) => "cast_f32_i64_strided",
|
||||||
|
(DType::F32, DType::U32) => "cast_f32_u32_strided",
|
||||||
|
(DType::F32, DType::U8) => "cast_f32_u8_strided",
|
||||||
|
|
||||||
|
(DType::I64, DType::F32) => "cast_i64_f32_strided",
|
||||||
|
(DType::I64, DType::BF16) => "cast_i64_bf16_strided",
|
||||||
|
(DType::I64, DType::F16) => "cast_i64_f16_strided",
|
||||||
|
(DType::I64, DType::U32) => "cast_i64_u32_strided",
|
||||||
|
(DType::I64, DType::U8) => "cast_i64_u8_strided",
|
||||||
|
|
||||||
|
(DType::U32, DType::BF16) => "cast_u32_bf16_strided",
|
||||||
|
(DType::U32, DType::F16) => "cast_u32_f16_strided",
|
||||||
(DType::U32, DType::F32) => "cast_u32_f32_strided",
|
(DType::U32, DType::F32) => "cast_u32_f32_strided",
|
||||||
(DType::U32, DType::U8) => "cast_u32_u8_strided",
|
|
||||||
(DType::U32, DType::I64) => "cast_u32_i64_strided",
|
(DType::U32, DType::I64) => "cast_u32_i64_strided",
|
||||||
(DType::U8, DType::U32) => "cast_u8_u32_strided",
|
(DType::U32, DType::U8) => "cast_u32_u8_strided",
|
||||||
|
|
||||||
|
(DType::U8, DType::BF16) => "cast_u8_bf16_strided",
|
||||||
|
(DType::U8, DType::F16) => "cast_u8_f16_strided",
|
||||||
(DType::U8, DType::F32) => "cast_u8_f32_strided",
|
(DType::U8, DType::F32) => "cast_u8_f32_strided",
|
||||||
(DType::U8, DType::I64) => "cast_u8_i64_strided",
|
(DType::U8, DType::I64) => "cast_u8_i64_strided",
|
||||||
(DType::F32, DType::F16) => "cast_f32_f16_strided",
|
(DType::U8, DType::U32) => "cast_u8_u32_strided",
|
||||||
(DType::F16, DType::F32) => "cast_f16_f32_strided",
|
|
||||||
(DType::I64, DType::F32) => "cast_i64_f32_strided",
|
|
||||||
(DType::F32, DType::BF16) => "cast_f32_bf16_strided",
|
|
||||||
(DType::BF16, DType::F32) => "cast_bf16_f32_strided",
|
|
||||||
(left, right) => {
|
(left, right) => {
|
||||||
crate::bail!("Metal strided to_dtype {left:?} {right:?} not implemented")
|
crate::bail!("Metal strided to_dtype {left:?} {right:?} not implemented")
|
||||||
}
|
}
|
||||||
@ -1396,6 +1423,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
Ok(acc)
|
Ok(acc)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn matmul(
|
fn matmul(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
@ -1404,31 +1432,78 @@ impl BackendStorage for MetalStorage {
|
|||||||
rhs_l: &Layout,
|
rhs_l: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
|
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
|
||||||
let name = match self.dtype {
|
|
||||||
DType::F32 => "sgemm",
|
|
||||||
DType::F16 => "hgemm",
|
|
||||||
dtype => {
|
|
||||||
return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into())
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let command_buffer = self.device.command_buffer()?;
|
let command_buffer = self.device.command_buffer()?;
|
||||||
command_buffer.set_label("matmul");
|
command_buffer.set_label("matmul");
|
||||||
candle_metal_kernels::call_gemm(
|
if self.dtype == DType::BF16 {
|
||||||
&self.device.device,
|
candle_metal_kernels::call_mlx_gemm(
|
||||||
&command_buffer,
|
&self.device.device,
|
||||||
&self.device.kernels,
|
&command_buffer,
|
||||||
name,
|
&self.device.kernels,
|
||||||
(b, m, n, k),
|
candle_metal_kernels::GemmDType::BF16,
|
||||||
lhs_l.stride(),
|
(b, m, n, k),
|
||||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
lhs_l.stride(),
|
||||||
&self.buffer,
|
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||||
rhs_l.stride(),
|
&self.buffer,
|
||||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
rhs_l.stride(),
|
||||||
&rhs.buffer,
|
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||||
&buffer,
|
&rhs.buffer,
|
||||||
)
|
&buffer,
|
||||||
.map_err(MetalError::from)?;
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
} else if self.device.use_mlx_mm {
|
||||||
|
let dtype = match self.dtype {
|
||||||
|
DType::F32 => candle_metal_kernels::GemmDType::F32,
|
||||||
|
DType::F16 => candle_metal_kernels::GemmDType::F16,
|
||||||
|
DType::BF16 => candle_metal_kernels::GemmDType::BF16,
|
||||||
|
dtype => {
|
||||||
|
return Err(MetalError::Message(format!(
|
||||||
|
"mlx matmul doesn't support {dtype:?}"
|
||||||
|
))
|
||||||
|
.into())
|
||||||
|
}
|
||||||
|
};
|
||||||
|
candle_metal_kernels::call_mlx_gemm(
|
||||||
|
&self.device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&self.device.kernels,
|
||||||
|
dtype,
|
||||||
|
(b, m, n, k),
|
||||||
|
lhs_l.stride(),
|
||||||
|
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||||
|
&self.buffer,
|
||||||
|
rhs_l.stride(),
|
||||||
|
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||||
|
&rhs.buffer,
|
||||||
|
&buffer,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
} else {
|
||||||
|
let name = match self.dtype {
|
||||||
|
DType::F32 => "sgemm",
|
||||||
|
DType::F16 => "hgemm",
|
||||||
|
dtype => {
|
||||||
|
return Err(
|
||||||
|
MetalError::Message(format!("matmul doesn't support {dtype:?}")).into(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
candle_metal_kernels::call_gemm(
|
||||||
|
&self.device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&self.device.kernels,
|
||||||
|
name,
|
||||||
|
(b, m, n, k),
|
||||||
|
lhs_l.stride(),
|
||||||
|
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||||
|
&self.buffer,
|
||||||
|
rhs_l.stride(),
|
||||||
|
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||||
|
&rhs.buffer,
|
||||||
|
&buffer,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
}
|
||||||
Ok(Self::new(
|
Ok(Self::new(
|
||||||
buffer,
|
buffer,
|
||||||
self.device.clone(),
|
self.device.clone(),
|
||||||
@ -1789,31 +1864,25 @@ impl BackendDevice for MetalDevice {
|
|||||||
fn new(ordinal: usize) -> Result<Self> {
|
fn new(ordinal: usize) -> Result<Self> {
|
||||||
let device = metal::Device::all().swap_remove(ordinal);
|
let device = metal::Device::all().swap_remove(ordinal);
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer().to_owned();
|
|
||||||
command_buffer.enqueue();
|
|
||||||
let command_buffer = Arc::new(RwLock::new(command_buffer));
|
|
||||||
let command_buffer_index = Arc::new(RwLock::new(0));
|
|
||||||
let kernels = Arc::new(Kernels::new());
|
let kernels = Arc::new(Kernels::new());
|
||||||
let buffers = Arc::new(RwLock::new(HashMap::new()));
|
let use_mlx_mm = match std::env::var("CANDLE_USE_MLX_MM").as_deref() {
|
||||||
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
|
Ok("false") | Ok("False") | Ok("FALSE") | Ok("0") | Err(_) => false,
|
||||||
Ok(val) => val.parse()?,
|
Ok(_) => true,
|
||||||
_ => 50,
|
|
||||||
};
|
};
|
||||||
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
|
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
|
||||||
[299792458].as_ptr() as *const c_void,
|
[299792458].as_ptr() as *const c_void,
|
||||||
4,
|
4,
|
||||||
MTLResourceOptions::StorageModeManaged,
|
MTLResourceOptions::StorageModeManaged,
|
||||||
)));
|
)));
|
||||||
|
let commands = device::Commands::new(command_queue)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
id: DeviceId::new(),
|
id: DeviceId::new(),
|
||||||
device,
|
device,
|
||||||
command_queue,
|
commands: Arc::new(RwLock::new(commands)),
|
||||||
command_buffer,
|
buffers: Arc::new(RwLock::new(HashMap::new())),
|
||||||
command_buffer_index,
|
|
||||||
compute_per_buffer,
|
|
||||||
buffers,
|
|
||||||
kernels,
|
kernels,
|
||||||
seed,
|
seed,
|
||||||
|
use_mlx_mm,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -34,7 +34,10 @@ fn ceil_div(p: usize, q: usize) -> usize {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn pad(p: usize, q: usize) -> usize {
|
fn pad(p: usize, q: usize) -> usize {
|
||||||
ceil_div(p, q) * q
|
// Overallocate by q rather than just padding by q as this should pad the last row
|
||||||
|
// and we don't have enough information here to know how many elements to add :(
|
||||||
|
// ceil_div(p, q) * q
|
||||||
|
p + q
|
||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q8_1(
|
fn quantize_q8_1(
|
||||||
@ -439,7 +442,7 @@ impl QCudaStorage {
|
|||||||
}
|
}
|
||||||
_ => crate::bail!("only f32 can be quantized"),
|
_ => crate::bail!("only f32 can be quantized"),
|
||||||
};
|
};
|
||||||
let src_len = src.len();
|
let src_len = pad(src.len(), MATRIX_ROW_PADDING);
|
||||||
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
|
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
|
||||||
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
|
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
|
||||||
qcpu_storage.quantize(&src)?;
|
qcpu_storage.quantize(&src)?;
|
||||||
|
@ -18,7 +18,7 @@ pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
|
|||||||
let actual_blocks = ys.len();
|
let actual_blocks = ys.len();
|
||||||
|
|
||||||
// Validate that the input is the right size
|
// Validate that the input is the right size
|
||||||
if expected_blocks != actual_blocks {
|
if actual_blocks < expected_blocks {
|
||||||
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
|
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -304,6 +304,7 @@ impl Dim for usize {
|
|||||||
pub enum D {
|
pub enum D {
|
||||||
Minus1,
|
Minus1,
|
||||||
Minus2,
|
Minus2,
|
||||||
|
Minus(usize),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl D {
|
impl D {
|
||||||
@ -311,6 +312,7 @@ impl D {
|
|||||||
let dim = match self {
|
let dim = match self {
|
||||||
Self::Minus1 => -1,
|
Self::Minus1 => -1,
|
||||||
Self::Minus2 => -2,
|
Self::Minus2 => -2,
|
||||||
|
Self::Minus(u) => -(*u as i32),
|
||||||
};
|
};
|
||||||
Error::DimOutOfRange {
|
Error::DimOutOfRange {
|
||||||
shape: shape.clone(),
|
shape: shape.clone(),
|
||||||
@ -327,6 +329,7 @@ impl Dim for D {
|
|||||||
match self {
|
match self {
|
||||||
Self::Minus1 if rank >= 1 => Ok(rank - 1),
|
Self::Minus1 if rank >= 1 => Ok(rank - 1),
|
||||||
Self::Minus2 if rank >= 2 => Ok(rank - 2),
|
Self::Minus2 if rank >= 2 => Ok(rank - 2),
|
||||||
|
Self::Minus(u) if *u > 0 && rank >= *u => Ok(rank - *u),
|
||||||
_ => Err(self.out_of_range(shape, op)),
|
_ => Err(self.out_of_range(shape, op)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -336,6 +339,7 @@ impl Dim for D {
|
|||||||
match self {
|
match self {
|
||||||
Self::Minus1 => Ok(rank),
|
Self::Minus1 => Ok(rank),
|
||||||
Self::Minus2 if rank >= 1 => Ok(rank - 1),
|
Self::Minus2 if rank >= 1 => Ok(rank - 1),
|
||||||
|
Self::Minus(u) if *u > 0 && rank + 1 >= *u => Ok(rank + 1 - *u),
|
||||||
_ => Err(self.out_of_range(shape, op)),
|
_ => Err(self.out_of_range(shape, op)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
206
candle-core/src/streaming.rs
Normal file
206
candle-core/src/streaming.rs
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
use crate::{Result, Shape, Tensor};
|
||||||
|
|
||||||
|
pub trait Dim: crate::shape::Dim + Copy {}
|
||||||
|
impl<T: crate::shape::Dim + Copy> Dim for T {}
|
||||||
|
|
||||||
|
/// A stream tensor is used in streaming module. It can either contain an actual tensor or be
|
||||||
|
/// empty.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct StreamTensor(Option<Tensor>);
|
||||||
|
|
||||||
|
impl std::fmt::Debug for StreamTensor {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match &self.0 {
|
||||||
|
Some(t) => write!(f, "{:?}", t.shape()),
|
||||||
|
None => write!(f, "Empty"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::convert::From<Option<Tensor>> for StreamTensor {
|
||||||
|
fn from(value: Option<Tensor>) -> Self {
|
||||||
|
Self(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::convert::From<Tensor> for StreamTensor {
|
||||||
|
fn from(value: Tensor) -> Self {
|
||||||
|
Self(Some(value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::convert::From<()> for StreamTensor {
|
||||||
|
fn from(_value: ()) -> Self {
|
||||||
|
Self(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StreamTensor {
|
||||||
|
pub fn empty() -> Self {
|
||||||
|
Self(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_tensor(tensor: Tensor) -> Self {
|
||||||
|
Self(Some(tensor))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn shape(&self) -> Option<&Shape> {
|
||||||
|
self.0.as_ref().map(|t| t.shape())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn cat2<D: Dim>(&self, rhs: &Self, dim: D) -> Result<Self> {
|
||||||
|
let xs = match (&self.0, &rhs.0) {
|
||||||
|
(Some(lhs), Some(rhs)) => {
|
||||||
|
let xs = Tensor::cat(&[lhs, rhs], dim)?;
|
||||||
|
Some(xs)
|
||||||
|
}
|
||||||
|
(Some(xs), None) | (None, Some(xs)) => Some(xs.clone()),
|
||||||
|
(None, None) => None,
|
||||||
|
};
|
||||||
|
Ok(Self(xs))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn seq_len<D: Dim>(&self, dim: D) -> Result<usize> {
|
||||||
|
match &self.0 {
|
||||||
|
None => Ok(0),
|
||||||
|
Some(v) => v.dim(dim),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(&mut self) {
|
||||||
|
self.0 = None
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn narrow<D: Dim>(&self, dim: D, offset: usize, len: usize) -> Result<StreamTensor> {
|
||||||
|
let t = match &self.0 {
|
||||||
|
None => None,
|
||||||
|
Some(t) => {
|
||||||
|
let seq_len = t.dim(dim)?;
|
||||||
|
if seq_len <= offset {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let t = t.narrow(dim, offset, usize::min(len, seq_len - offset))?;
|
||||||
|
Some(t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(Self(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Splits the Streaming Tensor on the time axis `dim` with the first `lhs_len` elements
|
||||||
|
/// returned in the first output and the remaining in the second output.
|
||||||
|
pub fn split<D: Dim>(&self, dim: D, lhs_len: usize) -> Result<(Self, Self)> {
|
||||||
|
match &self.0 {
|
||||||
|
None => Ok((Self::empty(), Self::empty())),
|
||||||
|
Some(t) => {
|
||||||
|
let seq_len = t.dim(dim)?;
|
||||||
|
let lhs_len = usize::min(seq_len, lhs_len);
|
||||||
|
if lhs_len == 0 {
|
||||||
|
Ok((Self::empty(), t.clone().into()))
|
||||||
|
} else {
|
||||||
|
let lhs = Self::from_tensor(t.narrow(dim, 0, lhs_len)?);
|
||||||
|
let rhs_len = seq_len - lhs_len;
|
||||||
|
let rhs = if rhs_len == 0 {
|
||||||
|
Self::empty()
|
||||||
|
} else {
|
||||||
|
Self::from_tensor(t.narrow(dim, lhs_len, rhs_len)?)
|
||||||
|
};
|
||||||
|
Ok((lhs, rhs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn as_option(&self) -> Option<&Tensor> {
|
||||||
|
self.0.as_ref()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
|
||||||
|
match &self.0 {
|
||||||
|
None => Ok(Self::empty()),
|
||||||
|
Some(t) => Ok(Self::from_tensor(t.apply(m)?)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Streaming modules take as input a stream tensor and return a stream tensor. They may perform
|
||||||
|
/// some internal buffering so that enough data has been received for the module to be able to
|
||||||
|
/// perform some operations.
|
||||||
|
pub trait StreamingModule {
|
||||||
|
// TODO: Should we also have a flush method?
|
||||||
|
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor>;
|
||||||
|
fn reset_state(&mut self);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
|
pub enum BinOp {
|
||||||
|
Add,
|
||||||
|
Mul,
|
||||||
|
Sub,
|
||||||
|
Div,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct StreamingBinOp {
|
||||||
|
prev_lhs: StreamTensor,
|
||||||
|
prev_rhs: StreamTensor,
|
||||||
|
pub op: BinOp,
|
||||||
|
pub dim: crate::D,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StreamingBinOp {
|
||||||
|
pub fn new(op: BinOp, dim: crate::D) -> Self {
|
||||||
|
Self {
|
||||||
|
prev_lhs: StreamTensor::empty(),
|
||||||
|
prev_rhs: StreamTensor::empty(),
|
||||||
|
op,
|
||||||
|
dim,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset_state(&mut self) {
|
||||||
|
self.prev_lhs.reset();
|
||||||
|
self.prev_rhs.reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor> {
|
||||||
|
match self.op {
|
||||||
|
BinOp::Add => Tensor::add(lhs, rhs),
|
||||||
|
BinOp::Mul => Tensor::mul(lhs, rhs),
|
||||||
|
BinOp::Sub => Tensor::sub(lhs, rhs),
|
||||||
|
BinOp::Div => Tensor::div(lhs, rhs),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn step(&mut self, lhs: &StreamTensor, rhs: &StreamTensor) -> Result<StreamTensor> {
|
||||||
|
let lhs = StreamTensor::cat2(&self.prev_lhs, lhs, self.dim)?;
|
||||||
|
let rhs = StreamTensor::cat2(&self.prev_rhs, rhs, self.dim)?;
|
||||||
|
let lhs_len = lhs.seq_len(self.dim)?;
|
||||||
|
let rhs_len = rhs.seq_len(self.dim)?;
|
||||||
|
let common_len = usize::min(lhs_len, rhs_len);
|
||||||
|
let (lhs, prev_lhs) = lhs.split(self.dim, common_len)?;
|
||||||
|
let (rhs, prev_rhs) = rhs.split(self.dim, common_len)?;
|
||||||
|
let ys = match (lhs.0, rhs.0) {
|
||||||
|
(Some(lhs), Some(rhs)) => {
|
||||||
|
let ys = self.forward(&lhs, &rhs)?;
|
||||||
|
StreamTensor::from_tensor(ys)
|
||||||
|
}
|
||||||
|
(None, None) => StreamTensor::empty(),
|
||||||
|
(lhs, rhs) => crate::bail!("INTERNAL ERROR inconsistent lhs and rhs {lhs:?} {rhs:?}"),
|
||||||
|
};
|
||||||
|
self.prev_lhs = prev_lhs;
|
||||||
|
self.prev_rhs = prev_rhs;
|
||||||
|
Ok(ys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Simple wrapper that doesn't do any buffering.
|
||||||
|
pub struct Map<T: crate::Module>(T);
|
||||||
|
|
||||||
|
impl<T: crate::Module> StreamingModule for Map<T> {
|
||||||
|
fn reset_state(&mut self) {}
|
||||||
|
|
||||||
|
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
|
||||||
|
xs.apply(&self.0)
|
||||||
|
}
|
||||||
|
}
|
@ -370,6 +370,15 @@ impl Tensor {
|
|||||||
|
|
||||||
/// Returns a new tensor with all the elements having the same specified value. Note that
|
/// Returns a new tensor with all the elements having the same specified value. Note that
|
||||||
/// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
|
/// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
|
||||||
|
///```rust
|
||||||
|
/// use candle_core::{Tensor, Device};
|
||||||
|
/// let a = Tensor::full(3.5, (2, 4), &Device::Cpu)?;
|
||||||
|
///
|
||||||
|
/// assert_eq!(a.to_vec2::<f64>()?, &[
|
||||||
|
/// [3.5, 3.5, 3.5, 3.5],
|
||||||
|
/// [3.5, 3.5, 3.5, 3.5],
|
||||||
|
/// ]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
pub fn full<D: crate::WithDType, S: Into<Shape>>(
|
pub fn full<D: crate::WithDType, S: Into<Shape>>(
|
||||||
value: D,
|
value: D,
|
||||||
shape: S,
|
shape: S,
|
||||||
@ -379,6 +388,13 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a new 1D tensor from an iterator.
|
/// Creates a new 1D tensor from an iterator.
|
||||||
|
///```rust
|
||||||
|
/// use candle_core::{Tensor, Device};
|
||||||
|
/// let a = Tensor::from_iter( [1.0, 2.0, 3.0, 4.0].into_iter(), &Device::Cpu)?;
|
||||||
|
///
|
||||||
|
/// assert_eq!(a.to_vec1::<f64>()?, &[1.0, 2.0, 3.0, 4.0]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
pub fn from_iter<D: crate::WithDType>(
|
pub fn from_iter<D: crate::WithDType>(
|
||||||
iter: impl IntoIterator<Item = D>,
|
iter: impl IntoIterator<Item = D>,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
@ -390,12 +406,26 @@ impl Tensor {
|
|||||||
|
|
||||||
/// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
|
/// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
|
||||||
/// difference `1` from `start`.
|
/// difference `1` from `start`.
|
||||||
|
///```rust
|
||||||
|
/// use candle_core::{Tensor, Device};
|
||||||
|
/// let a = Tensor::arange(2., 5., &Device::Cpu)?;
|
||||||
|
///
|
||||||
|
/// assert_eq!(a.to_vec1::<f64>()?, &[2., 3., 4.]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
pub fn arange<D: crate::WithDType>(start: D, end: D, device: &Device) -> Result<Self> {
|
pub fn arange<D: crate::WithDType>(start: D, end: D, device: &Device) -> Result<Self> {
|
||||||
Self::arange_step(start, end, D::one(), device)
|
Self::arange_step(start, end, D::one(), device)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
|
/// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
|
||||||
/// difference `step` from `start`.
|
/// difference `step` from `start`.
|
||||||
|
///```rust
|
||||||
|
/// use candle_core::{Tensor, Device};
|
||||||
|
/// let a = Tensor::arange_step(2.0, 4.0, 0.5, &Device::Cpu)?;
|
||||||
|
///
|
||||||
|
/// assert_eq!(a.to_vec1::<f64>()?, &[2.0, 2.5, 3.0, 3.5]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
pub fn arange_step<D: crate::WithDType>(
|
pub fn arange_step<D: crate::WithDType>(
|
||||||
start: D,
|
start: D,
|
||||||
end: D,
|
end: D,
|
||||||
@ -441,6 +471,16 @@ impl Tensor {
|
|||||||
/// Creates a new tensor initialized with values from the input vector. The number of elements
|
/// Creates a new tensor initialized with values from the input vector. The number of elements
|
||||||
/// in this vector must be the same as the number of elements defined by the shape.
|
/// in this vector must be the same as the number of elements defined by the shape.
|
||||||
/// If the device is cpu, no data copy is made.
|
/// If the device is cpu, no data copy is made.
|
||||||
|
///```rust
|
||||||
|
/// use candle_core::{Tensor, Device};
|
||||||
|
/// let a = Tensor::from_vec(vec!{1., 2., 3., 4., 5., 6.}, (2, 3), &Device::Cpu)?;
|
||||||
|
///
|
||||||
|
/// assert_eq!(a.to_vec2::<f64>()?, &[
|
||||||
|
/// [1., 2., 3.],
|
||||||
|
/// [4., 5., 6.]
|
||||||
|
/// ]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
|
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
|
||||||
data: Vec<D>,
|
data: Vec<D>,
|
||||||
shape: S,
|
shape: S,
|
||||||
@ -451,6 +491,17 @@ impl Tensor {
|
|||||||
|
|
||||||
/// Creates a new tensor initialized with values from the input slice. The number of elements
|
/// Creates a new tensor initialized with values from the input slice. The number of elements
|
||||||
/// in this vector must be the same as the number of elements defined by the shape.
|
/// in this vector must be the same as the number of elements defined by the shape.
|
||||||
|
///```rust
|
||||||
|
/// use candle_core::{Tensor, Device};
|
||||||
|
/// let values = vec![1., 2., 3., 4., 5., 6., 7., 8.];
|
||||||
|
/// let a = Tensor::from_slice(&values[1..7], (2, 3), &Device::Cpu)?;
|
||||||
|
///
|
||||||
|
/// assert_eq!(a.to_vec2::<f64>()?, &[
|
||||||
|
/// [2., 3., 4.],
|
||||||
|
/// [5., 6., 7.]
|
||||||
|
/// ]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
|
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
|
||||||
array: &[D],
|
array: &[D],
|
||||||
shape: S,
|
shape: S,
|
||||||
@ -590,9 +641,9 @@ impl Tensor {
|
|||||||
///
|
///
|
||||||
/// * `args` - A slice of 1D tensors.
|
/// * `args` - A slice of 1D tensors.
|
||||||
/// * `xy_indexing` - Whether to use xy indexing or ij indexing. If xy is selected, the
|
/// * `xy_indexing` - Whether to use xy indexing or ij indexing. If xy is selected, the
|
||||||
/// first dimension corresponds to the cardinality of the second input and the second
|
/// first dimension corresponds to the cardinality of the second input and the second
|
||||||
/// dimension corresponds to the cardinality of the first input. If ij is selected, the
|
/// dimension corresponds to the cardinality of the first input. If ij is selected, the
|
||||||
/// dimensions are in the same order as the cardinality of the inputs.
|
/// dimensions are in the same order as the cardinality of the inputs.
|
||||||
///
|
///
|
||||||
/// # Examples
|
/// # Examples
|
||||||
///
|
///
|
||||||
@ -732,6 +783,30 @@ impl Tensor {
|
|||||||
|
|
||||||
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
||||||
/// ranges from `start` to `start + len`.
|
/// ranges from `start` to `start + len`.
|
||||||
|
/// ```
|
||||||
|
/// use candle_core::{Tensor, Device};
|
||||||
|
/// let a = Tensor::new(&[
|
||||||
|
/// [0f32, 1., 2.],
|
||||||
|
/// [3. , 4., 5.],
|
||||||
|
/// [6. , 7., 8.]
|
||||||
|
/// ], &Device::Cpu)?;
|
||||||
|
///
|
||||||
|
/// let b = a.narrow(0, 1, 2)?;
|
||||||
|
/// assert_eq!(b.shape().dims(), &[2, 3]);
|
||||||
|
/// assert_eq!(b.to_vec2::<f32>()?, &[
|
||||||
|
/// [3., 4., 5.],
|
||||||
|
/// [6., 7., 8.]
|
||||||
|
/// ]);
|
||||||
|
///
|
||||||
|
/// let c = a.narrow(1, 1, 1)?;
|
||||||
|
/// assert_eq!(c.shape().dims(), &[3, 1]);
|
||||||
|
/// assert_eq!(c.to_vec2::<f32>()?, &[
|
||||||
|
/// [1.],
|
||||||
|
/// [4.],
|
||||||
|
/// [7.]
|
||||||
|
/// ]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
|
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
|
||||||
let dims = self.dims();
|
let dims = self.dims();
|
||||||
let dim = dim.to_index(self.shape(), "narrow")?;
|
let dim = dim.to_index(self.shape(), "narrow")?;
|
||||||
@ -1950,7 +2025,11 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
||||||
_ => {
|
_ => {
|
||||||
bail!("not implemented yet")
|
bail!(
|
||||||
|
"not implemented yet, self.device: {:?}, device: {:?}",
|
||||||
|
self.device(),
|
||||||
|
device
|
||||||
|
)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let op = BackpropOp::new1(self, Op::ToDevice);
|
let op = BackpropOp::new1(self, Op::ToDevice);
|
||||||
@ -2440,9 +2519,19 @@ impl Tensor {
|
|||||||
|
|
||||||
/// Returns log(sum(exp(tensor), dim)).
|
/// Returns log(sum(exp(tensor), dim)).
|
||||||
pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
|
pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
|
||||||
let exp = self.exp()?;
|
let sum_dims = sum_dims.to_indexes(self.shape(), "log-sum-exp")?;
|
||||||
let sum = exp.sum(sum_dims)?;
|
if sum_dims.is_empty() {
|
||||||
sum.log()
|
return Ok(self.clone());
|
||||||
|
}
|
||||||
|
let max = sum_dims[1..]
|
||||||
|
.iter()
|
||||||
|
.try_fold(self.max_keepdim(sum_dims[0])?, |max, &dim| {
|
||||||
|
max.max_keepdim(dim)
|
||||||
|
})?;
|
||||||
|
let exp = self.broadcast_sub(&max)?.exp()?;
|
||||||
|
let sum = exp.sum(sum_dims.clone())?;
|
||||||
|
|
||||||
|
sum.log()? + max.squeeze_dims(&sum_dims)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Pointwise pow operation.
|
/// Pointwise pow operation.
|
||||||
|
@ -730,6 +730,103 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
|||||||
]
|
]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Test the same, but then with the following properties, t & w are unmodified.
|
||||||
|
let padding = 1;
|
||||||
|
let outpadding = 1;
|
||||||
|
let dilation = 1;
|
||||||
|
let stride = 2;
|
||||||
|
|
||||||
|
let res = t.conv_transpose2d(&w, padding, outpadding, stride, dilation)?;
|
||||||
|
let loss = res.sqr()?.sum_all()?;
|
||||||
|
assert_eq!(test_utils::to_vec0_round(&loss, 0)?, 3627.0); // torch gives 3626.8560
|
||||||
|
|
||||||
|
let grads = loss.backward()?;
|
||||||
|
|
||||||
|
let grad_t = grads.get(&t).unwrap();
|
||||||
|
let grad_w = grads.get(&w).unwrap();
|
||||||
|
assert_eq!(grad_t.dims(), [1, 4, 7, 5]);
|
||||||
|
assert_eq!(grad_w.dims(), [4, 2, 3, 5]);
|
||||||
|
|
||||||
|
#[rustfmt::skip]
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec3_round(&grad_t.i(0)?, 1)?,
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[ 13.2, -40.7, -9.7, -47.3, -82.7],
|
||||||
|
[ -98.2, 9.7, 57.7, -6.2, 180.7],
|
||||||
|
[ 100.2, 24.1, 3.7, -100.5, -48.1],
|
||||||
|
[ -0.3, 13.5, -2.9, 80.0, -49.8],
|
||||||
|
[ 47.2, -25.6, -74.4, 61.2, -18.4],
|
||||||
|
[ 4.6, -69.5, 27.9, 66.5, -88.1],
|
||||||
|
// 4th column on next row; torch is 4.2
|
||||||
|
[ -12.0, 79.2, -40.0, 4.1, -97.1],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[ -42.2, -36.5, -51.1, 7.5, 32.3],
|
||||||
|
[ 74.1, -44.6, -68.8, 19.5, 7.7],
|
||||||
|
[ 137.1, 54.2, 153.8, -58.0, 45.5],
|
||||||
|
[ 24.4, -56.8, 9.7, -41.0, -14.5],
|
||||||
|
[ -3.7, 72.6, 8.3, 134.8, 40.5],
|
||||||
|
[ 43.2, -56.9, -47.5, -89.4, -95.4],
|
||||||
|
[ 68.2, 108.1, -80.0, 57.0, -121.1]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[ 31.1, -11.4, -34.8, 33.1, -44.2],
|
||||||
|
[ 29.4, -31.6, -40.2, 13.7, 13.1],
|
||||||
|
[ -0.8, -83.8, -7.8, -17.3, 78.2],
|
||||||
|
[ 12.0, -118.7, 137.5, -76.7, 50.8],
|
||||||
|
[ -28.7, -114.2, -3.7, -96.3, -13.8],
|
||||||
|
[ -31.8, 28.5, -14.3, 4.6, 13.4],
|
||||||
|
[ 28.0, -0.2, -38.9, -29.7, -59.0]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[ -16.8, 38.5, 15.5, 26.6, 48.9],
|
||||||
|
[ 14.5, 49.6, -24.8, 65.6, 61.7],
|
||||||
|
[ 22.1, -64.7, -4.3, -51.0, 36.3],
|
||||||
|
[ 31.0, -88.9, 47.1, -123.5, -3.8],
|
||||||
|
[ -14.8, -39.8, 128.2, -110.3, 42.6],
|
||||||
|
// 1st column on next row; torch is -7.2
|
||||||
|
[ -7.1, 95.3, -21.3, -58.7, -13.9],
|
||||||
|
[ 26.9, 21.3, 16.1, 70.3, 32.1]
|
||||||
|
]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
#[rustfmt::skip]
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&grad_w.flatten_all()?, 1)?,
|
||||||
|
[
|
||||||
|
// 2nd value; torch gets -3.2, 3rd value; torch gets 221.8
|
||||||
|
-2.460e+01, -3.100e+00, 2.219e+02, 7.400e+00, 5.620e+01,
|
||||||
|
7.420e+01, 7.830e+01, 8.900e+00, 1.050e+01, 2.810e+01,
|
||||||
|
5.100e+00, -1.046e+02, -1.572e+02, 8.710e+01, -9.840e+01,
|
||||||
|
-4.230e+01, -1.898e+02, 1.860e+01, -3.570e+01, 9.810e+01,
|
||||||
|
4.680e+01, 1.182e+02, 4.020e+01, -1.900e+00, 1.508e+02,
|
||||||
|
1.094e+02, 1.018e+02, -4.620e+01, 1.591e+02, -2.320e+01,
|
||||||
|
// 5th value; torch gets 7.1
|
||||||
|
-8.450e+01, -4.600e+00, 6.330e+01, 1.123e+02, -7.000e+00,
|
||||||
|
1.101e+02, -6.620e+01, 2.090e+01, -5.120e+01, 8.990e+01,
|
||||||
|
9.050e+01, -6.990e+01, 6.800e+01, -9.250e+01, 1.380e+02,
|
||||||
|
4.720e+01, 4.710e+01, 6.210e+01, 8.870e+01, 2.098e+02,
|
||||||
|
3.870e+01, -1.390e+01, 6.270e+01, 1.484e+02, -9.920e+01,
|
||||||
|
-4.200e+01, -1.505e+02, -1.480e+01, -2.620e+01, 8.220e+01,
|
||||||
|
-3.350e+01, -2.260e+01, -1.198e+02, -5.080e+01, 1.259e+02,
|
||||||
|
5.600e+01, 9.270e+01, 1.209e+02, 6.590e+01, -8.330e+01,
|
||||||
|
7.000e+00, -2.600e+01, -1.133e+02, 3.870e+01, 4.020e+01,
|
||||||
|
-6.300e+00, -8.710e+01, -5.150e+01, -8.510e+01, 2.000e-01,
|
||||||
|
3.640e+01, -6.100e+00, 6.590e+01, -2.700e+00, 6.550e+01,
|
||||||
|
// 4th value; torch gets 3.8
|
||||||
|
5.300e+00, -6.760e+01, -4.270e+01, -3.900e+00, 2.880e+01,
|
||||||
|
5.260e+01, 6.170e+01, -1.203e+02, -1.610e+01, 7.740e+01,
|
||||||
|
-1.008e+02, -1.070e+01, -9.900e+00, 3.300e+00, -2.620e+01,
|
||||||
|
-4.440e+01, 2.580e+01, -6.920e+01, -4.220e+01, 1.108e+02,
|
||||||
|
1.240e+01, -3.440e+01, -2.800e+00, 7.880e+01, -6.690e+01,
|
||||||
|
1.480e+01, 2.310e+01, -4.260e+01, -1.500e+00, -4.760e+01,
|
||||||
|
5.350e+01, -2.260e+01, 8.000e-01, -3.840e+01, -2.500e+00
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,6 +49,20 @@ fn matmul(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn matmul_bf16(device: &Device) -> Result<()> {
|
||||||
|
if !device.supports_bf16() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||||
|
let a = Tensor::from_slice(&data, (2, 2), device)?.to_dtype(DType::BF16)?;
|
||||||
|
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||||
|
let b = Tensor::from_slice(&data, (2, 2), device)?.to_dtype(DType::BF16)?;
|
||||||
|
|
||||||
|
let c = a.matmul(&b)?.to_dtype(DType::F32)?;
|
||||||
|
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn broadcast_matmul(device: &Device) -> Result<()> {
|
fn broadcast_matmul(device: &Device) -> Result<()> {
|
||||||
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
|
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
|
||||||
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
|
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
|
||||||
@ -96,6 +110,12 @@ fn mm_layout(device: &Device) -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
||||||
|
test_device!(
|
||||||
|
matmul_bf16,
|
||||||
|
matmul_bf16_cpu,
|
||||||
|
matmul_bf16_gpu,
|
||||||
|
matmul_bf16_metal
|
||||||
|
);
|
||||||
test_device!(
|
test_device!(
|
||||||
broadcast_matmul,
|
broadcast_matmul,
|
||||||
broadcast_matmul_cpu,
|
broadcast_matmul_cpu,
|
||||||
|
@ -193,6 +193,19 @@ fn unary_op(device: &Device) -> Result<()> {
|
|||||||
tensor.sign()?.to_vec1::<f32>()?,
|
tensor.sign()?.to_vec1::<f32>()?,
|
||||||
[-1., -1., -1., 0., 0., 1., 1., 1., 1.]
|
[-1., -1., -1., 0., 0., 1., 1., 1., 1.]
|
||||||
);
|
);
|
||||||
|
let tensor = Tensor::new(&[-1.0f32, 0., -2., 3.], device)?;
|
||||||
|
let y = tensor.elu(2.)?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&y, 4)?,
|
||||||
|
[-1.2642, 0.0000, -1.7293, 3.0000]
|
||||||
|
);
|
||||||
|
// This test failed on metal prior to the following PR:
|
||||||
|
// https://github.com/huggingface/candle/pull/2490
|
||||||
|
let y = tensor.reshape((2, 2))?.t()?.elu(2.)?.flatten_all()?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&y, 4)?,
|
||||||
|
[-1.2642, -1.7293, 0.0000, 3.0000]
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1326,11 +1339,29 @@ fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn log_sum_exp() -> Result<()> {
|
fn log_sum_exp() -> Result<()> {
|
||||||
let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
let input = Tensor::new(
|
||||||
|
&[
|
||||||
|
[[1f64, 2., 3.], [4., 5., 6.]],
|
||||||
|
[[-1000.0, -999.0, -1001.0], [1000.0, 999.0, 1001.0]],
|
||||||
|
],
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
|
||||||
let output = input.log_sum_exp(D::Minus1)?;
|
let output = input.log_sum_exp(D::Minus1)?;
|
||||||
// The expectations obtained from pytorch.
|
// The expectations obtained from pytorch.
|
||||||
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
|
let expected = Tensor::new(&[[3.4076, 6.4076], [-998.5924, 1001.4076]], &Device::Cpu)?;
|
||||||
assert_close(&output, &expected, 0.00001)?;
|
assert_eq!(output.dims(), expected.dims());
|
||||||
|
assert_close(&output.flatten_all()?, &expected.flatten_all()?, 0.00001)?;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
input.log_sum_exp((0, 1))?.to_vec1::<f64>()?,
|
||||||
|
[1000.0, 999.0, 1001.0]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
input.log_sum_exp(())?.to_vec3::<f64>()?,
|
||||||
|
input.to_vec3::<f64>()?
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -35,7 +35,7 @@ serde = { workspace = true }
|
|||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
symphonia = { version = "0.5.3", features = ["all"], optional = true }
|
symphonia = { version = "0.5.3", features = ["all"], optional = true }
|
||||||
tokenizers = { workspace = true, features = ["onig"] }
|
tokenizers = { workspace = true, features = ["onig"] }
|
||||||
cpal= { version = "0.15.2", optional = true }
|
cpal = { version = "0.15.2", optional = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
@ -67,6 +67,7 @@ onnx = ["candle-onnx"]
|
|||||||
metal = ["candle/metal", "candle-nn/metal"]
|
metal = ["candle/metal", "candle-nn/metal"]
|
||||||
microphone = ["cpal"]
|
microphone = ["cpal"]
|
||||||
encodec = ["cpal", "symphonia", "rubato"]
|
encodec = ["cpal", "symphonia", "rubato"]
|
||||||
|
mimi = ["cpal", "symphonia", "rubato"]
|
||||||
depth_anything_v2 = ["palette", "enterpolation"]
|
depth_anything_v2 = ["palette", "enterpolation"]
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
@ -101,6 +102,10 @@ required-features = ["candle-datasets"]
|
|||||||
name = "llama2-c"
|
name = "llama2-c"
|
||||||
required-features = ["candle-datasets"]
|
required-features = ["candle-datasets"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "mimi"
|
||||||
|
required-features = ["mimi"]
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
name = "encodec"
|
name = "encodec"
|
||||||
required-features = ["encodec"]
|
required-features = ["encodec"]
|
||||||
@ -108,3 +113,7 @@ required-features = ["encodec"]
|
|||||||
[[example]]
|
[[example]]
|
||||||
name = "depth_anything_v2"
|
name = "depth_anything_v2"
|
||||||
required-features = ["depth_anything_v2"]
|
required-features = ["depth_anything_v2"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "silero-vad"
|
||||||
|
required-features = ["onnx"]
|
||||||
|
20
candle-examples/examples/based/README.md
Normal file
20
candle-examples/examples/based/README.md
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# candle-based
|
||||||
|
|
||||||
|
Experimental, not instruction-tuned small LLM from the Hazy Research group, combining local and linear attention layers.
|
||||||
|
|
||||||
|
[Blogpost](https://hazyresearch.stanford.edu/blog/2024-03-03-based)
|
||||||
|
|
||||||
|
[Simple linear attention language models balance the recall-throughput tradeoff](https://arxiv.org/abs/2402.18668)
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example based --release -- --prompt "Flying monkeys are" --which 1b-50b --sample-len 100
|
||||||
|
|
||||||
|
Flying monkeys are a common sight in the wild, but they are also a threat to humans.
|
||||||
|
|
||||||
|
The new study, published today (July 31) in the journal Science Advances, shows that the monkeys are using their brains to solve the problem of how to get around the problem.
|
||||||
|
|
||||||
|
"We found that the monkeys were using a strategy called 'cognitive mapping' - they would use their brains to map out the route ahead," says lead author Dr. David J. Smith from the University of California
|
||||||
|
|
||||||
|
```
|
275
candle-examples/examples/based/main.rs
Normal file
275
candle-examples/examples/based/main.rs
Normal file
@ -0,0 +1,275 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
use candle_transformers::models::based::Model;
|
||||||
|
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
struct TextGeneration {
|
||||||
|
model: Model,
|
||||||
|
device: Device,
|
||||||
|
tokenizer: TokenOutputStream,
|
||||||
|
logits_processor: LogitsProcessor,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextGeneration {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn new(
|
||||||
|
model: Model,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
seed: u64,
|
||||||
|
temp: Option<f64>,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
device: &Device,
|
||||||
|
) -> Self {
|
||||||
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
tokenizer: TokenOutputStream::new(tokenizer),
|
||||||
|
logits_processor,
|
||||||
|
repeat_penalty,
|
||||||
|
repeat_last_n,
|
||||||
|
device: device.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||||
|
use std::io::Write;
|
||||||
|
self.tokenizer.clear();
|
||||||
|
let mut tokens = self
|
||||||
|
.tokenizer
|
||||||
|
.tokenizer()
|
||||||
|
.encode(prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
for &t in tokens.iter() {
|
||||||
|
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||||
|
print!("{t}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
|
||||||
|
let mut generated_tokens = 0usize;
|
||||||
|
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||||
|
Some(token) => token,
|
||||||
|
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
||||||
|
};
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
for index in 0..sample_len {
|
||||||
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
|
let start_pos = tokens.len().saturating_sub(context_size);
|
||||||
|
let ctxt = &tokens[start_pos..];
|
||||||
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
|
let logits = self.model.forward(&input, start_pos)?;
|
||||||
|
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
let logits = if self.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
self.repeat_penalty,
|
||||||
|
&tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
generated_tokens += 1;
|
||||||
|
if next_token == eos_token {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let dt = start_gen.elapsed();
|
||||||
|
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
println!(
|
||||||
|
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||||
|
generated_tokens as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "360m")]
|
||||||
|
W360m,
|
||||||
|
#[value(name = "1b")]
|
||||||
|
W1b,
|
||||||
|
#[value(name = "1b-50b")]
|
||||||
|
W1b50b,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long)]
|
||||||
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "refs/pr/1")]
|
||||||
|
revision: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
config_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
weight_files: Option<String>,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "360m")]
|
||||||
|
which: Which,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature.unwrap_or(0.),
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let api = Api::new()?;
|
||||||
|
let model_id = match args.model_id {
|
||||||
|
Some(model_id) => model_id,
|
||||||
|
None => match args.which {
|
||||||
|
Which::W360m => "hazyresearch/based-360m".to_string(),
|
||||||
|
Which::W1b => "hazyresearch/based-1b".to_string(),
|
||||||
|
Which::W1b50b => "hazyresearch/based-1b-50b".to_string(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let repo = api.repo(Repo::with_revision(
|
||||||
|
model_id,
|
||||||
|
RepoType::Model,
|
||||||
|
args.revision,
|
||||||
|
));
|
||||||
|
let config_file = match args.config_file {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => repo.get("config.json")?,
|
||||||
|
};
|
||||||
|
let filenames = match args.weight_files {
|
||||||
|
Some(files) => files
|
||||||
|
.split(',')
|
||||||
|
.map(std::path::PathBuf::from)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
None => vec![repo.get("model.safetensors")?],
|
||||||
|
};
|
||||||
|
|
||||||
|
let repo = api.model("openai-community/gpt2".to_string());
|
||||||
|
let tokenizer_file = match args.tokenizer_file {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => repo.get("tokenizer.json")?,
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let config = serde_json::from_reader(std::fs::File::open(config_file)?)?;
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let dtype = if device.is_cuda() {
|
||||||
|
DType::BF16
|
||||||
|
} else {
|
||||||
|
DType::F32
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
|
if args.which == Which::W1b50b {
|
||||||
|
vb = vb.pp("model");
|
||||||
|
};
|
||||||
|
|
||||||
|
let model = Model::new(&config, vb)?;
|
||||||
|
|
||||||
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let mut pipeline = TextGeneration::new(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
args.seed,
|
||||||
|
args.temperature,
|
||||||
|
args.top_p,
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
pipeline.run(&args.prompt, args.sample_len)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
20
candle-examples/examples/beit/README.md
Normal file
20
candle-examples/examples/beit/README.md
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# candle-beit
|
||||||
|
|
||||||
|
[Beit](https://arxiv.org/abs/2106.08254) is a computer vision model.
|
||||||
|
In this example, it is used as an ImageNet classifier: the model returns the
|
||||||
|
probability for the image to belong to each of the 1000 ImageNet categories.
|
||||||
|
|
||||||
|
## Running some example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example beit --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
|
||||||
|
> mountain bike, all-terrain bike, off-roader: 56.16%
|
||||||
|
> bicycle-built-for-two, tandem bicycle, tandem: 3.08%
|
||||||
|
> maillot : 2.23%
|
||||||
|
> alp : 0.88%
|
||||||
|
> crash helmet : 0.85%
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|

|
79
candle-examples/examples/beit/main.rs
Normal file
79
candle-examples/examples/beit/main.rs
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
//! BEiT: BERT Pre-Training of Image Transformers
|
||||||
|
//! https://github.com/microsoft/unilm/tree/master/beit
|
||||||
|
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||||
|
use candle_nn::{Module, VarBuilder};
|
||||||
|
use candle_transformers::models::beit;
|
||||||
|
|
||||||
|
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||||
|
/// (3, 384, 384). Beit special normalization is applied.
|
||||||
|
pub fn load_image384_beit_norm<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||||
|
let img = image::ImageReader::open(p)?
|
||||||
|
.decode()
|
||||||
|
.map_err(candle::Error::wrap)?
|
||||||
|
.resize_to_fill(384, 384, image::imageops::FilterType::Triangle);
|
||||||
|
let img = img.to_rgb8();
|
||||||
|
let data = img.into_raw();
|
||||||
|
let data = Tensor::from_vec(data, (384, 384, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
||||||
|
let mean = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||||
|
let std = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||||
|
(data.to_dtype(candle::DType::F32)? / 255.)?
|
||||||
|
.broadcast_sub(&mean)?
|
||||||
|
.broadcast_div(&std)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
image: String,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let image = load_image384_beit_norm(args.image)?.to_device(&device)?;
|
||||||
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
|
let model_file = match args.model {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model("vincent-espitalier/candle-beit".into());
|
||||||
|
api.get("beit_base_patch16_384.in22k_ft_in22k_in1k.safetensors")?
|
||||||
|
}
|
||||||
|
Some(model) => model.into(),
|
||||||
|
};
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||||
|
let model = beit::vit_base(vb)?;
|
||||||
|
println!("model built");
|
||||||
|
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||||
|
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||||
|
.i(0)?
|
||||||
|
.to_vec1::<f32>()?;
|
||||||
|
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||||
|
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||||
|
for &(category_idx, pr) in prs.iter().take(5) {
|
||||||
|
println!(
|
||||||
|
"{:24}: {:.2}%",
|
||||||
|
candle_examples::imagenet::CLASSES[category_idx],
|
||||||
|
100. * pr
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -126,7 +126,7 @@ fn main() -> Result<()> {
|
|||||||
println!("Loaded and encoded {:?}", start.elapsed());
|
println!("Loaded and encoded {:?}", start.elapsed());
|
||||||
for idx in 0..args.n {
|
for idx in 0..args.n {
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let ys = model.forward(&token_ids, &token_type_ids)?;
|
let ys = model.forward(&token_ids, &token_type_ids, None)?;
|
||||||
if idx == 0 {
|
if idx == 0 {
|
||||||
println!("{ys}");
|
println!("{ys}");
|
||||||
}
|
}
|
||||||
@ -163,11 +163,19 @@ fn main() -> Result<()> {
|
|||||||
Ok(Tensor::new(tokens.as_slice(), device)?)
|
Ok(Tensor::new(tokens.as_slice(), device)?)
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
let attention_mask = tokens
|
||||||
|
.iter()
|
||||||
|
.map(|tokens| {
|
||||||
|
let tokens = tokens.get_attention_mask().to_vec();
|
||||||
|
Ok(Tensor::new(tokens.as_slice(), device)?)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
|
||||||
let token_ids = Tensor::stack(&token_ids, 0)?;
|
let token_ids = Tensor::stack(&token_ids, 0)?;
|
||||||
|
let attention_mask = Tensor::stack(&attention_mask, 0)?;
|
||||||
let token_type_ids = token_ids.zeros_like()?;
|
let token_type_ids = token_ids.zeros_like()?;
|
||||||
println!("running inference on batch {:?}", token_ids.shape());
|
println!("running inference on batch {:?}", token_ids.shape());
|
||||||
let embeddings = model.forward(&token_ids, &token_type_ids)?;
|
let embeddings = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
|
||||||
println!("generated embeddings {:?}", embeddings.shape());
|
println!("generated embeddings {:?}", embeddings.shape());
|
||||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||||
|
@ -55,7 +55,7 @@ const SEP_TOKEN_ID: u32 = 102;
|
|||||||
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||||
/// (3, 384, 384). OpenAI normalization is applied.
|
/// (3, 384, 384). OpenAI normalization is applied.
|
||||||
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||||
let img = image::io::Reader::open(p)?
|
let img = image::ImageReader::open(p)?
|
||||||
.decode()
|
.decode()
|
||||||
.map_err(candle::Error::wrap)?
|
.map_err(candle::Error::wrap)?
|
||||||
.resize_to_fill(384, 384, image::imageops::FilterType::Triangle);
|
.resize_to_fill(384, 384, image::imageops::FilterType::Triangle);
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
Contrastive Language-Image Pre-Training
|
# candle-clip
|
||||||
|
|
||||||
Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
|
Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
|
||||||
pairs of images with related texts.
|
pairs of images with related texts.
|
||||||
|
@ -33,7 +33,7 @@ struct Args {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
|
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
|
||||||
let img = image::io::Reader::open(path)?.decode()?;
|
let img = image::ImageReader::open(path)?.decode()?;
|
||||||
let (height, width) = (image_size, image_size);
|
let (height, width) = (image_size, image_size);
|
||||||
let img = img.resize_to_fill(
|
let img = img.resize_to_fill(
|
||||||
width as u32,
|
width as u32,
|
||||||
|
96
candle-examples/examples/codegeex4-9b/README.org
Normal file
96
candle-examples/examples/codegeex4-9b/README.org
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
* candle-codegeex4_9b
|
||||||
|
THUDM/CodeGeeX4 is a versatile model for all AI software development scenarios, including code completion, code interpreter, web search, function calling, repository-level Q&A and much more.
|
||||||
|
|
||||||
|
- [[https://github.com/THUDM/CodeGeeX4][Github]]
|
||||||
|
- [[https://codegeex.cn/][HomePage]]
|
||||||
|
- [[https://huggingface.co/THUDM/codegeex4-all-9b][huggingface]]
|
||||||
|
|
||||||
|
** Running with ~cuda~
|
||||||
|
|
||||||
|
#+begin_src shell
|
||||||
|
cargo run --example codegeex4-9b --release --features cuda -- --prompt "please write a insertion sort in rust" --sample-len 300
|
||||||
|
#+end_src
|
||||||
|
|
||||||
|
** Running with ~cpu~
|
||||||
|
#+begin_src shell
|
||||||
|
cargo run --example codegeex4-9b --release --cpu -- --prompt "please write a insertion sort in rust" --sample-len 300
|
||||||
|
#+end_src
|
||||||
|
|
||||||
|
** Output_Example
|
||||||
|
*** Input
|
||||||
|
#+begin_src shell
|
||||||
|
cargo run --release --features cuda -- --prompt 'please write a FFT in rust' --sample-len 500 --cache /root/autodl-tmp
|
||||||
|
#+end_src
|
||||||
|
|
||||||
|
*** Output
|
||||||
|
#+begin_src shell
|
||||||
|
avx: false, neon: false, simd128: false, f16c: false
|
||||||
|
temp: 0.95 repeat-penalty: 1.10 repeat-last-n: 64
|
||||||
|
cache path /root/autodl-tmp
|
||||||
|
Prompt: [please write a FFT in rust]
|
||||||
|
Using Seed 11511762269791786684
|
||||||
|
DType is BF16
|
||||||
|
transofrmer layers create
|
||||||
|
模型加载完毕 4
|
||||||
|
starting the inference loop
|
||||||
|
|
||||||
|
开始生成
|
||||||
|
samplelen 500
|
||||||
|
|
||||||
|
500 tokens generated (34.60 token/s)
|
||||||
|
Result:
|
||||||
|
|
||||||
|
Sure, I can help you with that. Here's an example of a Fast Fourier Transform (FFT) implementation in Rust:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use num_complex::Complex;
|
||||||
|
|
||||||
|
fn fft(input: &[Complex<f64> > ] ) -> Vec<Complex<f64> > > {
|
||||||
|
let n = input.len();
|
||||||
|
|
||||||
|
if n == 1 {
|
||||||
|
return vec![input[0]]];
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut even = vec![];
|
||||||
|
let mut odd = vec![];
|
||||||
|
|
||||||
|
for i in 0..n {
|
||||||
|
|
||||||
|
if i % 2 == 0 {
|
||||||
|
even.push(input[i]);
|
||||||
|
} else {
|
||||||
|
odd.push(input[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let even_fft = fft(&even);
|
||||||
|
let odd_fft = fft(&odd);
|
||||||
|
|
||||||
|
let mut output = vec![];
|
||||||
|
|
||||||
|
for k in 0..n/2 {
|
||||||
|
let t = Complex::new(0.0, -2.0 * std::f64::consts::PI * (k as f64) / (n as f64))) ).exp();
|
||||||
|
|
||||||
|
output.push(even_fft[k] + odd_fft[k] * t]);
|
||||||
|
output.push(even_fft[k] - odd_fft[k] * t]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
This implementation uses the Cooley-Tukey algorithm to perform the FFT. The function takes an array of complex numbers and returns an array of complex numbers which is the result of the FFT.
|
||||||
|
#+end_src
|
||||||
|
|
||||||
|
|
||||||
|
* Citation
|
||||||
|
#+begin_src
|
||||||
|
@inproceedings{zheng2023codegeex,
|
||||||
|
title={CodeGeeX: A Pre-Trained Model for Code Generation with Multilingual Benchmarking on HumanEval-X},
|
||||||
|
author={Qinkai Zheng and Xiao Xia and Xu Zou and Yuxiao Dong and Shan Wang and Yufei Xue and Zihan Wang and Lei Shen and Andi Wang and Yang Li and Teng Su and Zhilin Yang and Jie Tang},
|
||||||
|
booktitle={Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining},
|
||||||
|
pages={5673--5684},
|
||||||
|
year={2023}
|
||||||
|
}
|
||||||
|
#+end_src
|
252
candle-examples/examples/codegeex4-9b/main.rs
Normal file
252
candle-examples/examples/codegeex4-9b/main.rs
Normal file
@ -0,0 +1,252 @@
|
|||||||
|
use candle_transformers::models::codegeex4_9b::*;
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
use hf_hub::{Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
struct TextGeneration {
|
||||||
|
model: Model,
|
||||||
|
device: Device,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
logits_processor: LogitsProcessor,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
verbose_prompt: bool,
|
||||||
|
dtype: DType,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextGeneration {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn new(
|
||||||
|
model: Model,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
seed: u64,
|
||||||
|
temp: Option<f64>,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
verbose_prompt: bool,
|
||||||
|
device: &Device,
|
||||||
|
dtype: DType,
|
||||||
|
) -> Self {
|
||||||
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
logits_processor,
|
||||||
|
repeat_penalty,
|
||||||
|
repeat_last_n,
|
||||||
|
verbose_prompt,
|
||||||
|
device: device.clone(),
|
||||||
|
dtype,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(&mut self, prompt: &str, sample_len: usize) -> anyhow::Result<()> {
|
||||||
|
use std::io::Write;
|
||||||
|
println!("starting the inference loop");
|
||||||
|
let tokens = self.tokenizer.encode(prompt, true).expect("tokens error");
|
||||||
|
if tokens.is_empty() {
|
||||||
|
panic!("Empty prompts are not supported in the chatglm model.")
|
||||||
|
}
|
||||||
|
if self.verbose_prompt {
|
||||||
|
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
||||||
|
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
||||||
|
println!("{id:7} -> '{token}'");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
||||||
|
Some(token) => *token,
|
||||||
|
None => panic!("cannot find the endoftext token"),
|
||||||
|
};
|
||||||
|
let mut tokens = tokens.get_ids().to_vec();
|
||||||
|
let mut generated_tokens = 0usize;
|
||||||
|
|
||||||
|
print!("{prompt}");
|
||||||
|
std::io::stdout().flush().expect("output flush error");
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
|
||||||
|
println!("\n start_gen");
|
||||||
|
println!("samplelen {}", sample_len);
|
||||||
|
let mut count = 0;
|
||||||
|
let mut result = vec![];
|
||||||
|
for index in 0..sample_len {
|
||||||
|
count += 1;
|
||||||
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
|
let logits = self.model.forward(&input)?;
|
||||||
|
let logits = logits.squeeze(0)?.to_dtype(self.dtype)?;
|
||||||
|
let logits = if self.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
self.repeat_penalty,
|
||||||
|
&tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
generated_tokens += 1;
|
||||||
|
if next_token == eos_token {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let token = self
|
||||||
|
.tokenizer
|
||||||
|
.decode(&[next_token], true)
|
||||||
|
.expect("Token error");
|
||||||
|
if self.verbose_prompt {
|
||||||
|
println!(
|
||||||
|
"[Count: {}] [Raw Token: {}] [Decode Token: {}]",
|
||||||
|
count, next_token, token
|
||||||
|
);
|
||||||
|
}
|
||||||
|
result.push(token);
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
let dt = start_gen.elapsed();
|
||||||
|
println!(
|
||||||
|
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||||
|
generated_tokens as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
println!("Result:");
|
||||||
|
for tokens in result {
|
||||||
|
print!("{tokens}");
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(name = "cache", short, long, default_value = ".")]
|
||||||
|
cache_path: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Display the token for the specified prompt.
|
||||||
|
#[arg(long)]
|
||||||
|
verbose_prompt: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long)]
|
||||||
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(long, short = 'n', default_value_t = 5000)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
revision: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
weight_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature.unwrap_or(0.95),
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
println!("cache path {}", args.cache_path);
|
||||||
|
let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
|
||||||
|
.build()
|
||||||
|
.map_err(anyhow::Error::msg)?;
|
||||||
|
|
||||||
|
let model_id = match args.model_id {
|
||||||
|
Some(model_id) => model_id.to_string(),
|
||||||
|
None => "THUDM/codegeex4-all-9b".to_string(),
|
||||||
|
};
|
||||||
|
let revision = match args.revision {
|
||||||
|
Some(rev) => rev.to_string(),
|
||||||
|
None => "main".to_string(),
|
||||||
|
};
|
||||||
|
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||||
|
let tokenizer_filename = match args.tokenizer {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => api
|
||||||
|
.model("THUDM/codegeex4-all-9b".to_string())
|
||||||
|
.get("tokenizer.json")
|
||||||
|
.map_err(anyhow::Error::msg)?,
|
||||||
|
};
|
||||||
|
let filenames = match args.weight_file {
|
||||||
|
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
||||||
|
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||||
|
};
|
||||||
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error");
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let config = Config::codegeex4();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let dtype = if device.is_cuda() {
|
||||||
|
DType::BF16
|
||||||
|
} else {
|
||||||
|
DType::F32
|
||||||
|
};
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
|
let model = Model::new(&config, vb)?;
|
||||||
|
|
||||||
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let mut pipeline = TextGeneration::new(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
args.seed,
|
||||||
|
args.temperature,
|
||||||
|
args.top_p,
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n,
|
||||||
|
args.verbose_prompt,
|
||||||
|
&device,
|
||||||
|
dtype,
|
||||||
|
);
|
||||||
|
pipeline.run(&args.prompt, args.sample_len)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
25
candle-examples/examples/dinov2reg4/README.md
Normal file
25
candle-examples/examples/dinov2reg4/README.md
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# candle-dinov2-reg4
|
||||||
|
|
||||||
|
[DINOv2-reg4](https://arxiv.org/abs/2309.16588) is the lastest version of DINOv2 with registers.
|
||||||
|
In this example, it is used as an plant species classifier: the model returns the
|
||||||
|
probability for the image to belong to each of the 7806 PlantCLEF2024 categories.
|
||||||
|
|
||||||
|
## Running some example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Download classes names and a plant picture to identify
|
||||||
|
curl https://huggingface.co/vincent-espitalier/dino-v2-reg4-with-plantclef2024-weights/raw/main/species_id_mapping.txt --output candle-examples/examples/dinov2reg4/species_id_mapping.txt
|
||||||
|
curl https://bs.plantnet.org/image/o/bd2d3830ac3270218ba82fd24e2290becd01317c --output candle-examples/examples/dinov2reg4/bd2d3830ac3270218ba82fd24e2290becd01317c.jpg
|
||||||
|
|
||||||
|
# Perform inference
|
||||||
|
cargo run --example dinov2reg4 --release -- --image candle-examples/examples/dinov2reg4/bd2d3830ac3270218ba82fd24e2290becd01317c.jpg
|
||||||
|
|
||||||
|
> Orchis simia Lam. : 45.55%
|
||||||
|
> Orchis × bergonii Nanteuil: 9.80%
|
||||||
|
> Orchis italica Poir. : 9.66%
|
||||||
|
> Orchis × angusticruris Franch.: 2.76%
|
||||||
|
> Orchis × bivonae Tod. : 2.54%
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|

|
70
candle-examples/examples/dinov2reg4/main.rs
Normal file
70
candle-examples/examples/dinov2reg4/main.rs
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
//! DINOv2 reg4 finetuned on PlantCLEF 2024
|
||||||
|
//! https://arxiv.org/abs/2309.16588
|
||||||
|
//! https://huggingface.co/spaces/BVRA/PlantCLEF2024
|
||||||
|
//! https://zenodo.org/records/10848263
|
||||||
|
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle::{DType, IndexOp, D};
|
||||||
|
use candle_nn::{Module, VarBuilder};
|
||||||
|
use candle_transformers::models::dinov2reg4;
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
image: String,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let image = candle_examples::imagenet::load_image518(args.image)?.to_device(&device)?;
|
||||||
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
|
let f_species_id_mapping = "candle-examples/examples/dinov2reg4/species_id_mapping.txt";
|
||||||
|
let classes: Vec<String> = std::fs::read_to_string(f_species_id_mapping)
|
||||||
|
.expect("missing classes file")
|
||||||
|
.split('\n')
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let model_file = match args.model {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api =
|
||||||
|
api.model("vincent-espitalier/dino-v2-reg4-with-plantclef2024-weights".into());
|
||||||
|
api.get(
|
||||||
|
"vit_base_patch14_reg4_dinov2_lvd142m_pc24_onlyclassifier_then_all.safetensors",
|
||||||
|
)?
|
||||||
|
}
|
||||||
|
Some(model) => model.into(),
|
||||||
|
};
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||||
|
let model = dinov2reg4::vit_base(vb)?;
|
||||||
|
println!("model built");
|
||||||
|
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||||
|
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||||
|
.i(0)?
|
||||||
|
.to_vec1::<f32>()?;
|
||||||
|
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||||
|
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||||
|
for &(category_idx, pr) in prs.iter().take(5) {
|
||||||
|
println!("{:24}: {:.2}%", classes[category_idx], 100. * pr);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -7,7 +7,7 @@ quantization.
|
|||||||
## Running one example
|
## Running one example
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cargo run --example encodec --features symphonia --release -- code-to-audio \
|
cargo run --example encodec --features encodec --release -- code-to-audio \
|
||||||
candle-examples/examples/encodec/jfk-codes.safetensors \
|
candle-examples/examples/encodec/jfk-codes.safetensors \
|
||||||
jfk.wav
|
jfk.wav
|
||||||
```
|
```
|
||||||
|
Binary file not shown.
21
candle-examples/examples/eva2/README.md
Normal file
21
candle-examples/examples/eva2/README.md
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
# candle-eva2
|
||||||
|
|
||||||
|
[EVA-02](https://arxiv.org/abs/2303.11331) is a computer vision model.
|
||||||
|
In this example, it is used as an ImageNet classifier: the model returns the
|
||||||
|
probability for the image to belong to each of the 1000 ImageNet categories.
|
||||||
|
|
||||||
|
## Running some example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example eva2 --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
|
||||||
|
> mountain bike, all-terrain bike, off-roader: 37.09%
|
||||||
|
> maillot : 8.30%
|
||||||
|
> alp : 2.13%
|
||||||
|
> bicycle-built-for-two, tandem bicycle, tandem: 0.84%
|
||||||
|
> crash helmet : 0.73%
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|

|
82
candle-examples/examples/eva2/main.rs
Normal file
82
candle-examples/examples/eva2/main.rs
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
//! EVA-02: Explore the limits of Visual representation at scAle
|
||||||
|
//! https://github.com/baaivision/EVA
|
||||||
|
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||||
|
use candle_nn::{Module, VarBuilder};
|
||||||
|
use candle_transformers::models::eva2;
|
||||||
|
|
||||||
|
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||||
|
/// (3, 448, 448). OpenAI normalization is applied.
|
||||||
|
pub fn load_image448_openai_norm<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||||
|
let img = image::ImageReader::open(p)?
|
||||||
|
.decode()
|
||||||
|
.map_err(candle::Error::wrap)?
|
||||||
|
.resize_to_fill(448, 448, image::imageops::FilterType::Triangle);
|
||||||
|
let img = img.to_rgb8();
|
||||||
|
let data = img.into_raw();
|
||||||
|
let data = Tensor::from_vec(data, (448, 448, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
||||||
|
let mean =
|
||||||
|
Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||||
|
let std = Tensor::new(&[0.26862954f32, 0.261_302_6, 0.275_777_1], &Device::Cpu)?
|
||||||
|
.reshape((3, 1, 1))?;
|
||||||
|
(data.to_dtype(candle::DType::F32)? / 255.)?
|
||||||
|
.broadcast_sub(&mean)?
|
||||||
|
.broadcast_div(&std)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
image: String,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let image = load_image448_openai_norm(args.image)?.to_device(&device)?;
|
||||||
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
|
let model_file = match args.model {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model("vincent-espitalier/candle-eva2".into());
|
||||||
|
api.get("eva02_base_patch14_448.mim_in22k_ft_in22k_in1k_adapted.safetensors")?
|
||||||
|
}
|
||||||
|
Some(model) => model.into(),
|
||||||
|
};
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||||
|
|
||||||
|
let model = eva2::vit_base(vb)?;
|
||||||
|
println!("model built");
|
||||||
|
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||||
|
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||||
|
.i(0)?
|
||||||
|
.to_vec1::<f32>()?;
|
||||||
|
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||||
|
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||||
|
for &(category_idx, pr) in prs.iter().take(5) {
|
||||||
|
println!(
|
||||||
|
"{:24}: {:.2}%",
|
||||||
|
candle_examples::imagenet::CLASSES[category_idx],
|
||||||
|
100. * pr
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
20
candle-examples/examples/fastvit/README.md
Normal file
20
candle-examples/examples/fastvit/README.md
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# candle-fastvit
|
||||||
|
|
||||||
|
[FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization](https://arxiv.org/abs/2303.14189).
|
||||||
|
This candle implementation uses a pre-trained FastViT network for inference. The
|
||||||
|
classification head has been trained on the ImageNet dataset and returns the
|
||||||
|
probabilities for the top-5 classes.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```
|
||||||
|
$ cargo run --example fastvit --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which sa12
|
||||||
|
|
||||||
|
loaded image Tensor[dims 3, 256, 256; f32]
|
||||||
|
model built
|
||||||
|
mountain bike, all-terrain bike, off-roader: 52.67%
|
||||||
|
bicycle-built-for-two, tandem bicycle, tandem: 7.93%
|
||||||
|
unicycle, monocycle : 3.46%
|
||||||
|
maillot : 1.32%
|
||||||
|
crash helmet : 1.28%
|
||||||
|
```
|
102
candle-examples/examples/fastvit/main.rs
Normal file
102
candle-examples/examples/fastvit/main.rs
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
use candle::{DType, IndexOp, D};
|
||||||
|
use candle_nn::{Module, VarBuilder};
|
||||||
|
use candle_transformers::models::fastvit;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
T8,
|
||||||
|
T12,
|
||||||
|
S12,
|
||||||
|
SA12,
|
||||||
|
SA24,
|
||||||
|
SA36,
|
||||||
|
MA36,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Which {
|
||||||
|
fn model_filename(&self) -> String {
|
||||||
|
let name = match self {
|
||||||
|
Self::T8 => "t8",
|
||||||
|
Self::T12 => "t12",
|
||||||
|
Self::S12 => "s12",
|
||||||
|
Self::SA12 => "sa12",
|
||||||
|
Self::SA24 => "sa24",
|
||||||
|
Self::SA36 => "sa36",
|
||||||
|
Self::MA36 => "ma36",
|
||||||
|
};
|
||||||
|
format!("timm/fastvit_{}.apple_in1k", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn config(&self) -> fastvit::Config {
|
||||||
|
match self {
|
||||||
|
Self::T8 => fastvit::Config::t8(),
|
||||||
|
Self::T12 => fastvit::Config::t12(),
|
||||||
|
Self::S12 => fastvit::Config::s12(),
|
||||||
|
Self::SA12 => fastvit::Config::sa12(),
|
||||||
|
Self::SA24 => fastvit::Config::sa24(),
|
||||||
|
Self::SA36 => fastvit::Config::sa36(),
|
||||||
|
Self::MA36 => fastvit::Config::ma36(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
image: String,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
#[arg(value_enum, long, default_value_t=Which::S12)]
|
||||||
|
which: Which,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let image = candle_examples::imagenet::load_image(args.image, 256)?.to_device(&device)?;
|
||||||
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
|
let model_file = match args.model {
|
||||||
|
None => {
|
||||||
|
let model_name = args.which.model_filename();
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model(model_name);
|
||||||
|
api.get("model.safetensors")?
|
||||||
|
}
|
||||||
|
Some(model) => model.into(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||||
|
let model = fastvit::fastvit(&args.which.config(), 1000, vb)?;
|
||||||
|
println!("model built");
|
||||||
|
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||||
|
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||||
|
.i(0)?
|
||||||
|
.to_vec1::<f32>()?;
|
||||||
|
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||||
|
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||||
|
for &(category_idx, pr) in prs.iter().take(5) {
|
||||||
|
println!(
|
||||||
|
"{:24}: {:.2}%",
|
||||||
|
candle_examples::imagenet::CLASSES[category_idx],
|
||||||
|
100. * pr
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
19
candle-examples/examples/flux/README.md
Normal file
19
candle-examples/examples/flux/README.md
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# candle-flux: image generation with latent rectified flow transformers
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
Flux is a 12B rectified flow transformer capable of generating images from text
|
||||||
|
descriptions,
|
||||||
|
[huggingface](https://huggingface.co/black-forest-labs/FLUX.1-schnell),
|
||||||
|
[github](https://github.com/black-forest-labs/flux),
|
||||||
|
[blog post](https://blackforestlabs.ai/announcing-black-forest-labs/).
|
||||||
|
|
||||||
|
|
||||||
|
## Running the model
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --features cuda --example flux -r -- \
|
||||||
|
--height 1024 --width 1024 \
|
||||||
|
--prompt "a rusty robot walking on a beach holding a small torch, the robot has the word "rust" written on it, high quality, 4k"
|
||||||
|
```
|
||||||
|
|
BIN
candle-examples/examples/flux/assets/flux-robot.jpg
Normal file
BIN
candle-examples/examples/flux/assets/flux-robot.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 90 KiB |
248
candle-examples/examples/flux/main.rs
Normal file
248
candle-examples/examples/flux/main.rs
Normal file
@ -0,0 +1,248 @@
|
|||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
use candle_transformers::models::{clip, flux, t5};
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use candle::{IndexOp, Module, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use clap::Parser;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// The prompt to be used for image generation.
|
||||||
|
#[arg(long, default_value = "A rusty robot walking on a beach")]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Use the quantized model.
|
||||||
|
#[arg(long)]
|
||||||
|
quantized: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
/// The height in pixels of the generated image.
|
||||||
|
#[arg(long)]
|
||||||
|
height: Option<usize>,
|
||||||
|
|
||||||
|
/// The width in pixels of the generated image.
|
||||||
|
#[arg(long)]
|
||||||
|
width: Option<usize>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
decode_only: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, value_enum, default_value = "schnell")]
|
||||||
|
model: Model,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
|
||||||
|
enum Model {
|
||||||
|
Schnell,
|
||||||
|
Dev,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(args: Args) -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let Args {
|
||||||
|
prompt,
|
||||||
|
cpu,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
tracing,
|
||||||
|
decode_only,
|
||||||
|
model,
|
||||||
|
quantized,
|
||||||
|
} = args;
|
||||||
|
let width = width.unwrap_or(1360);
|
||||||
|
let height = height.unwrap_or(768);
|
||||||
|
|
||||||
|
let _guard = if tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let bf_repo = {
|
||||||
|
let name = match model {
|
||||||
|
Model::Dev => "black-forest-labs/FLUX.1-dev",
|
||||||
|
Model::Schnell => "black-forest-labs/FLUX.1-schnell",
|
||||||
|
};
|
||||||
|
api.repo(hf_hub::Repo::model(name.to_string()))
|
||||||
|
};
|
||||||
|
let device = candle_examples::device(cpu)?;
|
||||||
|
let dtype = device.bf16_default_to_f32();
|
||||||
|
let img = match decode_only {
|
||||||
|
None => {
|
||||||
|
let t5_emb = {
|
||||||
|
let repo = api.repo(hf_hub::Repo::with_revision(
|
||||||
|
"google/t5-v1_1-xxl".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/2".to_string(),
|
||||||
|
));
|
||||||
|
let model_file = repo.get("model.safetensors")?;
|
||||||
|
let vb =
|
||||||
|
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
||||||
|
let config_filename = repo.get("config.json")?;
|
||||||
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
|
let config: t5::Config = serde_json::from_str(&config)?;
|
||||||
|
let mut model = t5::T5EncoderModel::load(vb, &config)?;
|
||||||
|
let tokenizer_filename = api
|
||||||
|
.model("lmz/mt5-tokenizers".to_string())
|
||||||
|
.get("t5-v1_1-xxl.tokenizer.json")?;
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
let mut tokens = tokenizer
|
||||||
|
.encode(prompt.as_str(), true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
tokens.resize(256, 0);
|
||||||
|
let input_token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
|
||||||
|
println!("{input_token_ids}");
|
||||||
|
model.forward(&input_token_ids)?
|
||||||
|
};
|
||||||
|
println!("T5\n{t5_emb}");
|
||||||
|
let clip_emb = {
|
||||||
|
let repo = api.repo(hf_hub::Repo::model(
|
||||||
|
"openai/clip-vit-large-patch14".to_string(),
|
||||||
|
));
|
||||||
|
let model_file = repo.get("model.safetensors")?;
|
||||||
|
let vb =
|
||||||
|
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
||||||
|
// https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
|
||||||
|
let config = clip::text_model::ClipTextConfig {
|
||||||
|
vocab_size: 49408,
|
||||||
|
projection_dim: 768,
|
||||||
|
activation: clip::text_model::Activation::QuickGelu,
|
||||||
|
intermediate_size: 3072,
|
||||||
|
embed_dim: 768,
|
||||||
|
max_position_embeddings: 77,
|
||||||
|
pad_with: None,
|
||||||
|
num_hidden_layers: 12,
|
||||||
|
num_attention_heads: 12,
|
||||||
|
};
|
||||||
|
let model =
|
||||||
|
clip::text_model::ClipTextTransformer::new(vb.pp("text_model"), &config)?;
|
||||||
|
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
let tokens = tokenizer
|
||||||
|
.encode(prompt.as_str(), true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
let input_token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
|
||||||
|
println!("{input_token_ids}");
|
||||||
|
model.forward(&input_token_ids)?
|
||||||
|
};
|
||||||
|
println!("CLIP\n{clip_emb}");
|
||||||
|
let img = {
|
||||||
|
let cfg = match model {
|
||||||
|
Model::Dev => flux::model::Config::dev(),
|
||||||
|
Model::Schnell => flux::model::Config::schnell(),
|
||||||
|
};
|
||||||
|
let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?;
|
||||||
|
let state = if quantized {
|
||||||
|
flux::sampling::State::new(
|
||||||
|
&t5_emb.to_dtype(candle::DType::F32)?,
|
||||||
|
&clip_emb.to_dtype(candle::DType::F32)?,
|
||||||
|
&img.to_dtype(candle::DType::F32)?,
|
||||||
|
)?
|
||||||
|
} else {
|
||||||
|
flux::sampling::State::new(&t5_emb, &clip_emb, &img)?
|
||||||
|
};
|
||||||
|
let timesteps = match model {
|
||||||
|
Model::Dev => {
|
||||||
|
flux::sampling::get_schedule(50, Some((state.img.dim(1)?, 0.5, 1.15)))
|
||||||
|
}
|
||||||
|
Model::Schnell => flux::sampling::get_schedule(4, None),
|
||||||
|
};
|
||||||
|
println!("{state:?}");
|
||||||
|
println!("{timesteps:?}");
|
||||||
|
if quantized {
|
||||||
|
let model_file = match model {
|
||||||
|
Model::Schnell => api
|
||||||
|
.repo(hf_hub::Repo::model("lmz/candle-flux".to_string()))
|
||||||
|
.get("flux1-schnell.gguf")?,
|
||||||
|
Model::Dev => todo!(),
|
||||||
|
};
|
||||||
|
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
||||||
|
model_file, &device,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let model = flux::quantized_model::Flux::new(&cfg, vb)?;
|
||||||
|
flux::sampling::denoise(
|
||||||
|
&model,
|
||||||
|
&state.img,
|
||||||
|
&state.img_ids,
|
||||||
|
&state.txt,
|
||||||
|
&state.txt_ids,
|
||||||
|
&state.vec,
|
||||||
|
×teps,
|
||||||
|
4.,
|
||||||
|
)?
|
||||||
|
.to_dtype(dtype)?
|
||||||
|
} else {
|
||||||
|
let model_file = match model {
|
||||||
|
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
|
||||||
|
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
|
||||||
|
};
|
||||||
|
let vb = unsafe {
|
||||||
|
VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)?
|
||||||
|
};
|
||||||
|
let model = flux::model::Flux::new(&cfg, vb)?;
|
||||||
|
flux::sampling::denoise(
|
||||||
|
&model,
|
||||||
|
&state.img,
|
||||||
|
&state.img_ids,
|
||||||
|
&state.txt,
|
||||||
|
&state.txt_ids,
|
||||||
|
&state.vec,
|
||||||
|
×teps,
|
||||||
|
4.,
|
||||||
|
)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
flux::sampling::unpack(&img, height, width)?
|
||||||
|
}
|
||||||
|
Some(file) => {
|
||||||
|
let mut st = candle::safetensors::load(file, &device)?;
|
||||||
|
st.remove("img").unwrap().to_dtype(dtype)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
println!("latent img\n{img}");
|
||||||
|
|
||||||
|
let img = {
|
||||||
|
let model_file = bf_repo.get("ae.safetensors")?;
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
||||||
|
let cfg = match model {
|
||||||
|
Model::Dev => flux::autoencoder::Config::dev(),
|
||||||
|
Model::Schnell => flux::autoencoder::Config::schnell(),
|
||||||
|
};
|
||||||
|
let model = flux::autoencoder::AutoEncoder::new(&cfg, vb)?;
|
||||||
|
model.decode(&img)?
|
||||||
|
};
|
||||||
|
println!("img\n{img}");
|
||||||
|
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?;
|
||||||
|
candle_examples::save_image(&img.i(0)?, "out.jpg")?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
run(args)
|
||||||
|
}
|
6
candle-examples/examples/flux/t5_tokenizer.py
Normal file
6
candle-examples/examples/flux/t5_tokenizer.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
BASE_MODEL = "google/t5-v1_1-xxl"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
||||||
|
# The tokenizer will be saved in /tmp/tokenizer/tokenizer.json
|
||||||
|
tokenizer.save_pretrained("/tmp/tokenizer/")
|
@ -1,27 +1,27 @@
|
|||||||
# candle-gemma: 2b and 7b LLMs from Google DeepMind
|
# candle-gemma: 2b and 7b LLMs from Google DeepMind
|
||||||
|
|
||||||
[Gemma](https://ai.google.dev/gemma/docs) is a collection of lightweight open
|
[Gemma](https://ai.google.dev/gemma/docs) is a collection of lightweight open
|
||||||
models published by Google Deepmind with a 2b and a 7b variant.
|
models published by Google Deepmind with a 2b and a 7b variant for the first
|
||||||
|
version, and a 2b and a 9b variant for v2.
|
||||||
In order to use the example below, you have to accept the license on the
|
|
||||||
[HuggingFace Hub Gemma repo](https://huggingface.co/google/gemma-7b) and set up
|
|
||||||
your access token via the [HuggingFace cli login
|
|
||||||
command](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login).
|
|
||||||
|
|
||||||
## Running the example
|
## Running the example
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cargo run --example gemma --release -- --prompt "fn count_primes(max_n: usize)"
|
$ cargo run --example gemma --features cuda -r -- \
|
||||||
fn count_primes(max_n: usize) -> usize {
|
--prompt "Here is a proof that square root of 2 is not rational: "
|
||||||
let mut primes = vec![true; max_n];
|
|
||||||
for i in 2..=max_n {
|
Here is a proof that square root of 2 is not rational:
|
||||||
if primes[i] {
|
|
||||||
for j in i * i..max_n {
|
Let us assume it to be rational. Then, we can write √2 = p/q where q ≠ 0 and p and q are integers with no common factors other than 1. Squaring both sides gives us (p/q)^2 = 2 or p^2/q^2 = 2. This implies that p^2 is divisible by 2, which means that p must be even. Let us write p = 2m where m is an integer. Substituting this in the above equation we get:
|
||||||
primes[j] = false;
|
|
||||||
}
|
(p^2)/q^2 = 2 or (4m^2)/q^2 = 2 or q^2/2m^2 = 1 which implies that q^2 must be divisible by 2, and hence q is even. This contradicts our assumption that p and q have no common factors other than 1. Hence we conclude that √2 cannot be rational.
|
||||||
}
|
|
||||||
}
|
|
||||||
primes.len()
|
|
||||||
}
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Access restrictions
|
||||||
|
|
||||||
|
In order to use the v1 examples, you have to accept the license on the
|
||||||
|
[HuggingFace Hub Gemma repo](https://huggingface.co/google/gemma-7b) and set up
|
||||||
|
your access token via the [HuggingFace cli login
|
||||||
|
command](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login).
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,7 +7,8 @@ extern crate accelerate_src;
|
|||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
|
||||||
use candle_transformers::models::gemma::{Config, Model};
|
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||||
|
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
@ -38,6 +39,46 @@ enum Which {
|
|||||||
CodeInstruct2B,
|
CodeInstruct2B,
|
||||||
#[value(name = "code-7b-it")]
|
#[value(name = "code-7b-it")]
|
||||||
CodeInstruct7B,
|
CodeInstruct7B,
|
||||||
|
#[value(name = "2-2b")]
|
||||||
|
BaseV2_2B,
|
||||||
|
#[value(name = "2-2b-it")]
|
||||||
|
InstructV2_2B,
|
||||||
|
#[value(name = "2-9b")]
|
||||||
|
BaseV2_9B,
|
||||||
|
#[value(name = "2-9b-it")]
|
||||||
|
InstructV2_9B,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Which {
|
||||||
|
fn is_v1(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::Base2B
|
||||||
|
| Self::Base7B
|
||||||
|
| Self::Instruct2B
|
||||||
|
| Self::Instruct7B
|
||||||
|
| Self::InstructV1_1_2B
|
||||||
|
| Self::InstructV1_1_7B
|
||||||
|
| Self::CodeBase2B
|
||||||
|
| Self::CodeBase7B
|
||||||
|
| Self::CodeInstruct2B
|
||||||
|
| Self::CodeInstruct7B => true,
|
||||||
|
Self::BaseV2_2B | Self::InstructV2_2B | Self::BaseV2_9B | Self::InstructV2_9B => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
enum Model {
|
||||||
|
V1(Model1),
|
||||||
|
V2(Model2),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
fn forward(&mut self, input_ids: &Tensor, pos: usize) -> candle::Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
Self::V1(m) => m.forward(input_ids, pos),
|
||||||
|
Self::V2(m) => m.forward(input_ids, pos),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct TextGeneration {
|
struct TextGeneration {
|
||||||
@ -191,7 +232,7 @@ struct Args {
|
|||||||
repeat_last_n: usize,
|
repeat_last_n: usize,
|
||||||
|
|
||||||
/// The model to use.
|
/// The model to use.
|
||||||
#[arg(long, default_value = "2b")]
|
#[arg(long, default_value = "2-2b")]
|
||||||
which: Which,
|
which: Which,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@ -239,6 +280,10 @@ fn main() -> Result<()> {
|
|||||||
Which::CodeBase7B => "google/codegemma-7b".to_string(),
|
Which::CodeBase7B => "google/codegemma-7b".to_string(),
|
||||||
Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
|
Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
|
||||||
Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
|
Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
|
||||||
|
Which::BaseV2_2B => "google/gemma-2-2b".to_string(),
|
||||||
|
Which::InstructV2_2B => "google/gemma-2-2b-it".to_string(),
|
||||||
|
Which::BaseV2_9B => "google/gemma-2-9b".to_string(),
|
||||||
|
Which::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
let repo = api.repo(Repo::with_revision(
|
let repo = api.repo(Repo::with_revision(
|
||||||
@ -263,7 +308,6 @@ fn main() -> Result<()> {
|
|||||||
};
|
};
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
@ -273,7 +317,15 @@ fn main() -> Result<()> {
|
|||||||
DType::F32
|
DType::F32
|
||||||
};
|
};
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
let model = Model::new(args.use_flash_attn, &config, vb)?;
|
let model = if args.which.is_v1() {
|
||||||
|
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||||
|
let model = Model1::new(args.use_flash_attn, &config, vb)?;
|
||||||
|
Model::V1(model)
|
||||||
|
} else {
|
||||||
|
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||||
|
let model = Model2::new(args.use_flash_attn, &config, vb)?;
|
||||||
|
Model::V2(model)
|
||||||
|
};
|
||||||
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
77
candle-examples/examples/glm4/README.org
Normal file
77
candle-examples/examples/glm4/README.org
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
* GLM4
|
||||||
|
GLM-4-9B is the open-source version of the latest generation of pre-trained models in the GLM-4 series launched by Zhipu AI.
|
||||||
|
|
||||||
|
- [[https://github.com/THUDM/GLM4][Github]]
|
||||||
|
- [[https://huggingface.co/THUDM/glm-4-9b][huggingface]]
|
||||||
|
|
||||||
|
** Running with ~cuda~
|
||||||
|
|
||||||
|
#+begin_src shell
|
||||||
|
cargo run --example glm4 --release --features cuda
|
||||||
|
#+end_src
|
||||||
|
|
||||||
|
** Running with ~cpu~
|
||||||
|
#+begin_src shell
|
||||||
|
cargo run --example glm4 --release -- --cpu
|
||||||
|
#+end_src
|
||||||
|
|
||||||
|
** Output Example
|
||||||
|
#+begin_src shell
|
||||||
|
cargo run --example glm4 --release --features cuda -- --sample-len 500 --cache .
|
||||||
|
Finished release [optimized] target(s) in 0.24s
|
||||||
|
Running `/root/candle/target/release/examples/glm4 --sample-len 500 --cache .`
|
||||||
|
avx: true, neon: false, simd128: false, f16c: true
|
||||||
|
temp: 0.60 repeat-penalty: 1.20 repeat-last-n: 64
|
||||||
|
cache path .
|
||||||
|
retrieved the files in 6.88963ms
|
||||||
|
loaded the model in 6.113752297s
|
||||||
|
starting the inference loop
|
||||||
|
[欢迎使用GLM-4,请输入prompt]
|
||||||
|
请你告诉我什么是FFT
|
||||||
|
266 tokens generated (34.50 token/s)
|
||||||
|
Result:
|
||||||
|
。Fast Fourier Transform (FFT) 是一种快速计算离散傅里叶变换(DFT)的方法,它广泛应用于信号处理、图像处理和数据分析等领域。
|
||||||
|
|
||||||
|
具体来说,FFT是一种将时域数据转换为频域数据的算法。在数字信号处理中,我们通常需要知道信号的频率成分,这就需要进行傅立叶变换。传统的傅立叶变换的计算复杂度较高,而 FFT 则大大提高了计算效率,使得大规模的 DFT 换成为可能。
|
||||||
|
|
||||||
|
以下是使用 Python 中的 numpy 进行 FFT 的简单示例:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# 创建一个时域信号
|
||||||
|
t = np.linspace(0, 1, num=100)
|
||||||
|
f = np.sin(2*np.pi*5*t) + 3*np.cos(2*np.pi*10*t)
|
||||||
|
|
||||||
|
# 对该信号做FFT变换,并计算其幅值谱
|
||||||
|
fft_result = np.fft.fftshift(np.abs(np.fft.fft(f)))
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
在这个例子中,我们首先创建了一个时域信号 f。然后我们对这个信号进行了 FFT 换,得到了一个频域结果 fft_result。
|
||||||
|
#+end_src
|
||||||
|
|
||||||
|
This example will read prompt from stdin
|
||||||
|
|
||||||
|
* Citation
|
||||||
|
#+begin_src
|
||||||
|
@misc{glm2024chatglm,
|
||||||
|
title={ChatGLM: A Family of Large Language Models from GLM-130B to GLM-4 All Tools},
|
||||||
|
author={Team GLM and Aohan Zeng and Bin Xu and Bowen Wang and Chenhui Zhang and Da Yin and Diego Rojas and Guanyu Feng and Hanlin Zhao and Hanyu Lai and Hao Yu and Hongning Wang and Jiadai Sun and Jiajie Zhang and Jiale Cheng and Jiayi Gui and Jie Tang and Jing Zhang and Juanzi Li and Lei Zhao and Lindong Wu and Lucen Zhong and Mingdao Liu and Minlie Huang and Peng Zhang and Qinkai Zheng and Rui Lu and Shuaiqi Duan and Shudan Zhang and Shulin Cao and Shuxun Yang and Weng Lam Tam and Wenyi Zhao and Xiao Liu and Xiao Xia and Xiaohan Zhang and Xiaotao Gu and Xin Lv and Xinghan Liu and Xinyi Liu and Xinyue Yang and Xixuan Song and Xunkai Zhang and Yifan An and Yifan Xu and Yilin Niu and Yuantao Yang and Yueyan Li and Yushi Bai and Yuxiao Dong and Zehan Qi and Zhaoyu Wang and Zhen Yang and Zhengxiao Du and Zhenyu Hou and Zihan Wang},
|
||||||
|
year={2024},
|
||||||
|
eprint={2406.12793},
|
||||||
|
archivePrefix={arXiv},
|
||||||
|
primaryClass={id='cs.CL' full_name='Computation and Language' is_active=True alt_name='cmp-lg' in_archive='cs' is_general=False description='Covers natural language processing. Roughly includes material in ACM Subject Class I.2.7. Note that work on artificial languages (programming languages, logics, formal systems) that does not explicitly address natural-language issues broadly construed (natural-language processing, computational linguistics, speech, text retrieval, etc.) is not appropriate for this area.'}
|
||||||
|
}
|
||||||
|
#+end_src
|
||||||
|
|
||||||
|
#+begin_src
|
||||||
|
@misc{wang2023cogvlm,
|
||||||
|
title={CogVLM: Visual Expert for Pretrained Language Models},
|
||||||
|
author={Weihan Wang and Qingsong Lv and Wenmeng Yu and Wenyi Hong and Ji Qi and Yan Wang and Junhui Ji and Zhuoyi Yang and Lei Zhao and Xixuan Song and Jiazheng Xu and Bin Xu and Juanzi Li and Yuxiao Dong and Ming Ding and Jie Tang},
|
||||||
|
year={2023},
|
||||||
|
eprint={2311.03079},
|
||||||
|
archivePrefix={arXiv},
|
||||||
|
primaryClass={cs.CV}
|
||||||
|
}
|
||||||
|
#+end_src
|
255
candle-examples/examples/glm4/main.rs
Normal file
255
candle-examples/examples/glm4/main.rs
Normal file
@ -0,0 +1,255 @@
|
|||||||
|
use candle_transformers::models::glm4::*;
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
use hf_hub::{Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
struct TextGeneration {
|
||||||
|
model: Model,
|
||||||
|
device: Device,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
logits_processor: LogitsProcessor,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
verbose_prompt: bool,
|
||||||
|
dtype: DType,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextGeneration {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn new(
|
||||||
|
model: Model,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
seed: u64,
|
||||||
|
temp: Option<f64>,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
verbose_prompt: bool,
|
||||||
|
device: &Device,
|
||||||
|
dtype: DType,
|
||||||
|
) -> Self {
|
||||||
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
logits_processor,
|
||||||
|
repeat_penalty,
|
||||||
|
repeat_last_n,
|
||||||
|
verbose_prompt,
|
||||||
|
device: device.clone(),
|
||||||
|
dtype,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(&mut self, sample_len: usize) -> anyhow::Result<()> {
|
||||||
|
use std::io::BufRead;
|
||||||
|
use std::io::BufReader;
|
||||||
|
use std::io::Write;
|
||||||
|
println!("starting the inference loop");
|
||||||
|
println!("[欢迎使用GLM-4,请输入prompt]");
|
||||||
|
let stdin = std::io::stdin();
|
||||||
|
let reader = BufReader::new(stdin);
|
||||||
|
for line in reader.lines() {
|
||||||
|
let line = line.expect("Failed to read line");
|
||||||
|
|
||||||
|
let tokens = self.tokenizer.encode(line, true).expect("tokens error");
|
||||||
|
if tokens.is_empty() {
|
||||||
|
panic!("Empty prompts are not supported in the chatglm model.")
|
||||||
|
}
|
||||||
|
if self.verbose_prompt {
|
||||||
|
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
||||||
|
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
||||||
|
println!("{id:7} -> '{token}'");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
||||||
|
Some(token) => *token,
|
||||||
|
None => panic!("cannot find the endoftext token"),
|
||||||
|
};
|
||||||
|
let mut tokens = tokens.get_ids().to_vec();
|
||||||
|
let mut generated_tokens = 0usize;
|
||||||
|
|
||||||
|
std::io::stdout().flush().expect("output flush error");
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
|
||||||
|
let mut count = 0;
|
||||||
|
let mut result = vec![];
|
||||||
|
for index in 0..sample_len {
|
||||||
|
count += 1;
|
||||||
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
|
let logits = self.model.forward(&input)?;
|
||||||
|
let logits = logits.squeeze(0)?.to_dtype(self.dtype)?;
|
||||||
|
let logits = if self.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
self.repeat_penalty,
|
||||||
|
&tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
generated_tokens += 1;
|
||||||
|
if next_token == eos_token {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let token = self
|
||||||
|
.tokenizer
|
||||||
|
.decode(&[next_token], true)
|
||||||
|
.expect("Token error");
|
||||||
|
if self.verbose_prompt {
|
||||||
|
println!(
|
||||||
|
"[Count: {}] [Raw Token: {}] [Decode Token: {}]",
|
||||||
|
count, next_token, token
|
||||||
|
);
|
||||||
|
}
|
||||||
|
result.push(token);
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
let dt = start_gen.elapsed();
|
||||||
|
println!(
|
||||||
|
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||||
|
generated_tokens as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
println!("Result:");
|
||||||
|
for tokens in result {
|
||||||
|
print!("{tokens}");
|
||||||
|
}
|
||||||
|
self.model.reset_kv_cache(); // clean the cache
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(name = "cache", short, long, default_value = ".")]
|
||||||
|
cache_path: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Display the token for the specified prompt.
|
||||||
|
#[arg(long)]
|
||||||
|
verbose_prompt: bool,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long)]
|
||||||
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(long, short = 'n', default_value_t = 8192)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
revision: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
weight_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.2)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature.unwrap_or(0.6),
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
println!("cache path {}", args.cache_path);
|
||||||
|
let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
|
||||||
|
.build()
|
||||||
|
.map_err(anyhow::Error::msg)?;
|
||||||
|
|
||||||
|
let model_id = match args.model_id {
|
||||||
|
Some(model_id) => model_id.to_string(),
|
||||||
|
None => "THUDM/glm-4-9b".to_string(),
|
||||||
|
};
|
||||||
|
let revision = match args.revision {
|
||||||
|
Some(rev) => rev.to_string(),
|
||||||
|
None => "main".to_string(),
|
||||||
|
};
|
||||||
|
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||||
|
let tokenizer_filename = match args.tokenizer {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => api
|
||||||
|
.model("THUDM/codegeex4-all-9b".to_string())
|
||||||
|
.get("tokenizer.json")
|
||||||
|
.map_err(anyhow::Error::msg)?,
|
||||||
|
};
|
||||||
|
let filenames = match args.weight_file {
|
||||||
|
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
||||||
|
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||||
|
};
|
||||||
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error");
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let config = Config::glm4();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let dtype = if device.is_cuda() {
|
||||||
|
DType::BF16
|
||||||
|
} else {
|
||||||
|
DType::F32
|
||||||
|
};
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
|
let model = Model::new(&config, vb)?;
|
||||||
|
|
||||||
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let mut pipeline = TextGeneration::new(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
args.seed,
|
||||||
|
args.temperature,
|
||||||
|
args.top_p,
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n,
|
||||||
|
args.verbose_prompt,
|
||||||
|
&device,
|
||||||
|
dtype,
|
||||||
|
);
|
||||||
|
pipeline.run(args.sample_len)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
20
candle-examples/examples/granite/README.md
Normal file
20
candle-examples/examples/granite/README.md
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# candle-granite LLMs from IBM Research
|
||||||
|
|
||||||
|
[Granite](https://www.ibm.com/granite) is a family of Large Language Models built for business, to help drive trust and scalability in AI-driven applications.
|
||||||
|
|
||||||
|
## Running the example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example granite --features metal -r -- --model-type "granite7b-instruct" \
|
||||||
|
--prompt "Explain how quantum computing differs from classical computing, focusing on key concepts like qubits, superposition, and entanglement. Describe two potential breakthroughs in the fields of drug discovery and cryptography. Offer a convincing argument for why businesses and governments should invest in quantum computing research now, emphasizing its future benefits and the risks of falling behind"
|
||||||
|
|
||||||
|
Explain how quantum computing differs from classical computing, focusing on key concepts like qubits, superposition, and entanglement. Describe two potential breakthroughs in the fields of drug discovery and cryptography. Offer a convincing argument for why businesses and governments should invest in quantum computing research now, emphasizing its future benefits and the risks of falling behind competitors.
|
||||||
|
|
||||||
|
In recent years, there has been significant interest in quantum computing due to its potential to revolutionize various fields, including drug discovery, cryptography, and optimization problems. Quantum computers, which leverage the principles of quantum mechanics, differ fundamentally from classical computers. Here are some of the key differences:
|
||||||
|
```
|
||||||
|
|
||||||
|
## Supported Models
|
||||||
|
There are two different modalities for the Granite family models: Language and Code.
|
||||||
|
|
||||||
|
### Granite for language
|
||||||
|
1. [Granite 7b Instruct](https://huggingface.co/ibm-granite/granite-7b-instruct)
|
251
candle-examples/examples/granite/main.rs
Normal file
251
candle-examples/examples/granite/main.rs
Normal file
@ -0,0 +1,251 @@
|
|||||||
|
// An implementation of different Granite models https://www.ibm.com/granite
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
use anyhow::{bail, Error as E, Result};
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
use candle::{DType, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use std::io::Write;
|
||||||
|
|
||||||
|
use candle_transformers::models::granite as model;
|
||||||
|
use model::{Granite, GraniteConfig};
|
||||||
|
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
const EOS_TOKEN: &str = "</s>";
|
||||||
|
const DEFAULT_PROMPT: &str = "How Fault Tolerant Quantum Computers will help humanity?";
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
|
enum GraniteModel {
|
||||||
|
Granite7bInstruct,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long, default_value_t = 0.8)]
|
||||||
|
temperature: f64,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// Only sample among the top K samples.
|
||||||
|
#[arg(long)]
|
||||||
|
top_k: Option<usize>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(short = 'n', long, default_value_t = 10000)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
/// Disable the key-value cache.
|
||||||
|
#[arg(long)]
|
||||||
|
no_kv_cache: bool,
|
||||||
|
|
||||||
|
/// The initial prompt.
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: Option<String>,
|
||||||
|
|
||||||
|
/// Use different dtype than f16
|
||||||
|
#[arg(long)]
|
||||||
|
dtype: Option<String>,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
revision: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "granite7b-instruct")]
|
||||||
|
model_type: GraniteModel,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
use_flash_attn: bool,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 128)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let dtype = match args.dtype.as_deref() {
|
||||||
|
Some("f16") => DType::F16,
|
||||||
|
Some("bf16") => DType::BF16,
|
||||||
|
Some("f32") => DType::F32,
|
||||||
|
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||||
|
None => DType::F16,
|
||||||
|
};
|
||||||
|
let (granite, tokenizer_filename, mut cache, config) = {
|
||||||
|
let api = Api::new()?;
|
||||||
|
let model_id = args.model_id.unwrap_or_else(|| match args.model_type {
|
||||||
|
GraniteModel::Granite7bInstruct => "ibm-granite/granite-7b-instruct".to_string(),
|
||||||
|
});
|
||||||
|
println!("loading the model weights from {model_id}");
|
||||||
|
let revision = args.revision.unwrap_or("main".to_string());
|
||||||
|
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||||
|
|
||||||
|
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||||
|
let config_filename = api.get("config.json")?;
|
||||||
|
let config: GraniteConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||||
|
let config = config.into_config(args.use_flash_attn);
|
||||||
|
|
||||||
|
let filenames = match args.model_type {
|
||||||
|
GraniteModel::Granite7bInstruct => {
|
||||||
|
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||||
|
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
|
(
|
||||||
|
Granite::load(vb, &config)?,
|
||||||
|
tokenizer_filename,
|
||||||
|
cache,
|
||||||
|
config,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
let eos_token_id = config.eos_token_id.or_else(|| {
|
||||||
|
tokenizer
|
||||||
|
.token_to_id(EOS_TOKEN)
|
||||||
|
.map(model::GraniteEosToks::Single)
|
||||||
|
});
|
||||||
|
|
||||||
|
let default_prompt = match args.model_type {
|
||||||
|
GraniteModel::Granite7bInstruct => DEFAULT_PROMPT,
|
||||||
|
};
|
||||||
|
|
||||||
|
let prompt = args.prompt.as_ref().map_or(default_prompt, |p| p.as_str());
|
||||||
|
let mut tokens = tokenizer
|
||||||
|
.encode(prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
|
||||||
|
|
||||||
|
println!("Starting the inference loop:");
|
||||||
|
print!("{prompt}");
|
||||||
|
let mut logits_processor = {
|
||||||
|
let temperature = args.temperature;
|
||||||
|
let sampling = if temperature <= 0. {
|
||||||
|
Sampling::ArgMax
|
||||||
|
} else {
|
||||||
|
match (args.top_k, args.top_p) {
|
||||||
|
(None, None) => Sampling::All { temperature },
|
||||||
|
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||||
|
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||||
|
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||||
|
}
|
||||||
|
};
|
||||||
|
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut start_gen = std::time::Instant::now();
|
||||||
|
let mut index_pos = 0;
|
||||||
|
let mut token_generated = 0;
|
||||||
|
let use_cache_kv = cache.use_kv_cache;
|
||||||
|
|
||||||
|
(0..args.sample_len)
|
||||||
|
.inspect(|index| {
|
||||||
|
if *index == 1 {
|
||||||
|
start_gen = Instant::now();
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.try_for_each(|index| -> Result<()> {
|
||||||
|
let (context_size, context_index) = if use_cache_kv && index > 0 {
|
||||||
|
(1, index_pos)
|
||||||
|
} else {
|
||||||
|
(tokens.len(), 0)
|
||||||
|
};
|
||||||
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
|
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||||
|
let logits = granite
|
||||||
|
.forward(&input, context_index, &mut cache)?
|
||||||
|
.squeeze(0)?;
|
||||||
|
|
||||||
|
let logits = if args.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
args.repeat_penalty,
|
||||||
|
&tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
|
index_pos += ctxt.len();
|
||||||
|
|
||||||
|
let next_token = logits_processor.sample(&logits)?;
|
||||||
|
token_generated += 1;
|
||||||
|
tokens.push(next_token);
|
||||||
|
|
||||||
|
if let Some(model::GraniteEosToks::Single(eos_tok_id)) = eos_token_id {
|
||||||
|
if next_token == eos_tok_id {
|
||||||
|
return Err(E::msg("EOS token found"));
|
||||||
|
}
|
||||||
|
} else if let Some(model::GraniteEosToks::Multiple(ref eos_ids)) = eos_token_id {
|
||||||
|
if eos_ids.contains(&next_token) {
|
||||||
|
return Err(E::msg("EOS token found"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.unwrap_or(());
|
||||||
|
|
||||||
|
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
|
||||||
|
let dt = start_gen.elapsed();
|
||||||
|
println!(
|
||||||
|
"\n\n{} tokens generated ({} token/s)\n",
|
||||||
|
token_generated,
|
||||||
|
(token_generated - 1) as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
18
candle-examples/examples/hiera/README.md
Normal file
18
candle-examples/examples/hiera/README.md
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# hiera
|
||||||
|
|
||||||
|
[Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles](https://arxiv.org/abs/2306.00989)
|
||||||
|
This candle implementation uses pre-trained Hiera models from timm for inference.
|
||||||
|
The classification head has been trained on the ImageNet dataset and returns the probabilities for the top-5 classes.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```
|
||||||
|
$ cargo run --example hiera --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which tiny
|
||||||
|
loaded image Tensor[dims 3, 224, 224; f32]
|
||||||
|
model built
|
||||||
|
mountain bike, all-terrain bike, off-roader: 71.15%
|
||||||
|
unicycle, monocycle : 7.11%
|
||||||
|
knee pad : 4.26%
|
||||||
|
crash helmet : 1.48%
|
||||||
|
moped : 1.07%
|
||||||
|
```
|
99
candle-examples/examples/hiera/main.rs
Normal file
99
candle-examples/examples/hiera/main.rs
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
use candle::{DType, IndexOp, D};
|
||||||
|
use candle_nn::{Module, VarBuilder};
|
||||||
|
use candle_transformers::models::hiera;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
Tiny,
|
||||||
|
Small,
|
||||||
|
Base,
|
||||||
|
BasePlus,
|
||||||
|
Large,
|
||||||
|
Huge,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Which {
|
||||||
|
fn model_filename(&self) -> String {
|
||||||
|
let name = match self {
|
||||||
|
Self::Tiny => "tiny",
|
||||||
|
Self::Small => "small",
|
||||||
|
Self::Base => "base",
|
||||||
|
Self::BasePlus => "base_plus",
|
||||||
|
Self::Large => "large",
|
||||||
|
Self::Huge => "huge",
|
||||||
|
};
|
||||||
|
format!("timm/hiera_{}_224.mae_in1k_ft_in1k", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn config(&self) -> hiera::Config {
|
||||||
|
match self {
|
||||||
|
Self::Tiny => hiera::Config::tiny(),
|
||||||
|
Self::Small => hiera::Config::small(),
|
||||||
|
Self::Base => hiera::Config::base(),
|
||||||
|
Self::BasePlus => hiera::Config::base_plus(),
|
||||||
|
Self::Large => hiera::Config::large(),
|
||||||
|
Self::Huge => hiera::Config::huge(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
image: String,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
#[arg(value_enum, long, default_value_t=Which::Tiny)]
|
||||||
|
which: Which,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||||
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
|
let model_file = match args.model {
|
||||||
|
None => {
|
||||||
|
let model_name = args.which.model_filename();
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model(model_name);
|
||||||
|
api.get("model.safetensors")?
|
||||||
|
}
|
||||||
|
Some(model) => model.into(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||||
|
let model = hiera::hiera(&args.which.config(), 1000, vb)?;
|
||||||
|
println!("model built");
|
||||||
|
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||||
|
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||||
|
.i(0)?
|
||||||
|
.to_vec1::<f32>()?;
|
||||||
|
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||||
|
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||||
|
for &(category_idx, pr) in prs.iter().take(5) {
|
||||||
|
println!(
|
||||||
|
"{:24}: {:.2}%",
|
||||||
|
candle_examples::imagenet::CLASSES[category_idx],
|
||||||
|
100. * pr
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -4,7 +4,7 @@ extern crate intel_mkl_src;
|
|||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use candle_transformers::models::jina_bert::{BertModel, Config};
|
use candle_transformers::models::jina_bert::{BertModel, Config, PositionEmbeddingType};
|
||||||
|
|
||||||
use anyhow::Error as E;
|
use anyhow::Error as E;
|
||||||
use candle::{DType, Module, Tensor};
|
use candle::{DType, Module, Tensor};
|
||||||
@ -39,32 +39,47 @@ struct Args {
|
|||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model: Option<String>,
|
model: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_file: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Args {
|
impl Args {
|
||||||
fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
|
fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
let model = match &self.model {
|
let model_name = match self.model.as_ref() {
|
||||||
|
Some(model) => model.to_string(),
|
||||||
|
None => "jinaai/jina-embeddings-v2-base-en".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let model = match &self.model_file {
|
||||||
Some(model_file) => std::path::PathBuf::from(model_file),
|
Some(model_file) => std::path::PathBuf::from(model_file),
|
||||||
None => Api::new()?
|
None => Api::new()?
|
||||||
.repo(Repo::new(
|
.repo(Repo::new(model_name.to_string(), RepoType::Model))
|
||||||
"jinaai/jina-embeddings-v2-base-en".to_string(),
|
|
||||||
RepoType::Model,
|
|
||||||
))
|
|
||||||
.get("model.safetensors")?,
|
.get("model.safetensors")?,
|
||||||
};
|
};
|
||||||
let tokenizer = match &self.tokenizer {
|
let tokenizer = match &self.tokenizer {
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
None => Api::new()?
|
None => Api::new()?
|
||||||
.repo(Repo::new(
|
.repo(Repo::new(model_name.to_string(), RepoType::Model))
|
||||||
"sentence-transformers/all-MiniLM-L6-v2".to_string(),
|
|
||||||
RepoType::Model,
|
|
||||||
))
|
|
||||||
.get("tokenizer.json")?,
|
.get("tokenizer.json")?,
|
||||||
};
|
};
|
||||||
let device = candle_examples::device(self.cpu)?;
|
let device = candle_examples::device(self.cpu)?;
|
||||||
let config = Config::v2_base();
|
|
||||||
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||||
|
let config = Config::new(
|
||||||
|
tokenizer.get_vocab_size(true),
|
||||||
|
768,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
3072,
|
||||||
|
candle_nn::Activation::Gelu,
|
||||||
|
8192,
|
||||||
|
2,
|
||||||
|
0.02,
|
||||||
|
1e-12,
|
||||||
|
0,
|
||||||
|
PositionEmbeddingType::Alibi,
|
||||||
|
);
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||||
let model = BertModel::new(vb, &config)?;
|
let model = BertModel::new(vb, &config)?;
|
||||||
Ok((model, tokenizer))
|
Ok((model, tokenizer))
|
||||||
@ -101,14 +116,20 @@ fn main() -> anyhow::Result<()> {
|
|||||||
.to_vec();
|
.to_vec();
|
||||||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||||
println!("Loaded and encoded {:?}", start.elapsed());
|
println!("Loaded and encoded {:?}", start.elapsed());
|
||||||
for idx in 0..args.n {
|
let start = std::time::Instant::now();
|
||||||
let start = std::time::Instant::now();
|
let embeddings = model.forward(&token_ids)?;
|
||||||
let ys = model.forward(&token_ids)?;
|
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||||
if idx == 0 {
|
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||||
println!("{ys}");
|
println!("pooled_embeddigns: {embeddings}");
|
||||||
}
|
let embeddings = if args.normalize_embeddings {
|
||||||
println!("Took {:?}", start.elapsed());
|
normalize_l2(&embeddings)?
|
||||||
|
} else {
|
||||||
|
embeddings
|
||||||
|
};
|
||||||
|
if args.normalize_embeddings {
|
||||||
|
println!("normalized_embeddings: {embeddings}");
|
||||||
}
|
}
|
||||||
|
println!("Took {:?}", start.elapsed());
|
||||||
} else {
|
} else {
|
||||||
let sentences = [
|
let sentences = [
|
||||||
"The cat sits outside",
|
"The cat sits outside",
|
||||||
|
@ -32,7 +32,9 @@ enum Which {
|
|||||||
V1,
|
V1,
|
||||||
V2,
|
V2,
|
||||||
V3,
|
V3,
|
||||||
|
V31,
|
||||||
V3Instruct,
|
V3Instruct,
|
||||||
|
V31Instruct,
|
||||||
#[value(name = "solar-10.7b")]
|
#[value(name = "solar-10.7b")]
|
||||||
Solar10_7B,
|
Solar10_7B,
|
||||||
#[value(name = "tiny-llama-1.1b-chat")]
|
#[value(name = "tiny-llama-1.1b-chat")]
|
||||||
@ -133,6 +135,8 @@ fn main() -> Result<()> {
|
|||||||
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
|
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
|
||||||
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
|
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
|
||||||
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
|
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
|
||||||
|
Which::V31 => "meta-llama/Meta-Llama-3.1-8B".to_string(),
|
||||||
|
Which::V31Instruct => "meta-llama/Meta-Llama-3.1-8B-Instruct".to_string(),
|
||||||
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
||||||
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
|
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
|
||||||
});
|
});
|
||||||
@ -146,7 +150,13 @@ fn main() -> Result<()> {
|
|||||||
let config = config.into_config(args.use_flash_attn);
|
let config = config.into_config(args.use_flash_attn);
|
||||||
|
|
||||||
let filenames = match args.which {
|
let filenames = match args.which {
|
||||||
Which::V1 | Which::V2 | Which::V3 | Which::V3Instruct | Which::Solar10_7B => {
|
Which::V1
|
||||||
|
| Which::V2
|
||||||
|
| Which::V3
|
||||||
|
| Which::V3Instruct
|
||||||
|
| Which::V31
|
||||||
|
| Which::V31Instruct
|
||||||
|
| Which::Solar10_7B => {
|
||||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||||
}
|
}
|
||||||
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
||||||
@ -157,9 +167,11 @@ fn main() -> Result<()> {
|
|||||||
(Llama::load(vb, &config)?, tokenizer_filename, cache, config)
|
(Llama::load(vb, &config)?, tokenizer_filename, cache, config)
|
||||||
};
|
};
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
let eos_token_id = config
|
let eos_token_id = config.eos_token_id.or_else(|| {
|
||||||
.eos_token_id
|
tokenizer
|
||||||
.or_else(|| tokenizer.token_to_id(EOS_TOKEN));
|
.token_to_id(EOS_TOKEN)
|
||||||
|
.map(model::LlamaEosToks::Single)
|
||||||
|
});
|
||||||
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
||||||
let mut tokens = tokenizer
|
let mut tokens = tokenizer
|
||||||
.encode(prompt, true)
|
.encode(prompt, true)
|
||||||
@ -217,8 +229,14 @@ fn main() -> Result<()> {
|
|||||||
token_generated += 1;
|
token_generated += 1;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
|
|
||||||
if Some(next_token) == eos_token_id {
|
match eos_token_id {
|
||||||
break;
|
Some(model::LlamaEosToks::Single(eos_tok_id)) if next_token == eos_tok_id => {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Some(model::LlamaEosToks::Multiple(ref eos_ids)) if eos_ids.contains(&next_token) => {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
_ => (),
|
||||||
}
|
}
|
||||||
if let Some(t) = tokenizer.next_token(next_token)? {
|
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||||
print!("{t}");
|
print!("{t}");
|
||||||
|
@ -14,6 +14,7 @@ use clap::{Parser, ValueEnum};
|
|||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
use candle_transformers::models::llama::LlamaEosToks;
|
||||||
use cudarc::driver::safe::CudaDevice;
|
use cudarc::driver::safe::CudaDevice;
|
||||||
use cudarc::nccl::safe::{Comm, Id};
|
use cudarc::nccl::safe::{Comm, Id};
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
@ -219,9 +220,16 @@ fn main() -> Result<()> {
|
|||||||
let next_token = logits_processor.sample(&logits)?;
|
let next_token = logits_processor.sample(&logits)?;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
new_tokens.push(next_token);
|
new_tokens.push(next_token);
|
||||||
if Some(next_token) == config.eos_token_id {
|
match config.eos_token_id {
|
||||||
break;
|
Some(LlamaEosToks::Single(eos_tok_id)) if next_token == eos_tok_id => {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Some(LlamaEosToks::Multiple(ref eos_ids)) if eos_ids.contains(&next_token) => {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
_ => (),
|
||||||
}
|
}
|
||||||
|
|
||||||
if rank == 0 {
|
if rank == 0 {
|
||||||
if let Some(t) = tokenizer.next_token(next_token)? {
|
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||||
print!("{t}");
|
print!("{t}");
|
||||||
|
@ -57,7 +57,7 @@ fn load_image<T: AsRef<std::path::Path>>(
|
|||||||
llava_config: &LLaVAConfig,
|
llava_config: &LLaVAConfig,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
) -> Result<((u32, u32), Tensor)> {
|
) -> Result<((u32, u32), Tensor)> {
|
||||||
let img = image::io::Reader::open(path)?.decode()?;
|
let img = image::ImageReader::open(path)?.decode()?;
|
||||||
let img_tensor = process_image(&img, processor, llava_config)?;
|
let img_tensor = process_image(&img, processor, llava_config)?;
|
||||||
Ok(((img.width(), img.height()), img_tensor.to_dtype(dtype)?))
|
Ok(((img.width(), img.height()), img_tensor.to_dtype(dtype)?))
|
||||||
}
|
}
|
||||||
|
@ -43,6 +43,14 @@ def import_protobuf(error_message=""):
|
|||||||
else:
|
else:
|
||||||
raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))
|
raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))
|
||||||
|
|
||||||
|
def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str:
|
||||||
|
if add_prefix_space:
|
||||||
|
prepend_scheme = "always"
|
||||||
|
if hasattr(original_tokenizer, "legacy") and not original_tokenizer.legacy:
|
||||||
|
prepend_scheme = "first"
|
||||||
|
else:
|
||||||
|
prepend_scheme = "never"
|
||||||
|
return prepend_scheme
|
||||||
|
|
||||||
class SentencePieceExtractor:
|
class SentencePieceExtractor:
|
||||||
"""
|
"""
|
||||||
@ -519,13 +527,15 @@ class SpmConverter(Converter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def pre_tokenizer(self, replacement, add_prefix_space):
|
def pre_tokenizer(self, replacement, add_prefix_space):
|
||||||
return pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
|
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
|
||||||
|
return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
|
||||||
|
|
||||||
def post_processor(self):
|
def post_processor(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def decoder(self, replacement, add_prefix_space):
|
def decoder(self, replacement, add_prefix_space):
|
||||||
return decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
|
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
|
||||||
|
return decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
|
||||||
|
|
||||||
def converted(self) -> Tokenizer:
|
def converted(self) -> Tokenizer:
|
||||||
tokenizer = self.tokenizer(self.proto)
|
tokenizer = self.tokenizer(self.proto)
|
||||||
@ -636,7 +646,8 @@ class DebertaV2Converter(SpmConverter):
|
|||||||
list_pretokenizers = []
|
list_pretokenizers = []
|
||||||
if self.original_tokenizer.split_by_punct:
|
if self.original_tokenizer.split_by_punct:
|
||||||
list_pretokenizers.append(pre_tokenizers.Punctuation(behavior="isolated"))
|
list_pretokenizers.append(pre_tokenizers.Punctuation(behavior="isolated"))
|
||||||
list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space))
|
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
|
||||||
|
list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme))
|
||||||
return pre_tokenizers.Sequence(list_pretokenizers)
|
return pre_tokenizers.Sequence(list_pretokenizers)
|
||||||
|
|
||||||
def normalizer(self, proto):
|
def normalizer(self, proto):
|
||||||
@ -929,10 +940,11 @@ class PegasusConverter(SpmConverter):
|
|||||||
return proto.trainer_spec.unk_id + self.original_tokenizer.offset
|
return proto.trainer_spec.unk_id + self.original_tokenizer.offset
|
||||||
|
|
||||||
def pre_tokenizer(self, replacement, add_prefix_space):
|
def pre_tokenizer(self, replacement, add_prefix_space):
|
||||||
|
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
|
||||||
return pre_tokenizers.Sequence(
|
return pre_tokenizers.Sequence(
|
||||||
[
|
[
|
||||||
pre_tokenizers.WhitespaceSplit(),
|
pre_tokenizers.WhitespaceSplit(),
|
||||||
pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space),
|
pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
20
candle-examples/examples/mimi/README.md
Normal file
20
candle-examples/examples/mimi/README.md
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# candle-mimi
|
||||||
|
|
||||||
|
[Mimi](https://huggingface.co/kyutai/mimi) is a state of the art audio
|
||||||
|
compression model using an encoder/decoder architecture with residual vector
|
||||||
|
quantization. The candle implementation supports streaming meaning that it's
|
||||||
|
possible to encode or decode a stream of audio tokens on the flight to provide
|
||||||
|
low latency interaction with an audio model.
|
||||||
|
|
||||||
|
## Running one example
|
||||||
|
|
||||||
|
Generating some audio tokens from an audio files.
|
||||||
|
```bash
|
||||||
|
wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
|
||||||
|
cargo run --example mimi --features mimi --release -- audio-to-code bria.mp3 bria.safetensors
|
||||||
|
```
|
||||||
|
|
||||||
|
And decoding the audio tokens back into a sound file.
|
||||||
|
```bash
|
||||||
|
cargo run --example mimi --features mimi --release -- code-to-audio bria.safetensors bria.wav
|
||||||
|
```
|
275
candle-examples/examples/mimi/audio_io.rs
Normal file
275
candle-examples/examples/mimi/audio_io.rs
Normal file
@ -0,0 +1,275 @@
|
|||||||
|
#![allow(unused)]
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
pub const SAMPLE_RATE: usize = 24_000;
|
||||||
|
|
||||||
|
pub(crate) struct AudioOutputData_ {
|
||||||
|
resampled_data: std::collections::VecDeque<f32>,
|
||||||
|
resampler: rubato::FastFixedIn<f32>,
|
||||||
|
output_buffer: Vec<f32>,
|
||||||
|
input_buffer: Vec<f32>,
|
||||||
|
input_len: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AudioOutputData_ {
|
||||||
|
pub(crate) fn new(input_sample_rate: usize, output_sample_rate: usize) -> Result<Self> {
|
||||||
|
use rubato::Resampler;
|
||||||
|
|
||||||
|
let resampled_data = std::collections::VecDeque::with_capacity(output_sample_rate * 10);
|
||||||
|
let resample_ratio = output_sample_rate as f64 / input_sample_rate as f64;
|
||||||
|
let resampler = rubato::FastFixedIn::new(
|
||||||
|
resample_ratio,
|
||||||
|
f64::max(resample_ratio, 1.0),
|
||||||
|
rubato::PolynomialDegree::Septic,
|
||||||
|
1024,
|
||||||
|
1,
|
||||||
|
)?;
|
||||||
|
let input_buffer = resampler.input_buffer_allocate(true).remove(0);
|
||||||
|
let output_buffer = resampler.output_buffer_allocate(true).remove(0);
|
||||||
|
Ok(Self {
|
||||||
|
resampled_data,
|
||||||
|
resampler,
|
||||||
|
input_buffer,
|
||||||
|
output_buffer,
|
||||||
|
input_len: 0,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(&mut self) {
|
||||||
|
use rubato::Resampler;
|
||||||
|
self.output_buffer.fill(0.);
|
||||||
|
self.input_buffer.fill(0.);
|
||||||
|
self.resampler.reset();
|
||||||
|
self.resampled_data.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn take_all(&mut self) -> Vec<f32> {
|
||||||
|
let mut data = Vec::with_capacity(self.resampled_data.len());
|
||||||
|
while let Some(elem) = self.resampled_data.pop_back() {
|
||||||
|
data.push(elem);
|
||||||
|
}
|
||||||
|
data
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn is_empty(&self) -> bool {
|
||||||
|
self.resampled_data.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assumes that the input buffer is large enough.
|
||||||
|
fn push_input_buffer(&mut self, samples: &[f32]) {
|
||||||
|
self.input_buffer[self.input_len..self.input_len + samples.len()].copy_from_slice(samples);
|
||||||
|
self.input_len += samples.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn push_samples(&mut self, samples: &[f32]) -> Result<()> {
|
||||||
|
use rubato::Resampler;
|
||||||
|
|
||||||
|
let mut pos_in = 0;
|
||||||
|
loop {
|
||||||
|
let rem = self.input_buffer.len() - self.input_len;
|
||||||
|
let pos_end = usize::min(pos_in + rem, samples.len());
|
||||||
|
self.push_input_buffer(&samples[pos_in..pos_end]);
|
||||||
|
pos_in = pos_end;
|
||||||
|
if self.input_len < self.input_buffer.len() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let (_, out_len) = self.resampler.process_into_buffer(
|
||||||
|
&[&self.input_buffer],
|
||||||
|
&mut [&mut self.output_buffer],
|
||||||
|
None,
|
||||||
|
)?;
|
||||||
|
for &elem in self.output_buffer[..out_len].iter() {
|
||||||
|
self.resampled_data.push_front(elem)
|
||||||
|
}
|
||||||
|
self.input_len = 0;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type AudioOutputData = Arc<Mutex<AudioOutputData_>>;
|
||||||
|
|
||||||
|
pub(crate) fn setup_output_stream() -> Result<(cpal::Stream, AudioOutputData)> {
|
||||||
|
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||||
|
|
||||||
|
println!("Setup audio output stream!");
|
||||||
|
let host = cpal::default_host();
|
||||||
|
let device = host
|
||||||
|
.default_output_device()
|
||||||
|
.context("no output device available")?;
|
||||||
|
let mut supported_configs_range = device.supported_output_configs()?;
|
||||||
|
let config_range = match supported_configs_range.find(|c| c.channels() == 1) {
|
||||||
|
// On macOS, it's commonly the case that there are only stereo outputs.
|
||||||
|
None => device
|
||||||
|
.supported_output_configs()?
|
||||||
|
.next()
|
||||||
|
.context("no audio output available")?,
|
||||||
|
Some(config_range) => config_range,
|
||||||
|
};
|
||||||
|
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
|
||||||
|
config_range.min_sample_rate(),
|
||||||
|
config_range.max_sample_rate(),
|
||||||
|
);
|
||||||
|
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
|
||||||
|
let channels = config.channels as usize;
|
||||||
|
println!(
|
||||||
|
"cpal device: {} {} {config:?}",
|
||||||
|
device.name().unwrap_or_else(|_| "unk".to_string()),
|
||||||
|
config.sample_rate.0
|
||||||
|
);
|
||||||
|
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
|
||||||
|
SAMPLE_RATE,
|
||||||
|
config.sample_rate.0 as usize,
|
||||||
|
)?));
|
||||||
|
let ad = audio_data.clone();
|
||||||
|
let stream = device.build_output_stream(
|
||||||
|
&config,
|
||||||
|
move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
|
||||||
|
data.fill(0.);
|
||||||
|
let mut ad = ad.lock().unwrap();
|
||||||
|
let mut last_elem = 0f32;
|
||||||
|
for (idx, elem) in data.iter_mut().enumerate() {
|
||||||
|
if idx % channels == 0 {
|
||||||
|
match ad.resampled_data.pop_back() {
|
||||||
|
None => break,
|
||||||
|
Some(v) => {
|
||||||
|
last_elem = v;
|
||||||
|
*elem = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
*elem = last_elem
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
move |err| eprintln!("cpal error: {err}"),
|
||||||
|
None, // None=blocking, Some(Duration)=timeout
|
||||||
|
)?;
|
||||||
|
stream.play()?;
|
||||||
|
Ok((stream, audio_data))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn setup_input_stream() -> Result<(cpal::Stream, AudioOutputData)> {
|
||||||
|
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||||
|
|
||||||
|
println!("Setup audio input stream!");
|
||||||
|
let host = cpal::default_host();
|
||||||
|
let device = host
|
||||||
|
.default_input_device()
|
||||||
|
.context("no input device available")?;
|
||||||
|
let mut supported_configs_range = device.supported_input_configs()?;
|
||||||
|
let config_range = supported_configs_range
|
||||||
|
.find(|c| c.channels() == 1)
|
||||||
|
.context("no audio input available")?;
|
||||||
|
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
|
||||||
|
config_range.min_sample_rate(),
|
||||||
|
config_range.max_sample_rate(),
|
||||||
|
);
|
||||||
|
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
|
||||||
|
println!(
|
||||||
|
"cpal device: {} {} {config:?}",
|
||||||
|
device.name().unwrap_or_else(|_| "unk".to_string()),
|
||||||
|
config.sample_rate.0
|
||||||
|
);
|
||||||
|
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
|
||||||
|
config.sample_rate.0 as usize,
|
||||||
|
SAMPLE_RATE,
|
||||||
|
)?));
|
||||||
|
let ad = audio_data.clone();
|
||||||
|
let stream = device.build_input_stream(
|
||||||
|
&config,
|
||||||
|
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||||
|
let mut ad = ad.lock().unwrap();
|
||||||
|
if let Err(err) = ad.push_samples(data) {
|
||||||
|
eprintln!("error processing audio input {err:?}")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
move |err| eprintln!("cpal error: {err}"),
|
||||||
|
None, // None=blocking, Some(Duration)=timeout
|
||||||
|
)?;
|
||||||
|
stream.play()?;
|
||||||
|
Ok((stream, audio_data))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
|
||||||
|
where
|
||||||
|
T: symphonia::core::sample::Sample,
|
||||||
|
f32: symphonia::core::conv::FromSample<T>,
|
||||||
|
{
|
||||||
|
use symphonia::core::audio::Signal;
|
||||||
|
use symphonia::core::conv::FromSample;
|
||||||
|
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
|
||||||
|
use symphonia::core::audio::{AudioBufferRef, Signal};
|
||||||
|
|
||||||
|
let src = std::fs::File::open(path)?;
|
||||||
|
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
||||||
|
let hint = symphonia::core::probe::Hint::new();
|
||||||
|
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
||||||
|
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
||||||
|
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
|
||||||
|
let mut format = probed.format;
|
||||||
|
let track = format
|
||||||
|
.tracks()
|
||||||
|
.iter()
|
||||||
|
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
|
||||||
|
.expect("no supported audio tracks");
|
||||||
|
let mut decoder = symphonia::default::get_codecs()
|
||||||
|
.make(&track.codec_params, &Default::default())
|
||||||
|
.expect("unsupported codec");
|
||||||
|
let track_id = track.id;
|
||||||
|
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
||||||
|
let mut pcm_data = Vec::new();
|
||||||
|
while let Ok(packet) = format.next_packet() {
|
||||||
|
while !format.metadata().is_latest() {
|
||||||
|
format.metadata().pop();
|
||||||
|
}
|
||||||
|
if packet.track_id() != track_id {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
match decoder.decode(&packet)? {
|
||||||
|
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
||||||
|
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((pcm_data, sample_rate))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn resample(pcm_in: &[f32], sr_in: usize, sr_out: usize) -> Result<Vec<f32>> {
|
||||||
|
use rubato::Resampler;
|
||||||
|
|
||||||
|
let mut pcm_out =
|
||||||
|
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
|
||||||
|
|
||||||
|
let mut resampler = rubato::FftFixedInOut::<f32>::new(sr_in, sr_out, 1024, 1)?;
|
||||||
|
let mut output_buffer = resampler.output_buffer_allocate(true);
|
||||||
|
let mut pos_in = 0;
|
||||||
|
while pos_in + resampler.input_frames_next() < pcm_in.len() {
|
||||||
|
let (in_len, out_len) =
|
||||||
|
resampler.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)?;
|
||||||
|
pos_in += in_len;
|
||||||
|
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if pos_in < pcm_in.len() {
|
||||||
|
let (_in_len, out_len) = resampler.process_partial_into_buffer(
|
||||||
|
Some(&[&pcm_in[pos_in..]]),
|
||||||
|
&mut output_buffer,
|
||||||
|
None,
|
||||||
|
)?;
|
||||||
|
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(pcm_out)
|
||||||
|
}
|
165
candle-examples/examples/mimi/main.rs
Normal file
165
candle-examples/examples/mimi/main.rs
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use candle::{DType, IndexOp, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::models::mimi::{Config, Model};
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
use hf_hub::api::sync::Api;
|
||||||
|
|
||||||
|
mod audio_io;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
|
enum Action {
|
||||||
|
AudioToAudio,
|
||||||
|
AudioToCode,
|
||||||
|
CodeToAudio,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// The action to be performed, specifies the format for the input and output data.
|
||||||
|
action: Action,
|
||||||
|
|
||||||
|
/// The input file, either an audio file or some mimi tokens stored as safetensors.
|
||||||
|
in_file: String,
|
||||||
|
|
||||||
|
/// The output file, either a wave audio file or some mimi tokens stored as safetensors.
|
||||||
|
out_file: String,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// The model weight file, in safetensor format.
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
/// Whether to use streaming or not, when streaming slices of data of the given size are passed
|
||||||
|
/// to the encoder/decoder one at a time.
|
||||||
|
#[arg(long)]
|
||||||
|
streaming: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let model = match args.model {
|
||||||
|
Some(model) => std::path::PathBuf::from(model),
|
||||||
|
None => Api::new()?
|
||||||
|
.model("kyutai/mimi".to_string())
|
||||||
|
.get("model.safetensors")?,
|
||||||
|
};
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||||
|
let config = Config::v0_1(None);
|
||||||
|
let mut model = Model::new(config, vb)?;
|
||||||
|
|
||||||
|
let codes = match args.action {
|
||||||
|
Action::CodeToAudio => {
|
||||||
|
let codes = candle::safetensors::load(args.in_file, &device)?;
|
||||||
|
codes.get("codes").expect("no codes in input file").clone()
|
||||||
|
}
|
||||||
|
Action::AudioToCode | Action::AudioToAudio => {
|
||||||
|
let pcm = if args.in_file == "-" {
|
||||||
|
println!(">>>> RECORDING AUDIO, PRESS ENTER ONCE DONE <<<<");
|
||||||
|
let (stream, input_audio) = audio_io::setup_input_stream()?;
|
||||||
|
let mut pcms = vec![];
|
||||||
|
let stdin = std::thread::spawn(|| {
|
||||||
|
let mut s = String::new();
|
||||||
|
std::io::stdin().read_line(&mut s)
|
||||||
|
});
|
||||||
|
while !stdin.is_finished() {
|
||||||
|
let input = input_audio.lock().unwrap().take_all();
|
||||||
|
if input.is_empty() {
|
||||||
|
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
pcms.push(input)
|
||||||
|
}
|
||||||
|
drop(stream);
|
||||||
|
pcms.concat()
|
||||||
|
} else {
|
||||||
|
let (pcm, sample_rate) = audio_io::pcm_decode(args.in_file)?;
|
||||||
|
if sample_rate != 24_000 {
|
||||||
|
println!("WARNING: mimi uses a 24khz sample rate, input uses {sample_rate}, resampling...");
|
||||||
|
audio_io::resample(&pcm, sample_rate as usize, 24_000)?
|
||||||
|
} else {
|
||||||
|
pcm
|
||||||
|
}
|
||||||
|
};
|
||||||
|
match args.streaming {
|
||||||
|
Some(chunk_size) => {
|
||||||
|
let mut code_chunks = vec![];
|
||||||
|
for pcm in pcm.chunks(chunk_size) {
|
||||||
|
let pcm = Tensor::new(pcm, &device)?.reshape((1, 1, ()))?;
|
||||||
|
let code_chunk = model.encode(&pcm)?;
|
||||||
|
code_chunks.push(code_chunk)
|
||||||
|
}
|
||||||
|
Tensor::cat(&code_chunks, candle::D::Minus1)?
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
let pcm_len = pcm.len();
|
||||||
|
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
|
||||||
|
println!("input pcm shape: {:?}", pcm.shape());
|
||||||
|
model.encode(&pcm)?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
println!("codes shape: {:?}", codes.shape());
|
||||||
|
model.reset_state();
|
||||||
|
|
||||||
|
match args.action {
|
||||||
|
Action::AudioToCode => {
|
||||||
|
codes.save_safetensors("codes", &args.out_file)?;
|
||||||
|
}
|
||||||
|
Action::AudioToAudio | Action::CodeToAudio => {
|
||||||
|
let pcm = match args.streaming {
|
||||||
|
Some(chunk_size) => {
|
||||||
|
let seq_len = codes.dim(candle::D::Minus1)?;
|
||||||
|
let mut pcm_chunks = vec![];
|
||||||
|
for chunk_start in (0..seq_len).step_by(chunk_size) {
|
||||||
|
let chunk_len = usize::min(chunk_size, seq_len - chunk_start);
|
||||||
|
let codes = codes.narrow(candle::D::Minus1, chunk_start, chunk_len)?;
|
||||||
|
let pcm = model.decode_step(&codes.into())?;
|
||||||
|
if let Some(pcm) = pcm.as_option() {
|
||||||
|
pcm_chunks.push(pcm.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Tensor::cat(&pcm_chunks, candle::D::Minus1)?
|
||||||
|
}
|
||||||
|
None => model.decode(&codes)?,
|
||||||
|
};
|
||||||
|
println!("output pcm shape: {:?}", pcm.shape());
|
||||||
|
let pcm = pcm.i(0)?.i(0)?;
|
||||||
|
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
||||||
|
let pcm = pcm.to_vec1::<f32>()?;
|
||||||
|
if args.out_file == "-" {
|
||||||
|
let (stream, ad) = audio_io::setup_output_stream()?;
|
||||||
|
{
|
||||||
|
let mut ad = ad.lock().unwrap();
|
||||||
|
ad.push_samples(&pcm)?;
|
||||||
|
}
|
||||||
|
loop {
|
||||||
|
let ad = ad.lock().unwrap();
|
||||||
|
if ad.is_empty() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// That's very weird, calling thread::sleep here triggers the stream to stop
|
||||||
|
// playing (the callback doesn't seem to be called anymore).
|
||||||
|
// std::thread::sleep(std::time::Duration::from_millis(100));
|
||||||
|
}
|
||||||
|
drop(stream)
|
||||||
|
} else {
|
||||||
|
let mut output = std::fs::File::create(&args.out_file)?;
|
||||||
|
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -147,6 +147,12 @@ enum Which {
|
|||||||
Mistral7bInstructV01,
|
Mistral7bInstructV01,
|
||||||
#[value(name = "7b-instruct-v0.2")]
|
#[value(name = "7b-instruct-v0.2")]
|
||||||
Mistral7bInstructV02,
|
Mistral7bInstructV02,
|
||||||
|
#[value(name = "7b-maths-v0.1")]
|
||||||
|
Mathstral7bV01,
|
||||||
|
#[value(name = "nemo-2407")]
|
||||||
|
MistralNemo2407,
|
||||||
|
#[value(name = "nemo-instruct-2407")]
|
||||||
|
MistralNemoInstruct2407,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -261,12 +267,16 @@ fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
"lmz/candle-mistral".to_string()
|
"lmz/candle-mistral".to_string()
|
||||||
} else {
|
} else {
|
||||||
match args.which {
|
let name = match args.which {
|
||||||
Which::Mistral7bV01 => "mistralai/Mistral-7B-v0.1".to_string(),
|
Which::Mistral7bV01 => "mistralai/Mistral-7B-v0.1",
|
||||||
Which::Mistral7bV02 => "mistralai/Mistral-7B-v0.2".to_string(),
|
Which::Mistral7bV02 => "mistralai/Mistral-7B-v0.2",
|
||||||
Which::Mistral7bInstructV01 => "mistralai/Mistral-7B-Instruct-v0.1".to_string(),
|
Which::Mistral7bInstructV01 => "mistralai/Mistral-7B-Instruct-v0.1",
|
||||||
Which::Mistral7bInstructV02 => "mistralai/Mistral-7B-Instruct-v0.2".to_string(),
|
Which::Mistral7bInstructV02 => "mistralai/Mistral-7B-Instruct-v0.2",
|
||||||
}
|
Which::Mathstral7bV01 => "mistralai/mathstral-7B-v0.1",
|
||||||
|
Which::MistralNemo2407 => "mistralai/Mistral-Nemo-Base-2407",
|
||||||
|
Which::MistralNemoInstruct2407 => "mistralai/Mistral-Nemo-Instruct-2407",
|
||||||
|
};
|
||||||
|
name.to_string()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -217,11 +217,7 @@ fn main() -> Result<()> {
|
|||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let config = Config::v0_1_8x7b(args.use_flash_attn);
|
let config = Config::v0_1_8x7b(args.use_flash_attn);
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let dtype = if device.is_cuda() {
|
let dtype = device.bf16_default_to_f32();
|
||||||
DType::BF16
|
|
||||||
} else {
|
|
||||||
DType::F32
|
|
||||||
};
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
let model = Model::new(&config, vb)?;
|
let model = Model::new(&config, vb)?;
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
28
candle-examples/examples/mobileclip/README.md
Normal file
28
candle-examples/examples/mobileclip/README.md
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# candle-mobileclip
|
||||||
|
|
||||||
|
MobileCLIP is family of efficient CLIP-like models using FastViT-based image encoders.
|
||||||
|
|
||||||
|
See [MobileCLIP: Fast Image-Text Models through Multi-Modal Reinforced Training](https://arxiv.org/abs/2311.17049)
|
||||||
|
|
||||||
|
|
||||||
|
## Running on an example on cpu
|
||||||
|
|
||||||
|
```
|
||||||
|
$ cargo run --example mobileclip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
|
||||||
|
|
||||||
|
softmax_image_vec: [2.4819004e-5, 3.81081e-6, 0.9999714, 0.9999738, 2.382714e-5, 2.3317718e-6]
|
||||||
|
|
||||||
|
|
||||||
|
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
|
||||||
|
|
||||||
|
Probability: 0.0025% Text: a cycling race
|
||||||
|
Probability: 0.0004% Text: a photo of two cats
|
||||||
|
Probability: 99.9971% Text: a robot holding a candle
|
||||||
|
|
||||||
|
|
||||||
|
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
|
||||||
|
Probability: 99.9974% Text: a cycling race
|
||||||
|
Probability: 0.0024% Text: a photo of two cats
|
||||||
|
Probability: 0.0002% Text: a robot holding a candle
|
||||||
|
```
|
192
candle-examples/examples/mobileclip/main.rs
Normal file
192
candle-examples/examples/mobileclip/main.rs
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::Error as E;
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_nn::{ops::softmax, VarBuilder};
|
||||||
|
use candle_transformers::models::mobileclip;
|
||||||
|
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
S1,
|
||||||
|
S2,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Which {
|
||||||
|
fn model_name(&self) -> String {
|
||||||
|
let name = match self {
|
||||||
|
Self::S1 => "S1",
|
||||||
|
Self::S2 => "S2",
|
||||||
|
};
|
||||||
|
format!("apple/MobileCLIP-{}-OpenCLIP", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn config(&self) -> mobileclip::MobileClipConfig {
|
||||||
|
match self {
|
||||||
|
Self::S1 => mobileclip::MobileClipConfig::s1(),
|
||||||
|
Self::S2 => mobileclip::MobileClipConfig::s2(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long, use_value_delimiter = true)]
|
||||||
|
images: Option<Vec<String>>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Use the pytorch weights rather than the safetensors ones
|
||||||
|
#[arg(long)]
|
||||||
|
use_pth: bool,
|
||||||
|
|
||||||
|
#[arg(long, use_value_delimiter = true)]
|
||||||
|
sequences: Option<Vec<String>>,
|
||||||
|
|
||||||
|
#[arg(value_enum, long, default_value_t=Which::S1)]
|
||||||
|
which: Which,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_images<T: AsRef<std::path::Path>>(
|
||||||
|
paths: &Vec<T>,
|
||||||
|
image_size: usize,
|
||||||
|
) -> anyhow::Result<Tensor> {
|
||||||
|
let mut images = vec![];
|
||||||
|
|
||||||
|
for path in paths {
|
||||||
|
let tensor = candle_examples::imagenet::load_image_with_std_mean(
|
||||||
|
path,
|
||||||
|
image_size,
|
||||||
|
&[0.0, 0.0, 0.0],
|
||||||
|
&[1.0, 1.0, 1.0],
|
||||||
|
)?;
|
||||||
|
images.push(tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
let images = Tensor::stack(&images, 0)?;
|
||||||
|
|
||||||
|
Ok(images)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let model_name = args.which.model_name();
|
||||||
|
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model(model_name);
|
||||||
|
|
||||||
|
let model_file = if args.use_pth {
|
||||||
|
api.get("open_clip_pytorch_model.bin")?
|
||||||
|
} else {
|
||||||
|
api.get("open_clip_model.safetensors")?
|
||||||
|
};
|
||||||
|
|
||||||
|
let tokenizer = api.get("tokenizer.json")?;
|
||||||
|
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let config = &args.which.config();
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let vec_imgs = match args.images {
|
||||||
|
Some(imgs) => imgs,
|
||||||
|
None => vec![
|
||||||
|
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
|
||||||
|
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
|
||||||
|
|
||||||
|
let vb = if args.use_pth {
|
||||||
|
VarBuilder::from_pth(&model_file, DType::F32, &device)?
|
||||||
|
} else {
|
||||||
|
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? }
|
||||||
|
};
|
||||||
|
|
||||||
|
let model = mobileclip::MobileClipModel::new(vb, config)?;
|
||||||
|
|
||||||
|
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
|
||||||
|
|
||||||
|
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
|
||||||
|
|
||||||
|
let softmax_image = softmax(&logits_per_image, 1)?;
|
||||||
|
|
||||||
|
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
|
||||||
|
|
||||||
|
println!("softmax_image_vec: {:?}", softmax_image_vec);
|
||||||
|
|
||||||
|
let probability_vec = softmax_image_vec
|
||||||
|
.iter()
|
||||||
|
.map(|v| v * 100.0)
|
||||||
|
.collect::<Vec<f32>>();
|
||||||
|
|
||||||
|
let probability_per_image = probability_vec.len() / vec_imgs.len();
|
||||||
|
|
||||||
|
for (i, img) in vec_imgs.iter().enumerate() {
|
||||||
|
let start = i * probability_per_image;
|
||||||
|
let end = start + probability_per_image;
|
||||||
|
let prob = &probability_vec[start..end];
|
||||||
|
println!("\n\nResults for image: {}\n", img);
|
||||||
|
|
||||||
|
for (i, p) in prob.iter().enumerate() {
|
||||||
|
println!("Probability: {:.4}% Text: {}", p, vec_seq[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tokenize_sequences(
|
||||||
|
sequences: Option<Vec<String>>,
|
||||||
|
tokenizer: &Tokenizer,
|
||||||
|
device: &Device,
|
||||||
|
) -> anyhow::Result<(Tensor, Vec<String>)> {
|
||||||
|
// let pad_id = *tokenizer
|
||||||
|
// .get_vocab(true)
|
||||||
|
// .get("<|endoftext|>")
|
||||||
|
// .ok_or(E::msg("No pad token"))?;
|
||||||
|
|
||||||
|
// The model does not work well if the text is padded using the <|endoftext|> token, using 0
|
||||||
|
// as the original OpenCLIP code.
|
||||||
|
let pad_id = 0;
|
||||||
|
|
||||||
|
let vec_seq = match sequences {
|
||||||
|
Some(seq) => seq,
|
||||||
|
None => vec![
|
||||||
|
"a cycling race".to_string(),
|
||||||
|
"a photo of two cats".to_string(),
|
||||||
|
"a robot holding a candle".to_string(),
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut tokens = vec![];
|
||||||
|
|
||||||
|
for seq in vec_seq.clone() {
|
||||||
|
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
|
||||||
|
tokens.push(encoding.get_ids().to_vec());
|
||||||
|
}
|
||||||
|
|
||||||
|
let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0);
|
||||||
|
// Pad the sequences to have the same length
|
||||||
|
for token_vec in tokens.iter_mut() {
|
||||||
|
let len_diff = max_len - token_vec.len();
|
||||||
|
if len_diff > 0 {
|
||||||
|
token_vec.extend(vec![pad_id; len_diff]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let input_ids = Tensor::new(tokens, device)?;
|
||||||
|
|
||||||
|
Ok((input_ids, vec_seq))
|
||||||
|
}
|
18
candle-examples/examples/mobilenetv4/README.md
Normal file
18
candle-examples/examples/mobilenetv4/README.md
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# candle-mobilenetv4
|
||||||
|
|
||||||
|
[MobileNetV4 - Universal Models for the Mobile Ecosystem](https://arxiv.org/abs/2404.10518)
|
||||||
|
This candle implementation uses pre-trained MobileNetV4 models from timm for inference.
|
||||||
|
The classification head has been trained on the ImageNet dataset and returns the probabilities for the top-5 classes.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```
|
||||||
|
$ cargo run --example mobilenetv4 --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which medium
|
||||||
|
loaded image Tensor[dims 3, 256, 256; f32]
|
||||||
|
model built
|
||||||
|
unicycle, monocycle : 20.18%
|
||||||
|
mountain bike, all-terrain bike, off-roader: 19.77%
|
||||||
|
bicycle-built-for-two, tandem bicycle, tandem: 15.91%
|
||||||
|
crash helmet : 1.15%
|
||||||
|
tricycle, trike, velocipede: 0.67%
|
||||||
|
```
|
107
candle-examples/examples/mobilenetv4/main.rs
Normal file
107
candle-examples/examples/mobilenetv4/main.rs
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
use candle::{DType, IndexOp, D};
|
||||||
|
use candle_nn::{Module, VarBuilder};
|
||||||
|
use candle_transformers::models::mobilenetv4;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
Small,
|
||||||
|
Medium,
|
||||||
|
Large,
|
||||||
|
HybridMedium,
|
||||||
|
HybridLarge,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Which {
|
||||||
|
fn model_filename(&self) -> String {
|
||||||
|
let name = match self {
|
||||||
|
Self::Small => "conv_small.e2400_r224",
|
||||||
|
Self::Medium => "conv_medium.e500_r256",
|
||||||
|
Self::HybridMedium => "hybrid_medium.ix_e550_r256",
|
||||||
|
Self::Large => "conv_large.e600_r384",
|
||||||
|
Self::HybridLarge => "hybrid_large.ix_e600_r384",
|
||||||
|
};
|
||||||
|
format!("timm/mobilenetv4_{}_in1k", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn resolution(&self) -> u32 {
|
||||||
|
match self {
|
||||||
|
Self::Small => 224,
|
||||||
|
Self::Medium => 256,
|
||||||
|
Self::HybridMedium => 256,
|
||||||
|
Self::Large => 384,
|
||||||
|
Self::HybridLarge => 384,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn config(&self) -> mobilenetv4::Config {
|
||||||
|
match self {
|
||||||
|
Self::Small => mobilenetv4::Config::small(),
|
||||||
|
Self::Medium => mobilenetv4::Config::medium(),
|
||||||
|
Self::HybridMedium => mobilenetv4::Config::hybrid_medium(),
|
||||||
|
Self::Large => mobilenetv4::Config::large(),
|
||||||
|
Self::HybridLarge => mobilenetv4::Config::hybrid_large(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
image: String,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
#[arg(value_enum, long, default_value_t=Which::Small)]
|
||||||
|
which: Which,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let image =
|
||||||
|
candle_examples::imagenet::load_image(args.image, args.which.resolution() as usize)?
|
||||||
|
.to_device(&device)?;
|
||||||
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
|
let model_file = match args.model {
|
||||||
|
None => {
|
||||||
|
let model_name = args.which.model_filename();
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model(model_name);
|
||||||
|
api.get("model.safetensors")?
|
||||||
|
}
|
||||||
|
Some(model) => model.into(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||||
|
let model = mobilenetv4::mobilenetv4(&args.which.config(), 1000, vb)?;
|
||||||
|
println!("model built");
|
||||||
|
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||||
|
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||||
|
.i(0)?
|
||||||
|
.to_vec1::<f32>()?;
|
||||||
|
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||||
|
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||||
|
for &(category_idx, pr) in prs.iter().take(5) {
|
||||||
|
println!(
|
||||||
|
"{:24}: {:.2}%",
|
||||||
|
candle_examples::imagenet::CLASSES[category_idx],
|
||||||
|
100. * pr
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -188,8 +188,8 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model_id: Option<String>,
|
model_id: Option<String>,
|
||||||
|
|
||||||
#[arg(long, default_value = "main")]
|
#[arg(long)]
|
||||||
revision: String,
|
revision: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
quantized: bool,
|
quantized: bool,
|
||||||
@ -208,7 +208,7 @@ struct Args {
|
|||||||
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||||
/// (3, 378, 378).
|
/// (3, 378, 378).
|
||||||
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> candle::Result<Tensor> {
|
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> candle::Result<Tensor> {
|
||||||
let img = image::io::Reader::open(p)?
|
let img = image::ImageReader::open(p)?
|
||||||
.decode()
|
.decode()
|
||||||
.map_err(candle::Error::wrap)?
|
.map_err(candle::Error::wrap)?
|
||||||
.resize_to_fill(378, 378, image::imageops::FilterType::Triangle); // Adjusted to 378x378
|
.resize_to_fill(378, 378, image::imageops::FilterType::Triangle); // Adjusted to 378x378
|
||||||
@ -252,20 +252,28 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let api = hf_hub::api::tokio::Api::new()?;
|
let api = hf_hub::api::tokio::Api::new()?;
|
||||||
let model_id = match args.model_id {
|
let (model_id, revision) = match args.model_id {
|
||||||
Some(model_id) => model_id.to_string(),
|
Some(model_id) => (model_id.to_string(), None),
|
||||||
None => {
|
None => {
|
||||||
if args.quantized {
|
if args.quantized {
|
||||||
"santiagomed/candle-moondream".to_string()
|
("santiagomed/candle-moondream".to_string(), None)
|
||||||
} else {
|
} else {
|
||||||
"vikhyatk/moondream2".to_string()
|
(
|
||||||
|
"vikhyatk/moondream2".to_string(),
|
||||||
|
Some("30c7cdf3fa6914f50bee3956694374143f5cc884"),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
let revision = match (args.revision, revision) {
|
||||||
|
(Some(r), _) => r,
|
||||||
|
(None, Some(r)) => r.to_string(),
|
||||||
|
(None, None) => "main".to_string(),
|
||||||
|
};
|
||||||
let repo = api.repo(hf_hub::Repo::with_revision(
|
let repo = api.repo(hf_hub::Repo::with_revision(
|
||||||
model_id,
|
model_id,
|
||||||
hf_hub::RepoType::Model,
|
hf_hub::RepoType::Model,
|
||||||
args.revision,
|
revision,
|
||||||
));
|
));
|
||||||
let model_file = match args.model_file {
|
let model_file = match args.model_file {
|
||||||
Some(m) => m.into(),
|
Some(m) => m.into(),
|
||||||
|
@ -284,11 +284,11 @@ impl MusicgenDecoder {
|
|||||||
};
|
};
|
||||||
let embed_dim = cfg.vocab_size + 1;
|
let embed_dim = cfg.vocab_size + 1;
|
||||||
let embed_tokens = (0..cfg.num_codebooks)
|
let embed_tokens = (0..cfg.num_codebooks)
|
||||||
.map(|i| embedding(embed_dim, h, vb.pp(&format!("embed_tokens.{i}"))))
|
.map(|i| embedding(embed_dim, h, vb.pp(format!("embed_tokens.{i}"))))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let embed_positions = MusicgenSinusoidalPositionalEmbedding::load(vb.clone(), cfg)?;
|
let embed_positions = MusicgenSinusoidalPositionalEmbedding::load(vb.clone(), cfg)?;
|
||||||
let layers = (0..cfg.num_hidden_layers)
|
let layers = (0..cfg.num_hidden_layers)
|
||||||
.map(|i| MusicgenDecoderLayer::load(vb.pp(&format!("layers.{i}")), cfg))
|
.map(|i| MusicgenDecoderLayer::load(vb.pp(format!("layers.{i}")), cfg))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let layer_norm = layer_norm(h, 1e-5, vb.pp("layer_norm"))?;
|
let layer_norm = layer_norm(h, 1e-5, vb.pp("layer_norm"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -341,7 +341,7 @@ impl MusicgenForCausalLM {
|
|||||||
let h = cfg.hidden_size;
|
let h = cfg.hidden_size;
|
||||||
let decoder = MusicgenDecoder::load(vb.pp("model.decoder"), cfg)?;
|
let decoder = MusicgenDecoder::load(vb.pp("model.decoder"), cfg)?;
|
||||||
let lm_heads = (0..cfg.num_codebooks)
|
let lm_heads = (0..cfg.num_codebooks)
|
||||||
.map(|i| linear_no_bias(h, cfg.vocab_size, vb.pp(&format!("lm_heads.{i}"))))
|
.map(|i| linear_no_bias(h, cfg.vocab_size, vb.pp(format!("lm_heads.{i}"))))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
decoder,
|
decoder,
|
||||||
|
23
candle-examples/examples/parler-tts/README.md
Normal file
23
candle-examples/examples/parler-tts/README.md
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# candle-parler-tts
|
||||||
|
|
||||||
|
[Parler-TTS](https://huggingface.co/parler-tts/parler-tts-large-v1) is a large
|
||||||
|
text-to-speech model with 2.2B parameters trained on ~45K hours of audio data.
|
||||||
|
The voice can be controlled by a text prompt.
|
||||||
|
|
||||||
|
## Run an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example parler-tts -r -- \
|
||||||
|
--prompt "Hey, how are you doing today?"
|
||||||
|
```
|
||||||
|
|
||||||
|
In order to specify some prompt for the voice, use the `--description` argument.
|
||||||
|
```bash
|
||||||
|
cargo run --example parler-tts -r -- \
|
||||||
|
--prompt "Hey, how are you doing today?" \
|
||||||
|
--description "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up."
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
https://github.com/user-attachments/assets/1b16aeac-70a3-4803-8589-4563279bba33
|
||||||
|
|
BIN
candle-examples/examples/parler-tts/hello.mp4
Normal file
BIN
candle-examples/examples/parler-tts/hello.mp4
Normal file
Binary file not shown.
206
candle-examples/examples/parler-tts/main.rs
Normal file
206
candle-examples/examples/parler-tts/main.rs
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::Error as E;
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle::{DType, IndexOp, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::models::parler_tts::{Config, Model};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
/// Display the token for the specified prompt.
|
||||||
|
#[arg(long)]
|
||||||
|
verbose_prompt: bool,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "Hey, how are you doing today?")]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
#[arg(
|
||||||
|
long,
|
||||||
|
default_value = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up."
|
||||||
|
)]
|
||||||
|
description: String,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long, default_value_t = 0.0)]
|
||||||
|
temperature: f64,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 0)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
#[arg(long, default_value_t = 5000)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.0)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
revision: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
quantized: bool,
|
||||||
|
|
||||||
|
/// Use f16 precision for all the computations rather than f32.
|
||||||
|
#[arg(long)]
|
||||||
|
f16: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
config_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, default_value_t = 512)]
|
||||||
|
max_steps: usize,
|
||||||
|
|
||||||
|
/// The output wav file.
|
||||||
|
#[arg(long, default_value = "out.wav")]
|
||||||
|
out_file: String,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "large-v1")]
|
||||||
|
which: Which,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "large-v1")]
|
||||||
|
LargeV1,
|
||||||
|
#[value(name = "mini-v1")]
|
||||||
|
MiniV1,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> anyhow::Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature, args.repeat_penalty, args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let model_id = match args.model_id {
|
||||||
|
Some(model_id) => model_id.to_string(),
|
||||||
|
None => match args.which {
|
||||||
|
Which::LargeV1 => "parler-tts/parler-tts-large-v1".to_string(),
|
||||||
|
Which::MiniV1 => "parler-tts/parler-tts-mini-v1".to_string(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let revision = match args.revision {
|
||||||
|
Some(r) => r,
|
||||||
|
None => "main".to_string(),
|
||||||
|
};
|
||||||
|
let repo = api.repo(hf_hub::Repo::with_revision(
|
||||||
|
model_id,
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
revision,
|
||||||
|
));
|
||||||
|
let model_files = match args.model_file {
|
||||||
|
Some(m) => vec![m.into()],
|
||||||
|
None => match args.which {
|
||||||
|
Which::MiniV1 => vec![repo.get("model.safetensors")?],
|
||||||
|
Which::LargeV1 => {
|
||||||
|
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let config = match args.config_file {
|
||||||
|
Some(m) => m.into(),
|
||||||
|
None => repo.get("config.json")?,
|
||||||
|
};
|
||||||
|
let tokenizer = match args.tokenizer_file {
|
||||||
|
Some(m) => m.into(),
|
||||||
|
None => repo.get("tokenizer.json")?,
|
||||||
|
};
|
||||||
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_files, DType::F32, &device)? };
|
||||||
|
let config: Config = serde_json::from_reader(std::fs::File::open(config)?)?;
|
||||||
|
let mut model = Model::new(&config, vb)?;
|
||||||
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let description_tokens = tokenizer
|
||||||
|
.encode(args.description, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
let description_tokens = Tensor::new(description_tokens, &device)?.unsqueeze(0)?;
|
||||||
|
let prompt_tokens = tokenizer
|
||||||
|
.encode(args.prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
let prompt_tokens = Tensor::new(prompt_tokens, &device)?.unsqueeze(0)?;
|
||||||
|
let lp = candle_transformers::generation::LogitsProcessor::new(
|
||||||
|
args.seed,
|
||||||
|
Some(args.temperature),
|
||||||
|
args.top_p,
|
||||||
|
);
|
||||||
|
println!("starting generation...");
|
||||||
|
let codes = model.generate(&prompt_tokens, &description_tokens, lp, args.max_steps)?;
|
||||||
|
println!("generated codes\n{codes}");
|
||||||
|
let codes = codes.to_dtype(DType::I64)?;
|
||||||
|
codes.save_safetensors("codes", "out.safetensors")?;
|
||||||
|
let codes = codes.unsqueeze(0)?;
|
||||||
|
let pcm = model
|
||||||
|
.audio_encoder
|
||||||
|
.decode_codes(&codes.to_device(&device)?)?;
|
||||||
|
println!("{pcm}");
|
||||||
|
let pcm = pcm.i((0, 0))?;
|
||||||
|
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
||||||
|
let pcm = pcm.to_vec1::<f32>()?;
|
||||||
|
let mut output = std::fs::File::create(&args.out_file)?;
|
||||||
|
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, config.audio_encoder.sampling_rate)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -114,6 +114,10 @@ impl TextGeneration {
|
|||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
generated_tokens += 1;
|
generated_tokens += 1;
|
||||||
if next_token == eos_token {
|
if next_token == eos_token {
|
||||||
|
if let Some(t) = self.tokenizer.decode_rest()? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||||
@ -357,10 +361,8 @@ fn main() -> Result<()> {
|
|||||||
let dtype = match args.dtype {
|
let dtype = match args.dtype {
|
||||||
Some(dtype) => std::str::FromStr::from_str(&dtype)?,
|
Some(dtype) => std::str::FromStr::from_str(&dtype)?,
|
||||||
None => {
|
None => {
|
||||||
if (args.model == WhichModel::V3 || args.model == WhichModel::V3Medium)
|
if args.model == WhichModel::V3 || args.model == WhichModel::V3Medium {
|
||||||
&& device.is_cuda()
|
device.bf16_default_to_f32()
|
||||||
{
|
|
||||||
DType::BF16
|
|
||||||
} else {
|
} else {
|
||||||
DType::F32
|
DType::F32
|
||||||
}
|
}
|
||||||
|
11
candle-examples/examples/quantized-qwen2-instruct/README.md
Normal file
11
candle-examples/examples/quantized-qwen2-instruct/README.md
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# candle-quantized-qwen2-instruct
|
||||||
|
|
||||||
|
[Qwen2]((https://qwenlm.github.io/blog/qwen2/)) is an upgraded version of Qwen1.5, released by Alibaba Cloud.
|
||||||
|
|
||||||
|
## Running the example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example quantized-qwen2-instruct --release -- --prompt "Write a function to count prime numbers up to N."
|
||||||
|
```
|
||||||
|
|
||||||
|
0.5b, 1.5b, 7b and 72b models are available via `--model` argument.
|
306
candle-examples/examples/quantized-qwen2-instruct/main.rs
Normal file
306
candle-examples/examples/quantized-qwen2-instruct/main.rs
Normal file
@ -0,0 +1,306 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
use std::io::Write;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
use candle::quantized::gguf_file;
|
||||||
|
use candle::Tensor;
|
||||||
|
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||||
|
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
|
use candle_transformers::models::quantized_qwen2::ModelWeights as Qwen2;
|
||||||
|
|
||||||
|
const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. ";
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "0.5b")]
|
||||||
|
W2_0_5b,
|
||||||
|
#[value(name = "1.5b")]
|
||||||
|
W2_1_5b,
|
||||||
|
#[value(name = "7b")]
|
||||||
|
W2_7b,
|
||||||
|
#[value(name = "72b")]
|
||||||
|
W2_72b,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// GGUF file to load, typically a .gguf file generated by the quantize command from llama.cpp
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
|
||||||
|
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
|
||||||
|
/// is preserved.
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: Option<String>,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(short = 'n', long, default_value_t = 1000)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
/// The tokenizer config in json format.
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples, use 0 for greedy sampling.
|
||||||
|
#[arg(long, default_value_t = 0.8)]
|
||||||
|
temperature: f64,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// Only sample among the top K samples.
|
||||||
|
#[arg(long)]
|
||||||
|
top_k: Option<usize>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
/// Process prompt elements separately.
|
||||||
|
#[arg(long)]
|
||||||
|
split_prompt: bool,
|
||||||
|
|
||||||
|
/// Run on CPU rather than GPU even if a GPU is available.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
|
||||||
|
/// The model size to use.
|
||||||
|
#[arg(long, default_value = "0.5b")]
|
||||||
|
which: Which,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Args {
|
||||||
|
fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
|
||||||
|
let tokenizer_path = match &self.tokenizer {
|
||||||
|
Some(config) => std::path::PathBuf::from(config),
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let repo = match self.which {
|
||||||
|
Which::W2_0_5b => "Qwen/Qwen2-0.5B-Instruct",
|
||||||
|
Which::W2_1_5b => "Qwen/Qwen2-1.5B-Instruct",
|
||||||
|
Which::W2_7b => "Qwen/Qwen2-7B-Instruct",
|
||||||
|
Which::W2_72b => "Qwen/Qwen2-72B-Instruct",
|
||||||
|
};
|
||||||
|
let api = api.model(repo.to_string());
|
||||||
|
api.get("tokenizer.json")?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn model(&self) -> anyhow::Result<std::path::PathBuf> {
|
||||||
|
let model_path = match &self.model {
|
||||||
|
Some(config) => std::path::PathBuf::from(config),
|
||||||
|
None => {
|
||||||
|
let (repo, filename, revision) = match self.which {
|
||||||
|
Which::W2_0_5b => (
|
||||||
|
"Qwen/Qwen2-0.5B-Instruct-GGUF",
|
||||||
|
"qwen2-0_5b-instruct-q4_0.gguf",
|
||||||
|
"main",
|
||||||
|
),
|
||||||
|
Which::W2_1_5b => (
|
||||||
|
"Qwen/Qwen2-1.5B-Instruct-GGUF",
|
||||||
|
"qwen2-1_5b-instruct-q4_0.gguf",
|
||||||
|
"main",
|
||||||
|
),
|
||||||
|
Which::W2_7b => (
|
||||||
|
"Qwen/Qwen2-7B-Instruct-GGUF",
|
||||||
|
"qwen2-7b-instruct-q4_0.gguf",
|
||||||
|
"main",
|
||||||
|
),
|
||||||
|
Which::W2_72b => (
|
||||||
|
"Qwen/Qwen2-72B-Instruct-GGUF",
|
||||||
|
"qwen2-72b-instruct-q4_0.gguf",
|
||||||
|
"main",
|
||||||
|
),
|
||||||
|
};
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
api.repo(hf_hub::Repo::with_revision(
|
||||||
|
repo.to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
revision.to_string(),
|
||||||
|
))
|
||||||
|
.get(filename)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(model_path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn format_size(size_in_bytes: usize) -> String {
|
||||||
|
if size_in_bytes < 1_000 {
|
||||||
|
format!("{}B", size_in_bytes)
|
||||||
|
} else if size_in_bytes < 1_000_000 {
|
||||||
|
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
|
||||||
|
} else if size_in_bytes < 1_000_000_000 {
|
||||||
|
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
|
||||||
|
} else {
|
||||||
|
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> anyhow::Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature, args.repeat_penalty, args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let model_path = args.model()?;
|
||||||
|
let mut file = std::fs::File::open(&model_path)?;
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let mut model = {
|
||||||
|
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
|
||||||
|
let mut total_size_in_bytes = 0;
|
||||||
|
for (_, tensor) in model.tensor_infos.iter() {
|
||||||
|
let elem_count = tensor.shape.elem_count();
|
||||||
|
total_size_in_bytes +=
|
||||||
|
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
|
||||||
|
}
|
||||||
|
println!(
|
||||||
|
"loaded {:?} tensors ({}) in {:.2}s",
|
||||||
|
model.tensor_infos.len(),
|
||||||
|
&format_size(total_size_in_bytes),
|
||||||
|
start.elapsed().as_secs_f32(),
|
||||||
|
);
|
||||||
|
Qwen2::from_gguf(model, &mut file, &device)?
|
||||||
|
};
|
||||||
|
println!("model built");
|
||||||
|
|
||||||
|
let tokenizer = args.tokenizer()?;
|
||||||
|
let mut tos = TokenOutputStream::new(tokenizer);
|
||||||
|
let prompt_str = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
|
||||||
|
let prompt_str = format!(
|
||||||
|
"<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
|
||||||
|
prompt_str
|
||||||
|
);
|
||||||
|
print!("formatted instruct prompt: {}", &prompt_str);
|
||||||
|
let tokens = tos
|
||||||
|
.tokenizer()
|
||||||
|
.encode(prompt_str, true)
|
||||||
|
.map_err(anyhow::Error::msg)?;
|
||||||
|
let tokens = tokens.get_ids();
|
||||||
|
let to_sample = args.sample_len.saturating_sub(1);
|
||||||
|
let mut all_tokens = vec![];
|
||||||
|
let mut logits_processor = {
|
||||||
|
let temperature = args.temperature;
|
||||||
|
let sampling = if temperature <= 0. {
|
||||||
|
Sampling::ArgMax
|
||||||
|
} else {
|
||||||
|
match (args.top_k, args.top_p) {
|
||||||
|
(None, None) => Sampling::All { temperature },
|
||||||
|
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||||
|
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||||
|
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||||
|
}
|
||||||
|
};
|
||||||
|
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||||
|
};
|
||||||
|
let start_prompt_processing = std::time::Instant::now();
|
||||||
|
let mut next_token = if !args.split_prompt {
|
||||||
|
let input = Tensor::new(tokens, &device)?.unsqueeze(0)?;
|
||||||
|
let logits = model.forward(&input, 0)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
logits_processor.sample(&logits)?
|
||||||
|
} else {
|
||||||
|
let mut next_token = 0;
|
||||||
|
for (pos, token) in tokens.iter().enumerate() {
|
||||||
|
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
|
||||||
|
let logits = model.forward(&input, pos)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
next_token = logits_processor.sample(&logits)?
|
||||||
|
}
|
||||||
|
next_token
|
||||||
|
};
|
||||||
|
let prompt_dt = start_prompt_processing.elapsed();
|
||||||
|
all_tokens.push(next_token);
|
||||||
|
if let Some(t) = tos.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
let eos_token = *tos.tokenizer().get_vocab(true).get("<|im_end|>").unwrap();
|
||||||
|
let start_post_prompt = std::time::Instant::now();
|
||||||
|
let mut sampled = 0;
|
||||||
|
for index in 0..to_sample {
|
||||||
|
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
|
||||||
|
let logits = model.forward(&input, tokens.len() + index)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
let logits = if args.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
args.repeat_penalty,
|
||||||
|
&all_tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
next_token = logits_processor.sample(&logits)?;
|
||||||
|
all_tokens.push(next_token);
|
||||||
|
if let Some(t) = tos.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
sampled += 1;
|
||||||
|
if next_token == eos_token {
|
||||||
|
break;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
let dt = start_post_prompt.elapsed();
|
||||||
|
println!(
|
||||||
|
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
||||||
|
tokens.len(),
|
||||||
|
tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"{sampled:4} tokens generated: {:.2} token/s",
|
||||||
|
sampled as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -139,7 +139,7 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
let (_one, h, w) = mask.dims3()?;
|
let (_one, h, w) = mask.dims3()?;
|
||||||
let mask = mask.expand((3, h, w))?;
|
let mask = mask.expand((3, h, w))?;
|
||||||
|
|
||||||
let mut img = image::io::Reader::open(&args.image)?
|
let mut img = image::ImageReader::open(&args.image)?
|
||||||
.decode()
|
.decode()
|
||||||
.map_err(candle::Error::wrap)?;
|
.map_err(candle::Error::wrap)?;
|
||||||
let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
|
let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
|
||||||
|
12
candle-examples/examples/silero-vad/README.md
Normal file
12
candle-examples/examples/silero-vad/README.md
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
# silero-vad: Voice Activity Detection
|
||||||
|
|
||||||
|
[Silero VAD (v5)](https://github.com/snakers4/silero-vad) detects voice activity in streaming audio.
|
||||||
|
|
||||||
|
This example uses the models available in the hugging face [onnx-community/silero-vad](https://huggingface.co/onnx-community/silero-vad).
|
||||||
|
|
||||||
|
## Running the example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ arecord -t raw -f S16_LE -r 16000 -c 1 -d 5 - | cargo run --example silero-vad --release --features onnx -- --sample-rate 16000
|
||||||
|
```
|
||||||
|
|
199
candle-examples/examples/silero-vad/main.rs
Normal file
199
candle-examples/examples/silero-vad/main.rs
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle::{DType, Tensor};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "silero")]
|
||||||
|
Silero,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||||
|
enum SampleRate {
|
||||||
|
#[value(name = "8000")]
|
||||||
|
Sr8k,
|
||||||
|
#[value(name = "16000")]
|
||||||
|
Sr16k,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
input: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
sample_rate: SampleRate,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
config_file: Option<String>,
|
||||||
|
|
||||||
|
/// The model to use.
|
||||||
|
#[arg(long, default_value = "silero")]
|
||||||
|
which: Which,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// an iterator which reads consecutive frames of le i16 values from a reader
|
||||||
|
struct I16Frames<R> {
|
||||||
|
rdr: R,
|
||||||
|
buf: Box<[u8]>,
|
||||||
|
len: usize,
|
||||||
|
eof: bool,
|
||||||
|
}
|
||||||
|
impl<R> I16Frames<R> {
|
||||||
|
fn new(rdr: R, frame_size: usize) -> Self {
|
||||||
|
I16Frames {
|
||||||
|
rdr,
|
||||||
|
buf: vec![0; frame_size * std::mem::size_of::<i16>()].into_boxed_slice(),
|
||||||
|
len: 0,
|
||||||
|
eof: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl<R: std::io::Read> Iterator for I16Frames<R> {
|
||||||
|
type Item = std::io::Result<Vec<f32>>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
if self.eof {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
self.len += match self.rdr.read(&mut self.buf[self.len..]) {
|
||||||
|
Ok(0) => {
|
||||||
|
self.eof = true;
|
||||||
|
0
|
||||||
|
}
|
||||||
|
Ok(n) => n,
|
||||||
|
Err(e) => return Some(Err(e)),
|
||||||
|
};
|
||||||
|
if self.eof || self.len == self.buf.len() {
|
||||||
|
let buf = self.buf[..self.len]
|
||||||
|
.chunks(2)
|
||||||
|
.map(|bs| match bs {
|
||||||
|
[a, b] => i16::from_le_bytes([*a, *b]),
|
||||||
|
_ => unreachable!(),
|
||||||
|
})
|
||||||
|
.map(|i| i as f32 / i16::MAX as f32)
|
||||||
|
.collect();
|
||||||
|
self.len = 0;
|
||||||
|
Some(Ok(buf))
|
||||||
|
} else {
|
||||||
|
self.next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let model_id = match &args.model_id {
|
||||||
|
Some(model_id) => std::path::PathBuf::from(model_id),
|
||||||
|
None => match args.which {
|
||||||
|
Which::Silero => hf_hub::api::sync::Api::new()?
|
||||||
|
.model("onnx-community/silero-vad".into())
|
||||||
|
.get("onnx/model.onnx")?,
|
||||||
|
// TODO: candle-onnx doesn't support Int8 dtype
|
||||||
|
// Which::SileroQuantized => hf_hub::api::sync::Api::new()?
|
||||||
|
// .model("onnx-community/silero-vad".into())
|
||||||
|
// .get("onnx/model_quantized.onnx")?,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let (sample_rate, frame_size, context_size): (i64, usize, usize) = match args.sample_rate {
|
||||||
|
SampleRate::Sr8k => (8000, 256, 32),
|
||||||
|
SampleRate::Sr16k => (16000, 512, 64),
|
||||||
|
};
|
||||||
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let model = candle_onnx::read_file(model_id)?;
|
||||||
|
|
||||||
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
struct State {
|
||||||
|
frame_size: usize,
|
||||||
|
sample_rate: Tensor,
|
||||||
|
state: Tensor,
|
||||||
|
context: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut state = State {
|
||||||
|
frame_size,
|
||||||
|
sample_rate: Tensor::new(sample_rate, &device)?,
|
||||||
|
state: Tensor::zeros((2, 1, 128), DType::F32, &device)?,
|
||||||
|
context: Tensor::zeros((1, context_size), DType::F32, &device)?,
|
||||||
|
};
|
||||||
|
let mut res = vec![];
|
||||||
|
for chunk in I16Frames::new(std::io::stdin().lock(), state.frame_size) {
|
||||||
|
let chunk = chunk.unwrap();
|
||||||
|
if chunk.len() < state.frame_size {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let next_context = Tensor::from_slice(
|
||||||
|
&chunk[state.frame_size - context_size..],
|
||||||
|
(1, context_size),
|
||||||
|
&device,
|
||||||
|
)?;
|
||||||
|
let chunk = Tensor::from_vec(chunk, (1, state.frame_size), &device)?;
|
||||||
|
let chunk = Tensor::cat(&[&state.context, &chunk], 1)?;
|
||||||
|
let inputs = std::collections::HashMap::from_iter([
|
||||||
|
("input".to_string(), chunk),
|
||||||
|
("sr".to_string(), state.sample_rate.clone()),
|
||||||
|
("state".to_string(), state.state.clone()),
|
||||||
|
]);
|
||||||
|
let out = candle_onnx::simple_eval(&model, inputs).unwrap();
|
||||||
|
let out_names = &model.graph.as_ref().unwrap().output;
|
||||||
|
let output = out.get(&out_names[0].name).unwrap().clone();
|
||||||
|
state.state = out.get(&out_names[1].name).unwrap().clone();
|
||||||
|
assert_eq!(state.state.dims(), &[2, 1, 128]);
|
||||||
|
state.context = next_context;
|
||||||
|
|
||||||
|
let output = output.flatten_all()?.to_vec1::<f32>()?;
|
||||||
|
assert_eq!(output.len(), 1);
|
||||||
|
let output = output[0];
|
||||||
|
println!("vad chunk prediction: {output}");
|
||||||
|
res.push(output);
|
||||||
|
}
|
||||||
|
println!("calculated prediction in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let res_len = res.len() as f32;
|
||||||
|
let prediction = res.iter().sum::<f32>() / res_len;
|
||||||
|
println!("vad average prediction: {prediction}");
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -380,7 +380,7 @@ fn text_embeddings(
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn image_preprocess<T: AsRef<std::path::Path>>(path: T) -> anyhow::Result<Tensor> {
|
fn image_preprocess<T: AsRef<std::path::Path>>(path: T) -> anyhow::Result<Tensor> {
|
||||||
let img = image::io::Reader::open(path)?.decode()?;
|
let img = image::ImageReader::open(path)?.decode()?;
|
||||||
let (height, width) = (img.height() as usize, img.width() as usize);
|
let (height, width) = (img.height() as usize, img.width() as usize);
|
||||||
let height = height - height % 32;
|
let height = height - height % 32;
|
||||||
let width = width - width % 32;
|
let width = width - width % 32;
|
||||||
|
@ -145,7 +145,7 @@ impl ViTImageProcessor {
|
|||||||
pub fn load_images(&self, image_path: Vec<&str>) -> Result<Vec<image::DynamicImage>> {
|
pub fn load_images(&self, image_path: Vec<&str>) -> Result<Vec<image::DynamicImage>> {
|
||||||
let mut images: Vec<image::DynamicImage> = Vec::new();
|
let mut images: Vec<image::DynamicImage> = Vec::new();
|
||||||
for path in image_path {
|
for path in image_path {
|
||||||
let img = image::io::Reader::open(path)?.decode().unwrap();
|
let img = image::ImageReader::open(path)?.decode().unwrap();
|
||||||
images.push(img);
|
images.push(img);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -123,7 +123,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
|
|||||||
let padding = if pad != 0 { (size - 1) / 2 } else { 0 };
|
let padding = if pad != 0 { (size - 1) / 2 } else { 0 };
|
||||||
let (bn, bias) = match b.parameters.get("batch_normalize") {
|
let (bn, bias) = match b.parameters.get("batch_normalize") {
|
||||||
Some(p) if p.parse::<usize>()? != 0 => {
|
Some(p) if p.parse::<usize>()? != 0 => {
|
||||||
let bn = batch_norm(filters, 1e-5, vb.pp(&format!("batch_norm_{index}")))?;
|
let bn = batch_norm(filters, 1e-5, vb.pp(format!("batch_norm_{index}")))?;
|
||||||
(Some(bn), false)
|
(Some(bn), false)
|
||||||
}
|
}
|
||||||
Some(_) | None => (None, true),
|
Some(_) | None => (None, true),
|
||||||
@ -135,9 +135,9 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
|
|||||||
dilation: 1,
|
dilation: 1,
|
||||||
};
|
};
|
||||||
let conv = if bias {
|
let conv = if bias {
|
||||||
conv2d(p, filters, size, conv_cfg, vb.pp(&format!("conv_{index}")))?
|
conv2d(p, filters, size, conv_cfg, vb.pp(format!("conv_{index}")))?
|
||||||
} else {
|
} else {
|
||||||
conv2d_no_bias(p, filters, size, conv_cfg, vb.pp(&format!("conv_{index}")))?
|
conv2d_no_bias(p, filters, size, conv_cfg, vb.pp(format!("conv_{index}")))?
|
||||||
};
|
};
|
||||||
let leaky = match activation {
|
let leaky = match activation {
|
||||||
"leaky" => true,
|
"leaky" => true,
|
||||||
@ -272,7 +272,7 @@ impl Darknet {
|
|||||||
let mut prev_channels: usize = 3;
|
let mut prev_channels: usize = 3;
|
||||||
for (index, block) in self.blocks.iter().enumerate() {
|
for (index, block) in self.blocks.iter().enumerate() {
|
||||||
let channels_and_bl = match block.block_type.as_str() {
|
let channels_and_bl = match block.block_type.as_str() {
|
||||||
"convolutional" => conv(vb.pp(&index.to_string()), index, prev_channels, block)?,
|
"convolutional" => conv(vb.pp(index.to_string()), index, prev_channels, block)?,
|
||||||
"upsample" => upsample(prev_channels)?,
|
"upsample" => upsample(prev_channels)?,
|
||||||
"shortcut" => shortcut(index, prev_channels, block)?,
|
"shortcut" => shortcut(index, prev_channels, block)?,
|
||||||
"route" => route(index, &blocks, block)?,
|
"route" => route(index, &blocks, block)?,
|
||||||
|
@ -159,7 +159,7 @@ pub fn main() -> Result<()> {
|
|||||||
let net_width = darknet.width()?;
|
let net_width = darknet.width()?;
|
||||||
let net_height = darknet.height()?;
|
let net_height = darknet.height()?;
|
||||||
|
|
||||||
let original_image = image::io::Reader::open(&image_name)?
|
let original_image = image::ImageReader::open(&image_name)?
|
||||||
.decode()
|
.decode()
|
||||||
.map_err(candle::Error::wrap)?;
|
.map_err(candle::Error::wrap)?;
|
||||||
let image = {
|
let image = {
|
||||||
|
@ -390,7 +390,7 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
|
|||||||
for image_name in args.images.iter() {
|
for image_name in args.images.iter() {
|
||||||
println!("processing {image_name}");
|
println!("processing {image_name}");
|
||||||
let mut image_name = std::path::PathBuf::from(image_name);
|
let mut image_name = std::path::PathBuf::from(image_name);
|
||||||
let original_image = image::io::Reader::open(&image_name)?
|
let original_image = image::ImageReader::open(&image_name)?
|
||||||
.decode()
|
.decode()
|
||||||
.map_err(candle::Error::wrap)?;
|
.map_err(candle::Error::wrap)?;
|
||||||
let (width, height) = {
|
let (width, height) = {
|
||||||
|
@ -161,7 +161,7 @@ impl C2f {
|
|||||||
let cv2 = ConvBlock::load(vb.pp("cv2"), (2 + n) * c, c2, 1, 1, None)?;
|
let cv2 = ConvBlock::load(vb.pp("cv2"), (2 + n) * c, c2, 1, 1, None)?;
|
||||||
let mut bottleneck = Vec::with_capacity(n);
|
let mut bottleneck = Vec::with_capacity(n);
|
||||||
for idx in 0..n {
|
for idx in 0..n {
|
||||||
let b = Bottleneck::load(vb.pp(&format!("bottleneck.{idx}")), c, c, shortcut)?;
|
let b = Bottleneck::load(vb.pp(format!("bottleneck.{idx}")), c, c, shortcut)?;
|
||||||
bottleneck.push(b)
|
bottleneck.push(b)
|
||||||
}
|
}
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
@ -1,20 +1,53 @@
|
|||||||
use candle::{Device, Result, Tensor};
|
use candle::{Device, Result, Tensor};
|
||||||
|
|
||||||
|
pub const IMAGENET_MEAN: [f32; 3] = [0.485f32, 0.456, 0.406];
|
||||||
|
pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225];
|
||||||
|
|
||||||
|
/// Loads an image from disk using the image crate at the requested resolution,
|
||||||
|
/// using the given std and mean parameters.
|
||||||
|
/// This returns a tensor with shape (3, res, res). imagenet normalization is applied.
|
||||||
|
|
||||||
|
pub fn load_image_with_std_mean<P: AsRef<std::path::Path>>(
|
||||||
|
p: P,
|
||||||
|
res: usize,
|
||||||
|
mean: &[f32; 3],
|
||||||
|
std: &[f32; 3],
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let img = image::ImageReader::open(p)?
|
||||||
|
.decode()
|
||||||
|
.map_err(candle::Error::wrap)?
|
||||||
|
.resize_to_fill(
|
||||||
|
res as u32,
|
||||||
|
res as u32,
|
||||||
|
image::imageops::FilterType::Triangle,
|
||||||
|
);
|
||||||
|
let img = img.to_rgb8();
|
||||||
|
let data = img.into_raw();
|
||||||
|
let data = Tensor::from_vec(data, (res, res, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
||||||
|
let mean = Tensor::new(mean, &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||||
|
let std = Tensor::new(std, &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||||
|
(data.to_dtype(candle::DType::F32)? / 255.)?
|
||||||
|
.broadcast_sub(&mean)?
|
||||||
|
.broadcast_div(&std)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Loads an image from disk using the image crate at the requested resolution.
|
||||||
|
/// This returns a tensor with shape (3, res, res). imagenet normalization is applied.
|
||||||
|
pub fn load_image<P: AsRef<std::path::Path>>(p: P, res: usize) -> Result<Tensor> {
|
||||||
|
load_image_with_std_mean(p, res, &IMAGENET_MEAN, &IMAGENET_STD)
|
||||||
|
}
|
||||||
|
|
||||||
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||||
/// (3, 224, 224). imagenet normalization is applied.
|
/// (3, 224, 224). imagenet normalization is applied.
|
||||||
pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||||
let img = image::io::Reader::open(p)?
|
load_image(p, 224)
|
||||||
.decode()
|
}
|
||||||
.map_err(candle::Error::wrap)?
|
|
||||||
.resize_to_fill(224, 224, image::imageops::FilterType::Triangle);
|
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||||
let img = img.to_rgb8();
|
/// (3, 518, 518). imagenet normalization is applied.
|
||||||
let data = img.into_raw();
|
/// The model dinov2 reg4 analyzes images with dimensions 3x518x518 (resulting in 37x37 transformer tokens).
|
||||||
let data = Tensor::from_vec(data, (224, 224, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
pub fn load_image518<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||||
let mean = Tensor::new(&[0.485f32, 0.456, 0.406], &Device::Cpu)?.reshape((3, 1, 1))?;
|
load_image(p, 518)
|
||||||
let std = Tensor::new(&[0.229f32, 0.224, 0.225], &Device::Cpu)?.reshape((3, 1, 1))?;
|
|
||||||
(data.to_dtype(candle::DType::F32)? / 255.)?
|
|
||||||
.broadcast_sub(&mean)?
|
|
||||||
.broadcast_div(&std)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const CLASS_COUNT: i64 = 1000;
|
pub const CLASS_COUNT: i64 = 1000;
|
||||||
|
@ -34,7 +34,7 @@ pub fn load_image<P: AsRef<std::path::Path>>(
|
|||||||
p: P,
|
p: P,
|
||||||
resize_longest: Option<usize>,
|
resize_longest: Option<usize>,
|
||||||
) -> Result<(Tensor, usize, usize)> {
|
) -> Result<(Tensor, usize, usize)> {
|
||||||
let img = image::io::Reader::open(p)?
|
let img = image::ImageReader::open(p)?
|
||||||
.decode()
|
.decode()
|
||||||
.map_err(candle::Error::wrap)?;
|
.map_err(candle::Error::wrap)?;
|
||||||
let (initial_h, initial_w) = (img.height() as usize, img.width() as usize);
|
let (initial_h, initial_w) = (img.height() as usize, img.width() as usize);
|
||||||
@ -65,7 +65,7 @@ pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
|
|||||||
width: usize,
|
width: usize,
|
||||||
height: usize,
|
height: usize,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
let img = image::io::Reader::open(p)?
|
let img = image::ImageReader::open(p)?
|
||||||
.decode()
|
.decode()
|
||||||
.map_err(candle::Error::wrap)?
|
.map_err(candle::Error::wrap)?
|
||||||
.resize_to_fill(
|
.resize_to_fill(
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-flash-attn"
|
name = "candle-flash-attn"
|
||||||
version = "0.6.0"
|
version = "0.7.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Flash attention layer for the candle ML framework."
|
description = "Flash attention layer for the candle ML framework."
|
||||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.6.0" }
|
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.7.1" }
|
||||||
half = { version = "2.3.1", features = ["num-traits"] }
|
half = { version = "2.3.1", features = ["num-traits"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
const KERNEL_FILES: [&str; 17] = [
|
const KERNEL_FILES: [&str; 33] = [
|
||||||
"kernels/flash_api.cu",
|
"kernels/flash_api.cu",
|
||||||
"kernels/flash_fwd_hdim128_fp16_sm80.cu",
|
"kernels/flash_fwd_hdim128_fp16_sm80.cu",
|
||||||
"kernels/flash_fwd_hdim160_fp16_sm80.cu",
|
"kernels/flash_fwd_hdim160_fp16_sm80.cu",
|
||||||
@ -22,6 +22,22 @@ const KERNEL_FILES: [&str; 17] = [
|
|||||||
"kernels/flash_fwd_hdim32_bf16_sm80.cu",
|
"kernels/flash_fwd_hdim32_bf16_sm80.cu",
|
||||||
"kernels/flash_fwd_hdim64_bf16_sm80.cu",
|
"kernels/flash_fwd_hdim64_bf16_sm80.cu",
|
||||||
"kernels/flash_fwd_hdim96_bf16_sm80.cu",
|
"kernels/flash_fwd_hdim96_bf16_sm80.cu",
|
||||||
|
"kernels/flash_fwd_hdim128_fp16_causal_sm80.cu",
|
||||||
|
"kernels/flash_fwd_hdim160_fp16_causal_sm80.cu",
|
||||||
|
"kernels/flash_fwd_hdim192_fp16_causal_sm80.cu",
|
||||||
|
"kernels/flash_fwd_hdim224_fp16_causal_sm80.cu",
|
||||||
|
"kernels/flash_fwd_hdim256_fp16_causal_sm80.cu",
|
||||||
|
"kernels/flash_fwd_hdim32_fp16_causal_sm80.cu",
|
||||||
|
"kernels/flash_fwd_hdim64_fp16_causal_sm80.cu",
|
||||||
|
"kernels/flash_fwd_hdim96_fp16_causal_sm80.cu",
|
||||||
|
"kernels/flash_fwd_hdim128_bf16_causal_sm80.cu",
|
||||||
|
"kernels/flash_fwd_hdim160_bf16_causal_sm80.cu",
|
||||||
|
"kernels/flash_fwd_hdim192_bf16_causal_sm80.cu",
|
||||||
|
"kernels/flash_fwd_hdim224_bf16_causal_sm80.cu",
|
||||||
|
"kernels/flash_fwd_hdim256_bf16_causal_sm80.cu",
|
||||||
|
"kernels/flash_fwd_hdim32_bf16_causal_sm80.cu",
|
||||||
|
"kernels/flash_fwd_hdim64_bf16_causal_sm80.cu",
|
||||||
|
"kernels/flash_fwd_hdim96_bf16_causal_sm80.cu",
|
||||||
];
|
];
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
|
Submodule candle-flash-attn/cutlass updated: c4f6b8c6bc...7d49e6c7e2
@ -13,50 +13,62 @@ using namespace cute;
|
|||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <bool Is_causal, typename Engine, typename Layout>
|
template <bool Is_causal>
|
||||||
inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
|
struct Alibi {
|
||||||
const int col_idx_offset_,
|
|
||||||
const int max_seqlen_k,
|
const float alibi_slope;
|
||||||
const int row_idx_offset,
|
const int max_seqlen_k, max_seqlen_q;
|
||||||
const int max_seqlen_q,
|
|
||||||
const int warp_row_stride,
|
__forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q)
|
||||||
const float alibi_slope) {
|
: alibi_slope(alibi_slope)
|
||||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
, max_seqlen_k(max_seqlen_k)
|
||||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
, max_seqlen_q(max_seqlen_q) {
|
||||||
const int lane_id = threadIdx.x % 32;
|
};
|
||||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
|
||||||
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
|
|
||||||
#pragma unroll
|
template <typename Engine, typename Layout>
|
||||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
__forceinline__ __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
|
||||||
const int col_idx_base = col_idx_offset + nj * 8;
|
const int col_idx_offset_,
|
||||||
|
const int row_idx_offset,
|
||||||
|
const int warp_row_stride) {
|
||||||
|
// tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||||
|
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||||
|
const int lane_id = threadIdx.x % 32;
|
||||||
|
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||||
|
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||||
const int col_idx = col_idx_base + j;
|
const int col_idx_base = col_idx_offset + nj * 8;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||||
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
|
const int col_idx = col_idx_base + j;
|
||||||
|
#pragma unroll
|
||||||
|
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||||
|
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
} else { // Bias depends on both row_idx and col_idx
|
||||||
} else { // Bias depends on both row_idx and col_idx
|
|
||||||
#pragma unroll
|
|
||||||
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
|
||||||
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
||||||
const int row_idx = row_idx_base + i * 8;
|
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
||||||
const int col_idx_base = col_idx_offset + nj * 8;
|
const int row_idx = row_idx_base + i * 8;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||||
const int col_idx = col_idx_base + j;
|
const int col_idx_base = col_idx_offset + nj * 8;
|
||||||
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
|
#pragma unroll
|
||||||
|
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||||
|
const int col_idx = col_idx_base + j;
|
||||||
|
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace flash
|
} // namespace flash
|
||||||
|
@ -24,12 +24,12 @@ struct BlockInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename index_t>
|
template <typename index_t>
|
||||||
inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
__forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
||||||
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
|
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename index_t>
|
template <typename index_t>
|
||||||
inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
||||||
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
|
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
94
candle-flash-attn/kernels/dropout.h
Normal file
94
candle-flash-attn/kernels/dropout.h
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
/******************************************************************************
|
||||||
|
* Copyright (c) 2024, Tri Dao.
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "philox.cuh"
|
||||||
|
#include "utils.h"
|
||||||
|
|
||||||
|
namespace flash {
|
||||||
|
|
||||||
|
struct Dropout {
|
||||||
|
|
||||||
|
const unsigned long long seed, offset;
|
||||||
|
const uint8_t p_dropout_in_uint8_t;
|
||||||
|
|
||||||
|
__forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset,
|
||||||
|
const uint8_t p_dropout_in_uint8_t,
|
||||||
|
const int bid, const int hid, const int tid, const int nheads)
|
||||||
|
: seed(seed)
|
||||||
|
, offset(offset + (bid * nheads + hid) * 32 + tid % 32)
|
||||||
|
, p_dropout_in_uint8_t(p_dropout_in_uint8_t) {
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
|
||||||
|
__forceinline__ __device__ void apply_dropout(Tensor<Engine, Layout> &tensor_,
|
||||||
|
int block_row_start, int block_col_start, int block_row_stride) {
|
||||||
|
// convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2)
|
||||||
|
Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_dropout(tensor_.layout()));
|
||||||
|
using T = typename Engine::value_type;
|
||||||
|
auto encode_dropout = [](bool keep, T val) {
|
||||||
|
return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));
|
||||||
|
};
|
||||||
|
static_assert(decltype(size<2>(tensor))::value % 2 == 0);
|
||||||
|
const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t);
|
||||||
|
const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t);
|
||||||
|
// if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); }
|
||||||
|
#pragma unroll
|
||||||
|
for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) {
|
||||||
|
uint2 rowcol = make_uint2(block_row_start, block_col_start);
|
||||||
|
#pragma unroll
|
||||||
|
for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) {
|
||||||
|
// if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
|
||||||
|
uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
|
||||||
|
// if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
|
||||||
|
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
|
||||||
|
// Special implementation for 16-bit types: we duplicate the threshold to the
|
||||||
|
// low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction
|
||||||
|
// to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000,
|
||||||
|
// and the high 16 bits will be either 0xffff or 0x0000, depending on whether
|
||||||
|
// the random value is less than the threshold.
|
||||||
|
// We then do a bit-wise AND between the mask and the original value (in 32-bit).
|
||||||
|
// We're exploiting the fact that floating point comparison is equivalent to integer
|
||||||
|
// comparison, since we're comparing unsigned integers whose top 8-bits are zero.
|
||||||
|
if (!encode_dropout_in_sign_bit
|
||||||
|
&& (std::is_same<T, cutlass::half_t>::value || std::is_same<T, cutlass::bfloat16_t>::value)) {
|
||||||
|
uint16_t rnd_16[16];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); }
|
||||||
|
uint32_t (&rnd_32)[8] = reinterpret_cast<uint32_t (&)[8]>(rnd_16);
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < 2; j++) {
|
||||||
|
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
|
||||||
|
// if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }
|
||||||
|
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
uint32_t mask;
|
||||||
|
asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t));
|
||||||
|
tensor_uint32(i) &= mask;
|
||||||
|
}
|
||||||
|
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < 2; j++) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j));
|
||||||
|
}
|
||||||
|
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
|
||||||
|
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||||
|
// // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
|
||||||
|
// // }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace flash
|
8
candle-flash-attn/kernels/error.h
Normal file
8
candle-flash-attn/kernels/error.h
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#define C10_CUDA_CHECK(EXPR) \
|
||||||
|
do { \
|
||||||
|
const cudaError_t __err = EXPR; \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define C10_CUDA_KERNEL_LAUNCH_CHECK() C10_CUDA_CHECK(cudaGetLastError())
|
@ -7,6 +7,14 @@
|
|||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
// #ifdef OLD_GENERATOR_PATH
|
||||||
|
// #include <ATen/CUDAGeneratorImpl.h>
|
||||||
|
// #else
|
||||||
|
// #include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||||
|
// #endif
|
||||||
|
//
|
||||||
|
// #include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
|
||||||
|
|
||||||
constexpr int TOTAL_DIM = 0;
|
constexpr int TOTAL_DIM = 0;
|
||||||
constexpr int H_DIM = 1;
|
constexpr int H_DIM = 1;
|
||||||
constexpr int D_DIM = 2;
|
constexpr int D_DIM = 2;
|
||||||
@ -14,7 +22,7 @@ constexpr int D_DIM = 2;
|
|||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
struct Qkv_params {
|
struct Qkv_params {
|
||||||
using index_t = uint32_t;
|
using index_t = int64_t;
|
||||||
// The QKV matrices.
|
// The QKV matrices.
|
||||||
void *__restrict__ q_ptr;
|
void *__restrict__ q_ptr;
|
||||||
void *__restrict__ k_ptr;
|
void *__restrict__ k_ptr;
|
||||||
@ -59,7 +67,7 @@ struct Flash_fwd_params : public Qkv_params {
|
|||||||
void * __restrict__ softmax_lseaccum_ptr;
|
void * __restrict__ softmax_lseaccum_ptr;
|
||||||
|
|
||||||
// The dimensions.
|
// The dimensions.
|
||||||
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
|
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q;
|
||||||
|
|
||||||
// The scaling factors for the kernel.
|
// The scaling factors for the kernel.
|
||||||
float scale_softmax;
|
float scale_softmax;
|
||||||
@ -91,7 +99,12 @@ struct Flash_fwd_params : public Qkv_params {
|
|||||||
void * __restrict__ rotary_sin_ptr;
|
void * __restrict__ rotary_sin_ptr;
|
||||||
|
|
||||||
// The indices to index into the KV cache.
|
// The indices to index into the KV cache.
|
||||||
int *__restrict__ cache_batch_idx;
|
int * __restrict__ cache_batch_idx;
|
||||||
|
|
||||||
|
// Paged KV cache
|
||||||
|
int * __restrict__ block_table;
|
||||||
|
index_t block_table_batch_stride;
|
||||||
|
int page_block_size;
|
||||||
|
|
||||||
// The dropout probability (probability of keeping an activation).
|
// The dropout probability (probability of keeping an activation).
|
||||||
float p_dropout;
|
float p_dropout;
|
||||||
@ -105,6 +118,13 @@ struct Flash_fwd_params : public Qkv_params {
|
|||||||
|
|
||||||
// Local window size
|
// Local window size
|
||||||
int window_size_left, window_size_right;
|
int window_size_left, window_size_right;
|
||||||
|
float softcap;
|
||||||
|
|
||||||
|
// Random state.
|
||||||
|
// at::PhiloxCudaState philox_args;
|
||||||
|
|
||||||
|
// Pointer to the RNG seed (idx 0) and offset (idx 1).
|
||||||
|
uint64_t * rng_state;
|
||||||
|
|
||||||
bool is_bf16;
|
bool is_bf16;
|
||||||
bool is_causal;
|
bool is_causal;
|
||||||
@ -119,6 +139,9 @@ struct Flash_fwd_params : public Qkv_params {
|
|||||||
|
|
||||||
void * __restrict__ alibi_slopes_ptr;
|
void * __restrict__ alibi_slopes_ptr;
|
||||||
index_t alibi_slopes_batch_stride;
|
index_t alibi_slopes_batch_stride;
|
||||||
|
|
||||||
|
bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
|
||||||
|
bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).
|
||||||
};
|
};
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -165,7 +188,7 @@ struct Flash_bwd_params : public Flash_fwd_params {
|
|||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream);
|
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||||
template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
|
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||||
|
|
||||||
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure);
|
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream);
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
|
#include "kernels.h"
|
||||||
|
#include "kernel_helpers.h"
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) {
|
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
FP16_SWITCH(!params.is_bf16, [&] {
|
FP16_SWITCH(!params.is_bf16, [&] {
|
||||||
FWD_HEADDIM_SWITCH(params.d, [&] {
|
HEADDIM_SWITCH(params.d, [&] {
|
||||||
// if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
|
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||||
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
|
run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
|
||||||
// } else {
|
});
|
||||||
// run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
|
});
|
||||||
// }
|
});
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" void run_mha(
|
extern "C" void run_mha(
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user