Compare commits

..

39 Commits

Author SHA1 Message Date
6d6d87f8b3 Use BF16 for llama v3 by default. 2024-04-19 14:22:01 +02:00
9c532aef47 Also enable llama-v3 8b instruct. (#2088) 2024-04-19 08:50:06 +02:00
f7a6468238 Add support for llama3 on the quantized example (#2086)
* add support for l3b, new tokenizer

* add todo

* Add todo and use k_s model

* Use the official tokenizers.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-18 22:52:00 +02:00
2b93dffe64 Use faster rotary embeddings for llama like models. (#2087) 2024-04-18 22:34:29 +02:00
e6ee7ba4d4 Llama v3. (#2085)
* Llama v3.

* Tweak the default params + handle special tokens.

* Small tweak.
2024-04-18 22:19:54 +02:00
1690ab45d2 Fix the silu gradient issue on 0. (#2083) 2024-04-18 14:31:41 +02:00
8de0ce6cba Add more QMMV cuda kernels. (#2077)
* Add more QMMV cuda kernels.

* Enable the new kernels.

* Adapt the testing.
2024-04-18 08:36:43 +02:00
ce6d08df94 Minor fix to the readme. (#2080)
Co-authored-by: Jane Doe <jane.doe@example.org>
2024-04-17 22:43:00 +02:00
2817643db9 Add the mmv kernels for small batch sizes. (#2075)
* Add the mmv kernels for smaller sizes.

* Support more mmv kernels.

* Use the new kernels.

* Fix the call.

* Silly fix.

* Improve the testing.

* Fix for dmmv.

* Add another dedicated test for the batching mmv.
2024-04-16 21:30:51 +02:00
4d14777673 Utilize batches in Stable Diffusion (#2071)
* Utilize batches in Stable Diffusion that were already there, but unutilized.

Also refactor out the `save_image` function.

* Clippy + cosmetic fixes.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-16 06:49:04 +02:00
f135b7963d Fix for the batch dim in the quantized matmul example. (#2073)
* Fix for the batch dim in the quantized matmul example.

* Enable more tests on cuda.

* Add a test for qmm with a batch.

* Fix the zeros-dim test on metal.
2024-04-15 20:00:28 +02:00
af955f260c Make the falcon model cloneable. (#2067) 2024-04-15 09:39:03 +02:00
8ad822a983 Add a function to clear the KV cache in falcon. (#2066)
* Add a function to clear the KV cache in falcon.

* Clippy.
2024-04-15 09:29:25 +02:00
e198bb0816 Handle zero dims in some simple operations. (#2064)
* Handle zero dims in some simple operations.

* Handle zero-dims in matmul.

* More testing.
2024-04-15 09:18:54 +02:00
f7d5bf5b97 Faster kernels for quantized matmul on cuda (#2060)
* Hook the quantized matmul cuda kernels.

* Add a (currently broken) test.

* Kernel fixes.

* Fix by transposing the rhs matrix.

* Add the q4-1 kernels.

* Proper block sizes.

* More details in the tests.
2024-04-15 08:32:47 +02:00
c119600d6e Move image tensor to device in trocr example (#2063)
Signed-off-by: Harry Stern <harry@harrystern.net>
2024-04-15 06:50:32 +02:00
c449f65b12 Expose the synchronize function on the generic device. (#2062) 2024-04-14 23:02:03 +02:00
db7dbf3071 Add missing bfloat unary strided kernels and fix typo (#2058) 2024-04-14 20:01:13 +02:00
4ecedb1598 Add the full quantized matmul kernels for cuda. (#2057) 2024-04-14 17:52:08 +02:00
53e5380bf6 Add a synchronize method to devices. (#2055)
* Add a synchronize method to devices.

* Metal version.
2024-04-14 16:32:55 +02:00
50e49ecc5f Add a quantized version of recurrent-gemma. (#2054)
* Add a quantized version of recurrent-gemma.

* Share the rglru part.

* Get the quantized gemma model to work.
2024-04-13 20:07:01 +02:00
4c88c3ce06 Add benchmarks for qmatmul operations (#2048)
* Add qmatmul bench

* add all dtypes
2024-04-13 12:30:14 +02:00
8b8fb630df Add a convenient way to rename tensors accessed through a varbuilder. (#2052) 2024-04-13 12:09:41 +02:00
fb805b8ca2 Avoid crashes when running T5 models with F16 tensors on CPU (#2047)
* This change avoids crashes when running T5 models with F16 tensors on CPU.

* This enables running ProstT5's (https://huggingface.co/Rostlab/ProstT5) encoder-only mode in Candle. This ProstT5 mode stores it's embed_tokens weights within the encoder, as its decoding stage was replaced with a CNN. You could write more, like: This alone is not sufficient to run ProstT5 within Candle examples. We will develop a ProstT5 runner outside candle for now, but would be willing to upstream it to candle-examples at a later point.

* Revert "This enables running ProstT5's (https://huggingface.co/Rostlab/ProstT5) encoder-only mode in Candle. This ProstT5 mode stores it's embed_tokens weights within the encoder, as its decoding stage was replaced with a CNN. You could write more, like: This alone is not sufficient to run ProstT5 within Candle examples. We will develop a ProstT5 runner outside candle for now, but would be willing to upstream it to candle-examples at a later point."

This reverts commit d886d3ce5e.
2024-04-13 11:07:28 +02:00
79e3bec789 Change for the encoder-only ProstT5 model (#2045)
* This change avoids crashes when running T5 models with F16 tensors on CPU.

* This enables running ProstT5's (https://huggingface.co/Rostlab/ProstT5) encoder-only mode in Candle. This ProstT5 mode stores it's embed_tokens weights within the encoder, as its decoding stage was replaced with a CNN.  This alone is not sufficient to run ProstT5 within Candle examples. We will develop a ProstT5 runner outside candle for now, but would be willing to upstream it to candle-examples at a later point.
2024-04-13 11:06:24 +02:00
e6d412b156 Add ReduceMean onnx operation (#2049)
* Add ReduceMean onnx operation

* Format code with rustfmt
2024-04-13 11:00:25 +02:00
26cbbf8d84 Mandatory topk sampling for recurrent-gemma. (#2051) 2024-04-13 10:31:39 +02:00
2bf413caa3 Add the recurrent-gemma model. (#2039)
* Start adding the recurrent-gemma model.

* More griffin.

* Add the example + get the weights to load from the HF version.

* More inference code.

* Rope + kv-cache on the attention side.

* Add to the inference code.

* Add more to the recurrent gemma inference.

* Get some first inference to run.

* Add the softcap on logits.

* Fixes.

* Use partial rotary embeddings.

* Get inference to work.

* Add a comment.

* And add a readme.
2024-04-13 00:05:21 +02:00
3ad4770eb6 Use cat for faster MQA computation. (#2043)
* Use cat for faster MQA computation.

* Move the function to utils + use it in mistral.

* Use the shared repeat-kv in a few more models.

* Fix.
2024-04-12 09:15:10 +02:00
a0460cd2b1 Add the code-gemma models. (#2038)
* Add the code-gemma models.

* Tweak to the gemma config.
2024-04-10 21:19:21 +02:00
b81ecf712d Support alternative dtypes for mamba (#2036)
* Allow different dtypes in mamba.

* Add a dtype flag.
2024-04-10 18:10:01 +02:00
a4d5a414e3 Support gather on bf16 for metal. (#2035) 2024-04-10 12:49:25 +02:00
798e0335cd Handle more tensor shapes in onnx "Gather" operation (#2026)
* Handle more tensor shapes in onnx "Gather" operation

* Add more tests

* Add comment

* Fix typo
2024-04-08 14:06:14 +02:00
718671a0d5 Use BufferOffset in metal backend ops. (#2029)
* Use BufferOffset in the metal backend.

* More BufferOffset usage.

* Use in where-cond.
2024-04-08 09:37:25 +02:00
c5fe4a7f89 Rework the buffer offset logic for metal kernels (#2028)
* Move the metal kernels utils in a separate module.

* Use the BufferOffset for unary ops.

* Fix clippy lints.

* Use the new BufferOffset.

* Adapt the binary ops.

* Affine.

* More ops (powf, elu, cast).
2024-04-07 22:37:53 +02:00
7f354473cf Optimize copy-2d for metal. (#2024)
* Optimize copy-2d for metal.

* Add a hacky stopping rule for moondream.
2024-04-07 12:34:16 +02:00
33c9b66554 Add the new gemma models. (#2023)
* Add the new gemma models.

* Revert the lightning changes.

* Support for the 1.1 models.
2024-04-06 21:25:38 +02:00
9fd52b3b71 Handle the batch dimension in quantized MMV on metal. (#2022) 2024-04-06 20:02:24 +02:00
e662431acf Fix the final rmsnorm for quantized-metavoice. (#2021) 2024-04-06 19:35:01 +02:00
62 changed files with 4752 additions and 812 deletions

View File

@ -63,8 +63,9 @@ We also provide a some command line based examples using state of the art models
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes
the SOLAR-10.7B variant.
- [Falcon](./candle-examples/examples/falcon/): general LLM.
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google
Deepmind.
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google Deepmind.
- [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b
Griffin based models from Google that mix attention with a RNN like state.
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
pre-trained on 1T tokens of English and code datasets. Also supports
@ -374,9 +375,9 @@ git submodule update --init
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ...:
```
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the CANDLE_NVCC_CCBIN environment variable.
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the NVCC_CCBIN environment variable.
```
env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
env NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
```
#### Linking error on windows when running rustdoc or mdbook tests

View File

@ -7,4 +7,5 @@ criterion_main!(
benchmarks::random::benches,
benchmarks::where_cond::benches,
benchmarks::conv_transpose2d::benches,
benchmarks::qmatmul::benches,
);

View File

@ -1,6 +1,7 @@
pub(crate) mod affine;
pub(crate) mod conv_transpose2d;
pub(crate) mod matmul;
pub(crate) mod qmatmul;
pub(crate) mod random;
pub(crate) mod where_cond;

View File

@ -0,0 +1,72 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{
quantized::{self, GgmlDType, QMatMul},
Device, Module, Tensor,
};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(matmul: &QMatMul, x: &Tensor) {
matmul.forward(&x).unwrap();
}
fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
let b = 1;
let m = 1;
let n = 1024;
let k = 1024;
let lhs = (0..(m * k))
.map(|v| v as f32 / (m * k) as f32)
.collect::<Vec<_>>();
let rhs = (0..(k * n))
.map(|v| v as f32 / (n * k) as f32)
.collect::<Vec<_>>();
let lhs = Tensor::from_slice(&lhs, (m, k), device).unwrap();
let rhs = Tensor::from_slice(&rhs, (k, n), device).unwrap();
let qtensor = quantized::QTensor::quantize(&rhs.t().unwrap(), dtype).unwrap();
let matmul = quantized::QMatMul::from_qtensor(qtensor).unwrap();
let flops = b * m * n * k;
let mut group = c.benchmark_group(device.bench_name(format!("qmatmul_{:?}", dtype)));
group.sample_size(200);
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&matmul), black_box(&lhs));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
for dtype in vec![
GgmlDType::F32,
GgmlDType::F16,
GgmlDType::Q4_0,
GgmlDType::Q4_1,
GgmlDType::Q5_0,
GgmlDType::Q5_1,
GgmlDType::Q8_0,
GgmlDType::Q2K,
GgmlDType::Q3K,
GgmlDType::Q4K,
GgmlDType::Q5K,
GgmlDType::Q6K,
] {
run_bench(c, &device, dtype);
}
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -142,4 +142,7 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
fn set_seed(&self, _: u64) -> Result<()>;
/// Synchronize should block until all the operations on the device are completed.
fn synchronize(&self) -> Result<()>;
}

View File

@ -624,7 +624,7 @@ impl Tensor {
Op::Unary(arg, UnaryOp::Silu) => {
let sum_grad = grads.or_insert(arg)?;
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
let sigmoid_arg = (*node / arg)?;
let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
}

View File

@ -2628,6 +2628,10 @@ impl BackendDevice for CpuDevice {
};
Ok(storage)
}
fn synchronize(&self) -> Result<()> {
Ok(())
}
}
#[macro_export]

View File

@ -407,4 +407,9 @@ impl BackendDevice for CudaDevice {
device: self.clone(),
})
}
fn synchronize(&self) -> Result<()> {
self.device.synchronize().map_err(crate::Error::wrap)?;
Ok(())
}
}

View File

@ -337,4 +337,12 @@ impl Device {
}
}
}
pub fn synchronize(&self) -> Result<()> {
match self {
Self::Cpu => Ok(()),
Self::Cuda(d) => d.synchronize(),
Self::Metal(d) => d.synchronize(),
}
}
}

View File

@ -229,4 +229,8 @@ impl crate::backend::BackendDevice for CudaDevice {
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
fn synchronize(&self) -> Result<()> {
Ok(())
}
}

View File

@ -241,4 +241,8 @@ impl crate::backend::BackendDevice for MetalDevice {
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn synchronize(&self) -> Result<()> {
Ok(())
}
}

View File

@ -283,5 +283,5 @@ impl MetalDevice {
}
fn buf_size(size: NSUInteger) -> NSUInteger {
(size - 1).next_power_of_two() as NSUInteger
size.saturating_sub(1).next_power_of_two() as NSUInteger
}

View File

@ -2,8 +2,7 @@ use crate::backend::{BackendDevice, BackendStorage};
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels::CallConvTranspose2dCfg;
use candle_metal_kernels::Kernels;
use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels};
use metal::{Buffer, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
use std::ffi::c_void;
@ -12,6 +11,12 @@ use std::sync::{Arc, Mutex, RwLock, TryLockError};
mod device;
pub use device::{DeviceId, MetalDevice};
fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> {
BufferOffset {
buffer,
offset_in_bytes: l.start_offset() * dtype.size_in_bytes(),
}
}
/// Simple way to catch lock error without
/// depending on T
#[derive(thiserror::Error, Debug)]
@ -102,7 +107,8 @@ impl BackendStorage for MetalStorage {
let buffer = device.new_buffer(el, self.dtype, "affine")?;
let command_buffer = self.device.command_buffer()?;
if layout.is_contiguous() && layout.start_offset() == 0 {
let src = buffer_o(&self.buffer, layout, dtype);
if layout.is_contiguous() {
let name = match self.dtype {
DType::F32 => "affine_f32",
DType::F16 => "affine_f16",
@ -115,7 +121,7 @@ impl BackendStorage for MetalStorage {
&device.kernels,
name,
el,
&self.buffer,
src,
&buffer,
mul as f32,
add as f32,
@ -134,9 +140,8 @@ impl BackendStorage for MetalStorage {
&device.kernels,
name,
layout.dims(),
&self.buffer,
src,
layout.stride(),
layout.start_offset() * dtype.size_in_bytes(),
&buffer,
mul as f32,
add as f32,
@ -155,7 +160,8 @@ impl BackendStorage for MetalStorage {
let buffer = device.new_buffer(el, self.dtype, "powf")?;
let command_buffer = self.device.command_buffer()?;
if layout.is_contiguous() && layout.start_offset() == 0 {
let src = buffer_o(&self.buffer, layout, dtype);
if layout.is_contiguous() {
let name = match self.dtype {
DType::F32 => "powf_f32",
DType::F16 => "powf_f16",
@ -168,7 +174,7 @@ impl BackendStorage for MetalStorage {
&device.kernels,
name,
el,
&self.buffer,
src,
&buffer,
pow as f32,
)
@ -186,9 +192,8 @@ impl BackendStorage for MetalStorage {
&device.kernels,
name,
layout.dims(),
&self.buffer,
src,
layout.stride(),
layout.start_offset() * dtype.size_in_bytes(),
&buffer,
pow as f32,
)
@ -206,7 +211,8 @@ impl BackendStorage for MetalStorage {
let buffer = device.new_buffer(el, self.dtype, "elu")?;
let command_buffer = self.device.command_buffer()?;
if layout.is_contiguous() && layout.start_offset() == 0 {
let src = buffer_o(&self.buffer, layout, self.dtype);
if layout.is_contiguous() {
let name = match self.dtype {
DType::F32 => "elu_f32",
DType::F16 => "elu_f16",
@ -219,7 +225,7 @@ impl BackendStorage for MetalStorage {
&device.kernels,
name,
el,
&self.buffer,
src,
&buffer,
alpha as f32,
)
@ -237,9 +243,8 @@ impl BackendStorage for MetalStorage {
&device.kernels,
name,
layout.dims(),
&self.buffer,
src,
layout.stride(),
layout.start_offset() * dtype.size_in_bytes(),
&buffer,
alpha as f32,
)
@ -309,6 +314,7 @@ impl BackendStorage for MetalStorage {
let dtype = if return_index { DType::U32 } else { self.dtype };
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, layout, self.dtype);
candle_metal_kernels::call_reduce_strided(
&device.device,
&command_buffer,
@ -317,8 +323,7 @@ impl BackendStorage for MetalStorage {
&dims,
&stride,
dst_el,
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
src,
&buffer,
)
.map_err(MetalError::from)?;
@ -344,7 +349,8 @@ impl BackendStorage for MetalStorage {
let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype, "todtype")?;
let command_buffer = device.command_buffer()?;
if layout.is_contiguous() && layout.start_offset() == 0 {
let src = buffer_o(&self.buffer, layout, self.dtype);
if layout.is_contiguous() {
let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::BF16) => "cast_u32_bf16",
(DType::U32, DType::F16) => "cast_u32_f16",
@ -392,8 +398,7 @@ impl BackendStorage for MetalStorage {
&device.kernels,
kernel_name,
el_count,
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
src,
&buffer,
)
.map_err(MetalError::from)?;
@ -420,9 +425,8 @@ impl BackendStorage for MetalStorage {
&device.kernels,
kernel_name,
layout.dims(),
&self.buffer,
src,
layout.stride(),
layout.start_offset() * self.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
@ -439,7 +443,8 @@ impl BackendStorage for MetalStorage {
let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?;
let command_buffer = device.command_buffer()?;
command_buffer.set_label(B::KERNEL);
if layout.is_contiguous() && layout.start_offset() == 0 {
let src = buffer_o(&self.buffer, layout, self.dtype);
if layout.is_contiguous() {
use candle_metal_kernels::unary::contiguous;
let kernel_name = match (B::KERNEL, dtype) {
@ -511,7 +516,7 @@ impl BackendStorage for MetalStorage {
&device.kernels,
kernel_name,
el_count,
&self.buffer,
src,
&buffer,
)
.map_err(MetalError::from)?;
@ -535,6 +540,7 @@ impl BackendStorage for MetalStorage {
("urelu", DType::F32) => strided::relu::FLOAT,
("uround", DType::F32) => strided::round::FLOAT,
("utanh", DType::F32) => strided::tanh::FLOAT,
("ucos", DType::F16) => strided::cos::HALF,
("usin", DType::F16) => strided::sin::HALF,
("usqr", DType::F16) => strided::sqr::HALF,
@ -552,21 +558,39 @@ impl BackendStorage for MetalStorage {
("urelu", DType::F16) => strided::relu::HALF,
("uround", DType::F16) => strided::round::HALF,
("utanh", DType::F16) => strided::tanh::HALF,
("ucos", DType::BF16) => strided::cos::BFLOAT,
("usin", DType::BF16) => strided::sin::BFLOAT,
("usqr", DType::BF16) => strided::sqr::BFLOAT,
("usqrt", DType::BF16) => strided::sqrt::BFLOAT,
("uneg", DType::BF16) => strided::neg::BFLOAT,
("uexp", DType::BF16) => strided::exp::BFLOAT,
("ulog", DType::BF16) => strided::log::BFLOAT,
("ugelu", DType::BF16) => strided::gelu::BFLOAT,
("ugelu_erf", DType::BF16) => strided::gelu_erf::BFLOAT,
("uerf", DType::BF16) => strided::erf::BFLOAT,
("usilu", DType::BF16) => strided::silu::BFLOAT,
("uabs", DType::BF16) => strided::abs::BFLOAT,
("uceil", DType::BF16) => strided::ceil::BFLOAT,
("ufloor", DType::BF16) => strided::floor::BFLOAT,
("urelu", DType::BF16) => strided::relu::BFLOAT,
("uround", DType::BF16) => strided::round::BFLOAT,
("utanh", DType::BF16) => strided::tanh::BFLOAT,
(name, dtype) => {
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
}
};
let dst = BufferOffset::zero_offset(&buffer);
candle_metal_kernels::call_unary_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
layout.dims(),
&self.buffer,
src,
layout.stride(),
layout.start_offset() * self.dtype.size_in_bytes(),
&buffer,
0,
dst,
)
.map_err(MetalError::from)?;
}
@ -613,21 +637,21 @@ impl BackendStorage for MetalStorage {
(DType::U8, DType::U8) => "where_u8_u8",
(left, right) => crate::bail!("Metal where_cond {left:?} {right:?} not implemented"),
};
let src = buffer_o(&self.buffer, layout, self.dtype);
let t = buffer_o(&t.buffer, t_l, t.dtype);
let f = buffer_o(&f.buffer, f_l, f.dtype);
candle_metal_kernels::call_where_cond_strided(
&device.device,
&command_buffer,
&device.kernels,
name,
dims,
&self.buffer,
(
layout.stride(),
layout.start_offset() * self.dtype.size_in_bytes(),
),
&t.buffer,
(t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
&f.buffer,
(f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
src,
layout.stride(),
t,
t_l.stride(),
f,
f_l.stride(),
&buffer,
)
.map_err(MetalError::from)?;
@ -660,6 +684,7 @@ impl BackendStorage for MetalStorage {
DType::F32 => "im2col1d_f32",
dtype => crate::bail!("Metal conv1d {dtype:?} not implemented"),
};
let src = buffer_o(&self.buffer, layout, self.dtype);
candle_metal_kernels::call_im2col1d_strided(
&self.device.device,
&command_buffer,
@ -668,8 +693,7 @@ impl BackendStorage for MetalStorage {
layout.shape().dims(),
strides,
(k_size, stride, padding, dilation),
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
src,
&dst,
)
.map_err(MetalError::from)?;
@ -787,6 +811,7 @@ impl BackendStorage for MetalStorage {
DType::U32 => "im2col_u32",
dtype => crate::bail!("Metal conv2d {dtype:?} not implemented"),
};
let src = buffer_o(&self.buffer, layout, self.dtype);
candle_metal_kernels::call_im2col_strided(
&self.device.device,
&command_buffer,
@ -795,8 +820,7 @@ impl BackendStorage for MetalStorage {
layout.shape().dims(),
layout.stride(),
(h_k, w_k, stride, padding, dilation),
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
src,
&dst,
)
.map_err(MetalError::from)?;
@ -1009,6 +1033,7 @@ impl BackendStorage for MetalStorage {
.device
.new_buffer(dst_el, self.dtype, "upsample_nearest2d")?;
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, inp_l, self.dtype);
candle_metal_kernels::call_upsample_nearest_2d(
&self.device.device,
&command_buffer,
@ -1018,8 +1043,7 @@ impl BackendStorage for MetalStorage {
strides,
out_w,
out_h,
&self.buffer,
inp_l.start_offset() * self.dtype.size_in_bytes(),
src,
&buffer,
)
.map_err(MetalError::from)?;
@ -1027,9 +1051,8 @@ impl BackendStorage for MetalStorage {
}
fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
let (ids_o1, _) = match ids_l.contiguous_offsets() {
Some(o12) => o12,
None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?,
if !ids_l.is_contiguous() {
return Err(crate::Error::RequiresContiguous { op: "gather" }.bt());
};
let ids_el = ids_l.dims()[dim];
let dst_el = ids_l.shape().elem_count();
@ -1039,9 +1062,12 @@ impl BackendStorage for MetalStorage {
let name = match (ids.dtype, self.dtype) {
(DType::U32, DType::F32) => "gather_u32_f32",
(DType::U32, DType::F16) => "gather_u32_f16",
(DType::U32, DType::BF16) => "gather_u32_bf16",
(left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"),
};
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, src_l, dtype);
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
candle_metal_kernels::call_gather(
&device.device,
&command_buffer,
@ -1050,10 +1076,8 @@ impl BackendStorage for MetalStorage {
src_l.dims(),
ids_el,
dim,
&self.buffer,
src_l.start_offset() * dtype.size_in_bytes(),
&ids.buffer,
ids_o1 * ids.dtype.size_in_bytes(),
src,
ids,
&buffer,
)
.map_err(MetalError::from)?;
@ -1071,13 +1095,8 @@ impl BackendStorage for MetalStorage {
) -> Result<Self> {
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
self.copy_strided_src(&mut acc, 0, l)?;
let (ids_offset, _) = match ids_l.contiguous_offsets() {
Some(o12) => o12,
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
};
let src_offset = match src_l.contiguous_offsets() {
Some((o1, _)) => o1,
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
if !ids_l.is_contiguous() || !src_l.is_contiguous() {
return Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt());
};
let name = match (ids.dtype, self.dtype) {
(DType::U8, DType::F32) => "sa_u8_f32",
@ -1096,6 +1115,8 @@ impl BackendStorage for MetalStorage {
})?,
};
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&src.buffer, src_l, src.dtype);
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
candle_metal_kernels::call_scatter_add(
&self.device.device,
&command_buffer,
@ -1104,10 +1125,8 @@ impl BackendStorage for MetalStorage {
src_l.dims(),
l.dims(),
dim,
&src.buffer,
src_offset * src.dtype.size_in_bytes(),
&ids.buffer,
ids_offset * ids.dtype.size_in_bytes(),
src,
ids,
&acc.buffer,
)
.map_err(MetalError::from)?;
@ -1143,6 +1162,8 @@ impl BackendStorage for MetalStorage {
}
};
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, src_l, dtype);
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
candle_metal_kernels::call_index_select(
&device.device,
&command_buffer,
@ -1154,10 +1175,8 @@ impl BackendStorage for MetalStorage {
src_l.is_contiguous(),
src_l.dims(),
src_l.stride(),
&self.buffer,
src_l.start_offset() * dtype.size_in_bytes(),
&ids.buffer,
ids_l.start_offset() * ids.dtype.size_in_bytes(),
src,
ids,
&buffer,
)
.map_err(MetalError::from)?;
@ -1175,13 +1194,8 @@ impl BackendStorage for MetalStorage {
) -> Result<Self> {
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
self.copy_strided_src(&mut acc, 0, l)?;
let (ids_offset, _) = match ids_l.contiguous_offsets() {
Some(o12) => o12,
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
};
let src_offset = match src_l.contiguous_offsets() {
Some((o1, _)) => o1,
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
if !ids_l.is_contiguous() || !src_l.is_contiguous() {
return Err(crate::Error::RequiresContiguous { op: "index-add" }.bt());
};
let name = match (ids.dtype, self.dtype) {
(DType::I64, DType::BF16) => "ia_i64_bf16",
@ -1212,6 +1226,8 @@ impl BackendStorage for MetalStorage {
})?,
};
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&src.buffer, src_l, src.dtype);
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
candle_metal_kernels::call_index_add(
&self.device.device,
&command_buffer,
@ -1221,10 +1237,8 @@ impl BackendStorage for MetalStorage {
l.dims(),
ids_l.dims(),
dim,
&src.buffer,
src_offset * src.dtype.size_in_bytes(),
&ids.buffer,
ids_offset * ids.dtype.size_in_bytes(),
src,
ids,
&acc.buffer,
)
.map_err(MetalError::from)?;
@ -1358,17 +1372,20 @@ impl BackendStorage for MetalStorage {
DType::U8 => candle_metal_kernels::unary::strided::copy::U8,
dtype => crate::bail!("Metal copy_strided {dtype:?} not implemented"),
};
let src = buffer_o(&self.buffer, src_l, self.dtype);
let dst = BufferOffset {
buffer: &dst.buffer,
offset_in_bytes: dst_offset * dst.dtype.size_in_bytes(),
};
candle_metal_kernels::call_unary_strided(
&self.device.device,
&command_buffer,
&self.device.kernels,
kernel_name,
src_l.dims(),
&self.buffer,
src,
src_l.stride(),
src_l.start_offset() * self.dtype.size_in_bytes(),
&dst.buffer,
dst_offset * dst.dtype.size_in_bytes(),
dst,
)
.map_err(MetalError::from)?;
command_buffer.set_label("copy_strided");
@ -1402,10 +1419,9 @@ impl MetalStorage {
let shape = lhs_l.shape();
let el_count = shape.elem_count();
let command_buffer = device.command_buffer()?;
let (buffer, dtype) = if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
&& &op[..1] != "b"
{
let lhs = buffer_o(&self.buffer, lhs_l, self.dtype);
let rhs = buffer_o(&rhs.buffer, rhs_l, rhs.dtype);
let (buffer, dtype) = if lhs_l.is_contiguous() && rhs_l.is_contiguous() && &op[..1] != "b" {
use candle_metal_kernels::binary::contiguous;
let (kernel_name, dtype) = match (op, self.dtype) {
@ -1486,8 +1502,8 @@ impl MetalStorage {
&device.kernels,
kernel_name,
el_count,
&self.buffer,
&rhs.buffer,
lhs,
rhs,
&buffer,
)
.map_err(MetalError::from)?;
@ -1585,12 +1601,10 @@ impl MetalStorage {
&device.kernels,
kernel_name,
lhs_l.dims(),
&self.buffer,
lhs,
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&rhs.buffer,
rhs,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
@ -1796,6 +1810,10 @@ impl BackendDevice for MetalDevice {
Ok(())
}
fn synchronize(&self) -> Result<()> {
self.wait_until_completed()
}
}
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {

View File

@ -40,6 +40,7 @@ fn quantize_q8_1(
src: &CudaView<f32>,
dst: &mut CudaSlice<u8>,
elem_count: usize,
ky: usize,
dev: &CudaDevice,
) -> Result<()> {
use cudarc::driver::LaunchAsync;
@ -49,7 +50,7 @@ fn quantize_q8_1(
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks as u32, 1, 1),
grid_dim: (num_blocks as u32, ky as u32, 1),
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
shared_mem_bytes: 0,
};
@ -165,6 +166,7 @@ fn mul_mat_vec_via_q8_1(
dtype: GgmlDType,
ncols: usize,
nrows: usize,
b_size: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
@ -173,14 +175,18 @@ fn mul_mat_vec_via_q8_1(
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
}
if y.len() != ncols {
if y.len() != ncols * b_size {
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
}
if b_size == 0 || b_size > 8 {
crate::bail!("only bsize between 1 and 8 are supported, got {b_size}")
}
// Start by quantizing y
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
let y_size_in_bytes = ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let y_size_in_bytes =
b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
quantize_q8_1(y, &mut y_q8_1, ncols, dev)?;
quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?;
let kernel_name = match dtype {
GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda",
@ -195,11 +201,19 @@ fn mul_mat_vec_via_q8_1(
GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda",
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
let kernel_name = format!("{kernel_name}{b_size}");
let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
let (nblocks, nwarps) = match b_size {
1 => (nrows as u32, 4),
2..=4 => ((nrows as u32 + 1) / 2, 4),
5..=8 => ((nrows as u32 + 1) / 2, 2),
_ => crate::bail!("unexpected bsize {b_size}"),
};
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (nrows as u32, 1, 1),
block_dim: (WARP_SIZE as u32, 4, 1),
grid_dim: (nblocks, 1, 1),
block_dim: (WARP_SIZE as u32, nwarps, 1),
shared_mem_bytes: 0,
};
@ -209,13 +223,83 @@ fn mul_mat_vec_via_q8_1(
&dst,
/* ncols_x */ ncols as i32,
/* nrows_x */ nrows as i32,
/* nrows_y */ ncols as i32,
/* nrows_y */ ncols_padded as i32,
/* nrows_dst */ nrows as i32,
);
unsafe { func.launch(cfg, params) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
#[allow(clippy::too_many_arguments)]
fn mul_mat_via_q8_1(
data: &CudaSlice<u8>,
y: &CudaView<f32>,
dtype: GgmlDType,
x_rows: usize,
x_cols: usize,
y_rows: usize,
y_cols: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
if data_elems < x_rows * x_cols {
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
}
if y.len() != y_rows * y_cols {
crate::bail!("unexpected y size {}, {y_rows} {y_cols}", y.len())
}
if x_cols != y_rows {
crate::bail!("unexpected x/y size {x_rows} {x_cols} {y_rows} {y_cols}")
}
let k = x_cols;
// Start by quantizing y
let k_padded = pad(k, MATRIX_ROW_PADDING);
let y_size_in_bytes =
k_padded * y_rows * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
let (kernel_name, mmq_x, mmq_y) = match dtype {
GgmlDType::Q4_0 => ("mul_mat_q4_0", 64, 128),
GgmlDType::Q4_1 => ("mul_mat_q4_1", 64, 128),
GgmlDType::Q5_0 => ("mul_mat_q5_0", 128, 64),
GgmlDType::Q5_1 => ("mul_mat_q5_1", 128, 64),
GgmlDType::Q8_0 => ("mul_mat_q8_0", 128, 64),
GgmlDType::Q2K => ("mul_mat_q2_K", 64, 128),
GgmlDType::Q3K => ("mul_mat_q3_K", 128, 128),
GgmlDType::Q4K => ("mul_mat_q4_K", 64, 128),
GgmlDType::Q5K => ("mul_mat_q5_K", 64, 128),
GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64),
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (
ceil_div(x_rows, mmq_y) as u32,
ceil_div(y_cols, mmq_x) as u32,
1,
),
block_dim: (WARP_SIZE as u32, 4, 1),
shared_mem_bytes: 0,
};
let params = (
/* vx */ data,
/* vy */ &y_q8_1,
/* dst */ &dst,
/* ncols_x */ x_cols as i32,
/* nrows_x */ x_rows as i32,
/* ncols_y */ y_cols as i32,
/* nrows_y */ k_padded as i32,
/* nrows_dst */ x_rows as i32,
);
unsafe { func.launch(cfg, params) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
impl QCudaStorage {
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
@ -313,7 +397,17 @@ impl QCudaStorage {
storage: &CudaStorage,
layout: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
if matches!(layout.shape().dims(), [1, 1, _] | [1, _]) {
let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
1
} else {
8
};
let use_vec_kernel = match layout.shape().dims() {
[b, m, _k] => b * m <= max_bm,
[b, _k] => *b <= max_bm,
_ => false,
};
if use_vec_kernel {
self.dequantize_matmul_vec(self_shape, storage, layout)
} else {
self.dequantize_matmul(self_shape, storage, layout)
@ -334,25 +428,31 @@ impl QCudaStorage {
Some((o1, o2)) => rhs.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
};
let (with_batch, k) = match rhs_l.shape().dims() {
[1, 1, k] => (true, k),
[1, k] => (false, k),
let (b_size, k) = match rhs_l.shape().dims() {
[b, m, k] => (b * m, *k),
[b, k] => (*b, *k),
_ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()),
};
if ncols != *k {
if ncols != k {
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
}
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
} else {
mul_mat_vec_via_q8_1(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
};
let out_shape = if with_batch {
vec![1, 1, nrows]
} else {
vec![1, nrows]
mul_mat_vec_via_q8_1(
&self.data,
&rhs,
self.dtype,
ncols,
nrows,
b_size,
self.device(),
)?
};
let mut out_shape = rhs_l.shape().dims().to_vec();
out_shape.pop();
out_shape.push(nrows);
Ok((out, out_shape.into()))
}
@ -373,9 +473,30 @@ impl QCudaStorage {
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", layout.shape())
}
let data_f32 = self.dequantize(n * k)?;
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?;
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
let data_f32 = self.dequantize(n * k)?;
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?
} else {
let storage = storage.as_cuda_slice::<f32>()?;
let storage = match layout.contiguous_offsets() {
Some((o1, o2)) => storage.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous {
op: "quantized-matmul",
}
.bt())?,
};
mul_mat_via_q8_1(
&self.data,
&storage,
self.dtype,
/* x_rows */ n,
/* x_cols */ k,
/* y_rows */ k,
/* y_cols */ b * m,
self.device(),
)?
};
let mut out_shape = layout.shape().dims().to_vec();
out_shape.pop();
out_shape.push(n);
@ -412,7 +533,7 @@ mod test {
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
let y = dev.htod_sync_copy(&vs).w()?;
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, &dev)?;
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
Ok(())
}
@ -430,6 +551,7 @@ mod test {
/* dtype */ GgmlDType::Q4_0,
/* ncols */ ncols,
/* nrows */ 1,
/* b_size */ 1,
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
@ -453,4 +575,44 @@ mod test {
assert_eq!(vs[0], 5561851.0);
Ok(())
}
#[test]
fn cuda_mm_q8_1() -> Result<()> {
let dev = CudaDevice::new(0)?;
let ncols = 256;
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
let y = dev.htod_sync_copy(&vs).w()?;
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_via_q8_1(
&xs.data,
&y.slice(..),
/* dtype */ GgmlDType::Q4_0,
/* x_rows */ 4,
/* x_cols */ ncols,
/* y_rows */ ncols,
/* y_cols */ 4,
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
/*
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
x @ x.t() / 16
tensor([[ 347480.0000, 869720.0000, 1391960.0000, 1914200.0000],
[ 869720.0000, 2440536.0000, 4011352.0000, 5582166.5000],
[ 1391960.0000, 4011352.0000, 6630742.0000, 9250132.0000],
[ 1914200.0000, 5582166.5000, 9250132.0000, 12918099.0000]])
*/
assert_eq!(vs.len(), 16);
assert_eq!(vs[0], 347604.0);
assert_eq!(vs[1], 888153.06);
assert_eq!(vs[4], 869780.7);
assert_eq!(vs[5], 2483145.0);
assert_eq!(vs[11], 9407368.0);
assert_eq!(vs[14], 9470856.0);
assert_eq!(vs[15], 13138824.0);
Ok(())
}
}

View File

@ -149,8 +149,11 @@ impl QMetalStorage {
let (n, k) = self_shape.dims2()?;
let mut dst_shape = src_shape.dims().to_vec();
// We always use a single batch dimension and stack all the tensors in the batch on the
// second dimension as the implementation in candle-metal-kernels doesn't handle batch
// properly.
let (b, m) = match dst_shape.len() {
3 => (dst_shape[0], dst_shape[1]),
3 => (1, dst_shape[0] * dst_shape[1]),
2 => (1, dst_shape[0]),
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
};

View File

@ -79,6 +79,9 @@ macro_rules! unary_op {
($fn_name:ident, $op_name:ident) => {
pub fn $fn_name(&self) -> Result<Self> {
let shape = self.shape();
if shape.elem_count() == 0 {
return Ok(self.clone());
}
let storage = self
.storage()
.unary_impl::<crate::op::$op_name>(self.layout())?;
@ -92,6 +95,9 @@ macro_rules! binary_op {
($fn_name:ident, $op_name:ident) => {
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
if shape.elem_count() == 0 {
return Ok(self.clone());
}
let storage = self.storage().binary_impl::<crate::op::$op_name>(
&*rhs.storage(),
self.layout(),
@ -114,6 +120,9 @@ macro_rules! binary_op_scalar {
.broadcast_as(self.shape())?,
};
let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
if self.elem_count() == 0 {
return Ok(self.clone());
}
let storage = self.storage().binary_impl::<crate::op::$op_name>(
&*rhs.storage(),
self.layout(),
@ -646,6 +655,9 @@ impl Tensor {
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
if self.elem_count() == 0 {
return Ok(self.clone());
}
let storage = self.storage().affine(self.layout(), mul, add)?;
let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
Ok(from_storage(storage, self.shape(), op, false))
@ -653,6 +665,9 @@ impl Tensor {
/// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
pub fn elu(&self, alpha: f64) -> Result<Self> {
if self.elem_count() == 0 {
return Ok(self.clone());
}
let storage = self.storage().elu(self.layout(), alpha)?;
let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
Ok(from_storage(storage, self.shape(), op, false))
@ -660,6 +675,9 @@ impl Tensor {
/// Raise the tensor to some float exponent `e`.
pub fn powf(&self, e: f64) -> Result<Self> {
if self.elem_count() == 0 {
return Ok(self.clone());
}
let storage = self.storage().powf(self.layout(), e)?;
let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
Ok(from_storage(storage, self.shape(), op, false))
@ -1154,6 +1172,9 @@ impl Tensor {
let n = b_dims[dim - 1];
let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
if c_shape.elem_count() == 0 || k == 0 {
return Tensor::zeros(c_shape, self.dtype(), self.device());
}
let batching: usize = a_dims[..dim - 2].iter().product();
let batching_b: usize = b_dims[..dim - 2].iter().product();
if k != k2 || batching != batching_b {

View File

@ -3,7 +3,7 @@ use candle_core::{
quantized::{self, GgmlDType},
test_device,
test_utils::to_vec2_round,
Device, Module, Result, Tensor,
Device, IndexOp, Module, Result, Tensor,
};
use quantized::{k_quants, GgmlType};
use rand::prelude::*;
@ -47,18 +47,14 @@ fn test_matmul(
}
fn quantized_matmul(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let (m, k, n) = (3, 64, 4);
let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
let lhs_s = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
let lhs = Tensor::from_slice(&lhs_s, (m, k), device)?;
let mut dst = vec![42.; 3 * 4];
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
k_quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?;
assert_eq!(
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
&[
@ -67,7 +63,7 @@ fn quantized_matmul(device: &Device) -> Result<()> {
]
);
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
let mm = tensor_lhs.matmul(&tensor_rhs)?;
let mm = lhs.matmul(&tensor_rhs)?;
assert_eq!(
mm.to_vec2::<f32>()?,
&[
@ -79,7 +75,7 @@ fn quantized_matmul(device: &Device) -> Result<()> {
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?;
let res = matmul.forward(&lhs)?;
match device {
Device::Metal(_) => assert_eq!(
to_vec2_round(&res, 0)?,
@ -89,7 +85,15 @@ fn quantized_matmul(device: &Device) -> Result<()> {
[341970.0, 994574.0, 1656181.0, 2302182.0]
]
),
_ => assert_eq!(
Device::Cuda(_) => assert_eq!(
to_vec2_round(&res, 0)?,
&[
[84866.0, 214045.0, 344676.0, 473707.0],
[213425.0, 604313.0, 1000431.0, 1387960.0],
[342030.0, 994630.0, 1656248.0, 2302250.0]
]
),
Device::Cpu => assert_eq!(
to_vec2_round(&res, 0)?,
&[
[85120.0, 214562.0, 345455.0, 474748.0],
@ -98,22 +102,16 @@ fn quantized_matmul(device: &Device) -> Result<()> {
]
),
}
test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?;
Ok(())
}
fn quantized_matmul_neg(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let (m, k, n) = (3, 64, 4);
let lhs = (0..(m * k))
let lhs_s = (0..(m * k))
.map(|v| v as f32 - (m * k) as f32 / 2.0)
.collect::<Vec<_>>();
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
let lhs = Tensor::from_slice(&lhs_s, (m, k), device)?;
let mut dst = vec![42.; 3 * 4];
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
let rhs = (0..k * n)
@ -121,7 +119,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
.collect::<Vec<_>>();
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
k_quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?;
assert_eq!(
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
&[
@ -129,7 +127,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
-196472.0, 63012.0, 324585.0, 587902.0
]
);
let mm = tensor_lhs.matmul(&tensor_rhs)?;
let mm = lhs.matmul(&tensor_rhs)?;
assert_eq!(
to_vec2_round(&mm, 0)?,
&[
@ -141,7 +139,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?;
let res = matmul.forward(&lhs)?;
match device {
Device::Metal(_) => assert_eq!(
to_vec2_round(&res, 0)?,
@ -151,7 +149,15 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
[-196102.0, 63022.0, 324233.0, 587191.0]
]
),
_ => assert_eq!(
Device::Cuda(_) => assert_eq!(
to_vec2_round(&res, 0)?,
&[
[243740.0, -19762.0, -285476.0, -550498.0],
[23774.0, 21645.0, 19395.0, 18364.0],
[-196045.0, 63030.0, 324120.0, 587079.0]
]
),
Device::Cpu => assert_eq!(
to_vec2_round(&res, 0)?,
&[
[243524.0, -19596.0, -285051.0, -549815.0],
@ -160,22 +166,58 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
]
),
}
let lhs2 = Tensor::stack(&[&lhs, &lhs], 0)?;
let res2 = matmul.forward(&lhs2)?;
let res2 = res2.i(1)?;
let diff = (res - res2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
if device.is_cuda() {
assert!(diff < 0.1);
} else {
assert_eq!(diff, 0.);
}
Ok(())
}
test_device!(
quantized_matmul,
quantized_matmul_cpu,
quantized_matmul_cuda,
quantized_matmul_metal
);
test_device!(
quantized_matmul_neg,
quantized_matmul_neg_cpu,
quantized_matmul_neg_cuda,
quantized_matmul_neg_metal
);
fn qmm_batch(dev: &Device) -> Result<()> {
let (lhs, rhs, _mm) = get_random_tensors(2, 256, 6, dev)?;
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.shape().dims(), [2, 6]);
let lhs2 = Tensor::cat(&[&lhs, &lhs], 0)?;
let mm2 = rhs.forward(&lhs2)?;
assert_eq!(mm2.shape().dims(), [4, 6]);
let diff2 = (mm2.i(2..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff2, 0.0);
let lhs3 = Tensor::cat(&[&lhs2, &lhs], 0)?;
let mm3 = rhs.forward(&lhs3)?;
assert_eq!(mm3.shape().dims(), [6, 6]);
let diff3 = (mm3.i(2..4)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff3, 0.0);
let diff3 = (mm3.i(4..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff3, 0.0);
let lhs4 = Tensor::cat(&[&lhs3, &lhs3], 0)?;
let mm4 = rhs.forward(&lhs4)?;
assert_eq!(mm4.shape().dims(), [12, 6]);
let diff4 = (mm4.i(..6)? - &mm3)?.abs()?.sum_all()?.to_vec0::<f32>()?;
if dev.is_cuda() {
// We use a different kernel for sizes from 1 to 8 on cuda which explains
// the difference here.
assert!(0. < diff4 && diff4 < 1e-4)
} else {
assert_eq!(diff4, 0.0)
};
let diff4 = (mm4.i(6..)? - &mm4.i(..6)?)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff4, 0.0);
Ok(())
}
test_device!(quantized_matmul, qmm_cpu, qmm_cuda, qmm_metal);
test_device!(quantized_matmul_neg, qmm_n_cpu, qmm_n_cuda, qmm_n_metal);
test_device!(qmm_batch, qmm_b_cpu, qmm_b_cuda, qmm_b_metal);
fn quantize_q4_0(device: &Device) -> Result<()> {
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();

View File

@ -1083,6 +1083,27 @@ fn randn(device: &Device) -> Result<()> {
Ok(())
}
fn zero_dim(device: &Device) -> Result<()> {
let t = Tensor::zeros((4, 0, 1), DType::F32, device)?;
assert_eq!(t.dims3()?, (4, 0, 1));
let t2 = Tensor::zeros((4, 3, 1), DType::F32, device)?;
let t_cat = Tensor::cat(&[&t, &t2], 1)?;
assert_eq!(t_cat.dims3()?, (4, 3, 1));
let t_cat = Tensor::cat(&[&t, &t], 1)?;
assert_eq!(t_cat.dims3()?, (4, 0, 1));
let t_unary = t.sqrt()?;
assert_eq!(t_unary.dims3()?, (4, 0, 1));
let t_plus = (&t + 1.)?;
assert_eq!(t_plus.dims3()?, (4, 0, 1));
let t_mm = t2.matmul(&t.t()?)?;
assert_eq!(t_mm.dims3()?, (4, 3, 0));
let t_mm = t.matmul(&t2.t()?)?;
assert_eq!(t_mm.dims3()?, (4, 0, 3));
let t_mm = t.t()?.matmul(&t)?;
assert_eq!(t_mm.dims3()?, (4, 1, 1));
Ok(())
}
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
test_device!(full, full_cpu, full_gpu, full_metal);
@ -1131,6 +1152,7 @@ test_device!(
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
test_device!(var, var_cpu, var_gpu, var_metal);
test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal);
// There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381

View File

@ -16,6 +16,30 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "2b")]
Base2B,
#[value(name = "7b")]
Base7B,
#[value(name = "2b-it")]
Instruct2B,
#[value(name = "7b-it")]
Instruct7B,
#[value(name = "1.1-2b-it")]
InstructV1_1_2B,
#[value(name = "1.1-7b-it")]
InstructV1_1_7B,
#[value(name = "code-2b")]
CodeBase2B,
#[value(name = "code-7b")]
CodeBase7B,
#[value(name = "code-2b-it")]
CodeInstruct2B,
#[value(name = "code-7b-it")]
CodeInstruct7B,
}
struct TextGeneration {
model: Model,
device: Device,
@ -165,6 +189,10 @@ struct Args {
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model to use.
#[arg(long, default_value = "2b")]
which: Which,
}
fn main() -> Result<()> {
@ -196,14 +224,19 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => match model_id.as_str() {
"7b-it" => "google/gemma-7b-it".to_string(),
"7b" => "google/gemma-7b".to_string(),
"2b-it" => "google/gemma-2b-it".to_string(),
"2b" => "google/gemma-2b".to_string(),
_ => model_id.to_string(),
Some(model_id) => model_id.to_string(),
None => match args.which {
Which::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
Which::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
Which::Base2B => "google/gemma-2b".to_string(),
Which::Base7B => "google/gemma-7b".to_string(),
Which::Instruct2B => "google/gemma-2b-it".to_string(),
Which::Instruct7B => "google/gemma-7b-it".to_string(),
Which::CodeBase2B => "google/codegemma-2b".to_string(),
Which::CodeBase7B => "google/codegemma-7b".to_string(),
Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
},
None => "google/gemma-2b".to_string(),
};
let repo = api.repo(Repo::with_revision(
model_id,

View File

@ -31,6 +31,8 @@ const DEFAULT_PROMPT: &str = "My favorite theorem is ";
enum Which {
V1,
V2,
V3,
V3Instruct,
#[value(name = "solar-10.7b")]
Solar10_7B,
#[value(name = "tiny-llama-1.1b-chat")]
@ -45,8 +47,8 @@ struct Args {
cpu: bool,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
#[arg(long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
@ -90,11 +92,11 @@ struct Args {
use_flash_attn: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.0)]
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
#[arg(long, default_value_t = 128)]
repeat_last_n: usize,
}
@ -118,13 +120,18 @@ fn main() -> Result<()> {
Some("bf16") => DType::BF16,
Some("f32") => DType::F32,
Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::F16,
None => match args.which {
Which::V3 | Which::V3Instruct => DType::BF16,
Which::V1 | Which::V2 | Which::Solar10_7B | Which::TinyLlama1_1BChat => DType::F16,
},
};
let (llama, tokenizer_filename, mut cache) = {
let (llama, tokenizer_filename, mut cache, config) = {
let api = Api::new()?;
let model_id = args.model_id.unwrap_or_else(|| match args.which {
Which::V1 => "Narsil/amall-7b".to_string(),
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
});
@ -138,7 +145,7 @@ fn main() -> Result<()> {
let config = config.into_config(args.use_flash_attn);
let filenames = match args.which {
Which::V1 | Which::V2 | Which::Solar10_7B => {
Which::V1 | Which::V2 | Which::V3 | Which::V3Instruct | Which::Solar10_7B => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
}
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
@ -146,10 +153,12 @@ fn main() -> Result<()> {
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
(Llama::load(vb, &config)?, tokenizer_filename, cache)
(Llama::load(vb, &config)?, tokenizer_filename, cache, config)
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
let eos_token_id = config
.eos_token_id
.or_else(|| tokenizer.token_to_id(EOS_TOKEN));
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
let mut tokens = tokenizer
.encode(prompt, true)
@ -160,7 +169,7 @@ fn main() -> Result<()> {
println!("starting the inference loop");
print!("{prompt}");
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), args.top_p);
let start_gen = std::time::Instant::now();
let mut index_pos = 0;
let mut token_generated = 0;

View File

@ -54,6 +54,7 @@ impl TextGeneration {
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let dtype = self.model.dtype();
let mut tokens = self
.tokenizer
.tokenizer()
@ -66,7 +67,7 @@ impl TextGeneration {
Some(token) => token,
None => anyhow::bail!("cannot find the </s> token"),
};
let mut state = State::new(1, &self.config, &self.device)?;
let mut state = State::new(1, &self.config, dtype, &self.device)?;
let mut next_logits = None;
for &t in tokens.iter() {
let input = Tensor::new(&[t], &self.device)?;
@ -84,7 +85,7 @@ impl TextGeneration {
Some(logits) => logits,
None => anyhow::bail!("cannot work on an empty prompt"),
};
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let logits = logits.squeeze(0)?.to_dtype(dtype)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
@ -210,6 +211,9 @@ struct Args {
#[arg(long)]
config_file: Option<String>,
#[arg(long, default_value = "f32")]
dtype: String,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
@ -220,6 +224,7 @@ struct Args {
}
fn main() -> Result<()> {
use std::str::FromStr;
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
@ -279,7 +284,8 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let dtype = DType::from_str(&args.dtype)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb.pp("backbone"))?;
println!("loaded the model in {:?}", start.elapsed());

View File

@ -123,7 +123,7 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
if next_token == eos_token || tokens.ends_with(&[27, 10619, 29] /* <END> */) {
break;
}
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;

View File

@ -67,6 +67,8 @@ enum Which {
Mixtral,
#[value(name = "mixtral-instruct")]
MixtralInstruct,
#[value(name = "llama3-8b")]
L8b,
}
impl Which {
@ -82,7 +84,8 @@ impl Which {
| Self::L13bCode
| Self::L34bCode
| Self::Leo7b
| Self::Leo13b => false,
| Self::Leo13b
| Self::L8b => false,
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
// same way. Starling is a fine tuned version of OpenChat.
Self::OpenChat35
@ -116,7 +119,8 @@ impl Which {
| Self::Mistral7bInstruct
| Self::Mistral7bInstructV02
| Self::OpenChat35
| Self::Starling7bAlpha => false,
| Self::Starling7bAlpha
| Self::L8b => false,
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
}
}
@ -140,7 +144,8 @@ impl Which {
| Self::Mistral7bInstruct
| Self::Mistral7bInstructV02
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta => false,
| Self::Zephyr7bBeta
| Self::L8b => false,
Self::OpenChat35 | Self::Starling7bAlpha => true,
}
}
@ -167,6 +172,7 @@ impl Which {
| Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
Which::OpenChat35 => "openchat/openchat_3.5",
Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
Self::L8b => "meta-llama/Meta-Llama-3-8B",
}
}
}
@ -322,6 +328,11 @@ impl Args {
"TheBloke/Starling-LM-7B-alpha-GGUF",
"starling-lm-7b-alpha.Q4_K_M.gguf",
),
// TODO: swap to TheBloke model when available
Which::L8b => (
"QuantFactory/Meta-Llama-3-8B-GGUF",
"Meta-Llama-3-8B.Q4_K_S.gguf",
),
};
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(repo.to_string());
@ -420,7 +431,8 @@ fn main() -> anyhow::Result<()> {
| Which::L13bCode
| Which::L34bCode
| Which::Leo7b
| Which::Leo13b => 1,
| Which::Leo13b
| Which::L8b => 1,
Which::Mixtral
| Which::MixtralInstruct
| Which::Mistral7b
@ -537,11 +549,14 @@ fn main() -> anyhow::Result<()> {
std::io::stdout().flush()?;
}
let eos_token = if args.which.is_open_chat() {
"<|end_of_turn|>"
} else {
"</s>"
let eos_token = match args.which {
Which::L8b => "<|end_of_text|>",
_ => match args.which.is_open_chat() {
true => "<|end_of_turn|>",
false => "</s>",
},
};
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
let start_post_prompt = std::time::Instant::now();
let mut sampled = 0;

View File

@ -0,0 +1,9 @@
# candle-recurrent-gemma
This model card corresponds to the 2B base version of the RecurrentGemma model
[huggingface model card](https://huggingface.co/google/recurrentgemma-2b).
```bash
cargo run --features cuda -r --example recurrent-gemma -- \
--prompt "Write me a poem about Machine Learning."
```

View File

@ -0,0 +1,321 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::quantized_recurrent_gemma::Model as QModel;
use candle_transformers::models::recurrent_gemma::{Config, Model as BModel};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
enum Model {
B(BModel),
Q(QModel),
}
impl Model {
fn forward(&mut self, xs: &Tensor, pos: usize) -> candle::Result<Tensor> {
match self {
Self::B(m) => m.forward(xs, pos),
Self::Q(m) => m.forward(xs, pos),
}
}
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "2b")]
Base2B,
#[value(name = "2b-it")]
Instruct2B,
}
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
top_k: usize,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let sampling = match temp {
None => candle_transformers::generation::Sampling::ArgMax,
Some(temperature) => match top_p {
None => candle_transformers::generation::Sampling::TopK {
temperature,
k: top_k,
},
Some(top_p) => candle_transformers::generation::Sampling::TopKThenTopP {
temperature,
k: top_k,
p: top_p,
},
},
};
let logits_processor = LogitsProcessor::from_sampling(seed, sampling);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
#[arg(long, default_value_t = 250)]
top_k: usize,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 8000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model to use.
#[arg(long, default_value = "2b")]
which: Which,
#[arg(long)]
quantized: bool,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => match args.which {
Which::Base2B => "google/recurrentgemma-2b".to_string(),
Which::Instruct2B => "google/recurrentgemma-2b-it".to_string(),
},
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
if args.quantized {
let filename = match args.which {
Which::Base2B => "recurrent-gemma-2b-q4k.gguf",
Which::Instruct2B => "recurrent-gemma-7b-q4k.gguf",
};
let filename = api.model("lmz/candle-gemma".to_string()).get(filename)?;
vec![filename]
} else {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
}
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let model = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
&filenames[0],
&device,
)?;
Model::Q(QModel::new(&config, vb.pp("model"))?)
} else {
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
Model::B(BModel::new(&config, vb.pp("model"))?)
};
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.top_k,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -46,7 +46,8 @@ The default scheduler for the XL Turbo version is the Euler Ancestral scheduler.
- `--cpu`: use the cpu rather than the gpu (much slower).
- `--height`, `--width`: set the height and width for the generated image.
- `--n-steps`: the number of steps to be used in the diffusion process.
- `--num-samples`: the number of samples to generate.
- `--num-samples`: the number of samples to generate iteratively.
- `--bsize`: the numbers of samples to generate simultaneously.
- `--final-image`: the filename for the generated image(s).
### Using flash-attention

View File

@ -9,6 +9,7 @@ use candle_transformers::models::stable_diffusion;
use anyhow::{Error as E, Result};
use candle::{DType, Device, IndexOp, Module, Tensor, D};
use clap::Parser;
use stable_diffusion::vae::AutoEncoderKL;
use tokenizers::Tokenizer;
#[derive(Parser)]
@ -64,9 +65,13 @@ struct Args {
#[arg(long)]
n_steps: Option<usize>,
/// The number of samples to generate.
/// The number of samples to generate iteratively.
#[arg(long, default_value_t = 1)]
num_samples: i64,
num_samples: usize,
/// The numbers of samples to generate simultaneously.
#[arg[long, default_value_t = 1]]
bsize: usize,
/// The name of the final image to generate.
#[arg(long, value_name = "FILE", default_value = "sd_final.png")]
@ -236,8 +241,8 @@ impl ModelFile {
fn output_filename(
basename: &str,
sample_idx: i64,
num_samples: i64,
sample_idx: usize,
num_samples: usize,
timestep_idx: Option<usize>,
) -> String {
let filename = if num_samples > 1 {
@ -261,6 +266,33 @@ fn output_filename(
}
}
#[allow(clippy::too_many_arguments)]
fn save_image(
vae: &AutoEncoderKL,
latents: &Tensor,
vae_scale: f64,
bsize: usize,
idx: usize,
final_image: &str,
num_samples: usize,
timestep_ids: Option<usize>,
) -> Result<()> {
let images = vae.decode(&(latents / vae_scale)?)?;
let images = ((images / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let images = (images.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?;
for batch in 0..bsize {
let image = images.i(batch)?;
let image_filename = output_filename(
final_image,
(bsize * idx) + batch + 1,
batch + num_samples,
timestep_ids,
);
candle_examples::save_image(&image, image_filename)?;
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn text_embeddings(
prompt: &str,
@ -382,6 +414,7 @@ fn run(args: Args) -> Result<()> {
final_image,
sliced_attention_size,
num_samples,
bsize,
sd_version,
clip_weights,
vae_weights,
@ -475,6 +508,7 @@ fn run(args: Args) -> Result<()> {
.collect::<Result<Vec<_>>>()?;
let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
let text_embeddings = text_embeddings.repeat((bsize, 1, 1))?;
println!("{text_embeddings:?}");
println!("Building the autoencoder.");
@ -496,7 +530,6 @@ fn run(args: Args) -> Result<()> {
} else {
0
};
let bsize = 1;
let vae_scale = match sd_version {
StableDiffusionVersion::V1_5
@ -560,12 +593,16 @@ fn run(args: Args) -> Result<()> {
println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);
if args.intermediary_images {
let image = vae.decode(&(&latents / vae_scale)?)?;
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename =
output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1));
candle_examples::save_image(&image, image_filename)?
save_image(
&vae,
&latents,
vae_scale,
bsize,
idx,
&final_image,
num_samples,
Some(timestep_index + 1),
)?;
}
}
@ -574,11 +611,16 @@ fn run(args: Args) -> Result<()> {
idx + 1,
num_samples
);
let image = vae.decode(&(&latents / vae_scale)?)?;
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
candle_examples::save_image(&image, image_filename)?
save_image(
&vae,
&latents,
vae_scale,
bsize,
idx,
&final_image,
num_samples,
None,
)?;
}
Ok(())
}

View File

@ -115,7 +115,7 @@ pub fn main() -> anyhow::Result<()> {
let processor = image_processor::ViTImageProcessor::new(&processor_config);
let image = vec![args.image.as_str()];
let image = processor.preprocess(image)?;
let image = processor.preprocess(image)?.to_device(&device)?;
let encoder_xs = model.encoder().forward(&image)?;

Binary file not shown.

After

Width:  |  Height:  |  Size: 175 KiB

File diff suppressed because it is too large Load Diff

View File

@ -207,6 +207,9 @@ INDEX_OP(is_u8_bf16, uint8_t, bfloat)
GATHER_OP(gather_u32_f32, uint, float)
GATHER_OP(gather_u32_f16, uint, half)
#if defined(__HAVE_BFLOAT__)
GATHER_OP(gather_u32_bf16, uint, bfloat)
#endif
SCATTER_ADD_OP(sa_u32_f32, uint32_t, float)
SCATTER_ADD_OP(sa_u8_f32, uint8_t, float)

View File

@ -1,11 +1,15 @@
use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
Buffer, CommandBufferRef, CompileOptions, ComputePipelineState, Device, Function,
FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
};
use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::RwLock;
mod utils;
pub use utils::BufferOffset;
use utils::{get_block_dims, linear_split};
const AFFINE: &str = include_str!("affine.metal");
const INDEXING: &str = include_str!("indexing.metal");
const UNARY: &str = include_str!("unary.metal");
@ -18,100 +22,6 @@ const RANDOM: &str = include_str!("random.metal");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const QUANTIZED: &str = include_str!("quantized.metal");
/// Most kernels apply similarly across the tensors
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
/// actual total buffer length).
/// Then kernels can just do their op on their single point in the buffer.
fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
let size = length as u64;
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
let count = (size + width - 1) / width;
let thread_group_count = MTLSize {
width: count,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
(thread_group_count, thread_group_size)
}
fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
<P as EncoderParam>::set_param(encoder, position, data)
}
/// Helper functions to create the various objects on the compute command encoder
/// on a single line.
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
trait EncoderParam {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
}
macro_rules! primitive {
($type:ty) => {
impl EncoderParam for $type {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_bytes(
position,
core::mem::size_of::<$type>() as u64,
&data as *const $type as *const c_void,
);
}
}
};
}
primitive!(bool);
primitive!(usize);
primitive!(i32);
primitive!(i64);
primitive!(u32);
primitive!(u64);
primitive!(f32);
impl<T> EncoderParam for &[T] {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_bytes(
position,
core::mem::size_of_val(data) as u64,
data.as_ptr() as *const c_void,
);
}
}
impl EncoderParam for &Buffer {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data), 0);
}
}
impl EncoderParam for (&Buffer, usize) {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data.0), data.1 as u64);
}
}
impl EncoderParam for &mut Buffer {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data), 0);
}
}
impl EncoderParam for (&mut Buffer, usize) {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data.0), data.1 as u64);
}
}
macro_rules! set_params {
($encoder:ident, ($($param:expr),+)) => (
let mut _index = 0;
$(
set_param($encoder, _index, $param);
_index += 1;
)*
);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
Affine,
@ -235,6 +145,12 @@ pub struct Kernels {
pipelines: RwLock<Pipelines>,
}
impl Default for Kernels {
fn default() -> Self {
Self::new()
}
}
impl Kernels {
pub fn new() -> Self {
let libraries = RwLock::new(Libraries::new());
@ -358,17 +274,17 @@ pub fn call_unary_contiguous(
kernels: &Kernels,
kernel_name: unary::contiguous::Kernel,
length: usize,
input: &Buffer,
input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, input, output));
set_params!(encoder, (length, &input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -396,21 +312,24 @@ pub fn call_copy2d(
set_params!(
encoder,
(
d1,
d2,
src_s,
dst_s,
d1 as i64,
d2 as i64,
src_s as i64,
dst_s as i64,
(input, src_o_in_bytes),
(output, dst_o_in_bytes)
)
);
let width: usize = d1 * d2;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width / 4);
let grid_dims = MTLSize {
width: d1 as u64,
height: d2 as u64,
depth: 1,
};
let group_dims = get_block_dims(d1 as u64, d2 as u64, 1);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.dispatch_threads(grid_dims, group_dims);
encoder.end_encoding();
Ok(())
}
@ -422,11 +341,9 @@ pub fn call_unary_strided(
kernels: &Kernels,
name: unary::strided::Kernel,
shape: &[usize],
input: &Buffer,
input: BufferOffset,
strides: &[usize],
offset: usize,
output: &Buffer,
output_offset: usize,
output: BufferOffset,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
@ -435,23 +352,13 @@ pub fn call_unary_strided(
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
set_params!(
encoder,
(
length,
num_dims,
shape,
strides,
(input, offset),
(output, output_offset)
)
);
set_params!(encoder, (length, num_dims, shape, strides, &input, &output));
let width: usize = shape.iter().product();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
@ -464,8 +371,8 @@ pub fn call_binary_contiguous(
kernels: &Kernels,
kernel_name: binary::contiguous::Kernel,
length: usize,
left: &Buffer,
right: &Buffer,
left: BufferOffset,
right: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
@ -473,12 +380,12 @@ pub fn call_binary_contiguous(
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, left, right, output));
set_params!(encoder, (length, &left, &right, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.use_resource(left, metal::MTLResourceUsage::Read);
encoder.use_resource(right, metal::MTLResourceUsage::Read);
encoder.use_resource(left.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -492,12 +399,10 @@ pub fn call_binary_strided(
kernels: &Kernels,
name: binary::strided::Kernel,
shape: &[usize],
left_input: &Buffer,
left_input: BufferOffset,
left_strides: &[usize],
left_offset: usize,
right_input: &Buffer,
right_input: BufferOffset,
right_strides: &[usize],
right_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
@ -517,16 +422,16 @@ pub fn call_binary_strided(
shape,
left_strides,
right_strides,
(left_input, left_offset),
(right_input, right_offset),
&left_input,
&right_input,
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
encoder.use_resource(left_input, metal::MTLResourceUsage::Read);
encoder.use_resource(right_input, metal::MTLResourceUsage::Read);
encoder.use_resource(left_input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -540,8 +445,7 @@ pub fn call_cast_contiguous(
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
input: &Buffer,
input_offset: usize,
input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
@ -549,10 +453,10 @@ pub fn call_cast_contiguous(
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, (input, input_offset), output));
set_params!(encoder, (length, &input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -566,9 +470,8 @@ pub fn call_cast_strided(
kernels: &Kernels,
kernel_name: &'static str,
shape: &[usize],
input: &Buffer,
input: BufferOffset,
input_strides: &[usize],
input_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
@ -580,25 +483,19 @@ pub fn call_cast_strided(
set_params!(
encoder,
(
length,
shape.len(),
shape,
input_strides,
(input, input_offset),
output
)
(length, shape.len(), shape, input_strides, &input, output)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_reduce_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
@ -606,8 +503,7 @@ pub fn call_reduce_contiguous(
kernel_name: &'static str,
length: usize,
out_length: usize,
input: &Buffer,
input_offset: usize,
input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
@ -616,10 +512,7 @@ pub fn call_reduce_contiguous(
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(length, elements_to_sum, (input, input_offset), output)
);
set_params!(encoder, (length, elements_to_sum, &input, output));
let thread_group_count = MTLSize {
width: out_length as u64,
@ -639,13 +532,14 @@ pub fn call_reduce_contiguous(
depth: 1,
};
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_reduce_strided(
device: &Device,
command_buffer: &CommandBufferRef,
@ -654,8 +548,7 @@ pub fn call_reduce_strided(
shape: &[usize],
strides: &[usize],
out_length: usize,
input: &Buffer,
input_offset: usize,
input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let length: usize = shape.iter().product();
@ -667,14 +560,7 @@ pub fn call_reduce_strided(
set_params!(
encoder,
(
shape.len(),
shape,
strides,
elements_to_sum,
(input, input_offset),
output
)
(shape.len(), shape, strides, elements_to_sum, &input, output)
);
let thread_group_count = MTLSize {
@ -695,7 +581,7 @@ pub fn call_reduce_strided(
depth: 1,
};
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -944,7 +830,7 @@ pub fn call_affine(
kernels: &Kernels,
name: &'static str,
size: usize,
input: &Buffer,
input: BufferOffset,
output: &Buffer,
mul: f32,
add: f32,
@ -954,10 +840,10 @@ pub fn call_affine(
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, add, input, output));
set_params!(encoder, (size, mul, add, &input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -971,9 +857,8 @@ pub fn call_affine_strided(
kernels: &Kernels,
name: &'static str,
shape: &[usize],
input: &Buffer,
input: BufferOffset,
input_stride: &[usize],
input_offset: usize,
output: &Buffer,
mul: f32,
add: f32,
@ -993,13 +878,13 @@ pub fn call_affine_strided(
input_stride,
mul,
add,
(input, input_offset),
&input,
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -1013,7 +898,7 @@ pub fn call_powf(
kernels: &Kernels,
name: &'static str,
size: usize,
input: &Buffer,
input: BufferOffset,
output: &Buffer,
mul: f32,
) -> Result<(), MetalKernelError> {
@ -1022,10 +907,10 @@ pub fn call_powf(
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, input, output));
set_params!(encoder, (size, mul, &input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -1039,9 +924,8 @@ pub fn call_powf_strided(
kernels: &Kernels,
name: &'static str,
shape: &[usize],
input: &Buffer,
input: BufferOffset,
input_stride: &[usize],
input_offset: usize,
output: &Buffer,
mul: f32,
) -> Result<(), MetalKernelError> {
@ -1053,19 +937,11 @@ pub fn call_powf_strided(
set_params!(
encoder,
(
size,
shape.len(),
shape,
input_stride,
mul,
(input, input_offset),
output
)
(size, shape.len(), shape, input_stride, mul, &input, output)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -1079,7 +955,7 @@ pub fn call_elu(
kernels: &Kernels,
name: &'static str,
size: usize,
input: &Buffer,
input: BufferOffset,
output: &Buffer,
mul: f32,
) -> Result<(), MetalKernelError> {
@ -1088,10 +964,10 @@ pub fn call_elu(
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, input, output));
set_params!(encoder, (size, mul, &input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -1105,9 +981,8 @@ pub fn call_elu_strided(
kernels: &Kernels,
name: &'static str,
shape: &[usize],
input: &Buffer,
input: BufferOffset,
input_stride: &[usize],
input_offset: usize,
output: &Buffer,
mul: f32,
) -> Result<(), MetalKernelError> {
@ -1119,37 +994,30 @@ pub fn call_elu_strided(
set_params!(
encoder,
(
size,
shape.len(),
shape,
input_stride,
mul,
(input, input_offset),
output
)
(size, shape.len(), shape, input_stride, mul, &input, output)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_where_cond_strided(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
cond: &Buffer,
(cond_stride, cond_offset): (&[usize], usize),
left: &Buffer,
(left_stride, left_offset): (&[usize], usize),
right: &Buffer,
(right_stride, right_offset): (&[usize], usize),
cond: BufferOffset,
cond_stride: &[usize],
left: BufferOffset,
left_stride: &[usize],
right: BufferOffset,
right_stride: &[usize],
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
@ -1169,18 +1037,18 @@ pub fn call_where_cond_strided(
cond_stride,
left_stride,
right_stride,
(cond, cond_offset),
(left, left_offset),
(right, right_offset),
&cond,
&left,
&right,
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.use_resource(cond, metal::MTLResourceUsage::Read);
encoder.use_resource(left, metal::MTLResourceUsage::Read);
encoder.use_resource(right, metal::MTLResourceUsage::Read);
encoder.use_resource(cond.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(left.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -1199,10 +1067,8 @@ pub fn call_index_select(
contiguous: bool,
src_dims: &[usize],
src_strides: &[usize],
input: &Buffer,
src_offset: usize,
ids: &Buffer,
ids_offset: usize,
input: BufferOffset,
ids: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = shape[..dim].iter().product();
@ -1227,16 +1093,16 @@ pub fn call_index_select(
contiguous,
src_dims,
src_strides,
(input, src_offset),
(ids, ids_offset),
&input,
&ids,
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -1252,10 +1118,8 @@ pub fn call_gather(
shape: &[usize],
ids_size: usize,
dim: usize,
input: &Buffer,
input_offset: usize,
ids: &Buffer,
ids_offset: usize,
input: BufferOffset,
ids: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = shape[..dim].iter().product();
@ -1277,22 +1141,23 @@ pub fn call_gather(
src_dim_size,
right_size,
ids_size,
(input, input_offset),
(ids, ids_offset),
&input,
&ids,
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_scatter_add(
device: &Device,
command_buffer: &CommandBufferRef,
@ -1301,10 +1166,8 @@ pub fn call_scatter_add(
src_shape: &[usize],
dst_shape: &[usize],
dim: usize,
input: &Buffer,
input_offset: usize,
ids: &Buffer,
ids_offset: usize,
input: BufferOffset,
ids: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = src_shape[..dim].iter().product();
@ -1327,22 +1190,23 @@ pub fn call_scatter_add(
src_dim_size,
right_size,
dst_dim_size,
(input, input_offset),
(ids, ids_offset),
&input,
&ids,
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_index_add(
device: &Device,
command_buffer: &CommandBufferRef,
@ -1352,10 +1216,8 @@ pub fn call_index_add(
dst_shape: &[usize],
ids_shape: &[usize],
dim: usize,
input: &Buffer,
input_offset: usize,
ids: &Buffer,
ids_offset: usize,
input: BufferOffset,
ids: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = src_shape[..dim].iter().product();
@ -1379,16 +1241,16 @@ pub fn call_index_add(
right_size,
dst_dim_size,
ids_dim_size,
(input, input_offset),
(ids, ids_offset),
&input,
&ids,
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -1654,8 +1516,7 @@ pub fn call_im2col1d_strided(
shape: &[usize],
strides: &[usize],
(k_size, stride, padding, dilation): (usize, usize, usize, usize),
input: &Buffer,
input_offset: usize,
input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
@ -1667,20 +1528,9 @@ pub fn call_im2col1d_strided(
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
dst_el,
l_out,
k_size,
stride,
padding,
dilation,
shape,
strides,
(input, input_offset),
output
)
(dst_el, l_out, k_size, stride, padding, dilation, shape, strides, &input, output)
);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -1697,8 +1547,7 @@ pub fn call_im2col_strided(
shape: &[usize],
strides: &[usize],
(h_k, w_k, stride, padding, dilation): (usize, usize, usize, usize, usize),
input: &Buffer,
input_offset: usize,
input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
@ -1716,21 +1565,11 @@ pub fn call_im2col_strided(
set_params!(
encoder,
(
dst_el,
h_out,
w_out,
h_k,
w_k,
stride,
padding,
dilation,
shape,
strides,
(input, input_offset),
dst_el, h_out, w_out, h_k, w_k, stride, padding, dilation, shape, strides, &input,
output
)
);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -1748,8 +1587,7 @@ pub fn call_upsample_nearest_2d(
strides: &[usize],
out_w: usize,
out_h: usize,
input: &Buffer,
input_offset: usize,
input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
@ -1761,18 +1599,9 @@ pub fn call_upsample_nearest_2d(
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
out_w,
out_h,
scale_w,
scale_h,
shape,
strides,
(input, input_offset),
output
)
(out_w, out_h, scale_w, scale_h, shape, strides, &input, output)
);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -1869,6 +1698,7 @@ pub enum GgmlDType {
F32,
}
#[allow(clippy::too_many_arguments)]
pub fn call_quantized_matmul_t(
device: &Device,
command_buffer: &CommandBufferRef,
@ -1884,16 +1714,16 @@ pub fn call_quantized_matmul_t(
let ne00 = k as i64;
let ne01 = n as i64;
let ne02 = b as i64;
let ne03 = 1 as i64;
let ne03 = 1i64;
let nb00 = 0i64;
let nb01 = 0 as i64;
let nb02 = 0 as i64;
let nb01 = 0i64;
let nb02 = 0i64;
let ne10 = k as i64;
let ne11 = m as i64;
let ne12 = b as i64;
let ne13 = 1 as i64;
let ne13 = 1i64;
let nb10 = 0i64;
let nb11 = 0i64;
@ -2128,6 +1958,7 @@ pub struct CallConvTranspose2dCfg<'a> {
pub kernel_offset: usize,
}
#[allow(clippy::too_many_arguments)]
pub fn call_conv_transpose2d(
device: &Device,
command_buffer: &CommandBufferRef,

View File

@ -12,7 +12,7 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
let options = MTLResourceOptions::StorageModeManaged;
let ptr = data.as_ptr() as *const c_void;
let size = (data.len() * std::mem::size_of::<T>()) as u64;
let size = std::mem::size_of_val(data) as u64;
device.new_buffer_with_data(ptr, size, options)
}
@ -41,6 +41,10 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let input = BufferOffset {
buffer: &input,
offset_in_bytes: 0,
};
let output = new_buffer(&device, v);
call_unary_contiguous(
&device,
@ -48,7 +52,7 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
&kernels,
name,
v.len(),
&input,
input,
&output,
)
.unwrap();
@ -72,8 +76,8 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V
&kernels,
name,
x.len(),
&left,
&right,
BufferOffset::zero_offset(&left),
BufferOffset::zero_offset(&right),
&output,
)
.unwrap();
@ -93,7 +97,15 @@ fn run_strided<T: Clone>(
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let output = new_buffer(&device, v);
let input = BufferOffset {
buffer: &input,
offset_in_bytes: offset,
};
let output_b = new_buffer(&device, v);
let output = BufferOffset {
buffer: &output_b,
offset_in_bytes: 0,
};
let kernels = Kernels::new();
call_unary_strided(
&device,
@ -101,16 +113,14 @@ fn run_strided<T: Clone>(
&kernels,
kernel,
shape,
&input,
input,
strides,
offset,
&output,
0,
output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec(&output, v.len())
read_to_vec(&output_b, v.len())
}
#[test]
@ -308,8 +318,7 @@ fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
&kernels,
name,
v.len(),
&input,
0,
BufferOffset::zero_offset(&input),
&output,
)
.unwrap();
@ -521,7 +530,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
&kernels,
"affine_f32",
size,
&input,
BufferOffset::zero_offset(&input),
&output,
mul as f32,
add as f32,
@ -554,9 +563,8 @@ fn run_affine_strided<T: Clone>(
&kernels,
"affine_f32_strided",
shape,
&input,
BufferOffset::zero_offset(&input),
strides,
0,
&output,
mul as f32,
add as f32,
@ -633,7 +641,7 @@ fn index_select_strided() {
fn index_select_f16() {
let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
.into_iter()
.map(|x| f16::from_f32(x))
.map(f16::from_f32)
.collect();
let shape = [5, 2];
let stride = [2, 1];
@ -700,8 +708,8 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let embeddings_buffer = new_buffer(&device, &embeddings);
let ids_buffer = new_buffer(&device, &ids);
let embeddings_buffer = new_buffer(&device, embeddings);
let ids_buffer = new_buffer(&device, ids);
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
@ -711,7 +719,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
let kernels = Kernels::new();
call_index_select(
&device,
&command_buffer,
command_buffer,
&kernels,
name,
shape,
@ -720,10 +728,8 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
true,
shape,
stride,
&embeddings_buffer,
0,
&ids_buffer,
0,
BufferOffset::zero_offset(&embeddings_buffer),
BufferOffset::zero_offset(&ids_buffer),
&dst_buffer,
)
.unwrap();
@ -746,8 +752,8 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let embeddings_buffer = new_buffer(&device, &embeddings);
let ids_buffer = new_buffer(&device, &ids);
let embeddings_buffer = new_buffer(&device, embeddings);
let ids_buffer = new_buffer(&device, ids);
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
@ -757,7 +763,7 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
let kernels = Kernels::new();
call_index_select(
&device,
&command_buffer,
command_buffer,
&kernels,
name,
shape,
@ -766,10 +772,8 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
false,
shape,
stride,
&embeddings_buffer,
0,
&ids_buffer,
0,
BufferOffset::zero_offset(&embeddings_buffer),
BufferOffset::zero_offset(&ids_buffer),
&dst_buffer,
)
.unwrap();
@ -811,8 +815,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
&dims,
&strides,
out_length,
&input,
0,
BufferOffset::zero_offset(&input),
&output,
)
.unwrap();
@ -931,6 +934,7 @@ fn softmax() {
);
}
#[allow(clippy::too_many_arguments)]
fn run_where_cond<I: Clone, T: Clone>(
shape: &[usize],
cond: &[I],
@ -965,18 +969,30 @@ fn run_where_cond<I: Clone, T: Clone>(
);
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
let cond = BufferOffset {
buffer: &cond,
offset_in_bytes: cond_offset,
};
let left = BufferOffset {
buffer: &left,
offset_in_bytes: left_offset,
};
let right = BufferOffset {
buffer: &right,
offset_in_bytes: cond_offset,
};
call_where_cond_strided(
&device,
command_buffer,
&kernels,
name,
shape,
&cond,
(&cond_stride, cond_offset),
&left,
(&left_stride, left_offset),
&right,
(&cond_stride, cond_offset),
cond,
&cond_stride,
left,
&left_stride,
right,
&cond_stride,
&output,
)
.unwrap();
@ -1148,7 +1164,7 @@ fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b:
#[test]
fn random() {
fn calc_mean(data: &[f32]) -> f32 {
let sum = data.iter().sum::<f32>() as f32;
let sum = data.iter().sum::<f32>();
let count = data.len();
assert!(count > 0);
sum / count as f32
@ -1162,7 +1178,7 @@ fn random() {
let variance = data
.iter()
.map(|value| {
let diff = mean - (*value as f32);
let diff = mean - *value;
diff * diff
})
.sum::<f32>()
@ -1241,10 +1257,8 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
shape,
shape,
dim,
&input_buffer,
0,
&ids_buffer,
0,
BufferOffset::zero_offset(&input_buffer),
BufferOffset::zero_offset(&ids_buffer),
&output,
)
.unwrap();
@ -1346,10 +1360,8 @@ fn run_index_add<T: Clone, I: Clone + std::fmt::Debug>(
shape,
shape,
dim,
&input_buffer,
0,
&indices_buffer,
0,
BufferOffset::zero_offset(&input_buffer),
BufferOffset::zero_offset(&indices_buffer),
&output,
)
.unwrap();
@ -1787,6 +1799,7 @@ fn avg_pool2d_u32() {
assert_eq!(results, expected);
}
#[allow(clippy::too_many_arguments)]
fn run_conv_transpose1d<T: Clone>(
input: &[T],
input_shape: &[usize],

View File

@ -104,26 +104,18 @@ UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided);
#define COPY2D(FN_NAME, TYPENAME) \
kernel void FN_NAME( \
constant size_t &d1, \
constant size_t &d2, \
constant size_t &src_s, \
constant size_t &dst_s, \
constant int64_t &d1, \
constant int64_t &d2, \
constant int64_t &src_s, \
constant int64_t &dst_s, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
uint2 idx [[thread_position_in_grid]] \
) { \
tid *= 4; \
if (tid >= d1 * d2) { \
return; \
} \
size_t idx1 = tid / d2; \
size_t idx2 = tid - idx1 * d2; \
size_t src_idx = idx1 * src_s + idx2; \
size_t dst_idx = idx1 * dst_s + idx2; \
if (idx.x >= d1 || idx.y >= d2) return; \
int64_t src_idx = idx.x * src_s + idx.y; \
int64_t dst_idx = idx.x * dst_s + idx.y; \
output[dst_idx] = input[src_idx]; \
output[dst_idx+1] = input[src_idx+1]; \
output[dst_idx+2] = input[src_idx+2]; \
output[dst_idx+3] = input[src_idx+3]; \
}
COPY2D(copy2d_f32, float)
@ -183,5 +175,5 @@ BFLOAT_UNARY_OP(sign)
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
COPY2D(copy2d_bf64, bfloat)
COPY2D(copy2d_bf16, bfloat)
#endif

View File

@ -0,0 +1,162 @@
use metal::{Buffer, ComputeCommandEncoderRef, ComputePipelineState, MTLSize};
use std::ffi::c_void;
/// Most kernels apply similarly across the tensors
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
/// actual total buffer length).
/// Then kernels can just do their op on their single point in the buffer.
pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
let size = length as u64;
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
let count = (size + width - 1) / width;
let thread_group_count = MTLSize {
width: count,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
(thread_group_count, thread_group_size)
}
// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96
pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
let mut pows0 = 0u64;
let mut pows1 = 0u64;
let mut pows2 = 0u64;
let mut sum = 0u64;
loop {
let presum = sum;
// Check all the pows
if dim0 >= (1 << (pows0 + 1)) {
pows0 += 1;
sum += 1;
}
if sum == 10 {
break;
}
if dim1 >= (1 << (pows1 + 1)) {
pows1 += 1;
sum += 1;
}
if sum == 10 {
break;
}
if dim2 >= (1 << (pows2 + 1)) {
pows2 += 1;
sum += 1;
}
if sum == presum || sum == 10 {
break;
}
}
MTLSize {
width: 1 << pows0,
height: 1 << pows1,
depth: 1 << pows2,
}
}
pub(crate) fn set_param<P: EncoderParam>(
encoder: &ComputeCommandEncoderRef,
position: u64,
data: P,
) {
<P as EncoderParam>::set_param(encoder, position, data)
}
/// Helper functions to create the various objects on the compute command encoder
/// on a single line.
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
pub(crate) trait EncoderParam {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
}
macro_rules! primitive {
($type:ty) => {
impl EncoderParam for $type {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_bytes(
position,
core::mem::size_of::<$type>() as u64,
&data as *const $type as *const c_void,
);
}
}
};
}
primitive!(bool);
primitive!(usize);
primitive!(i32);
primitive!(i64);
primitive!(u32);
primitive!(u64);
primitive!(f32);
pub struct BufferOffset<'a> {
pub buffer: &'a Buffer,
pub offset_in_bytes: usize,
}
impl<'a> BufferOffset<'a> {
pub fn zero_offset(buffer: &'a Buffer) -> Self {
Self {
buffer,
offset_in_bytes: 0,
}
}
}
impl<T> EncoderParam for &[T] {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_bytes(
position,
core::mem::size_of_val(data) as u64,
data.as_ptr() as *const c_void,
);
}
}
impl EncoderParam for &Buffer {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data), 0);
}
}
impl EncoderParam for (&Buffer, usize) {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data.0), data.1 as u64);
}
}
impl<'a> EncoderParam for &BufferOffset<'a> {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data.buffer), data.offset_in_bytes as u64);
}
}
impl EncoderParam for &mut Buffer {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data), 0);
}
}
impl EncoderParam for (&mut Buffer, usize) {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data.0), data.1 as u64);
}
}
#[macro_export]
macro_rules! set_params {
($encoder:ident, ($($param:expr),+)) => (
let mut _index = 0;
$(
$crate::utils::set_param($encoder, _index, $param);
_index += 1;
)*
);
}

View File

@ -498,6 +498,53 @@ impl<'a> VarBuilder<'a> {
let pth = candle::pickle::PthTensors::new(p, None)?;
Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
}
/// Gets a VarBuilder that applies some renaming function on tensor it gets queried for before
/// passing the new names to the inner VarBuilder.
///
/// ```rust
/// use candle::{Tensor, DType, Device};
///
/// let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
/// let tensors: std::collections::HashMap<_, _> = [
/// ("foo".to_string(), a),
/// ]
/// .into_iter()
/// .collect();
/// let vb = candle_nn::VarBuilder::from_tensors(tensors, DType::F32, &Device::Cpu);
/// assert!(vb.contains_tensor("foo"));
/// assert!(vb.get((2, 3), "foo").is_ok());
/// assert!(!vb.contains_tensor("bar"));
/// let vb = vb.rename_f(|f: &str| if f == "bar" { "foo".to_string() } else { f.to_string() });
/// assert!(vb.contains_tensor("bar"));
/// assert!(vb.contains_tensor("foo"));
/// assert!(vb.get((2, 3), "bar").is_ok());
/// assert!(vb.get((2, 3), "foo").is_ok());
/// assert!(!vb.contains_tensor("baz"));
/// # Ok::<(), candle::Error>(())
/// ```
pub fn rename_f<F: Fn(&str) -> String + Sync + Send + 'static>(self, f: F) -> Self {
let f: Box<dyn Fn(&str) -> String + Sync + Send + 'static> = Box::new(f);
self.rename(f)
}
pub fn rename<R: Renamer + Send + Sync + 'a>(self, renamer: R) -> Self {
let dtype = self.dtype();
let device = self.device().clone();
let path = self.path.clone();
let backend = Rename::new(self, renamer);
let backend: Box<dyn SimpleBackend + 'a> = Box::new(backend);
let data = TensorData {
backend,
dtype,
device,
};
Self {
data: Arc::new(data),
path,
_phantom: std::marker::PhantomData,
}
}
}
pub struct ShardedSafeTensors(candle::safetensors::MmapedSafetensors);
@ -618,3 +665,49 @@ impl Backend for ShardedSafeTensors {
self.0.get(name).is_ok()
}
}
/// This traits specifies a way to rename the queried names into names that are stored in an inner
/// VarBuilder.
pub trait Renamer {
/// This is applied to the name obtained by a name call and the resulting name is passed to the
/// inner VarBuilder.
fn rename(&self, v: &str) -> std::borrow::Cow<'_, str>;
}
pub struct Rename<'a, R: Renamer> {
inner: VarBuilder<'a>,
renamer: R,
}
impl<'a, R: Renamer + Sync + Send> SimpleBackend for Rename<'a, R> {
fn get(
&self,
s: Shape,
name: &str,
h: crate::Init,
dtype: DType,
dev: &Device,
) -> Result<Tensor> {
let name = self.renamer.rename(name);
self.inner
.get_with_hints_dtype(s, &name, h, dtype)?
.to_device(dev)
}
fn contains_tensor(&self, name: &str) -> bool {
let name = self.renamer.rename(name);
self.inner.contains_tensor(&name)
}
}
impl<'a, R: Renamer> Rename<'a, R> {
pub fn new(inner: VarBuilder<'a>, renamer: R) -> Self {
Self { inner, renamer }
}
}
impl Renamer for Box<dyn Fn(&str) -> String + Sync + Send> {
fn rename(&self, v: &str) -> std::borrow::Cow<'_, str> {
std::borrow::Cow::Owned(self(v))
}
}

View File

@ -2,7 +2,7 @@ use crate::onnx;
use crate::onnx::attribute_proto::AttributeType;
use crate::onnx::tensor_proto::DataType;
use candle::{bail, DType, Device, Result, Tensor};
use std::collections::HashMap;
use std::{collections::HashMap, usize};
pub type Value = Tensor;
@ -508,17 +508,33 @@ pub fn simple_eval(
values.insert(node.output[0].clone(), xs);
}
"Gather" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gather
let xs = get(&node.input[0])?;
let indices = get(&node.input[1])?;
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
let axis = xs.normalize_axis(axis)?;
// TODO: Provide an op to handle the ONNX generalized gather op ideally in a
// differentiable way.
let xs = if indices.rank() == 0 {
let index = indices.to_vec0::<i64>()? as usize;
xs.narrow(axis, index, 1)?.squeeze(axis)?
} else {
todo!("implement gather for {xs:?} {indices:?} axis {axis}")
// In Pytorch or Numpy this can be done by indexing the xs tensor using the indices
// tensor directly, but candle does not support tensor indexing at the moment, so
// some workarounds must be done.
let xs = match indices.dims() {
[] => {
let index = indices.to_vec0::<i64>()? as usize;
xs.narrow(axis, index, 1)?.squeeze(axis)?
}
[_] => xs.index_select(indices, axis)?,
[first, _] => {
let mut v = Vec::with_capacity(*first);
for i in 0..*first {
v.push(xs.index_select(&indices.get(i)?, axis)?)
}
Tensor::stack(&v, axis)?
}
_ => {
// TODO: Provide an op to handle the ONNX generalized gather op ideally in a
// differentiable way.
todo!("implement gather for {xs:?} {indices:?} axis {axis}")
}
};
values.insert(node.output[0].clone(), xs);
}
@ -781,6 +797,29 @@ pub fn simple_eval(
let input = get(&node.input[0])?;
values.insert(node.output[0].clone(), input.clone());
}
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
// TODO: This version is only compatible with ReduceMean V13 and below.
"ReduceMean" => {
let input = get(&node.input[0])?;
let axes = get_attr_opt::<[i64]>(node, "axes")?;
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1);
let n_dims = input.dims().len();
let axes: Vec<usize> = if let Some(axes) = axes {
axes.iter()
.map(|e| (if e < &0 { (n_dims as i64) + *e } else { *e }) as usize)
.collect()
} else {
(0..n_dims).collect()
};
let output = if keepdims == 1 {
input.mean_keepdim(axes)?
} else {
input.mean(axes)?
};
values.insert(node.output[0].clone(), output);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}

View File

@ -4,7 +4,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{Device, Result, Tensor};
use candle::{Device, NdArray, Result, Tensor};
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
use std::collections::HashMap;
@ -829,7 +829,134 @@ fn test_flatten_operation() -> Result<()> {
// #[test]
// "Gather"
// #[test]
#[test]
fn test_gather_operation() -> Result<()> {
// test taken from https://onnx.ai/onnx/operators/onnx__Gather.html#summary.
test(
&[[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]],
&[[0i64, 1], [1, 2]],
0,
&[[[1.0, 1.2], [2.3, 3.4]], [[2.3, 3.4], [4.5, 5.7]]],
)?;
// test taken from https://onnx.ai/onnx/operators/onnx__Gather.html#summary.
test(
&[[1.0, 1.2, 1.9], [2.3, 3.4, 3.9], [4.5, 5.7, 5.9]],
&[[0i64, 2]],
1,
&[[[1.0, 1.9]], [[2.3, 3.9]], [[4.5, 5.9]]],
)?;
// all the tests below are generated from numpy.take, which works like
// onnx's Gather operation.
test(&[1.0, 2.0, 3.0, 4.0], 3i64, 0, 4.0)?;
test(&[[1.0, 2.0, 3.0, 4.0]], 3i64, 1, &[4.0])?;
test(
&[[1.0], [2.0], [3.0], [4.0]],
&[3i64, 2],
0,
&[[4.0], [3.0]],
)?;
test(
&[
[[1.0, 2.0], [3.0, 4.0]],
[[5.0, 6.0], [7.0, 8.0]],
[[9.0, 10.0], [11.0, 12.0]],
[[13.0, 14.0], [15.0, 16.0]],
],
1i64,
0,
&[[5.0, 6.0], [7.0, 8.0]],
)?;
test(
&[
[[1.0, 2.0], [3.0, 4.0]],
[[5.0, 6.0], [7.0, 8.0]],
[[9.0, 10.0], [11.0, 12.0]],
[[13.0, 14.0], [15.0, 16.0]],
],
&[1i64, 0],
0,
&[[[5.0, 6.0], [7.0, 8.0]], [[1.0, 2.0], [3.0, 4.0]]],
)?;
fn test(
data: impl NdArray,
indices: impl NdArray,
axis: i64,
expected: impl NdArray,
) -> Result<()> {
let att_axis = AttributeProto {
name: "axis".to_string(),
ref_attr_name: "axis".to_string(),
i: axis,
doc_string: "axis".to_string(),
r#type: 2,
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Gather".to_string(),
domain: "".to_string(),
attribute: vec![att_axis],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
inputs.insert(INPUT_Y.to_string(), Tensor::new(indices, &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let expected = Tensor::new(expected, &Device::Cpu)?;
match expected.dims().len() {
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
_ => unreachable!(),
};
Ok(())
}
Ok(())
}
// "Shape"
#[test]
@ -1335,3 +1462,180 @@ fn test_relu_operation() -> Result<()> {
// "Cast"
// #[test]
// "ReduceMean"
#[test]
fn test_reduce_mean() -> Result<()> {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-120 default_axes_keepdims
test(
&[
[[5., 1.], [20., 2.]],
[[30., 1.], [40., 2.]],
[[55., 1.], [60., 2.]],
],
None,
1,
&[[[18.25]]],
)?;
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-120 do_no_keepdims
test(
&[
[[5., 1.], [20., 2.]],
[[30., 1.], [40., 2.]],
[[55., 1.], [60., 2.]],
],
Some(vec![1]),
0,
&[[12.5, 1.5], [35.0, 1.5], [57.5, 1.5]],
)?;
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-120 keepdims
test(
&[
[[5., 1.], [20., 2.]],
[[30., 1.], [40., 2.]],
[[55., 1.], [60., 2.]],
],
Some(vec![1]),
1,
&[[[12.5, 1.5]], [[35.0, 1.5]], [[57.5, 1.5]]],
)?;
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-120 negative_axes_keepdims
test(
&[
[[5., 1.], [20., 2.]],
[[30., 1.], [40., 2.]],
[[55., 1.], [60., 2.]],
],
Some(vec![-2]),
1,
&[[[12.5, 1.5]], [[35.0, 1.5]], [[57.5, 1.5]]],
)?;
// All the test data below was generated based on numpy's np.mean
test(
&[
[[5., 1.], [20., 2.]],
[[30., 1.], [40., 2.]],
[[55., 1.], [60., 2.]],
],
Some(vec![1, 2]),
0,
&[7.0, 18.25, 29.5],
)?;
test(
&[
[[5., 1.], [20., 2.]],
[[30., 1.], [40., 2.]],
[[55., 1.], [60., 2.]],
],
Some(vec![1, 2]),
1,
&[[[7.0]], [[18.25]], [[29.5]]],
)?;
test(&[1., 2., 3.], None, 1, &[2.0])?;
fn test(
data: impl NdArray,
axes: Option<Vec<i64>>,
keepdims: i64,
expected: impl NdArray,
) -> Result<()> {
let has_axes = axes.is_some();
let att_axes = AttributeProto {
name: "axes".to_string(),
ref_attr_name: "axes".to_string(),
i: 0,
doc_string: "axes".to_string(),
r#type: 7,
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: axes.unwrap_or_default(),
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let att_keepdims = AttributeProto {
name: "keepdims".to_string(),
ref_attr_name: "keepdims".to_string(),
i: keepdims,
doc_string: "keepdims".to_string(),
r#type: 2,
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "ReduceMean".to_string(),
domain: "".to_string(),
attribute: if has_axes {
vec![att_axes, att_keepdims]
} else {
vec![att_keepdims]
},
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let expected = Tensor::new(expected, &Device::Cpu)?;
match expected.dims().len() {
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
_ => unreachable!(),
};
Ok(())
}
Ok(())
}

View File

@ -120,7 +120,7 @@ fn rotate_half(x: &Tensor) -> Result<Tensor> {
Ok(x21)
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct FalconRotaryEmbedding {
inv_freq: Tensor,
cache: Option<(usize, Tensor, Tensor)>,
@ -179,12 +179,14 @@ impl FalconRotaryEmbedding {
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let on_true = Tensor::new(on_true, on_false.device())?
.to_dtype(on_false.dtype())?
.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct FalconAttention {
query_key_value: Linear,
dense: Linear,
@ -313,9 +315,13 @@ impl FalconAttention {
let attn_output = self.dense.forward(&attn_output)?;
Ok(attn_output)
}
fn clear_kv_cache(&mut self) {
self.kv_cache = None
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct FalconMlp {
dense_h_to_4h: Linear,
dense_4h_to_h: Linear,
@ -340,7 +346,7 @@ impl FalconMlp {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct FalconDecoderLayer {
inp_layernorm: LayerNorm,
self_attention: FalconAttention,
@ -400,9 +406,13 @@ impl FalconDecoderLayer {
let output = (mlp_output + residual)?;
Ok(output)
}
pub fn clear_kv_cache(&mut self) {
self.self_attention.clear_kv_cache()
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Falcon {
word_embeddings: Embedding,
blocks: Vec<FalconDecoderLayer>,
@ -475,4 +485,10 @@ impl Falcon {
let logits = self.lm_head.forward(&hidden_state)?.squeeze(1)?;
Ok(logits)
}
pub fn clear_kv_cache(&mut self) {
for block in self.blocks.iter_mut() {
block.clear_kv_cache()
}
}
}

View File

@ -1,7 +1,7 @@
use std::sync::Arc;
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{linear_b as linear, Linear, VarBuilder};
use candle_nn::{linear_b as linear, Activation, Linear, VarBuilder};
fn default_max_position_embeddings() -> usize {
4096
@ -11,7 +11,9 @@ fn default_max_position_embeddings() -> usize {
pub struct Config {
pub attention_bias: bool,
pub head_dim: usize,
pub hidden_act: candle_nn::Activation,
// The code gemma configs include both hidden_act and hidden_activation.
pub hidden_act: Option<Activation>,
pub hidden_activation: Option<Activation>,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_attention_heads: usize,
@ -25,6 +27,16 @@ pub struct Config {
pub max_position_embeddings: usize,
}
impl Config {
fn hidden_act(&self) -> Result<Activation> {
match (self.hidden_act, self.hidden_activation) {
(None, Some(act)) | (Some(act), None) => Ok(act),
(Some(_), Some(_)) => candle::bail!("both hidden_act and hidden_activation are set"),
(None, None) => candle::bail!("none of hidden_act and hidden_activation are set"),
}
}
}
#[derive(Debug, Clone)]
struct RmsNorm {
weight: Tensor,
@ -126,7 +138,7 @@ impl MLP {
gate_proj,
up_proj,
down_proj,
act_fn: cfg.hidden_act,
act_fn: cfg.hidden_act()?,
})
}
}
@ -179,18 +191,6 @@ impl Attention {
})
}
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
let n_rep = self.num_kv_groups;
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
xs.unsqueeze(2)?
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
}
}
fn forward(
&mut self,
xs: &Tensor,
@ -227,8 +227,8 @@ impl Attention {
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
let key_states = self.repeat_kv(key_states)?.contiguous()?;
let value_states = self.repeat_kv(value_states)?.contiguous()?;
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);

View File

@ -16,6 +16,8 @@ pub struct LlamaConfig {
pub rms_norm_eps: f64,
#[serde(default = "default_rope")]
pub rope_theta: f32,
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<u32>,
}
fn default_rope() -> f32 {
@ -34,6 +36,8 @@ impl LlamaConfig {
rms_norm_eps: self.rms_norm_eps,
rope_theta: self.rope_theta,
use_flash_attn,
bos_token_id: self.bos_token_id,
eos_token_id: self.eos_token_id,
}
}
}
@ -49,6 +53,8 @@ pub struct Config {
pub use_flash_attn: bool,
pub rms_norm_eps: f64,
pub rope_theta: f32,
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<u32>,
}
impl Config {
@ -63,6 +69,8 @@ impl Config {
use_flash_attn,
rms_norm_eps: 1e-6,
rope_theta: 10_000.0,
bos_token_id: None,
eos_token_id: None,
}
}
@ -77,6 +85,8 @@ impl Config {
use_flash_attn,
rms_norm_eps: 1e-5,
rope_theta: 10_000.0,
bos_token_id: None,
eos_token_id: None,
}
}
}
@ -106,7 +116,6 @@ impl Cache {
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
// This is different from the paper, see:
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
let cos = idx_theta.cos()?.to_dtype(dtype)?;
let sin = idx_theta.sin()?.to_dtype(dtype)?;
Ok(Self {
@ -166,16 +175,10 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
let _enter = self.span_rot.enter();
let (b_sz, _, seq_len, hidden_size) = x.dims4()?;
let (_b_sz, _, seq_len, _hidden_size) = x.dims4()?;
let cos = cache.cos.narrow(0, index_pos, seq_len)?;
let sin = cache.sin.narrow(0, index_pos, seq_len)?;
let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?;
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
Ok(rope)
candle_nn::rotary_emb::rope(x, &cos, &sin)
}
fn forward(
@ -193,10 +196,12 @@ impl CausalSelfAttention {
let q = q
.reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
.transpose(1, 2)?;
.transpose(1, 2)?
.contiguous()?;
let k = k
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
.transpose(1, 2)?;
.transpose(1, 2)?
.contiguous()?;
let mut v = v
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
.transpose(1, 2)?;
@ -256,17 +261,7 @@ impl CausalSelfAttention {
}
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
let n_rep = self.num_attention_heads / self.num_key_value_heads;
if n_rep == 1 {
Ok(x)
} else {
let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
let x = x
.unsqueeze(2)?
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
Ok(x)
}
crate::utils::repeat_kv(x, self.num_attention_heads / self.num_key_value_heads)
}
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {

View File

@ -1,4 +1,3 @@
#![allow(unused)]
/// A fast implementation of mamba for inference only.
/// This is based on: https://github.com/LaurentMazare/mamba.rs
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
@ -38,12 +37,12 @@ pub struct State {
}
impl State {
pub fn new(batch_size: usize, cfg: &Config, device: &Device) -> Result<Self> {
pub fn new(batch_size: usize, cfg: &Config, dtype: DType, device: &Device) -> Result<Self> {
let mut hs = Vec::with_capacity(cfg.n_layer);
let mut prev_xs = Vec::with_capacity(cfg.n_layer);
for _i in 0..cfg.n_layer {
let h = Tensor::zeros((batch_size, cfg.d_inner(), D_STATE), DType::F32, device)?;
let x = Tensor::zeros((batch_size, cfg.d_inner()), DType::F32, device)?;
let h = Tensor::zeros((batch_size, cfg.d_inner(), D_STATE), dtype, device)?;
let x = Tensor::zeros((batch_size, cfg.d_inner()), dtype, device)?;
hs.push(h);
prev_xs.push([x.clone(), x.clone(), x.clone(), x.clone()]);
}
@ -128,8 +127,8 @@ impl MambaBlock {
let delta = delta.apply(&self.dt_proj)?;
// softplus
let delta = (delta.exp()? + 1.)?.log()?;
let a = self.a_log.to_dtype(candle::DType::F32)?.exp()?.neg()?;
let d = self.d.to_dtype(candle::DType::F32)?;
let a = self.a_log.to_dtype(delta.dtype())?.exp()?.neg()?;
let d = self.d.to_dtype(delta.dtype())?;
// Selective scan part
// Eqn (2a), page 3, h_t = Ab h_{t-1} + Bb x_t
@ -178,6 +177,7 @@ pub struct Model {
layers: Vec<ResidualBlock>,
norm_f: RmsNorm,
lm_head: Linear,
dtype: DType,
}
impl Model {
@ -196,6 +196,7 @@ impl Model {
layers,
norm_f,
lm_head,
dtype: vb.dtype(),
})
}
@ -208,4 +209,8 @@ impl Model {
state.pos += 1;
xs.apply(&self.norm_f)?.apply(&self.lm_head)
}
pub fn dtype(&self) -> DType {
self.dtype
}
}

View File

@ -216,18 +216,6 @@ impl Attention {
})
}
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
let n_rep = self.num_kv_groups;
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
xs.unsqueeze(2)?
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
}
}
fn forward(
&mut self,
xs: &Tensor,
@ -266,8 +254,8 @@ impl Attention {
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
let key_states = self.repeat_kv(key_states)?;
let value_states = self.repeat_kv(value_states)?;
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
let attn_output = if self.use_flash_attn {
// flash-attn expects (b_sz, seq_len, nheads, head_dim)

View File

@ -158,18 +158,6 @@ impl Attention {
})
}
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
let n_rep = self.num_kv_groups;
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
xs.unsqueeze(2)?
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
}
}
fn forward(
&mut self,
xs: &Tensor,
@ -206,8 +194,8 @@ impl Attention {
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
let key_states = self.repeat_kv(key_states)?;
let value_states = self.repeat_kv(value_states)?;
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
let attn_output = if self.use_flash_attn {
// flash-attn expects (b_sz, seq_len, nheads, head_dim)

View File

@ -37,12 +37,14 @@ pub mod quantized_mistral;
pub mod quantized_mixformer;
pub mod quantized_moondream;
pub mod quantized_mpt;
pub mod quantized_recurrent_gemma;
pub mod quantized_rwkv_v5;
pub mod quantized_rwkv_v6;
pub mod quantized_stable_lm;
pub mod quantized_t5;
pub mod qwen2;
pub mod qwen2_moe;
pub mod recurrent_gemma;
pub mod repvgg;
pub mod resnet;
pub mod rwkv_v5;

View File

@ -104,8 +104,8 @@ impl GroupedQueryAttention {
};
self.kv_cache = Some((key.clone(), value.clone()));
let query = query.contiguous()?;
let key = repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
let value = repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
let key = crate::utils::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
let value = crate::utils::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
let attn_weights = (query.matmul(&key)? * self.softmax_scale)?;
let attn_bias = {
let s_q = query.dim(D::Minus2)?;
@ -134,20 +134,6 @@ impl GroupedQueryAttention {
}
}
// This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
// The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
// (batch, num_attention_heads, seqlen, head_dim)
pub(crate) fn repeat_kv(xs: Tensor, n_rep: usize) -> Result<Tensor> {
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
xs.unsqueeze(2)?
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
}
}
#[derive(Debug, Clone)]
struct Ffn {
up_proj: Linear,

View File

@ -174,15 +174,7 @@ impl Attention {
}
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
let n_rep = self.num_heads / self.num_kv_heads;
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
xs.unsqueeze(2)?
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
}
crate::utils::repeat_kv(xs, self.num_heads / self.num_kv_heads)
}
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {

View File

@ -205,9 +205,9 @@ impl LayerWeights {
};
self.kv_cache = Some((k.clone(), v.clone()));
// Support for MQA, useful for 70B models.
let k = self.repeat_kv(k)?;
let v = self.repeat_kv(v)?;
// Support for MQA, useful for 70B models and mistral.
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let att = match mask {
@ -224,20 +224,6 @@ impl LayerWeights {
let y = self.attention_wo.forward(&y)?;
Ok(y)
}
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
let n_rep = self.n_head / self.n_kv_head;
if n_rep == 1 {
Ok(x)
} else {
let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
let x = x
.unsqueeze(2)?
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
Ok(x)
}
}
}
#[derive(Debug, Clone)]

View File

@ -235,6 +235,7 @@ pub mod transformer {
xs = layer.forward(&xs, pos, &mask)?
}
xs.narrow(1, seqlen - 1, 1)?
.contiguous()?
.apply(&self.norm)?
.apply(&self.output)
}

View File

@ -122,18 +122,6 @@ impl Attention {
})
}
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
let n_rep = self.num_kv_groups;
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
xs.unsqueeze(2)?
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
}
}
fn forward(
&mut self,
xs: &Tensor,
@ -172,8 +160,8 @@ impl Attention {
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
let key_states = self.repeat_kv(key_states)?;
let value_states = self.repeat_kv(value_states)?;
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);

View File

@ -71,8 +71,8 @@ impl GroupedQueryAttention {
};
self.kv_cache = Some((key.clone(), value.clone()));
let query = query.contiguous()?;
let key = super::mpt::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
let value = super::mpt::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
let key = crate::utils::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
let value = crate::utils::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
let attn_weights = (query.matmul(&key)? * self.softmax_scale)?;
let attn_bias = {
let s_q = query.dim(D::Minus2)?;

View File

@ -0,0 +1,412 @@
use crate::quantized_nn::{linear_b as linear, Embedding, Linear};
pub use crate::quantized_var_builder::VarBuilder;
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use std::sync::Arc;
use crate::models::recurrent_gemma::{Config, Rglru, RmsNorm, RotaryEmbedding, TemporalBlockType};
fn rms_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<RmsNorm> {
let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
Ok(RmsNorm::from_weight(weight, eps))
}
#[derive(Debug, Clone)]
struct Mlp {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
act_fn: candle_nn::Activation,
}
impl Mlp {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let h = cfg.hidden_size;
let intermediate_size = cfg.intermediate_size / 2;
let gate_proj = linear(h, intermediate_size, true, vb.pp("gate_proj"))?;
let up_proj = linear(h, intermediate_size, true, vb.pp("up_proj"))?;
let down_proj = linear(intermediate_size, h, true, vb.pp("down_proj"))?;
Ok(Self {
gate_proj,
up_proj,
down_proj,
act_fn: cfg.hidden_activation,
})
}
}
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let gate = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
(gate * xs.apply(&self.up_proj))?.apply(&self.down_proj)
}
}
fn rglru(cfg: &Config, vb: VarBuilder) -> Result<Rglru> {
let h = cfg.hidden_size;
let lru_width = cfg.lru_width.unwrap_or(h);
let n_heads = cfg.num_attention_heads;
let block_width = lru_width / n_heads;
let recurrent_param = vb.get((lru_width,), "recurrent_param")?;
let input_gate_weight = vb.get((n_heads, block_width, block_width), "input_gate_weight")?;
let input_gate_bias = vb.get((n_heads, block_width), "input_gate_bias")?;
let recurrent_gate_weight =
vb.get((n_heads, block_width, block_width), "recurrent_gate_weight")?;
let recurrent_gate_bias = vb.get((n_heads, block_width), "recurrent_gate_bias")?;
Ok(Rglru {
recurrent_param: recurrent_param.dequantize(vb.device())?,
input_gate_bias: input_gate_bias.dequantize(vb.device())?,
input_gate_weight: input_gate_weight.dequantize(vb.device())?,
recurrent_gate_bias: recurrent_gate_bias.dequantize(vb.device())?,
recurrent_gate_weight: recurrent_gate_weight.dequantize(vb.device())?,
block_width,
n_heads,
recurrent_states: None,
})
}
#[derive(Debug, Clone)]
struct RecurrentBlock {
linear_y: Linear,
linear_x: Linear,
linear_out: Linear,
conv_1d: candle_nn::Conv1d,
conv1d_state: Option<Tensor>,
conv1d_width: usize,
rg_lru: Rglru,
act_fn: candle_nn::Activation,
}
impl RecurrentBlock {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let h = cfg.hidden_size;
let lru_width = cfg.lru_width.unwrap_or(h);
let linear_y = linear(h, lru_width, true, vb.pp("linear_y"))?;
let linear_x = linear(h, lru_width, true, vb.pp("linear_x"))?;
let linear_out = linear(lru_width, h, true, vb.pp("linear_out"))?;
let conv_1d = {
let ws = vb
.get((lru_width, 1, cfg.conv1d_width), "conv_1d.weight")?
.dequantize(vb.device())?;
let bs = vb.get(lru_width, "conv_1d.bias")?.dequantize(vb.device())?;
let config = candle_nn::Conv1dConfig {
groups: lru_width,
padding: cfg.conv1d_width - 1,
..Default::default()
};
candle_nn::Conv1d::new(ws, Some(bs), config)
};
let rg_lru = rglru(cfg, vb.pp("rg_lru"))?;
Ok(Self {
linear_y,
linear_x,
linear_out,
conv_1d,
conv1d_state: None,
conv1d_width: cfg.conv1d_width,
rg_lru,
act_fn: cfg.hidden_activation,
})
}
pub fn forward(&mut self, xs: &Tensor, pos: usize) -> Result<Tensor> {
let (_b_sz, seq_len, _) = xs.dims3()?;
let y_branch = xs.apply(&self.linear_y)?.apply(&self.act_fn)?;
let x_branch = xs.apply(&self.linear_x)?.transpose(1, 2)?;
let x_branch = if pos == 0 {
let x_len = x_branch.dim(D::Minus1)?;
let pad = self.conv1d_width as i64 - x_len as i64 - 1;
let padded = match pad.cmp(&0) {
std::cmp::Ordering::Equal => x_branch.clone(),
std::cmp::Ordering::Less => {
let rev_pad = (-pad) as usize;
x_branch.narrow(D::Minus1, rev_pad, x_len - rev_pad)?
}
std::cmp::Ordering::Greater => {
x_branch.pad_with_zeros(D::Minus1, pad as usize, 0)?
}
};
self.conv1d_state = Some(padded);
x_branch
.apply(&self.conv_1d)?
.narrow(D::Minus1, 0, seq_len)?
} else {
let conv_state = match self.conv1d_state.as_ref() {
None => candle::bail!("empty cache despite pos > 0"),
Some(s) => Tensor::cat(&[s, &x_branch], D::Minus1)?,
};
let w = self.conv_1d.weight().i((.., 0, ..))?;
let x_branch = conv_state.broadcast_mul(&w)?.sum(D::Minus1)?;
let x_branch = match self.conv_1d.bias() {
None => x_branch,
Some(b) => x_branch.broadcast_add(b)?,
};
let x_branch = x_branch.unsqueeze(D::Minus1)?;
self.conv1d_state = Some(conv_state.i((.., .., 1..))?);
x_branch
};
let x_branch = x_branch.transpose(1, 2)?;
let x_branch = self.rg_lru.forward(&x_branch, pos)?;
(x_branch * y_branch)?.apply(&self.linear_out)
}
}
#[derive(Debug, Clone)]
struct SdpaAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
n_heads: usize,
n_kv_heads: usize,
head_dim: usize,
hidden_size: usize,
kv_cache: Option<(Tensor, Tensor)>,
rotary_emb: Arc<RotaryEmbedding>,
}
impl SdpaAttention {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let h = cfg.hidden_size;
let n_heads = cfg.num_attention_heads;
let n_kv_heads = cfg.num_key_value_heads;
let hd = cfg.head_dim;
let q_proj = linear(h, n_heads * hd, cfg.attention_bias, vb.pp("q_proj"))?;
let k_proj = linear(h, n_kv_heads * hd, cfg.attention_bias, vb.pp("k_proj"))?;
let v_proj = linear(h, n_kv_heads * hd, cfg.attention_bias, vb.pp("v_proj"))?;
let o_proj = linear(n_heads * hd, h, true, vb.pp("o_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
n_heads,
n_kv_heads,
head_dim: hd,
hidden_size: h,
kv_cache: None,
rotary_emb,
})
}
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
let n_rep = self.n_heads / self.n_kv_heads;
crate::utils::repeat_kv(x, n_rep)
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
pos: usize,
) -> Result<Tensor> {
let (bsz, q_len, _) = xs.dims3()?;
let query_states = xs.apply(&self.q_proj)?;
let key_states = xs.apply(&self.k_proj)?;
let value_states = xs.apply(&self.v_proj)?;
let query_states = query_states
.reshape((bsz, q_len, self.n_heads, self.head_dim))?
.transpose(1, 2)?;
let key_states = key_states
.reshape((bsz, q_len, self.n_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let value_states = value_states
.reshape((bsz, q_len, self.n_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let query_states = query_states.chunk(2, D::Minus1)?;
let key_states = key_states.chunk(2, D::Minus1)?;
let (query_rot, key_rot) =
self.rotary_emb
.apply_rotary_emb_qkv(&query_states[0], &key_states[0], pos)?;
let query_states = Tensor::cat(&[&query_rot, &query_states[1]], D::Minus1)?.contiguous()?;
let key_states = Tensor::cat(&[&key_rot, &key_states[1]], D::Minus1)?.contiguous()?;
let (key_states, value_states) = match &self.kv_cache {
None => (key_states, value_states),
Some((prev_k, prev_v)) => {
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
(key_states, value_states)
}
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
let key_states = self.repeat_kv(key_states)?;
let value_states = self.repeat_kv(value_states)?;
let xs = {
let att = (query_states.matmul(&key_states.t()?)? / (self.head_dim as f64).sqrt())?;
let att = if q_len == 1 {
att
} else {
match attention_mask {
None => att,
Some(mask) => att.broadcast_add(mask)?,
}
};
let att = candle_nn::ops::softmax_last_dim(&att)?;
att.matmul(&value_states.contiguous()?)?
};
let xs = xs
.transpose(1, 2)?
.reshape((bsz, q_len, self.hidden_size))?;
self.o_proj.forward(&xs)
}
}
#[derive(Debug, Clone)]
enum TemporalBlock {
Recurrent(RecurrentBlock),
Attention(SdpaAttention),
}
impl TemporalBlock {
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
pos: usize,
) -> Result<Tensor> {
match self {
Self::Recurrent(b) => b.forward(xs, pos),
Self::Attention(b) => b.forward(xs, attention_mask, pos),
}
}
}
#[derive(Debug, Clone)]
struct DecoderLayer {
temporal_pre_norm: RmsNorm,
channel_pre_norm: RmsNorm,
temporal_block: TemporalBlock,
mlp_block: Mlp,
}
impl DecoderLayer {
fn new(
block_idx: usize,
rotary_emb: Arc<RotaryEmbedding>,
cfg: &Config,
vb: VarBuilder,
) -> Result<Self> {
let h = cfg.hidden_size;
let temporal_pre_norm = rms_norm(h, cfg.rms_norm_eps, vb.pp("temporal_pre_norm"))?;
let channel_pre_norm = rms_norm(h, cfg.rms_norm_eps, vb.pp("channel_pre_norm"))?;
let temporal_block = match cfg.block_types[block_idx % cfg.block_types.len()] {
TemporalBlockType::Recurrent => {
let block = RecurrentBlock::new(cfg, vb.pp("temporal_block"))?;
TemporalBlock::Recurrent(block)
}
TemporalBlockType::Attention => {
let block = SdpaAttention::new(rotary_emb, cfg, vb.pp("temporal_block"))?;
TemporalBlock::Attention(block)
}
};
let mlp_block = Mlp::new(cfg, vb.pp("mlp_block"))?;
Ok(Self {
temporal_pre_norm,
channel_pre_norm,
temporal_block,
mlp_block,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
pos: usize,
) -> Result<Tensor> {
let residual = xs;
let xs = xs.apply(&self.temporal_pre_norm)?;
let xs = self.temporal_block.forward(&xs, attention_mask, pos)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = xs.apply(&self.channel_pre_norm)?.apply(&self.mlp_block)?;
xs + residual
}
}
#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: Embedding,
layers: Vec<DecoderLayer>,
final_norm: RmsNorm,
lm_head: Linear,
hidden_size: usize,
logits_soft_cap: f64,
device: Device,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let embed_tokens = Embedding::new(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?;
let rotary_emb = Arc::new(RotaryEmbedding::new(DType::F32, cfg, vb.device())?);
let vb_b = vb.pp("layers");
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
for idx in 0..cfg.num_hidden_layers {
let layer = DecoderLayer::new(idx, rotary_emb.clone(), cfg, vb_b.pp(idx))?;
layers.push(layer)
}
let final_norm = rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("final_norm"))?;
let lm_head = linear(
cfg.hidden_size,
cfg.vocab_size,
false,
vb.pp("embed_tokens"),
)?;
Ok(Self {
embed_tokens,
layers,
final_norm,
lm_head,
hidden_size: cfg.hidden_size,
logits_soft_cap: cfg.logits_soft_cap,
device: vb.device().clone(),
})
}
fn prepare_decoder_attention_mask(
&self,
b_size: usize,
tgt_len: usize,
seqlen_offset: usize,
) -> Result<Tensor> {
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(DType::F32)
}
pub fn forward(&mut self, xs: &Tensor, pos: usize) -> Result<Tensor> {
let (b_size, seq_len) = xs.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, pos)?;
Some(mask)
};
let xs = xs.apply(&self.embed_tokens)?;
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), pos)?;
}
let logits = xs
.narrow(1, seq_len - 1, 1)?
.apply(&self.final_norm)?
.apply(&self.lm_head)?;
let logits = ((logits / self.logits_soft_cap)?.tanh()? * self.logits_soft_cap)?;
Ok(logits)
}
}

View File

@ -94,18 +94,6 @@ impl Attention {
})
}
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
let n_rep = self.num_kv_groups;
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
xs.unsqueeze(2)?
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
}
}
fn forward(
&mut self,
xs: &Tensor,
@ -152,8 +140,9 @@ impl Attention {
self.kv_cache = Some((key_states.clone(), value_states.clone()));
}
let key_states = self.repeat_kv(key_states)?.contiguous()?;
let value_states = self.repeat_kv(value_states)?.contiguous()?;
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
let value_states =
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);

View File

@ -146,18 +146,6 @@ impl Attention {
})
}
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
let n_rep = self.num_kv_groups;
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
xs.unsqueeze(2)?
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
}
}
fn forward(
&mut self,
xs: &Tensor,
@ -194,8 +182,9 @@ impl Attention {
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
let key_states = self.repeat_kv(key_states)?.contiguous()?;
let value_states = self.repeat_kv(value_states)?.contiguous()?;
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
let value_states =
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);

View File

@ -151,18 +151,6 @@ impl Attention {
})
}
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
let n_rep = self.num_kv_groups;
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
xs.unsqueeze(2)?
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
}
}
fn forward(
&mut self,
xs: &Tensor,
@ -199,8 +187,9 @@ impl Attention {
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
let key_states = self.repeat_kv(key_states)?.contiguous()?;
let value_states = self.repeat_kv(value_states)?.contiguous()?;
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
let value_states =
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);

View File

@ -0,0 +1,643 @@
// This implementation is based on the python version from huggingface/transformers.
// https://github.com/huggingface/transformers/blob/b109257f4fb8b1166e7c53cc5418632014ed53a5/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py#L2
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{linear_b as linear, Linear, VarBuilder};
use std::sync::Arc;
#[derive(serde::Deserialize, Debug, Clone, Copy)]
#[serde(rename_all = "snake_case")]
pub enum TemporalBlockType {
Attention,
Recurrent,
}
#[derive(serde::Deserialize, Debug, Clone)]
pub struct Config {
pub num_hidden_layers: usize,
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub head_dim: usize,
pub lru_width: Option<usize>,
pub attention_window_size: usize,
pub conv1d_width: usize,
pub logits_soft_cap: f64,
pub hidden_activation: candle_nn::Activation,
pub partial_rotary_factor: f64,
pub rms_norm_eps: f64,
pub rope_theta: f64,
#[serde(alias = "_block_types")]
pub block_types: Vec<TemporalBlockType>,
pub attention_bias: bool,
#[serde(default = "default_max_seq_len")]
pub max_seq_len: usize,
}
fn default_max_seq_len() -> usize {
8192
}
#[derive(Debug, Clone)]
pub(crate) struct RmsNorm {
weight: Tensor,
eps: f64,
}
impl RmsNorm {
pub(crate) fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let weight = vb.get(dim, "weight")?;
Ok(Self { weight, eps })
}
pub(crate) fn from_weight(weight: Tensor, eps: f64) -> Self {
Self { weight, eps }
}
}
impl Module for RmsNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = x.dim(D::Minus1)?;
let x = x.to_dtype(internal_dtype)?;
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
x_normed
.to_dtype(x_dtype)?
.broadcast_mul(&(&self.weight + 1.0)?)
}
}
#[derive(Debug, Clone)]
pub(crate) struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
let last_dim = xs.dim(D::Minus1)?;
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
}
impl RotaryEmbedding {
pub(crate) fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
if cfg.partial_rotary_factor != 0.5 {
candle::bail!("partial-rotary-factor {} <> 0.5", cfg.partial_rotary_factor)
}
let dim = cfg.head_dim / 2;
let max_seq_len = cfg.max_seq_len;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
pub(crate) fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,
seqlen_offset: usize,
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
Ok((q_embed, k_embed))
}
}
#[derive(Debug, Clone)]
struct Mlp {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
act_fn: candle_nn::Activation,
}
impl Mlp {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let h = cfg.hidden_size;
let intermediate_size = cfg.intermediate_size / 2;
let gate_proj = linear(h, intermediate_size, true, vb.pp("gate_proj"))?;
let up_proj = linear(h, intermediate_size, true, vb.pp("up_proj"))?;
let down_proj = linear(intermediate_size, h, true, vb.pp("down_proj"))?;
Ok(Self {
gate_proj,
up_proj,
down_proj,
act_fn: cfg.hidden_activation,
})
}
}
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let gate = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
(gate * xs.apply(&self.up_proj))?.apply(&self.down_proj)
}
}
// Real-Gated Linear Recurrent Unit
#[derive(Debug, Clone)]
pub(crate) struct Rglru {
pub(crate) recurrent_param: Tensor,
pub(crate) input_gate_weight: Tensor,
pub(crate) input_gate_bias: Tensor,
pub(crate) recurrent_gate_weight: Tensor,
pub(crate) recurrent_gate_bias: Tensor,
pub(crate) block_width: usize,
pub(crate) n_heads: usize,
pub(crate) recurrent_states: Option<Tensor>,
}
fn baddbmm(a: &Tensor, b: &Tensor, c: &Tensor) -> Result<Tensor> {
a.broadcast_add(&b.matmul(c)?)
}
fn softplus(xs: &Tensor) -> Result<Tensor> {
(xs.exp()? + 1.0)?.log()
}
impl Rglru {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let h = cfg.hidden_size;
let lru_width = cfg.lru_width.unwrap_or(h);
let n_heads = cfg.num_attention_heads;
let block_width = lru_width / n_heads;
let recurrent_param = vb.get((lru_width,), "recurrent_param")?;
let input_gate_weight = vb.get((n_heads, block_width, block_width), "input_gate_weight")?;
let input_gate_bias = vb.get((n_heads, block_width), "input_gate_bias")?;
let recurrent_gate_weight =
vb.get((n_heads, block_width, block_width), "recurrent_gate_weight")?;
let recurrent_gate_bias = vb.get((n_heads, block_width), "recurrent_gate_bias")?;
Ok(Self {
recurrent_param,
input_gate_bias,
input_gate_weight,
recurrent_gate_bias,
recurrent_gate_weight,
block_width,
n_heads,
recurrent_states: None,
})
}
// https://github.com/huggingface/transformers/blob/0bd58f1ce0573c0e3269de4215a17d318add49b9/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py#L303
pub(crate) fn forward(&mut self, xs: &Tensor, pos: usize) -> Result<Tensor> {
let (b_sz, seq_len, lru_width) = xs.dims3()?;
let pos = Tensor::arange(pos as u32, (pos + seq_len) as u32, xs.device())?;
let reset = pos.eq(0u32)?.unsqueeze(1)?.unsqueeze(0)?;
let reshape_act = xs
.reshape((b_sz * seq_len, self.n_heads, self.block_width))?
.permute((1, 0, 2))?
.contiguous()?;
let res = baddbmm(
&self.input_gate_bias.unsqueeze(1)?,
&reshape_act,
&self.input_gate_weight,
)?;
let input_gate = res.transpose(0, 1)?.reshape((b_sz, seq_len, lru_width))?;
let input_gate = candle_nn::ops::sigmoid(&input_gate)?;
let res = baddbmm(
&self.recurrent_gate_bias.unsqueeze(1)?,
&reshape_act,
&self.recurrent_gate_weight,
)?;
let recurrent_gate = res.transpose(0, 1)?.reshape((b_sz, seq_len, lru_width))?;
let recurrent_gate = candle_nn::ops::sigmoid(&recurrent_gate)?;
let log_recurrent_gate =
(recurrent_gate * (-8.0))?.broadcast_mul(&softplus(&self.recurrent_param)?)?;
let recurrent_gate = log_recurrent_gate.exp()?;
let a_square = (log_recurrent_gate * 2.)?.exp()?;
// Gate the input.
let gated_inputs = (xs * input_gate)?;
let reset = reset.to_dtype(a_square.dtype())?;
let multiplier =
reset.broadcast_add(&((1.0 - &reset)?.broadcast_mul(&(1.0 - a_square)?.sqrt()?))?)?;
let normalized_x = (gated_inputs * multiplier.to_dtype(xs.dtype()))?;
let (hidden_states, recurrent_states) = rnn_scan(
&normalized_x,
&recurrent_gate,
&reset,
self.recurrent_states.as_ref(),
)?;
self.recurrent_states = Some(recurrent_states);
Ok(hidden_states)
}
}
fn rnn_scan(
hidden_states: &Tensor,
recurrent_gate: &Tensor,
reset: &Tensor,
recurrent_states: Option<&Tensor>,
) -> Result<(Tensor, Tensor)> {
let acc_dtype = DType::F32;
let dev = hidden_states.device();
let in_dtype = hidden_states.dtype();
let inv_reset = (1.0 - reset)?.to_dtype(recurrent_gate.dtype())?;
let recurrent_gate = recurrent_gate.broadcast_mul(&inv_reset)?;
let (c, r) = if hidden_states.dim(1)? == 1 {
match recurrent_states {
None => {
let next_state = hidden_states.i((.., 0))?.to_dtype(acc_dtype)?;
(hidden_states.clone(), next_state)
}
Some(recurrent_states) => {
let contextualized_states =
recurrent_gate.to_dtype(acc_dtype)? * recurrent_states.unsqueeze(1)?;
let contextualized_states =
(contextualized_states + hidden_states.to_dtype(acc_dtype)?)?;
let c = contextualized_states.to_dtype(in_dtype)?;
let l = contextualized_states.dim(1)?;
let r = contextualized_states.i((.., l - 1))?;
(c, r)
}
}
} else {
let mut recurrent_states = match recurrent_states {
None => Tensor::zeros(hidden_states.i((.., 0))?.shape(), acc_dtype, dev)?,
Some(r) => r.clone(),
};
let mut contextualized_states = vec![];
for t in 0..hidden_states.dim(1)? {
recurrent_states =
(recurrent_gate.i((.., t))?.to_dtype(acc_dtype)? * recurrent_states)?;
recurrent_states =
(recurrent_states + hidden_states.i((.., t))?.to_dtype(acc_dtype)?)?;
contextualized_states.push(recurrent_states.to_dtype(in_dtype)?)
}
let contextualized_states = Tensor::stack(&contextualized_states, 1)?;
(contextualized_states, recurrent_states)
};
Ok((c, r))
}
#[derive(Debug, Clone)]
struct RecurrentBlock {
linear_y: Linear,
linear_x: Linear,
linear_out: Linear,
conv_1d: candle_nn::Conv1d,
conv1d_state: Option<Tensor>,
conv1d_width: usize,
rg_lru: Rglru,
act_fn: candle_nn::Activation,
}
impl RecurrentBlock {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let h = cfg.hidden_size;
let lru_width = cfg.lru_width.unwrap_or(h);
let linear_y = linear(h, lru_width, true, vb.pp("linear_y"))?;
let linear_x = linear(h, lru_width, true, vb.pp("linear_x"))?;
let linear_out = linear(lru_width, h, true, vb.pp("linear_out"))?;
let conv_1d = candle_nn::conv1d(
lru_width,
lru_width,
cfg.conv1d_width,
candle_nn::Conv1dConfig {
groups: lru_width,
padding: cfg.conv1d_width - 1,
..Default::default()
},
vb.pp("conv_1d"),
)?;
let rg_lru = Rglru::new(cfg, vb.pp("rg_lru"))?;
Ok(Self {
linear_y,
linear_x,
linear_out,
conv_1d,
conv1d_state: None,
conv1d_width: cfg.conv1d_width,
rg_lru,
act_fn: cfg.hidden_activation,
})
}
pub fn forward(&mut self, xs: &Tensor, pos: usize) -> Result<Tensor> {
let (_b_sz, seq_len, _) = xs.dims3()?;
let y_branch = xs.apply(&self.linear_y)?.apply(&self.act_fn)?;
let x_branch = xs.apply(&self.linear_x)?.transpose(1, 2)?;
let x_branch = if pos == 0 {
let x_len = x_branch.dim(D::Minus1)?;
let pad = self.conv1d_width as i64 - x_len as i64 - 1;
let padded = match pad.cmp(&0) {
std::cmp::Ordering::Equal => x_branch.clone(),
std::cmp::Ordering::Less => {
let rev_pad = (-pad) as usize;
x_branch.narrow(D::Minus1, rev_pad, x_len - rev_pad)?
}
std::cmp::Ordering::Greater => {
x_branch.pad_with_zeros(D::Minus1, pad as usize, 0)?
}
};
self.conv1d_state = Some(padded);
x_branch
.apply(&self.conv_1d)?
.narrow(D::Minus1, 0, seq_len)?
} else {
let conv_state = match self.conv1d_state.as_ref() {
None => candle::bail!("empty cache despite pos > 0"),
Some(s) => Tensor::cat(&[s, &x_branch], D::Minus1)?,
};
let w = self.conv_1d.weight().i((.., 0, ..))?;
let x_branch = conv_state.broadcast_mul(&w)?.sum(D::Minus1)?;
let x_branch = match self.conv_1d.bias() {
None => x_branch,
Some(b) => x_branch.broadcast_add(b)?,
};
let x_branch = x_branch.unsqueeze(D::Minus1)?;
self.conv1d_state = Some(conv_state.i((.., .., 1..))?);
x_branch
};
let x_branch = x_branch.transpose(1, 2)?;
let x_branch = self.rg_lru.forward(&x_branch, pos)?;
(x_branch * y_branch)?.apply(&self.linear_out)
}
}
#[derive(Debug, Clone)]
struct SdpaAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
n_heads: usize,
n_kv_heads: usize,
head_dim: usize,
hidden_size: usize,
kv_cache: Option<(Tensor, Tensor)>,
rotary_emb: Arc<RotaryEmbedding>,
}
impl SdpaAttention {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let h = cfg.hidden_size;
let n_heads = cfg.num_attention_heads;
let n_kv_heads = cfg.num_key_value_heads;
let hd = cfg.head_dim;
let q_proj = linear(h, n_heads * hd, cfg.attention_bias, vb.pp("q_proj"))?;
let k_proj = linear(h, n_kv_heads * hd, cfg.attention_bias, vb.pp("k_proj"))?;
let v_proj = linear(h, n_kv_heads * hd, cfg.attention_bias, vb.pp("v_proj"))?;
let o_proj = linear(n_heads * hd, h, true, vb.pp("o_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
n_heads,
n_kv_heads,
head_dim: hd,
hidden_size: h,
kv_cache: None,
rotary_emb,
})
}
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
let n_rep = self.n_heads / self.n_kv_heads;
crate::utils::repeat_kv(x, n_rep)
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
pos: usize,
) -> Result<Tensor> {
let (bsz, q_len, _) = xs.dims3()?;
let query_states = xs.apply(&self.q_proj)?;
let key_states = xs.apply(&self.k_proj)?;
let value_states = xs.apply(&self.v_proj)?;
let query_states = query_states
.reshape((bsz, q_len, self.n_heads, self.head_dim))?
.transpose(1, 2)?;
let key_states = key_states
.reshape((bsz, q_len, self.n_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let value_states = value_states
.reshape((bsz, q_len, self.n_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let query_states = query_states.chunk(2, D::Minus1)?;
let key_states = key_states.chunk(2, D::Minus1)?;
let (query_rot, key_rot) =
self.rotary_emb
.apply_rotary_emb_qkv(&query_states[0], &key_states[0], pos)?;
let query_states = Tensor::cat(&[&query_rot, &query_states[1]], D::Minus1)?.contiguous()?;
let key_states = Tensor::cat(&[&key_rot, &key_states[1]], D::Minus1)?.contiguous()?;
let (key_states, value_states) = match &self.kv_cache {
None => (key_states, value_states),
Some((prev_k, prev_v)) => {
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
(key_states, value_states)
}
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
let key_states = self.repeat_kv(key_states)?;
let value_states = self.repeat_kv(value_states)?;
let xs = {
let att = (query_states.matmul(&key_states.t()?)? / (self.head_dim as f64).sqrt())?;
let att = if q_len == 1 {
att
} else {
match attention_mask {
None => att,
Some(mask) => att.broadcast_add(mask)?,
}
};
let att = candle_nn::ops::softmax_last_dim(&att)?;
att.matmul(&value_states.contiguous()?)?
};
let xs = xs
.transpose(1, 2)?
.reshape((bsz, q_len, self.hidden_size))?;
self.o_proj.forward(&xs)
}
}
#[derive(Debug, Clone)]
enum TemporalBlock {
Recurrent(RecurrentBlock),
Attention(SdpaAttention),
}
impl TemporalBlock {
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
pos: usize,
) -> Result<Tensor> {
match self {
Self::Recurrent(b) => b.forward(xs, pos),
Self::Attention(b) => b.forward(xs, attention_mask, pos),
}
}
}
#[derive(Debug, Clone)]
struct DecoderLayer {
temporal_pre_norm: RmsNorm,
channel_pre_norm: RmsNorm,
temporal_block: TemporalBlock,
mlp_block: Mlp,
}
impl DecoderLayer {
fn new(
block_idx: usize,
rotary_emb: Arc<RotaryEmbedding>,
cfg: &Config,
vb: VarBuilder,
) -> Result<Self> {
let h = cfg.hidden_size;
let temporal_pre_norm = RmsNorm::new(h, cfg.rms_norm_eps, vb.pp("temporal_pre_norm"))?;
let channel_pre_norm = RmsNorm::new(h, cfg.rms_norm_eps, vb.pp("channel_pre_norm"))?;
let temporal_block = match cfg.block_types[block_idx % cfg.block_types.len()] {
TemporalBlockType::Recurrent => {
let block = RecurrentBlock::new(cfg, vb.pp("temporal_block"))?;
TemporalBlock::Recurrent(block)
}
TemporalBlockType::Attention => {
let block = SdpaAttention::new(rotary_emb, cfg, vb.pp("temporal_block"))?;
TemporalBlock::Attention(block)
}
};
let mlp_block = Mlp::new(cfg, vb.pp("mlp_block"))?;
Ok(Self {
temporal_pre_norm,
channel_pre_norm,
temporal_block,
mlp_block,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
pos: usize,
) -> Result<Tensor> {
let residual = xs;
let xs = xs.apply(&self.temporal_pre_norm)?;
let xs = self.temporal_block.forward(&xs, attention_mask, pos)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = xs.apply(&self.channel_pre_norm)?.apply(&self.mlp_block)?;
xs + residual
}
}
#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,
final_norm: RmsNorm,
lm_head: Linear,
hidden_size: usize,
logits_soft_cap: f64,
dtype: DType,
device: Device,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let embed_tokens =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?;
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?);
let vb_b = vb.pp("layers");
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
for idx in 0..cfg.num_hidden_layers {
let layer = DecoderLayer::new(idx, rotary_emb.clone(), cfg, vb_b.pp(idx))?;
layers.push(layer)
}
let final_norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("final_norm"))?;
let lm_head = Linear::new(embed_tokens.embeddings().clone(), None);
Ok(Self {
embed_tokens,
layers,
final_norm,
lm_head,
hidden_size: cfg.hidden_size,
logits_soft_cap: cfg.logits_soft_cap,
dtype: vb.dtype(),
device: vb.device().clone(),
})
}
fn prepare_decoder_attention_mask(
&self,
b_size: usize,
tgt_len: usize,
seqlen_offset: usize,
) -> Result<Tensor> {
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(self.dtype)
}
pub fn forward(&mut self, xs: &Tensor, pos: usize) -> Result<Tensor> {
let (b_size, seq_len) = xs.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, pos)?;
Some(mask)
};
let xs = xs.apply(&self.embed_tokens)?;
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), pos)?;
}
let logits = xs
.narrow(1, seq_len - 1, 1)?
.apply(&self.final_norm)?
.apply(&self.lm_head)?;
let logits = ((logits / self.logits_soft_cap)?.tanh()? * self.logits_soft_cap)?;
Ok(logits)
}
}

View File

@ -217,18 +217,6 @@ impl Attention {
})
}
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
let n_rep = self.num_kv_groups;
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
xs.unsqueeze(2)?
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
}
}
fn forward(
&mut self,
xs: &Tensor,
@ -275,8 +263,9 @@ impl Attention {
self.kv_cache = Some((key_states.clone(), value_states.clone()));
}
let key_states = self.repeat_kv(key_states)?.contiguous()?;
let value_states = self.repeat_kv(value_states)?.contiguous()?;
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
let value_states =
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
let attn_output = if self.use_flash_attn {
// flash-attn expects (b_sz, seq_len, nheads, head_dim)

View File

@ -139,18 +139,6 @@ impl Attention {
})
}
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
let n_rep = self.num_kv_groups;
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
xs.unsqueeze(2)?
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
}
}
fn forward(
&mut self,
xs: &Tensor,
@ -187,8 +175,8 @@ impl Attention {
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
let key_states = self.repeat_kv(key_states)?;
let value_states = self.repeat_kv(value_states)?;
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;

View File

@ -183,7 +183,7 @@ impl Module for T5LayerNorm {
let xs_f32 = xs.to_dtype(DType::F32)?;
// variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
let xs = xs_f32.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
let xs = xs.to_dtype(dtype)?;
let xs = xs.broadcast_mul(&self.weight)?;
Ok(xs)
@ -709,8 +709,10 @@ impl T5EncoderModel {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let shared_vb = if vb.contains_tensor("shared.weight") {
vb.pp("shared")
} else {
} else if vb.contains_tensor("decoder.embed_tokens") {
vb.pp("decoder").pp("embed_tokens")
} else {
vb.pp("encoder").pp("embed_tokens")
};
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
let shared = Arc::new(shared);

View File

@ -175,18 +175,6 @@ impl Attention {
})
}
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
let n_rep = self.num_kv_groups;
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
xs.unsqueeze(2)?
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
}
}
fn forward(
&mut self,
xs: &Tensor,
@ -223,8 +211,8 @@ impl Attention {
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
let key_states = self.repeat_kv(key_states)?;
let value_states = self.repeat_kv(value_states)?;
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);

View File

@ -63,7 +63,7 @@ impl VarBuilder {
let path = self.path(name);
match self.data.get(&path) {
None => {
candle::bail!("cannot find tensor {name}")
candle::bail!("cannot find tensor {path}")
}
Some(qtensor) => {
let shape = s.into();

View File

@ -2,7 +2,7 @@ use candle::{Result, Tensor};
pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result<Tensor> {
let device = logits.device();
let mut logits = logits.to_vec1::<f32>()?;
let mut logits = logits.to_dtype(candle::DType::F32)?.to_vec1::<f32>()?;
let mut already_seen = std::collections::HashSet::new();
for token_id in context {
if already_seen.contains(token_id) {
@ -20,3 +20,17 @@ pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> R
let logits_len = logits.len();
Tensor::from_vec(logits, logits_len, device)
}
/// Repeats a key or value tensor for grouped query attention
/// The input tensor should have a shape `(batch, num_kv_heads, seq_len, head_dim)`,
pub fn repeat_kv(xs: Tensor, n_rep: usize) -> Result<Tensor> {
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, n_kv_head, seq_len, head_dim) = xs.dims4()?;
// Using cat is faster than a broadcast as it avoids going through a potentially
// strided copy.
// https://github.com/huggingface/candle/pull/2043
Tensor::cat(&vec![&xs; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
}
}