mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Compare commits
40 Commits
0.9.0-alph
...
main
Author | SHA1 | Date | |
---|---|---|---|
17313a4226 | |||
0224a749f0 | |||
cd7b877d6b | |||
5aed817f1b | |||
1a183c988a | |||
cac51fe16a | |||
61ddb9535e | |||
9a62c91643 | |||
92106c8762 | |||
9ce4fe6194 | |||
450a49ed1a | |||
6bd61727bc | |||
485ddf2996 | |||
36508a2c93 | |||
3d05f5cf3d | |||
637473cb5e | |||
e27b4700ad | |||
1fdfb58de5 | |||
cd96fa80da | |||
8a19bb7df2 | |||
38fc86621c | |||
5029ac52bb | |||
de23d34a28 | |||
d4bac37a61 | |||
e98754fc5a | |||
e3db30021f | |||
6e0646c208 | |||
fbaf0b0e32 | |||
a2e925462c | |||
3827685524 | |||
3aeb9575c7 | |||
6ff0a6999c | |||
82def7ae38 | |||
99bd69f383 | |||
a4c56a958e | |||
b2904a830b | |||
21055b5697 | |||
9dbaf958dc | |||
ce5f8dd129 | |||
9954981327 |
20
Cargo.toml
20
Cargo.toml
@ -20,7 +20,7 @@ exclude = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.9.0-alpha.4"
|
||||
version = "0.9.1"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -33,17 +33,17 @@ ab_glyph = "0.2.23"
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.4" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.4" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.4" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.4" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.4" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.4" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.4" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.4" }
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.9.1" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.9.1" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.1" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.9.1" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.1" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.9.1" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.9.1" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.9.1" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.16.0", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||
cudarc = { version = "0.16.3", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||
fancy-regex = "0.13.0"
|
||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||
hf-hub = "0.4.1"
|
||||
|
@ -16,6 +16,7 @@
|
||||
- [Running a model](inference/inference.md)
|
||||
- [Using the hub](inference/hub.md)
|
||||
- [Error management](error_manage.md)
|
||||
- [Tracing](tracing.md)
|
||||
- [Training](training/training.md)
|
||||
- [Simplified](training/simplified.md)
|
||||
- [MNIST](training/mnist.md)
|
||||
|
68
candle-book/src/tracing.md
Normal file
68
candle-book/src/tracing.md
Normal file
@ -0,0 +1,68 @@
|
||||
# Tracing
|
||||
|
||||
Tracing is a powerful tool for identifying performance issues and bottlenecks in code.
|
||||
|
||||
> Profiling on GPUs is trickier due to asynchronous execution, see the [GPU section](#gpu).
|
||||
|
||||
## Overview
|
||||
|
||||
Candle uses the [tracing](https://docs.rs/tracing/latest/tracing/) crate for instrumentation.
|
||||
|
||||
To try it out, run an example in `candle-examples` with the `--tracing` flag.
|
||||
This generates a trace file, typically named `trace-<timestamp>.json`.
|
||||
You can view the trace in Chrome by navigating to `chrome://tracing/`, clicking **Load**, and selecting the generated trace file.
|
||||
|
||||
## Adding Tracing
|
||||
|
||||
Candle includes built-in tracing for many internal operations, using [spans](https://docs.rs/tracing/latest/tracing/struct.Span.html) to mark key points of execution.
|
||||
|
||||
To add custom tracing in your code, you can define a span like this:
|
||||
|
||||
```rust
|
||||
let span = tracing::span!(tracing::Level::TRACE, name);
|
||||
```
|
||||
|
||||
Then, to record the span during execution, create a guard:
|
||||
|
||||
```rust
|
||||
let _enter = span.enter();
|
||||
```
|
||||
|
||||
This guard will record the span's duration, from when it is created to when it is dropped, into a global data structure managed by the tracing crate.
|
||||
|
||||
## Recording and Saving a Trace
|
||||
|
||||
To capture and save trace data, you need to configure the tracing system with an output format. Candle uses the [tracing_subscriber](https://docs.rs/tracing-subscriber/latest/tracing_subscriber/) and [tracing_chrome](https://docs.rs/tracing-chrome/latest/tracing_chrome/) crates.
|
||||
|
||||
The snippet below sets up a Chrome compatible recorder that logs all tracing activity between creation and drop of the guard:
|
||||
|
||||
```rust
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let _guard = {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
guard
|
||||
};
|
||||
```
|
||||
|
||||
## GPU
|
||||
|
||||
When using CUDA, Metal, or other asynchronous GPU backends, tracing may produce misleading timing data because operations are queued rather than executed immediately.
|
||||
|
||||
### CUDA
|
||||
|
||||
For CUDA-specific profiling, you have two options:
|
||||
|
||||
1. Set the environment variable `CUDA_LAUNCH_BLOCKING=1` which forces synchronous execution. This makes trace timings more accurate, at the cost of reduced performance.
|
||||
2. Use [NVIDIA's Nsight Systems](https://developer.nvidia.com/nsight-systems) (`nsys profile` and `nsys-ui`) which are designed specifically for profiling asynchronous CUDA executions.
|
||||
|
||||
We recommend using NVIDIA's Nsight Systems when possible, as it offers accurate performance data without altering typical execution patterns. In contrast, setting the `CUDA_LAUNCH_BLOCKING` environment variable forces synchronous execution, which can significantly alter execution behavior.
|
||||
|
||||
#### Performance Profiling with NVIDIA Nsight Systems
|
||||
|
||||
1. Generate an `.nsys-rep` file containing performance data ([docs](https://docs.nvidia.com/nsight-systems/UserGuide/index.html#example-single-command-lines))
|
||||
- Run `nsys profile --trace cuda,nvtx,osrt --gpu-metrics-device=all --output profile_run ./target/debug/... --prompt "whatever "`
|
||||
1. Open the generated `.nsys-rep` report file in Nsight Systems GUI
|
||||
- File > Open
|
@ -4,11 +4,12 @@ use criterion::criterion_main;
|
||||
|
||||
criterion_main!(
|
||||
benchmarks::affine::benches,
|
||||
benchmarks::copy::benches,
|
||||
benchmarks::conv_transpose2d::benches,
|
||||
benchmarks::matmul::benches,
|
||||
benchmarks::qmatmul::benches,
|
||||
benchmarks::random::benches,
|
||||
benchmarks::reduce::benches,
|
||||
benchmarks::unary::benches,
|
||||
benchmarks::where_cond::benches,
|
||||
benchmarks::conv_transpose2d::benches,
|
||||
benchmarks::qmatmul::benches,
|
||||
benchmarks::unary::benches
|
||||
);
|
||||
|
38
candle-core/benches/benchmarks/copy.rs
Normal file
38
candle-core/benches/benchmarks/copy.rs
Normal file
@ -0,0 +1,38 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{Device, Tensor, WithDType};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run_copy_mask_benchmark<D: WithDType>(c: &mut Criterion, device: &Device, name: &str) {
|
||||
let batch_size = 128;
|
||||
let in_seq_len = 1;
|
||||
let kv_seq_len = 1024;
|
||||
|
||||
let attn_mask = vec![vec![vec![D::zero(); kv_seq_len]; in_seq_len]; batch_size];
|
||||
let size_in_bytes = batch_size * in_seq_len * kv_seq_len * D::DTYPE.size_in_bytes();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(size_in_bytes as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let attn_masks = vec![attn_mask.clone(); iters as usize];
|
||||
let start = Instant::now();
|
||||
for attn_mask in attn_masks.into_iter() {
|
||||
let tensor = Tensor::new(black_box(attn_mask), device).unwrap();
|
||||
black_box(tensor);
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
run_copy_mask_benchmark::<f32>(c, &device, "copy_mask");
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -1,5 +1,6 @@
|
||||
pub(crate) mod affine;
|
||||
pub(crate) mod conv_transpose2d;
|
||||
pub(crate) mod copy;
|
||||
pub(crate) mod matmul;
|
||||
pub(crate) mod qmatmul;
|
||||
pub(crate) mod random;
|
||||
|
@ -71,15 +71,27 @@ pub trait BackendStorage: Sized {
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
|
||||
|
||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
|
||||
fn scatter_add(
|
||||
&self,
|
||||
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self>;
|
||||
) -> Result<()>;
|
||||
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<()>;
|
||||
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
|
||||
fn index_add(
|
||||
&self,
|
||||
@ -113,6 +125,8 @@ pub trait BackendStorage: Sized {
|
||||
_src_offset: usize,
|
||||
_dst_offset: usize,
|
||||
) -> Result<()>;
|
||||
|
||||
fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()>;
|
||||
}
|
||||
|
||||
pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
@ -127,8 +141,6 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
|
||||
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
||||
|
||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
||||
|
||||
/// # Safety
|
||||
/// This function is unsafe as it doesn't initialize the underlying data store.
|
||||
/// The caller should ensure that the data is properly initialized as early as possible
|
||||
|
@ -53,6 +53,7 @@ impl Tensor {
|
||||
} else if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::IndexAdd(t1, t2, t3, _)
|
||||
| Op::Scatter(t1, t2, t3, _)
|
||||
| Op::ScatterAdd(t1, t2, t3, _)
|
||||
| Op::CustomOp3(t1, t2, t3, _)
|
||||
| Op::WhereCond(t1, t2, t3) => {
|
||||
@ -419,7 +420,7 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
|
||||
}
|
||||
Op::ScatterAdd(init, indexes, src, dim) => {
|
||||
Op::Scatter(init, indexes, src, dim) => {
|
||||
let init_sum_grad = grads.or_insert(init)?;
|
||||
*init_sum_grad = init_sum_grad.add(&grad)?;
|
||||
|
||||
@ -427,6 +428,16 @@ impl Tensor {
|
||||
let src_sum_grad = grads.or_insert(src)?;
|
||||
*src_sum_grad = src_sum_grad.add(&src_grad)?;
|
||||
}
|
||||
Op::ScatterAdd(init, indexes, src, dim) => {
|
||||
let init_sum_grad = grads.or_insert(init)?;
|
||||
let mask = init.ones_like()?;
|
||||
let mask = mask.scatter(indexes, &mask.zeros_like()?, *dim)?;
|
||||
*init_sum_grad = init_sum_grad.add(&grad.mul(&mask)?)?;
|
||||
|
||||
let src_grad = grad.gather(indexes, *dim)?;
|
||||
let src_sum_grad = grads.or_insert(src)?;
|
||||
*src_sum_grad = src_sum_grad.add(&src_grad)?;
|
||||
}
|
||||
Op::IndexAdd(init, indexes, src, dim) => {
|
||||
let init_sum_grad = grads.or_insert(init)?;
|
||||
*init_sum_grad = init_sum_grad.add(&grad)?;
|
||||
|
@ -7,7 +7,7 @@ use rayon::prelude::*;
|
||||
|
||||
mod utils;
|
||||
pub use utils::{
|
||||
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8,
|
||||
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2InPlace, Map2U8,
|
||||
};
|
||||
|
||||
const USE_IM2COL_CONV1D: bool = true;
|
||||
@ -483,17 +483,22 @@ impl<I: IntDType> Map1 for Gather<'_, I> {
|
||||
let start_dst_idx = start_dst_idx + i * dst_right_len;
|
||||
for right_i in 0..dst_right_len {
|
||||
let dst_idx = start_dst_idx + right_i;
|
||||
let index = ids[dst_idx].as_usize();
|
||||
if index >= src_dim_len {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
size: src_dim_len,
|
||||
op: "gather",
|
||||
let index = ids[dst_idx];
|
||||
if index == I::max_value() {
|
||||
dst[dst_idx] = T::zero();
|
||||
} else {
|
||||
let index = index.as_usize();
|
||||
if index >= src_dim_len {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
size: src_dim_len,
|
||||
op: "gather",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
.bt())?
|
||||
let src_idx = start_src_idx + index * src_right_len + right_i;
|
||||
dst[dst_idx] = src[src_idx]
|
||||
}
|
||||
let src_idx = start_src_idx + index * src_right_len + right_i;
|
||||
dst[dst_idx] = src[src_idx]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -535,45 +540,89 @@ impl<I: IntDType> Map1 for IndexSelect<'_, I> {
|
||||
let start_src_idx = left_i * right_len * src_dim;
|
||||
let start_dst_idx = left_i * right_len * n_ids;
|
||||
for i in 0..n_ids {
|
||||
let index = self.ids[self.ids_l.start_offset() + stride_ids * i].as_usize();
|
||||
if index >= src_dim {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
size: src_dim,
|
||||
op: "index-select",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let start_src_idx = start_src_idx + index * right_len;
|
||||
let start_dst_idx = start_dst_idx + i * right_len;
|
||||
dst[start_dst_idx..start_dst_idx + right_len]
|
||||
.copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
|
||||
let index = self.ids[self.ids_l.start_offset() + stride_ids * i];
|
||||
if index == I::max_value() {
|
||||
dst[start_dst_idx..start_dst_idx + right_len].fill(T::zero());
|
||||
} else {
|
||||
let index = index.as_usize();
|
||||
if index >= src_dim {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
size: src_dim,
|
||||
op: "index-select",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let start_src_idx = start_src_idx + index * right_len;
|
||||
dst[start_dst_idx..start_dst_idx + right_len]
|
||||
.copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct ScatterAdd<'a, I: IntDType> {
|
||||
trait ElemUpdate {
|
||||
fn f<T: WithDType>(dst: &mut T, src: T);
|
||||
}
|
||||
|
||||
struct Set;
|
||||
struct Add;
|
||||
|
||||
impl ElemUpdate for Set {
|
||||
fn f<T: WithDType>(dst: &mut T, src: T) {
|
||||
*dst = src
|
||||
}
|
||||
}
|
||||
|
||||
impl ElemUpdate for Add {
|
||||
fn f<T: WithDType>(dst: &mut T, src: T) {
|
||||
*dst += src
|
||||
}
|
||||
}
|
||||
|
||||
struct Scatter<'a, I: IntDType, M: ElemUpdate> {
|
||||
ids: &'a [I],
|
||||
ids_l: &'a Layout,
|
||||
dim: usize,
|
||||
_phantom: std::marker::PhantomData<M>,
|
||||
}
|
||||
|
||||
impl<I: IntDType> Map2 for ScatterAdd<'_, I> {
|
||||
const OP: &'static str = "scatter-add";
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
let dst_len = l1.shape().elem_count();
|
||||
let mut dst = vec![T::zero(); dst_len];
|
||||
copy_strided_src_(v1, &mut dst, 0, l1);
|
||||
impl<'a, I: IntDType, M: ElemUpdate> Scatter<'a, I, M> {
|
||||
fn new(ids: &'a [I], ids_l: &'a Layout, dim: usize) -> Self {
|
||||
Self {
|
||||
ids,
|
||||
ids_l,
|
||||
dim,
|
||||
_phantom: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: IntDType, M: ElemUpdate> Map2InPlace for Scatter<'_, I, M> {
|
||||
const OP: &'static str = "scatter";
|
||||
fn f<T: WithDType>(
|
||||
&self,
|
||||
dst: &mut [T],
|
||||
dst_l: &Layout,
|
||||
src: &[T],
|
||||
src_l: &Layout,
|
||||
) -> Result<()> {
|
||||
let dst = match dst_l.contiguous_offsets() {
|
||||
None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||
Some((o1, o2)) => &mut dst[o1..o2],
|
||||
};
|
||||
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
};
|
||||
|
||||
let dim = self.dim;
|
||||
let ids_dims = self.ids_l.dims();
|
||||
let dst_dims = l1.dims();
|
||||
let dst_dims = dst_l.dims();
|
||||
let dst_dim_len = dst_dims[dim];
|
||||
let dst_right_len: usize = dst_dims[dim + 1..].iter().product();
|
||||
|
||||
@ -592,7 +641,11 @@ impl<I: IntDType> Map2 for ScatterAdd<'_, I> {
|
||||
let start_ids_idx = start_ids_idx + i * ids_right_len;
|
||||
for right_i in 0..dst_right_len {
|
||||
let ids_idx = start_ids_idx + right_i;
|
||||
let index = ids[ids_idx].as_usize();
|
||||
let index = ids[ids_idx];
|
||||
if index == I::max_value() {
|
||||
continue;
|
||||
}
|
||||
let index = index.as_usize();
|
||||
if index >= dst_dim_len {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
@ -602,12 +655,12 @@ impl<I: IntDType> Map2 for ScatterAdd<'_, I> {
|
||||
.bt())?
|
||||
}
|
||||
let dst_idx = start_dst_idx + index * dst_right_len + right_i;
|
||||
dst[dst_idx] += src[ids_idx]
|
||||
M::f(&mut dst[dst_idx], src[ids_idx])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(dst)
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@ -635,6 +688,9 @@ impl<I: IntDType> Map2 for IndexAdd<'_, I> {
|
||||
let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
|
||||
if dim == 0 {
|
||||
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
|
||||
if *dst_idx == I::max_value() {
|
||||
continue;
|
||||
}
|
||||
let dst_idx = dst_idx.as_usize();
|
||||
if dst_idx >= max_idx {
|
||||
Err(Error::InvalidIndex {
|
||||
@ -653,6 +709,9 @@ impl<I: IntDType> Map2 for IndexAdd<'_, I> {
|
||||
}
|
||||
} else {
|
||||
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
|
||||
if *dst_idx == I::max_value() {
|
||||
continue;
|
||||
}
|
||||
let dst_idx = dst_idx.as_usize();
|
||||
if dst_idx >= max_idx {
|
||||
Err(Error::InvalidIndex {
|
||||
@ -2381,19 +2440,36 @@ impl BackendStorage for CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
match ids {
|
||||
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
Self::U8(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
Self::U32(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
Self::I64(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter").bt()),
|
||||
}
|
||||
}
|
||||
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<()> {
|
||||
match ids {
|
||||
Self::U8(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
Self::U32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
Self::I64(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()),
|
||||
}
|
||||
}
|
||||
@ -2454,6 +2530,48 @@ impl BackendStorage for CpuStorage {
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
Ok(self.clone())
|
||||
}
|
||||
|
||||
fn const_set(&mut self, s: crate::scalar::Scalar, l: &Layout) -> Result<()> {
|
||||
use crate::scalar::Scalar;
|
||||
fn set<T: crate::WithDType>(src: &mut [T], l: &Layout, s: T) {
|
||||
match l.strided_blocks() {
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||
src[start_offset..start_offset + len].fill(s)
|
||||
}
|
||||
crate::StridedBlocks::MultipleBlocks {
|
||||
block_start_index,
|
||||
block_len: 1,
|
||||
} => {
|
||||
for src_index in block_start_index {
|
||||
src[src_index] = s
|
||||
}
|
||||
}
|
||||
crate::StridedBlocks::MultipleBlocks {
|
||||
block_start_index,
|
||||
block_len,
|
||||
} => {
|
||||
for src_index in block_start_index {
|
||||
src[src_index..src_index + block_len].fill(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
match (self, s) {
|
||||
(Self::BF16(storage), Scalar::BF16(v)) => set(storage, l, v),
|
||||
(Self::F16(storage), Scalar::F16(v)) => set(storage, l, v),
|
||||
(Self::F32(storage), Scalar::F32(v)) => set(storage, l, v),
|
||||
(Self::F64(storage), Scalar::F64(v)) => set(storage, l, v),
|
||||
(Self::U8(storage), Scalar::U8(v)) => set(storage, l, v),
|
||||
(Self::U32(storage), Scalar::U32(v)) => set(storage, l, v),
|
||||
(Self::I64(storage), Scalar::I64(v)) => set(storage, l, v),
|
||||
(st, s) => crate::bail!(
|
||||
"const_set dtype mismatch, expected {:?} but got {:?}",
|
||||
st.dtype(),
|
||||
s
|
||||
),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendDevice for CpuDevice {
|
||||
@ -2628,20 +2746,6 @@ impl BackendDevice for CpuDevice {
|
||||
Ok(storage)
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let storage = match dtype {
|
||||
DType::U8 => CpuStorage::U8(vec![1u8; elem_count]),
|
||||
DType::U32 => CpuStorage::U32(vec![1u32; elem_count]),
|
||||
DType::I64 => CpuStorage::I64(vec![1i64; elem_count]),
|
||||
DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]),
|
||||
DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]),
|
||||
DType::F32 => CpuStorage::F32(vec![1f32; elem_count]),
|
||||
DType::F64 => CpuStorage::F64(vec![1f64; elem_count]),
|
||||
};
|
||||
Ok(storage)
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let storage = match dtype {
|
||||
|
@ -58,6 +58,30 @@ pub trait Map2 {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2InPlace {
|
||||
const OP: &'static str;
|
||||
fn f<T: WithDType>(&self, v1: &mut [T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<()>;
|
||||
|
||||
fn map(&self, v1: &mut C, l1: &Layout, v2: &C, l2: &Layout) -> Result<()> {
|
||||
match (v1, v2) {
|
||||
(C::U8(v1), C::U8(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::U32(v1), C::U32(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::I64(v1), C::I64(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::BF16(v1), C::BF16(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::F16(v1), C::F16(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::F32(v1), C::F32(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::F64(v1), C::F64(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(v1, v2) => Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: v1.dtype(),
|
||||
rhs: v2.dtype(),
|
||||
op: Self::OP,
|
||||
}
|
||||
.bt())?,
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2U8 {
|
||||
const OP: &'static str;
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
|
||||
|
@ -2,7 +2,7 @@ use crate::backend::BackendDevice;
|
||||
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
||||
pub use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::driver::{CudaFunction, LaunchConfig, PushKernelArg};
|
||||
use cudarc::driver::CudaFunction;
|
||||
use half::{bf16, f16};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
@ -188,100 +188,6 @@ impl CudaDevice {
|
||||
self.id
|
||||
}
|
||||
|
||||
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let slice = match dtype {
|
||||
DType::U8 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u8>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_u8", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as u8;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u32>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_u32", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as u32;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<i64>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_i64", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as i64;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<bf16>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = bf16::from_f64(v);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f16>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_f16", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = f16::from_f64(v);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f32>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_f32", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as f32;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f64>(elem_count) }?;
|
||||
let func = self.get_or_load_func("fill_f64", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_or_load_custom_func(
|
||||
&self,
|
||||
fn_name: &str,
|
||||
@ -504,10 +410,6 @@ impl BackendDevice for CudaDevice {
|
||||
})
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
self.const_impl(1., shape, dtype)
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let slice = match dtype {
|
||||
|
@ -2,7 +2,7 @@
|
||||
//!
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{builder_arg as barg, CpuStorage, DType, Layout, Result, Shape, WithDType};
|
||||
use crate::{builder_arg as barg, CpuStorage, DType, Layout, Result, WithDType};
|
||||
pub use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||
@ -34,6 +34,21 @@ impl<T: DeviceRepr> SlicePtrOrNull<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::scalar::Scalar {
|
||||
pub fn builder_arg<'a, 'b: 'a>(&'b self, builder: &mut cudarc::driver::LaunchArgs<'a>) {
|
||||
use crate::scalar::Scalar;
|
||||
match self {
|
||||
Scalar::U8(v) => builder.arg(v),
|
||||
Scalar::U32(v) => builder.arg(v),
|
||||
Scalar::I64(v) => builder.arg(v),
|
||||
Scalar::F32(v) => builder.arg(v),
|
||||
Scalar::F64(v) => builder.arg(v),
|
||||
Scalar::F16(v) => builder.arg(v),
|
||||
Scalar::BF16(v) => builder.arg(v),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
impl SlicePtrOrNull<usize> {
|
||||
pub fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
|
||||
let ds = if l.is_contiguous() {
|
||||
@ -395,7 +410,7 @@ impl Map1 for IndexSelect<'_> {
|
||||
CudaStorageSlice::U8(slice) => ("is_u8", slice_ptr(slice, ids_l.start_offset())),
|
||||
CudaStorageSlice::I64(slice) => ("is_i64", slice_ptr(slice, ids_l.start_offset())),
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "index_select ids should be u8 or u32",
|
||||
msg: "index_select ids should be u8, u32, or i64",
|
||||
expected: DType::U32,
|
||||
got: self.0.dtype(),
|
||||
})
|
||||
@ -492,7 +507,7 @@ impl Map2InPlace for IndexAdd<'_> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
dst_l: &Layout,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
@ -514,6 +529,10 @@ impl Map2InPlace for IndexAdd<'_> {
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let dst = match dst_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => dst.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
@ -521,7 +540,7 @@ impl Map2InPlace for IndexAdd<'_> {
|
||||
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let src_dim_sz = src_l.dims()[dim];
|
||||
let dst_dim_sz = dst_shape.dims()[dim];
|
||||
let dst_dim_sz = dst_l.dims()[dim];
|
||||
let ids_dim_sz = ids_l.dims()[0];
|
||||
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
||||
@ -529,7 +548,59 @@ impl Map2InPlace for IndexAdd<'_> {
|
||||
barg!(builder, ids);
|
||||
barg!(builder, ids_dim_sz);
|
||||
builder.arg(&src);
|
||||
builder.arg(dst);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
struct Scatter<'a>(&'a CudaStorage, &'a Layout, usize);
|
||||
impl Map2InPlace for Scatter<'_> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_l: &Layout,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<()> {
|
||||
let ids = &self.0;
|
||||
let ids_l = &self.1;
|
||||
let dim = self.2;
|
||||
let (ids_o1, _) = match ids_l.contiguous_offsets() {
|
||||
Some(o12) => o12,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||
};
|
||||
let (name, (ids, _guard)) = match &ids.slice {
|
||||
CudaStorageSlice::U32(slice) => ("s_u32", slice_ptr(slice, ids_o1)),
|
||||
CudaStorageSlice::I64(slice) => ("s_i64", slice_ptr(slice, ids_o1)),
|
||||
CudaStorageSlice::U8(slice) => ("s_u8", slice_ptr(slice, ids_o1)),
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "scatter ids should be u8/u32/i64",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let dst = match dst_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => dst.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||
};
|
||||
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let src_dim_sz = src_l.dims()[dim];
|
||||
let dst_dim_sz = dst_l.dims()[dim];
|
||||
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, ids);
|
||||
builder.arg(&src);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
@ -542,7 +613,7 @@ impl Map2InPlace for ScatterAdd<'_> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
dst_l: &Layout,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
@ -564,6 +635,10 @@ impl Map2InPlace for ScatterAdd<'_> {
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let dst = match dst_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => dst.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
@ -571,13 +646,13 @@ impl Map2InPlace for ScatterAdd<'_> {
|
||||
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let src_dim_sz = src_l.dims()[dim];
|
||||
let dst_dim_sz = dst_shape.dims()[dim];
|
||||
let dst_dim_sz = dst_l.dims()[dim];
|
||||
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, ids);
|
||||
builder.arg(&src);
|
||||
builder.arg(dst);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
@ -1235,6 +1310,36 @@ impl BackendStorage for CudaStorage {
|
||||
&self.device
|
||||
}
|
||||
|
||||
fn const_set(&mut self, s: crate::scalar::Scalar, layout: &Layout) -> Result<()> {
|
||||
let dev = &self.device;
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||
let src_o = layout.start_offset();
|
||||
let ((src, _guard_src), kernel_name) = match &mut self.slice {
|
||||
S::U8(s) => (slice_ptr(s, src_o), "const_set_u8"),
|
||||
S::U32(s) => (slice_ptr(s, src_o), "const_set_u32"),
|
||||
S::I64(s) => (slice_ptr(s, src_o), "const_set_i64"),
|
||||
S::BF16(s) => (slice_ptr(s, src_o), "const_set_bf16"),
|
||||
S::F16(s) => (slice_ptr(s, src_o), "const_set_f16"),
|
||||
S::F32(s) => (slice_ptr(s, src_o), "const_set_f32"),
|
||||
S::F64(s) => (slice_ptr(s, src_o), "const_set_f64"),
|
||||
};
|
||||
|
||||
let func = dev.get_or_load_func(kernel_name, &kernels::FILL)?;
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, el_count);
|
||||
barg!(builder, dims.len());
|
||||
ds.builder_arg(&mut builder);
|
||||
s.builder_arg(&mut builder);
|
||||
barg!(builder, src);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
@ -1793,20 +1898,29 @@ impl BackendStorage for CudaStorage {
|
||||
let slice = Gather(ids, ids_l, dim).map(&self.slice, &device, l)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
fn scatter_add(
|
||||
&self,
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
let device = self.device().clone();
|
||||
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
||||
Ok(acc)
|
||||
Scatter(ids, ids_l, dim).map(&mut self.slice, l, &src.slice, src_l, &device)
|
||||
}
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<()> {
|
||||
let device = self.device().clone();
|
||||
ScatterAdd(ids, ids_l, dim).map(&mut self.slice, l, &src.slice, src_l, &device)
|
||||
}
|
||||
fn index_add(
|
||||
&self,
|
||||
@ -1820,7 +1934,7 @@ impl BackendStorage for CudaStorage {
|
||||
let device = self.device().clone();
|
||||
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
||||
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l, &src.slice, src_l, &device)?;
|
||||
Ok(acc)
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
/// Helper functions to plug cuda kernels in candle.
|
||||
use crate::{Layout, Result, Shape, WithDType};
|
||||
use crate::{Layout, Result, WithDType};
|
||||
pub use cudarc;
|
||||
use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits};
|
||||
|
||||
@ -96,7 +96,7 @@ pub trait Map2InPlace {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
dst_l: &Layout,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
@ -105,19 +105,19 @@ pub trait Map2InPlace {
|
||||
fn map(
|
||||
&self,
|
||||
dst: &mut S,
|
||||
dst_s: &Shape,
|
||||
dst_l: &Layout,
|
||||
src: &S,
|
||||
src_l: &Layout,
|
||||
d: &CudaDevice,
|
||||
) -> Result<()> {
|
||||
match (dst, src) {
|
||||
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::U8(dst), S::U8(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::U32(dst), S::U32(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::I64(dst), S::I64(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::F16(dst), S::F16(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::F32(dst), S::F32(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::F64(dst), S::F64(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||
}
|
||||
}
|
||||
|
@ -103,7 +103,63 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4:
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: NdArray> NdArray for Vec<S> {
|
||||
impl<S: WithDType> NdArray for Vec<S> {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
Ok(Shape::from(self.len()))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
S::to_cpu_storage(self.as_slice())
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: WithDType> NdArray for Vec<&[S]> {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
if self.is_empty() {
|
||||
crate::bail!("empty array")
|
||||
}
|
||||
let n = self.len();
|
||||
let m = self[0].len();
|
||||
for v in self.iter() {
|
||||
if v.len() != m {
|
||||
crate::bail!("two elements have different len {m} {}", v.len())
|
||||
}
|
||||
}
|
||||
Ok(Shape::from((n, m)))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
let data = self.iter().copied().flatten().copied().collect::<Vec<_>>();
|
||||
S::to_cpu_storage_owned(data)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: WithDType> NdArray for Vec<Vec<S>> {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
if self.is_empty() {
|
||||
crate::bail!("empty array")
|
||||
}
|
||||
let n = self.len();
|
||||
let m = self[0].len();
|
||||
for v in self.iter() {
|
||||
if v.len() != m {
|
||||
crate::bail!("two elements have different len {m} {}", v.len())
|
||||
}
|
||||
}
|
||||
Ok(Shape::from((n, m)))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
let len: usize = self.iter().map(|v| v.len()).sum();
|
||||
let mut dst = Vec::with_capacity(len);
|
||||
for v in self.iter() {
|
||||
dst.extend(v.iter().copied());
|
||||
}
|
||||
S::to_cpu_storage_owned(dst)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: WithDType> NdArray for Vec<Vec<Vec<S>>> {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
if self.is_empty() {
|
||||
crate::bail!("empty array")
|
||||
@ -120,9 +176,57 @@ impl<S: NdArray> NdArray for Vec<S> {
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
// This allocates intermediary memory and shouldn't be necessary.
|
||||
let storages = self.iter().map(|v| v.to_cpu_storage()).collect::<Vec<_>>();
|
||||
CpuStorage::concat(storages.as_slice()).unwrap()
|
||||
if self.is_empty() {
|
||||
return S::to_cpu_storage_owned(vec![]);
|
||||
}
|
||||
let len: usize = self
|
||||
.iter()
|
||||
.map(|v| v.iter().map(|v| v.len()).sum::<usize>())
|
||||
.sum();
|
||||
let mut dst = Vec::with_capacity(len);
|
||||
for v1 in self.iter() {
|
||||
for v2 in v1.iter() {
|
||||
dst.extend(v2.iter().copied());
|
||||
}
|
||||
}
|
||||
S::to_cpu_storage_owned(dst)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: WithDType> NdArray for Vec<Vec<Vec<Vec<S>>>> {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
if self.is_empty() {
|
||||
crate::bail!("empty array")
|
||||
}
|
||||
let shape0 = self[0].shape()?;
|
||||
let n = self.len();
|
||||
for v in self.iter() {
|
||||
let shape = v.shape()?;
|
||||
if shape != shape0 {
|
||||
crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
|
||||
}
|
||||
}
|
||||
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
let len: usize = self
|
||||
.iter()
|
||||
.map(|v| {
|
||||
v.iter()
|
||||
.map(|v| v.iter().map(|v| v.len()).sum::<usize>())
|
||||
.sum::<usize>()
|
||||
})
|
||||
.sum();
|
||||
let mut dst = Vec::with_capacity(len);
|
||||
for v1 in self.iter() {
|
||||
for v2 in v1.iter() {
|
||||
for v3 in v2.iter() {
|
||||
dst.extend(v3.iter().copied());
|
||||
}
|
||||
}
|
||||
}
|
||||
S::to_cpu_storage_owned(dst)
|
||||
}
|
||||
}
|
||||
|
||||
@ -292,23 +396,6 @@ impl Device {
|
||||
self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
|
||||
}
|
||||
|
||||
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = CpuDevice.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
let storage = device.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
|
@ -107,6 +107,7 @@ pub trait WithDType:
|
||||
|
||||
fn from_f64(v: f64) -> Self;
|
||||
fn to_f64(self) -> f64;
|
||||
fn to_scalar(self) -> crate::scalar::Scalar;
|
||||
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>;
|
||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
|
||||
|
||||
@ -131,6 +132,10 @@ macro_rules! with_dtype {
|
||||
$to_f64(self)
|
||||
}
|
||||
|
||||
fn to_scalar(self) -> crate::scalar::Scalar {
|
||||
crate::scalar::Scalar::$dtype(self)
|
||||
}
|
||||
|
||||
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> {
|
||||
CpuStorageRef::$dtype(data)
|
||||
}
|
||||
@ -175,7 +180,7 @@ with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
|
||||
with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
|
||||
with_dtype!(f64, F64, |v: f64| v, |v: f64| v);
|
||||
|
||||
pub trait IntDType: WithDType {
|
||||
pub trait IntDType: WithDType + num_traits::Bounded {
|
||||
fn is_true(&self) -> bool;
|
||||
fn as_usize(&self) -> usize;
|
||||
}
|
||||
|
@ -37,6 +37,10 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
@ -124,15 +128,27 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
@ -214,10 +230,6 @@ impl crate::backend::BackendDevice for CudaDevice {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
@ -41,6 +41,10 @@ impl crate::backend::BackendStorage for MetalStorage {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
@ -128,15 +132,27 @@ impl crate::backend::BackendStorage for MetalStorage {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
@ -218,10 +234,6 @@ impl crate::backend::BackendDevice for MetalDevice {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
@ -226,8 +226,8 @@ where
|
||||
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
|
||||
///
|
||||
/// let d = a.i((2.., ..))?;
|
||||
/// assert_eq!(c.shape().dims(), &[2]);
|
||||
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
|
||||
/// assert_eq!(d.shape().dims(), &[1, 3]);
|
||||
/// assert_eq!(d.to_vec2::<f32>()?, &[[6., 7., 8.]]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
fn i(&self, (a, b): (A, B)) -> Result<Tensor, Error> {
|
||||
|
@ -413,6 +413,100 @@ impl BackendStorage for MetalStorage {
|
||||
self.binary(name, rhs, lhs_l, rhs_l)
|
||||
}
|
||||
|
||||
fn const_set(&mut self, s: crate::scalar::Scalar, l: &Layout) -> Result<()> {
|
||||
use crate::scalar::Scalar;
|
||||
fn set<S: crate::WithDType + candle_metal_kernels::utils::EncoderParam>(
|
||||
self_: &mut MetalStorage,
|
||||
s: S,
|
||||
l: &Layout,
|
||||
) -> Result<()> {
|
||||
let device = self_.device();
|
||||
let dtype = self_.dtype;
|
||||
let shape = l.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let command_buffer = device.command_buffer()?;
|
||||
command_buffer.set_label("const-set");
|
||||
let dst = buffer_o(&self_.buffer, l, self_.dtype);
|
||||
|
||||
match (el_count % 2, dtype, l.is_contiguous()) {
|
||||
(0, DType::BF16 | DType::F16, true) => {
|
||||
use candle_metal_kernels::unary::contiguous_tiled;
|
||||
let kernel_name = match dtype {
|
||||
DType::F16 => contiguous_tiled::const_set::HALF,
|
||||
DType::BF16 => contiguous_tiled::const_set::BFLOAT,
|
||||
_ => crate::bail!("internal bug in const_set"),
|
||||
};
|
||||
candle_metal_kernels::call_const_set_contiguous_tiled(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
s,
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
(_, _, true) => {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
let kernel_name = match dtype {
|
||||
DType::F16 => contiguous::const_set::HALF,
|
||||
DType::BF16 => contiguous::const_set::BFLOAT,
|
||||
DType::F32 => contiguous::const_set::FLOAT,
|
||||
DType::I64 => contiguous::const_set::I64,
|
||||
DType::U32 => contiguous::const_set::U32,
|
||||
DType::U8 => contiguous::const_set::U8,
|
||||
DType::F64 => crate::bail!("unsupported const-set f64"),
|
||||
};
|
||||
candle_metal_kernels::call_const_set_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
s,
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
(_, _, false) => {
|
||||
use candle_metal_kernels::unary::strided;
|
||||
let kernel_name = match dtype {
|
||||
DType::F16 => strided::const_set::HALF,
|
||||
DType::BF16 => strided::const_set::BFLOAT,
|
||||
DType::F32 => strided::const_set::FLOAT,
|
||||
DType::I64 => strided::const_set::I64,
|
||||
DType::U32 => strided::const_set::U32,
|
||||
DType::U8 => strided::const_set::U8,
|
||||
DType::F64 => crate::bail!("unsupported const-set f64"),
|
||||
};
|
||||
candle_metal_kernels::call_const_set_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
l.dims(),
|
||||
s,
|
||||
l.stride(),
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
match (self.dtype, s) {
|
||||
(DType::U8, Scalar::U8(s)) => set(self, s, l),
|
||||
(DType::U32, Scalar::U32(s)) => set(self, s, l),
|
||||
(DType::I64, Scalar::I64(s)) => set(self, s, l),
|
||||
(DType::F16, Scalar::F16(s)) => set(self, s, l),
|
||||
(DType::BF16, Scalar::BF16(s)) => set(self, s, l),
|
||||
(DType::F32, Scalar::F32(s)) => set(self, s, l),
|
||||
(DType::F64, Scalar::F64(s)) => set(self, s, l),
|
||||
_ => crate::bail!("dtype mismatch, expected {:?}, got {:?}", self.dtype, s),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
let device = self.device();
|
||||
let shape = layout.shape();
|
||||
@ -1332,18 +1426,65 @@ impl BackendStorage for MetalStorage {
|
||||
Ok(Self::new(buffer, device.clone(), dst_el, dtype))
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
if !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
||||
) -> Result<()> {
|
||||
if !l.is_contiguous() || !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous { op: "scatter" }.bt());
|
||||
};
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U8, DType::F32) => "s_u8_f32",
|
||||
(DType::U8, DType::F16) => "s_u8_f16",
|
||||
(DType::U8, DType::BF16) => "s_u8_bf16",
|
||||
(DType::U32, DType::U32) => "s_u32_u32",
|
||||
(DType::U32, DType::F32) => "s_u32_f32",
|
||||
(DType::U32, DType::F16) => "s_u32_f16",
|
||||
(DType::U32, DType::BF16) => "s_u32_bf16",
|
||||
(DType::I64, DType::F32) => "s_i64_f32",
|
||||
(DType::I64, DType::F16) => "s_i64_f16",
|
||||
(DType::I64, DType::BF16) => "s_i64_bf16",
|
||||
_ => Err(MetalError::UnexpectedDType {
|
||||
msg: "scatter ids should be u8/u32/i64",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let dst = buffer_o(&self.buffer, l, self.dtype);
|
||||
let src = buffer_o(&src.buffer, src_l, src.dtype);
|
||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||
candle_metal_kernels::call_scatter(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
src_l.dims(),
|
||||
l.dims(),
|
||||
dim,
|
||||
src,
|
||||
ids,
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<()> {
|
||||
if !l.is_contiguous() || !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt());
|
||||
};
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
@ -1364,9 +1505,10 @@ impl BackendStorage for MetalStorage {
|
||||
})?,
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let dst = buffer_o(&self.buffer, l, self.dtype);
|
||||
let src = buffer_o(&src.buffer, src_l, src.dtype);
|
||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||
candle_metal_kernels::call_scatter_add(
|
||||
candle_metal_kernels::call_scatter(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
@ -1376,10 +1518,10 @@ impl BackendStorage for MetalStorage {
|
||||
dim,
|
||||
src,
|
||||
ids,
|
||||
&acc.buffer,
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(acc)
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
@ -1513,50 +1655,32 @@ impl BackendStorage for MetalStorage {
|
||||
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
command_buffer.set_label("matmul");
|
||||
if self.dtype == DType::BF16 {
|
||||
candle_metal_kernels::call_mlx_gemm(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
candle_metal_kernels::GemmDType::BF16,
|
||||
(b, m, n, k),
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&self.buffer,
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
let dtype = match self.dtype {
|
||||
DType::F32 => candle_metal_kernels::GemmDType::F32,
|
||||
DType::F16 => candle_metal_kernels::GemmDType::F16,
|
||||
DType::BF16 => candle_metal_kernels::GemmDType::BF16,
|
||||
dtype => {
|
||||
return Err(MetalError::Message(format!(
|
||||
"mlx matmul doesn't support {dtype:?}"
|
||||
))
|
||||
.into())
|
||||
}
|
||||
};
|
||||
candle_metal_kernels::call_mlx_gemm(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
dtype,
|
||||
(b, m, n, k),
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&self.buffer,
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
let dtype = match self.dtype {
|
||||
DType::F32 => candle_metal_kernels::GemmDType::F32,
|
||||
DType::F16 => candle_metal_kernels::GemmDType::F16,
|
||||
DType::BF16 => candle_metal_kernels::GemmDType::BF16,
|
||||
dtype => {
|
||||
return Err(
|
||||
MetalError::Message(format!("mlx matmul doesn't support {dtype:?}")).into(),
|
||||
)
|
||||
}
|
||||
};
|
||||
candle_metal_kernels::call_mlx_gemm(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
dtype,
|
||||
(b, m, n, k),
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&self.buffer,
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(Self::new(
|
||||
buffer,
|
||||
self.device.clone(),
|
||||
@ -1965,40 +2089,6 @@ impl BackendDevice for MetalDevice {
|
||||
))
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||
let name = match dtype {
|
||||
DType::U8 => "fill_u8",
|
||||
DType::U32 => "fill_u32",
|
||||
DType::I64 => "fill_i64",
|
||||
DType::F16 => "fill_f16",
|
||||
DType::BF16 => "fill_bf16",
|
||||
DType::F32 => "fill_f32",
|
||||
DType::F64 => {
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
|
||||
return self.storage_from_cpu_storage(&cpu_storage);
|
||||
}
|
||||
};
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
candle_metal_kernels::call_const_fill(
|
||||
&self.device,
|
||||
&command_buffer,
|
||||
&self.kernels,
|
||||
name,
|
||||
shape.elem_count(),
|
||||
&buffer,
|
||||
1.,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(MetalStorage::new(
|
||||
buffer,
|
||||
self.clone(),
|
||||
shape.elem_count(),
|
||||
dtype,
|
||||
))
|
||||
}
|
||||
|
||||
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||
let (count, buffer) = match T::cpu_storage_ref(s) {
|
||||
CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
|
@ -80,6 +80,7 @@ pub enum Op {
|
||||
Reduce(Tensor, ReduceOp, Vec<usize>),
|
||||
Matmul(Tensor, Tensor),
|
||||
Gather(Tensor, Tensor, usize),
|
||||
Scatter(Tensor, Tensor, Tensor, usize),
|
||||
ScatterAdd(Tensor, Tensor, Tensor, usize),
|
||||
IndexSelect(Tensor, Tensor, usize),
|
||||
IndexAdd(Tensor, Tensor, Tensor, usize),
|
||||
|
@ -1,6 +1,74 @@
|
||||
//! TensorScalar Enum and Trait
|
||||
//!
|
||||
use crate::{Result, Tensor, WithDType};
|
||||
use crate::{DType, Result, Tensor, WithDType};
|
||||
use half::{bf16, f16};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum Scalar {
|
||||
U8(u8),
|
||||
U32(u32),
|
||||
I64(i64),
|
||||
BF16(bf16),
|
||||
F16(f16),
|
||||
F32(f32),
|
||||
F64(f64),
|
||||
}
|
||||
|
||||
impl<T: WithDType> From<T> for Scalar {
|
||||
fn from(value: T) -> Self {
|
||||
value.to_scalar()
|
||||
}
|
||||
}
|
||||
|
||||
impl Scalar {
|
||||
pub fn zero(dtype: DType) -> Self {
|
||||
match dtype {
|
||||
DType::U8 => Scalar::U8(0),
|
||||
DType::U32 => Scalar::U32(0),
|
||||
DType::I64 => Scalar::I64(0),
|
||||
DType::BF16 => Scalar::BF16(bf16::ZERO),
|
||||
DType::F16 => Scalar::F16(f16::ZERO),
|
||||
DType::F32 => Scalar::F32(0.0),
|
||||
DType::F64 => Scalar::F64(0.0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn one(dtype: DType) -> Self {
|
||||
match dtype {
|
||||
DType::U8 => Scalar::U8(1),
|
||||
DType::U32 => Scalar::U32(1),
|
||||
DType::I64 => Scalar::I64(1),
|
||||
DType::BF16 => Scalar::BF16(bf16::ONE),
|
||||
DType::F16 => Scalar::F16(f16::ONE),
|
||||
DType::F32 => Scalar::F32(1.0),
|
||||
DType::F64 => Scalar::F64(1.0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
match self {
|
||||
Scalar::U8(_) => DType::U8,
|
||||
Scalar::U32(_) => DType::U32,
|
||||
Scalar::I64(_) => DType::I64,
|
||||
Scalar::BF16(_) => DType::BF16,
|
||||
Scalar::F16(_) => DType::F16,
|
||||
Scalar::F32(_) => DType::F32,
|
||||
Scalar::F64(_) => DType::F64,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_f64(&self) -> f64 {
|
||||
match self {
|
||||
Scalar::U8(v) => *v as f64,
|
||||
Scalar::U32(v) => *v as f64,
|
||||
Scalar::I64(v) => *v as f64,
|
||||
Scalar::BF16(v) => v.to_f64(),
|
||||
Scalar::F16(v) => v.to_f64(),
|
||||
Scalar::F32(v) => *v as f64,
|
||||
Scalar::F64(v) => *v,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum TensorScalar {
|
||||
Tensor(Tensor),
|
||||
|
@ -1,5 +1,6 @@
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::op::{self, CmpOp, ReduceOp};
|
||||
use crate::scalar::Scalar;
|
||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
|
||||
use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
|
||||
|
||||
@ -73,6 +74,14 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn const_set(&mut self, v: Scalar, l: &Layout) -> Result<()> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => storage.const_set(v, l),
|
||||
Storage::Cuda(storage) => storage.const_set(v, l),
|
||||
Storage::Metal(storage) => storage.const_set(v, l),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
@ -619,32 +628,56 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn scatter_add(
|
||||
&self,
|
||||
pub(crate) fn scatter_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
indexes: &Self,
|
||||
indexes_l: &Layout,
|
||||
source: &Self,
|
||||
source_l: &Layout,
|
||||
d: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
self.same_device(indexes, "scatter-set")?;
|
||||
self.same_device(source, "scatter-set")?;
|
||||
match (self, indexes, source) {
|
||||
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
|
||||
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
|
||||
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn scatter_add(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
indexes: &Self,
|
||||
indexes_l: &Layout,
|
||||
source: &Self,
|
||||
source_l: &Layout,
|
||||
d: usize,
|
||||
) -> Result<()> {
|
||||
self.same_device(indexes, "scatter-add")?;
|
||||
self.same_device(source, "scatter-add")?;
|
||||
match (self, indexes, source) {
|
||||
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Metal(storage))
|
||||
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn index_add(
|
||||
|
@ -3,7 +3,7 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
|
||||
use crate::scalar::TensorOrScalar;
|
||||
use crate::shape::{Dim, Dims};
|
||||
use crate::shape::{Dim, Dims, ShapeWithOneHole};
|
||||
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
@ -185,7 +185,9 @@ impl Tensor {
|
||||
) -> Result<Self> {
|
||||
let none = BackpropOp::none();
|
||||
let shape = shape.into();
|
||||
let storage = device.ones(&shape, dtype)?;
|
||||
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
|
||||
let layout = Layout::contiguous(shape.clone());
|
||||
storage.const_set(crate::scalar::Scalar::one(dtype), &layout)?;
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
}
|
||||
|
||||
@ -202,6 +204,18 @@ impl Tensor {
|
||||
Self::ones_impl(shape, dtype, device, false)
|
||||
}
|
||||
|
||||
pub fn const_set(&self, value: crate::scalar::Scalar) -> Result<()> {
|
||||
self.storage_mut().const_set(value, self.layout())
|
||||
}
|
||||
|
||||
pub fn zero_set(&self) -> Result<()> {
|
||||
self.const_set(crate::scalar::Scalar::zero(self.dtype()))
|
||||
}
|
||||
|
||||
pub fn one_set(&self) -> Result<()> {
|
||||
self.const_set(crate::scalar::Scalar::one(self.dtype()))
|
||||
}
|
||||
|
||||
/// Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor.
|
||||
///
|
||||
/// ```rust
|
||||
@ -368,8 +382,7 @@ impl Tensor {
|
||||
Self::new_impl(array, shape, device, false)
|
||||
}
|
||||
|
||||
/// Returns a new tensor with all the elements having the same specified value. Note that
|
||||
/// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
|
||||
/// Returns a new tensor with all the elements having the same specified value.
|
||||
///```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let a = Tensor::full(3.5, (2, 4), &Device::Cpu)?;
|
||||
@ -384,7 +397,12 @@ impl Tensor {
|
||||
shape: S,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
Self::from_vec_impl(vec![value], (), device, false)?.broadcast_as(shape)
|
||||
let none = BackpropOp::none();
|
||||
let shape = shape.into();
|
||||
let mut storage = unsafe { device.alloc_uninit(&shape, D::DTYPE)? };
|
||||
let layout = Layout::contiguous(shape.clone());
|
||||
storage.const_set(value.to_scalar(), &layout)?;
|
||||
Ok(from_storage(storage, shape, none, false))
|
||||
}
|
||||
|
||||
/// Creates a new 1D tensor from an iterator.
|
||||
@ -452,17 +470,13 @@ impl Tensor {
|
||||
Self::from_vec_impl(data, len, device, false)
|
||||
}
|
||||
|
||||
pub(crate) fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
|
||||
pub(crate) fn from_vec_impl<S: ShapeWithOneHole, D: crate::WithDType>(
|
||||
data: Vec<D>,
|
||||
shape: S,
|
||||
device: &Device,
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
let buffer_size = data.len();
|
||||
if buffer_size != shape.elem_count() {
|
||||
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
||||
}
|
||||
let shape = shape.into_shape(data.len())?;
|
||||
let storage = device.storage_owned(data)?;
|
||||
let none = BackpropOp::none();
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
@ -481,7 +495,7 @@ impl Tensor {
|
||||
/// ]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
|
||||
pub fn from_vec<S: ShapeWithOneHole, D: crate::WithDType>(
|
||||
data: Vec<D>,
|
||||
shape: S,
|
||||
device: &Device,
|
||||
@ -502,17 +516,12 @@ impl Tensor {
|
||||
/// ]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
|
||||
pub fn from_slice<S: ShapeWithOneHole, D: crate::WithDType>(
|
||||
array: &[D],
|
||||
shape: S,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
let n: usize = shape.elem_count();
|
||||
let buffer_size: usize = array.len();
|
||||
if buffer_size != n {
|
||||
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
||||
}
|
||||
let shape = shape.into_shape(array.len())?;
|
||||
let storage = device.storage_from_slice(array)?;
|
||||
let none = BackpropOp::none();
|
||||
Ok(from_storage(storage, shape, none, false))
|
||||
@ -1226,6 +1235,83 @@ impl Tensor {
|
||||
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
|
||||
}
|
||||
|
||||
/// Computes the dot product of two 1D tensors.
|
||||
///
|
||||
/// - If inputs are 1D vectors (`[n]`), returns their scalar dot product.
|
||||
/// - Panics if shapes are not compatible
|
||||
/// - Not supported for integer dtypes
|
||||
///
|
||||
/// # Example (vectors)
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let t1 = Tensor::new(&[1.0, 2.0, 3.0], &Device::Cpu)?;
|
||||
/// let t2 = Tensor::new(&[4.0, 5.0, 6.0], &Device::Cpu)?;
|
||||
/// let res = t1.dot(&t2)?;
|
||||
/// assert_eq!(res.to_scalar::<f64>()?, 32.);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn dot(&self, rhs: &Self) -> Result<Self> {
|
||||
if self.dims().len() != 1 || rhs.dims().len() != 1 {
|
||||
return Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: self.shape().clone(),
|
||||
rhs: rhs.shape().clone(),
|
||||
op: "dot",
|
||||
});
|
||||
}
|
||||
|
||||
(self * rhs).and_then(|ret| ret.sum_all())
|
||||
}
|
||||
|
||||
/// Computes the **Frobenius norm** (L2 norm of all elements) of the tensor.
|
||||
/// - Output is `sqrt(sum(x^2))`.
|
||||
/// - Always returns a scalar (`[]` shape).
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let t = Tensor::new(&[[3., 4.], [0., 0.]], &Device::Cpu)?;
|
||||
/// let norm = t.norm()?;
|
||||
/// assert_eq!(norm.to_scalar::<f64>()?, 5.);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn norm(&self) -> Result<Self> {
|
||||
if self.dtype().is_int() {
|
||||
bail!("norm not supported for integer dtypes");
|
||||
}
|
||||
|
||||
self.sqr().and_then(|x| x.sum_all()).and_then(|x| x.sqrt())
|
||||
}
|
||||
|
||||
/// Performs strict matrix-vector multiplication (`[m, n] * [n] = [m]`).
|
||||
///
|
||||
/// - If `self` is a matrix (`[m, n]`) and `rhs` is a vector (`[n]`), returns a vector (`[m]`).
|
||||
/// - **No broadcasting**: Panics if `self` is not 2D or if `rhs` is not 1D with matching size.
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let mat = Tensor::new(&[[1., 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||
/// let vec = Tensor::new(&[1., 1., 1.], &Device::Cpu)?;
|
||||
/// let res = mat.mv(&vec)?;
|
||||
/// assert_eq!(res.to_vec1::<f64>()?, [6., 15.]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn mv(&self, rhs: &Self) -> Result<Self> {
|
||||
// Strict shape checks
|
||||
let lhs_dims = self.dims();
|
||||
let rhs_dims = rhs.dims();
|
||||
if lhs_dims.len() != 2 || rhs_dims.len() != 1 || lhs_dims[1] != rhs_dims[0] {
|
||||
return Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: self.shape().clone(),
|
||||
rhs: rhs.shape().clone(),
|
||||
op: "mv",
|
||||
});
|
||||
}
|
||||
|
||||
// Direct matmul after ensuring rhs is column vector
|
||||
self.matmul(&rhs.unsqueeze(1)?)?.squeeze(1)
|
||||
}
|
||||
|
||||
/// Returns the matrix-multiplication of the input tensor with the other provided tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
@ -1349,8 +1435,7 @@ impl Tensor {
|
||||
self.index_select(ids, 0)
|
||||
}
|
||||
|
||||
pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "scatter-add")?;
|
||||
fn scatter_checks(&self, indexes: &Self, source: &Self, dim: usize) -> Result<()> {
|
||||
let source_dims = source.dims();
|
||||
let self_dims = self.dims();
|
||||
let mismatch = if source_dims.len() != self_dims.len() {
|
||||
@ -1367,7 +1452,7 @@ impl Tensor {
|
||||
};
|
||||
if mismatch {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "scatter-add (self, src)",
|
||||
op: "scatter (self, src)",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
}
|
||||
@ -1375,13 +1460,44 @@ impl Tensor {
|
||||
}
|
||||
if indexes.dims() != source.dims() {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "scatter-add (indexes, src)",
|
||||
op: "scatter (indexes, src)",
|
||||
lhs: indexes.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let storage = self.storage().scatter_add(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn scatter<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "scatter")?;
|
||||
self.scatter_checks(indexes, source, dim)?;
|
||||
let shape = self.shape();
|
||||
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
let layout = Layout::contiguous(shape);
|
||||
storage.scatter_set(
|
||||
&layout,
|
||||
&indexes.storage(),
|
||||
indexes.layout(),
|
||||
&source.storage(),
|
||||
source.layout(),
|
||||
dim,
|
||||
)?;
|
||||
let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
|
||||
Op::Scatter(t1, t2, t3, dim)
|
||||
});
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
pub fn scatter_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
|
||||
if self.same_storage(source) {
|
||||
crate::bail!("cannot use slice_set when self and src share their storage")
|
||||
}
|
||||
let dim = dim.to_index(self.shape(), "scatter-set")?;
|
||||
self.scatter_checks(indexes, source, dim)?;
|
||||
self.storage_mut().scatter_set(
|
||||
self.layout(),
|
||||
&indexes.storage(),
|
||||
indexes.layout(),
|
||||
@ -1389,12 +1505,48 @@ impl Tensor {
|
||||
source.layout(),
|
||||
dim,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "scatter-add")?;
|
||||
self.scatter_checks(indexes, source, dim)?;
|
||||
let shape = self.shape();
|
||||
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
let layout = Layout::contiguous(shape);
|
||||
storage.scatter_add(
|
||||
&layout,
|
||||
&indexes.storage(),
|
||||
indexes.layout(),
|
||||
&source.storage(),
|
||||
source.layout(),
|
||||
dim,
|
||||
)?;
|
||||
let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
|
||||
Op::ScatterAdd(t1, t2, t3, dim)
|
||||
});
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
pub fn scatter_add_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
|
||||
if self.same_storage(source) {
|
||||
crate::bail!("cannot use slice_set when self and src share their storage")
|
||||
}
|
||||
let dim = dim.to_index(self.shape(), "scatter-add-set")?;
|
||||
self.scatter_checks(indexes, source, dim)?;
|
||||
self.storage_mut().scatter_add(
|
||||
self.layout(),
|
||||
&indexes.storage(),
|
||||
indexes.layout(),
|
||||
&source.storage(),
|
||||
source.layout(),
|
||||
dim,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
|
||||
pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "slice-scatter")?;
|
||||
@ -2197,7 +2349,7 @@ impl Tensor {
|
||||
///
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn reshape<S: crate::shape::ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
|
||||
pub fn reshape<S: ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
|
||||
let shape = s.into_shape(self.elem_count())?;
|
||||
if shape.elem_count() != self.elem_count() {
|
||||
return Err(Error::ShapeMismatchBinaryOp {
|
||||
|
@ -241,7 +241,7 @@ impl Tensor {
|
||||
/// `self` and `src` must have the same shape except on dimension `dim` where the `self` size
|
||||
/// has to be greater than or equal to `offset` plus the `src` size.
|
||||
///
|
||||
/// Note that this modifies `self` in place and as such is not compatibel with
|
||||
/// Note that this modifies `self` in place and as such is not compatible with
|
||||
/// back-propagation.
|
||||
pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
|
||||
let dim = dim.to_index(self.shape(), "slice-set")?;
|
||||
|
@ -82,6 +82,26 @@ fn broadcast_matmul(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tensor_dot() -> Result<()> {
|
||||
let lhs = Tensor::new(&[1., 2., 3.], &Device::Cpu)?;
|
||||
let rhs = Tensor::new(&[4., 5., 6.], &Device::Cpu)?;
|
||||
let expected = Tensor::new(32., &Device::Cpu)?;
|
||||
let dot_ret = lhs.dot(&rhs)?;
|
||||
candle_core::test_utils::assert_tensor_eq(&dot_ret, &expected)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tensor_mv() -> Result<()> {
|
||||
let mat = Tensor::new(&[[1., 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||
let vec = Tensor::new(&[1., 1., 1.], &Device::Cpu)?;
|
||||
let expected = Tensor::new(&[6., 15.], &Device::Cpu)?;
|
||||
let mv_ret = mat.mv(&vec)?;
|
||||
candle_core::test_utils::assert_tensor_eq(&mv_ret, &expected)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/candle/issues/1948
|
||||
fn squeeze_mm(device: &Device) -> Result<()> {
|
||||
let seq_len = 8_usize;
|
||||
|
@ -25,10 +25,12 @@ fn ones(device: &Device) -> Result<()> {
|
||||
Tensor::ones((2, 3), DType::F32, device)?.to_vec2::<f32>()?,
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
);
|
||||
if !device.is_metal() {
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
);
|
||||
}
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::F16, device)?.to_vec2::<half::f16>()?,
|
||||
[
|
||||
@ -63,6 +65,26 @@ fn ones(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn full(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::zeros((3, 4), DType::U32, device)?;
|
||||
tensor.const_set(42u32.into())?;
|
||||
assert_eq!(
|
||||
tensor.to_vec2::<u32>()?,
|
||||
[[42, 42, 42, 42], [42, 42, 42, 42], [42, 42, 42, 42]]
|
||||
);
|
||||
tensor.i((.., 2))?.const_set(1337u32.into())?;
|
||||
assert_eq!(
|
||||
tensor.to_vec2::<u32>()?,
|
||||
[[42, 42, 1337, 42], [42, 42, 1337, 42], [42, 42, 1337, 42]]
|
||||
);
|
||||
tensor.i((2, ..))?.const_set(1u32.into())?;
|
||||
assert_eq!(
|
||||
tensor.to_vec2::<u32>()?,
|
||||
[[42, 42, 1337, 42], [42, 42, 1337, 42], [1, 1, 1, 1]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn const_set(device: &Device) -> Result<()> {
|
||||
assert_eq!(
|
||||
Tensor::full(42u32, (2, 3), device)?.to_vec2::<u32>()?,
|
||||
[[42, 42, 42], [42, 42, 42]],
|
||||
@ -823,9 +845,37 @@ fn embeddings(device: &Device) -> Result<()> {
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||
let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||
let ids = Tensor::new(&[u32::MAX, 2u32, u32::MAX], device)?;
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 0.0], [4.0, 5.0], [0.0, 0.0]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select_fail() -> Result<()> {
|
||||
// Check that an error is properly reported on out of bounds.
|
||||
let ids = Tensor::new(&[4u32, 2u32, 1u32], &Device::Cpu)?;
|
||||
let t = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], &Device::Cpu)?;
|
||||
let hs = t.index_select(&ids, 0);
|
||||
assert!(hs.is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// The test below triggers an unwinding panic as there is a panic within the
|
||||
// #[cfg(feature = "cuda")]
|
||||
// #[test]
|
||||
// #[should_panic]
|
||||
// fn index_select_fail_gpu() {
|
||||
// // Check that a panic happens for out of bounds in cuda
|
||||
// if let Ok(device) = Device::new_cuda(0) {
|
||||
// if let Ok(ids) = Tensor::new(&[4u32, 2u32, 1u32], &device) {
|
||||
// if let Ok(t) = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], &device) {
|
||||
// let _ = t.index_select(&ids, 0);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
fn cmp(device: &Device) -> Result<()> {
|
||||
let t1 = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?;
|
||||
let t2 = Tensor::new(&[[1f32, 0f32], [3f32, 3f32], [4f32, 7f32]], device)?;
|
||||
@ -980,7 +1030,7 @@ fn slice_scatter(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn scatter_add(device: &Device) -> Result<()> {
|
||||
fn scatter(device: &Device) -> Result<()> {
|
||||
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
@ -1004,6 +1054,17 @@ fn scatter_add(device: &Device) -> Result<()> {
|
||||
]
|
||||
);
|
||||
|
||||
let hs = init.scatter(&ids, &t, 1)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0, 1.0, 1.0],
|
||||
[5.0, 1.0, 1.0, 3.0, 4.0],
|
||||
[1.0, 8.0, 1.0, 7.0, 1.0],
|
||||
[10.0, 1.0, 9.0, 1.0, 11.0]
|
||||
]
|
||||
);
|
||||
|
||||
let init = Tensor::ones((6, 3), DType::F32, device)?;
|
||||
let hs = init.scatter_add(&ids, &t, 0)?;
|
||||
assert_eq!(
|
||||
@ -1017,6 +1078,56 @@ fn scatter_add(device: &Device) -> Result<()> {
|
||||
[1.0, 1.0, 1.0]
|
||||
]
|
||||
);
|
||||
let hs = init.scatter(&ids, &t, 0)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 10.0, 5.0],
|
||||
[1.0, 1.0, 8.0],
|
||||
[9.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 1.0],
|
||||
[1.0, 4.0, 11.0],
|
||||
[1.0, 1.0, 1.0]
|
||||
]
|
||||
);
|
||||
|
||||
let hs = {
|
||||
let ids = Tensor::new(
|
||||
&[
|
||||
[0u32, u32::MAX, 2],
|
||||
[3, 4, u32::MAX],
|
||||
[3, 3, 1],
|
||||
[u32::MAX, u32::MAX, 4],
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
init.scatter(&ids, &t, 0)?
|
||||
};
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 8.0],
|
||||
[1.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 1.0],
|
||||
[1.0, 4.0, 11.0],
|
||||
[1.0, 1.0, 1.0]
|
||||
]
|
||||
);
|
||||
|
||||
init.scatter_set(&ids, &t, 0)?;
|
||||
assert_eq!(
|
||||
init.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 10.0, 5.0],
|
||||
[1.0, 1.0, 8.0],
|
||||
[9.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 1.0],
|
||||
[1.0, 4.0, 11.0],
|
||||
[1.0, 1.0, 1.0]
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1050,6 +1161,23 @@ fn gather(device: &Device) -> Result<()> {
|
||||
let hs = t.gather(&ids, 0)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 7.0, 2.0], [0.0, 4.0, 5.0]]);
|
||||
|
||||
let hs = {
|
||||
let ids = Tensor::new(
|
||||
&[
|
||||
[0u32, 0u32],
|
||||
[2u32, u32::MAX],
|
||||
[u32::MAX, 1u32],
|
||||
[0u32, 2u32],
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
t.gather(&ids, 1)?
|
||||
};
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[[0.0, 0.0], [5.0, 0.0], [0.0, 7.0], [9.0, 11.0]]
|
||||
);
|
||||
|
||||
// Random data
|
||||
|
||||
// Dim: 0
|
||||
@ -1484,6 +1612,7 @@ fn zero_dim(device: &Device) -> Result<()> {
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
||||
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
||||
test_device!(full, full_cpu, full_gpu, full_metal);
|
||||
test_device!(const_set, cs_cpu, cs_gpu, cs_metal);
|
||||
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
|
||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
||||
@ -1515,12 +1644,7 @@ test_device!(
|
||||
);
|
||||
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
|
||||
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
|
||||
test_device!(
|
||||
scatter_add,
|
||||
scatter_add_cpu,
|
||||
scatter_add_gpu,
|
||||
scatter_add_metal
|
||||
);
|
||||
test_device!(scatter, scatter_cpu, scatter_gpu, scatter_metal);
|
||||
test_device!(
|
||||
slice_scatter,
|
||||
slice_scatter_cpu,
|
||||
@ -1733,3 +1857,34 @@ fn test_flip_3d_channels() -> Result<()> {
|
||||
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tensor_new() -> Result<()> {
|
||||
let t1 = Tensor::new(vec![1f32, 2.0, 3.0], &Device::Cpu)?;
|
||||
assert_eq!(t1.to_vec1::<f32>()?, [1.0, 2.0, 3.0]);
|
||||
let t2 = Tensor::new(vec![vec![1f32, 2., 3.], vec![4., 5., 6.]], &Device::Cpu)?;
|
||||
assert_eq!(t2.to_vec2::<f32>()?, [[1., 2., 3.], [4., 5., 6.]]);
|
||||
let t3 = Tensor::new(
|
||||
vec![
|
||||
vec![vec![1f32, 2., 3.], vec![4., 5., 6.]],
|
||||
vec![vec![3f32, 1., 4.], vec![1., 5., 9.]],
|
||||
],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
assert_eq!(
|
||||
t3.to_vec3::<f32>()?,
|
||||
[
|
||||
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
|
||||
[[3.0, 1.0, 4.0], [1.0, 5.0, 9.0]]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tensor_norm() -> Result<()> {
|
||||
let t = Tensor::new(&[[3., 4.], [0., 0.]], &Device::Cpu)?;
|
||||
let norm = t.norm()?;
|
||||
assert_eq!(norm.to_scalar::<f64>()?, 5.);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -16,10 +16,9 @@ fn read_u32<T: Read>(reader: &mut T) -> std::io::Result<u32> {
|
||||
fn check_magic_number<T: Read>(reader: &mut T, expected: u32) -> Result<()> {
|
||||
let magic_number = read_u32(reader)?;
|
||||
if magic_number != expected {
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("incorrect magic number {magic_number} != {expected}"),
|
||||
))?;
|
||||
Err(io::Error::other(format!(
|
||||
"incorrect magic number {magic_number} != {expected}"
|
||||
)))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -84,6 +84,10 @@ required-features = ["pyo3"]
|
||||
name = "onnx"
|
||||
required-features = ["onnx"]
|
||||
|
||||
[[example]]
|
||||
name = "onnx-llm"
|
||||
required-features = ["onnx"]
|
||||
|
||||
[[example]]
|
||||
name = "onnx_basics"
|
||||
required-features = ["onnx"]
|
||||
|
@ -20,8 +20,8 @@ use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::{Encoding, PaddingParams, Tokenizer};
|
||||
|
||||
enum TaskType {
|
||||
Ner(DebertaV2NERModel),
|
||||
TextClassification(DebertaV2SeqClassificationModel),
|
||||
Ner(Box<DebertaV2NERModel>),
|
||||
TextClassification(Box<DebertaV2SeqClassificationModel>),
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone, ValueEnum)]
|
||||
@ -169,21 +169,16 @@ impl Args {
|
||||
|
||||
match self.task {
|
||||
ArgsTask::Ner => Ok((
|
||||
TaskType::Ner(DebertaV2NERModel::load(
|
||||
vb,
|
||||
&config,
|
||||
Some(id2label.clone()),
|
||||
)?),
|
||||
TaskType::Ner(DebertaV2NERModel::load(vb, &config, Some(id2label.clone()))?.into()),
|
||||
config,
|
||||
tokenizer,
|
||||
id2label,
|
||||
)),
|
||||
ArgsTask::TextClassification => Ok((
|
||||
TaskType::TextClassification(DebertaV2SeqClassificationModel::load(
|
||||
vb,
|
||||
&config,
|
||||
Some(id2label.clone()),
|
||||
)?),
|
||||
TaskType::TextClassification(
|
||||
DebertaV2SeqClassificationModel::load(vb, &config, Some(id2label.clone()))?
|
||||
.into(),
|
||||
),
|
||||
config,
|
||||
tokenizer,
|
||||
id2label,
|
||||
|
@ -16,8 +16,8 @@ use std::path::PathBuf;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum ModelType {
|
||||
Masked(DistilBertForMaskedLM),
|
||||
UnMasked(DistilBertModel),
|
||||
Masked(Box<DistilBertForMaskedLM>),
|
||||
UnMasked(Box<DistilBertModel>),
|
||||
}
|
||||
|
||||
impl ModelType {
|
||||
@ -144,10 +144,12 @@ impl Args {
|
||||
|
||||
fn create_model(&self, config: &Config, vb: VarBuilder) -> Result<ModelType> {
|
||||
match self.model {
|
||||
Which::DistilbertForMaskedLM => {
|
||||
Ok(ModelType::Masked(DistilBertForMaskedLM::load(vb, config)?))
|
||||
}
|
||||
Which::DistilBert => Ok(ModelType::UnMasked(DistilBertModel::load(vb, config)?)),
|
||||
Which::DistilbertForMaskedLM => Ok(ModelType::Masked(
|
||||
DistilBertForMaskedLM::load(vb, config)?.into(),
|
||||
)),
|
||||
Which::DistilBert => Ok(ModelType::UnMasked(
|
||||
DistilBertModel::load(vb, config)?.into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -124,6 +124,17 @@ impl TextGeneration {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <eos> token"),
|
||||
};
|
||||
|
||||
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
||||
Some(token) => token,
|
||||
None => {
|
||||
println!(
|
||||
"Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup"
|
||||
);
|
||||
eos_token
|
||||
}
|
||||
};
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
@ -146,7 +157,7 @@ impl TextGeneration {
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
@ -350,6 +361,31 @@ fn main() -> Result<()> {
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
|
||||
let prompt = match args.which {
|
||||
Which::Base2B
|
||||
| Which::Base7B
|
||||
| Which::Instruct2B
|
||||
| Which::Instruct7B
|
||||
| Which::InstructV1_1_2B
|
||||
| Which::InstructV1_1_7B
|
||||
| Which::CodeBase2B
|
||||
| Which::CodeBase7B
|
||||
| Which::CodeInstruct2B
|
||||
| Which::CodeInstruct7B
|
||||
| Which::BaseV2_2B
|
||||
| Which::InstructV2_2B
|
||||
| Which::BaseV2_9B
|
||||
| Which::InstructV2_9B
|
||||
| Which::BaseV3_1B => args.prompt,
|
||||
Which::InstructV3_1B => {
|
||||
format!(
|
||||
"<start_of_turn> user\n{}<end_of_turn>\n<start_of_turn> model\n",
|
||||
args.prompt
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
pipeline.run(&prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -7,7 +7,10 @@ extern crate accelerate_src;
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::helium::{Config, Model};
|
||||
use candle_transformers::models::helium::{Config as ConfigPreview, Model as ModelPreview};
|
||||
use candle_transformers::models::llama::{
|
||||
Cache as CacheV1, Llama as ModelV1, LlamaConfig as ConfigV1, LlamaEosToks,
|
||||
};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
@ -16,6 +19,44 @@ use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum Model {
|
||||
V1 { model: ModelV1, cache: CacheV1 },
|
||||
Preview(ModelPreview),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&mut self, input: &Tensor, start_pos: usize) -> Result<Tensor> {
|
||||
let model = match self {
|
||||
Model::V1 { model, cache } => model.forward(input, start_pos, cache)?,
|
||||
Model::Preview(m) => m.forward(input, start_pos)?,
|
||||
};
|
||||
Ok(model)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum Config {
|
||||
V1(ConfigV1),
|
||||
Preview(ConfigPreview),
|
||||
}
|
||||
|
||||
impl Config {
|
||||
fn bos_token_id(&self) -> Option<u32> {
|
||||
match self {
|
||||
Config::V1(c) => c.bos_token_id,
|
||||
Config::Preview(c) => Some(c.bos_token_id),
|
||||
}
|
||||
}
|
||||
|
||||
fn eos_token_id(&self) -> Option<LlamaEosToks> {
|
||||
match self {
|
||||
Config::V1(c) => c.eos_token_id.clone(),
|
||||
Config::Preview(c) => Some(LlamaEosToks::Single(c.eos_token_id)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
@ -106,7 +147,15 @@ impl TextGeneration {
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == self.config.bos_token_id || next_token == self.config.eos_token_id {
|
||||
let is_eos = self
|
||||
.config
|
||||
.eos_token_id()
|
||||
.as_ref()
|
||||
.is_some_and(|v| match v {
|
||||
LlamaEosToks::Single(eos) => *eos == next_token,
|
||||
LlamaEosToks::Multiple(eos) => eos.contains(&next_token),
|
||||
});
|
||||
if Some(next_token) == self.config.bos_token_id() || is_eos {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
@ -131,6 +180,8 @@ impl TextGeneration {
|
||||
enum Which {
|
||||
#[value(name = "v1-preview")]
|
||||
V1Preview,
|
||||
#[value(name = "v1")]
|
||||
V1,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -144,9 +195,6 @@ struct Args {
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
@ -171,7 +219,7 @@ struct Args {
|
||||
sample_len: usize,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "v1-preview")]
|
||||
#[arg(long, default_value = "v1")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
@ -230,6 +278,7 @@ fn main() -> Result<()> {
|
||||
None => {
|
||||
let name = match args.which {
|
||||
Which::V1Preview => "kyutai/helium-1-preview-2b",
|
||||
Which::V1 => "kyutai/helium-1-2b",
|
||||
};
|
||||
name.to_string()
|
||||
}
|
||||
@ -254,18 +303,27 @@ fn main() -> Result<()> {
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config: Config = match args.config {
|
||||
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
|
||||
None => {
|
||||
let config_file = repo.get("config.json")?;
|
||||
serde_json::from_slice(&std::fs::read(config_file)?)?
|
||||
}
|
||||
let config_file = match args.config {
|
||||
Some(config_file) => std::path::PathBuf::from(config_file),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
let config = match args.which {
|
||||
Which::V1Preview => Config::Preview(serde_json::from_slice(&std::fs::read(config_file)?)?),
|
||||
Which::V1 => Config::V1(serde_json::from_slice(&std::fs::read(config_file)?)?),
|
||||
};
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (model, device) = {
|
||||
let dtype = device.bf16_default_to_f32();
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
let model = match &config {
|
||||
Config::V1(c) => {
|
||||
let c = c.clone().into_config(false);
|
||||
let model = ModelV1::load(vb, &c)?;
|
||||
let cache = CacheV1::new(true, dtype, &c, &device)?;
|
||||
Model::V1 { model, cache }
|
||||
}
|
||||
Config::Preview(c) => Model::Preview(ModelPreview::new(c, vb)?),
|
||||
};
|
||||
(model, device)
|
||||
};
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
OLMo is a series of Open Language Models designed to enable the science of language models.
|
||||
|
||||
- **Project Page:** https://allenai.org/olmo
|
||||
- **Paper:** [Link](https://arxiv.org/abs/2402.00838)
|
||||
- **Papers:** [OLMo](https://arxiv.org/abs/2402.00838) [OLMo 2](https://arxiv.org/abs/2501.00656)
|
||||
- **Technical blog post:** https://blog.allenai.org/olmo-open-language-model-87ccfc95f580
|
||||
- **W&B Logs:** https://wandb.ai/ai2-llm/OLMo-1B/reports/OLMo-1B--Vmlldzo2NzY1Njk1
|
||||
<!-- - **Press release:** TODO -->
|
||||
|
@ -8,6 +8,7 @@ use anyhow::{Error as E, Result};
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::olmo::{Config, Model as OLMo};
|
||||
use candle_transformers::models::olmo2::{Config as Config2, Model as OLMo2};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
@ -18,6 +19,7 @@ use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
OLMo(OLMo),
|
||||
OLMo2(OLMo2),
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
@ -82,6 +84,7 @@ impl TextGeneration {
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = match &mut self.model {
|
||||
Model::OLMo(m) => m.forward(&input, start_pos)?,
|
||||
Model::OLMo2(m) => m.forward(&input, start_pos)?,
|
||||
};
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
@ -129,6 +132,8 @@ enum Which {
|
||||
W7bTwin2T,
|
||||
#[value(name = "1.7-7b")]
|
||||
V1_7W7b,
|
||||
#[value(name = "2-1b")]
|
||||
V2W1b,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -220,6 +225,7 @@ fn main() -> Result<()> {
|
||||
Which::W7b => "allenai/OLMo-7B-hf".to_string(),
|
||||
Which::W7bTwin2T => "allenai/OLMo-7B-Twin-2T-hf".to_string(),
|
||||
Which::V1_7W7b => "allenai/OLMo-1.7-7B-hf".to_string(),
|
||||
Which::V2W1b => "allenai/OLMo-2-0425-1B-Instruct".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
@ -238,33 +244,36 @@ fn main() -> Result<()> {
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => match args.model {
|
||||
Which::W1b => {
|
||||
Which::W1b | Which::V2W1b => {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
},
|
||||
};
|
||||
|
||||
let config_filename = repo.get("config.json")?;
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = {
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
config
|
||||
};
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let model = {
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = OLMo::new(&config, vb)?;
|
||||
Model::OLMo(model)
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = match args.model {
|
||||
Which::W1b | Which::W7b | Which::W7bTwin2T | Which::V1_7W7b => {
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let model = OLMo::new(&config, vb)?;
|
||||
Model::OLMo(model)
|
||||
}
|
||||
Which::V2W1b => {
|
||||
let config: Config2 = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let model = OLMo2::new(&config, vb)?;
|
||||
Model::OLMo2(model)
|
||||
}
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
11
candle-examples/examples/onnx-llm/README.md
Normal file
11
candle-examples/examples/onnx-llm/README.md
Normal file
@ -0,0 +1,11 @@
|
||||
## Using ONNX models in Candle
|
||||
|
||||
This example demonstrates how to run [ONNX](https://github.com/onnx/onnx) based LLM models in Candle.
|
||||
|
||||
This script only implements SmolLM-135M right now.
|
||||
|
||||
You can run the examples with following commands:
|
||||
|
||||
```bash
|
||||
cargo run --example onnx-llm --features onnx
|
||||
```
|
209
candle-examples/examples/onnx-llm/main.rs
Normal file
209
candle-examples/examples/onnx-llm/main.rs
Normal file
@ -0,0 +1,209 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Tensor};
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::api::sync::Api;
|
||||
use serde::Deserialize;
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub hidden_size: usize,
|
||||
pub num_attention_heads: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
SmolLM135M,
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
/// The prompt to be used.
|
||||
#[arg(long, default_value = "My favorite theorem is ")]
|
||||
prompt: String,
|
||||
|
||||
/// The model to be used.
|
||||
#[arg(value_enum, long, default_value_t = Which::SmolLM135M)]
|
||||
which: Which,
|
||||
|
||||
/// Run on CPU rather than GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// The number of tokens to generate.
|
||||
#[arg(long, default_value_t = 100)]
|
||||
max_tokens: usize,
|
||||
|
||||
/// The temperature used for sampling.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f32,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples.
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
}
|
||||
|
||||
pub fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let (model_id, tokenizer_id) = match args.which {
|
||||
Which::SmolLM135M => ("HuggingFaceTB/SmolLM-135M", "HuggingFaceTB/SmolLM-135M"),
|
||||
};
|
||||
|
||||
let api = Api::new()?;
|
||||
let model_repo = api.model(model_id.to_string());
|
||||
let tokenizer_repo = api.model(tokenizer_id.to_string());
|
||||
|
||||
let model_path = model_repo.get("onnx/model.onnx")?;
|
||||
let config_file = model_repo.get("config.json")?;
|
||||
let config: Config = serde_json::from_reader(std::fs::File::open(config_file)?)?;
|
||||
|
||||
let tokenizer_path = tokenizer_repo.get("tokenizer.json")?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?;
|
||||
|
||||
let tokens_u32 = tokenizer
|
||||
.encode(args.prompt.as_str(), true)
|
||||
.map_err(anyhow::Error::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let tokens: Vec<i64> = tokens_u32.iter().map(|&t| t as i64).collect();
|
||||
|
||||
println!("Loading ONNX model from {:?}", model_path);
|
||||
let model = candle_onnx::read_file(model_path)?;
|
||||
|
||||
let mut generated_tokens = tokens.clone();
|
||||
print!("{}", args.prompt);
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut logits_processor = {
|
||||
let temperature = args.temperature as f64;
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (args.top_k, args.top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||
};
|
||||
|
||||
let mut past_key_values: Option<Vec<(Tensor, Tensor)>> = None;
|
||||
let num_layers = config.num_hidden_layers;
|
||||
|
||||
for _ in 0..args.max_tokens {
|
||||
let mut inputs = std::collections::HashMap::new();
|
||||
|
||||
if let Some(past_kv) = &past_key_values {
|
||||
let last_token = vec![generated_tokens[generated_tokens.len() - 1]];
|
||||
let input_tensor = Tensor::new(last_token, &device)?.unsqueeze(0)?;
|
||||
inputs.insert("input_ids".to_string(), input_tensor);
|
||||
|
||||
let seq_len = generated_tokens.len();
|
||||
let attention_mask = vec![vec![1i64; seq_len]];
|
||||
let attention_mask_tensor = Tensor::new(attention_mask, &device)?;
|
||||
inputs.insert("attention_mask".to_string(), attention_mask_tensor);
|
||||
|
||||
let position_ids = vec![vec![(seq_len - 1) as i64]];
|
||||
let position_ids_tensor = Tensor::new(position_ids, &device)?;
|
||||
inputs.insert("position_ids".to_string(), position_ids_tensor);
|
||||
|
||||
for (i, (key, value)) in past_kv.iter().enumerate() {
|
||||
inputs.insert(format!("past_key_values.{}.key", i), key.clone());
|
||||
inputs.insert(format!("past_key_values.{}.value", i), value.clone());
|
||||
}
|
||||
} else {
|
||||
let input_tensor = Tensor::new(generated_tokens.clone(), &device)?.unsqueeze(0)?;
|
||||
inputs.insert("input_ids".to_string(), input_tensor);
|
||||
|
||||
let seq_len = generated_tokens.len();
|
||||
let attention_mask = vec![vec![1i64; seq_len]];
|
||||
let attention_mask_tensor = Tensor::new(attention_mask, &device)?;
|
||||
inputs.insert("attention_mask".to_string(), attention_mask_tensor);
|
||||
|
||||
let position_ids: Vec<i64> = (0..seq_len as i64).collect();
|
||||
let position_ids_tensor = Tensor::new(position_ids, &device)?.unsqueeze(0)?;
|
||||
inputs.insert("position_ids".to_string(), position_ids_tensor);
|
||||
|
||||
// Create empty key and value tensors
|
||||
for i in 0..num_layers {
|
||||
let batch_size = 1;
|
||||
let num_heads = config.num_key_value_heads;
|
||||
let head_dim = config.hidden_size / config.num_attention_heads;
|
||||
let seq_len = 0;
|
||||
|
||||
let empty_key = Tensor::zeros(
|
||||
&[batch_size, num_heads, seq_len, head_dim],
|
||||
DType::F32,
|
||||
&device,
|
||||
)?;
|
||||
let empty_value = Tensor::zeros(
|
||||
&[batch_size, num_heads, seq_len, head_dim],
|
||||
DType::F32,
|
||||
&device,
|
||||
)?;
|
||||
|
||||
inputs.insert(format!("past_key_values.{}.key", i), empty_key);
|
||||
inputs.insert(format!("past_key_values.{}.value", i), empty_value);
|
||||
}
|
||||
}
|
||||
|
||||
let outputs = candle_onnx::simple_eval(&model, inputs)?;
|
||||
|
||||
let logits = outputs.get("logits").unwrap();
|
||||
|
||||
let mut new_past_kv = Vec::with_capacity(num_layers);
|
||||
for i in 0..num_layers {
|
||||
let key = outputs
|
||||
.get(&format!("present.{}.key", i))
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing present.{}.key", i))?;
|
||||
let value = outputs
|
||||
.get(&format!("present.{}.value", i))
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing present.{}.value", i))?;
|
||||
new_past_kv.push((key.clone(), value.clone()));
|
||||
}
|
||||
past_key_values = Some(new_past_kv);
|
||||
|
||||
let logits_dim = logits.dims();
|
||||
let seq_len = logits_dim[1];
|
||||
|
||||
let next_token_id = logits_processor.sample(&logits.get(0)?.get(seq_len - 1)?)?;
|
||||
generated_tokens.push(next_token_id as i64);
|
||||
|
||||
if let Some(token_str) = tokenizer.decode(&[next_token_id], true).ok() {
|
||||
print!("{}", token_str);
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
||||
if let Some(eos_id) = tokenizer.token_to_id("<|endoftext|>") {
|
||||
if next_token_id == eos_id {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("\nGeneration complete!");
|
||||
Ok(())
|
||||
}
|
@ -5,12 +5,14 @@ extern crate intel_mkl_src;
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::{IndexOp, D};
|
||||
use candle_examples::save_image;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
SqueezeNet,
|
||||
EfficientNet,
|
||||
EsrGan,
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
@ -28,10 +30,21 @@ struct Args {
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
let image = match args.which {
|
||||
Which::SqueezeNet | Which::EfficientNet => {
|
||||
candle_examples::imagenet::load_image224(&args.image)?
|
||||
}
|
||||
Which::EsrGan => candle_examples::imagenet::load_image_with_std_mean(
|
||||
&args.image,
|
||||
128,
|
||||
&[0.0f32, 0.0, 0.0],
|
||||
&[1.0f32, 1.0, 1.0],
|
||||
)?,
|
||||
};
|
||||
let image = match args.which {
|
||||
Which::SqueezeNet => image,
|
||||
Which::EfficientNet => image.permute((1, 2, 0))?,
|
||||
Which::EsrGan => image,
|
||||
};
|
||||
|
||||
println!("loaded image {image:?}");
|
||||
@ -45,6 +58,9 @@ pub fn main() -> anyhow::Result<()> {
|
||||
Which::EfficientNet => hf_hub::api::sync::Api::new()?
|
||||
.model("onnx/EfficientNet-Lite4".into())
|
||||
.get("efficientnet-lite4-11.onnx")?,
|
||||
Which::EsrGan => hf_hub::api::sync::Api::new()?
|
||||
.model("qualcomm/Real-ESRGAN-x4plus".into())
|
||||
.get("Real-ESRGAN-x4plus.onnx")?,
|
||||
},
|
||||
};
|
||||
|
||||
@ -57,21 +73,40 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let prs = match args.which {
|
||||
Which::SqueezeNet => candle_nn::ops::softmax(&output, D::Minus1)?,
|
||||
Which::EfficientNet => output,
|
||||
Which::EsrGan => output,
|
||||
};
|
||||
let prs = prs.i(0)?.to_vec1::<f32>()?;
|
||||
|
||||
// Sort the predictions and take the top 5
|
||||
let mut top: Vec<_> = prs.iter().enumerate().collect();
|
||||
top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
|
||||
let top = top.into_iter().take(5).collect::<Vec<_>>();
|
||||
match args.which {
|
||||
Which::EfficientNet | Which::SqueezeNet => {
|
||||
let prs = prs.i(0)?.to_vec1::<f32>()?;
|
||||
|
||||
// Print the top predictions
|
||||
for &(i, p) in &top {
|
||||
println!(
|
||||
"{:50}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[i],
|
||||
p * 100.0
|
||||
);
|
||||
// Sort the predictions and take the top 5
|
||||
let mut top: Vec<_> = prs.iter().enumerate().collect();
|
||||
top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
|
||||
let top = top.into_iter().take(5).collect::<Vec<_>>();
|
||||
|
||||
// Print the top predictions
|
||||
for &(i, p) in &top {
|
||||
println!(
|
||||
"{:50}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[i],
|
||||
p * 100.0
|
||||
);
|
||||
}
|
||||
}
|
||||
Which::EsrGan => {
|
||||
let max_pixel_val = candle::Tensor::try_from(255.0f32)?
|
||||
.to_device(prs.device())?
|
||||
.broadcast_as(prs.shape())?;
|
||||
let out = (prs * max_pixel_val)?.i(0)?.to_dtype(candle::DType::U8)?;
|
||||
|
||||
let pb = std::path::PathBuf::from(args.image);
|
||||
let input_file_name = pb.file_name().unwrap();
|
||||
let mut output_file_name = std::ffi::OsString::from("super_");
|
||||
output_file_name.push(input_file_name);
|
||||
|
||||
save_image(&out, output_file_name)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -147,9 +147,9 @@ enum WhichModel {
|
||||
V3,
|
||||
#[value(name = "3-medium")]
|
||||
V3Medium,
|
||||
#[value(name = "2-old")]
|
||||
V4Mini,
|
||||
#[value(name = "4-mini")]
|
||||
V4Mini,
|
||||
#[value(name = "2-old")]
|
||||
V2Old,
|
||||
PuffinPhiV2,
|
||||
PhiHermes,
|
||||
|
18
candle-examples/examples/quantized-gemma/README.md
Normal file
18
candle-examples/examples/quantized-gemma/README.md
Normal file
@ -0,0 +1,18 @@
|
||||
# candle-quantized-gemma
|
||||
|
||||
Candle implementation of quantized Gemma.
|
||||
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
$ cargo run --example quantized-gemma -- --prompt "Write a function to calculate fibonacci numbers. "
|
||||
|
||||
> ```python
|
||||
> def fibonacci(n):
|
||||
> """Calculates the nth Fibonacci number using recursion."""
|
||||
> if n <= 1:
|
||||
> return n
|
||||
> else:
|
||||
> return fibonacci(n-1) + fibonacci(n-2
|
||||
> ```
|
||||
```
|
344
candle-examples/examples/quantized-gemma/main.rs
Normal file
344
candle-examples/examples/quantized-gemma/main.rs
Normal file
@ -0,0 +1,344 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use candle::quantized::gguf_file;
|
||||
use candle::Tensor;
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_transformers::models::quantized_gemma3::ModelWeights;
|
||||
|
||||
const DEFAULT_PROMPT: &str = "Write a function to calculate fibonacci num";
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "gemma3-4b-it")]
|
||||
Gemma3_4bIt,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// GGUF file to load, typically a .gguf file generated by quantization
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
|
||||
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
|
||||
/// is preserved.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(short = 'n', long, default_value_t = 1000)]
|
||||
sample_len: usize,
|
||||
|
||||
/// The tokenizer config in json format.
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
/// The temperature used to generate samples, use 0 for greedy sampling.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples.
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Process prompt elements separately.
|
||||
#[arg(long)]
|
||||
split_prompt: bool,
|
||||
|
||||
/// Run on CPU rather than GPU even if a GPU is available.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "gemma3-4b-it")]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
|
||||
let tokenizer_path = match &self.tokenizer {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let repo = "google/gemma-3-4b-it";
|
||||
println!("DEBUG: Downloading tokenizer from {}", repo);
|
||||
let api = api.model(repo.to_string());
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
};
|
||||
println!("DEBUG: Loading tokenizer from {:?}", tokenizer_path);
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?;
|
||||
|
||||
Ok(tokenizer)
|
||||
}
|
||||
|
||||
fn model(&self) -> anyhow::Result<std::path::PathBuf> {
|
||||
let model_path = match &self.model {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let (repo, filename) = match self.which {
|
||||
Which::Gemma3_4bIt => (
|
||||
"google/gemma-3-4b-it-qat-q4_0-gguf",
|
||||
"gemma-3-4b-it-q4_0.gguf",
|
||||
),
|
||||
};
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
api.repo(hf_hub::Repo::with_revision(
|
||||
repo.to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"main".to_string(),
|
||||
))
|
||||
.get(filename)?
|
||||
}
|
||||
};
|
||||
Ok(model_path)
|
||||
}
|
||||
}
|
||||
|
||||
fn format_size(size_in_bytes: usize) -> String {
|
||||
if size_in_bytes < 1_000 {
|
||||
format!("{}B", size_in_bytes)
|
||||
} else if size_in_bytes < 1_000_000 {
|
||||
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
|
||||
} else if size_in_bytes < 1_000_000_000 {
|
||||
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
|
||||
} else {
|
||||
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Prompt {
|
||||
Interactive,
|
||||
Chat,
|
||||
One(String),
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature, args.repeat_penalty, args.repeat_last_n
|
||||
);
|
||||
|
||||
let model_path = args.model()?;
|
||||
let mut file = std::fs::File::open(&model_path)?;
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let mut model = {
|
||||
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(&model_path))?;
|
||||
let mut total_size_in_bytes = 0;
|
||||
for (_, tensor) in model.tensor_infos.iter() {
|
||||
let elem_count = tensor.shape.elem_count();
|
||||
total_size_in_bytes +=
|
||||
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
|
||||
}
|
||||
println!(
|
||||
"loaded {:?} tensors ({}) in {:.2}s",
|
||||
model.tensor_infos.len(),
|
||||
&format_size(total_size_in_bytes),
|
||||
start.elapsed().as_secs_f32(),
|
||||
);
|
||||
ModelWeights::from_gguf(model, &mut file, &device)?
|
||||
};
|
||||
println!("model built");
|
||||
|
||||
let tokenizer = args.tokenizer()?;
|
||||
|
||||
let mut tos = TokenOutputStream::new(tokenizer);
|
||||
println!(
|
||||
"DEBUG: Tokenizer vocabulary size: {}",
|
||||
tos.tokenizer().get_vocab(true).len()
|
||||
);
|
||||
|
||||
let prompt = match args.prompt.as_deref() {
|
||||
Some("chat") => Prompt::Chat,
|
||||
Some("interactive") => Prompt::Interactive,
|
||||
Some(s) => Prompt::One(s.to_string()),
|
||||
None => Prompt::One(DEFAULT_PROMPT.to_string()),
|
||||
};
|
||||
|
||||
let mut pre_prompt_tokens = vec![];
|
||||
for _ in 0.. {
|
||||
let prompt_str = match &prompt {
|
||||
Prompt::One(prompt) => prompt.clone(),
|
||||
Prompt::Interactive | Prompt::Chat => {
|
||||
print!("> ");
|
||||
std::io::stdout().flush()?;
|
||||
let mut prompt = String::new();
|
||||
std::io::stdin().read_line(&mut prompt)?;
|
||||
if prompt.ends_with('\n') {
|
||||
prompt.pop();
|
||||
if prompt.ends_with('\r') {
|
||||
prompt.pop();
|
||||
}
|
||||
}
|
||||
// Format for Gemma 3 chat/instruction format
|
||||
format!("<start_of_turn> user\n{prompt}<end_of_turn>\n<start_of_turn> model\n")
|
||||
}
|
||||
};
|
||||
print!("{}", &prompt_str);
|
||||
|
||||
let tokens = tos
|
||||
.tokenizer()
|
||||
.encode(prompt_str, true)
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
let prompt_tokens = [&pre_prompt_tokens, tokens.get_ids()].concat();
|
||||
|
||||
let to_sample = args.sample_len.saturating_sub(1);
|
||||
let max_seq_len = 8192; // Gemma 3 context length
|
||||
let prompt_tokens = if prompt_tokens.len() + to_sample > max_seq_len - 10 {
|
||||
let to_remove = prompt_tokens.len() + to_sample + 10 - max_seq_len;
|
||||
prompt_tokens[prompt_tokens.len().saturating_sub(to_remove)..].to_vec()
|
||||
} else {
|
||||
prompt_tokens
|
||||
};
|
||||
let mut all_tokens = vec![];
|
||||
let mut logits_processor = {
|
||||
let temperature = args.temperature;
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (args.top_k, args.top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||
};
|
||||
|
||||
let start_prompt_processing = std::time::Instant::now();
|
||||
let mut next_token = if !args.split_prompt {
|
||||
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
logits_processor.sample(&logits)?
|
||||
} else {
|
||||
let mut next_token = 0;
|
||||
for (pos, token) in prompt_tokens.iter().enumerate() {
|
||||
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
next_token = logits_processor.sample(&logits)?
|
||||
}
|
||||
next_token
|
||||
};
|
||||
let prompt_dt = start_prompt_processing.elapsed();
|
||||
all_tokens.push(next_token);
|
||||
if let Some(t) = tos.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
||||
// For Gemma 3, use the correct end of sequence token
|
||||
let eos_token = *tos
|
||||
.tokenizer()
|
||||
.get_vocab(true)
|
||||
.get("<end_of_turn>")
|
||||
.unwrap();
|
||||
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
let mut sampled = 0;
|
||||
for index in 0..to_sample {
|
||||
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&all_tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
next_token = logits_processor.sample(&logits)?;
|
||||
all_tokens.push(next_token);
|
||||
if let Some(t) = tos.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
sampled += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
};
|
||||
}
|
||||
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
let dt = start_post_prompt.elapsed();
|
||||
println!(
|
||||
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
||||
prompt_tokens.len(),
|
||||
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
||||
);
|
||||
println!(
|
||||
"{sampled:4} tokens generated: {:.2} token/s",
|
||||
sampled as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
|
||||
match prompt {
|
||||
Prompt::One(_) => break,
|
||||
Prompt::Interactive => {}
|
||||
Prompt::Chat => {
|
||||
pre_prompt_tokens = [prompt_tokens.as_slice(), all_tokens.as_slice()].concat()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
@ -8,4 +8,8 @@
|
||||
cargo run --example quantized-qwen2-instruct --release -- --prompt "Write a function to count prime numbers up to N."
|
||||
```
|
||||
|
||||
0.5b, 1.5b, 7b and 72b models are available via `--model` argument.
|
||||
0.5b, 1.5b, 7b and 72b models are available via `--which` argument.
|
||||
|
||||
```bash
|
||||
cargo run --release --example quantized-qwen2-instruct -- --which 0.5b --prompt "Write a function to count prime numbers up to N."
|
||||
```
|
||||
|
17
candle-examples/examples/quantized-qwen3/README.md
Normal file
17
candle-examples/examples/quantized-qwen3/README.md
Normal file
@ -0,0 +1,17 @@
|
||||
# candle-quantized-qwen3
|
||||
|
||||
[Qwen3]((https://qwenlm.github.io/blog/qwen3/)) is an upgraded version of Qwen2.5, released by Alibaba Cloud.
|
||||
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
cargo run --example quantized-qwen3 --release -- --prompt "Write a function to count prime numbers up to N."
|
||||
```
|
||||
|
||||
|
||||
0.6b is used by default, 1.7b, 4b, 8b, 14b, and 32b models are available via `--which` argument.
|
||||
|
||||
```bash
|
||||
cargo run --example quantized-qwen3 --release -- --which 4b --prompt "A train is travelling at 120mph, how far does it travel in 3 minutes 30 seconds?"
|
||||
```
|
||||
|
314
candle-examples/examples/quantized-qwen3/main.rs
Normal file
314
candle-examples/examples/quantized-qwen3/main.rs
Normal file
@ -0,0 +1,314 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use candle::quantized::gguf_file;
|
||||
use candle::Tensor;
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_transformers::models::quantized_qwen3::ModelWeights as Qwen3;
|
||||
|
||||
const DEFAULT_PROMPT: &str = "Write a Rust function to calculate the factorial of a given number.";
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "0.6b")]
|
||||
W3_0_6b,
|
||||
#[value(name = "1.7b")]
|
||||
W3_1_7b,
|
||||
#[value(name = "4b")]
|
||||
W3_4b,
|
||||
#[value(name = "8b")]
|
||||
W3_8b,
|
||||
#[value(name = "14b")]
|
||||
W3_14b,
|
||||
#[value(name = "32b")]
|
||||
W3_32b,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// GGUF file to load, typically a .gguf file generated by the quantize command from llama.cpp
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
|
||||
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
|
||||
/// is preserved.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(short = 'n', long, default_value_t = 1000)]
|
||||
sample_len: usize,
|
||||
|
||||
/// The tokenizer config in json format.
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
/// The temperature used to generate samples, use 0 for greedy sampling.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples.
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Process prompt elements separately.
|
||||
#[arg(long)]
|
||||
split_prompt: bool,
|
||||
|
||||
/// Run on CPU rather than GPU even if a GPU is available.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "0.6b")]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
|
||||
let tokenizer_path = match &self.tokenizer {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let repo = match self.which {
|
||||
Which::W3_0_6b => "Qwen/Qwen3-0.6B",
|
||||
Which::W3_1_7b => "Qwen/Qwen3-1.7B",
|
||||
Which::W3_4b => "Qwen/Qwen3-4B",
|
||||
Which::W3_8b => "Qwen/Qwen3-8B",
|
||||
Which::W3_14b => "Qwen/Qwen3-14B",
|
||||
Which::W3_32b => "Qwen/Qwen3-32B",
|
||||
};
|
||||
let api = api.model(repo.to_string());
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
};
|
||||
Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
|
||||
}
|
||||
|
||||
fn model(&self) -> anyhow::Result<std::path::PathBuf> {
|
||||
let model_path = match &self.model {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let (repo, filename, revision) = match self.which {
|
||||
Which::W3_0_6b => ("unsloth/Qwen3-0.6B-GGUF", "Qwen3-0.6B-Q4_K_M.gguf", "main"),
|
||||
Which::W3_1_7b => ("unsloth/Qwen3-1.7B-GGUF", "Qwen3-1.7B-Q4_K_M.gguf", "main"),
|
||||
Which::W3_4b => ("unsloth/Qwen3-4B-GGUF", "Qwen3-4B-Q4_K_M.gguf", "main"),
|
||||
Which::W3_8b => ("unsloth/Qwen3-8B-GGUF", "Qwen3-8B-Q4_K_M.gguf", "main"),
|
||||
Which::W3_14b => ("unsloth/Qwen3-14B-GGUF", "Qwen3-14B-Q4_K_M.gguf", "main"),
|
||||
Which::W3_32b => ("unsloth/Qwen3-32B-GGUF", "Qwen3-32B-Q4_K_M.gguf", "main"),
|
||||
};
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
api.repo(hf_hub::Repo::with_revision(
|
||||
repo.to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
revision.to_string(),
|
||||
))
|
||||
.get(filename)?
|
||||
}
|
||||
};
|
||||
Ok(model_path)
|
||||
}
|
||||
}
|
||||
|
||||
fn format_size(size_in_bytes: usize) -> String {
|
||||
if size_in_bytes < 1_000 {
|
||||
format!("{}B", size_in_bytes)
|
||||
} else if size_in_bytes < 1_000_000 {
|
||||
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
|
||||
} else if size_in_bytes < 1_000_000_000 {
|
||||
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
|
||||
} else {
|
||||
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature, args.repeat_penalty, args.repeat_last_n
|
||||
);
|
||||
|
||||
let model_path = args.model()?;
|
||||
let mut file = std::fs::File::open(&model_path)?;
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let mut model = {
|
||||
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
|
||||
let mut total_size_in_bytes = 0;
|
||||
for (_, tensor) in model.tensor_infos.iter() {
|
||||
let elem_count = tensor.shape.elem_count();
|
||||
total_size_in_bytes +=
|
||||
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
|
||||
}
|
||||
println!(
|
||||
"loaded {:?} tensors ({}) in {:.2}s",
|
||||
model.tensor_infos.len(),
|
||||
&format_size(total_size_in_bytes),
|
||||
start.elapsed().as_secs_f32(),
|
||||
);
|
||||
Qwen3::from_gguf(model, &mut file, &device)?
|
||||
};
|
||||
println!("model built");
|
||||
|
||||
let tokenizer = args.tokenizer()?;
|
||||
let mut tos = TokenOutputStream::new(tokenizer);
|
||||
let prompt_str = args
|
||||
.prompt
|
||||
.clone()
|
||||
.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
|
||||
|
||||
let prompt_str = format!("<|im_start|>user\n{prompt_str}<|im_end|>\n<|im_start|>assistant\n");
|
||||
print!("formatted prompt: {}", &prompt_str);
|
||||
|
||||
let tokens = tos
|
||||
.tokenizer()
|
||||
.encode(prompt_str, true)
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
|
||||
let tokens = tokens.get_ids();
|
||||
|
||||
let to_sample = args.sample_len.saturating_sub(1);
|
||||
|
||||
let mut all_tokens = vec![];
|
||||
|
||||
let mut logits_processor = {
|
||||
let temperature = args.temperature;
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (args.top_k, args.top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||
};
|
||||
|
||||
let start_prompt_processing = std::time::Instant::now();
|
||||
|
||||
let mut next_token = if !args.split_prompt {
|
||||
let input = Tensor::new(tokens, &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
logits_processor.sample(&logits)?
|
||||
} else {
|
||||
let mut next_token = 0;
|
||||
for (pos, token) in tokens.iter().enumerate() {
|
||||
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
next_token = logits_processor.sample(&logits)?
|
||||
}
|
||||
next_token
|
||||
};
|
||||
|
||||
let prompt_dt = start_prompt_processing.elapsed();
|
||||
|
||||
all_tokens.push(next_token);
|
||||
|
||||
if let Some(t) = tos.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
||||
let eos_token = *tos.tokenizer().get_vocab(true).get("<|im_end|>").unwrap();
|
||||
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
|
||||
let mut sampled = 0;
|
||||
for index in 0..to_sample {
|
||||
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, tokens.len() + index)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&all_tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
next_token = logits_processor.sample(&logits)?;
|
||||
all_tokens.push(next_token);
|
||||
if let Some(t) = tos.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
sampled += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
|
||||
std::io::stdout().flush()?;
|
||||
let dt = start_post_prompt.elapsed();
|
||||
println!(
|
||||
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
||||
tokens.len(),
|
||||
tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
||||
);
|
||||
println!(
|
||||
"{sampled:4} tokens generated: {:.2} token/s",
|
||||
sampled as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
@ -25,3 +25,28 @@ def print_prime(n: int): # n is the number of primes to be printed
|
||||
print(i)
|
||||
```
|
||||
|
||||
The qwen3 MoE variant is also an option.
|
||||
|
||||
```bash
|
||||
$ cargo run --example qwen --features metal --release -- --prompt "Write a poem about butterflies. <think></think>." --model "3-moe-a3b"
|
||||
> In morning's hush, where daisies sleep,
|
||||
> A fleeting dance through sunlit deep—
|
||||
> They flutter soft on gossamer thread,
|
||||
> The messengers of spring’s own head.
|
||||
>
|
||||
> With painted sails and delicate grace,
|
||||
> They drift from bloom to blossom's face.
|
||||
> Each wing a tale in hues unseen,
|
||||
> Of ancient dreams and secrets between.
|
||||
>
|
||||
> No sound they make, yet still they speak—
|
||||
> Of time that flies, of life so brief.
|
||||
> A fleeting kiss on summer’s breath,
|
||||
> A whisper lost before death.
|
||||
>
|
||||
> Yet in their flight, the soul takes wing,
|
||||
> And for a moment, all is spring.
|
||||
> For though they fade, they never die—
|
||||
> Their beauty lives where hearts can fly.
|
||||
> 161 tokens generated (3.00 token/s)
|
||||
```
|
||||
|
@ -9,6 +9,8 @@ use clap::Parser;
|
||||
|
||||
use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase};
|
||||
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
|
||||
use candle_transformers::models::qwen3::{Config as Config3, ModelForCausalLM as Model3};
|
||||
use candle_transformers::models::qwen3_moe::{Config as ConfigMoe3, ModelForCausalLM as ModelMoe3};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
@ -20,6 +22,8 @@ use tokenizers::Tokenizer;
|
||||
enum Model {
|
||||
Base(ModelBase),
|
||||
Moe(ModelMoe),
|
||||
Base3(Model3),
|
||||
Moe3(ModelMoe3),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
@ -27,6 +31,8 @@ impl Model {
|
||||
match self {
|
||||
Self::Moe(ref mut m) => m.forward(xs, s),
|
||||
Self::Base(ref mut m) => m.forward(xs, s),
|
||||
Self::Base3(ref mut m) => m.forward(xs, s),
|
||||
Self::Moe3(ref mut m) => m.forward(xs, s),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -85,6 +91,10 @@ impl TextGeneration {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
||||
};
|
||||
let eos_token2 = match self.tokenizer.get_token("<|im_end|>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <|im_end|> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
@ -107,7 +117,7 @@ impl TextGeneration {
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
if next_token == eos_token || next_token == eos_token2 {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
@ -152,6 +162,16 @@ enum WhichModel {
|
||||
W2_7b,
|
||||
#[value(name = "2-72b")]
|
||||
W2_72b,
|
||||
#[value(name = "3-0.6b")]
|
||||
W3_0_6b,
|
||||
#[value(name = "3-1.7b")]
|
||||
W3_1_7b,
|
||||
#[value(name = "3-4b")]
|
||||
W3_4b,
|
||||
#[value(name = "3-8b")]
|
||||
W3_8b,
|
||||
#[value(name = "3-moe-a3b")]
|
||||
W3MoeA3b,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -254,6 +274,11 @@ fn main() -> Result<()> {
|
||||
WhichModel::W14b => ("1.5", "14B"),
|
||||
WhichModel::W72b => ("1.5", "72B"),
|
||||
WhichModel::MoeA27b => ("1.5", "MoE-A2.7B"),
|
||||
WhichModel::W3_0_6b => ("3", "0.6B"),
|
||||
WhichModel::W3_1_7b => ("3", "1.7B"),
|
||||
WhichModel::W3_4b => ("3", "4B"),
|
||||
WhichModel::W3_8b => ("3", "8B"),
|
||||
WhichModel::W3MoeA3b => ("3", "30B-A3B"),
|
||||
};
|
||||
format!("Qwen/Qwen{version}-{size}")
|
||||
}
|
||||
@ -273,7 +298,11 @@ fn main() -> Result<()> {
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => match args.model {
|
||||
WhichModel::W0_5b | WhichModel::W2_0_5b | WhichModel::W2_1_5b | WhichModel::W1_8b => {
|
||||
WhichModel::W0_5b
|
||||
| WhichModel::W2_0_5b
|
||||
| WhichModel::W2_1_5b
|
||||
| WhichModel::W1_8b
|
||||
| WhichModel::W3_0_6b => {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
WhichModel::W4b
|
||||
@ -282,7 +311,11 @@ fn main() -> Result<()> {
|
||||
| WhichModel::W14b
|
||||
| WhichModel::W72b
|
||||
| WhichModel::W2_72b
|
||||
| WhichModel::MoeA27b => {
|
||||
| WhichModel::MoeA27b
|
||||
| WhichModel::W3_1_7b
|
||||
| WhichModel::W3_4b
|
||||
| WhichModel::W3_8b
|
||||
| WhichModel::W3MoeA3b => {
|
||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||
}
|
||||
},
|
||||
@ -304,6 +337,14 @@ fn main() -> Result<()> {
|
||||
let config: ConfigMoe = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||
Model::Moe(ModelMoe::new(&config, vb)?)
|
||||
}
|
||||
WhichModel::W3_0_6b | WhichModel::W3_1_7b | WhichModel::W3_4b | WhichModel::W3_8b => {
|
||||
let config: Config3 = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||
Model::Base3(Model3::new(&config, vb)?)
|
||||
}
|
||||
WhichModel::W3MoeA3b => {
|
||||
let config: ConfigMoe3 = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||
Model::Moe3(ModelMoe3::new(&config, vb)?)
|
||||
}
|
||||
_ => {
|
||||
let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||
Model::Base(ModelBase::new(&config, vb)?)
|
||||
|
@ -28,3 +28,26 @@ Ranking Results:
|
||||
> Rank #1 | Score: 0.9990 | The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.
|
||||
--------------------------------------------------------------------------------
|
||||
```
|
||||
|
||||
Text-Classification:
|
||||
```bash
|
||||
cargo run --example xlm-roberta -- --task text-classification --model xlmr-formality-classifier
|
||||
```
|
||||
```markdown
|
||||
Formality Scores:
|
||||
Text 1: "I like you. I love you"
|
||||
formal: 0.9933
|
||||
informal: 0.0067
|
||||
|
||||
Text 2: "Hey, what's up?"
|
||||
formal: 0.8812
|
||||
informal: 0.1188
|
||||
|
||||
Text 3: "Siema, co porabiasz?"
|
||||
formal: 0.9358
|
||||
informal: 0.0642
|
||||
|
||||
Text 4: "I feel deep regret and sadness about the situation in international politics."
|
||||
formal: 0.9987
|
||||
informal: 0.0013
|
||||
```
|
@ -2,6 +2,7 @@ use std::path::PathBuf;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{Device, Tensor};
|
||||
use candle_nn::ops::softmax;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::xlm_roberta::{
|
||||
Config, XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification,
|
||||
@ -17,12 +18,14 @@ enum Model {
|
||||
BgeRerankerBaseV2,
|
||||
XLMRobertaBase,
|
||||
XLMRobertaLarge,
|
||||
XLMRFormalityClassifier,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, ValueEnum)]
|
||||
enum Task {
|
||||
FillMask,
|
||||
Reranker,
|
||||
TextClassification,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -83,6 +86,12 @@ fn main() -> Result<()> {
|
||||
Model::BgeRerankerBaseV2 => "BAAI/bge-reranker-base-v2-m3".to_string(),
|
||||
_ => anyhow::bail!("XLM-RoBERTa models are not supported for reranker task"),
|
||||
},
|
||||
Task::TextClassification => match args.model {
|
||||
Model::XLMRFormalityClassifier => "s-nlp/xlmr_formality_classifier".to_string(),
|
||||
_ => anyhow::bail!(
|
||||
"XLM-RoBERTa models are not supported for text classification task"
|
||||
),
|
||||
},
|
||||
},
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
@ -217,6 +226,36 @@ fn main() -> Result<()> {
|
||||
});
|
||||
println!("{:-<80}", "");
|
||||
}
|
||||
Task::TextClassification => {
|
||||
let sentences = vec![
|
||||
"I like you. I love you".to_string(),
|
||||
"Hey, what's up?".to_string(),
|
||||
"Siema, co porabiasz?".to_string(),
|
||||
"I feel deep regret and sadness about the situation in international politics."
|
||||
.to_string(),
|
||||
];
|
||||
let model = XLMRobertaForSequenceClassification::new(2, &config, vb)?;
|
||||
let input_ids = tokenize_batch(&tokenizer, TokenizeInput::Single(&sentences), &device)?;
|
||||
|
||||
let attention_mask =
|
||||
get_attention_mask(&tokenizer, TokenizeInput::Single(&sentences), &device)?;
|
||||
let token_type_ids = Tensor::zeros(input_ids.dims(), input_ids.dtype(), &device)?;
|
||||
|
||||
let logits = model
|
||||
.forward(&input_ids, &attention_mask, &token_type_ids)?
|
||||
.to_dtype(candle::DType::F32)?;
|
||||
|
||||
let probabilities = softmax(&logits, 1)?;
|
||||
let probs_vec = probabilities.to_vec2::<f32>()?;
|
||||
|
||||
println!("Formality Scores:");
|
||||
for (i, (text, probs)) in sentences.iter().zip(probs_vec.iter()).enumerate() {
|
||||
println!("Text {}: \"{}\"", i + 1, text);
|
||||
println!(" formal: {:.4}", probs[0]);
|
||||
println!(" informal: {:.4}", probs[1]);
|
||||
println!();
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.9.0-alpha.4"
|
||||
version = "0.9.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.4" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.1" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.9.0-alpha.4"
|
||||
version = "0.9.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include<stdint.h>
|
||||
#include "cuda_fp16.h"
|
||||
#include "cuda_utils.cuh"
|
||||
|
||||
template<typename T>
|
||||
__device__ void fill_with(T *buf, T value, const size_t numel) {
|
||||
@ -36,13 +37,45 @@ COPY2D_OP(uint8_t, copy2d_u8)
|
||||
COPY2D_OP(uint32_t, copy2d_u32)
|
||||
COPY2D_OP(int64_t, copy2d_i64)
|
||||
|
||||
#define CONST_SET_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
const size_t num_dims, \
|
||||
const size_t *info, \
|
||||
const TYPENAME inp, \
|
||||
TYPENAME *out \
|
||||
) { \
|
||||
const size_t *dims = info; \
|
||||
const size_t *strides = info + num_dims; \
|
||||
if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
out[i] = inp; \
|
||||
} \
|
||||
} \
|
||||
else { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
out[strided_i] = inp; \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
|
||||
CONST_SET_OP(float, const_set_f32)
|
||||
CONST_SET_OP(double, const_set_f64)
|
||||
CONST_SET_OP(uint8_t, const_set_u8)
|
||||
CONST_SET_OP(uint32_t, const_set_u32)
|
||||
CONST_SET_OP(int64_t, const_set_i64)
|
||||
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
COPY2D_OP(__half, copy2d_f16)
|
||||
CONST_SET_OP(__half, const_set_f16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
COPY2D_OP(__nv_bfloat16, copy2d_bf16)
|
||||
CONST_SET_OP(__nv_bfloat16, const_set_bf16)
|
||||
#endif
|
||||
|
@ -3,6 +3,28 @@
|
||||
#include "cuda_utils.cuh"
|
||||
#include<stdint.h>
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__
|
||||
constexpr T max_value();
|
||||
|
||||
template <>
|
||||
__host__ __device__
|
||||
constexpr int64_t max_value<int64_t>() {
|
||||
return 0x7FFFFFFFFFFFFFFFLL;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__
|
||||
constexpr uint32_t max_value<uint32_t>() {
|
||||
return 0xFFFFFFFFu;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__
|
||||
constexpr uint8_t max_value<uint8_t>() {
|
||||
return 0xFFu;
|
||||
}
|
||||
|
||||
template<typename T, typename I>
|
||||
__device__ void index_select(
|
||||
const size_t numel,
|
||||
@ -23,9 +45,14 @@ __device__ void index_select(
|
||||
unsigned int left_i = dst_i / (ids_dim_size * right_size);
|
||||
unsigned int id_i = dst_i / right_size % ids_dim_size;
|
||||
unsigned int right_i = dst_i % right_size;
|
||||
unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i;
|
||||
unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides);
|
||||
out[dst_i] = inp[strided_i];
|
||||
if (ids[id_i] == max_value<I>()) {
|
||||
out[dst_i] = static_cast<T>(0);
|
||||
} else {
|
||||
assert(ids[id_i] < src_dim_size);
|
||||
unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i;
|
||||
unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides);
|
||||
out[dst_i] = inp[strided_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -56,10 +83,15 @@ __device__ void gather(
|
||||
) {
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||
size_t post = i % right_size;
|
||||
size_t idx = ids[i];
|
||||
size_t pre = i / (right_size * ids_dim_size);
|
||||
size_t src_i = (pre * src_dim_size + idx) * right_size + post;
|
||||
out[i] = inp[src_i];
|
||||
const I idx = ids[i];
|
||||
if (ids[i] == max_value<I>()) {
|
||||
out[i] = static_cast<T>(0);
|
||||
} else {
|
||||
assert(idx < src_dim_size);
|
||||
size_t pre = i / (right_size * ids_dim_size);
|
||||
size_t src_i = (pre * src_dim_size + idx) * right_size + post;
|
||||
out[i] = inp[src_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -91,10 +123,13 @@ __device__ void index_add(
|
||||
const size_t pre = i / right_size;
|
||||
const size_t post = i % right_size;
|
||||
for (unsigned int j = 0; j < ids_dim_size; ++j) {
|
||||
const size_t idx = ids[j];
|
||||
const I idx = ids[j];
|
||||
const size_t src_i = (pre * ids_dim_size + j) * right_size + post;
|
||||
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||
out[dst_i] += inp[src_i];
|
||||
if (idx < max_value<I>()) {
|
||||
assert(idx < dst_dim_size);
|
||||
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||
out[dst_i] += inp[src_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -111,6 +146,32 @@ extern "C" __global__ void FN_NAME( \
|
||||
const size_t right_size \
|
||||
) { index_add(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \
|
||||
|
||||
template<typename T, typename I>
|
||||
__device__ void scatter(
|
||||
const I *ids,
|
||||
const T *inp,
|
||||
T *out,
|
||||
const size_t left_size,
|
||||
const size_t src_dim_size,
|
||||
const size_t dst_dim_size,
|
||||
const size_t right_size
|
||||
) {
|
||||
const size_t numel = left_size * right_size;
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||
const size_t pre = i / right_size;
|
||||
const size_t post = i % right_size;
|
||||
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
|
||||
const I idx = ids[src_i];
|
||||
if (idx < max_value<I>()) {
|
||||
assert(idx < dst_dim_size);
|
||||
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||
out[dst_i] = inp[src_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename I>
|
||||
__device__ void scatter_add(
|
||||
const I *ids,
|
||||
@ -127,13 +188,27 @@ __device__ void scatter_add(
|
||||
const size_t post = i % right_size;
|
||||
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
|
||||
const size_t idx = ids[src_i];
|
||||
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||
out[dst_i] += inp[src_i];
|
||||
const I idx = ids[src_i];
|
||||
if (idx < max_value<I>()) {
|
||||
assert(idx < dst_dim_size);
|
||||
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||
out[dst_i] += inp[src_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define S_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const INDEX_TYPENAME *ids, \
|
||||
const TYPENAME *inp, \
|
||||
TYPENAME *out, \
|
||||
const size_t left_size, \
|
||||
const size_t src_dim_size, \
|
||||
const size_t dst_dim_size, \
|
||||
const size_t right_size \
|
||||
) { scatter(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \
|
||||
|
||||
#define SA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const INDEX_TYPENAME *ids, \
|
||||
@ -159,6 +234,9 @@ IA_OP(__nv_bfloat16, uint8_t, ia_u8_bf16)
|
||||
SA_OP(__nv_bfloat16, int64_t, sa_i64_bf16)
|
||||
SA_OP(__nv_bfloat16, uint32_t, sa_u32_bf16)
|
||||
SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16)
|
||||
S_OP(__nv_bfloat16, int64_t, s_i64_bf16)
|
||||
S_OP(__nv_bfloat16, uint32_t, s_u32_bf16)
|
||||
S_OP(__nv_bfloat16, uint8_t, s_u8_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
@ -174,6 +252,9 @@ IA_OP(__half, uint8_t, ia_u8_f16)
|
||||
SA_OP(__half, int64_t, sa_i64_f16)
|
||||
SA_OP(__half, uint32_t, sa_u32_f16)
|
||||
SA_OP(__half, uint8_t, sa_u8_f16)
|
||||
S_OP(__half, int64_t, s_i64_f16)
|
||||
S_OP(__half, uint32_t, s_u32_f16)
|
||||
S_OP(__half, uint8_t, s_u8_f16)
|
||||
#endif
|
||||
|
||||
IS_OP(float, int64_t, is_i64_f32)
|
||||
@ -247,3 +328,21 @@ SA_OP(double, uint8_t, sa_u8_f64)
|
||||
SA_OP(uint8_t, uint8_t, sa_u8_u8)
|
||||
SA_OP(uint32_t, uint8_t, sa_u8_u32)
|
||||
SA_OP(int64_t, uint8_t, sa_u8_i64)
|
||||
|
||||
S_OP(float, int64_t, s_i64_f32)
|
||||
S_OP(double, int64_t, s_i64_f64)
|
||||
S_OP(uint8_t, int64_t, s_i64_u8)
|
||||
S_OP(int64_t, int64_t, s_i64_i64)
|
||||
S_OP(uint32_t, int64_t, s_i64_u32)
|
||||
|
||||
S_OP(float, uint32_t, s_u32_f32)
|
||||
S_OP(double, uint32_t, s_u32_f64)
|
||||
S_OP(uint8_t, uint32_t, s_u32_u8)
|
||||
S_OP(int64_t, uint32_t, s_u32_i64)
|
||||
S_OP(uint32_t, uint32_t, s_u32_u32)
|
||||
|
||||
S_OP(float, uint8_t, s_u8_f32)
|
||||
S_OP(double, uint8_t, s_u8_f64)
|
||||
S_OP(uint8_t, uint8_t, s_u8_u8)
|
||||
S_OP(uint32_t, uint8_t, s_u8_u32)
|
||||
S_OP(int64_t, uint8_t, s_u8_i64)
|
||||
|
@ -219,11 +219,15 @@ __device__ void softmax(const T * x, T * dst, const int ncols) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) {
|
||||
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t stride_b) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (2 * idx >= bh * td) return;
|
||||
|
||||
uint32_t rope_idx = idx % (td / 2);
|
||||
if (stride_b > 0) {
|
||||
uint32_t b_idx = (2 * idx) / stride_b;
|
||||
rope_idx += b_idx * (td / 2);
|
||||
}
|
||||
T c = cos[rope_idx];
|
||||
T s = sin[rope_idx];
|
||||
|
||||
@ -232,7 +236,7 @@ __device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, cons
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d) {
|
||||
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d, const uint32_t stride_b) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (2 * idx >= bh * td) return;
|
||||
|
||||
@ -243,6 +247,10 @@ __device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const
|
||||
uint32_t i1 = i_bh * td + i_t * d + i_d;
|
||||
uint32_t i2 = i1 + d / 2;
|
||||
uint32_t i_cs = i_t * (d / 2) + i_d;
|
||||
if (stride_b > 0) {
|
||||
uint32_t b_idx = (2 * idx) / stride_b;
|
||||
i_cs += b_idx * (td / 2);
|
||||
}
|
||||
T c = cos[i_cs];
|
||||
T s = sin[i_cs];
|
||||
|
||||
@ -259,7 +267,8 @@ __device__ void rope_thd(
|
||||
const uint32_t b,
|
||||
const uint32_t t,
|
||||
const uint32_t h,
|
||||
const uint32_t d
|
||||
const uint32_t d,
|
||||
const uint32_t stride_b
|
||||
) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (2 * idx >= b * t * h * d) return;
|
||||
@ -270,6 +279,10 @@ __device__ void rope_thd(
|
||||
uint32_t i1 = i_bth * d + i_d;
|
||||
uint32_t i2 = i1 + d / 2;
|
||||
uint32_t i_cs = i_t * (d / 2) + i_d;
|
||||
if (stride_b > 0) {
|
||||
uint32_t b_idx = (2 * idx) / stride_b;
|
||||
i_cs += b_idx * ((t * d) / 2);
|
||||
}
|
||||
T c = cos[i_cs];
|
||||
T s = sin[i_cs];
|
||||
|
||||
@ -546,8 +559,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
const TYPENAME *sin, \
|
||||
TYPENAME *dst, \
|
||||
const uint32_t bh, \
|
||||
const uint32_t td) { \
|
||||
ropei<TYPENAME>(src, cos, sin, dst, bh, td); \
|
||||
const uint32_t td, \
|
||||
const uint32_t stride_b) { \
|
||||
ropei<TYPENAME>(src, cos, sin, dst, bh, td, stride_b); \
|
||||
} \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const TYPENAME *src, \
|
||||
@ -556,8 +570,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
TYPENAME *dst, \
|
||||
const uint32_t bh, \
|
||||
const uint32_t td, \
|
||||
const uint32_t d) { \
|
||||
rope<TYPENAME>(src, cos, sin, dst, bh, td, d); \
|
||||
const uint32_t d, \
|
||||
const uint32_t stride_b) { \
|
||||
rope<TYPENAME>(src, cos, sin, dst, bh, td, d, stride_b); \
|
||||
} \
|
||||
extern "C" __global__ void FN_NAME_THD( \
|
||||
const TYPENAME *src, \
|
||||
@ -567,8 +582,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
const uint32_t b, \
|
||||
const uint32_t t, \
|
||||
const uint32_t h, \
|
||||
const uint32_t d) { \
|
||||
rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d); \
|
||||
const uint32_t d, \
|
||||
const uint32_t stride_b) { \
|
||||
rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d, stride_b); \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.9.0-alpha.4"
|
||||
version = "0.9.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "Metal kernels for Candle"
|
||||
@ -12,6 +12,7 @@ license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
metal = { version = "0.27.0", features = ["mps"] }
|
||||
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
once_cell = "1.18.0"
|
||||
thiserror = "1"
|
||||
tracing = "0.1.37"
|
||||
|
@ -4,20 +4,20 @@ using namespace metal;
|
||||
|
||||
template<typename T> METAL_FUNC void fill_with(
|
||||
device T *out,
|
||||
constant float &value,
|
||||
constant T &value,
|
||||
constant size_t &numel,
|
||||
uint tid [[thread_position_in_grid]]
|
||||
) {
|
||||
if (tid >= numel) {
|
||||
return;
|
||||
}
|
||||
out[tid] = static_cast<T>(value);
|
||||
out[tid] = value;
|
||||
}
|
||||
|
||||
#define FILL_OP(NAME, T) \
|
||||
kernel void fill_##NAME( \
|
||||
device T *out, \
|
||||
constant float &value, \
|
||||
constant T &value, \
|
||||
constant size_t &numel, \
|
||||
uint tid [[thread_position_in_grid]] \
|
||||
) { \
|
||||
|
@ -1,6 +1,24 @@
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
template <typename T>
|
||||
inline T max_value();
|
||||
|
||||
template <>
|
||||
inline int64_t max_value<int64_t>() {
|
||||
return 0x7FFFFFFFFFFFFFFF;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline uint32_t max_value<uint32_t>() {
|
||||
return 0xFFFFFFFFu;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline uint8_t max_value<uint8_t>() {
|
||||
return 0xFF;
|
||||
}
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
@ -35,17 +53,21 @@ METAL_FUNC void index(
|
||||
return;
|
||||
}
|
||||
const size_t id_i = (tid / right_size) % ids_size;
|
||||
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
|
||||
const size_t right_rank_i = tid % right_size;
|
||||
const size_t left_rank_i = tid / right_size / ids_size;
|
||||
/*
|
||||
// Force prevent out of bounds indexing
|
||||
// since there doesn't seem to be a good way to force crash
|
||||
// No need to check for zero we're only allowing unsized.
|
||||
*/
|
||||
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
|
||||
const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides);
|
||||
output[tid] = input[strided_src_i];
|
||||
if (input_ids[id_i] == max_value<INDEX_TYPENAME>()) {
|
||||
output[tid] = static_cast<TYPENAME>(0);
|
||||
} else {
|
||||
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
|
||||
const size_t right_rank_i = tid % right_size;
|
||||
const size_t left_rank_i = tid / right_size / ids_size;
|
||||
/*
|
||||
// Force prevent out of bounds indexing
|
||||
// since there doesn't seem to be a good way to force crash
|
||||
// No need to check for zero we're only allowing unsized.
|
||||
*/
|
||||
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
|
||||
const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides);
|
||||
output[tid] = input[strided_src_i];
|
||||
}
|
||||
}
|
||||
|
||||
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||
@ -83,10 +105,14 @@ METAL_FUNC void gather(
|
||||
return;
|
||||
}
|
||||
const INDEX_TYPENAME input_i = input_ids[tid];
|
||||
const size_t right_rank_i = tid % right_size;
|
||||
const size_t left_rank_i = tid / right_size / ids_size;
|
||||
const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
|
||||
output[tid] = input[src_i];
|
||||
if (input_i == max_value<INDEX_TYPENAME>()) {
|
||||
output[tid] = static_cast<TYPENAME>(0);
|
||||
} else {
|
||||
const size_t right_rank_i = tid % right_size;
|
||||
const size_t left_rank_i = tid / right_size / ids_size;
|
||||
const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
|
||||
output[tid] = input[src_i];
|
||||
}
|
||||
}
|
||||
|
||||
# define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||
@ -104,6 +130,33 @@ kernel void NAME( \
|
||||
gather<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
|
||||
}
|
||||
|
||||
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||
METAL_FUNC void scatter(
|
||||
constant size_t &dst_size,
|
||||
constant size_t &left_size,
|
||||
constant size_t &src_dim_size,
|
||||
constant size_t &right_size,
|
||||
constant size_t &dst_dim_size,
|
||||
const device TYPENAME *input,
|
||||
const device INDEX_TYPENAME *input_ids,
|
||||
device TYPENAME *output,
|
||||
uint tid [[ thread_position_in_grid ]]
|
||||
) {
|
||||
if (tid >= dst_size) {
|
||||
return;
|
||||
}
|
||||
const size_t right_rank_i = tid % right_size;
|
||||
const size_t left_rank_i = tid / right_size;
|
||||
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
||||
const INDEX_TYPENAME idx = input_ids[src_i];
|
||||
if (idx < max_value<INDEX_TYPENAME>()) {
|
||||
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
||||
output[dst_i] = input[src_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||
METAL_FUNC void scatter_add(
|
||||
constant size_t &dst_size,
|
||||
@ -124,11 +177,28 @@ METAL_FUNC void scatter_add(
|
||||
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
||||
const INDEX_TYPENAME idx = input_ids[src_i];
|
||||
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
||||
output[dst_i] += input[src_i];
|
||||
if (idx < max_value<INDEX_TYPENAME>()) {
|
||||
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
||||
output[dst_i] += input[src_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# define SCATTER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||
kernel void NAME( \
|
||||
constant size_t &dst_size, \
|
||||
constant size_t &left_size, \
|
||||
constant size_t &src_dim_size, \
|
||||
constant size_t &right_size, \
|
||||
constant size_t &dst_dim_size, \
|
||||
const device TYPENAME *input, \
|
||||
const device INDEX_TYPENAME *input_ids, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
scatter<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, input, input_ids, output, tid); \
|
||||
}
|
||||
|
||||
# define SCATTER_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||
kernel void NAME( \
|
||||
constant size_t &dst_size, \
|
||||
@ -164,9 +234,11 @@ METAL_FUNC void index_add(
|
||||
const size_t left_rank_i = tid / right_size;
|
||||
for (unsigned int j = 0; j < ids_dim_size; ++j) {
|
||||
const INDEX_TYPENAME idx = input_ids[j];
|
||||
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
||||
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
||||
output[dst_i] += input[src_i];
|
||||
if (idx < max_value<INDEX_TYPENAME>()) {
|
||||
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
||||
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
||||
output[dst_i] += input[src_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -235,6 +307,19 @@ SCATTER_ADD_OP(sa_u8_bf16, uint8_t, bfloat)
|
||||
SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat)
|
||||
#endif
|
||||
|
||||
SCATTER_OP(s_u32_f32, uint32_t, float)
|
||||
SCATTER_OP(s_u8_f32, uint8_t, float)
|
||||
SCATTER_OP(s_i64_f32, int64_t, float)
|
||||
SCATTER_OP(s_u32_u32, uint32_t, uint32_t)
|
||||
SCATTER_OP(s_u32_f16, uint32_t, half)
|
||||
SCATTER_OP(s_u8_f16, uint8_t, half)
|
||||
SCATTER_OP(s_i64_f16, int64_t, half)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
SCATTER_OP(s_u32_bf16, uint32_t, bfloat)
|
||||
SCATTER_OP(s_u8_bf16, uint8_t, bfloat)
|
||||
SCATTER_OP(s_i64_bf16, int64_t, bfloat)
|
||||
#endif
|
||||
|
||||
// i64
|
||||
INDEX_ADD_OP(ia_i64_f16, int64_t, half)
|
||||
INDEX_ADD_OP(ia_i64_f32, int64_t, float)
|
||||
|
@ -161,7 +161,7 @@ macro_rules! ops{
|
||||
pub mod unary {
|
||||
ops!(
|
||||
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
|
||||
tanh, recip, silu, sign, sigmoid
|
||||
tanh, recip, silu, sign, sigmoid, const_set
|
||||
);
|
||||
}
|
||||
pub mod binary {
|
||||
@ -419,6 +419,82 @@ pub fn call_copy2d(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_const_set_contiguous_tiled(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
kernel_name: unary::contiguous_tiled::Kernel,
|
||||
length: usize,
|
||||
input: impl EncoderParam,
|
||||
output: BufferOffset,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
let tile_size = 2;
|
||||
let tiles = length.div_ceil(tile_size);
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, input, &output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles);
|
||||
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_const_set_contiguous(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
kernel_name: unary::contiguous::Kernel,
|
||||
length: usize,
|
||||
input: impl EncoderParam,
|
||||
output: BufferOffset,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, input, &output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_const_set_strided(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
name: unary::strided::Kernel,
|
||||
shape: &[usize],
|
||||
input: impl EncoderParam,
|
||||
strides: &[usize],
|
||||
output: BufferOffset,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
let num_dims: usize = shape.len();
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(encoder, (length, num_dims, shape, strides, input, &output));
|
||||
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_unary_contiguous_tiled(
|
||||
device: &Device,
|
||||
@ -915,6 +991,7 @@ pub fn call_rope_i(
|
||||
kernel_name: &'static str,
|
||||
bh: usize,
|
||||
td: usize,
|
||||
stride_b: usize,
|
||||
src: &Buffer,
|
||||
src_offset: usize,
|
||||
cos: &Buffer,
|
||||
@ -933,6 +1010,7 @@ pub fn call_rope_i(
|
||||
(
|
||||
bh,
|
||||
td,
|
||||
stride_b,
|
||||
(src, src_offset),
|
||||
(cos, cos_offset),
|
||||
(sin, sin_offset),
|
||||
@ -958,6 +1036,7 @@ pub fn call_rope_thd(
|
||||
t: usize,
|
||||
h: usize,
|
||||
d: usize,
|
||||
stride_b: usize,
|
||||
src: &Buffer,
|
||||
src_offset: usize,
|
||||
cos: &Buffer,
|
||||
@ -978,6 +1057,7 @@ pub fn call_rope_thd(
|
||||
t,
|
||||
h,
|
||||
d,
|
||||
stride_b,
|
||||
(src, src_offset),
|
||||
(cos, cos_offset),
|
||||
(sin, sin_offset),
|
||||
@ -1002,6 +1082,7 @@ pub fn call_rope(
|
||||
bh: usize,
|
||||
td: usize,
|
||||
d: usize,
|
||||
stride_b: usize,
|
||||
src: &Buffer,
|
||||
src_offset: usize,
|
||||
cos: &Buffer,
|
||||
@ -1021,6 +1102,7 @@ pub fn call_rope(
|
||||
bh,
|
||||
td,
|
||||
d,
|
||||
stride_b,
|
||||
(src, src_offset),
|
||||
(cos, cos_offset),
|
||||
(sin, sin_offset),
|
||||
@ -1371,7 +1453,7 @@ pub fn call_gather(
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_scatter_add(
|
||||
pub fn call_scatter(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
@ -1381,7 +1463,7 @@ pub fn call_scatter_add(
|
||||
dim: usize,
|
||||
input: BufferOffset,
|
||||
ids: BufferOffset,
|
||||
output: &Buffer,
|
||||
output: BufferOffset,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = src_shape[..dim].iter().product();
|
||||
let right_size: usize = src_shape[dim + 1..].iter().product();
|
||||
@ -1406,7 +1488,7 @@ pub fn call_scatter_add(
|
||||
dst_dim_size,
|
||||
&input,
|
||||
&ids,
|
||||
output
|
||||
&output
|
||||
)
|
||||
);
|
||||
|
||||
@ -1414,7 +1496,7 @@ pub fn call_scatter_add(
|
||||
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
Ok(())
|
||||
}
|
||||
@ -2570,7 +2652,7 @@ pub fn call_const_fill(
|
||||
name: &'static str,
|
||||
length: usize,
|
||||
output: &Buffer,
|
||||
v: f32,
|
||||
v: impl EncoderParam,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
|
||||
let encoder = ep.encoder();
|
||||
|
@ -1097,6 +1097,7 @@ template<typename T>
|
||||
METAL_FUNC void ropei(
|
||||
constant size_t &bh,
|
||||
constant size_t &td,
|
||||
constant size_t &stride_b,
|
||||
device const T *src,
|
||||
device const T *cos,
|
||||
device const T *sin,
|
||||
@ -1107,6 +1108,10 @@ METAL_FUNC void ropei(
|
||||
return;
|
||||
}
|
||||
size_t rope_idx = tid % (td / 2);
|
||||
if (stride_b > 0) {
|
||||
size_t b_idx = (2 * tid) / stride_b;
|
||||
rope_idx += b_idx * (td / 2);
|
||||
}
|
||||
T c = cos[rope_idx];
|
||||
T s = sin[rope_idx];
|
||||
dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s;
|
||||
@ -1118,6 +1123,7 @@ METAL_FUNC void rope(
|
||||
constant size_t &bh,
|
||||
constant size_t &td,
|
||||
constant size_t &d,
|
||||
constant size_t &stride_b,
|
||||
device const T *src,
|
||||
device const T *cos,
|
||||
device const T *sin,
|
||||
@ -1134,6 +1140,10 @@ METAL_FUNC void rope(
|
||||
size_t i1 = i_bh * td + i_t * d + i_d;
|
||||
size_t i2 = i1 + d / 2;
|
||||
size_t i_cs = i_t * (d / 2) + i_d;
|
||||
if (stride_b > 0) {
|
||||
size_t b_idx = (2 * idx) / stride_b;
|
||||
i_cs += b_idx * (td / 2);
|
||||
}
|
||||
T c = cos[i_cs];
|
||||
T s = sin[i_cs];
|
||||
dst[i1] = src[i1] * c - src[i2] * s;
|
||||
@ -1146,6 +1156,7 @@ METAL_FUNC void rope_thd(
|
||||
constant size_t &t,
|
||||
constant size_t &h,
|
||||
constant size_t &d,
|
||||
constant size_t &stride_b,
|
||||
device const T *src,
|
||||
device const T *cos,
|
||||
device const T *sin,
|
||||
@ -1160,8 +1171,12 @@ METAL_FUNC void rope_thd(
|
||||
const size_t i_t = (i_bth / h) % t;
|
||||
const size_t i1 = i_bth * d + i_d;
|
||||
const size_t i2 = i1 + d / 2;
|
||||
const size_t i_cs = i_t * (d / 2) + i_d;
|
||||
T c = cos[i_cs];
|
||||
size_t i_cs = i_t * (d / 2) + i_d;
|
||||
if (stride_b > 0) {
|
||||
const size_t b_idx = (2 * idx) / stride_b;
|
||||
i_cs += b_idx * ((t * d) / 2);
|
||||
}
|
||||
T c = cos[i_cs];
|
||||
T s = sin[i_cs];
|
||||
dst[i1] = src[i1] * c - src[i2] * s;
|
||||
dst[i2] = src[i1] * s + src[i2] * c;
|
||||
@ -1171,38 +1186,41 @@ METAL_FUNC void rope_thd(
|
||||
kernel void FN_NAME_I( \
|
||||
constant size_t &bh, \
|
||||
constant size_t &td, \
|
||||
constant size_t &stride_b, \
|
||||
device const TYPENAME *src, \
|
||||
device const TYPENAME *cos, \
|
||||
device const TYPENAME *sin, \
|
||||
device TYPENAME *dst, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
ropei<TYPENAME>(bh, td, src, cos, sin, dst, tid); \
|
||||
ropei<TYPENAME>(bh, td, stride_b, src, cos, sin, dst, tid); \
|
||||
}\
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &bh, \
|
||||
constant size_t &td, \
|
||||
constant size_t &d, \
|
||||
constant size_t &stride_b, \
|
||||
device const TYPENAME *src, \
|
||||
device const TYPENAME *cos, \
|
||||
device const TYPENAME *sin, \
|
||||
device TYPENAME *dst, \
|
||||
uint idx [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
rope<TYPENAME>(bh, td, d, src, cos, sin, dst, idx); \
|
||||
rope<TYPENAME>(bh, td, d, stride_b, src, cos, sin, dst, idx); \
|
||||
}\
|
||||
kernel void FN_NAME_THD( \
|
||||
constant size_t &b, \
|
||||
constant size_t &t, \
|
||||
constant size_t &h, \
|
||||
constant size_t &d, \
|
||||
constant size_t &stride_b, \
|
||||
device const TYPENAME *src, \
|
||||
device const TYPENAME *cos, \
|
||||
device const TYPENAME *sin, \
|
||||
device TYPENAME *dst, \
|
||||
uint idx [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
rope_thd<TYPENAME>(b, t, h, d, src, cos, sin, dst, idx); \
|
||||
rope_thd<TYPENAME>(b, t, h, d, stride_b, src, cos, sin, dst, idx); \
|
||||
}\
|
||||
|
||||
RMSNORM(rmsnorm_f32, float)
|
||||
|
@ -1574,7 +1574,7 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
let input_buffer = new_buffer(&device, input);
|
||||
let ids_buffer = new_buffer(&device, ids);
|
||||
let output = device.new_buffer(std::mem::size_of_val(input) as u64, options);
|
||||
call_scatter_add(
|
||||
call_scatter(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
@ -1584,7 +1584,7 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
dim,
|
||||
BufferOffset::zero_offset(&input_buffer),
|
||||
BufferOffset::zero_offset(&ids_buffer),
|
||||
&output,
|
||||
BufferOffset::zero_offset(&output),
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
@ -2343,7 +2343,7 @@ fn conv_transpose1d_u32() {
|
||||
|
||||
#[test]
|
||||
fn const_fill() {
|
||||
fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
|
||||
fn constant_fill<T: Clone + EncoderParam>(name: &'static str, len: usize, value: T) -> Vec<T> {
|
||||
let dev = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = dev.new_command_queue();
|
||||
@ -2357,11 +2357,15 @@ fn const_fill() {
|
||||
command_buffer.wait_until_completed();
|
||||
read_to_vec::<T>(&buffer, len)
|
||||
}
|
||||
fn test<T: Clone + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(name: &'static str, f: F) {
|
||||
fn test<T: Clone + Copy + EncoderParam + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(
|
||||
name: &'static str,
|
||||
f: F,
|
||||
) {
|
||||
let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16);
|
||||
let value = rand::thread_rng().gen_range(1. ..19.);
|
||||
let value = f(value);
|
||||
let v = constant_fill::<T>(name, len, value);
|
||||
assert_eq!(v, vec![f(value); len])
|
||||
assert_eq!(v, vec![value; len])
|
||||
}
|
||||
test::<u8, _>("fill_u8", |v| v as u8);
|
||||
test::<u32, _>("fill_u32", |v| v as u32);
|
||||
|
@ -73,6 +73,44 @@ template <typename T> METAL_FUNC T sigmoid(T in) {
|
||||
|
||||
#define TILE_SIZE 2
|
||||
|
||||
#define CONST_SET(TYPENAME, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
constant TYPENAME &input, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = input; \
|
||||
} \
|
||||
kernel void FN_NAME##_##strided( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant TYPENAME &input, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[get_strided_index(tid, num_dims, dims, strides)] = input; \
|
||||
} \
|
||||
kernel void FN_NAME##_##tiled( \
|
||||
constant size_t &dim, \
|
||||
constant TYPENAME &input, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
for (uint i = 0; i < TILE_SIZE; i++) { \
|
||||
const uint idx = tid * TILE_SIZE + i; \
|
||||
output[idx] = input; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
@ -139,6 +177,11 @@ COPY2D(copy2d_f16, half)
|
||||
COPY2D(copy2d_u8, uint8_t)
|
||||
COPY2D(copy2d_u32, uint32_t)
|
||||
|
||||
CONST_SET(float, const_set_f32)
|
||||
CONST_SET(half, const_set_f16)
|
||||
CONST_SET(uint8_t, const_set_u8)
|
||||
CONST_SET(uint32_t, const_set_u32)
|
||||
|
||||
UNARY_OP(cos)
|
||||
UNARY_OP(sin)
|
||||
UNARY_OP(sqr)
|
||||
@ -171,6 +214,7 @@ UNARY(precise::tanh, half, tanh_f16, tanh_f16_strided);
|
||||
#if __METAL_VERSION__ >= 220
|
||||
UNARY(id, int64_t, copy_i64, copy_i64_strided)
|
||||
COPY2D(copy2d_i64, int64_t)
|
||||
CONST_SET(int64_t, const_set_i64)
|
||||
#endif
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
@ -199,4 +243,5 @@ UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
||||
UNARY(precise::tanh, bfloat, tanh_bf16, tanh_bf16_strided);
|
||||
|
||||
COPY2D(copy2d_bf16, bfloat)
|
||||
CONST_SET(bfloat, const_set_bf16)
|
||||
#endif
|
||||
|
@ -88,9 +88,13 @@ primitive!(bool);
|
||||
primitive!(usize);
|
||||
primitive!(i32);
|
||||
primitive!(i64);
|
||||
primitive!(u8);
|
||||
primitive!(u32);
|
||||
primitive!(u64);
|
||||
primitive!(f32);
|
||||
primitive!(f64);
|
||||
primitive!(half::bf16);
|
||||
primitive!(half::f16);
|
||||
|
||||
pub struct BufferOffset<'a> {
|
||||
pub buffer: &'a Buffer,
|
||||
|
@ -71,6 +71,8 @@ impl candle::Module for PReLU {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let weight = if self.is_scalar {
|
||||
self.weight.reshape(())?
|
||||
} else if xs.shape() == self.weight.shape() {
|
||||
self.weight.clone()
|
||||
} else if xs.rank() >= 2 {
|
||||
let num_channels = xs.dim(1)?;
|
||||
let num_weights = self.weight.elem_count();
|
||||
@ -78,7 +80,7 @@ impl candle::Module for PReLU {
|
||||
candle::bail!("error in prelu: unexpected number of channels for the input, got {num_channels}, weight dim is {num_weights}")
|
||||
}
|
||||
let mut s = vec![1; xs.rank()];
|
||||
s[1] = self.weight.elem_count();
|
||||
s[1] = num_weights;
|
||||
self.weight.reshape(s)?
|
||||
} else {
|
||||
self.weight.clone()
|
||||
|
@ -1,6 +1,6 @@
|
||||
//! Cache Implementations
|
||||
//!
|
||||
use candle::{Device, Result, Tensor};
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Cache {
|
||||
@ -399,3 +399,322 @@ impl RotatingKvCache {
|
||||
self.v.reset();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IndicesAndMask {
|
||||
indices: Tensor,
|
||||
mask: Tensor,
|
||||
}
|
||||
|
||||
impl IndicesAndMask {
|
||||
pub fn mask(&self) -> &Tensor {
|
||||
&self.mask
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ScatteredKvCache {
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
context: usize,
|
||||
}
|
||||
|
||||
impl ScatteredKvCache {
|
||||
pub fn append(
|
||||
&mut self,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
iam: &IndicesAndMask,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
if self.context <= k.dim(2)? {
|
||||
return Ok((k.clone(), v.clone()));
|
||||
}
|
||||
let indices = iam.indices.unsqueeze(2)?.unsqueeze(1)?;
|
||||
let indices = indices.broadcast_as(k.shape())?.contiguous()?;
|
||||
self.k.scatter_set(&indices, k, 2)?;
|
||||
self.v.scatter_set(&indices, v, 2)?;
|
||||
Ok((self.k.clone(), self.v.clone()))
|
||||
}
|
||||
|
||||
pub fn k(&self) -> &Tensor {
|
||||
&self.k
|
||||
}
|
||||
|
||||
pub fn v(&self) -> &Tensor {
|
||||
&self.v
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ScatteredCacheBuilder {
|
||||
context: usize,
|
||||
// The current position in the stream, this can be larger than context.
|
||||
positions: Vec<usize>,
|
||||
// The index where the next element will be stored.
|
||||
indices: Vec<usize>,
|
||||
dtype: DType,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl ScatteredCacheBuilder {
|
||||
pub fn new(batch_size: usize, context: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||
let positions = vec![0; batch_size];
|
||||
let indices = vec![0; batch_size];
|
||||
Ok(Self {
|
||||
positions,
|
||||
indices,
|
||||
context,
|
||||
dtype,
|
||||
device: device.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn make_cache(&self, num_heads: usize, head_dim: usize) -> Result<ScatteredKvCache> {
|
||||
let batch_size = self.batch_size();
|
||||
let shape = (batch_size, num_heads, self.context, head_dim);
|
||||
let k = Tensor::zeros(shape, self.dtype, self.device())?;
|
||||
let v = Tensor::zeros(shape, self.dtype, self.device())?;
|
||||
Ok(ScatteredKvCache {
|
||||
k,
|
||||
v,
|
||||
context: self.context,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn positions(&self) -> &[usize] {
|
||||
&self.positions
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.positions.fill(0);
|
||||
self.indices.fill(0);
|
||||
}
|
||||
|
||||
pub fn batch_size(&self) -> usize {
|
||||
self.positions.len()
|
||||
}
|
||||
|
||||
pub fn reset_batch_index(&mut self, batch_index: usize) {
|
||||
self.positions[batch_index] = 0;
|
||||
self.indices[batch_index] = 0;
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
pub fn indices_and_mask(
|
||||
&mut self,
|
||||
seq_len: usize,
|
||||
batch_mask: &[bool],
|
||||
) -> Result<IndicesAndMask> {
|
||||
// mask shape is (b, h, t, k)
|
||||
let context = self.context;
|
||||
if self.context <= seq_len {
|
||||
return self.indices_and_mask_abs(seq_len, batch_mask);
|
||||
}
|
||||
let mut attention_masks = Vec::with_capacity(self.batch_size());
|
||||
let mut cache_indices = Vec::with_capacity(self.batch_size());
|
||||
for (batch_i, &batch_mask) in batch_mask.iter().enumerate() {
|
||||
if !batch_mask {
|
||||
let masks: Vec<Vec<f32>> = vec![vec![0.0; context]; seq_len];
|
||||
let indices = vec![self.indices[batch_i] as u32; seq_len];
|
||||
attention_masks.push(masks);
|
||||
cache_indices.push(indices);
|
||||
} else {
|
||||
let start_index = self.indices[batch_i];
|
||||
let start_pos = self.positions[batch_i];
|
||||
let mut masks: Vec<Vec<f32>> = Vec::with_capacity(seq_len);
|
||||
let mut indices = Vec::with_capacity(seq_len);
|
||||
let mut all_pos = vec![usize::MAX; context];
|
||||
if start_pos < context {
|
||||
for i in 0..start_pos {
|
||||
all_pos[i] = i;
|
||||
}
|
||||
} else {
|
||||
let offset = start_pos - start_index;
|
||||
for i in 0..context {
|
||||
all_pos[i] = if i < start_index {
|
||||
i + offset
|
||||
} else {
|
||||
i + offset - context
|
||||
};
|
||||
}
|
||||
}
|
||||
for seq_i in 0..seq_len {
|
||||
let index = self.indices[batch_i];
|
||||
all_pos[index] = seq_i + start_pos;
|
||||
indices.push(index as u32);
|
||||
self.indices[batch_i] += 1;
|
||||
self.positions[batch_i] += 1;
|
||||
if self.indices[batch_i] >= self.context {
|
||||
self.indices[batch_i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
for seq_i in 0..seq_len {
|
||||
let my_pos = seq_i + start_pos;
|
||||
let mask = all_pos
|
||||
.iter()
|
||||
.map(|&pos| {
|
||||
if pos <= my_pos {
|
||||
0.0
|
||||
} else {
|
||||
f32::NEG_INFINITY
|
||||
}
|
||||
})
|
||||
.collect::<Vec<f32>>();
|
||||
masks.push(mask);
|
||||
}
|
||||
|
||||
attention_masks.push(masks);
|
||||
cache_indices.push(indices);
|
||||
}
|
||||
}
|
||||
// Flattening the attention mask then using Tensor::from_vec rather using Tensor::new ends
|
||||
// up being almost 10x faster with candle 0.9.0. This has been fixed in candle 0.9.1.
|
||||
let attention_masks = attention_masks
|
||||
.into_iter()
|
||||
.flat_map(|m| m.into_iter().flatten())
|
||||
.collect::<Vec<f32>>();
|
||||
let mask = Tensor::from_vec(attention_masks, ((), 1, seq_len, context), self.device())?
|
||||
.to_dtype(self.dtype)?;
|
||||
let indices = Tensor::new(cache_indices, self.device())?;
|
||||
Ok(IndicesAndMask { indices, mask })
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
fn indices_and_mask_abs(
|
||||
&mut self,
|
||||
seq_len: usize,
|
||||
batch_mask: &[bool],
|
||||
) -> Result<IndicesAndMask> {
|
||||
let mask = self.get_mask_abs(seq_len, seq_len)?;
|
||||
let mut cache_indices = Vec::with_capacity(self.batch_size());
|
||||
for (batch_i, &batch_mask) in batch_mask.iter().enumerate() {
|
||||
if !batch_mask {
|
||||
let indices = vec![self.indices[batch_i] as u32; seq_len];
|
||||
cache_indices.push(indices);
|
||||
} else {
|
||||
let mut indices = Vec::with_capacity(seq_len);
|
||||
for _ in 0..seq_len {
|
||||
let index = self.indices[batch_i];
|
||||
indices.push(index as u32);
|
||||
self.indices[batch_i] += 1;
|
||||
self.positions[batch_i] += 1;
|
||||
if self.indices[batch_i] >= self.context {
|
||||
self.indices[batch_i] = 0;
|
||||
}
|
||||
}
|
||||
cache_indices.push(indices);
|
||||
}
|
||||
}
|
||||
let indices = Tensor::new(cache_indices, self.device())?;
|
||||
Ok(IndicesAndMask { indices, mask })
|
||||
}
|
||||
|
||||
fn get_mask_abs(&self, size1: usize, size2: usize) -> Result<Tensor> {
|
||||
let context = self.context;
|
||||
let mask: Vec<_> = (0..size1)
|
||||
.flat_map(|i| {
|
||||
(0..size2).map(move |j| {
|
||||
if size1 + j > size2 + i || size1 + j + context < size2 + i {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (size1, size2), self.device())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use candle::IndexOp;
|
||||
|
||||
#[test]
|
||||
fn test_scattered_kv_cache() -> Result<()> {
|
||||
let device = Device::Cpu;
|
||||
let mut cache = ScatteredCacheBuilder::new(2, 5, DType::F32, &device)?;
|
||||
let inf = f32::INFINITY;
|
||||
|
||||
let iam = cache.indices_and_mask(1, &[true, false])?;
|
||||
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
|
||||
assert_eq!(iam.indices.to_vec2::<u32>()?, [[0], [0]]);
|
||||
assert_eq!(
|
||||
mask,
|
||||
[[[0.0, -inf, -inf, -inf, -inf]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
|
||||
);
|
||||
|
||||
let iam = cache.indices_and_mask(1, &[true, false])?;
|
||||
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
|
||||
assert_eq!(iam.indices.to_vec2::<u32>()?, [[1], [0]]);
|
||||
assert_eq!(
|
||||
mask,
|
||||
[[[0.0, 0.0, -inf, -inf, -inf]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
|
||||
);
|
||||
|
||||
let iam = cache.indices_and_mask(3, &[false, true])?;
|
||||
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
|
||||
assert_eq!(iam.indices.to_vec2::<u32>()?, [[2, 2, 2], [0, 1, 2]]);
|
||||
assert_eq!(
|
||||
mask,
|
||||
[
|
||||
[
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
],
|
||||
[
|
||||
[0.0, -inf, -inf, -inf, -inf],
|
||||
[0.0, 0.0, -inf, -inf, -inf],
|
||||
[0.0, 0.0, 0.0, -inf, -inf]
|
||||
]
|
||||
]
|
||||
);
|
||||
|
||||
let iam = cache.indices_and_mask(3, &[true, true])?;
|
||||
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
|
||||
assert_eq!(iam.indices.to_vec2::<u32>()?, [[2, 3, 4], [3, 4, 0]]);
|
||||
assert_eq!(
|
||||
mask,
|
||||
[
|
||||
[
|
||||
[0.0, 0.0, 0.0, -inf, -inf],
|
||||
[0.0, 0.0, 0.0, 0.0, -inf],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
],
|
||||
[
|
||||
[-inf, 0.0, 0.0, 0.0, -inf],
|
||||
[-inf, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
]
|
||||
]
|
||||
);
|
||||
|
||||
let iam = cache.indices_and_mask(1, &[true, false])?;
|
||||
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
|
||||
assert_eq!(iam.indices.to_vec2::<u32>()?, [[0], [1]]);
|
||||
assert_eq!(
|
||||
mask,
|
||||
[[[0.0, 0.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
|
||||
);
|
||||
|
||||
let iam = cache.indices_and_mask(2, &[true, false])?;
|
||||
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
|
||||
assert_eq!(iam.indices.to_vec2::<u32>()?, [[1, 2], [1, 1]]);
|
||||
assert_eq!(
|
||||
mask,
|
||||
[
|
||||
[[0.0, 0.0, -inf, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]],
|
||||
[[0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]]
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -46,15 +46,23 @@ impl candle::CustomOp3 for RotaryEmbI {
|
||||
Some((o1, o2)) => &sin[o1..o2],
|
||||
};
|
||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
|
||||
let el_count = b * h * t * d;
|
||||
let mut dst = vec![T::zero(); el_count];
|
||||
src.par_chunks(t * d)
|
||||
.zip(dst.par_chunks_mut(t * d))
|
||||
.for_each(|(src, dst)| {
|
||||
.enumerate()
|
||||
.for_each(|(bh_i, (src, dst))| {
|
||||
for i_over_2 in 0..t * d / 2 {
|
||||
let i = 2 * i_over_2;
|
||||
dst[i] = src[i] * cos[i_over_2] - src[i + 1] * sin[i_over_2];
|
||||
dst[i + 1] = src[i] * sin[i_over_2] + src[i + 1] * cos[i_over_2];
|
||||
let rope_i = if unbatched_rope {
|
||||
let b_i = bh_i / h;
|
||||
i_over_2 + b_i * t * d / 2
|
||||
} else {
|
||||
i_over_2
|
||||
};
|
||||
dst[i] = src[i] * cos[rope_i] - src[i + 1] * sin[rope_i];
|
||||
dst[i + 1] = src[i] * sin[rope_i] + src[i + 1] * cos[rope_i];
|
||||
}
|
||||
});
|
||||
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||
@ -115,6 +123,11 @@ impl candle::CustomOp3 for RotaryEmbI {
|
||||
Some((o1, o2)) => sin.slice(o1..o2),
|
||||
};
|
||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
|
||||
(h * t * d) as u32
|
||||
} else {
|
||||
0u32
|
||||
};
|
||||
let el = b * h * t * d;
|
||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), &kernels::REDUCE)?;
|
||||
@ -125,7 +138,7 @@ impl candle::CustomOp3 for RotaryEmbI {
|
||||
builder.arg(&cos);
|
||||
builder.arg(&sin);
|
||||
builder.arg(&dst);
|
||||
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32);
|
||||
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, stride_b);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(dst)
|
||||
@ -182,6 +195,11 @@ impl candle::CustomOp3 for RotaryEmbI {
|
||||
dtype => candle::bail!("rope-i is not implemented for {dtype:?}"),
|
||||
};
|
||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
|
||||
h * t * d
|
||||
} else {
|
||||
0usize
|
||||
};
|
||||
let el = b * h * t * d;
|
||||
let output = device.new_buffer(el, src.dtype(), "rope-i")?;
|
||||
candle_metal_kernels::call_rope_i(
|
||||
@ -191,6 +209,7 @@ impl candle::CustomOp3 for RotaryEmbI {
|
||||
name,
|
||||
b * h,
|
||||
t * d,
|
||||
stride_b,
|
||||
src.buffer(),
|
||||
l_src.start_offset() * src.dtype().size_in_bytes(),
|
||||
cos.buffer(),
|
||||
@ -205,10 +224,23 @@ impl candle::CustomOp3 for RotaryEmbI {
|
||||
}
|
||||
}
|
||||
|
||||
fn rope_check_cs(cs: &Tensor, b_sz: usize) -> Result<(usize, usize)> {
|
||||
match *cs.dims() {
|
||||
[t, d] => Ok((t, d)),
|
||||
[b, t, d] => {
|
||||
if b != b_sz {
|
||||
candle::bail!("inconsistent batch size in rope {b_sz} {cs:?}",)
|
||||
}
|
||||
Ok((t, d))
|
||||
}
|
||||
_ => candle::bail!("cos/sin has to be 2D or 3D in rope {b_sz} {cs:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rope_i(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
|
||||
let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
|
||||
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
|
||||
let (sin_seq_len, sin_n_embd) = cos.dims2()?;
|
||||
let (b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
|
||||
let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
|
||||
let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
|
||||
if cos_n_embd * 2 != n_embd
|
||||
|| sin_n_embd * 2 != n_embd
|
||||
|| seq_len > cos_seq_len
|
||||
@ -292,16 +324,24 @@ impl candle::CustomOp3 for RotaryEmb {
|
||||
Some((o1, o2)) => &sin[o1..o2],
|
||||
};
|
||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
|
||||
let el_count = b * h * t * d;
|
||||
let mut dst = vec![T::zero(); el_count];
|
||||
src.par_chunks(t * d)
|
||||
.zip(dst.par_chunks_mut(t * d))
|
||||
.for_each(|(src, dst)| {
|
||||
.enumerate()
|
||||
.for_each(|(bh_i, (src, dst))| {
|
||||
for i_t in 0..t {
|
||||
for i_d in 0..d / 2 {
|
||||
let i1 = i_t * d + i_d;
|
||||
let i2 = i1 + d / 2;
|
||||
let i_cs = i_t * (d / 2) + i_d;
|
||||
let i_cs = if unbatched_rope {
|
||||
let b_i = bh_i / h;
|
||||
i_cs + b_i * t * d / 2
|
||||
} else {
|
||||
i_cs
|
||||
};
|
||||
dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs];
|
||||
dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs];
|
||||
}
|
||||
@ -365,6 +405,11 @@ impl candle::CustomOp3 for RotaryEmb {
|
||||
Some((o1, o2)) => sin.slice(o1..o2),
|
||||
};
|
||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
|
||||
(h * t * d) as u32
|
||||
} else {
|
||||
0u32
|
||||
};
|
||||
let el = b * h * t * d;
|
||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), &kernels::REDUCE)?;
|
||||
@ -375,7 +420,7 @@ impl candle::CustomOp3 for RotaryEmb {
|
||||
builder.arg(&cos);
|
||||
builder.arg(&sin);
|
||||
builder.arg(&dst);
|
||||
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32);
|
||||
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32, stride_b);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(dst)
|
||||
@ -432,6 +477,11 @@ impl candle::CustomOp3 for RotaryEmb {
|
||||
dtype => candle::bail!("rope is not implemented for {dtype:?}"),
|
||||
};
|
||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
|
||||
h * t * d
|
||||
} else {
|
||||
0usize
|
||||
};
|
||||
let el = b * h * t * d;
|
||||
let output = device.new_buffer(el, src.dtype(), "rope-i")?;
|
||||
candle_metal_kernels::call_rope(
|
||||
@ -442,6 +492,7 @@ impl candle::CustomOp3 for RotaryEmb {
|
||||
b * h,
|
||||
t * d,
|
||||
d,
|
||||
stride_b,
|
||||
src.buffer(),
|
||||
l_src.start_offset() * src.dtype().size_in_bytes(),
|
||||
cos.buffer(),
|
||||
@ -457,9 +508,9 @@ impl candle::CustomOp3 for RotaryEmb {
|
||||
}
|
||||
|
||||
pub fn rope(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
|
||||
let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
|
||||
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
|
||||
let (sin_seq_len, sin_n_embd) = sin.dims2()?;
|
||||
let (b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
|
||||
let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
|
||||
let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
|
||||
if cos_n_embd * 2 != n_embd
|
||||
|| sin_n_embd * 2 != n_embd
|
||||
|| seq_len > cos_seq_len
|
||||
@ -541,14 +592,21 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
||||
Some((o1, o2)) => &sin[o1..o2],
|
||||
};
|
||||
let (b, t, h, d) = l_src.shape().dims4()?;
|
||||
let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
|
||||
let el_count = b * h * t * d;
|
||||
let mut dst = vec![T::zero(); el_count];
|
||||
src.par_chunks(t * h * d)
|
||||
.zip(dst.par_chunks_mut(t * h * d))
|
||||
.for_each(|(src, dst)| {
|
||||
.enumerate()
|
||||
.for_each(|(b_i, (src, dst))| {
|
||||
for i_t in 0..t {
|
||||
for i_d in 0..d / 2 {
|
||||
let i_cs = i_t * (d / 2) + i_d;
|
||||
let i_cs = if unbatched_rope {
|
||||
i_cs + b_i * t * d / 2
|
||||
} else {
|
||||
i_cs
|
||||
};
|
||||
for i_h in 0..h {
|
||||
let i1 = i_t * h * d + i_h * d + i_d;
|
||||
let i2 = i1 + d / 2;
|
||||
@ -616,6 +674,11 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
||||
Some((o1, o2)) => sin.slice(o1..o2),
|
||||
};
|
||||
let (b, t, h, d) = l_src.shape().dims4()?;
|
||||
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
|
||||
(h * t * d) as u32
|
||||
} else {
|
||||
0u32
|
||||
};
|
||||
let el = b * h * t * d;
|
||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), &kernels::REDUCE)?;
|
||||
@ -626,7 +689,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
||||
builder.arg(&cos);
|
||||
builder.arg(&sin);
|
||||
builder.arg(&dst);
|
||||
candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32);
|
||||
candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32, stride_b);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(dst)
|
||||
@ -683,6 +746,11 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
||||
dtype => candle::bail!("rope_thd is not implemented for {dtype:?}"),
|
||||
};
|
||||
let (b, t, h, d) = l_src.shape().dims4()?;
|
||||
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
|
||||
h * t * d
|
||||
} else {
|
||||
0usize
|
||||
};
|
||||
let el = b * h * t * d;
|
||||
let output = device.new_buffer(el, src.dtype(), "rope-thd")?;
|
||||
candle_metal_kernels::call_rope_thd(
|
||||
@ -694,6 +762,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
||||
t,
|
||||
h,
|
||||
d,
|
||||
stride_b,
|
||||
src.buffer(),
|
||||
l_src.start_offset() * src.dtype().size_in_bytes(),
|
||||
cos.buffer(),
|
||||
@ -709,9 +778,9 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
||||
}
|
||||
|
||||
pub fn rope_thd(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len, _n_head, n_embd) = xs.dims4()?;
|
||||
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
|
||||
let (sin_seq_len, sin_n_embd) = sin.dims2()?;
|
||||
let (b_sz, seq_len, _n_head, n_embd) = xs.dims4()?;
|
||||
let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
|
||||
let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
|
||||
if cos_n_embd * 2 != n_embd
|
||||
|| sin_n_embd * 2 != n_embd
|
||||
|| seq_len > cos_seq_len
|
||||
|
@ -8,13 +8,16 @@ pub fn gumbel_softmax<D: candle::shape::Dim>(
|
||||
) -> Result<Tensor> {
|
||||
if temperature <= 0.0 {
|
||||
logits.argmax(dim)
|
||||
} else if temperature == 1.0 {
|
||||
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
|
||||
let sampled = (logits - minus_g)?.argmax(dim)?;
|
||||
Ok(sampled)
|
||||
} else {
|
||||
// Cast to f32, doing the Gumbel softmax in bf16 is a bit unstable.
|
||||
let logits = logits.to_dtype(candle::DType::F32)?;
|
||||
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
|
||||
let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
|
||||
Ok(sampled)
|
||||
if temperature == 1.0 {
|
||||
let sampled = (logits - minus_g)?.argmax(dim)?;
|
||||
Ok(sampled)
|
||||
} else {
|
||||
let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
|
||||
Ok(sampled)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ extern crate intel_mkl_src;
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::{test_device, test_utils::to_vec3_round, Device, Result, Tensor};
|
||||
use candle::{test_device, test_utils::to_vec3_round, Device, IndexOp, Result, Tensor};
|
||||
|
||||
fn softmax(device: &Device) -> Result<()> {
|
||||
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
|
||||
@ -179,6 +179,28 @@ fn ropei(device: &Device) -> Result<()> {
|
||||
} else {
|
||||
assert!(sum_diff < 1e-4);
|
||||
}
|
||||
|
||||
// Test with a 3d cos/sin
|
||||
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.random::<f32>())
|
||||
.collect();
|
||||
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.random::<f32>())
|
||||
.collect();
|
||||
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
|
||||
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
|
||||
let rope1 = candle_nn::rotary_emb::rope_i(&src.i(0..1)?, &cos, &sin)?;
|
||||
let rope2 = candle_nn::rotary_emb::rope_i(&src.i(1..2)?, &cos2, &sin2)?;
|
||||
|
||||
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
|
||||
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
|
||||
let both_rope = candle_nn::rotary_emb::rope_i(&src, &both_cos, &both_sin)?;
|
||||
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
|
||||
let sum_diff = (both_rope - both_rope2)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(sum_diff, 0.);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -206,6 +228,28 @@ fn rope(device: &Device) -> Result<()> {
|
||||
} else {
|
||||
assert!(sum_diff < 1e-4);
|
||||
}
|
||||
|
||||
// Test with a 3d cos/sin
|
||||
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.random::<f32>())
|
||||
.collect();
|
||||
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.random::<f32>())
|
||||
.collect();
|
||||
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
|
||||
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
|
||||
let rope1 = candle_nn::rotary_emb::rope(&src.i(0..1)?, &cos, &sin)?;
|
||||
let rope2 = candle_nn::rotary_emb::rope(&src.i(1..2)?, &cos2, &sin2)?;
|
||||
|
||||
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
|
||||
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
|
||||
let both_rope = candle_nn::rotary_emb::rope(&src, &both_cos, &both_sin)?;
|
||||
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
|
||||
let sum_diff = (both_rope - both_rope2)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(sum_diff, 0.);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -236,6 +280,37 @@ fn rope_thd(device: &Device) -> Result<()> {
|
||||
} else {
|
||||
assert!(sum_diff < 1e-4);
|
||||
}
|
||||
|
||||
// Test with a 3d cos/sin
|
||||
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.random::<f32>())
|
||||
.collect();
|
||||
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.random::<f32>())
|
||||
.collect();
|
||||
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
|
||||
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
|
||||
let rope1 = {
|
||||
let src = src.transpose(1, 2)?.contiguous()?;
|
||||
candle_nn::rotary_emb::rope_thd(&src.i(0..1)?, &cos, &sin)?
|
||||
};
|
||||
let rope2 = {
|
||||
let src = src.transpose(1, 2)?.contiguous()?;
|
||||
candle_nn::rotary_emb::rope_thd(&src.i(1..2)?, &cos2, &sin2)?
|
||||
};
|
||||
|
||||
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
|
||||
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
|
||||
let both_rope = {
|
||||
let src = src.transpose(1, 2)?.contiguous()?;
|
||||
candle_nn::rotary_emb::rope_thd(&src, &both_cos, &both_sin)?
|
||||
};
|
||||
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
|
||||
let sum_diff = (both_rope - both_rope2)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(sum_diff, 0.);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-onnx"
|
||||
version = "0.9.0-alpha.4"
|
||||
version = "0.9.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "ONNX support for Candle"
|
||||
@ -10,8 +10,8 @@ categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.4" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.4" }
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.9.1" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.9.1" }
|
||||
prost = "0.12.1"
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -1,7 +1,9 @@
|
||||
use crate::onnx::attribute_proto::AttributeType;
|
||||
use crate::onnx::tensor_proto::DataType;
|
||||
use crate::onnx::{self, GraphProto};
|
||||
use candle::Module;
|
||||
use candle::{bail, DType, Device, Result, Tensor};
|
||||
use candle_nn::activation::PReLU;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
pub type Value = Tensor;
|
||||
@ -581,7 +583,13 @@ fn simple_eval_(
|
||||
&Device::Cpu,
|
||||
)?);
|
||||
|
||||
let xs = Tensor::ones(input.shape(), value.dtype(), input.device())?
|
||||
let shape_vec: Vec<usize> = input
|
||||
.to_vec1::<i64>()?
|
||||
.iter()
|
||||
.map(|&x| x as usize)
|
||||
.collect();
|
||||
|
||||
let xs = Tensor::ones(shape_vec, value.dtype(), input.device())?
|
||||
.broadcast_mul(&value)?;
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
@ -991,6 +999,14 @@ fn simple_eval_(
|
||||
let output = input.relu()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"PRelu" => {
|
||||
// https://onnx.ai/onnx/operators/onnx__PRelu.html
|
||||
let input = get(&node.input[0])?;
|
||||
let slope = get(&node.input[1])?;
|
||||
|
||||
let output = PReLU::new(slope.clone(), false).forward(input)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Ceil" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = input.ceil()?;
|
||||
@ -1228,7 +1244,7 @@ fn simple_eval_(
|
||||
}
|
||||
|
||||
let indexes = Tensor::arange_step(s, e, p, data.device())?;
|
||||
out = out.index_select(&indexes, axis)?
|
||||
out = out.contiguous()?.index_select(&indexes, axis)?
|
||||
}
|
||||
values.insert(node.output[0].clone(), out);
|
||||
}
|
||||
@ -1950,6 +1966,273 @@ fn simple_eval_(
|
||||
let output = input.sign()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Resize" => {
|
||||
let input = get(&node.input[0])?;
|
||||
|
||||
if input.rank() != 4 {
|
||||
bail!("Unsupported rank for nearest resize: {}", input.rank());
|
||||
}
|
||||
|
||||
let scales = if node.input.len() > 2 && !node.input[2].is_empty() {
|
||||
Some(get(&node.input[2])?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let sizes = if node.input.len() > 3 && !node.input[3].is_empty() {
|
||||
Some(get(&node.input[3])?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let output_dims = match (scales, sizes) {
|
||||
(Some(_), Some(_)) => {
|
||||
bail!("Scales and sizes cannot both be set for Resize operation")
|
||||
}
|
||||
(Some(scales_tensor), None) => {
|
||||
let scale_values = scales_tensor.to_vec1::<f32>()?;
|
||||
input
|
||||
.dims()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &d)| (d as f32 * scale_values[i]) as usize)
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
(None, Some(sizes_tensor)) => sizes_tensor
|
||||
.to_vec1::<i64>()?
|
||||
.iter()
|
||||
.map(|&d| d as usize)
|
||||
.collect::<Vec<_>>(),
|
||||
(None, None) => bail!("Either scales or sizes should be present"),
|
||||
};
|
||||
|
||||
let coordinate_transformation_mode =
|
||||
get_attr_opt::<str>(node, "coordinate_transformation_mode")?
|
||||
.unwrap_or("half_pixel");
|
||||
// Interpolation mode: nearest, linear, or cubic.
|
||||
let mode = get_attr_opt::<str>(node, "mode")?.unwrap_or("nearest");
|
||||
// How to determine the "nearest" pixel in nearest interpolation mode.
|
||||
let nearest_mode =
|
||||
get_attr_opt::<str>(node, "nearest_mode")?.unwrap_or("round_prefer_floor");
|
||||
|
||||
if mode != "nearest" {
|
||||
bail!("Unsupported resize mode: {}", mode);
|
||||
}
|
||||
|
||||
if nearest_mode != "floor" {
|
||||
bail!("Unsupported nearest_mode for resize: {}", nearest_mode);
|
||||
}
|
||||
|
||||
if coordinate_transformation_mode != "asymmetric" {
|
||||
bail!(
|
||||
"Unsupported coordinate_transformation_mode for resize: {}",
|
||||
coordinate_transformation_mode
|
||||
);
|
||||
}
|
||||
|
||||
let h = output_dims[2];
|
||||
let w = output_dims[3];
|
||||
let output = input.upsample_nearest2d(h, w)?;
|
||||
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Trilu" => {
|
||||
let input = get(&node.input[0])?;
|
||||
|
||||
// Get the diagonal offset 'k' from the second input if provided
|
||||
let k = if node.input.len() > 1 && !node.input[1].is_empty() {
|
||||
get(&node.input[1])?.to_vec0::<i64>()?
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
// Get the 'upper' attribute
|
||||
let upper = get_attr_opt::<i64>(node, "upper")?.copied().unwrap_or(1);
|
||||
|
||||
// For batched inputs, we need to handle each matrix separately
|
||||
let dims = input.dims();
|
||||
if dims.len() < 2 {
|
||||
bail!("Trilu expects input with at least 2 dimensions: {:?}", dims);
|
||||
}
|
||||
|
||||
// Get the last two dimensions which represent the matrix
|
||||
let n = dims[dims.len() - 2];
|
||||
let m = dims[dims.len() - 1];
|
||||
let max_dim = std::cmp::max(n, m);
|
||||
|
||||
// Handle the diagonal offset k
|
||||
let mask = if k != 0 {
|
||||
let mut data = vec![0u32; n * m];
|
||||
for i in 0..n {
|
||||
for j in 0..m {
|
||||
if (upper != 0 && (j as i64) >= (i as i64) + k)
|
||||
|| (upper == 0 && (j as i64) <= (i as i64) + k)
|
||||
{
|
||||
data[i * m + j] = 1u32;
|
||||
}
|
||||
}
|
||||
}
|
||||
Tensor::from_vec(data, (n, m), input.device())?.to_dtype(input.dtype())?
|
||||
} else if upper == 0 {
|
||||
Tensor::tril2(max_dim, input.dtype(), input.device())?
|
||||
} else {
|
||||
Tensor::triu2(max_dim, input.dtype(), input.device())?
|
||||
};
|
||||
|
||||
let final_mask = if n != m {
|
||||
mask.narrow(0, 0, n)?.narrow(1, 0, m)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
|
||||
let output = (input * &final_mask)?;
|
||||
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"ScatterND" => {
|
||||
let data = get(&node.input[0])?;
|
||||
|
||||
let indices = get(&node.input[1])?;
|
||||
let indices = indices.to_dtype(DType::I64)?;
|
||||
|
||||
let updates = get(&node.input[2])?;
|
||||
|
||||
let reduction = get_attr_opt::<str>(node, "reduction")?.unwrap_or("none");
|
||||
|
||||
let indices_shape = indices.dims();
|
||||
let data_shape = data.dims();
|
||||
let updates_shape = updates.dims();
|
||||
|
||||
// Last dimension of indices represents the depth of indexing
|
||||
let k = indices_shape.last().unwrap().clone();
|
||||
|
||||
if k > data.rank() {
|
||||
bail!("ScatterND expects k (indices.shape[-1]) to be at most the rank of data");
|
||||
}
|
||||
|
||||
let num_updates = indices_shape[..indices_shape.len() - 1]
|
||||
.iter()
|
||||
.product::<usize>();
|
||||
|
||||
let flat_indices = if indices.rank() == 1 && k == 1 {
|
||||
indices.unsqueeze(0)?
|
||||
} else {
|
||||
indices.reshape((num_updates, k))?
|
||||
};
|
||||
|
||||
// Calculate the shape of each update element
|
||||
let update_element_shape = if k < data_shape.len() {
|
||||
data_shape[k..].to_vec()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
// Expected shape for updates based on indices and target tensor
|
||||
let expected_updates_shape = {
|
||||
let mut shape = indices_shape[..indices_shape.len() - 1].to_vec();
|
||||
shape.extend(&update_element_shape);
|
||||
shape
|
||||
};
|
||||
|
||||
// Validate or reshape updates to expected shape
|
||||
let updates = if updates.dims() != expected_updates_shape {
|
||||
if updates.rank() == 0 {
|
||||
// Handle scalar updates
|
||||
let mut target_shape = vec![num_updates];
|
||||
target_shape.extend(&update_element_shape);
|
||||
updates.broadcast_as(target_shape)?
|
||||
} else {
|
||||
// Try to broadcast or reshape updates to expected shape
|
||||
let flat_shape =
|
||||
vec![num_updates, update_element_shape.iter().product::<usize>()];
|
||||
let flattened = updates.reshape(flat_shape)?;
|
||||
flattened.reshape(expected_updates_shape)?
|
||||
}
|
||||
} else {
|
||||
updates.clone()
|
||||
};
|
||||
|
||||
let mut output = data.clone();
|
||||
|
||||
// convert indices to flat indices
|
||||
let mut flat_output = output.flatten_all()?;
|
||||
let flat_updates = if update_element_shape.is_empty() {
|
||||
updates.reshape(num_updates)?
|
||||
} else {
|
||||
let product = update_element_shape.iter().product::<usize>();
|
||||
updates.reshape((num_updates, product))?
|
||||
};
|
||||
|
||||
// Calculate strides for the output tensor
|
||||
let mut strides: Vec<usize> = vec![1];
|
||||
for i in (0..data_shape.len() - 1).rev() {
|
||||
strides.push(strides.last().unwrap() * data_shape[i + 1]);
|
||||
}
|
||||
strides.reverse();
|
||||
|
||||
// Process each update
|
||||
for i in 0..num_updates {
|
||||
let index_slice = flat_indices.narrow(0, i, 1)?;
|
||||
let indices_vec = index_slice.squeeze(0)?.to_vec1::<i64>()?;
|
||||
|
||||
// Convert multi-dimensional indices to flat index
|
||||
let mut flat_idx: usize = 0;
|
||||
for (dim, &idx) in indices_vec.iter().enumerate() {
|
||||
let dim_size = data_shape[dim] as i64;
|
||||
let norm_idx = if idx < 0 { dim_size + idx } else { idx };
|
||||
|
||||
if norm_idx < 0 || norm_idx >= dim_size {
|
||||
bail!(
|
||||
"Index {} out of bounds for dimension {} with size {}",
|
||||
idx,
|
||||
dim,
|
||||
dim_size
|
||||
);
|
||||
}
|
||||
|
||||
flat_idx += (norm_idx as usize) * strides[dim];
|
||||
}
|
||||
|
||||
// Extract current update
|
||||
let update_slice = if update_element_shape.is_empty() {
|
||||
flat_updates.narrow(0, i, 1)?.squeeze(0)?
|
||||
} else {
|
||||
flat_updates.narrow(0, i, 1)?
|
||||
};
|
||||
|
||||
match reduction {
|
||||
"add" => {
|
||||
if update_element_shape.is_empty() {
|
||||
let existing = flat_output.narrow(0, flat_idx, 1)?;
|
||||
let new_value = existing.add(&update_slice.unsqueeze(0)?)?;
|
||||
flat_output = flat_output.slice_scatter(&new_value, 0, flat_idx)?;
|
||||
} else {
|
||||
let slice_size = update_element_shape.iter().product::<usize>();
|
||||
let existing = flat_output.narrow(0, flat_idx, slice_size)?;
|
||||
let new_value = existing.add(&update_slice)?;
|
||||
flat_output = flat_output.slice_scatter(&new_value, 0, flat_idx)?;
|
||||
}
|
||||
}
|
||||
"none" | _ => {
|
||||
if update_element_shape.is_empty() {
|
||||
flat_output = flat_output.slice_scatter(
|
||||
&update_slice.unsqueeze(0)?,
|
||||
0,
|
||||
flat_idx,
|
||||
)?;
|
||||
} else {
|
||||
flat_output =
|
||||
flat_output.slice_scatter(&update_slice, 0, flat_idx)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reshape flat output back to original shape
|
||||
output = flat_output.reshape(data_shape.to_vec())?;
|
||||
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||
}
|
||||
}
|
||||
|
@ -842,13 +842,22 @@ fn test_flatten_operation() -> Result<()> {
|
||||
#[test]
|
||||
fn test_constant_of_shape() -> Result<()> {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31
|
||||
test(&[4i64, 3, 2], Some(1.), &[1., 1., 1.])?;
|
||||
test(
|
||||
&[4i64, 3, 2],
|
||||
Some(1.),
|
||||
&[
|
||||
[[1., 1.], [1., 1.], [1., 1.]],
|
||||
[[1., 1.], [1., 1.], [1., 1.]],
|
||||
[[1., 1.], [1., 1.], [1., 1.]],
|
||||
[[1., 1.], [1., 1.], [1., 1.]],
|
||||
],
|
||||
)?;
|
||||
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31
|
||||
test(&[0.], Some(0i64), &[0i64])?;
|
||||
test(&[1i64], Some(0i64), &[0i64])?;
|
||||
|
||||
// "value" defaults to 0 f32
|
||||
test(&[1i64, 2, 3, 4], None as Option<i64>, &[0., 0., 0., 0.])?;
|
||||
test(&[4i64], None as Option<i64>, &[0., 0., 0., 0.])?;
|
||||
|
||||
fn test(
|
||||
input: impl NdArray,
|
||||
@ -1846,6 +1855,64 @@ fn test_relu_operation() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "PRelu"
|
||||
#[test]
|
||||
fn test_prelu_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "PRelu".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: INPUT_Y.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
let x: Tensor = Tensor::from_vec(
|
||||
vec![-1.0f32, 1.0f32, -2.0f32, 3.0f32],
|
||||
&[2, 2],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let y: Tensor = Tensor::from_vec(vec![1.0f32, 1.1f32, 1.2f32, 1.3f32], &[2, 2], &Device::Cpu)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
inputs.insert(INPUT_Y.to_string(), y);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
assert_eq!(results, vec![vec![-1.0, 1.0], vec![-2.4, 3.0]]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
// "Constant"
|
||||
// #[test]
|
||||
|
||||
@ -5910,3 +5977,512 @@ fn test_sign_operation() -> Result<()> {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scatternd_operation() -> Result<()> {
|
||||
// Example 1 based on ONNX documentation
|
||||
test(
|
||||
&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
|
||||
&[[4i64], [3], [1], [7]],
|
||||
&[9.0f32, 10.0, 11.0, 12.0],
|
||||
&[1.0f32, 11.0, 3.0, 10.0, 9.0, 6.0, 7.0, 12.0],
|
||||
)?;
|
||||
|
||||
// A more complex example with 2D data
|
||||
test(
|
||||
&[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]],
|
||||
&[[0i64, 1], [1, 0]],
|
||||
&[10.0f32, 20.0],
|
||||
&[[1.0f32, 10.0], [20.0, 4.0], [5.0, 6.0]],
|
||||
)?;
|
||||
|
||||
// 3D example with indices pointing to specific locations
|
||||
test(
|
||||
&[
|
||||
[[1.0f32, 2.0], [3.0, 4.0]],
|
||||
[[5.0, 6.0], [7.0, 8.0]],
|
||||
[[9.0, 10.0], [11.0, 12.0]],
|
||||
],
|
||||
&[[0i64, 0, 1], [1, 1, 0]],
|
||||
&[100.0f32, 200.0],
|
||||
&[
|
||||
[[1.0f32, 100.0], [3.0, 4.0]],
|
||||
[[5.0, 6.0], [200.0, 8.0]],
|
||||
[[9.0, 10.0], [11.0, 12.0]],
|
||||
],
|
||||
)?;
|
||||
|
||||
fn test(
|
||||
data: impl NdArray,
|
||||
indices: impl NdArray,
|
||||
updates: impl NdArray,
|
||||
expected: impl NdArray,
|
||||
) -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "ScatterND".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![
|
||||
INPUT_X.to_string(),
|
||||
INPUT_Y.to_string(),
|
||||
INPUT_A.to_string(),
|
||||
],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
|
||||
inputs.insert(INPUT_Y.to_string(), Tensor::new(indices, &Device::Cpu)?);
|
||||
inputs.insert(INPUT_A.to_string(), Tensor::new(updates, &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let expected = Tensor::new(expected, &Device::Cpu)?;
|
||||
|
||||
match expected.dims().len() {
|
||||
1 => assert_eq!(z.to_vec1::<f32>()?, expected.to_vec1::<f32>()?),
|
||||
2 => assert_eq!(z.to_vec2::<f32>()?, expected.to_vec2::<f32>()?),
|
||||
3 => assert_eq!(z.to_vec3::<f32>()?, expected.to_vec3::<f32>()?),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trilu_operation() -> Result<()> {
|
||||
// Test 1: Upper triangular matrix (default behavior with upper=true)
|
||||
{
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Trilu".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![], // empty attribute means default upper=true
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
vec![
|
||||
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 0, 8, 7, 4, 3, 4, 2, 4,
|
||||
],
|
||||
&[4, 5],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let results = z.to_vec2::<i64>()?;
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![
|
||||
vec![4, 7, 3, 7, 9],
|
||||
vec![0, 2, 8, 6, 9],
|
||||
vec![0, 0, 0, 8, 7],
|
||||
vec![0, 0, 0, 2, 4]
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
// Test 2: Upper triangular with positive k=1 (diagonal above main)
|
||||
{
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Trilu".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: INPUT_Y.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
vec![1i64, 4, 9, 7, 1, 9, 2, 8, 8, 4, 3, 9, 7, 4, 2],
|
||||
&[3, 5],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let k = Tensor::from_vec(vec![1i64], (), &Device::Cpu)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
inputs.insert(INPUT_Y.to_string(), k);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let results = z.to_vec2::<i64>()?;
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![
|
||||
vec![0, 4, 9, 7, 1],
|
||||
vec![0, 0, 8, 8, 4],
|
||||
vec![0, 0, 0, 4, 2]
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
// Test 3: Upper triangular with negative k=-1 (one diagonal below main)
|
||||
{
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Trilu".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
vec![
|
||||
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 0, 8, 7, 4, 3, 4, 2, 4,
|
||||
],
|
||||
&[4, 5],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let k = Tensor::from_vec(vec![-1i64], (), &Device::Cpu)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
inputs.insert(INPUT_Y.to_string(), k);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let results = z.to_vec2::<i64>()?;
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![
|
||||
vec![4, 7, 3, 7, 9],
|
||||
vec![1, 2, 8, 6, 9],
|
||||
vec![0, 4, 0, 8, 7],
|
||||
vec![0, 0, 4, 2, 4]
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
// Test 4: Lower triangular matrix (upper=0)
|
||||
{
|
||||
let att_upper = AttributeProto {
|
||||
name: "upper".to_string(),
|
||||
ref_attr_name: "upper".to_string(),
|
||||
i: 0, // 0 means false, use lower triangular
|
||||
doc_string: "upper".to_string(),
|
||||
r#type: 2,
|
||||
f: 0.0,
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: vec![],
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Trilu".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![att_upper],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
vec![
|
||||
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 1, 8, 7, 4, 3, 4, 2, 4,
|
||||
],
|
||||
&[4, 5],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let results = z.to_vec2::<i64>()?;
|
||||
|
||||
// Lower triangular matrix (default k=0)
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![
|
||||
vec![4, 0, 0, 0, 0],
|
||||
vec![1, 2, 0, 0, 0],
|
||||
vec![9, 4, 1, 0, 0],
|
||||
vec![4, 3, 4, 2, 0]
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
// Test 5: Lower triangular with negative k=-1
|
||||
{
|
||||
let att_upper = AttributeProto {
|
||||
name: "upper".to_string(),
|
||||
ref_attr_name: "upper".to_string(),
|
||||
i: 0,
|
||||
doc_string: "upper".to_string(),
|
||||
r#type: 2,
|
||||
f: 0.0,
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: vec![],
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Trilu".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![att_upper],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
vec![
|
||||
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 1, 8, 7, 4, 3, 4, 2, 4,
|
||||
],
|
||||
&[4, 5],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let k = Tensor::from_vec(vec![-1i64], (), &Device::Cpu)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
inputs.insert(INPUT_Y.to_string(), k);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let results = z.to_vec2::<i64>()?;
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![
|
||||
vec![0, 0, 0, 0, 0],
|
||||
vec![1, 0, 0, 0, 0],
|
||||
vec![9, 4, 0, 0, 0],
|
||||
vec![4, 3, 4, 0, 0]
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
// Test 6: Lower triangular with positive k=2
|
||||
{
|
||||
let att_upper = AttributeProto {
|
||||
name: "upper".to_string(),
|
||||
ref_attr_name: "upper".to_string(),
|
||||
i: 0,
|
||||
doc_string: "upper".to_string(),
|
||||
r#type: 2,
|
||||
f: 0.0,
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: vec![],
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Trilu".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![att_upper],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
vec![
|
||||
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 1, 8, 7, 4, 3, 4, 2, 4,
|
||||
],
|
||||
&[4, 5],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let k = Tensor::from_vec(vec![2i64], (), &Device::Cpu)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
inputs.insert(INPUT_Y.to_string(), k);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let results = z.to_vec2::<i64>()?;
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![
|
||||
vec![4, 7, 3, 0, 0],
|
||||
vec![1, 2, 8, 6, 0],
|
||||
vec![9, 4, 1, 8, 7],
|
||||
vec![4, 3, 4, 2, 4]
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -869,8 +869,8 @@ impl Moe {
|
||||
}
|
||||
|
||||
enum MoeOrMlp {
|
||||
Moe(Moe),
|
||||
Mlp(Mlp),
|
||||
Moe(Box<Moe>),
|
||||
Mlp(Box<Mlp>),
|
||||
}
|
||||
|
||||
impl MoeOrMlp {
|
||||
@ -908,14 +908,17 @@ impl DecoderLayer {
|
||||
&& layer_idx >= cfg.first_k_dense_replace
|
||||
&& layer_idx % cfg.moe_layer_freq == 0
|
||||
{
|
||||
MoeOrMlp::Moe(Moe::new(
|
||||
cfg,
|
||||
vb.pp("mlp"),
|
||||
cfg.n_shared_experts,
|
||||
cfg.n_routed_experts.unwrap(),
|
||||
)?)
|
||||
MoeOrMlp::Moe(
|
||||
Moe::new(
|
||||
cfg,
|
||||
vb.pp("mlp"),
|
||||
cfg.n_shared_experts,
|
||||
cfg.n_routed_experts.unwrap(),
|
||||
)?
|
||||
.into(),
|
||||
)
|
||||
} else {
|
||||
MoeOrMlp::Mlp(Mlp::new(cfg, vb.pp("mlp"), None, None)?)
|
||||
MoeOrMlp::Mlp(Mlp::new(cfg, vb.pp("mlp"), None, None)?.into())
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
|
@ -21,6 +21,7 @@ pub struct Config {
|
||||
pub num_key_value_heads: usize,
|
||||
pub rms_norm_eps: f64,
|
||||
pub rope_theta: f64,
|
||||
pub rope_local_base_freq: f64,
|
||||
pub vocab_size: usize,
|
||||
pub final_logit_softcapping: Option<f64>,
|
||||
pub attn_logit_softcapping: Option<f64>,
|
||||
@ -67,12 +68,22 @@ struct RotaryEmbedding {
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
fn new(
|
||||
dtype: DType,
|
||||
cfg: &Config,
|
||||
dev: &Device,
|
||||
sliding_window: Option<usize>,
|
||||
) -> Result<Self> {
|
||||
let dim = cfg.head_dim;
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let rope_freq = if sliding_window.is_some() {
|
||||
cfg.rope_local_base_freq
|
||||
} else {
|
||||
cfg.rope_theta
|
||||
};
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.map(|i| 1f32 / rope_freq.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
@ -162,8 +173,8 @@ impl Attention {
|
||||
fn new(
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
use_flash_attn: bool,
|
||||
is_sliding: bool,
|
||||
cfg: &Config,
|
||||
sliding_window: Option<usize>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
@ -178,13 +189,13 @@ impl Attention {
|
||||
let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?;
|
||||
let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?;
|
||||
let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?;
|
||||
let kv_cache = if is_sliding {
|
||||
KvCache::Rotating(candle_nn::kv_cache::RotatingKvCache::new(
|
||||
2,
|
||||
cfg.sliding_window,
|
||||
))
|
||||
let kv_cache = if let Some(sliding_window) = sliding_window {
|
||||
KvCache::Rotating(candle_nn::kv_cache::RotatingKvCache::new(2, sliding_window))
|
||||
} else {
|
||||
KvCache::Normal(candle_nn::kv_cache::KvCache::new(2, cfg.sliding_window))
|
||||
KvCache::Normal(candle_nn::kv_cache::KvCache::new(
|
||||
2,
|
||||
cfg.max_position_embeddings,
|
||||
))
|
||||
};
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
@ -302,21 +313,27 @@ struct DecoderLayer {
|
||||
pre_feedforward_layernorm: RmsNorm,
|
||||
post_feedforward_layernorm: RmsNorm,
|
||||
post_attention_layernorm: RmsNorm,
|
||||
sliding_window: Option<usize>,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
use_flash_attn: bool,
|
||||
is_sliding: bool,
|
||||
cfg: &Config,
|
||||
vb: VarBuilder,
|
||||
sliding_window: Option<usize>,
|
||||
) -> Result<Self> {
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(
|
||||
vb.dtype(),
|
||||
cfg,
|
||||
vb.device(),
|
||||
sliding_window,
|
||||
)?);
|
||||
let self_attn = Attention::new(
|
||||
rotary_emb,
|
||||
use_flash_attn,
|
||||
is_sliding,
|
||||
cfg,
|
||||
sliding_window,
|
||||
vb.pp("self_attn"),
|
||||
)?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
@ -344,6 +361,7 @@ impl DecoderLayer {
|
||||
pre_feedforward_layernorm,
|
||||
post_feedforward_layernorm,
|
||||
post_attention_layernorm,
|
||||
sliding_window,
|
||||
})
|
||||
}
|
||||
|
||||
@ -370,6 +388,42 @@ impl DecoderLayer {
|
||||
}
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
b_size: usize,
|
||||
tgt_len: usize,
|
||||
seqlen_offset: usize,
|
||||
sliding_window: Option<usize>,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = if let Some(sliding_window) = sliding_window {
|
||||
(0..tgt_len)
|
||||
.flat_map(|i| {
|
||||
(0..tgt_len).map(move |j| {
|
||||
if i < j || j + sliding_window < i {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0.
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
(0..tgt_len)
|
||||
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0f32 }))
|
||||
.collect()
|
||||
};
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||
.to_dtype(dtype)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embed_tokens: candle_nn::Embedding,
|
||||
@ -388,17 +442,15 @@ impl Model {
|
||||
let vb_m = vb.pp("model");
|
||||
let embed_tokens =
|
||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_l = vb_m.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers {
|
||||
let is_sliding = (layer_idx + 1) % cfg.sliding_window_pattern > 0;
|
||||
let sliding_window = (layer_idx + 1) % cfg.sliding_window_pattern > 0;
|
||||
let layer = DecoderLayer::new(
|
||||
rotary_emb.clone(),
|
||||
use_flash_attn,
|
||||
is_sliding,
|
||||
cfg,
|
||||
vb_l.pp(layer_idx),
|
||||
sliding_window.then_some(cfg.sliding_window),
|
||||
)?;
|
||||
layers.push(layer)
|
||||
}
|
||||
@ -417,51 +469,52 @@ impl Model {
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
fn create_attention_masks(
|
||||
&self,
|
||||
b_size: usize,
|
||||
tgt_len: usize,
|
||||
batch_size: usize,
|
||||
seq_len: usize,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = match Some(self.sliding_window) {
|
||||
None => (0..tgt_len)
|
||||
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||
.collect(),
|
||||
Some(sliding_window) => (0..tgt_len)
|
||||
.flat_map(|i| {
|
||||
(0..tgt_len).map(move |j| {
|
||||
if i < j || j + sliding_window < i {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0.
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
};
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||
.to_dtype(self.dtype)
|
||||
) -> Result<(Option<Tensor>, Option<Tensor>)> {
|
||||
if seq_len <= 1 {
|
||||
return Ok((None, None));
|
||||
}
|
||||
|
||||
let mask = prepare_decoder_attention_mask(
|
||||
batch_size,
|
||||
seq_len,
|
||||
seqlen_offset,
|
||||
None,
|
||||
self.dtype,
|
||||
&self.device,
|
||||
)?;
|
||||
|
||||
let sliding_mask = prepare_decoder_attention_mask(
|
||||
batch_size,
|
||||
seq_len,
|
||||
seqlen_offset,
|
||||
Some(self.sliding_window),
|
||||
self.dtype,
|
||||
&self.device,
|
||||
)?;
|
||||
|
||||
Ok((Some(mask), Some(sliding_mask)))
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (b_size, seq_len) = input_ids.dims2()?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
||||
Some(mask)
|
||||
};
|
||||
let xs = self.embed_tokens.forward(input_ids)?;
|
||||
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
|
||||
|
||||
let (attention_mask, sliding_attention_mask) =
|
||||
self.create_attention_masks(b_size, seq_len, seqlen_offset)?;
|
||||
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||
let mask = if layer.sliding_window.is_some() {
|
||||
&sliding_attention_mask
|
||||
} else {
|
||||
&attention_mask
|
||||
};
|
||||
xs = layer.forward(&xs, mask.as_ref(), seqlen_offset)?
|
||||
}
|
||||
let logits = xs
|
||||
.narrow(1, seq_len - 1, 1)?
|
||||
|
@ -70,6 +70,7 @@ pub mod moondream;
|
||||
pub mod mpt;
|
||||
pub mod nvembed_v2;
|
||||
pub mod olmo;
|
||||
pub mod olmo2;
|
||||
pub mod openclip;
|
||||
pub mod paligemma;
|
||||
pub mod parler_tts;
|
||||
@ -79,6 +80,7 @@ pub mod phi3;
|
||||
pub mod pixtral;
|
||||
pub mod quantized_blip;
|
||||
pub mod quantized_blip_text;
|
||||
pub mod quantized_gemma3;
|
||||
pub mod quantized_llama;
|
||||
pub mod quantized_llama2_c;
|
||||
pub mod quantized_metavoice;
|
||||
@ -89,6 +91,7 @@ pub mod quantized_mpt;
|
||||
pub mod quantized_phi;
|
||||
pub mod quantized_phi3;
|
||||
pub mod quantized_qwen2;
|
||||
pub mod quantized_qwen3;
|
||||
pub mod quantized_recurrent_gemma;
|
||||
pub mod quantized_rwkv_v5;
|
||||
pub mod quantized_rwkv_v6;
|
||||
@ -96,6 +99,8 @@ pub mod quantized_stable_lm;
|
||||
pub mod quantized_t5;
|
||||
pub mod qwen2;
|
||||
pub mod qwen2_moe;
|
||||
pub mod qwen3;
|
||||
pub mod qwen3_moe;
|
||||
pub mod recurrent_gemma;
|
||||
pub mod repvgg;
|
||||
pub mod resnet;
|
||||
|
348
candle-transformers/src/models/olmo2.rs
Normal file
348
candle-transformers/src/models/olmo2.rs
Normal file
@ -0,0 +1,348 @@
|
||||
//! OLMo 2 (Open Language Model) implementation
|
||||
//!
|
||||
//! See OLMo 2 model details at:
|
||||
//! - [Hugging Face Collection](https://huggingface.co/collections/allenai/olmo-2-674117b93ab84e98afc72edc)
|
||||
//! - [OLMo 2 Paper](https://arxiv.org/abs/2501.00656)
|
||||
//!
|
||||
//!
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{linear_b, linear_no_bias, rms_norm, Activation, Linear, RmsNorm, VarBuilder};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
pub vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub attention_bias: bool,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub rms_norm_eps: f64,
|
||||
pub hidden_act: candle_nn::Activation,
|
||||
pub max_position_embeddings: usize,
|
||||
pub rope_theta: f64,
|
||||
pub tie_word_embeddings: bool,
|
||||
pub clip_qkv: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
struct MLP {
|
||||
gate_proj: Linear,
|
||||
up_proj: Linear,
|
||||
down_proj: Linear,
|
||||
act_fn: Activation,
|
||||
}
|
||||
|
||||
impl MLP {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
let intermediate_sz = cfg.intermediate_size;
|
||||
let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
|
||||
let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
|
||||
let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
|
||||
Ok(Self {
|
||||
gate_proj,
|
||||
up_proj,
|
||||
down_proj,
|
||||
act_fn: cfg.hidden_act,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||
let rhs = xs.apply(&self.up_proj)?;
|
||||
(lhs * rhs)?.apply(&self.down_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
o_proj: Linear,
|
||||
q_norm: RmsNorm,
|
||||
k_norm: RmsNorm,
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
num_kv_groups: usize,
|
||||
head_dim: usize,
|
||||
hidden_size: usize,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let num_kv_heads = cfg.num_key_value_heads;
|
||||
let num_kv_groups = num_heads / num_kv_heads;
|
||||
let head_dim = hidden_sz / num_heads;
|
||||
let b = cfg.attention_bias;
|
||||
let q_proj = linear_b(hidden_sz, num_heads * head_dim, b, vb.pp("q_proj"))?;
|
||||
let k_proj = linear_b(hidden_sz, num_kv_heads * head_dim, b, vb.pp("k_proj"))?;
|
||||
let v_proj = linear_b(hidden_sz, num_kv_heads * head_dim, b, vb.pp("v_proj"))?;
|
||||
let o_proj = linear_b(num_heads * head_dim, hidden_sz, b, vb.pp("o_proj"))?;
|
||||
let q_norm = rms_norm(hidden_sz, cfg.rms_norm_eps, vb.pp("q_norm"))?;
|
||||
let k_norm = rms_norm(num_kv_heads * head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
q_norm,
|
||||
k_norm,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
num_kv_groups,
|
||||
head_dim,
|
||||
hidden_size: hidden_sz,
|
||||
rotary_emb,
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, q_len, _) = xs.dims3()?;
|
||||
|
||||
let query_states = self.q_proj.forward(xs)?;
|
||||
let key_states = self.k_proj.forward(xs)?;
|
||||
let value_states = self.v_proj.forward(xs)?;
|
||||
|
||||
let query_states = self.q_norm.forward(&query_states)?;
|
||||
let key_states = self.k_norm.forward(&key_states)?;
|
||||
|
||||
let query_states = query_states
|
||||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let key_states = key_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let value_states = value_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let (query_states, key_states) =
|
||||
self.rotary_emb
|
||||
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
|
||||
|
||||
let (key_states, value_states) = match &self.kv_cache {
|
||||
None => (key_states, value_states),
|
||||
Some((prev_k, prev_v)) => {
|
||||
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
|
||||
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
|
||||
(key_states, value_states)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
|
||||
let value_states =
|
||||
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
|
||||
|
||||
let attn_output = {
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||
|
||||
let attn_weights = match attention_mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||
};
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
attn_weights.matmul(&value_states)?
|
||||
};
|
||||
attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, q_len, self.hidden_size))?
|
||||
.apply(&self.o_proj)
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderLayer {
|
||||
self_attn: Attention,
|
||||
mlp: MLP,
|
||||
post_attention_layernorm: RmsNorm,
|
||||
post_feedforward_layernorm: RmsNorm,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
let post_feedforward_layernorm = rms_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
vb.pp("post_feedforward_layernorm"),
|
||||
)?;
|
||||
let post_attention_layernorm = rms_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
mlp,
|
||||
post_attention_layernorm,
|
||||
post_feedforward_layernorm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self.self_attn.forward(xs, attention_mask, seqlen_offset)?;
|
||||
let xs = self.post_attention_layernorm.forward(&xs)?;
|
||||
let xs = (xs + residual)?;
|
||||
let residual = &xs;
|
||||
let xs = self.mlp.forward(&xs)?;
|
||||
let xs = self.post_feedforward_layernorm.forward(&xs)?;
|
||||
residual + xs
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.self_attn.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embed_tokens: candle_nn::Embedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
norm: RmsNorm,
|
||||
lm_head: Linear,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_m = vb.pp("model");
|
||||
let embed_tokens =
|
||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_l = vb_m.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers {
|
||||
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm = rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
||||
let lm_head = if cfg.tie_word_embeddings {
|
||||
Linear::new(embed_tokens.embeddings().clone(), None)
|
||||
} else {
|
||||
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
|
||||
};
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
lm_head,
|
||||
device: vb.device().clone(),
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
&self,
|
||||
b_size: usize,
|
||||
tgt_len: usize,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
// Sliding window mask?
|
||||
let mask: Vec<_> = (0..tgt_len)
|
||||
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), self.dtype, &self.device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||
.to_dtype(self.dtype)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (b_size, seq_len) = input_ids.dims2()?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
||||
Some(mask)
|
||||
};
|
||||
let mut xs = self.embed_tokens.forward(input_ids)?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||
}
|
||||
xs.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.norm)?
|
||||
.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
for layer in self.layers.iter_mut() {
|
||||
layer.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
}
|
@ -20,10 +20,24 @@
|
||||
// This implementation is based on:
|
||||
// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
|
||||
use crate::models::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::VarBuilder;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub enum RopeScalingType {
|
||||
#[serde(rename = "longrope")]
|
||||
LongRope,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct RopeScaling {
|
||||
pub short_factor: Vec<f32>,
|
||||
pub long_factor: Vec<f32>,
|
||||
#[serde(rename = "type")]
|
||||
pub type_: RopeScalingType,
|
||||
}
|
||||
|
||||
// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
@ -38,8 +52,12 @@ pub struct Config {
|
||||
pub rope_theta: f64,
|
||||
pub bos_token_id: Option<u32>,
|
||||
pub eos_token_id: Option<u32>,
|
||||
pub rope_scaling: Option<String>,
|
||||
pub rope_scaling: Option<RopeScaling>,
|
||||
pub max_position_embeddings: usize,
|
||||
pub original_max_position_embeddings: Option<usize>,
|
||||
pub partial_rotary_factor: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub tie_word_embeddings: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@ -50,30 +68,88 @@ impl Config {
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RotaryEmbedding {
|
||||
partial_dim: Option<usize>,
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
pub fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.head_dim();
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let partial_dim = cfg
|
||||
.partial_rotary_factor
|
||||
.as_ref()
|
||||
.map(|v| (v * cfg.head_dim() as f64) as usize);
|
||||
let dim = partial_dim.unwrap_or(cfg.head_dim());
|
||||
let freqs = match cfg.rope_scaling.as_ref() {
|
||||
None => {
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, ()), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
t.matmul(&inv_freq)?
|
||||
}
|
||||
Some(rope_scaling) => {
|
||||
let inv_freq_s: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.zip(rope_scaling.short_factor.iter())
|
||||
.map(|(i, &f)| f / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_s = Tensor::from_vec(inv_freq_s, (1, ()), dev)?.to_dtype(dtype)?;
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
match cfg.original_max_position_embeddings {
|
||||
None => {
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
t.matmul(&inv_freq_s)?
|
||||
}
|
||||
Some(original_max_seq_len) => {
|
||||
let t_s = Tensor::arange(0u32, original_max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((original_max_seq_len, 1))?;
|
||||
let freq_s = t_s.matmul(&inv_freq_s)?;
|
||||
let inv_freq_l: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.zip(rope_scaling.long_factor.iter())
|
||||
.map(|(i, &f)| f / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_l =
|
||||
Tensor::from_vec(inv_freq_l, (1, ()), dev)?.to_dtype(dtype)?;
|
||||
let t_l =
|
||||
Tensor::arange(original_max_seq_len as u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape(((), 1))?;
|
||||
let freq_l = t_l.matmul(&inv_freq_l)?;
|
||||
Tensor::cat(&[&freq_s, &freq_l], 0)?
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(Self {
|
||||
partial_dim,
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
fn rope(&self, xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
|
||||
let x = match self.partial_dim {
|
||||
None => candle_nn::rotary_emb::rope(&xs.contiguous()?, cos, sin)?,
|
||||
Some(dim) => {
|
||||
let xs_rot = xs.i((.., .., .., ..dim))?.contiguous()?;
|
||||
let xs_pass = xs.i((.., .., .., dim..))?;
|
||||
let xs_rot = candle_nn::rotary_emb::rope(&xs_rot, cos, sin)?;
|
||||
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)?.contiguous()?
|
||||
}
|
||||
};
|
||||
Ok(x)
|
||||
}
|
||||
|
||||
pub fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
@ -83,8 +159,8 @@ impl RotaryEmbedding {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||
let q_embed = self.rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = self.rope(&k.contiguous()?, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
@ -292,7 +368,11 @@ impl Model {
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let lm_head = if cfg.tie_word_embeddings {
|
||||
Linear::from_weights(embed_tokens.embeddings().clone(), None)
|
||||
} else {
|
||||
linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
|
||||
};
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
|
466
candle-transformers/src/models/quantized_gemma3.rs
Normal file
466
candle-transformers/src/models/quantized_gemma3.rs
Normal file
@ -0,0 +1,466 @@
|
||||
//! Gemma 3 model implementation with quantization support.
|
||||
//!
|
||||
//! Gemma 3 is a family of multimodal language models developed by Google.
|
||||
//! This implementation provides quantization for reduced memory usage and faster inference.
|
||||
//!
|
||||
//! Key characteristics:
|
||||
//! - Group-Query Attention (GQA) with specialized key-value heads
|
||||
//! - RMSNorm for layer normalization
|
||||
//! - Specialized attention patterns with separate normalization for Q/K/V
|
||||
//! - Feed-forward network with SwiGLU activation
|
||||
//! - Support for 2/3/4/8-bit quantization
|
||||
//!
|
||||
//! References:
|
||||
//! - [Gemma 3 Models](https://blog.google/technology/developers/gemma-3/)
|
||||
//!
|
||||
|
||||
use crate::quantized_nn::RmsNorm;
|
||||
use candle::quantized::gguf_file;
|
||||
use candle::quantized::QTensor;
|
||||
use candle::D;
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||
use candle_nn::{Embedding, Module};
|
||||
|
||||
pub const MAX_SEQ_LEN: usize = 131072; // Gemma 3 supports 128K context window
|
||||
pub const DEFAULT_SLIDING_WINDOW_TYPE: usize = 6;
|
||||
pub const DEFAULT_ROPE_FREQUENCY: f32 = 1_000_000.;
|
||||
pub const DEFAULT_ROPE_FREQUENCY_SLIDING: f32 = 10_000.;
|
||||
pub const DEFAULT_ROPE_FREQUENCY_SCALE_FACTOR: f32 = 1.;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct QMatMul {
|
||||
inner: candle::quantized::QMatMul,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl QMatMul {
|
||||
fn from_qtensor(qtensor: QTensor) -> Result<Self> {
|
||||
let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Mlp {
|
||||
feed_forward_gate: QMatMul, // ffn_gate in GGUF
|
||||
feed_forward_up: QMatMul, // ffn_up in GGUF
|
||||
feed_forward_down: QMatMul, // ffn_down in GGUF
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let gate = self.feed_forward_gate.forward(xs)?;
|
||||
let up = self.feed_forward_up.forward(xs)?;
|
||||
let silu = candle_nn::ops::silu(&gate)?;
|
||||
let gated = (silu * up)?;
|
||||
self.feed_forward_down.forward(&gated)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(head_dim: usize, rope_frequency: f32, device: &Device) -> Result<Self> {
|
||||
let theta: Vec<_> = (0..head_dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / rope_frequency.powf(i as f32 / head_dim as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((MAX_SEQ_LEN, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
let cos = idx_theta.cos()?;
|
||||
let sin = idx_theta.sin()?;
|
||||
Ok(Self { sin, cos })
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
index_pos: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.sin.narrow(0, index_pos, seq_len)?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct LayerWeights {
|
||||
// Attention components
|
||||
attention_wq: QMatMul,
|
||||
attention_wk: QMatMul,
|
||||
attention_wv: QMatMul,
|
||||
attention_wo: QMatMul,
|
||||
|
||||
// Specialized normalization for Q and K
|
||||
attention_q_norm: RmsNorm,
|
||||
attention_k_norm: RmsNorm,
|
||||
|
||||
// Layer normalization
|
||||
attention_norm: RmsNorm, // Applied before attention
|
||||
post_attention_norm: RmsNorm, // Applied after attention
|
||||
ffn_norm: RmsNorm, // Applied before feedforward
|
||||
post_ffn_norm: RmsNorm, // Applied after feedforward
|
||||
|
||||
// Feed-forward network
|
||||
mlp: Mlp,
|
||||
|
||||
// Attention parameters
|
||||
n_head: usize, // Number of query heads
|
||||
n_kv_head: usize, // Number of key-value heads
|
||||
head_dim: usize, // Dimension of each head
|
||||
q_dim: usize, // Total dimension for queries
|
||||
|
||||
sliding_window_size: Option<usize>,
|
||||
|
||||
rotary_embedding: RotaryEmbedding,
|
||||
neg_inf: Tensor,
|
||||
|
||||
// Cache
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
|
||||
// Tracing
|
||||
span_attn: tracing::Span,
|
||||
span_mlp: tracing::Span,
|
||||
}
|
||||
|
||||
impl LayerWeights {
|
||||
fn mask(
|
||||
&self,
|
||||
b_sz: usize,
|
||||
seq_len: usize,
|
||||
index_pos: usize,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = if let Some(sliding_window_size) = self.sliding_window_size {
|
||||
(0..seq_len)
|
||||
.flat_map(|i| {
|
||||
(0..seq_len).map(move |j| {
|
||||
if i < j || j + sliding_window_size < i {
|
||||
0u32
|
||||
} else {
|
||||
1u32
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
(0..seq_len)
|
||||
.flat_map(|i| (0..seq_len).map(move |j| if i < j { 0u32 } else { 1u32 }))
|
||||
.collect()
|
||||
};
|
||||
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
|
||||
let mask = if index_pos > 0 {
|
||||
let mask0 = Tensor::zeros((seq_len, index_pos), DType::F32, device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_sz, 1, seq_len, seq_len + index_pos))?
|
||||
.to_dtype(dtype)
|
||||
}
|
||||
|
||||
fn forward_attn(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
mask: Option<&Tensor>,
|
||||
index_pos: usize,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span_attn.enter();
|
||||
let (b_sz, seq_len, _) = x.dims3()?;
|
||||
|
||||
let q = self.attention_wq.forward(x)?;
|
||||
let k = self.attention_wk.forward(x)?;
|
||||
let v = self.attention_wv.forward(x)?;
|
||||
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let k = k
|
||||
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let v = v
|
||||
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let q = self.attention_q_norm.forward(&q.contiguous()?)?;
|
||||
let k = self.attention_k_norm.forward(&k.contiguous()?)?;
|
||||
|
||||
let (q, k) = self
|
||||
.rotary_embedding
|
||||
.apply_rotary_emb_qkv(&q, &k, index_pos)?;
|
||||
|
||||
let (k, v) = match &self.kv_cache {
|
||||
None => (k, v),
|
||||
Some((k_cache, v_cache)) => {
|
||||
if index_pos == 0 {
|
||||
(k, v)
|
||||
} else {
|
||||
let k = Tensor::cat(&[k_cache, &k], 2)?; // concat on seq dim
|
||||
let v = Tensor::cat(&[v_cache, &v], 2)?;
|
||||
(k, v)
|
||||
}
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((k.clone(), v.clone())); // update cache
|
||||
|
||||
// Repeat KV for GQA
|
||||
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
|
||||
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
|
||||
|
||||
// Scaled Dot-Product Attention
|
||||
let scale = 1.0 / (self.head_dim as f64).sqrt();
|
||||
let mut attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
|
||||
|
||||
if let Some(mask) = mask {
|
||||
let mask = mask.broadcast_as(attn_weights.shape())?;
|
||||
let neg_inf = self.neg_inf.broadcast_as(attn_weights.dims())?;
|
||||
attn_weights = mask.eq(0u32)?.where_cond(&neg_inf, &attn_weights)?;
|
||||
}
|
||||
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
let attn_output = attn_weights.matmul(&v)?;
|
||||
|
||||
let attn_output = attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, seq_len, self.q_dim))?;
|
||||
|
||||
self.attention_wo.forward(&attn_output)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelWeights {
|
||||
tok_embeddings: Embedding,
|
||||
embedding_length: usize,
|
||||
layers: Vec<LayerWeights>,
|
||||
norm: RmsNorm,
|
||||
output: QMatMul,
|
||||
span: tracing::Span,
|
||||
span_output: tracing::Span,
|
||||
}
|
||||
|
||||
impl ModelWeights {
|
||||
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
||||
ct: gguf_file::Content,
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let md_get = |s: &str| match ct.metadata.get(s) {
|
||||
None => candle::bail!("cannot find {s} in metadata"),
|
||||
Some(v) => Ok(v),
|
||||
};
|
||||
|
||||
let head_count = md_get("gemma3.attention.head_count")?.to_u32()? as usize;
|
||||
let head_count_kv = md_get("gemma3.attention.head_count_kv")?.to_u32()? as usize;
|
||||
let block_count = md_get("gemma3.block_count")?.to_u32()? as usize;
|
||||
let embedding_length = md_get("gemma3.embedding_length")?.to_u32()? as usize;
|
||||
let key_length = md_get("gemma3.attention.key_length")?.to_u32()? as usize;
|
||||
let _value_length = md_get("gemma3.attention.value_length")?.to_u32()? as usize;
|
||||
let rms_norm_eps = md_get("gemma3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
|
||||
let sliding_window_size = md_get("gemma3.attention.sliding_window")?.to_u32()? as usize;
|
||||
|
||||
let sliding_window_type = md_get("gemma3.attention.sliding_window_type")
|
||||
.and_then(|m| Ok(m.to_u32()? as usize))
|
||||
.unwrap_or(DEFAULT_SLIDING_WINDOW_TYPE);
|
||||
|
||||
let rope_freq_base = md_get("gemma3.rope.freq_base")
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(DEFAULT_ROPE_FREQUENCY);
|
||||
|
||||
let rope_freq_base_sliding = md_get("gemma3.rope.local_freq_base")
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(DEFAULT_ROPE_FREQUENCY_SLIDING);
|
||||
|
||||
// Unused in Llama.cpp so we aren't using it here.
|
||||
let _rope_freq_scaling_factor = md_get("gemma3.rope.scaling.factor")
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(DEFAULT_ROPE_FREQUENCY_SCALE_FACTOR);
|
||||
|
||||
// Compute the dimensions for queries, keys, and values
|
||||
// These are the total dimensions when projected across all heads
|
||||
let q_dim = head_count * key_length;
|
||||
|
||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
|
||||
|
||||
// Load token embeddings and output projection
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||
let norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(reader, "output_norm.weight", device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
let output = match ct.tensor(reader, "output.weight", device) {
|
||||
Ok(tensor) => tensor,
|
||||
Err(_) => ct.tensor(reader, "token_embd.weight", device)?, // Use tied weights if output.weight doesn't exist
|
||||
};
|
||||
|
||||
let mut layers = Vec::with_capacity(block_count);
|
||||
for layer_idx in 0..block_count {
|
||||
let prefix = format!("blk.{layer_idx}");
|
||||
|
||||
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
|
||||
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
|
||||
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
|
||||
let attention_wo =
|
||||
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
|
||||
|
||||
let attention_q_norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(reader, &format!("{prefix}.attn_q_norm.weight"), device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
|
||||
let attention_k_norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(reader, &format!("{prefix}.attn_k_norm.weight"), device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
|
||||
let attention_norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
|
||||
let post_attention_norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(
|
||||
reader,
|
||||
&format!("{prefix}.post_attention_norm.weight"),
|
||||
device,
|
||||
)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
|
||||
let ffn_norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
|
||||
let post_ffn_norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(reader, &format!("{prefix}.post_ffw_norm.weight"), device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
|
||||
let feed_forward_gate =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
|
||||
let feed_forward_up = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
|
||||
let feed_forward_down =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
|
||||
|
||||
let mlp = Mlp {
|
||||
feed_forward_gate: QMatMul::from_qtensor(feed_forward_gate)?,
|
||||
feed_forward_up: QMatMul::from_qtensor(feed_forward_up)?,
|
||||
feed_forward_down: QMatMul::from_qtensor(feed_forward_down)?,
|
||||
};
|
||||
|
||||
// Sliding window pattern hardcoded to 6 because it's not explicitly defined
|
||||
let is_sliding = (layer_idx + 1) % sliding_window_type > 0;
|
||||
let sliding_window_size = is_sliding.then_some(sliding_window_size);
|
||||
let layer_rope_frequency = if is_sliding {
|
||||
rope_freq_base_sliding
|
||||
} else {
|
||||
rope_freq_base
|
||||
};
|
||||
|
||||
let rotary_embedding = RotaryEmbedding::new(key_length, layer_rope_frequency, device)?;
|
||||
|
||||
// Tracing spans
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
||||
|
||||
layers.push(LayerWeights {
|
||||
attention_wq: QMatMul::from_qtensor(attention_wq)?,
|
||||
attention_wk: QMatMul::from_qtensor(attention_wk)?,
|
||||
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
||||
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
||||
attention_q_norm,
|
||||
attention_k_norm,
|
||||
attention_norm,
|
||||
post_attention_norm,
|
||||
ffn_norm,
|
||||
post_ffn_norm,
|
||||
mlp,
|
||||
n_head: head_count,
|
||||
n_kv_head: head_count_kv,
|
||||
head_dim: key_length,
|
||||
q_dim,
|
||||
sliding_window_size,
|
||||
rotary_embedding,
|
||||
neg_inf: neg_inf.clone(),
|
||||
kv_cache: None,
|
||||
span_attn,
|
||||
span_mlp,
|
||||
})
|
||||
}
|
||||
|
||||
let span = tracing::span!(tracing::Level::TRACE, "model");
|
||||
let span_output = tracing::span!(tracing::Level::TRACE, "output");
|
||||
|
||||
Ok(Self {
|
||||
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
|
||||
embedding_length,
|
||||
layers,
|
||||
norm,
|
||||
output: QMatMul::from_qtensor(output)?,
|
||||
span,
|
||||
span_output,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (b_sz, seq_len) = x.dims2()?;
|
||||
let _enter = self.span.enter();
|
||||
|
||||
let mut layer_in = self.tok_embeddings.forward(x)?;
|
||||
layer_in = (layer_in * (self.embedding_length as f64).sqrt())?;
|
||||
|
||||
for layer in self.layers.iter_mut() {
|
||||
let attention_mask = if seq_len == 1 {
|
||||
None
|
||||
} else {
|
||||
Some(layer.mask(b_sz, seq_len, index_pos, x.dtype(), x.device())?)
|
||||
};
|
||||
|
||||
// Attention block
|
||||
let residual = &layer_in;
|
||||
let x = layer.attention_norm.forward(&layer_in)?;
|
||||
let x = layer.forward_attn(&x, attention_mask.as_ref(), index_pos)?;
|
||||
let x = layer.post_attention_norm.forward(&x)?;
|
||||
let x = (x + residual)?;
|
||||
|
||||
// Feed-forward block
|
||||
let _enter = layer.span_mlp.enter();
|
||||
let residual = &x;
|
||||
let x = layer.ffn_norm.forward(&x)?;
|
||||
let x = layer.mlp.forward(&x)?;
|
||||
let x = layer.post_ffn_norm.forward(&x)?;
|
||||
let x = (x + residual)?;
|
||||
drop(_enter);
|
||||
|
||||
layer_in = x;
|
||||
}
|
||||
|
||||
let _enter = self.span_output.enter();
|
||||
|
||||
let x = layer_in.i((.., seq_len - 1, ..))?;
|
||||
let x = self.norm.forward(&x)?;
|
||||
let output = self.output.forward(&x)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
}
|
429
candle-transformers/src/models/quantized_qwen3.rs
Normal file
429
candle-transformers/src/models/quantized_qwen3.rs
Normal file
@ -0,0 +1,429 @@
|
||||
//! Qwen3 implementation with quantization support.
|
||||
//!
|
||||
//! Based on the Qwen3 architecture and implemented with quantized weights
|
||||
//! for reduced memory usage and faster inference on compatible hardware.
|
||||
//!
|
||||
//! References:
|
||||
//! - [Qwen3 Models](https://huggingface.co/Qwen/Qwen3-0.6B) (architecture based on official implementations)
|
||||
//!
|
||||
use super::with_tracing::QMatMul;
|
||||
use crate::{quantized_nn::RmsNorm, utils::repeat_kv};
|
||||
use candle::quantized::{gguf_file, QTensor};
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle_nn::{kv_cache::KvCache, Activation, Embedding, Module};
|
||||
use std::io::{Read, Seek};
|
||||
use std::sync::Arc;
|
||||
|
||||
struct Gguf<R: Read + Seek> {
|
||||
ct: gguf_file::Content,
|
||||
reader: R,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl<R: Read + Seek> Gguf<R> {
|
||||
fn new(ct: gguf_file::Content, reader: R, device: Device) -> Self {
|
||||
Self { ct, reader, device }
|
||||
}
|
||||
|
||||
fn qmatmul(&mut self, name: &str) -> Result<QMatMul> {
|
||||
let ws = self.ct.tensor(&mut self.reader, name, &self.device)?;
|
||||
QMatMul::from_weights(ws.into())
|
||||
}
|
||||
|
||||
fn rms_norm(&mut self, name: &str, eps: f64) -> Result<RmsNorm> {
|
||||
let ws = self.ct.tensor(&mut self.reader, name, &self.device)?;
|
||||
RmsNorm::from_qtensor(ws, eps)
|
||||
}
|
||||
|
||||
fn metadata(&self) -> &std::collections::HashMap<String, gguf_file::Value> {
|
||||
&self.ct.metadata
|
||||
}
|
||||
|
||||
fn tensor(&mut self, name: &str) -> Result<QTensor> {
|
||||
self.ct.tensor(&mut self.reader, name, &self.device)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct MlpWeights {
|
||||
gate_proj: QMatMul,
|
||||
up_proj: QMatMul,
|
||||
down_proj: QMatMul,
|
||||
act_fn: Activation,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl MlpWeights {
|
||||
fn new<R: Read + Seek>(gg: &mut Gguf<R>, prefix: &str) -> Result<Self> {
|
||||
let gate_proj = gg.qmatmul(&format!("{prefix}.ffn_gate.weight"))?;
|
||||
let up_proj = gg.qmatmul(&format!("{prefix}.ffn_up.weight"))?;
|
||||
let down_proj = gg.qmatmul(&format!("{prefix}.ffn_down.weight"))?;
|
||||
let act_fn = Activation::Silu;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "mlp");
|
||||
Ok(Self {
|
||||
gate_proj,
|
||||
up_proj,
|
||||
down_proj,
|
||||
act_fn,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MlpWeights {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let gate = self.gate_proj.forward(x)?.apply(&self.act_fn)?;
|
||||
let up = self.up_proj.forward(x)?;
|
||||
let gated = (gate * up)?;
|
||||
self.down_proj.forward(&gated)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(
|
||||
dtype: DType,
|
||||
head_dim: usize,
|
||||
max_position_embeddings: usize,
|
||||
rope_theta: f64,
|
||||
dev: &Device,
|
||||
) -> Result<Self> {
|
||||
let dim = head_dim;
|
||||
let max_seq_len = max_position_embeddings;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
/// Apply RoPE (q, k shape: B x H x L x D)
|
||||
fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> {
|
||||
let (_, _, seq_len, _) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, offset, seq_len)?.to_dtype(q.dtype())?;
|
||||
let sin = self.sin.narrow(0, offset, seq_len)?.to_dtype(q.dtype())?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct AttentionWeights {
|
||||
q_proj: QMatMul,
|
||||
k_proj: QMatMul,
|
||||
v_proj: QMatMul,
|
||||
o_proj: QMatMul,
|
||||
q_norm: RmsNorm,
|
||||
k_norm: RmsNorm,
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
num_kv_groups: usize,
|
||||
head_dim: usize,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
kv_cache: KvCache,
|
||||
span_attn: tracing::Span,
|
||||
}
|
||||
|
||||
impl AttentionWeights {
|
||||
fn new<R: Read + Seek>(
|
||||
gg: &mut Gguf<R>,
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
head_dim: usize,
|
||||
rms_norm_eps: f64,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
prefix: &str,
|
||||
) -> Result<Self> {
|
||||
let num_kv_groups = num_heads / num_kv_heads;
|
||||
|
||||
let q_proj = gg.qmatmul(&format!("{prefix}.attn_q.weight"))?;
|
||||
let k_proj = gg.qmatmul(&format!("{prefix}.attn_k.weight"))?;
|
||||
let v_proj = gg.qmatmul(&format!("{prefix}.attn_v.weight"))?;
|
||||
let o_proj = gg.qmatmul(&format!("{prefix}.attn_output.weight"))?;
|
||||
|
||||
let q_norm = gg.rms_norm(&format!("{prefix}.attn_q_norm.weight"), rms_norm_eps)?;
|
||||
let k_norm = gg.rms_norm(&format!("{prefix}.attn_k_norm.weight"), rms_norm_eps)?;
|
||||
|
||||
// Initialize KV cache with 512 tokens capacity to reduce initial memory allocation.
|
||||
// The cache will grow in chunks of 512 tokens when needed.
|
||||
let kv_cache = KvCache::new(2, 512);
|
||||
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
q_norm,
|
||||
k_norm,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
num_kv_groups,
|
||||
head_dim,
|
||||
rotary_emb,
|
||||
kv_cache,
|
||||
span_attn,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, x: &Tensor, attn_mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
|
||||
let _enter = self.span_attn.enter();
|
||||
let (b, l, _) = x.dims3()?;
|
||||
|
||||
let q = self.q_proj.forward(x)?;
|
||||
let k = self.k_proj.forward(x)?;
|
||||
let v = self.v_proj.forward(x)?;
|
||||
|
||||
let q = q
|
||||
.reshape((b, l, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let k = k
|
||||
.reshape((b, l, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let v = v
|
||||
.reshape((b, l, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let q_flat = q.flatten(0, 2)?;
|
||||
let k_flat = k.flatten(0, 2)?;
|
||||
|
||||
let q_flat = self.q_norm.forward(&q_flat)?;
|
||||
let k_flat = self.k_norm.forward(&k_flat)?;
|
||||
let q = q_flat.reshape((b, self.num_heads, l, self.head_dim))?;
|
||||
let k = k_flat.reshape((b, self.num_kv_heads, l, self.head_dim))?;
|
||||
|
||||
let (q, k) = self.rotary_emb.apply(&q, &k, offset)?;
|
||||
|
||||
// Reset KV cache if we're at the first position
|
||||
if offset == 0 {
|
||||
self.kv_cache.reset();
|
||||
}
|
||||
let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
|
||||
|
||||
// Make tensor contiguous to avoid some strided copies
|
||||
let k = k.contiguous()?;
|
||||
let v = v.contiguous()?;
|
||||
|
||||
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
|
||||
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
|
||||
|
||||
let scale = 1.0 / (self.head_dim as f64).sqrt();
|
||||
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
|
||||
if let Some(m) = attn_mask {
|
||||
let m_dtype = m.dtype();
|
||||
let scores_dtype = scores.dtype();
|
||||
let mask = if m_dtype != scores_dtype {
|
||||
m.to_dtype(scores_dtype)?
|
||||
} else {
|
||||
m.clone()
|
||||
};
|
||||
scores = scores.broadcast_add(&mask)?;
|
||||
}
|
||||
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
|
||||
let ctx = probs.matmul(&v)?; // (B, H, L, D)
|
||||
let reshaped_ctx = ctx
|
||||
.transpose(1, 2)?
|
||||
.reshape((b, l, self.num_heads * self.head_dim))?;
|
||||
self.o_proj.forward(&reshaped_ctx)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct LayerWeights {
|
||||
self_attn: AttentionWeights,
|
||||
mlp: MlpWeights,
|
||||
ln1: RmsNorm,
|
||||
ln2: RmsNorm,
|
||||
}
|
||||
|
||||
impl LayerWeights {
|
||||
fn new<R: Read + Seek>(
|
||||
gg: &mut Gguf<R>,
|
||||
num_attention_heads: usize,
|
||||
num_key_value_heads: usize,
|
||||
head_dim: usize,
|
||||
rms_norm_eps: f64,
|
||||
rotary: Arc<RotaryEmbedding>,
|
||||
layer_idx: usize,
|
||||
) -> Result<Self> {
|
||||
let prefix = format!("blk.{layer_idx}");
|
||||
|
||||
let ln1 = gg.rms_norm(&format!("{prefix}.attn_norm.weight"), rms_norm_eps)?;
|
||||
let ln2 = gg.rms_norm(&format!("{prefix}.ffn_norm.weight"), rms_norm_eps)?;
|
||||
let self_attn = AttentionWeights::new(
|
||||
gg,
|
||||
num_attention_heads,
|
||||
num_key_value_heads,
|
||||
head_dim,
|
||||
rms_norm_eps,
|
||||
rotary,
|
||||
&prefix,
|
||||
)?;
|
||||
let mlp = MlpWeights::new(gg, &prefix)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
mlp,
|
||||
ln1,
|
||||
ln2,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
|
||||
let h = self.ln1.forward(x)?;
|
||||
let h = self.self_attn.forward(&h, mask, offset)?;
|
||||
let x = (x + h)?;
|
||||
let h2 = self.ln2.forward(&x)?;
|
||||
let h2 = h2.apply(&self.mlp)?;
|
||||
x + h2
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelWeights {
|
||||
embed_tokens: Embedding,
|
||||
layers: Vec<LayerWeights>,
|
||||
norm: RmsNorm,
|
||||
lm_head: QMatMul,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
span: tracing::Span,
|
||||
span_output: tracing::Span,
|
||||
}
|
||||
|
||||
impl ModelWeights {
|
||||
pub fn from_gguf<R: Read + Seek>(
|
||||
ct: gguf_file::Content,
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let mut gg = Gguf::new(ct, reader, device.clone());
|
||||
let md_get = |s: &str| match gg.metadata().get(s) {
|
||||
None => candle::bail!("cannot find {s} in metadata"),
|
||||
Some(v) => Ok(v),
|
||||
};
|
||||
|
||||
let num_attention_heads = md_get("qwen3.attention.head_count")?.to_u32()? as usize;
|
||||
let num_kv_heads = md_get("qwen3.attention.head_count_kv")?.to_u32()? as usize;
|
||||
let head_dim = md_get("qwen3.attention.key_length")?.to_u32()? as usize;
|
||||
let num_layers = md_get("qwen3.block_count")?.to_u32()? as usize;
|
||||
let hidden_size = md_get("qwen3.embedding_length")?.to_u32()? as usize;
|
||||
let max_position_embeddings = md_get("qwen3.context_length")?.to_u32()? as usize;
|
||||
let rms_norm_eps = md_get("qwen3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
|
||||
let rope_freq_base = md_get("qwen3.rope.freq_base")?.to_f32()? as f64;
|
||||
|
||||
let dtype = match gg.metadata().get("general.dtype") {
|
||||
Some(v) => match v.to_u32() {
|
||||
Ok(0) => DType::F32,
|
||||
Ok(1) => DType::F16,
|
||||
_ => DType::F16,
|
||||
},
|
||||
None => DType::F16,
|
||||
};
|
||||
|
||||
let embed_tensor = gg.tensor("token_embd.weight")?;
|
||||
let embed_tokens = Embedding::new(embed_tensor.dequantize(device)?, hidden_size);
|
||||
|
||||
let rotary = Arc::new(RotaryEmbedding::new(
|
||||
dtype,
|
||||
head_dim,
|
||||
max_position_embeddings,
|
||||
rope_freq_base,
|
||||
device,
|
||||
)?);
|
||||
|
||||
let mut layers = Vec::with_capacity(num_layers);
|
||||
for i in 0..num_layers {
|
||||
layers.push(LayerWeights::new(
|
||||
&mut gg,
|
||||
num_attention_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
rms_norm_eps,
|
||||
rotary.clone(),
|
||||
i,
|
||||
)?);
|
||||
}
|
||||
|
||||
let norm = gg.rms_norm("output_norm.weight", rms_norm_eps)?;
|
||||
// Load output projection tensor, falling back to tied embeddings like gemma3
|
||||
let lm_head_tensor = match gg.tensor("output.weight") {
|
||||
Ok(tensor) => tensor,
|
||||
Err(_) => gg.tensor("token_embd.weight")?,
|
||||
};
|
||||
let lm_head = QMatMul::from_weights(lm_head_tensor.into())?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "model");
|
||||
let span_output = tracing::span!(tracing::Level::TRACE, "output");
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
lm_head,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
span,
|
||||
span_output,
|
||||
})
|
||||
}
|
||||
|
||||
fn causal_mask(
|
||||
&self,
|
||||
b: usize,
|
||||
tgt: usize,
|
||||
offset: usize,
|
||||
sw: Option<usize>,
|
||||
) -> Result<Tensor> {
|
||||
let minf = f32::NEG_INFINITY;
|
||||
let mask: Vec<_> = (0..tgt)
|
||||
.flat_map(|i| {
|
||||
(0..(tgt + offset)).map(move |j| {
|
||||
let past_ok = j <= i + offset;
|
||||
let sw_ok = match sw {
|
||||
Some(w) => (i + offset) as i64 - j as i64 <= w as i64,
|
||||
None => true,
|
||||
};
|
||||
if past_ok && sw_ok {
|
||||
0.
|
||||
} else {
|
||||
minf
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b, l) = input.dims2()?;
|
||||
let mut h = self.embed_tokens.forward(input)?;
|
||||
let causal_mask = if l == 1 {
|
||||
None
|
||||
} else {
|
||||
Some(self.causal_mask(b, l, offset, None)?)
|
||||
};
|
||||
for layer in &mut self.layers {
|
||||
h = layer.forward(&h, causal_mask.as_ref(), offset)?;
|
||||
}
|
||||
let h = self.norm.forward(&h)?;
|
||||
let _enter = self.span_output.enter();
|
||||
let last_hidden = h.narrow(1, l - 1, 1)?;
|
||||
self.lm_head.forward(&last_hidden)?.squeeze(1)
|
||||
}
|
||||
}
|
389
candle-transformers/src/models/qwen3.rs
Normal file
389
candle-transformers/src/models/qwen3.rs
Normal file
@ -0,0 +1,389 @@
|
||||
use crate::{
|
||||
models::with_tracing::{linear_b, linear_no_bias, Linear, RmsNorm},
|
||||
utils::repeat_kv,
|
||||
};
|
||||
use candle::{DType, Device, Module, Result, Tensor};
|
||||
use candle_nn::{kv_cache::KvCache, Activation, VarBuilder};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
pub vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub head_dim: usize,
|
||||
pub attention_bias: bool,
|
||||
pub num_key_value_heads: usize,
|
||||
pub max_position_embeddings: usize,
|
||||
pub sliding_window: Option<usize>,
|
||||
pub max_window_layers: usize,
|
||||
pub tie_word_embeddings: bool,
|
||||
pub rope_theta: f64,
|
||||
pub rms_norm_eps: f64,
|
||||
pub use_sliding_window: bool,
|
||||
pub hidden_act: Activation,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct Qwen3RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
impl Qwen3RotaryEmbedding {
|
||||
pub(crate) fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.head_dim;
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
/// Apply RoPE (q, k shape: B x H x L x D)
|
||||
pub(crate) fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> {
|
||||
let (_, _, seq_len, _) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, offset, seq_len)?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct Qwen3MLP {
|
||||
gate_proj: Linear,
|
||||
up_proj: Linear,
|
||||
down_proj: Linear,
|
||||
act_fn: Activation,
|
||||
}
|
||||
|
||||
impl Qwen3MLP {
|
||||
pub(crate) fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
Ok(Self {
|
||||
gate_proj: linear_no_bias(cfg.hidden_size, cfg.intermediate_size, vb.pp("gate_proj"))?,
|
||||
up_proj: linear_no_bias(cfg.hidden_size, cfg.intermediate_size, vb.pp("up_proj"))?,
|
||||
down_proj: linear_no_bias(cfg.intermediate_size, cfg.hidden_size, vb.pp("down_proj"))?,
|
||||
act_fn: cfg.hidden_act,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Qwen3MLP {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let lhs = x.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||
let rhs = x.apply(&self.up_proj)?;
|
||||
(lhs * rhs)?.apply(&self.down_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct Qwen3Attention {
|
||||
// projections
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
o_proj: Linear,
|
||||
// norms
|
||||
q_norm: RmsNorm,
|
||||
k_norm: RmsNorm,
|
||||
// hyper params
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
num_kv_groups: usize,
|
||||
head_dim: usize,
|
||||
hidden_size: usize,
|
||||
// utils
|
||||
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||
kv_cache: KvCache,
|
||||
}
|
||||
|
||||
impl Qwen3Attention {
|
||||
pub(crate) fn new(
|
||||
cfg: &Config,
|
||||
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
if cfg.use_sliding_window {
|
||||
candle::bail!("sliding window is not suppored")
|
||||
}
|
||||
|
||||
let head_dim = cfg.head_dim;
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let num_kv_heads = cfg.num_key_value_heads;
|
||||
let num_kv_groups = num_heads / num_kv_heads;
|
||||
|
||||
let q_proj = linear_b(
|
||||
cfg.hidden_size,
|
||||
num_heads * head_dim,
|
||||
cfg.attention_bias,
|
||||
vb.pp("q_proj"),
|
||||
)?;
|
||||
let k_proj = linear_b(
|
||||
cfg.hidden_size,
|
||||
num_kv_heads * head_dim,
|
||||
cfg.attention_bias,
|
||||
vb.pp("k_proj"),
|
||||
)?;
|
||||
let v_proj = linear_b(
|
||||
cfg.hidden_size,
|
||||
num_kv_heads * head_dim,
|
||||
cfg.attention_bias,
|
||||
vb.pp("v_proj"),
|
||||
)?;
|
||||
let o_proj = linear_b(
|
||||
num_heads * head_dim,
|
||||
cfg.hidden_size,
|
||||
cfg.attention_bias,
|
||||
vb.pp("o_proj"),
|
||||
)?;
|
||||
|
||||
let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?;
|
||||
let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?;
|
||||
|
||||
// Necessary because the hidden_size in the config isn't always accurate
|
||||
let hidden_size = head_dim * cfg.num_attention_heads;
|
||||
|
||||
// Initialize KV cache with 512 tokens capacity to reduce initial memory allocation.
|
||||
// The cache will grow in chunks of 512 tokens when needed.
|
||||
let kv_cache = KvCache::new(2, 512);
|
||||
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
q_norm,
|
||||
k_norm,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
num_kv_groups,
|
||||
head_dim,
|
||||
hidden_size,
|
||||
rotary_emb,
|
||||
kv_cache,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn forward(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
attn_mask: Option<&Tensor>,
|
||||
offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let (b, l, _) = x.dims3()?;
|
||||
|
||||
// 1. Proj
|
||||
let q = self.q_proj.forward(x)?;
|
||||
let k = self.k_proj.forward(x)?;
|
||||
let v = self.v_proj.forward(x)?;
|
||||
|
||||
// 2. Reshape: (B, L, H, D) -> (B, H, L, D)
|
||||
let q = q
|
||||
.reshape((b, l, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let k = k
|
||||
.reshape((b, l, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let v = v
|
||||
.reshape((b, l, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
// 3. Per‑head RMSNorm
|
||||
let q_flat = q.flatten(0, 2)?; // (B*H, L, D) -> (BHL, D) after transpose later
|
||||
let k_flat = k.flatten(0, 2)?;
|
||||
let q_flat = self.q_norm.forward(&q_flat)?;
|
||||
let k_flat = self.k_norm.forward(&k_flat)?;
|
||||
let q = q_flat.reshape((b, self.num_heads, l, self.head_dim))?;
|
||||
let k = k_flat.reshape((b, self.num_kv_heads, l, self.head_dim))?;
|
||||
|
||||
// 4. RoPE
|
||||
let (q, k) = self.rotary_emb.apply(&q, &k, offset)?;
|
||||
|
||||
// 5. Accumulate KV cache
|
||||
let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
|
||||
|
||||
// 6. GQA repeat_kv
|
||||
let k = repeat_kv(k, self.num_kv_groups)?;
|
||||
let v = repeat_kv(v, self.num_kv_groups)?;
|
||||
|
||||
// 7. Attention score
|
||||
let scale = 1.0 / (self.head_dim as f64).sqrt();
|
||||
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
|
||||
if let Some(m) = attn_mask {
|
||||
scores = scores.broadcast_add(m)?;
|
||||
}
|
||||
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
|
||||
let ctx = probs.matmul(&v)?; // (B, H, L, D)
|
||||
|
||||
// 8. Output proj
|
||||
ctx.transpose(1, 2)?
|
||||
.reshape((b, l, self.hidden_size))?
|
||||
.apply(&self.o_proj)
|
||||
}
|
||||
|
||||
pub(crate) fn clear_kv_cache(&mut self) {
|
||||
self.kv_cache.reset();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderLayer {
|
||||
self_attn: Qwen3Attention,
|
||||
mlp: Qwen3MLP,
|
||||
ln1: RmsNorm,
|
||||
ln2: RmsNorm,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(cfg: &Config, rotary: Arc<Qwen3RotaryEmbedding>, vb: VarBuilder) -> Result<Self> {
|
||||
let self_attn = Qwen3Attention::new(cfg, rotary, vb.pp("self_attn"))?;
|
||||
let mlp = Qwen3MLP::new(cfg, vb.pp("mlp"))?;
|
||||
let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||
let ln2 = RmsNorm::new(
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
mlp,
|
||||
ln1,
|
||||
ln2,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
|
||||
let h = self.ln1.forward(x)?;
|
||||
let h = self.self_attn.forward(&h, mask, offset)?;
|
||||
let x = (x + h)?;
|
||||
let h2 = self.ln2.forward(&x)?;
|
||||
let h2 = h2.apply(&self.mlp)?;
|
||||
x + h2
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.self_attn.clear_kv_cache();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embed_tokens: candle_nn::Embedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
norm: RmsNorm,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let embed_tokens =
|
||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
|
||||
let rotary = Arc::new(Qwen3RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?);
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_l = vb.pp("model.layers");
|
||||
for i in 0..cfg.num_hidden_layers {
|
||||
layers.push(DecoderLayer::new(cfg, rotary.clone(), vb_l.pp(i))?);
|
||||
}
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm: RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?,
|
||||
device: vb.device().clone(),
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
for l in &mut self.layers {
|
||||
l.clear_kv_cache();
|
||||
}
|
||||
}
|
||||
|
||||
fn causal_mask(
|
||||
&self,
|
||||
b: usize,
|
||||
tgt: usize,
|
||||
offset: usize,
|
||||
sw: Option<usize>,
|
||||
) -> Result<Tensor> {
|
||||
let minf = f32::NEG_INFINITY;
|
||||
let mask: Vec<_> = (0..tgt)
|
||||
.flat_map(|i| {
|
||||
(0..(tgt + offset)).map(move |j| {
|
||||
let past_ok = j <= i + offset;
|
||||
let sw_ok = match sw {
|
||||
Some(w) => (i + offset) as i64 - j as i64 <= w as i64,
|
||||
None => true,
|
||||
};
|
||||
if past_ok && sw_ok {
|
||||
0.
|
||||
} else {
|
||||
minf
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
|
||||
let (b, l) = input.dims2()?;
|
||||
let mut h = self.embed_tokens.forward(input)?;
|
||||
|
||||
let causal = if l == 1 {
|
||||
None
|
||||
} else {
|
||||
Some(self.causal_mask(b, l, offset, None)?)
|
||||
};
|
||||
|
||||
for layer in &mut self.layers {
|
||||
h = layer.forward(&h, causal.as_ref(), offset)?;
|
||||
}
|
||||
self.norm.forward(&h)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelForCausalLM {
|
||||
base: Model,
|
||||
lm_head: Linear,
|
||||
}
|
||||
|
||||
impl ModelForCausalLM {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let base = Model::new(cfg, vb.clone())?;
|
||||
let lm_head = if cfg.tie_word_embeddings {
|
||||
Linear::from_weights(base.embed_tokens.embeddings().clone(), None)
|
||||
} else {
|
||||
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
|
||||
};
|
||||
Ok(Self { base, lm_head })
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
|
||||
let (_, l) = input.dims2()?;
|
||||
self.base
|
||||
.forward(input, offset)?
|
||||
.narrow(1, l - 1, 1)?
|
||||
.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
self.base.clear_kv_cache();
|
||||
}
|
||||
}
|
355
candle-transformers/src/models/qwen3_moe.rs
Normal file
355
candle-transformers/src/models/qwen3_moe.rs
Normal file
@ -0,0 +1,355 @@
|
||||
use crate::models::{
|
||||
qwen3::{Config as Qwen3Config, Qwen3Attention, Qwen3MLP, Qwen3RotaryEmbedding},
|
||||
with_tracing::{linear_no_bias, Linear, RmsNorm},
|
||||
};
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{Activation, VarBuilder};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
pub vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub head_dim: usize,
|
||||
pub attention_bias: bool,
|
||||
pub num_key_value_heads: usize,
|
||||
pub max_position_embeddings: usize,
|
||||
pub sliding_window: Option<usize>,
|
||||
pub max_window_layers: usize,
|
||||
pub tie_word_embeddings: bool,
|
||||
pub rope_theta: f64,
|
||||
pub rms_norm_eps: f64,
|
||||
pub use_sliding_window: bool,
|
||||
pub hidden_act: Activation,
|
||||
// MoE specific configuration
|
||||
pub decoder_sparse_step: usize,
|
||||
pub moe_intermediate_size: usize,
|
||||
pub num_experts_per_tok: usize,
|
||||
pub num_experts: usize,
|
||||
pub norm_topk_prob: bool,
|
||||
}
|
||||
|
||||
impl From<&Config> for Qwen3Config {
|
||||
fn from(val: &Config) -> Self {
|
||||
Qwen3Config {
|
||||
vocab_size: val.vocab_size,
|
||||
hidden_size: val.hidden_size,
|
||||
intermediate_size: val.intermediate_size,
|
||||
num_hidden_layers: val.num_hidden_layers,
|
||||
num_attention_heads: val.num_attention_heads,
|
||||
head_dim: val.head_dim,
|
||||
attention_bias: val.attention_bias,
|
||||
num_key_value_heads: val.num_key_value_heads,
|
||||
max_position_embeddings: val.max_position_embeddings,
|
||||
sliding_window: val.sliding_window,
|
||||
max_window_layers: val.max_window_layers,
|
||||
tie_word_embeddings: val.tie_word_embeddings,
|
||||
rope_theta: val.rope_theta,
|
||||
rms_norm_eps: val.rms_norm_eps,
|
||||
use_sliding_window: val.use_sliding_window,
|
||||
hidden_act: val.hidden_act,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Qwen3MLPExpert {
|
||||
gate_proj: Linear,
|
||||
up_proj: Linear,
|
||||
down_proj: Linear,
|
||||
act_fn: Activation,
|
||||
}
|
||||
|
||||
impl Qwen3MLPExpert {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
Ok(Self {
|
||||
gate_proj: linear_no_bias(
|
||||
cfg.hidden_size,
|
||||
cfg.moe_intermediate_size,
|
||||
vb.pp("gate_proj"),
|
||||
)?,
|
||||
up_proj: linear_no_bias(cfg.hidden_size, cfg.moe_intermediate_size, vb.pp("up_proj"))?,
|
||||
down_proj: linear_no_bias(
|
||||
cfg.moe_intermediate_size,
|
||||
cfg.hidden_size,
|
||||
vb.pp("down_proj"),
|
||||
)?,
|
||||
act_fn: cfg.hidden_act,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Qwen3MLPExpert {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let lhs = x.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||
let rhs = x.apply(&self.up_proj)?;
|
||||
(lhs * rhs)?.apply(&self.down_proj)
|
||||
}
|
||||
}
|
||||
|
||||
// Qwen3 Sparse MoE Block implementation
|
||||
#[derive(Debug, Clone)]
|
||||
struct Qwen3SparseMoeBlock {
|
||||
gate: Linear,
|
||||
experts: Vec<Qwen3MLPExpert>,
|
||||
norm_topk_prob: bool,
|
||||
num_experts_per_tok: usize,
|
||||
}
|
||||
|
||||
impl Qwen3SparseMoeBlock {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let gate = linear_no_bias(cfg.hidden_size, cfg.num_experts, vb.pp("gate"))?;
|
||||
let mut experts = Vec::with_capacity(cfg.num_experts);
|
||||
let vb_e = vb.pp("experts");
|
||||
for idx in 0..cfg.num_experts {
|
||||
let expert = Qwen3MLPExpert::new(cfg, vb_e.pp(idx))?;
|
||||
experts.push(expert)
|
||||
}
|
||||
Ok(Self {
|
||||
gate,
|
||||
experts,
|
||||
norm_topk_prob: cfg.norm_topk_prob,
|
||||
num_experts_per_tok: cfg.num_experts_per_tok,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Qwen3SparseMoeBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (b_size, seq_len, hidden_dim) = xs.dims3()?;
|
||||
let xs = xs.reshape(((), hidden_dim))?;
|
||||
let router_logits = xs.apply(&self.gate)?;
|
||||
let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
|
||||
|
||||
// Extract topk experts per token
|
||||
let experts_per_tok = routing_weights
|
||||
.arg_sort_last_dim(false)?
|
||||
.narrow(D::Minus1, 0, self.num_experts_per_tok)?
|
||||
.contiguous()?;
|
||||
let routing_weights = routing_weights.gather(&experts_per_tok, D::Minus1)?;
|
||||
|
||||
// Extract needed data
|
||||
let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
|
||||
let experts_per_tok = experts_per_tok.to_vec2::<u32>()?;
|
||||
let mut top_x = vec![vec![]; self.experts.len()];
|
||||
let mut selected_experts = vec![vec![]; self.experts.len()];
|
||||
for (row_idx, (rw, expert_idxs)) in routing_weights
|
||||
.iter()
|
||||
.zip(experts_per_tok.iter())
|
||||
.enumerate()
|
||||
{
|
||||
let sum_rw = rw.iter().sum::<f32>();
|
||||
for (&rw, &expert_idx) in rw.iter().zip(expert_idxs.iter()) {
|
||||
top_x[expert_idx as usize].push(row_idx as u32);
|
||||
let rw = if self.norm_topk_prob { rw / sum_rw } else { rw };
|
||||
selected_experts[expert_idx as usize].push(rw)
|
||||
}
|
||||
}
|
||||
|
||||
// Process through experts
|
||||
let mut ys = xs.zeros_like()?;
|
||||
for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
|
||||
let top_x = &top_x[expert_idx];
|
||||
if top_x.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
|
||||
let selected_experts =
|
||||
Tensor::new(selected_experts[expert_idx].as_slice(), xs.device())?
|
||||
.reshape(((), 1))?
|
||||
.to_dtype(xs.dtype())?;
|
||||
|
||||
let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
|
||||
let current_hidden_states = expert_layer.forward(¤t_state)?;
|
||||
let current_hidden_states = current_hidden_states.broadcast_mul(&selected_experts)?;
|
||||
ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?;
|
||||
}
|
||||
|
||||
ys.reshape((b_size, seq_len, hidden_dim))
|
||||
}
|
||||
}
|
||||
|
||||
// MLP or MoE decision enum
|
||||
#[derive(Debug, Clone)]
|
||||
enum Qwen3FeedForward {
|
||||
Mlp(Qwen3MLP),
|
||||
MoE(Qwen3SparseMoeBlock),
|
||||
}
|
||||
|
||||
impl Module for Qwen3FeedForward {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::Mlp(m) => m.forward(xs),
|
||||
Self::MoE(m) => m.forward(xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderLayer {
|
||||
self_attn: Qwen3Attention,
|
||||
feed_forward: Qwen3FeedForward,
|
||||
ln1: RmsNorm,
|
||||
ln2: RmsNorm,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(
|
||||
layer_idx: usize,
|
||||
cfg: &Config,
|
||||
rotary: Arc<Qwen3RotaryEmbedding>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let self_attn = Qwen3Attention::new(&cfg.into(), rotary, vb.pp("self_attn"))?;
|
||||
|
||||
// Decide whether to use MoE or regular MLP based on layer_idx and decoder_sparse_step
|
||||
let feed_forward = if cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0
|
||||
{
|
||||
Qwen3FeedForward::MoE(Qwen3SparseMoeBlock::new(cfg, vb.pp("mlp"))?)
|
||||
} else {
|
||||
Qwen3FeedForward::Mlp(Qwen3MLP::new(&cfg.into(), vb.pp("mlp"))?)
|
||||
};
|
||||
|
||||
let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||
let ln2 = RmsNorm::new(
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
feed_forward,
|
||||
ln1,
|
||||
ln2,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
|
||||
let h = self.ln1.forward(x)?;
|
||||
let h = self.self_attn.forward(&h, mask, offset)?;
|
||||
let x = (x + h)?;
|
||||
let h2 = self.ln2.forward(&x)?;
|
||||
let h2 = h2.apply(&self.feed_forward)?;
|
||||
x + h2
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.self_attn.clear_kv_cache();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embed_tokens: candle_nn::Embedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
norm: RmsNorm,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let embed_tokens =
|
||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
|
||||
let rotary = Arc::new(Qwen3RotaryEmbedding::new(
|
||||
vb.dtype(),
|
||||
&cfg.into(),
|
||||
vb.device(),
|
||||
)?);
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_l = vb.pp("model.layers");
|
||||
for i in 0..cfg.num_hidden_layers {
|
||||
layers.push(DecoderLayer::new(i, cfg, rotary.clone(), vb_l.pp(i))?);
|
||||
}
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm: RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?,
|
||||
device: vb.device().clone(),
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
for l in &mut self.layers {
|
||||
l.clear_kv_cache();
|
||||
}
|
||||
}
|
||||
|
||||
fn causal_mask(
|
||||
&self,
|
||||
b: usize,
|
||||
tgt: usize,
|
||||
offset: usize,
|
||||
sw: Option<usize>,
|
||||
) -> Result<Tensor> {
|
||||
let minf = f32::NEG_INFINITY;
|
||||
let mask: Vec<_> = (0..tgt)
|
||||
.flat_map(|i| {
|
||||
(0..(tgt + offset)).map(move |j| {
|
||||
let past_ok = j <= i + offset;
|
||||
let sw_ok = match sw {
|
||||
Some(w) => (i + offset) as i64 - j as i64 <= w as i64,
|
||||
None => true,
|
||||
};
|
||||
if past_ok && sw_ok {
|
||||
0.
|
||||
} else {
|
||||
minf
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
|
||||
let (b, l) = input.dims2()?;
|
||||
let mut h = self.embed_tokens.forward(input)?;
|
||||
|
||||
let causal = if l == 1 {
|
||||
None
|
||||
} else {
|
||||
Some(self.causal_mask(b, l, offset, None)?)
|
||||
};
|
||||
|
||||
for layer in &mut self.layers {
|
||||
h = layer.forward(&h, causal.as_ref(), offset)?;
|
||||
}
|
||||
self.norm.forward(&h)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelForCausalLM {
|
||||
base: Model,
|
||||
lm_head: Linear,
|
||||
}
|
||||
|
||||
impl ModelForCausalLM {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let base = Model::new(cfg, vb.clone())?;
|
||||
let lm_head = if cfg.tie_word_embeddings {
|
||||
Linear::from_weights(base.embed_tokens.embeddings().clone(), None)
|
||||
} else {
|
||||
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
|
||||
};
|
||||
Ok(Self { base, lm_head })
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
|
||||
let (_, l) = input.dims2()?;
|
||||
self.base
|
||||
.forward(input, offset)?
|
||||
.narrow(1, l - 1, 1)?
|
||||
.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
self.base.clear_kv_cache();
|
||||
}
|
||||
}
|
@ -17,8 +17,8 @@ const CROP_NMS_THRESH: f32 = 0.7;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ImageEncoder {
|
||||
Original(ImageEncoderViT),
|
||||
TinyViT(TinyViT),
|
||||
Original(Box<ImageEncoderViT>),
|
||||
TinyViT(Box<TinyViT>),
|
||||
}
|
||||
|
||||
impl Module for ImageEncoder {
|
||||
@ -83,7 +83,7 @@ impl Sam {
|
||||
let pixel_std =
|
||||
Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
|
||||
Ok(Self {
|
||||
image_encoder: ImageEncoder::Original(image_encoder),
|
||||
image_encoder: ImageEncoder::Original(image_encoder.into()),
|
||||
prompt_encoder,
|
||||
mask_decoder,
|
||||
pixel_std,
|
||||
@ -114,7 +114,7 @@ impl Sam {
|
||||
let pixel_std =
|
||||
Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
|
||||
Ok(Self {
|
||||
image_encoder: ImageEncoder::TinyViT(image_encoder),
|
||||
image_encoder: ImageEncoder::TinyViT(image_encoder.into()),
|
||||
prompt_encoder,
|
||||
mask_decoder,
|
||||
pixel_std,
|
||||
|
@ -134,12 +134,7 @@ impl Scheduler for DDIMScheduler {
|
||||
timestep
|
||||
};
|
||||
// https://github.com/huggingface/diffusers/blob/6e099e2c8ce4c4f5c7318e970a8c093dc5c7046e/src/diffusers/schedulers/scheduling_ddim.py#L195
|
||||
let prev_timestep = if timestep > self.step_ratio {
|
||||
timestep - self.step_ratio
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let prev_timestep = timestep.saturating_sub(self.step_ratio);
|
||||
let alpha_prod_t = self.alphas_cumprod[timestep];
|
||||
let alpha_prod_t_prev = self.alphas_cumprod[prev_timestep];
|
||||
let beta_prod_t = 1. - alpha_prod_t;
|
||||
|
@ -482,8 +482,10 @@ impl XLMRobertaClassificationHead {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let cls_states = hidden_states.get_on_dim(1, 0)?.contiguous()?;
|
||||
let hidden_states = self.dense.forward(&cls_states)?;
|
||||
let hidden_states = candle_nn::Activation::GeluPytorchTanh.forward(&hidden_states)?;
|
||||
let hidden_states = self.out_proj.forward(&hidden_states)?;
|
||||
// The activation used in the classification head is tanh, as per the original
|
||||
// implementation.
|
||||
// https://github.com/huggingface/transformers/blob/6e3063422c4b1c014aa60c32b9254fd2902f0f28/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py#L1454
|
||||
let hidden_states = self.out_proj.forward(&hidden_states.tanh()?)?;
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user