mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Compare commits
142 Commits
phi2-gguf
...
metal-gemm
Author | SHA1 | Date | |
---|---|---|---|
7ec4f64d38 | |||
8712ceb84f | |||
f9b2bb4d46 | |||
aefca7f8e6 | |||
c3b0757995 | |||
560f666d29 | |||
16b49dfd89 | |||
fc0deede31 | |||
e178aacead | |||
bb191c25d5 | |||
2be9bd211e | |||
89eae41efd | |||
c0a559d427 | |||
aa7ac1832d | |||
19db6b9723 | |||
0fcb40b229 | |||
6991a37b94 | |||
9ca277a9d7 | |||
2e9c010609 | |||
ac51f477eb | |||
d4b6f6eef6 | |||
957d604a78 | |||
ce90287f45 | |||
1ba87a9450 | |||
bd80078acf | |||
fea46cb719 | |||
8696cf6494 | |||
4a52aeb437 | |||
24d54d0ff9 | |||
636eff652a | |||
0f5cbb08b3 | |||
ddafc61055 | |||
a925ae6bc6 | |||
6056fd5c90 | |||
ebc9aa60bc | |||
2489a606fe | |||
3c815b1dca | |||
42891cc613 | |||
f25173d68b | |||
6a4741bbf9 | |||
30cdd769f9 | |||
d74fbed334 | |||
c63048d374 | |||
a226a9736b | |||
25960676ca | |||
9cd54aa5d4 | |||
eec11ce2ce | |||
9182f9f5c2 | |||
ecff05d72b | |||
7f1ba8038c | |||
74e9e41911 | |||
e27aac0a06 | |||
a3dd87f15e | |||
242e006bbb | |||
6baa1d486b | |||
36cf54525d | |||
2b10aaa05d | |||
9f804af29d | |||
54ff971e35 | |||
b9fac7ec00 | |||
f65e90e7ef | |||
d39462856b | |||
cb180eb23a | |||
9182c828e6 | |||
3f13ad3d79 | |||
cd4d941ed1 | |||
03344d3c19 | |||
1ec3b2cc18 | |||
f7773d498a | |||
7abc3b8cd7 | |||
46012ed31f | |||
f3fade3b03 | |||
ea260aeffd | |||
0814dfd148 | |||
3ceca9901a | |||
1df2bddccf | |||
6f0b807ffd | |||
d54e02d73d | |||
45e235a747 | |||
31cf64147b | |||
77ea479a18 | |||
72e7ca529a | |||
7ff921c538 | |||
9b8537a62f | |||
7ebc3548e1 | |||
eefc1c77ef | |||
01545f7303 | |||
349c3e806a | |||
bdaa34216a | |||
cc80e065e5 | |||
13c64f6828 | |||
21f82a5155 | |||
9cff7bc3f4 | |||
d9bc5ec151 | |||
84328e2b60 | |||
82b641fd27 | |||
01794dc16e | |||
a75cd8164f | |||
b13a82a438 | |||
59b18d974e | |||
89f53b9d7b | |||
a09d451d11 | |||
fa06f5f5f9 | |||
09d4845aa8 | |||
a0d03aded1 | |||
3bbb88fcb4 | |||
ed7b99f525 | |||
287013ef28 | |||
eb26e2467e | |||
c68ed8963f | |||
e5c8b88f90 | |||
805f3be8e1 | |||
3b429f3023 | |||
96a48e5cc4 | |||
6cf82fd7a3 | |||
cfab6e7616 | |||
11d4a3c588 | |||
9d3f1c8af5 | |||
7211009179 | |||
6fadaf2eff | |||
8a05743a21 | |||
b2e816752b | |||
618ecf5e23 | |||
267601eec1 | |||
08a15cb79e | |||
c388be93e7 | |||
d22f1d4f4e | |||
0067fe00a8 | |||
587ee3bb6f | |||
dd78422701 | |||
9215e9ce8c | |||
52ae332910 | |||
8b390ddd29 | |||
c97d639fa0 | |||
b45c710dbf | |||
9c532aef47 | |||
f7a6468238 | |||
2b93dffe64 | |||
e6ee7ba4d4 | |||
1690ab45d2 | |||
8de0ce6cba | |||
ce6d08df94 |
15
.github/workflows/trufflehog.yml
vendored
Normal file
15
.github/workflows/trufflehog.yml
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
on:
|
||||
push:
|
||||
|
||||
name: Secret Leaks
|
||||
|
||||
jobs:
|
||||
trufflehog:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Secret Scanning
|
||||
uses: trufflesecurity/trufflehog@main
|
4
.gitignore
vendored
4
.gitignore
vendored
@ -9,6 +9,10 @@ target/
|
||||
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
|
||||
Cargo.lock
|
||||
|
||||
# editor config
|
||||
.helix
|
||||
.vscode
|
||||
|
||||
# These are backup files generated by rustfmt
|
||||
**/*.rs.bk
|
||||
|
||||
|
28
Cargo.toml
28
Cargo.toml
@ -20,7 +20,7 @@ exclude = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.5.0"
|
||||
version = "0.6.0"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -33,22 +33,23 @@ ab_glyph = "0.2.23"
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.5.0" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.5.0" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.5.0" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.5.0" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.5.0" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.5.0" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.5.0" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.5.0" }
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.6.0" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.6.0" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.6.0" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.6.0" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.6.0" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.6.0" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.6.0" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.6.0" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.10.0", features = ["f16"] }
|
||||
cudarc = { version = "=0.11.6", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||
fancy-regex = "0.13.0"
|
||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||
hf-hub = "0.3.0"
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
image = { version = "0.25.0", default-features = false, features = ["jpeg", "png"] }
|
||||
hound = "3.5.1"
|
||||
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
|
||||
imageproc = { version = "0.24.0", default-features = false }
|
||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||
libc = { version = "0.2.147" }
|
||||
@ -65,13 +66,12 @@ serde = { version = "1.0.171", features = ["derive"] }
|
||||
serde_plain = "1.0.2"
|
||||
serde_json = "1.0.99"
|
||||
thiserror = "1"
|
||||
tokenizers = { version = "0.15.0", default-features = false }
|
||||
tokenizers = { version = "0.19.1", default-features = false }
|
||||
tracing = "0.1.37"
|
||||
tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
wav = "1.0.0"
|
||||
yoke = { version = "0.7.2", features = ["derive"] }
|
||||
zip = { version = "0.6.6", default-features = false }
|
||||
zip = { version = "1.1.1", default-features = false }
|
||||
metal = { version = "0.27.0", features = ["mps"]}
|
||||
|
||||
[profile.release-with-debug]
|
||||
|
24
README.md
24
README.md
@ -60,13 +60,14 @@ These online demos run entirely in your browser:
|
||||
|
||||
We also provide a some command line based examples using state of the art models:
|
||||
|
||||
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes
|
||||
- [LLaMA v1, v2, and v3](./candle-examples/examples/llama/): general LLM, includes
|
||||
the SOLAR-10.7B variant.
|
||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google Deepmind.
|
||||
- [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b
|
||||
Griffin based models from Google that mix attention with a RNN like state.
|
||||
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
|
||||
- [Phi-1, Phi-1.5, Phi-2, and Phi-3](./candle-examples/examples/phi/): 1.3b,
|
||||
2.7b, and 3.8b general LLMs with performance on par with 7b models.
|
||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||
pre-trained on 1T tokens of English and code datasets. Also supports
|
||||
StableLM-2, a 1.6b LLM trained on 2T tokens, as well as the code variants.
|
||||
@ -111,7 +112,7 @@ We also provide a some command line based examples using state of the art models
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
|
||||
|
||||
- [SegFormer](./candle-examples/examples/segformer/): transformer based semantic segmantation model.
|
||||
- [SegFormer](./candle-examples/examples/segformer/): transformer based semantic segmentation model.
|
||||
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
||||
- [EnCodec](./candle-examples/examples/encodec/): high-quality audio compression
|
||||
model using residual vector quantization.
|
||||
@ -200,10 +201,10 @@ If you have an addition to this list, please submit a pull request.
|
||||
- WASM support, run your models in a browser.
|
||||
- Included models.
|
||||
- Language Models.
|
||||
- LLaMA v1 and v2 with variants such as SOLAR-10.7B.
|
||||
- LLaMA v1, v2, and v3 with variants such as SOLAR-10.7B.
|
||||
- Falcon.
|
||||
- StarCoder, StarCoder2.
|
||||
- Phi 1, 1.5, and 2.
|
||||
- Phi 1, 1.5, 2, and 3.
|
||||
- Mamba, Minimal Mamba
|
||||
- Gemma 2b and 7b.
|
||||
- Mistral 7b v0.1.
|
||||
@ -235,7 +236,7 @@ If you have an addition to this list, please submit a pull request.
|
||||
- MetaVoice-1B, text-to-speech model.
|
||||
- Computer Vision Models.
|
||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT,
|
||||
ConvNeXTv2, MobileOne, EfficientVit (MSRA).
|
||||
ConvNeXTv2, MobileOne, EfficientVit (MSRA), MobileNetv4, Hiera.
|
||||
- yolo-v3, yolo-v8.
|
||||
- Segment-Anything Model (SAM).
|
||||
- SegFormer.
|
||||
@ -375,9 +376,9 @@ git submodule update --init
|
||||
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ‘...’:
|
||||
```
|
||||
|
||||
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the CANDLE_NVCC_CCBIN environment variable.
|
||||
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the NVCC_CCBIN environment variable.
|
||||
```
|
||||
env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
|
||||
env NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
|
||||
```
|
||||
|
||||
#### Linking error on windows when running rustdoc or mdbook tests
|
||||
@ -407,3 +408,10 @@ This may be caused by the models being loaded from `/mnt/c`, more details on
|
||||
|
||||
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
|
||||
error is generated.
|
||||
|
||||
#### CudaRC error
|
||||
|
||||
If you encounter an error like this one `called `Result::unwrap()` on an `Err` value: LoadLibraryExW { source: Os { code: 126, kind: Uncategorized, message: "The specified module could not be found." } }` on windows. To fix copy and rename these 3 files (make sure they are in path). The paths depend on your cuda version.
|
||||
`c:\Windows\System32\nvcuda.dll` -> `cuda.dll`
|
||||
`c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\cublas64_12.dll` -> `cublas.dll`
|
||||
`c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\curand64_10.dll` -> `curand.dll`
|
||||
|
@ -37,7 +37,6 @@ tokenizers = { workspace = true, features = ["onig"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-chrome = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||
parquet = { workspace = true }
|
||||
image = { workspace = true }
|
||||
|
@ -81,7 +81,7 @@ let mut tp_shape = view.shape().to_vec();
|
||||
let size = tp_shape[0];
|
||||
|
||||
if size % world_size != 0 {
|
||||
panic!("The dimension is not divisble by `world_size`");
|
||||
panic!("The dimension is not divisible by `world_size`");
|
||||
}
|
||||
let block_size = size / world_size;
|
||||
let start = rank * block_size;
|
||||
@ -106,8 +106,8 @@ let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).un
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_training_1() -> Result<()>{
|
||||
// ANCHOR: book_training_1
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
|
@ -48,3 +48,7 @@ metal = ["dep:metal", "dep:candle-metal-kernels"]
|
||||
[[bench]]
|
||||
name = "bench_main"
|
||||
harness = false
|
||||
|
||||
[[example]]
|
||||
name = "metal_basics"
|
||||
required-features = ["metal"]
|
||||
|
@ -8,4 +8,5 @@ criterion_main!(
|
||||
benchmarks::where_cond::benches,
|
||||
benchmarks::conv_transpose2d::benches,
|
||||
benchmarks::qmatmul::benches,
|
||||
benchmarks::unary::benches
|
||||
);
|
||||
|
@ -12,7 +12,7 @@ fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name:
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let tensor = Tensor::zeros((b, m, k), dtype, &device).unwrap();
|
||||
let tensor = Tensor::zeros((b, m, k), dtype, device).unwrap();
|
||||
|
||||
let flops = b * m * k * dtype.size_in_bytes();
|
||||
|
||||
|
@ -3,6 +3,7 @@ pub(crate) mod conv_transpose2d;
|
||||
pub(crate) mod matmul;
|
||||
pub(crate) mod qmatmul;
|
||||
pub(crate) mod random;
|
||||
pub(crate) mod unary;
|
||||
pub(crate) mod where_cond;
|
||||
|
||||
use candle_core::{Device, Result};
|
||||
|
@ -7,7 +7,7 @@ use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(matmul: &QMatMul, x: &Tensor) {
|
||||
matmul.forward(&x).unwrap();
|
||||
matmul.forward(x).unwrap();
|
||||
}
|
||||
|
||||
fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
|
||||
@ -50,7 +50,7 @@ fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
for dtype in vec![
|
||||
for dtype in [
|
||||
GgmlDType::F32,
|
||||
GgmlDType::F16,
|
||||
GgmlDType::Q4_0,
|
||||
|
49
candle-core/benches/benchmarks/unary.rs
Normal file
49
candle-core/benches/benchmarks/unary.rs
Normal file
@ -0,0 +1,49 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(a: &Tensor) {
|
||||
a.sqrt().unwrap();
|
||||
}
|
||||
|
||||
fn run_unary_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let tensor = Tensor::arange(0.0f32, (b * m * k) as f32, device)
|
||||
.unwrap()
|
||||
.to_dtype(dtype)
|
||||
.unwrap()
|
||||
.reshape((b, m, k))
|
||||
.unwrap();
|
||||
|
||||
let flops = b * m * k * dtype.size_in_bytes();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(black_box(&tensor));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
for dtype in [DType::F32, DType::BF16, DType::F16] {
|
||||
let name = format!("sqrt_{:?}", dtype);
|
||||
run_unary_benchmark(c, &device, dtype, &name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -25,9 +25,9 @@ const SIZE: usize = B * M * K;
|
||||
const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
|
||||
|
||||
fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), &device).unwrap();
|
||||
let on_true = Tensor::ones((B, M, K), dtype, &device).unwrap();
|
||||
let on_false = Tensor::zeros((B, M, K), dtype, &device).unwrap();
|
||||
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), device).unwrap();
|
||||
let on_true = Tensor::ones((B, M, K), dtype, device).unwrap();
|
||||
let on_false = Tensor::zeros((B, M, K), dtype, device).unwrap();
|
||||
|
||||
let elements = B * M * K;
|
||||
// E.g. 2 f32 tensors + 1 u8 tensor
|
||||
|
@ -5,32 +5,29 @@ extern crate accelerate_src;
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle_core::{Device, Module, Tensor};
|
||||
|
||||
use candle_core::quantized::{QMatMul, QTensor};
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
let q = Tensor::randn(0f32, 1.0, (72, 256), &device)?;
|
||||
let q_cpu = q.to_device(&Device::Cpu)?;
|
||||
let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q8K)?;
|
||||
let q = QMatMul::from_qtensor(q)?;
|
||||
let x = Tensor::randn(0f32, 1.0, (5, 256), &device)?;
|
||||
let res_q_cuda = q.forward(&x)?;
|
||||
println!("{res_q_cuda}");
|
||||
|
||||
let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q8K)?;
|
||||
let q_cpu_tensor = q_cpu.dequantize(&Device::Cpu)?;
|
||||
let q_cpu = QMatMul::from_qtensor(q_cpu)?;
|
||||
let x_cpu = x.to_device(&Device::Cpu)?;
|
||||
let res_q_cpu = q_cpu.forward(&x_cpu)?;
|
||||
println!("{res_q_cpu}");
|
||||
|
||||
let res_mm = x_cpu.matmul(&q_cpu_tensor.t()?)?;
|
||||
let diff = (res_mm - res_q_cuda.to_device(&Device::Cpu))?
|
||||
.abs()?
|
||||
.flatten_all()?
|
||||
.max(0)?;
|
||||
println!("{diff}");
|
||||
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
|
||||
.to_dtype(candle_core::DType::BF16)?;
|
||||
candle_core::cuda::set_gemm_reduced_precision_f32(false);
|
||||
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
|
||||
let _x1 = x.matmul(&x)?;
|
||||
drop(_x1);
|
||||
let start_time = std::time::Instant::now();
|
||||
let _x1 = x.matmul(&x)?;
|
||||
device.synchronize()?;
|
||||
println!("fp32: {:?}", start_time.elapsed());
|
||||
drop(_x1);
|
||||
candle_core::cuda::set_gemm_reduced_precision_f32(true);
|
||||
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
|
||||
let _x1 = x.matmul(&x)?;
|
||||
drop(_x1);
|
||||
let start_time = std::time::Instant::now();
|
||||
let _x1 = x.matmul(&x)?;
|
||||
device.synchronize()?;
|
||||
println!("tf32: {:?}", start_time.elapsed());
|
||||
drop(_x1);
|
||||
Ok(())
|
||||
}
|
||||
|
28
candle-core/examples/metal_basics.rs
Normal file
28
candle-core/examples/metal_basics.rs
Normal file
@ -0,0 +1,28 @@
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// This requires the code to be run with MTL_CAPTURE_ENABLED=1
|
||||
let device = Device::new_metal(0)?;
|
||||
let metal_device = match &device {
|
||||
Device::Metal(m) => m,
|
||||
_ => anyhow::bail!("unexpected device"),
|
||||
};
|
||||
metal_device.capture("/tmp/candle.gputrace")?;
|
||||
// This first synchronize ensures that a new command buffer gets created after setting up the
|
||||
// capture scope.
|
||||
device.synchronize()?;
|
||||
let x = Tensor::randn(0f32, 1.0, (128, 128), &device)?;
|
||||
let x1 = x.add(&x)?;
|
||||
println!("{x1:?}");
|
||||
// This second synchronize ensures that the command buffer gets commited before the end of the
|
||||
// capture scope.
|
||||
device.synchronize()?;
|
||||
Ok(())
|
||||
}
|
@ -133,6 +133,8 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
/// after this call.
|
||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
||||
|
||||
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage>;
|
||||
|
||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
|
||||
|
||||
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>;
|
||||
|
@ -320,13 +320,13 @@ impl Tensor {
|
||||
dilation,
|
||||
output_padding: _output_padding,
|
||||
} => {
|
||||
let grad_arg = grad.conv2d(kernel, *padding, *dilation, *stride, 1)?;
|
||||
let grad_arg = grad.conv2d(kernel, *padding, *stride, *dilation, 1)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
|
||||
let grad_kernel = grad
|
||||
.transpose(0, 1)?
|
||||
.conv2d(&arg.transpose(0, 1)?, *padding, *stride, *dilation, 1)?
|
||||
.conv2d(&arg.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
||||
.transpose(0, 1)?;
|
||||
let sum_grad = grads.or_insert(kernel)?;
|
||||
let (_, _, k0, k1) = kernel.dims4()?;
|
||||
@ -623,9 +623,9 @@ impl Tensor {
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Silu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
|
||||
let sigmoid_arg = (*node / arg)?;
|
||||
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
|
||||
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x))) = sigmoid(x) * (1 - node) + node
|
||||
let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
|
||||
let silu_grad = &sigmoid_arg * (1. - *node) + *node;
|
||||
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
|
||||
}
|
||||
Op::Elu(arg, alpha) => {
|
||||
@ -634,7 +634,8 @@ impl Tensor {
|
||||
let zeros = arg.zeros_like()?;
|
||||
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
|
||||
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
|
||||
let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?;
|
||||
// node == alpha * (e^x - 1) for x <= 0, reuse it
|
||||
let negative_exp_mask = (negative_mask * (*node + *alpha))?;
|
||||
let combined_mask = (positive_mask + negative_exp_mask)?;
|
||||
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
|
||||
}
|
||||
@ -755,4 +756,9 @@ impl GradStore {
|
||||
};
|
||||
Ok(grad)
|
||||
}
|
||||
|
||||
/// Get the tensor ids of the stored gradient tensors
|
||||
pub fn get_ids(&self) -> impl Iterator<Item = &TensorId> {
|
||||
self.0.keys()
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
pub mod erf;
|
||||
pub mod kernels;
|
||||
|
||||
#[allow(unused)]
|
||||
trait Cpu<const ARR: usize> {
|
||||
type Unit;
|
||||
type Array;
|
||||
@ -18,6 +19,7 @@ trait Cpu<const ARR: usize> {
|
||||
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit);
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
trait CpuF16<const ARR: usize> {
|
||||
type Unit;
|
||||
type Array;
|
||||
|
@ -10,7 +10,7 @@ pub use utils::{
|
||||
};
|
||||
|
||||
const USE_IM2COL_CONV1D: bool = true;
|
||||
const USE_IM2COL_CONV1D_TR: bool = true;
|
||||
const USE_COL2IM_CONV1D_TR: bool = true;
|
||||
const USE_IM2COL_CONV2D: bool = true;
|
||||
|
||||
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
||||
@ -26,6 +26,17 @@ pub enum CpuStorage {
|
||||
F64(Vec<f64>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum CpuStorageRef<'a> {
|
||||
U8(&'a [u8]),
|
||||
U32(&'a [u32]),
|
||||
I64(&'a [i64]),
|
||||
BF16(&'a [bf16]),
|
||||
F16(&'a [f16]),
|
||||
F32(&'a [f32]),
|
||||
F64(&'a [f64]),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CpuDevice;
|
||||
|
||||
@ -110,7 +121,8 @@ impl ReduceIndex {
|
||||
let dst_len = src_l.shape().elem_count() / reduce_dim_size;
|
||||
let mut dst: Vec<U> = Vec::with_capacity(dst_len);
|
||||
let dst_to_set = dst.spare_capacity_mut();
|
||||
let dst_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(dst_to_set) };
|
||||
let dst_to_set =
|
||||
unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(dst_to_set) };
|
||||
match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => {
|
||||
let src = &src[o1..o2];
|
||||
@ -2238,7 +2250,7 @@ impl BackendStorage for CpuStorage {
|
||||
&& params.dilation == 1
|
||||
&& params.padding == 0
|
||||
&& params.output_padding == 0;
|
||||
if USE_IM2COL_CONV1D_TR && can_use_col2im {
|
||||
if USE_COL2IM_CONV1D_TR && can_use_col2im {
|
||||
let (b_size, c_in, l_in) = l.shape().dims3()?;
|
||||
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
|
||||
if !kernel_l.is_contiguous() {
|
||||
@ -2445,6 +2457,10 @@ impl BackendDevice for CpuDevice {
|
||||
true
|
||||
}
|
||||
|
||||
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||
Ok(T::to_cpu_storage(s))
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result<Self::Storage> {
|
||||
Ok(s.clone())
|
||||
}
|
||||
|
@ -174,7 +174,9 @@ pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [
|
||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
|
||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||
let ys_to_set = unsafe {
|
||||
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
|
||||
};
|
||||
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(el_count) };
|
||||
@ -185,7 +187,9 @@ pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [
|
||||
let rhs = &rhs[ob.start..ob.start + ob.len];
|
||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||
let ys_to_set = unsafe {
|
||||
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
|
||||
};
|
||||
let mut dst_i = 0;
|
||||
for src_i in (o_l1..o_l2).step_by(ob.len) {
|
||||
f_vec(
|
||||
@ -224,7 +228,9 @@ pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [
|
||||
let lhs = &lhs[ob.start..ob.start + ob.len];
|
||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||
let ys_to_set = unsafe {
|
||||
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
|
||||
};
|
||||
let mut dst_i = 0;
|
||||
for src_i in (o_r1..o_r2).step_by(ob.len) {
|
||||
f_vec(
|
||||
@ -311,7 +317,9 @@ pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||
let mut ys: Vec<U> = Vec::with_capacity(len);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
||||
let ys_to_set = unsafe {
|
||||
std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
|
||||
};
|
||||
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(len) };
|
||||
@ -333,7 +341,9 @@ pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U
|
||||
} else {
|
||||
let mut ys: Vec<U> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
||||
let ys_to_set = unsafe {
|
||||
std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
|
||||
};
|
||||
let mut dst_index = 0;
|
||||
for src_index in block_start_index {
|
||||
let vs = &vs[src_index..src_index + block_len];
|
||||
|
@ -1,5 +1,5 @@
|
||||
use crate::backend::BackendDevice;
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
||||
pub use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
|
||||
@ -334,6 +334,43 @@ impl BackendDevice for CudaDevice {
|
||||
})
|
||||
}
|
||||
|
||||
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||
let slice = match T::cpu_storage_ref(s) {
|
||||
CpuStorageRef::U8(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
CpuStorageRef::U32(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorageRef::I64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorageRef::BF16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
CpuStorageRef::F16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
CpuStorageRef::F32(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
CpuStorageRef::F64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||
let slice = match storage {
|
||||
CpuStorage::U8(storage) => {
|
||||
|
@ -16,9 +16,9 @@ mod error;
|
||||
mod utils;
|
||||
pub use device::{CudaDevice, DeviceId};
|
||||
pub use error::{CudaError, WrapErr};
|
||||
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S};
|
||||
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, Map3, S};
|
||||
|
||||
enum SlicePtrOrNull<T> {
|
||||
pub enum SlicePtrOrNull<T> {
|
||||
Ptr(CudaSlice<T>),
|
||||
Null,
|
||||
}
|
||||
@ -33,7 +33,7 @@ unsafe impl<T: DeviceRepr> DeviceRepr for &SlicePtrOrNull<T> {
|
||||
}
|
||||
|
||||
impl SlicePtrOrNull<usize> {
|
||||
fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
|
||||
pub fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
|
||||
let ds = if l.is_contiguous() {
|
||||
SlicePtrOrNull::Null
|
||||
} else {
|
||||
@ -250,44 +250,6 @@ impl Map1 for Powf {
|
||||
}
|
||||
}
|
||||
|
||||
struct Sum<'a>(&'a [usize]);
|
||||
impl<'a> Map1 for Sum<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let shape = layout.shape();
|
||||
let src_dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let mut dst_el = el;
|
||||
for &sum_dim in self.0.iter() {
|
||||
dst_el /= src_dims[sum_dim];
|
||||
}
|
||||
let mut sum_dims = self.0.to_vec();
|
||||
// Sort the sum_dims as they have to be processed from left to right when converting the
|
||||
// indexes.
|
||||
sum_dims.sort();
|
||||
let sum_dims_l: Vec<usize> = sum_dims.iter().map(|&d| src_dims[d]).collect();
|
||||
let sum_dims_s: Vec<usize> = sum_dims
|
||||
.iter()
|
||||
.map(|&d| src_dims[d + 1..].iter().product::<usize>())
|
||||
.collect();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let ds = dev
|
||||
.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())
|
||||
.w()?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("sum"), kernels::REDUCE)?;
|
||||
let out = dev.alloc_zeros::<T>(dst_el).w()?;
|
||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, src, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct FastReduce<'a>(&'a [usize], ReduceOp);
|
||||
impl<'a> Map1Any for FastReduce<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||
@ -668,6 +630,31 @@ impl<'a> Map2 for Conv2D<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct Col2Im1D {
|
||||
stride: usize,
|
||||
}
|
||||
|
||||
impl Map1 for Col2Im1D {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
col: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
l: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
|
||||
let stride = self.stride;
|
||||
let l_out = (l_in - 1) * stride + k_size;
|
||||
let dst_el = b_size * c_out * l_out;
|
||||
let mut im = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let params = (dst_el, l_out, l_in, c_out, k_size, stride, col, &mut im);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("col2im1d"), kernels::CONV)?;
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(im)
|
||||
}
|
||||
}
|
||||
|
||||
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
||||
impl<'a> Map2 for ConvTranspose1D<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
@ -1404,9 +1391,55 @@ impl BackendStorage for CudaStorage {
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
const USE_COL2IM_CONV1D_TR: bool = true;
|
||||
|
||||
let device = self.device().clone();
|
||||
let slice =
|
||||
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
let can_use_col2im = kernel_l.is_contiguous()
|
||||
&& params.dilation == 1
|
||||
&& params.padding == 0
|
||||
&& params.output_padding == 0;
|
||||
let slice = if USE_COL2IM_CONV1D_TR && can_use_col2im {
|
||||
let (b_size, c_in, l_in) = l.shape().dims3()?;
|
||||
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
|
||||
if !kernel_l.is_contiguous() {
|
||||
crate::bail!(
|
||||
"convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}"
|
||||
)
|
||||
}
|
||||
if c_in != c_in2 {
|
||||
crate::bail!(
|
||||
"convtr1d: shape mismatch on c_in {:?} {:?}",
|
||||
l.shape(),
|
||||
kernel_l.shape()
|
||||
)
|
||||
}
|
||||
let col = {
|
||||
// This merges the last two dimensions of the kernel together.
|
||||
let kernel_l_mm = Layout::new(
|
||||
(b_size, c_in, k_size * c_out).into(),
|
||||
vec![0, k_size * c_out, 1],
|
||||
kernel_l.start_offset(),
|
||||
);
|
||||
self.matmul(
|
||||
kernel,
|
||||
(
|
||||
b_size,
|
||||
/* m */ l_in,
|
||||
/* n */ c_out * k_size,
|
||||
/* k */ c_in,
|
||||
),
|
||||
&l.transpose(1, 2)?,
|
||||
&kernel_l_mm,
|
||||
)?
|
||||
};
|
||||
let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
|
||||
Col2Im1D {
|
||||
stride: params.stride,
|
||||
}
|
||||
.map(&col.slice, &device, &col_l)?
|
||||
} else {
|
||||
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?
|
||||
};
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
@ -1635,12 +1668,8 @@ impl BackendStorage for CudaStorage {
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let cfg = gemm_config(bf16::ONE, bf16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
|
||||
let mut out = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
|
||||
unsafe {
|
||||
self.device
|
||||
.blas
|
||||
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
|
||||
}
|
||||
.w()?;
|
||||
unsafe { gemm_strided_batched_bf16(&self.device.blas, cfg, rhs, lhs, &mut out) }
|
||||
.w()?;
|
||||
CudaStorageSlice::BF16(out)
|
||||
}
|
||||
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
|
||||
@ -1648,12 +1677,8 @@ impl BackendStorage for CudaStorage {
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
|
||||
let mut out = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||
unsafe {
|
||||
self.device
|
||||
.blas
|
||||
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
|
||||
}
|
||||
.w()?;
|
||||
unsafe { gemm_strided_batched_f16(&self.device.blas, cfg, rhs, lhs, &mut out) }
|
||||
.w()?;
|
||||
CudaStorageSlice::F16(out)
|
||||
}
|
||||
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
|
||||
@ -1661,12 +1686,8 @@ impl BackendStorage for CudaStorage {
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
|
||||
let mut out = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
||||
unsafe {
|
||||
self.device
|
||||
.blas
|
||||
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
|
||||
}
|
||||
.w()?;
|
||||
unsafe { gemm_strided_batched_f32(&self.device.blas, cfg, rhs, lhs, &mut out) }
|
||||
.w()?;
|
||||
CudaStorageSlice::F32(out)
|
||||
}
|
||||
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
|
||||
@ -1856,3 +1877,203 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// Default for the reduced precision setting is false, similar to pytorch.
|
||||
// https://github.com/pytorch/pytorch/issues/123157
|
||||
static MM_F16_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
|
||||
std::sync::atomic::AtomicBool::new(false);
|
||||
static MM_BF16_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
|
||||
std::sync::atomic::AtomicBool::new(false);
|
||||
static MM_F32_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
|
||||
std::sync::atomic::AtomicBool::new(false);
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
|
||||
/// allowed with f32 GEMMs.
|
||||
pub fn gemm_reduced_precision_f32() -> bool {
|
||||
MM_F32_REDUCED_PRECISION.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
|
||||
/// allowed with f32 GEMMs.
|
||||
pub fn set_gemm_reduced_precision_f32(b: bool) {
|
||||
MM_F32_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||
/// allowed with f16 GEMMs.
|
||||
pub fn gemm_reduced_precision_f16() -> bool {
|
||||
MM_F16_REDUCED_PRECISION.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||
/// allowed with f16 GEMMs.
|
||||
pub fn set_gemm_reduced_precision_f16(b: bool) {
|
||||
MM_F16_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||
/// allowed with bf16 GEMMs.
|
||||
pub fn gemm_reduced_precision_bf16() -> bool {
|
||||
MM_BF16_REDUCED_PRECISION.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||
/// allowed with bf16 GEMMs.
|
||||
pub fn set_gemm_reduced_precision_bf16(b: bool) {
|
||||
MM_BF16_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
unsafe fn gemm_strided_batched_f32(
|
||||
cublas: &cudarc::cublas::CudaBlas,
|
||||
cfg: StridedBatchedConfig<f32>,
|
||||
a: &cudarc::driver::CudaView<f32>,
|
||||
b: &cudarc::driver::CudaView<f32>,
|
||||
c: &mut CudaSlice<f32>,
|
||||
) -> std::result::Result<(), cudarc::cublas::result::CublasError> {
|
||||
use cudarc::cublas::sys;
|
||||
use cudarc::driver::DevicePtrMut;
|
||||
|
||||
let compute_type = if gemm_reduced_precision_f32() {
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32
|
||||
} else {
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
|
||||
};
|
||||
let alpha = &cfg.gemm.alpha as *const f32 as *const _;
|
||||
let beta = &cfg.gemm.beta as *const f32 as *const _;
|
||||
|
||||
cudarc::cublas::result::gemm_strided_batched_ex(
|
||||
*cublas.handle(),
|
||||
cfg.gemm.transa,
|
||||
cfg.gemm.transb,
|
||||
cfg.gemm.m,
|
||||
cfg.gemm.n,
|
||||
cfg.gemm.k,
|
||||
alpha,
|
||||
*a.device_ptr() as *const _,
|
||||
sys::cudaDataType_t::CUDA_R_32F,
|
||||
cfg.gemm.lda,
|
||||
cfg.stride_a,
|
||||
*b.device_ptr() as *const _,
|
||||
sys::cudaDataType_t::CUDA_R_32F,
|
||||
cfg.gemm.ldb,
|
||||
cfg.stride_b,
|
||||
beta,
|
||||
*c.device_ptr_mut() as *mut _,
|
||||
sys::cudaDataType_t::CUDA_R_32F,
|
||||
cfg.gemm.ldc,
|
||||
cfg.stride_c,
|
||||
cfg.batch_size,
|
||||
compute_type,
|
||||
sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP,
|
||||
)
|
||||
}
|
||||
|
||||
unsafe fn gemm_strided_batched_f16(
|
||||
cublas: &cudarc::cublas::CudaBlas,
|
||||
cfg: StridedBatchedConfig<f16>,
|
||||
a: &cudarc::driver::CudaView<f16>,
|
||||
b: &cudarc::driver::CudaView<f16>,
|
||||
c: &mut CudaSlice<f16>,
|
||||
) -> std::result::Result<(), cudarc::cublas::result::CublasError> {
|
||||
use cudarc::cublas::sys;
|
||||
use cudarc::driver::DevicePtrMut;
|
||||
|
||||
let alpha = cfg.gemm.alpha;
|
||||
let beta = cfg.gemm.beta;
|
||||
let alpha_f32: f32 = cfg.gemm.alpha.to_f32();
|
||||
let beta_f32: f32 = cfg.gemm.beta.to_f32();
|
||||
let (compute_type, alpha, beta) = if gemm_reduced_precision_f16() {
|
||||
(
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F,
|
||||
(&alpha) as *const f16 as *const _,
|
||||
(&beta) as *const f16 as *const _,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
|
||||
(&alpha_f32) as *const f32 as *const _,
|
||||
(&beta_f32) as *const f32 as *const _,
|
||||
)
|
||||
};
|
||||
|
||||
cudarc::cublas::result::gemm_strided_batched_ex(
|
||||
*cublas.handle(),
|
||||
cfg.gemm.transa,
|
||||
cfg.gemm.transb,
|
||||
cfg.gemm.m,
|
||||
cfg.gemm.n,
|
||||
cfg.gemm.k,
|
||||
alpha,
|
||||
*a.device_ptr() as *const _,
|
||||
sys::cudaDataType_t::CUDA_R_16F,
|
||||
cfg.gemm.lda,
|
||||
cfg.stride_a,
|
||||
*b.device_ptr() as *const _,
|
||||
sys::cudaDataType_t::CUDA_R_16F,
|
||||
cfg.gemm.ldb,
|
||||
cfg.stride_b,
|
||||
beta,
|
||||
*c.device_ptr_mut() as *mut _,
|
||||
sys::cudaDataType_t::CUDA_R_16F,
|
||||
cfg.gemm.ldc,
|
||||
cfg.stride_c,
|
||||
cfg.batch_size,
|
||||
compute_type,
|
||||
sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP,
|
||||
)
|
||||
}
|
||||
|
||||
unsafe fn gemm_strided_batched_bf16(
|
||||
cublas: &cudarc::cublas::CudaBlas,
|
||||
cfg: StridedBatchedConfig<bf16>,
|
||||
a: &cudarc::driver::CudaView<bf16>,
|
||||
b: &cudarc::driver::CudaView<bf16>,
|
||||
c: &mut CudaSlice<bf16>,
|
||||
) -> std::result::Result<(), cudarc::cublas::result::CublasError> {
|
||||
use cudarc::cublas::sys;
|
||||
use cudarc::driver::DevicePtrMut;
|
||||
|
||||
let alpha_f32: f32 = cfg.gemm.alpha.to_f32();
|
||||
let beta_f32: f32 = cfg.gemm.beta.to_f32();
|
||||
// The type for alpha and beta depends on the computeType.
|
||||
// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex
|
||||
let (compute_type, alpha, beta) = if gemm_reduced_precision_bf16() {
|
||||
(
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16BF,
|
||||
(&alpha_f32) as *const f32 as *const _,
|
||||
(&beta_f32) as *const f32 as *const _,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
|
||||
(&alpha_f32) as *const f32 as *const _,
|
||||
(&beta_f32) as *const f32 as *const _,
|
||||
)
|
||||
};
|
||||
|
||||
cudarc::cublas::result::gemm_strided_batched_ex(
|
||||
*cublas.handle(),
|
||||
cfg.gemm.transa,
|
||||
cfg.gemm.transb,
|
||||
cfg.gemm.m,
|
||||
cfg.gemm.n,
|
||||
cfg.gemm.k,
|
||||
alpha,
|
||||
*a.device_ptr() as *const _,
|
||||
sys::cudaDataType_t::CUDA_R_16BF,
|
||||
cfg.gemm.lda,
|
||||
cfg.stride_a,
|
||||
*b.device_ptr() as *const _,
|
||||
sys::cudaDataType_t::CUDA_R_16BF,
|
||||
cfg.gemm.ldb,
|
||||
cfg.stride_b,
|
||||
beta,
|
||||
*c.device_ptr_mut() as *mut _,
|
||||
sys::cudaDataType_t::CUDA_R_16BF,
|
||||
cfg.gemm.ldc,
|
||||
cfg.stride_c,
|
||||
cfg.batch_size,
|
||||
compute_type,
|
||||
sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP,
|
||||
)
|
||||
}
|
||||
|
@ -54,6 +54,44 @@ pub trait Map2 {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map3 {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
layout1: &Layout,
|
||||
src2: &CudaSlice<T>,
|
||||
layout2: &Layout,
|
||||
src3: &CudaSlice<T>,
|
||||
layout3: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>>;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn map(
|
||||
&self,
|
||||
s1: &S,
|
||||
l1: &Layout,
|
||||
s2: &S,
|
||||
l2: &Layout,
|
||||
s3: &S,
|
||||
l3: &Layout,
|
||||
d: &CudaDevice,
|
||||
) -> Result<S> {
|
||||
let out = match (s1, s2, s3) {
|
||||
(S::U8(s1), S::U8(s2), S::U8(s3)) => S::U8(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
(S::U32(s1), S::U32(s2), S::U32(s3)) => S::U32(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
(S::I64(s1), S::I64(s2), S::I64(s3)) => S::I64(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
(S::BF16(s1), S::BF16(s2), S::BF16(s3)) => S::BF16(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
(S::F16(s1), S::F16(s2), S::F16(s3)) => S::F16(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
(S::F32(s1), S::F32(s2), S::F32(s3)) => S::F32(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
(S::F64(s1), S::F64(s2), S::F64(s3)) => S::F64(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in ternary op"))?,
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2InPlace {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
|
@ -171,6 +171,22 @@ impl Device {
|
||||
matches!(self, Self::Metal(_))
|
||||
}
|
||||
|
||||
pub fn supports_bf16(&self) -> bool {
|
||||
match self {
|
||||
Self::Cuda(_) | Self::Metal(_) => true,
|
||||
Self::Cpu => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return `BF16` for devices that support it, otherwise default to `F32`.
|
||||
pub fn bf16_default_to_f32(&self) -> DType {
|
||||
if self.supports_bf16() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
||||
if crate::utils::cuda_is_available() {
|
||||
Self::new_cuda(ordinal)
|
||||
@ -306,6 +322,20 @@ impl Device {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn storage_from_slice<D: WithDType>(&self, data: &[D]) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => Ok(Storage::Cpu(data.to_cpu_storage())),
|
||||
Device::Cuda(device) => {
|
||||
let storage = device.storage_from_slice(data)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.storage_from_slice(data)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
|
||||
|
@ -1,7 +1,7 @@
|
||||
//! Types for elements that can be stored and manipulated using tensors.
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::{CpuStorage, Error, Result};
|
||||
use crate::{CpuStorage, CpuStorageRef, Error, Result};
|
||||
|
||||
/// The different types of elements allowed in tensors.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
@ -100,12 +100,14 @@ pub trait WithDType:
|
||||
+ 'static
|
||||
+ Send
|
||||
+ Sync
|
||||
+ std::any::Any
|
||||
+ crate::cpu::kernels::VecOps
|
||||
{
|
||||
const DTYPE: DType;
|
||||
|
||||
fn from_f64(v: f64) -> Self;
|
||||
fn to_f64(self) -> f64;
|
||||
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>;
|
||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
|
||||
|
||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
||||
@ -129,6 +131,10 @@ macro_rules! with_dtype {
|
||||
$to_f64(self)
|
||||
}
|
||||
|
||||
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> {
|
||||
CpuStorageRef::$dtype(data)
|
||||
}
|
||||
|
||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
|
||||
CpuStorage::$dtype(data)
|
||||
}
|
||||
|
@ -214,6 +214,10 @@ impl crate::backend::BackendDevice for CudaDevice {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
@ -234,3 +238,33 @@ impl crate::backend::BackendDevice for CudaDevice {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||
/// allowed with f16 GEMMs.
|
||||
pub fn gemm_reduced_precision_f16() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||
/// allowed with f16 GEMMs.
|
||||
pub fn set_gemm_reduced_precision_f16(_: bool) {}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||
/// allowed with bf16 GEMMs.
|
||||
pub fn gemm_reduced_precision_bf16() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||
/// allowed with bf16 GEMMs.
|
||||
pub fn set_gemm_reduced_precision_bf16(_: bool) {}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
|
||||
/// allowed with f32 GEMMs.
|
||||
pub fn gemm_reduced_precision_f32() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
|
||||
/// allowed with f32 GEMMs.
|
||||
pub fn set_gemm_reduced_precision_f32(_b: bool) {}
|
||||
|
@ -226,6 +226,10 @@ impl crate::backend::BackendDevice for MetalDevice {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
@ -219,10 +219,14 @@ impl Error {
|
||||
Self::Wrapped(Box::new(err)).bt()
|
||||
}
|
||||
|
||||
pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self {
|
||||
pub fn msg(err: impl std::error::Error) -> Self {
|
||||
Self::Msg(err.to_string()).bt()
|
||||
}
|
||||
|
||||
pub fn debug(err: impl std::fmt::Debug) -> Self {
|
||||
Self::Msg(format!("{err:?}")).bt()
|
||||
}
|
||||
|
||||
pub fn bt(self) -> Self {
|
||||
let backtrace = std::backtrace::Backtrace::capture();
|
||||
match backtrace.status() {
|
||||
|
@ -47,7 +47,7 @@ mod custom_op;
|
||||
mod device;
|
||||
pub mod display;
|
||||
mod dtype;
|
||||
mod dummy_cuda_backend;
|
||||
pub mod dummy_cuda_backend;
|
||||
mod dummy_metal_backend;
|
||||
pub mod error;
|
||||
mod indexer;
|
||||
@ -63,6 +63,7 @@ pub mod quantized;
|
||||
pub mod safetensors;
|
||||
pub mod scalar;
|
||||
pub mod shape;
|
||||
mod sort;
|
||||
mod storage;
|
||||
mod strided_index;
|
||||
mod tensor;
|
||||
@ -74,7 +75,7 @@ mod variable;
|
||||
#[cfg(feature = "cudnn")]
|
||||
pub use cuda_backend::cudnn;
|
||||
|
||||
pub use cpu_backend::CpuStorage;
|
||||
pub use cpu_backend::{CpuStorage, CpuStorageRef};
|
||||
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
|
||||
pub use device::{Device, DeviceLocation, NdArray};
|
||||
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
|
||||
@ -88,10 +89,12 @@ pub use tensor::{Tensor, TensorId};
|
||||
pub use variable::Var;
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
pub use cuda_backend::{CudaDevice, CudaStorage};
|
||||
pub use cuda_backend as cuda;
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
|
||||
pub use dummy_cuda_backend as cuda;
|
||||
|
||||
pub use cuda::{CudaDevice, CudaStorage};
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
|
||||
|
@ -100,11 +100,11 @@ impl MetalDevice {
|
||||
}
|
||||
|
||||
pub fn command_buffer(&self) -> Result<CommandBuffer> {
|
||||
let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?;
|
||||
let mut command_buffer_lock = self.command_buffer.write().map_err(MetalError::from)?;
|
||||
let mut command_buffer = command_buffer_lock.to_owned();
|
||||
let mut index = self
|
||||
.command_buffer_index
|
||||
.try_write()
|
||||
.write()
|
||||
.map_err(MetalError::from)?;
|
||||
if *index > self.compute_per_buffer {
|
||||
command_buffer.commit();
|
||||
@ -119,7 +119,7 @@ impl MetalDevice {
|
||||
}
|
||||
|
||||
pub fn wait_until_completed(&self) -> Result<()> {
|
||||
let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?;
|
||||
let mut command_buffer = self.command_buffer.write().map_err(MetalError::from)?;
|
||||
match command_buffer.status() {
|
||||
metal::MTLCommandBufferStatus::Committed
|
||||
| metal::MTLCommandBufferStatus::Scheduled
|
||||
@ -179,7 +179,7 @@ impl MetalDevice {
|
||||
size,
|
||||
MTLResourceOptions::StorageModeManaged,
|
||||
);
|
||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||
let subbuffers = buffers
|
||||
.entry((size, MTLResourceOptions::StorageModeManaged))
|
||||
.or_insert(vec![]);
|
||||
@ -232,7 +232,7 @@ impl MetalDevice {
|
||||
}
|
||||
|
||||
fn drop_unused_buffers(&self) -> Result<()> {
|
||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||
for subbuffers in buffers.values_mut() {
|
||||
let newbuffers = subbuffers
|
||||
.iter()
|
||||
@ -251,7 +251,7 @@ impl MetalDevice {
|
||||
option: MTLResourceOptions,
|
||||
_name: &str,
|
||||
) -> Result<Arc<Buffer>> {
|
||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
|
||||
// Cloning also ensures we increment the strong count
|
||||
return Ok(b.clone());
|
||||
@ -273,7 +273,13 @@ impl MetalDevice {
|
||||
let descriptor = metal::CaptureDescriptor::new();
|
||||
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
||||
descriptor.set_capture_device(self);
|
||||
descriptor.set_output_url(path);
|
||||
// The [set_output_url] call requires an absolute path so we convert it if needed.
|
||||
if path.as_ref().is_absolute() {
|
||||
descriptor.set_output_url(path);
|
||||
} else {
|
||||
let path = std::env::current_dir()?.join(path);
|
||||
descriptor.set_output_url(path);
|
||||
}
|
||||
|
||||
capture
|
||||
.start_capture(&descriptor)
|
||||
|
@ -1,17 +1,17 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
||||
use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels};
|
||||
use metal::{Buffer, MTLResourceOptions, NSUInteger};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::sync::{Arc, Mutex, RwLock, TryLockError};
|
||||
use std::sync::{Arc, Mutex, PoisonError, RwLock, TryLockError};
|
||||
|
||||
mod device;
|
||||
pub use device::{DeviceId, MetalDevice};
|
||||
|
||||
fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> {
|
||||
pub fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> {
|
||||
BufferOffset {
|
||||
buffer,
|
||||
offset_in_bytes: l.start_offset() * dtype.size_in_bytes(),
|
||||
@ -36,6 +36,12 @@ impl<T> From<TryLockError<T>> for MetalError {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<PoisonError<T>> for MetalError {
|
||||
fn from(p: PoisonError<T>) -> Self {
|
||||
MetalError::LockError(LockError::Poisoned(p.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Metal related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum MetalError {
|
||||
@ -113,6 +119,8 @@ impl BackendStorage for MetalStorage {
|
||||
DType::F32 => "affine_f32",
|
||||
DType::F16 => "affine_f16",
|
||||
DType::BF16 => "affine_bf16",
|
||||
DType::U8 => "affine_u8",
|
||||
DType::U32 => "affine_u32",
|
||||
dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_affine(
|
||||
@ -444,156 +452,238 @@ impl BackendStorage for MetalStorage {
|
||||
let command_buffer = device.command_buffer()?;
|
||||
command_buffer.set_label(B::KERNEL);
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
if layout.is_contiguous() {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("uabs", DType::F16) => contiguous::abs::HALF,
|
||||
("uabs", DType::F32) => contiguous::abs::FLOAT,
|
||||
("uabs", DType::BF16) => contiguous::abs::BFLOAT,
|
||||
("uceil", DType::F16) => contiguous::ceil::HALF,
|
||||
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
||||
("uceil", DType::BF16) => contiguous::ceil::BFLOAT,
|
||||
("ucos", DType::F16) => contiguous::cos::HALF,
|
||||
("ucos", DType::F32) => contiguous::cos::FLOAT,
|
||||
("ucos", DType::BF16) => contiguous::cos::BFLOAT,
|
||||
("uerf", DType::F16) => contiguous::erf::HALF,
|
||||
("uerf", DType::F32) => contiguous::erf::FLOAT,
|
||||
("uerf", DType::BF16) => contiguous::erf::BFLOAT,
|
||||
("uexp", DType::F16) => contiguous::exp::HALF,
|
||||
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
||||
("uexp", DType::BF16) => contiguous::exp::BFLOAT,
|
||||
("ufloor", DType::F16) => contiguous::floor::HALF,
|
||||
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
||||
("ufloor", DType::BF16) => contiguous::floor::BFLOAT,
|
||||
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
|
||||
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
|
||||
("ugelu_erf", DType::BF16) => contiguous::gelu_erf::BFLOAT,
|
||||
("ugelu", DType::F16) => contiguous::gelu::HALF,
|
||||
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
||||
("ugelu", DType::BF16) => contiguous::gelu::BFLOAT,
|
||||
("ulog", DType::F16) => contiguous::log::HALF,
|
||||
("ulog", DType::F32) => contiguous::log::FLOAT,
|
||||
("ulog", DType::BF16) => contiguous::log::BFLOAT,
|
||||
("uneg", DType::F16) => contiguous::neg::HALF,
|
||||
("uneg", DType::F32) => contiguous::neg::FLOAT,
|
||||
("uneg", DType::BF16) => contiguous::neg::BFLOAT,
|
||||
("urecip", DType::F16) => contiguous::recip::HALF,
|
||||
("urecip", DType::F32) => contiguous::recip::FLOAT,
|
||||
("urecip", DType::BF16) => contiguous::recip::BFLOAT,
|
||||
("urelu", DType::F16) => contiguous::relu::HALF,
|
||||
("urelu", DType::F32) => contiguous::relu::FLOAT,
|
||||
("urelu", DType::BF16) => contiguous::relu::BFLOAT,
|
||||
("uround", DType::F16) => contiguous::round::HALF,
|
||||
("uround", DType::F32) => contiguous::round::FLOAT,
|
||||
("uround", DType::BF16) => contiguous::round::BFLOAT,
|
||||
("usilu", DType::F16) => contiguous::silu::HALF,
|
||||
("usilu", DType::F32) => contiguous::silu::FLOAT,
|
||||
("usilu", DType::BF16) => contiguous::silu::BFLOAT,
|
||||
("usin", DType::F16) => contiguous::sin::HALF,
|
||||
("usin", DType::F32) => contiguous::sin::FLOAT,
|
||||
("usin", DType::BF16) => contiguous::sin::BFLOAT,
|
||||
("usqr", DType::F16) => contiguous::sqr::HALF,
|
||||
("usqr", DType::F32) => contiguous::sqr::FLOAT,
|
||||
("usqr", DType::BF16) => contiguous::sqr::BFLOAT,
|
||||
("usqrt", DType::F16) => contiguous::sqrt::HALF,
|
||||
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
|
||||
("usqrt", DType::BF16) => contiguous::sqrt::BFLOAT,
|
||||
("utanh", DType::F16) => contiguous::tanh::HALF,
|
||||
("utanh", DType::F32) => contiguous::tanh::FLOAT,
|
||||
("utanh", DType::BF16) => contiguous::tanh::BFLOAT,
|
||||
("usign", DType::F16) => contiguous::sign::HALF,
|
||||
("usign", DType::F32) => contiguous::sign::FLOAT,
|
||||
("usign", DType::BF16) => contiguous::sign::BFLOAT,
|
||||
("usign", DType::I64) => contiguous::sign::I64,
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
|
||||
}
|
||||
};
|
||||
candle_metal_kernels::call_unary_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
use candle_metal_kernels::unary::strided;
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("ucos", DType::F32) => strided::cos::FLOAT,
|
||||
("usin", DType::F32) => strided::sin::FLOAT,
|
||||
("usqr", DType::F32) => strided::sqr::FLOAT,
|
||||
("usqrt", DType::F32) => strided::sqrt::FLOAT,
|
||||
("uneg", DType::F32) => strided::neg::FLOAT,
|
||||
("uexp", DType::F32) => strided::exp::FLOAT,
|
||||
("ulog", DType::F32) => strided::log::FLOAT,
|
||||
("ugelu", DType::F32) => strided::gelu::FLOAT,
|
||||
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
|
||||
("uerf", DType::F32) => strided::erf::FLOAT,
|
||||
("usilu", DType::F32) => strided::silu::FLOAT,
|
||||
("uabs", DType::F32) => strided::abs::FLOAT,
|
||||
("uceil", DType::F32) => strided::ceil::FLOAT,
|
||||
("ufloor", DType::F32) => strided::floor::FLOAT,
|
||||
("urelu", DType::F32) => strided::relu::FLOAT,
|
||||
("uround", DType::F32) => strided::round::FLOAT,
|
||||
("utanh", DType::F32) => strided::tanh::FLOAT,
|
||||
match (el_count % 2, dtype, layout.is_contiguous()) {
|
||||
(0, DType::BF16 | DType::F16, true) => {
|
||||
use candle_metal_kernels::unary::contiguous_tiled;
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("uabs", DType::F16) => contiguous_tiled::abs::HALF,
|
||||
("uabs", DType::F32) => contiguous_tiled::abs::FLOAT,
|
||||
("uabs", DType::BF16) => contiguous_tiled::abs::BFLOAT,
|
||||
("uceil", DType::F16) => contiguous_tiled::ceil::HALF,
|
||||
("uceil", DType::F32) => contiguous_tiled::ceil::FLOAT,
|
||||
("uceil", DType::BF16) => contiguous_tiled::ceil::BFLOAT,
|
||||
("ucos", DType::F16) => contiguous_tiled::cos::HALF,
|
||||
("ucos", DType::F32) => contiguous_tiled::cos::FLOAT,
|
||||
("ucos", DType::BF16) => contiguous_tiled::cos::BFLOAT,
|
||||
("uerf", DType::F16) => contiguous_tiled::erf::HALF,
|
||||
("uerf", DType::F32) => contiguous_tiled::erf::FLOAT,
|
||||
("uerf", DType::BF16) => contiguous_tiled::erf::BFLOAT,
|
||||
("uexp", DType::F16) => contiguous_tiled::exp::HALF,
|
||||
("uexp", DType::F32) => contiguous_tiled::exp::FLOAT,
|
||||
("uexp", DType::BF16) => contiguous_tiled::exp::BFLOAT,
|
||||
("ufloor", DType::F16) => contiguous_tiled::floor::HALF,
|
||||
("ufloor", DType::F32) => contiguous_tiled::floor::FLOAT,
|
||||
("ufloor", DType::BF16) => contiguous_tiled::floor::BFLOAT,
|
||||
("ugelu_erf", DType::F16) => contiguous_tiled::gelu_erf::HALF,
|
||||
("ugelu_erf", DType::F32) => contiguous_tiled::gelu_erf::FLOAT,
|
||||
("ugelu_erf", DType::BF16) => contiguous_tiled::gelu_erf::BFLOAT,
|
||||
("ugelu", DType::F16) => contiguous_tiled::gelu::HALF,
|
||||
("ugelu", DType::F32) => contiguous_tiled::gelu::FLOAT,
|
||||
("ugelu", DType::BF16) => contiguous_tiled::gelu::BFLOAT,
|
||||
("ulog", DType::F16) => contiguous_tiled::log::HALF,
|
||||
("ulog", DType::F32) => contiguous_tiled::log::FLOAT,
|
||||
("ulog", DType::BF16) => contiguous_tiled::log::BFLOAT,
|
||||
("uneg", DType::F16) => contiguous_tiled::neg::HALF,
|
||||
("uneg", DType::F32) => contiguous_tiled::neg::FLOAT,
|
||||
("uneg", DType::BF16) => contiguous_tiled::neg::BFLOAT,
|
||||
("urecip", DType::F16) => contiguous_tiled::recip::HALF,
|
||||
("urecip", DType::F32) => contiguous_tiled::recip::FLOAT,
|
||||
("urecip", DType::BF16) => contiguous_tiled::recip::BFLOAT,
|
||||
("urelu", DType::F16) => contiguous_tiled::relu::HALF,
|
||||
("urelu", DType::F32) => contiguous_tiled::relu::FLOAT,
|
||||
("urelu", DType::BF16) => contiguous_tiled::relu::BFLOAT,
|
||||
("uround", DType::F16) => contiguous_tiled::round::HALF,
|
||||
("uround", DType::F32) => contiguous_tiled::round::FLOAT,
|
||||
("uround", DType::BF16) => contiguous_tiled::round::BFLOAT,
|
||||
("usilu", DType::F16) => contiguous_tiled::silu::HALF,
|
||||
("usilu", DType::F32) => contiguous_tiled::silu::FLOAT,
|
||||
("usilu", DType::BF16) => contiguous_tiled::silu::BFLOAT,
|
||||
("usin", DType::F16) => contiguous_tiled::sin::HALF,
|
||||
("usin", DType::F32) => contiguous_tiled::sin::FLOAT,
|
||||
("usin", DType::BF16) => contiguous_tiled::sin::BFLOAT,
|
||||
("usqr", DType::F16) => contiguous_tiled::sqr::HALF,
|
||||
("usqr", DType::F32) => contiguous_tiled::sqr::FLOAT,
|
||||
("usqr", DType::BF16) => contiguous_tiled::sqr::BFLOAT,
|
||||
("usqrt", DType::F16) => contiguous_tiled::sqrt::HALF,
|
||||
("usqrt", DType::F32) => contiguous_tiled::sqrt::FLOAT,
|
||||
("usqrt", DType::BF16) => contiguous_tiled::sqrt::BFLOAT,
|
||||
("utanh", DType::F16) => contiguous_tiled::tanh::HALF,
|
||||
("utanh", DType::F32) => contiguous_tiled::tanh::FLOAT,
|
||||
("utanh", DType::BF16) => contiguous_tiled::tanh::BFLOAT,
|
||||
("usign", DType::F16) => contiguous_tiled::sign::HALF,
|
||||
("usign", DType::F32) => contiguous_tiled::sign::FLOAT,
|
||||
("usign", DType::BF16) => contiguous_tiled::sign::BFLOAT,
|
||||
("usign", DType::I64) => contiguous_tiled::sign::I64,
|
||||
(name, dtype) => {
|
||||
crate::bail!(
|
||||
"Metal contiguous_tiled unary {name} {dtype:?} not implemented"
|
||||
)
|
||||
}
|
||||
};
|
||||
candle_metal_kernels::call_unary_contiguous_tiled(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
(_, _, true) => {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("uabs", DType::F16) => contiguous::abs::HALF,
|
||||
("uabs", DType::F32) => contiguous::abs::FLOAT,
|
||||
("uabs", DType::BF16) => contiguous::abs::BFLOAT,
|
||||
("uceil", DType::F16) => contiguous::ceil::HALF,
|
||||
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
||||
("uceil", DType::BF16) => contiguous::ceil::BFLOAT,
|
||||
("ucos", DType::F16) => contiguous::cos::HALF,
|
||||
("ucos", DType::F32) => contiguous::cos::FLOAT,
|
||||
("ucos", DType::BF16) => contiguous::cos::BFLOAT,
|
||||
("uerf", DType::F16) => contiguous::erf::HALF,
|
||||
("uerf", DType::F32) => contiguous::erf::FLOAT,
|
||||
("uerf", DType::BF16) => contiguous::erf::BFLOAT,
|
||||
("uexp", DType::F16) => contiguous::exp::HALF,
|
||||
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
||||
("uexp", DType::BF16) => contiguous::exp::BFLOAT,
|
||||
("ufloor", DType::F16) => contiguous::floor::HALF,
|
||||
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
||||
("ufloor", DType::BF16) => contiguous::floor::BFLOAT,
|
||||
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
|
||||
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
|
||||
("ugelu_erf", DType::BF16) => contiguous::gelu_erf::BFLOAT,
|
||||
("ugelu", DType::F16) => contiguous::gelu::HALF,
|
||||
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
||||
("ugelu", DType::BF16) => contiguous::gelu::BFLOAT,
|
||||
("ulog", DType::F16) => contiguous::log::HALF,
|
||||
("ulog", DType::F32) => contiguous::log::FLOAT,
|
||||
("ulog", DType::BF16) => contiguous::log::BFLOAT,
|
||||
("uneg", DType::F16) => contiguous::neg::HALF,
|
||||
("uneg", DType::F32) => contiguous::neg::FLOAT,
|
||||
("uneg", DType::BF16) => contiguous::neg::BFLOAT,
|
||||
("urecip", DType::F16) => contiguous::recip::HALF,
|
||||
("urecip", DType::F32) => contiguous::recip::FLOAT,
|
||||
("urecip", DType::BF16) => contiguous::recip::BFLOAT,
|
||||
("urelu", DType::F16) => contiguous::relu::HALF,
|
||||
("urelu", DType::F32) => contiguous::relu::FLOAT,
|
||||
("urelu", DType::BF16) => contiguous::relu::BFLOAT,
|
||||
("uround", DType::F16) => contiguous::round::HALF,
|
||||
("uround", DType::F32) => contiguous::round::FLOAT,
|
||||
("uround", DType::BF16) => contiguous::round::BFLOAT,
|
||||
("usilu", DType::F16) => contiguous::silu::HALF,
|
||||
("usilu", DType::F32) => contiguous::silu::FLOAT,
|
||||
("usilu", DType::BF16) => contiguous::silu::BFLOAT,
|
||||
("usin", DType::F16) => contiguous::sin::HALF,
|
||||
("usin", DType::F32) => contiguous::sin::FLOAT,
|
||||
("usin", DType::BF16) => contiguous::sin::BFLOAT,
|
||||
("usqr", DType::F16) => contiguous::sqr::HALF,
|
||||
("usqr", DType::F32) => contiguous::sqr::FLOAT,
|
||||
("usqr", DType::BF16) => contiguous::sqr::BFLOAT,
|
||||
("usqrt", DType::F16) => contiguous::sqrt::HALF,
|
||||
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
|
||||
("usqrt", DType::BF16) => contiguous::sqrt::BFLOAT,
|
||||
("utanh", DType::F16) => contiguous::tanh::HALF,
|
||||
("utanh", DType::F32) => contiguous::tanh::FLOAT,
|
||||
("utanh", DType::BF16) => contiguous::tanh::BFLOAT,
|
||||
("usign", DType::F16) => contiguous::sign::HALF,
|
||||
("usign", DType::F32) => contiguous::sign::FLOAT,
|
||||
("usign", DType::BF16) => contiguous::sign::BFLOAT,
|
||||
("usign", DType::I64) => contiguous::sign::I64,
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
|
||||
}
|
||||
};
|
||||
candle_metal_kernels::call_unary_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
(_, _, false) => {
|
||||
use candle_metal_kernels::unary::strided;
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("ucos", DType::F32) => strided::cos::FLOAT,
|
||||
("usin", DType::F32) => strided::sin::FLOAT,
|
||||
("usqr", DType::F32) => strided::sqr::FLOAT,
|
||||
("usqrt", DType::F32) => strided::sqrt::FLOAT,
|
||||
("uneg", DType::F32) => strided::neg::FLOAT,
|
||||
("uexp", DType::F32) => strided::exp::FLOAT,
|
||||
("ulog", DType::F32) => strided::log::FLOAT,
|
||||
("ugelu", DType::F32) => strided::gelu::FLOAT,
|
||||
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
|
||||
("uerf", DType::F32) => strided::erf::FLOAT,
|
||||
("usilu", DType::F32) => strided::silu::FLOAT,
|
||||
("uabs", DType::F32) => strided::abs::FLOAT,
|
||||
("uceil", DType::F32) => strided::ceil::FLOAT,
|
||||
("ufloor", DType::F32) => strided::floor::FLOAT,
|
||||
("urelu", DType::F32) => strided::relu::FLOAT,
|
||||
("uround", DType::F32) => strided::round::FLOAT,
|
||||
("utanh", DType::F32) => strided::tanh::FLOAT,
|
||||
|
||||
("ucos", DType::F16) => strided::cos::HALF,
|
||||
("usin", DType::F16) => strided::sin::HALF,
|
||||
("usqr", DType::F16) => strided::sqr::HALF,
|
||||
("usqrt", DType::F16) => strided::sqrt::HALF,
|
||||
("uneg", DType::F16) => strided::neg::HALF,
|
||||
("uexp", DType::F16) => strided::exp::HALF,
|
||||
("ulog", DType::F16) => strided::log::HALF,
|
||||
("ugelu", DType::F16) => strided::gelu::HALF,
|
||||
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
|
||||
("uerf", DType::F16) => strided::erf::HALF,
|
||||
("usilu", DType::F16) => strided::silu::HALF,
|
||||
("uabs", DType::F16) => strided::abs::HALF,
|
||||
("uceil", DType::F16) => strided::ceil::HALF,
|
||||
("ufloor", DType::F16) => strided::floor::HALF,
|
||||
("urelu", DType::F16) => strided::relu::HALF,
|
||||
("uround", DType::F16) => strided::round::HALF,
|
||||
("utanh", DType::F16) => strided::tanh::HALF,
|
||||
("ucos", DType::F16) => strided::cos::HALF,
|
||||
("usin", DType::F16) => strided::sin::HALF,
|
||||
("usqr", DType::F16) => strided::sqr::HALF,
|
||||
("usqrt", DType::F16) => strided::sqrt::HALF,
|
||||
("uneg", DType::F16) => strided::neg::HALF,
|
||||
("uexp", DType::F16) => strided::exp::HALF,
|
||||
("ulog", DType::F16) => strided::log::HALF,
|
||||
("ugelu", DType::F16) => strided::gelu::HALF,
|
||||
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
|
||||
("uerf", DType::F16) => strided::erf::HALF,
|
||||
("usilu", DType::F16) => strided::silu::HALF,
|
||||
("uabs", DType::F16) => strided::abs::HALF,
|
||||
("uceil", DType::F16) => strided::ceil::HALF,
|
||||
("ufloor", DType::F16) => strided::floor::HALF,
|
||||
("urelu", DType::F16) => strided::relu::HALF,
|
||||
("uround", DType::F16) => strided::round::HALF,
|
||||
("utanh", DType::F16) => strided::tanh::HALF,
|
||||
|
||||
("ucos", DType::BF16) => strided::cos::BFLOAT,
|
||||
("usin", DType::BF16) => strided::sin::BFLOAT,
|
||||
("usqr", DType::BF16) => strided::sqr::BFLOAT,
|
||||
("usqrt", DType::BF16) => strided::sqrt::BFLOAT,
|
||||
("uneg", DType::BF16) => strided::neg::BFLOAT,
|
||||
("uexp", DType::BF16) => strided::exp::BFLOAT,
|
||||
("ulog", DType::BF16) => strided::log::BFLOAT,
|
||||
("ugelu", DType::BF16) => strided::gelu::BFLOAT,
|
||||
("ugelu_erf", DType::BF16) => strided::gelu_erf::BFLOAT,
|
||||
("uerf", DType::BF16) => strided::erf::BFLOAT,
|
||||
("usilu", DType::BF16) => strided::silu::BFLOAT,
|
||||
("uabs", DType::BF16) => strided::abs::BFLOAT,
|
||||
("uceil", DType::BF16) => strided::ceil::BFLOAT,
|
||||
("ufloor", DType::BF16) => strided::floor::BFLOAT,
|
||||
("urelu", DType::BF16) => strided::relu::BFLOAT,
|
||||
("uround", DType::BF16) => strided::round::BFLOAT,
|
||||
("utanh", DType::BF16) => strided::tanh::BFLOAT,
|
||||
("ucos", DType::BF16) => strided::cos::BFLOAT,
|
||||
("usin", DType::BF16) => strided::sin::BFLOAT,
|
||||
("usqr", DType::BF16) => strided::sqr::BFLOAT,
|
||||
("usqrt", DType::BF16) => strided::sqrt::BFLOAT,
|
||||
("uneg", DType::BF16) => strided::neg::BFLOAT,
|
||||
("uexp", DType::BF16) => strided::exp::BFLOAT,
|
||||
("ulog", DType::BF16) => strided::log::BFLOAT,
|
||||
("ugelu", DType::BF16) => strided::gelu::BFLOAT,
|
||||
("ugelu_erf", DType::BF16) => strided::gelu_erf::BFLOAT,
|
||||
("uerf", DType::BF16) => strided::erf::BFLOAT,
|
||||
("usilu", DType::BF16) => strided::silu::BFLOAT,
|
||||
("uabs", DType::BF16) => strided::abs::BFLOAT,
|
||||
("uceil", DType::BF16) => strided::ceil::BFLOAT,
|
||||
("ufloor", DType::BF16) => strided::floor::BFLOAT,
|
||||
("urelu", DType::BF16) => strided::relu::BFLOAT,
|
||||
("uround", DType::BF16) => strided::round::BFLOAT,
|
||||
("utanh", DType::BF16) => strided::tanh::BFLOAT,
|
||||
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
|
||||
}
|
||||
};
|
||||
let dst = BufferOffset::zero_offset(&buffer);
|
||||
candle_metal_kernels::call_unary_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
layout.dims(),
|
||||
src,
|
||||
layout.stride(),
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
|
||||
}
|
||||
};
|
||||
let dst = BufferOffset::zero_offset(&buffer);
|
||||
candle_metal_kernels::call_unary_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
layout.dims(),
|
||||
src,
|
||||
layout.stride(),
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self::new(buffer, device.clone(), el_count, dtype))
|
||||
}
|
||||
|
||||
@ -630,6 +720,7 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
let name = match (self.dtype, t.dtype()) {
|
||||
(DType::U8, DType::F32) => "where_u8_f32",
|
||||
(DType::U32, DType::F32) => "where_u32_f32",
|
||||
(DType::U8, DType::BF16) => "where_u8_bf16",
|
||||
(DType::U8, DType::F16) => "where_u8_f16",
|
||||
(DType::U8, DType::I64) => "where_u8_i64",
|
||||
@ -736,44 +827,107 @@ impl BackendStorage for MetalStorage {
|
||||
k_layout: &Layout,
|
||||
params: &ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
const USE_COL2IM_CONV1D_TR: bool = true;
|
||||
|
||||
let can_use_col2im = k_layout.is_contiguous()
|
||||
&& params.dilation == 1
|
||||
&& params.padding == 0
|
||||
&& params.output_padding == 0;
|
||||
let l_out = params.l_out();
|
||||
let dst_el = params.c_out * l_out * params.b_size;
|
||||
let buffer = self
|
||||
.device
|
||||
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
|
||||
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "conv_transpose1d_f32",
|
||||
DType::F16 => "conv_transpose1d_f16",
|
||||
DType::BF16 => "conv_transpose1d_bf16",
|
||||
DType::U32 => "conv_transpose1d_u32",
|
||||
DType::U8 => "conv_transpose1d_u8",
|
||||
dtype => crate::bail!("Metal conv_transpose1d {dtype:?} not implemented"),
|
||||
let buffer = if USE_COL2IM_CONV1D_TR && can_use_col2im {
|
||||
let (b_size, c_in, l_in) = layout.shape().dims3()?;
|
||||
let (c_in2, c_out, k_size) = k_layout.shape().dims3()?;
|
||||
if c_in != c_in2 {
|
||||
crate::bail!(
|
||||
"convtr1d: shape mismatch on c_in {:?} {:?}",
|
||||
layout.shape(),
|
||||
k_layout.shape()
|
||||
)
|
||||
}
|
||||
let buffer = self
|
||||
.device
|
||||
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
|
||||
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "col2im1d_f32",
|
||||
DType::U32 => "col2im1d_u32",
|
||||
DType::U8 => "col2im1d_u8",
|
||||
dtype => crate::bail!("metal col2im1d {dtype:?} not implemented"),
|
||||
};
|
||||
let col = {
|
||||
// This merges the last two dimensions of the kernel together.
|
||||
let kernel_l_mm = Layout::new(
|
||||
(b_size, c_in, k_size * c_out).into(),
|
||||
vec![0, k_size * c_out, 1],
|
||||
k_layout.start_offset(),
|
||||
);
|
||||
self.matmul(
|
||||
k,
|
||||
(b_size, l_in, c_out * k_size, c_in),
|
||||
&layout.transpose(1, 2)?,
|
||||
&kernel_l_mm,
|
||||
)?
|
||||
};
|
||||
// It is important for the command buffer to be obtained *after* the matmul
|
||||
// kernel has run, otherwise we might use a command-buffer that has been commited
|
||||
// already resulting in the following error.
|
||||
// _status < MTLCommandBufferStatusCommitted >
|
||||
// -[IOGPUMetalCommandBuffer setCurrentCommandEncoder:]
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
candle_metal_kernels::call_col2im1d(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
&[b_size, l_in, c_out, k_size],
|
||||
params.k_size,
|
||||
params.stride,
|
||||
BufferOffset::zero_offset(&col.buffer),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
buffer
|
||||
} else {
|
||||
let buffer = self
|
||||
.device
|
||||
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
|
||||
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "conv_transpose1d_f32",
|
||||
DType::F16 => "conv_transpose1d_f16",
|
||||
DType::BF16 => "conv_transpose1d_bf16",
|
||||
DType::U32 => "conv_transpose1d_u32",
|
||||
DType::U8 => "conv_transpose1d_u8",
|
||||
dtype => crate::bail!("Metal conv_transpose1d {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_conv_transpose1d(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
params.dilation,
|
||||
params.stride,
|
||||
params.padding,
|
||||
params.output_padding,
|
||||
params.c_out,
|
||||
l_out,
|
||||
params.b_size,
|
||||
layout.dims(),
|
||||
layout.stride(),
|
||||
k_layout.dims(),
|
||||
k_layout.stride(),
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&k.buffer,
|
||||
k_layout.start_offset() * k.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
buffer
|
||||
};
|
||||
candle_metal_kernels::call_conv_transpose1d(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
params.dilation,
|
||||
params.stride,
|
||||
params.padding,
|
||||
params.output_padding,
|
||||
params.c_out,
|
||||
l_out,
|
||||
params.b_size,
|
||||
layout.dims(),
|
||||
layout.stride(),
|
||||
k_layout.dims(),
|
||||
k_layout.stride(),
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&k.buffer,
|
||||
k_layout.start_offset() * k.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
|
||||
}
|
||||
|
||||
@ -1255,6 +1409,7 @@ impl BackendStorage for MetalStorage {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "sgemm",
|
||||
DType::F16 => "hgemm",
|
||||
DType::BF16 => "bgemm",
|
||||
dtype => {
|
||||
return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into())
|
||||
}
|
||||
@ -1702,6 +1857,19 @@ impl BackendDevice for MetalDevice {
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
}
|
||||
|
||||
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||
let (count, buffer) = match T::cpu_storage_ref(s) {
|
||||
CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorageRef::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorageRef::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorageRef::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorageRef::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorageRef::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorageRef::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
};
|
||||
Ok(Self::Storage::new(buffer?, self.clone(), count, T::DTYPE))
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
||||
let (count, buffer) = match storage {
|
||||
CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
|
@ -330,7 +330,7 @@ impl Tensor {
|
||||
path: P,
|
||||
) -> Result<()> {
|
||||
let mut zip = zip::ZipWriter::new(File::create(path.as_ref())?);
|
||||
let options =
|
||||
let options: zip::write::FileOptions<()> =
|
||||
zip::write::FileOptions::default().compression_method(zip::CompressionMethod::Stored);
|
||||
|
||||
for (name, tensor) in ts.iter() {
|
||||
|
@ -2,6 +2,7 @@ use super::{GgmlDType, QStorage};
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
|
||||
use crate::{CudaDevice, CudaStorage, Result};
|
||||
use half::f16;
|
||||
|
||||
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
|
||||
|
||||
@ -59,7 +60,7 @@ fn quantize_q8_1(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn dequantize(
|
||||
fn dequantize_f32(
|
||||
data: &CudaSlice<u8>,
|
||||
dtype: GgmlDType,
|
||||
elem_count: usize,
|
||||
@ -69,27 +70,27 @@ fn dequantize(
|
||||
|
||||
let nb = (elem_count + 255) / 256;
|
||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
|
||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
|
||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
|
||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb),
|
||||
GgmlDType::Q5_0 => (
|
||||
"dequantize_block_q5_0",
|
||||
"dequantize_block_q5_0_f32",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||
),
|
||||
GgmlDType::Q5_1 => (
|
||||
"dequantize_block_q5_1",
|
||||
"dequantize_block_q5_1_f32",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||
),
|
||||
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
|
||||
GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
|
||||
GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
|
||||
GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32, nb),
|
||||
GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64, nb),
|
||||
GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64, nb),
|
||||
GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32, nb),
|
||||
GgmlDType::Q8_0 => ("dequantize_block_q8_0_f32", false, 32, nb),
|
||||
GgmlDType::Q2K => ("dequantize_block_q2_K_f32", true, 64, nb),
|
||||
GgmlDType::Q3K => ("dequantize_block_q3_K_f32", true, 64, nb),
|
||||
GgmlDType::Q4K => ("dequantize_block_q4_K_f32", true, 32, nb),
|
||||
GgmlDType::Q5K => ("dequantize_block_q5_K_f32", true, 64, nb),
|
||||
GgmlDType::Q6K => ("dequantize_block_q6_K_f32", true, 64, nb),
|
||||
GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb),
|
||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
@ -116,6 +117,63 @@ fn dequantize(
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
fn dequantize_f16(
|
||||
data: &CudaSlice<u8>,
|
||||
dtype: GgmlDType,
|
||||
elem_count: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let nb = (elem_count + 255) / 256;
|
||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
|
||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb),
|
||||
GgmlDType::Q5_0 => (
|
||||
"dequantize_block_q5_0_f16",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||
),
|
||||
GgmlDType::Q5_1 => (
|
||||
"dequantize_block_q5_1_f16",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||
),
|
||||
GgmlDType::Q8_0 => ("dequantize_block_q8_0_f16", false, 32, nb),
|
||||
GgmlDType::Q2K => ("dequantize_block_q2_K_f16", true, 64, nb),
|
||||
GgmlDType::Q3K => ("dequantize_block_q3_K_f16", true, 64, nb),
|
||||
GgmlDType::Q4K => ("dequantize_block_q4_K_f16", true, 32, nb),
|
||||
GgmlDType::Q5K => ("dequantize_block_q5_K_f16", true, 64, nb),
|
||||
GgmlDType::Q6K => ("dequantize_block_q6_K_f16", true, 64, nb),
|
||||
GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb),
|
||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
|
||||
// See e.g.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (num_blocks as u32, 1, 1),
|
||||
block_dim: (block_dim as u32, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
if is_k {
|
||||
let params = (data, &dst);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
} else {
|
||||
let nb32 = match dtype {
|
||||
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
||||
_ => elem_count / 32,
|
||||
};
|
||||
let params = (data, &dst, nb32 as i32);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
}
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
fn dequantize_mul_mat_vec(
|
||||
data: &CudaSlice<u8>,
|
||||
y: &CudaView<f32>,
|
||||
@ -178,8 +236,8 @@ fn mul_mat_vec_via_q8_1(
|
||||
if y.len() != ncols * b_size {
|
||||
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
|
||||
}
|
||||
if b_size == 0 || b_size > 4 {
|
||||
crate::bail!("only bsize between 1 and 4 are supported, got {b_size}")
|
||||
if b_size == 0 || b_size > 8 {
|
||||
crate::bail!("only bsize between 1 and 8 are supported, got {b_size}")
|
||||
}
|
||||
// Start by quantizing y
|
||||
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
|
||||
@ -204,14 +262,16 @@ fn mul_mat_vec_via_q8_1(
|
||||
let kernel_name = format!("{kernel_name}{b_size}");
|
||||
let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
|
||||
let nblocks = if b_size == 1 {
|
||||
nrows as u32
|
||||
} else {
|
||||
(nrows as u32 + 1) / 2
|
||||
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
|
||||
let (nblocks, nwarps) = match b_size {
|
||||
1 => (nrows as u32, 4),
|
||||
2..=4 => ((nrows as u32 + 1) / 2, 4),
|
||||
5..=8 => ((nrows as u32 + 1) / 2, 2),
|
||||
_ => crate::bail!("unexpected bsize {b_size}"),
|
||||
};
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (nblocks, 1, 1),
|
||||
block_dim: (WARP_SIZE as u32, 4, 1),
|
||||
block_dim: (WARP_SIZE as u32, nwarps, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
@ -339,7 +399,7 @@ impl QCudaStorage {
|
||||
| GgmlDType::Q8K
|
||||
);
|
||||
if fast_kernel {
|
||||
return dequantize(&self.data, self.dtype, elem_count, self.device());
|
||||
return dequantize_f32(&self.data, self.dtype, elem_count, self.device());
|
||||
}
|
||||
// Run the dequantization on cpu.
|
||||
|
||||
@ -367,6 +427,10 @@ impl QCudaStorage {
|
||||
.storage_from_cpu_storage(&crate::CpuStorage::F32(out))
|
||||
}
|
||||
|
||||
pub fn dequantize_f16(&self, elem_count: usize) -> Result<CudaStorage> {
|
||||
dequantize_f16(&self.data, self.dtype, elem_count, self.device())
|
||||
}
|
||||
|
||||
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
|
||||
// Run the quantization on cpu.
|
||||
let src = match &src.slice {
|
||||
@ -398,7 +462,7 @@ impl QCudaStorage {
|
||||
let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
||||
1
|
||||
} else {
|
||||
4
|
||||
8
|
||||
};
|
||||
let use_vec_kernel = match layout.shape().dims() {
|
||||
[b, m, _k] => b * m <= max_bm,
|
||||
|
@ -24,6 +24,10 @@ impl QCudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub fn dequantize_f16(&self, _elem_count: usize) -> Result<CudaStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
@ -135,7 +135,6 @@ pub enum ValueType {
|
||||
// The value is a UTF-8 non-null-terminated string, with length prepended.
|
||||
String,
|
||||
// The value is an array of other values, with the length and type prepended.
|
||||
///
|
||||
// Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes.
|
||||
Array,
|
||||
}
|
||||
@ -218,10 +217,16 @@ impl Value {
|
||||
}
|
||||
}
|
||||
|
||||
/// This will also automatically upcast any integral types which will not truncate.
|
||||
pub fn to_u64(&self) -> Result<u64> {
|
||||
match self {
|
||||
Self::U64(v) => Ok(*v),
|
||||
v => crate::bail!("not a u64 {v:?}"),
|
||||
// Autoupcast cases here
|
||||
Self::U8(v) => Ok(*v as u64),
|
||||
Self::U16(v) => Ok(*v as u64),
|
||||
Self::U32(v) => Ok(*v as u64),
|
||||
Self::Bool(v) => Ok(*v as u64),
|
||||
v => crate::bail!("not a u64 or upcastable to u64 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -152,9 +152,9 @@ impl QMetalStorage {
|
||||
// We always use a single batch dimension and stack all the tensors in the batch on the
|
||||
// second dimension as the implementation in candle-metal-kernels doesn't handle batch
|
||||
// properly.
|
||||
let (b, m) = match dst_shape.len() {
|
||||
3 => (1, dst_shape[0] * dst_shape[1]),
|
||||
2 => (1, dst_shape[0]),
|
||||
let m = match dst_shape.len() {
|
||||
3 => dst_shape[0] * dst_shape[1],
|
||||
2 => dst_shape[0],
|
||||
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
||||
};
|
||||
let last_k = dst_shape.pop().unwrap();
|
||||
@ -166,18 +166,23 @@ impl QMetalStorage {
|
||||
let device = storage.device().clone();
|
||||
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
|
||||
let command_buffer = device.command_buffer()?;
|
||||
candle_metal_kernels::call_quantized_matmul_t(
|
||||
device.device(),
|
||||
&command_buffer,
|
||||
device.kernels(),
|
||||
self.dtype.into(),
|
||||
(b, m, n, k),
|
||||
storage.buffer(),
|
||||
layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||
&self.buffer,
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
// In some cases it would be better to use the mm variant, though it has its drawbacks
|
||||
// around memory alignemnt.
|
||||
for batch_id in 0..m {
|
||||
candle_metal_kernels::call_quantized_matmul_mv_t(
|
||||
device.device(),
|
||||
&command_buffer,
|
||||
device.kernels(),
|
||||
self.dtype.into(),
|
||||
(1, 1, n, k),
|
||||
storage.buffer(),
|
||||
(layout.start_offset() + batch_id * k) * storage.dtype().size_in_bytes(),
|
||||
&self.buffer,
|
||||
batch_id * n * DType::F32.size_in_bytes(),
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32);
|
||||
Ok((dst_storage, dst_shape))
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
|
||||
use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
|
||||
use k_quants::*;
|
||||
use std::borrow::Cow;
|
||||
|
||||
@ -360,9 +360,24 @@ impl QTensor {
|
||||
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
||||
let storage = self.storage.dequantize(self.shape.elem_count())?;
|
||||
let none = crate::op::BackpropOp::none();
|
||||
let is_variable = false;
|
||||
crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable)
|
||||
.to_device(device)
|
||||
crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device)
|
||||
}
|
||||
|
||||
pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> {
|
||||
// In the CUDA case, we have a specialized kernel as this can be useful for volta
|
||||
// architectures. https://github.com/huggingface/candle/issues/2136
|
||||
match &self.storage {
|
||||
QStorage::Cuda(s) => {
|
||||
let s = s.dequantize_f16(self.shape.elem_count())?;
|
||||
let none = crate::op::BackpropOp::none();
|
||||
crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false)
|
||||
.to_device(device)
|
||||
}
|
||||
_ => {
|
||||
let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?;
|
||||
Ok(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
@ -378,6 +393,7 @@ impl QTensor {
|
||||
pub enum QMatMul {
|
||||
QTensor(std::sync::Arc<QTensor>),
|
||||
Tensor(Tensor),
|
||||
TensorF16(Tensor),
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
@ -391,6 +407,17 @@ thread_local! {
|
||||
}
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static DEQUANTIZE_ALL_F16: bool = {
|
||||
match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") {
|
||||
Ok(s) => {
|
||||
!s.is_empty() && s != "0"
|
||||
},
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl QMatMul {
|
||||
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
|
||||
let dequantize = match qtensor.dtype() {
|
||||
@ -400,6 +427,9 @@ impl QMatMul {
|
||||
let t = if dequantize {
|
||||
let tensor = qtensor.dequantize(&qtensor.device())?;
|
||||
Self::Tensor(tensor)
|
||||
} else if DEQUANTIZE_ALL_F16.with(|b| *b) {
|
||||
let tensor = qtensor.dequantize_f16(&qtensor.device())?;
|
||||
Self::TensorF16(tensor)
|
||||
} else {
|
||||
Self::QTensor(qtensor)
|
||||
};
|
||||
@ -409,6 +439,25 @@ impl QMatMul {
|
||||
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
|
||||
Self::from_arc(std::sync::Arc::new(qtensor))
|
||||
}
|
||||
|
||||
pub fn dequantize_f16(&self) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::QTensor(t) => t.dequantize_f16(&t.device()),
|
||||
Self::Tensor(t) => t.to_dtype(DType::F16),
|
||||
Self::TensorF16(t) => Ok(t.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_via_f16(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let w = self.dequantize_f16()?;
|
||||
let in_dtype = xs.dtype();
|
||||
let w = match *xs.dims() {
|
||||
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
|
||||
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
|
||||
_ => w.t()?,
|
||||
};
|
||||
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::CustomOp1 for QTensor {
|
||||
@ -486,6 +535,15 @@ impl crate::Module for QMatMul {
|
||||
};
|
||||
xs.matmul(&w)
|
||||
}
|
||||
Self::TensorF16(w) => {
|
||||
let in_dtype = xs.dtype();
|
||||
let w = match *xs.dims() {
|
||||
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
|
||||
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
|
||||
_ => w.t()?,
|
||||
};
|
||||
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -349,6 +349,30 @@ impl MmapedSafetensors {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SliceSafetensors<'a> {
|
||||
safetensors: SafeTensors<'a>,
|
||||
}
|
||||
|
||||
impl<'a> SliceSafetensors<'a> {
|
||||
/// Creates a wrapper around a binary buffer and deserialize the safetensors header.
|
||||
pub fn new(buffer: &'a [u8]) -> Result<Self> {
|
||||
let safetensors = safetensors::SafeTensors::deserialize(buffer)?;
|
||||
Ok(Self { safetensors })
|
||||
}
|
||||
|
||||
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
|
||||
self.safetensors.tensor(name)?.load(dev)
|
||||
}
|
||||
|
||||
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
|
||||
self.safetensors.tensors()
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
|
||||
Ok(self.safetensors.tensor(name)?)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BufferedSafetensors {
|
||||
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
|
||||
}
|
||||
|
239
candle-core/src/sort.rs
Normal file
239
candle-core/src/sort.rs
Normal file
@ -0,0 +1,239 @@
|
||||
use crate::{Result, Tensor};
|
||||
use rayon::prelude::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct ArgSort {
|
||||
asc: bool,
|
||||
last_dim: usize,
|
||||
}
|
||||
|
||||
impl ArgSort {
|
||||
fn asort<T: crate::WithDType>(&self, vs: &[T], layout: &crate::Layout) -> Vec<u32> {
|
||||
#[allow(clippy::uninit_vec)]
|
||||
// Safety: indexes are set later in the parallelized section.
|
||||
let mut sort_indexes = unsafe {
|
||||
let el_count = layout.shape().elem_count();
|
||||
let mut v = Vec::with_capacity(el_count);
|
||||
v.set_len(el_count);
|
||||
v
|
||||
};
|
||||
if self.asc {
|
||||
sort_indexes
|
||||
.par_chunks_exact_mut(self.last_dim)
|
||||
.zip(vs.par_chunks_exact(self.last_dim))
|
||||
.for_each(|(indexes, vs)| {
|
||||
indexes
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, v)| *v = i as u32);
|
||||
indexes.sort_by(|&i, &j| {
|
||||
vs[i as usize]
|
||||
.partial_cmp(&vs[j as usize])
|
||||
.unwrap_or(std::cmp::Ordering::Greater)
|
||||
})
|
||||
});
|
||||
} else {
|
||||
sort_indexes
|
||||
.par_chunks_exact_mut(self.last_dim)
|
||||
.zip(vs.par_chunks_exact(self.last_dim))
|
||||
.for_each(|(indexes, vs)| {
|
||||
indexes
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, v)| *v = i as u32);
|
||||
indexes.sort_by(|&j, &i| {
|
||||
vs[i as usize]
|
||||
.partial_cmp(&vs[j as usize])
|
||||
.unwrap_or(std::cmp::Ordering::Greater)
|
||||
})
|
||||
});
|
||||
}
|
||||
sort_indexes
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::CustomOp1 for ArgSort {
|
||||
fn name(&self) -> &'static str {
|
||||
"argsort"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
storage: &crate::CpuStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::CpuStorage, crate::Shape)> {
|
||||
let sort_indexes = match storage {
|
||||
crate::CpuStorage::U8(vs) => self.asort(vs, layout),
|
||||
crate::CpuStorage::U32(vs) => self.asort(vs, layout),
|
||||
crate::CpuStorage::I64(vs) => self.asort(vs, layout),
|
||||
crate::CpuStorage::BF16(vs) => self.asort(vs, layout),
|
||||
crate::CpuStorage::F16(vs) => self.asort(vs, layout),
|
||||
crate::CpuStorage::F32(vs) => self.asort(vs, layout),
|
||||
crate::CpuStorage::F64(vs) => self.asort(vs, layout),
|
||||
};
|
||||
let sort_indexes = crate::CpuStorage::U32(sort_indexes);
|
||||
Ok((sort_indexes, layout.shape().into()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
storage: &crate::CudaStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::CudaStorage, crate::Shape)> {
|
||||
use crate::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||
};
|
||||
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, Map1Any, WrapErr};
|
||||
use crate::{CudaDevice, WithDType};
|
||||
|
||||
impl Map1Any for ArgSort {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &crate::Layout,
|
||||
_wrap: W,
|
||||
) -> Result<S> {
|
||||
let slice = match layout.contiguous_offsets() {
|
||||
None => crate::bail!("input has to be contiguous"),
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
};
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
|
||||
let func = if self.asc {
|
||||
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
|
||||
} else {
|
||||
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
|
||||
};
|
||||
let ncols = self.last_dim;
|
||||
let nrows = elem_count / ncols;
|
||||
let ncols_pad = next_power_of_2(ncols);
|
||||
let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
|
||||
let cfg = LaunchConfig {
|
||||
grid_dim: (1, nrows as u32, 1),
|
||||
block_dim: (ncols_pad as u32, 1, 1),
|
||||
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
|
||||
};
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(S::U32(dst))
|
||||
}
|
||||
}
|
||||
|
||||
use crate::backend::BackendStorage;
|
||||
let dev = storage.device();
|
||||
let slice = self.map(&storage.slice, dev, layout)?;
|
||||
let dst = crate::cuda_backend::CudaStorage {
|
||||
slice,
|
||||
device: dev.clone(),
|
||||
};
|
||||
Ok((dst, layout.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &crate::MetalStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::MetalStorage, crate::Shape)> {
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::DType;
|
||||
|
||||
let name = {
|
||||
if self.asc {
|
||||
match storage.dtype() {
|
||||
DType::BF16 => "asort_asc_bf16",
|
||||
DType::F16 => "asort_asc_f16",
|
||||
DType::F32 => "asort_asc_f32",
|
||||
DType::F64 => "asort_asc_f64",
|
||||
DType::U8 => "asort_asc_u8",
|
||||
DType::U32 => "asort_asc_u32",
|
||||
DType::I64 => "asort_asc_i64",
|
||||
}
|
||||
} else {
|
||||
match storage.dtype() {
|
||||
DType::BF16 => "asort_desc_bf16",
|
||||
DType::F16 => "asort_desc_f16",
|
||||
DType::F32 => "asort_desc_f32",
|
||||
DType::F64 => "asort_desc_f64",
|
||||
DType::U8 => "asort_desc_u8",
|
||||
DType::U32 => "asort_desc_u32",
|
||||
DType::I64 => "asort_desc_i64",
|
||||
}
|
||||
}
|
||||
};
|
||||
let device = storage.device();
|
||||
let kernels = device.kernels();
|
||||
let command_buffer = device.command_buffer()?;
|
||||
let el = layout.shape().elem_count();
|
||||
let ncols = self.last_dim;
|
||||
let nrows = el / ncols;
|
||||
let src = crate::metal_backend::buffer_o(storage.buffer(), layout, storage.dtype());
|
||||
let dst = device.new_buffer(el, DType::U32, "asort")?;
|
||||
let mut ncols_pad = 1;
|
||||
while ncols_pad < ncols {
|
||||
ncols_pad *= 2;
|
||||
}
|
||||
candle_metal_kernels::call_arg_sort(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
kernels,
|
||||
name,
|
||||
nrows,
|
||||
ncols,
|
||||
ncols_pad,
|
||||
src,
|
||||
&dst,
|
||||
)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
let dst = crate::MetalStorage::new(dst, device.clone(), el, DType::U32);
|
||||
Ok((dst, layout.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn next_power_of_2(x: usize) -> usize {
|
||||
let mut n = 1;
|
||||
while n < x {
|
||||
n *= 2
|
||||
}
|
||||
n
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
/// Returns the indices that sort the tensor along the last dimension.
|
||||
///
|
||||
/// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
|
||||
/// descending order. The sort is unstable so there is no guarantees on the final order when it
|
||||
/// comes to ties.
|
||||
pub fn arg_sort_last_dim(&self, asc: bool) -> Result<Tensor> {
|
||||
if !self.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous {
|
||||
op: "arg_sort_last_dim",
|
||||
});
|
||||
}
|
||||
let last_dim = match self.dims().last() {
|
||||
None => crate::bail!("empty last-dim in arg-sort"),
|
||||
Some(last_dim) => *last_dim,
|
||||
};
|
||||
// No need for a backward pass for arg sort.
|
||||
self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
|
||||
}
|
||||
|
||||
/// Sorts the tensor along the last dimension, returns the sorted tensor together with the
|
||||
/// sorted indexes.
|
||||
///
|
||||
/// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
|
||||
/// descending order. The sort is unstable so there is no guarantees on the final order when it
|
||||
/// comes to ties.
|
||||
pub fn sort_last_dim(&self, asc: bool) -> Result<(Tensor, Tensor)> {
|
||||
if !self.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous {
|
||||
op: "sort_last_dim",
|
||||
});
|
||||
}
|
||||
let asort = self.arg_sort_last_dim(asc)?;
|
||||
let sorted = self.gather(&asort, crate::D::Minus1)?;
|
||||
Ok((sorted, asort))
|
||||
}
|
||||
}
|
@ -456,7 +456,15 @@ impl Tensor {
|
||||
shape: S,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
Self::new_impl(array, shape.into(), device, false)
|
||||
let shape = shape.into();
|
||||
let n: usize = shape.elem_count();
|
||||
let buffer_size: usize = array.len();
|
||||
if buffer_size != n {
|
||||
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
||||
}
|
||||
let storage = device.storage_from_slice(array)?;
|
||||
let none = BackpropOp::none();
|
||||
Ok(from_storage(storage, shape, none, false))
|
||||
}
|
||||
|
||||
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
||||
@ -582,9 +590,9 @@ impl Tensor {
|
||||
///
|
||||
/// * `args` - A slice of 1D tensors.
|
||||
/// * `xy_indexing` - Whether to use xy indexing or ij indexing. If xy is selected, the
|
||||
/// first dimension corresponds to the cardinality of the second input and the second
|
||||
/// dimension corresponds to the cardinality of the first input. If ij is selected, the
|
||||
/// dimensions are in the same order as the cardinality of the inputs.
|
||||
/// first dimension corresponds to the cardinality of the second input and the second
|
||||
/// dimension corresponds to the cardinality of the first input. If ij is selected, the
|
||||
/// dimensions are in the same order as the cardinality of the inputs.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
@ -2432,9 +2440,19 @@ impl Tensor {
|
||||
|
||||
/// Returns log(sum(exp(tensor), dim)).
|
||||
pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
|
||||
let exp = self.exp()?;
|
||||
let sum = exp.sum(sum_dims)?;
|
||||
sum.log()
|
||||
let sum_dims = sum_dims.to_indexes(self.shape(), "log-sum-exp")?;
|
||||
if sum_dims.is_empty() {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let max = sum_dims[1..]
|
||||
.iter()
|
||||
.try_fold(self.max_keepdim(sum_dims[0])?, |max, &dim| {
|
||||
max.max_keepdim(dim)
|
||||
})?;
|
||||
let exp = self.broadcast_sub(&max)?.exp()?;
|
||||
let sum = exp.sum(sum_dims.clone())?;
|
||||
|
||||
sum.log()? + max.squeeze_dims(&sum_dims)
|
||||
}
|
||||
|
||||
/// Pointwise pow operation.
|
||||
|
@ -235,4 +235,66 @@ impl Tensor {
|
||||
}
|
||||
Ok(crate::tensor::from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
/// Set the values on `self` using values from `src`. The copy starts at the specified
|
||||
/// `offset` for the target dimension `dim` on `self`.
|
||||
/// `self` and `src` must have the same shape except on dimension `dim` where the `self` size
|
||||
/// has to be greater than or equal to `offset` plus the `src` size.
|
||||
///
|
||||
/// Note that this modifies `self` in place and as such is not compatibel with
|
||||
/// back-propagation.
|
||||
pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
|
||||
let dim = dim.to_index(self.shape(), "slice-set")?;
|
||||
if !self.is_contiguous() || !src.is_contiguous() {
|
||||
Err(Error::RequiresContiguous { op: "slice-set" }.bt())?
|
||||
}
|
||||
if self.dtype() != src.dtype() {
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: self.dtype(),
|
||||
rhs: src.dtype(),
|
||||
op: "slice-set",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if self.device().location() != src.device().location() {
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: self.device().location(),
|
||||
rhs: src.device().location(),
|
||||
op: "slice-set",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if self.rank() != src.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: self.rank(),
|
||||
got: src.rank(),
|
||||
shape: self.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in self.dims().iter().zip(src.dims().iter()).enumerate() {
|
||||
if dim_idx == dim && *v2 + offset > *v1 {
|
||||
crate::bail!("shape mismatch on target dim, dst: {v1}, src: {v2} + {offset}")
|
||||
}
|
||||
if dim_idx != dim && v1 != v2 {
|
||||
crate::bail!("shape mismatch on dim {dim_idx}, {v1} <> {v2}")
|
||||
}
|
||||
}
|
||||
let block_size: usize = src.dims().iter().skip(1 + dim).product();
|
||||
let d1: usize = src.dims().iter().take(dim).product();
|
||||
let d2 = block_size * src.dims()[dim];
|
||||
let dst_o = self.layout().start_offset() + offset * block_size;
|
||||
let src_o = src.layout().start_offset();
|
||||
src.storage().copy2d(
|
||||
&mut self.storage_mut(),
|
||||
d1,
|
||||
d2,
|
||||
/* src_s */ d2,
|
||||
/* dst_s */ block_size * self.dims()[dim],
|
||||
src_o,
|
||||
dst_o,
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -34,9 +34,14 @@ impl Var {
|
||||
Ok(Self(inner))
|
||||
}
|
||||
|
||||
// Convert a tensor to a variable, if the tensor is already a variable then it is returned as is.
|
||||
pub fn from_tensor(t: &Tensor) -> Result<Self> {
|
||||
let inner = t.make_var()?;
|
||||
Ok(Self(inner))
|
||||
if t.is_variable() {
|
||||
Ok(Self(t.clone()))
|
||||
} else {
|
||||
let inner = t.make_var()?;
|
||||
Ok(Self(inner))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rand_f64<S: Into<Shape>>(
|
||||
|
@ -730,6 +730,103 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||
]
|
||||
]
|
||||
);
|
||||
|
||||
// Test the same, but then with the following properties, t & w are unmodified.
|
||||
let padding = 1;
|
||||
let outpadding = 1;
|
||||
let dilation = 1;
|
||||
let stride = 2;
|
||||
|
||||
let res = t.conv_transpose2d(&w, padding, outpadding, stride, dilation)?;
|
||||
let loss = res.sqr()?.sum_all()?;
|
||||
assert_eq!(test_utils::to_vec0_round(&loss, 0)?, 3627.0); // torch gives 3626.8560
|
||||
|
||||
let grads = loss.backward()?;
|
||||
|
||||
let grad_t = grads.get(&t).unwrap();
|
||||
let grad_w = grads.get(&w).unwrap();
|
||||
assert_eq!(grad_t.dims(), [1, 4, 7, 5]);
|
||||
assert_eq!(grad_w.dims(), [4, 2, 3, 5]);
|
||||
|
||||
#[rustfmt::skip]
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_t.i(0)?, 1)?,
|
||||
[
|
||||
[
|
||||
[ 13.2, -40.7, -9.7, -47.3, -82.7],
|
||||
[ -98.2, 9.7, 57.7, -6.2, 180.7],
|
||||
[ 100.2, 24.1, 3.7, -100.5, -48.1],
|
||||
[ -0.3, 13.5, -2.9, 80.0, -49.8],
|
||||
[ 47.2, -25.6, -74.4, 61.2, -18.4],
|
||||
[ 4.6, -69.5, 27.9, 66.5, -88.1],
|
||||
// 4th column on next row; torch is 4.2
|
||||
[ -12.0, 79.2, -40.0, 4.1, -97.1],
|
||||
],
|
||||
[
|
||||
[ -42.2, -36.5, -51.1, 7.5, 32.3],
|
||||
[ 74.1, -44.6, -68.8, 19.5, 7.7],
|
||||
[ 137.1, 54.2, 153.8, -58.0, 45.5],
|
||||
[ 24.4, -56.8, 9.7, -41.0, -14.5],
|
||||
[ -3.7, 72.6, 8.3, 134.8, 40.5],
|
||||
[ 43.2, -56.9, -47.5, -89.4, -95.4],
|
||||
[ 68.2, 108.1, -80.0, 57.0, -121.1]
|
||||
],
|
||||
[
|
||||
[ 31.1, -11.4, -34.8, 33.1, -44.2],
|
||||
[ 29.4, -31.6, -40.2, 13.7, 13.1],
|
||||
[ -0.8, -83.8, -7.8, -17.3, 78.2],
|
||||
[ 12.0, -118.7, 137.5, -76.7, 50.8],
|
||||
[ -28.7, -114.2, -3.7, -96.3, -13.8],
|
||||
[ -31.8, 28.5, -14.3, 4.6, 13.4],
|
||||
[ 28.0, -0.2, -38.9, -29.7, -59.0]
|
||||
],
|
||||
[
|
||||
[ -16.8, 38.5, 15.5, 26.6, 48.9],
|
||||
[ 14.5, 49.6, -24.8, 65.6, 61.7],
|
||||
[ 22.1, -64.7, -4.3, -51.0, 36.3],
|
||||
[ 31.0, -88.9, 47.1, -123.5, -3.8],
|
||||
[ -14.8, -39.8, 128.2, -110.3, 42.6],
|
||||
// 1st column on next row; torch is -7.2
|
||||
[ -7.1, 95.3, -21.3, -58.7, -13.9],
|
||||
[ 26.9, 21.3, 16.1, 70.3, 32.1]
|
||||
]
|
||||
]
|
||||
);
|
||||
|
||||
#[rustfmt::skip]
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&grad_w.flatten_all()?, 1)?,
|
||||
[
|
||||
// 2nd value; torch gets -3.2, 3rd value; torch gets 221.8
|
||||
-2.460e+01, -3.100e+00, 2.219e+02, 7.400e+00, 5.620e+01,
|
||||
7.420e+01, 7.830e+01, 8.900e+00, 1.050e+01, 2.810e+01,
|
||||
5.100e+00, -1.046e+02, -1.572e+02, 8.710e+01, -9.840e+01,
|
||||
-4.230e+01, -1.898e+02, 1.860e+01, -3.570e+01, 9.810e+01,
|
||||
4.680e+01, 1.182e+02, 4.020e+01, -1.900e+00, 1.508e+02,
|
||||
1.094e+02, 1.018e+02, -4.620e+01, 1.591e+02, -2.320e+01,
|
||||
// 5th value; torch gets 7.1
|
||||
-8.450e+01, -4.600e+00, 6.330e+01, 1.123e+02, -7.000e+00,
|
||||
1.101e+02, -6.620e+01, 2.090e+01, -5.120e+01, 8.990e+01,
|
||||
9.050e+01, -6.990e+01, 6.800e+01, -9.250e+01, 1.380e+02,
|
||||
4.720e+01, 4.710e+01, 6.210e+01, 8.870e+01, 2.098e+02,
|
||||
3.870e+01, -1.390e+01, 6.270e+01, 1.484e+02, -9.920e+01,
|
||||
-4.200e+01, -1.505e+02, -1.480e+01, -2.620e+01, 8.220e+01,
|
||||
-3.350e+01, -2.260e+01, -1.198e+02, -5.080e+01, 1.259e+02,
|
||||
5.600e+01, 9.270e+01, 1.209e+02, 6.590e+01, -8.330e+01,
|
||||
7.000e+00, -2.600e+01, -1.133e+02, 3.870e+01, 4.020e+01,
|
||||
-6.300e+00, -8.710e+01, -5.150e+01, -8.510e+01, 2.000e-01,
|
||||
3.640e+01, -6.100e+00, 6.590e+01, -2.700e+00, 6.550e+01,
|
||||
// 4th value; torch gets 3.8
|
||||
5.300e+00, -6.760e+01, -4.270e+01, -3.900e+00, 2.880e+01,
|
||||
5.260e+01, 6.170e+01, -1.203e+02, -1.610e+01, 7.740e+01,
|
||||
-1.008e+02, -1.070e+01, -9.900e+00, 3.300e+00, -2.620e+01,
|
||||
-4.440e+01, 2.580e+01, -6.920e+01, -4.220e+01, 1.108e+02,
|
||||
1.240e+01, -3.440e+01, -2.800e+00, 7.880e+01, -6.690e+01,
|
||||
1.480e+01, 2.310e+01, -4.260e+01, -1.500e+00, -4.760e+01,
|
||||
5.350e+01, -2.260e+01, 8.000e-01, -3.840e+01, -2.500e+00
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -49,6 +49,20 @@ fn matmul(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn matmul_bf16(device: &Device) -> Result<()> {
|
||||
if !device.supports_bf16() {
|
||||
return Ok(());
|
||||
}
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let a = Tensor::from_slice(&data, (2, 2), device)?.to_dtype(DType::BF16)?;
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let b = Tensor::from_slice(&data, (2, 2), device)?.to_dtype(DType::BF16)?;
|
||||
|
||||
let c = a.matmul(&b)?.to_dtype(DType::F32)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn broadcast_matmul(device: &Device) -> Result<()> {
|
||||
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
|
||||
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
|
||||
@ -96,6 +110,12 @@ fn mm_layout(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
||||
test_device!(
|
||||
matmul_bf16,
|
||||
matmul_bf16_cpu,
|
||||
matmul_bf16_gpu,
|
||||
matmul_bf16_metal
|
||||
);
|
||||
test_device!(
|
||||
broadcast_matmul,
|
||||
broadcast_matmul_cpu,
|
||||
|
@ -3,7 +3,7 @@ use candle_core::{
|
||||
quantized::{self, GgmlDType},
|
||||
test_device,
|
||||
test_utils::to_vec2_round,
|
||||
Device, IndexOp, Module, Result, Tensor,
|
||||
DType, Device, IndexOp, Module, Result, Tensor,
|
||||
};
|
||||
use quantized::{k_quants, GgmlType};
|
||||
use rand::prelude::*;
|
||||
@ -193,17 +193,25 @@ fn qmm_batch(dev: &Device) -> Result<()> {
|
||||
let mm3 = rhs.forward(&lhs3)?;
|
||||
assert_eq!(mm3.shape().dims(), [6, 6]);
|
||||
let diff3 = (mm3.i(2..4)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
if dev.is_cuda() {
|
||||
assert!(diff3 < 1e-4)
|
||||
} else {
|
||||
assert_eq!(diff3, 0.0)
|
||||
};
|
||||
assert_eq!(diff3, 0.0);
|
||||
let diff3 = (mm3.i(4..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff3, 0.0);
|
||||
let lhs4 = Tensor::cat(&[&lhs3, &lhs3], 0)?;
|
||||
let mm4 = rhs.forward(&lhs4)?;
|
||||
assert_eq!(mm4.shape().dims(), [12, 6]);
|
||||
let diff4 = (mm4.i(..6)? - &mm3)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
if dev.is_cuda() {
|
||||
assert!(diff3 < 1e-4)
|
||||
// We use a different kernel for sizes from 1 to 8 on cuda which explains
|
||||
// the difference here.
|
||||
assert!(0. < diff4 && diff4 < 1e-4)
|
||||
} else {
|
||||
assert_eq!(diff3, 0.0)
|
||||
assert_eq!(diff4, 0.0)
|
||||
};
|
||||
let diff4 = (mm4.i(6..)? - &mm4.i(..6)?)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff4, 0.0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -217,6 +225,13 @@ fn quantize_q4_0(device: &Device) -> Result<()> {
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
assert_eq!(
|
||||
dst.to_vec1::<f32>()?,
|
||||
&[
|
||||
@ -243,6 +258,13 @@ fn quantize_q4_1(device: &Device) -> Result<()> {
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
assert_eq!(
|
||||
round_vector(&dst.to_vec1::<f32>()?),
|
||||
&[
|
||||
@ -269,6 +291,13 @@ fn quantize_q5_0(device: &Device) -> Result<()> {
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
assert_eq!(
|
||||
round_vector(&dst.to_vec1::<f32>()?),
|
||||
&[
|
||||
@ -295,6 +324,13 @@ fn quantize_q5_1(device: &Device) -> Result<()> {
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
assert_eq!(
|
||||
round_vector(&dst.to_vec1::<f32>()?),
|
||||
&[
|
||||
@ -379,6 +415,13 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3
|
||||
let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
let error = calculate_rmse(&src.to_vec1::<f32>()?, &dst.to_vec1::<f32>()?);
|
||||
if error > max_error {
|
||||
bail!(
|
||||
@ -396,6 +439,13 @@ fn quantize_q2k(device: &Device) -> Result<()> {
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
@ -415,6 +465,13 @@ fn quantize_q2k(device: &Device) -> Result<()> {
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
@ -429,6 +486,13 @@ fn quantize_q3k(device: &Device) -> Result<()> {
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
@ -448,6 +512,13 @@ fn quantize_q3k(device: &Device) -> Result<()> {
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
@ -462,6 +533,13 @@ fn quantize_q4k(device: &Device) -> Result<()> {
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
@ -481,6 +559,13 @@ fn quantize_q4k(device: &Device) -> Result<()> {
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
@ -495,6 +580,13 @@ fn quantize_q5k(device: &Device) -> Result<()> {
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
@ -514,6 +606,13 @@ fn quantize_q5k(device: &Device) -> Result<()> {
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
@ -528,6 +627,13 @@ fn quantize_q6k(device: &Device) -> Result<()> {
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
@ -547,6 +653,13 @@ fn quantize_q6k(device: &Device) -> Result<()> {
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
@ -561,6 +674,13 @@ fn quantize_q8k(device: &Device) -> Result<()> {
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
@ -580,6 +700,13 @@ fn quantize_q8k(device: &Device) -> Result<()> {
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
|
@ -1,5 +1,31 @@
|
||||
use candle_core::{DType, Result, Tensor};
|
||||
|
||||
struct TmpFile(std::path::PathBuf);
|
||||
|
||||
impl TmpFile {
|
||||
fn create(base: &str) -> TmpFile {
|
||||
let filename = std::env::temp_dir().join(format!(
|
||||
"candle-{}-{}-{:?}",
|
||||
base,
|
||||
std::process::id(),
|
||||
std::thread::current().id(),
|
||||
));
|
||||
TmpFile(filename)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::convert::AsRef<std::path::Path> for TmpFile {
|
||||
fn as_ref(&self) -> &std::path::Path {
|
||||
self.0.as_path()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TmpFile {
|
||||
fn drop(&mut self) {
|
||||
std::fs::remove_file(&self.0).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn npy() -> Result<()> {
|
||||
let npy = Tensor::read_npy("tests/test.npy")?;
|
||||
@ -22,3 +48,24 @@ fn npz() -> Result<()> {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn safetensors() -> Result<()> {
|
||||
use candle_core::safetensors::Load;
|
||||
|
||||
let tmp_file = TmpFile::create("st");
|
||||
let t = Tensor::arange(0f32, 24f32, &candle_core::Device::Cpu)?;
|
||||
t.save_safetensors("t", &tmp_file)?;
|
||||
// Load from file.
|
||||
let st = candle_core::safetensors::load(&tmp_file, &candle_core::Device::Cpu)?;
|
||||
let t2 = st.get("t").unwrap();
|
||||
let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0f32);
|
||||
// Load from bytes.
|
||||
let bytes = std::fs::read(tmp_file)?;
|
||||
let st = candle_core::safetensors::SliceSafetensors::new(&bytes)?;
|
||||
let t2 = st.get("t").unwrap().load(&candle_core::Device::Cpu);
|
||||
let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0f32);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -96,6 +96,40 @@ fn clamp(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn asort(device: &Device) -> Result<()> {
|
||||
let data = &[[3f32, 1., 4., 1.1, 5.], [2.1, 1., 7., 8., 2.]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
let indexes = tensor.arg_sort_last_dim(true)?;
|
||||
assert_eq!(
|
||||
indexes.to_vec2::<u32>()?,
|
||||
[[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]],
|
||||
);
|
||||
let indexes = tensor.arg_sort_last_dim(false)?;
|
||||
assert_eq!(
|
||||
indexes.to_vec2::<u32>()?,
|
||||
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
|
||||
);
|
||||
let (sorted, indexes) = tensor.sort_last_dim(true)?;
|
||||
assert_eq!(
|
||||
indexes.to_vec2::<u32>()?,
|
||||
[[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]],
|
||||
);
|
||||
assert_eq!(
|
||||
sorted.to_vec2::<f32>()?,
|
||||
[[1.0, 1.1, 3.0, 4.0, 5.0], [1.0, 2.0, 2.1, 7.0, 8.0]]
|
||||
);
|
||||
let (sorted, indexes) = tensor.sort_last_dim(false)?;
|
||||
assert_eq!(
|
||||
indexes.to_vec2::<u32>()?,
|
||||
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
|
||||
);
|
||||
assert_eq!(
|
||||
sorted.to_vec2::<f32>()?,
|
||||
[[5.0, 4.0, 3.0, 1.1, 1.0], [8.0, 7.0, 2.1, 2.0, 1.0]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn unary_op(device: &Device) -> Result<()> {
|
||||
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
@ -631,6 +665,30 @@ fn broadcast(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn slice_set(device: &Device) -> Result<()> {
|
||||
let (b, h, max_t, d) = (2, 4, 7, 3);
|
||||
let cache = Tensor::zeros((b, h, max_t, d), DType::F32, device)?;
|
||||
let tensor = Tensor::randn(0f32, 1f32, (b, h, 4, d), device)?;
|
||||
cache.slice_set(&tensor, 2, 0)?;
|
||||
let cache_t = cache.narrow(2, 0, 4)?;
|
||||
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
cache.slice_set(&tensor, 2, 1)?;
|
||||
let cache_t = cache.narrow(2, 1, 4)?;
|
||||
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
let ones = Tensor::ones((b, h, 1, d), DType::F32, device)?;
|
||||
cache.slice_set(&ones, 2, 6)?;
|
||||
let diff = cache.narrow(2, 5, 1)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
let diff = (cache.narrow(2, 6, 1)? - 1.)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cat(device: &Device) -> Result<()> {
|
||||
// 1D
|
||||
let t1 = Tensor::new(&[3f32, 1., 4.], device)?;
|
||||
@ -1112,6 +1170,7 @@ test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
||||
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
|
||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
|
||||
test_device!(slice_set, ss_cpu, ss_gpu, ss_metal);
|
||||
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
|
||||
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
|
||||
test_device!(min, min_cpu, min_gpu, min_metal);
|
||||
@ -1151,6 +1210,7 @@ test_device!(
|
||||
);
|
||||
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
||||
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
||||
test_device!(asort, asort_cpu, asort_gpu, asort_metal);
|
||||
test_device!(var, var_cpu, var_gpu, var_metal);
|
||||
test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal);
|
||||
|
||||
@ -1266,11 +1326,29 @@ fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
|
||||
|
||||
#[test]
|
||||
fn log_sum_exp() -> Result<()> {
|
||||
let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||
let input = Tensor::new(
|
||||
&[
|
||||
[[1f64, 2., 3.], [4., 5., 6.]],
|
||||
[[-1000.0, -999.0, -1001.0], [1000.0, 999.0, 1001.0]],
|
||||
],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let output = input.log_sum_exp(D::Minus1)?;
|
||||
// The expectations obtained from pytorch.
|
||||
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
|
||||
assert_close(&output, &expected, 0.00001)?;
|
||||
let expected = Tensor::new(&[[3.4076, 6.4076], [-998.5924, 1001.4076]], &Device::Cpu)?;
|
||||
assert_eq!(output.dims(), expected.dims());
|
||||
assert_close(&output.flatten_all()?, &expected.flatten_all()?, 0.00001)?;
|
||||
|
||||
assert_eq!(
|
||||
input.log_sum_exp((0, 1))?.to_vec1::<f64>()?,
|
||||
[1000.0, 999.0, 1001.0]
|
||||
);
|
||||
assert_eq!(
|
||||
input.log_sum_exp(())?.to_vec3::<f64>()?,
|
||||
input.to_vec3::<f64>()?
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -89,7 +89,7 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
|
||||
|
||||
pub fn load() -> Result<crate::vision::Dataset> {
|
||||
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||
let dataset_id = "mnist".to_string();
|
||||
let dataset_id = "ylecun/mnist".to_string();
|
||||
let repo = Repo::with_revision(
|
||||
dataset_id,
|
||||
RepoType::Dataset,
|
||||
|
@ -25,6 +25,8 @@ hf-hub = { workspace = true, features = ["tokio"] }
|
||||
image = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
palette = { version = "0.7.6", optional = true }
|
||||
enterpolation = { version = "0.2.1", optional = true}
|
||||
pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true }
|
||||
rayon = { workspace = true }
|
||||
rubato = { version = "0.15.0", optional = true }
|
||||
@ -33,7 +35,7 @@ serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
symphonia = { version = "0.5.3", features = ["all"], optional = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
cpal= { version = "0.15.2", optional = true }
|
||||
cpal = { version = "0.15.2", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
@ -65,6 +67,7 @@ onnx = ["candle-onnx"]
|
||||
metal = ["candle/metal", "candle-nn/metal"]
|
||||
microphone = ["cpal"]
|
||||
encodec = ["cpal", "symphonia", "rubato"]
|
||||
depth_anything_v2 = ["palette", "enterpolation"]
|
||||
|
||||
[[example]]
|
||||
name = "llama_multiprocess"
|
||||
@ -101,3 +104,7 @@ required-features = ["candle-datasets"]
|
||||
[[example]]
|
||||
name = "encodec"
|
||||
required-features = ["encodec"]
|
||||
|
||||
[[example]]
|
||||
name = "depth_anything_v2"
|
||||
required-features = ["depth_anything_v2"]
|
||||
|
20
candle-examples/examples/beit/README.md
Normal file
20
candle-examples/examples/beit/README.md
Normal file
@ -0,0 +1,20 @@
|
||||
# candle-beit
|
||||
|
||||
[Beit](https://arxiv.org/abs/2106.08254) is a computer vision model.
|
||||
In this example, it is used as an ImageNet classifier: the model returns the
|
||||
probability for the image to belong to each of the 1000 ImageNet categories.
|
||||
|
||||
## Running some example
|
||||
|
||||
```bash
|
||||
cargo run --example beit --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
|
||||
> mountain bike, all-terrain bike, off-roader: 56.16%
|
||||
> bicycle-built-for-two, tandem bicycle, tandem: 3.08%
|
||||
> maillot : 2.23%
|
||||
> alp : 0.88%
|
||||
> crash helmet : 0.85%
|
||||
|
||||
```
|
||||
|
||||

|
79
candle-examples/examples/beit/main.rs
Normal file
79
candle-examples/examples/beit/main.rs
Normal file
@ -0,0 +1,79 @@
|
||||
//! BEiT: BERT Pre-Training of Image Transformers
|
||||
//! https://github.com/microsoft/unilm/tree/master/beit
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::beit;
|
||||
|
||||
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||
/// (3, 384, 384). Beit special normalization is applied.
|
||||
pub fn load_image384_beit_norm<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||
let img = image::ImageReader::open(p)?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?
|
||||
.resize_to_fill(384, 384, image::imageops::FilterType::Triangle);
|
||||
let img = img.to_rgb8();
|
||||
let data = img.into_raw();
|
||||
let data = Tensor::from_vec(data, (384, 384, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
||||
let mean = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||
let std = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||
(data.to_dtype(candle::DType::F32)? / 255.)?
|
||||
.broadcast_sub(&mean)?
|
||||
.broadcast_div(&std)
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = load_image384_beit_norm(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("vincent-espitalier/candle-beit".into());
|
||||
api.get("beit_base_patch16_384.in22k_ft_in22k_in1k.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = beit::vit_base(vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -126,7 +126,7 @@ fn main() -> Result<()> {
|
||||
println!("Loaded and encoded {:?}", start.elapsed());
|
||||
for idx in 0..args.n {
|
||||
let start = std::time::Instant::now();
|
||||
let ys = model.forward(&token_ids, &token_type_ids)?;
|
||||
let ys = model.forward(&token_ids, &token_type_ids, None)?;
|
||||
if idx == 0 {
|
||||
println!("{ys}");
|
||||
}
|
||||
@ -163,11 +163,19 @@ fn main() -> Result<()> {
|
||||
Ok(Tensor::new(tokens.as_slice(), device)?)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let attention_mask = tokens
|
||||
.iter()
|
||||
.map(|tokens| {
|
||||
let tokens = tokens.get_attention_mask().to_vec();
|
||||
Ok(Tensor::new(tokens.as_slice(), device)?)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let token_ids = Tensor::stack(&token_ids, 0)?;
|
||||
let attention_mask = Tensor::stack(&attention_mask, 0)?;
|
||||
let token_type_ids = token_ids.zeros_like()?;
|
||||
println!("running inference on batch {:?}", token_ids.shape());
|
||||
let embeddings = model.forward(&token_ids, &token_type_ids)?;
|
||||
let embeddings = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
|
||||
println!("generated embeddings {:?}", embeddings.shape());
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||
|
@ -55,7 +55,7 @@ const SEP_TOKEN_ID: u32 = 102;
|
||||
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||
/// (3, 384, 384). OpenAI normalization is applied.
|
||||
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||
let img = image::io::Reader::open(p)?
|
||||
let img = image::ImageReader::open(p)?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?
|
||||
.resize_to_fill(384, 384, image::imageops::FilterType::Triangle);
|
||||
|
@ -1,4 +1,4 @@
|
||||
Contrastive Language-Image Pre-Training
|
||||
# candle-clip
|
||||
|
||||
Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
|
||||
pairs of images with related texts.
|
||||
|
@ -33,7 +33,7 @@ struct Args {
|
||||
}
|
||||
|
||||
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
|
||||
let img = image::io::Reader::open(path)?.decode()?;
|
||||
let img = image::ImageReader::open(path)?.decode()?;
|
||||
let (height, width) = (image_size, image_size);
|
||||
let img = img.resize_to_fill(
|
||||
width as u32,
|
||||
|
96
candle-examples/examples/codegeex4-9b/README.org
Normal file
96
candle-examples/examples/codegeex4-9b/README.org
Normal file
@ -0,0 +1,96 @@
|
||||
* candle-codegeex4_9b
|
||||
THUDM/CodeGeeX4 is a versatile model for all AI software development scenarios, including code completion, code interpreter, web search, function calling, repository-level Q&A and much more.
|
||||
|
||||
- [[https://github.com/THUDM/CodeGeeX4][Github]]
|
||||
- [[https://codegeex.cn/][HomePage]]
|
||||
- [[https://huggingface.co/THUDM/codegeex4-all-9b][huggingface]]
|
||||
|
||||
** Running with ~cuda~
|
||||
|
||||
#+begin_src shell
|
||||
cargo run --example codegeex4-9b --release --features cuda -- --prompt "please write a insertion sort in rust" --sample-len 300
|
||||
#+end_src
|
||||
|
||||
** Running with ~cpu~
|
||||
#+begin_src shell
|
||||
cargo run --example codegeex4-9b --release --cpu -- --prompt "please write a insertion sort in rust" --sample-len 300
|
||||
#+end_src
|
||||
|
||||
** Output_Example
|
||||
*** Input
|
||||
#+begin_src shell
|
||||
cargo run --release --features cuda -- --prompt 'please write a FFT in rust' --sample-len 500 --cache /root/autodl-tmp
|
||||
#+end_src
|
||||
|
||||
*** Output
|
||||
#+begin_src shell
|
||||
avx: false, neon: false, simd128: false, f16c: false
|
||||
temp: 0.95 repeat-penalty: 1.10 repeat-last-n: 64
|
||||
cache path /root/autodl-tmp
|
||||
Prompt: [please write a FFT in rust]
|
||||
Using Seed 11511762269791786684
|
||||
DType is BF16
|
||||
transofrmer layers create
|
||||
模型加载完毕 4
|
||||
starting the inference loop
|
||||
|
||||
开始生成
|
||||
samplelen 500
|
||||
|
||||
500 tokens generated (34.60 token/s)
|
||||
Result:
|
||||
|
||||
Sure, I can help you with that. Here's an example of a Fast Fourier Transform (FFT) implementation in Rust:
|
||||
|
||||
```rust
|
||||
use num_complex::Complex;
|
||||
|
||||
fn fft(input: &[Complex<f64> > ] ) -> Vec<Complex<f64> > > {
|
||||
let n = input.len();
|
||||
|
||||
if n == 1 {
|
||||
return vec![input[0]]];
|
||||
}
|
||||
|
||||
let mut even = vec![];
|
||||
let mut odd = vec![];
|
||||
|
||||
for i in 0..n {
|
||||
|
||||
if i % 2 == 0 {
|
||||
even.push(input[i]);
|
||||
} else {
|
||||
odd.push(input[i]);
|
||||
}
|
||||
}
|
||||
|
||||
let even_fft = fft(&even);
|
||||
let odd_fft = fft(&odd);
|
||||
|
||||
let mut output = vec![];
|
||||
|
||||
for k in 0..n/2 {
|
||||
let t = Complex::new(0.0, -2.0 * std::f64::consts::PI * (k as f64) / (n as f64))) ).exp();
|
||||
|
||||
output.push(even_fft[k] + odd_fft[k] * t]);
|
||||
output.push(even_fft[k] - odd_fft[k] * t]);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
```
|
||||
|
||||
This implementation uses the Cooley-Tukey algorithm to perform the FFT. The function takes an array of complex numbers and returns an array of complex numbers which is the result of the FFT.
|
||||
#+end_src
|
||||
|
||||
|
||||
* Citation
|
||||
#+begin_src
|
||||
@inproceedings{zheng2023codegeex,
|
||||
title={CodeGeeX: A Pre-Trained Model for Code Generation with Multilingual Benchmarking on HumanEval-X},
|
||||
author={Qinkai Zheng and Xiao Xia and Xu Zou and Yuxiao Dong and Shan Wang and Yufei Xue and Zihan Wang and Lei Shen and Andi Wang and Yang Li and Teng Su and Zhilin Yang and Jie Tang},
|
||||
booktitle={Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining},
|
||||
pages={5673--5684},
|
||||
year={2023}
|
||||
}
|
||||
#+end_src
|
252
candle-examples/examples/codegeex4-9b/main.rs
Normal file
252
candle-examples/examples/codegeex4-9b/main.rs
Normal file
@ -0,0 +1,252 @@
|
||||
use candle_transformers::models::codegeex4_9b::*;
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
device: &Device,
|
||||
dtype: DType,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
verbose_prompt,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> anyhow::Result<()> {
|
||||
use std::io::Write;
|
||||
println!("starting the inference loop");
|
||||
let tokens = self.tokenizer.encode(prompt, true).expect("tokens error");
|
||||
if tokens.is_empty() {
|
||||
panic!("Empty prompts are not supported in the chatglm model.")
|
||||
}
|
||||
if self.verbose_prompt {
|
||||
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
||||
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
println!("{id:7} -> '{token}'");
|
||||
}
|
||||
}
|
||||
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
||||
Some(token) => *token,
|
||||
None => panic!("cannot find the endoftext token"),
|
||||
};
|
||||
let mut tokens = tokens.get_ids().to_vec();
|
||||
let mut generated_tokens = 0usize;
|
||||
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().expect("output flush error");
|
||||
let start_gen = std::time::Instant::now();
|
||||
|
||||
println!("\n start_gen");
|
||||
println!("samplelen {}", sample_len);
|
||||
let mut count = 0;
|
||||
let mut result = vec![];
|
||||
for index in 0..sample_len {
|
||||
count += 1;
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input)?;
|
||||
let logits = logits.squeeze(0)?.to_dtype(self.dtype)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
let token = self
|
||||
.tokenizer
|
||||
.decode(&[next_token], true)
|
||||
.expect("Token error");
|
||||
if self.verbose_prompt {
|
||||
println!(
|
||||
"[Count: {}] [Raw Token: {}] [Decode Token: {}]",
|
||||
count, next_token, token
|
||||
);
|
||||
}
|
||||
result.push(token);
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
println!("Result:");
|
||||
for tokens in result {
|
||||
print!("{tokens}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(name = "cache", short, long, default_value = ".")]
|
||||
cache_path: String,
|
||||
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Display the token for the specified prompt.
|
||||
#[arg(long)]
|
||||
verbose_prompt: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 5000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.95),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
println!("cache path {}", args.cache_path);
|
||||
let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
|
||||
.build()
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id.to_string(),
|
||||
None => "THUDM/codegeex4-all-9b".to_string(),
|
||||
};
|
||||
let revision = match args.revision {
|
||||
Some(rev) => rev.to_string(),
|
||||
None => "main".to_string(),
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
let tokenizer_filename = match args.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => api
|
||||
.model("THUDM/codegeex4-all-9b".to_string())
|
||||
.get("tokenizer.json")
|
||||
.map_err(anyhow::Error::msg)?,
|
||||
};
|
||||
let filenames = match args.weight_file {
|
||||
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
||||
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error");
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::codegeex4();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
args.verbose_prompt,
|
||||
&device,
|
||||
dtype,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
13
candle-examples/examples/depth_anything_v2/README.md
Normal file
13
candle-examples/examples/depth_anything_v2/README.md
Normal file
@ -0,0 +1,13 @@
|
||||
# candle-dinov2
|
||||
|
||||
[Depth Anything V2] is a model for Monocular Depth Estimation (MDE, i.e. just using a single image) which
|
||||
builds on the [DINOv2](https://github.com/facebookresearch/dinov2) vision transformer.
|
||||
|
||||
This example first instantiates the DINOv2 model and then proceeds to create DepthAnythingV2 and run it.
|
||||
|
||||
## Running an example with color map and CUDA
|
||||
|
||||
```bash
|
||||
cargo run --features cuda,depth_anything_v2 --package candle-examples --example depth_anything_v2 -- --color-map --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
```
|
||||
|
50
candle-examples/examples/depth_anything_v2/color_map.rs
Normal file
50
candle-examples/examples/depth_anything_v2/color_map.rs
Normal file
@ -0,0 +1,50 @@
|
||||
use enterpolation::linear::ConstEquidistantLinear;
|
||||
use enterpolation::Generator;
|
||||
use palette::LinSrgb;
|
||||
|
||||
use candle::Tensor;
|
||||
|
||||
pub struct SpectralRColormap {
|
||||
gradient: ConstEquidistantLinear<f32, LinSrgb, 9>,
|
||||
}
|
||||
|
||||
impl SpectralRColormap {
|
||||
pub(crate) fn new() -> Self {
|
||||
// Define a colormap similar to 'Spectral_r' by specifying key colors.
|
||||
// got the colors from ChatGPT-4o
|
||||
let gradient = ConstEquidistantLinear::<f32, _, 9>::equidistant_unchecked([
|
||||
LinSrgb::new(0.3686, 0.3098, 0.6353), // Dark blue
|
||||
LinSrgb::new(0.1961, 0.5333, 0.7412), // Blue
|
||||
LinSrgb::new(0.4000, 0.7608, 0.6471), // Cyan
|
||||
LinSrgb::new(0.6706, 0.8667, 0.6431), // Green
|
||||
LinSrgb::new(0.9020, 0.9608, 0.5961), // Yellow
|
||||
LinSrgb::new(0.9961, 0.8784, 0.5451), // Orange
|
||||
LinSrgb::new(0.9922, 0.6824, 0.3804), // Red
|
||||
LinSrgb::new(0.9569, 0.4275, 0.2627), // Dark red
|
||||
LinSrgb::new(0.8353, 0.2431, 0.3098), // Dark purple
|
||||
]);
|
||||
Self { gradient }
|
||||
}
|
||||
|
||||
fn get_color(&self, value: f32) -> LinSrgb {
|
||||
self.gradient.gen(value)
|
||||
}
|
||||
|
||||
pub fn gray2color(&self, gray: &Tensor) -> candle::Result<Tensor> {
|
||||
println!("Gray: {:?}", gray.dims());
|
||||
let gray_values: Vec<f32> = gray.flatten_all()?.to_vec1()?;
|
||||
let rgb_values: Vec<f32> = gray_values
|
||||
.iter()
|
||||
.map(|g| self.get_color(*g))
|
||||
.flat_map(|rgb| [rgb.red, rgb.green, rgb.blue])
|
||||
.collect();
|
||||
|
||||
let [.., height, width] = gray.dims() else {
|
||||
candle::bail!("Not enough dims!")
|
||||
};
|
||||
|
||||
let color = Tensor::from_vec(rgb_values, (*height, *width, 3), gray.device())?;
|
||||
|
||||
color.permute((2, 0, 1))
|
||||
}
|
||||
}
|
187
candle-examples/examples/depth_anything_v2/main.rs
Normal file
187
candle-examples/examples/depth_anything_v2/main.rs
Normal file
@ -0,0 +1,187 @@
|
||||
//! Depth Anything V2
|
||||
//! https://huggingface.co/spaces/depth-anything/Depth-Anything-V2
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use std::ffi::OsString;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
use candle::DType::{F32, U8};
|
||||
use candle::{DType, Device, Module, Result, Tensor};
|
||||
use candle_examples::{load_image, load_image_and_resize, save_image};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::depth_anything_v2::{DepthAnythingV2, DepthAnythingV2Config};
|
||||
use candle_transformers::models::dinov2;
|
||||
|
||||
use crate::color_map::SpectralRColormap;
|
||||
|
||||
mod color_map;
|
||||
|
||||
// taken these from: https://huggingface.co/spaces/depth-anything/Depth-Anything-V2/blob/main/depth_anything_v2/dpt.py#L207
|
||||
const MAGIC_MEAN: [f32; 3] = [0.485, 0.456, 0.406];
|
||||
const MAGIC_STD: [f32; 3] = [0.229, 0.224, 0.225];
|
||||
|
||||
const DINO_IMG_SIZE: usize = 518;
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
dinov2_model: Option<PathBuf>,
|
||||
|
||||
#[arg(long)]
|
||||
depth_anything_v2_model: Option<PathBuf>,
|
||||
|
||||
#[arg(long)]
|
||||
image: PathBuf,
|
||||
|
||||
#[arg(long)]
|
||||
output_dir: Option<PathBuf>,
|
||||
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(long)]
|
||||
color_map: bool,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let dinov2_model_file = match args.dinov2_model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("lmz/candle-dino-v2".into());
|
||||
api.get("dinov2_vits14.safetensors")?
|
||||
}
|
||||
Some(dinov2_model) => dinov2_model,
|
||||
};
|
||||
println!("Using file {:?}", dinov2_model_file);
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[dinov2_model_file], F32, &device)? };
|
||||
let dinov2 = dinov2::vit_small(vb)?;
|
||||
println!("DinoV2 model built");
|
||||
|
||||
let depth_anything_model_file = match args.depth_anything_v2_model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("jeroenvlek/depth-anything-v2-safetensors".into());
|
||||
api.get("depth_anything_v2_vits.safetensors")?
|
||||
}
|
||||
Some(depth_anything_model) => depth_anything_model,
|
||||
};
|
||||
println!("Using file {:?}", depth_anything_model_file);
|
||||
|
||||
let vb = unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&[depth_anything_model_file], DType::F32, &device)?
|
||||
};
|
||||
|
||||
let config = DepthAnythingV2Config::vit_small();
|
||||
let depth_anything = DepthAnythingV2::new(&dinov2, &config, vb)?;
|
||||
|
||||
let (original_height, original_width, image) = load_and_prep_image(&args.image, &device)?;
|
||||
|
||||
println!("Loaded image {image:?}");
|
||||
|
||||
let depth = depth_anything.forward(&image)?;
|
||||
|
||||
println!("Got predictions {:?}", depth.shape());
|
||||
|
||||
let output_image = post_process_image(&depth, original_height, original_width, args.color_map)?;
|
||||
|
||||
let output_path = full_output_path(&args.image, &args.output_dir);
|
||||
println!("Saving image to {}", output_path.to_string_lossy());
|
||||
save_image(&output_image, output_path)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn full_output_path(image_path: &PathBuf, output_dir: &Option<PathBuf>) -> PathBuf {
|
||||
let input_file_name = image_path.file_name().unwrap();
|
||||
let mut output_file_name = OsString::from("depth_");
|
||||
output_file_name.push(input_file_name);
|
||||
let mut output_path = match output_dir {
|
||||
None => image_path.parent().unwrap().to_path_buf(),
|
||||
Some(output_path) => output_path.clone(),
|
||||
};
|
||||
output_path.push(output_file_name);
|
||||
|
||||
output_path
|
||||
}
|
||||
|
||||
fn load_and_prep_image(
|
||||
image_path: &PathBuf,
|
||||
device: &Device,
|
||||
) -> anyhow::Result<(usize, usize, Tensor)> {
|
||||
let (_original_image, original_height, original_width) = load_image(&image_path, None)?;
|
||||
|
||||
let image = load_image_and_resize(&image_path, DINO_IMG_SIZE, DINO_IMG_SIZE)?
|
||||
.unsqueeze(0)?
|
||||
.to_dtype(F32)?
|
||||
.to_device(&device)?;
|
||||
|
||||
let max_pixel_val = Tensor::try_from(255.0f32)?
|
||||
.to_device(&device)?
|
||||
.broadcast_as(image.shape())?;
|
||||
let image = (image / max_pixel_val)?;
|
||||
let image = normalize_image(&image, &MAGIC_MEAN, &MAGIC_STD)?;
|
||||
|
||||
Ok((original_height, original_width, image))
|
||||
}
|
||||
|
||||
fn normalize_image(image: &Tensor, mean: &[f32; 3], std: &[f32; 3]) -> Result<Tensor> {
|
||||
let mean_tensor =
|
||||
Tensor::from_vec(mean.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?;
|
||||
let std_tensor =
|
||||
Tensor::from_vec(std.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?;
|
||||
image.sub(&mean_tensor)?.div(&std_tensor)
|
||||
}
|
||||
|
||||
fn post_process_image(
|
||||
image: &Tensor,
|
||||
original_height: usize,
|
||||
original_width: usize,
|
||||
color_map: bool,
|
||||
) -> Result<Tensor> {
|
||||
let out = image.interpolate2d(original_height, original_width)?;
|
||||
let out = scale_image(&out)?;
|
||||
|
||||
let out = if color_map {
|
||||
let spectral_r = SpectralRColormap::new();
|
||||
spectral_r.gray2color(&out)?
|
||||
} else {
|
||||
let rgb_slice = [&out, &out, &out];
|
||||
Tensor::cat(&rgb_slice, 0)?.squeeze(1)?
|
||||
};
|
||||
|
||||
let max_pixel_val = Tensor::try_from(255.0f32)?
|
||||
.to_device(out.device())?
|
||||
.broadcast_as(out.shape())?;
|
||||
let out = (out * max_pixel_val)?;
|
||||
|
||||
out.to_dtype(U8)
|
||||
}
|
||||
|
||||
fn scale_image(depth: &Tensor) -> Result<Tensor> {
|
||||
let flat_values: Vec<f32> = depth.flatten_all()?.to_vec1()?;
|
||||
|
||||
let min_val = flat_values.iter().min_by(|a, b| a.total_cmp(b)).unwrap();
|
||||
let max_val = flat_values.iter().max_by(|a, b| a.total_cmp(b)).unwrap();
|
||||
|
||||
let min_val_tensor = Tensor::try_from(*min_val)?
|
||||
.to_device(depth.device())?
|
||||
.broadcast_as(depth.shape())?;
|
||||
let depth = (depth - min_val_tensor)?;
|
||||
|
||||
let range = max_val - min_val;
|
||||
let range_tensor = Tensor::try_from(range)?
|
||||
.to_device(depth.device())?
|
||||
.broadcast_as(depth.shape())?;
|
||||
|
||||
depth / range_tensor
|
||||
}
|
25
candle-examples/examples/dinov2reg4/README.md
Normal file
25
candle-examples/examples/dinov2reg4/README.md
Normal file
@ -0,0 +1,25 @@
|
||||
# candle-dinov2-reg4
|
||||
|
||||
[DINOv2-reg4](https://arxiv.org/abs/2309.16588) is the lastest version of DINOv2 with registers.
|
||||
In this example, it is used as an plant species classifier: the model returns the
|
||||
probability for the image to belong to each of the 7806 PlantCLEF2024 categories.
|
||||
|
||||
## Running some example
|
||||
|
||||
```bash
|
||||
# Download classes names and a plant picture to identify
|
||||
curl https://huggingface.co/vincent-espitalier/dino-v2-reg4-with-plantclef2024-weights/raw/main/species_id_mapping.txt --output candle-examples/examples/dinov2reg4/species_id_mapping.txt
|
||||
curl https://bs.plantnet.org/image/o/bd2d3830ac3270218ba82fd24e2290becd01317c --output candle-examples/examples/dinov2reg4/bd2d3830ac3270218ba82fd24e2290becd01317c.jpg
|
||||
|
||||
# Perform inference
|
||||
cargo run --example dinov2reg4 --release -- --image candle-examples/examples/dinov2reg4/bd2d3830ac3270218ba82fd24e2290becd01317c.jpg
|
||||
|
||||
> Orchis simia Lam. : 45.55%
|
||||
> Orchis × bergonii Nanteuil: 9.80%
|
||||
> Orchis italica Poir. : 9.66%
|
||||
> Orchis × angusticruris Franch.: 2.76%
|
||||
> Orchis × bivonae Tod. : 2.54%
|
||||
|
||||
```
|
||||
|
||||

|
70
candle-examples/examples/dinov2reg4/main.rs
Normal file
70
candle-examples/examples/dinov2reg4/main.rs
Normal file
@ -0,0 +1,70 @@
|
||||
//! DINOv2 reg4 finetuned on PlantCLEF 2024
|
||||
//! https://arxiv.org/abs/2309.16588
|
||||
//! https://huggingface.co/spaces/BVRA/PlantCLEF2024
|
||||
//! https://zenodo.org/records/10848263
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::dinov2reg4;
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image518(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let f_species_id_mapping = "candle-examples/examples/dinov2reg4/species_id_mapping.txt";
|
||||
let classes: Vec<String> = std::fs::read_to_string(f_species_id_mapping)
|
||||
.expect("missing classes file")
|
||||
.split('\n')
|
||||
.map(|s| s.to_string())
|
||||
.collect();
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api =
|
||||
api.model("vincent-espitalier/dino-v2-reg4-with-plantclef2024-weights".into());
|
||||
api.get(
|
||||
"vit_base_patch14_reg4_dinov2_lvd142m_pc24_onlyclassifier_then_all.safetensors",
|
||||
)?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = dinov2reg4::vit_base(vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!("{:24}: {:.2}%", classes[category_idx], 100. * pr);
|
||||
}
|
||||
Ok(())
|
||||
}
|
21
candle-examples/examples/eva2/README.md
Normal file
21
candle-examples/examples/eva2/README.md
Normal file
@ -0,0 +1,21 @@
|
||||
# candle-eva2
|
||||
|
||||
[EVA-02](https://arxiv.org/abs/2303.11331) is a computer vision model.
|
||||
In this example, it is used as an ImageNet classifier: the model returns the
|
||||
probability for the image to belong to each of the 1000 ImageNet categories.
|
||||
|
||||
## Running some example
|
||||
|
||||
```bash
|
||||
cargo run --example eva2 --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
|
||||
> mountain bike, all-terrain bike, off-roader: 37.09%
|
||||
> maillot : 8.30%
|
||||
> alp : 2.13%
|
||||
> bicycle-built-for-two, tandem bicycle, tandem: 0.84%
|
||||
> crash helmet : 0.73%
|
||||
|
||||
|
||||
```
|
||||
|
||||

|
82
candle-examples/examples/eva2/main.rs
Normal file
82
candle-examples/examples/eva2/main.rs
Normal file
@ -0,0 +1,82 @@
|
||||
//! EVA-02: Explore the limits of Visual representation at scAle
|
||||
//! https://github.com/baaivision/EVA
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::eva2;
|
||||
|
||||
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||
/// (3, 448, 448). OpenAI normalization is applied.
|
||||
pub fn load_image448_openai_norm<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||
let img = image::ImageReader::open(p)?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?
|
||||
.resize_to_fill(448, 448, image::imageops::FilterType::Triangle);
|
||||
let img = img.to_rgb8();
|
||||
let data = img.into_raw();
|
||||
let data = Tensor::from_vec(data, (448, 448, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
||||
let mean =
|
||||
Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||
let std = Tensor::new(&[0.26862954f32, 0.261_302_6, 0.275_777_1], &Device::Cpu)?
|
||||
.reshape((3, 1, 1))?;
|
||||
(data.to_dtype(candle::DType::F32)? / 255.)?
|
||||
.broadcast_sub(&mean)?
|
||||
.broadcast_div(&std)
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = load_image448_openai_norm(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("vincent-espitalier/candle-eva2".into());
|
||||
api.get("eva02_base_patch14_448.mim_in22k_ft_in22k_in1k_adapted.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
|
||||
let model = eva2::vit_base(vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
19
candle-examples/examples/flux/README.md
Normal file
19
candle-examples/examples/flux/README.md
Normal file
@ -0,0 +1,19 @@
|
||||
# candle-flux: image generation with latent rectified flow transformers
|
||||
|
||||

|
||||
|
||||
Flux is a 12B rectified flow transformer capable of generating images from text
|
||||
descriptions,
|
||||
[huggingface](https://huggingface.co/black-forest-labs/FLUX.1-schnell),
|
||||
[github](https://github.com/black-forest-labs/flux),
|
||||
[blog post](https://blackforestlabs.ai/announcing-black-forest-labs/).
|
||||
|
||||
|
||||
## Running the model
|
||||
|
||||
```bash
|
||||
cargo run --features cuda --example flux -r -- \
|
||||
--height 1024 --width 1024
|
||||
--prompt "a rusty robot walking on a beach holding a small torch, the robot has the word "rust" written on it, high quality, 4k"
|
||||
```
|
||||
|
BIN
candle-examples/examples/flux/assets/flux-robot.jpg
Normal file
BIN
candle-examples/examples/flux/assets/flux-robot.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 90 KiB |
210
candle-examples/examples/flux/main.rs
Normal file
210
candle-examples/examples/flux/main.rs
Normal file
@ -0,0 +1,210 @@
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use candle_transformers::models::{clip, flux, t5};
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{IndexOp, Module, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// The prompt to be used for image generation.
|
||||
#[arg(long, default_value = "A rusty robot walking on a beach")]
|
||||
prompt: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// The height in pixels of the generated image.
|
||||
#[arg(long)]
|
||||
height: Option<usize>,
|
||||
|
||||
/// The width in pixels of the generated image.
|
||||
#[arg(long)]
|
||||
width: Option<usize>,
|
||||
|
||||
#[arg(long)]
|
||||
decode_only: Option<String>,
|
||||
|
||||
#[arg(long, value_enum, default_value = "schnell")]
|
||||
model: Model,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
|
||||
enum Model {
|
||||
Schnell,
|
||||
Dev,
|
||||
}
|
||||
|
||||
fn run(args: Args) -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let Args {
|
||||
prompt,
|
||||
cpu,
|
||||
height,
|
||||
width,
|
||||
tracing,
|
||||
decode_only,
|
||||
model,
|
||||
} = args;
|
||||
let width = width.unwrap_or(1360);
|
||||
let height = height.unwrap_or(768);
|
||||
|
||||
let _guard = if tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let bf_repo = {
|
||||
let name = match model {
|
||||
Model::Dev => "black-forest-labs/FLUX.1-dev",
|
||||
Model::Schnell => "black-forest-labs/FLUX.1-schnell",
|
||||
};
|
||||
api.repo(hf_hub::Repo::model(name.to_string()))
|
||||
};
|
||||
let device = candle_examples::device(cpu)?;
|
||||
let dtype = device.bf16_default_to_f32();
|
||||
let img = match decode_only {
|
||||
None => {
|
||||
let t5_emb = {
|
||||
let repo = api.repo(hf_hub::Repo::with_revision(
|
||||
"google/t5-v1_1-xxl".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/2".to_string(),
|
||||
));
|
||||
let model_file = repo.get("model.safetensors")?;
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let config: t5::Config = serde_json::from_str(&config)?;
|
||||
let mut model = t5::T5EncoderModel::load(vb, &config)?;
|
||||
let tokenizer_filename = api
|
||||
.model("lmz/mt5-tokenizers".to_string())
|
||||
.get("t5-v1_1-xxl.tokenizer.json")?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let mut tokens = tokenizer
|
||||
.encode(prompt.as_str(), true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
tokens.resize(256, 0);
|
||||
let input_token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
|
||||
println!("{input_token_ids}");
|
||||
model.forward(&input_token_ids)?
|
||||
};
|
||||
println!("T5\n{t5_emb}");
|
||||
let clip_emb = {
|
||||
let repo = api.repo(hf_hub::Repo::model(
|
||||
"openai/clip-vit-large-patch14".to_string(),
|
||||
));
|
||||
let model_file = repo.get("model.safetensors")?;
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
||||
// https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
|
||||
let config = clip::text_model::ClipTextConfig {
|
||||
vocab_size: 49408,
|
||||
projection_dim: 768,
|
||||
activation: clip::text_model::Activation::QuickGelu,
|
||||
intermediate_size: 3072,
|
||||
embed_dim: 768,
|
||||
max_position_embeddings: 77,
|
||||
pad_with: None,
|
||||
num_hidden_layers: 12,
|
||||
num_attention_heads: 12,
|
||||
};
|
||||
let model =
|
||||
clip::text_model::ClipTextTransformer::new(vb.pp("text_model"), &config)?;
|
||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let tokens = tokenizer
|
||||
.encode(prompt.as_str(), true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let input_token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
|
||||
println!("{input_token_ids}");
|
||||
model.forward(&input_token_ids)?
|
||||
};
|
||||
println!("CLIP\n{clip_emb}");
|
||||
let img = {
|
||||
let model_file = match model {
|
||||
Model::Schnell => bf_repo.get("flux1-schnell.sft")?,
|
||||
Model::Dev => bf_repo.get("flux1-dev.sft")?,
|
||||
};
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
||||
let cfg = match model {
|
||||
Model::Dev => flux::model::Config::dev(),
|
||||
Model::Schnell => flux::model::Config::schnell(),
|
||||
};
|
||||
let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?;
|
||||
let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?;
|
||||
let timesteps = match model {
|
||||
Model::Dev => {
|
||||
flux::sampling::get_schedule(50, Some((state.img.dim(1)?, 0.5, 1.15)))
|
||||
}
|
||||
Model::Schnell => flux::sampling::get_schedule(4, None),
|
||||
};
|
||||
let model = flux::model::Flux::new(&cfg, vb)?;
|
||||
|
||||
println!("{state:?}");
|
||||
println!("{timesteps:?}");
|
||||
flux::sampling::denoise(
|
||||
&model,
|
||||
&state.img,
|
||||
&state.img_ids,
|
||||
&state.txt,
|
||||
&state.txt_ids,
|
||||
&state.vec,
|
||||
×teps,
|
||||
4.,
|
||||
)?
|
||||
};
|
||||
flux::sampling::unpack(&img, height, width)?
|
||||
}
|
||||
Some(file) => {
|
||||
let mut st = candle::safetensors::load(file, &device)?;
|
||||
st.remove("img").unwrap().to_dtype(dtype)?
|
||||
}
|
||||
};
|
||||
println!("latent img\n{img}");
|
||||
|
||||
let img = {
|
||||
let model_file = bf_repo.get("ae.sft")?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
||||
let cfg = match model {
|
||||
Model::Dev => flux::autoencoder::Config::dev(),
|
||||
Model::Schnell => flux::autoencoder::Config::schnell(),
|
||||
};
|
||||
let model = flux::autoencoder::AutoEncoder::new(&cfg, vb)?;
|
||||
model.decode(&img)?
|
||||
};
|
||||
println!("img\n{img}");
|
||||
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?;
|
||||
candle_examples::save_image(&img.i(0)?, "out.jpg")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
run(args)
|
||||
}
|
@ -193,6 +193,9 @@ struct Args {
|
||||
/// The model to use.
|
||||
#[arg(long, default_value = "2b")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -270,7 +273,7 @@ fn main() -> Result<()> {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
let model = Model::new(args.use_flash_attn, &config, vb)?;
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
|
19
candle-examples/examples/gte-qwen/README.md
Normal file
19
candle-examples/examples/gte-qwen/README.md
Normal file
@ -0,0 +1,19 @@
|
||||
# gte-Qwen1.5-7B-instruct
|
||||
|
||||
gte-Qwen1.5-7B-instruct is a variant of the GTE embedding model family.
|
||||
|
||||
- [Model card](https://huggingface.co/Alibaba-NLP/gte-Qwen1.5-7B-instruct) on the HuggingFace Hub.
|
||||
- [Technical report](https://arxiv.org/abs/2308.03281) *Towards General Text Embeddings with Multi-stage Contrastive Learning*
|
||||
|
||||
|
||||
## Running the example
|
||||
|
||||
Automatically download the model from the HuggingFace hub:
|
||||
```bash
|
||||
$ cargo run --example gte-qwen --release
|
||||
```
|
||||
|
||||
or, load the model from a local directory:
|
||||
```bash
|
||||
cargo run --example gte-qwen --release --features cuda -- --local-repo /path/to/gte_Qwen1.5-7B-instruct/
|
||||
```
|
178
candle-examples/examples/gte-qwen/main.rs
Normal file
178
candle-examples/examples/gte-qwen/main.rs
Normal file
@ -0,0 +1,178 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::qwen2::{Config, Model};
|
||||
|
||||
use candle::{DType, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::{
|
||||
utils::padding::{PaddingDirection, PaddingParams, PaddingStrategy},
|
||||
Tokenizer,
|
||||
};
|
||||
|
||||
// gte-Qwen1.5-7B-instruct use EOS token as padding token
|
||||
const EOS_TOKEN: &str = "<|endoftext|>";
|
||||
const EOS_TOKEN_ID: u32 = 151643;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long, default_value = "Alibaba-NLP/gte-Qwen1.5-7B-instruct")]
|
||||
model_id: String,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
local_repo: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ConfigFiles {
|
||||
pub config: std::path::PathBuf,
|
||||
pub tokenizer: std::path::PathBuf,
|
||||
pub weights: Vec<std::path::PathBuf>,
|
||||
}
|
||||
|
||||
// Loading the model from the HuggingFace Hub. Network access is required.
|
||||
fn load_from_hub(model_id: &str, revision: &str) -> Result<ConfigFiles> {
|
||||
let api = Api::new()?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id.to_string(),
|
||||
RepoType::Model,
|
||||
revision.to_string(),
|
||||
));
|
||||
Ok(ConfigFiles {
|
||||
config: repo.get("config.json")?,
|
||||
tokenizer: repo.get("tokenizer.json")?,
|
||||
weights: candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
})
|
||||
}
|
||||
|
||||
// Loading the model from a local directory.
|
||||
fn load_from_local(local_path: &str) -> Result<ConfigFiles> {
|
||||
let local_path = std::path::PathBuf::from(local_path);
|
||||
let weight_path = local_path.join("model.safetensors.index.json");
|
||||
let json: serde_json::Value = serde_json::from_str(&std::fs::read_to_string(weight_path)?)?;
|
||||
let weight_map = match json.get("weight_map") {
|
||||
Some(serde_json::Value::Object(map)) => map,
|
||||
Some(_) => panic!("`weight map` is not a map"),
|
||||
None => panic!("`weight map` not found"),
|
||||
};
|
||||
let mut safetensors_files = std::collections::HashSet::new();
|
||||
for value in weight_map.values() {
|
||||
safetensors_files.insert(
|
||||
value
|
||||
.as_str()
|
||||
.expect("Weight files should be parsed as strings"),
|
||||
);
|
||||
}
|
||||
let safetensors_paths = safetensors_files
|
||||
.iter()
|
||||
.map(|v| local_path.join(v))
|
||||
.collect::<Vec<_>>();
|
||||
Ok(ConfigFiles {
|
||||
config: local_path.join("config.json"),
|
||||
tokenizer: local_path.join("tokenizer.json"),
|
||||
weights: safetensors_paths,
|
||||
})
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Fetch the model. Do this offline if local path provided.
|
||||
println!("Fetching model files...");
|
||||
let start = std::time::Instant::now();
|
||||
let config_files = match args.local_repo {
|
||||
Some(local_path) => load_from_local(&local_path)?,
|
||||
None => load_from_hub(&args.model_id, &args.revision)?,
|
||||
};
|
||||
println!("Model file retrieved in {:?}", start.elapsed());
|
||||
|
||||
// Inputs will be padded to the longest sequence in the batch.
|
||||
let padding = PaddingParams {
|
||||
strategy: PaddingStrategy::BatchLongest,
|
||||
direction: PaddingDirection::Left,
|
||||
pad_to_multiple_of: None,
|
||||
pad_id: EOS_TOKEN_ID,
|
||||
pad_type_id: 0,
|
||||
pad_token: String::from(EOS_TOKEN),
|
||||
};
|
||||
|
||||
// Tokenizer setup
|
||||
let mut tokenizer = Tokenizer::from_file(config_files.tokenizer).map_err(E::msg)?;
|
||||
tokenizer.with_padding(Some(padding));
|
||||
|
||||
// Model initialization
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_files.config)?)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&config_files.weights, dtype, &device)? };
|
||||
let mut model = Model::new(&config, vb)?;
|
||||
println!("Model loaded in {:?}", start.elapsed());
|
||||
|
||||
// Encode the queries and the targets
|
||||
let instruct = "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: ";
|
||||
let documents = vec![
|
||||
format!("{instruct}how much protein should a female eat{EOS_TOKEN}"),
|
||||
format!("{instruct}summit define{EOS_TOKEN}"),
|
||||
format!("As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.{EOS_TOKEN}"),
|
||||
format!("Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.{EOS_TOKEN}"),
|
||||
];
|
||||
let encoded = tokenizer.encode_batch(documents, true).map_err(E::msg)?;
|
||||
let tokens: Vec<&[u32]> = encoded.iter().map(|x| x.get_ids()).collect();
|
||||
let tokens = Tensor::new(tokens, &device)?;
|
||||
let mask: Vec<&[u32]> = encoded.iter().map(|x| x.get_attention_mask()).collect();
|
||||
let mask = Tensor::new(mask, &device)?;
|
||||
|
||||
// Inference
|
||||
let start_gen = std::time::Instant::now();
|
||||
let logits = model.forward(&tokens, 0, Some(&mask))?;
|
||||
|
||||
// Extract the last hidden states as embeddings since inputs are padded left.
|
||||
let (_, seq_len, _) = logits.dims3()?;
|
||||
let embd = logits
|
||||
.narrow(1, seq_len - 1, 1)?
|
||||
.squeeze(1)?
|
||||
.to_dtype(DType::F32)?;
|
||||
|
||||
// Calculate the relativity scores. Note the embeddings should be normalized.
|
||||
let norm = embd.broadcast_div(&embd.sqr()?.sum_keepdim(1)?.sqrt()?)?;
|
||||
let scores = norm.narrow(0, 0, 2)?.matmul(&norm.narrow(0, 2, 2)?.t()?)?;
|
||||
|
||||
// Print the results
|
||||
println!("Embedding done in {:?}", start_gen.elapsed());
|
||||
println!("Scores: {:?}", scores.to_vec2::<f32>()?);
|
||||
|
||||
Ok(())
|
||||
}
|
18
candle-examples/examples/hiera/README.md
Normal file
18
candle-examples/examples/hiera/README.md
Normal file
@ -0,0 +1,18 @@
|
||||
# hiera
|
||||
|
||||
[Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles](https://arxiv.org/abs/2306.00989)
|
||||
This candle implementation uses pre-trained Hiera models from timm for inference.
|
||||
The classification head has been trained on the ImageNet dataset and returns the probabilities for the top-5 classes.
|
||||
|
||||
## Running an example
|
||||
|
||||
```
|
||||
$ cargo run --example hiera --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which tiny
|
||||
loaded image Tensor[dims 3, 224, 224; f32]
|
||||
model built
|
||||
mountain bike, all-terrain bike, off-roader: 71.15%
|
||||
unicycle, monocycle : 7.11%
|
||||
knee pad : 4.26%
|
||||
crash helmet : 1.48%
|
||||
moped : 1.07%
|
||||
```
|
99
candle-examples/examples/hiera/main.rs
Normal file
99
candle-examples/examples/hiera/main.rs
Normal file
@ -0,0 +1,99 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::hiera;
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
Tiny,
|
||||
Small,
|
||||
Base,
|
||||
BasePlus,
|
||||
Large,
|
||||
Huge,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_filename(&self) -> String {
|
||||
let name = match self {
|
||||
Self::Tiny => "tiny",
|
||||
Self::Small => "small",
|
||||
Self::Base => "base",
|
||||
Self::BasePlus => "base_plus",
|
||||
Self::Large => "large",
|
||||
Self::Huge => "huge",
|
||||
};
|
||||
format!("timm/hiera_{}_224.mae_in1k_ft_in1k", name)
|
||||
}
|
||||
|
||||
fn config(&self) -> hiera::Config {
|
||||
match self {
|
||||
Self::Tiny => hiera::Config::tiny(),
|
||||
Self::Small => hiera::Config::small(),
|
||||
Self::Base => hiera::Config::base(),
|
||||
Self::BasePlus => hiera::Config::base_plus(),
|
||||
Self::Large => hiera::Config::large(),
|
||||
Self::Huge => hiera::Config::huge(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(value_enum, long, default_value_t=Which::Tiny)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let model_name = args.which.model_filename();
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(model_name);
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = hiera::hiera(&args.which.config(), 1000, vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -4,7 +4,7 @@ extern crate intel_mkl_src;
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle_transformers::models::jina_bert::{BertModel, Config};
|
||||
use candle_transformers::models::jina_bert::{BertModel, Config, PositionEmbeddingType};
|
||||
|
||||
use anyhow::Error as E;
|
||||
use candle::{DType, Module, Tensor};
|
||||
@ -39,32 +39,47 @@ struct Args {
|
||||
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
model_file: Option<String>,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
let model = match &self.model {
|
||||
let model_name = match self.model.as_ref() {
|
||||
Some(model) => model.to_string(),
|
||||
None => "jinaai/jina-embeddings-v2-base-en".to_string(),
|
||||
};
|
||||
|
||||
let model = match &self.model_file {
|
||||
Some(model_file) => std::path::PathBuf::from(model_file),
|
||||
None => Api::new()?
|
||||
.repo(Repo::new(
|
||||
"jinaai/jina-embeddings-v2-base-en".to_string(),
|
||||
RepoType::Model,
|
||||
))
|
||||
.repo(Repo::new(model_name.to_string(), RepoType::Model))
|
||||
.get("model.safetensors")?,
|
||||
};
|
||||
let tokenizer = match &self.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => Api::new()?
|
||||
.repo(Repo::new(
|
||||
"sentence-transformers/all-MiniLM-L6-v2".to_string(),
|
||||
RepoType::Model,
|
||||
))
|
||||
.repo(Repo::new(model_name.to_string(), RepoType::Model))
|
||||
.get("tokenizer.json")?,
|
||||
};
|
||||
let device = candle_examples::device(self.cpu)?;
|
||||
let config = Config::v2_base();
|
||||
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
let config = Config::new(
|
||||
tokenizer.get_vocab_size(true),
|
||||
768,
|
||||
12,
|
||||
12,
|
||||
3072,
|
||||
candle_nn::Activation::Gelu,
|
||||
8192,
|
||||
2,
|
||||
0.02,
|
||||
1e-12,
|
||||
0,
|
||||
PositionEmbeddingType::Alibi,
|
||||
);
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||
let model = BertModel::new(vb, &config)?;
|
||||
Ok((model, tokenizer))
|
||||
@ -101,14 +116,20 @@ fn main() -> anyhow::Result<()> {
|
||||
.to_vec();
|
||||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
println!("Loaded and encoded {:?}", start.elapsed());
|
||||
for idx in 0..args.n {
|
||||
let start = std::time::Instant::now();
|
||||
let ys = model.forward(&token_ids)?;
|
||||
if idx == 0 {
|
||||
println!("{ys}");
|
||||
}
|
||||
println!("Took {:?}", start.elapsed());
|
||||
let start = std::time::Instant::now();
|
||||
let embeddings = model.forward(&token_ids)?;
|
||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||
println!("pooled_embeddigns: {embeddings}");
|
||||
let embeddings = if args.normalize_embeddings {
|
||||
normalize_l2(&embeddings)?
|
||||
} else {
|
||||
embeddings
|
||||
};
|
||||
if args.normalize_embeddings {
|
||||
println!("normalized_embeddings: {embeddings}");
|
||||
}
|
||||
println!("Took {:?}", start.elapsed());
|
||||
} else {
|
||||
let sentences = [
|
||||
"The cat sits outside",
|
||||
|
@ -17,7 +17,7 @@ use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::io::Write;
|
||||
|
||||
@ -31,6 +31,10 @@ const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||
enum Which {
|
||||
V1,
|
||||
V2,
|
||||
V3,
|
||||
V31,
|
||||
V3Instruct,
|
||||
V31Instruct,
|
||||
#[value(name = "solar-10.7b")]
|
||||
Solar10_7B,
|
||||
#[value(name = "tiny-llama-1.1b-chat")]
|
||||
@ -45,19 +49,23 @@ struct Args {
|
||||
cpu: bool,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples.
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, default_value_t = 10000)]
|
||||
#[arg(short = 'n', long, default_value_t = 10000)]
|
||||
sample_len: usize,
|
||||
|
||||
/// Disable the key-value cache.
|
||||
@ -83,18 +91,18 @@ struct Args {
|
||||
revision: Option<String>,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "v2")]
|
||||
#[arg(long, default_value = "v3")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.0)]
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
#[arg(long, default_value_t = 128)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
@ -120,11 +128,15 @@ fn main() -> Result<()> {
|
||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||
None => DType::F16,
|
||||
};
|
||||
let (llama, tokenizer_filename, mut cache) = {
|
||||
let (llama, tokenizer_filename, mut cache, config) = {
|
||||
let api = Api::new()?;
|
||||
let model_id = args.model_id.unwrap_or_else(|| match args.which {
|
||||
Which::V1 => "Narsil/amall-7b".to_string(),
|
||||
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
|
||||
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
|
||||
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
|
||||
Which::V31 => "meta-llama/Meta-Llama-3.1-8B".to_string(),
|
||||
Which::V31Instruct => "meta-llama/Meta-Llama-3.1-8B-Instruct".to_string(),
|
||||
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
||||
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
|
||||
});
|
||||
@ -138,7 +150,13 @@ fn main() -> Result<()> {
|
||||
let config = config.into_config(args.use_flash_attn);
|
||||
|
||||
let filenames = match args.which {
|
||||
Which::V1 | Which::V2 | Which::Solar10_7B => {
|
||||
Which::V1
|
||||
| Which::V2
|
||||
| Which::V3
|
||||
| Which::V3Instruct
|
||||
| Which::V31
|
||||
| Which::V31Instruct
|
||||
| Which::Solar10_7B => {
|
||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||
}
|
||||
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
||||
@ -146,10 +164,14 @@ fn main() -> Result<()> {
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
(Llama::load(vb, &config)?, tokenizer_filename, cache)
|
||||
(Llama::load(vb, &config)?, tokenizer_filename, cache, config)
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
|
||||
let eos_token_id = config.eos_token_id.or_else(|| {
|
||||
tokenizer
|
||||
.token_to_id(EOS_TOKEN)
|
||||
.map(model::LlamaEosToks::Single)
|
||||
});
|
||||
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
||||
let mut tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
@ -160,8 +182,22 @@ fn main() -> Result<()> {
|
||||
|
||||
println!("starting the inference loop");
|
||||
print!("{prompt}");
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut logits_processor = {
|
||||
let temperature = args.temperature;
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (args.top_k, args.top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||
};
|
||||
|
||||
let mut start_gen = std::time::Instant::now();
|
||||
let mut index_pos = 0;
|
||||
let mut token_generated = 0;
|
||||
for index in 0..args.sample_len {
|
||||
@ -170,6 +206,9 @@ fn main() -> Result<()> {
|
||||
} else {
|
||||
(tokens.len(), 0)
|
||||
};
|
||||
if index == 1 {
|
||||
start_gen = std::time::Instant::now()
|
||||
}
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
let logits = llama.forward(&input, context_index, &mut cache)?;
|
||||
@ -190,8 +229,14 @@ fn main() -> Result<()> {
|
||||
token_generated += 1;
|
||||
tokens.push(next_token);
|
||||
|
||||
if Some(next_token) == eos_token_id {
|
||||
break;
|
||||
match eos_token_id {
|
||||
Some(model::LlamaEosToks::Single(eos_tok_id)) if next_token == eos_tok_id => {
|
||||
break;
|
||||
}
|
||||
Some(model::LlamaEosToks::Multiple(ref eos_ids)) if eos_ids.contains(&next_token) => {
|
||||
break;
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
@ -205,7 +250,7 @@ fn main() -> Result<()> {
|
||||
println!(
|
||||
"\n\n{} tokens generated ({} token/s)\n",
|
||||
token_generated,
|
||||
token_generated as f64 / dt.as_secs_f64(),
|
||||
(token_generated - 1) as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -10,7 +10,7 @@
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{bail, Error as E, Result};
|
||||
use clap::Parser;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
@ -24,57 +24,15 @@ mod model;
|
||||
use model::{Config, Llama};
|
||||
|
||||
const MAX_SEQ_LEN: usize = 4096;
|
||||
const DEFAULT_PROMPT: &str = r"
|
||||
EDWARD:
|
||||
I wonder how our princely father 'scaped,
|
||||
Or whether he be 'scaped away or no
|
||||
From Clifford's and Northumberland's pursuit:
|
||||
Had he been ta'en, we should have heard the news;
|
||||
Had he been slain, we should have heard the news;
|
||||
Or had he 'scaped, methinks we should have heard
|
||||
The happy tidings of his good escape.
|
||||
How fares my brother? why is he so sad?
|
||||
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||
|
||||
RICHARD:
|
||||
I cannot joy, until I be resolved
|
||||
Where our right valiant father is become.
|
||||
I saw him in the battle range about;
|
||||
And watch'd him how he singled Clifford forth.
|
||||
Methought he bore him in the thickest troop
|
||||
As doth a lion in a herd of neat;
|
||||
Or as a bear, encompass'd round with dogs,
|
||||
Who having pinch'd a few and made them cry,
|
||||
The rest stand all aloof, and bark at him.
|
||||
So fared our father with his enemies;
|
||||
So fled his enemies my warlike father:
|
||||
Methinks, 'tis prize enough to be his son.
|
||||
See how the morning opes her golden gates,
|
||||
And takes her farewell of the glorious sun!
|
||||
How well resembles it the prime of youth,
|
||||
Trimm'd like a younker prancing to his love!
|
||||
|
||||
EDWARD:
|
||||
Dazzle mine eyes, or do I see three suns?
|
||||
|
||||
RICHARD:
|
||||
Three glorious suns, each one a perfect sun;
|
||||
Not separated with the racking clouds,
|
||||
But sever'd in a pale clear-shining sky.
|
||||
See, see! they join, embrace, and seem to kiss,
|
||||
As if they vow'd some league inviolable:
|
||||
Now are they but one lamp, one light, one sun.
|
||||
In this the heaven figures some event.
|
||||
|
||||
EDWARD:
|
||||
'Tis wondrous strange, the like yet never heard of.
|
||||
I think it cites us, brother, to the field,
|
||||
That we, the sons of brave Plantagenet,
|
||||
Each one already blazing by our meeds,
|
||||
Should notwithstanding join our lights together
|
||||
And over-shine the earth as this the world.
|
||||
Whate'er it bodes, henceforward will I bear
|
||||
Upon my target three fair-shining suns.
|
||||
";
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Which {
|
||||
V2_7b,
|
||||
V2_70b,
|
||||
V3_8b,
|
||||
V3_70b,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
@ -86,8 +44,8 @@ struct Args {
|
||||
rank: Option<usize>,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
@ -117,6 +75,12 @@ struct Args {
|
||||
|
||||
#[arg(long)]
|
||||
dtype: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "v3-8b")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long, default_value = "nccl_id.txt")]
|
||||
comm_file: String,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -129,14 +93,27 @@ fn main() -> Result<()> {
|
||||
Some("bf16") => DType::BF16,
|
||||
Some("f32") => DType::F32,
|
||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||
None => DType::F16,
|
||||
None => match args.which {
|
||||
Which::V2_7b | Which::V2_70b => DType::F16,
|
||||
Which::V3_8b | Which::V3_70b => DType::BF16,
|
||||
},
|
||||
};
|
||||
|
||||
let api = Api::new()?;
|
||||
let comm_file = std::path::PathBuf::from(&args.comm_file);
|
||||
if comm_file.exists() {
|
||||
bail!("comm file {comm_file:?} already exists, please remove it first")
|
||||
}
|
||||
|
||||
let model_id = args
|
||||
.model_id
|
||||
.unwrap_or_else(|| "meta-llama/Llama-2-7b-hf".to_string());
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model) => model,
|
||||
None => match args.which {
|
||||
Which::V2_7b => "meta-llama/Llama-2-7b-hf".to_string(),
|
||||
Which::V2_70b => "meta-llama/Llama-2-70b-hf".to_string(),
|
||||
Which::V3_8b => "meta-llama/Meta-Llama-3-8B".to_string(),
|
||||
Which::V3_70b => "meta-llama/Meta-Llama-3-70B".to_string(),
|
||||
},
|
||||
};
|
||||
println!("loading the model weights from {model_id}");
|
||||
let revision = args.revision.unwrap_or("main".to_string());
|
||||
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
@ -145,39 +122,40 @@ fn main() -> Result<()> {
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let filenames = candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
|
||||
|
||||
if args.rank.is_none() {
|
||||
let children: Vec<_> = (0..args.num_shards)
|
||||
.map(|rank| {
|
||||
let mut args: std::collections::VecDeque<_> = std::env::args().collect();
|
||||
args.push_back("--rank".to_string());
|
||||
args.push_back(format!("{rank}"));
|
||||
let name = args.pop_front().unwrap();
|
||||
std::process::Command::new(name).args(args).spawn().unwrap()
|
||||
})
|
||||
.collect();
|
||||
for mut child in children {
|
||||
child.wait().unwrap();
|
||||
let rank = match args.rank {
|
||||
None => {
|
||||
println!("creating {} child processes", args.num_shards);
|
||||
let children: Vec<_> = (0..args.num_shards)
|
||||
.map(|rank| {
|
||||
let mut args: std::collections::VecDeque<_> = std::env::args().collect();
|
||||
args.push_back("--rank".to_string());
|
||||
args.push_back(format!("{rank}"));
|
||||
let name = args.pop_front().unwrap();
|
||||
std::process::Command::new(name).args(args).spawn().unwrap()
|
||||
})
|
||||
.collect();
|
||||
for mut child in children {
|
||||
child.wait()?;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
Some(rank) => rank,
|
||||
};
|
||||
|
||||
let i = args.rank.unwrap();
|
||||
let num_shards = args.num_shards;
|
||||
let rank = i;
|
||||
// Primitive IPC
|
||||
let id = if rank == 0 {
|
||||
let id = Id::new().unwrap();
|
||||
std::fs::File::create("nccl_id.txt.tmp")?
|
||||
.write_all(&id.internal().iter().map(|&i| i as u8).collect::<Vec<_>>())
|
||||
.unwrap();
|
||||
std::fs::rename("nccl_id.txt.tmp", "nccl_id.txt")?;
|
||||
let tmp_file = comm_file.with_extension(".comm.tgz");
|
||||
std::fs::File::create(&tmp_file)?
|
||||
.write_all(&id.internal().iter().map(|&i| i as u8).collect::<Vec<_>>())?;
|
||||
std::fs::rename(&tmp_file, &comm_file)?;
|
||||
id
|
||||
} else {
|
||||
let path = std::path::PathBuf::from("nccl_id.txt");
|
||||
while !path.exists() {
|
||||
while !comm_file.exists() {
|
||||
std::thread::sleep(std::time::Duration::from_secs(1));
|
||||
}
|
||||
let data = std::fs::read("nccl_id.txt")?;
|
||||
let data = std::fs::read(&comm_file)?;
|
||||
let internal: [i8; 128] = data
|
||||
.into_iter()
|
||||
.map(|i| i as i8)
|
||||
@ -187,14 +165,17 @@ fn main() -> Result<()> {
|
||||
let id: Id = Id::uninit(internal);
|
||||
id
|
||||
};
|
||||
let device = CudaDevice::new(i)?;
|
||||
let comm = Rc::new(Comm::from_rank(device, i, num_shards, id).unwrap());
|
||||
let device = CudaDevice::new(rank)?;
|
||||
let comm = match Comm::from_rank(device, rank, num_shards, id) {
|
||||
Ok(comm) => Rc::new(comm),
|
||||
Err(err) => anyhow::bail!("nccl error {:?}", err.0),
|
||||
};
|
||||
if rank == 0 {
|
||||
std::fs::remove_file("nccl_id.txt")?;
|
||||
std::fs::remove_file(comm_file)?;
|
||||
}
|
||||
println!("Rank {rank:?} spawned");
|
||||
|
||||
let device = Device::new_cuda(i)?;
|
||||
let device = Device::new_cuda(rank)?;
|
||||
let cache = model::Cache::new(dtype, &config, &device)?;
|
||||
|
||||
println!("building the model");
|
||||
@ -210,14 +191,24 @@ fn main() -> Result<()> {
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
|
||||
|
||||
println!("starting the inference loop");
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
|
||||
let temperature = if args.temperature <= 0. {
|
||||
None
|
||||
} else {
|
||||
Some(args.temperature)
|
||||
};
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
|
||||
let mut new_tokens = vec![];
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut start_gen = std::time::Instant::now();
|
||||
let mut index_pos = 0;
|
||||
for index in 0..args.sample_len {
|
||||
let start_gen = std::time::Instant::now();
|
||||
// Only start timing at the second token as processing the first token waits for all the
|
||||
// weights to be loaded in an async way.
|
||||
if index == 1 {
|
||||
start_gen = std::time::Instant::now()
|
||||
};
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
@ -228,25 +219,23 @@ fn main() -> Result<()> {
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
new_tokens.push(next_token);
|
||||
if Some(next_token) == config.eos_token_id {
|
||||
break;
|
||||
}
|
||||
if rank == 0 {
|
||||
println!("> {:?}", start_gen.elapsed());
|
||||
println!(
|
||||
"{} token: {} '{}'",
|
||||
index + 1,
|
||||
next_token,
|
||||
tokenizer.decode(&[next_token], true).map_err(E::msg)?
|
||||
);
|
||||
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!();
|
||||
if rank == 0 {
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"{} tokens generated ({} token/s)\n----\n{}\n----",
|
||||
"\n\n{} tokens generated ({} token/s)\n",
|
||||
args.sample_len,
|
||||
args.sample_len as f64 / dt.as_secs_f64(),
|
||||
tokenizer
|
||||
.decode(new_tokens.as_slice(), true)
|
||||
.map_err(E::msg)?
|
||||
(args.sample_len - 1) as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
|
@ -1,15 +1,14 @@
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
|
||||
use candle_nn::var_builder::ShardedVarBuilder as VarBuilder;
|
||||
use candle_nn::{Embedding, Linear, Module, RmsNorm};
|
||||
use cudarc::nccl::safe::{Comm, ReduceOp};
|
||||
use half::f16;
|
||||
use serde::Deserialize;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::MAX_SEQ_LEN;
|
||||
|
||||
use candle_nn::var_builder::ShardedVarBuilder as VarBuilder;
|
||||
pub type Config = candle_transformers::models::llama::LlamaConfig;
|
||||
|
||||
struct TensorParallelColumnLinear {
|
||||
linear: Linear,
|
||||
@ -26,7 +25,7 @@ impl TensorParallelColumnLinear {
|
||||
|
||||
struct TensorParallelRowLinear {
|
||||
linear: Linear,
|
||||
comm: Rc<Comm>,
|
||||
all_reduce: AllReduce,
|
||||
}
|
||||
|
||||
struct AllReduce {
|
||||
@ -36,8 +35,6 @@ struct AllReduce {
|
||||
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
|
||||
/// But for this example purposes, this will work
|
||||
unsafe impl Sync for AllReduce {}
|
||||
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
|
||||
/// But for this example purposes, this will work
|
||||
unsafe impl Send for AllReduce {}
|
||||
|
||||
impl CustomOp1 for AllReduce {
|
||||
@ -46,7 +43,7 @@ impl CustomOp1 for AllReduce {
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||
todo!("implement allreduce for cpu is not necessary for single node");
|
||||
candle::bail!("AllReduce is never used on cpu")
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
@ -56,31 +53,49 @@ impl CustomOp1 for AllReduce {
|
||||
l: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::WrapErr;
|
||||
use cudarc::driver::DeviceSlice;
|
||||
use half::{bf16, f16};
|
||||
|
||||
let elem_count = l.shape().elem_count();
|
||||
let dev = s.device().clone();
|
||||
let s = s.as_cuda_slice::<f16>()?;
|
||||
// let s = match l.contiguous_offsets() {
|
||||
// None => Err(Error::Wrapped("input has to be contiguous".into()))?,
|
||||
// Some((o1, o2)) => s.slice(o1..o2),
|
||||
// };
|
||||
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||
self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
|
||||
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
||||
let dst = match s.dtype() {
|
||||
DType::BF16 => {
|
||||
let s = s.as_cuda_slice::<bf16>()?;
|
||||
let s = match l.contiguous_offsets() {
|
||||
Some((0, l)) if l == s.len() => s,
|
||||
Some(_) | None => candle::bail!("input has to be contiguous"),
|
||||
};
|
||||
let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
|
||||
self.comm
|
||||
.all_reduce(s, &mut dst, &ReduceOp::Sum)
|
||||
.map_err(candle::Error::debug)?;
|
||||
candle::CudaStorage::wrap_cuda_slice(dst, dev)
|
||||
}
|
||||
DType::F16 => {
|
||||
let s = s.as_cuda_slice::<f16>()?;
|
||||
let s = match l.contiguous_offsets() {
|
||||
Some((0, l)) if l == s.len() => s,
|
||||
Some(_) | None => candle::bail!("input has to be contiguous"),
|
||||
};
|
||||
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||
self.comm
|
||||
.all_reduce(s, &mut dst, &ReduceOp::Sum)
|
||||
.map_err(candle::Error::debug)?;
|
||||
candle::CudaStorage::wrap_cuda_slice(dst, dev)
|
||||
}
|
||||
dtype => candle::bail!("unsupported dtype {dtype:?}"),
|
||||
};
|
||||
Ok((dst, l.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
|
||||
x.apply_op1(AllReduce { comm: comm.clone() })
|
||||
}
|
||||
|
||||
impl TensorParallelRowLinear {
|
||||
fn new(linear: Linear, comm: Rc<Comm>) -> Self {
|
||||
Self { linear, comm }
|
||||
let all_reduce = AllReduce { comm };
|
||||
Self { linear, all_reduce }
|
||||
}
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = self.linear.forward(x)?;
|
||||
all_reduce_sum(&x, &self.comm)
|
||||
self.linear.forward(x)?.apply_op1_no_bwd(&self.all_reduce)
|
||||
}
|
||||
}
|
||||
|
||||
@ -121,23 +136,6 @@ impl TensorParallelRowLinear {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct Config {
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub vocab_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub rms_norm_eps: f64,
|
||||
#[serde(default = "default_rope")]
|
||||
pub rope_theta: f32,
|
||||
}
|
||||
|
||||
fn default_rope() -> f32 {
|
||||
10_000.0
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Cache {
|
||||
#[allow(clippy::type_complexity)]
|
||||
@ -161,7 +159,6 @@ impl Cache {
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
// This is different from the paper, see:
|
||||
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
||||
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
|
||||
let cos = idx_theta.cos()?.to_dtype(dtype)?;
|
||||
let sin = idx_theta.sin()?.to_dtype(dtype)?;
|
||||
Ok(Self {
|
||||
@ -197,16 +194,10 @@ struct CausalSelfAttention {
|
||||
|
||||
impl CausalSelfAttention {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (b_sz, _, seq_len, hidden_size) = x.shape().dims4()?;
|
||||
let (_b_sz, _, seq_len, _hidden_size) = x.shape().dims4()?;
|
||||
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
||||
let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
|
||||
let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?;
|
||||
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
|
||||
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
|
||||
Ok(rope)
|
||||
candle_nn::rotary_emb::rope(x, &cos, &sin)
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
@ -232,13 +223,16 @@ impl CausalSelfAttention {
|
||||
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let k = k
|
||||
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let mut v = v
|
||||
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
||||
@ -269,25 +263,14 @@ impl CausalSelfAttention {
|
||||
let v = v.transpose(1, 2)?;
|
||||
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||
let y = candle_flash_attn::flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?
|
||||
.transpose(1, 2)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
|
||||
.reshape((b_sz, seq_len, hidden_size))?;
|
||||
let y = self.o_proj.forward(&y)?;
|
||||
Ok(y)
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_attention_heads / self.num_key_value_heads;
|
||||
if n_rep == 1 {
|
||||
Ok(x)
|
||||
} else {
|
||||
let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().dims4()?;
|
||||
let x = x
|
||||
.unsqueeze(2)?
|
||||
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?;
|
||||
Ok(x)
|
||||
}
|
||||
candle_transformers::utils::repeat_kv(x, n_rep)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
||||
@ -301,7 +284,7 @@ impl CausalSelfAttention {
|
||||
qkv_proj,
|
||||
o_proj,
|
||||
num_attention_heads: cfg.num_attention_heads / comm.world_size(),
|
||||
num_key_value_heads: cfg.num_key_value_heads / comm.world_size(),
|
||||
num_key_value_heads: cfg.num_key_value_heads() / comm.world_size(),
|
||||
head_dim: cfg.hidden_size / cfg.num_attention_heads,
|
||||
cache: cache.clone(),
|
||||
})
|
||||
@ -315,18 +298,6 @@ struct Mlp {
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(
|
||||
c_fc1: TensorParallelColumnLinear,
|
||||
c_fc2: TensorParallelColumnLinear,
|
||||
c_proj: TensorParallelRowLinear,
|
||||
) -> Self {
|
||||
Self {
|
||||
c_fc1,
|
||||
c_fc2,
|
||||
c_proj,
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
||||
self.c_proj.forward(&x)
|
||||
@ -336,7 +307,11 @@ impl Mlp {
|
||||
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
|
||||
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
|
||||
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm)?;
|
||||
Ok(Self::new(c_fc1, c_fc2, c_proj))
|
||||
Ok(Self {
|
||||
c_fc1,
|
||||
c_fc2,
|
||||
c_proj,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -427,10 +402,8 @@ impl Llama {
|
||||
cfg,
|
||||
comm.clone(),
|
||||
)
|
||||
.unwrap()
|
||||
})
|
||||
.collect();
|
||||
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self::new(wte, blocks, norm, lm_head))
|
||||
}
|
||||
}
|
||||
|
4
candle-examples/examples/llava/constants.rs
Normal file
4
candle-examples/examples/llava/constants.rs
Normal file
@ -0,0 +1,4 @@
|
||||
pub const DEFAULT_IMAGE_TOKEN: &str = "<image>";
|
||||
pub const DEFAULT_IM_START_TOKEN: &str = "<im_start>";
|
||||
pub const DEFAULT_IM_END_TOKEN: &str = "<im_end>";
|
||||
pub const IMAGE_PLACEHOLDER: &str = "<image-placeholder>";
|
114
candle-examples/examples/llava/conversation.rs
Normal file
114
candle-examples/examples/llava/conversation.rs
Normal file
@ -0,0 +1,114 @@
|
||||
pub enum SeparatorStyle {
|
||||
Two,
|
||||
Mpt,
|
||||
}
|
||||
pub struct Conversation {
|
||||
pub system: String,
|
||||
pub roles: Vec<String>,
|
||||
pub messages: Vec<(String, Option<String>)>,
|
||||
pub offset: i32,
|
||||
pub sep_style: SeparatorStyle,
|
||||
pub sep: String,
|
||||
pub sep2: Option<String>,
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
impl Conversation {
|
||||
pub fn new(
|
||||
system: &str,
|
||||
roles: &[String],
|
||||
offset: i32,
|
||||
sep_style: SeparatorStyle,
|
||||
sep: &str,
|
||||
sep2: Option<&str>,
|
||||
version: &str,
|
||||
) -> Self {
|
||||
Conversation {
|
||||
system: system.to_string(),
|
||||
roles: roles.to_vec(),
|
||||
messages: Vec::new(),
|
||||
offset,
|
||||
sep_style,
|
||||
sep: sep.to_string(),
|
||||
sep2: sep2.map(|s| s.to_string()),
|
||||
version: version.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn conv_chatml_direct() -> Self {
|
||||
Conversation::new(
|
||||
"<|im_start|>system\nAnswer the questions.",
|
||||
&[
|
||||
"<|im_start|>user\n".to_string(),
|
||||
"<|im_start|>assistant\n".to_string(),
|
||||
],
|
||||
0,
|
||||
SeparatorStyle::Mpt,
|
||||
"<|im_end|>",
|
||||
None,
|
||||
"mpt",
|
||||
)
|
||||
}
|
||||
|
||||
pub fn conv_llava_v1() -> Self {
|
||||
Conversation::new(
|
||||
"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
||||
&[
|
||||
"USER".to_string(),
|
||||
"ASSISTANT".to_string(),
|
||||
],
|
||||
0,
|
||||
SeparatorStyle::Two,
|
||||
" ",
|
||||
Some("</s>"),
|
||||
"v1"
|
||||
)
|
||||
}
|
||||
|
||||
pub fn append_message(&mut self, role: String, message: Option<&str>) {
|
||||
self.messages.push((role, message.map(|s| s.to_string())))
|
||||
}
|
||||
|
||||
pub fn append_user_message(&mut self, message: Option<&str>) {
|
||||
self.append_message(self.roles[0].clone(), message);
|
||||
}
|
||||
|
||||
pub fn append_assistant_message(&mut self, message: Option<&str>) {
|
||||
self.append_message(self.roles[1].clone(), message);
|
||||
}
|
||||
|
||||
pub fn get_prompt(&self) -> String {
|
||||
match self.sep_style {
|
||||
SeparatorStyle::Mpt => {
|
||||
let mut ret = String::new();
|
||||
ret.push_str(&self.system);
|
||||
ret.push_str(&self.sep);
|
||||
for (role, message) in &self.messages {
|
||||
ret.push_str(role);
|
||||
if let Some(message) = message {
|
||||
ret.push_str(message);
|
||||
};
|
||||
ret.push_str(&self.sep);
|
||||
}
|
||||
ret
|
||||
}
|
||||
SeparatorStyle::Two => {
|
||||
let seps = [self.sep.clone(), self.sep2.clone().unwrap()];
|
||||
let mut ret = String::new();
|
||||
ret.push_str(&self.system);
|
||||
ret.push_str(&seps[0]);
|
||||
for (i, (role, message)) in self.messages.iter().enumerate() {
|
||||
ret.push_str(role);
|
||||
if let Some(message) = message {
|
||||
ret.push_str(": "); // strictly follow the python implementation, otherwise it will cause some minor difference between tokens ^_^
|
||||
ret.push_str(message);
|
||||
ret.push_str(&seps[i % 2]);
|
||||
} else {
|
||||
ret.push(':')
|
||||
}
|
||||
}
|
||||
ret
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
317
candle-examples/examples/llava/image_processor.rs
Normal file
317
candle-examples/examples/llava/image_processor.rs
Normal file
@ -0,0 +1,317 @@
|
||||
use std::cmp::min;
|
||||
|
||||
use candle::{bail, DType, Device, Result, Tensor};
|
||||
use candle_transformers::models::llava::{
|
||||
config::{HFPreProcessorConfig, LLaVAConfig},
|
||||
utils::select_best_resolution,
|
||||
};
|
||||
use hf_hub::api::sync::Api;
|
||||
use image::{imageops::overlay, DynamicImage, GenericImageView, Rgb, RgbImage};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
//This struct is mainly for LLaVA aplications, hence it's not completely compatible with python transformer CLIPImageProcessor few several preprocess that LLaVA used, including "openai/clip-vit-large-patch14-336" and "openai/clip-vit-large-patch14".
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ImageProcessor {
|
||||
#[serde(default = "default_size")]
|
||||
pub size: u32, // this is not the same as python transformer
|
||||
#[serde(default = "default_do_resize")]
|
||||
pub do_resize: bool,
|
||||
|
||||
//resample: u32 // 3 for PIL bicubic, equivalent to rust CatmullRom. Hence below we use CatmullRom
|
||||
#[serde(default = "default_do_center_crop")]
|
||||
pub do_center_crop: bool,
|
||||
#[serde(default = "default_crop_size")]
|
||||
pub crop_size: u32, // this is not the same as python transformer
|
||||
#[serde(default = "default_do_rescale")]
|
||||
pub do_rescale: bool,
|
||||
#[serde(default = "default_rescale_factor")]
|
||||
pub rescale_factor: f32,
|
||||
#[serde(default = "default_do_normalize")]
|
||||
pub do_normalize: bool,
|
||||
#[serde(default = "default_image_mean")]
|
||||
pub image_mean: Vec<f32>,
|
||||
#[serde(default = "default_image_std")]
|
||||
pub image_std: Vec<f32>,
|
||||
}
|
||||
|
||||
fn default_size() -> u32 {
|
||||
224
|
||||
}
|
||||
|
||||
fn default_do_resize() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_do_center_crop() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_crop_size() -> u32 {
|
||||
224
|
||||
}
|
||||
|
||||
fn default_do_rescale() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_rescale_factor() -> f32 {
|
||||
1.0 / 255.0
|
||||
}
|
||||
|
||||
fn default_do_normalize() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_image_mean() -> Vec<f32> {
|
||||
vec![0.48145466, 0.4578275, 0.40821073]
|
||||
}
|
||||
|
||||
fn default_image_std() -> Vec<f32> {
|
||||
vec![0.26862954, 0.2613026, 0.2757771]
|
||||
}
|
||||
|
||||
impl ImageProcessor {
|
||||
pub fn from_pretrained(clip_id: &str) -> Result<Self> {
|
||||
let api = Api::new().map_err(|e| candle::Error::Msg(e.to_string()))?;
|
||||
let api = api.model(clip_id.to_string());
|
||||
let config_filename = api
|
||||
.get("preprocessor_config.json")
|
||||
.map_err(|e| candle::Error::Msg(e.to_string()))?;
|
||||
let image_processor =
|
||||
serde_json::from_slice(&std::fs::read(config_filename).map_err(candle::Error::Io)?)
|
||||
.map_err(|e| candle::Error::Msg(e.to_string()))?;
|
||||
Ok(image_processor)
|
||||
}
|
||||
|
||||
pub fn from_hf_preprocessor_config(hf_preprocessor_config: &HFPreProcessorConfig) -> Self {
|
||||
Self {
|
||||
size: hf_preprocessor_config.size["shortest_edge"] as u32,
|
||||
do_resize: hf_preprocessor_config.do_resize,
|
||||
do_center_crop: hf_preprocessor_config.do_center_crop,
|
||||
crop_size: hf_preprocessor_config.crop_size["height"] as u32,
|
||||
do_rescale: hf_preprocessor_config.do_rescale,
|
||||
rescale_factor: hf_preprocessor_config.rescale_factor,
|
||||
do_normalize: hf_preprocessor_config.do_normalize,
|
||||
image_mean: hf_preprocessor_config.image_mean.clone(),
|
||||
image_std: hf_preprocessor_config.image_std.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
///shortest edge to self.resize, other edge is resized to maintain aspect ratio
|
||||
pub fn resize(&self, image: &DynamicImage) -> DynamicImage {
|
||||
let (width, height) = image.dimensions();
|
||||
let size = self.size;
|
||||
if width == size && height == size {
|
||||
image.clone()
|
||||
} else {
|
||||
let (new_width, new_height) = if width < height {
|
||||
(
|
||||
size,
|
||||
(((size * height) as f32) / width as f32).ceil() as u32,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
(((size * width) as f32) / height as f32).ceil() as u32,
|
||||
size,
|
||||
)
|
||||
};
|
||||
image.resize(
|
||||
new_width,
|
||||
new_height,
|
||||
image::imageops::FilterType::CatmullRom,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn center_crop(&self, image: &DynamicImage) -> DynamicImage {
|
||||
let (width, height) = image.dimensions();
|
||||
let crop_size = self.crop_size;
|
||||
let (left, top) = calculate_middle((width, height), (crop_size, crop_size));
|
||||
image.crop_imm(left, top, crop_size, crop_size)
|
||||
}
|
||||
|
||||
pub fn to_tensor(&self, image: &DynamicImage) -> Result<Tensor> {
|
||||
let img = image.to_rgb8().into_raw();
|
||||
let (width, height) = image.dimensions();
|
||||
Tensor::from_vec(img, (height as usize, width as usize, 3), &Device::Cpu)?
|
||||
.to_dtype(DType::F32) // only for internal compute
|
||||
}
|
||||
|
||||
pub fn rescale(&self, tensor: &Tensor) -> Result<Tensor> {
|
||||
let rescale_factor = self.rescale_factor as f64;
|
||||
tensor.affine(rescale_factor, 0.0)
|
||||
}
|
||||
|
||||
pub fn normalize(&self, tensor: &Tensor) -> Result<Tensor> {
|
||||
let image_mean = self.image_mean.clone();
|
||||
let image_std = self.image_std.clone();
|
||||
let mean = Tensor::from_vec(image_mean, (3,), &Device::Cpu)?;
|
||||
let std = Tensor::from_vec(image_std, (3,), &Device::Cpu)?;
|
||||
tensor.broadcast_sub(&mean)?.broadcast_div(&std)
|
||||
}
|
||||
|
||||
pub fn to_channel_dimension_format(&self, tensor: &Tensor) -> Result<Tensor> {
|
||||
tensor.permute((2, 0, 1))
|
||||
}
|
||||
|
||||
pub fn preprocess(&self, image: &DynamicImage) -> Result<Tensor> {
|
||||
let image = if self.do_resize {
|
||||
self.resize(image)
|
||||
} else {
|
||||
image.clone()
|
||||
};
|
||||
let image = if self.do_center_crop {
|
||||
self.center_crop(&image)
|
||||
} else {
|
||||
image
|
||||
};
|
||||
let tensor = self.to_tensor(&image)?;
|
||||
let tensor = if self.do_rescale {
|
||||
self.rescale(&tensor)?
|
||||
} else {
|
||||
tensor
|
||||
};
|
||||
let tensor = if self.do_normalize {
|
||||
self.normalize(&tensor)?
|
||||
} else {
|
||||
tensor
|
||||
};
|
||||
self.to_channel_dimension_format(&tensor)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn calculate_middle(image_size: (u32, u32), center_size: (u32, u32)) -> (u32, u32) {
|
||||
let (width, height) = image_size;
|
||||
let (center_width, center_height) = center_size;
|
||||
let left = if width <= center_width {
|
||||
0
|
||||
} else {
|
||||
((width as f32 - center_width as f32) / 2.0).ceil() as u32
|
||||
};
|
||||
let top = if height <= center_height {
|
||||
0
|
||||
} else {
|
||||
((height as f32 - center_height as f32) / 2.0).ceil() as u32
|
||||
};
|
||||
(left, top)
|
||||
}
|
||||
|
||||
pub fn process_image(
|
||||
image: &DynamicImage,
|
||||
processor: &ImageProcessor,
|
||||
llava_config: &LLaVAConfig,
|
||||
) -> candle::Result<Tensor> {
|
||||
if llava_config.image_aspect_ratio == *"square" {
|
||||
processor.preprocess(image)?.unsqueeze(0)
|
||||
} else if llava_config.image_aspect_ratio == *"anyres" {
|
||||
process_anyres_image(image, processor, &llava_config.image_grid_pinpoints)
|
||||
} else if llava_config.image_aspect_ratio == *"pad" {
|
||||
process_pad_image(image, processor)
|
||||
} else {
|
||||
bail!("Invalid image aspect ratio")
|
||||
}
|
||||
}
|
||||
|
||||
fn process_pad_image(image: &DynamicImage, processor: &ImageProcessor) -> Result<Tensor> {
|
||||
let mean_color = processor
|
||||
.image_mean
|
||||
.iter()
|
||||
.map(|x| ((*x) * 255.0) as u8)
|
||||
.collect::<Vec<u8>>();
|
||||
let mean_color = Rgb::from([mean_color[0], mean_color[1], mean_color[2]]);
|
||||
let image_padded = expand2square(image, mean_color);
|
||||
processor.preprocess(&image_padded)
|
||||
}
|
||||
|
||||
fn process_anyres_image(
|
||||
image: &DynamicImage,
|
||||
processor: &ImageProcessor,
|
||||
grid_pinpoints: &[(u32, u32)],
|
||||
) -> Result<Tensor> {
|
||||
let original_size = image.dimensions();
|
||||
let best_resolution = select_best_resolution(original_size, grid_pinpoints);
|
||||
let image_padded = resize_and_pad_image(image, best_resolution);
|
||||
let image_original_resize = image.resize_exact(
|
||||
processor.size,
|
||||
processor.size,
|
||||
image::imageops::FilterType::CatmullRom,
|
||||
);
|
||||
let mut patches = vec![image_original_resize];
|
||||
for patch in divide_to_patches(&image_padded, processor.crop_size) {
|
||||
patches.push(patch);
|
||||
}
|
||||
let tensors = patches
|
||||
.iter()
|
||||
.map(|patch| processor.preprocess(patch))
|
||||
.collect::<Result<Vec<Tensor>>>()?;
|
||||
Tensor::stack(&tensors, 0)
|
||||
}
|
||||
|
||||
fn expand2square(image: &DynamicImage, background_color: Rgb<u8>) -> DynamicImage {
|
||||
let (width, height) = image.dimensions();
|
||||
match width.cmp(&height) {
|
||||
std::cmp::Ordering::Less => {
|
||||
let mut new_image =
|
||||
DynamicImage::from(RgbImage::from_pixel(height, height, background_color));
|
||||
overlay(&mut new_image, image, ((height - width) / 2) as i64, 0);
|
||||
new_image
|
||||
}
|
||||
std::cmp::Ordering::Equal => image.clone(),
|
||||
std::cmp::Ordering::Greater => {
|
||||
let mut new_image =
|
||||
DynamicImage::from(RgbImage::from_pixel(width, width, background_color));
|
||||
overlay(&mut new_image, image, 0, ((width - height) / 2) as i64);
|
||||
new_image
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn resize_and_pad_image(image: &DynamicImage, target_resolution: (u32, u32)) -> DynamicImage {
|
||||
let (original_width, original_height) = image.dimensions();
|
||||
let original_width_f = original_width as f32;
|
||||
let original_height_f = original_height as f32;
|
||||
let (target_width, target_height) = target_resolution;
|
||||
let target_width_f = target_width as f32;
|
||||
let target_height_f = target_height as f32;
|
||||
let scale_w = target_width_f / original_width_f;
|
||||
let scale_h = target_height_f / original_height_f;
|
||||
let (new_width, new_height) = if scale_w < scale_h {
|
||||
(
|
||||
target_width,
|
||||
min((original_height_f * scale_w).ceil() as u32, target_height),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
min((original_width_f * scale_h).ceil() as u32, target_width),
|
||||
target_height,
|
||||
)
|
||||
};
|
||||
let resized_image = image.resize_exact(
|
||||
new_width,
|
||||
new_height,
|
||||
image::imageops::FilterType::CatmullRom,
|
||||
);
|
||||
let mut new_image = DynamicImage::new_rgb8(target_width, target_height);
|
||||
let (paste_x, paste_y) =
|
||||
calculate_middle((target_width, target_height), (new_width, new_height));
|
||||
overlay(
|
||||
&mut new_image,
|
||||
&resized_image,
|
||||
paste_x.into(),
|
||||
paste_y.into(),
|
||||
);
|
||||
new_image
|
||||
}
|
||||
|
||||
fn divide_to_patches(image: &DynamicImage, patch_size: u32) -> Vec<DynamicImage> {
|
||||
let (width, height) = image.dimensions();
|
||||
let mut patches = Vec::new();
|
||||
for y in (0..height).step_by(patch_size as usize) {
|
||||
for x in (0..width).step_by(patch_size as usize) {
|
||||
let patch = image.crop_imm(x, y, patch_size, patch_size);
|
||||
patches.push(patch);
|
||||
}
|
||||
}
|
||||
patches
|
||||
}
|
316
candle-examples/examples/llava/main.rs
Normal file
316
candle-examples/examples/llava/main.rs
Normal file
@ -0,0 +1,316 @@
|
||||
pub mod constants;
|
||||
pub mod conversation;
|
||||
pub mod image_processor;
|
||||
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
use candle_transformers::models::llama::Cache;
|
||||
|
||||
use anyhow::{bail, Error as E, Result};
|
||||
use candle::{DType, Device, IndexOp, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::llava::config::{
|
||||
HFGenerationConfig, HFLLaVAConfig, HFPreProcessorConfig,
|
||||
};
|
||||
use candle_transformers::models::llava::{config::LLaVAConfig, LLaVA};
|
||||
use clap::Parser;
|
||||
use constants::*;
|
||||
use conversation::Conversation;
|
||||
use hf_hub::api::sync::Api;
|
||||
use image_processor::{process_image, ImageProcessor};
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about,long_about=None)]
|
||||
struct Args {
|
||||
#[arg(long, default_value = "llava-hf/llava-v1.6-vicuna-7b-hf")]
|
||||
model_path: String,
|
||||
#[arg(long, default_value = "tokenizer/tokenizer.json")]
|
||||
tokenizer_path: String,
|
||||
#[arg(long)]
|
||||
model_base: Option<String>,
|
||||
#[arg(long)]
|
||||
image_file: String, // Required
|
||||
#[arg(long)]
|
||||
conv_mode: Option<String>,
|
||||
#[arg(long, default_value_t = 0.2)]
|
||||
temperature: f32,
|
||||
#[arg(long, default_value_t = 512)]
|
||||
max_new_tokens: usize,
|
||||
#[arg(long, action)]
|
||||
hf: bool,
|
||||
#[arg(long, action)]
|
||||
cpu: bool,
|
||||
#[arg(long, action)]
|
||||
no_kv_cache: bool,
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
/// The seed to use when generating random samples. Copy from candle llama. Not exist in python llava.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
}
|
||||
|
||||
//from https://github.com/huggingface/candle/blob/main/candle-examples/examples/clip/main.rs
|
||||
fn load_image<T: AsRef<std::path::Path>>(
|
||||
path: T,
|
||||
processor: &ImageProcessor,
|
||||
llava_config: &LLaVAConfig,
|
||||
dtype: DType,
|
||||
) -> Result<((u32, u32), Tensor)> {
|
||||
let img = image::ImageReader::open(path)?.decode()?;
|
||||
let img_tensor = process_image(&img, processor, llava_config)?;
|
||||
Ok(((img.width(), img.height()), img_tensor.to_dtype(dtype)?))
|
||||
}
|
||||
|
||||
fn get_model_name_from_path(model_path: &str) -> String {
|
||||
let model_paths: Vec<String> = model_path
|
||||
.trim_matches('/')
|
||||
.split('/')
|
||||
.map(|s| s.to_string())
|
||||
.collect();
|
||||
if model_paths.last().unwrap().starts_with("checkpoint-") {
|
||||
format!(
|
||||
"{}_{}",
|
||||
model_paths[model_paths.len() - 2],
|
||||
model_paths.last().unwrap()
|
||||
)
|
||||
} else {
|
||||
model_paths.last().unwrap().to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn duplicate_vec<T>(vec: &[T], n: usize) -> Vec<T>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
let mut res = Vec::new();
|
||||
for _ in 0..n {
|
||||
res.extend(vec.to_owned());
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
fn insert_separator<T>(x: Vec<Vec<T>>, sep: Vec<T>) -> Vec<Vec<T>>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
let sep = vec![sep];
|
||||
let sep = duplicate_vec(&sep, x.len());
|
||||
let mut res = x
|
||||
.iter()
|
||||
.zip(sep.iter())
|
||||
.flat_map(|(x, y)| vec![x.clone(), y.clone()])
|
||||
.collect::<Vec<Vec<T>>>();
|
||||
res.pop();
|
||||
res
|
||||
}
|
||||
|
||||
fn tokenizer_image_token(
|
||||
prompt: &str,
|
||||
tokenizer: &Tokenizer,
|
||||
image_token_index: i64,
|
||||
llava_config: &LLaVAConfig,
|
||||
) -> Result<Tensor> {
|
||||
let prompt_chunks = prompt
|
||||
.split("<image>")
|
||||
.map(|s| {
|
||||
tokenizer
|
||||
.encode(s, true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec()
|
||||
.iter()
|
||||
.map(|x| *x as i64)
|
||||
.collect()
|
||||
})
|
||||
.collect::<Vec<Vec<i64>>>();
|
||||
let mut input_ids = Vec::new();
|
||||
let mut offset = 0;
|
||||
if !prompt_chunks.is_empty()
|
||||
&& !prompt_chunks[0].is_empty()
|
||||
&& prompt_chunks[0][0] == llava_config.bos_token_id as i64
|
||||
{
|
||||
offset = 1;
|
||||
input_ids.push(prompt_chunks[0][0]);
|
||||
}
|
||||
|
||||
for x in insert_separator(
|
||||
prompt_chunks,
|
||||
duplicate_vec(&[image_token_index], offset + 1),
|
||||
)
|
||||
.iter()
|
||||
{
|
||||
input_ids.extend(x[1..].to_vec())
|
||||
}
|
||||
let input_len = input_ids.len();
|
||||
Tensor::from_vec(input_ids, (1, input_len), &Device::Cpu).map_err(E::msg)
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let mut args = Args::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
println!("Start loading model");
|
||||
let api = Api::new()?;
|
||||
let api = api.model(args.model_path.clone());
|
||||
let (llava_config, tokenizer, clip_vision_config, image_processor) = if args.hf {
|
||||
let config_filename = api.get("config.json")?;
|
||||
let hf_llava_config: HFLLaVAConfig =
|
||||
serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let generation_config_filename = api.get("generation_config.json")?;
|
||||
let generation_config: HFGenerationConfig =
|
||||
serde_json::from_slice(&std::fs::read(generation_config_filename)?)?;
|
||||
let preprocessor_config_filename = api.get("preprocessor_config.json")?;
|
||||
let preprocessor_config: HFPreProcessorConfig =
|
||||
serde_json::from_slice(&std::fs::read(preprocessor_config_filename)?)?;
|
||||
let llava_config =
|
||||
hf_llava_config.to_llava_config(&generation_config, &preprocessor_config);
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let clip_vision_config = hf_llava_config.to_clip_vision_config();
|
||||
(
|
||||
llava_config,
|
||||
tokenizer,
|
||||
Some(clip_vision_config),
|
||||
ImageProcessor::from_hf_preprocessor_config(&preprocessor_config),
|
||||
)
|
||||
} else {
|
||||
let config_filename = api.get("config.json")?;
|
||||
let llava_config: LLaVAConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let tokenizer = Tokenizer::from_file(&args.tokenizer_path)
|
||||
.map_err(|e| E::msg(format!("Error loading {}: {}", &args.tokenizer_path, e)))?;
|
||||
(
|
||||
llava_config.clone(),
|
||||
tokenizer,
|
||||
None,
|
||||
ImageProcessor::from_pretrained(&llava_config.mm_vision_tower.unwrap())?,
|
||||
)
|
||||
};
|
||||
|
||||
let llama_config = llava_config.to_llama_config();
|
||||
let dtype: DType = match llava_config.torch_dtype.as_str() {
|
||||
"float16" => DType::F16,
|
||||
"bfloat16" => DType::BF16,
|
||||
_ => bail!("unsupported dtype"),
|
||||
};
|
||||
|
||||
let eos_token_id = llava_config.eos_token_id;
|
||||
|
||||
println!("setting kv cache");
|
||||
let mut cache = Cache::new(!args.no_kv_cache, dtype, &llama_config, &device)?;
|
||||
|
||||
println!("loading model weights");
|
||||
|
||||
let weight_filenames =
|
||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&weight_filenames, dtype, &device)? };
|
||||
let llava: LLaVA = LLaVA::load(vb, &llava_config, clip_vision_config)?;
|
||||
|
||||
println!("generating conv template");
|
||||
let image_token_se = format!(
|
||||
"{}{}{}",
|
||||
DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_END_TOKEN
|
||||
);
|
||||
let qs = if args.prompt.contains(IMAGE_PLACEHOLDER) {
|
||||
if llava_config.mm_use_im_start_end {
|
||||
args.prompt.replace(IMAGE_PLACEHOLDER, &image_token_se)
|
||||
} else {
|
||||
args.prompt.replace(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN)
|
||||
}
|
||||
} else if llava_config.mm_use_im_start_end {
|
||||
format!("{}\n{}", image_token_se, args.prompt)
|
||||
} else {
|
||||
format!("{}\n{}", DEFAULT_IMAGE_TOKEN, args.prompt)
|
||||
};
|
||||
|
||||
let model_name = get_model_name_from_path(&args.model_path).to_lowercase();
|
||||
let conv_mode = if model_name.contains("llama-2") {
|
||||
"llava_llama_2"
|
||||
} else if model_name.contains("mistral") {
|
||||
"mistral_instruct"
|
||||
} else if model_name.contains("v1.6-34b") {
|
||||
"chatml_direct"
|
||||
} else if model_name.contains("v1") {
|
||||
"llava_v1"
|
||||
} else if model_name.contains("mpt") {
|
||||
"mpt"
|
||||
} else {
|
||||
"llava_v0"
|
||||
};
|
||||
if args.conv_mode.is_some() && args.conv_mode.as_deref() != Some(conv_mode) {
|
||||
println!(
|
||||
"Warning: the model is trained with {}, but you are using {}",
|
||||
conv_mode,
|
||||
args.conv_mode.as_deref().unwrap()
|
||||
);
|
||||
} else {
|
||||
args.conv_mode = Some(conv_mode.to_string());
|
||||
}
|
||||
|
||||
let mut conv = match args.conv_mode {
|
||||
Some(conv_mode) => match conv_mode.as_str() {
|
||||
"chatml_direct" => Conversation::conv_chatml_direct(),
|
||||
"llava_v1" => Conversation::conv_llava_v1(),
|
||||
_ => todo!("not implement yet"),
|
||||
},
|
||||
None => bail!("conv_mode is required"),
|
||||
};
|
||||
conv.append_user_message(Some(&qs));
|
||||
conv.append_assistant_message(None);
|
||||
let prompt = conv.get_prompt();
|
||||
println!("loading image");
|
||||
let (image_size, image_tensor) =
|
||||
load_image(&args.image_file, &image_processor, &llava_config, dtype)
|
||||
.map_err(|e| E::msg(format!("Error loading {}: {}", &args.image_file, e)))?;
|
||||
let image_tensor = image_tensor.to_device(&device)?;
|
||||
|
||||
let mut logits_processor = {
|
||||
let temperature = f64::from(args.temperature);
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
Sampling::All { temperature }
|
||||
};
|
||||
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||
};
|
||||
|
||||
// get input tokens
|
||||
let tokens = tokenizer_image_token(
|
||||
&prompt,
|
||||
&tokenizer,
|
||||
llava_config.image_token_index as i64,
|
||||
&llava_config,
|
||||
)?;
|
||||
let mut input_embeds =
|
||||
llava.prepare_inputs_labels_for_multimodal(&tokens, &[image_tensor], &[image_size])?;
|
||||
//inference loop, based on https://github.com/huggingface/candle/blob/main/candle-examples/examples/llama/main.rs
|
||||
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
|
||||
let mut index_pos = 0;
|
||||
for index in 0..args.max_new_tokens {
|
||||
let (_, input_embeds_len, _) = input_embeds.dims3()?;
|
||||
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
|
||||
(1, index_pos)
|
||||
} else {
|
||||
(input_embeds_len, 0)
|
||||
};
|
||||
let input = input_embeds.i((.., input_embeds_len.saturating_sub(context_size).., ..))?;
|
||||
let logits = llava.forward(&input, context_index, &mut cache)?; //[1,32000]
|
||||
let logits = logits.squeeze(0)?;
|
||||
let (_, input_len, _) = input.dims3()?;
|
||||
index_pos += input_len;
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
let next_token_tensor = Tensor::from_vec(vec![next_token], 1, &device)?;
|
||||
let next_embeds = llava.llama.embed(&next_token_tensor)?.unsqueeze(0)?;
|
||||
input_embeds = Tensor::cat(&[input_embeds, next_embeds], 1)?;
|
||||
if next_token == eos_token_id as u32 {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
40
candle-examples/examples/llava/readme.md
Normal file
40
candle-examples/examples/llava/readme.md
Normal file
@ -0,0 +1,40 @@
|
||||
# candle-llava
|
||||
|
||||
LLaVA (Large Language-and-Vision Assistant) is an end-to-end trained large
|
||||
multimodal model. This example is from [candle-llava](https://github.com/chenwanqq/candle-llava)
|
||||
|
||||
The code is based on [https://github.com/haotian-liu/LLaVA](https://github.com/haotian-liu/LLaVA), Hence the llava-hf version of config may perform differently.
|
||||
|
||||
## model zoo
|
||||
* [liuhaotian/LLaVA](https://huggingface.co/liuhaotian)
|
||||
* [llava-hf](https://huggingface.co/llava-hf)
|
||||
|
||||
Right now this has been tested on `liuhaotian/llava-v1.6-vicuna-7b` and
|
||||
`llava-hf/llava-v1.6-vicuna-7b-hf`. Memory usage might have room for optimization.
|
||||
|
||||
## Tokenizer Setup
|
||||
The llava-hf models contain a `tokenizer.json` file so can be used directly with
|
||||
the `-hf` command line flag.
|
||||
|
||||
For the original llava models, you can use the following code to generate the `tokenizer.json` file.
|
||||
|
||||
```bash
|
||||
conda create -n llava python=3.10
|
||||
pip install transformers protobuf
|
||||
conda activate llava
|
||||
python -c "from transformers import AutoTokenizer;tokenizer=AutoTokenizer.from_pretrained('liuhaotian/llava-v1.6-vicuna-7b');tokenizer.save_pretrained('tokenizer')"
|
||||
```
|
||||
Then the `tokenizer.json` file should be in `tokenizer/tokenizer.json` (which is the default path).
|
||||
|
||||
|
||||
## eval
|
||||
|
||||
```bash
|
||||
cargo run --example llava --features cuda -- --image-file "llava_logo.png" --prompt "is this a cat?" --hf # default args, use llava-hf/llava-v1.6-vicuna-7b-hf. image-file is required^_^
|
||||
cargo run --example llava --features cuda -- --model-path liuhaotian/llava-v1.6-vicuna-7b --image-file "llava_logo.png" --prompt "is this a cat?" # use liuhaotian/llava-v1.6-vicuna-7b, tokenizer setup should be done
|
||||
```
|
||||
|
||||
## Major Limitations
|
||||
1. Currently only support llama-2/vicuna llm. Haven't supoort Mistral yet.
|
||||
2. There are some ops like split, nonzero and where are not supported by candle.
|
||||
3. Lack of quantization and LoRA support.
|
@ -147,6 +147,12 @@ enum Which {
|
||||
Mistral7bInstructV01,
|
||||
#[value(name = "7b-instruct-v0.2")]
|
||||
Mistral7bInstructV02,
|
||||
#[value(name = "7b-maths-v0.1")]
|
||||
Mathstral7bV01,
|
||||
#[value(name = "nemo-2407")]
|
||||
MistralNemo2407,
|
||||
#[value(name = "nemo-instruct-2407")]
|
||||
MistralNemoInstruct2407,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -261,12 +267,16 @@ fn main() -> Result<()> {
|
||||
}
|
||||
"lmz/candle-mistral".to_string()
|
||||
} else {
|
||||
match args.which {
|
||||
Which::Mistral7bV01 => "mistralai/Mistral-7B-v0.1".to_string(),
|
||||
Which::Mistral7bV02 => "mistralai/Mistral-7B-v0.2".to_string(),
|
||||
Which::Mistral7bInstructV01 => "mistralai/Mistral-7B-Instruct-v0.1".to_string(),
|
||||
Which::Mistral7bInstructV02 => "mistralai/Mistral-7B-Instruct-v0.2".to_string(),
|
||||
}
|
||||
let name = match args.which {
|
||||
Which::Mistral7bV01 => "mistralai/Mistral-7B-v0.1",
|
||||
Which::Mistral7bV02 => "mistralai/Mistral-7B-v0.2",
|
||||
Which::Mistral7bInstructV01 => "mistralai/Mistral-7B-Instruct-v0.1",
|
||||
Which::Mistral7bInstructV02 => "mistralai/Mistral-7B-Instruct-v0.2",
|
||||
Which::Mathstral7bV01 => "mistralai/mathstral-7B-v0.1",
|
||||
Which::MistralNemo2407 => "mistralai/Mistral-Nemo-Base-2407",
|
||||
Which::MistralNemoInstruct2407 => "mistralai/Mistral-Nemo-Instruct-2407",
|
||||
};
|
||||
name.to_string()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -217,11 +217,7 @@ fn main() -> Result<()> {
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::v0_1_8x7b(args.use_flash_attn);
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let dtype = device.bf16_default_to_f32();
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
18
candle-examples/examples/mobilenetv4/README.md
Normal file
18
candle-examples/examples/mobilenetv4/README.md
Normal file
@ -0,0 +1,18 @@
|
||||
# candle-mobilenetv4
|
||||
|
||||
[MobileNetV4 - Universal Models for the Mobile Ecosystem](https://arxiv.org/abs/2404.10518)
|
||||
This candle implementation uses pre-trained MobileNetV4 models from timm for inference.
|
||||
The classification head has been trained on the ImageNet dataset and returns the probabilities for the top-5 classes.
|
||||
|
||||
## Running an example
|
||||
|
||||
```
|
||||
$ cargo run --example mobilenetv4 --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which medium
|
||||
loaded image Tensor[dims 3, 256, 256; f32]
|
||||
model built
|
||||
unicycle, monocycle : 20.18%
|
||||
mountain bike, all-terrain bike, off-roader: 19.77%
|
||||
bicycle-built-for-two, tandem bicycle, tandem: 15.91%
|
||||
crash helmet : 1.15%
|
||||
tricycle, trike, velocipede: 0.67%
|
||||
```
|
106
candle-examples/examples/mobilenetv4/main.rs
Normal file
106
candle-examples/examples/mobilenetv4/main.rs
Normal file
@ -0,0 +1,106 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::mobilenetv4;
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
Small,
|
||||
Medium,
|
||||
Large,
|
||||
HybridMedium,
|
||||
HybridLarge,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_filename(&self) -> String {
|
||||
let name = match self {
|
||||
Self::Small => "conv_small.e2400_r224",
|
||||
Self::Medium => "conv_medium.e500_r256",
|
||||
Self::HybridMedium => "hybrid_medium.ix_e550_r256",
|
||||
Self::Large => "conv_large.e600_r384",
|
||||
Self::HybridLarge => "hybrid_large.ix_e600_r384",
|
||||
};
|
||||
format!("timm/mobilenetv4_{}_in1k", name)
|
||||
}
|
||||
|
||||
fn resolution(&self) -> u32 {
|
||||
match self {
|
||||
Self::Small => 224,
|
||||
Self::Medium => 256,
|
||||
Self::HybridMedium => 256,
|
||||
Self::Large => 384,
|
||||
Self::HybridLarge => 384,
|
||||
}
|
||||
}
|
||||
fn config(&self) -> mobilenetv4::Config {
|
||||
match self {
|
||||
Self::Small => mobilenetv4::Config::small(),
|
||||
Self::Medium => mobilenetv4::Config::medium(),
|
||||
Self::HybridMedium => mobilenetv4::Config::hybrid_medium(),
|
||||
Self::Large => mobilenetv4::Config::large(),
|
||||
Self::HybridLarge => mobilenetv4::Config::hybrid_large(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(value_enum, long, default_value_t=Which::Small)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image(args.image, args.which.resolution())?
|
||||
.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let model_name = args.which.model_filename();
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(model_name);
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = mobilenetv4::mobilenetv4(&args.which.config(), 1000, vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -188,8 +188,8 @@ struct Args {
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
@ -208,7 +208,7 @@ struct Args {
|
||||
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||
/// (3, 378, 378).
|
||||
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> candle::Result<Tensor> {
|
||||
let img = image::io::Reader::open(p)?
|
||||
let img = image::ImageReader::open(p)?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?
|
||||
.resize_to_fill(378, 378, image::imageops::FilterType::Triangle); // Adjusted to 378x378
|
||||
@ -252,20 +252,28 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = hf_hub::api::tokio::Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id.to_string(),
|
||||
let (model_id, revision) = match args.model_id {
|
||||
Some(model_id) => (model_id.to_string(), None),
|
||||
None => {
|
||||
if args.quantized {
|
||||
"santiagomed/candle-moondream".to_string()
|
||||
("santiagomed/candle-moondream".to_string(), None)
|
||||
} else {
|
||||
"vikhyatk/moondream2".to_string()
|
||||
(
|
||||
"vikhyatk/moondream2".to_string(),
|
||||
Some("30c7cdf3fa6914f50bee3956694374143f5cc884"),
|
||||
)
|
||||
}
|
||||
}
|
||||
};
|
||||
let revision = match (args.revision, revision) {
|
||||
(Some(r), _) => r,
|
||||
(None, Some(r)) => r.to_string(),
|
||||
(None, None) => "main".to_string(),
|
||||
};
|
||||
let repo = api.repo(hf_hub::Repo::with_revision(
|
||||
model_id,
|
||||
hf_hub::RepoType::Model,
|
||||
args.revision,
|
||||
revision,
|
||||
));
|
||||
let model_file = match args.model_file {
|
||||
Some(m) => m.into(),
|
||||
|
36
candle-examples/examples/olmo/README.md
Normal file
36
candle-examples/examples/olmo/README.md
Normal file
@ -0,0 +1,36 @@
|
||||
# candle-olmo: Open Language Models designed to enable the science of language models
|
||||
|
||||
OLMo is a series of Open Language Models designed to enable the science of language models.
|
||||
|
||||
- **Project Page:** https://allenai.org/olmo
|
||||
- **Paper:** [Link](https://arxiv.org/abs/2402.00838)
|
||||
- **Technical blog post:** https://blog.allenai.org/olmo-open-language-model-87ccfc95f580
|
||||
- **W&B Logs:** https://wandb.ai/ai2-llm/OLMo-1B/reports/OLMo-1B--Vmlldzo2NzY1Njk1
|
||||
<!-- - **Press release:** TODO -->
|
||||
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
$ cargo run --example olmo --release -- --prompt "It is only with the heart that one can see rightly"
|
||||
|
||||
avx: true, neon: false, simd128: false, f16c: true
|
||||
temp: 0.20 repeat-penalty: 1.10 repeat-last-n: 64
|
||||
retrieved the files in 354.977µs
|
||||
loaded the model in 19.87779666s
|
||||
It is only with the heart that one can see rightly; what is essential is invisible to the eye.
|
||||
```
|
||||
|
||||
Various model sizes are available via the `--model` argument.
|
||||
|
||||
```bash
|
||||
$ cargo run --example olmo --release -- --model 1.7-7b --prompt 'It is only with the heart that one can see rightly'
|
||||
|
||||
avx: true, neon: false, simd128: false, f16c: true
|
||||
temp: 0.20 repeat-penalty: 1.10 repeat-last-n: 64
|
||||
retrieved the files in 1.226087ms
|
||||
loaded the model in 171.274578609s
|
||||
It is only with the heart that one can see rightly; what is essential is invisible to the eye.”
|
||||
~ Antoine de Saint-Exupery, The Little Prince
|
||||
I am a big fan of this quote. It reminds me that I need to be open and aware of my surroundings in order to truly appreciate them.
|
||||
```
|
||||
|
284
candle-examples/examples/olmo/main.rs
Normal file
284
candle-examples/examples/olmo/main.rs
Normal file
@ -0,0 +1,284 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::olmo::{Config, Model as OLMo};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
OLMo(OLMo),
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, false)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = match &mut self.model {
|
||||
Model::OLMo(m) => m.forward(&input, start_pos)?,
|
||||
};
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)]
|
||||
enum Which {
|
||||
#[value(name = "1b")]
|
||||
W1b,
|
||||
#[value(name = "7b")]
|
||||
W7b,
|
||||
#[value(name = "7b-twin-2t")]
|
||||
W7bTwin2T,
|
||||
#[value(name = "1.7-7b")]
|
||||
V1_7W7b,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 1000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long, default_value = "1b")]
|
||||
model: Which,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id,
|
||||
None => match args.model {
|
||||
Which::W1b => "allenai/OLMo-1B-hf".to_string(),
|
||||
Which::W7b => "allenai/OLMo-7B-hf".to_string(),
|
||||
Which::W7bTwin2T => "allenai/OLMo-7B-Twin-2T-hf".to_string(),
|
||||
Which::V1_7W7b => "allenai/OLMo-1.7-7B-hf".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => match args.model {
|
||||
Which::W1b => {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
},
|
||||
};
|
||||
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = {
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
config
|
||||
};
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let model = {
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = OLMo::new(&config, vb)?;
|
||||
Model::OLMo(model)
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -1,8 +1,9 @@
|
||||
# candle-phi: 1.3b and 2.7b LLM with state of the art performance for <10b models.
|
||||
|
||||
[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) and
|
||||
[Phi-2](https://huggingface.co/microsoft/phi-2) are language models using
|
||||
only 1.3 and 2.7 billion parameters but with state of the art performance compared to
|
||||
[Phi-1.5](https://huggingface.co/microsoft/phi-1_5),
|
||||
[Phi-2](https://huggingface.co/microsoft/phi-2), and
|
||||
[Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) are language models using
|
||||
only 1.3, 2.7, and 3.8 billion parameters but with state of the art performance compared to
|
||||
models with up to 10 billion parameters.
|
||||
|
||||
The candle implementation provides both the standard version as well as a
|
||||
|
@ -7,11 +7,13 @@ extern crate accelerate_src;
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
||||
use candle_transformers::models::phi::{Config as PhiConfig, Model as Phi};
|
||||
use candle_transformers::models::phi3::{Config as Phi3Config, Model as Phi3};
|
||||
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle::{DType, Device, IndexOp, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
@ -20,13 +22,14 @@ use tokenizers::Tokenizer;
|
||||
enum Model {
|
||||
MixFormer(MixFormer),
|
||||
Phi(Phi),
|
||||
Phi3(Phi3),
|
||||
Quantized(QMixFormer),
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
@ -49,7 +52,7 @@ impl TextGeneration {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
@ -61,7 +64,11 @@ impl TextGeneration {
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
println!("starting the inference loop");
|
||||
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
||||
let tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?;
|
||||
if tokens.is_empty() {
|
||||
anyhow::bail!("Empty prompts are not supported in the phi model.")
|
||||
}
|
||||
@ -73,13 +80,14 @@ impl TextGeneration {
|
||||
}
|
||||
let mut tokens = tokens.get_ids().to_vec();
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
||||
Some(token) => *token,
|
||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the endoftext token"),
|
||||
};
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush()?;
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut pos = 0;
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
@ -88,6 +96,7 @@ impl TextGeneration {
|
||||
Model::MixFormer(m) => m.forward(&input)?,
|
||||
Model::Phi(m) => m.forward(&input)?,
|
||||
Model::Quantized(m) => m.forward(&input)?,
|
||||
Model::Phi3(m) => m.forward(&input, pos)?.i((.., 0, ..))?,
|
||||
};
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
@ -105,11 +114,17 @@ impl TextGeneration {
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
if let Some(t) = self.tokenizer.decode_rest()? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
break;
|
||||
}
|
||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||
print!("{token}");
|
||||
std::io::stdout().flush()?;
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
pos += context_size;
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
@ -128,6 +143,10 @@ enum WhichModel {
|
||||
V1_5,
|
||||
#[value(name = "2")]
|
||||
V2,
|
||||
#[value(name = "3")]
|
||||
V3,
|
||||
#[value(name = "3-medium")]
|
||||
V3Medium,
|
||||
#[value(name = "2-old")]
|
||||
V2Old,
|
||||
PuffinPhiV2,
|
||||
@ -196,6 +215,10 @@ struct Args {
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The dtype to be used for running the model, e.g. f32, bf16, or f16.
|
||||
#[arg(long)]
|
||||
dtype: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -236,6 +259,8 @@ fn main() -> Result<()> {
|
||||
WhichModel::V1 => "microsoft/phi-1".to_string(),
|
||||
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
||||
WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(),
|
||||
WhichModel::V3 => "microsoft/Phi-3-mini-4k-instruct".to_string(),
|
||||
WhichModel::V3Medium => "microsoft/Phi-3-medium-4k-instruct".to_string(),
|
||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
"lmz/candle-quantized-phi".to_string()
|
||||
}
|
||||
@ -253,9 +278,11 @@ fn main() -> Result<()> {
|
||||
WhichModel::V1 => "refs/pr/8".to_string(),
|
||||
WhichModel::V1_5 => "refs/pr/73".to_string(),
|
||||
WhichModel::V2Old => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(),
|
||||
WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
"main".to_string()
|
||||
}
|
||||
WhichModel::V2
|
||||
| WhichModel::V3
|
||||
| WhichModel::V3Medium
|
||||
| WhichModel::PuffinPhiV2
|
||||
| WhichModel::PhiHermes => "main".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -264,9 +291,12 @@ fn main() -> Result<()> {
|
||||
let tokenizer_filename = match args.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 | WhichModel::V2Old => {
|
||||
repo.get("tokenizer.json")?
|
||||
}
|
||||
WhichModel::V1
|
||||
| WhichModel::V1_5
|
||||
| WhichModel::V2
|
||||
| WhichModel::V2Old
|
||||
| WhichModel::V3
|
||||
| WhichModel::V3Medium => repo.get("tokenizer.json")?,
|
||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
repo.get("tokenizer-puffin-phi-v2.json")?
|
||||
}
|
||||
@ -282,14 +312,19 @@ fn main() -> Result<()> {
|
||||
WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?],
|
||||
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
|
||||
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
|
||||
WhichModel::V3 | WhichModel::V3Medium => anyhow::bail!(
|
||||
"use the quantized or quantized-phi examples for quantized phi-v3"
|
||||
),
|
||||
}
|
||||
} else {
|
||||
match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
|
||||
WhichModel::V2 | WhichModel::V2Old => candle_examples::hub_load_safetensors(
|
||||
&repo,
|
||||
"model.safetensors.index.json",
|
||||
)?,
|
||||
WhichModel::V2 | WhichModel::V2Old | WhichModel::V3 | WhichModel::V3Medium => {
|
||||
candle_examples::hub_load_safetensors(
|
||||
&repo,
|
||||
"model.safetensors.index.json",
|
||||
)?
|
||||
}
|
||||
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2.safetensors")?],
|
||||
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B.safetensors")?],
|
||||
}
|
||||
@ -306,6 +341,9 @@ fn main() -> Result<()> {
|
||||
WhichModel::V2 | WhichModel::V2Old => Config::v2(),
|
||||
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
||||
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
||||
WhichModel::V3 | WhichModel::V3Medium => {
|
||||
panic!("use the quantized or quantized-phi examples for quantized phi-v3")
|
||||
}
|
||||
};
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let model = if args.quantized {
|
||||
@ -320,7 +358,17 @@ fn main() -> Result<()> {
|
||||
};
|
||||
Model::Quantized(model)
|
||||
} else {
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let dtype = match args.dtype {
|
||||
Some(dtype) => std::str::FromStr::from_str(&dtype)?,
|
||||
None => {
|
||||
if args.model == WhichModel::V3 || args.model == WhichModel::V3Medium {
|
||||
device.bf16_default_to_f32()
|
||||
} else {
|
||||
DType::F32
|
||||
}
|
||||
}
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => {
|
||||
let config_filename = repo.get("config.json")?;
|
||||
@ -329,6 +377,13 @@ fn main() -> Result<()> {
|
||||
let phi = Phi::new(&config, vb)?;
|
||||
Model::Phi(phi)
|
||||
}
|
||||
WhichModel::V3 | WhichModel::V3Medium => {
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let config: Phi3Config = serde_json::from_str(&config)?;
|
||||
let phi3 = Phi3::new(&config, vb)?;
|
||||
Model::Phi3(phi3)
|
||||
}
|
||||
WhichModel::V2Old => {
|
||||
let config = config();
|
||||
Model::MixFormer(MixFormer::new_v2(&config, vb)?)
|
||||
@ -421,6 +476,10 @@ fn mmlu<P: AsRef<std::path::Path>>(
|
||||
m.clear_kv_cache();
|
||||
m.forward(&input)?
|
||||
}
|
||||
Model::Phi3(m) => {
|
||||
m.clear_kv_cache();
|
||||
m.forward(&input, 0)?
|
||||
}
|
||||
Model::Quantized(m) => {
|
||||
m.clear_kv_cache();
|
||||
m.forward(&input)?
|
||||
|
325
candle-examples/examples/quantized-phi/main.rs
Normal file
325
candle-examples/examples/quantized-phi/main.rs
Normal file
@ -0,0 +1,325 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use candle::quantized::gguf_file;
|
||||
use candle::Tensor;
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_transformers::models::quantized_llama::ModelWeights as Phi3b;
|
||||
use candle_transformers::models::quantized_phi::ModelWeights as Phi2;
|
||||
use candle_transformers::models::quantized_phi3::ModelWeights as Phi3;
|
||||
|
||||
const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. ";
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "phi-2")]
|
||||
Phi2,
|
||||
#[value(name = "phi-3")]
|
||||
Phi3,
|
||||
/// Alternative implementation of phi-3, based on llama.
|
||||
#[value(name = "phi-3b")]
|
||||
Phi3b,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// GGUF file to load, typically a .gguf file generated by the quantize command from llama.cpp
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
|
||||
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
|
||||
/// is preserved.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(short = 'n', long, default_value_t = 1000)]
|
||||
sample_len: usize,
|
||||
|
||||
/// The tokenizer config in json format.
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
/// The temperature used to generate samples, use 0 for greedy sampling.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples.
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Process prompt elements separately.
|
||||
#[arg(long)]
|
||||
split_prompt: bool,
|
||||
|
||||
/// Run on CPU rather than GPU even if a GPU is available.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "phi-3b")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
|
||||
let tokenizer_path = match &self.tokenizer {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let repo = match self.which {
|
||||
Which::Phi2 => "microsoft/phi-2",
|
||||
Which::Phi3 | Which::Phi3b => "microsoft/Phi-3-mini-4k-instruct",
|
||||
};
|
||||
let api = api.model(repo.to_string());
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
};
|
||||
Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
|
||||
}
|
||||
|
||||
fn model(&self) -> anyhow::Result<std::path::PathBuf> {
|
||||
let model_path = match &self.model {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let (repo, filename, revision) = match self.which {
|
||||
Which::Phi2 => ("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf", "main"),
|
||||
Which::Phi3 => (
|
||||
"microsoft/Phi-3-mini-4k-instruct-gguf",
|
||||
"Phi-3-mini-4k-instruct-q4.gguf",
|
||||
"main",
|
||||
),
|
||||
Which::Phi3b => (
|
||||
"microsoft/Phi-3-mini-4k-instruct-gguf",
|
||||
"Phi-3-mini-4k-instruct-q4.gguf",
|
||||
"5eef2ce24766d31909c0b269fe90c817a8f263fb",
|
||||
),
|
||||
};
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
api.repo(hf_hub::Repo::with_revision(
|
||||
repo.to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
revision.to_string(),
|
||||
))
|
||||
.get(filename)?
|
||||
}
|
||||
};
|
||||
Ok(model_path)
|
||||
}
|
||||
}
|
||||
|
||||
fn format_size(size_in_bytes: usize) -> String {
|
||||
if size_in_bytes < 1_000 {
|
||||
format!("{}B", size_in_bytes)
|
||||
} else if size_in_bytes < 1_000_000 {
|
||||
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
|
||||
} else if size_in_bytes < 1_000_000_000 {
|
||||
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
|
||||
} else {
|
||||
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
|
||||
}
|
||||
}
|
||||
|
||||
enum Model {
|
||||
Phi2(Phi2),
|
||||
Phi3(Phi3),
|
||||
Phi3b(Phi3b),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&mut self, xs: &Tensor, pos: usize) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Phi2(m) => m.forward(xs, pos),
|
||||
Self::Phi3(m) => m.forward(xs, pos),
|
||||
Self::Phi3b(m) => m.forward(xs, pos),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature, args.repeat_penalty, args.repeat_last_n
|
||||
);
|
||||
|
||||
let model_path = args.model()?;
|
||||
let mut file = std::fs::File::open(&model_path)?;
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let mut model = {
|
||||
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
|
||||
let mut total_size_in_bytes = 0;
|
||||
for (_, tensor) in model.tensor_infos.iter() {
|
||||
let elem_count = tensor.shape.elem_count();
|
||||
total_size_in_bytes +=
|
||||
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
|
||||
}
|
||||
println!(
|
||||
"loaded {:?} tensors ({}) in {:.2}s",
|
||||
model.tensor_infos.len(),
|
||||
&format_size(total_size_in_bytes),
|
||||
start.elapsed().as_secs_f32(),
|
||||
);
|
||||
match args.which {
|
||||
Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?),
|
||||
Which::Phi3 => Model::Phi3(Phi3::from_gguf(
|
||||
args.use_flash_attn,
|
||||
model,
|
||||
&mut file,
|
||||
&device,
|
||||
)?),
|
||||
Which::Phi3b => Model::Phi3b(Phi3b::from_gguf(model, &mut file, &device)?),
|
||||
}
|
||||
};
|
||||
println!("model built");
|
||||
|
||||
let tokenizer = args.tokenizer()?;
|
||||
let mut tos = TokenOutputStream::new(tokenizer);
|
||||
let prompt_str = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
|
||||
print!("{}", &prompt_str);
|
||||
let tokens = tos
|
||||
.tokenizer()
|
||||
.encode(prompt_str, true)
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
let tokens = tokens.get_ids();
|
||||
let to_sample = args.sample_len.saturating_sub(1);
|
||||
let mut all_tokens = vec![];
|
||||
let mut logits_processor = {
|
||||
let temperature = args.temperature;
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (args.top_k, args.top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||
};
|
||||
|
||||
let start_prompt_processing = std::time::Instant::now();
|
||||
let mut next_token = if !args.split_prompt {
|
||||
let input = Tensor::new(tokens, &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
logits_processor.sample(&logits)?
|
||||
} else {
|
||||
let mut next_token = 0;
|
||||
for (pos, token) in tokens.iter().enumerate() {
|
||||
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
next_token = logits_processor.sample(&logits)?
|
||||
}
|
||||
next_token
|
||||
};
|
||||
let prompt_dt = start_prompt_processing.elapsed();
|
||||
all_tokens.push(next_token);
|
||||
if let Some(t) = tos.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
let eos_token = *tos
|
||||
.tokenizer()
|
||||
.get_vocab(true)
|
||||
.get("<|endoftext|>")
|
||||
.unwrap();
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
let mut sampled = 0;
|
||||
for index in 0..to_sample {
|
||||
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, tokens.len() + index)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&all_tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
next_token = logits_processor.sample(&logits)?;
|
||||
all_tokens.push(next_token);
|
||||
if let Some(t) = tos.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
sampled += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
};
|
||||
}
|
||||
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
let dt = start_post_prompt.elapsed();
|
||||
println!(
|
||||
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
||||
tokens.len(),
|
||||
tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
||||
);
|
||||
println!(
|
||||
"{sampled:4} tokens generated: {:.2} token/s",
|
||||
sampled as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
11
candle-examples/examples/quantized-qwen2-instruct/README.md
Normal file
11
candle-examples/examples/quantized-qwen2-instruct/README.md
Normal file
@ -0,0 +1,11 @@
|
||||
# candle-quantized-qwen2-instruct
|
||||
|
||||
[Qwen2]((https://qwenlm.github.io/blog/qwen2/)) is an upgraded version of Qwen1.5, released by Alibaba Cloud.
|
||||
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
cargo run --example quantized-qwen2-instruct --release -- --prompt "Write a function to count prime numbers up to N."
|
||||
```
|
||||
|
||||
0.5b, 1.5b, 7b and 72b models are available via `--model` argument.
|
306
candle-examples/examples/quantized-qwen2-instruct/main.rs
Normal file
306
candle-examples/examples/quantized-qwen2-instruct/main.rs
Normal file
@ -0,0 +1,306 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use candle::quantized::gguf_file;
|
||||
use candle::Tensor;
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_transformers::models::quantized_qwen2::ModelWeights as Qwen2;
|
||||
|
||||
const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. ";
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "0.5b")]
|
||||
W2_0_5b,
|
||||
#[value(name = "1.5b")]
|
||||
W2_1_5b,
|
||||
#[value(name = "7b")]
|
||||
W2_7b,
|
||||
#[value(name = "72b")]
|
||||
W2_72b,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// GGUF file to load, typically a .gguf file generated by the quantize command from llama.cpp
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
|
||||
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
|
||||
/// is preserved.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(short = 'n', long, default_value_t = 1000)]
|
||||
sample_len: usize,
|
||||
|
||||
/// The tokenizer config in json format.
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
/// The temperature used to generate samples, use 0 for greedy sampling.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples.
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Process prompt elements separately.
|
||||
#[arg(long)]
|
||||
split_prompt: bool,
|
||||
|
||||
/// Run on CPU rather than GPU even if a GPU is available.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "0.5b")]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
|
||||
let tokenizer_path = match &self.tokenizer {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let repo = match self.which {
|
||||
Which::W2_0_5b => "Qwen/Qwen2-0.5B-Instruct",
|
||||
Which::W2_1_5b => "Qwen/Qwen2-1.5B-Instruct",
|
||||
Which::W2_7b => "Qwen/Qwen2-7B-Instruct",
|
||||
Which::W2_72b => "Qwen/Qwen2-72B-Instruct",
|
||||
};
|
||||
let api = api.model(repo.to_string());
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
};
|
||||
Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
|
||||
}
|
||||
|
||||
fn model(&self) -> anyhow::Result<std::path::PathBuf> {
|
||||
let model_path = match &self.model {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let (repo, filename, revision) = match self.which {
|
||||
Which::W2_0_5b => (
|
||||
"Qwen/Qwen2-0.5B-Instruct-GGUF",
|
||||
"qwen2-0_5b-instruct-q4_0.gguf",
|
||||
"main",
|
||||
),
|
||||
Which::W2_1_5b => (
|
||||
"Qwen/Qwen2-1.5B-Instruct-GGUF",
|
||||
"qwen2-1_5b-instruct-q4_0.gguf",
|
||||
"main",
|
||||
),
|
||||
Which::W2_7b => (
|
||||
"Qwen/Qwen2-7B-Instruct-GGUF",
|
||||
"qwen2-7b-instruct-q4_0.gguf",
|
||||
"main",
|
||||
),
|
||||
Which::W2_72b => (
|
||||
"Qwen/Qwen2-72B-Instruct-GGUF",
|
||||
"qwen2-72b-instruct-q4_0.gguf",
|
||||
"main",
|
||||
),
|
||||
};
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
api.repo(hf_hub::Repo::with_revision(
|
||||
repo.to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
revision.to_string(),
|
||||
))
|
||||
.get(filename)?
|
||||
}
|
||||
};
|
||||
Ok(model_path)
|
||||
}
|
||||
}
|
||||
|
||||
fn format_size(size_in_bytes: usize) -> String {
|
||||
if size_in_bytes < 1_000 {
|
||||
format!("{}B", size_in_bytes)
|
||||
} else if size_in_bytes < 1_000_000 {
|
||||
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
|
||||
} else if size_in_bytes < 1_000_000_000 {
|
||||
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
|
||||
} else {
|
||||
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature, args.repeat_penalty, args.repeat_last_n
|
||||
);
|
||||
|
||||
let model_path = args.model()?;
|
||||
let mut file = std::fs::File::open(&model_path)?;
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let mut model = {
|
||||
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
|
||||
let mut total_size_in_bytes = 0;
|
||||
for (_, tensor) in model.tensor_infos.iter() {
|
||||
let elem_count = tensor.shape.elem_count();
|
||||
total_size_in_bytes +=
|
||||
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
|
||||
}
|
||||
println!(
|
||||
"loaded {:?} tensors ({}) in {:.2}s",
|
||||
model.tensor_infos.len(),
|
||||
&format_size(total_size_in_bytes),
|
||||
start.elapsed().as_secs_f32(),
|
||||
);
|
||||
Qwen2::from_gguf(model, &mut file, &device)?
|
||||
};
|
||||
println!("model built");
|
||||
|
||||
let tokenizer = args.tokenizer()?;
|
||||
let mut tos = TokenOutputStream::new(tokenizer);
|
||||
let prompt_str = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
|
||||
let prompt_str = format!(
|
||||
"<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
|
||||
prompt_str
|
||||
);
|
||||
print!("formatted instruct prompt: {}", &prompt_str);
|
||||
let tokens = tos
|
||||
.tokenizer()
|
||||
.encode(prompt_str, true)
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
let tokens = tokens.get_ids();
|
||||
let to_sample = args.sample_len.saturating_sub(1);
|
||||
let mut all_tokens = vec![];
|
||||
let mut logits_processor = {
|
||||
let temperature = args.temperature;
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (args.top_k, args.top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||
};
|
||||
let start_prompt_processing = std::time::Instant::now();
|
||||
let mut next_token = if !args.split_prompt {
|
||||
let input = Tensor::new(tokens, &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
logits_processor.sample(&logits)?
|
||||
} else {
|
||||
let mut next_token = 0;
|
||||
for (pos, token) in tokens.iter().enumerate() {
|
||||
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
next_token = logits_processor.sample(&logits)?
|
||||
}
|
||||
next_token
|
||||
};
|
||||
let prompt_dt = start_prompt_processing.elapsed();
|
||||
all_tokens.push(next_token);
|
||||
if let Some(t) = tos.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
let eos_token = *tos.tokenizer().get_vocab(true).get("<|im_end|>").unwrap();
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
let mut sampled = 0;
|
||||
for index in 0..to_sample {
|
||||
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, tokens.len() + index)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&all_tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
next_token = logits_processor.sample(&logits)?;
|
||||
all_tokens.push(next_token);
|
||||
if let Some(t) = tos.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
sampled += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
};
|
||||
}
|
||||
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
let dt = start_post_prompt.elapsed();
|
||||
println!(
|
||||
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
||||
tokens.len(),
|
||||
tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
||||
);
|
||||
println!(
|
||||
"{sampled:4} tokens generated: {:.2} token/s",
|
||||
sampled as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
@ -67,8 +67,10 @@ enum Which {
|
||||
Mixtral,
|
||||
#[value(name = "mixtral-instruct")]
|
||||
MixtralInstruct,
|
||||
#[value(name = "phi-2")]
|
||||
Phi2,
|
||||
#[value(name = "llama3-8b")]
|
||||
L8b,
|
||||
#[value(name = "phi3")]
|
||||
Phi3,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
@ -85,7 +87,8 @@ impl Which {
|
||||
| Self::L34bCode
|
||||
| Self::Leo7b
|
||||
| Self::Leo13b
|
||||
| Self::Phi2 => false,
|
||||
| Self::L8b
|
||||
| Self::Phi3 => false,
|
||||
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
|
||||
// same way. Starling is a fine tuned version of OpenChat.
|
||||
Self::OpenChat35
|
||||
@ -119,8 +122,9 @@ impl Which {
|
||||
| Self::Mistral7bInstruct
|
||||
| Self::Mistral7bInstructV02
|
||||
| Self::OpenChat35
|
||||
| Self::Phi2
|
||||
| Self::Starling7bAlpha => false,
|
||||
| Self::Starling7bAlpha
|
||||
| Self::L8b
|
||||
| Self::Phi3 => false,
|
||||
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
||||
}
|
||||
}
|
||||
@ -143,9 +147,10 @@ impl Which {
|
||||
| Self::Mistral7b
|
||||
| Self::Mistral7bInstruct
|
||||
| Self::Mistral7bInstructV02
|
||||
| Self::Phi2
|
||||
| Self::Zephyr7bAlpha
|
||||
| Self::Zephyr7bBeta => false,
|
||||
| Self::Zephyr7bBeta
|
||||
| Self::L8b
|
||||
| Self::Phi3 => false,
|
||||
Self::OpenChat35 | Self::Starling7bAlpha => true,
|
||||
}
|
||||
}
|
||||
@ -172,7 +177,8 @@ impl Which {
|
||||
| Self::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
|
||||
Self::OpenChat35 => "openchat/openchat_3.5",
|
||||
Self::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
|
||||
Self::Phi2 => "microsoft/phi-2",
|
||||
Self::L8b => "meta-llama/Meta-Llama-3-8B",
|
||||
Self::Phi3 => "microsoft/Phi-3-mini-4k-instruct",
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -328,11 +334,28 @@ impl Args {
|
||||
"TheBloke/Starling-LM-7B-alpha-GGUF",
|
||||
"starling-lm-7b-alpha.Q4_K_M.gguf",
|
||||
),
|
||||
Which::Phi2 => ("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf"),
|
||||
// TODO: swap to TheBloke model when available
|
||||
Which::L8b => (
|
||||
"QuantFactory/Meta-Llama-3-8B-GGUF",
|
||||
"Meta-Llama-3-8B.Q4_K_S.gguf",
|
||||
),
|
||||
Which::Phi3 => (
|
||||
"microsoft/Phi-3-mini-4k-instruct-gguf",
|
||||
"Phi-3-mini-4k-instruct-q4.gguf",
|
||||
),
|
||||
};
|
||||
let revision = if self.which == Which::Phi3 {
|
||||
"5eef2ce24766d31909c0b269fe90c817a8f263fb"
|
||||
} else {
|
||||
"main"
|
||||
};
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(repo.to_string());
|
||||
api.get(filename)?
|
||||
api.repo(hf_hub::Repo::with_revision(
|
||||
repo.to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
revision.to_string(),
|
||||
))
|
||||
.get(filename)?
|
||||
}
|
||||
};
|
||||
Ok(model_path)
|
||||
@ -360,6 +383,9 @@ fn main() -> anyhow::Result<()> {
|
||||
#[cfg(feature = "cuda")]
|
||||
candle::quantized::cuda::set_force_dmmv(args.force_dmmv);
|
||||
|
||||
candle::cuda::set_gemm_reduced_precision_f16(true);
|
||||
candle::cuda::set_gemm_reduced_precision_bf16(true);
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
@ -428,7 +454,8 @@ fn main() -> anyhow::Result<()> {
|
||||
| Which::L34bCode
|
||||
| Which::Leo7b
|
||||
| Which::Leo13b
|
||||
| Which::Phi2 => 1,
|
||||
| Which::L8b
|
||||
| Which::Phi3 => 1,
|
||||
Which::Mixtral
|
||||
| Which::MixtralInstruct
|
||||
| Which::Mistral7b
|
||||
@ -545,11 +572,14 @@ fn main() -> anyhow::Result<()> {
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
||||
let eos_token = if args.which.is_open_chat() {
|
||||
"<|end_of_turn|>"
|
||||
} else {
|
||||
"</s>"
|
||||
let eos_token = match args.which {
|
||||
Which::L8b => "<|end_of_text|>",
|
||||
_ => match args.which.is_open_chat() {
|
||||
true => "<|end_of_turn|>",
|
||||
false => "</s>",
|
||||
},
|
||||
};
|
||||
|
||||
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
let mut sampled = 0;
|
||||
|
@ -7,7 +7,7 @@ extern crate accelerate_src;
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::qwen2::{Config as ConfigBase, Model as ModelBase};
|
||||
use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase};
|
||||
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
@ -144,6 +144,14 @@ enum WhichModel {
|
||||
W72b,
|
||||
#[value(name = "moe-a2.7b")]
|
||||
MoeA27b,
|
||||
#[value(name = "2-0.5b")]
|
||||
W2_0_5b,
|
||||
#[value(name = "2-1.5b")]
|
||||
W2_1_5b,
|
||||
#[value(name = "2-7b")]
|
||||
W2_7b,
|
||||
#[value(name = "2-72b")]
|
||||
W2_72b,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -234,16 +242,20 @@ fn main() -> Result<()> {
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id,
|
||||
None => {
|
||||
let size = match args.model {
|
||||
WhichModel::W0_5b => "0.5B",
|
||||
WhichModel::W1_8b => "1.8B",
|
||||
WhichModel::W4b => "4B",
|
||||
WhichModel::W7b => "7B",
|
||||
WhichModel::W14b => "14B",
|
||||
WhichModel::W72b => "72B",
|
||||
WhichModel::MoeA27b => "MoE-A2.7B",
|
||||
let (version, size) = match args.model {
|
||||
WhichModel::W2_0_5b => ("2", "0.5B"),
|
||||
WhichModel::W2_1_5b => ("2", "1.5B"),
|
||||
WhichModel::W2_7b => ("2", "7B"),
|
||||
WhichModel::W2_72b => ("2", "72B"),
|
||||
WhichModel::W0_5b => ("1.5", "0.5B"),
|
||||
WhichModel::W1_8b => ("1.5", "1.8B"),
|
||||
WhichModel::W4b => ("1.5", "4B"),
|
||||
WhichModel::W7b => ("1.5", "7B"),
|
||||
WhichModel::W14b => ("1.5", "14B"),
|
||||
WhichModel::W72b => ("1.5", "72B"),
|
||||
WhichModel::MoeA27b => ("1.5", "MoE-A2.7B"),
|
||||
};
|
||||
format!("Qwen/Qwen1.5-{size}")
|
||||
format!("Qwen/Qwen{version}-{size}")
|
||||
}
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
@ -261,11 +273,15 @@ fn main() -> Result<()> {
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => match args.model {
|
||||
WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?],
|
||||
WhichModel::W0_5b | WhichModel::W2_0_5b | WhichModel::W2_1_5b | WhichModel::W1_8b => {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
WhichModel::W4b
|
||||
| WhichModel::W7b
|
||||
| WhichModel::W2_7b
|
||||
| WhichModel::W14b
|
||||
| WhichModel::W72b
|
||||
| WhichModel::W2_72b
|
||||
| WhichModel::MoeA27b => {
|
||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||
}
|
||||
|
@ -39,7 +39,7 @@ struct Args {
|
||||
|
||||
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
|
||||
/// mask, positive makes the mask more selective.
|
||||
#[arg(long, default_value_t = 0.)]
|
||||
#[arg(long, allow_hyphen_values = true, default_value_t = 0.)]
|
||||
threshold: f32,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
@ -139,7 +139,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let (_one, h, w) = mask.dims3()?;
|
||||
let mask = mask.expand((3, h, w))?;
|
||||
|
||||
let mut img = image::io::Reader::open(&args.image)?
|
||||
let mut img = image::ImageReader::open(&args.image)?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?;
|
||||
let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
|
||||
|
@ -380,7 +380,7 @@ fn text_embeddings(
|
||||
}
|
||||
|
||||
fn image_preprocess<T: AsRef<std::path::Path>>(path: T) -> anyhow::Result<Tensor> {
|
||||
let img = image::io::Reader::open(path)?.decode()?;
|
||||
let img = image::ImageReader::open(path)?.decode()?;
|
||||
let (height, width) = (img.height() as usize, img.width() as usize);
|
||||
let height = height - height % 32;
|
||||
let width = width - width % 32;
|
||||
|
@ -145,7 +145,7 @@ impl ViTImageProcessor {
|
||||
pub fn load_images(&self, image_path: Vec<&str>) -> Result<Vec<image::DynamicImage>> {
|
||||
let mut images: Vec<image::DynamicImage> = Vec::new();
|
||||
for path in image_path {
|
||||
let img = image::io::Reader::open(path)?.decode().unwrap();
|
||||
let img = image::ImageReader::open(path)?.decode().unwrap();
|
||||
images.push(img);
|
||||
}
|
||||
|
||||
|
@ -13,7 +13,7 @@ struct Block {
|
||||
|
||||
impl Block {
|
||||
fn get(&self, key: &str) -> Result<&str> {
|
||||
match self.parameters.get(&key.to_string()) {
|
||||
match self.parameters.get(key) {
|
||||
None => candle::bail!("cannot find {} in {}", key, self.block_type),
|
||||
Some(value) => Ok(value),
|
||||
}
|
||||
@ -28,7 +28,7 @@ pub struct Darknet {
|
||||
|
||||
impl Darknet {
|
||||
fn get(&self, key: &str) -> Result<&str> {
|
||||
match self.parameters.get(&key.to_string()) {
|
||||
match self.parameters.get(key) {
|
||||
None => candle::bail!("cannot find {} in net parameters", key),
|
||||
Some(value) => Ok(value),
|
||||
}
|
||||
@ -272,7 +272,7 @@ impl Darknet {
|
||||
let mut prev_channels: usize = 3;
|
||||
for (index, block) in self.blocks.iter().enumerate() {
|
||||
let channels_and_bl = match block.block_type.as_str() {
|
||||
"convolutional" => conv(vb.pp(&index.to_string()), index, prev_channels, block)?,
|
||||
"convolutional" => conv(vb.pp(index.to_string()), index, prev_channels, block)?,
|
||||
"upsample" => upsample(prev_channels)?,
|
||||
"shortcut" => shortcut(index, prev_channels, block)?,
|
||||
"route" => route(index, &blocks, block)?,
|
||||
|
@ -159,7 +159,7 @@ pub fn main() -> Result<()> {
|
||||
let net_width = darknet.width()?;
|
||||
let net_height = darknet.height()?;
|
||||
|
||||
let original_image = image::io::Reader::open(&image_name)?
|
||||
let original_image = image::ImageReader::open(&image_name)?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?;
|
||||
let image = {
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user