Compare commits

...

40 Commits

Author SHA1 Message Date
17313a4226 Fix cuda memory error for Qwen3 non-quantized (#2987)
* Update KvCache initialization in Qwen3 model to use a fixed max position embedding value of 512

* add doc
2025-06-07 16:02:58 +02:00
0224a749f0 Add Qwen3 MoE (#2934)
* qwen-moe rebase

* lint

* fixed rebase error

* swapped normal MoE model with CausalMoE Model in example, and swapped the tie word embeddings if statement

* updated readme
2025-05-31 15:33:28 +02:00
cd7b877d6b candle-onnx: Implement Trilu and ScatterND ops (#2952)
* onnx attention

* setup an example, adding and fixing onnx ops bit by bit

* model working, output is garbage data

* trilu working

* close but not quite, Issues still with scatterND

* closer but the outputs are still slightly wrong

* added tests for trilu and scatterND

* lint

* readme

* clippy

* removed unnessisary comments

* changed device selection, took hyperparameters from model config
2025-05-30 07:36:09 +02:00
5aed817f1b feat: enhance linear algebra operations (#2972)
- Add `dot()` for vector/matrix products
- Implement the `Frobenius` norm
- Add `mv()` for matrix-vector multiply
2025-05-29 09:41:01 +02:00
1a183c988a Add fine-tuned text classifier to xlm roberta example (#2969) 2025-05-28 06:17:07 +02:00
cac51fe16a (hotfix) fix the doc test for indexer (#2970) 2025-05-28 06:13:26 +02:00
61ddb9535e Use a tanh activation in the xlm-roberta classification head. (#2968) 2025-05-26 08:54:31 +02:00
9a62c91643 Proper support for phi-4 (#2960)
* Add phi-4 support.

* Long-rope support.

* Get clippy to be happy.:
2025-05-21 10:18:33 +02:00
92106c8762 Fixes for clippy 1.87. (#2956) 2025-05-15 21:50:27 +02:00
9ce4fe6194 Fix docs quantized qwen3 (#2955)
* fixed docs quantized-qwen3 README

* fixed docs quantized-qwen2-instruct README
2025-05-15 07:58:03 +02:00
450a49ed1a Olmo 2 model (#2954)
* OLMo 2 model

* Update olmo-2 to example

* Clippy fix.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2025-05-14 19:18:02 +02:00
6bd61727bc Make tensor contiguous before the repeat_kv calls to avoid strided copies (#2953) 2025-05-14 10:47:28 +02:00
485ddf2996 Fixed Quantized Qwen3 Model (#2951)
* optimize KV cache to reduce GPU memory usage

* revert to using candle_nn::kv_cache::KvCache with initial capacity of 512
2025-05-13 05:53:42 +02:00
36508a2c93 Add Resize to onnx ops (#2946)
* added resize to candle-onnx, not currently working

* changed unreachable to bail, and bailed when both scales and sizes are set

* cleanup and added other unused options for this op

* cleanup

* fixed image loading to make output work

* cleanup and removed unused variables

* removed path path creation code, and changed unwrap to ?
2025-05-10 07:05:03 +02:00
3d05f5cf3d Qwen3 quantized implementation (#2939)
* fixed quantized_phi3 implementation

* quantized_qwen3 implementation

* Update quantized_phi3.rs

* Update quantized_phi3.rs

* add quantized_qwen3 example

* Clippy fixes.

* Cleanup.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2025-05-08 15:06:10 +02:00
637473cb5e Bump cudarc to 0.16.3. (#2942) 2025-05-04 09:14:28 +02:00
e27b4700ad Indexing with max-value results in zero/no-op. (#2940)
* Indexing with max-value results in zero/no-op.

* Add some testing.

* Also adapt the metal kernels.

* Another test.

* Fix.
2025-05-03 11:36:31 +02:00
1fdfb58de5 Updating Add qwen3 (PR 2903) to use HF weights (#2930)
* add Qwen3.rs

* fixed compile error

* attempting to gett pr 2903 working with qwen weights

* different qwen variants working

* added moe model

* clippy

* added additional eos token

* translated Korean comments to English as well as I can

* removed specialized Qwen3RmsNorm and replaced with generic Candle RmsNorm

* replaced custom repeat_kv implementation with candle's repeat_kv implementation

* replace linear with linear_b in attention initalization

* replaced custom custom kv_cache implementation with candle kv_cache

* style

* replaced explicit broadcast add with normal add in decoder layer

* removed keeping the Rotary embedding layer in the model struct

* used tie_word_embeddings bool from config instead of relying on existence of weights for lm head in CasualLM

* removed duplicate code from qwen3_moe

* removed sliding window from qwen3 attention

* removed MoE code

* removed unused option

* Fixed Typo

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>

* fixed tie word embeddings to use the correct embedding weights instead of the opposite

---------

Co-authored-by: Max <naturale@hufs.ac.kr>
Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
2025-05-02 06:05:53 +02:00
cd96fa80da Add a scattered kv cache. (#2936)
* Add a scattered kv cache.

* Update some comments.
2025-05-01 10:20:48 +02:00
8a19bb7df2 Bump the candle version to 0.9.1. (#2935) 2025-05-01 10:08:16 +02:00
38fc86621c Add support for Helium-v1. (#2932) 2025-04-30 19:38:44 +02:00
5029ac52bb Added tracing page to the candle book. (#2922)
* tracing page

* warned about asynchronous execution

* cleanup

* added Nsignt Systems recommendation
2025-04-29 21:35:36 +02:00
de23d34a28 Switch Tensor::full to return a contiguous tensor. (#2929) 2025-04-28 21:36:39 +02:00
d4bac37a61 Fix the gumbel softmax by casting to f32. (#2928) 2025-04-28 19:48:51 +02:00
e98754fc5a Optimize Tensor::new when called on nested Vec<..>. (#2927)
* Optimize Tensor::new when called on nested Vec<..>.

* Improve performance.

* Similar flattening for the 4d case.

* More tweaks.

* Add some dummy test.
2025-04-28 09:19:45 +02:00
e3db30021f Support for "unbatched" rope. (#2926)
* Support for (un)-batched rope.

* Use 3d rope in the rope/ropei/rope_thd functions.

* Get the CPU versions to work.

* Fix the cuda version.

* Adapt the metal side.

* Fix the metal tests.
2025-04-27 15:12:02 +02:00
6e0646c208 Remove redundant mlx gemm dtype check (#2925) 2025-04-27 06:14:57 +02:00
fbaf0b0e32 Bump the crate version to 0.9.0. (#2924) 2025-04-26 11:01:21 +02:00
a2e925462c Add the scatter in place ops. (#2923)
* Add the scatter_set op.

* Metal op.

* Cuda version.

* Merge the checks.

* Add the actual ops.
2025-04-26 07:36:49 +02:00
3827685524 Add the scatter op. (#2921)
* Add the scatter op.

* Backprop support.

* Cuda support.
2025-04-25 21:46:58 +02:00
3aeb9575c7 Fixed Quantized Gemma3 Model and example (#2918)
* removed scale factor from computation and made quantized gemma3 work similarly to non-quantized gemma3

* created default consts, replaced is_sliding with Option holding a window_size
2025-04-25 05:47:48 +02:00
6ff0a6999c Fixed Gemma3 model and example (#2917)
* gemma3: changed RotaryEmbedding base freq based on layer and sliding window

* Changed attention mask per layer, either normal or sliding

* made attention mask creation slightly more efficient by only creating them once per model iteration

* changed is_sliding to an Option

* clippy

* changed to stop on both <eos> and <end_of_turn> instead of either or
2025-04-25 05:35:08 +02:00
82def7ae38 Cudarc update. (#2915) 2025-04-23 07:03:26 +02:00
99bd69f383 fixed quantized-gemma example (#2914)
* fixed quantized-gemma example

* lint
2025-04-23 05:39:03 +02:00
a4c56a958e Add the const-set op. (#2910)
* Add the const-set op.

* Cuda implementation.

* Bugfix.

* Metal cleanup.

* Add the metal kernels.

* Add some testing.

* Finish the metal implementation.

* Bump the version.
2025-04-19 10:07:02 +02:00
b2904a830b implemented quantized-gemma3 (#2902)
* implemented quantized-gemma, inference not working

* Fixed a few modeling bugs: outputing the correct tokens for a few iterations then garbage

* lint

* clippy

* quantized-gemma3 example working

* added readme

* clippy
2025-04-19 07:46:41 +02:00
21055b5697 Add PRelu operation (#2904)
* Add PRelu operation

* Apply rustfmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2025-04-19 07:24:10 +02:00
9dbaf958dc Add an enum for scalar values. (#2909)
* Add a scalar enum type.

* Add a bit more to the scalar type.

* Small tweak.

* More scalar usage.
2025-04-18 22:13:38 +02:00
ce5f8dd129 Check the bounds in the cuda indexing kernels. (#2908)
* Check the bounds in the cuda indexing kernels.

* Another check.
2025-04-18 20:08:17 +02:00
9954981327 Allow from_vec/from_slice to use a ShapeWithOneHole as shape. (#2905) 2025-04-17 08:59:18 +02:00
80 changed files with 6566 additions and 633 deletions

View File

@ -20,7 +20,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.9.0-alpha.4"
version = "0.9.1"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -33,17 +33,17 @@ ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.4" }
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.4" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.4" }
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.4" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.4" }
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.4" }
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.4" }
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.4" }
candle = { path = "./candle-core", package = "candle-core", version = "0.9.1" }
candle-datasets = { path = "./candle-datasets", version = "0.9.1" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.1" }
candle-kernels = { path = "./candle-kernels", version = "0.9.1" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.1" }
candle-nn = { path = "./candle-nn", version = "0.9.1" }
candle-onnx = { path = "./candle-onnx", version = "0.9.1" }
candle-transformers = { path = "./candle-transformers", version = "0.9.1" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.16.0", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
cudarc = { version = "0.16.3", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.4.1"

View File

@ -16,6 +16,7 @@
- [Running a model](inference/inference.md)
- [Using the hub](inference/hub.md)
- [Error management](error_manage.md)
- [Tracing](tracing.md)
- [Training](training/training.md)
- [Simplified](training/simplified.md)
- [MNIST](training/mnist.md)

View File

@ -0,0 +1,68 @@
# Tracing
Tracing is a powerful tool for identifying performance issues and bottlenecks in code.
> Profiling on GPUs is trickier due to asynchronous execution, see the [GPU section](#gpu).
## Overview
Candle uses the [tracing](https://docs.rs/tracing/latest/tracing/) crate for instrumentation.
To try it out, run an example in `candle-examples` with the `--tracing` flag.
This generates a trace file, typically named `trace-<timestamp>.json`.
You can view the trace in Chrome by navigating to `chrome://tracing/`, clicking **Load**, and selecting the generated trace file.
## Adding Tracing
Candle includes built-in tracing for many internal operations, using [spans](https://docs.rs/tracing/latest/tracing/struct.Span.html) to mark key points of execution.
To add custom tracing in your code, you can define a span like this:
```rust
let span = tracing::span!(tracing::Level::TRACE, name);
```
Then, to record the span during execution, create a guard:
```rust
let _enter = span.enter();
```
This guard will record the span's duration, from when it is created to when it is dropped, into a global data structure managed by the tracing crate.
## Recording and Saving a Trace
To capture and save trace data, you need to configure the tracing system with an output format. Candle uses the [tracing_subscriber](https://docs.rs/tracing-subscriber/latest/tracing_subscriber/) and [tracing_chrome](https://docs.rs/tracing-chrome/latest/tracing_chrome/) crates.
The snippet below sets up a Chrome compatible recorder that logs all tracing activity between creation and drop of the guard:
```rust
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let _guard = {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
guard
};
```
## GPU
When using CUDA, Metal, or other asynchronous GPU backends, tracing may produce misleading timing data because operations are queued rather than executed immediately.
### CUDA
For CUDA-specific profiling, you have two options:
1. Set the environment variable `CUDA_LAUNCH_BLOCKING=1` which forces synchronous execution. This makes trace timings more accurate, at the cost of reduced performance.
2. Use [NVIDIA's Nsight Systems](https://developer.nvidia.com/nsight-systems) (`nsys profile` and `nsys-ui`) which are designed specifically for profiling asynchronous CUDA executions.
We recommend using NVIDIA's Nsight Systems when possible, as it offers accurate performance data without altering typical execution patterns. In contrast, setting the `CUDA_LAUNCH_BLOCKING` environment variable forces synchronous execution, which can significantly alter execution behavior.
#### Performance Profiling with NVIDIA Nsight Systems
1. Generate an `.nsys-rep` file containing performance data ([docs](https://docs.nvidia.com/nsight-systems/UserGuide/index.html#example-single-command-lines))
- Run `nsys profile --trace cuda,nvtx,osrt --gpu-metrics-device=all --output profile_run ./target/debug/... --prompt "whatever "`
1. Open the generated `.nsys-rep` report file in Nsight Systems GUI
- File > Open

View File

@ -4,11 +4,12 @@ use criterion::criterion_main;
criterion_main!(
benchmarks::affine::benches,
benchmarks::copy::benches,
benchmarks::conv_transpose2d::benches,
benchmarks::matmul::benches,
benchmarks::qmatmul::benches,
benchmarks::random::benches,
benchmarks::reduce::benches,
benchmarks::unary::benches,
benchmarks::where_cond::benches,
benchmarks::conv_transpose2d::benches,
benchmarks::qmatmul::benches,
benchmarks::unary::benches
);

View File

@ -0,0 +1,38 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{Device, Tensor, WithDType};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run_copy_mask_benchmark<D: WithDType>(c: &mut Criterion, device: &Device, name: &str) {
let batch_size = 128;
let in_seq_len = 1;
let kv_seq_len = 1024;
let attn_mask = vec![vec![vec![D::zero(); kv_seq_len]; in_seq_len]; batch_size];
let size_in_bytes = batch_size * in_seq_len * kv_seq_len * D::DTYPE.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(size_in_bytes as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let attn_masks = vec![attn_mask.clone(); iters as usize];
let start = Instant::now();
for attn_mask in attn_masks.into_iter() {
let tensor = Tensor::new(black_box(attn_mask), device).unwrap();
black_box(tensor);
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_copy_mask_benchmark::<f32>(c, &device, "copy_mask");
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -1,5 +1,6 @@
pub(crate) mod affine;
pub(crate) mod conv_transpose2d;
pub(crate) mod copy;
pub(crate) mod matmul;
pub(crate) mod qmatmul;
pub(crate) mod random;

View File

@ -71,15 +71,27 @@ pub trait BackendStorage: Sized {
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
fn scatter_add(
&self,
fn scatter_set(
&mut self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self>;
) -> Result<()>;
fn scatter_add_set(
&mut self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<()>;
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
fn index_add(
&self,
@ -113,6 +125,8 @@ pub trait BackendStorage: Sized {
_src_offset: usize,
_dst_offset: usize,
) -> Result<()>;
fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()>;
}
pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
@ -127,8 +141,6 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
/// # Safety
/// This function is unsafe as it doesn't initialize the underlying data store.
/// The caller should ensure that the data is properly initialized as early as possible

View File

@ -53,6 +53,7 @@ impl Tensor {
} else if let Some(op) = node.op() {
match op {
Op::IndexAdd(t1, t2, t3, _)
| Op::Scatter(t1, t2, t3, _)
| Op::ScatterAdd(t1, t2, t3, _)
| Op::CustomOp3(t1, t2, t3, _)
| Op::WhereCond(t1, t2, t3) => {
@ -419,7 +420,7 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
}
Op::ScatterAdd(init, indexes, src, dim) => {
Op::Scatter(init, indexes, src, dim) => {
let init_sum_grad = grads.or_insert(init)?;
*init_sum_grad = init_sum_grad.add(&grad)?;
@ -427,6 +428,16 @@ impl Tensor {
let src_sum_grad = grads.or_insert(src)?;
*src_sum_grad = src_sum_grad.add(&src_grad)?;
}
Op::ScatterAdd(init, indexes, src, dim) => {
let init_sum_grad = grads.or_insert(init)?;
let mask = init.ones_like()?;
let mask = mask.scatter(indexes, &mask.zeros_like()?, *dim)?;
*init_sum_grad = init_sum_grad.add(&grad.mul(&mask)?)?;
let src_grad = grad.gather(indexes, *dim)?;
let src_sum_grad = grads.or_insert(src)?;
*src_sum_grad = src_sum_grad.add(&src_grad)?;
}
Op::IndexAdd(init, indexes, src, dim) => {
let init_sum_grad = grads.or_insert(init)?;
*init_sum_grad = init_sum_grad.add(&grad)?;

View File

@ -7,7 +7,7 @@ use rayon::prelude::*;
mod utils;
pub use utils::{
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8,
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2InPlace, Map2U8,
};
const USE_IM2COL_CONV1D: bool = true;
@ -483,17 +483,22 @@ impl<I: IntDType> Map1 for Gather<'_, I> {
let start_dst_idx = start_dst_idx + i * dst_right_len;
for right_i in 0..dst_right_len {
let dst_idx = start_dst_idx + right_i;
let index = ids[dst_idx].as_usize();
if index >= src_dim_len {
Err(Error::InvalidIndex {
index,
size: src_dim_len,
op: "gather",
let index = ids[dst_idx];
if index == I::max_value() {
dst[dst_idx] = T::zero();
} else {
let index = index.as_usize();
if index >= src_dim_len {
Err(Error::InvalidIndex {
index,
size: src_dim_len,
op: "gather",
}
.bt())?
}
.bt())?
let src_idx = start_src_idx + index * src_right_len + right_i;
dst[dst_idx] = src[src_idx]
}
let src_idx = start_src_idx + index * src_right_len + right_i;
dst[dst_idx] = src[src_idx]
}
}
}
@ -535,45 +540,89 @@ impl<I: IntDType> Map1 for IndexSelect<'_, I> {
let start_src_idx = left_i * right_len * src_dim;
let start_dst_idx = left_i * right_len * n_ids;
for i in 0..n_ids {
let index = self.ids[self.ids_l.start_offset() + stride_ids * i].as_usize();
if index >= src_dim {
Err(Error::InvalidIndex {
index,
size: src_dim,
op: "index-select",
}
.bt())?
}
let start_src_idx = start_src_idx + index * right_len;
let start_dst_idx = start_dst_idx + i * right_len;
dst[start_dst_idx..start_dst_idx + right_len]
.copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
let index = self.ids[self.ids_l.start_offset() + stride_ids * i];
if index == I::max_value() {
dst[start_dst_idx..start_dst_idx + right_len].fill(T::zero());
} else {
let index = index.as_usize();
if index >= src_dim {
Err(Error::InvalidIndex {
index,
size: src_dim,
op: "index-select",
}
.bt())?
}
let start_src_idx = start_src_idx + index * right_len;
dst[start_dst_idx..start_dst_idx + right_len]
.copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
}
}
}
Ok(dst)
}
}
struct ScatterAdd<'a, I: IntDType> {
trait ElemUpdate {
fn f<T: WithDType>(dst: &mut T, src: T);
}
struct Set;
struct Add;
impl ElemUpdate for Set {
fn f<T: WithDType>(dst: &mut T, src: T) {
*dst = src
}
}
impl ElemUpdate for Add {
fn f<T: WithDType>(dst: &mut T, src: T) {
*dst += src
}
}
struct Scatter<'a, I: IntDType, M: ElemUpdate> {
ids: &'a [I],
ids_l: &'a Layout,
dim: usize,
_phantom: std::marker::PhantomData<M>,
}
impl<I: IntDType> Map2 for ScatterAdd<'_, I> {
const OP: &'static str = "scatter-add";
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
let dst_len = l1.shape().elem_count();
let mut dst = vec![T::zero(); dst_len];
copy_strided_src_(v1, &mut dst, 0, l1);
impl<'a, I: IntDType, M: ElemUpdate> Scatter<'a, I, M> {
fn new(ids: &'a [I], ids_l: &'a Layout, dim: usize) -> Self {
Self {
ids,
ids_l,
dim,
_phantom: Default::default(),
}
}
}
impl<I: IntDType, M: ElemUpdate> Map2InPlace for Scatter<'_, I, M> {
const OP: &'static str = "scatter";
fn f<T: WithDType>(
&self,
dst: &mut [T],
dst_l: &Layout,
src: &[T],
src_l: &Layout,
) -> Result<()> {
let dst = match dst_l.contiguous_offsets() {
None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
Some((o1, o2)) => &mut dst[o1..o2],
};
let src = match src_l.contiguous_offsets() {
None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
Some((o1, o2)) => &src[o1..o2],
};
let dim = self.dim;
let ids_dims = self.ids_l.dims();
let dst_dims = l1.dims();
let dst_dims = dst_l.dims();
let dst_dim_len = dst_dims[dim];
let dst_right_len: usize = dst_dims[dim + 1..].iter().product();
@ -592,7 +641,11 @@ impl<I: IntDType> Map2 for ScatterAdd<'_, I> {
let start_ids_idx = start_ids_idx + i * ids_right_len;
for right_i in 0..dst_right_len {
let ids_idx = start_ids_idx + right_i;
let index = ids[ids_idx].as_usize();
let index = ids[ids_idx];
if index == I::max_value() {
continue;
}
let index = index.as_usize();
if index >= dst_dim_len {
Err(Error::InvalidIndex {
index,
@ -602,12 +655,12 @@ impl<I: IntDType> Map2 for ScatterAdd<'_, I> {
.bt())?
}
let dst_idx = start_dst_idx + index * dst_right_len + right_i;
dst[dst_idx] += src[ids_idx]
M::f(&mut dst[dst_idx], src[ids_idx])
}
}
}
Ok(dst)
Ok(())
}
}
@ -635,6 +688,9 @@ impl<I: IntDType> Map2 for IndexAdd<'_, I> {
let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
if dim == 0 {
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
if *dst_idx == I::max_value() {
continue;
}
let dst_idx = dst_idx.as_usize();
if dst_idx >= max_idx {
Err(Error::InvalidIndex {
@ -653,6 +709,9 @@ impl<I: IntDType> Map2 for IndexAdd<'_, I> {
}
} else {
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
if *dst_idx == I::max_value() {
continue;
}
let dst_idx = dst_idx.as_usize();
if dst_idx >= max_idx {
Err(Error::InvalidIndex {
@ -2381,19 +2440,36 @@ impl BackendStorage for CpuStorage {
}
}
fn scatter_add(
&self,
fn scatter_set(
&mut self,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<Self> {
) -> Result<()> {
match ids {
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::U8(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
Self::U32(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
Self::I64(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter").bt()),
}
}
fn scatter_add_set(
&mut self,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<()> {
match ids {
Self::U8(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
Self::U32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
Self::I64(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()),
}
}
@ -2454,6 +2530,48 @@ impl BackendStorage for CpuStorage {
fn to_cpu_storage(&self) -> Result<CpuStorage> {
Ok(self.clone())
}
fn const_set(&mut self, s: crate::scalar::Scalar, l: &Layout) -> Result<()> {
use crate::scalar::Scalar;
fn set<T: crate::WithDType>(src: &mut [T], l: &Layout, s: T) {
match l.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => {
src[start_offset..start_offset + len].fill(s)
}
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len: 1,
} => {
for src_index in block_start_index {
src[src_index] = s
}
}
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len,
} => {
for src_index in block_start_index {
src[src_index..src_index + block_len].fill(s)
}
}
}
}
match (self, s) {
(Self::BF16(storage), Scalar::BF16(v)) => set(storage, l, v),
(Self::F16(storage), Scalar::F16(v)) => set(storage, l, v),
(Self::F32(storage), Scalar::F32(v)) => set(storage, l, v),
(Self::F64(storage), Scalar::F64(v)) => set(storage, l, v),
(Self::U8(storage), Scalar::U8(v)) => set(storage, l, v),
(Self::U32(storage), Scalar::U32(v)) => set(storage, l, v),
(Self::I64(storage), Scalar::I64(v)) => set(storage, l, v),
(st, s) => crate::bail!(
"const_set dtype mismatch, expected {:?} but got {:?}",
st.dtype(),
s
),
}
Ok(())
}
}
impl BackendDevice for CpuDevice {
@ -2628,20 +2746,6 @@ impl BackendDevice for CpuDevice {
Ok(storage)
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
let elem_count = shape.elem_count();
let storage = match dtype {
DType::U8 => CpuStorage::U8(vec![1u8; elem_count]),
DType::U32 => CpuStorage::U32(vec![1u32; elem_count]),
DType::I64 => CpuStorage::I64(vec![1i64; elem_count]),
DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]),
DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]),
DType::F32 => CpuStorage::F32(vec![1f32; elem_count]),
DType::F64 => CpuStorage::F64(vec![1f64; elem_count]),
};
Ok(storage)
}
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
let elem_count = shape.elem_count();
let storage = match dtype {

View File

@ -58,6 +58,30 @@ pub trait Map2 {
}
}
pub trait Map2InPlace {
const OP: &'static str;
fn f<T: WithDType>(&self, v1: &mut [T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<()>;
fn map(&self, v1: &mut C, l1: &Layout, v2: &C, l2: &Layout) -> Result<()> {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => self.f(v1, l1, v2, l2)?,
(C::U32(v1), C::U32(v2)) => self.f(v1, l1, v2, l2)?,
(C::I64(v1), C::I64(v2)) => self.f(v1, l1, v2, l2)?,
(C::BF16(v1), C::BF16(v2)) => self.f(v1, l1, v2, l2)?,
(C::F16(v1), C::F16(v2)) => self.f(v1, l1, v2, l2)?,
(C::F32(v1), C::F32(v2)) => self.f(v1, l1, v2, l2)?,
(C::F64(v1), C::F64(v2)) => self.f(v1, l1, v2, l2)?,
(v1, v2) => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
op: Self::OP,
}
.bt())?,
};
Ok(())
}
}
pub trait Map2U8 {
const OP: &'static str;
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;

View File

@ -2,7 +2,7 @@ use crate::backend::BackendDevice;
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
pub use candle_kernels as kernels;
pub use cudarc;
use cudarc::driver::{CudaFunction, LaunchConfig, PushKernelArg};
use cudarc::driver::CudaFunction;
use half::{bf16, f16};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@ -188,100 +188,6 @@ impl CudaDevice {
self.id
}
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let slice = match dtype {
DType::U8 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u8>(elem_count)? };
let func = self.get_or_load_func("fill_u8", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as u8;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u32>(elem_count)? };
let func = self.get_or_load_func("fill_u32", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as u32;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<i64>(elem_count)? };
let func = self.get_or_load_func("fill_i64", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as i64;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<bf16>(elem_count)? };
let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = bf16::from_f64(v);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f16>(elem_count)? };
let func = self.get_or_load_func("fill_f16", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = f16::from_f64(v);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f32>(elem_count)? };
let func = self.get_or_load_func("fill_f32", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as f32;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f64>(elem_count) }?;
let func = self.get_or_load_func("fill_f64", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
pub fn get_or_load_custom_func(
&self,
fn_name: &str,
@ -504,10 +410,6 @@ impl BackendDevice for CudaDevice {
})
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
self.const_impl(1., shape, dtype)
}
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
let elem_count = shape.elem_count();
let slice = match dtype {

View File

@ -2,7 +2,7 @@
//!
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{builder_arg as barg, CpuStorage, DType, Layout, Result, Shape, WithDType};
use crate::{builder_arg as barg, CpuStorage, DType, Layout, Result, WithDType};
pub use candle_kernels as kernels;
pub use cudarc;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
@ -34,6 +34,21 @@ impl<T: DeviceRepr> SlicePtrOrNull<T> {
}
}
impl crate::scalar::Scalar {
pub fn builder_arg<'a, 'b: 'a>(&'b self, builder: &mut cudarc::driver::LaunchArgs<'a>) {
use crate::scalar::Scalar;
match self {
Scalar::U8(v) => builder.arg(v),
Scalar::U32(v) => builder.arg(v),
Scalar::I64(v) => builder.arg(v),
Scalar::F32(v) => builder.arg(v),
Scalar::F64(v) => builder.arg(v),
Scalar::F16(v) => builder.arg(v),
Scalar::BF16(v) => builder.arg(v),
};
}
}
impl SlicePtrOrNull<usize> {
pub fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
let ds = if l.is_contiguous() {
@ -395,7 +410,7 @@ impl Map1 for IndexSelect<'_> {
CudaStorageSlice::U8(slice) => ("is_u8", slice_ptr(slice, ids_l.start_offset())),
CudaStorageSlice::I64(slice) => ("is_i64", slice_ptr(slice, ids_l.start_offset())),
_ => Err(CudaError::UnexpectedDType {
msg: "index_select ids should be u8 or u32",
msg: "index_select ids should be u8, u32, or i64",
expected: DType::U32,
got: self.0.dtype(),
})
@ -492,7 +507,7 @@ impl Map2InPlace for IndexAdd<'_> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
dst_shape: &Shape,
dst_l: &Layout,
src: &CudaSlice<T>,
src_l: &Layout,
dev: &CudaDevice,
@ -514,6 +529,10 @@ impl Map2InPlace for IndexAdd<'_> {
got: ids.dtype(),
})?,
};
let dst = match dst_l.contiguous_offsets() {
Some((o1, o2)) => dst.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
};
let src = match src_l.contiguous_offsets() {
Some((o1, o2)) => src.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
@ -521,7 +540,7 @@ impl Map2InPlace for IndexAdd<'_> {
let left_sz: usize = src_l.dims()[..dim].iter().product();
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
let src_dim_sz = src_l.dims()[dim];
let dst_dim_sz = dst_shape.dims()[dim];
let dst_dim_sz = dst_l.dims()[dim];
let ids_dim_sz = ids_l.dims()[0];
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
@ -529,7 +548,59 @@ impl Map2InPlace for IndexAdd<'_> {
barg!(builder, ids);
barg!(builder, ids_dim_sz);
builder.arg(&src);
builder.arg(dst);
builder.arg(&dst);
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
Ok(())
}
}
struct Scatter<'a>(&'a CudaStorage, &'a Layout, usize);
impl Map2InPlace for Scatter<'_> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
dst_l: &Layout,
src: &CudaSlice<T>,
src_l: &Layout,
dev: &CudaDevice,
) -> Result<()> {
let ids = &self.0;
let ids_l = &self.1;
let dim = self.2;
let (ids_o1, _) = match ids_l.contiguous_offsets() {
Some(o12) => o12,
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
};
let (name, (ids, _guard)) = match &ids.slice {
CudaStorageSlice::U32(slice) => ("s_u32", slice_ptr(slice, ids_o1)),
CudaStorageSlice::I64(slice) => ("s_i64", slice_ptr(slice, ids_o1)),
CudaStorageSlice::U8(slice) => ("s_u8", slice_ptr(slice, ids_o1)),
_ => Err(CudaError::UnexpectedDType {
msg: "scatter ids should be u8/u32/i64",
expected: DType::U32,
got: ids.dtype(),
})?,
};
let dst = match dst_l.contiguous_offsets() {
Some((o1, o2)) => dst.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
};
let src = match src_l.contiguous_offsets() {
Some((o1, o2)) => src.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
};
let left_sz: usize = src_l.dims()[..dim].iter().product();
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
let src_dim_sz = src_l.dims()[dim];
let dst_dim_sz = dst_l.dims()[dim];
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
let mut builder = func.builder();
barg!(builder, ids);
builder.arg(&src);
builder.arg(&dst);
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
@ -542,7 +613,7 @@ impl Map2InPlace for ScatterAdd<'_> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
dst_shape: &Shape,
dst_l: &Layout,
src: &CudaSlice<T>,
src_l: &Layout,
dev: &CudaDevice,
@ -564,6 +635,10 @@ impl Map2InPlace for ScatterAdd<'_> {
got: ids.dtype(),
})?,
};
let dst = match dst_l.contiguous_offsets() {
Some((o1, o2)) => dst.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
};
let src = match src_l.contiguous_offsets() {
Some((o1, o2)) => src.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
@ -571,13 +646,13 @@ impl Map2InPlace for ScatterAdd<'_> {
let left_sz: usize = src_l.dims()[..dim].iter().product();
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
let src_dim_sz = src_l.dims()[dim];
let dst_dim_sz = dst_shape.dims()[dim];
let dst_dim_sz = dst_l.dims()[dim];
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
let mut builder = func.builder();
barg!(builder, ids);
builder.arg(&src);
builder.arg(dst);
builder.arg(&dst);
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
@ -1235,6 +1310,36 @@ impl BackendStorage for CudaStorage {
&self.device
}
fn const_set(&mut self, s: crate::scalar::Scalar, layout: &Layout) -> Result<()> {
let dev = &self.device;
let shape = layout.shape();
let dims = shape.dims();
let el_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
let src_o = layout.start_offset();
let ((src, _guard_src), kernel_name) = match &mut self.slice {
S::U8(s) => (slice_ptr(s, src_o), "const_set_u8"),
S::U32(s) => (slice_ptr(s, src_o), "const_set_u32"),
S::I64(s) => (slice_ptr(s, src_o), "const_set_i64"),
S::BF16(s) => (slice_ptr(s, src_o), "const_set_bf16"),
S::F16(s) => (slice_ptr(s, src_o), "const_set_f16"),
S::F32(s) => (slice_ptr(s, src_o), "const_set_f32"),
S::F64(s) => (slice_ptr(s, src_o), "const_set_f64"),
};
let func = dev.get_or_load_func(kernel_name, &kernels::FILL)?;
let mut builder = func.builder();
barg!(builder, el_count);
barg!(builder, dims.len());
ds.builder_arg(&mut builder);
s.builder_arg(&mut builder);
barg!(builder, src);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
Ok(())
}
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
let shape = layout.shape();
let dims = shape.dims();
@ -1793,20 +1898,29 @@ impl BackendStorage for CudaStorage {
let slice = Gather(ids, ids_l, dim).map(&self.slice, &device, l)?;
Ok(Self { slice, device })
}
fn scatter_add(
&self,
fn scatter_set(
&mut self,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<Self> {
) -> Result<()> {
let device = self.device().clone();
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
self.copy_strided_src(&mut acc, 0, l)?;
ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
Ok(acc)
Scatter(ids, ids_l, dim).map(&mut self.slice, l, &src.slice, src_l, &device)
}
fn scatter_add_set(
&mut self,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<()> {
let device = self.device().clone();
ScatterAdd(ids, ids_l, dim).map(&mut self.slice, l, &src.slice, src_l, &device)
}
fn index_add(
&self,
@ -1820,7 +1934,7 @@ impl BackendStorage for CudaStorage {
let device = self.device().clone();
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
self.copy_strided_src(&mut acc, 0, l)?;
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l, &src.slice, src_l, &device)?;
Ok(acc)
}

View File

@ -1,5 +1,5 @@
/// Helper functions to plug cuda kernels in candle.
use crate::{Layout, Result, Shape, WithDType};
use crate::{Layout, Result, WithDType};
pub use cudarc;
use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits};
@ -96,7 +96,7 @@ pub trait Map2InPlace {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
dst_shape: &Shape,
dst_l: &Layout,
src: &CudaSlice<T>,
src_l: &Layout,
dev: &CudaDevice,
@ -105,19 +105,19 @@ pub trait Map2InPlace {
fn map(
&self,
dst: &mut S,
dst_s: &Shape,
dst_l: &Layout,
src: &S,
src_l: &Layout,
d: &CudaDevice,
) -> Result<()> {
match (dst, src) {
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
(S::U8(dst), S::U8(src)) => self.f(dst, dst_l, src, src_l, d),
(S::U32(dst), S::U32(src)) => self.f(dst, dst_l, src, src_l, d),
(S::I64(dst), S::I64(src)) => self.f(dst, dst_l, src, src_l, d),
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_l, src, src_l, d),
(S::F16(dst), S::F16(src)) => self.f(dst, dst_l, src, src_l, d),
(S::F32(dst), S::F32(src)) => self.f(dst, dst_l, src, src_l, d),
(S::F64(dst), S::F64(src)) => self.f(dst, dst_l, src, src_l, d),
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
}
}

View File

@ -103,7 +103,63 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4:
}
}
impl<S: NdArray> NdArray for Vec<S> {
impl<S: WithDType> NdArray for Vec<S> {
fn shape(&self) -> Result<Shape> {
Ok(Shape::from(self.len()))
}
fn to_cpu_storage(&self) -> CpuStorage {
S::to_cpu_storage(self.as_slice())
}
}
impl<S: WithDType> NdArray for Vec<&[S]> {
fn shape(&self) -> Result<Shape> {
if self.is_empty() {
crate::bail!("empty array")
}
let n = self.len();
let m = self[0].len();
for v in self.iter() {
if v.len() != m {
crate::bail!("two elements have different len {m} {}", v.len())
}
}
Ok(Shape::from((n, m)))
}
fn to_cpu_storage(&self) -> CpuStorage {
let data = self.iter().copied().flatten().copied().collect::<Vec<_>>();
S::to_cpu_storage_owned(data)
}
}
impl<S: WithDType> NdArray for Vec<Vec<S>> {
fn shape(&self) -> Result<Shape> {
if self.is_empty() {
crate::bail!("empty array")
}
let n = self.len();
let m = self[0].len();
for v in self.iter() {
if v.len() != m {
crate::bail!("two elements have different len {m} {}", v.len())
}
}
Ok(Shape::from((n, m)))
}
fn to_cpu_storage(&self) -> CpuStorage {
let len: usize = self.iter().map(|v| v.len()).sum();
let mut dst = Vec::with_capacity(len);
for v in self.iter() {
dst.extend(v.iter().copied());
}
S::to_cpu_storage_owned(dst)
}
}
impl<S: WithDType> NdArray for Vec<Vec<Vec<S>>> {
fn shape(&self) -> Result<Shape> {
if self.is_empty() {
crate::bail!("empty array")
@ -120,9 +176,57 @@ impl<S: NdArray> NdArray for Vec<S> {
}
fn to_cpu_storage(&self) -> CpuStorage {
// This allocates intermediary memory and shouldn't be necessary.
let storages = self.iter().map(|v| v.to_cpu_storage()).collect::<Vec<_>>();
CpuStorage::concat(storages.as_slice()).unwrap()
if self.is_empty() {
return S::to_cpu_storage_owned(vec![]);
}
let len: usize = self
.iter()
.map(|v| v.iter().map(|v| v.len()).sum::<usize>())
.sum();
let mut dst = Vec::with_capacity(len);
for v1 in self.iter() {
for v2 in v1.iter() {
dst.extend(v2.iter().copied());
}
}
S::to_cpu_storage_owned(dst)
}
}
impl<S: WithDType> NdArray for Vec<Vec<Vec<Vec<S>>>> {
fn shape(&self) -> Result<Shape> {
if self.is_empty() {
crate::bail!("empty array")
}
let shape0 = self[0].shape()?;
let n = self.len();
for v in self.iter() {
let shape = v.shape()?;
if shape != shape0 {
crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
}
}
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
}
fn to_cpu_storage(&self) -> CpuStorage {
let len: usize = self
.iter()
.map(|v| {
v.iter()
.map(|v| v.iter().map(|v| v.len()).sum::<usize>())
.sum::<usize>()
})
.sum();
let mut dst = Vec::with_capacity(len);
for v1 in self.iter() {
for v2 in v1.iter() {
for v3 in v2.iter() {
dst.extend(v3.iter().copied());
}
}
}
S::to_cpu_storage_owned(dst)
}
}
@ -292,23 +396,6 @@ impl Device {
self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
}
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self {
Device::Cpu => {
let storage = CpuDevice.ones_impl(shape, dtype)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Metal(storage))
}
}
}
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self {
Device::Cpu => {

View File

@ -107,6 +107,7 @@ pub trait WithDType:
fn from_f64(v: f64) -> Self;
fn to_f64(self) -> f64;
fn to_scalar(self) -> crate::scalar::Scalar;
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>;
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
@ -131,6 +132,10 @@ macro_rules! with_dtype {
$to_f64(self)
}
fn to_scalar(self) -> crate::scalar::Scalar {
crate::scalar::Scalar::$dtype(self)
}
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> {
CpuStorageRef::$dtype(data)
}
@ -175,7 +180,7 @@ with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
with_dtype!(f64, F64, |v: f64| v, |v: f64| v);
pub trait IntDType: WithDType {
pub trait IntDType: WithDType + num_traits::Bounded {
fn is_true(&self) -> bool;
fn as_usize(&self) -> usize;
}

View File

@ -37,6 +37,10 @@ impl crate::backend::BackendStorage for CudaStorage {
fail!()
}
fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
Err(Error::NotCompiledWithCudaSupport)
}
@ -124,15 +128,27 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
fn scatter_add(
&self,
fn scatter_set(
&mut self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
fn scatter_add_set(
&mut self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
@ -214,10 +230,6 @@ impl crate::backend::BackendDevice for CudaDevice {
Err(Error::NotCompiledWithCudaSupport)
}
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}

View File

@ -41,6 +41,10 @@ impl crate::backend::BackendStorage for MetalStorage {
fail!()
}
fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
Err(Error::NotCompiledWithMetalSupport)
}
@ -128,15 +132,27 @@ impl crate::backend::BackendStorage for MetalStorage {
Err(Error::NotCompiledWithMetalSupport)
}
fn scatter_add(
&self,
fn scatter_set(
&mut self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
fn scatter_add_set(
&mut self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
@ -218,10 +234,6 @@ impl crate::backend::BackendDevice for MetalDevice {
Err(Error::NotCompiledWithMetalSupport)
}
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}

View File

@ -226,8 +226,8 @@ where
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
///
/// let d = a.i((2.., ..))?;
/// assert_eq!(c.shape().dims(), &[2]);
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
/// assert_eq!(d.shape().dims(), &[1, 3]);
/// assert_eq!(d.to_vec2::<f32>()?, &[[6., 7., 8.]]);
/// # Ok::<(), candle_core::Error>(())
/// ```
fn i(&self, (a, b): (A, B)) -> Result<Tensor, Error> {

View File

@ -413,6 +413,100 @@ impl BackendStorage for MetalStorage {
self.binary(name, rhs, lhs_l, rhs_l)
}
fn const_set(&mut self, s: crate::scalar::Scalar, l: &Layout) -> Result<()> {
use crate::scalar::Scalar;
fn set<S: crate::WithDType + candle_metal_kernels::utils::EncoderParam>(
self_: &mut MetalStorage,
s: S,
l: &Layout,
) -> Result<()> {
let device = self_.device();
let dtype = self_.dtype;
let shape = l.shape();
let el_count = shape.elem_count();
let command_buffer = device.command_buffer()?;
command_buffer.set_label("const-set");
let dst = buffer_o(&self_.buffer, l, self_.dtype);
match (el_count % 2, dtype, l.is_contiguous()) {
(0, DType::BF16 | DType::F16, true) => {
use candle_metal_kernels::unary::contiguous_tiled;
let kernel_name = match dtype {
DType::F16 => contiguous_tiled::const_set::HALF,
DType::BF16 => contiguous_tiled::const_set::BFLOAT,
_ => crate::bail!("internal bug in const_set"),
};
candle_metal_kernels::call_const_set_contiguous_tiled(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
s,
dst,
)
.map_err(MetalError::from)?;
}
(_, _, true) => {
use candle_metal_kernels::unary::contiguous;
let kernel_name = match dtype {
DType::F16 => contiguous::const_set::HALF,
DType::BF16 => contiguous::const_set::BFLOAT,
DType::F32 => contiguous::const_set::FLOAT,
DType::I64 => contiguous::const_set::I64,
DType::U32 => contiguous::const_set::U32,
DType::U8 => contiguous::const_set::U8,
DType::F64 => crate::bail!("unsupported const-set f64"),
};
candle_metal_kernels::call_const_set_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
s,
dst,
)
.map_err(MetalError::from)?;
}
(_, _, false) => {
use candle_metal_kernels::unary::strided;
let kernel_name = match dtype {
DType::F16 => strided::const_set::HALF,
DType::BF16 => strided::const_set::BFLOAT,
DType::F32 => strided::const_set::FLOAT,
DType::I64 => strided::const_set::I64,
DType::U32 => strided::const_set::U32,
DType::U8 => strided::const_set::U8,
DType::F64 => crate::bail!("unsupported const-set f64"),
};
candle_metal_kernels::call_const_set_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
l.dims(),
s,
l.stride(),
dst,
)
.map_err(MetalError::from)?;
}
}
Ok(())
}
match (self.dtype, s) {
(DType::U8, Scalar::U8(s)) => set(self, s, l),
(DType::U32, Scalar::U32(s)) => set(self, s, l),
(DType::I64, Scalar::I64(s)) => set(self, s, l),
(DType::F16, Scalar::F16(s)) => set(self, s, l),
(DType::BF16, Scalar::BF16(s)) => set(self, s, l),
(DType::F32, Scalar::F32(s)) => set(self, s, l),
(DType::F64, Scalar::F64(s)) => set(self, s, l),
_ => crate::bail!("dtype mismatch, expected {:?}, got {:?}", self.dtype, s),
}
}
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
let device = self.device();
let shape = layout.shape();
@ -1332,18 +1426,65 @@ impl BackendStorage for MetalStorage {
Ok(Self::new(buffer, device.clone(), dst_el, dtype))
}
fn scatter_add(
&self,
fn scatter_set(
&mut self,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<Self> {
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
self.copy_strided_src(&mut acc, 0, l)?;
if !ids_l.is_contiguous() || !src_l.is_contiguous() {
) -> Result<()> {
if !l.is_contiguous() || !ids_l.is_contiguous() || !src_l.is_contiguous() {
return Err(crate::Error::RequiresContiguous { op: "scatter" }.bt());
};
let name = match (ids.dtype, self.dtype) {
(DType::U8, DType::F32) => "s_u8_f32",
(DType::U8, DType::F16) => "s_u8_f16",
(DType::U8, DType::BF16) => "s_u8_bf16",
(DType::U32, DType::U32) => "s_u32_u32",
(DType::U32, DType::F32) => "s_u32_f32",
(DType::U32, DType::F16) => "s_u32_f16",
(DType::U32, DType::BF16) => "s_u32_bf16",
(DType::I64, DType::F32) => "s_i64_f32",
(DType::I64, DType::F16) => "s_i64_f16",
(DType::I64, DType::BF16) => "s_i64_bf16",
_ => Err(MetalError::UnexpectedDType {
msg: "scatter ids should be u8/u32/i64",
expected: DType::U32,
got: ids.dtype(),
})?,
};
let command_buffer = self.device.command_buffer()?;
let dst = buffer_o(&self.buffer, l, self.dtype);
let src = buffer_o(&src.buffer, src_l, src.dtype);
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
candle_metal_kernels::call_scatter(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
src_l.dims(),
l.dims(),
dim,
src,
ids,
dst,
)
.map_err(MetalError::from)?;
Ok(())
}
fn scatter_add_set(
&mut self,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<()> {
if !l.is_contiguous() || !ids_l.is_contiguous() || !src_l.is_contiguous() {
return Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt());
};
let name = match (ids.dtype, self.dtype) {
@ -1364,9 +1505,10 @@ impl BackendStorage for MetalStorage {
})?,
};
let command_buffer = self.device.command_buffer()?;
let dst = buffer_o(&self.buffer, l, self.dtype);
let src = buffer_o(&src.buffer, src_l, src.dtype);
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
candle_metal_kernels::call_scatter_add(
candle_metal_kernels::call_scatter(
&self.device.device,
&command_buffer,
&self.device.kernels,
@ -1376,10 +1518,10 @@ impl BackendStorage for MetalStorage {
dim,
src,
ids,
&acc.buffer,
dst,
)
.map_err(MetalError::from)?;
Ok(acc)
Ok(())
}
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
@ -1513,50 +1655,32 @@ impl BackendStorage for MetalStorage {
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("matmul");
if self.dtype == DType::BF16 {
candle_metal_kernels::call_mlx_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
candle_metal_kernels::GemmDType::BF16,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
} else {
let dtype = match self.dtype {
DType::F32 => candle_metal_kernels::GemmDType::F32,
DType::F16 => candle_metal_kernels::GemmDType::F16,
DType::BF16 => candle_metal_kernels::GemmDType::BF16,
dtype => {
return Err(MetalError::Message(format!(
"mlx matmul doesn't support {dtype:?}"
))
.into())
}
};
candle_metal_kernels::call_mlx_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
dtype,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
}
let dtype = match self.dtype {
DType::F32 => candle_metal_kernels::GemmDType::F32,
DType::F16 => candle_metal_kernels::GemmDType::F16,
DType::BF16 => candle_metal_kernels::GemmDType::BF16,
dtype => {
return Err(
MetalError::Message(format!("mlx matmul doesn't support {dtype:?}")).into(),
)
}
};
candle_metal_kernels::call_mlx_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
dtype,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(
buffer,
self.device.clone(),
@ -1965,40 +2089,6 @@ impl BackendDevice for MetalDevice {
))
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
let name = match dtype {
DType::U8 => "fill_u8",
DType::U32 => "fill_u32",
DType::I64 => "fill_i64",
DType::F16 => "fill_f16",
DType::BF16 => "fill_bf16",
DType::F32 => "fill_f32",
DType::F64 => {
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
return self.storage_from_cpu_storage(&cpu_storage);
}
};
let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?;
let command_buffer = self.command_buffer()?;
candle_metal_kernels::call_const_fill(
&self.device,
&command_buffer,
&self.kernels,
name,
shape.elem_count(),
&buffer,
1.,
)
.map_err(MetalError::from)?;
Ok(MetalStorage::new(
buffer,
self.clone(),
shape.elem_count(),
dtype,
))
}
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
let (count, buffer) = match T::cpu_storage_ref(s) {
CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),

View File

@ -80,6 +80,7 @@ pub enum Op {
Reduce(Tensor, ReduceOp, Vec<usize>),
Matmul(Tensor, Tensor),
Gather(Tensor, Tensor, usize),
Scatter(Tensor, Tensor, Tensor, usize),
ScatterAdd(Tensor, Tensor, Tensor, usize),
IndexSelect(Tensor, Tensor, usize),
IndexAdd(Tensor, Tensor, Tensor, usize),

View File

@ -1,6 +1,74 @@
//! TensorScalar Enum and Trait
//!
use crate::{Result, Tensor, WithDType};
use crate::{DType, Result, Tensor, WithDType};
use half::{bf16, f16};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Scalar {
U8(u8),
U32(u32),
I64(i64),
BF16(bf16),
F16(f16),
F32(f32),
F64(f64),
}
impl<T: WithDType> From<T> for Scalar {
fn from(value: T) -> Self {
value.to_scalar()
}
}
impl Scalar {
pub fn zero(dtype: DType) -> Self {
match dtype {
DType::U8 => Scalar::U8(0),
DType::U32 => Scalar::U32(0),
DType::I64 => Scalar::I64(0),
DType::BF16 => Scalar::BF16(bf16::ZERO),
DType::F16 => Scalar::F16(f16::ZERO),
DType::F32 => Scalar::F32(0.0),
DType::F64 => Scalar::F64(0.0),
}
}
pub fn one(dtype: DType) -> Self {
match dtype {
DType::U8 => Scalar::U8(1),
DType::U32 => Scalar::U32(1),
DType::I64 => Scalar::I64(1),
DType::BF16 => Scalar::BF16(bf16::ONE),
DType::F16 => Scalar::F16(f16::ONE),
DType::F32 => Scalar::F32(1.0),
DType::F64 => Scalar::F64(1.0),
}
}
pub fn dtype(&self) -> DType {
match self {
Scalar::U8(_) => DType::U8,
Scalar::U32(_) => DType::U32,
Scalar::I64(_) => DType::I64,
Scalar::BF16(_) => DType::BF16,
Scalar::F16(_) => DType::F16,
Scalar::F32(_) => DType::F32,
Scalar::F64(_) => DType::F64,
}
}
pub fn to_f64(&self) -> f64 {
match self {
Scalar::U8(v) => *v as f64,
Scalar::U32(v) => *v as f64,
Scalar::I64(v) => *v as f64,
Scalar::BF16(v) => v.to_f64(),
Scalar::F16(v) => v.to_f64(),
Scalar::F32(v) => *v as f64,
Scalar::F64(v) => *v,
}
}
}
pub enum TensorScalar {
Tensor(Tensor),

View File

@ -1,5 +1,6 @@
use crate::backend::BackendStorage;
use crate::op::{self, CmpOp, ReduceOp};
use crate::scalar::Scalar;
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
@ -73,6 +74,14 @@ impl Storage {
}
}
pub(crate) fn const_set(&mut self, v: Scalar, l: &Layout) -> Result<()> {
match self {
Storage::Cpu(storage) => storage.const_set(v, l),
Storage::Cuda(storage) => storage.const_set(v, l),
Storage::Metal(storage) => storage.const_set(v, l),
}
}
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
@ -619,32 +628,56 @@ impl Storage {
}
}
pub(crate) fn scatter_add(
&self,
pub(crate) fn scatter_set(
&mut self,
l: &Layout,
indexes: &Self,
indexes_l: &Layout,
source: &Self,
source_l: &Layout,
d: usize,
) -> Result<Self> {
) -> Result<()> {
self.same_device(indexes, "scatter-set")?;
self.same_device(source, "scatter-set")?;
match (self, indexes, source) {
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
}
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
}
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
}
_ => unreachable!(),
}
Ok(())
}
pub(crate) fn scatter_add(
&mut self,
l: &Layout,
indexes: &Self,
indexes_l: &Layout,
source: &Self,
source_l: &Layout,
d: usize,
) -> Result<()> {
self.same_device(indexes, "scatter-add")?;
self.same_device(source, "scatter-add")?;
match (self, indexes, source) {
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cpu(storage))
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
}
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cuda(storage))
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
}
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Metal(storage))
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
}
_ => unreachable!(),
}
Ok(())
}
pub(crate) fn index_add(

View File

@ -3,7 +3,7 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
use crate::scalar::TensorOrScalar;
use crate::shape::{Dim, Dims};
use crate::shape::{Dim, Dims, ShapeWithOneHole};
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::{Arc, RwLock};
@ -185,7 +185,9 @@ impl Tensor {
) -> Result<Self> {
let none = BackpropOp::none();
let shape = shape.into();
let storage = device.ones(&shape, dtype)?;
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
let layout = Layout::contiguous(shape.clone());
storage.const_set(crate::scalar::Scalar::one(dtype), &layout)?;
Ok(from_storage(storage, shape, none, is_variable))
}
@ -202,6 +204,18 @@ impl Tensor {
Self::ones_impl(shape, dtype, device, false)
}
pub fn const_set(&self, value: crate::scalar::Scalar) -> Result<()> {
self.storage_mut().const_set(value, self.layout())
}
pub fn zero_set(&self) -> Result<()> {
self.const_set(crate::scalar::Scalar::zero(self.dtype()))
}
pub fn one_set(&self) -> Result<()> {
self.const_set(crate::scalar::Scalar::one(self.dtype()))
}
/// Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor.
///
/// ```rust
@ -368,8 +382,7 @@ impl Tensor {
Self::new_impl(array, shape, device, false)
}
/// Returns a new tensor with all the elements having the same specified value. Note that
/// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
/// Returns a new tensor with all the elements having the same specified value.
///```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::full(3.5, (2, 4), &Device::Cpu)?;
@ -384,7 +397,12 @@ impl Tensor {
shape: S,
device: &Device,
) -> Result<Self> {
Self::from_vec_impl(vec![value], (), device, false)?.broadcast_as(shape)
let none = BackpropOp::none();
let shape = shape.into();
let mut storage = unsafe { device.alloc_uninit(&shape, D::DTYPE)? };
let layout = Layout::contiguous(shape.clone());
storage.const_set(value.to_scalar(), &layout)?;
Ok(from_storage(storage, shape, none, false))
}
/// Creates a new 1D tensor from an iterator.
@ -452,17 +470,13 @@ impl Tensor {
Self::from_vec_impl(data, len, device, false)
}
pub(crate) fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
pub(crate) fn from_vec_impl<S: ShapeWithOneHole, D: crate::WithDType>(
data: Vec<D>,
shape: S,
device: &Device,
is_variable: bool,
) -> Result<Self> {
let shape = shape.into();
let buffer_size = data.len();
if buffer_size != shape.elem_count() {
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
}
let shape = shape.into_shape(data.len())?;
let storage = device.storage_owned(data)?;
let none = BackpropOp::none();
Ok(from_storage(storage, shape, none, is_variable))
@ -481,7 +495,7 @@ impl Tensor {
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
pub fn from_vec<S: ShapeWithOneHole, D: crate::WithDType>(
data: Vec<D>,
shape: S,
device: &Device,
@ -502,17 +516,12 @@ impl Tensor {
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
pub fn from_slice<S: ShapeWithOneHole, D: crate::WithDType>(
array: &[D],
shape: S,
device: &Device,
) -> Result<Self> {
let shape = shape.into();
let n: usize = shape.elem_count();
let buffer_size: usize = array.len();
if buffer_size != n {
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
}
let shape = shape.into_shape(array.len())?;
let storage = device.storage_from_slice(array)?;
let none = BackpropOp::none();
Ok(from_storage(storage, shape, none, false))
@ -1226,6 +1235,83 @@ impl Tensor {
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
}
/// Computes the dot product of two 1D tensors.
///
/// - If inputs are 1D vectors (`[n]`), returns their scalar dot product.
/// - Panics if shapes are not compatible
/// - Not supported for integer dtypes
///
/// # Example (vectors)
/// ```rust
/// use candle_core::{Tensor, Device};
/// let t1 = Tensor::new(&[1.0, 2.0, 3.0], &Device::Cpu)?;
/// let t2 = Tensor::new(&[4.0, 5.0, 6.0], &Device::Cpu)?;
/// let res = t1.dot(&t2)?;
/// assert_eq!(res.to_scalar::<f64>()?, 32.);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn dot(&self, rhs: &Self) -> Result<Self> {
if self.dims().len() != 1 || rhs.dims().len() != 1 {
return Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(),
rhs: rhs.shape().clone(),
op: "dot",
});
}
(self * rhs).and_then(|ret| ret.sum_all())
}
/// Computes the **Frobenius norm** (L2 norm of all elements) of the tensor.
/// - Output is `sqrt(sum(x^2))`.
/// - Always returns a scalar (`[]` shape).
///
/// # Example
/// ```rust
/// use candle_core::{Tensor, Device};
/// let t = Tensor::new(&[[3., 4.], [0., 0.]], &Device::Cpu)?;
/// let norm = t.norm()?;
/// assert_eq!(norm.to_scalar::<f64>()?, 5.);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn norm(&self) -> Result<Self> {
if self.dtype().is_int() {
bail!("norm not supported for integer dtypes");
}
self.sqr().and_then(|x| x.sum_all()).and_then(|x| x.sqrt())
}
/// Performs strict matrix-vector multiplication (`[m, n] * [n] = [m]`).
///
/// - If `self` is a matrix (`[m, n]`) and `rhs` is a vector (`[n]`), returns a vector (`[m]`).
/// - **No broadcasting**: Panics if `self` is not 2D or if `rhs` is not 1D with matching size.
///
/// # Example
/// ```rust
/// use candle_core::{Tensor, Device};
/// let mat = Tensor::new(&[[1., 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
/// let vec = Tensor::new(&[1., 1., 1.], &Device::Cpu)?;
/// let res = mat.mv(&vec)?;
/// assert_eq!(res.to_vec1::<f64>()?, [6., 15.]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn mv(&self, rhs: &Self) -> Result<Self> {
// Strict shape checks
let lhs_dims = self.dims();
let rhs_dims = rhs.dims();
if lhs_dims.len() != 2 || rhs_dims.len() != 1 || lhs_dims[1] != rhs_dims[0] {
return Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(),
rhs: rhs.shape().clone(),
op: "mv",
});
}
// Direct matmul after ensuring rhs is column vector
self.matmul(&rhs.unsqueeze(1)?)?.squeeze(1)
}
/// Returns the matrix-multiplication of the input tensor with the other provided tensor.
///
/// # Arguments
@ -1349,8 +1435,7 @@ impl Tensor {
self.index_select(ids, 0)
}
pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "scatter-add")?;
fn scatter_checks(&self, indexes: &Self, source: &Self, dim: usize) -> Result<()> {
let source_dims = source.dims();
let self_dims = self.dims();
let mismatch = if source_dims.len() != self_dims.len() {
@ -1367,7 +1452,7 @@ impl Tensor {
};
if mismatch {
Err(Error::ShapeMismatchBinaryOp {
op: "scatter-add (self, src)",
op: "scatter (self, src)",
lhs: self.shape().clone(),
rhs: source.shape().clone(),
}
@ -1375,13 +1460,44 @@ impl Tensor {
}
if indexes.dims() != source.dims() {
Err(Error::ShapeMismatchBinaryOp {
op: "scatter-add (indexes, src)",
op: "scatter (indexes, src)",
lhs: indexes.shape().clone(),
rhs: source.shape().clone(),
}
.bt())?
}
let storage = self.storage().scatter_add(
Ok(())
}
pub fn scatter<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "scatter")?;
self.scatter_checks(indexes, source, dim)?;
let shape = self.shape();
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
let layout = Layout::contiguous(shape);
storage.scatter_set(
&layout,
&indexes.storage(),
indexes.layout(),
&source.storage(),
source.layout(),
dim,
)?;
let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
Op::Scatter(t1, t2, t3, dim)
});
Ok(from_storage(storage, self.shape(), op, false))
}
pub fn scatter_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
if self.same_storage(source) {
crate::bail!("cannot use slice_set when self and src share their storage")
}
let dim = dim.to_index(self.shape(), "scatter-set")?;
self.scatter_checks(indexes, source, dim)?;
self.storage_mut().scatter_set(
self.layout(),
&indexes.storage(),
indexes.layout(),
@ -1389,12 +1505,48 @@ impl Tensor {
source.layout(),
dim,
)?;
Ok(())
}
pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "scatter-add")?;
self.scatter_checks(indexes, source, dim)?;
let shape = self.shape();
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
let layout = Layout::contiguous(shape);
storage.scatter_add(
&layout,
&indexes.storage(),
indexes.layout(),
&source.storage(),
source.layout(),
dim,
)?;
let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
Op::ScatterAdd(t1, t2, t3, dim)
});
Ok(from_storage(storage, self.shape(), op, false))
}
pub fn scatter_add_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
if self.same_storage(source) {
crate::bail!("cannot use slice_set when self and src share their storage")
}
let dim = dim.to_index(self.shape(), "scatter-add-set")?;
self.scatter_checks(indexes, source, dim)?;
self.storage_mut().scatter_add(
self.layout(),
&indexes.storage(),
indexes.layout(),
&source.storage(),
source.layout(),
dim,
)?;
Ok(())
}
/// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
let dim = dim.to_index(self.shape(), "slice-scatter")?;
@ -2197,7 +2349,7 @@ impl Tensor {
///
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn reshape<S: crate::shape::ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
pub fn reshape<S: ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
let shape = s.into_shape(self.elem_count())?;
if shape.elem_count() != self.elem_count() {
return Err(Error::ShapeMismatchBinaryOp {

View File

@ -241,7 +241,7 @@ impl Tensor {
/// `self` and `src` must have the same shape except on dimension `dim` where the `self` size
/// has to be greater than or equal to `offset` plus the `src` size.
///
/// Note that this modifies `self` in place and as such is not compatibel with
/// Note that this modifies `self` in place and as such is not compatible with
/// back-propagation.
pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
let dim = dim.to_index(self.shape(), "slice-set")?;

View File

@ -82,6 +82,26 @@ fn broadcast_matmul(device: &Device) -> Result<()> {
Ok(())
}
#[test]
fn tensor_dot() -> Result<()> {
let lhs = Tensor::new(&[1., 2., 3.], &Device::Cpu)?;
let rhs = Tensor::new(&[4., 5., 6.], &Device::Cpu)?;
let expected = Tensor::new(32., &Device::Cpu)?;
let dot_ret = lhs.dot(&rhs)?;
candle_core::test_utils::assert_tensor_eq(&dot_ret, &expected)?;
Ok(())
}
#[test]
fn tensor_mv() -> Result<()> {
let mat = Tensor::new(&[[1., 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
let vec = Tensor::new(&[1., 1., 1.], &Device::Cpu)?;
let expected = Tensor::new(&[6., 15.], &Device::Cpu)?;
let mv_ret = mat.mv(&vec)?;
candle_core::test_utils::assert_tensor_eq(&mv_ret, &expected)?;
Ok(())
}
// https://github.com/huggingface/candle/issues/1948
fn squeeze_mm(device: &Device) -> Result<()> {
let seq_len = 8_usize;

View File

@ -25,10 +25,12 @@ fn ones(device: &Device) -> Result<()> {
Tensor::ones((2, 3), DType::F32, device)?.to_vec2::<f32>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
assert_eq!(
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
if !device.is_metal() {
assert_eq!(
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
}
assert_eq!(
Tensor::ones((2, 3), DType::F16, device)?.to_vec2::<half::f16>()?,
[
@ -63,6 +65,26 @@ fn ones(device: &Device) -> Result<()> {
}
fn full(device: &Device) -> Result<()> {
let tensor = Tensor::zeros((3, 4), DType::U32, device)?;
tensor.const_set(42u32.into())?;
assert_eq!(
tensor.to_vec2::<u32>()?,
[[42, 42, 42, 42], [42, 42, 42, 42], [42, 42, 42, 42]]
);
tensor.i((.., 2))?.const_set(1337u32.into())?;
assert_eq!(
tensor.to_vec2::<u32>()?,
[[42, 42, 1337, 42], [42, 42, 1337, 42], [42, 42, 1337, 42]]
);
tensor.i((2, ..))?.const_set(1u32.into())?;
assert_eq!(
tensor.to_vec2::<u32>()?,
[[42, 42, 1337, 42], [42, 42, 1337, 42], [1, 1, 1, 1]]
);
Ok(())
}
fn const_set(device: &Device) -> Result<()> {
assert_eq!(
Tensor::full(42u32, (2, 3), device)?.to_vec2::<u32>()?,
[[42, 42, 42], [42, 42, 42]],
@ -823,9 +845,37 @@ fn embeddings(device: &Device) -> Result<()> {
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
let ids = Tensor::new(&[u32::MAX, 2u32, u32::MAX], device)?;
let hs = t.index_select(&ids, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 0.0], [4.0, 5.0], [0.0, 0.0]]);
Ok(())
}
#[test]
fn index_select_fail() -> Result<()> {
// Check that an error is properly reported on out of bounds.
let ids = Tensor::new(&[4u32, 2u32, 1u32], &Device::Cpu)?;
let t = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], &Device::Cpu)?;
let hs = t.index_select(&ids, 0);
assert!(hs.is_err());
Ok(())
}
// The test below triggers an unwinding panic as there is a panic within the
// #[cfg(feature = "cuda")]
// #[test]
// #[should_panic]
// fn index_select_fail_gpu() {
// // Check that a panic happens for out of bounds in cuda
// if let Ok(device) = Device::new_cuda(0) {
// if let Ok(ids) = Tensor::new(&[4u32, 2u32, 1u32], &device) {
// if let Ok(t) = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], &device) {
// let _ = t.index_select(&ids, 0);
// }
// }
// }
// }
fn cmp(device: &Device) -> Result<()> {
let t1 = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?;
let t2 = Tensor::new(&[[1f32, 0f32], [3f32, 3f32], [4f32, 7f32]], device)?;
@ -980,7 +1030,7 @@ fn slice_scatter(device: &Device) -> Result<()> {
Ok(())
}
fn scatter_add(device: &Device) -> Result<()> {
fn scatter(device: &Device) -> Result<()> {
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
assert_eq!(
t.to_vec2::<f32>()?,
@ -1004,6 +1054,17 @@ fn scatter_add(device: &Device) -> Result<()> {
]
);
let hs = init.scatter(&ids, &t, 1)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0, 1.0, 1.0],
[5.0, 1.0, 1.0, 3.0, 4.0],
[1.0, 8.0, 1.0, 7.0, 1.0],
[10.0, 1.0, 9.0, 1.0, 11.0]
]
);
let init = Tensor::ones((6, 3), DType::F32, device)?;
let hs = init.scatter_add(&ids, &t, 0)?;
assert_eq!(
@ -1017,6 +1078,56 @@ fn scatter_add(device: &Device) -> Result<()> {
[1.0, 1.0, 1.0]
]
);
let hs = init.scatter(&ids, &t, 0)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[0.0, 10.0, 5.0],
[1.0, 1.0, 8.0],
[9.0, 1.0, 2.0],
[6.0, 7.0, 1.0],
[1.0, 4.0, 11.0],
[1.0, 1.0, 1.0]
]
);
let hs = {
let ids = Tensor::new(
&[
[0u32, u32::MAX, 2],
[3, 4, u32::MAX],
[3, 3, 1],
[u32::MAX, u32::MAX, 4],
],
device,
)?;
init.scatter(&ids, &t, 0)?
};
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[0.0, 1.0, 1.0],
[1.0, 1.0, 8.0],
[1.0, 1.0, 2.0],
[6.0, 7.0, 1.0],
[1.0, 4.0, 11.0],
[1.0, 1.0, 1.0]
]
);
init.scatter_set(&ids, &t, 0)?;
assert_eq!(
init.to_vec2::<f32>()?,
&[
[0.0, 10.0, 5.0],
[1.0, 1.0, 8.0],
[9.0, 1.0, 2.0],
[6.0, 7.0, 1.0],
[1.0, 4.0, 11.0],
[1.0, 1.0, 1.0]
]
);
Ok(())
}
@ -1050,6 +1161,23 @@ fn gather(device: &Device) -> Result<()> {
let hs = t.gather(&ids, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 7.0, 2.0], [0.0, 4.0, 5.0]]);
let hs = {
let ids = Tensor::new(
&[
[0u32, 0u32],
[2u32, u32::MAX],
[u32::MAX, 1u32],
[0u32, 2u32],
],
device,
)?;
t.gather(&ids, 1)?
};
assert_eq!(
hs.to_vec2::<f32>()?,
&[[0.0, 0.0], [5.0, 0.0], [0.0, 7.0], [9.0, 11.0]]
);
// Random data
// Dim: 0
@ -1484,6 +1612,7 @@ fn zero_dim(device: &Device) -> Result<()> {
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
test_device!(full, full_cpu, full_gpu, full_metal);
test_device!(const_set, cs_cpu, cs_gpu, cs_metal);
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
@ -1515,12 +1644,7 @@ test_device!(
);
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
test_device!(
scatter_add,
scatter_add_cpu,
scatter_add_gpu,
scatter_add_metal
);
test_device!(scatter, scatter_cpu, scatter_gpu, scatter_metal);
test_device!(
slice_scatter,
slice_scatter_cpu,
@ -1733,3 +1857,34 @@ fn test_flip_3d_channels() -> Result<()> {
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
Ok(())
}
#[test]
fn tensor_new() -> Result<()> {
let t1 = Tensor::new(vec![1f32, 2.0, 3.0], &Device::Cpu)?;
assert_eq!(t1.to_vec1::<f32>()?, [1.0, 2.0, 3.0]);
let t2 = Tensor::new(vec![vec![1f32, 2., 3.], vec![4., 5., 6.]], &Device::Cpu)?;
assert_eq!(t2.to_vec2::<f32>()?, [[1., 2., 3.], [4., 5., 6.]]);
let t3 = Tensor::new(
vec![
vec![vec![1f32, 2., 3.], vec![4., 5., 6.]],
vec![vec![3f32, 1., 4.], vec![1., 5., 9.]],
],
&Device::Cpu,
)?;
assert_eq!(
t3.to_vec3::<f32>()?,
[
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
[[3.0, 1.0, 4.0], [1.0, 5.0, 9.0]]
]
);
Ok(())
}
#[test]
fn tensor_norm() -> Result<()> {
let t = Tensor::new(&[[3., 4.], [0., 0.]], &Device::Cpu)?;
let norm = t.norm()?;
assert_eq!(norm.to_scalar::<f64>()?, 5.);
Ok(())
}

View File

@ -16,10 +16,9 @@ fn read_u32<T: Read>(reader: &mut T) -> std::io::Result<u32> {
fn check_magic_number<T: Read>(reader: &mut T, expected: u32) -> Result<()> {
let magic_number = read_u32(reader)?;
if magic_number != expected {
Err(io::Error::new(
io::ErrorKind::Other,
format!("incorrect magic number {magic_number} != {expected}"),
))?;
Err(io::Error::other(format!(
"incorrect magic number {magic_number} != {expected}"
)))?;
}
Ok(())
}

View File

@ -84,6 +84,10 @@ required-features = ["pyo3"]
name = "onnx"
required-features = ["onnx"]
[[example]]
name = "onnx-llm"
required-features = ["onnx"]
[[example]]
name = "onnx_basics"
required-features = ["onnx"]

View File

@ -20,8 +20,8 @@ use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{Encoding, PaddingParams, Tokenizer};
enum TaskType {
Ner(DebertaV2NERModel),
TextClassification(DebertaV2SeqClassificationModel),
Ner(Box<DebertaV2NERModel>),
TextClassification(Box<DebertaV2SeqClassificationModel>),
}
#[derive(Parser, Debug, Clone, ValueEnum)]
@ -169,21 +169,16 @@ impl Args {
match self.task {
ArgsTask::Ner => Ok((
TaskType::Ner(DebertaV2NERModel::load(
vb,
&config,
Some(id2label.clone()),
)?),
TaskType::Ner(DebertaV2NERModel::load(vb, &config, Some(id2label.clone()))?.into()),
config,
tokenizer,
id2label,
)),
ArgsTask::TextClassification => Ok((
TaskType::TextClassification(DebertaV2SeqClassificationModel::load(
vb,
&config,
Some(id2label.clone()),
)?),
TaskType::TextClassification(
DebertaV2SeqClassificationModel::load(vb, &config, Some(id2label.clone()))?
.into(),
),
config,
tokenizer,
id2label,

View File

@ -16,8 +16,8 @@ use std::path::PathBuf;
use tokenizers::Tokenizer;
enum ModelType {
Masked(DistilBertForMaskedLM),
UnMasked(DistilBertModel),
Masked(Box<DistilBertForMaskedLM>),
UnMasked(Box<DistilBertModel>),
}
impl ModelType {
@ -144,10 +144,12 @@ impl Args {
fn create_model(&self, config: &Config, vb: VarBuilder) -> Result<ModelType> {
match self.model {
Which::DistilbertForMaskedLM => {
Ok(ModelType::Masked(DistilBertForMaskedLM::load(vb, config)?))
}
Which::DistilBert => Ok(ModelType::UnMasked(DistilBertModel::load(vb, config)?)),
Which::DistilbertForMaskedLM => Ok(ModelType::Masked(
DistilBertForMaskedLM::load(vb, config)?.into(),
)),
Which::DistilBert => Ok(ModelType::UnMasked(
DistilBertModel::load(vb, config)?.into(),
)),
}
}
}

View File

@ -124,6 +124,17 @@ impl TextGeneration {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
Some(token) => token,
None => {
println!(
"Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup"
);
eos_token
}
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
@ -146,7 +157,7 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
if next_token == eos_token || next_token == eot_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
@ -350,6 +361,31 @@ fn main() -> Result<()> {
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
let prompt = match args.which {
Which::Base2B
| Which::Base7B
| Which::Instruct2B
| Which::Instruct7B
| Which::InstructV1_1_2B
| Which::InstructV1_1_7B
| Which::CodeBase2B
| Which::CodeBase7B
| Which::CodeInstruct2B
| Which::CodeInstruct7B
| Which::BaseV2_2B
| Which::InstructV2_2B
| Which::BaseV2_9B
| Which::InstructV2_9B
| Which::BaseV3_1B => args.prompt,
Which::InstructV3_1B => {
format!(
"<start_of_turn> user\n{}<end_of_turn>\n<start_of_turn> model\n",
args.prompt
)
}
};
pipeline.run(&prompt, args.sample_len)?;
Ok(())
}

View File

@ -7,7 +7,10 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::helium::{Config, Model};
use candle_transformers::models::helium::{Config as ConfigPreview, Model as ModelPreview};
use candle_transformers::models::llama::{
Cache as CacheV1, Llama as ModelV1, LlamaConfig as ConfigV1, LlamaEosToks,
};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
@ -16,6 +19,44 @@ use candle_transformers::generation::{LogitsProcessor, Sampling};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
#[derive(Debug, Clone)]
enum Model {
V1 { model: ModelV1, cache: CacheV1 },
Preview(ModelPreview),
}
impl Model {
fn forward(&mut self, input: &Tensor, start_pos: usize) -> Result<Tensor> {
let model = match self {
Model::V1 { model, cache } => model.forward(input, start_pos, cache)?,
Model::Preview(m) => m.forward(input, start_pos)?,
};
Ok(model)
}
}
#[derive(Debug, Clone)]
enum Config {
V1(ConfigV1),
Preview(ConfigPreview),
}
impl Config {
fn bos_token_id(&self) -> Option<u32> {
match self {
Config::V1(c) => c.bos_token_id,
Config::Preview(c) => Some(c.bos_token_id),
}
}
fn eos_token_id(&self) -> Option<LlamaEosToks> {
match self {
Config::V1(c) => c.eos_token_id.clone(),
Config::Preview(c) => Some(LlamaEosToks::Single(c.eos_token_id)),
}
}
}
struct TextGeneration {
model: Model,
device: Device,
@ -106,7 +147,15 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == self.config.bos_token_id || next_token == self.config.eos_token_id {
let is_eos = self
.config
.eos_token_id()
.as_ref()
.is_some_and(|v| match v {
LlamaEosToks::Single(eos) => *eos == next_token,
LlamaEosToks::Multiple(eos) => eos.contains(&next_token),
});
if Some(next_token) == self.config.bos_token_id() || is_eos {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
@ -131,6 +180,8 @@ impl TextGeneration {
enum Which {
#[value(name = "v1-preview")]
V1Preview,
#[value(name = "v1")]
V1,
}
#[derive(Parser, Debug)]
@ -144,9 +195,6 @@ struct Args {
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
#[arg(long)]
prompt: String,
@ -171,7 +219,7 @@ struct Args {
sample_len: usize,
/// The model size to use.
#[arg(long, default_value = "v1-preview")]
#[arg(long, default_value = "v1")]
which: Which,
#[arg(long)]
@ -230,6 +278,7 @@ fn main() -> Result<()> {
None => {
let name = match args.which {
Which::V1Preview => "kyutai/helium-1-preview-2b",
Which::V1 => "kyutai/helium-1-2b",
};
name.to_string()
}
@ -254,18 +303,27 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config: Config = match args.config {
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
None => {
let config_file = repo.get("config.json")?;
serde_json::from_slice(&std::fs::read(config_file)?)?
}
let config_file = match args.config {
Some(config_file) => std::path::PathBuf::from(config_file),
None => repo.get("config.json")?,
};
let config = match args.which {
Which::V1Preview => Config::Preview(serde_json::from_slice(&std::fs::read(config_file)?)?),
Which::V1 => Config::V1(serde_json::from_slice(&std::fs::read(config_file)?)?),
};
let device = candle_examples::device(args.cpu)?;
let (model, device) = {
let dtype = device.bf16_default_to_f32();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
let model = match &config {
Config::V1(c) => {
let c = c.clone().into_config(false);
let model = ModelV1::load(vb, &c)?;
let cache = CacheV1::new(true, dtype, &c, &device)?;
Model::V1 { model, cache }
}
Config::Preview(c) => Model::Preview(ModelPreview::new(c, vb)?),
};
(model, device)
};

View File

@ -3,7 +3,7 @@
OLMo is a series of Open Language Models designed to enable the science of language models.
- **Project Page:** https://allenai.org/olmo
- **Paper:** [Link](https://arxiv.org/abs/2402.00838)
- **Papers:** [OLMo](https://arxiv.org/abs/2402.00838) [OLMo 2](https://arxiv.org/abs/2501.00656)
- **Technical blog post:** https://blog.allenai.org/olmo-open-language-model-87ccfc95f580
- **W&B Logs:** https://wandb.ai/ai2-llm/OLMo-1B/reports/OLMo-1B--Vmlldzo2NzY1Njk1
<!-- - **Press release:** TODO -->

View File

@ -8,6 +8,7 @@ use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum};
use candle_transformers::models::olmo::{Config, Model as OLMo};
use candle_transformers::models::olmo2::{Config as Config2, Model as OLMo2};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
@ -18,6 +19,7 @@ use tokenizers::Tokenizer;
enum Model {
OLMo(OLMo),
OLMo2(OLMo2),
}
struct TextGeneration {
@ -82,6 +84,7 @@ impl TextGeneration {
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = match &mut self.model {
Model::OLMo(m) => m.forward(&input, start_pos)?,
Model::OLMo2(m) => m.forward(&input, start_pos)?,
};
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
@ -129,6 +132,8 @@ enum Which {
W7bTwin2T,
#[value(name = "1.7-7b")]
V1_7W7b,
#[value(name = "2-1b")]
V2W1b,
}
#[derive(Parser, Debug)]
@ -220,6 +225,7 @@ fn main() -> Result<()> {
Which::W7b => "allenai/OLMo-7B-hf".to_string(),
Which::W7bTwin2T => "allenai/OLMo-7B-Twin-2T-hf".to_string(),
Which::V1_7W7b => "allenai/OLMo-1.7-7B-hf".to_string(),
Which::V2W1b => "allenai/OLMo-2-0425-1B-Instruct".to_string(),
},
};
@ -238,33 +244,36 @@ fn main() -> Result<()> {
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => match args.model {
Which::W1b => {
Which::W1b | Which::V2W1b => {
vec![repo.get("model.safetensors")?]
}
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
},
};
let config_filename = repo.get("config.json")?;
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = {
let config_filename = repo.get("config.json")?;
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
config
};
let device = candle_examples::device(args.cpu)?;
let model = {
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = OLMo::new(&config, vb)?;
Model::OLMo(model)
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = match args.model {
Which::W1b | Which::W7b | Which::W7bTwin2T | Which::V1_7W7b => {
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let model = OLMo::new(&config, vb)?;
Model::OLMo(model)
}
Which::V2W1b => {
let config: Config2 = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let model = OLMo2::new(&config, vb)?;
Model::OLMo2(model)
}
};
println!("loaded the model in {:?}", start.elapsed());

View File

@ -0,0 +1,11 @@
## Using ONNX models in Candle
This example demonstrates how to run [ONNX](https://github.com/onnx/onnx) based LLM models in Candle.
This script only implements SmolLM-135M right now.
You can run the examples with following commands:
```bash
cargo run --example onnx-llm --features onnx
```

View File

@ -0,0 +1,209 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use candle::{DType, Tensor};
use candle_transformers::generation::{LogitsProcessor, Sampling};
use clap::{Parser, ValueEnum};
use hf_hub::api::sync::Api;
use serde::Deserialize;
use std::io::Write;
use tokenizers::Tokenizer;
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
pub num_hidden_layers: usize,
pub num_key_value_heads: usize,
pub hidden_size: usize,
pub num_attention_heads: usize,
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
SmolLM135M,
}
#[derive(Parser)]
struct Args {
/// The prompt to be used.
#[arg(long, default_value = "My favorite theorem is ")]
prompt: String,
/// The model to be used.
#[arg(value_enum, long, default_value_t = Which::SmolLM135M)]
which: Which,
/// Run on CPU rather than GPU.
#[arg(long)]
cpu: bool,
/// The number of tokens to generate.
#[arg(long, default_value_t = 100)]
max_tokens: usize,
/// The temperature used for sampling.
#[arg(long, default_value_t = 0.8)]
temperature: f32,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
}
pub fn main() -> Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let (model_id, tokenizer_id) = match args.which {
Which::SmolLM135M => ("HuggingFaceTB/SmolLM-135M", "HuggingFaceTB/SmolLM-135M"),
};
let api = Api::new()?;
let model_repo = api.model(model_id.to_string());
let tokenizer_repo = api.model(tokenizer_id.to_string());
let model_path = model_repo.get("onnx/model.onnx")?;
let config_file = model_repo.get("config.json")?;
let config: Config = serde_json::from_reader(std::fs::File::open(config_file)?)?;
let tokenizer_path = tokenizer_repo.get("tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?;
let tokens_u32 = tokenizer
.encode(args.prompt.as_str(), true)
.map_err(anyhow::Error::msg)?
.get_ids()
.to_vec();
let tokens: Vec<i64> = tokens_u32.iter().map(|&t| t as i64).collect();
println!("Loading ONNX model from {:?}", model_path);
let model = candle_onnx::read_file(model_path)?;
let mut generated_tokens = tokens.clone();
print!("{}", args.prompt);
std::io::stdout().flush()?;
let mut logits_processor = {
let temperature = args.temperature as f64;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k, args.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
let mut past_key_values: Option<Vec<(Tensor, Tensor)>> = None;
let num_layers = config.num_hidden_layers;
for _ in 0..args.max_tokens {
let mut inputs = std::collections::HashMap::new();
if let Some(past_kv) = &past_key_values {
let last_token = vec![generated_tokens[generated_tokens.len() - 1]];
let input_tensor = Tensor::new(last_token, &device)?.unsqueeze(0)?;
inputs.insert("input_ids".to_string(), input_tensor);
let seq_len = generated_tokens.len();
let attention_mask = vec![vec![1i64; seq_len]];
let attention_mask_tensor = Tensor::new(attention_mask, &device)?;
inputs.insert("attention_mask".to_string(), attention_mask_tensor);
let position_ids = vec![vec![(seq_len - 1) as i64]];
let position_ids_tensor = Tensor::new(position_ids, &device)?;
inputs.insert("position_ids".to_string(), position_ids_tensor);
for (i, (key, value)) in past_kv.iter().enumerate() {
inputs.insert(format!("past_key_values.{}.key", i), key.clone());
inputs.insert(format!("past_key_values.{}.value", i), value.clone());
}
} else {
let input_tensor = Tensor::new(generated_tokens.clone(), &device)?.unsqueeze(0)?;
inputs.insert("input_ids".to_string(), input_tensor);
let seq_len = generated_tokens.len();
let attention_mask = vec![vec![1i64; seq_len]];
let attention_mask_tensor = Tensor::new(attention_mask, &device)?;
inputs.insert("attention_mask".to_string(), attention_mask_tensor);
let position_ids: Vec<i64> = (0..seq_len as i64).collect();
let position_ids_tensor = Tensor::new(position_ids, &device)?.unsqueeze(0)?;
inputs.insert("position_ids".to_string(), position_ids_tensor);
// Create empty key and value tensors
for i in 0..num_layers {
let batch_size = 1;
let num_heads = config.num_key_value_heads;
let head_dim = config.hidden_size / config.num_attention_heads;
let seq_len = 0;
let empty_key = Tensor::zeros(
&[batch_size, num_heads, seq_len, head_dim],
DType::F32,
&device,
)?;
let empty_value = Tensor::zeros(
&[batch_size, num_heads, seq_len, head_dim],
DType::F32,
&device,
)?;
inputs.insert(format!("past_key_values.{}.key", i), empty_key);
inputs.insert(format!("past_key_values.{}.value", i), empty_value);
}
}
let outputs = candle_onnx::simple_eval(&model, inputs)?;
let logits = outputs.get("logits").unwrap();
let mut new_past_kv = Vec::with_capacity(num_layers);
for i in 0..num_layers {
let key = outputs
.get(&format!("present.{}.key", i))
.ok_or_else(|| anyhow::anyhow!("Missing present.{}.key", i))?;
let value = outputs
.get(&format!("present.{}.value", i))
.ok_or_else(|| anyhow::anyhow!("Missing present.{}.value", i))?;
new_past_kv.push((key.clone(), value.clone()));
}
past_key_values = Some(new_past_kv);
let logits_dim = logits.dims();
let seq_len = logits_dim[1];
let next_token_id = logits_processor.sample(&logits.get(0)?.get(seq_len - 1)?)?;
generated_tokens.push(next_token_id as i64);
if let Some(token_str) = tokenizer.decode(&[next_token_id], true).ok() {
print!("{}", token_str);
std::io::stdout().flush()?;
}
if let Some(eos_id) = tokenizer.token_to_id("<|endoftext|>") {
if next_token_id == eos_id {
break;
}
}
}
println!("\nGeneration complete!");
Ok(())
}

View File

@ -5,12 +5,14 @@ extern crate intel_mkl_src;
extern crate accelerate_src;
use candle::{IndexOp, D};
use candle_examples::save_image;
use clap::{Parser, ValueEnum};
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
SqueezeNet,
EfficientNet,
EsrGan,
}
#[derive(Parser)]
@ -28,10 +30,21 @@ struct Args {
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let image = candle_examples::imagenet::load_image224(args.image)?;
let image = match args.which {
Which::SqueezeNet | Which::EfficientNet => {
candle_examples::imagenet::load_image224(&args.image)?
}
Which::EsrGan => candle_examples::imagenet::load_image_with_std_mean(
&args.image,
128,
&[0.0f32, 0.0, 0.0],
&[1.0f32, 1.0, 1.0],
)?,
};
let image = match args.which {
Which::SqueezeNet => image,
Which::EfficientNet => image.permute((1, 2, 0))?,
Which::EsrGan => image,
};
println!("loaded image {image:?}");
@ -45,6 +58,9 @@ pub fn main() -> anyhow::Result<()> {
Which::EfficientNet => hf_hub::api::sync::Api::new()?
.model("onnx/EfficientNet-Lite4".into())
.get("efficientnet-lite4-11.onnx")?,
Which::EsrGan => hf_hub::api::sync::Api::new()?
.model("qualcomm/Real-ESRGAN-x4plus".into())
.get("Real-ESRGAN-x4plus.onnx")?,
},
};
@ -57,21 +73,40 @@ pub fn main() -> anyhow::Result<()> {
let prs = match args.which {
Which::SqueezeNet => candle_nn::ops::softmax(&output, D::Minus1)?,
Which::EfficientNet => output,
Which::EsrGan => output,
};
let prs = prs.i(0)?.to_vec1::<f32>()?;
// Sort the predictions and take the top 5
let mut top: Vec<_> = prs.iter().enumerate().collect();
top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
let top = top.into_iter().take(5).collect::<Vec<_>>();
match args.which {
Which::EfficientNet | Which::SqueezeNet => {
let prs = prs.i(0)?.to_vec1::<f32>()?;
// Print the top predictions
for &(i, p) in &top {
println!(
"{:50}: {:.2}%",
candle_examples::imagenet::CLASSES[i],
p * 100.0
);
// Sort the predictions and take the top 5
let mut top: Vec<_> = prs.iter().enumerate().collect();
top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
let top = top.into_iter().take(5).collect::<Vec<_>>();
// Print the top predictions
for &(i, p) in &top {
println!(
"{:50}: {:.2}%",
candle_examples::imagenet::CLASSES[i],
p * 100.0
);
}
}
Which::EsrGan => {
let max_pixel_val = candle::Tensor::try_from(255.0f32)?
.to_device(prs.device())?
.broadcast_as(prs.shape())?;
let out = (prs * max_pixel_val)?.i(0)?.to_dtype(candle::DType::U8)?;
let pb = std::path::PathBuf::from(args.image);
let input_file_name = pb.file_name().unwrap();
let mut output_file_name = std::ffi::OsString::from("super_");
output_file_name.push(input_file_name);
save_image(&out, output_file_name)?;
}
}
Ok(())

View File

@ -147,9 +147,9 @@ enum WhichModel {
V3,
#[value(name = "3-medium")]
V3Medium,
#[value(name = "2-old")]
V4Mini,
#[value(name = "4-mini")]
V4Mini,
#[value(name = "2-old")]
V2Old,
PuffinPhiV2,
PhiHermes,

View File

@ -0,0 +1,18 @@
# candle-quantized-gemma
Candle implementation of quantized Gemma.
## Running an example
```bash
$ cargo run --example quantized-gemma -- --prompt "Write a function to calculate fibonacci numbers. "
> ```python
> def fibonacci(n):
> """Calculates the nth Fibonacci number using recursion."""
> if n <= 1:
> return n
> else:
> return fibonacci(n-1) + fibonacci(n-2
> ```
```

View File

@ -0,0 +1,344 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use std::io::Write;
use tokenizers::Tokenizer;
use candle::quantized::gguf_file;
use candle::Tensor;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_transformers::models::quantized_gemma3::ModelWeights;
const DEFAULT_PROMPT: &str = "Write a function to calculate fibonacci num";
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
#[value(name = "gemma3-4b-it")]
Gemma3_4bIt,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// GGUF file to load, typically a .gguf file generated by quantization
#[arg(long)]
model: Option<String>,
/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
/// is preserved.
#[arg(long)]
prompt: Option<String>,
/// The length of the sample to generate (in tokens).
#[arg(short = 'n', long, default_value_t = 1000)]
sample_len: usize,
/// The tokenizer config in json format.
#[arg(long)]
tokenizer: Option<String>,
/// The temperature used to generate samples, use 0 for greedy sampling.
#[arg(long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Process prompt elements separately.
#[arg(long)]
split_prompt: bool,
/// Run on CPU rather than GPU even if a GPU is available.
#[arg(long)]
cpu: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model size to use.
#[arg(long, default_value = "gemma3-4b-it")]
which: Which,
}
impl Args {
fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
let tokenizer_path = match &self.tokenizer {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
let repo = "google/gemma-3-4b-it";
println!("DEBUG: Downloading tokenizer from {}", repo);
let api = api.model(repo.to_string());
api.get("tokenizer.json")?
}
};
println!("DEBUG: Loading tokenizer from {:?}", tokenizer_path);
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?;
Ok(tokenizer)
}
fn model(&self) -> anyhow::Result<std::path::PathBuf> {
let model_path = match &self.model {
Some(config) => std::path::PathBuf::from(config),
None => {
let (repo, filename) = match self.which {
Which::Gemma3_4bIt => (
"google/gemma-3-4b-it-qat-q4_0-gguf",
"gemma-3-4b-it-q4_0.gguf",
),
};
let api = hf_hub::api::sync::Api::new()?;
api.repo(hf_hub::Repo::with_revision(
repo.to_string(),
hf_hub::RepoType::Model,
"main".to_string(),
))
.get(filename)?
}
};
Ok(model_path)
}
}
fn format_size(size_in_bytes: usize) -> String {
if size_in_bytes < 1_000 {
format!("{}B", size_in_bytes)
} else if size_in_bytes < 1_000_000 {
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
} else if size_in_bytes < 1_000_000_000 {
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
} else {
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
}
}
#[derive(Debug)]
enum Prompt {
Interactive,
Chat,
One(String),
}
fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature, args.repeat_penalty, args.repeat_last_n
);
let model_path = args.model()?;
let mut file = std::fs::File::open(&model_path)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let mut model = {
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(&model_path))?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count();
total_size_in_bytes +=
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
}
println!(
"loaded {:?} tensors ({}) in {:.2}s",
model.tensor_infos.len(),
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
ModelWeights::from_gguf(model, &mut file, &device)?
};
println!("model built");
let tokenizer = args.tokenizer()?;
let mut tos = TokenOutputStream::new(tokenizer);
println!(
"DEBUG: Tokenizer vocabulary size: {}",
tos.tokenizer().get_vocab(true).len()
);
let prompt = match args.prompt.as_deref() {
Some("chat") => Prompt::Chat,
Some("interactive") => Prompt::Interactive,
Some(s) => Prompt::One(s.to_string()),
None => Prompt::One(DEFAULT_PROMPT.to_string()),
};
let mut pre_prompt_tokens = vec![];
for _ in 0.. {
let prompt_str = match &prompt {
Prompt::One(prompt) => prompt.clone(),
Prompt::Interactive | Prompt::Chat => {
print!("> ");
std::io::stdout().flush()?;
let mut prompt = String::new();
std::io::stdin().read_line(&mut prompt)?;
if prompt.ends_with('\n') {
prompt.pop();
if prompt.ends_with('\r') {
prompt.pop();
}
}
// Format for Gemma 3 chat/instruction format
format!("<start_of_turn> user\n{prompt}<end_of_turn>\n<start_of_turn> model\n")
}
};
print!("{}", &prompt_str);
let tokens = tos
.tokenizer()
.encode(prompt_str, true)
.map_err(anyhow::Error::msg)?;
let prompt_tokens = [&pre_prompt_tokens, tokens.get_ids()].concat();
let to_sample = args.sample_len.saturating_sub(1);
let max_seq_len = 8192; // Gemma 3 context length
let prompt_tokens = if prompt_tokens.len() + to_sample > max_seq_len - 10 {
let to_remove = prompt_tokens.len() + to_sample + 10 - max_seq_len;
prompt_tokens[prompt_tokens.len().saturating_sub(to_remove)..].to_vec()
} else {
prompt_tokens
};
let mut all_tokens = vec![];
let mut logits_processor = {
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k, args.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
let start_prompt_processing = std::time::Instant::now();
let mut next_token = if !args.split_prompt {
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)?
} else {
let mut next_token = 0;
for (pos, token) in prompt_tokens.iter().enumerate() {
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, pos)?;
let logits = logits.squeeze(0)?;
next_token = logits_processor.sample(&logits)?
}
next_token
};
let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
// For Gemma 3, use the correct end of sequence token
let eos_token = *tos
.tokenizer()
.get_vocab(true)
.get("<end_of_turn>")
.unwrap();
let start_post_prompt = std::time::Instant::now();
let mut sampled = 0;
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&all_tokens[start_at..],
)?
};
next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
sampled += 1;
if next_token == eos_token {
break;
};
}
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
let dt = start_post_prompt.elapsed();
println!(
"\n\n{:4} prompt tokens processed: {:.2} token/s",
prompt_tokens.len(),
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
);
println!(
"{sampled:4} tokens generated: {:.2} token/s",
sampled as f64 / dt.as_secs_f64(),
);
match prompt {
Prompt::One(_) => break,
Prompt::Interactive => {}
Prompt::Chat => {
pre_prompt_tokens = [prompt_tokens.as_slice(), all_tokens.as_slice()].concat()
}
}
}
Ok(())
}

View File

@ -8,4 +8,8 @@
cargo run --example quantized-qwen2-instruct --release -- --prompt "Write a function to count prime numbers up to N."
```
0.5b, 1.5b, 7b and 72b models are available via `--model` argument.
0.5b, 1.5b, 7b and 72b models are available via `--which` argument.
```bash
cargo run --release --example quantized-qwen2-instruct -- --which 0.5b --prompt "Write a function to count prime numbers up to N."
```

View File

@ -0,0 +1,17 @@
# candle-quantized-qwen3
[Qwen3]((https://qwenlm.github.io/blog/qwen3/)) is an upgraded version of Qwen2.5, released by Alibaba Cloud.
## Running the example
```bash
cargo run --example quantized-qwen3 --release -- --prompt "Write a function to count prime numbers up to N."
```
0.6b is used by default, 1.7b, 4b, 8b, 14b, and 32b models are available via `--which` argument.
```bash
cargo run --example quantized-qwen3 --release -- --which 4b --prompt "A train is travelling at 120mph, how far does it travel in 3 minutes 30 seconds?"
```

View File

@ -0,0 +1,314 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use std::io::Write;
use tokenizers::Tokenizer;
use candle::quantized::gguf_file;
use candle::Tensor;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_transformers::models::quantized_qwen3::ModelWeights as Qwen3;
const DEFAULT_PROMPT: &str = "Write a Rust function to calculate the factorial of a given number.";
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
#[value(name = "0.6b")]
W3_0_6b,
#[value(name = "1.7b")]
W3_1_7b,
#[value(name = "4b")]
W3_4b,
#[value(name = "8b")]
W3_8b,
#[value(name = "14b")]
W3_14b,
#[value(name = "32b")]
W3_32b,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// GGUF file to load, typically a .gguf file generated by the quantize command from llama.cpp
#[arg(long)]
model: Option<String>,
/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
/// is preserved.
#[arg(long)]
prompt: Option<String>,
/// The length of the sample to generate (in tokens).
#[arg(short = 'n', long, default_value_t = 1000)]
sample_len: usize,
/// The tokenizer config in json format.
#[arg(long)]
tokenizer: Option<String>,
/// The temperature used to generate samples, use 0 for greedy sampling.
#[arg(long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Process prompt elements separately.
#[arg(long)]
split_prompt: bool,
/// Run on CPU rather than GPU even if a GPU is available.
#[arg(long)]
cpu: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model size to use.
#[arg(long, default_value = "0.6b")]
which: Which,
}
impl Args {
fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
let tokenizer_path = match &self.tokenizer {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
let repo = match self.which {
Which::W3_0_6b => "Qwen/Qwen3-0.6B",
Which::W3_1_7b => "Qwen/Qwen3-1.7B",
Which::W3_4b => "Qwen/Qwen3-4B",
Which::W3_8b => "Qwen/Qwen3-8B",
Which::W3_14b => "Qwen/Qwen3-14B",
Which::W3_32b => "Qwen/Qwen3-32B",
};
let api = api.model(repo.to_string());
api.get("tokenizer.json")?
}
};
Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
}
fn model(&self) -> anyhow::Result<std::path::PathBuf> {
let model_path = match &self.model {
Some(config) => std::path::PathBuf::from(config),
None => {
let (repo, filename, revision) = match self.which {
Which::W3_0_6b => ("unsloth/Qwen3-0.6B-GGUF", "Qwen3-0.6B-Q4_K_M.gguf", "main"),
Which::W3_1_7b => ("unsloth/Qwen3-1.7B-GGUF", "Qwen3-1.7B-Q4_K_M.gguf", "main"),
Which::W3_4b => ("unsloth/Qwen3-4B-GGUF", "Qwen3-4B-Q4_K_M.gguf", "main"),
Which::W3_8b => ("unsloth/Qwen3-8B-GGUF", "Qwen3-8B-Q4_K_M.gguf", "main"),
Which::W3_14b => ("unsloth/Qwen3-14B-GGUF", "Qwen3-14B-Q4_K_M.gguf", "main"),
Which::W3_32b => ("unsloth/Qwen3-32B-GGUF", "Qwen3-32B-Q4_K_M.gguf", "main"),
};
let api = hf_hub::api::sync::Api::new()?;
api.repo(hf_hub::Repo::with_revision(
repo.to_string(),
hf_hub::RepoType::Model,
revision.to_string(),
))
.get(filename)?
}
};
Ok(model_path)
}
}
fn format_size(size_in_bytes: usize) -> String {
if size_in_bytes < 1_000 {
format!("{}B", size_in_bytes)
} else if size_in_bytes < 1_000_000 {
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
} else if size_in_bytes < 1_000_000_000 {
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
} else {
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
}
}
fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature, args.repeat_penalty, args.repeat_last_n
);
let model_path = args.model()?;
let mut file = std::fs::File::open(&model_path)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let mut model = {
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count();
total_size_in_bytes +=
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
}
println!(
"loaded {:?} tensors ({}) in {:.2}s",
model.tensor_infos.len(),
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
Qwen3::from_gguf(model, &mut file, &device)?
};
println!("model built");
let tokenizer = args.tokenizer()?;
let mut tos = TokenOutputStream::new(tokenizer);
let prompt_str = args
.prompt
.clone()
.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
let prompt_str = format!("<|im_start|>user\n{prompt_str}<|im_end|>\n<|im_start|>assistant\n");
print!("formatted prompt: {}", &prompt_str);
let tokens = tos
.tokenizer()
.encode(prompt_str, true)
.map_err(anyhow::Error::msg)?;
let tokens = tokens.get_ids();
let to_sample = args.sample_len.saturating_sub(1);
let mut all_tokens = vec![];
let mut logits_processor = {
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k, args.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
let start_prompt_processing = std::time::Instant::now();
let mut next_token = if !args.split_prompt {
let input = Tensor::new(tokens, &device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)?
} else {
let mut next_token = 0;
for (pos, token) in tokens.iter().enumerate() {
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, pos)?;
let logits = logits.squeeze(0)?;
next_token = logits_processor.sample(&logits)?
}
next_token
};
let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
let eos_token = *tos.tokenizer().get_vocab(true).get("<|im_end|>").unwrap();
let start_post_prompt = std::time::Instant::now();
let mut sampled = 0;
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, tokens.len() + index)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&all_tokens[start_at..],
)?
};
next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
sampled += 1;
if next_token == eos_token {
break;
};
}
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
let dt = start_post_prompt.elapsed();
println!(
"\n\n{:4} prompt tokens processed: {:.2} token/s",
tokens.len(),
tokens.len() as f64 / prompt_dt.as_secs_f64(),
);
println!(
"{sampled:4} tokens generated: {:.2} token/s",
sampled as f64 / dt.as_secs_f64(),
);
Ok(())
}

View File

@ -25,3 +25,28 @@ def print_prime(n: int): # n is the number of primes to be printed
print(i)
```
The qwen3 MoE variant is also an option.
```bash
$ cargo run --example qwen --features metal --release -- --prompt "Write a poem about butterflies. <think></think>." --model "3-moe-a3b"
> In morning's hush, where daisies sleep,
> A fleeting dance through sunlit deep—
> They flutter soft on gossamer thread,
> The messengers of springs own head.
>
> With painted sails and delicate grace,
> They drift from bloom to blossom's face.
> Each wing a tale in hues unseen,
> Of ancient dreams and secrets between.
>
> No sound they make, yet still they speak—
> Of time that flies, of life so brief.
> A fleeting kiss on summers breath,
> A whisper lost before death.
>
> Yet in their flight, the soul takes wing,
> And for a moment, all is spring.
> For though they fade, they never die—
> Their beauty lives where hearts can fly.
> 161 tokens generated (3.00 token/s)
```

View File

@ -9,6 +9,8 @@ use clap::Parser;
use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase};
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
use candle_transformers::models::qwen3::{Config as Config3, ModelForCausalLM as Model3};
use candle_transformers::models::qwen3_moe::{Config as ConfigMoe3, ModelForCausalLM as ModelMoe3};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
@ -20,6 +22,8 @@ use tokenizers::Tokenizer;
enum Model {
Base(ModelBase),
Moe(ModelMoe),
Base3(Model3),
Moe3(ModelMoe3),
}
impl Model {
@ -27,6 +31,8 @@ impl Model {
match self {
Self::Moe(ref mut m) => m.forward(xs, s),
Self::Base(ref mut m) => m.forward(xs, s),
Self::Base3(ref mut m) => m.forward(xs, s),
Self::Moe3(ref mut m) => m.forward(xs, s),
}
}
}
@ -85,6 +91,10 @@ impl TextGeneration {
Some(token) => token,
None => anyhow::bail!("cannot find the <|endoftext|> token"),
};
let eos_token2 = match self.tokenizer.get_token("<|im_end|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <|im_end|> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
@ -107,7 +117,7 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
if next_token == eos_token || next_token == eos_token2 {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
@ -152,6 +162,16 @@ enum WhichModel {
W2_7b,
#[value(name = "2-72b")]
W2_72b,
#[value(name = "3-0.6b")]
W3_0_6b,
#[value(name = "3-1.7b")]
W3_1_7b,
#[value(name = "3-4b")]
W3_4b,
#[value(name = "3-8b")]
W3_8b,
#[value(name = "3-moe-a3b")]
W3MoeA3b,
}
#[derive(Parser, Debug)]
@ -254,6 +274,11 @@ fn main() -> Result<()> {
WhichModel::W14b => ("1.5", "14B"),
WhichModel::W72b => ("1.5", "72B"),
WhichModel::MoeA27b => ("1.5", "MoE-A2.7B"),
WhichModel::W3_0_6b => ("3", "0.6B"),
WhichModel::W3_1_7b => ("3", "1.7B"),
WhichModel::W3_4b => ("3", "4B"),
WhichModel::W3_8b => ("3", "8B"),
WhichModel::W3MoeA3b => ("3", "30B-A3B"),
};
format!("Qwen/Qwen{version}-{size}")
}
@ -273,7 +298,11 @@ fn main() -> Result<()> {
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => match args.model {
WhichModel::W0_5b | WhichModel::W2_0_5b | WhichModel::W2_1_5b | WhichModel::W1_8b => {
WhichModel::W0_5b
| WhichModel::W2_0_5b
| WhichModel::W2_1_5b
| WhichModel::W1_8b
| WhichModel::W3_0_6b => {
vec![repo.get("model.safetensors")?]
}
WhichModel::W4b
@ -282,7 +311,11 @@ fn main() -> Result<()> {
| WhichModel::W14b
| WhichModel::W72b
| WhichModel::W2_72b
| WhichModel::MoeA27b => {
| WhichModel::MoeA27b
| WhichModel::W3_1_7b
| WhichModel::W3_4b
| WhichModel::W3_8b
| WhichModel::W3MoeA3b => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
},
@ -304,6 +337,14 @@ fn main() -> Result<()> {
let config: ConfigMoe = serde_json::from_slice(&std::fs::read(config_file)?)?;
Model::Moe(ModelMoe::new(&config, vb)?)
}
WhichModel::W3_0_6b | WhichModel::W3_1_7b | WhichModel::W3_4b | WhichModel::W3_8b => {
let config: Config3 = serde_json::from_slice(&std::fs::read(config_file)?)?;
Model::Base3(Model3::new(&config, vb)?)
}
WhichModel::W3MoeA3b => {
let config: ConfigMoe3 = serde_json::from_slice(&std::fs::read(config_file)?)?;
Model::Moe3(ModelMoe3::new(&config, vb)?)
}
_ => {
let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?;
Model::Base(ModelBase::new(&config, vb)?)

View File

@ -28,3 +28,26 @@ Ranking Results:
> Rank #1 | Score: 0.9990 | The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.
--------------------------------------------------------------------------------
```
Text-Classification:
```bash
cargo run --example xlm-roberta -- --task text-classification --model xlmr-formality-classifier
```
```markdown
Formality Scores:
Text 1: "I like you. I love you"
formal: 0.9933
informal: 0.0067
Text 2: "Hey, what's up?"
formal: 0.8812
informal: 0.1188
Text 3: "Siema, co porabiasz?"
formal: 0.9358
informal: 0.0642
Text 4: "I feel deep regret and sadness about the situation in international politics."
formal: 0.9987
informal: 0.0013
```

View File

@ -2,6 +2,7 @@ use std::path::PathBuf;
use anyhow::{Error as E, Result};
use candle::{Device, Tensor};
use candle_nn::ops::softmax;
use candle_nn::VarBuilder;
use candle_transformers::models::xlm_roberta::{
Config, XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification,
@ -17,12 +18,14 @@ enum Model {
BgeRerankerBaseV2,
XLMRobertaBase,
XLMRobertaLarge,
XLMRFormalityClassifier,
}
#[derive(Debug, Clone, ValueEnum)]
enum Task {
FillMask,
Reranker,
TextClassification,
}
#[derive(Parser, Debug)]
@ -83,6 +86,12 @@ fn main() -> Result<()> {
Model::BgeRerankerBaseV2 => "BAAI/bge-reranker-base-v2-m3".to_string(),
_ => anyhow::bail!("XLM-RoBERTa models are not supported for reranker task"),
},
Task::TextClassification => match args.model {
Model::XLMRFormalityClassifier => "s-nlp/xlmr_formality_classifier".to_string(),
_ => anyhow::bail!(
"XLM-RoBERTa models are not supported for text classification task"
),
},
},
};
let repo = api.repo(Repo::with_revision(
@ -217,6 +226,36 @@ fn main() -> Result<()> {
});
println!("{:-<80}", "");
}
Task::TextClassification => {
let sentences = vec![
"I like you. I love you".to_string(),
"Hey, what's up?".to_string(),
"Siema, co porabiasz?".to_string(),
"I feel deep regret and sadness about the situation in international politics."
.to_string(),
];
let model = XLMRobertaForSequenceClassification::new(2, &config, vb)?;
let input_ids = tokenize_batch(&tokenizer, TokenizeInput::Single(&sentences), &device)?;
let attention_mask =
get_attention_mask(&tokenizer, TokenizeInput::Single(&sentences), &device)?;
let token_type_ids = Tensor::zeros(input_ids.dims(), input_ids.dtype(), &device)?;
let logits = model
.forward(&input_ids, &attention_mask, &token_type_ids)?
.to_dtype(candle::DType::F32)?;
let probabilities = softmax(&logits, 1)?;
let probs_vec = probabilities.to_vec2::<f32>()?;
println!("Formality Scores:");
for (i, (text, probs)) in sentences.iter().zip(probs_vec.iter()).enumerate() {
println!("Text {}: \"{}\"", i + 1, text);
println!(" formal: {:.4}", probs[0]);
println!(" informal: {:.4}", probs[1]);
println!();
}
}
}
Ok(())
}

View File

@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
version = "0.9.0-alpha.4"
version = "0.9.1"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.4" }
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.1" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]

View File

@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
version = "0.9.0-alpha.4"
version = "0.9.1"
edition = "2021"
description = "CUDA kernels for Candle"

View File

@ -1,5 +1,6 @@
#include<stdint.h>
#include "cuda_fp16.h"
#include "cuda_utils.cuh"
template<typename T>
__device__ void fill_with(T *buf, T value, const size_t numel) {
@ -36,13 +37,45 @@ COPY2D_OP(uint8_t, copy2d_u8)
COPY2D_OP(uint32_t, copy2d_u32)
COPY2D_OP(int64_t, copy2d_i64)
#define CONST_SET_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t numel, \
const size_t num_dims, \
const size_t *info, \
const TYPENAME inp, \
TYPENAME *out \
) { \
const size_t *dims = info; \
const size_t *strides = info + num_dims; \
if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
out[i] = inp; \
} \
} \
else { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
out[strided_i] = inp; \
} \
} \
} \
CONST_SET_OP(float, const_set_f32)
CONST_SET_OP(double, const_set_f64)
CONST_SET_OP(uint8_t, const_set_u8)
CONST_SET_OP(uint32_t, const_set_u32)
CONST_SET_OP(int64_t, const_set_i64)
#if __CUDA_ARCH__ >= 530
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
COPY2D_OP(__half, copy2d_f16)
CONST_SET_OP(__half, const_set_f16)
#endif
#if __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); }
COPY2D_OP(__nv_bfloat16, copy2d_bf16)
CONST_SET_OP(__nv_bfloat16, const_set_bf16)
#endif

View File

@ -3,6 +3,28 @@
#include "cuda_utils.cuh"
#include<stdint.h>
template <typename T>
__host__ __device__
constexpr T max_value();
template <>
__host__ __device__
constexpr int64_t max_value<int64_t>() {
return 0x7FFFFFFFFFFFFFFFLL;
}
template <>
__host__ __device__
constexpr uint32_t max_value<uint32_t>() {
return 0xFFFFFFFFu;
}
template <>
__host__ __device__
constexpr uint8_t max_value<uint8_t>() {
return 0xFFu;
}
template<typename T, typename I>
__device__ void index_select(
const size_t numel,
@ -23,9 +45,14 @@ __device__ void index_select(
unsigned int left_i = dst_i / (ids_dim_size * right_size);
unsigned int id_i = dst_i / right_size % ids_dim_size;
unsigned int right_i = dst_i % right_size;
unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i;
unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides);
out[dst_i] = inp[strided_i];
if (ids[id_i] == max_value<I>()) {
out[dst_i] = static_cast<T>(0);
} else {
assert(ids[id_i] < src_dim_size);
unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i;
unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides);
out[dst_i] = inp[strided_i];
}
}
}
@ -56,10 +83,15 @@ __device__ void gather(
) {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
size_t post = i % right_size;
size_t idx = ids[i];
size_t pre = i / (right_size * ids_dim_size);
size_t src_i = (pre * src_dim_size + idx) * right_size + post;
out[i] = inp[src_i];
const I idx = ids[i];
if (ids[i] == max_value<I>()) {
out[i] = static_cast<T>(0);
} else {
assert(idx < src_dim_size);
size_t pre = i / (right_size * ids_dim_size);
size_t src_i = (pre * src_dim_size + idx) * right_size + post;
out[i] = inp[src_i];
}
}
}
@ -91,10 +123,13 @@ __device__ void index_add(
const size_t pre = i / right_size;
const size_t post = i % right_size;
for (unsigned int j = 0; j < ids_dim_size; ++j) {
const size_t idx = ids[j];
const I idx = ids[j];
const size_t src_i = (pre * ids_dim_size + j) * right_size + post;
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] += inp[src_i];
if (idx < max_value<I>()) {
assert(idx < dst_dim_size);
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] += inp[src_i];
}
}
}
}
@ -111,6 +146,32 @@ extern "C" __global__ void FN_NAME( \
const size_t right_size \
) { index_add(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \
template<typename T, typename I>
__device__ void scatter(
const I *ids,
const T *inp,
T *out,
const size_t left_size,
const size_t src_dim_size,
const size_t dst_dim_size,
const size_t right_size
) {
const size_t numel = left_size * right_size;
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
const size_t pre = i / right_size;
const size_t post = i % right_size;
for (unsigned int j = 0; j < src_dim_size; ++j) {
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
const I idx = ids[src_i];
if (idx < max_value<I>()) {
assert(idx < dst_dim_size);
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] = inp[src_i];
}
}
}
}
template<typename T, typename I>
__device__ void scatter_add(
const I *ids,
@ -127,13 +188,27 @@ __device__ void scatter_add(
const size_t post = i % right_size;
for (unsigned int j = 0; j < src_dim_size; ++j) {
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
const size_t idx = ids[src_i];
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] += inp[src_i];
const I idx = ids[src_i];
if (idx < max_value<I>()) {
assert(idx < dst_dim_size);
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] += inp[src_i];
}
}
}
}
#define S_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const INDEX_TYPENAME *ids, \
const TYPENAME *inp, \
TYPENAME *out, \
const size_t left_size, \
const size_t src_dim_size, \
const size_t dst_dim_size, \
const size_t right_size \
) { scatter(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \
#define SA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const INDEX_TYPENAME *ids, \
@ -159,6 +234,9 @@ IA_OP(__nv_bfloat16, uint8_t, ia_u8_bf16)
SA_OP(__nv_bfloat16, int64_t, sa_i64_bf16)
SA_OP(__nv_bfloat16, uint32_t, sa_u32_bf16)
SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16)
S_OP(__nv_bfloat16, int64_t, s_i64_bf16)
S_OP(__nv_bfloat16, uint32_t, s_u32_bf16)
S_OP(__nv_bfloat16, uint8_t, s_u8_bf16)
#endif
#if __CUDA_ARCH__ >= 530
@ -174,6 +252,9 @@ IA_OP(__half, uint8_t, ia_u8_f16)
SA_OP(__half, int64_t, sa_i64_f16)
SA_OP(__half, uint32_t, sa_u32_f16)
SA_OP(__half, uint8_t, sa_u8_f16)
S_OP(__half, int64_t, s_i64_f16)
S_OP(__half, uint32_t, s_u32_f16)
S_OP(__half, uint8_t, s_u8_f16)
#endif
IS_OP(float, int64_t, is_i64_f32)
@ -247,3 +328,21 @@ SA_OP(double, uint8_t, sa_u8_f64)
SA_OP(uint8_t, uint8_t, sa_u8_u8)
SA_OP(uint32_t, uint8_t, sa_u8_u32)
SA_OP(int64_t, uint8_t, sa_u8_i64)
S_OP(float, int64_t, s_i64_f32)
S_OP(double, int64_t, s_i64_f64)
S_OP(uint8_t, int64_t, s_i64_u8)
S_OP(int64_t, int64_t, s_i64_i64)
S_OP(uint32_t, int64_t, s_i64_u32)
S_OP(float, uint32_t, s_u32_f32)
S_OP(double, uint32_t, s_u32_f64)
S_OP(uint8_t, uint32_t, s_u32_u8)
S_OP(int64_t, uint32_t, s_u32_i64)
S_OP(uint32_t, uint32_t, s_u32_u32)
S_OP(float, uint8_t, s_u8_f32)
S_OP(double, uint8_t, s_u8_f64)
S_OP(uint8_t, uint8_t, s_u8_u8)
S_OP(uint32_t, uint8_t, s_u8_u32)
S_OP(int64_t, uint8_t, s_u8_i64)

View File

@ -219,11 +219,15 @@ __device__ void softmax(const T * x, T * dst, const int ncols) {
}
template <typename T>
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) {
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t stride_b) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (2 * idx >= bh * td) return;
uint32_t rope_idx = idx % (td / 2);
if (stride_b > 0) {
uint32_t b_idx = (2 * idx) / stride_b;
rope_idx += b_idx * (td / 2);
}
T c = cos[rope_idx];
T s = sin[rope_idx];
@ -232,7 +236,7 @@ __device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, cons
}
template <typename T>
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d) {
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d, const uint32_t stride_b) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (2 * idx >= bh * td) return;
@ -243,6 +247,10 @@ __device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const
uint32_t i1 = i_bh * td + i_t * d + i_d;
uint32_t i2 = i1 + d / 2;
uint32_t i_cs = i_t * (d / 2) + i_d;
if (stride_b > 0) {
uint32_t b_idx = (2 * idx) / stride_b;
i_cs += b_idx * (td / 2);
}
T c = cos[i_cs];
T s = sin[i_cs];
@ -259,7 +267,8 @@ __device__ void rope_thd(
const uint32_t b,
const uint32_t t,
const uint32_t h,
const uint32_t d
const uint32_t d,
const uint32_t stride_b
) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (2 * idx >= b * t * h * d) return;
@ -270,6 +279,10 @@ __device__ void rope_thd(
uint32_t i1 = i_bth * d + i_d;
uint32_t i2 = i1 + d / 2;
uint32_t i_cs = i_t * (d / 2) + i_d;
if (stride_b > 0) {
uint32_t b_idx = (2 * idx) / stride_b;
i_cs += b_idx * ((t * d) / 2);
}
T c = cos[i_cs];
T s = sin[i_cs];
@ -546,8 +559,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
const TYPENAME *sin, \
TYPENAME *dst, \
const uint32_t bh, \
const uint32_t td) { \
ropei<TYPENAME>(src, cos, sin, dst, bh, td); \
const uint32_t td, \
const uint32_t stride_b) { \
ropei<TYPENAME>(src, cos, sin, dst, bh, td, stride_b); \
} \
extern "C" __global__ void FN_NAME( \
const TYPENAME *src, \
@ -556,8 +570,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
TYPENAME *dst, \
const uint32_t bh, \
const uint32_t td, \
const uint32_t d) { \
rope<TYPENAME>(src, cos, sin, dst, bh, td, d); \
const uint32_t d, \
const uint32_t stride_b) { \
rope<TYPENAME>(src, cos, sin, dst, bh, td, d, stride_b); \
} \
extern "C" __global__ void FN_NAME_THD( \
const TYPENAME *src, \
@ -567,8 +582,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
const uint32_t b, \
const uint32_t t, \
const uint32_t h, \
const uint32_t d) { \
rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d); \
const uint32_t d, \
const uint32_t stride_b) { \
rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d, stride_b); \
} \
#if __CUDA_ARCH__ >= 800

View File

@ -1,6 +1,6 @@
[package]
name = "candle-metal-kernels"
version = "0.9.0-alpha.4"
version = "0.9.1"
edition = "2021"
description = "Metal kernels for Candle"
@ -12,6 +12,7 @@ license = "MIT OR Apache-2.0"
[dependencies]
metal = { version = "0.27.0", features = ["mps"] }
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] }
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"

View File

@ -4,20 +4,20 @@ using namespace metal;
template<typename T> METAL_FUNC void fill_with(
device T *out,
constant float &value,
constant T &value,
constant size_t &numel,
uint tid [[thread_position_in_grid]]
) {
if (tid >= numel) {
return;
}
out[tid] = static_cast<T>(value);
out[tid] = value;
}
#define FILL_OP(NAME, T) \
kernel void fill_##NAME( \
device T *out, \
constant float &value, \
constant T &value, \
constant size_t &numel, \
uint tid [[thread_position_in_grid]] \
) { \

View File

@ -1,6 +1,24 @@
#include <metal_stdlib>
using namespace metal;
template <typename T>
inline T max_value();
template <>
inline int64_t max_value<int64_t>() {
return 0x7FFFFFFFFFFFFFFF;
}
template <>
inline uint32_t max_value<uint32_t>() {
return 0xFFFFFFFFu;
}
template <>
inline uint8_t max_value<uint8_t>() {
return 0xFF;
}
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
@ -35,17 +53,21 @@ METAL_FUNC void index(
return;
}
const size_t id_i = (tid / right_size) % ids_size;
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size / ids_size;
/*
// Force prevent out of bounds indexing
// since there doesn't seem to be a good way to force crash
// No need to check for zero we're only allowing unsized.
*/
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides);
output[tid] = input[strided_src_i];
if (input_ids[id_i] == max_value<INDEX_TYPENAME>()) {
output[tid] = static_cast<TYPENAME>(0);
} else {
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size / ids_size;
/*
// Force prevent out of bounds indexing
// since there doesn't seem to be a good way to force crash
// No need to check for zero we're only allowing unsized.
*/
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides);
output[tid] = input[strided_src_i];
}
}
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
@ -83,10 +105,14 @@ METAL_FUNC void gather(
return;
}
const INDEX_TYPENAME input_i = input_ids[tid];
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size / ids_size;
const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
output[tid] = input[src_i];
if (input_i == max_value<INDEX_TYPENAME>()) {
output[tid] = static_cast<TYPENAME>(0);
} else {
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size / ids_size;
const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
output[tid] = input[src_i];
}
}
# define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
@ -104,6 +130,33 @@ kernel void NAME( \
gather<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
}
template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void scatter(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &dst_dim_size,
const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
uint tid [[ thread_position_in_grid ]]
) {
if (tid >= dst_size) {
return;
}
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size;
for (unsigned int j = 0; j < src_dim_size; ++j) {
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
const INDEX_TYPENAME idx = input_ids[src_i];
if (idx < max_value<INDEX_TYPENAME>()) {
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
output[dst_i] = input[src_i];
}
}
}
template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void scatter_add(
constant size_t &dst_size,
@ -124,11 +177,28 @@ METAL_FUNC void scatter_add(
for (unsigned int j = 0; j < src_dim_size; ++j) {
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
const INDEX_TYPENAME idx = input_ids[src_i];
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
output[dst_i] += input[src_i];
if (idx < max_value<INDEX_TYPENAME>()) {
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
output[dst_i] += input[src_i];
}
}
}
# define SCATTER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \
constant size_t &dst_size, \
constant size_t &left_size, \
constant size_t &src_dim_size, \
constant size_t &right_size, \
constant size_t &dst_dim_size, \
const device TYPENAME *input, \
const device INDEX_TYPENAME *input_ids, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
scatter<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, input, input_ids, output, tid); \
}
# define SCATTER_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \
constant size_t &dst_size, \
@ -164,9 +234,11 @@ METAL_FUNC void index_add(
const size_t left_rank_i = tid / right_size;
for (unsigned int j = 0; j < ids_dim_size; ++j) {
const INDEX_TYPENAME idx = input_ids[j];
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
output[dst_i] += input[src_i];
if (idx < max_value<INDEX_TYPENAME>()) {
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
output[dst_i] += input[src_i];
}
}
}
@ -235,6 +307,19 @@ SCATTER_ADD_OP(sa_u8_bf16, uint8_t, bfloat)
SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat)
#endif
SCATTER_OP(s_u32_f32, uint32_t, float)
SCATTER_OP(s_u8_f32, uint8_t, float)
SCATTER_OP(s_i64_f32, int64_t, float)
SCATTER_OP(s_u32_u32, uint32_t, uint32_t)
SCATTER_OP(s_u32_f16, uint32_t, half)
SCATTER_OP(s_u8_f16, uint8_t, half)
SCATTER_OP(s_i64_f16, int64_t, half)
#if defined(__HAVE_BFLOAT__)
SCATTER_OP(s_u32_bf16, uint32_t, bfloat)
SCATTER_OP(s_u8_bf16, uint8_t, bfloat)
SCATTER_OP(s_i64_bf16, int64_t, bfloat)
#endif
// i64
INDEX_ADD_OP(ia_i64_f16, int64_t, half)
INDEX_ADD_OP(ia_i64_f32, int64_t, float)

View File

@ -161,7 +161,7 @@ macro_rules! ops{
pub mod unary {
ops!(
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
tanh, recip, silu, sign, sigmoid
tanh, recip, silu, sign, sigmoid, const_set
);
}
pub mod binary {
@ -419,6 +419,82 @@ pub fn call_copy2d(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_const_set_contiguous_tiled(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: unary::contiguous_tiled::Kernel,
length: usize,
input: impl EncoderParam,
output: BufferOffset,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let tile_size = 2;
let tiles = length.div_ceil(tile_size);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, input, &output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles);
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_const_set_contiguous(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: unary::contiguous::Kernel,
length: usize,
input: impl EncoderParam,
output: BufferOffset,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, input, &output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_const_set_strided(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
name: unary::strided::Kernel,
shape: &[usize],
input: impl EncoderParam,
strides: &[usize],
output: BufferOffset,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
let length: usize = shape.iter().product();
let num_dims: usize = shape.len();
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, num_dims, shape, strides, input, &output));
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_unary_contiguous_tiled(
device: &Device,
@ -915,6 +991,7 @@ pub fn call_rope_i(
kernel_name: &'static str,
bh: usize,
td: usize,
stride_b: usize,
src: &Buffer,
src_offset: usize,
cos: &Buffer,
@ -933,6 +1010,7 @@ pub fn call_rope_i(
(
bh,
td,
stride_b,
(src, src_offset),
(cos, cos_offset),
(sin, sin_offset),
@ -958,6 +1036,7 @@ pub fn call_rope_thd(
t: usize,
h: usize,
d: usize,
stride_b: usize,
src: &Buffer,
src_offset: usize,
cos: &Buffer,
@ -978,6 +1057,7 @@ pub fn call_rope_thd(
t,
h,
d,
stride_b,
(src, src_offset),
(cos, cos_offset),
(sin, sin_offset),
@ -1002,6 +1082,7 @@ pub fn call_rope(
bh: usize,
td: usize,
d: usize,
stride_b: usize,
src: &Buffer,
src_offset: usize,
cos: &Buffer,
@ -1021,6 +1102,7 @@ pub fn call_rope(
bh,
td,
d,
stride_b,
(src, src_offset),
(cos, cos_offset),
(sin, sin_offset),
@ -1371,7 +1453,7 @@ pub fn call_gather(
}
#[allow(clippy::too_many_arguments)]
pub fn call_scatter_add(
pub fn call_scatter(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
@ -1381,7 +1463,7 @@ pub fn call_scatter_add(
dim: usize,
input: BufferOffset,
ids: BufferOffset,
output: &Buffer,
output: BufferOffset,
) -> Result<(), MetalKernelError> {
let left_size: usize = src_shape[..dim].iter().product();
let right_size: usize = src_shape[dim + 1..].iter().product();
@ -1406,7 +1488,7 @@ pub fn call_scatter_add(
dst_dim_size,
&input,
&ids,
output
&output
)
);
@ -1414,7 +1496,7 @@ pub fn call_scatter_add(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}
@ -2570,7 +2652,7 @@ pub fn call_const_fill(
name: &'static str,
length: usize,
output: &Buffer,
v: f32,
v: impl EncoderParam,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
let encoder = ep.encoder();

View File

@ -1097,6 +1097,7 @@ template<typename T>
METAL_FUNC void ropei(
constant size_t &bh,
constant size_t &td,
constant size_t &stride_b,
device const T *src,
device const T *cos,
device const T *sin,
@ -1107,6 +1108,10 @@ METAL_FUNC void ropei(
return;
}
size_t rope_idx = tid % (td / 2);
if (stride_b > 0) {
size_t b_idx = (2 * tid) / stride_b;
rope_idx += b_idx * (td / 2);
}
T c = cos[rope_idx];
T s = sin[rope_idx];
dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s;
@ -1118,6 +1123,7 @@ METAL_FUNC void rope(
constant size_t &bh,
constant size_t &td,
constant size_t &d,
constant size_t &stride_b,
device const T *src,
device const T *cos,
device const T *sin,
@ -1134,6 +1140,10 @@ METAL_FUNC void rope(
size_t i1 = i_bh * td + i_t * d + i_d;
size_t i2 = i1 + d / 2;
size_t i_cs = i_t * (d / 2) + i_d;
if (stride_b > 0) {
size_t b_idx = (2 * idx) / stride_b;
i_cs += b_idx * (td / 2);
}
T c = cos[i_cs];
T s = sin[i_cs];
dst[i1] = src[i1] * c - src[i2] * s;
@ -1146,6 +1156,7 @@ METAL_FUNC void rope_thd(
constant size_t &t,
constant size_t &h,
constant size_t &d,
constant size_t &stride_b,
device const T *src,
device const T *cos,
device const T *sin,
@ -1160,8 +1171,12 @@ METAL_FUNC void rope_thd(
const size_t i_t = (i_bth / h) % t;
const size_t i1 = i_bth * d + i_d;
const size_t i2 = i1 + d / 2;
const size_t i_cs = i_t * (d / 2) + i_d;
T c = cos[i_cs];
size_t i_cs = i_t * (d / 2) + i_d;
if (stride_b > 0) {
const size_t b_idx = (2 * idx) / stride_b;
i_cs += b_idx * ((t * d) / 2);
}
T c = cos[i_cs];
T s = sin[i_cs];
dst[i1] = src[i1] * c - src[i2] * s;
dst[i2] = src[i1] * s + src[i2] * c;
@ -1171,38 +1186,41 @@ METAL_FUNC void rope_thd(
kernel void FN_NAME_I( \
constant size_t &bh, \
constant size_t &td, \
constant size_t &stride_b, \
device const TYPENAME *src, \
device const TYPENAME *cos, \
device const TYPENAME *sin, \
device TYPENAME *dst, \
uint tid [[ thread_position_in_grid ]] \
) { \
ropei<TYPENAME>(bh, td, src, cos, sin, dst, tid); \
ropei<TYPENAME>(bh, td, stride_b, src, cos, sin, dst, tid); \
}\
kernel void FN_NAME( \
constant size_t &bh, \
constant size_t &td, \
constant size_t &d, \
constant size_t &stride_b, \
device const TYPENAME *src, \
device const TYPENAME *cos, \
device const TYPENAME *sin, \
device TYPENAME *dst, \
uint idx [[ thread_position_in_grid ]] \
) { \
rope<TYPENAME>(bh, td, d, src, cos, sin, dst, idx); \
rope<TYPENAME>(bh, td, d, stride_b, src, cos, sin, dst, idx); \
}\
kernel void FN_NAME_THD( \
constant size_t &b, \
constant size_t &t, \
constant size_t &h, \
constant size_t &d, \
constant size_t &stride_b, \
device const TYPENAME *src, \
device const TYPENAME *cos, \
device const TYPENAME *sin, \
device TYPENAME *dst, \
uint idx [[ thread_position_in_grid ]] \
) { \
rope_thd<TYPENAME>(b, t, h, d, src, cos, sin, dst, idx); \
rope_thd<TYPENAME>(b, t, h, d, stride_b, src, cos, sin, dst, idx); \
}\
RMSNORM(rmsnorm_f32, float)

View File

@ -1574,7 +1574,7 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
let input_buffer = new_buffer(&device, input);
let ids_buffer = new_buffer(&device, ids);
let output = device.new_buffer(std::mem::size_of_val(input) as u64, options);
call_scatter_add(
call_scatter(
&device,
command_buffer,
&kernels,
@ -1584,7 +1584,7 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
dim,
BufferOffset::zero_offset(&input_buffer),
BufferOffset::zero_offset(&ids_buffer),
&output,
BufferOffset::zero_offset(&output),
)
.unwrap();
command_buffer.commit();
@ -2343,7 +2343,7 @@ fn conv_transpose1d_u32() {
#[test]
fn const_fill() {
fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
fn constant_fill<T: Clone + EncoderParam>(name: &'static str, len: usize, value: T) -> Vec<T> {
let dev = device();
let kernels = Kernels::new();
let command_queue = dev.new_command_queue();
@ -2357,11 +2357,15 @@ fn const_fill() {
command_buffer.wait_until_completed();
read_to_vec::<T>(&buffer, len)
}
fn test<T: Clone + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(name: &'static str, f: F) {
fn test<T: Clone + Copy + EncoderParam + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(
name: &'static str,
f: F,
) {
let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16);
let value = rand::thread_rng().gen_range(1. ..19.);
let value = f(value);
let v = constant_fill::<T>(name, len, value);
assert_eq!(v, vec![f(value); len])
assert_eq!(v, vec![value; len])
}
test::<u8, _>("fill_u8", |v| v as u8);
test::<u32, _>("fill_u32", |v| v as u32);

View File

@ -73,6 +73,44 @@ template <typename T> METAL_FUNC T sigmoid(T in) {
#define TILE_SIZE 2
#define CONST_SET(TYPENAME, FN_NAME) \
kernel void FN_NAME( \
constant size_t &dim, \
constant TYPENAME &input, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
if (tid >= dim) { \
return; \
} \
output[tid] = input; \
} \
kernel void FN_NAME##_##strided( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant TYPENAME &input, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
if (tid >= dim) { \
return; \
} \
output[get_strided_index(tid, num_dims, dims, strides)] = input; \
} \
kernel void FN_NAME##_##tiled( \
constant size_t &dim, \
constant TYPENAME &input, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
for (uint i = 0; i < TILE_SIZE; i++) { \
const uint idx = tid * TILE_SIZE + i; \
output[idx] = input; \
} \
}
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
constant size_t &dim, \
@ -139,6 +177,11 @@ COPY2D(copy2d_f16, half)
COPY2D(copy2d_u8, uint8_t)
COPY2D(copy2d_u32, uint32_t)
CONST_SET(float, const_set_f32)
CONST_SET(half, const_set_f16)
CONST_SET(uint8_t, const_set_u8)
CONST_SET(uint32_t, const_set_u32)
UNARY_OP(cos)
UNARY_OP(sin)
UNARY_OP(sqr)
@ -171,6 +214,7 @@ UNARY(precise::tanh, half, tanh_f16, tanh_f16_strided);
#if __METAL_VERSION__ >= 220
UNARY(id, int64_t, copy_i64, copy_i64_strided)
COPY2D(copy2d_i64, int64_t)
CONST_SET(int64_t, const_set_i64)
#endif
#if defined(__HAVE_BFLOAT__)
@ -199,4 +243,5 @@ UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
UNARY(precise::tanh, bfloat, tanh_bf16, tanh_bf16_strided);
COPY2D(copy2d_bf16, bfloat)
CONST_SET(bfloat, const_set_bf16)
#endif

View File

@ -88,9 +88,13 @@ primitive!(bool);
primitive!(usize);
primitive!(i32);
primitive!(i64);
primitive!(u8);
primitive!(u32);
primitive!(u64);
primitive!(f32);
primitive!(f64);
primitive!(half::bf16);
primitive!(half::f16);
pub struct BufferOffset<'a> {
pub buffer: &'a Buffer,

View File

@ -71,6 +71,8 @@ impl candle::Module for PReLU {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let weight = if self.is_scalar {
self.weight.reshape(())?
} else if xs.shape() == self.weight.shape() {
self.weight.clone()
} else if xs.rank() >= 2 {
let num_channels = xs.dim(1)?;
let num_weights = self.weight.elem_count();
@ -78,7 +80,7 @@ impl candle::Module for PReLU {
candle::bail!("error in prelu: unexpected number of channels for the input, got {num_channels}, weight dim is {num_weights}")
}
let mut s = vec![1; xs.rank()];
s[1] = self.weight.elem_count();
s[1] = num_weights;
self.weight.reshape(s)?
} else {
self.weight.clone()

View File

@ -1,6 +1,6 @@
//! Cache Implementations
//!
use candle::{Device, Result, Tensor};
use candle::{DType, Device, Result, Tensor};
#[derive(Debug, Clone)]
pub struct Cache {
@ -399,3 +399,322 @@ impl RotatingKvCache {
self.v.reset();
}
}
#[derive(Debug, Clone)]
pub struct IndicesAndMask {
indices: Tensor,
mask: Tensor,
}
impl IndicesAndMask {
pub fn mask(&self) -> &Tensor {
&self.mask
}
}
#[derive(Debug, Clone)]
pub struct ScatteredKvCache {
k: Tensor,
v: Tensor,
context: usize,
}
impl ScatteredKvCache {
pub fn append(
&mut self,
k: &Tensor,
v: &Tensor,
iam: &IndicesAndMask,
) -> Result<(Tensor, Tensor)> {
if self.context <= k.dim(2)? {
return Ok((k.clone(), v.clone()));
}
let indices = iam.indices.unsqueeze(2)?.unsqueeze(1)?;
let indices = indices.broadcast_as(k.shape())?.contiguous()?;
self.k.scatter_set(&indices, k, 2)?;
self.v.scatter_set(&indices, v, 2)?;
Ok((self.k.clone(), self.v.clone()))
}
pub fn k(&self) -> &Tensor {
&self.k
}
pub fn v(&self) -> &Tensor {
&self.v
}
}
#[derive(Debug, Clone)]
pub struct ScatteredCacheBuilder {
context: usize,
// The current position in the stream, this can be larger than context.
positions: Vec<usize>,
// The index where the next element will be stored.
indices: Vec<usize>,
dtype: DType,
device: Device,
}
impl ScatteredCacheBuilder {
pub fn new(batch_size: usize, context: usize, dtype: DType, device: &Device) -> Result<Self> {
let positions = vec![0; batch_size];
let indices = vec![0; batch_size];
Ok(Self {
positions,
indices,
context,
dtype,
device: device.clone(),
})
}
pub fn make_cache(&self, num_heads: usize, head_dim: usize) -> Result<ScatteredKvCache> {
let batch_size = self.batch_size();
let shape = (batch_size, num_heads, self.context, head_dim);
let k = Tensor::zeros(shape, self.dtype, self.device())?;
let v = Tensor::zeros(shape, self.dtype, self.device())?;
Ok(ScatteredKvCache {
k,
v,
context: self.context,
})
}
pub fn positions(&self) -> &[usize] {
&self.positions
}
pub fn reset(&mut self) {
self.positions.fill(0);
self.indices.fill(0);
}
pub fn batch_size(&self) -> usize {
self.positions.len()
}
pub fn reset_batch_index(&mut self, batch_index: usize) {
self.positions[batch_index] = 0;
self.indices[batch_index] = 0;
}
#[allow(clippy::needless_range_loop)]
pub fn indices_and_mask(
&mut self,
seq_len: usize,
batch_mask: &[bool],
) -> Result<IndicesAndMask> {
// mask shape is (b, h, t, k)
let context = self.context;
if self.context <= seq_len {
return self.indices_and_mask_abs(seq_len, batch_mask);
}
let mut attention_masks = Vec::with_capacity(self.batch_size());
let mut cache_indices = Vec::with_capacity(self.batch_size());
for (batch_i, &batch_mask) in batch_mask.iter().enumerate() {
if !batch_mask {
let masks: Vec<Vec<f32>> = vec![vec![0.0; context]; seq_len];
let indices = vec![self.indices[batch_i] as u32; seq_len];
attention_masks.push(masks);
cache_indices.push(indices);
} else {
let start_index = self.indices[batch_i];
let start_pos = self.positions[batch_i];
let mut masks: Vec<Vec<f32>> = Vec::with_capacity(seq_len);
let mut indices = Vec::with_capacity(seq_len);
let mut all_pos = vec![usize::MAX; context];
if start_pos < context {
for i in 0..start_pos {
all_pos[i] = i;
}
} else {
let offset = start_pos - start_index;
for i in 0..context {
all_pos[i] = if i < start_index {
i + offset
} else {
i + offset - context
};
}
}
for seq_i in 0..seq_len {
let index = self.indices[batch_i];
all_pos[index] = seq_i + start_pos;
indices.push(index as u32);
self.indices[batch_i] += 1;
self.positions[batch_i] += 1;
if self.indices[batch_i] >= self.context {
self.indices[batch_i] = 0;
}
}
for seq_i in 0..seq_len {
let my_pos = seq_i + start_pos;
let mask = all_pos
.iter()
.map(|&pos| {
if pos <= my_pos {
0.0
} else {
f32::NEG_INFINITY
}
})
.collect::<Vec<f32>>();
masks.push(mask);
}
attention_masks.push(masks);
cache_indices.push(indices);
}
}
// Flattening the attention mask then using Tensor::from_vec rather using Tensor::new ends
// up being almost 10x faster with candle 0.9.0. This has been fixed in candle 0.9.1.
let attention_masks = attention_masks
.into_iter()
.flat_map(|m| m.into_iter().flatten())
.collect::<Vec<f32>>();
let mask = Tensor::from_vec(attention_masks, ((), 1, seq_len, context), self.device())?
.to_dtype(self.dtype)?;
let indices = Tensor::new(cache_indices, self.device())?;
Ok(IndicesAndMask { indices, mask })
}
pub fn device(&self) -> &Device {
&self.device
}
#[allow(clippy::needless_range_loop)]
fn indices_and_mask_abs(
&mut self,
seq_len: usize,
batch_mask: &[bool],
) -> Result<IndicesAndMask> {
let mask = self.get_mask_abs(seq_len, seq_len)?;
let mut cache_indices = Vec::with_capacity(self.batch_size());
for (batch_i, &batch_mask) in batch_mask.iter().enumerate() {
if !batch_mask {
let indices = vec![self.indices[batch_i] as u32; seq_len];
cache_indices.push(indices);
} else {
let mut indices = Vec::with_capacity(seq_len);
for _ in 0..seq_len {
let index = self.indices[batch_i];
indices.push(index as u32);
self.indices[batch_i] += 1;
self.positions[batch_i] += 1;
if self.indices[batch_i] >= self.context {
self.indices[batch_i] = 0;
}
}
cache_indices.push(indices);
}
}
let indices = Tensor::new(cache_indices, self.device())?;
Ok(IndicesAndMask { indices, mask })
}
fn get_mask_abs(&self, size1: usize, size2: usize) -> Result<Tensor> {
let context = self.context;
let mask: Vec<_> = (0..size1)
.flat_map(|i| {
(0..size2).map(move |j| {
if size1 + j > size2 + i || size1 + j + context < size2 + i {
f32::NEG_INFINITY
} else {
0.0
}
})
})
.collect();
Tensor::from_slice(&mask, (size1, size2), self.device())
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle::IndexOp;
#[test]
fn test_scattered_kv_cache() -> Result<()> {
let device = Device::Cpu;
let mut cache = ScatteredCacheBuilder::new(2, 5, DType::F32, &device)?;
let inf = f32::INFINITY;
let iam = cache.indices_and_mask(1, &[true, false])?;
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
assert_eq!(iam.indices.to_vec2::<u32>()?, [[0], [0]]);
assert_eq!(
mask,
[[[0.0, -inf, -inf, -inf, -inf]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
);
let iam = cache.indices_and_mask(1, &[true, false])?;
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
assert_eq!(iam.indices.to_vec2::<u32>()?, [[1], [0]]);
assert_eq!(
mask,
[[[0.0, 0.0, -inf, -inf, -inf]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
);
let iam = cache.indices_and_mask(3, &[false, true])?;
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
assert_eq!(iam.indices.to_vec2::<u32>()?, [[2, 2, 2], [0, 1, 2]]);
assert_eq!(
mask,
[
[
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
],
[
[0.0, -inf, -inf, -inf, -inf],
[0.0, 0.0, -inf, -inf, -inf],
[0.0, 0.0, 0.0, -inf, -inf]
]
]
);
let iam = cache.indices_and_mask(3, &[true, true])?;
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
assert_eq!(iam.indices.to_vec2::<u32>()?, [[2, 3, 4], [3, 4, 0]]);
assert_eq!(
mask,
[
[
[0.0, 0.0, 0.0, -inf, -inf],
[0.0, 0.0, 0.0, 0.0, -inf],
[0.0, 0.0, 0.0, 0.0, 0.0]
],
[
[-inf, 0.0, 0.0, 0.0, -inf],
[-inf, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
]
]
);
let iam = cache.indices_and_mask(1, &[true, false])?;
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
assert_eq!(iam.indices.to_vec2::<u32>()?, [[0], [1]]);
assert_eq!(
mask,
[[[0.0, 0.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
);
let iam = cache.indices_and_mask(2, &[true, false])?;
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
assert_eq!(iam.indices.to_vec2::<u32>()?, [[1, 2], [1, 1]]);
assert_eq!(
mask,
[
[[0.0, 0.0, -inf, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]],
[[0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]]
]
);
Ok(())
}
}

View File

@ -46,15 +46,23 @@ impl candle::CustomOp3 for RotaryEmbI {
Some((o1, o2)) => &sin[o1..o2],
};
let (b, h, t, d) = l_src.shape().dims4()?;
let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
let el_count = b * h * t * d;
let mut dst = vec![T::zero(); el_count];
src.par_chunks(t * d)
.zip(dst.par_chunks_mut(t * d))
.for_each(|(src, dst)| {
.enumerate()
.for_each(|(bh_i, (src, dst))| {
for i_over_2 in 0..t * d / 2 {
let i = 2 * i_over_2;
dst[i] = src[i] * cos[i_over_2] - src[i + 1] * sin[i_over_2];
dst[i + 1] = src[i] * sin[i_over_2] + src[i + 1] * cos[i_over_2];
let rope_i = if unbatched_rope {
let b_i = bh_i / h;
i_over_2 + b_i * t * d / 2
} else {
i_over_2
};
dst[i] = src[i] * cos[rope_i] - src[i + 1] * sin[rope_i];
dst[i + 1] = src[i] * sin[rope_i] + src[i + 1] * cos[rope_i];
}
});
let storage = candle::WithDType::to_cpu_storage_owned(dst);
@ -115,6 +123,11 @@ impl candle::CustomOp3 for RotaryEmbI {
Some((o1, o2)) => sin.slice(o1..o2),
};
let (b, h, t, d) = l_src.shape().dims4()?;
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
(h * t * d) as u32
} else {
0u32
};
let el = b * h * t * d;
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), &kernels::REDUCE)?;
@ -125,7 +138,7 @@ impl candle::CustomOp3 for RotaryEmbI {
builder.arg(&cos);
builder.arg(&sin);
builder.arg(&dst);
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32);
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, stride_b);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
Ok(dst)
@ -182,6 +195,11 @@ impl candle::CustomOp3 for RotaryEmbI {
dtype => candle::bail!("rope-i is not implemented for {dtype:?}"),
};
let (b, h, t, d) = l_src.shape().dims4()?;
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
h * t * d
} else {
0usize
};
let el = b * h * t * d;
let output = device.new_buffer(el, src.dtype(), "rope-i")?;
candle_metal_kernels::call_rope_i(
@ -191,6 +209,7 @@ impl candle::CustomOp3 for RotaryEmbI {
name,
b * h,
t * d,
stride_b,
src.buffer(),
l_src.start_offset() * src.dtype().size_in_bytes(),
cos.buffer(),
@ -205,10 +224,23 @@ impl candle::CustomOp3 for RotaryEmbI {
}
}
fn rope_check_cs(cs: &Tensor, b_sz: usize) -> Result<(usize, usize)> {
match *cs.dims() {
[t, d] => Ok((t, d)),
[b, t, d] => {
if b != b_sz {
candle::bail!("inconsistent batch size in rope {b_sz} {cs:?}",)
}
Ok((t, d))
}
_ => candle::bail!("cos/sin has to be 2D or 3D in rope {b_sz} {cs:?}"),
}
}
pub fn rope_i(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
let (sin_seq_len, sin_n_embd) = cos.dims2()?;
let (b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
if cos_n_embd * 2 != n_embd
|| sin_n_embd * 2 != n_embd
|| seq_len > cos_seq_len
@ -292,16 +324,24 @@ impl candle::CustomOp3 for RotaryEmb {
Some((o1, o2)) => &sin[o1..o2],
};
let (b, h, t, d) = l_src.shape().dims4()?;
let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
let el_count = b * h * t * d;
let mut dst = vec![T::zero(); el_count];
src.par_chunks(t * d)
.zip(dst.par_chunks_mut(t * d))
.for_each(|(src, dst)| {
.enumerate()
.for_each(|(bh_i, (src, dst))| {
for i_t in 0..t {
for i_d in 0..d / 2 {
let i1 = i_t * d + i_d;
let i2 = i1 + d / 2;
let i_cs = i_t * (d / 2) + i_d;
let i_cs = if unbatched_rope {
let b_i = bh_i / h;
i_cs + b_i * t * d / 2
} else {
i_cs
};
dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs];
dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs];
}
@ -365,6 +405,11 @@ impl candle::CustomOp3 for RotaryEmb {
Some((o1, o2)) => sin.slice(o1..o2),
};
let (b, h, t, d) = l_src.shape().dims4()?;
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
(h * t * d) as u32
} else {
0u32
};
let el = b * h * t * d;
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), &kernels::REDUCE)?;
@ -375,7 +420,7 @@ impl candle::CustomOp3 for RotaryEmb {
builder.arg(&cos);
builder.arg(&sin);
builder.arg(&dst);
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32);
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32, stride_b);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
Ok(dst)
@ -432,6 +477,11 @@ impl candle::CustomOp3 for RotaryEmb {
dtype => candle::bail!("rope is not implemented for {dtype:?}"),
};
let (b, h, t, d) = l_src.shape().dims4()?;
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
h * t * d
} else {
0usize
};
let el = b * h * t * d;
let output = device.new_buffer(el, src.dtype(), "rope-i")?;
candle_metal_kernels::call_rope(
@ -442,6 +492,7 @@ impl candle::CustomOp3 for RotaryEmb {
b * h,
t * d,
d,
stride_b,
src.buffer(),
l_src.start_offset() * src.dtype().size_in_bytes(),
cos.buffer(),
@ -457,9 +508,9 @@ impl candle::CustomOp3 for RotaryEmb {
}
pub fn rope(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
let (sin_seq_len, sin_n_embd) = sin.dims2()?;
let (b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
if cos_n_embd * 2 != n_embd
|| sin_n_embd * 2 != n_embd
|| seq_len > cos_seq_len
@ -541,14 +592,21 @@ impl candle::CustomOp3 for RotaryEmbThd {
Some((o1, o2)) => &sin[o1..o2],
};
let (b, t, h, d) = l_src.shape().dims4()?;
let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
let el_count = b * h * t * d;
let mut dst = vec![T::zero(); el_count];
src.par_chunks(t * h * d)
.zip(dst.par_chunks_mut(t * h * d))
.for_each(|(src, dst)| {
.enumerate()
.for_each(|(b_i, (src, dst))| {
for i_t in 0..t {
for i_d in 0..d / 2 {
let i_cs = i_t * (d / 2) + i_d;
let i_cs = if unbatched_rope {
i_cs + b_i * t * d / 2
} else {
i_cs
};
for i_h in 0..h {
let i1 = i_t * h * d + i_h * d + i_d;
let i2 = i1 + d / 2;
@ -616,6 +674,11 @@ impl candle::CustomOp3 for RotaryEmbThd {
Some((o1, o2)) => sin.slice(o1..o2),
};
let (b, t, h, d) = l_src.shape().dims4()?;
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
(h * t * d) as u32
} else {
0u32
};
let el = b * h * t * d;
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), &kernels::REDUCE)?;
@ -626,7 +689,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
builder.arg(&cos);
builder.arg(&sin);
builder.arg(&dst);
candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32);
candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32, stride_b);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
Ok(dst)
@ -683,6 +746,11 @@ impl candle::CustomOp3 for RotaryEmbThd {
dtype => candle::bail!("rope_thd is not implemented for {dtype:?}"),
};
let (b, t, h, d) = l_src.shape().dims4()?;
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
h * t * d
} else {
0usize
};
let el = b * h * t * d;
let output = device.new_buffer(el, src.dtype(), "rope-thd")?;
candle_metal_kernels::call_rope_thd(
@ -694,6 +762,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
t,
h,
d,
stride_b,
src.buffer(),
l_src.start_offset() * src.dtype().size_in_bytes(),
cos.buffer(),
@ -709,9 +778,9 @@ impl candle::CustomOp3 for RotaryEmbThd {
}
pub fn rope_thd(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let (_b_sz, seq_len, _n_head, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
let (sin_seq_len, sin_n_embd) = sin.dims2()?;
let (b_sz, seq_len, _n_head, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
if cos_n_embd * 2 != n_embd
|| sin_n_embd * 2 != n_embd
|| seq_len > cos_seq_len

View File

@ -8,13 +8,16 @@ pub fn gumbel_softmax<D: candle::shape::Dim>(
) -> Result<Tensor> {
if temperature <= 0.0 {
logits.argmax(dim)
} else if temperature == 1.0 {
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
let sampled = (logits - minus_g)?.argmax(dim)?;
Ok(sampled)
} else {
// Cast to f32, doing the Gumbel softmax in bf16 is a bit unstable.
let logits = logits.to_dtype(candle::DType::F32)?;
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
Ok(sampled)
if temperature == 1.0 {
let sampled = (logits - minus_g)?.argmax(dim)?;
Ok(sampled)
} else {
let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
Ok(sampled)
}
}
}

View File

@ -4,7 +4,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{test_device, test_utils::to_vec3_round, Device, Result, Tensor};
use candle::{test_device, test_utils::to_vec3_round, Device, IndexOp, Result, Tensor};
fn softmax(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
@ -179,6 +179,28 @@ fn ropei(device: &Device) -> Result<()> {
} else {
assert!(sum_diff < 1e-4);
}
// Test with a 3d cos/sin
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
let rope1 = candle_nn::rotary_emb::rope_i(&src.i(0..1)?, &cos, &sin)?;
let rope2 = candle_nn::rotary_emb::rope_i(&src.i(1..2)?, &cos2, &sin2)?;
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
let both_rope = candle_nn::rotary_emb::rope_i(&src, &both_cos, &both_sin)?;
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
let sum_diff = (both_rope - both_rope2)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(sum_diff, 0.);
Ok(())
}
@ -206,6 +228,28 @@ fn rope(device: &Device) -> Result<()> {
} else {
assert!(sum_diff < 1e-4);
}
// Test with a 3d cos/sin
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
let rope1 = candle_nn::rotary_emb::rope(&src.i(0..1)?, &cos, &sin)?;
let rope2 = candle_nn::rotary_emb::rope(&src.i(1..2)?, &cos2, &sin2)?;
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
let both_rope = candle_nn::rotary_emb::rope(&src, &both_cos, &both_sin)?;
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
let sum_diff = (both_rope - both_rope2)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(sum_diff, 0.);
Ok(())
}
@ -236,6 +280,37 @@ fn rope_thd(device: &Device) -> Result<()> {
} else {
assert!(sum_diff < 1e-4);
}
// Test with a 3d cos/sin
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
let rope1 = {
let src = src.transpose(1, 2)?.contiguous()?;
candle_nn::rotary_emb::rope_thd(&src.i(0..1)?, &cos, &sin)?
};
let rope2 = {
let src = src.transpose(1, 2)?.contiguous()?;
candle_nn::rotary_emb::rope_thd(&src.i(1..2)?, &cos2, &sin2)?
};
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
let both_rope = {
let src = src.transpose(1, 2)?.contiguous()?;
candle_nn::rotary_emb::rope_thd(&src, &both_cos, &both_sin)?
};
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
let sum_diff = (both_rope - both_rope2)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(sum_diff, 0.);
Ok(())
}

View File

@ -1,6 +1,6 @@
[package]
name = "candle-onnx"
version = "0.9.0-alpha.4"
version = "0.9.1"
edition = "2021"
description = "ONNX support for Candle"
@ -10,8 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.4" }
candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.4" }
candle = { path = "../candle-core", package = "candle-core", version = "0.9.1" }
candle-nn = { path = "../candle-nn", version = "0.9.1" }
prost = "0.12.1"
[build-dependencies]

View File

@ -1,7 +1,9 @@
use crate::onnx::attribute_proto::AttributeType;
use crate::onnx::tensor_proto::DataType;
use crate::onnx::{self, GraphProto};
use candle::Module;
use candle::{bail, DType, Device, Result, Tensor};
use candle_nn::activation::PReLU;
use std::collections::{HashMap, HashSet};
pub type Value = Tensor;
@ -581,7 +583,13 @@ fn simple_eval_(
&Device::Cpu,
)?);
let xs = Tensor::ones(input.shape(), value.dtype(), input.device())?
let shape_vec: Vec<usize> = input
.to_vec1::<i64>()?
.iter()
.map(|&x| x as usize)
.collect();
let xs = Tensor::ones(shape_vec, value.dtype(), input.device())?
.broadcast_mul(&value)?;
values.insert(node.output[0].clone(), xs);
}
@ -991,6 +999,14 @@ fn simple_eval_(
let output = input.relu()?;
values.insert(node.output[0].clone(), output);
}
"PRelu" => {
// https://onnx.ai/onnx/operators/onnx__PRelu.html
let input = get(&node.input[0])?;
let slope = get(&node.input[1])?;
let output = PReLU::new(slope.clone(), false).forward(input)?;
values.insert(node.output[0].clone(), output);
}
"Ceil" => {
let input = get(&node.input[0])?;
let output = input.ceil()?;
@ -1228,7 +1244,7 @@ fn simple_eval_(
}
let indexes = Tensor::arange_step(s, e, p, data.device())?;
out = out.index_select(&indexes, axis)?
out = out.contiguous()?.index_select(&indexes, axis)?
}
values.insert(node.output[0].clone(), out);
}
@ -1950,6 +1966,273 @@ fn simple_eval_(
let output = input.sign()?;
values.insert(node.output[0].clone(), output);
}
"Resize" => {
let input = get(&node.input[0])?;
if input.rank() != 4 {
bail!("Unsupported rank for nearest resize: {}", input.rank());
}
let scales = if node.input.len() > 2 && !node.input[2].is_empty() {
Some(get(&node.input[2])?)
} else {
None
};
let sizes = if node.input.len() > 3 && !node.input[3].is_empty() {
Some(get(&node.input[3])?)
} else {
None
};
let output_dims = match (scales, sizes) {
(Some(_), Some(_)) => {
bail!("Scales and sizes cannot both be set for Resize operation")
}
(Some(scales_tensor), None) => {
let scale_values = scales_tensor.to_vec1::<f32>()?;
input
.dims()
.iter()
.enumerate()
.map(|(i, &d)| (d as f32 * scale_values[i]) as usize)
.collect::<Vec<_>>()
}
(None, Some(sizes_tensor)) => sizes_tensor
.to_vec1::<i64>()?
.iter()
.map(|&d| d as usize)
.collect::<Vec<_>>(),
(None, None) => bail!("Either scales or sizes should be present"),
};
let coordinate_transformation_mode =
get_attr_opt::<str>(node, "coordinate_transformation_mode")?
.unwrap_or("half_pixel");
// Interpolation mode: nearest, linear, or cubic.
let mode = get_attr_opt::<str>(node, "mode")?.unwrap_or("nearest");
// How to determine the "nearest" pixel in nearest interpolation mode.
let nearest_mode =
get_attr_opt::<str>(node, "nearest_mode")?.unwrap_or("round_prefer_floor");
if mode != "nearest" {
bail!("Unsupported resize mode: {}", mode);
}
if nearest_mode != "floor" {
bail!("Unsupported nearest_mode for resize: {}", nearest_mode);
}
if coordinate_transformation_mode != "asymmetric" {
bail!(
"Unsupported coordinate_transformation_mode for resize: {}",
coordinate_transformation_mode
);
}
let h = output_dims[2];
let w = output_dims[3];
let output = input.upsample_nearest2d(h, w)?;
values.insert(node.output[0].clone(), output);
}
"Trilu" => {
let input = get(&node.input[0])?;
// Get the diagonal offset 'k' from the second input if provided
let k = if node.input.len() > 1 && !node.input[1].is_empty() {
get(&node.input[1])?.to_vec0::<i64>()?
} else {
0
};
// Get the 'upper' attribute
let upper = get_attr_opt::<i64>(node, "upper")?.copied().unwrap_or(1);
// For batched inputs, we need to handle each matrix separately
let dims = input.dims();
if dims.len() < 2 {
bail!("Trilu expects input with at least 2 dimensions: {:?}", dims);
}
// Get the last two dimensions which represent the matrix
let n = dims[dims.len() - 2];
let m = dims[dims.len() - 1];
let max_dim = std::cmp::max(n, m);
// Handle the diagonal offset k
let mask = if k != 0 {
let mut data = vec![0u32; n * m];
for i in 0..n {
for j in 0..m {
if (upper != 0 && (j as i64) >= (i as i64) + k)
|| (upper == 0 && (j as i64) <= (i as i64) + k)
{
data[i * m + j] = 1u32;
}
}
}
Tensor::from_vec(data, (n, m), input.device())?.to_dtype(input.dtype())?
} else if upper == 0 {
Tensor::tril2(max_dim, input.dtype(), input.device())?
} else {
Tensor::triu2(max_dim, input.dtype(), input.device())?
};
let final_mask = if n != m {
mask.narrow(0, 0, n)?.narrow(1, 0, m)?
} else {
mask
};
let output = (input * &final_mask)?;
values.insert(node.output[0].clone(), output);
}
"ScatterND" => {
let data = get(&node.input[0])?;
let indices = get(&node.input[1])?;
let indices = indices.to_dtype(DType::I64)?;
let updates = get(&node.input[2])?;
let reduction = get_attr_opt::<str>(node, "reduction")?.unwrap_or("none");
let indices_shape = indices.dims();
let data_shape = data.dims();
let updates_shape = updates.dims();
// Last dimension of indices represents the depth of indexing
let k = indices_shape.last().unwrap().clone();
if k > data.rank() {
bail!("ScatterND expects k (indices.shape[-1]) to be at most the rank of data");
}
let num_updates = indices_shape[..indices_shape.len() - 1]
.iter()
.product::<usize>();
let flat_indices = if indices.rank() == 1 && k == 1 {
indices.unsqueeze(0)?
} else {
indices.reshape((num_updates, k))?
};
// Calculate the shape of each update element
let update_element_shape = if k < data_shape.len() {
data_shape[k..].to_vec()
} else {
vec![]
};
// Expected shape for updates based on indices and target tensor
let expected_updates_shape = {
let mut shape = indices_shape[..indices_shape.len() - 1].to_vec();
shape.extend(&update_element_shape);
shape
};
// Validate or reshape updates to expected shape
let updates = if updates.dims() != expected_updates_shape {
if updates.rank() == 0 {
// Handle scalar updates
let mut target_shape = vec![num_updates];
target_shape.extend(&update_element_shape);
updates.broadcast_as(target_shape)?
} else {
// Try to broadcast or reshape updates to expected shape
let flat_shape =
vec![num_updates, update_element_shape.iter().product::<usize>()];
let flattened = updates.reshape(flat_shape)?;
flattened.reshape(expected_updates_shape)?
}
} else {
updates.clone()
};
let mut output = data.clone();
// convert indices to flat indices
let mut flat_output = output.flatten_all()?;
let flat_updates = if update_element_shape.is_empty() {
updates.reshape(num_updates)?
} else {
let product = update_element_shape.iter().product::<usize>();
updates.reshape((num_updates, product))?
};
// Calculate strides for the output tensor
let mut strides: Vec<usize> = vec![1];
for i in (0..data_shape.len() - 1).rev() {
strides.push(strides.last().unwrap() * data_shape[i + 1]);
}
strides.reverse();
// Process each update
for i in 0..num_updates {
let index_slice = flat_indices.narrow(0, i, 1)?;
let indices_vec = index_slice.squeeze(0)?.to_vec1::<i64>()?;
// Convert multi-dimensional indices to flat index
let mut flat_idx: usize = 0;
for (dim, &idx) in indices_vec.iter().enumerate() {
let dim_size = data_shape[dim] as i64;
let norm_idx = if idx < 0 { dim_size + idx } else { idx };
if norm_idx < 0 || norm_idx >= dim_size {
bail!(
"Index {} out of bounds for dimension {} with size {}",
idx,
dim,
dim_size
);
}
flat_idx += (norm_idx as usize) * strides[dim];
}
// Extract current update
let update_slice = if update_element_shape.is_empty() {
flat_updates.narrow(0, i, 1)?.squeeze(0)?
} else {
flat_updates.narrow(0, i, 1)?
};
match reduction {
"add" => {
if update_element_shape.is_empty() {
let existing = flat_output.narrow(0, flat_idx, 1)?;
let new_value = existing.add(&update_slice.unsqueeze(0)?)?;
flat_output = flat_output.slice_scatter(&new_value, 0, flat_idx)?;
} else {
let slice_size = update_element_shape.iter().product::<usize>();
let existing = flat_output.narrow(0, flat_idx, slice_size)?;
let new_value = existing.add(&update_slice)?;
flat_output = flat_output.slice_scatter(&new_value, 0, flat_idx)?;
}
}
"none" | _ => {
if update_element_shape.is_empty() {
flat_output = flat_output.slice_scatter(
&update_slice.unsqueeze(0)?,
0,
flat_idx,
)?;
} else {
flat_output =
flat_output.slice_scatter(&update_slice, 0, flat_idx)?;
}
}
}
}
// Reshape flat output back to original shape
output = flat_output.reshape(data_shape.to_vec())?;
values.insert(node.output[0].clone(), output);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}

View File

@ -842,13 +842,22 @@ fn test_flatten_operation() -> Result<()> {
#[test]
fn test_constant_of_shape() -> Result<()> {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31
test(&[4i64, 3, 2], Some(1.), &[1., 1., 1.])?;
test(
&[4i64, 3, 2],
Some(1.),
&[
[[1., 1.], [1., 1.], [1., 1.]],
[[1., 1.], [1., 1.], [1., 1.]],
[[1., 1.], [1., 1.], [1., 1.]],
[[1., 1.], [1., 1.], [1., 1.]],
],
)?;
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31
test(&[0.], Some(0i64), &[0i64])?;
test(&[1i64], Some(0i64), &[0i64])?;
// "value" defaults to 0 f32
test(&[1i64, 2, 3, 4], None as Option<i64>, &[0., 0., 0., 0.])?;
test(&[4i64], None as Option<i64>, &[0., 0., 0., 0.])?;
fn test(
input: impl NdArray,
@ -1846,6 +1855,64 @@ fn test_relu_operation() -> Result<()> {
Ok(())
}
// "PRelu"
#[test]
fn test_prelu_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "PRelu".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x: Tensor = Tensor::from_vec(
vec![-1.0f32, 1.0f32, -2.0f32, 3.0f32],
&[2, 2],
&Device::Cpu,
)?;
let y: Tensor = Tensor::from_vec(vec![1.0f32, 1.1f32, 1.2f32, 1.3f32], &[2, 2], &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
inputs.insert(INPUT_Y.to_string(), y);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(results, vec![vec![-1.0, 1.0], vec![-2.4, 3.0]]);
Ok(())
}
// "Constant"
// #[test]
@ -5910,3 +5977,512 @@ fn test_sign_operation() -> Result<()> {
);
Ok(())
}
#[test]
fn test_scatternd_operation() -> Result<()> {
// Example 1 based on ONNX documentation
test(
&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
&[[4i64], [3], [1], [7]],
&[9.0f32, 10.0, 11.0, 12.0],
&[1.0f32, 11.0, 3.0, 10.0, 9.0, 6.0, 7.0, 12.0],
)?;
// A more complex example with 2D data
test(
&[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]],
&[[0i64, 1], [1, 0]],
&[10.0f32, 20.0],
&[[1.0f32, 10.0], [20.0, 4.0], [5.0, 6.0]],
)?;
// 3D example with indices pointing to specific locations
test(
&[
[[1.0f32, 2.0], [3.0, 4.0]],
[[5.0, 6.0], [7.0, 8.0]],
[[9.0, 10.0], [11.0, 12.0]],
],
&[[0i64, 0, 1], [1, 1, 0]],
&[100.0f32, 200.0],
&[
[[1.0f32, 100.0], [3.0, 4.0]],
[[5.0, 6.0], [200.0, 8.0]],
[[9.0, 10.0], [11.0, 12.0]],
],
)?;
fn test(
data: impl NdArray,
indices: impl NdArray,
updates: impl NdArray,
expected: impl NdArray,
) -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "ScatterND".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![
INPUT_X.to_string(),
INPUT_Y.to_string(),
INPUT_A.to_string(),
],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
inputs.insert(INPUT_Y.to_string(), Tensor::new(indices, &Device::Cpu)?);
inputs.insert(INPUT_A.to_string(), Tensor::new(updates, &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let expected = Tensor::new(expected, &Device::Cpu)?;
match expected.dims().len() {
1 => assert_eq!(z.to_vec1::<f32>()?, expected.to_vec1::<f32>()?),
2 => assert_eq!(z.to_vec2::<f32>()?, expected.to_vec2::<f32>()?),
3 => assert_eq!(z.to_vec3::<f32>()?, expected.to_vec3::<f32>()?),
_ => unreachable!(),
};
Ok(())
}
Ok(())
}
#[test]
fn test_trilu_operation() -> Result<()> {
// Test 1: Upper triangular matrix (default behavior with upper=true)
{
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Trilu".to_string(),
domain: "".to_string(),
attribute: vec![], // empty attribute means default upper=true
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
vec![
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 0, 8, 7, 4, 3, 4, 2, 4,
],
&[4, 5],
&Device::Cpu,
)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<i64>()?;
assert_eq!(
results,
vec![
vec![4, 7, 3, 7, 9],
vec![0, 2, 8, 6, 9],
vec![0, 0, 0, 8, 7],
vec![0, 0, 0, 2, 4]
]
);
}
// Test 2: Upper triangular with positive k=1 (diagonal above main)
{
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Trilu".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
vec![1i64, 4, 9, 7, 1, 9, 2, 8, 8, 4, 3, 9, 7, 4, 2],
&[3, 5],
&Device::Cpu,
)?;
let k = Tensor::from_vec(vec![1i64], (), &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
inputs.insert(INPUT_Y.to_string(), k);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<i64>()?;
assert_eq!(
results,
vec![
vec![0, 4, 9, 7, 1],
vec![0, 0, 8, 8, 4],
vec![0, 0, 0, 4, 2]
]
);
}
// Test 3: Upper triangular with negative k=-1 (one diagonal below main)
{
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Trilu".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
vec![
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 0, 8, 7, 4, 3, 4, 2, 4,
],
&[4, 5],
&Device::Cpu,
)?;
let k = Tensor::from_vec(vec![-1i64], (), &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
inputs.insert(INPUT_Y.to_string(), k);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<i64>()?;
assert_eq!(
results,
vec![
vec![4, 7, 3, 7, 9],
vec![1, 2, 8, 6, 9],
vec![0, 4, 0, 8, 7],
vec![0, 0, 4, 2, 4]
]
);
}
// Test 4: Lower triangular matrix (upper=0)
{
let att_upper = AttributeProto {
name: "upper".to_string(),
ref_attr_name: "upper".to_string(),
i: 0, // 0 means false, use lower triangular
doc_string: "upper".to_string(),
r#type: 2,
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Trilu".to_string(),
domain: "".to_string(),
attribute: vec![att_upper],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
vec![
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 1, 8, 7, 4, 3, 4, 2, 4,
],
&[4, 5],
&Device::Cpu,
)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<i64>()?;
// Lower triangular matrix (default k=0)
assert_eq!(
results,
vec![
vec![4, 0, 0, 0, 0],
vec![1, 2, 0, 0, 0],
vec![9, 4, 1, 0, 0],
vec![4, 3, 4, 2, 0]
]
);
}
// Test 5: Lower triangular with negative k=-1
{
let att_upper = AttributeProto {
name: "upper".to_string(),
ref_attr_name: "upper".to_string(),
i: 0,
doc_string: "upper".to_string(),
r#type: 2,
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Trilu".to_string(),
domain: "".to_string(),
attribute: vec![att_upper],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
vec![
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 1, 8, 7, 4, 3, 4, 2, 4,
],
&[4, 5],
&Device::Cpu,
)?;
let k = Tensor::from_vec(vec![-1i64], (), &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
inputs.insert(INPUT_Y.to_string(), k);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<i64>()?;
assert_eq!(
results,
vec![
vec![0, 0, 0, 0, 0],
vec![1, 0, 0, 0, 0],
vec![9, 4, 0, 0, 0],
vec![4, 3, 4, 0, 0]
]
);
}
// Test 6: Lower triangular with positive k=2
{
let att_upper = AttributeProto {
name: "upper".to_string(),
ref_attr_name: "upper".to_string(),
i: 0,
doc_string: "upper".to_string(),
r#type: 2,
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Trilu".to_string(),
domain: "".to_string(),
attribute: vec![att_upper],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
vec![
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 1, 8, 7, 4, 3, 4, 2, 4,
],
&[4, 5],
&Device::Cpu,
)?;
let k = Tensor::from_vec(vec![2i64], (), &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
inputs.insert(INPUT_Y.to_string(), k);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<i64>()?;
assert_eq!(
results,
vec![
vec![4, 7, 3, 0, 0],
vec![1, 2, 8, 6, 0],
vec![9, 4, 1, 8, 7],
vec![4, 3, 4, 2, 4]
]
);
}
Ok(())
}

View File

@ -869,8 +869,8 @@ impl Moe {
}
enum MoeOrMlp {
Moe(Moe),
Mlp(Mlp),
Moe(Box<Moe>),
Mlp(Box<Mlp>),
}
impl MoeOrMlp {
@ -908,14 +908,17 @@ impl DecoderLayer {
&& layer_idx >= cfg.first_k_dense_replace
&& layer_idx % cfg.moe_layer_freq == 0
{
MoeOrMlp::Moe(Moe::new(
cfg,
vb.pp("mlp"),
cfg.n_shared_experts,
cfg.n_routed_experts.unwrap(),
)?)
MoeOrMlp::Moe(
Moe::new(
cfg,
vb.pp("mlp"),
cfg.n_shared_experts,
cfg.n_routed_experts.unwrap(),
)?
.into(),
)
} else {
MoeOrMlp::Mlp(Mlp::new(cfg, vb.pp("mlp"), None, None)?)
MoeOrMlp::Mlp(Mlp::new(cfg, vb.pp("mlp"), None, None)?.into())
};
Ok(Self {

View File

@ -21,6 +21,7 @@ pub struct Config {
pub num_key_value_heads: usize,
pub rms_norm_eps: f64,
pub rope_theta: f64,
pub rope_local_base_freq: f64,
pub vocab_size: usize,
pub final_logit_softcapping: Option<f64>,
pub attn_logit_softcapping: Option<f64>,
@ -67,12 +68,22 @@ struct RotaryEmbedding {
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
fn new(
dtype: DType,
cfg: &Config,
dev: &Device,
sliding_window: Option<usize>,
) -> Result<Self> {
let dim = cfg.head_dim;
let max_seq_len = cfg.max_position_embeddings;
let rope_freq = if sliding_window.is_some() {
cfg.rope_local_base_freq
} else {
cfg.rope_theta
};
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.map(|i| 1f32 / rope_freq.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
@ -162,8 +173,8 @@ impl Attention {
fn new(
rotary_emb: Arc<RotaryEmbedding>,
use_flash_attn: bool,
is_sliding: bool,
cfg: &Config,
sliding_window: Option<usize>,
vb: VarBuilder,
) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
@ -178,13 +189,13 @@ impl Attention {
let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?;
let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?;
let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?;
let kv_cache = if is_sliding {
KvCache::Rotating(candle_nn::kv_cache::RotatingKvCache::new(
2,
cfg.sliding_window,
))
let kv_cache = if let Some(sliding_window) = sliding_window {
KvCache::Rotating(candle_nn::kv_cache::RotatingKvCache::new(2, sliding_window))
} else {
KvCache::Normal(candle_nn::kv_cache::KvCache::new(2, cfg.sliding_window))
KvCache::Normal(candle_nn::kv_cache::KvCache::new(
2,
cfg.max_position_embeddings,
))
};
Ok(Self {
q_proj,
@ -302,21 +313,27 @@ struct DecoderLayer {
pre_feedforward_layernorm: RmsNorm,
post_feedforward_layernorm: RmsNorm,
post_attention_layernorm: RmsNorm,
sliding_window: Option<usize>,
}
impl DecoderLayer {
fn new(
rotary_emb: Arc<RotaryEmbedding>,
use_flash_attn: bool,
is_sliding: bool,
cfg: &Config,
vb: VarBuilder,
sliding_window: Option<usize>,
) -> Result<Self> {
let rotary_emb = Arc::new(RotaryEmbedding::new(
vb.dtype(),
cfg,
vb.device(),
sliding_window,
)?);
let self_attn = Attention::new(
rotary_emb,
use_flash_attn,
is_sliding,
cfg,
sliding_window,
vb.pp("self_attn"),
)?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
@ -344,6 +361,7 @@ impl DecoderLayer {
pre_feedforward_layernorm,
post_feedforward_layernorm,
post_attention_layernorm,
sliding_window,
})
}
@ -370,6 +388,42 @@ impl DecoderLayer {
}
}
fn prepare_decoder_attention_mask(
b_size: usize,
tgt_len: usize,
seqlen_offset: usize,
sliding_window: Option<usize>,
dtype: DType,
device: &Device,
) -> Result<Tensor> {
let mask: Vec<_> = if let Some(sliding_window) = sliding_window {
(0..tgt_len)
.flat_map(|i| {
(0..tgt_len).map(move |j| {
if i < j || j + sliding_window < i {
f32::NEG_INFINITY
} else {
0.
}
})
})
.collect()
} else {
(0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0f32 }))
.collect()
};
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(dtype)
}
#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: candle_nn::Embedding,
@ -388,17 +442,15 @@ impl Model {
let vb_m = vb.pp("model");
let embed_tokens =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_m.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let is_sliding = (layer_idx + 1) % cfg.sliding_window_pattern > 0;
let sliding_window = (layer_idx + 1) % cfg.sliding_window_pattern > 0;
let layer = DecoderLayer::new(
rotary_emb.clone(),
use_flash_attn,
is_sliding,
cfg,
vb_l.pp(layer_idx),
sliding_window.then_some(cfg.sliding_window),
)?;
layers.push(layer)
}
@ -417,51 +469,52 @@ impl Model {
})
}
fn prepare_decoder_attention_mask(
fn create_attention_masks(
&self,
b_size: usize,
tgt_len: usize,
batch_size: usize,
seq_len: usize,
seqlen_offset: usize,
) -> Result<Tensor> {
let mask: Vec<_> = match Some(self.sliding_window) {
None => (0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
.collect(),
Some(sliding_window) => (0..tgt_len)
.flat_map(|i| {
(0..tgt_len).map(move |j| {
if i < j || j + sliding_window < i {
f32::NEG_INFINITY
} else {
0.
}
})
})
.collect(),
};
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(self.dtype)
) -> Result<(Option<Tensor>, Option<Tensor>)> {
if seq_len <= 1 {
return Ok((None, None));
}
let mask = prepare_decoder_attention_mask(
batch_size,
seq_len,
seqlen_offset,
None,
self.dtype,
&self.device,
)?;
let sliding_mask = prepare_decoder_attention_mask(
batch_size,
seq_len,
seqlen_offset,
Some(self.sliding_window),
self.dtype,
&self.device,
)?;
Ok((Some(mask), Some(sliding_mask)))
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
Some(mask)
};
let xs = self.embed_tokens.forward(input_ids)?;
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
let (attention_mask, sliding_attention_mask) =
self.create_attention_masks(b_size, seq_len, seqlen_offset)?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
let mask = if layer.sliding_window.is_some() {
&sliding_attention_mask
} else {
&attention_mask
};
xs = layer.forward(&xs, mask.as_ref(), seqlen_offset)?
}
let logits = xs
.narrow(1, seq_len - 1, 1)?

View File

@ -70,6 +70,7 @@ pub mod moondream;
pub mod mpt;
pub mod nvembed_v2;
pub mod olmo;
pub mod olmo2;
pub mod openclip;
pub mod paligemma;
pub mod parler_tts;
@ -79,6 +80,7 @@ pub mod phi3;
pub mod pixtral;
pub mod quantized_blip;
pub mod quantized_blip_text;
pub mod quantized_gemma3;
pub mod quantized_llama;
pub mod quantized_llama2_c;
pub mod quantized_metavoice;
@ -89,6 +91,7 @@ pub mod quantized_mpt;
pub mod quantized_phi;
pub mod quantized_phi3;
pub mod quantized_qwen2;
pub mod quantized_qwen3;
pub mod quantized_recurrent_gemma;
pub mod quantized_rwkv_v5;
pub mod quantized_rwkv_v6;
@ -96,6 +99,8 @@ pub mod quantized_stable_lm;
pub mod quantized_t5;
pub mod qwen2;
pub mod qwen2_moe;
pub mod qwen3;
pub mod qwen3_moe;
pub mod recurrent_gemma;
pub mod repvgg;
pub mod resnet;

View File

@ -0,0 +1,348 @@
//! OLMo 2 (Open Language Model) implementation
//!
//! See OLMo 2 model details at:
//! - [Hugging Face Collection](https://huggingface.co/collections/allenai/olmo-2-674117b93ab84e98afc72edc)
//! - [OLMo 2 Paper](https://arxiv.org/abs/2501.00656)
//!
//!
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{linear_b, linear_no_bias, rms_norm, Activation, Linear, RmsNorm, VarBuilder};
use std::sync::Arc;
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub attention_bias: bool,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub rms_norm_eps: f64,
pub hidden_act: candle_nn::Activation,
pub max_position_embeddings: usize,
pub rope_theta: f64,
pub tie_word_embeddings: bool,
pub clip_qkv: Option<f64>,
}
#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.hidden_size / cfg.num_attention_heads;
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,
seqlen_offset: usize,
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
act_fn: Activation,
}
impl MLP {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let intermediate_sz = cfg.intermediate_size;
let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
Ok(Self {
gate_proj,
up_proj,
down_proj,
act_fn: cfg.hidden_act,
})
}
}
impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
let rhs = xs.apply(&self.up_proj)?;
(lhs * rhs)?.apply(&self.down_proj)
}
}
#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
q_norm: RmsNorm,
k_norm: RmsNorm,
num_heads: usize,
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
hidden_size: usize,
rotary_emb: Arc<RotaryEmbedding>,
kv_cache: Option<(Tensor, Tensor)>,
}
impl Attention {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let num_kv_groups = num_heads / num_kv_heads;
let head_dim = hidden_sz / num_heads;
let b = cfg.attention_bias;
let q_proj = linear_b(hidden_sz, num_heads * head_dim, b, vb.pp("q_proj"))?;
let k_proj = linear_b(hidden_sz, num_kv_heads * head_dim, b, vb.pp("k_proj"))?;
let v_proj = linear_b(hidden_sz, num_kv_heads * head_dim, b, vb.pp("v_proj"))?;
let o_proj = linear_b(num_heads * head_dim, hidden_sz, b, vb.pp("o_proj"))?;
let q_norm = rms_norm(hidden_sz, cfg.rms_norm_eps, vb.pp("q_norm"))?;
let k_norm = rms_norm(num_kv_heads * head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
num_heads,
num_kv_heads,
num_kv_groups,
head_dim,
hidden_size: hidden_sz,
rotary_emb,
kv_cache: None,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;
let query_states = self.q_proj.forward(xs)?;
let key_states = self.k_proj.forward(xs)?;
let value_states = self.v_proj.forward(xs)?;
let query_states = self.q_norm.forward(&query_states)?;
let key_states = self.k_norm.forward(&key_states)?;
let query_states = query_states
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let key_states = key_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let value_states = value_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let (query_states, key_states) =
self.rotary_emb
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
let (key_states, value_states) = match &self.kv_cache {
None => (key_states, value_states),
Some((prev_k, prev_v)) => {
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
(key_states, value_states)
}
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
let value_states =
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
attn_weights.matmul(&value_states)?
};
attn_output
.transpose(1, 2)?
.reshape((b_sz, q_len, self.hidden_size))?
.apply(&self.o_proj)
}
fn clear_kv_cache(&mut self) {
self.kv_cache = None
}
}
#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Attention,
mlp: MLP,
post_attention_layernorm: RmsNorm,
post_feedforward_layernorm: RmsNorm,
}
impl DecoderLayer {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
let post_feedforward_layernorm = rms_norm(
cfg.hidden_size,
cfg.rms_norm_eps,
vb.pp("post_feedforward_layernorm"),
)?;
let post_attention_layernorm = rms_norm(
cfg.hidden_size,
cfg.rms_norm_eps,
vb.pp("post_attention_layernorm"),
)?;
Ok(Self {
self_attn,
mlp,
post_attention_layernorm,
post_feedforward_layernorm,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let residual = xs;
let xs = self.self_attn.forward(xs, attention_mask, seqlen_offset)?;
let xs = self.post_attention_layernorm.forward(&xs)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = self.mlp.forward(&xs)?;
let xs = self.post_feedforward_layernorm.forward(&xs)?;
residual + xs
}
fn clear_kv_cache(&mut self) {
self.self_attn.clear_kv_cache()
}
}
#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,
norm: RmsNorm,
lm_head: Linear,
device: Device,
dtype: DType,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_m = vb.pp("model");
let embed_tokens =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_m.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm = rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
let lm_head = if cfg.tie_word_embeddings {
Linear::new(embed_tokens.embeddings().clone(), None)
} else {
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
};
Ok(Self {
embed_tokens,
layers,
norm,
lm_head,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
fn prepare_decoder_attention_mask(
&self,
b_size: usize,
tgt_len: usize,
seqlen_offset: usize,
) -> Result<Tensor> {
// Sliding window mask?
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), self.dtype, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(self.dtype)
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
Some(mask)
};
let mut xs = self.embed_tokens.forward(input_ids)?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
}
xs.narrow(1, seq_len - 1, 1)?
.apply(&self.norm)?
.apply(&self.lm_head)
}
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()
}
}
}

View File

@ -20,10 +20,24 @@
// This implementation is based on:
// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
use crate::models::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
use candle::{DType, Device, Module, Result, Tensor, D};
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::VarBuilder;
use std::sync::Arc;
#[derive(Debug, Clone, serde::Deserialize)]
pub enum RopeScalingType {
#[serde(rename = "longrope")]
LongRope,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct RopeScaling {
pub short_factor: Vec<f32>,
pub long_factor: Vec<f32>,
#[serde(rename = "type")]
pub type_: RopeScalingType,
}
// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
@ -38,8 +52,12 @@ pub struct Config {
pub rope_theta: f64,
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<u32>,
pub rope_scaling: Option<String>,
pub rope_scaling: Option<RopeScaling>,
pub max_position_embeddings: usize,
pub original_max_position_embeddings: Option<usize>,
pub partial_rotary_factor: Option<f64>,
#[serde(default)]
pub tie_word_embeddings: bool,
}
impl Config {
@ -50,30 +68,88 @@ impl Config {
#[derive(Debug, Clone)]
pub struct RotaryEmbedding {
partial_dim: Option<usize>,
sin: Tensor,
cos: Tensor,
}
impl RotaryEmbedding {
pub fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.head_dim();
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
let partial_dim = cfg
.partial_rotary_factor
.as_ref()
.map(|v| (v * cfg.head_dim() as f64) as usize);
let dim = partial_dim.unwrap_or(cfg.head_dim());
let freqs = match cfg.rope_scaling.as_ref() {
None => {
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq = Tensor::from_vec(inv_freq, (1, ()), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
t.matmul(&inv_freq)?
}
Some(rope_scaling) => {
let inv_freq_s: Vec<_> = (0..dim)
.step_by(2)
.zip(rope_scaling.short_factor.iter())
.map(|(i, &f)| f / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_s = Tensor::from_vec(inv_freq_s, (1, ()), dev)?.to_dtype(dtype)?;
let max_seq_len = cfg.max_position_embeddings;
match cfg.original_max_position_embeddings {
None => {
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
t.matmul(&inv_freq_s)?
}
Some(original_max_seq_len) => {
let t_s = Tensor::arange(0u32, original_max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((original_max_seq_len, 1))?;
let freq_s = t_s.matmul(&inv_freq_s)?;
let inv_freq_l: Vec<_> = (0..dim)
.step_by(2)
.zip(rope_scaling.long_factor.iter())
.map(|(i, &f)| f / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_l =
Tensor::from_vec(inv_freq_l, (1, ()), dev)?.to_dtype(dtype)?;
let t_l =
Tensor::arange(original_max_seq_len as u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape(((), 1))?;
let freq_l = t_l.matmul(&inv_freq_l)?;
Tensor::cat(&[&freq_s, &freq_l], 0)?
}
}
}
};
Ok(Self {
partial_dim,
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
fn rope(&self, xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let x = match self.partial_dim {
None => candle_nn::rotary_emb::rope(&xs.contiguous()?, cos, sin)?,
Some(dim) => {
let xs_rot = xs.i((.., .., .., ..dim))?.contiguous()?;
let xs_pass = xs.i((.., .., .., dim..))?;
let xs_rot = candle_nn::rotary_emb::rope(&xs_rot, cos, sin)?;
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)?.contiguous()?
}
};
Ok(x)
}
pub fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
@ -83,8 +159,8 @@ impl RotaryEmbedding {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
let q_embed = self.rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = self.rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
@ -292,7 +368,11 @@ impl Model {
layers.push(layer)
}
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let lm_head = if cfg.tie_word_embeddings {
Linear::from_weights(embed_tokens.embeddings().clone(), None)
} else {
linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
};
Ok(Self {
embed_tokens,
layers,

View File

@ -0,0 +1,466 @@
//! Gemma 3 model implementation with quantization support.
//!
//! Gemma 3 is a family of multimodal language models developed by Google.
//! This implementation provides quantization for reduced memory usage and faster inference.
//!
//! Key characteristics:
//! - Group-Query Attention (GQA) with specialized key-value heads
//! - RMSNorm for layer normalization
//! - Specialized attention patterns with separate normalization for Q/K/V
//! - Feed-forward network with SwiGLU activation
//! - Support for 2/3/4/8-bit quantization
//!
//! References:
//! - [Gemma 3 Models](https://blog.google/technology/developers/gemma-3/)
//!
use crate::quantized_nn::RmsNorm;
use candle::quantized::gguf_file;
use candle::quantized::QTensor;
use candle::D;
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{Embedding, Module};
pub const MAX_SEQ_LEN: usize = 131072; // Gemma 3 supports 128K context window
pub const DEFAULT_SLIDING_WINDOW_TYPE: usize = 6;
pub const DEFAULT_ROPE_FREQUENCY: f32 = 1_000_000.;
pub const DEFAULT_ROPE_FREQUENCY_SLIDING: f32 = 10_000.;
pub const DEFAULT_ROPE_FREQUENCY_SCALE_FACTOR: f32 = 1.;
#[derive(Debug, Clone)]
struct QMatMul {
inner: candle::quantized::QMatMul,
span: tracing::Span,
}
impl QMatMul {
fn from_qtensor(qtensor: QTensor) -> Result<Self> {
let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?;
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
Ok(Self { inner, span })
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(xs)
}
}
#[derive(Debug, Clone)]
struct Mlp {
feed_forward_gate: QMatMul, // ffn_gate in GGUF
feed_forward_up: QMatMul, // ffn_up in GGUF
feed_forward_down: QMatMul, // ffn_down in GGUF
}
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let gate = self.feed_forward_gate.forward(xs)?;
let up = self.feed_forward_up.forward(xs)?;
let silu = candle_nn::ops::silu(&gate)?;
let gated = (silu * up)?;
self.feed_forward_down.forward(&gated)
}
}
#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
impl RotaryEmbedding {
fn new(head_dim: usize, rope_frequency: f32, device: &Device) -> Result<Self> {
let theta: Vec<_> = (0..head_dim)
.step_by(2)
.map(|i| 1f32 / rope_frequency.powf(i as f32 / head_dim as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), device)?;
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
.to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let cos = idx_theta.cos()?;
let sin = idx_theta.sin()?;
Ok(Self { sin, cos })
}
fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,
index_pos: usize,
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, index_pos, seq_len)?;
let sin = self.sin.narrow(0, index_pos, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
#[derive(Debug, Clone)]
struct LayerWeights {
// Attention components
attention_wq: QMatMul,
attention_wk: QMatMul,
attention_wv: QMatMul,
attention_wo: QMatMul,
// Specialized normalization for Q and K
attention_q_norm: RmsNorm,
attention_k_norm: RmsNorm,
// Layer normalization
attention_norm: RmsNorm, // Applied before attention
post_attention_norm: RmsNorm, // Applied after attention
ffn_norm: RmsNorm, // Applied before feedforward
post_ffn_norm: RmsNorm, // Applied after feedforward
// Feed-forward network
mlp: Mlp,
// Attention parameters
n_head: usize, // Number of query heads
n_kv_head: usize, // Number of key-value heads
head_dim: usize, // Dimension of each head
q_dim: usize, // Total dimension for queries
sliding_window_size: Option<usize>,
rotary_embedding: RotaryEmbedding,
neg_inf: Tensor,
// Cache
kv_cache: Option<(Tensor, Tensor)>,
// Tracing
span_attn: tracing::Span,
span_mlp: tracing::Span,
}
impl LayerWeights {
fn mask(
&self,
b_sz: usize,
seq_len: usize,
index_pos: usize,
dtype: DType,
device: &Device,
) -> Result<Tensor> {
let mask: Vec<_> = if let Some(sliding_window_size) = self.sliding_window_size {
(0..seq_len)
.flat_map(|i| {
(0..seq_len).map(move |j| {
if i < j || j + sliding_window_size < i {
0u32
} else {
1u32
}
})
})
.collect()
} else {
(0..seq_len)
.flat_map(|i| (0..seq_len).map(move |j| if i < j { 0u32 } else { 1u32 }))
.collect()
};
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
let mask = if index_pos > 0 {
let mask0 = Tensor::zeros((seq_len, index_pos), DType::F32, device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_sz, 1, seq_len, seq_len + index_pos))?
.to_dtype(dtype)
}
fn forward_attn(
&mut self,
x: &Tensor,
mask: Option<&Tensor>,
index_pos: usize,
) -> Result<Tensor> {
let _enter = self.span_attn.enter();
let (b_sz, seq_len, _) = x.dims3()?;
let q = self.attention_wq.forward(x)?;
let k = self.attention_wk.forward(x)?;
let v = self.attention_wv.forward(x)?;
let q = q
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
.transpose(1, 2)?;
let q = self.attention_q_norm.forward(&q.contiguous()?)?;
let k = self.attention_k_norm.forward(&k.contiguous()?)?;
let (q, k) = self
.rotary_embedding
.apply_rotary_emb_qkv(&q, &k, index_pos)?;
let (k, v) = match &self.kv_cache {
None => (k, v),
Some((k_cache, v_cache)) => {
if index_pos == 0 {
(k, v)
} else {
let k = Tensor::cat(&[k_cache, &k], 2)?; // concat on seq dim
let v = Tensor::cat(&[v_cache, &v], 2)?;
(k, v)
}
}
};
self.kv_cache = Some((k.clone(), v.clone())); // update cache
// Repeat KV for GQA
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
// Scaled Dot-Product Attention
let scale = 1.0 / (self.head_dim as f64).sqrt();
let mut attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
if let Some(mask) = mask {
let mask = mask.broadcast_as(attn_weights.shape())?;
let neg_inf = self.neg_inf.broadcast_as(attn_weights.dims())?;
attn_weights = mask.eq(0u32)?.where_cond(&neg_inf, &attn_weights)?;
}
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let attn_output = attn_weights.matmul(&v)?;
let attn_output = attn_output
.transpose(1, 2)?
.reshape((b_sz, seq_len, self.q_dim))?;
self.attention_wo.forward(&attn_output)
}
}
#[derive(Debug, Clone)]
pub struct ModelWeights {
tok_embeddings: Embedding,
embedding_length: usize,
layers: Vec<LayerWeights>,
norm: RmsNorm,
output: QMatMul,
span: tracing::Span,
span_output: tracing::Span,
}
impl ModelWeights {
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content,
reader: &mut R,
device: &Device,
) -> Result<Self> {
let md_get = |s: &str| match ct.metadata.get(s) {
None => candle::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v),
};
let head_count = md_get("gemma3.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("gemma3.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("gemma3.block_count")?.to_u32()? as usize;
let embedding_length = md_get("gemma3.embedding_length")?.to_u32()? as usize;
let key_length = md_get("gemma3.attention.key_length")?.to_u32()? as usize;
let _value_length = md_get("gemma3.attention.value_length")?.to_u32()? as usize;
let rms_norm_eps = md_get("gemma3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
let sliding_window_size = md_get("gemma3.attention.sliding_window")?.to_u32()? as usize;
let sliding_window_type = md_get("gemma3.attention.sliding_window_type")
.and_then(|m| Ok(m.to_u32()? as usize))
.unwrap_or(DEFAULT_SLIDING_WINDOW_TYPE);
let rope_freq_base = md_get("gemma3.rope.freq_base")
.and_then(|m| m.to_f32())
.unwrap_or(DEFAULT_ROPE_FREQUENCY);
let rope_freq_base_sliding = md_get("gemma3.rope.local_freq_base")
.and_then(|m| m.to_f32())
.unwrap_or(DEFAULT_ROPE_FREQUENCY_SLIDING);
// Unused in Llama.cpp so we aren't using it here.
let _rope_freq_scaling_factor = md_get("gemma3.rope.scaling.factor")
.and_then(|m| m.to_f32())
.unwrap_or(DEFAULT_ROPE_FREQUENCY_SCALE_FACTOR);
// Compute the dimensions for queries, keys, and values
// These are the total dimensions when projected across all heads
let q_dim = head_count * key_length;
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
// Load token embeddings and output projection
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
let norm = RmsNorm::from_qtensor(
ct.tensor(reader, "output_norm.weight", device)?,
rms_norm_eps,
)?;
let output = match ct.tensor(reader, "output.weight", device) {
Ok(tensor) => tensor,
Err(_) => ct.tensor(reader, "token_embd.weight", device)?, // Use tied weights if output.weight doesn't exist
};
let mut layers = Vec::with_capacity(block_count);
for layer_idx in 0..block_count {
let prefix = format!("blk.{layer_idx}");
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
let attention_wo =
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
let attention_q_norm = RmsNorm::from_qtensor(
ct.tensor(reader, &format!("{prefix}.attn_q_norm.weight"), device)?,
rms_norm_eps,
)?;
let attention_k_norm = RmsNorm::from_qtensor(
ct.tensor(reader, &format!("{prefix}.attn_k_norm.weight"), device)?,
rms_norm_eps,
)?;
let attention_norm = RmsNorm::from_qtensor(
ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?,
rms_norm_eps,
)?;
let post_attention_norm = RmsNorm::from_qtensor(
ct.tensor(
reader,
&format!("{prefix}.post_attention_norm.weight"),
device,
)?,
rms_norm_eps,
)?;
let ffn_norm = RmsNorm::from_qtensor(
ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?,
rms_norm_eps,
)?;
let post_ffn_norm = RmsNorm::from_qtensor(
ct.tensor(reader, &format!("{prefix}.post_ffw_norm.weight"), device)?,
rms_norm_eps,
)?;
let feed_forward_gate =
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
let feed_forward_up = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
let feed_forward_down =
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
let mlp = Mlp {
feed_forward_gate: QMatMul::from_qtensor(feed_forward_gate)?,
feed_forward_up: QMatMul::from_qtensor(feed_forward_up)?,
feed_forward_down: QMatMul::from_qtensor(feed_forward_down)?,
};
// Sliding window pattern hardcoded to 6 because it's not explicitly defined
let is_sliding = (layer_idx + 1) % sliding_window_type > 0;
let sliding_window_size = is_sliding.then_some(sliding_window_size);
let layer_rope_frequency = if is_sliding {
rope_freq_base_sliding
} else {
rope_freq_base
};
let rotary_embedding = RotaryEmbedding::new(key_length, layer_rope_frequency, device)?;
// Tracing spans
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
layers.push(LayerWeights {
attention_wq: QMatMul::from_qtensor(attention_wq)?,
attention_wk: QMatMul::from_qtensor(attention_wk)?,
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
attention_q_norm,
attention_k_norm,
attention_norm,
post_attention_norm,
ffn_norm,
post_ffn_norm,
mlp,
n_head: head_count,
n_kv_head: head_count_kv,
head_dim: key_length,
q_dim,
sliding_window_size,
rotary_embedding,
neg_inf: neg_inf.clone(),
kv_cache: None,
span_attn,
span_mlp,
})
}
let span = tracing::span!(tracing::Level::TRACE, "model");
let span_output = tracing::span!(tracing::Level::TRACE, "output");
Ok(Self {
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
embedding_length,
layers,
norm,
output: QMatMul::from_qtensor(output)?,
span,
span_output,
})
}
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (b_sz, seq_len) = x.dims2()?;
let _enter = self.span.enter();
let mut layer_in = self.tok_embeddings.forward(x)?;
layer_in = (layer_in * (self.embedding_length as f64).sqrt())?;
for layer in self.layers.iter_mut() {
let attention_mask = if seq_len == 1 {
None
} else {
Some(layer.mask(b_sz, seq_len, index_pos, x.dtype(), x.device())?)
};
// Attention block
let residual = &layer_in;
let x = layer.attention_norm.forward(&layer_in)?;
let x = layer.forward_attn(&x, attention_mask.as_ref(), index_pos)?;
let x = layer.post_attention_norm.forward(&x)?;
let x = (x + residual)?;
// Feed-forward block
let _enter = layer.span_mlp.enter();
let residual = &x;
let x = layer.ffn_norm.forward(&x)?;
let x = layer.mlp.forward(&x)?;
let x = layer.post_ffn_norm.forward(&x)?;
let x = (x + residual)?;
drop(_enter);
layer_in = x;
}
let _enter = self.span_output.enter();
let x = layer_in.i((.., seq_len - 1, ..))?;
let x = self.norm.forward(&x)?;
let output = self.output.forward(&x)?;
Ok(output)
}
}

View File

@ -0,0 +1,429 @@
//! Qwen3 implementation with quantization support.
//!
//! Based on the Qwen3 architecture and implemented with quantized weights
//! for reduced memory usage and faster inference on compatible hardware.
//!
//! References:
//! - [Qwen3 Models](https://huggingface.co/Qwen/Qwen3-0.6B) (architecture based on official implementations)
//!
use super::with_tracing::QMatMul;
use crate::{quantized_nn::RmsNorm, utils::repeat_kv};
use candle::quantized::{gguf_file, QTensor};
use candle::{DType, Device, Result, Tensor};
use candle_nn::{kv_cache::KvCache, Activation, Embedding, Module};
use std::io::{Read, Seek};
use std::sync::Arc;
struct Gguf<R: Read + Seek> {
ct: gguf_file::Content,
reader: R,
device: Device,
}
impl<R: Read + Seek> Gguf<R> {
fn new(ct: gguf_file::Content, reader: R, device: Device) -> Self {
Self { ct, reader, device }
}
fn qmatmul(&mut self, name: &str) -> Result<QMatMul> {
let ws = self.ct.tensor(&mut self.reader, name, &self.device)?;
QMatMul::from_weights(ws.into())
}
fn rms_norm(&mut self, name: &str, eps: f64) -> Result<RmsNorm> {
let ws = self.ct.tensor(&mut self.reader, name, &self.device)?;
RmsNorm::from_qtensor(ws, eps)
}
fn metadata(&self) -> &std::collections::HashMap<String, gguf_file::Value> {
&self.ct.metadata
}
fn tensor(&mut self, name: &str) -> Result<QTensor> {
self.ct.tensor(&mut self.reader, name, &self.device)
}
}
#[derive(Debug, Clone)]
struct MlpWeights {
gate_proj: QMatMul,
up_proj: QMatMul,
down_proj: QMatMul,
act_fn: Activation,
span: tracing::Span,
}
impl MlpWeights {
fn new<R: Read + Seek>(gg: &mut Gguf<R>, prefix: &str) -> Result<Self> {
let gate_proj = gg.qmatmul(&format!("{prefix}.ffn_gate.weight"))?;
let up_proj = gg.qmatmul(&format!("{prefix}.ffn_up.weight"))?;
let down_proj = gg.qmatmul(&format!("{prefix}.ffn_down.weight"))?;
let act_fn = Activation::Silu;
let span = tracing::span!(tracing::Level::TRACE, "mlp");
Ok(Self {
gate_proj,
up_proj,
down_proj,
act_fn,
span,
})
}
}
impl Module for MlpWeights {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let gate = self.gate_proj.forward(x)?.apply(&self.act_fn)?;
let up = self.up_proj.forward(x)?;
let gated = (gate * up)?;
self.down_proj.forward(&gated)
}
}
#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
impl RotaryEmbedding {
fn new(
dtype: DType,
head_dim: usize,
max_position_embeddings: usize,
rope_theta: f64,
dev: &Device,
) -> Result<Self> {
let dim = head_dim;
let max_seq_len = max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
/// Apply RoPE (q, k shape: B x H x L x D)
fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> {
let (_, _, seq_len, _) = q.dims4()?;
let cos = self.cos.narrow(0, offset, seq_len)?.to_dtype(q.dtype())?;
let sin = self.sin.narrow(0, offset, seq_len)?.to_dtype(q.dtype())?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
#[derive(Debug, Clone)]
struct AttentionWeights {
q_proj: QMatMul,
k_proj: QMatMul,
v_proj: QMatMul,
o_proj: QMatMul,
q_norm: RmsNorm,
k_norm: RmsNorm,
num_heads: usize,
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
rotary_emb: Arc<RotaryEmbedding>,
kv_cache: KvCache,
span_attn: tracing::Span,
}
impl AttentionWeights {
fn new<R: Read + Seek>(
gg: &mut Gguf<R>,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
rms_norm_eps: f64,
rotary_emb: Arc<RotaryEmbedding>,
prefix: &str,
) -> Result<Self> {
let num_kv_groups = num_heads / num_kv_heads;
let q_proj = gg.qmatmul(&format!("{prefix}.attn_q.weight"))?;
let k_proj = gg.qmatmul(&format!("{prefix}.attn_k.weight"))?;
let v_proj = gg.qmatmul(&format!("{prefix}.attn_v.weight"))?;
let o_proj = gg.qmatmul(&format!("{prefix}.attn_output.weight"))?;
let q_norm = gg.rms_norm(&format!("{prefix}.attn_q_norm.weight"), rms_norm_eps)?;
let k_norm = gg.rms_norm(&format!("{prefix}.attn_k_norm.weight"), rms_norm_eps)?;
// Initialize KV cache with 512 tokens capacity to reduce initial memory allocation.
// The cache will grow in chunks of 512 tokens when needed.
let kv_cache = KvCache::new(2, 512);
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
num_heads,
num_kv_heads,
num_kv_groups,
head_dim,
rotary_emb,
kv_cache,
span_attn,
})
}
fn forward(&mut self, x: &Tensor, attn_mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
let _enter = self.span_attn.enter();
let (b, l, _) = x.dims3()?;
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;
let q = q
.reshape((b, l, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b, l, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b, l, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let q_flat = q.flatten(0, 2)?;
let k_flat = k.flatten(0, 2)?;
let q_flat = self.q_norm.forward(&q_flat)?;
let k_flat = self.k_norm.forward(&k_flat)?;
let q = q_flat.reshape((b, self.num_heads, l, self.head_dim))?;
let k = k_flat.reshape((b, self.num_kv_heads, l, self.head_dim))?;
let (q, k) = self.rotary_emb.apply(&q, &k, offset)?;
// Reset KV cache if we're at the first position
if offset == 0 {
self.kv_cache.reset();
}
let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
// Make tensor contiguous to avoid some strided copies
let k = k.contiguous()?;
let v = v.contiguous()?;
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
let scale = 1.0 / (self.head_dim as f64).sqrt();
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
if let Some(m) = attn_mask {
let m_dtype = m.dtype();
let scores_dtype = scores.dtype();
let mask = if m_dtype != scores_dtype {
m.to_dtype(scores_dtype)?
} else {
m.clone()
};
scores = scores.broadcast_add(&mask)?;
}
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
let ctx = probs.matmul(&v)?; // (B, H, L, D)
let reshaped_ctx = ctx
.transpose(1, 2)?
.reshape((b, l, self.num_heads * self.head_dim))?;
self.o_proj.forward(&reshaped_ctx)
}
}
#[derive(Debug, Clone)]
struct LayerWeights {
self_attn: AttentionWeights,
mlp: MlpWeights,
ln1: RmsNorm,
ln2: RmsNorm,
}
impl LayerWeights {
fn new<R: Read + Seek>(
gg: &mut Gguf<R>,
num_attention_heads: usize,
num_key_value_heads: usize,
head_dim: usize,
rms_norm_eps: f64,
rotary: Arc<RotaryEmbedding>,
layer_idx: usize,
) -> Result<Self> {
let prefix = format!("blk.{layer_idx}");
let ln1 = gg.rms_norm(&format!("{prefix}.attn_norm.weight"), rms_norm_eps)?;
let ln2 = gg.rms_norm(&format!("{prefix}.ffn_norm.weight"), rms_norm_eps)?;
let self_attn = AttentionWeights::new(
gg,
num_attention_heads,
num_key_value_heads,
head_dim,
rms_norm_eps,
rotary,
&prefix,
)?;
let mlp = MlpWeights::new(gg, &prefix)?;
Ok(Self {
self_attn,
mlp,
ln1,
ln2,
})
}
fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
let h = self.ln1.forward(x)?;
let h = self.self_attn.forward(&h, mask, offset)?;
let x = (x + h)?;
let h2 = self.ln2.forward(&x)?;
let h2 = h2.apply(&self.mlp)?;
x + h2
}
}
#[derive(Debug, Clone)]
pub struct ModelWeights {
embed_tokens: Embedding,
layers: Vec<LayerWeights>,
norm: RmsNorm,
lm_head: QMatMul,
device: Device,
dtype: DType,
span: tracing::Span,
span_output: tracing::Span,
}
impl ModelWeights {
pub fn from_gguf<R: Read + Seek>(
ct: gguf_file::Content,
reader: &mut R,
device: &Device,
) -> Result<Self> {
let mut gg = Gguf::new(ct, reader, device.clone());
let md_get = |s: &str| match gg.metadata().get(s) {
None => candle::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v),
};
let num_attention_heads = md_get("qwen3.attention.head_count")?.to_u32()? as usize;
let num_kv_heads = md_get("qwen3.attention.head_count_kv")?.to_u32()? as usize;
let head_dim = md_get("qwen3.attention.key_length")?.to_u32()? as usize;
let num_layers = md_get("qwen3.block_count")?.to_u32()? as usize;
let hidden_size = md_get("qwen3.embedding_length")?.to_u32()? as usize;
let max_position_embeddings = md_get("qwen3.context_length")?.to_u32()? as usize;
let rms_norm_eps = md_get("qwen3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
let rope_freq_base = md_get("qwen3.rope.freq_base")?.to_f32()? as f64;
let dtype = match gg.metadata().get("general.dtype") {
Some(v) => match v.to_u32() {
Ok(0) => DType::F32,
Ok(1) => DType::F16,
_ => DType::F16,
},
None => DType::F16,
};
let embed_tensor = gg.tensor("token_embd.weight")?;
let embed_tokens = Embedding::new(embed_tensor.dequantize(device)?, hidden_size);
let rotary = Arc::new(RotaryEmbedding::new(
dtype,
head_dim,
max_position_embeddings,
rope_freq_base,
device,
)?);
let mut layers = Vec::with_capacity(num_layers);
for i in 0..num_layers {
layers.push(LayerWeights::new(
&mut gg,
num_attention_heads,
num_kv_heads,
head_dim,
rms_norm_eps,
rotary.clone(),
i,
)?);
}
let norm = gg.rms_norm("output_norm.weight", rms_norm_eps)?;
// Load output projection tensor, falling back to tied embeddings like gemma3
let lm_head_tensor = match gg.tensor("output.weight") {
Ok(tensor) => tensor,
Err(_) => gg.tensor("token_embd.weight")?,
};
let lm_head = QMatMul::from_weights(lm_head_tensor.into())?;
let span = tracing::span!(tracing::Level::TRACE, "model");
let span_output = tracing::span!(tracing::Level::TRACE, "output");
Ok(Self {
embed_tokens,
layers,
norm,
lm_head,
device: device.clone(),
dtype,
span,
span_output,
})
}
fn causal_mask(
&self,
b: usize,
tgt: usize,
offset: usize,
sw: Option<usize>,
) -> Result<Tensor> {
let minf = f32::NEG_INFINITY;
let mask: Vec<_> = (0..tgt)
.flat_map(|i| {
(0..(tgt + offset)).map(move |j| {
let past_ok = j <= i + offset;
let sw_ok = match sw {
Some(w) => (i + offset) as i64 - j as i64 <= w as i64,
None => true,
};
if past_ok && sw_ok {
0.
} else {
minf
}
})
})
.collect();
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
}
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
let _enter = self.span.enter();
let (b, l) = input.dims2()?;
let mut h = self.embed_tokens.forward(input)?;
let causal_mask = if l == 1 {
None
} else {
Some(self.causal_mask(b, l, offset, None)?)
};
for layer in &mut self.layers {
h = layer.forward(&h, causal_mask.as_ref(), offset)?;
}
let h = self.norm.forward(&h)?;
let _enter = self.span_output.enter();
let last_hidden = h.narrow(1, l - 1, 1)?;
self.lm_head.forward(&last_hidden)?.squeeze(1)
}
}

View File

@ -0,0 +1,389 @@
use crate::{
models::with_tracing::{linear_b, linear_no_bias, Linear, RmsNorm},
utils::repeat_kv,
};
use candle::{DType, Device, Module, Result, Tensor};
use candle_nn::{kv_cache::KvCache, Activation, VarBuilder};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct Config {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub head_dim: usize,
pub attention_bias: bool,
pub num_key_value_heads: usize,
pub max_position_embeddings: usize,
pub sliding_window: Option<usize>,
pub max_window_layers: usize,
pub tie_word_embeddings: bool,
pub rope_theta: f64,
pub rms_norm_eps: f64,
pub use_sliding_window: bool,
pub hidden_act: Activation,
}
#[derive(Debug, Clone)]
pub(crate) struct Qwen3RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
impl Qwen3RotaryEmbedding {
pub(crate) fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.head_dim;
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
/// Apply RoPE (q, k shape: B x H x L x D)
pub(crate) fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> {
let (_, _, seq_len, _) = q.dims4()?;
let cos = self.cos.narrow(0, offset, seq_len)?;
let sin = self.sin.narrow(0, offset, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
#[derive(Debug, Clone)]
pub(crate) struct Qwen3MLP {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
act_fn: Activation,
}
impl Qwen3MLP {
pub(crate) fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
Ok(Self {
gate_proj: linear_no_bias(cfg.hidden_size, cfg.intermediate_size, vb.pp("gate_proj"))?,
up_proj: linear_no_bias(cfg.hidden_size, cfg.intermediate_size, vb.pp("up_proj"))?,
down_proj: linear_no_bias(cfg.intermediate_size, cfg.hidden_size, vb.pp("down_proj"))?,
act_fn: cfg.hidden_act,
})
}
}
impl Module for Qwen3MLP {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let lhs = x.apply(&self.gate_proj)?.apply(&self.act_fn)?;
let rhs = x.apply(&self.up_proj)?;
(lhs * rhs)?.apply(&self.down_proj)
}
}
#[derive(Debug, Clone)]
pub(crate) struct Qwen3Attention {
// projections
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
// norms
q_norm: RmsNorm,
k_norm: RmsNorm,
// hyper params
num_heads: usize,
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
hidden_size: usize,
// utils
rotary_emb: Arc<Qwen3RotaryEmbedding>,
kv_cache: KvCache,
}
impl Qwen3Attention {
pub(crate) fn new(
cfg: &Config,
rotary_emb: Arc<Qwen3RotaryEmbedding>,
vb: VarBuilder,
) -> Result<Self> {
if cfg.use_sliding_window {
candle::bail!("sliding window is not suppored")
}
let head_dim = cfg.head_dim;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let num_kv_groups = num_heads / num_kv_heads;
let q_proj = linear_b(
cfg.hidden_size,
num_heads * head_dim,
cfg.attention_bias,
vb.pp("q_proj"),
)?;
let k_proj = linear_b(
cfg.hidden_size,
num_kv_heads * head_dim,
cfg.attention_bias,
vb.pp("k_proj"),
)?;
let v_proj = linear_b(
cfg.hidden_size,
num_kv_heads * head_dim,
cfg.attention_bias,
vb.pp("v_proj"),
)?;
let o_proj = linear_b(
num_heads * head_dim,
cfg.hidden_size,
cfg.attention_bias,
vb.pp("o_proj"),
)?;
let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?;
let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?;
// Necessary because the hidden_size in the config isn't always accurate
let hidden_size = head_dim * cfg.num_attention_heads;
// Initialize KV cache with 512 tokens capacity to reduce initial memory allocation.
// The cache will grow in chunks of 512 tokens when needed.
let kv_cache = KvCache::new(2, 512);
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
num_heads,
num_kv_heads,
num_kv_groups,
head_dim,
hidden_size,
rotary_emb,
kv_cache,
})
}
pub(crate) fn forward(
&mut self,
x: &Tensor,
attn_mask: Option<&Tensor>,
offset: usize,
) -> Result<Tensor> {
let (b, l, _) = x.dims3()?;
// 1. Proj
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;
// 2. Reshape: (B, L, H, D) -> (B, H, L, D)
let q = q
.reshape((b, l, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b, l, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b, l, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
// 3. Perhead RMSNorm
let q_flat = q.flatten(0, 2)?; // (B*H, L, D) -> (BHL, D) after transpose later
let k_flat = k.flatten(0, 2)?;
let q_flat = self.q_norm.forward(&q_flat)?;
let k_flat = self.k_norm.forward(&k_flat)?;
let q = q_flat.reshape((b, self.num_heads, l, self.head_dim))?;
let k = k_flat.reshape((b, self.num_kv_heads, l, self.head_dim))?;
// 4. RoPE
let (q, k) = self.rotary_emb.apply(&q, &k, offset)?;
// 5. Accumulate KV cache
let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
// 6. GQA repeat_kv
let k = repeat_kv(k, self.num_kv_groups)?;
let v = repeat_kv(v, self.num_kv_groups)?;
// 7. Attention score
let scale = 1.0 / (self.head_dim as f64).sqrt();
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
if let Some(m) = attn_mask {
scores = scores.broadcast_add(m)?;
}
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
let ctx = probs.matmul(&v)?; // (B, H, L, D)
// 8. Output proj
ctx.transpose(1, 2)?
.reshape((b, l, self.hidden_size))?
.apply(&self.o_proj)
}
pub(crate) fn clear_kv_cache(&mut self) {
self.kv_cache.reset();
}
}
#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Qwen3Attention,
mlp: Qwen3MLP,
ln1: RmsNorm,
ln2: RmsNorm,
}
impl DecoderLayer {
fn new(cfg: &Config, rotary: Arc<Qwen3RotaryEmbedding>, vb: VarBuilder) -> Result<Self> {
let self_attn = Qwen3Attention::new(cfg, rotary, vb.pp("self_attn"))?;
let mlp = Qwen3MLP::new(cfg, vb.pp("mlp"))?;
let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
let ln2 = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
vb.pp("post_attention_layernorm"),
)?;
Ok(Self {
self_attn,
mlp,
ln1,
ln2,
})
}
fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
let h = self.ln1.forward(x)?;
let h = self.self_attn.forward(&h, mask, offset)?;
let x = (x + h)?;
let h2 = self.ln2.forward(&x)?;
let h2 = h2.apply(&self.mlp)?;
x + h2
}
fn clear_kv_cache(&mut self) {
self.self_attn.clear_kv_cache();
}
}
#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,
norm: RmsNorm,
device: Device,
dtype: DType,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let embed_tokens =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
let rotary = Arc::new(Qwen3RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?);
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb.pp("model.layers");
for i in 0..cfg.num_hidden_layers {
layers.push(DecoderLayer::new(cfg, rotary.clone(), vb_l.pp(i))?);
}
Ok(Self {
embed_tokens,
layers,
norm: RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
fn clear_kv_cache(&mut self) {
for l in &mut self.layers {
l.clear_kv_cache();
}
}
fn causal_mask(
&self,
b: usize,
tgt: usize,
offset: usize,
sw: Option<usize>,
) -> Result<Tensor> {
let minf = f32::NEG_INFINITY;
let mask: Vec<_> = (0..tgt)
.flat_map(|i| {
(0..(tgt + offset)).map(move |j| {
let past_ok = j <= i + offset;
let sw_ok = match sw {
Some(w) => (i + offset) as i64 - j as i64 <= w as i64,
None => true,
};
if past_ok && sw_ok {
0.
} else {
minf
}
})
})
.collect();
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
}
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
let (b, l) = input.dims2()?;
let mut h = self.embed_tokens.forward(input)?;
let causal = if l == 1 {
None
} else {
Some(self.causal_mask(b, l, offset, None)?)
};
for layer in &mut self.layers {
h = layer.forward(&h, causal.as_ref(), offset)?;
}
self.norm.forward(&h)
}
}
#[derive(Debug, Clone)]
pub struct ModelForCausalLM {
base: Model,
lm_head: Linear,
}
impl ModelForCausalLM {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let base = Model::new(cfg, vb.clone())?;
let lm_head = if cfg.tie_word_embeddings {
Linear::from_weights(base.embed_tokens.embeddings().clone(), None)
} else {
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
};
Ok(Self { base, lm_head })
}
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
let (_, l) = input.dims2()?;
self.base
.forward(input, offset)?
.narrow(1, l - 1, 1)?
.apply(&self.lm_head)
}
pub fn clear_kv_cache(&mut self) {
self.base.clear_kv_cache();
}
}

View File

@ -0,0 +1,355 @@
use crate::models::{
qwen3::{Config as Qwen3Config, Qwen3Attention, Qwen3MLP, Qwen3RotaryEmbedding},
with_tracing::{linear_no_bias, Linear, RmsNorm},
};
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct Config {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub head_dim: usize,
pub attention_bias: bool,
pub num_key_value_heads: usize,
pub max_position_embeddings: usize,
pub sliding_window: Option<usize>,
pub max_window_layers: usize,
pub tie_word_embeddings: bool,
pub rope_theta: f64,
pub rms_norm_eps: f64,
pub use_sliding_window: bool,
pub hidden_act: Activation,
// MoE specific configuration
pub decoder_sparse_step: usize,
pub moe_intermediate_size: usize,
pub num_experts_per_tok: usize,
pub num_experts: usize,
pub norm_topk_prob: bool,
}
impl From<&Config> for Qwen3Config {
fn from(val: &Config) -> Self {
Qwen3Config {
vocab_size: val.vocab_size,
hidden_size: val.hidden_size,
intermediate_size: val.intermediate_size,
num_hidden_layers: val.num_hidden_layers,
num_attention_heads: val.num_attention_heads,
head_dim: val.head_dim,
attention_bias: val.attention_bias,
num_key_value_heads: val.num_key_value_heads,
max_position_embeddings: val.max_position_embeddings,
sliding_window: val.sliding_window,
max_window_layers: val.max_window_layers,
tie_word_embeddings: val.tie_word_embeddings,
rope_theta: val.rope_theta,
rms_norm_eps: val.rms_norm_eps,
use_sliding_window: val.use_sliding_window,
hidden_act: val.hidden_act,
}
}
}
#[derive(Debug, Clone)]
struct Qwen3MLPExpert {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
act_fn: Activation,
}
impl Qwen3MLPExpert {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
Ok(Self {
gate_proj: linear_no_bias(
cfg.hidden_size,
cfg.moe_intermediate_size,
vb.pp("gate_proj"),
)?,
up_proj: linear_no_bias(cfg.hidden_size, cfg.moe_intermediate_size, vb.pp("up_proj"))?,
down_proj: linear_no_bias(
cfg.moe_intermediate_size,
cfg.hidden_size,
vb.pp("down_proj"),
)?,
act_fn: cfg.hidden_act,
})
}
}
impl Module for Qwen3MLPExpert {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let lhs = x.apply(&self.gate_proj)?.apply(&self.act_fn)?;
let rhs = x.apply(&self.up_proj)?;
(lhs * rhs)?.apply(&self.down_proj)
}
}
// Qwen3 Sparse MoE Block implementation
#[derive(Debug, Clone)]
struct Qwen3SparseMoeBlock {
gate: Linear,
experts: Vec<Qwen3MLPExpert>,
norm_topk_prob: bool,
num_experts_per_tok: usize,
}
impl Qwen3SparseMoeBlock {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let gate = linear_no_bias(cfg.hidden_size, cfg.num_experts, vb.pp("gate"))?;
let mut experts = Vec::with_capacity(cfg.num_experts);
let vb_e = vb.pp("experts");
for idx in 0..cfg.num_experts {
let expert = Qwen3MLPExpert::new(cfg, vb_e.pp(idx))?;
experts.push(expert)
}
Ok(Self {
gate,
experts,
norm_topk_prob: cfg.norm_topk_prob,
num_experts_per_tok: cfg.num_experts_per_tok,
})
}
}
impl Module for Qwen3SparseMoeBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (b_size, seq_len, hidden_dim) = xs.dims3()?;
let xs = xs.reshape(((), hidden_dim))?;
let router_logits = xs.apply(&self.gate)?;
let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
// Extract topk experts per token
let experts_per_tok = routing_weights
.arg_sort_last_dim(false)?
.narrow(D::Minus1, 0, self.num_experts_per_tok)?
.contiguous()?;
let routing_weights = routing_weights.gather(&experts_per_tok, D::Minus1)?;
// Extract needed data
let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
let experts_per_tok = experts_per_tok.to_vec2::<u32>()?;
let mut top_x = vec![vec![]; self.experts.len()];
let mut selected_experts = vec![vec![]; self.experts.len()];
for (row_idx, (rw, expert_idxs)) in routing_weights
.iter()
.zip(experts_per_tok.iter())
.enumerate()
{
let sum_rw = rw.iter().sum::<f32>();
for (&rw, &expert_idx) in rw.iter().zip(expert_idxs.iter()) {
top_x[expert_idx as usize].push(row_idx as u32);
let rw = if self.norm_topk_prob { rw / sum_rw } else { rw };
selected_experts[expert_idx as usize].push(rw)
}
}
// Process through experts
let mut ys = xs.zeros_like()?;
for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
let top_x = &top_x[expert_idx];
if top_x.is_empty() {
continue;
}
let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
let selected_experts =
Tensor::new(selected_experts[expert_idx].as_slice(), xs.device())?
.reshape(((), 1))?
.to_dtype(xs.dtype())?;
let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
let current_hidden_states = expert_layer.forward(&current_state)?;
let current_hidden_states = current_hidden_states.broadcast_mul(&selected_experts)?;
ys = ys.index_add(&top_x, &current_hidden_states, 0)?;
}
ys.reshape((b_size, seq_len, hidden_dim))
}
}
// MLP or MoE decision enum
#[derive(Debug, Clone)]
enum Qwen3FeedForward {
Mlp(Qwen3MLP),
MoE(Qwen3SparseMoeBlock),
}
impl Module for Qwen3FeedForward {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::Mlp(m) => m.forward(xs),
Self::MoE(m) => m.forward(xs),
}
}
}
#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Qwen3Attention,
feed_forward: Qwen3FeedForward,
ln1: RmsNorm,
ln2: RmsNorm,
}
impl DecoderLayer {
fn new(
layer_idx: usize,
cfg: &Config,
rotary: Arc<Qwen3RotaryEmbedding>,
vb: VarBuilder,
) -> Result<Self> {
let self_attn = Qwen3Attention::new(&cfg.into(), rotary, vb.pp("self_attn"))?;
// Decide whether to use MoE or regular MLP based on layer_idx and decoder_sparse_step
let feed_forward = if cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0
{
Qwen3FeedForward::MoE(Qwen3SparseMoeBlock::new(cfg, vb.pp("mlp"))?)
} else {
Qwen3FeedForward::Mlp(Qwen3MLP::new(&cfg.into(), vb.pp("mlp"))?)
};
let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
let ln2 = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
vb.pp("post_attention_layernorm"),
)?;
Ok(Self {
self_attn,
feed_forward,
ln1,
ln2,
})
}
fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
let h = self.ln1.forward(x)?;
let h = self.self_attn.forward(&h, mask, offset)?;
let x = (x + h)?;
let h2 = self.ln2.forward(&x)?;
let h2 = h2.apply(&self.feed_forward)?;
x + h2
}
fn clear_kv_cache(&mut self) {
self.self_attn.clear_kv_cache();
}
}
#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,
norm: RmsNorm,
device: Device,
dtype: DType,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let embed_tokens =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
let rotary = Arc::new(Qwen3RotaryEmbedding::new(
vb.dtype(),
&cfg.into(),
vb.device(),
)?);
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb.pp("model.layers");
for i in 0..cfg.num_hidden_layers {
layers.push(DecoderLayer::new(i, cfg, rotary.clone(), vb_l.pp(i))?);
}
Ok(Self {
embed_tokens,
layers,
norm: RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
fn clear_kv_cache(&mut self) {
for l in &mut self.layers {
l.clear_kv_cache();
}
}
fn causal_mask(
&self,
b: usize,
tgt: usize,
offset: usize,
sw: Option<usize>,
) -> Result<Tensor> {
let minf = f32::NEG_INFINITY;
let mask: Vec<_> = (0..tgt)
.flat_map(|i| {
(0..(tgt + offset)).map(move |j| {
let past_ok = j <= i + offset;
let sw_ok = match sw {
Some(w) => (i + offset) as i64 - j as i64 <= w as i64,
None => true,
};
if past_ok && sw_ok {
0.
} else {
minf
}
})
})
.collect();
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
}
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
let (b, l) = input.dims2()?;
let mut h = self.embed_tokens.forward(input)?;
let causal = if l == 1 {
None
} else {
Some(self.causal_mask(b, l, offset, None)?)
};
for layer in &mut self.layers {
h = layer.forward(&h, causal.as_ref(), offset)?;
}
self.norm.forward(&h)
}
}
#[derive(Debug, Clone)]
pub struct ModelForCausalLM {
base: Model,
lm_head: Linear,
}
impl ModelForCausalLM {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let base = Model::new(cfg, vb.clone())?;
let lm_head = if cfg.tie_word_embeddings {
Linear::from_weights(base.embed_tokens.embeddings().clone(), None)
} else {
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
};
Ok(Self { base, lm_head })
}
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
let (_, l) = input.dims2()?;
self.base
.forward(input, offset)?
.narrow(1, l - 1, 1)?
.apply(&self.lm_head)
}
pub fn clear_kv_cache(&mut self) {
self.base.clear_kv_cache();
}
}

View File

@ -17,8 +17,8 @@ const CROP_NMS_THRESH: f32 = 0.7;
#[derive(Debug)]
enum ImageEncoder {
Original(ImageEncoderViT),
TinyViT(TinyViT),
Original(Box<ImageEncoderViT>),
TinyViT(Box<TinyViT>),
}
impl Module for ImageEncoder {
@ -83,7 +83,7 @@ impl Sam {
let pixel_std =
Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
Ok(Self {
image_encoder: ImageEncoder::Original(image_encoder),
image_encoder: ImageEncoder::Original(image_encoder.into()),
prompt_encoder,
mask_decoder,
pixel_std,
@ -114,7 +114,7 @@ impl Sam {
let pixel_std =
Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
Ok(Self {
image_encoder: ImageEncoder::TinyViT(image_encoder),
image_encoder: ImageEncoder::TinyViT(image_encoder.into()),
prompt_encoder,
mask_decoder,
pixel_std,

View File

@ -134,12 +134,7 @@ impl Scheduler for DDIMScheduler {
timestep
};
// https://github.com/huggingface/diffusers/blob/6e099e2c8ce4c4f5c7318e970a8c093dc5c7046e/src/diffusers/schedulers/scheduling_ddim.py#L195
let prev_timestep = if timestep > self.step_ratio {
timestep - self.step_ratio
} else {
0
};
let prev_timestep = timestep.saturating_sub(self.step_ratio);
let alpha_prod_t = self.alphas_cumprod[timestep];
let alpha_prod_t_prev = self.alphas_cumprod[prev_timestep];
let beta_prod_t = 1. - alpha_prod_t;

View File

@ -482,8 +482,10 @@ impl XLMRobertaClassificationHead {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let cls_states = hidden_states.get_on_dim(1, 0)?.contiguous()?;
let hidden_states = self.dense.forward(&cls_states)?;
let hidden_states = candle_nn::Activation::GeluPytorchTanh.forward(&hidden_states)?;
let hidden_states = self.out_proj.forward(&hidden_states)?;
// The activation used in the classification head is tanh, as per the original
// implementation.
// https://github.com/huggingface/transformers/blob/6e3063422c4b1c014aa60c32b9254fd2902f0f28/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py#L1454
let hidden_states = self.out_proj.forward(&hidden_states.tanh()?)?;
Ok(hidden_states)
}
}