Compare commits

..

9 Commits
0.9.0 ... 0.9.1

Author SHA1 Message Date
cd96fa80da Add a scattered kv cache. (#2936)
* Add a scattered kv cache.

* Update some comments.
2025-05-01 10:20:48 +02:00
8a19bb7df2 Bump the candle version to 0.9.1. (#2935) 2025-05-01 10:08:16 +02:00
38fc86621c Add support for Helium-v1. (#2932) 2025-04-30 19:38:44 +02:00
5029ac52bb Added tracing page to the candle book. (#2922)
* tracing page

* warned about asynchronous execution

* cleanup

* added Nsignt Systems recommendation
2025-04-29 21:35:36 +02:00
de23d34a28 Switch Tensor::full to return a contiguous tensor. (#2929) 2025-04-28 21:36:39 +02:00
d4bac37a61 Fix the gumbel softmax by casting to f32. (#2928) 2025-04-28 19:48:51 +02:00
e98754fc5a Optimize Tensor::new when called on nested Vec<..>. (#2927)
* Optimize Tensor::new when called on nested Vec<..>.

* Improve performance.

* Similar flattening for the 4d case.

* More tweaks.

* Add some dummy test.
2025-04-28 09:19:45 +02:00
e3db30021f Support for "unbatched" rope. (#2926)
* Support for (un)-batched rope.

* Use 3d rope in the rope/ropei/rope_thd functions.

* Get the CPU versions to work.

* Fix the cuda version.

* Adapt the metal side.

* Fix the metal tests.
2025-04-27 15:12:02 +02:00
6e0646c208 Remove redundant mlx gemm dtype check (#2925) 2025-04-27 06:14:57 +02:00
23 changed files with 909 additions and 123 deletions

View File

@ -20,7 +20,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.9.0"
version = "0.9.1"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -33,14 +33,14 @@ ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0" }
candle-datasets = { path = "./candle-datasets", version = "0.9.0" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0" }
candle-kernels = { path = "./candle-kernels", version = "0.9.0" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0" }
candle-nn = { path = "./candle-nn", version = "0.9.0" }
candle-onnx = { path = "./candle-onnx", version = "0.9.0" }
candle-transformers = { path = "./candle-transformers", version = "0.9.0" }
candle = { path = "./candle-core", package = "candle-core", version = "0.9.1" }
candle-datasets = { path = "./candle-datasets", version = "0.9.1" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.1" }
candle-kernels = { path = "./candle-kernels", version = "0.9.1" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.1" }
candle-nn = { path = "./candle-nn", version = "0.9.1" }
candle-onnx = { path = "./candle-onnx", version = "0.9.1" }
candle-transformers = { path = "./candle-transformers", version = "0.9.1" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.16.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }

View File

@ -16,6 +16,7 @@
- [Running a model](inference/inference.md)
- [Using the hub](inference/hub.md)
- [Error management](error_manage.md)
- [Tracing](tracing.md)
- [Training](training/training.md)
- [Simplified](training/simplified.md)
- [MNIST](training/mnist.md)

View File

@ -0,0 +1,68 @@
# Tracing
Tracing is a powerful tool for identifying performance issues and bottlenecks in code.
> Profiling on GPUs is trickier due to asynchronous execution, see the [GPU section](#gpu).
## Overview
Candle uses the [tracing](https://docs.rs/tracing/latest/tracing/) crate for instrumentation.
To try it out, run an example in `candle-examples` with the `--tracing` flag.
This generates a trace file, typically named `trace-<timestamp>.json`.
You can view the trace in Chrome by navigating to `chrome://tracing/`, clicking **Load**, and selecting the generated trace file.
## Adding Tracing
Candle includes built-in tracing for many internal operations, using [spans](https://docs.rs/tracing/latest/tracing/struct.Span.html) to mark key points of execution.
To add custom tracing in your code, you can define a span like this:
```rust
let span = tracing::span!(tracing::Level::TRACE, name);
```
Then, to record the span during execution, create a guard:
```rust
let _enter = span.enter();
```
This guard will record the span's duration, from when it is created to when it is dropped, into a global data structure managed by the tracing crate.
## Recording and Saving a Trace
To capture and save trace data, you need to configure the tracing system with an output format. Candle uses the [tracing_subscriber](https://docs.rs/tracing-subscriber/latest/tracing_subscriber/) and [tracing_chrome](https://docs.rs/tracing-chrome/latest/tracing_chrome/) crates.
The snippet below sets up a Chrome compatible recorder that logs all tracing activity between creation and drop of the guard:
```rust
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let _guard = {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
guard
};
```
## GPU
When using CUDA, Metal, or other asynchronous GPU backends, tracing may produce misleading timing data because operations are queued rather than executed immediately.
### CUDA
For CUDA-specific profiling, you have two options:
1. Set the environment variable `CUDA_LAUNCH_BLOCKING=1` which forces synchronous execution. This makes trace timings more accurate, at the cost of reduced performance.
2. Use [NVIDIA's Nsight Systems](https://developer.nvidia.com/nsight-systems) (`nsys profile` and `nsys-ui`) which are designed specifically for profiling asynchronous CUDA executions.
We recommend using NVIDIA's Nsight Systems when possible, as it offers accurate performance data without altering typical execution patterns. In contrast, setting the `CUDA_LAUNCH_BLOCKING` environment variable forces synchronous execution, which can significantly alter execution behavior.
#### Performance Profiling with NVIDIA Nsight Systems
1. Generate an `.nsys-rep` file containing performance data ([docs](https://docs.nvidia.com/nsight-systems/UserGuide/index.html#example-single-command-lines))
- Run `nsys profile --trace cuda,nvtx,osrt --gpu-metrics-device=all --output profile_run ./target/debug/... --prompt "whatever "`
1. Open the generated `.nsys-rep` report file in Nsight Systems GUI
- File > Open

View File

@ -4,11 +4,12 @@ use criterion::criterion_main;
criterion_main!(
benchmarks::affine::benches,
benchmarks::copy::benches,
benchmarks::conv_transpose2d::benches,
benchmarks::matmul::benches,
benchmarks::qmatmul::benches,
benchmarks::random::benches,
benchmarks::reduce::benches,
benchmarks::unary::benches,
benchmarks::where_cond::benches,
benchmarks::conv_transpose2d::benches,
benchmarks::qmatmul::benches,
benchmarks::unary::benches
);

View File

@ -0,0 +1,38 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{Device, Tensor, WithDType};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run_copy_mask_benchmark<D: WithDType>(c: &mut Criterion, device: &Device, name: &str) {
let batch_size = 128;
let in_seq_len = 1;
let kv_seq_len = 1024;
let attn_mask = vec![vec![vec![D::zero(); kv_seq_len]; in_seq_len]; batch_size];
let size_in_bytes = batch_size * in_seq_len * kv_seq_len * D::DTYPE.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(size_in_bytes as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let attn_masks = vec![attn_mask.clone(); iters as usize];
let start = Instant::now();
for attn_mask in attn_masks.into_iter() {
let tensor = Tensor::new(black_box(attn_mask), device).unwrap();
black_box(tensor);
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_copy_mask_benchmark::<f32>(c, &device, "copy_mask");
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -1,5 +1,6 @@
pub(crate) mod affine;
pub(crate) mod conv_transpose2d;
pub(crate) mod copy;
pub(crate) mod matmul;
pub(crate) mod qmatmul;
pub(crate) mod random;

View File

@ -103,7 +103,63 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4:
}
}
impl<S: NdArray> NdArray for Vec<S> {
impl<S: WithDType> NdArray for Vec<S> {
fn shape(&self) -> Result<Shape> {
Ok(Shape::from(self.len()))
}
fn to_cpu_storage(&self) -> CpuStorage {
S::to_cpu_storage(self.as_slice())
}
}
impl<S: WithDType> NdArray for Vec<&[S]> {
fn shape(&self) -> Result<Shape> {
if self.is_empty() {
crate::bail!("empty array")
}
let n = self.len();
let m = self[0].len();
for v in self.iter() {
if v.len() != m {
crate::bail!("two elements have different len {m} {}", v.len())
}
}
Ok(Shape::from((n, m)))
}
fn to_cpu_storage(&self) -> CpuStorage {
let data = self.iter().copied().flatten().copied().collect::<Vec<_>>();
S::to_cpu_storage_owned(data)
}
}
impl<S: WithDType> NdArray for Vec<Vec<S>> {
fn shape(&self) -> Result<Shape> {
if self.is_empty() {
crate::bail!("empty array")
}
let n = self.len();
let m = self[0].len();
for v in self.iter() {
if v.len() != m {
crate::bail!("two elements have different len {m} {}", v.len())
}
}
Ok(Shape::from((n, m)))
}
fn to_cpu_storage(&self) -> CpuStorage {
let len: usize = self.iter().map(|v| v.len()).sum();
let mut dst = Vec::with_capacity(len);
for v in self.iter() {
dst.extend(v.iter().copied());
}
S::to_cpu_storage_owned(dst)
}
}
impl<S: WithDType> NdArray for Vec<Vec<Vec<S>>> {
fn shape(&self) -> Result<Shape> {
if self.is_empty() {
crate::bail!("empty array")
@ -120,9 +176,57 @@ impl<S: NdArray> NdArray for Vec<S> {
}
fn to_cpu_storage(&self) -> CpuStorage {
// This allocates intermediary memory and shouldn't be necessary.
let storages = self.iter().map(|v| v.to_cpu_storage()).collect::<Vec<_>>();
CpuStorage::concat(storages.as_slice()).unwrap()
if self.is_empty() {
return S::to_cpu_storage_owned(vec![]);
}
let len: usize = self
.iter()
.map(|v| v.iter().map(|v| v.len()).sum::<usize>())
.sum();
let mut dst = Vec::with_capacity(len);
for v1 in self.iter() {
for v2 in v1.iter() {
dst.extend(v2.iter().copied());
}
}
S::to_cpu_storage_owned(dst)
}
}
impl<S: WithDType> NdArray for Vec<Vec<Vec<Vec<S>>>> {
fn shape(&self) -> Result<Shape> {
if self.is_empty() {
crate::bail!("empty array")
}
let shape0 = self[0].shape()?;
let n = self.len();
for v in self.iter() {
let shape = v.shape()?;
if shape != shape0 {
crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
}
}
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
}
fn to_cpu_storage(&self) -> CpuStorage {
let len: usize = self
.iter()
.map(|v| {
v.iter()
.map(|v| v.iter().map(|v| v.len()).sum::<usize>())
.sum::<usize>()
})
.sum();
let mut dst = Vec::with_capacity(len);
for v1 in self.iter() {
for v2 in v1.iter() {
for v3 in v2.iter() {
dst.extend(v3.iter().copied());
}
}
}
S::to_cpu_storage_owned(dst)
}
}

View File

@ -1655,50 +1655,32 @@ impl BackendStorage for MetalStorage {
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("matmul");
if self.dtype == DType::BF16 {
candle_metal_kernels::call_mlx_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
candle_metal_kernels::GemmDType::BF16,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
} else {
let dtype = match self.dtype {
DType::F32 => candle_metal_kernels::GemmDType::F32,
DType::F16 => candle_metal_kernels::GemmDType::F16,
DType::BF16 => candle_metal_kernels::GemmDType::BF16,
dtype => {
return Err(MetalError::Message(format!(
"mlx matmul doesn't support {dtype:?}"
))
.into())
}
};
candle_metal_kernels::call_mlx_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
dtype,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
}
let dtype = match self.dtype {
DType::F32 => candle_metal_kernels::GemmDType::F32,
DType::F16 => candle_metal_kernels::GemmDType::F16,
DType::BF16 => candle_metal_kernels::GemmDType::BF16,
dtype => {
return Err(
MetalError::Message(format!("mlx matmul doesn't support {dtype:?}")).into(),
)
}
};
candle_metal_kernels::call_mlx_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
dtype,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(
buffer,
self.device.clone(),

View File

@ -382,8 +382,7 @@ impl Tensor {
Self::new_impl(array, shape, device, false)
}
/// Returns a new tensor with all the elements having the same specified value. Note that
/// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
/// Returns a new tensor with all the elements having the same specified value.
///```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::full(3.5, (2, 4), &Device::Cpu)?;
@ -398,7 +397,12 @@ impl Tensor {
shape: S,
device: &Device,
) -> Result<Self> {
Self::from_vec_impl(vec![value], (), device, false)?.broadcast_as(shape)
let none = BackpropOp::none();
let shape = shape.into();
let mut storage = unsafe { device.alloc_uninit(&shape, D::DTYPE)? };
let layout = Layout::contiguous(shape.clone());
storage.const_set(value.to_scalar(), &layout)?;
Ok(from_storage(storage, shape, none, false))
}
/// Creates a new 1D tensor from an iterator.

View File

@ -1811,3 +1811,26 @@ fn test_flip_3d_channels() -> Result<()> {
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
Ok(())
}
#[test]
fn tensor_new() -> Result<()> {
let t1 = Tensor::new(vec![1f32, 2.0, 3.0], &Device::Cpu)?;
assert_eq!(t1.to_vec1::<f32>()?, [1.0, 2.0, 3.0]);
let t2 = Tensor::new(vec![vec![1f32, 2., 3.], vec![4., 5., 6.]], &Device::Cpu)?;
assert_eq!(t2.to_vec2::<f32>()?, [[1., 2., 3.], [4., 5., 6.]]);
let t3 = Tensor::new(
vec![
vec![vec![1f32, 2., 3.], vec![4., 5., 6.]],
vec![vec![3f32, 1., 4.], vec![1., 5., 9.]],
],
&Device::Cpu,
)?;
assert_eq!(
t3.to_vec3::<f32>()?,
[
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
[[3.0, 1.0, 4.0], [1.0, 5.0, 9.0]]
]
);
Ok(())
}

View File

@ -7,7 +7,10 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::helium::{Config, Model};
use candle_transformers::models::helium::{Config as ConfigPreview, Model as ModelPreview};
use candle_transformers::models::llama::{
Cache as CacheV1, Llama as ModelV1, LlamaConfig as ConfigV1, LlamaEosToks,
};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
@ -16,6 +19,44 @@ use candle_transformers::generation::{LogitsProcessor, Sampling};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
#[derive(Debug, Clone)]
enum Model {
V1 { model: ModelV1, cache: CacheV1 },
Preview(ModelPreview),
}
impl Model {
fn forward(&mut self, input: &Tensor, start_pos: usize) -> Result<Tensor> {
let model = match self {
Model::V1 { model, cache } => model.forward(input, start_pos, cache)?,
Model::Preview(m) => m.forward(input, start_pos)?,
};
Ok(model)
}
}
#[derive(Debug, Clone)]
enum Config {
V1(ConfigV1),
Preview(ConfigPreview),
}
impl Config {
fn bos_token_id(&self) -> Option<u32> {
match self {
Config::V1(c) => c.bos_token_id,
Config::Preview(c) => Some(c.bos_token_id),
}
}
fn eos_token_id(&self) -> Option<LlamaEosToks> {
match self {
Config::V1(c) => c.eos_token_id.clone(),
Config::Preview(c) => Some(LlamaEosToks::Single(c.eos_token_id)),
}
}
}
struct TextGeneration {
model: Model,
device: Device,
@ -106,7 +147,15 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == self.config.bos_token_id || next_token == self.config.eos_token_id {
let is_eos = self
.config
.eos_token_id()
.as_ref()
.is_some_and(|v| match v {
LlamaEosToks::Single(eos) => *eos == next_token,
LlamaEosToks::Multiple(eos) => eos.contains(&next_token),
});
if Some(next_token) == self.config.bos_token_id() || is_eos {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
@ -131,6 +180,8 @@ impl TextGeneration {
enum Which {
#[value(name = "v1-preview")]
V1Preview,
#[value(name = "v1")]
V1,
}
#[derive(Parser, Debug)]
@ -144,9 +195,6 @@ struct Args {
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
#[arg(long)]
prompt: String,
@ -171,7 +219,7 @@ struct Args {
sample_len: usize,
/// The model size to use.
#[arg(long, default_value = "v1-preview")]
#[arg(long, default_value = "v1")]
which: Which,
#[arg(long)]
@ -230,6 +278,7 @@ fn main() -> Result<()> {
None => {
let name = match args.which {
Which::V1Preview => "kyutai/helium-1-preview-2b",
Which::V1 => "kyutai/helium-1-2b",
};
name.to_string()
}
@ -254,18 +303,27 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config: Config = match args.config {
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
None => {
let config_file = repo.get("config.json")?;
serde_json::from_slice(&std::fs::read(config_file)?)?
}
let config_file = match args.config {
Some(config_file) => std::path::PathBuf::from(config_file),
None => repo.get("config.json")?,
};
let config = match args.which {
Which::V1Preview => Config::Preview(serde_json::from_slice(&std::fs::read(config_file)?)?),
Which::V1 => Config::V1(serde_json::from_slice(&std::fs::read(config_file)?)?),
};
let device = candle_examples::device(args.cpu)?;
let (model, device) = {
let dtype = device.bf16_default_to_f32();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
let model = match &config {
Config::V1(c) => {
let c = c.clone().into_config(false);
let model = ModelV1::load(vb, &c)?;
let cache = CacheV1::new(true, dtype, &c, &device)?;
Model::V1 { model, cache }
}
Config::Preview(c) => Model::Preview(ModelPreview::new(c, vb)?),
};
(model, device)
};

View File

@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
version = "0.9.0"
version = "0.9.1"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0" }
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.1" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]

View File

@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
version = "0.9.0"
version = "0.9.1"
edition = "2021"
description = "CUDA kernels for Candle"

View File

@ -219,11 +219,15 @@ __device__ void softmax(const T * x, T * dst, const int ncols) {
}
template <typename T>
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) {
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t stride_b) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (2 * idx >= bh * td) return;
uint32_t rope_idx = idx % (td / 2);
if (stride_b > 0) {
uint32_t b_idx = (2 * idx) / stride_b;
rope_idx += b_idx * (td / 2);
}
T c = cos[rope_idx];
T s = sin[rope_idx];
@ -232,7 +236,7 @@ __device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, cons
}
template <typename T>
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d) {
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d, const uint32_t stride_b) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (2 * idx >= bh * td) return;
@ -243,6 +247,10 @@ __device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const
uint32_t i1 = i_bh * td + i_t * d + i_d;
uint32_t i2 = i1 + d / 2;
uint32_t i_cs = i_t * (d / 2) + i_d;
if (stride_b > 0) {
uint32_t b_idx = (2 * idx) / stride_b;
i_cs += b_idx * (td / 2);
}
T c = cos[i_cs];
T s = sin[i_cs];
@ -259,7 +267,8 @@ __device__ void rope_thd(
const uint32_t b,
const uint32_t t,
const uint32_t h,
const uint32_t d
const uint32_t d,
const uint32_t stride_b
) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (2 * idx >= b * t * h * d) return;
@ -270,6 +279,10 @@ __device__ void rope_thd(
uint32_t i1 = i_bth * d + i_d;
uint32_t i2 = i1 + d / 2;
uint32_t i_cs = i_t * (d / 2) + i_d;
if (stride_b > 0) {
uint32_t b_idx = (2 * idx) / stride_b;
i_cs += b_idx * ((t * d) / 2);
}
T c = cos[i_cs];
T s = sin[i_cs];
@ -546,8 +559,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
const TYPENAME *sin, \
TYPENAME *dst, \
const uint32_t bh, \
const uint32_t td) { \
ropei<TYPENAME>(src, cos, sin, dst, bh, td); \
const uint32_t td, \
const uint32_t stride_b) { \
ropei<TYPENAME>(src, cos, sin, dst, bh, td, stride_b); \
} \
extern "C" __global__ void FN_NAME( \
const TYPENAME *src, \
@ -556,8 +570,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
TYPENAME *dst, \
const uint32_t bh, \
const uint32_t td, \
const uint32_t d) { \
rope<TYPENAME>(src, cos, sin, dst, bh, td, d); \
const uint32_t d, \
const uint32_t stride_b) { \
rope<TYPENAME>(src, cos, sin, dst, bh, td, d, stride_b); \
} \
extern "C" __global__ void FN_NAME_THD( \
const TYPENAME *src, \
@ -567,8 +582,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
const uint32_t b, \
const uint32_t t, \
const uint32_t h, \
const uint32_t d) { \
rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d); \
const uint32_t d, \
const uint32_t stride_b) { \
rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d, stride_b); \
} \
#if __CUDA_ARCH__ >= 800

View File

@ -1,6 +1,6 @@
[package]
name = "candle-metal-kernels"
version = "0.9.0"
version = "0.9.1"
edition = "2021"
description = "Metal kernels for Candle"

View File

@ -991,6 +991,7 @@ pub fn call_rope_i(
kernel_name: &'static str,
bh: usize,
td: usize,
stride_b: usize,
src: &Buffer,
src_offset: usize,
cos: &Buffer,
@ -1009,6 +1010,7 @@ pub fn call_rope_i(
(
bh,
td,
stride_b,
(src, src_offset),
(cos, cos_offset),
(sin, sin_offset),
@ -1034,6 +1036,7 @@ pub fn call_rope_thd(
t: usize,
h: usize,
d: usize,
stride_b: usize,
src: &Buffer,
src_offset: usize,
cos: &Buffer,
@ -1054,6 +1057,7 @@ pub fn call_rope_thd(
t,
h,
d,
stride_b,
(src, src_offset),
(cos, cos_offset),
(sin, sin_offset),
@ -1078,6 +1082,7 @@ pub fn call_rope(
bh: usize,
td: usize,
d: usize,
stride_b: usize,
src: &Buffer,
src_offset: usize,
cos: &Buffer,
@ -1097,6 +1102,7 @@ pub fn call_rope(
bh,
td,
d,
stride_b,
(src, src_offset),
(cos, cos_offset),
(sin, sin_offset),

View File

@ -1097,6 +1097,7 @@ template<typename T>
METAL_FUNC void ropei(
constant size_t &bh,
constant size_t &td,
constant size_t &stride_b,
device const T *src,
device const T *cos,
device const T *sin,
@ -1107,6 +1108,10 @@ METAL_FUNC void ropei(
return;
}
size_t rope_idx = tid % (td / 2);
if (stride_b > 0) {
size_t b_idx = (2 * tid) / stride_b;
rope_idx += b_idx * (td / 2);
}
T c = cos[rope_idx];
T s = sin[rope_idx];
dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s;
@ -1118,6 +1123,7 @@ METAL_FUNC void rope(
constant size_t &bh,
constant size_t &td,
constant size_t &d,
constant size_t &stride_b,
device const T *src,
device const T *cos,
device const T *sin,
@ -1134,6 +1140,10 @@ METAL_FUNC void rope(
size_t i1 = i_bh * td + i_t * d + i_d;
size_t i2 = i1 + d / 2;
size_t i_cs = i_t * (d / 2) + i_d;
if (stride_b > 0) {
size_t b_idx = (2 * idx) / stride_b;
i_cs += b_idx * (td / 2);
}
T c = cos[i_cs];
T s = sin[i_cs];
dst[i1] = src[i1] * c - src[i2] * s;
@ -1146,6 +1156,7 @@ METAL_FUNC void rope_thd(
constant size_t &t,
constant size_t &h,
constant size_t &d,
constant size_t &stride_b,
device const T *src,
device const T *cos,
device const T *sin,
@ -1160,8 +1171,12 @@ METAL_FUNC void rope_thd(
const size_t i_t = (i_bth / h) % t;
const size_t i1 = i_bth * d + i_d;
const size_t i2 = i1 + d / 2;
const size_t i_cs = i_t * (d / 2) + i_d;
T c = cos[i_cs];
size_t i_cs = i_t * (d / 2) + i_d;
if (stride_b > 0) {
const size_t b_idx = (2 * idx) / stride_b;
i_cs += b_idx * ((t * d) / 2);
}
T c = cos[i_cs];
T s = sin[i_cs];
dst[i1] = src[i1] * c - src[i2] * s;
dst[i2] = src[i1] * s + src[i2] * c;
@ -1171,38 +1186,41 @@ METAL_FUNC void rope_thd(
kernel void FN_NAME_I( \
constant size_t &bh, \
constant size_t &td, \
constant size_t &stride_b, \
device const TYPENAME *src, \
device const TYPENAME *cos, \
device const TYPENAME *sin, \
device TYPENAME *dst, \
uint tid [[ thread_position_in_grid ]] \
) { \
ropei<TYPENAME>(bh, td, src, cos, sin, dst, tid); \
ropei<TYPENAME>(bh, td, stride_b, src, cos, sin, dst, tid); \
}\
kernel void FN_NAME( \
constant size_t &bh, \
constant size_t &td, \
constant size_t &d, \
constant size_t &stride_b, \
device const TYPENAME *src, \
device const TYPENAME *cos, \
device const TYPENAME *sin, \
device TYPENAME *dst, \
uint idx [[ thread_position_in_grid ]] \
) { \
rope<TYPENAME>(bh, td, d, src, cos, sin, dst, idx); \
rope<TYPENAME>(bh, td, d, stride_b, src, cos, sin, dst, idx); \
}\
kernel void FN_NAME_THD( \
constant size_t &b, \
constant size_t &t, \
constant size_t &h, \
constant size_t &d, \
constant size_t &stride_b, \
device const TYPENAME *src, \
device const TYPENAME *cos, \
device const TYPENAME *sin, \
device TYPENAME *dst, \
uint idx [[ thread_position_in_grid ]] \
) { \
rope_thd<TYPENAME>(b, t, h, d, src, cos, sin, dst, idx); \
rope_thd<TYPENAME>(b, t, h, d, stride_b, src, cos, sin, dst, idx); \
}\
RMSNORM(rmsnorm_f32, float)

View File

@ -1584,7 +1584,7 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
dim,
BufferOffset::zero_offset(&input_buffer),
BufferOffset::zero_offset(&ids_buffer),
&output,
BufferOffset::zero_offset(&output),
)
.unwrap();
command_buffer.commit();

View File

@ -1,6 +1,6 @@
//! Cache Implementations
//!
use candle::{Device, Result, Tensor};
use candle::{DType, Device, Result, Tensor};
#[derive(Debug, Clone)]
pub struct Cache {
@ -399,3 +399,322 @@ impl RotatingKvCache {
self.v.reset();
}
}
#[derive(Debug, Clone)]
pub struct IndicesAndMask {
indices: Tensor,
mask: Tensor,
}
impl IndicesAndMask {
pub fn mask(&self) -> &Tensor {
&self.mask
}
}
#[derive(Debug, Clone)]
pub struct ScatteredKvCache {
k: Tensor,
v: Tensor,
context: usize,
}
impl ScatteredKvCache {
pub fn append(
&mut self,
k: &Tensor,
v: &Tensor,
iam: &IndicesAndMask,
) -> Result<(Tensor, Tensor)> {
if self.context <= k.dim(2)? {
return Ok((k.clone(), v.clone()));
}
let indices = iam.indices.unsqueeze(2)?.unsqueeze(1)?;
let indices = indices.broadcast_as(k.shape())?.contiguous()?;
self.k.scatter_set(&indices, k, 2)?;
self.v.scatter_set(&indices, v, 2)?;
Ok((self.k.clone(), self.v.clone()))
}
pub fn k(&self) -> &Tensor {
&self.k
}
pub fn v(&self) -> &Tensor {
&self.v
}
}
#[derive(Debug, Clone)]
pub struct ScatteredCacheBuilder {
context: usize,
// The current position in the stream, this can be larger than context.
positions: Vec<usize>,
// The index where the next element will be stored.
indices: Vec<usize>,
dtype: DType,
device: Device,
}
impl ScatteredCacheBuilder {
pub fn new(batch_size: usize, context: usize, dtype: DType, device: &Device) -> Result<Self> {
let positions = vec![0; batch_size];
let indices = vec![0; batch_size];
Ok(Self {
positions,
indices,
context,
dtype,
device: device.clone(),
})
}
pub fn make_cache(&self, num_heads: usize, head_dim: usize) -> Result<ScatteredKvCache> {
let batch_size = self.batch_size();
let shape = (batch_size, num_heads, self.context, head_dim);
let k = Tensor::zeros(shape, self.dtype, self.device())?;
let v = Tensor::zeros(shape, self.dtype, self.device())?;
Ok(ScatteredKvCache {
k,
v,
context: self.context,
})
}
pub fn positions(&self) -> &[usize] {
&self.positions
}
pub fn reset(&mut self) {
self.positions.fill(0);
self.indices.fill(0);
}
pub fn batch_size(&self) -> usize {
self.positions.len()
}
pub fn reset_batch_index(&mut self, batch_index: usize) {
self.positions[batch_index] = 0;
self.indices[batch_index] = 0;
}
#[allow(clippy::needless_range_loop)]
pub fn indices_and_mask(
&mut self,
seq_len: usize,
batch_mask: &[bool],
) -> Result<IndicesAndMask> {
// mask shape is (b, h, t, k)
let context = self.context;
if self.context <= seq_len {
return self.indices_and_mask_abs(seq_len, batch_mask);
}
let mut attention_masks = Vec::with_capacity(self.batch_size());
let mut cache_indices = Vec::with_capacity(self.batch_size());
for (batch_i, &batch_mask) in batch_mask.iter().enumerate() {
if !batch_mask {
let masks: Vec<Vec<f32>> = vec![vec![0.0; context]; seq_len];
let indices = vec![self.indices[batch_i] as u32; seq_len];
attention_masks.push(masks);
cache_indices.push(indices);
} else {
let start_index = self.indices[batch_i];
let start_pos = self.positions[batch_i];
let mut masks: Vec<Vec<f32>> = Vec::with_capacity(seq_len);
let mut indices = Vec::with_capacity(seq_len);
let mut all_pos = vec![usize::MAX; context];
if start_pos < context {
for i in 0..start_pos {
all_pos[i] = i;
}
} else {
let offset = start_pos - start_index;
for i in 0..context {
all_pos[i] = if i < start_index {
i + offset
} else {
i + offset - context
};
}
}
for seq_i in 0..seq_len {
let index = self.indices[batch_i];
all_pos[index] = seq_i + start_pos;
indices.push(index as u32);
self.indices[batch_i] += 1;
self.positions[batch_i] += 1;
if self.indices[batch_i] >= self.context {
self.indices[batch_i] = 0;
}
}
for seq_i in 0..seq_len {
let my_pos = seq_i + start_pos;
let mask = all_pos
.iter()
.map(|&pos| {
if pos <= my_pos {
0.0
} else {
f32::NEG_INFINITY
}
})
.collect::<Vec<f32>>();
masks.push(mask);
}
attention_masks.push(masks);
cache_indices.push(indices);
}
}
// Flattening the attention mask then using Tensor::from_vec rather using Tensor::new ends
// up being almost 10x faster with candle 0.9.0. This has been fixed in candle 0.9.1.
let attention_masks = attention_masks
.into_iter()
.flat_map(|m| m.into_iter().flatten())
.collect::<Vec<f32>>();
let mask = Tensor::from_vec(attention_masks, ((), 1, seq_len, context), self.device())?
.to_dtype(self.dtype)?;
let indices = Tensor::new(cache_indices, self.device())?;
Ok(IndicesAndMask { indices, mask })
}
pub fn device(&self) -> &Device {
&self.device
}
#[allow(clippy::needless_range_loop)]
fn indices_and_mask_abs(
&mut self,
seq_len: usize,
batch_mask: &[bool],
) -> Result<IndicesAndMask> {
let mask = self.get_mask_abs(seq_len, seq_len)?;
let mut cache_indices = Vec::with_capacity(self.batch_size());
for (batch_i, &batch_mask) in batch_mask.iter().enumerate() {
if !batch_mask {
let indices = vec![self.indices[batch_i] as u32; seq_len];
cache_indices.push(indices);
} else {
let mut indices = Vec::with_capacity(seq_len);
for _ in 0..seq_len {
let index = self.indices[batch_i];
indices.push(index as u32);
self.indices[batch_i] += 1;
self.positions[batch_i] += 1;
if self.indices[batch_i] >= self.context {
self.indices[batch_i] = 0;
}
}
cache_indices.push(indices);
}
}
let indices = Tensor::new(cache_indices, self.device())?;
Ok(IndicesAndMask { indices, mask })
}
fn get_mask_abs(&self, size1: usize, size2: usize) -> Result<Tensor> {
let context = self.context;
let mask: Vec<_> = (0..size1)
.flat_map(|i| {
(0..size2).map(move |j| {
if size1 + j > size2 + i || size1 + j + context < size2 + i {
f32::NEG_INFINITY
} else {
0.0
}
})
})
.collect();
Tensor::from_slice(&mask, (size1, size2), self.device())
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle::IndexOp;
#[test]
fn test_scattered_kv_cache() -> Result<()> {
let device = Device::Cpu;
let mut cache = ScatteredCacheBuilder::new(2, 5, DType::F32, &device)?;
let inf = f32::INFINITY;
let iam = cache.indices_and_mask(1, &[true, false])?;
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
assert_eq!(iam.indices.to_vec2::<u32>()?, [[0], [0]]);
assert_eq!(
mask,
[[[0.0, -inf, -inf, -inf, -inf]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
);
let iam = cache.indices_and_mask(1, &[true, false])?;
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
assert_eq!(iam.indices.to_vec2::<u32>()?, [[1], [0]]);
assert_eq!(
mask,
[[[0.0, 0.0, -inf, -inf, -inf]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
);
let iam = cache.indices_and_mask(3, &[false, true])?;
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
assert_eq!(iam.indices.to_vec2::<u32>()?, [[2, 2, 2], [0, 1, 2]]);
assert_eq!(
mask,
[
[
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
],
[
[0.0, -inf, -inf, -inf, -inf],
[0.0, 0.0, -inf, -inf, -inf],
[0.0, 0.0, 0.0, -inf, -inf]
]
]
);
let iam = cache.indices_and_mask(3, &[true, true])?;
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
assert_eq!(iam.indices.to_vec2::<u32>()?, [[2, 3, 4], [3, 4, 0]]);
assert_eq!(
mask,
[
[
[0.0, 0.0, 0.0, -inf, -inf],
[0.0, 0.0, 0.0, 0.0, -inf],
[0.0, 0.0, 0.0, 0.0, 0.0]
],
[
[-inf, 0.0, 0.0, 0.0, -inf],
[-inf, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
]
]
);
let iam = cache.indices_and_mask(1, &[true, false])?;
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
assert_eq!(iam.indices.to_vec2::<u32>()?, [[0], [1]]);
assert_eq!(
mask,
[[[0.0, 0.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
);
let iam = cache.indices_and_mask(2, &[true, false])?;
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
assert_eq!(iam.indices.to_vec2::<u32>()?, [[1, 2], [1, 1]]);
assert_eq!(
mask,
[
[[0.0, 0.0, -inf, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]],
[[0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]]
]
);
Ok(())
}
}

View File

@ -46,15 +46,23 @@ impl candle::CustomOp3 for RotaryEmbI {
Some((o1, o2)) => &sin[o1..o2],
};
let (b, h, t, d) = l_src.shape().dims4()?;
let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
let el_count = b * h * t * d;
let mut dst = vec![T::zero(); el_count];
src.par_chunks(t * d)
.zip(dst.par_chunks_mut(t * d))
.for_each(|(src, dst)| {
.enumerate()
.for_each(|(bh_i, (src, dst))| {
for i_over_2 in 0..t * d / 2 {
let i = 2 * i_over_2;
dst[i] = src[i] * cos[i_over_2] - src[i + 1] * sin[i_over_2];
dst[i + 1] = src[i] * sin[i_over_2] + src[i + 1] * cos[i_over_2];
let rope_i = if unbatched_rope {
let b_i = bh_i / h;
i_over_2 + b_i * t * d / 2
} else {
i_over_2
};
dst[i] = src[i] * cos[rope_i] - src[i + 1] * sin[rope_i];
dst[i + 1] = src[i] * sin[rope_i] + src[i + 1] * cos[rope_i];
}
});
let storage = candle::WithDType::to_cpu_storage_owned(dst);
@ -115,6 +123,11 @@ impl candle::CustomOp3 for RotaryEmbI {
Some((o1, o2)) => sin.slice(o1..o2),
};
let (b, h, t, d) = l_src.shape().dims4()?;
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
(h * t * d) as u32
} else {
0u32
};
let el = b * h * t * d;
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), &kernels::REDUCE)?;
@ -125,7 +138,7 @@ impl candle::CustomOp3 for RotaryEmbI {
builder.arg(&cos);
builder.arg(&sin);
builder.arg(&dst);
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32);
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, stride_b);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
Ok(dst)
@ -182,6 +195,11 @@ impl candle::CustomOp3 for RotaryEmbI {
dtype => candle::bail!("rope-i is not implemented for {dtype:?}"),
};
let (b, h, t, d) = l_src.shape().dims4()?;
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
h * t * d
} else {
0usize
};
let el = b * h * t * d;
let output = device.new_buffer(el, src.dtype(), "rope-i")?;
candle_metal_kernels::call_rope_i(
@ -191,6 +209,7 @@ impl candle::CustomOp3 for RotaryEmbI {
name,
b * h,
t * d,
stride_b,
src.buffer(),
l_src.start_offset() * src.dtype().size_in_bytes(),
cos.buffer(),
@ -205,10 +224,23 @@ impl candle::CustomOp3 for RotaryEmbI {
}
}
fn rope_check_cs(cs: &Tensor, b_sz: usize) -> Result<(usize, usize)> {
match *cs.dims() {
[t, d] => Ok((t, d)),
[b, t, d] => {
if b != b_sz {
candle::bail!("inconsistent batch size in rope {b_sz} {cs:?}",)
}
Ok((t, d))
}
_ => candle::bail!("cos/sin has to be 2D or 3D in rope {b_sz} {cs:?}"),
}
}
pub fn rope_i(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
let (sin_seq_len, sin_n_embd) = cos.dims2()?;
let (b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
if cos_n_embd * 2 != n_embd
|| sin_n_embd * 2 != n_embd
|| seq_len > cos_seq_len
@ -292,16 +324,24 @@ impl candle::CustomOp3 for RotaryEmb {
Some((o1, o2)) => &sin[o1..o2],
};
let (b, h, t, d) = l_src.shape().dims4()?;
let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
let el_count = b * h * t * d;
let mut dst = vec![T::zero(); el_count];
src.par_chunks(t * d)
.zip(dst.par_chunks_mut(t * d))
.for_each(|(src, dst)| {
.enumerate()
.for_each(|(bh_i, (src, dst))| {
for i_t in 0..t {
for i_d in 0..d / 2 {
let i1 = i_t * d + i_d;
let i2 = i1 + d / 2;
let i_cs = i_t * (d / 2) + i_d;
let i_cs = if unbatched_rope {
let b_i = bh_i / h;
i_cs + b_i * t * d / 2
} else {
i_cs
};
dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs];
dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs];
}
@ -365,6 +405,11 @@ impl candle::CustomOp3 for RotaryEmb {
Some((o1, o2)) => sin.slice(o1..o2),
};
let (b, h, t, d) = l_src.shape().dims4()?;
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
(h * t * d) as u32
} else {
0u32
};
let el = b * h * t * d;
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), &kernels::REDUCE)?;
@ -375,7 +420,7 @@ impl candle::CustomOp3 for RotaryEmb {
builder.arg(&cos);
builder.arg(&sin);
builder.arg(&dst);
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32);
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32, stride_b);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
Ok(dst)
@ -432,6 +477,11 @@ impl candle::CustomOp3 for RotaryEmb {
dtype => candle::bail!("rope is not implemented for {dtype:?}"),
};
let (b, h, t, d) = l_src.shape().dims4()?;
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
h * t * d
} else {
0usize
};
let el = b * h * t * d;
let output = device.new_buffer(el, src.dtype(), "rope-i")?;
candle_metal_kernels::call_rope(
@ -442,6 +492,7 @@ impl candle::CustomOp3 for RotaryEmb {
b * h,
t * d,
d,
stride_b,
src.buffer(),
l_src.start_offset() * src.dtype().size_in_bytes(),
cos.buffer(),
@ -457,9 +508,9 @@ impl candle::CustomOp3 for RotaryEmb {
}
pub fn rope(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
let (sin_seq_len, sin_n_embd) = sin.dims2()?;
let (b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
if cos_n_embd * 2 != n_embd
|| sin_n_embd * 2 != n_embd
|| seq_len > cos_seq_len
@ -541,14 +592,21 @@ impl candle::CustomOp3 for RotaryEmbThd {
Some((o1, o2)) => &sin[o1..o2],
};
let (b, t, h, d) = l_src.shape().dims4()?;
let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
let el_count = b * h * t * d;
let mut dst = vec![T::zero(); el_count];
src.par_chunks(t * h * d)
.zip(dst.par_chunks_mut(t * h * d))
.for_each(|(src, dst)| {
.enumerate()
.for_each(|(b_i, (src, dst))| {
for i_t in 0..t {
for i_d in 0..d / 2 {
let i_cs = i_t * (d / 2) + i_d;
let i_cs = if unbatched_rope {
i_cs + b_i * t * d / 2
} else {
i_cs
};
for i_h in 0..h {
let i1 = i_t * h * d + i_h * d + i_d;
let i2 = i1 + d / 2;
@ -616,6 +674,11 @@ impl candle::CustomOp3 for RotaryEmbThd {
Some((o1, o2)) => sin.slice(o1..o2),
};
let (b, t, h, d) = l_src.shape().dims4()?;
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
(h * t * d) as u32
} else {
0u32
};
let el = b * h * t * d;
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), &kernels::REDUCE)?;
@ -626,7 +689,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
builder.arg(&cos);
builder.arg(&sin);
builder.arg(&dst);
candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32);
candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32, stride_b);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
Ok(dst)
@ -683,6 +746,11 @@ impl candle::CustomOp3 for RotaryEmbThd {
dtype => candle::bail!("rope_thd is not implemented for {dtype:?}"),
};
let (b, t, h, d) = l_src.shape().dims4()?;
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
h * t * d
} else {
0usize
};
let el = b * h * t * d;
let output = device.new_buffer(el, src.dtype(), "rope-thd")?;
candle_metal_kernels::call_rope_thd(
@ -694,6 +762,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
t,
h,
d,
stride_b,
src.buffer(),
l_src.start_offset() * src.dtype().size_in_bytes(),
cos.buffer(),
@ -709,9 +778,9 @@ impl candle::CustomOp3 for RotaryEmbThd {
}
pub fn rope_thd(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let (_b_sz, seq_len, _n_head, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
let (sin_seq_len, sin_n_embd) = sin.dims2()?;
let (b_sz, seq_len, _n_head, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
if cos_n_embd * 2 != n_embd
|| sin_n_embd * 2 != n_embd
|| seq_len > cos_seq_len

View File

@ -8,13 +8,16 @@ pub fn gumbel_softmax<D: candle::shape::Dim>(
) -> Result<Tensor> {
if temperature <= 0.0 {
logits.argmax(dim)
} else if temperature == 1.0 {
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
let sampled = (logits - minus_g)?.argmax(dim)?;
Ok(sampled)
} else {
// Cast to f32, doing the Gumbel softmax in bf16 is a bit unstable.
let logits = logits.to_dtype(candle::DType::F32)?;
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
Ok(sampled)
if temperature == 1.0 {
let sampled = (logits - minus_g)?.argmax(dim)?;
Ok(sampled)
} else {
let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
Ok(sampled)
}
}
}

View File

@ -4,7 +4,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{test_device, test_utils::to_vec3_round, Device, Result, Tensor};
use candle::{test_device, test_utils::to_vec3_round, Device, IndexOp, Result, Tensor};
fn softmax(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
@ -179,6 +179,28 @@ fn ropei(device: &Device) -> Result<()> {
} else {
assert!(sum_diff < 1e-4);
}
// Test with a 3d cos/sin
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
let rope1 = candle_nn::rotary_emb::rope_i(&src.i(0..1)?, &cos, &sin)?;
let rope2 = candle_nn::rotary_emb::rope_i(&src.i(1..2)?, &cos2, &sin2)?;
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
let both_rope = candle_nn::rotary_emb::rope_i(&src, &both_cos, &both_sin)?;
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
let sum_diff = (both_rope - both_rope2)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(sum_diff, 0.);
Ok(())
}
@ -206,6 +228,28 @@ fn rope(device: &Device) -> Result<()> {
} else {
assert!(sum_diff < 1e-4);
}
// Test with a 3d cos/sin
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
let rope1 = candle_nn::rotary_emb::rope(&src.i(0..1)?, &cos, &sin)?;
let rope2 = candle_nn::rotary_emb::rope(&src.i(1..2)?, &cos2, &sin2)?;
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
let both_rope = candle_nn::rotary_emb::rope(&src, &both_cos, &both_sin)?;
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
let sum_diff = (both_rope - both_rope2)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(sum_diff, 0.);
Ok(())
}
@ -236,6 +280,37 @@ fn rope_thd(device: &Device) -> Result<()> {
} else {
assert!(sum_diff < 1e-4);
}
// Test with a 3d cos/sin
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.random::<f32>())
.collect();
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
let rope1 = {
let src = src.transpose(1, 2)?.contiguous()?;
candle_nn::rotary_emb::rope_thd(&src.i(0..1)?, &cos, &sin)?
};
let rope2 = {
let src = src.transpose(1, 2)?.contiguous()?;
candle_nn::rotary_emb::rope_thd(&src.i(1..2)?, &cos2, &sin2)?
};
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
let both_rope = {
let src = src.transpose(1, 2)?.contiguous()?;
candle_nn::rotary_emb::rope_thd(&src, &both_cos, &both_sin)?
};
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
let sum_diff = (both_rope - both_rope2)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(sum_diff, 0.);
Ok(())
}

View File

@ -1,6 +1,6 @@
[package]
name = "candle-onnx"
version = "0.9.0"
version = "0.9.1"
edition = "2021"
description = "ONNX support for Candle"
@ -10,8 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0" }
candle-nn = { path = "../candle-nn", version = "0.9.0" }
candle = { path = "../candle-core", package = "candle-core", version = "0.9.1" }
candle-nn = { path = "../candle-nn", version = "0.9.1" }
prost = "0.12.1"
[build-dependencies]