Compare commits

...

22 Commits

Author SHA1 Message Date
cd96fa80da Add a scattered kv cache. (#2936)
* Add a scattered kv cache.

* Update some comments.
2025-05-01 10:20:48 +02:00
8a19bb7df2 Bump the candle version to 0.9.1. (#2935) 2025-05-01 10:08:16 +02:00
38fc86621c Add support for Helium-v1. (#2932) 2025-04-30 19:38:44 +02:00
5029ac52bb Added tracing page to the candle book. (#2922)
* tracing page

* warned about asynchronous execution

* cleanup

* added Nsignt Systems recommendation
2025-04-29 21:35:36 +02:00
de23d34a28 Switch Tensor::full to return a contiguous tensor. (#2929) 2025-04-28 21:36:39 +02:00
d4bac37a61 Fix the gumbel softmax by casting to f32. (#2928) 2025-04-28 19:48:51 +02:00
e98754fc5a Optimize Tensor::new when called on nested Vec<..>. (#2927)
* Optimize Tensor::new when called on nested Vec<..>.

* Improve performance.

* Similar flattening for the 4d case.

* More tweaks.

* Add some dummy test.
2025-04-28 09:19:45 +02:00
e3db30021f Support for "unbatched" rope. (#2926)
* Support for (un)-batched rope.

* Use 3d rope in the rope/ropei/rope_thd functions.

* Get the CPU versions to work.

* Fix the cuda version.

* Adapt the metal side.

* Fix the metal tests.
2025-04-27 15:12:02 +02:00
6e0646c208 Remove redundant mlx gemm dtype check (#2925) 2025-04-27 06:14:57 +02:00
fbaf0b0e32 Bump the crate version to 0.9.0. (#2924) 2025-04-26 11:01:21 +02:00
a2e925462c Add the scatter in place ops. (#2923)
* Add the scatter_set op.

* Metal op.

* Cuda version.

* Merge the checks.

* Add the actual ops.
2025-04-26 07:36:49 +02:00
3827685524 Add the scatter op. (#2921)
* Add the scatter op.

* Backprop support.

* Cuda support.
2025-04-25 21:46:58 +02:00
3aeb9575c7 Fixed Quantized Gemma3 Model and example (#2918)
* removed scale factor from computation and made quantized gemma3 work similarly to non-quantized gemma3

* created default consts, replaced is_sliding with Option holding a window_size
2025-04-25 05:47:48 +02:00
6ff0a6999c Fixed Gemma3 model and example (#2917)
* gemma3: changed RotaryEmbedding base freq based on layer and sliding window

* Changed attention mask per layer, either normal or sliding

* made attention mask creation slightly more efficient by only creating them once per model iteration

* changed is_sliding to an Option

* clippy

* changed to stop on both <eos> and <end_of_turn> instead of either or
2025-04-25 05:35:08 +02:00
82def7ae38 Cudarc update. (#2915) 2025-04-23 07:03:26 +02:00
99bd69f383 fixed quantized-gemma example (#2914)
* fixed quantized-gemma example

* lint
2025-04-23 05:39:03 +02:00
a4c56a958e Add the const-set op. (#2910)
* Add the const-set op.

* Cuda implementation.

* Bugfix.

* Metal cleanup.

* Add the metal kernels.

* Add some testing.

* Finish the metal implementation.

* Bump the version.
2025-04-19 10:07:02 +02:00
b2904a830b implemented quantized-gemma3 (#2902)
* implemented quantized-gemma, inference not working

* Fixed a few modeling bugs: outputing the correct tokens for a few iterations then garbage

* lint

* clippy

* quantized-gemma3 example working

* added readme

* clippy
2025-04-19 07:46:41 +02:00
21055b5697 Add PRelu operation (#2904)
* Add PRelu operation

* Apply rustfmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2025-04-19 07:24:10 +02:00
9dbaf958dc Add an enum for scalar values. (#2909)
* Add a scalar enum type.

* Add a bit more to the scalar type.

* Small tweak.

* More scalar usage.
2025-04-18 22:13:38 +02:00
ce5f8dd129 Check the bounds in the cuda indexing kernels. (#2908)
* Check the bounds in the cuda indexing kernels.

* Another check.
2025-04-18 20:08:17 +02:00
9954981327 Allow from_vec/from_slice to use a ShapeWithOneHole as shape. (#2905) 2025-04-17 08:59:18 +02:00
52 changed files with 3044 additions and 473 deletions

View File

@ -20,7 +20,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.9.0-alpha.4"
version = "0.9.1"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -33,17 +33,17 @@ ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.4" }
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.4" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.4" }
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.4" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.4" }
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.4" }
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.4" }
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.4" }
candle = { path = "./candle-core", package = "candle-core", version = "0.9.1" }
candle-datasets = { path = "./candle-datasets", version = "0.9.1" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.1" }
candle-kernels = { path = "./candle-kernels", version = "0.9.1" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.1" }
candle-nn = { path = "./candle-nn", version = "0.9.1" }
candle-onnx = { path = "./candle-onnx", version = "0.9.1" }
candle-transformers = { path = "./candle-transformers", version = "0.9.1" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.16.0", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
cudarc = { version = "0.16.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.4.1"

View File

@ -16,6 +16,7 @@
- [Running a model](inference/inference.md)
- [Using the hub](inference/hub.md)
- [Error management](error_manage.md)
- [Tracing](tracing.md)
- [Training](training/training.md)
- [Simplified](training/simplified.md)
- [MNIST](training/mnist.md)

View File

@ -0,0 +1,68 @@
# Tracing
Tracing is a powerful tool for identifying performance issues and bottlenecks in code.
> Profiling on GPUs is trickier due to asynchronous execution, see the [GPU section](#gpu).
## Overview
Candle uses the [tracing](https://docs.rs/tracing/latest/tracing/) crate for instrumentation.
To try it out, run an example in `candle-examples` with the `--tracing` flag.
This generates a trace file, typically named `trace-<timestamp>.json`.
You can view the trace in Chrome by navigating to `chrome://tracing/`, clicking **Load**, and selecting the generated trace file.
## Adding Tracing
Candle includes built-in tracing for many internal operations, using [spans](https://docs.rs/tracing/latest/tracing/struct.Span.html) to mark key points of execution.
To add custom tracing in your code, you can define a span like this:
```rust
let span = tracing::span!(tracing::Level::TRACE, name);
```
Then, to record the span during execution, create a guard:
```rust
let _enter = span.enter();
```
This guard will record the span's duration, from when it is created to when it is dropped, into a global data structure managed by the tracing crate.
## Recording and Saving a Trace
To capture and save trace data, you need to configure the tracing system with an output format. Candle uses the [tracing_subscriber](https://docs.rs/tracing-subscriber/latest/tracing_subscriber/) and [tracing_chrome](https://docs.rs/tracing-chrome/latest/tracing_chrome/) crates.
The snippet below sets up a Chrome compatible recorder that logs all tracing activity between creation and drop of the guard:
```rust
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let _guard = {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
guard
};
```
## GPU
When using CUDA, Metal, or other asynchronous GPU backends, tracing may produce misleading timing data because operations are queued rather than executed immediately.
### CUDA
For CUDA-specific profiling, you have two options:
1. Set the environment variable `CUDA_LAUNCH_BLOCKING=1` which forces synchronous execution. This makes trace timings more accurate, at the cost of reduced performance.
2. Use [NVIDIA's Nsight Systems](https://developer.nvidia.com/nsight-systems) (`nsys profile` and `nsys-ui`) which are designed specifically for profiling asynchronous CUDA executions.
We recommend using NVIDIA's Nsight Systems when possible, as it offers accurate performance data without altering typical execution patterns. In contrast, setting the `CUDA_LAUNCH_BLOCKING` environment variable forces synchronous execution, which can significantly alter execution behavior.
#### Performance Profiling with NVIDIA Nsight Systems
1. Generate an `.nsys-rep` file containing performance data ([docs](https://docs.nvidia.com/nsight-systems/UserGuide/index.html#example-single-command-lines))
- Run `nsys profile --trace cuda,nvtx,osrt --gpu-metrics-device=all --output profile_run ./target/debug/... --prompt "whatever "`
1. Open the generated `.nsys-rep` report file in Nsight Systems GUI
- File > Open

View File

@ -4,11 +4,12 @@ use criterion::criterion_main;
criterion_main!(
benchmarks::affine::benches,
benchmarks::copy::benches,
benchmarks::conv_transpose2d::benches,
benchmarks::matmul::benches,
benchmarks::qmatmul::benches,
benchmarks::random::benches,
benchmarks::reduce::benches,
benchmarks::unary::benches,
benchmarks::where_cond::benches,
benchmarks::conv_transpose2d::benches,
benchmarks::qmatmul::benches,
benchmarks::unary::benches
);

View File

@ -0,0 +1,38 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{Device, Tensor, WithDType};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run_copy_mask_benchmark<D: WithDType>(c: &mut Criterion, device: &Device, name: &str) {
let batch_size = 128;
let in_seq_len = 1;
let kv_seq_len = 1024;
let attn_mask = vec![vec![vec![D::zero(); kv_seq_len]; in_seq_len]; batch_size];
let size_in_bytes = batch_size * in_seq_len * kv_seq_len * D::DTYPE.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(size_in_bytes as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let attn_masks = vec![attn_mask.clone(); iters as usize];
let start = Instant::now();
for attn_mask in attn_masks.into_iter() {
let tensor = Tensor::new(black_box(attn_mask), device).unwrap();
black_box(tensor);
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_copy_mask_benchmark::<f32>(c, &device, "copy_mask");
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -1,5 +1,6 @@
pub(crate) mod affine;
pub(crate) mod conv_transpose2d;
pub(crate) mod copy;
pub(crate) mod matmul;
pub(crate) mod qmatmul;
pub(crate) mod random;

View File

@ -71,15 +71,27 @@ pub trait BackendStorage: Sized {
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
fn scatter_add(
&self,
fn scatter_set(
&mut self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self>;
) -> Result<()>;
fn scatter_add_set(
&mut self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<()>;
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
fn index_add(
&self,
@ -113,6 +125,8 @@ pub trait BackendStorage: Sized {
_src_offset: usize,
_dst_offset: usize,
) -> Result<()>;
fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()>;
}
pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
@ -127,8 +141,6 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
/// # Safety
/// This function is unsafe as it doesn't initialize the underlying data store.
/// The caller should ensure that the data is properly initialized as early as possible

View File

@ -53,6 +53,7 @@ impl Tensor {
} else if let Some(op) = node.op() {
match op {
Op::IndexAdd(t1, t2, t3, _)
| Op::Scatter(t1, t2, t3, _)
| Op::ScatterAdd(t1, t2, t3, _)
| Op::CustomOp3(t1, t2, t3, _)
| Op::WhereCond(t1, t2, t3) => {
@ -419,7 +420,7 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
}
Op::ScatterAdd(init, indexes, src, dim) => {
Op::Scatter(init, indexes, src, dim) => {
let init_sum_grad = grads.or_insert(init)?;
*init_sum_grad = init_sum_grad.add(&grad)?;
@ -427,6 +428,16 @@ impl Tensor {
let src_sum_grad = grads.or_insert(src)?;
*src_sum_grad = src_sum_grad.add(&src_grad)?;
}
Op::ScatterAdd(init, indexes, src, dim) => {
let init_sum_grad = grads.or_insert(init)?;
let mask = init.ones_like()?;
let mask = mask.scatter(indexes, &mask.zeros_like()?, *dim)?;
*init_sum_grad = init_sum_grad.add(&grad.mul(&mask)?)?;
let src_grad = grad.gather(indexes, *dim)?;
let src_sum_grad = grads.or_insert(src)?;
*src_sum_grad = src_sum_grad.add(&src_grad)?;
}
Op::IndexAdd(init, indexes, src, dim) => {
let init_sum_grad = grads.or_insert(init)?;
*init_sum_grad = init_sum_grad.add(&grad)?;

View File

@ -7,7 +7,7 @@ use rayon::prelude::*;
mod utils;
pub use utils::{
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8,
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2InPlace, Map2U8,
};
const USE_IM2COL_CONV1D: bool = true;
@ -554,26 +554,65 @@ impl<I: IntDType> Map1 for IndexSelect<'_, I> {
}
}
struct ScatterAdd<'a, I: IntDType> {
trait ElemUpdate {
fn f<T: WithDType>(dst: &mut T, src: T);
}
struct Set;
struct Add;
impl ElemUpdate for Set {
fn f<T: WithDType>(dst: &mut T, src: T) {
*dst = src
}
}
impl ElemUpdate for Add {
fn f<T: WithDType>(dst: &mut T, src: T) {
*dst += src
}
}
struct Scatter<'a, I: IntDType, M: ElemUpdate> {
ids: &'a [I],
ids_l: &'a Layout,
dim: usize,
_phantom: std::marker::PhantomData<M>,
}
impl<I: IntDType> Map2 for ScatterAdd<'_, I> {
const OP: &'static str = "scatter-add";
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
let dst_len = l1.shape().elem_count();
let mut dst = vec![T::zero(); dst_len];
copy_strided_src_(v1, &mut dst, 0, l1);
impl<'a, I: IntDType, M: ElemUpdate> Scatter<'a, I, M> {
fn new(ids: &'a [I], ids_l: &'a Layout, dim: usize) -> Self {
Self {
ids,
ids_l,
dim,
_phantom: Default::default(),
}
}
}
impl<I: IntDType, M: ElemUpdate> Map2InPlace for Scatter<'_, I, M> {
const OP: &'static str = "scatter";
fn f<T: WithDType>(
&self,
dst: &mut [T],
dst_l: &Layout,
src: &[T],
src_l: &Layout,
) -> Result<()> {
let dst = match dst_l.contiguous_offsets() {
None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
Some((o1, o2)) => &mut dst[o1..o2],
};
let src = match src_l.contiguous_offsets() {
None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
Some((o1, o2)) => &src[o1..o2],
};
let dim = self.dim;
let ids_dims = self.ids_l.dims();
let dst_dims = l1.dims();
let dst_dims = dst_l.dims();
let dst_dim_len = dst_dims[dim];
let dst_right_len: usize = dst_dims[dim + 1..].iter().product();
@ -602,12 +641,12 @@ impl<I: IntDType> Map2 for ScatterAdd<'_, I> {
.bt())?
}
let dst_idx = start_dst_idx + index * dst_right_len + right_i;
dst[dst_idx] += src[ids_idx]
M::f(&mut dst[dst_idx], src[ids_idx])
}
}
}
Ok(dst)
Ok(())
}
}
@ -2381,19 +2420,36 @@ impl BackendStorage for CpuStorage {
}
}
fn scatter_add(
&self,
fn scatter_set(
&mut self,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<Self> {
) -> Result<()> {
match ids {
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::U8(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
Self::U32(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
Self::I64(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter").bt()),
}
}
fn scatter_add_set(
&mut self,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<()> {
match ids {
Self::U8(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
Self::U32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
Self::I64(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()),
}
}
@ -2454,6 +2510,48 @@ impl BackendStorage for CpuStorage {
fn to_cpu_storage(&self) -> Result<CpuStorage> {
Ok(self.clone())
}
fn const_set(&mut self, s: crate::scalar::Scalar, l: &Layout) -> Result<()> {
use crate::scalar::Scalar;
fn set<T: crate::WithDType>(src: &mut [T], l: &Layout, s: T) {
match l.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => {
src[start_offset..start_offset + len].fill(s)
}
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len: 1,
} => {
for src_index in block_start_index {
src[src_index] = s
}
}
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len,
} => {
for src_index in block_start_index {
src[src_index..src_index + block_len].fill(s)
}
}
}
}
match (self, s) {
(Self::BF16(storage), Scalar::BF16(v)) => set(storage, l, v),
(Self::F16(storage), Scalar::F16(v)) => set(storage, l, v),
(Self::F32(storage), Scalar::F32(v)) => set(storage, l, v),
(Self::F64(storage), Scalar::F64(v)) => set(storage, l, v),
(Self::U8(storage), Scalar::U8(v)) => set(storage, l, v),
(Self::U32(storage), Scalar::U32(v)) => set(storage, l, v),
(Self::I64(storage), Scalar::I64(v)) => set(storage, l, v),
(st, s) => crate::bail!(
"const_set dtype mismatch, expected {:?} but got {:?}",
st.dtype(),
s
),
}
Ok(())
}
}
impl BackendDevice for CpuDevice {
@ -2628,20 +2726,6 @@ impl BackendDevice for CpuDevice {
Ok(storage)
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
let elem_count = shape.elem_count();
let storage = match dtype {
DType::U8 => CpuStorage::U8(vec![1u8; elem_count]),
DType::U32 => CpuStorage::U32(vec![1u32; elem_count]),
DType::I64 => CpuStorage::I64(vec![1i64; elem_count]),
DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]),
DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]),
DType::F32 => CpuStorage::F32(vec![1f32; elem_count]),
DType::F64 => CpuStorage::F64(vec![1f64; elem_count]),
};
Ok(storage)
}
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
let elem_count = shape.elem_count();
let storage = match dtype {

View File

@ -58,6 +58,30 @@ pub trait Map2 {
}
}
pub trait Map2InPlace {
const OP: &'static str;
fn f<T: WithDType>(&self, v1: &mut [T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<()>;
fn map(&self, v1: &mut C, l1: &Layout, v2: &C, l2: &Layout) -> Result<()> {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => self.f(v1, l1, v2, l2)?,
(C::U32(v1), C::U32(v2)) => self.f(v1, l1, v2, l2)?,
(C::I64(v1), C::I64(v2)) => self.f(v1, l1, v2, l2)?,
(C::BF16(v1), C::BF16(v2)) => self.f(v1, l1, v2, l2)?,
(C::F16(v1), C::F16(v2)) => self.f(v1, l1, v2, l2)?,
(C::F32(v1), C::F32(v2)) => self.f(v1, l1, v2, l2)?,
(C::F64(v1), C::F64(v2)) => self.f(v1, l1, v2, l2)?,
(v1, v2) => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
op: Self::OP,
}
.bt())?,
};
Ok(())
}
}
pub trait Map2U8 {
const OP: &'static str;
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;

View File

@ -2,7 +2,7 @@ use crate::backend::BackendDevice;
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
pub use candle_kernels as kernels;
pub use cudarc;
use cudarc::driver::{CudaFunction, LaunchConfig, PushKernelArg};
use cudarc::driver::CudaFunction;
use half::{bf16, f16};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@ -188,100 +188,6 @@ impl CudaDevice {
self.id
}
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let slice = match dtype {
DType::U8 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u8>(elem_count)? };
let func = self.get_or_load_func("fill_u8", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as u8;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u32>(elem_count)? };
let func = self.get_or_load_func("fill_u32", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as u32;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<i64>(elem_count)? };
let func = self.get_or_load_func("fill_i64", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as i64;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<bf16>(elem_count)? };
let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = bf16::from_f64(v);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f16>(elem_count)? };
let func = self.get_or_load_func("fill_f16", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = f16::from_f64(v);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f32>(elem_count)? };
let func = self.get_or_load_func("fill_f32", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as f32;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f64>(elem_count) }?;
let func = self.get_or_load_func("fill_f64", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
pub fn get_or_load_custom_func(
&self,
fn_name: &str,
@ -504,10 +410,6 @@ impl BackendDevice for CudaDevice {
})
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
self.const_impl(1., shape, dtype)
}
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
let elem_count = shape.elem_count();
let slice = match dtype {

View File

@ -2,7 +2,7 @@
//!
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{builder_arg as barg, CpuStorage, DType, Layout, Result, Shape, WithDType};
use crate::{builder_arg as barg, CpuStorage, DType, Layout, Result, WithDType};
pub use candle_kernels as kernels;
pub use cudarc;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
@ -34,6 +34,21 @@ impl<T: DeviceRepr> SlicePtrOrNull<T> {
}
}
impl crate::scalar::Scalar {
pub fn builder_arg<'a, 'b: 'a>(&'b self, builder: &mut cudarc::driver::LaunchArgs<'a>) {
use crate::scalar::Scalar;
match self {
Scalar::U8(v) => builder.arg(v),
Scalar::U32(v) => builder.arg(v),
Scalar::I64(v) => builder.arg(v),
Scalar::F32(v) => builder.arg(v),
Scalar::F64(v) => builder.arg(v),
Scalar::F16(v) => builder.arg(v),
Scalar::BF16(v) => builder.arg(v),
};
}
}
impl SlicePtrOrNull<usize> {
pub fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
let ds = if l.is_contiguous() {
@ -395,7 +410,7 @@ impl Map1 for IndexSelect<'_> {
CudaStorageSlice::U8(slice) => ("is_u8", slice_ptr(slice, ids_l.start_offset())),
CudaStorageSlice::I64(slice) => ("is_i64", slice_ptr(slice, ids_l.start_offset())),
_ => Err(CudaError::UnexpectedDType {
msg: "index_select ids should be u8 or u32",
msg: "index_select ids should be u8, u32, or i64",
expected: DType::U32,
got: self.0.dtype(),
})
@ -492,7 +507,7 @@ impl Map2InPlace for IndexAdd<'_> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
dst_shape: &Shape,
dst_l: &Layout,
src: &CudaSlice<T>,
src_l: &Layout,
dev: &CudaDevice,
@ -514,6 +529,10 @@ impl Map2InPlace for IndexAdd<'_> {
got: ids.dtype(),
})?,
};
let dst = match dst_l.contiguous_offsets() {
Some((o1, o2)) => dst.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
};
let src = match src_l.contiguous_offsets() {
Some((o1, o2)) => src.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
@ -521,7 +540,7 @@ impl Map2InPlace for IndexAdd<'_> {
let left_sz: usize = src_l.dims()[..dim].iter().product();
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
let src_dim_sz = src_l.dims()[dim];
let dst_dim_sz = dst_shape.dims()[dim];
let dst_dim_sz = dst_l.dims()[dim];
let ids_dim_sz = ids_l.dims()[0];
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
@ -529,7 +548,59 @@ impl Map2InPlace for IndexAdd<'_> {
barg!(builder, ids);
barg!(builder, ids_dim_sz);
builder.arg(&src);
builder.arg(dst);
builder.arg(&dst);
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
Ok(())
}
}
struct Scatter<'a>(&'a CudaStorage, &'a Layout, usize);
impl Map2InPlace for Scatter<'_> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
dst_l: &Layout,
src: &CudaSlice<T>,
src_l: &Layout,
dev: &CudaDevice,
) -> Result<()> {
let ids = &self.0;
let ids_l = &self.1;
let dim = self.2;
let (ids_o1, _) = match ids_l.contiguous_offsets() {
Some(o12) => o12,
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
};
let (name, (ids, _guard)) = match &ids.slice {
CudaStorageSlice::U32(slice) => ("s_u32", slice_ptr(slice, ids_o1)),
CudaStorageSlice::I64(slice) => ("s_i64", slice_ptr(slice, ids_o1)),
CudaStorageSlice::U8(slice) => ("s_u8", slice_ptr(slice, ids_o1)),
_ => Err(CudaError::UnexpectedDType {
msg: "scatter ids should be u8/u32/i64",
expected: DType::U32,
got: ids.dtype(),
})?,
};
let dst = match dst_l.contiguous_offsets() {
Some((o1, o2)) => dst.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
};
let src = match src_l.contiguous_offsets() {
Some((o1, o2)) => src.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
};
let left_sz: usize = src_l.dims()[..dim].iter().product();
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
let src_dim_sz = src_l.dims()[dim];
let dst_dim_sz = dst_l.dims()[dim];
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
let mut builder = func.builder();
barg!(builder, ids);
builder.arg(&src);
builder.arg(&dst);
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
@ -542,7 +613,7 @@ impl Map2InPlace for ScatterAdd<'_> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
dst_shape: &Shape,
dst_l: &Layout,
src: &CudaSlice<T>,
src_l: &Layout,
dev: &CudaDevice,
@ -564,6 +635,10 @@ impl Map2InPlace for ScatterAdd<'_> {
got: ids.dtype(),
})?,
};
let dst = match dst_l.contiguous_offsets() {
Some((o1, o2)) => dst.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
};
let src = match src_l.contiguous_offsets() {
Some((o1, o2)) => src.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
@ -571,13 +646,13 @@ impl Map2InPlace for ScatterAdd<'_> {
let left_sz: usize = src_l.dims()[..dim].iter().product();
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
let src_dim_sz = src_l.dims()[dim];
let dst_dim_sz = dst_shape.dims()[dim];
let dst_dim_sz = dst_l.dims()[dim];
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
let mut builder = func.builder();
barg!(builder, ids);
builder.arg(&src);
builder.arg(dst);
builder.arg(&dst);
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
@ -1235,6 +1310,36 @@ impl BackendStorage for CudaStorage {
&self.device
}
fn const_set(&mut self, s: crate::scalar::Scalar, layout: &Layout) -> Result<()> {
let dev = &self.device;
let shape = layout.shape();
let dims = shape.dims();
let el_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
let src_o = layout.start_offset();
let ((src, _guard_src), kernel_name) = match &mut self.slice {
S::U8(s) => (slice_ptr(s, src_o), "const_set_u8"),
S::U32(s) => (slice_ptr(s, src_o), "const_set_u32"),
S::I64(s) => (slice_ptr(s, src_o), "const_set_i64"),
S::BF16(s) => (slice_ptr(s, src_o), "const_set_bf16"),
S::F16(s) => (slice_ptr(s, src_o), "const_set_f16"),
S::F32(s) => (slice_ptr(s, src_o), "const_set_f32"),
S::F64(s) => (slice_ptr(s, src_o), "const_set_f64"),
};
let func = dev.get_or_load_func(kernel_name, &kernels::FILL)?;
let mut builder = func.builder();
barg!(builder, el_count);
barg!(builder, dims.len());
ds.builder_arg(&mut builder);
s.builder_arg(&mut builder);
barg!(builder, src);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
Ok(())
}
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
let shape = layout.shape();
let dims = shape.dims();
@ -1793,20 +1898,29 @@ impl BackendStorage for CudaStorage {
let slice = Gather(ids, ids_l, dim).map(&self.slice, &device, l)?;
Ok(Self { slice, device })
}
fn scatter_add(
&self,
fn scatter_set(
&mut self,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<Self> {
) -> Result<()> {
let device = self.device().clone();
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
self.copy_strided_src(&mut acc, 0, l)?;
ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
Ok(acc)
Scatter(ids, ids_l, dim).map(&mut self.slice, l, &src.slice, src_l, &device)
}
fn scatter_add_set(
&mut self,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<()> {
let device = self.device().clone();
ScatterAdd(ids, ids_l, dim).map(&mut self.slice, l, &src.slice, src_l, &device)
}
fn index_add(
&self,
@ -1820,7 +1934,7 @@ impl BackendStorage for CudaStorage {
let device = self.device().clone();
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
self.copy_strided_src(&mut acc, 0, l)?;
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l, &src.slice, src_l, &device)?;
Ok(acc)
}

View File

@ -1,5 +1,5 @@
/// Helper functions to plug cuda kernels in candle.
use crate::{Layout, Result, Shape, WithDType};
use crate::{Layout, Result, WithDType};
pub use cudarc;
use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits};
@ -96,7 +96,7 @@ pub trait Map2InPlace {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
dst_shape: &Shape,
dst_l: &Layout,
src: &CudaSlice<T>,
src_l: &Layout,
dev: &CudaDevice,
@ -105,19 +105,19 @@ pub trait Map2InPlace {
fn map(
&self,
dst: &mut S,
dst_s: &Shape,
dst_l: &Layout,
src: &S,
src_l: &Layout,
d: &CudaDevice,
) -> Result<()> {
match (dst, src) {
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
(S::U8(dst), S::U8(src)) => self.f(dst, dst_l, src, src_l, d),
(S::U32(dst), S::U32(src)) => self.f(dst, dst_l, src, src_l, d),
(S::I64(dst), S::I64(src)) => self.f(dst, dst_l, src, src_l, d),
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_l, src, src_l, d),
(S::F16(dst), S::F16(src)) => self.f(dst, dst_l, src, src_l, d),
(S::F32(dst), S::F32(src)) => self.f(dst, dst_l, src, src_l, d),
(S::F64(dst), S::F64(src)) => self.f(dst, dst_l, src, src_l, d),
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
}
}

View File

@ -103,7 +103,63 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4:
}
}
impl<S: NdArray> NdArray for Vec<S> {
impl<S: WithDType> NdArray for Vec<S> {
fn shape(&self) -> Result<Shape> {
Ok(Shape::from(self.len()))
}
fn to_cpu_storage(&self) -> CpuStorage {
S::to_cpu_storage(self.as_slice())
}
}
impl<S: WithDType> NdArray for Vec<&[S]> {
fn shape(&self) -> Result<Shape> {
if self.is_empty() {
crate::bail!("empty array")
}
let n = self.len();
let m = self[0].len();
for v in self.iter() {
if v.len() != m {
crate::bail!("two elements have different len {m} {}", v.len())
}
}
Ok(Shape::from((n, m)))
}
fn to_cpu_storage(&self) -> CpuStorage {
let data = self.iter().copied().flatten().copied().collect::<Vec<_>>();
S::to_cpu_storage_owned(data)
}
}
impl<S: WithDType> NdArray for Vec<Vec<S>> {
fn shape(&self) -> Result<Shape> {
if self.is_empty() {
crate::bail!("empty array")
}
let n = self.len();
let m = self[0].len();
for v in self.iter() {
if v.len() != m {
crate::bail!("two elements have different len {m} {}", v.len())
}
}
Ok(Shape::from((n, m)))
}
fn to_cpu_storage(&self) -> CpuStorage {
let len: usize = self.iter().map(|v| v.len()).sum();
let mut dst = Vec::with_capacity(len);
for v in self.iter() {
dst.extend(v.iter().copied());
}
S::to_cpu_storage_owned(dst)
}
}
impl<S: WithDType> NdArray for Vec<Vec<Vec<S>>> {
fn shape(&self) -> Result<Shape> {
if self.is_empty() {
crate::bail!("empty array")
@ -120,9 +176,57 @@ impl<S: NdArray> NdArray for Vec<S> {
}
fn to_cpu_storage(&self) -> CpuStorage {
// This allocates intermediary memory and shouldn't be necessary.
let storages = self.iter().map(|v| v.to_cpu_storage()).collect::<Vec<_>>();
CpuStorage::concat(storages.as_slice()).unwrap()
if self.is_empty() {
return S::to_cpu_storage_owned(vec![]);
}
let len: usize = self
.iter()
.map(|v| v.iter().map(|v| v.len()).sum::<usize>())
.sum();
let mut dst = Vec::with_capacity(len);
for v1 in self.iter() {
for v2 in v1.iter() {
dst.extend(v2.iter().copied());
}
}
S::to_cpu_storage_owned(dst)
}
}
impl<S: WithDType> NdArray for Vec<Vec<Vec<Vec<S>>>> {
fn shape(&self) -> Result<Shape> {
if self.is_empty() {
crate::bail!("empty array")
}
let shape0 = self[0].shape()?;
let n = self.len();
for v in self.iter() {
let shape = v.shape()?;
if shape != shape0 {
crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
}
}
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
}
fn to_cpu_storage(&self) -> CpuStorage {
let len: usize = self
.iter()
.map(|v| {
v.iter()
.map(|v| v.iter().map(|v| v.len()).sum::<usize>())
.sum::<usize>()
})
.sum();
let mut dst = Vec::with_capacity(len);
for v1 in self.iter() {
for v2 in v1.iter() {
for v3 in v2.iter() {
dst.extend(v3.iter().copied());
}
}
}
S::to_cpu_storage_owned(dst)
}
}
@ -292,23 +396,6 @@ impl Device {
self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
}
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self {
Device::Cpu => {
let storage = CpuDevice.ones_impl(shape, dtype)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Metal(storage))
}
}
}
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self {
Device::Cpu => {

View File

@ -107,6 +107,7 @@ pub trait WithDType:
fn from_f64(v: f64) -> Self;
fn to_f64(self) -> f64;
fn to_scalar(self) -> crate::scalar::Scalar;
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>;
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
@ -131,6 +132,10 @@ macro_rules! with_dtype {
$to_f64(self)
}
fn to_scalar(self) -> crate::scalar::Scalar {
crate::scalar::Scalar::$dtype(self)
}
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> {
CpuStorageRef::$dtype(data)
}

View File

@ -37,6 +37,10 @@ impl crate::backend::BackendStorage for CudaStorage {
fail!()
}
fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
Err(Error::NotCompiledWithCudaSupport)
}
@ -124,15 +128,27 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
fn scatter_add(
&self,
fn scatter_set(
&mut self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
fn scatter_add_set(
&mut self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
@ -214,10 +230,6 @@ impl crate::backend::BackendDevice for CudaDevice {
Err(Error::NotCompiledWithCudaSupport)
}
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}

View File

@ -41,6 +41,10 @@ impl crate::backend::BackendStorage for MetalStorage {
fail!()
}
fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
Err(Error::NotCompiledWithMetalSupport)
}
@ -128,15 +132,27 @@ impl crate::backend::BackendStorage for MetalStorage {
Err(Error::NotCompiledWithMetalSupport)
}
fn scatter_add(
&self,
fn scatter_set(
&mut self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
fn scatter_add_set(
&mut self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
@ -218,10 +234,6 @@ impl crate::backend::BackendDevice for MetalDevice {
Err(Error::NotCompiledWithMetalSupport)
}
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}

View File

@ -413,6 +413,100 @@ impl BackendStorage for MetalStorage {
self.binary(name, rhs, lhs_l, rhs_l)
}
fn const_set(&mut self, s: crate::scalar::Scalar, l: &Layout) -> Result<()> {
use crate::scalar::Scalar;
fn set<S: crate::WithDType + candle_metal_kernels::utils::EncoderParam>(
self_: &mut MetalStorage,
s: S,
l: &Layout,
) -> Result<()> {
let device = self_.device();
let dtype = self_.dtype;
let shape = l.shape();
let el_count = shape.elem_count();
let command_buffer = device.command_buffer()?;
command_buffer.set_label("const-set");
let dst = buffer_o(&self_.buffer, l, self_.dtype);
match (el_count % 2, dtype, l.is_contiguous()) {
(0, DType::BF16 | DType::F16, true) => {
use candle_metal_kernels::unary::contiguous_tiled;
let kernel_name = match dtype {
DType::F16 => contiguous_tiled::const_set::HALF,
DType::BF16 => contiguous_tiled::const_set::BFLOAT,
_ => crate::bail!("internal bug in const_set"),
};
candle_metal_kernels::call_const_set_contiguous_tiled(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
s,
dst,
)
.map_err(MetalError::from)?;
}
(_, _, true) => {
use candle_metal_kernels::unary::contiguous;
let kernel_name = match dtype {
DType::F16 => contiguous::const_set::HALF,
DType::BF16 => contiguous::const_set::BFLOAT,
DType::F32 => contiguous::const_set::FLOAT,
DType::I64 => contiguous::const_set::I64,
DType::U32 => contiguous::const_set::U32,
DType::U8 => contiguous::const_set::U8,
DType::F64 => crate::bail!("unsupported const-set f64"),
};
candle_metal_kernels::call_const_set_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
s,
dst,
)
.map_err(MetalError::from)?;
}
(_, _, false) => {
use candle_metal_kernels::unary::strided;
let kernel_name = match dtype {
DType::F16 => strided::const_set::HALF,
DType::BF16 => strided::const_set::BFLOAT,
DType::F32 => strided::const_set::FLOAT,
DType::I64 => strided::const_set::I64,
DType::U32 => strided::const_set::U32,
DType::U8 => strided::const_set::U8,
DType::F64 => crate::bail!("unsupported const-set f64"),
};
candle_metal_kernels::call_const_set_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
l.dims(),
s,
l.stride(),
dst,
)
.map_err(MetalError::from)?;
}
}
Ok(())
}
match (self.dtype, s) {
(DType::U8, Scalar::U8(s)) => set(self, s, l),
(DType::U32, Scalar::U32(s)) => set(self, s, l),
(DType::I64, Scalar::I64(s)) => set(self, s, l),
(DType::F16, Scalar::F16(s)) => set(self, s, l),
(DType::BF16, Scalar::BF16(s)) => set(self, s, l),
(DType::F32, Scalar::F32(s)) => set(self, s, l),
(DType::F64, Scalar::F64(s)) => set(self, s, l),
_ => crate::bail!("dtype mismatch, expected {:?}, got {:?}", self.dtype, s),
}
}
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
let device = self.device();
let shape = layout.shape();
@ -1332,18 +1426,65 @@ impl BackendStorage for MetalStorage {
Ok(Self::new(buffer, device.clone(), dst_el, dtype))
}
fn scatter_add(
&self,
fn scatter_set(
&mut self,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<Self> {
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
self.copy_strided_src(&mut acc, 0, l)?;
if !ids_l.is_contiguous() || !src_l.is_contiguous() {
) -> Result<()> {
if !l.is_contiguous() || !ids_l.is_contiguous() || !src_l.is_contiguous() {
return Err(crate::Error::RequiresContiguous { op: "scatter" }.bt());
};
let name = match (ids.dtype, self.dtype) {
(DType::U8, DType::F32) => "s_u8_f32",
(DType::U8, DType::F16) => "s_u8_f16",
(DType::U8, DType::BF16) => "s_u8_bf16",
(DType::U32, DType::U32) => "s_u32_u32",
(DType::U32, DType::F32) => "s_u32_f32",
(DType::U32, DType::F16) => "s_u32_f16",
(DType::U32, DType::BF16) => "s_u32_bf16",
(DType::I64, DType::F32) => "s_i64_f32",
(DType::I64, DType::F16) => "s_i64_f16",
(DType::I64, DType::BF16) => "s_i64_bf16",
_ => Err(MetalError::UnexpectedDType {
msg: "scatter ids should be u8/u32/i64",
expected: DType::U32,
got: ids.dtype(),
})?,
};
let command_buffer = self.device.command_buffer()?;
let dst = buffer_o(&self.buffer, l, self.dtype);
let src = buffer_o(&src.buffer, src_l, src.dtype);
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
candle_metal_kernels::call_scatter(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
src_l.dims(),
l.dims(),
dim,
src,
ids,
dst,
)
.map_err(MetalError::from)?;
Ok(())
}
fn scatter_add_set(
&mut self,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<()> {
if !l.is_contiguous() || !ids_l.is_contiguous() || !src_l.is_contiguous() {
return Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt());
};
let name = match (ids.dtype, self.dtype) {
@ -1364,9 +1505,10 @@ impl BackendStorage for MetalStorage {
})?,
};
let command_buffer = self.device.command_buffer()?;
let dst = buffer_o(&self.buffer, l, self.dtype);
let src = buffer_o(&src.buffer, src_l, src.dtype);
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
candle_metal_kernels::call_scatter_add(
candle_metal_kernels::call_scatter(
&self.device.device,
&command_buffer,
&self.device.kernels,
@ -1376,10 +1518,10 @@ impl BackendStorage for MetalStorage {
dim,
src,
ids,
&acc.buffer,
dst,
)
.map_err(MetalError::from)?;
Ok(acc)
Ok(())
}
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
@ -1513,50 +1655,32 @@ impl BackendStorage for MetalStorage {
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("matmul");
if self.dtype == DType::BF16 {
candle_metal_kernels::call_mlx_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
candle_metal_kernels::GemmDType::BF16,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
} else {
let dtype = match self.dtype {
DType::F32 => candle_metal_kernels::GemmDType::F32,
DType::F16 => candle_metal_kernels::GemmDType::F16,
DType::BF16 => candle_metal_kernels::GemmDType::BF16,
dtype => {
return Err(MetalError::Message(format!(
"mlx matmul doesn't support {dtype:?}"
))
.into())
}
};
candle_metal_kernels::call_mlx_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
dtype,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
}
let dtype = match self.dtype {
DType::F32 => candle_metal_kernels::GemmDType::F32,
DType::F16 => candle_metal_kernels::GemmDType::F16,
DType::BF16 => candle_metal_kernels::GemmDType::BF16,
dtype => {
return Err(
MetalError::Message(format!("mlx matmul doesn't support {dtype:?}")).into(),
)
}
};
candle_metal_kernels::call_mlx_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
dtype,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(
buffer,
self.device.clone(),
@ -1965,40 +2089,6 @@ impl BackendDevice for MetalDevice {
))
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
let name = match dtype {
DType::U8 => "fill_u8",
DType::U32 => "fill_u32",
DType::I64 => "fill_i64",
DType::F16 => "fill_f16",
DType::BF16 => "fill_bf16",
DType::F32 => "fill_f32",
DType::F64 => {
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
return self.storage_from_cpu_storage(&cpu_storage);
}
};
let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?;
let command_buffer = self.command_buffer()?;
candle_metal_kernels::call_const_fill(
&self.device,
&command_buffer,
&self.kernels,
name,
shape.elem_count(),
&buffer,
1.,
)
.map_err(MetalError::from)?;
Ok(MetalStorage::new(
buffer,
self.clone(),
shape.elem_count(),
dtype,
))
}
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
let (count, buffer) = match T::cpu_storage_ref(s) {
CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),

View File

@ -80,6 +80,7 @@ pub enum Op {
Reduce(Tensor, ReduceOp, Vec<usize>),
Matmul(Tensor, Tensor),
Gather(Tensor, Tensor, usize),
Scatter(Tensor, Tensor, Tensor, usize),
ScatterAdd(Tensor, Tensor, Tensor, usize),
IndexSelect(Tensor, Tensor, usize),
IndexAdd(Tensor, Tensor, Tensor, usize),

View File

@ -1,6 +1,74 @@
//! TensorScalar Enum and Trait
//!
use crate::{Result, Tensor, WithDType};
use crate::{DType, Result, Tensor, WithDType};
use half::{bf16, f16};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Scalar {
U8(u8),
U32(u32),
I64(i64),
BF16(bf16),
F16(f16),
F32(f32),
F64(f64),
}
impl<T: WithDType> From<T> for Scalar {
fn from(value: T) -> Self {
value.to_scalar()
}
}
impl Scalar {
pub fn zero(dtype: DType) -> Self {
match dtype {
DType::U8 => Scalar::U8(0),
DType::U32 => Scalar::U32(0),
DType::I64 => Scalar::I64(0),
DType::BF16 => Scalar::BF16(bf16::ZERO),
DType::F16 => Scalar::F16(f16::ZERO),
DType::F32 => Scalar::F32(0.0),
DType::F64 => Scalar::F64(0.0),
}
}
pub fn one(dtype: DType) -> Self {
match dtype {
DType::U8 => Scalar::U8(1),
DType::U32 => Scalar::U32(1),
DType::I64 => Scalar::I64(1),
DType::BF16 => Scalar::BF16(bf16::ONE),
DType::F16 => Scalar::F16(f16::ONE),
DType::F32 => Scalar::F32(1.0),
DType::F64 => Scalar::F64(1.0),
}
}
pub fn dtype(&self) -> DType {
match self {
Scalar::U8(_) => DType::U8,
Scalar::U32(_) => DType::U32,
Scalar::I64(_) => DType::I64,
Scalar::BF16(_) => DType::BF16,
Scalar::F16(_) => DType::F16,
Scalar::F32(_) => DType::F32,
Scalar::F64(_) => DType::F64,
}
}
pub fn to_f64(&self) -> f64 {
match self {
Scalar::U8(v) => *v as f64,
Scalar::U32(v) => *v as f64,
Scalar::I64(v) => *v as f64,
Scalar::BF16(v) => v.to_f64(),
Scalar::F16(v) => v.to_f64(),
Scalar::F32(v) => *v as f64,
Scalar::F64(v) => *v,
}
}
}
pub enum TensorScalar {
Tensor(Tensor),

View File

@ -1,5 +1,6 @@
use crate::backend::BackendStorage;
use crate::op::{self, CmpOp, ReduceOp};
use crate::scalar::Scalar;
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
@ -73,6 +74,14 @@ impl Storage {
}
}
pub(crate) fn const_set(&mut self, v: Scalar, l: &Layout) -> Result<()> {
match self {
Storage::Cpu(storage) => storage.const_set(v, l),
Storage::Cuda(storage) => storage.const_set(v, l),
Storage::Metal(storage) => storage.const_set(v, l),
}
}
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
@ -619,32 +628,56 @@ impl Storage {
}
}
pub(crate) fn scatter_add(
&self,
pub(crate) fn scatter_set(
&mut self,
l: &Layout,
indexes: &Self,
indexes_l: &Layout,
source: &Self,
source_l: &Layout,
d: usize,
) -> Result<Self> {
) -> Result<()> {
self.same_device(indexes, "scatter-set")?;
self.same_device(source, "scatter-set")?;
match (self, indexes, source) {
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
}
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
}
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
}
_ => unreachable!(),
}
Ok(())
}
pub(crate) fn scatter_add(
&mut self,
l: &Layout,
indexes: &Self,
indexes_l: &Layout,
source: &Self,
source_l: &Layout,
d: usize,
) -> Result<()> {
self.same_device(indexes, "scatter-add")?;
self.same_device(source, "scatter-add")?;
match (self, indexes, source) {
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cpu(storage))
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
}
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cuda(storage))
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
}
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Metal(storage))
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
}
_ => unreachable!(),
}
Ok(())
}
pub(crate) fn index_add(

View File

@ -3,7 +3,7 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
use crate::scalar::TensorOrScalar;
use crate::shape::{Dim, Dims};
use crate::shape::{Dim, Dims, ShapeWithOneHole};
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::{Arc, RwLock};
@ -185,7 +185,9 @@ impl Tensor {
) -> Result<Self> {
let none = BackpropOp::none();
let shape = shape.into();
let storage = device.ones(&shape, dtype)?;
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
let layout = Layout::contiguous(shape.clone());
storage.const_set(crate::scalar::Scalar::one(dtype), &layout)?;
Ok(from_storage(storage, shape, none, is_variable))
}
@ -202,6 +204,18 @@ impl Tensor {
Self::ones_impl(shape, dtype, device, false)
}
pub fn const_set(&self, value: crate::scalar::Scalar) -> Result<()> {
self.storage_mut().const_set(value, self.layout())
}
pub fn zero_set(&self) -> Result<()> {
self.const_set(crate::scalar::Scalar::zero(self.dtype()))
}
pub fn one_set(&self) -> Result<()> {
self.const_set(crate::scalar::Scalar::one(self.dtype()))
}
/// Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor.
///
/// ```rust
@ -368,8 +382,7 @@ impl Tensor {
Self::new_impl(array, shape, device, false)
}
/// Returns a new tensor with all the elements having the same specified value. Note that
/// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
/// Returns a new tensor with all the elements having the same specified value.
///```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::full(3.5, (2, 4), &Device::Cpu)?;
@ -384,7 +397,12 @@ impl Tensor {
shape: S,
device: &Device,
) -> Result<Self> {
Self::from_vec_impl(vec![value], (), device, false)?.broadcast_as(shape)
let none = BackpropOp::none();
let shape = shape.into();
let mut storage = unsafe { device.alloc_uninit(&shape, D::DTYPE)? };
let layout = Layout::contiguous(shape.clone());
storage.const_set(value.to_scalar(), &layout)?;
Ok(from_storage(storage, shape, none, false))
}
/// Creates a new 1D tensor from an iterator.
@ -452,17 +470,13 @@ impl Tensor {
Self::from_vec_impl(data, len, device, false)
}
pub(crate) fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
pub(crate) fn from_vec_impl<S: ShapeWithOneHole, D: crate::WithDType>(
data: Vec<D>,
shape: S,
device: &Device,
is_variable: bool,
) -> Result<Self> {
let shape = shape.into();
let buffer_size = data.len();
if buffer_size != shape.elem_count() {
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
}
let shape = shape.into_shape(data.len())?;
let storage = device.storage_owned(data)?;
let none = BackpropOp::none();
Ok(from_storage(storage, shape, none, is_variable))
@ -481,7 +495,7 @@ impl Tensor {
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
pub fn from_vec<S: ShapeWithOneHole, D: crate::WithDType>(
data: Vec<D>,
shape: S,
device: &Device,
@ -502,17 +516,12 @@ impl Tensor {
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
pub fn from_slice<S: ShapeWithOneHole, D: crate::WithDType>(
array: &[D],
shape: S,
device: &Device,
) -> Result<Self> {
let shape = shape.into();
let n: usize = shape.elem_count();
let buffer_size: usize = array.len();
if buffer_size != n {
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
}
let shape = shape.into_shape(array.len())?;
let storage = device.storage_from_slice(array)?;
let none = BackpropOp::none();
Ok(from_storage(storage, shape, none, false))
@ -1349,8 +1358,7 @@ impl Tensor {
self.index_select(ids, 0)
}
pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "scatter-add")?;
fn scatter_checks(&self, indexes: &Self, source: &Self, dim: usize) -> Result<()> {
let source_dims = source.dims();
let self_dims = self.dims();
let mismatch = if source_dims.len() != self_dims.len() {
@ -1367,7 +1375,7 @@ impl Tensor {
};
if mismatch {
Err(Error::ShapeMismatchBinaryOp {
op: "scatter-add (self, src)",
op: "scatter (self, src)",
lhs: self.shape().clone(),
rhs: source.shape().clone(),
}
@ -1375,13 +1383,44 @@ impl Tensor {
}
if indexes.dims() != source.dims() {
Err(Error::ShapeMismatchBinaryOp {
op: "scatter-add (indexes, src)",
op: "scatter (indexes, src)",
lhs: indexes.shape().clone(),
rhs: source.shape().clone(),
}
.bt())?
}
let storage = self.storage().scatter_add(
Ok(())
}
pub fn scatter<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "scatter")?;
self.scatter_checks(indexes, source, dim)?;
let shape = self.shape();
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
let layout = Layout::contiguous(shape);
storage.scatter_set(
&layout,
&indexes.storage(),
indexes.layout(),
&source.storage(),
source.layout(),
dim,
)?;
let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
Op::Scatter(t1, t2, t3, dim)
});
Ok(from_storage(storage, self.shape(), op, false))
}
pub fn scatter_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
if self.same_storage(source) {
crate::bail!("cannot use slice_set when self and src share their storage")
}
let dim = dim.to_index(self.shape(), "scatter-set")?;
self.scatter_checks(indexes, source, dim)?;
self.storage_mut().scatter_set(
self.layout(),
&indexes.storage(),
indexes.layout(),
@ -1389,12 +1428,48 @@ impl Tensor {
source.layout(),
dim,
)?;
Ok(())
}
pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "scatter-add")?;
self.scatter_checks(indexes, source, dim)?;
let shape = self.shape();
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
let layout = Layout::contiguous(shape);
storage.scatter_add(
&layout,
&indexes.storage(),
indexes.layout(),
&source.storage(),
source.layout(),
dim,
)?;
let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
Op::ScatterAdd(t1, t2, t3, dim)
});
Ok(from_storage(storage, self.shape(), op, false))
}
pub fn scatter_add_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
if self.same_storage(source) {
crate::bail!("cannot use slice_set when self and src share their storage")
}
let dim = dim.to_index(self.shape(), "scatter-add-set")?;
self.scatter_checks(indexes, source, dim)?;
self.storage_mut().scatter_add(
self.layout(),
&indexes.storage(),
indexes.layout(),
&source.storage(),
source.layout(),
dim,
)?;
Ok(())
}
/// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
let dim = dim.to_index(self.shape(), "slice-scatter")?;
@ -2197,7 +2272,7 @@ impl Tensor {
///
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn reshape<S: crate::shape::ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
pub fn reshape<S: ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
let shape = s.into_shape(self.elem_count())?;
if shape.elem_count() != self.elem_count() {
return Err(Error::ShapeMismatchBinaryOp {

View File

@ -241,7 +241,7 @@ impl Tensor {
/// `self` and `src` must have the same shape except on dimension `dim` where the `self` size
/// has to be greater than or equal to `offset` plus the `src` size.
///
/// Note that this modifies `self` in place and as such is not compatibel with
/// Note that this modifies `self` in place and as such is not compatible with
/// back-propagation.
pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
let dim = dim.to_index(self.shape(), "slice-set")?;

View File

@ -25,10 +25,12 @@ fn ones(device: &Device) -> Result<()> {
Tensor::ones((2, 3), DType::F32, device)?.to_vec2::<f32>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
assert_eq!(
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
if !device.is_metal() {
assert_eq!(
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
}
assert_eq!(
Tensor::ones((2, 3), DType::F16, device)?.to_vec2::<half::f16>()?,
[
@ -63,6 +65,26 @@ fn ones(device: &Device) -> Result<()> {
}
fn full(device: &Device) -> Result<()> {
let tensor = Tensor::zeros((3, 4), DType::U32, device)?;
tensor.const_set(42u32.into())?;
assert_eq!(
tensor.to_vec2::<u32>()?,
[[42, 42, 42, 42], [42, 42, 42, 42], [42, 42, 42, 42]]
);
tensor.i((.., 2))?.const_set(1337u32.into())?;
assert_eq!(
tensor.to_vec2::<u32>()?,
[[42, 42, 1337, 42], [42, 42, 1337, 42], [42, 42, 1337, 42]]
);
tensor.i((2, ..))?.const_set(1u32.into())?;
assert_eq!(
tensor.to_vec2::<u32>()?,
[[42, 42, 1337, 42], [42, 42, 1337, 42], [1, 1, 1, 1]]
);
Ok(())
}
fn const_set(device: &Device) -> Result<()> {
assert_eq!(
Tensor::full(42u32, (2, 3), device)?.to_vec2::<u32>()?,
[[42, 42, 42], [42, 42, 42]],
@ -826,6 +848,31 @@ fn embeddings(device: &Device) -> Result<()> {
Ok(())
}
#[test]
fn index_select_fail() -> Result<()> {
// Check that an error is properly reported on out of bounds.
let ids = Tensor::new(&[4u32, 2u32, 1u32], &Device::Cpu)?;
let t = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], &Device::Cpu)?;
let hs = t.index_select(&ids, 0);
assert!(hs.is_err());
Ok(())
}
// The test below triggers an unwinding panic as there is a panic within the
// #[cfg(feature = "cuda")]
// #[test]
// #[should_panic]
// fn index_select_fail_gpu() {
// // Check that a panic happens for out of bounds in cuda
// if let Ok(device) = Device::new_cuda(0) {
// if let Ok(ids) = Tensor::new(&[4u32, 2u32, 1u32], &device) {
// if let Ok(t) = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], &device) {
// let _ = t.index_select(&ids, 0);
// }
// }
// }
// }
fn cmp(device: &Device) -> Result<()> {
let t1 = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?;
let t2 = Tensor::new(&[[1f32, 0f32], [3f32, 3f32], [4f32, 7f32]], device)?;
@ -980,7 +1027,7 @@ fn slice_scatter(device: &Device) -> Result<()> {
Ok(())
}
fn scatter_add(device: &Device) -> Result<()> {
fn scatter(device: &Device) -> Result<()> {
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
assert_eq!(
t.to_vec2::<f32>()?,
@ -1004,6 +1051,17 @@ fn scatter_add(device: &Device) -> Result<()> {
]
);
let hs = init.scatter(&ids, &t, 1)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0, 1.0, 1.0],
[5.0, 1.0, 1.0, 3.0, 4.0],
[1.0, 8.0, 1.0, 7.0, 1.0],
[10.0, 1.0, 9.0, 1.0, 11.0]
]
);
let init = Tensor::ones((6, 3), DType::F32, device)?;
let hs = init.scatter_add(&ids, &t, 0)?;
assert_eq!(
@ -1017,6 +1075,30 @@ fn scatter_add(device: &Device) -> Result<()> {
[1.0, 1.0, 1.0]
]
);
let hs = init.scatter(&ids, &t, 0)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[0.0, 10.0, 5.0],
[1.0, 1.0, 8.0],
[9.0, 1.0, 2.0],
[6.0, 7.0, 1.0],
[1.0, 4.0, 11.0],
[1.0, 1.0, 1.0]
]
);
init.scatter_set(&ids, &t, 0)?;
assert_eq!(
init.to_vec2::<f32>()?,
&[
[0.0, 10.0, 5.0],
[1.0, 1.0, 8.0],
[9.0, 1.0, 2.0],
[6.0, 7.0, 1.0],
[1.0, 4.0, 11.0],
[1.0, 1.0, 1.0]
]
);
Ok(())
}
@ -1484,6 +1566,7 @@ fn zero_dim(device: &Device) -> Result<()> {
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
test_device!(full, full_cpu, full_gpu, full_metal);
test_device!(const_set, cs_cpu, cs_gpu, cs_metal);
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
@ -1515,12 +1598,7 @@ test_device!(
);
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
test_device!(
scatter_add,
scatter_add_cpu,
scatter_add_gpu,
scatter_add_metal
);
test_device!(scatter, scatter_cpu, scatter_gpu, scatter_metal);
test_device!(
slice_scatter,
slice_scatter_cpu,
@ -1733,3 +1811,26 @@ fn test_flip_3d_channels() -> Result<()> {
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
Ok(())
}
#[test]
fn tensor_new() -> Result<()> {
let t1 = Tensor::new(vec![1f32, 2.0, 3.0], &Device::Cpu)?;
assert_eq!(t1.to_vec1::<f32>()?, [1.0, 2.0, 3.0]);
let t2 = Tensor::new(vec![vec![1f32, 2., 3.], vec![4., 5., 6.]], &Device::Cpu)?;
assert_eq!(t2.to_vec2::<f32>()?, [[1., 2., 3.], [4., 5., 6.]]);
let t3 = Tensor::new(
vec![
vec![vec![1f32, 2., 3.], vec![4., 5., 6.]],
vec![vec![3f32, 1., 4.], vec![1., 5., 9.]],
],
&Device::Cpu,
)?;
assert_eq!(
t3.to_vec3::<f32>()?,
[
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
[[3.0, 1.0, 4.0], [1.0, 5.0, 9.0]]
]
);
Ok(())
}

View File

@ -124,6 +124,17 @@ impl TextGeneration {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
Some(token) => token,
None => {
println!(
"Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup"
);
eos_token
}
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
@ -146,7 +157,7 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
if next_token == eos_token || next_token == eot_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
@ -350,6 +361,31 @@ fn main() -> Result<()> {
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
let prompt = match args.which {
Which::Base2B
| Which::Base7B
| Which::Instruct2B
| Which::Instruct7B
| Which::InstructV1_1_2B
| Which::InstructV1_1_7B
| Which::CodeBase2B
| Which::CodeBase7B
| Which::CodeInstruct2B
| Which::CodeInstruct7B
| Which::BaseV2_2B
| Which::InstructV2_2B
| Which::BaseV2_9B
| Which::InstructV2_9B
| Which::BaseV3_1B => args.prompt,
Which::InstructV3_1B => {
format!(
"<start_of_turn> user\n{}<end_of_turn>\n<start_of_turn> model\n",
args.prompt
)
}
};
pipeline.run(&prompt, args.sample_len)?;
Ok(())
}

View File

@ -7,7 +7,10 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::helium::{Config, Model};
use candle_transformers::models::helium::{Config as ConfigPreview, Model as ModelPreview};
use candle_transformers::models::llama::{
Cache as CacheV1, Llama as ModelV1, LlamaConfig as ConfigV1, LlamaEosToks,
};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
@ -16,6 +19,44 @@ use candle_transformers::generation::{LogitsProcessor, Sampling};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
#[derive(Debug, Clone)]
enum Model {
V1 { model: ModelV1, cache: CacheV1 },
Preview(ModelPreview),
}
impl Model {
fn forward(&mut self, input: &Tensor, start_pos: usize) -> Result<Tensor> {
let model = match self {
Model::V1 { model, cache } => model.forward(input, start_pos, cache)?,
Model::Preview(m) => m.forward(input, start_pos)?,
};
Ok(model)
}
}
#[derive(Debug, Clone)]
enum Config {
V1(ConfigV1),
Preview(ConfigPreview),
}
impl Config {
fn bos_token_id(&self) -> Option<u32> {
match self {
Config::V1(c) => c.bos_token_id,
Config::Preview(c) => Some(c.bos_token_id),
}
}
fn eos_token_id(&self) -> Option<LlamaEosToks> {
match self {
Config::V1(c) => c.eos_token_id.clone(),
Config::Preview(c) => Some(LlamaEosToks::Single(c.eos_token_id)),
}
}
}
struct TextGeneration {
model: Model,
device: Device,
@ -106,7 +147,15 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == self.config.bos_token_id || next_token == self.config.eos_token_id {
let is_eos = self
.config
.eos_token_id()
.as_ref()
.is_some_and(|v| match v {
LlamaEosToks::Single(eos) => *eos == next_token,
LlamaEosToks::Multiple(eos) => eos.contains(&next_token),
});
if Some(next_token) == self.config.bos_token_id() || is_eos {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
@ -131,6 +180,8 @@ impl TextGeneration {
enum Which {
#[value(name = "v1-preview")]
V1Preview,
#[value(name = "v1")]
V1,
}
#[derive(Parser, Debug)]
@ -144,9 +195,6 @@ struct Args {
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
#[arg(long)]
prompt: String,
@ -171,7 +219,7 @@ struct Args {
sample_len: usize,
/// The model size to use.
#[arg(long, default_value = "v1-preview")]
#[arg(long, default_value = "v1")]
which: Which,
#[arg(long)]
@ -230,6 +278,7 @@ fn main() -> Result<()> {
None => {
let name = match args.which {
Which::V1Preview => "kyutai/helium-1-preview-2b",
Which::V1 => "kyutai/helium-1-2b",
};
name.to_string()
}
@ -254,18 +303,27 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config: Config = match args.config {
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
None => {
let config_file = repo.get("config.json")?;
serde_json::from_slice(&std::fs::read(config_file)?)?
}
let config_file = match args.config {
Some(config_file) => std::path::PathBuf::from(config_file),
None => repo.get("config.json")?,
};
let config = match args.which {
Which::V1Preview => Config::Preview(serde_json::from_slice(&std::fs::read(config_file)?)?),
Which::V1 => Config::V1(serde_json::from_slice(&std::fs::read(config_file)?)?),
};
let device = candle_examples::device(args.cpu)?;
let (model, device) = {
let dtype = device.bf16_default_to_f32();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
let model = match &config {
Config::V1(c) => {
let c = c.clone().into_config(false);
let model = ModelV1::load(vb, &c)?;
let cache = CacheV1::new(true, dtype, &c, &device)?;
Model::V1 { model, cache }
}
Config::Preview(c) => Model::Preview(ModelPreview::new(c, vb)?),
};
(model, device)
};

View File

@ -0,0 +1,18 @@
# candle-quantized-gemma
Candle implementation of quantized Gemma.
## Running an example
```bash
$ cargo run --example quantized-gemma -- --prompt "Write a function to calculate fibonacci numbers. "
> ```python
> def fibonacci(n):
> """Calculates the nth Fibonacci number using recursion."""
> if n <= 1:
> return n
> else:
> return fibonacci(n-1) + fibonacci(n-2
> ```
```

View File

@ -0,0 +1,344 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use std::io::Write;
use tokenizers::Tokenizer;
use candle::quantized::gguf_file;
use candle::Tensor;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_transformers::models::quantized_gemma3::ModelWeights;
const DEFAULT_PROMPT: &str = "Write a function to calculate fibonacci num";
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
#[value(name = "gemma3-4b-it")]
Gemma3_4bIt,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// GGUF file to load, typically a .gguf file generated by quantization
#[arg(long)]
model: Option<String>,
/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
/// is preserved.
#[arg(long)]
prompt: Option<String>,
/// The length of the sample to generate (in tokens).
#[arg(short = 'n', long, default_value_t = 1000)]
sample_len: usize,
/// The tokenizer config in json format.
#[arg(long)]
tokenizer: Option<String>,
/// The temperature used to generate samples, use 0 for greedy sampling.
#[arg(long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Process prompt elements separately.
#[arg(long)]
split_prompt: bool,
/// Run on CPU rather than GPU even if a GPU is available.
#[arg(long)]
cpu: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model size to use.
#[arg(long, default_value = "gemma3-4b-it")]
which: Which,
}
impl Args {
fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
let tokenizer_path = match &self.tokenizer {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
let repo = "google/gemma-3-4b-it";
println!("DEBUG: Downloading tokenizer from {}", repo);
let api = api.model(repo.to_string());
api.get("tokenizer.json")?
}
};
println!("DEBUG: Loading tokenizer from {:?}", tokenizer_path);
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?;
Ok(tokenizer)
}
fn model(&self) -> anyhow::Result<std::path::PathBuf> {
let model_path = match &self.model {
Some(config) => std::path::PathBuf::from(config),
None => {
let (repo, filename) = match self.which {
Which::Gemma3_4bIt => (
"google/gemma-3-4b-it-qat-q4_0-gguf",
"gemma-3-4b-it-q4_0.gguf",
),
};
let api = hf_hub::api::sync::Api::new()?;
api.repo(hf_hub::Repo::with_revision(
repo.to_string(),
hf_hub::RepoType::Model,
"main".to_string(),
))
.get(filename)?
}
};
Ok(model_path)
}
}
fn format_size(size_in_bytes: usize) -> String {
if size_in_bytes < 1_000 {
format!("{}B", size_in_bytes)
} else if size_in_bytes < 1_000_000 {
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
} else if size_in_bytes < 1_000_000_000 {
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
} else {
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
}
}
#[derive(Debug)]
enum Prompt {
Interactive,
Chat,
One(String),
}
fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature, args.repeat_penalty, args.repeat_last_n
);
let model_path = args.model()?;
let mut file = std::fs::File::open(&model_path)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let mut model = {
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(&model_path))?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count();
total_size_in_bytes +=
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
}
println!(
"loaded {:?} tensors ({}) in {:.2}s",
model.tensor_infos.len(),
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
ModelWeights::from_gguf(model, &mut file, &device)?
};
println!("model built");
let tokenizer = args.tokenizer()?;
let mut tos = TokenOutputStream::new(tokenizer);
println!(
"DEBUG: Tokenizer vocabulary size: {}",
tos.tokenizer().get_vocab(true).len()
);
let prompt = match args.prompt.as_deref() {
Some("chat") => Prompt::Chat,
Some("interactive") => Prompt::Interactive,
Some(s) => Prompt::One(s.to_string()),
None => Prompt::One(DEFAULT_PROMPT.to_string()),
};
let mut pre_prompt_tokens = vec![];
for _ in 0.. {
let prompt_str = match &prompt {
Prompt::One(prompt) => prompt.clone(),
Prompt::Interactive | Prompt::Chat => {
print!("> ");
std::io::stdout().flush()?;
let mut prompt = String::new();
std::io::stdin().read_line(&mut prompt)?;
if prompt.ends_with('\n') {
prompt.pop();
if prompt.ends_with('\r') {
prompt.pop();
}
}
// Format for Gemma 3 chat/instruction format
format!("<start_of_turn> user\n{prompt}<end_of_turn>\n<start_of_turn> model\n")
}
};
print!("{}", &prompt_str);
let tokens = tos
.tokenizer()
.encode(prompt_str, true)
.map_err(anyhow::Error::msg)?;
let prompt_tokens = [&pre_prompt_tokens, tokens.get_ids()].concat();
let to_sample = args.sample_len.saturating_sub(1);
let max_seq_len = 8192; // Gemma 3 context length
let prompt_tokens = if prompt_tokens.len() + to_sample > max_seq_len - 10 {
let to_remove = prompt_tokens.len() + to_sample + 10 - max_seq_len;
prompt_tokens[prompt_tokens.len().saturating_sub(to_remove)..].to_vec()
} else {
prompt_tokens
};
let mut all_tokens = vec![];
let mut logits_processor = {
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k, args.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
let start_prompt_processing = std::time::Instant::now();
let mut next_token = if !args.split_prompt {
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)?
} else {
let mut next_token = 0;
for (pos, token) in prompt_tokens.iter().enumerate() {
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, pos)?;
let logits = logits.squeeze(0)?;
next_token = logits_processor.sample(&logits)?
}
next_token
};
let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
// For Gemma 3, use the correct end of sequence token
let eos_token = *tos
.tokenizer()
.get_vocab(true)
.get("<end_of_turn>")
.unwrap();
let start_post_prompt = std::time::Instant::now();
let mut sampled = 0;
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&all_tokens[start_at..],
)?
};
next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
sampled += 1;
if next_token == eos_token {
break;
};
}
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
let dt = start_post_prompt.elapsed();
println!(
"\n\n{:4} prompt tokens processed: {:.2} token/s",
prompt_tokens.len(),
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
);
println!(
"{sampled:4} tokens generated: {:.2} token/s",
sampled as f64 / dt.as_secs_f64(),
);
match prompt {
Prompt::One(_) => break,
Prompt::Interactive => {}
Prompt::Chat => {
pre_prompt_tokens = [prompt_tokens.as_slice(), all_tokens.as_slice()].concat()
}
}
}
Ok(())
}

View File

@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
version = "0.9.0-alpha.4"
version = "0.9.1"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.4" }
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.1" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]

View File

@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
version = "0.9.0-alpha.4"
version = "0.9.1"
edition = "2021"
description = "CUDA kernels for Candle"

View File

@ -1,5 +1,6 @@
#include<stdint.h>
#include "cuda_fp16.h"
#include "cuda_utils.cuh"
template<typename T>
__device__ void fill_with(T *buf, T value, const size_t numel) {
@ -36,13 +37,45 @@ COPY2D_OP(uint8_t, copy2d_u8)
COPY2D_OP(uint32_t, copy2d_u32)
COPY2D_OP(int64_t, copy2d_i64)
#define CONST_SET_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t numel, \
const size_t num_dims, \
const size_t *info, \
const TYPENAME inp, \
TYPENAME *out \
) { \
const size_t *dims = info; \
const size_t *strides = info + num_dims; \
if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
out[i] = inp; \
} \
} \
else { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
out[strided_i] = inp; \
} \
} \
} \
CONST_SET_OP(float, const_set_f32)
CONST_SET_OP(double, const_set_f64)
CONST_SET_OP(uint8_t, const_set_u8)
CONST_SET_OP(uint32_t, const_set_u32)
CONST_SET_OP(int64_t, const_set_i64)
#if __CUDA_ARCH__ >= 530
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
COPY2D_OP(__half, copy2d_f16)
CONST_SET_OP(__half, const_set_f16)
#endif
#if __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); }
COPY2D_OP(__nv_bfloat16, copy2d_bf16)
CONST_SET_OP(__nv_bfloat16, const_set_bf16)
#endif

View File

@ -23,6 +23,7 @@ __device__ void index_select(
unsigned int left_i = dst_i / (ids_dim_size * right_size);
unsigned int id_i = dst_i / right_size % ids_dim_size;
unsigned int right_i = dst_i % right_size;
assert(ids[id_i] < src_dim_size);
unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i;
unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides);
out[dst_i] = inp[strided_i];
@ -57,6 +58,7 @@ __device__ void gather(
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
size_t post = i % right_size;
size_t idx = ids[i];
assert(idx < src_dim_size);
size_t pre = i / (right_size * ids_dim_size);
size_t src_i = (pre * src_dim_size + idx) * right_size + post;
out[i] = inp[src_i];
@ -92,6 +94,7 @@ __device__ void index_add(
const size_t post = i % right_size;
for (unsigned int j = 0; j < ids_dim_size; ++j) {
const size_t idx = ids[j];
assert(idx < dst_dim_size);
const size_t src_i = (pre * ids_dim_size + j) * right_size + post;
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] += inp[src_i];
@ -111,6 +114,30 @@ extern "C" __global__ void FN_NAME( \
const size_t right_size \
) { index_add(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \
template<typename T, typename I>
__device__ void scatter(
const I *ids,
const T *inp,
T *out,
const size_t left_size,
const size_t src_dim_size,
const size_t dst_dim_size,
const size_t right_size
) {
const size_t numel = left_size * right_size;
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
const size_t pre = i / right_size;
const size_t post = i % right_size;
for (unsigned int j = 0; j < src_dim_size; ++j) {
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
const size_t idx = ids[src_i];
assert(idx < dst_dim_size);
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] = inp[src_i];
}
}
}
template<typename T, typename I>
__device__ void scatter_add(
const I *ids,
@ -128,12 +155,24 @@ __device__ void scatter_add(
for (unsigned int j = 0; j < src_dim_size; ++j) {
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
const size_t idx = ids[src_i];
assert(idx < dst_dim_size);
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] += inp[src_i];
}
}
}
#define S_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const INDEX_TYPENAME *ids, \
const TYPENAME *inp, \
TYPENAME *out, \
const size_t left_size, \
const size_t src_dim_size, \
const size_t dst_dim_size, \
const size_t right_size \
) { scatter(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \
#define SA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const INDEX_TYPENAME *ids, \
@ -159,6 +198,9 @@ IA_OP(__nv_bfloat16, uint8_t, ia_u8_bf16)
SA_OP(__nv_bfloat16, int64_t, sa_i64_bf16)
SA_OP(__nv_bfloat16, uint32_t, sa_u32_bf16)
SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16)
S_OP(__nv_bfloat16, int64_t, s_i64_bf16)
S_OP(__nv_bfloat16, uint32_t, s_u32_bf16)
S_OP(__nv_bfloat16, uint8_t, s_u8_bf16)
#endif
#if __CUDA_ARCH__ >= 530
@ -174,6 +216,9 @@ IA_OP(__half, uint8_t, ia_u8_f16)
SA_OP(__half, int64_t, sa_i64_f16)
SA_OP(__half, uint32_t, sa_u32_f16)
SA_OP(__half, uint8_t, sa_u8_f16)
S_OP(__half, int64_t, s_i64_f16)
S_OP(__half, uint32_t, s_u32_f16)
S_OP(__half, uint8_t, s_u8_f16)
#endif
IS_OP(float, int64_t, is_i64_f32)
@ -247,3 +292,21 @@ SA_OP(double, uint8_t, sa_u8_f64)
SA_OP(uint8_t, uint8_t, sa_u8_u8)
SA_OP(uint32_t, uint8_t, sa_u8_u32)
SA_OP(int64_t, uint8_t, sa_u8_i64)
S_OP(float, int64_t, s_i64_f32)
S_OP(double, int64_t, s_i64_f64)
S_OP(uint8_t, int64_t, s_i64_u8)
S_OP(int64_t, int64_t, s_i64_i64)
S_OP(uint32_t, int64_t, s_i64_u32)
S_OP(float, uint32_t, s_u32_f32)
S_OP(double, uint32_t, s_u32_f64)
S_OP(uint8_t, uint32_t, s_u32_u8)
S_OP(int64_t, uint32_t, s_u32_i64)
S_OP(uint32_t, uint32_t, s_u32_u32)
S_OP(float, uint8_t, s_u8_f32)
S_OP(double, uint8_t, s_u8_f64)
S_OP(uint8_t, uint8_t, s_u8_u8)
S_OP(uint32_t, uint8_t, s_u8_u32)
S_OP(int64_t, uint8_t, s_u8_i64)

View File

@ -219,11 +219,15 @@ __device__ void softmax(const T * x, T * dst, const int ncols) {
}
template <typename T>
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) {
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t stride_b) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (2 * idx >= bh * td) return;
uint32_t rope_idx = idx % (td / 2);
if (stride_b > 0) {
uint32_t b_idx = (2 * idx) / stride_b;
rope_idx += b_idx * (td / 2);
}
T c = cos[rope_idx];
T s = sin[rope_idx];
@ -232,7 +236,7 @@ __device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, cons
}
template <typename T>
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d) {
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d, const uint32_t stride_b) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (2 * idx >= bh * td) return;
@ -243,6 +247,10 @@ __device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const
uint32_t i1 = i_bh * td + i_t * d + i_d;
uint32_t i2 = i1 + d / 2;
uint32_t i_cs = i_t * (d / 2) + i_d;
if (stride_b > 0) {
uint32_t b_idx = (2 * idx) / stride_b;
i_cs += b_idx * (td / 2);
}
T c = cos[i_cs];
T s = sin[i_cs];
@ -259,7 +267,8 @@ __device__ void rope_thd(
const uint32_t b,
const uint32_t t,
const uint32_t h,
const uint32_t d
const uint32_t d,
const uint32_t stride_b
) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (2 * idx >= b * t * h * d) return;
@ -270,6 +279,10 @@ __device__ void rope_thd(
uint32_t i1 = i_bth * d + i_d;
uint32_t i2 = i1 + d / 2;
uint32_t i_cs = i_t * (d / 2) + i_d;
if (stride_b > 0) {
uint32_t b_idx = (2 * idx) / stride_b;
i_cs += b_idx * ((t * d) / 2);
}
T c = cos[i_cs];
T s = sin[i_cs];
@ -546,8 +559,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
const TYPENAME *sin, \
TYPENAME *dst, \
const uint32_t bh, \
const uint32_t td) { \
ropei<TYPENAME>(src, cos, sin, dst, bh, td); \
const uint32_t td, \
const uint32_t stride_b) { \
ropei<TYPENAME>(src, cos, sin, dst, bh, td, stride_b); \
} \
extern "C" __global__ void FN_NAME( \
const TYPENAME *src, \
@ -556,8 +570,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
TYPENAME *dst, \
const uint32_t bh, \
const uint32_t td, \
const uint32_t d) { \
rope<TYPENAME>(src, cos, sin, dst, bh, td, d); \
const uint32_t d, \
const uint32_t stride_b) { \
rope<TYPENAME>(src, cos, sin, dst, bh, td, d, stride_b); \
} \
extern "C" __global__ void FN_NAME_THD( \
const TYPENAME *src, \
@ -567,8 +582,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
const uint32_t b, \
const uint32_t t, \
const uint32_t h, \
const uint32_t d) { \
rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d); \
const uint32_t d, \
const uint32_t stride_b) { \
rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d, stride_b); \
} \
#if __CUDA_ARCH__ >= 800

View File

@ -1,6 +1,6 @@
[package]
name = "candle-metal-kernels"
version = "0.9.0-alpha.4"
version = "0.9.1"
edition = "2021"
description = "Metal kernels for Candle"
@ -12,6 +12,7 @@ license = "MIT OR Apache-2.0"
[dependencies]
metal = { version = "0.27.0", features = ["mps"] }
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] }
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"

View File

@ -4,20 +4,20 @@ using namespace metal;
template<typename T> METAL_FUNC void fill_with(
device T *out,
constant float &value,
constant T &value,
constant size_t &numel,
uint tid [[thread_position_in_grid]]
) {
if (tid >= numel) {
return;
}
out[tid] = static_cast<T>(value);
out[tid] = value;
}
#define FILL_OP(NAME, T) \
kernel void fill_##NAME( \
device T *out, \
constant float &value, \
constant T &value, \
constant size_t &numel, \
uint tid [[thread_position_in_grid]] \
) { \

View File

@ -104,6 +104,31 @@ kernel void NAME( \
gather<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
}
template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void scatter(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &dst_dim_size,
const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
uint tid [[ thread_position_in_grid ]]
) {
if (tid >= dst_size) {
return;
}
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size;
for (unsigned int j = 0; j < src_dim_size; ++j) {
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
const INDEX_TYPENAME idx = input_ids[src_i];
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
output[dst_i] = input[src_i];
}
}
template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void scatter_add(
constant size_t &dst_size,
@ -129,6 +154,21 @@ METAL_FUNC void scatter_add(
}
}
# define SCATTER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \
constant size_t &dst_size, \
constant size_t &left_size, \
constant size_t &src_dim_size, \
constant size_t &right_size, \
constant size_t &dst_dim_size, \
const device TYPENAME *input, \
const device INDEX_TYPENAME *input_ids, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
scatter<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, input, input_ids, output, tid); \
}
# define SCATTER_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \
constant size_t &dst_size, \
@ -235,6 +275,19 @@ SCATTER_ADD_OP(sa_u8_bf16, uint8_t, bfloat)
SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat)
#endif
SCATTER_OP(s_u32_f32, uint32_t, float)
SCATTER_OP(s_u8_f32, uint8_t, float)
SCATTER_OP(s_i64_f32, int64_t, float)
SCATTER_OP(s_u32_u32, uint32_t, uint32_t)
SCATTER_OP(s_u32_f16, uint32_t, half)
SCATTER_OP(s_u8_f16, uint8_t, half)
SCATTER_OP(s_i64_f16, int64_t, half)
#if defined(__HAVE_BFLOAT__)
SCATTER_OP(s_u32_bf16, uint32_t, bfloat)
SCATTER_OP(s_u8_bf16, uint8_t, bfloat)
SCATTER_OP(s_i64_bf16, int64_t, bfloat)
#endif
// i64
INDEX_ADD_OP(ia_i64_f16, int64_t, half)
INDEX_ADD_OP(ia_i64_f32, int64_t, float)

View File

@ -161,7 +161,7 @@ macro_rules! ops{
pub mod unary {
ops!(
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
tanh, recip, silu, sign, sigmoid
tanh, recip, silu, sign, sigmoid, const_set
);
}
pub mod binary {
@ -419,6 +419,82 @@ pub fn call_copy2d(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_const_set_contiguous_tiled(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: unary::contiguous_tiled::Kernel,
length: usize,
input: impl EncoderParam,
output: BufferOffset,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let tile_size = 2;
let tiles = length.div_ceil(tile_size);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, input, &output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles);
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_const_set_contiguous(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: unary::contiguous::Kernel,
length: usize,
input: impl EncoderParam,
output: BufferOffset,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, input, &output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_const_set_strided(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
name: unary::strided::Kernel,
shape: &[usize],
input: impl EncoderParam,
strides: &[usize],
output: BufferOffset,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
let length: usize = shape.iter().product();
let num_dims: usize = shape.len();
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, num_dims, shape, strides, input, &output));
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_unary_contiguous_tiled(
device: &Device,
@ -915,6 +991,7 @@ pub fn call_rope_i(
kernel_name: &'static str,
bh: usize,
td: usize,
stride_b: usize,
src: &Buffer,
src_offset: usize,
cos: &Buffer,
@ -933,6 +1010,7 @@ pub fn call_rope_i(
(
bh,
td,
stride_b,
(src, src_offset),
(cos, cos_offset),
(sin, sin_offset),
@ -958,6 +1036,7 @@ pub fn call_rope_thd(
t: usize,
h: usize,
d: usize,
stride_b: usize,
src: &Buffer,
src_offset: usize,
cos: &Buffer,
@ -978,6 +1057,7 @@ pub fn call_rope_thd(
t,
h,
d,
stride_b,
(src, src_offset),
(cos, cos_offset),
(sin, sin_offset),
@ -1002,6 +1082,7 @@ pub fn call_rope(
bh: usize,
td: usize,
d: usize,
stride_b: usize,
src: &Buffer,
src_offset: usize,
cos: &Buffer,
@ -1021,6 +1102,7 @@ pub fn call_rope(
bh,
td,
d,
stride_b,
(src, src_offset),
(cos, cos_offset),
(sin, sin_offset),
@ -1371,7 +1453,7 @@ pub fn call_gather(
}
#[allow(clippy::too_many_arguments)]
pub fn call_scatter_add(
pub fn call_scatter(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
@ -1381,7 +1463,7 @@ pub fn call_scatter_add(
dim: usize,
input: BufferOffset,
ids: BufferOffset,
output: &Buffer,
output: BufferOffset,
) -> Result<(), MetalKernelError> {
let left_size: usize = src_shape[..dim].iter().product();
let right_size: usize = src_shape[dim + 1..].iter().product();
@ -1406,7 +1488,7 @@ pub fn call_scatter_add(
dst_dim_size,
&input,
&ids,
output
&output
)
);
@ -1414,7 +1496,7 @@ pub fn call_scatter_add(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}
@ -2570,7 +2652,7 @@ pub fn call_const_fill(
name: &'static str,
length: usize,
output: &Buffer,
v: f32,
v: impl EncoderParam,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
let encoder = ep.encoder();

View File

@ -1097,6 +1097,7 @@ template<typename T>
METAL_FUNC void ropei(
constant size_t &bh,
constant size_t &td,
constant size_t &stride_b,
device const T *src,
device const T *cos,
device const T *sin,
@ -1107,6 +1108,10 @@ METAL_FUNC void ropei(
return;
}
size_t rope_idx = tid % (td / 2);
if (stride_b > 0) {
size_t b_idx = (2 * tid) / stride_b;
rope_idx += b_idx * (td / 2);
}
T c = cos[rope_idx];
T s = sin[rope_idx];
dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s;
@ -1118,6 +1123,7 @@ METAL_FUNC void rope(
constant size_t &bh,
constant size_t &td,
constant size_t &d,
constant size_t &stride_b,
device const T *src,
device const T *cos,
device const T *sin,
@ -1134,6 +1140,10 @@ METAL_FUNC void rope(
size_t i1 = i_bh * td + i_t * d + i_d;
size_t i2 = i1 + d / 2;
size_t i_cs = i_t * (d / 2) + i_d;
if (stride_b > 0) {
size_t b_idx = (2 * idx) / stride_b;
i_cs += b_idx * (td / 2);
}
T c = cos[i_cs];
T s = sin[i_cs];
dst[i1] = src[i1] * c - src[i2] * s;
@ -1146,6 +1156,7 @@ METAL_FUNC void rope_thd(
constant size_t &t,
constant size_t &h,
constant size_t &d,
constant size_t &stride_b,
device const T *src,
device const T *cos,
device const T *sin,
@ -1160,8 +1171,12 @@ METAL_FUNC void rope_thd(
const size_t i_t = (i_bth / h) % t;
const size_t i1 = i_bth * d + i_d;
const size_t i2 = i1 + d / 2;
const size_t i_cs = i_t * (d / 2) + i_d;
T c = cos[i_cs];
size_t i_cs = i_t * (d / 2) + i_d;
if (stride_b > 0) {
const size_t b_idx = (2 * idx) / stride_b;
i_cs += b_idx * ((t * d) / 2);
}
T c = cos[i_cs];
T s = sin[i_cs];
dst[i1] = src[i1] * c - src[i2] * s;
dst[i2] = src[i1] * s + src[i2] * c;
@ -1171,38 +1186,41 @@ METAL_FUNC void rope_thd(
kernel void FN_NAME_I( \
constant size_t &bh, \
constant size_t &td, \
constant size_t &stride_b, \
device const TYPENAME *src, \
device const TYPENAME *cos, \
device const TYPENAME *sin, \
device TYPENAME *dst, \
uint tid [[ thread_position_in_grid ]] \
) { \
ropei<TYPENAME>(bh, td, src, cos, sin, dst, tid); \
ropei<TYPENAME>(bh, td, stride_b, src, cos, sin, dst, tid); \
}\
kernel void FN_NAME( \
constant size_t &bh, \
constant size_t &td, \
constant size_t &d, \
constant size_t &stride_b, \
device const TYPENAME *src, \
device const TYPENAME *cos, \
device const TYPENAME *sin, \
device TYPENAME *dst, \
uint idx [[ thread_position_in_grid ]] \
) { \
rope<TYPENAME>(bh, td, d, src, cos, sin, dst, idx); \
rope<TYPENAME>(bh, td, d, stride_b, src, cos, sin, dst, idx); \
}\
kernel void FN_NAME_THD( \
constant size_t &b, \
constant size_t &t, \
constant size_t &h, \
constant size_t &d, \
constant size_t &stride_b, \
device const TYPENAME *src, \
device const TYPENAME *cos, \
device const TYPENAME *sin, \
device TYPENAME *dst, \
uint idx [[ thread_position_in_grid ]] \
) { \
rope_thd<TYPENAME>(b, t, h, d, src, cos, sin, dst, idx); \
rope_thd<TYPENAME>(b, t, h, d, stride_b, src, cos, sin, dst, idx); \
}\
RMSNORM(rmsnorm_f32, float)

View File

@ -1574,7 +1574,7 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
let input_buffer = new_buffer(&device, input);
let ids_buffer = new_buffer(&device, ids);
let output = device.new_buffer(std::mem::size_of_val(input) as u64, options);
call_scatter_add(
call_scatter(
&device,
command_buffer,
&kernels,
@ -1584,7 +1584,7 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
dim,
BufferOffset::zero_offset(&input_buffer),
BufferOffset::zero_offset(&ids_buffer),
&output,
BufferOffset::zero_offset(&output),
)
.unwrap();
command_buffer.commit();
@ -2343,7 +2343,7 @@ fn conv_transpose1d_u32() {
#[test]
fn const_fill() {
fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
fn constant_fill<T: Clone + EncoderParam>(name: &'static str, len: usize, value: T) -> Vec<T> {
let dev = device();
let kernels = Kernels::new();
let command_queue = dev.new_command_queue();
@ -2357,11 +2357,15 @@ fn const_fill() {
command_buffer.wait_until_completed();
read_to_vec::<T>(&buffer, len)
}
fn test<T: Clone + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(name: &'static str, f: F) {
fn test<T: Clone + Copy + EncoderParam + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(
name: &'static str,
f: F,
) {
let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16);
let value = rand::thread_rng().gen_range(1. ..19.);
let value = f(value);
let v = constant_fill::<T>(name, len, value);
assert_eq!(v, vec![f(value); len])
assert_eq!(v, vec![value; len])
}
test::<u8, _>("fill_u8", |v| v as u8);
test::<u32, _>("fill_u32", |v| v as u32);

View File

@ -73,6 +73,44 @@ template <typename T> METAL_FUNC T sigmoid(T in) {
#define TILE_SIZE 2
#define CONST_SET(TYPENAME, FN_NAME) \
kernel void FN_NAME( \
constant size_t &dim, \
constant TYPENAME &input, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
if (tid >= dim) { \
return; \
} \
output[tid] = input; \
} \
kernel void FN_NAME##_##strided( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant TYPENAME &input, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
if (tid >= dim) { \
return; \
} \
output[get_strided_index(tid, num_dims, dims, strides)] = input; \
} \
kernel void FN_NAME##_##tiled( \
constant size_t &dim, \
constant TYPENAME &input, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
for (uint i = 0; i < TILE_SIZE; i++) { \
const uint idx = tid * TILE_SIZE + i; \
output[idx] = input; \
} \
}
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
constant size_t &dim, \
@ -139,6 +177,11 @@ COPY2D(copy2d_f16, half)
COPY2D(copy2d_u8, uint8_t)
COPY2D(copy2d_u32, uint32_t)
CONST_SET(float, const_set_f32)
CONST_SET(half, const_set_f16)
CONST_SET(uint8_t, const_set_u8)
CONST_SET(uint32_t, const_set_u32)
UNARY_OP(cos)
UNARY_OP(sin)
UNARY_OP(sqr)
@ -171,6 +214,7 @@ UNARY(precise::tanh, half, tanh_f16, tanh_f16_strided);
#if __METAL_VERSION__ >= 220
UNARY(id, int64_t, copy_i64, copy_i64_strided)
COPY2D(copy2d_i64, int64_t)
CONST_SET(int64_t, const_set_i64)
#endif
#if defined(__HAVE_BFLOAT__)
@ -199,4 +243,5 @@ UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
UNARY(precise::tanh, bfloat, tanh_bf16, tanh_bf16_strided);
COPY2D(copy2d_bf16, bfloat)
CONST_SET(bfloat, const_set_bf16)
#endif

View File

@ -88,9 +88,13 @@ primitive!(bool);
primitive!(usize);
primitive!(i32);
primitive!(i64);
primitive!(u8);
primitive!(u32);
primitive!(u64);
primitive!(f32);
primitive!(f64);
primitive!(half::bf16);
primitive!(half::f16);
pub struct BufferOffset<'a> {
pub buffer: &'a Buffer,

View File

@ -71,6 +71,8 @@ impl candle::Module for PReLU {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let weight = if self.is_scalar {
self.weight.reshape(())?
} else if xs.shape() == self.weight.shape() {
self.weight.clone()
} else if xs.rank() >= 2 {
let num_channels = xs.dim(1)?;
let num_weights = self.weight.elem_count();
@ -78,7 +80,7 @@ impl candle::Module for PReLU {
candle::bail!("error in prelu: unexpected number of channels for the input, got {num_channels}, weight dim is {num_weights}")
}
let mut s = vec![1; xs.rank()];
s[1] = self.weight.elem_count();
s[1] = num_weights;
self.weight.reshape(s)?
} else {
self.weight.clone()

View File

@ -1,6 +1,6 @@
//! Cache Implementations
//!
use candle::{Device, Result, Tensor};
use candle::{DType, Device, Result, Tensor};
#[derive(Debug, Clone)]
pub struct Cache {
@ -399,3 +399,322 @@ impl RotatingKvCache {
self.v.reset();
}
}
#[derive(Debug, Clone)]
pub struct IndicesAndMask {
indices: Tensor,
mask: Tensor,
}
impl IndicesAndMask {
pub fn mask(&self) -> &Tensor {
&self.mask
}
}
#[derive(Debug, Clone)]
pub struct ScatteredKvCache {
k: Tensor,
v: Tensor,
context: usize,
}
impl ScatteredKvCache {
pub fn append(
&mut self,
k: &Tensor,
v: &Tensor,
iam: &IndicesAndMask,
) -> Result<(Tensor, Tensor)> {
if self.context <= k.dim(2)? {
return Ok((k.clone(), v.clone()));
}
let indices = iam.indices.unsqueeze(2)?.unsqueeze(1)?;
let indices = indices.broadcast_as(k.shape())?.contiguous()?;
self.k.scatter_set(&indices, k, 2)?;
self.v.scatter_set(&indices, v, 2)?;
Ok((self.k.clone(), self.v.clone()))
}
pub fn k(&self) -> &Tensor {
&self.k
}
pub fn v(&self) -> &Tensor {
&self.v
}
}
#[derive(Debug, Clone)]
pub struct ScatteredCacheBuilder {
context: usize,
// The current position in the stream, this can be larger than context.
positions: Vec<usize>,
// The index where the next element will be stored.
indices: Vec<usize>,
dtype: DType,
device: Device,
}
impl ScatteredCacheBuilder {
pub fn new(batch_size: usize, context: usize, dtype: DType, device: &Device) -> Result<Self> {
let positions = vec![0; batch_size];
let indices = vec![0; batch_size];
Ok(Self {
positions,
indices,
context,
dtype,
device: device.clone(),
})
}
pub fn make_cache(&self, num_heads: usize, head_dim: usize) -> Result<ScatteredKvCache> {
let batch_size = self.batch_size();
let shape = (batch_size, num_heads, self.context, head_dim);
let k = Tensor::zeros(shape, self.dtype, self.device())?;
let v = Tensor::zeros(shape, self.dtype, self.device())?;
Ok(ScatteredKvCache {
k,
v,
context: self.context,
})
}
pub fn positions(&self) -> &[usize] {
&self.positions
}
pub fn reset(&mut self) {
self.positions.fill(0);
self.indices.fill(0);
}
pub fn batch_size(&self) -> usize {
self.positions.len()
}
pub fn reset_batch_index(&mut self, batch_index: usize) {
self.positions[batch_index] = 0;
self.indices[batch_index] = 0;
}
#[allow(clippy::needless_range_loop)]
pub fn indices_and_mask(
&mut self,
seq_len: usize,
batch_mask: &[bool],
) -> Result<IndicesAndMask> {
// mask shape is (b, h, t, k)
let context = self.context;
if self.context <= seq_len {
return self.indices_and_mask_abs(seq_len, batch_mask);
}
let mut attention_masks = Vec::with_capacity(self.batch_size());
let mut cache_indices = Vec::with_capacity(self.batch_size());
for (batch_i, &batch_mask) in batch_mask.iter().enumerate() {
if !batch_mask {
let masks: Vec<Vec<f32>> = vec![vec![0.0; context]; seq_len];
let indices = vec![self.indices[batch_i] as u32; seq_len];
attention_masks.push(masks);
cache_indices.push(indices);
} else {
let start_index = self.indices[batch_i];
let start_pos = self.positions[batch_i];
let mut masks: Vec<Vec<f32>> = Vec::with_capacity(seq_len);
let mut indices = Vec::with_capacity(seq_len);
let mut all_pos = vec![usize::MAX; context];
if start_pos < context {
for i in 0..start_pos {
all_pos[i] = i;
}
} else {
let offset = start_pos - start_index;
for i in 0..context {
all_pos[i] = if i < start_index {
i + offset
} else {
i + offset - context
};
}
}
for seq_i in 0..seq_len {
let index = self.indices[batch_i];
all_pos[index] = seq_i + start_pos;
indices.push(index as u32);
self.indices[batch_i] += 1;
self.positions[batch_i] += 1;
if self.indices[batch_i] >= self.context {
self.indices[batch_i] = 0;
}
}
for seq_i in 0..seq_len {
let my_pos = seq_i + start_pos;
let mask = all_pos
.iter()
.map(|&pos| {
if pos <= my_pos {
0.0
} else {
f32::NEG_INFINITY
}
})
.collect::<Vec<f32>>();
masks.push(mask);
}
attention_masks.push(masks);
cache_indices.push(indices);
}
}
// Flattening the attention mask then using Tensor::from_vec rather using Tensor::new ends
// up being almost 10x faster with candle 0.9.0. This has been fixed in candle 0.9.1.
let attention_masks = attention_masks
.into_iter()
.flat_map(|m| m.into_iter().flatten())
.collect::<Vec<f32>>();
let mask = Tensor::from_vec(attention_masks, ((), 1, seq_len, context), self.device())?
.to_dtype(self.dtype)?;
let indices = Tensor::new(cache_indices, self.device())?;
Ok(IndicesAndMask { indices, mask })
}
pub fn device(&self) -> &Device {
&self.device
}
#[allow(clippy::needless_range_loop)]
fn indices_and_mask_abs(
&mut self,
seq_len: usize,
batch_mask: &[bool],
) -> Result<IndicesAndMask> {
let mask = self.get_mask_abs(seq_len, seq_len)?;
let mut cache_indices = Vec::with_capacity(self.batch_size());
for (batch_i, &batch_mask) in batch_mask.iter().enumerate() {
if !batch_mask {
let indices = vec![self.indices[batch_i] as u32; seq_len];
cache_indices.push(indices);
} else {
let mut indices = Vec::with_capacity(seq_len);
for _ in 0..seq_len {
let index = self.indices[batch_i];
indices.push(index as u32);
self.indices[batch_i] += 1;
self.positions[batch_i] += 1;
if self.indices[batch_i] >= self.context {
self.indices[batch_i] = 0;
}
}
cache_indices.push(indices);
}
}
let indices = Tensor::new(cache_indices, self.device())?;
Ok(IndicesAndMask { indices, mask })
}
fn get_mask_abs(&self, size1: usize, size2: usize) -> Result<Tensor> {
let context = self.context;
let mask: Vec<_> = (0..size1)
.flat_map(|i| {
(0..size2).map(move |j| {
if size1 + j > size2 + i || size1 + j + context < size2 + i {
f32::NEG_INFINITY
} else {
0.0
}
})
})
.collect();
Tensor::from_slice(&mask, (size1, size2), self.device())
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle::IndexOp;
#[test]
fn test_scattered_kv_cache() -> Result<()> {
let device = Device::Cpu;
let mut cache = ScatteredCacheBuilder::new(2, 5, DType::F32, &device)?;
let inf = f32::INFINITY;
let iam = cache.indices_and_mask(1, &[true, false])?;
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
assert_eq!(iam.indices.to_vec2::<u32>()?, [[0], [0]]);
assert_eq!(
mask,
[[[0.0, -inf, -inf, -inf, -inf]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
);
let iam = cache.indices_and_mask(1, &[true, false])?;
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
assert_eq!(iam.indices.to_vec2::<u32>()?, [[1], [0]]);
assert_eq!(
mask,
[[[0.0, 0.0, -inf, -inf, -inf]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
);
let iam = cache.indices_and_mask(3, &[false, true])?;
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
assert_eq!(iam.indices.to_vec2::<u32>()?, [[2, 2, 2], [0, 1, 2]]);
assert_eq!(
mask,
[
[
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
],
[
[0.0, -inf, -inf, -inf, -inf],
[0.0, 0.0, -inf, -inf, -inf],
[0.0, 0.0, 0.0, -inf, -inf]
]
]
);
let iam = cache.indices_and_mask(3, &[true, true])?;
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
assert_eq!(iam.indices.to_vec2::<u32>()?, [[2, 3, 4], [3, 4, 0]]);
assert_eq!(
mask,
[
[
[0.0, 0.0, 0.0, -inf, -inf],
[0.0, 0.0, 0.0, 0.0, -inf],
[0.0, 0.0, 0.0, 0.0, 0.0]
],
[
[-inf, 0.0, 0.0, 0.0, -inf],
[-inf, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
]
]
);
let iam = cache.indices_and_mask(1, &[true, false])?;
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
assert_eq!(iam.indices.to_vec2::<u32>()?, [[0], [1]]);
assert_eq!(
mask,
[[[0.0, 0.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
);
let iam = cache.indices_and_mask(2, &[true, false])?;
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
assert_eq!(iam.indices.to_vec2::<u32>()?, [[1, 2], [1, 1]]);
assert_eq!(
mask,
[
[[0.0, 0.0, -inf, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]],
[[0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]]
]
);
Ok(())
}
}

View File

@ -46,15 +46,23 @@ impl candle::CustomOp3 for RotaryEmbI {
Some((o1, o2)) => &sin[o1..o2],
};
let (b, h, t, d) = l_src.shape().dims4()?;
let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
let el_count = b * h * t * d;
let mut dst = vec![T::zero(); el_count];
src.par_chunks(t * d)
.zip(dst.par_chunks_mut(t * d))
.for_each(|(src, dst)| {
.enumerate()
.for_each(|(bh_i, (src, dst))| {
for i_over_2 in 0..t * d / 2 {
let i = 2 * i_over_2;
dst[i] = src[i] * cos[i_over_2] - src[i + 1] * sin[i_over_2];
dst[i + 1] = src[i] * sin[i_over_2] + src[i + 1] * cos[i_over_2];
let rope_i = if unbatched_rope {
let b_i = bh_i / h;
i_over_2 + b_i * t * d / 2
} else {
i_over_2
};
dst[i] = src[i] * cos[rope_i] - src[i + 1] * sin[rope_i];
dst[i + 1] = src[i] * sin[rope_i] + src[i + 1] * cos[rope_i];
}
});
let storage = candle::WithDType::to_cpu_storage_owned(dst);
@ -115,6 +123,11 @@ impl candle::CustomOp3 for RotaryEmbI {
Some((o1, o2)) => sin.slice(o1..o2),
};
let (b, h, t, d) = l_src.shape().dims4()?;
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
(h * t * d) as u32
} else {
0u32
};
let el = b * h * t * d;
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), &kernels::REDUCE)?;
@ -125,7 +138,7 @@ impl candle::CustomOp3 for RotaryEmbI {
builder.arg(&cos);
builder.arg(&sin);
builder.arg(&dst);
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32);
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, stride_b);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
Ok(dst)
@ -182,6 +195,11 @@ impl candle::CustomOp3 for RotaryEmbI {
dtype => candle::bail!("rope-i is not implemented for {dtype:?}"),
};
let (b, h, t, d) = l_src.shape().dims4()?;
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
h * t * d
} else {
0usize
};
let el = b * h * t * d;
let output = device.new_buffer(el, src.dtype(), "rope-i")?;
candle_metal_kernels::call_rope_i(
@ -191,6 +209,7 @@ impl candle::CustomOp3 for RotaryEmbI {
name,
b * h,
t * d,
stride_b,
src.buffer(),
l_src.start_offset() * src.dtype().size_in_bytes(),
cos.buffer(),
@ -205,10 +224,23 @@ impl candle::CustomOp3 for RotaryEmbI {
}
}
fn rope_check_cs(cs: &Tensor, b_sz: usize) -> Result<(usize, usize)> {
match *cs.dims() {
[t, d] => Ok((t, d)),
[b, t, d] => {
if b != b_sz {
candle::bail!("inconsistent batch size in rope {b_sz} {cs:?}",)
}
Ok((t, d))
}
_ => candle::bail!("cos/sin has to be 2D or 3D in rope {b_sz} {cs:?}"),
}
}
pub fn rope_i(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
let (sin_seq_len, sin_n_embd) = cos.dims2()?;
let (b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
if cos_n_embd * 2 != n_embd
|| sin_n_embd * 2 != n_embd
|| seq_len > cos_seq_len
@ -292,16 +324,24 @@ impl candle::CustomOp3 for RotaryEmb {
Some((o1, o2)) => &sin[o1..o2],
};
let (b, h, t, d) = l_src.shape().dims4()?;
let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
let el_count = b * h * t * d;
let mut dst = vec![T::zero(); el_count];
src.par_chunks(t * d)
.zip(dst.par_chunks_mut(t * d))
.for_each(|(src, dst)| {
.enumerate()
.for_each(|(bh_i, (src, dst))| {
for i_t in 0..t {
for i_d in 0..d / 2 {
let i1 = i_t * d + i_d;
let i2 = i1 + d / 2;
let i_cs = i_t * (d / 2) + i_d;
let i_cs = if unbatched_rope {
let b_i = bh_i / h;
i_cs + b_i * t * d / 2
} else {
i_cs
};
dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs];
dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs];
}
@ -365,6 +405,11 @@ impl candle::CustomOp3 for RotaryEmb {
Some((o1, o2)) => sin.slice(o1..o2),
};
let (b, h, t, d) = l_src.shape().dims4()?;
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
(h * t * d) as u32
} else {
0u32
};
let el = b * h * t * d;
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), &kernels::REDUCE)?;
@ -375,7 +420,7 @@ impl candle::CustomOp3 for RotaryEmb {
builder.arg(&cos);
builder.arg(&sin);
builder.arg(&dst);
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32);
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32, stride_b);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
Ok(dst)
@ -432,6 +477,11 @@ impl candle::CustomOp3 for RotaryEmb {
dtype => candle::bail!("rope is not implemented for {dtype:?}"),
};
let (b, h, t, d) = l_src.shape().dims4()?;
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
h * t * d
} else {
0usize
};
let el = b * h * t * d;
let output = device.new_buffer(el, src.dtype(), "rope-i")?;
candle_metal_kernels::call_rope(
@ -442,6 +492,7 @@ impl candle::CustomOp3 for RotaryEmb {
b * h,
t * d,
d,
stride_b,
src.buffer(),
l_src.start_offset() * src.dtype().size_in_bytes(),
cos.buffer(),
@ -457,9 +508,9 @@ impl candle::CustomOp3 for RotaryEmb {
}
pub fn rope(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
let (sin_seq_len, sin_n_embd) = sin.dims2()?;
let (b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
if cos_n_embd * 2 != n_embd
|| sin_n_embd * 2 != n_embd
|| seq_len > cos_seq_len
@ -541,14 +592,21 @@ impl candle::CustomOp3 for RotaryEmbThd {
Some((o1, o2)) => &sin[o1..o2],
};
let (b, t, h, d) = l_src.shape().dims4()?;
let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
let el_count = b * h * t * d;
let mut dst = vec![T::zero(); el_count];
src.par_chunks(t * h * d)
.zip(dst.par_chunks_mut(t * h * d))
.for_each(|(src, dst)| {
.enumerate()
.for_each(|(b_i, (src, dst))| {
for i_t in 0..t {
for i_d in 0..d / 2 {
let i_cs = i_t * (d / 2) + i_d;
let i_cs = if unbatched_rope {
i_cs + b_i * t * d / 2
} else {
i_cs
};
for i_h in 0..h {
let i1 = i_t * h * d + i_h * d + i_d;
let i2 = i1 + d / 2;
@ -616,6 +674,11 @@ impl candle::CustomOp3 for RotaryEmbThd {
Some((o1, o2)) => sin.slice(o1..o2),
};
let (b, t, h, d) = l_src.shape().dims4()?;
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
(h * t * d) as u32
} else {
0u32
};
let el = b * h * t * d;
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), &kernels::REDUCE)?;
@ -626,7 +689,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
builder.arg(&cos);
builder.arg(&sin);
builder.arg(&dst);
candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32);
candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32, stride_b);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
Ok(dst)
@ -683,6 +746,11 @@ impl candle::CustomOp3 for RotaryEmbThd {
dtype => candle::bail!("rope_thd is not implemented for {dtype:?}"),
};
let (b, t, h, d) = l_src.shape().dims4()?;
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
h * t * d
} else {
0usize
};
let el = b * h * t * d;
let output = device.new_buffer(el, src.dtype(), "rope-thd")?;
candle_metal_kernels::call_rope_thd(
@ -694,6 +762,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
t,
h,
d,
stride_b,
src.buffer(),
l_src.start_offset() * src.dtype().size_in_bytes(),
cos.buffer(),
@ -709,9 +778,9 @@ impl candle::CustomOp3 for RotaryEmbThd {
}
pub fn rope_thd(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let (_b_sz, seq_len, _n_head, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
let (sin_seq_len, sin_n_embd) = sin.dims2()?;
let (b_sz, seq_len, _n_head, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
if cos_n_embd * 2 != n_embd
|| sin_n_embd * 2 != n_embd
|| seq_len > cos_seq_len

View File

@ -8,13 +8,16 @@ pub fn gumbel_softmax<D: candle::shape::Dim>(
) -> Result<Tensor> {
if temperature <= 0.0 {
logits.argmax(dim)
} else if temperature == 1.0 {
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
let sampled = (logits - minus_g)?.argmax(dim)?;
Ok(sampled)
} else {
// Cast to f32, doing the Gumbel softmax in bf16 is a bit unstable.
let logits = logits.to_dtype(candle::DType::F32)?;
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
Ok(sampled)
if temperature == 1.0 {
let sampled = (logits - minus_g)?.argmax(dim)?;
Ok(sampled)
} else {
let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
Ok(sampled)
}
}
}

View File

@ -4,7 +4,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{test_device, test_utils::to_vec3_round, Device, Result, Tensor};
use candle::{test_device, test_utils::to_vec3_round, Device, IndexOp, Result, Tensor};
fn softmax(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
@ -179,6 +179,28 @@ fn ropei(device: &Device) -> Result<()> {
} else {
assert!(sum_diff < 1e-4);
}
// Test with a 3d cos/sin
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
let rope1 = candle_nn::rotary_emb::rope_i(&src.i(0..1)?, &cos, &sin)?;
let rope2 = candle_nn::rotary_emb::rope_i(&src.i(1..2)?, &cos2, &sin2)?;
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
let both_rope = candle_nn::rotary_emb::rope_i(&src, &both_cos, &both_sin)?;
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
let sum_diff = (both_rope - both_rope2)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(sum_diff, 0.);
Ok(())
}
@ -206,6 +228,28 @@ fn rope(device: &Device) -> Result<()> {
} else {
assert!(sum_diff < 1e-4);
}
// Test with a 3d cos/sin
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
let rope1 = candle_nn::rotary_emb::rope(&src.i(0..1)?, &cos, &sin)?;
let rope2 = candle_nn::rotary_emb::rope(&src.i(1..2)?, &cos2, &sin2)?;
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
let both_rope = candle_nn::rotary_emb::rope(&src, &both_cos, &both_sin)?;
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
let sum_diff = (both_rope - both_rope2)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(sum_diff, 0.);
Ok(())
}
@ -236,6 +280,37 @@ fn rope_thd(device: &Device) -> Result<()> {
} else {
assert!(sum_diff < 1e-4);
}
// Test with a 3d cos/sin
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
let rope1 = {
let src = src.transpose(1, 2)?.contiguous()?;
candle_nn::rotary_emb::rope_thd(&src.i(0..1)?, &cos, &sin)?
};
let rope2 = {
let src = src.transpose(1, 2)?.contiguous()?;
candle_nn::rotary_emb::rope_thd(&src.i(1..2)?, &cos2, &sin2)?
};
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
let both_rope = {
let src = src.transpose(1, 2)?.contiguous()?;
candle_nn::rotary_emb::rope_thd(&src, &both_cos, &both_sin)?
};
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
let sum_diff = (both_rope - both_rope2)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(sum_diff, 0.);
Ok(())
}

View File

@ -1,6 +1,6 @@
[package]
name = "candle-onnx"
version = "0.9.0-alpha.4"
version = "0.9.1"
edition = "2021"
description = "ONNX support for Candle"
@ -10,8 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.4" }
candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.4" }
candle = { path = "../candle-core", package = "candle-core", version = "0.9.1" }
candle-nn = { path = "../candle-nn", version = "0.9.1" }
prost = "0.12.1"
[build-dependencies]

View File

@ -1,7 +1,9 @@
use crate::onnx::attribute_proto::AttributeType;
use crate::onnx::tensor_proto::DataType;
use crate::onnx::{self, GraphProto};
use candle::Module;
use candle::{bail, DType, Device, Result, Tensor};
use candle_nn::activation::PReLU;
use std::collections::{HashMap, HashSet};
pub type Value = Tensor;
@ -991,6 +993,14 @@ fn simple_eval_(
let output = input.relu()?;
values.insert(node.output[0].clone(), output);
}
"PRelu" => {
// https://onnx.ai/onnx/operators/onnx__PRelu.html
let input = get(&node.input[0])?;
let slope = get(&node.input[1])?;
let output = PReLU::new(slope.clone(), false).forward(input)?;
values.insert(node.output[0].clone(), output);
}
"Ceil" => {
let input = get(&node.input[0])?;
let output = input.ceil()?;

View File

@ -1846,6 +1846,64 @@ fn test_relu_operation() -> Result<()> {
Ok(())
}
// "PRelu"
#[test]
fn test_prelu_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "PRelu".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x: Tensor = Tensor::from_vec(
vec![-1.0f32, 1.0f32, -2.0f32, 3.0f32],
&[2, 2],
&Device::Cpu,
)?;
let y: Tensor = Tensor::from_vec(vec![1.0f32, 1.1f32, 1.2f32, 1.3f32], &[2, 2], &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
inputs.insert(INPUT_Y.to_string(), y);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(results, vec![vec![-1.0, 1.0], vec![-2.4, 3.0]]);
Ok(())
}
// "Constant"
// #[test]

View File

@ -21,6 +21,7 @@ pub struct Config {
pub num_key_value_heads: usize,
pub rms_norm_eps: f64,
pub rope_theta: f64,
pub rope_local_base_freq: f64,
pub vocab_size: usize,
pub final_logit_softcapping: Option<f64>,
pub attn_logit_softcapping: Option<f64>,
@ -67,12 +68,22 @@ struct RotaryEmbedding {
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
fn new(
dtype: DType,
cfg: &Config,
dev: &Device,
sliding_window: Option<usize>,
) -> Result<Self> {
let dim = cfg.head_dim;
let max_seq_len = cfg.max_position_embeddings;
let rope_freq = if sliding_window.is_some() {
cfg.rope_local_base_freq
} else {
cfg.rope_theta
};
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.map(|i| 1f32 / rope_freq.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
@ -162,8 +173,8 @@ impl Attention {
fn new(
rotary_emb: Arc<RotaryEmbedding>,
use_flash_attn: bool,
is_sliding: bool,
cfg: &Config,
sliding_window: Option<usize>,
vb: VarBuilder,
) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
@ -178,13 +189,13 @@ impl Attention {
let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?;
let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?;
let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?;
let kv_cache = if is_sliding {
KvCache::Rotating(candle_nn::kv_cache::RotatingKvCache::new(
2,
cfg.sliding_window,
))
let kv_cache = if let Some(sliding_window) = sliding_window {
KvCache::Rotating(candle_nn::kv_cache::RotatingKvCache::new(2, sliding_window))
} else {
KvCache::Normal(candle_nn::kv_cache::KvCache::new(2, cfg.sliding_window))
KvCache::Normal(candle_nn::kv_cache::KvCache::new(
2,
cfg.max_position_embeddings,
))
};
Ok(Self {
q_proj,
@ -302,21 +313,27 @@ struct DecoderLayer {
pre_feedforward_layernorm: RmsNorm,
post_feedforward_layernorm: RmsNorm,
post_attention_layernorm: RmsNorm,
sliding_window: Option<usize>,
}
impl DecoderLayer {
fn new(
rotary_emb: Arc<RotaryEmbedding>,
use_flash_attn: bool,
is_sliding: bool,
cfg: &Config,
vb: VarBuilder,
sliding_window: Option<usize>,
) -> Result<Self> {
let rotary_emb = Arc::new(RotaryEmbedding::new(
vb.dtype(),
cfg,
vb.device(),
sliding_window,
)?);
let self_attn = Attention::new(
rotary_emb,
use_flash_attn,
is_sliding,
cfg,
sliding_window,
vb.pp("self_attn"),
)?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
@ -344,6 +361,7 @@ impl DecoderLayer {
pre_feedforward_layernorm,
post_feedforward_layernorm,
post_attention_layernorm,
sliding_window,
})
}
@ -370,6 +388,42 @@ impl DecoderLayer {
}
}
fn prepare_decoder_attention_mask(
b_size: usize,
tgt_len: usize,
seqlen_offset: usize,
sliding_window: Option<usize>,
dtype: DType,
device: &Device,
) -> Result<Tensor> {
let mask: Vec<_> = if let Some(sliding_window) = sliding_window {
(0..tgt_len)
.flat_map(|i| {
(0..tgt_len).map(move |j| {
if i < j || j + sliding_window < i {
f32::NEG_INFINITY
} else {
0.
}
})
})
.collect()
} else {
(0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0f32 }))
.collect()
};
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(dtype)
}
#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: candle_nn::Embedding,
@ -388,17 +442,15 @@ impl Model {
let vb_m = vb.pp("model");
let embed_tokens =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_m.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let is_sliding = (layer_idx + 1) % cfg.sliding_window_pattern > 0;
let sliding_window = (layer_idx + 1) % cfg.sliding_window_pattern > 0;
let layer = DecoderLayer::new(
rotary_emb.clone(),
use_flash_attn,
is_sliding,
cfg,
vb_l.pp(layer_idx),
sliding_window.then_some(cfg.sliding_window),
)?;
layers.push(layer)
}
@ -417,51 +469,52 @@ impl Model {
})
}
fn prepare_decoder_attention_mask(
fn create_attention_masks(
&self,
b_size: usize,
tgt_len: usize,
batch_size: usize,
seq_len: usize,
seqlen_offset: usize,
) -> Result<Tensor> {
let mask: Vec<_> = match Some(self.sliding_window) {
None => (0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
.collect(),
Some(sliding_window) => (0..tgt_len)
.flat_map(|i| {
(0..tgt_len).map(move |j| {
if i < j || j + sliding_window < i {
f32::NEG_INFINITY
} else {
0.
}
})
})
.collect(),
};
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(self.dtype)
) -> Result<(Option<Tensor>, Option<Tensor>)> {
if seq_len <= 1 {
return Ok((None, None));
}
let mask = prepare_decoder_attention_mask(
batch_size,
seq_len,
seqlen_offset,
None,
self.dtype,
&self.device,
)?;
let sliding_mask = prepare_decoder_attention_mask(
batch_size,
seq_len,
seqlen_offset,
Some(self.sliding_window),
self.dtype,
&self.device,
)?;
Ok((Some(mask), Some(sliding_mask)))
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
Some(mask)
};
let xs = self.embed_tokens.forward(input_ids)?;
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
let (attention_mask, sliding_attention_mask) =
self.create_attention_masks(b_size, seq_len, seqlen_offset)?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
let mask = if layer.sliding_window.is_some() {
&sliding_attention_mask
} else {
&attention_mask
};
xs = layer.forward(&xs, mask.as_ref(), seqlen_offset)?
}
let logits = xs
.narrow(1, seq_len - 1, 1)?

View File

@ -79,6 +79,7 @@ pub mod phi3;
pub mod pixtral;
pub mod quantized_blip;
pub mod quantized_blip_text;
pub mod quantized_gemma3;
pub mod quantized_llama;
pub mod quantized_llama2_c;
pub mod quantized_metavoice;

View File

@ -0,0 +1,466 @@
//! Gemma 3 model implementation with quantization support.
//!
//! Gemma 3 is a family of multimodal language models developed by Google.
//! This implementation provides quantization for reduced memory usage and faster inference.
//!
//! Key characteristics:
//! - Group-Query Attention (GQA) with specialized key-value heads
//! - RMSNorm for layer normalization
//! - Specialized attention patterns with separate normalization for Q/K/V
//! - Feed-forward network with SwiGLU activation
//! - Support for 2/3/4/8-bit quantization
//!
//! References:
//! - [Gemma 3 Models](https://blog.google/technology/developers/gemma-3/)
//!
use crate::quantized_nn::RmsNorm;
use candle::quantized::gguf_file;
use candle::quantized::QTensor;
use candle::D;
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{Embedding, Module};
pub const MAX_SEQ_LEN: usize = 131072; // Gemma 3 supports 128K context window
pub const DEFAULT_SLIDING_WINDOW_TYPE: usize = 6;
pub const DEFAULT_ROPE_FREQUENCY: f32 = 1_000_000.;
pub const DEFAULT_ROPE_FREQUENCY_SLIDING: f32 = 10_000.;
pub const DEFAULT_ROPE_FREQUENCY_SCALE_FACTOR: f32 = 1.;
#[derive(Debug, Clone)]
struct QMatMul {
inner: candle::quantized::QMatMul,
span: tracing::Span,
}
impl QMatMul {
fn from_qtensor(qtensor: QTensor) -> Result<Self> {
let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?;
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
Ok(Self { inner, span })
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(xs)
}
}
#[derive(Debug, Clone)]
struct Mlp {
feed_forward_gate: QMatMul, // ffn_gate in GGUF
feed_forward_up: QMatMul, // ffn_up in GGUF
feed_forward_down: QMatMul, // ffn_down in GGUF
}
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let gate = self.feed_forward_gate.forward(xs)?;
let up = self.feed_forward_up.forward(xs)?;
let silu = candle_nn::ops::silu(&gate)?;
let gated = (silu * up)?;
self.feed_forward_down.forward(&gated)
}
}
#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
impl RotaryEmbedding {
fn new(head_dim: usize, rope_frequency: f32, device: &Device) -> Result<Self> {
let theta: Vec<_> = (0..head_dim)
.step_by(2)
.map(|i| 1f32 / rope_frequency.powf(i as f32 / head_dim as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), device)?;
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
.to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let cos = idx_theta.cos()?;
let sin = idx_theta.sin()?;
Ok(Self { sin, cos })
}
fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,
index_pos: usize,
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, index_pos, seq_len)?;
let sin = self.sin.narrow(0, index_pos, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
#[derive(Debug, Clone)]
struct LayerWeights {
// Attention components
attention_wq: QMatMul,
attention_wk: QMatMul,
attention_wv: QMatMul,
attention_wo: QMatMul,
// Specialized normalization for Q and K
attention_q_norm: RmsNorm,
attention_k_norm: RmsNorm,
// Layer normalization
attention_norm: RmsNorm, // Applied before attention
post_attention_norm: RmsNorm, // Applied after attention
ffn_norm: RmsNorm, // Applied before feedforward
post_ffn_norm: RmsNorm, // Applied after feedforward
// Feed-forward network
mlp: Mlp,
// Attention parameters
n_head: usize, // Number of query heads
n_kv_head: usize, // Number of key-value heads
head_dim: usize, // Dimension of each head
q_dim: usize, // Total dimension for queries
sliding_window_size: Option<usize>,
rotary_embedding: RotaryEmbedding,
neg_inf: Tensor,
// Cache
kv_cache: Option<(Tensor, Tensor)>,
// Tracing
span_attn: tracing::Span,
span_mlp: tracing::Span,
}
impl LayerWeights {
fn mask(
&self,
b_sz: usize,
seq_len: usize,
index_pos: usize,
dtype: DType,
device: &Device,
) -> Result<Tensor> {
let mask: Vec<_> = if let Some(sliding_window_size) = self.sliding_window_size {
(0..seq_len)
.flat_map(|i| {
(0..seq_len).map(move |j| {
if i < j || j + sliding_window_size < i {
0u32
} else {
1u32
}
})
})
.collect()
} else {
(0..seq_len)
.flat_map(|i| (0..seq_len).map(move |j| if i < j { 0u32 } else { 1u32 }))
.collect()
};
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
let mask = if index_pos > 0 {
let mask0 = Tensor::zeros((seq_len, index_pos), DType::F32, device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_sz, 1, seq_len, seq_len + index_pos))?
.to_dtype(dtype)
}
fn forward_attn(
&mut self,
x: &Tensor,
mask: Option<&Tensor>,
index_pos: usize,
) -> Result<Tensor> {
let _enter = self.span_attn.enter();
let (b_sz, seq_len, _) = x.dims3()?;
let q = self.attention_wq.forward(x)?;
let k = self.attention_wk.forward(x)?;
let v = self.attention_wv.forward(x)?;
let q = q
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
.transpose(1, 2)?;
let q = self.attention_q_norm.forward(&q.contiguous()?)?;
let k = self.attention_k_norm.forward(&k.contiguous()?)?;
let (q, k) = self
.rotary_embedding
.apply_rotary_emb_qkv(&q, &k, index_pos)?;
let (k, v) = match &self.kv_cache {
None => (k, v),
Some((k_cache, v_cache)) => {
if index_pos == 0 {
(k, v)
} else {
let k = Tensor::cat(&[k_cache, &k], 2)?; // concat on seq dim
let v = Tensor::cat(&[v_cache, &v], 2)?;
(k, v)
}
}
};
self.kv_cache = Some((k.clone(), v.clone())); // update cache
// Repeat KV for GQA
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
// Scaled Dot-Product Attention
let scale = 1.0 / (self.head_dim as f64).sqrt();
let mut attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
if let Some(mask) = mask {
let mask = mask.broadcast_as(attn_weights.shape())?;
let neg_inf = self.neg_inf.broadcast_as(attn_weights.dims())?;
attn_weights = mask.eq(0u32)?.where_cond(&neg_inf, &attn_weights)?;
}
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let attn_output = attn_weights.matmul(&v)?;
let attn_output = attn_output
.transpose(1, 2)?
.reshape((b_sz, seq_len, self.q_dim))?;
self.attention_wo.forward(&attn_output)
}
}
#[derive(Debug, Clone)]
pub struct ModelWeights {
tok_embeddings: Embedding,
embedding_length: usize,
layers: Vec<LayerWeights>,
norm: RmsNorm,
output: QMatMul,
span: tracing::Span,
span_output: tracing::Span,
}
impl ModelWeights {
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content,
reader: &mut R,
device: &Device,
) -> Result<Self> {
let md_get = |s: &str| match ct.metadata.get(s) {
None => candle::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v),
};
let head_count = md_get("gemma3.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("gemma3.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("gemma3.block_count")?.to_u32()? as usize;
let embedding_length = md_get("gemma3.embedding_length")?.to_u32()? as usize;
let key_length = md_get("gemma3.attention.key_length")?.to_u32()? as usize;
let _value_length = md_get("gemma3.attention.value_length")?.to_u32()? as usize;
let rms_norm_eps = md_get("gemma3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
let sliding_window_size = md_get("gemma3.attention.sliding_window")?.to_u32()? as usize;
let sliding_window_type = md_get("gemma3.attention.sliding_window_type")
.and_then(|m| Ok(m.to_u32()? as usize))
.unwrap_or(DEFAULT_SLIDING_WINDOW_TYPE);
let rope_freq_base = md_get("gemma3.rope.freq_base")
.and_then(|m| m.to_f32())
.unwrap_or(DEFAULT_ROPE_FREQUENCY);
let rope_freq_base_sliding = md_get("gemma3.rope.local_freq_base")
.and_then(|m| m.to_f32())
.unwrap_or(DEFAULT_ROPE_FREQUENCY_SLIDING);
// Unused in Llama.cpp so we aren't using it here.
let _rope_freq_scaling_factor = md_get("gemma3.rope.scaling.factor")
.and_then(|m| m.to_f32())
.unwrap_or(DEFAULT_ROPE_FREQUENCY_SCALE_FACTOR);
// Compute the dimensions for queries, keys, and values
// These are the total dimensions when projected across all heads
let q_dim = head_count * key_length;
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
// Load token embeddings and output projection
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
let norm = RmsNorm::from_qtensor(
ct.tensor(reader, "output_norm.weight", device)?,
rms_norm_eps,
)?;
let output = match ct.tensor(reader, "output.weight", device) {
Ok(tensor) => tensor,
Err(_) => ct.tensor(reader, "token_embd.weight", device)?, // Use tied weights if output.weight doesn't exist
};
let mut layers = Vec::with_capacity(block_count);
for layer_idx in 0..block_count {
let prefix = format!("blk.{layer_idx}");
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
let attention_wo =
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
let attention_q_norm = RmsNorm::from_qtensor(
ct.tensor(reader, &format!("{prefix}.attn_q_norm.weight"), device)?,
rms_norm_eps,
)?;
let attention_k_norm = RmsNorm::from_qtensor(
ct.tensor(reader, &format!("{prefix}.attn_k_norm.weight"), device)?,
rms_norm_eps,
)?;
let attention_norm = RmsNorm::from_qtensor(
ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?,
rms_norm_eps,
)?;
let post_attention_norm = RmsNorm::from_qtensor(
ct.tensor(
reader,
&format!("{prefix}.post_attention_norm.weight"),
device,
)?,
rms_norm_eps,
)?;
let ffn_norm = RmsNorm::from_qtensor(
ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?,
rms_norm_eps,
)?;
let post_ffn_norm = RmsNorm::from_qtensor(
ct.tensor(reader, &format!("{prefix}.post_ffw_norm.weight"), device)?,
rms_norm_eps,
)?;
let feed_forward_gate =
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
let feed_forward_up = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
let feed_forward_down =
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
let mlp = Mlp {
feed_forward_gate: QMatMul::from_qtensor(feed_forward_gate)?,
feed_forward_up: QMatMul::from_qtensor(feed_forward_up)?,
feed_forward_down: QMatMul::from_qtensor(feed_forward_down)?,
};
// Sliding window pattern hardcoded to 6 because it's not explicitly defined
let is_sliding = (layer_idx + 1) % sliding_window_type > 0;
let sliding_window_size = is_sliding.then_some(sliding_window_size);
let layer_rope_frequency = if is_sliding {
rope_freq_base_sliding
} else {
rope_freq_base
};
let rotary_embedding = RotaryEmbedding::new(key_length, layer_rope_frequency, device)?;
// Tracing spans
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
layers.push(LayerWeights {
attention_wq: QMatMul::from_qtensor(attention_wq)?,
attention_wk: QMatMul::from_qtensor(attention_wk)?,
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
attention_q_norm,
attention_k_norm,
attention_norm,
post_attention_norm,
ffn_norm,
post_ffn_norm,
mlp,
n_head: head_count,
n_kv_head: head_count_kv,
head_dim: key_length,
q_dim,
sliding_window_size,
rotary_embedding,
neg_inf: neg_inf.clone(),
kv_cache: None,
span_attn,
span_mlp,
})
}
let span = tracing::span!(tracing::Level::TRACE, "model");
let span_output = tracing::span!(tracing::Level::TRACE, "output");
Ok(Self {
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
embedding_length,
layers,
norm,
output: QMatMul::from_qtensor(output)?,
span,
span_output,
})
}
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (b_sz, seq_len) = x.dims2()?;
let _enter = self.span.enter();
let mut layer_in = self.tok_embeddings.forward(x)?;
layer_in = (layer_in * (self.embedding_length as f64).sqrt())?;
for layer in self.layers.iter_mut() {
let attention_mask = if seq_len == 1 {
None
} else {
Some(layer.mask(b_sz, seq_len, index_pos, x.dtype(), x.device())?)
};
// Attention block
let residual = &layer_in;
let x = layer.attention_norm.forward(&layer_in)?;
let x = layer.forward_attn(&x, attention_mask.as_ref(), index_pos)?;
let x = layer.post_attention_norm.forward(&x)?;
let x = (x + residual)?;
// Feed-forward block
let _enter = layer.span_mlp.enter();
let residual = &x;
let x = layer.ffn_norm.forward(&x)?;
let x = layer.mlp.forward(&x)?;
let x = layer.post_ffn_norm.forward(&x)?;
let x = (x + residual)?;
drop(_enter);
layer_in = x;
}
let _enter = self.span_output.enter();
let x = layer_in.i((.., seq_len - 1, ..))?;
let x = self.norm.forward(&x)?;
let output = self.output.forward(&x)?;
Ok(output)
}
}