mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Compare commits
9 Commits
Author | SHA1 | Date | |
---|---|---|---|
cd96fa80da | |||
8a19bb7df2 | |||
38fc86621c | |||
5029ac52bb | |||
de23d34a28 | |||
d4bac37a61 | |||
e98754fc5a | |||
e3db30021f | |||
6e0646c208 |
18
Cargo.toml
18
Cargo.toml
@ -20,7 +20,7 @@ exclude = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.9.0"
|
||||
version = "0.9.1"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -33,14 +33,14 @@ ab_glyph = "0.2.23"
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.9.0" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.9.0" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.9.0" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.9.0" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.9.0" }
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.9.1" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.9.1" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.1" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.9.1" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.1" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.9.1" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.9.1" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.9.1" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.16.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||
|
@ -16,6 +16,7 @@
|
||||
- [Running a model](inference/inference.md)
|
||||
- [Using the hub](inference/hub.md)
|
||||
- [Error management](error_manage.md)
|
||||
- [Tracing](tracing.md)
|
||||
- [Training](training/training.md)
|
||||
- [Simplified](training/simplified.md)
|
||||
- [MNIST](training/mnist.md)
|
||||
|
68
candle-book/src/tracing.md
Normal file
68
candle-book/src/tracing.md
Normal file
@ -0,0 +1,68 @@
|
||||
# Tracing
|
||||
|
||||
Tracing is a powerful tool for identifying performance issues and bottlenecks in code.
|
||||
|
||||
> Profiling on GPUs is trickier due to asynchronous execution, see the [GPU section](#gpu).
|
||||
|
||||
## Overview
|
||||
|
||||
Candle uses the [tracing](https://docs.rs/tracing/latest/tracing/) crate for instrumentation.
|
||||
|
||||
To try it out, run an example in `candle-examples` with the `--tracing` flag.
|
||||
This generates a trace file, typically named `trace-<timestamp>.json`.
|
||||
You can view the trace in Chrome by navigating to `chrome://tracing/`, clicking **Load**, and selecting the generated trace file.
|
||||
|
||||
## Adding Tracing
|
||||
|
||||
Candle includes built-in tracing for many internal operations, using [spans](https://docs.rs/tracing/latest/tracing/struct.Span.html) to mark key points of execution.
|
||||
|
||||
To add custom tracing in your code, you can define a span like this:
|
||||
|
||||
```rust
|
||||
let span = tracing::span!(tracing::Level::TRACE, name);
|
||||
```
|
||||
|
||||
Then, to record the span during execution, create a guard:
|
||||
|
||||
```rust
|
||||
let _enter = span.enter();
|
||||
```
|
||||
|
||||
This guard will record the span's duration, from when it is created to when it is dropped, into a global data structure managed by the tracing crate.
|
||||
|
||||
## Recording and Saving a Trace
|
||||
|
||||
To capture and save trace data, you need to configure the tracing system with an output format. Candle uses the [tracing_subscriber](https://docs.rs/tracing-subscriber/latest/tracing_subscriber/) and [tracing_chrome](https://docs.rs/tracing-chrome/latest/tracing_chrome/) crates.
|
||||
|
||||
The snippet below sets up a Chrome compatible recorder that logs all tracing activity between creation and drop of the guard:
|
||||
|
||||
```rust
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let _guard = {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
guard
|
||||
};
|
||||
```
|
||||
|
||||
## GPU
|
||||
|
||||
When using CUDA, Metal, or other asynchronous GPU backends, tracing may produce misleading timing data because operations are queued rather than executed immediately.
|
||||
|
||||
### CUDA
|
||||
|
||||
For CUDA-specific profiling, you have two options:
|
||||
|
||||
1. Set the environment variable `CUDA_LAUNCH_BLOCKING=1` which forces synchronous execution. This makes trace timings more accurate, at the cost of reduced performance.
|
||||
2. Use [NVIDIA's Nsight Systems](https://developer.nvidia.com/nsight-systems) (`nsys profile` and `nsys-ui`) which are designed specifically for profiling asynchronous CUDA executions.
|
||||
|
||||
We recommend using NVIDIA's Nsight Systems when possible, as it offers accurate performance data without altering typical execution patterns. In contrast, setting the `CUDA_LAUNCH_BLOCKING` environment variable forces synchronous execution, which can significantly alter execution behavior.
|
||||
|
||||
#### Performance Profiling with NVIDIA Nsight Systems
|
||||
|
||||
1. Generate an `.nsys-rep` file containing performance data ([docs](https://docs.nvidia.com/nsight-systems/UserGuide/index.html#example-single-command-lines))
|
||||
- Run `nsys profile --trace cuda,nvtx,osrt --gpu-metrics-device=all --output profile_run ./target/debug/... --prompt "whatever "`
|
||||
1. Open the generated `.nsys-rep` report file in Nsight Systems GUI
|
||||
- File > Open
|
@ -4,11 +4,12 @@ use criterion::criterion_main;
|
||||
|
||||
criterion_main!(
|
||||
benchmarks::affine::benches,
|
||||
benchmarks::copy::benches,
|
||||
benchmarks::conv_transpose2d::benches,
|
||||
benchmarks::matmul::benches,
|
||||
benchmarks::qmatmul::benches,
|
||||
benchmarks::random::benches,
|
||||
benchmarks::reduce::benches,
|
||||
benchmarks::unary::benches,
|
||||
benchmarks::where_cond::benches,
|
||||
benchmarks::conv_transpose2d::benches,
|
||||
benchmarks::qmatmul::benches,
|
||||
benchmarks::unary::benches
|
||||
);
|
||||
|
38
candle-core/benches/benchmarks/copy.rs
Normal file
38
candle-core/benches/benchmarks/copy.rs
Normal file
@ -0,0 +1,38 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{Device, Tensor, WithDType};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run_copy_mask_benchmark<D: WithDType>(c: &mut Criterion, device: &Device, name: &str) {
|
||||
let batch_size = 128;
|
||||
let in_seq_len = 1;
|
||||
let kv_seq_len = 1024;
|
||||
|
||||
let attn_mask = vec![vec![vec![D::zero(); kv_seq_len]; in_seq_len]; batch_size];
|
||||
let size_in_bytes = batch_size * in_seq_len * kv_seq_len * D::DTYPE.size_in_bytes();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(size_in_bytes as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let attn_masks = vec![attn_mask.clone(); iters as usize];
|
||||
let start = Instant::now();
|
||||
for attn_mask in attn_masks.into_iter() {
|
||||
let tensor = Tensor::new(black_box(attn_mask), device).unwrap();
|
||||
black_box(tensor);
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
run_copy_mask_benchmark::<f32>(c, &device, "copy_mask");
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -1,5 +1,6 @@
|
||||
pub(crate) mod affine;
|
||||
pub(crate) mod conv_transpose2d;
|
||||
pub(crate) mod copy;
|
||||
pub(crate) mod matmul;
|
||||
pub(crate) mod qmatmul;
|
||||
pub(crate) mod random;
|
||||
|
@ -103,7 +103,63 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4:
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: NdArray> NdArray for Vec<S> {
|
||||
impl<S: WithDType> NdArray for Vec<S> {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
Ok(Shape::from(self.len()))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
S::to_cpu_storage(self.as_slice())
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: WithDType> NdArray for Vec<&[S]> {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
if self.is_empty() {
|
||||
crate::bail!("empty array")
|
||||
}
|
||||
let n = self.len();
|
||||
let m = self[0].len();
|
||||
for v in self.iter() {
|
||||
if v.len() != m {
|
||||
crate::bail!("two elements have different len {m} {}", v.len())
|
||||
}
|
||||
}
|
||||
Ok(Shape::from((n, m)))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
let data = self.iter().copied().flatten().copied().collect::<Vec<_>>();
|
||||
S::to_cpu_storage_owned(data)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: WithDType> NdArray for Vec<Vec<S>> {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
if self.is_empty() {
|
||||
crate::bail!("empty array")
|
||||
}
|
||||
let n = self.len();
|
||||
let m = self[0].len();
|
||||
for v in self.iter() {
|
||||
if v.len() != m {
|
||||
crate::bail!("two elements have different len {m} {}", v.len())
|
||||
}
|
||||
}
|
||||
Ok(Shape::from((n, m)))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
let len: usize = self.iter().map(|v| v.len()).sum();
|
||||
let mut dst = Vec::with_capacity(len);
|
||||
for v in self.iter() {
|
||||
dst.extend(v.iter().copied());
|
||||
}
|
||||
S::to_cpu_storage_owned(dst)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: WithDType> NdArray for Vec<Vec<Vec<S>>> {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
if self.is_empty() {
|
||||
crate::bail!("empty array")
|
||||
@ -120,9 +176,57 @@ impl<S: NdArray> NdArray for Vec<S> {
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
// This allocates intermediary memory and shouldn't be necessary.
|
||||
let storages = self.iter().map(|v| v.to_cpu_storage()).collect::<Vec<_>>();
|
||||
CpuStorage::concat(storages.as_slice()).unwrap()
|
||||
if self.is_empty() {
|
||||
return S::to_cpu_storage_owned(vec![]);
|
||||
}
|
||||
let len: usize = self
|
||||
.iter()
|
||||
.map(|v| v.iter().map(|v| v.len()).sum::<usize>())
|
||||
.sum();
|
||||
let mut dst = Vec::with_capacity(len);
|
||||
for v1 in self.iter() {
|
||||
for v2 in v1.iter() {
|
||||
dst.extend(v2.iter().copied());
|
||||
}
|
||||
}
|
||||
S::to_cpu_storage_owned(dst)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: WithDType> NdArray for Vec<Vec<Vec<Vec<S>>>> {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
if self.is_empty() {
|
||||
crate::bail!("empty array")
|
||||
}
|
||||
let shape0 = self[0].shape()?;
|
||||
let n = self.len();
|
||||
for v in self.iter() {
|
||||
let shape = v.shape()?;
|
||||
if shape != shape0 {
|
||||
crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
|
||||
}
|
||||
}
|
||||
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
let len: usize = self
|
||||
.iter()
|
||||
.map(|v| {
|
||||
v.iter()
|
||||
.map(|v| v.iter().map(|v| v.len()).sum::<usize>())
|
||||
.sum::<usize>()
|
||||
})
|
||||
.sum();
|
||||
let mut dst = Vec::with_capacity(len);
|
||||
for v1 in self.iter() {
|
||||
for v2 in v1.iter() {
|
||||
for v3 in v2.iter() {
|
||||
dst.extend(v3.iter().copied());
|
||||
}
|
||||
}
|
||||
}
|
||||
S::to_cpu_storage_owned(dst)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1655,50 +1655,32 @@ impl BackendStorage for MetalStorage {
|
||||
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
command_buffer.set_label("matmul");
|
||||
if self.dtype == DType::BF16 {
|
||||
candle_metal_kernels::call_mlx_gemm(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
candle_metal_kernels::GemmDType::BF16,
|
||||
(b, m, n, k),
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&self.buffer,
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
let dtype = match self.dtype {
|
||||
DType::F32 => candle_metal_kernels::GemmDType::F32,
|
||||
DType::F16 => candle_metal_kernels::GemmDType::F16,
|
||||
DType::BF16 => candle_metal_kernels::GemmDType::BF16,
|
||||
dtype => {
|
||||
return Err(MetalError::Message(format!(
|
||||
"mlx matmul doesn't support {dtype:?}"
|
||||
))
|
||||
.into())
|
||||
}
|
||||
};
|
||||
candle_metal_kernels::call_mlx_gemm(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
dtype,
|
||||
(b, m, n, k),
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&self.buffer,
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
let dtype = match self.dtype {
|
||||
DType::F32 => candle_metal_kernels::GemmDType::F32,
|
||||
DType::F16 => candle_metal_kernels::GemmDType::F16,
|
||||
DType::BF16 => candle_metal_kernels::GemmDType::BF16,
|
||||
dtype => {
|
||||
return Err(
|
||||
MetalError::Message(format!("mlx matmul doesn't support {dtype:?}")).into(),
|
||||
)
|
||||
}
|
||||
};
|
||||
candle_metal_kernels::call_mlx_gemm(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
dtype,
|
||||
(b, m, n, k),
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&self.buffer,
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(Self::new(
|
||||
buffer,
|
||||
self.device.clone(),
|
||||
|
@ -382,8 +382,7 @@ impl Tensor {
|
||||
Self::new_impl(array, shape, device, false)
|
||||
}
|
||||
|
||||
/// Returns a new tensor with all the elements having the same specified value. Note that
|
||||
/// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
|
||||
/// Returns a new tensor with all the elements having the same specified value.
|
||||
///```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let a = Tensor::full(3.5, (2, 4), &Device::Cpu)?;
|
||||
@ -398,7 +397,12 @@ impl Tensor {
|
||||
shape: S,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
Self::from_vec_impl(vec![value], (), device, false)?.broadcast_as(shape)
|
||||
let none = BackpropOp::none();
|
||||
let shape = shape.into();
|
||||
let mut storage = unsafe { device.alloc_uninit(&shape, D::DTYPE)? };
|
||||
let layout = Layout::contiguous(shape.clone());
|
||||
storage.const_set(value.to_scalar(), &layout)?;
|
||||
Ok(from_storage(storage, shape, none, false))
|
||||
}
|
||||
|
||||
/// Creates a new 1D tensor from an iterator.
|
||||
|
@ -1811,3 +1811,26 @@ fn test_flip_3d_channels() -> Result<()> {
|
||||
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tensor_new() -> Result<()> {
|
||||
let t1 = Tensor::new(vec![1f32, 2.0, 3.0], &Device::Cpu)?;
|
||||
assert_eq!(t1.to_vec1::<f32>()?, [1.0, 2.0, 3.0]);
|
||||
let t2 = Tensor::new(vec![vec![1f32, 2., 3.], vec![4., 5., 6.]], &Device::Cpu)?;
|
||||
assert_eq!(t2.to_vec2::<f32>()?, [[1., 2., 3.], [4., 5., 6.]]);
|
||||
let t3 = Tensor::new(
|
||||
vec![
|
||||
vec![vec![1f32, 2., 3.], vec![4., 5., 6.]],
|
||||
vec![vec![3f32, 1., 4.], vec![1., 5., 9.]],
|
||||
],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
assert_eq!(
|
||||
t3.to_vec3::<f32>()?,
|
||||
[
|
||||
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
|
||||
[[3.0, 1.0, 4.0], [1.0, 5.0, 9.0]]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -7,7 +7,10 @@ extern crate accelerate_src;
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::helium::{Config, Model};
|
||||
use candle_transformers::models::helium::{Config as ConfigPreview, Model as ModelPreview};
|
||||
use candle_transformers::models::llama::{
|
||||
Cache as CacheV1, Llama as ModelV1, LlamaConfig as ConfigV1, LlamaEosToks,
|
||||
};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
@ -16,6 +19,44 @@ use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum Model {
|
||||
V1 { model: ModelV1, cache: CacheV1 },
|
||||
Preview(ModelPreview),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&mut self, input: &Tensor, start_pos: usize) -> Result<Tensor> {
|
||||
let model = match self {
|
||||
Model::V1 { model, cache } => model.forward(input, start_pos, cache)?,
|
||||
Model::Preview(m) => m.forward(input, start_pos)?,
|
||||
};
|
||||
Ok(model)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum Config {
|
||||
V1(ConfigV1),
|
||||
Preview(ConfigPreview),
|
||||
}
|
||||
|
||||
impl Config {
|
||||
fn bos_token_id(&self) -> Option<u32> {
|
||||
match self {
|
||||
Config::V1(c) => c.bos_token_id,
|
||||
Config::Preview(c) => Some(c.bos_token_id),
|
||||
}
|
||||
}
|
||||
|
||||
fn eos_token_id(&self) -> Option<LlamaEosToks> {
|
||||
match self {
|
||||
Config::V1(c) => c.eos_token_id.clone(),
|
||||
Config::Preview(c) => Some(LlamaEosToks::Single(c.eos_token_id)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
@ -106,7 +147,15 @@ impl TextGeneration {
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == self.config.bos_token_id || next_token == self.config.eos_token_id {
|
||||
let is_eos = self
|
||||
.config
|
||||
.eos_token_id()
|
||||
.as_ref()
|
||||
.is_some_and(|v| match v {
|
||||
LlamaEosToks::Single(eos) => *eos == next_token,
|
||||
LlamaEosToks::Multiple(eos) => eos.contains(&next_token),
|
||||
});
|
||||
if Some(next_token) == self.config.bos_token_id() || is_eos {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
@ -131,6 +180,8 @@ impl TextGeneration {
|
||||
enum Which {
|
||||
#[value(name = "v1-preview")]
|
||||
V1Preview,
|
||||
#[value(name = "v1")]
|
||||
V1,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -144,9 +195,6 @@ struct Args {
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
@ -171,7 +219,7 @@ struct Args {
|
||||
sample_len: usize,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "v1-preview")]
|
||||
#[arg(long, default_value = "v1")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
@ -230,6 +278,7 @@ fn main() -> Result<()> {
|
||||
None => {
|
||||
let name = match args.which {
|
||||
Which::V1Preview => "kyutai/helium-1-preview-2b",
|
||||
Which::V1 => "kyutai/helium-1-2b",
|
||||
};
|
||||
name.to_string()
|
||||
}
|
||||
@ -254,18 +303,27 @@ fn main() -> Result<()> {
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config: Config = match args.config {
|
||||
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
|
||||
None => {
|
||||
let config_file = repo.get("config.json")?;
|
||||
serde_json::from_slice(&std::fs::read(config_file)?)?
|
||||
}
|
||||
let config_file = match args.config {
|
||||
Some(config_file) => std::path::PathBuf::from(config_file),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
let config = match args.which {
|
||||
Which::V1Preview => Config::Preview(serde_json::from_slice(&std::fs::read(config_file)?)?),
|
||||
Which::V1 => Config::V1(serde_json::from_slice(&std::fs::read(config_file)?)?),
|
||||
};
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (model, device) = {
|
||||
let dtype = device.bf16_default_to_f32();
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
let model = match &config {
|
||||
Config::V1(c) => {
|
||||
let c = c.clone().into_config(false);
|
||||
let model = ModelV1::load(vb, &c)?;
|
||||
let cache = CacheV1::new(true, dtype, &c, &device)?;
|
||||
Model::V1 { model, cache }
|
||||
}
|
||||
Config::Preview(c) => Model::Preview(ModelPreview::new(c, vb)?),
|
||||
};
|
||||
(model, device)
|
||||
};
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.9.0"
|
||||
version = "0.9.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.1" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.9.0"
|
||||
version = "0.9.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
|
@ -219,11 +219,15 @@ __device__ void softmax(const T * x, T * dst, const int ncols) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) {
|
||||
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t stride_b) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (2 * idx >= bh * td) return;
|
||||
|
||||
uint32_t rope_idx = idx % (td / 2);
|
||||
if (stride_b > 0) {
|
||||
uint32_t b_idx = (2 * idx) / stride_b;
|
||||
rope_idx += b_idx * (td / 2);
|
||||
}
|
||||
T c = cos[rope_idx];
|
||||
T s = sin[rope_idx];
|
||||
|
||||
@ -232,7 +236,7 @@ __device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, cons
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d) {
|
||||
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d, const uint32_t stride_b) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (2 * idx >= bh * td) return;
|
||||
|
||||
@ -243,6 +247,10 @@ __device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const
|
||||
uint32_t i1 = i_bh * td + i_t * d + i_d;
|
||||
uint32_t i2 = i1 + d / 2;
|
||||
uint32_t i_cs = i_t * (d / 2) + i_d;
|
||||
if (stride_b > 0) {
|
||||
uint32_t b_idx = (2 * idx) / stride_b;
|
||||
i_cs += b_idx * (td / 2);
|
||||
}
|
||||
T c = cos[i_cs];
|
||||
T s = sin[i_cs];
|
||||
|
||||
@ -259,7 +267,8 @@ __device__ void rope_thd(
|
||||
const uint32_t b,
|
||||
const uint32_t t,
|
||||
const uint32_t h,
|
||||
const uint32_t d
|
||||
const uint32_t d,
|
||||
const uint32_t stride_b
|
||||
) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (2 * idx >= b * t * h * d) return;
|
||||
@ -270,6 +279,10 @@ __device__ void rope_thd(
|
||||
uint32_t i1 = i_bth * d + i_d;
|
||||
uint32_t i2 = i1 + d / 2;
|
||||
uint32_t i_cs = i_t * (d / 2) + i_d;
|
||||
if (stride_b > 0) {
|
||||
uint32_t b_idx = (2 * idx) / stride_b;
|
||||
i_cs += b_idx * ((t * d) / 2);
|
||||
}
|
||||
T c = cos[i_cs];
|
||||
T s = sin[i_cs];
|
||||
|
||||
@ -546,8 +559,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
const TYPENAME *sin, \
|
||||
TYPENAME *dst, \
|
||||
const uint32_t bh, \
|
||||
const uint32_t td) { \
|
||||
ropei<TYPENAME>(src, cos, sin, dst, bh, td); \
|
||||
const uint32_t td, \
|
||||
const uint32_t stride_b) { \
|
||||
ropei<TYPENAME>(src, cos, sin, dst, bh, td, stride_b); \
|
||||
} \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const TYPENAME *src, \
|
||||
@ -556,8 +570,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
TYPENAME *dst, \
|
||||
const uint32_t bh, \
|
||||
const uint32_t td, \
|
||||
const uint32_t d) { \
|
||||
rope<TYPENAME>(src, cos, sin, dst, bh, td, d); \
|
||||
const uint32_t d, \
|
||||
const uint32_t stride_b) { \
|
||||
rope<TYPENAME>(src, cos, sin, dst, bh, td, d, stride_b); \
|
||||
} \
|
||||
extern "C" __global__ void FN_NAME_THD( \
|
||||
const TYPENAME *src, \
|
||||
@ -567,8 +582,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
const uint32_t b, \
|
||||
const uint32_t t, \
|
||||
const uint32_t h, \
|
||||
const uint32_t d) { \
|
||||
rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d); \
|
||||
const uint32_t d, \
|
||||
const uint32_t stride_b) { \
|
||||
rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d, stride_b); \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.9.0"
|
||||
version = "0.9.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "Metal kernels for Candle"
|
||||
|
@ -991,6 +991,7 @@ pub fn call_rope_i(
|
||||
kernel_name: &'static str,
|
||||
bh: usize,
|
||||
td: usize,
|
||||
stride_b: usize,
|
||||
src: &Buffer,
|
||||
src_offset: usize,
|
||||
cos: &Buffer,
|
||||
@ -1009,6 +1010,7 @@ pub fn call_rope_i(
|
||||
(
|
||||
bh,
|
||||
td,
|
||||
stride_b,
|
||||
(src, src_offset),
|
||||
(cos, cos_offset),
|
||||
(sin, sin_offset),
|
||||
@ -1034,6 +1036,7 @@ pub fn call_rope_thd(
|
||||
t: usize,
|
||||
h: usize,
|
||||
d: usize,
|
||||
stride_b: usize,
|
||||
src: &Buffer,
|
||||
src_offset: usize,
|
||||
cos: &Buffer,
|
||||
@ -1054,6 +1057,7 @@ pub fn call_rope_thd(
|
||||
t,
|
||||
h,
|
||||
d,
|
||||
stride_b,
|
||||
(src, src_offset),
|
||||
(cos, cos_offset),
|
||||
(sin, sin_offset),
|
||||
@ -1078,6 +1082,7 @@ pub fn call_rope(
|
||||
bh: usize,
|
||||
td: usize,
|
||||
d: usize,
|
||||
stride_b: usize,
|
||||
src: &Buffer,
|
||||
src_offset: usize,
|
||||
cos: &Buffer,
|
||||
@ -1097,6 +1102,7 @@ pub fn call_rope(
|
||||
bh,
|
||||
td,
|
||||
d,
|
||||
stride_b,
|
||||
(src, src_offset),
|
||||
(cos, cos_offset),
|
||||
(sin, sin_offset),
|
||||
|
@ -1097,6 +1097,7 @@ template<typename T>
|
||||
METAL_FUNC void ropei(
|
||||
constant size_t &bh,
|
||||
constant size_t &td,
|
||||
constant size_t &stride_b,
|
||||
device const T *src,
|
||||
device const T *cos,
|
||||
device const T *sin,
|
||||
@ -1107,6 +1108,10 @@ METAL_FUNC void ropei(
|
||||
return;
|
||||
}
|
||||
size_t rope_idx = tid % (td / 2);
|
||||
if (stride_b > 0) {
|
||||
size_t b_idx = (2 * tid) / stride_b;
|
||||
rope_idx += b_idx * (td / 2);
|
||||
}
|
||||
T c = cos[rope_idx];
|
||||
T s = sin[rope_idx];
|
||||
dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s;
|
||||
@ -1118,6 +1123,7 @@ METAL_FUNC void rope(
|
||||
constant size_t &bh,
|
||||
constant size_t &td,
|
||||
constant size_t &d,
|
||||
constant size_t &stride_b,
|
||||
device const T *src,
|
||||
device const T *cos,
|
||||
device const T *sin,
|
||||
@ -1134,6 +1140,10 @@ METAL_FUNC void rope(
|
||||
size_t i1 = i_bh * td + i_t * d + i_d;
|
||||
size_t i2 = i1 + d / 2;
|
||||
size_t i_cs = i_t * (d / 2) + i_d;
|
||||
if (stride_b > 0) {
|
||||
size_t b_idx = (2 * idx) / stride_b;
|
||||
i_cs += b_idx * (td / 2);
|
||||
}
|
||||
T c = cos[i_cs];
|
||||
T s = sin[i_cs];
|
||||
dst[i1] = src[i1] * c - src[i2] * s;
|
||||
@ -1146,6 +1156,7 @@ METAL_FUNC void rope_thd(
|
||||
constant size_t &t,
|
||||
constant size_t &h,
|
||||
constant size_t &d,
|
||||
constant size_t &stride_b,
|
||||
device const T *src,
|
||||
device const T *cos,
|
||||
device const T *sin,
|
||||
@ -1160,8 +1171,12 @@ METAL_FUNC void rope_thd(
|
||||
const size_t i_t = (i_bth / h) % t;
|
||||
const size_t i1 = i_bth * d + i_d;
|
||||
const size_t i2 = i1 + d / 2;
|
||||
const size_t i_cs = i_t * (d / 2) + i_d;
|
||||
T c = cos[i_cs];
|
||||
size_t i_cs = i_t * (d / 2) + i_d;
|
||||
if (stride_b > 0) {
|
||||
const size_t b_idx = (2 * idx) / stride_b;
|
||||
i_cs += b_idx * ((t * d) / 2);
|
||||
}
|
||||
T c = cos[i_cs];
|
||||
T s = sin[i_cs];
|
||||
dst[i1] = src[i1] * c - src[i2] * s;
|
||||
dst[i2] = src[i1] * s + src[i2] * c;
|
||||
@ -1171,38 +1186,41 @@ METAL_FUNC void rope_thd(
|
||||
kernel void FN_NAME_I( \
|
||||
constant size_t &bh, \
|
||||
constant size_t &td, \
|
||||
constant size_t &stride_b, \
|
||||
device const TYPENAME *src, \
|
||||
device const TYPENAME *cos, \
|
||||
device const TYPENAME *sin, \
|
||||
device TYPENAME *dst, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
ropei<TYPENAME>(bh, td, src, cos, sin, dst, tid); \
|
||||
ropei<TYPENAME>(bh, td, stride_b, src, cos, sin, dst, tid); \
|
||||
}\
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &bh, \
|
||||
constant size_t &td, \
|
||||
constant size_t &d, \
|
||||
constant size_t &stride_b, \
|
||||
device const TYPENAME *src, \
|
||||
device const TYPENAME *cos, \
|
||||
device const TYPENAME *sin, \
|
||||
device TYPENAME *dst, \
|
||||
uint idx [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
rope<TYPENAME>(bh, td, d, src, cos, sin, dst, idx); \
|
||||
rope<TYPENAME>(bh, td, d, stride_b, src, cos, sin, dst, idx); \
|
||||
}\
|
||||
kernel void FN_NAME_THD( \
|
||||
constant size_t &b, \
|
||||
constant size_t &t, \
|
||||
constant size_t &h, \
|
||||
constant size_t &d, \
|
||||
constant size_t &stride_b, \
|
||||
device const TYPENAME *src, \
|
||||
device const TYPENAME *cos, \
|
||||
device const TYPENAME *sin, \
|
||||
device TYPENAME *dst, \
|
||||
uint idx [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
rope_thd<TYPENAME>(b, t, h, d, src, cos, sin, dst, idx); \
|
||||
rope_thd<TYPENAME>(b, t, h, d, stride_b, src, cos, sin, dst, idx); \
|
||||
}\
|
||||
|
||||
RMSNORM(rmsnorm_f32, float)
|
||||
|
@ -1584,7 +1584,7 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
dim,
|
||||
BufferOffset::zero_offset(&input_buffer),
|
||||
BufferOffset::zero_offset(&ids_buffer),
|
||||
&output,
|
||||
BufferOffset::zero_offset(&output),
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
|
@ -1,6 +1,6 @@
|
||||
//! Cache Implementations
|
||||
//!
|
||||
use candle::{Device, Result, Tensor};
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Cache {
|
||||
@ -399,3 +399,322 @@ impl RotatingKvCache {
|
||||
self.v.reset();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IndicesAndMask {
|
||||
indices: Tensor,
|
||||
mask: Tensor,
|
||||
}
|
||||
|
||||
impl IndicesAndMask {
|
||||
pub fn mask(&self) -> &Tensor {
|
||||
&self.mask
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ScatteredKvCache {
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
context: usize,
|
||||
}
|
||||
|
||||
impl ScatteredKvCache {
|
||||
pub fn append(
|
||||
&mut self,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
iam: &IndicesAndMask,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
if self.context <= k.dim(2)? {
|
||||
return Ok((k.clone(), v.clone()));
|
||||
}
|
||||
let indices = iam.indices.unsqueeze(2)?.unsqueeze(1)?;
|
||||
let indices = indices.broadcast_as(k.shape())?.contiguous()?;
|
||||
self.k.scatter_set(&indices, k, 2)?;
|
||||
self.v.scatter_set(&indices, v, 2)?;
|
||||
Ok((self.k.clone(), self.v.clone()))
|
||||
}
|
||||
|
||||
pub fn k(&self) -> &Tensor {
|
||||
&self.k
|
||||
}
|
||||
|
||||
pub fn v(&self) -> &Tensor {
|
||||
&self.v
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ScatteredCacheBuilder {
|
||||
context: usize,
|
||||
// The current position in the stream, this can be larger than context.
|
||||
positions: Vec<usize>,
|
||||
// The index where the next element will be stored.
|
||||
indices: Vec<usize>,
|
||||
dtype: DType,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl ScatteredCacheBuilder {
|
||||
pub fn new(batch_size: usize, context: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||
let positions = vec![0; batch_size];
|
||||
let indices = vec![0; batch_size];
|
||||
Ok(Self {
|
||||
positions,
|
||||
indices,
|
||||
context,
|
||||
dtype,
|
||||
device: device.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn make_cache(&self, num_heads: usize, head_dim: usize) -> Result<ScatteredKvCache> {
|
||||
let batch_size = self.batch_size();
|
||||
let shape = (batch_size, num_heads, self.context, head_dim);
|
||||
let k = Tensor::zeros(shape, self.dtype, self.device())?;
|
||||
let v = Tensor::zeros(shape, self.dtype, self.device())?;
|
||||
Ok(ScatteredKvCache {
|
||||
k,
|
||||
v,
|
||||
context: self.context,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn positions(&self) -> &[usize] {
|
||||
&self.positions
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.positions.fill(0);
|
||||
self.indices.fill(0);
|
||||
}
|
||||
|
||||
pub fn batch_size(&self) -> usize {
|
||||
self.positions.len()
|
||||
}
|
||||
|
||||
pub fn reset_batch_index(&mut self, batch_index: usize) {
|
||||
self.positions[batch_index] = 0;
|
||||
self.indices[batch_index] = 0;
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
pub fn indices_and_mask(
|
||||
&mut self,
|
||||
seq_len: usize,
|
||||
batch_mask: &[bool],
|
||||
) -> Result<IndicesAndMask> {
|
||||
// mask shape is (b, h, t, k)
|
||||
let context = self.context;
|
||||
if self.context <= seq_len {
|
||||
return self.indices_and_mask_abs(seq_len, batch_mask);
|
||||
}
|
||||
let mut attention_masks = Vec::with_capacity(self.batch_size());
|
||||
let mut cache_indices = Vec::with_capacity(self.batch_size());
|
||||
for (batch_i, &batch_mask) in batch_mask.iter().enumerate() {
|
||||
if !batch_mask {
|
||||
let masks: Vec<Vec<f32>> = vec![vec![0.0; context]; seq_len];
|
||||
let indices = vec![self.indices[batch_i] as u32; seq_len];
|
||||
attention_masks.push(masks);
|
||||
cache_indices.push(indices);
|
||||
} else {
|
||||
let start_index = self.indices[batch_i];
|
||||
let start_pos = self.positions[batch_i];
|
||||
let mut masks: Vec<Vec<f32>> = Vec::with_capacity(seq_len);
|
||||
let mut indices = Vec::with_capacity(seq_len);
|
||||
let mut all_pos = vec![usize::MAX; context];
|
||||
if start_pos < context {
|
||||
for i in 0..start_pos {
|
||||
all_pos[i] = i;
|
||||
}
|
||||
} else {
|
||||
let offset = start_pos - start_index;
|
||||
for i in 0..context {
|
||||
all_pos[i] = if i < start_index {
|
||||
i + offset
|
||||
} else {
|
||||
i + offset - context
|
||||
};
|
||||
}
|
||||
}
|
||||
for seq_i in 0..seq_len {
|
||||
let index = self.indices[batch_i];
|
||||
all_pos[index] = seq_i + start_pos;
|
||||
indices.push(index as u32);
|
||||
self.indices[batch_i] += 1;
|
||||
self.positions[batch_i] += 1;
|
||||
if self.indices[batch_i] >= self.context {
|
||||
self.indices[batch_i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
for seq_i in 0..seq_len {
|
||||
let my_pos = seq_i + start_pos;
|
||||
let mask = all_pos
|
||||
.iter()
|
||||
.map(|&pos| {
|
||||
if pos <= my_pos {
|
||||
0.0
|
||||
} else {
|
||||
f32::NEG_INFINITY
|
||||
}
|
||||
})
|
||||
.collect::<Vec<f32>>();
|
||||
masks.push(mask);
|
||||
}
|
||||
|
||||
attention_masks.push(masks);
|
||||
cache_indices.push(indices);
|
||||
}
|
||||
}
|
||||
// Flattening the attention mask then using Tensor::from_vec rather using Tensor::new ends
|
||||
// up being almost 10x faster with candle 0.9.0. This has been fixed in candle 0.9.1.
|
||||
let attention_masks = attention_masks
|
||||
.into_iter()
|
||||
.flat_map(|m| m.into_iter().flatten())
|
||||
.collect::<Vec<f32>>();
|
||||
let mask = Tensor::from_vec(attention_masks, ((), 1, seq_len, context), self.device())?
|
||||
.to_dtype(self.dtype)?;
|
||||
let indices = Tensor::new(cache_indices, self.device())?;
|
||||
Ok(IndicesAndMask { indices, mask })
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
fn indices_and_mask_abs(
|
||||
&mut self,
|
||||
seq_len: usize,
|
||||
batch_mask: &[bool],
|
||||
) -> Result<IndicesAndMask> {
|
||||
let mask = self.get_mask_abs(seq_len, seq_len)?;
|
||||
let mut cache_indices = Vec::with_capacity(self.batch_size());
|
||||
for (batch_i, &batch_mask) in batch_mask.iter().enumerate() {
|
||||
if !batch_mask {
|
||||
let indices = vec![self.indices[batch_i] as u32; seq_len];
|
||||
cache_indices.push(indices);
|
||||
} else {
|
||||
let mut indices = Vec::with_capacity(seq_len);
|
||||
for _ in 0..seq_len {
|
||||
let index = self.indices[batch_i];
|
||||
indices.push(index as u32);
|
||||
self.indices[batch_i] += 1;
|
||||
self.positions[batch_i] += 1;
|
||||
if self.indices[batch_i] >= self.context {
|
||||
self.indices[batch_i] = 0;
|
||||
}
|
||||
}
|
||||
cache_indices.push(indices);
|
||||
}
|
||||
}
|
||||
let indices = Tensor::new(cache_indices, self.device())?;
|
||||
Ok(IndicesAndMask { indices, mask })
|
||||
}
|
||||
|
||||
fn get_mask_abs(&self, size1: usize, size2: usize) -> Result<Tensor> {
|
||||
let context = self.context;
|
||||
let mask: Vec<_> = (0..size1)
|
||||
.flat_map(|i| {
|
||||
(0..size2).map(move |j| {
|
||||
if size1 + j > size2 + i || size1 + j + context < size2 + i {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (size1, size2), self.device())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use candle::IndexOp;
|
||||
|
||||
#[test]
|
||||
fn test_scattered_kv_cache() -> Result<()> {
|
||||
let device = Device::Cpu;
|
||||
let mut cache = ScatteredCacheBuilder::new(2, 5, DType::F32, &device)?;
|
||||
let inf = f32::INFINITY;
|
||||
|
||||
let iam = cache.indices_and_mask(1, &[true, false])?;
|
||||
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
|
||||
assert_eq!(iam.indices.to_vec2::<u32>()?, [[0], [0]]);
|
||||
assert_eq!(
|
||||
mask,
|
||||
[[[0.0, -inf, -inf, -inf, -inf]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
|
||||
);
|
||||
|
||||
let iam = cache.indices_and_mask(1, &[true, false])?;
|
||||
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
|
||||
assert_eq!(iam.indices.to_vec2::<u32>()?, [[1], [0]]);
|
||||
assert_eq!(
|
||||
mask,
|
||||
[[[0.0, 0.0, -inf, -inf, -inf]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
|
||||
);
|
||||
|
||||
let iam = cache.indices_and_mask(3, &[false, true])?;
|
||||
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
|
||||
assert_eq!(iam.indices.to_vec2::<u32>()?, [[2, 2, 2], [0, 1, 2]]);
|
||||
assert_eq!(
|
||||
mask,
|
||||
[
|
||||
[
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
],
|
||||
[
|
||||
[0.0, -inf, -inf, -inf, -inf],
|
||||
[0.0, 0.0, -inf, -inf, -inf],
|
||||
[0.0, 0.0, 0.0, -inf, -inf]
|
||||
]
|
||||
]
|
||||
);
|
||||
|
||||
let iam = cache.indices_and_mask(3, &[true, true])?;
|
||||
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
|
||||
assert_eq!(iam.indices.to_vec2::<u32>()?, [[2, 3, 4], [3, 4, 0]]);
|
||||
assert_eq!(
|
||||
mask,
|
||||
[
|
||||
[
|
||||
[0.0, 0.0, 0.0, -inf, -inf],
|
||||
[0.0, 0.0, 0.0, 0.0, -inf],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
],
|
||||
[
|
||||
[-inf, 0.0, 0.0, 0.0, -inf],
|
||||
[-inf, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
]
|
||||
]
|
||||
);
|
||||
|
||||
let iam = cache.indices_and_mask(1, &[true, false])?;
|
||||
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
|
||||
assert_eq!(iam.indices.to_vec2::<u32>()?, [[0], [1]]);
|
||||
assert_eq!(
|
||||
mask,
|
||||
[[[0.0, 0.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
|
||||
);
|
||||
|
||||
let iam = cache.indices_and_mask(2, &[true, false])?;
|
||||
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
|
||||
assert_eq!(iam.indices.to_vec2::<u32>()?, [[1, 2], [1, 1]]);
|
||||
assert_eq!(
|
||||
mask,
|
||||
[
|
||||
[[0.0, 0.0, -inf, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]],
|
||||
[[0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]]
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -46,15 +46,23 @@ impl candle::CustomOp3 for RotaryEmbI {
|
||||
Some((o1, o2)) => &sin[o1..o2],
|
||||
};
|
||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
|
||||
let el_count = b * h * t * d;
|
||||
let mut dst = vec![T::zero(); el_count];
|
||||
src.par_chunks(t * d)
|
||||
.zip(dst.par_chunks_mut(t * d))
|
||||
.for_each(|(src, dst)| {
|
||||
.enumerate()
|
||||
.for_each(|(bh_i, (src, dst))| {
|
||||
for i_over_2 in 0..t * d / 2 {
|
||||
let i = 2 * i_over_2;
|
||||
dst[i] = src[i] * cos[i_over_2] - src[i + 1] * sin[i_over_2];
|
||||
dst[i + 1] = src[i] * sin[i_over_2] + src[i + 1] * cos[i_over_2];
|
||||
let rope_i = if unbatched_rope {
|
||||
let b_i = bh_i / h;
|
||||
i_over_2 + b_i * t * d / 2
|
||||
} else {
|
||||
i_over_2
|
||||
};
|
||||
dst[i] = src[i] * cos[rope_i] - src[i + 1] * sin[rope_i];
|
||||
dst[i + 1] = src[i] * sin[rope_i] + src[i + 1] * cos[rope_i];
|
||||
}
|
||||
});
|
||||
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||
@ -115,6 +123,11 @@ impl candle::CustomOp3 for RotaryEmbI {
|
||||
Some((o1, o2)) => sin.slice(o1..o2),
|
||||
};
|
||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
|
||||
(h * t * d) as u32
|
||||
} else {
|
||||
0u32
|
||||
};
|
||||
let el = b * h * t * d;
|
||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), &kernels::REDUCE)?;
|
||||
@ -125,7 +138,7 @@ impl candle::CustomOp3 for RotaryEmbI {
|
||||
builder.arg(&cos);
|
||||
builder.arg(&sin);
|
||||
builder.arg(&dst);
|
||||
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32);
|
||||
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, stride_b);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(dst)
|
||||
@ -182,6 +195,11 @@ impl candle::CustomOp3 for RotaryEmbI {
|
||||
dtype => candle::bail!("rope-i is not implemented for {dtype:?}"),
|
||||
};
|
||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
|
||||
h * t * d
|
||||
} else {
|
||||
0usize
|
||||
};
|
||||
let el = b * h * t * d;
|
||||
let output = device.new_buffer(el, src.dtype(), "rope-i")?;
|
||||
candle_metal_kernels::call_rope_i(
|
||||
@ -191,6 +209,7 @@ impl candle::CustomOp3 for RotaryEmbI {
|
||||
name,
|
||||
b * h,
|
||||
t * d,
|
||||
stride_b,
|
||||
src.buffer(),
|
||||
l_src.start_offset() * src.dtype().size_in_bytes(),
|
||||
cos.buffer(),
|
||||
@ -205,10 +224,23 @@ impl candle::CustomOp3 for RotaryEmbI {
|
||||
}
|
||||
}
|
||||
|
||||
fn rope_check_cs(cs: &Tensor, b_sz: usize) -> Result<(usize, usize)> {
|
||||
match *cs.dims() {
|
||||
[t, d] => Ok((t, d)),
|
||||
[b, t, d] => {
|
||||
if b != b_sz {
|
||||
candle::bail!("inconsistent batch size in rope {b_sz} {cs:?}",)
|
||||
}
|
||||
Ok((t, d))
|
||||
}
|
||||
_ => candle::bail!("cos/sin has to be 2D or 3D in rope {b_sz} {cs:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rope_i(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
|
||||
let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
|
||||
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
|
||||
let (sin_seq_len, sin_n_embd) = cos.dims2()?;
|
||||
let (b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
|
||||
let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
|
||||
let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
|
||||
if cos_n_embd * 2 != n_embd
|
||||
|| sin_n_embd * 2 != n_embd
|
||||
|| seq_len > cos_seq_len
|
||||
@ -292,16 +324,24 @@ impl candle::CustomOp3 for RotaryEmb {
|
||||
Some((o1, o2)) => &sin[o1..o2],
|
||||
};
|
||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
|
||||
let el_count = b * h * t * d;
|
||||
let mut dst = vec![T::zero(); el_count];
|
||||
src.par_chunks(t * d)
|
||||
.zip(dst.par_chunks_mut(t * d))
|
||||
.for_each(|(src, dst)| {
|
||||
.enumerate()
|
||||
.for_each(|(bh_i, (src, dst))| {
|
||||
for i_t in 0..t {
|
||||
for i_d in 0..d / 2 {
|
||||
let i1 = i_t * d + i_d;
|
||||
let i2 = i1 + d / 2;
|
||||
let i_cs = i_t * (d / 2) + i_d;
|
||||
let i_cs = if unbatched_rope {
|
||||
let b_i = bh_i / h;
|
||||
i_cs + b_i * t * d / 2
|
||||
} else {
|
||||
i_cs
|
||||
};
|
||||
dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs];
|
||||
dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs];
|
||||
}
|
||||
@ -365,6 +405,11 @@ impl candle::CustomOp3 for RotaryEmb {
|
||||
Some((o1, o2)) => sin.slice(o1..o2),
|
||||
};
|
||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
|
||||
(h * t * d) as u32
|
||||
} else {
|
||||
0u32
|
||||
};
|
||||
let el = b * h * t * d;
|
||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), &kernels::REDUCE)?;
|
||||
@ -375,7 +420,7 @@ impl candle::CustomOp3 for RotaryEmb {
|
||||
builder.arg(&cos);
|
||||
builder.arg(&sin);
|
||||
builder.arg(&dst);
|
||||
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32);
|
||||
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32, stride_b);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(dst)
|
||||
@ -432,6 +477,11 @@ impl candle::CustomOp3 for RotaryEmb {
|
||||
dtype => candle::bail!("rope is not implemented for {dtype:?}"),
|
||||
};
|
||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
|
||||
h * t * d
|
||||
} else {
|
||||
0usize
|
||||
};
|
||||
let el = b * h * t * d;
|
||||
let output = device.new_buffer(el, src.dtype(), "rope-i")?;
|
||||
candle_metal_kernels::call_rope(
|
||||
@ -442,6 +492,7 @@ impl candle::CustomOp3 for RotaryEmb {
|
||||
b * h,
|
||||
t * d,
|
||||
d,
|
||||
stride_b,
|
||||
src.buffer(),
|
||||
l_src.start_offset() * src.dtype().size_in_bytes(),
|
||||
cos.buffer(),
|
||||
@ -457,9 +508,9 @@ impl candle::CustomOp3 for RotaryEmb {
|
||||
}
|
||||
|
||||
pub fn rope(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
|
||||
let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
|
||||
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
|
||||
let (sin_seq_len, sin_n_embd) = sin.dims2()?;
|
||||
let (b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
|
||||
let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
|
||||
let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
|
||||
if cos_n_embd * 2 != n_embd
|
||||
|| sin_n_embd * 2 != n_embd
|
||||
|| seq_len > cos_seq_len
|
||||
@ -541,14 +592,21 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
||||
Some((o1, o2)) => &sin[o1..o2],
|
||||
};
|
||||
let (b, t, h, d) = l_src.shape().dims4()?;
|
||||
let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
|
||||
let el_count = b * h * t * d;
|
||||
let mut dst = vec![T::zero(); el_count];
|
||||
src.par_chunks(t * h * d)
|
||||
.zip(dst.par_chunks_mut(t * h * d))
|
||||
.for_each(|(src, dst)| {
|
||||
.enumerate()
|
||||
.for_each(|(b_i, (src, dst))| {
|
||||
for i_t in 0..t {
|
||||
for i_d in 0..d / 2 {
|
||||
let i_cs = i_t * (d / 2) + i_d;
|
||||
let i_cs = if unbatched_rope {
|
||||
i_cs + b_i * t * d / 2
|
||||
} else {
|
||||
i_cs
|
||||
};
|
||||
for i_h in 0..h {
|
||||
let i1 = i_t * h * d + i_h * d + i_d;
|
||||
let i2 = i1 + d / 2;
|
||||
@ -616,6 +674,11 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
||||
Some((o1, o2)) => sin.slice(o1..o2),
|
||||
};
|
||||
let (b, t, h, d) = l_src.shape().dims4()?;
|
||||
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
|
||||
(h * t * d) as u32
|
||||
} else {
|
||||
0u32
|
||||
};
|
||||
let el = b * h * t * d;
|
||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), &kernels::REDUCE)?;
|
||||
@ -626,7 +689,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
||||
builder.arg(&cos);
|
||||
builder.arg(&sin);
|
||||
builder.arg(&dst);
|
||||
candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32);
|
||||
candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32, stride_b);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(dst)
|
||||
@ -683,6 +746,11 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
||||
dtype => candle::bail!("rope_thd is not implemented for {dtype:?}"),
|
||||
};
|
||||
let (b, t, h, d) = l_src.shape().dims4()?;
|
||||
let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
|
||||
h * t * d
|
||||
} else {
|
||||
0usize
|
||||
};
|
||||
let el = b * h * t * d;
|
||||
let output = device.new_buffer(el, src.dtype(), "rope-thd")?;
|
||||
candle_metal_kernels::call_rope_thd(
|
||||
@ -694,6 +762,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
||||
t,
|
||||
h,
|
||||
d,
|
||||
stride_b,
|
||||
src.buffer(),
|
||||
l_src.start_offset() * src.dtype().size_in_bytes(),
|
||||
cos.buffer(),
|
||||
@ -709,9 +778,9 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
||||
}
|
||||
|
||||
pub fn rope_thd(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len, _n_head, n_embd) = xs.dims4()?;
|
||||
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
|
||||
let (sin_seq_len, sin_n_embd) = sin.dims2()?;
|
||||
let (b_sz, seq_len, _n_head, n_embd) = xs.dims4()?;
|
||||
let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
|
||||
let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
|
||||
if cos_n_embd * 2 != n_embd
|
||||
|| sin_n_embd * 2 != n_embd
|
||||
|| seq_len > cos_seq_len
|
||||
|
@ -8,13 +8,16 @@ pub fn gumbel_softmax<D: candle::shape::Dim>(
|
||||
) -> Result<Tensor> {
|
||||
if temperature <= 0.0 {
|
||||
logits.argmax(dim)
|
||||
} else if temperature == 1.0 {
|
||||
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
|
||||
let sampled = (logits - minus_g)?.argmax(dim)?;
|
||||
Ok(sampled)
|
||||
} else {
|
||||
// Cast to f32, doing the Gumbel softmax in bf16 is a bit unstable.
|
||||
let logits = logits.to_dtype(candle::DType::F32)?;
|
||||
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
|
||||
let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
|
||||
Ok(sampled)
|
||||
if temperature == 1.0 {
|
||||
let sampled = (logits - minus_g)?.argmax(dim)?;
|
||||
Ok(sampled)
|
||||
} else {
|
||||
let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
|
||||
Ok(sampled)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ extern crate intel_mkl_src;
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::{test_device, test_utils::to_vec3_round, Device, Result, Tensor};
|
||||
use candle::{test_device, test_utils::to_vec3_round, Device, IndexOp, Result, Tensor};
|
||||
|
||||
fn softmax(device: &Device) -> Result<()> {
|
||||
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
|
||||
@ -179,6 +179,28 @@ fn ropei(device: &Device) -> Result<()> {
|
||||
} else {
|
||||
assert!(sum_diff < 1e-4);
|
||||
}
|
||||
|
||||
// Test with a 3d cos/sin
|
||||
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.random::<f32>())
|
||||
.collect();
|
||||
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.random::<f32>())
|
||||
.collect();
|
||||
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
|
||||
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
|
||||
let rope1 = candle_nn::rotary_emb::rope_i(&src.i(0..1)?, &cos, &sin)?;
|
||||
let rope2 = candle_nn::rotary_emb::rope_i(&src.i(1..2)?, &cos2, &sin2)?;
|
||||
|
||||
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
|
||||
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
|
||||
let both_rope = candle_nn::rotary_emb::rope_i(&src, &both_cos, &both_sin)?;
|
||||
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
|
||||
let sum_diff = (both_rope - both_rope2)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(sum_diff, 0.);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -206,6 +228,28 @@ fn rope(device: &Device) -> Result<()> {
|
||||
} else {
|
||||
assert!(sum_diff < 1e-4);
|
||||
}
|
||||
|
||||
// Test with a 3d cos/sin
|
||||
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.random::<f32>())
|
||||
.collect();
|
||||
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.random::<f32>())
|
||||
.collect();
|
||||
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
|
||||
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
|
||||
let rope1 = candle_nn::rotary_emb::rope(&src.i(0..1)?, &cos, &sin)?;
|
||||
let rope2 = candle_nn::rotary_emb::rope(&src.i(1..2)?, &cos2, &sin2)?;
|
||||
|
||||
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
|
||||
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
|
||||
let both_rope = candle_nn::rotary_emb::rope(&src, &both_cos, &both_sin)?;
|
||||
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
|
||||
let sum_diff = (both_rope - both_rope2)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(sum_diff, 0.);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -236,6 +280,37 @@ fn rope_thd(device: &Device) -> Result<()> {
|
||||
} else {
|
||||
assert!(sum_diff < 1e-4);
|
||||
}
|
||||
|
||||
// Test with a 3d cos/sin
|
||||
let cos2: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.random::<f32>())
|
||||
.collect();
|
||||
let sin2: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.random::<f32>())
|
||||
.collect();
|
||||
let cos2 = Tensor::from_vec(cos2, (seq_len, head_dim / 2), device)?;
|
||||
let sin2 = Tensor::from_vec(sin2, (seq_len, head_dim / 2), device)?;
|
||||
let rope1 = {
|
||||
let src = src.transpose(1, 2)?.contiguous()?;
|
||||
candle_nn::rotary_emb::rope_thd(&src.i(0..1)?, &cos, &sin)?
|
||||
};
|
||||
let rope2 = {
|
||||
let src = src.transpose(1, 2)?.contiguous()?;
|
||||
candle_nn::rotary_emb::rope_thd(&src.i(1..2)?, &cos2, &sin2)?
|
||||
};
|
||||
|
||||
let both_cos = Tensor::stack(&[cos, cos2], 0)?;
|
||||
let both_sin = Tensor::stack(&[sin, sin2], 0)?;
|
||||
let both_rope = {
|
||||
let src = src.transpose(1, 2)?.contiguous()?;
|
||||
candle_nn::rotary_emb::rope_thd(&src, &both_cos, &both_sin)?
|
||||
};
|
||||
let both_rope2 = Tensor::cat(&[rope1, rope2], 0)?;
|
||||
let sum_diff = (both_rope - both_rope2)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(sum_diff, 0.);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-onnx"
|
||||
version = "0.9.0"
|
||||
version = "0.9.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "ONNX support for Candle"
|
||||
@ -10,8 +10,8 @@ categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.9.0" }
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.9.1" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.9.1" }
|
||||
prost = "0.12.1"
|
||||
|
||||
[build-dependencies]
|
||||
|
Reference in New Issue
Block a user