Compare commits

..

47 Commits

Author SHA1 Message Date
1f23cea90c MFA 2023-12-13 16:09:20 +01:00
ce33d6ad2a Tmp. 2023-12-11 11:10:48 +01:00
3d0ade406a Tmp. 2023-12-11 09:38:25 +01:00
2ca086939f Put back affine strided tests 2023-11-30 11:40:39 +01:00
4349ff1fc2 Starting to fix some tests.
Few fixes.

Going back on remote metal-rs.

Reusing a single buffer (for now) to speed things up.

Adding some half kernels.

All tests are panicking instead of random failure.

Putting back f16 index select.

Add erf.

Working version for llama2-c.

Fixes + cache compute_pipeline_state.

BF16 metal fix.

Remove some prints.

new_owned -> new()..to_owned().

Better batched matmul.

Metal operational.

Reuse buffers on our own reference counts.

Tmp gemm.

Revert "Tmp gemm."

This reverts commit c65f68e988.

Interleave committing.

Speeding up copies using blit.

Fmt.

Fmt.

Remove the assert!

Fmt all.

Fixes after big rebase.

Add softmax for half and bfloat + tests

Fixing Llama example + accumulate softmax in float.
2023-11-30 11:30:31 +01:00
7c3cfd1086 Use the llama weight names for the Yi example. (#1381) 2023-11-27 20:42:52 +00:00
e2eb6590ed Merge pull request #1323 from huggingface/metal3
Adding the test scaffolding.
2023-11-27 13:06:01 +01:00
481c45d78d Add a basic implementation for slice-assign. (#1377) 2023-11-26 17:31:22 +00:00
14a2bdc062 Small tweak: remove the macro usage for the range indexing trait. (#1376) 2023-11-26 16:30:59 +00:00
bfa7c8fc01 Implement the module trait directly for QMatMul. (#1372) 2023-11-25 10:09:45 +00:00
762e996ce6 Distibert (#1366)
* add bce with logit loss

* add bce with logit loss

* remove imports

* fix tiny bug

* add test documentation and refactor function

* fix test cases and formatting

* distilbet files

* Apply various cleanups.

* More cleanups.

* More polish.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2023-11-24 15:09:14 +00:00
ca19a9af62 Fix linspace implementation (#1358)
* Fix linspace implementation

`steps` should be strictly greater than 1 to make it consistent with the context.

* Handle steps == 0 and steps == 1.

* Fix rustfmt.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2023-11-23 07:35:13 +00:00
ec23427d60 Ensure to copy data to cpu before iterating. (#1360) 2023-11-23 07:24:25 +00:00
f83e14f68d Add candle-lora transformers to readme? (#1356)
* Demonstrate lora transformers in readme

* Shorten readme
2023-11-21 17:54:24 +00:00
c7e613ab5e Update the readme. (#1354) 2023-11-21 09:38:27 +00:00
8f63f68289 Fix the kalosm link (#1353) 2023-11-21 06:18:14 +01:00
1edc3ddf24 Allowing feature metal to compile. 2023-11-20 20:17:16 +01:00
b380657bfe Merge pull request #1309 from huggingface/metal2
Adding the actual backend
2023-11-20 17:24:01 +01:00
60f624a902 Moving tests around. 2023-11-20 16:17:19 +01:00
8d6c6de8e0 Missing new test. 2023-11-20 14:38:35 +01:00
7ec345c2eb Adding the test scaffolding. 2023-11-20 14:38:35 +01:00
671fc29b36 Fmt. 2023-11-20 14:38:20 +01:00
dc64adb8e4 Fixing cos_f16 test. 2023-11-20 14:17:07 +01:00
c66e5d4716 Fix comments. 2023-11-20 14:13:44 +01:00
bd3b243725 Update candle-metal-kernels/Cargo.toml 2023-11-20 14:12:57 +01:00
2813fb5dbc Cleanup fixed a few ops removed debugging scaffolding. 2023-11-20 14:12:57 +01:00
7cfffcac10 Debugging rope. 2023-11-20 14:12:57 +01:00
38de52bc4b Fixed matmul (display still broken without casting back to CPU first? ) 2023-11-20 14:12:57 +01:00
d46670f7c0 Tmp state. 2023-11-20 14:12:57 +01:00
f710fab02e Fixing the kernels + launches to make them faster.
Cool work by @ivarflakstad

Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
2023-11-20 14:12:57 +01:00
f82bf2d915 Adding indexing.
Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
2023-11-20 14:12:57 +01:00
df6814f34e Refactor to simplify our lives for settings the params in the encoder. 2023-11-20 14:12:57 +01:00
39406a6721 Adding the actual backend 2023-11-20 14:12:56 +01:00
976ad9f9c2 Remove tracing. 2023-11-20 14:12:29 +01:00
a4c4a56429 Metal part 1 - Scaffolding for metal. 2023-11-20 14:12:05 +01:00
f49bf6a81d Fix OpenChat 3.5 tokenizer (#1347) 2023-11-19 18:48:04 +00:00
992a788da1 Add OpenChat 3.5 to quantized examples (#1346)
* Add OpenChat to quantized examples

* Add chat prompt

* Make the openchat example more in line with the other models.

* Fix a typo.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2023-11-19 18:28:52 +00:00
8d8f48c60c feat: add test for individual onnx ops (#1332)
* feat: add test for individual onnx ops

* fix: prefer consts when possible

* feat: add move op tests
2023-11-19 08:17:09 +01:00
d31f11035f Support for CumSum in ONNX models. (#1340) 2023-11-17 22:03:40 +00:00
9ab3f9729f Use the whisper-v3 tokenizer now that it has been added. (#1337)
* Use the whisper-v3 tokenizer now that it has been added.

* Use the appropriate nospeech token.
2023-11-16 22:10:31 +00:00
a1f41ab37b feat: adds reset_kv_cache (#1335) 2023-11-16 21:17:42 +00:00
92a05b51cf fix: address clippy 0.1.74 issues (#1336)
- clippy::needless-borrows-for-generic-args
- clippy::reserve-after-initialization
2023-11-16 21:15:22 +00:00
c6763e3b41 Add a simple implementation of cumsum. (#1334)
* Add a simple implementation of cumsum.

* Add another test.
2023-11-15 21:11:15 +00:00
347e31c9ff Add the tril/triu/eye ops. (#1333)
* Add tril/triu/eye.

* Revert the metal crate tweak.
2023-11-15 20:34:37 +00:00
f4fcf60900 Update readme.md (#1322)
Updating the readme to coincide with other examples. If you try to run it as previously written, you will get a "cannot find the path specified" error.
2023-11-12 09:46:19 +00:00
12561b31d3 Fix pose estimation image path (#1326) 2023-11-12 09:45:26 +00:00
a209ce8ceb Update for 0.3.1. (#1324) 2023-11-11 18:48:52 +00:00
63 changed files with 3483 additions and 1334 deletions

View File

@ -19,7 +19,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.3.0"
version = "0.3.1"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -61,10 +61,7 @@ tracing-subscriber = "0.3.7"
wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false }
#metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
metal = { path = "../metal-rs", features = ["mps"] }
dispatch = "0.2.0"
rustc-hash = "1.1"
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
[profile.release-with-debug]
inherits = "release"

View File

@ -139,16 +139,16 @@ And then head over to
<!--- ANCHOR: useful_libraries --->
## Useful External Resources
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): a
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): A
very detailed tutorial showing how to convert a PyTorch model to Candle.
- [`optimisers`](https://github.com/KGrewal1/optimisers): a collection of optimisers
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): Efficient and ergonomic LoRA implemenation for Candle. `candle-lora` has
out-of-the-box LoRA support for many models from Candle, which can be found [here](https://github.com/EricLBuehler/candle-lora/tree/master/candle-lora-transformers/examples).
- [`optimisers`](https://github.com/KGrewal1/optimisers): A collection of optimisers
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): a LoRA implementation
that conforms to the official `peft` implementation.
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
serving local LLMs including an OpenAI compatible API server.
- [`candle-ext`](https://github.com/mokeyish/candle-ext): an extension library to Candle that provides PyTorch functions not currently available in Candle.
- [`kalosm`](https://github.com/floneum/floneum/tree/master/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
If you have an addition to this list, please submit a pull request.
@ -177,6 +177,11 @@ If you have an addition to this list, please submit a pull request.
- Replit-code-v1.5-3B.
- Bert.
- Yi-6B and Yi-34B.
- Quantized LLMs.
- Llama 7b, 13b, 70b, as well as the chat and code variants.
- Mistral 7b, and 7b instruct.
- Zephyr 7b a and b (Mistral based).
- OpenChat 3.5 (Mistral based).
- Text to text.
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
- Marian MT (Machine Translation).

View File

@ -11,11 +11,11 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.3.1" }
candle-nn = { path = "../candle-nn", version = "0.3.1" }
candle-transformers = { path = "../candle-transformers", version = "0.3.1" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }

View File

@ -12,8 +12,8 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
candle-kernels = { path = "../candle-kernels", version = "0.3.1", optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.1", optional = true }
metal = { workspace = true, optional = true}
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
@ -30,8 +30,6 @@ safetensors = { workspace = true }
thiserror = { workspace = true }
yoke = { workspace = true }
zip = { workspace = true }
dispatch = { workspace = true, optional = true }
rustc-hash = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }
@ -43,4 +41,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels", "dep:dispatch"]
metal = ["dep:metal", "dep:candle-metal-kernels"]

View File

@ -8,11 +8,10 @@ use anyhow::Result;
use candle_core::{Device, Tensor};
fn main() -> Result<()> {
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
let start = std::time::Instant::now();
let res = inp.conv2d(&w, 0, 1, 1, 1)?;
println!("{:?}", start.elapsed());
println!("{res:?}");
let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?;
let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?;
let new_a = a.slice_scatter(&b, 1, 2)?;
assert_eq!(a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
assert_eq!(new_a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
Ok(())
}

View File

@ -104,37 +104,31 @@ impl From<&Tensor> for TensorIndexer {
}
}
macro_rules! impl_from_range {
($range_type:ty) => {
impl From<$range_type> for TensorIndexer {
fn from(range: $range_type) -> Self {
use std::ops::Bound::*;
trait RB: RangeBounds<usize> {}
impl RB for Range<usize> {}
impl RB for RangeFrom<usize> {}
impl RB for RangeFull {}
impl RB for RangeInclusive<usize> {}
impl RB for RangeTo<usize> {}
impl RB for RangeToInclusive<usize> {}
let start = match range.start_bound() {
Included(idx) => Included(*idx),
Excluded(idx) => Excluded(*idx),
Unbounded => Unbounded,
};
let end = match range.end_bound() {
Included(idx) => Included(*idx),
Excluded(idx) => Excluded(*idx),
Unbounded => Unbounded,
};
TensorIndexer::Narrow(start, end)
}
}
};
impl<T: RB> From<T> for TensorIndexer {
fn from(range: T) -> Self {
use std::ops::Bound::*;
let start = match range.start_bound() {
Included(idx) => Included(*idx),
Excluded(idx) => Excluded(*idx),
Unbounded => Unbounded,
};
let end = match range.end_bound() {
Included(idx) => Included(*idx),
Excluded(idx) => Excluded(*idx),
Unbounded => Unbounded,
};
TensorIndexer::Narrow(start, end)
}
}
impl_from_range!(Range<usize>);
impl_from_range!(RangeFrom<usize>);
impl_from_range!(RangeFull);
impl_from_range!(RangeInclusive<usize>);
impl_from_range!(RangeTo<usize>);
impl_from_range!(RangeToInclusive<usize>);
/// Trait used to implement multiple signatures for ease of use of the slicing
/// of a tensor
pub trait IndexOp<T> {

View File

@ -123,12 +123,6 @@ pub trait Module {
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
}
impl Module for quantized::QMatMul {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self.forward(xs)
}
}
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self(xs)

File diff suppressed because it is too large Load Diff

View File

@ -307,8 +307,8 @@ impl crate::CustomOp1 for QTensor {
}
}
impl QMatMul {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
impl crate::Module for QMatMul {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
Self::Tensor(w) => {

View File

@ -2457,6 +2457,110 @@ impl Tensor {
Ok(naxis as usize)
}
}
/// Returns a lower triangular matrix of ones of size n by n.
pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
let t = Tensor::arange(0u32, n as u32, device)?;
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
t1.le(&t2)?.to_dtype(dtype)
}
/// Returns an upper triangular matrix of ones of size n by n.
pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
let t = Tensor::arange(0u32, n as u32, device)?;
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
t1.ge(&t2)?.to_dtype(dtype)
}
/// Returns a matrix with a diagonal of ones of size n by n.
pub fn eye(n: usize, dtype: DType, device: &Device) -> Result<Self> {
let t = Tensor::arange(0u32, n as u32, device)?;
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
t1.eq(&t2)?.to_dtype(dtype)
}
/// Returns the cumulative sum of elements of the input tensor summed over the specified
/// dimension.
///
/// This operation is most efficient when dim is the last dimension of the tensor.
pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "cumsum")?;
let rank = self.rank();
if rank == 0 {
return Ok(self.clone());
}
let n_axis = self.dim(dim)?;
let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?;
if rank == 1 {
self.unsqueeze(0)?.matmul(&triu)?.squeeze(0)
} else {
let last = rank - 1;
let t = self.transpose(dim, last)?;
let t = t.broadcast_matmul(&triu)?;
t.transpose(dim, last)
}
}
/// Returns a copy of `self` where the values within `ranges` have been replaced with the
/// content of `src`.
pub fn slice_assign<D: std::ops::RangeBounds<usize>>(
&self,
ranges: &[D],
src: &Tensor,
) -> Result<Self> {
let src_dims = src.dims();
let self_dims = self.dims();
if self_dims.len() != src_dims.len() {
crate::bail!(
"slice-assign requires input with the same rank {} <> {}",
self_dims.len(),
src_dims.len()
)
}
if self_dims.len() != ranges.len() {
crate::bail!(
"slice-assign requires input with the same rank as there are ranges {} <> {}",
self_dims.len(),
ranges.len()
)
}
let mut src = src.clone();
let mut mask = Self::ones(src.shape(), DType::U8, src.device())?;
for (i, range) in ranges.iter().enumerate() {
let start_included = match range.start_bound() {
std::ops::Bound::Unbounded => 0,
std::ops::Bound::Included(v) => *v,
std::ops::Bound::Excluded(v) => *v + 1,
};
let end_excluded = match range.end_bound() {
std::ops::Bound::Unbounded => self_dims[i],
std::ops::Bound::Included(v) => *v + 1,
std::ops::Bound::Excluded(v) => *v,
};
if end_excluded <= start_included {
crate::bail!(
"slice-assign: empty range for dim {i}, {start_included} {end_excluded}"
)
}
if self_dims[i] < end_excluded {
crate::bail!(
"slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
self_dims[i]
)
}
if end_excluded - start_included != src_dims[i] {
crate::bail!(
"slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
)
}
src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?;
mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?
}
mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
}
}
macro_rules! bin_trait {

View File

@ -91,3 +91,32 @@ fn index_3d() -> Result<()> {
assert_eq!(tensor.i((1, .., 3))?.to_vec1::<u32>()?, &[15, 19, 23]);
Ok(())
}
#[test]
fn slice_assign() -> Result<()> {
let dev = Device::Cpu;
let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?;
let src = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((3, 2))?;
let out = tensor.slice_assign(&[1..4, 3..5], &src)?;
assert_eq!(
out.to_vec2::<u32>()?,
&[
[0, 1, 2, 3, 4],
[5, 6, 7, 0, 1],
[10, 11, 12, 2, 3],
[15, 16, 17, 4, 5]
]
);
let out = tensor.slice_assign(&[0..3, 0..2], &src)?;
assert_eq!(
out.to_vec2::<u32>()?,
&[
[0, 1, 2, 3, 4],
[2, 3, 7, 8, 9],
[4, 5, 12, 13, 14],
[15, 16, 17, 18, 19]
]
);
Ok(())
}

View File

@ -1,7 +1,7 @@
use candle_core::{
quantized::{self, GgmlDType},
test_utils::to_vec2_round,
Device, Result, Tensor,
Device, Module, Result, Tensor,
};
use quantized::{k_quants, GgmlType};
use rand::prelude::*;

View File

@ -1159,3 +1159,65 @@ fn i64_abs() -> Result<()> {
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
Ok(())
}
#[test]
fn tril_triu_eye() -> Result<()> {
let t = Tensor::tril2(4, DType::F32, &Device::Cpu)?;
assert_eq!(
t.to_vec2::<f32>()?,
[
[1.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 0.0],
[1.0, 1.0, 1.0, 1.0]
],
);
let t = Tensor::triu2(4, DType::F32, &Device::Cpu)?;
assert_eq!(
t.to_vec2::<f32>()?,
[
[1.0, 1.0, 1.0, 1.0],
[0.0, 1.0, 1.0, 1.0],
[0.0, 0.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 1.0]
]
);
let t = Tensor::eye(4, DType::F32, &Device::Cpu)?;
assert_eq!(
t.to_vec2::<f32>()?,
[
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0]
]
);
Ok(())
}
#[test]
fn cumsum() -> Result<()> {
let t = &[3f32, 1., 4., 1., 5.];
let t = Tensor::new(t, &Device::Cpu)?;
assert_eq!(t.cumsum(0)?.to_vec1::<f32>()?, [3., 4., 8., 9., 14.]);
let t = t.unsqueeze(1)?;
assert_eq!(
t.cumsum(0)?.to_vec2::<f32>()?,
[[3.0], [4.0], [8.0], [9.0], [14.0]]
);
assert_eq!(
t.cumsum(1)?.to_vec2::<f32>()?,
[[3.0], [1.0], [4.0], [1.0], [5.0]]
);
let t = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let t = Tensor::new(t, &Device::Cpu)?;
assert_eq!(
t.cumsum(1)?.to_vec2::<f32>()?,
[[3.0, 4.0, 8.0, 9.0, 14.0], [2.0, 3.0, 10.0, 18.0, 20.0]],
);
assert_eq!(
t.cumsum(0)?.to_vec2::<f32>()?,
[[3.0, 1.0, 4.0, 1.0, 5.0], [5.0, 2.0, 11.0, 9.0, 7.0]]
);
Ok(())
}

View File

@ -11,8 +11,8 @@ readme = "README.md"
[dependencies]
byteorder = { workspace = true }
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.1" }
hf-hub = { workspace = true}
intel-mkl-src = { workspace = true, optional = true }
memmap2 = { workspace = true }

View File

@ -11,12 +11,12 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
candle-onnx = { path = "../candle-onnx", version = "0.3.0", optional = true }
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.3.1" }
candle-nn = { path = "../candle-nn", version = "0.3.1" }
candle-transformers = { path = "../candle-transformers", version = "0.3.1" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true }
candle-onnx = { path = "../candle-onnx", version = "0.3.1", optional = true }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
image = { workspace = true }
@ -57,6 +57,7 @@ flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
[[example]]
name = "llama_multiprocess"

View File

@ -0,0 +1,22 @@
# candle-distilbert
DistilBert is a distiled version of the Bert model.
## Sentence embeddings
DistilBert is used to compute the sentence embeddings for a prompt. The model weights
are downloaded from the hub on the first run.
```bash
cargo run --example distilbert --release -- --prompt "Here is a test sentence"
> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441],
> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244],
> [ 0.0702, -0.1311, -0.4914, ..., 0.3483, -0.6194, 0.1829],
> ...
> [ 0.2993, -0.0106, -0.4640, ..., 0.2844, -0.6732, 0.0042],
> [ 0.1066, -0.0081, -0.4299, ..., 0.3435, -0.7729, 0.0190],
> [ 0.8903, 0.2055, -0.2541, ..., 0.3208, -0.6585, 0.0586]]]
> Tensor[[1, 7, 768], f32]
```

View File

@ -0,0 +1,135 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE};
use anyhow::{Error as E, Result};
use candle::{Device, Tensor};
use candle_nn::VarBuilder;
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
/// When set, compute embeddings for this prompt.
#[arg(long)]
prompt: String,
/// Use the pytorch weights rather than the safetensors ones
#[arg(long)]
use_pth: bool,
/// The number of times to run the prompt.
#[arg(long, default_value = "1")]
n: usize,
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
}
impl Args {
fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> {
let device = candle_examples::device(self.cpu)?;
let default_model = "distilbert-base-uncased".to_string();
let default_revision = "main".to_string();
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(None, Some(revision)) => (default_model, revision),
(None, None) => (default_model, default_revision),
};
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let (config_filename, tokenizer_filename, weights_filename) = {
let api = Api::new()?;
let api = api.repo(repo);
let config = api.get("config.json")?;
let tokenizer = api.get("tokenizer.json")?;
let weights = if self.use_pth {
api.get("pytorch_model.bin")?
} else {
api.get("model.safetensors")?
};
(config, tokenizer, weights)
};
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let vb = if self.use_pth {
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
} else {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
};
let model = DistilBertModel::load(vb, &config)?;
Ok((model, tokenizer))
}
}
fn get_mask(size: usize, device: &Device) -> Tensor {
let mask: Vec<_> = (0..size)
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
.collect();
Tensor::from_slice(&mask, (size, size), device).unwrap()
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
let device = &model.device;
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let tokens = tokenizer
.encode(args.prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let mask = get_mask(tokens.len(), device);
println!("token_ids: {:?}", token_ids.to_vec2::<u32>());
println!("mask: {:?}", mask.to_vec2::<u8>());
let ys = model.forward(&token_ids, &mask)?;
println!("{ys}");
Ok(())
}
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
}

View File

@ -53,6 +53,8 @@ enum Which {
Zephyr7bAlpha,
#[value(name = "7b-zephyr-b")]
Zephyr7bBeta,
#[value(name = "7b-open-chat-3.5")]
OpenChat35,
}
impl Which {
@ -67,8 +69,10 @@ impl Which {
| Self::L7bCode
| Self::L13bCode
| Self::L34bCode => false,
// Zephyr is a fine tuned version of mistral and should be treated in the same way.
Self::Zephyr7bAlpha
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
// same way.
Self::OpenChat35
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta
| Self::Mistral7b
| Self::Mistral7bInstruct => true,
@ -87,10 +91,30 @@ impl Which {
| Self::L13bCode
| Self::L34bCode
| Self::Mistral7b
| Self::Mistral7bInstruct => false,
| Self::Mistral7bInstruct
| Self::OpenChat35 => false,
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
}
}
fn is_open_chat(&self) -> bool {
match self {
Which::L7b
| Which::L13b
| Which::L70b
| Which::L7bChat
| Which::L13bChat
| Which::L70bChat
| Which::L7bCode
| Which::L13bCode
| Which::L34bCode
| Which::Mistral7b
| Which::Mistral7bInstruct
| Which::Zephyr7bAlpha
| Which::Zephyr7bBeta => false,
Which::OpenChat35 => true,
}
}
}
#[derive(Parser, Debug)]
@ -157,7 +181,9 @@ impl Args {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
let repo = if self.which.is_mistral() {
let repo = if self.which.is_open_chat() {
"openchat/openchat_3.5"
} else if self.which.is_mistral() {
"mistralai/Mistral-7B-v0.1"
} else {
"hf-internal-testing/llama-tokenizer"
@ -207,6 +233,7 @@ impl Args {
Which::Zephyr7bBeta => {
("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf")
}
Which::OpenChat35 => ("TheBloke/openchat_3.5-GGUF", "openchat_3.5.Q4_K_M.gguf"),
};
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(repo.to_string());
@ -308,7 +335,8 @@ fn main() -> anyhow::Result<()> {
| Which::Zephyr7bAlpha
| Which::Zephyr7bBeta
| Which::L70b
| Which::L70bChat => 8,
| Which::L70bChat
| Which::OpenChat35 => 8,
};
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
}
@ -340,7 +368,9 @@ fn main() -> anyhow::Result<()> {
prompt.pop();
}
}
if args.which.is_zephyr() {
if args.which.is_open_chat() {
format!("User: {prompt}<|end_of_turn|>Assistant: ")
} else if args.which.is_zephyr() {
if prompt_index == 0 || is_interactive {
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>",)
} else {
@ -390,8 +420,12 @@ fn main() -> anyhow::Result<()> {
std::io::stdout().flush()?;
}
let eos_token = *tos.tokenizer().get_vocab(true).get("</s>").unwrap();
let eos_token = if args.which.is_open_chat() {
"<|end_of_turn|>"
} else {
"</s>"
};
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
let start_post_prompt = std::time::Instant::now();
let mut sampled = 0;
for index in 0..to_sample {

View File

@ -416,7 +416,7 @@ fn run(args: Args) -> Result<()> {
println!("Building the autoencoder.");
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
let vae = sd_config.build_vae(&vae_weights, &device, dtype)?;
let vae = sd_config.build_vae(vae_weights, &device, dtype)?;
let init_latent_dist = match &img2img {
None => None,
Some(image) => {
@ -426,7 +426,7 @@ fn run(args: Args) -> Result<()> {
};
println!("Building the unet.");
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?;
let unet = sd_config.build_unet(&unet_weights, &device, 4, use_flash_attn, dtype)?;
let unet = sd_config.build_unet(unet_weights, &device, 4, use_flash_attn, dtype)?;
let t_start = if img2img.is_some() {
n_steps - (n_steps as f64 * img2img_strength) as usize

View File

@ -8,7 +8,7 @@ the model itself.
## Running an example
```bash
cargo run --example trocr --release -- --which base --cpu --image assets/trocr.png
cargo run --example trocr --release -- --which base --cpu --image candle-examples/examples/trocr/assets/trocr.png
```
```

View File

@ -128,7 +128,13 @@ impl Decoder {
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
let no_speech_token = m::NO_SPEECH_TOKENS
.iter()
.find_map(|token| token_id(&tokenizer, token).ok());
let no_speech_token = match no_speech_token {
None => anyhow::bail!("unable to find any non-speech token"),
Some(n) => n,
};
Ok(Self {
model,
rng: rand::rngs::StdRng::seed_from_u64(seed),
@ -512,11 +518,7 @@ fn main() -> Result<()> {
)
} else {
let config = repo.get("config.json")?;
let tokenizer = if args.model == WhichModel::LargeV3 {
panic!("openai/whisper-large-v3 does not provide a compatible tokenizer.json config at the moment")
} else {
repo.get("tokenizer.json")?
};
let tokenizer = repo.get("tokenizer.json")?;
let model = repo.get("model.safetensors")?;
(config, tokenizer, model)
};

View File

@ -74,9 +74,9 @@ impl TextGeneration {
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("</s>") {
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the </s> token"),
None => anyhow::bail!("cannot find the <|endoftext|> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {

View File

@ -43,6 +43,7 @@ pub fn report(
confidence_threshold: f32,
nms_threshold: f32,
) -> Result<DynamicImage> {
let pred = pred.to_device(&Device::Cpu)?;
let (npreds, pred_size) = pred.dims2()?;
let nclasses = pred_size - 5;
// The bounding boxes grouped by (maximum) class index.

View File

@ -32,7 +32,7 @@ Image source:
### Pose Estimation
```bash
cargo run --example yolo-v8 --release -- \
candle-examples/examples/yolo-v8/assets/peoples.jpeg --task pose
candle-examples/examples/yolo-v8/assets/bike.jpg --task pose
```
![Leading group, Giro d'Italia 2021](./assets/bike.pose.jpg)

View File

@ -7,7 +7,7 @@ extern crate accelerate_src;
mod model;
use model::{Multiples, YoloV8, YoloV8Pose};
use candle::{DType, IndexOp, Result, Tensor};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{Module, VarBuilder};
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
use clap::{Parser, ValueEnum};
@ -61,6 +61,7 @@ pub fn report_detect(
nms_threshold: f32,
legend_size: u32,
) -> Result<DynamicImage> {
let pred = pred.to_device(&Device::Cpu)?;
let (pred_size, npreds) = pred.dims2()?;
let nclasses = pred_size - 4;
// The bounding boxes grouped by (maximum) class index.
@ -153,6 +154,7 @@ pub fn report_pose(
confidence_threshold: f32,
nms_threshold: f32,
) -> Result<DynamicImage> {
let pred = pred.to_device(&Device::Cpu)?;
let (pred_size, npreds) = pred.dims2()?;
if pred_size != 17 * 3 + 4 + 1 {
candle::bail!("unexpected pred-size {pred_size}");

View File

@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
version = "0.3.0"
version = "0.3.1"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], version = "0.3.0", package = "candle-core" }
candle = { path = "../candle-core", features = ["cuda"], version = "0.3.1", package = "candle-core" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]
@ -21,4 +21,4 @@ rayon = "1.7.0"
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
candle-nn = { path = "../candle-nn", version = "0.3.0", features = ["cuda"] }
candle-nn = { path = "../candle-nn", version = "0.3.1", features = ["cuda"] }

View File

@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
version = "0.3.0"
version = "0.3.1"
edition = "2021"
description = "CUDA kernels for Candle"
@ -14,4 +14,4 @@ license = "MIT OR Apache-2.0"
[build-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
glob = "0.3.1"
rayon = "1.7.0"
rayon = "1.7.0"

View File

@ -1,17 +1,16 @@
[package]
name = "candle-metal-kernels"
version = "0.3.0"
version = "0.3.1"
edition = "2021"
description = "CUDA kernels for Candle"
description = "Metal kernels for Candle"
repository = "https://github.com/huggingface/candle"
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
metal = { path = "../../metal-rs", features = ["mps"] }
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"

View File

@ -23,12 +23,12 @@ kernel void FN_NAME( \
constant size_t &dim, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
uint tid [[ thread_position_in_grid ]] \
) { \
if (thread_position_in_grid >= dim) { \
if (tid >= dim) { \
return; \
} \
output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \
output[tid] = RIGHT_TYPENAME(input[tid]); \
} \
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@ -37,15 +37,17 @@ kernel void FN_NAME_STRIDED( \
constant size_t *strides, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint i [[ thread_position_in_grid ]] \
uint tid [[ thread_position_in_grid ]] \
) { \
if (i >= dim) { \
if (tid >= dim) { \
return; \
} \
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \
} \
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,8 @@
#include <metal_stdlib>
using namespace metal;
#define MAX(x, y) ((x) > (y) ? (x) : (y))
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
@ -16,18 +18,18 @@ METAL_FUNC uint get_strided_index(
return strided_i;
}
constant int THREADGROUP_SIZE = 256;
constant int THREADGROUP_SIZE = 1024;
# define REDUCE(FN, NAME, TYPENAME) \
# define REDUCE(FN, NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const TYPENAME *src, \
device TYPENAME *dst, \
device const T *src, \
device T *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint blockDim [[ threads_per_threadgroup ]] \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup float shared_memory[THREADGROUP_SIZE]; \
@ -45,10 +47,10 @@ kernel void NAME( \
// TODO: Fast version for the contiguous case. \
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
*/ \
TYPENAME x = shared_memory[tid]; \
TYPENAME y = src[idx]; \
T x = shared_memory[tid]; \
T y = src[idx]; \
shared_memory[tid] = FN; \
idx += blockDim; \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
@ -56,10 +58,10 @@ kernel void NAME( \
/* \
// reduction in shared memory \
*/ \
for (uint s = blockDim / 2; s > 0; s >>= 1) { \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
TYPENAME x = shared_memory[tid]; \
TYPENAME y = shared_memory[tid + s]; \
T x = shared_memory[tid]; \
T y = shared_memory[tid + s]; \
shared_memory[tid] = FN; \
} \
threadgroup_barrier(mem_flags::mem_none); \
@ -68,72 +70,74 @@ kernel void NAME( \
dst[dst_id] = shared_memory[0]; \
} \
kernel void softmax_float(
constant size_t &src_numel,
constant size_t &el_to_sum_per_block,
device const float *src,
device float *dst,
uint id [[ thread_position_in_grid ]],
uint tid [[ thread_index_in_threadgroup ]],
uint dst_id [[ threadgroup_position_in_grid ]],
uint blockDim [[ threads_per_threadgroup ]]
) {
threadgroup float shared_memory[THREADGROUP_SIZE];
shared_memory[tid] = -INFINITY;
// Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block.
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
size_t idx = start_idx + tid;
while (idx < stop_idx) {
// TODO: Fast version for the contiguous case.
shared_memory[tid] = max(shared_memory[tid], src[idx]);
idx += blockDim;
}
threadgroup_barrier(mem_flags::mem_none);
// reduction in shared memory
for (uint s = blockDim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]);
}
threadgroup_barrier(mem_flags::mem_none);
}
float max = shared_memory[0];
shared_memory[tid] = 0;
// Restart
idx = start_idx + tid;
while (idx < stop_idx) {
// TODO: Fast version for the contiguous case.
const float val = exp(src[idx] - max);
dst[idx] = val;
shared_memory[tid] += val;
idx += blockDim;
}
// reduction in shared memory
for (uint s = blockDim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] += shared_memory[tid + s];
}
threadgroup_barrier(mem_flags::mem_none);
}
const float inv_acc = 1/shared_memory[0];
idx = start_idx + tid;
while (idx < stop_idx) {
dst[idx] *= inv_acc;
idx += blockDim;
}
}
REDUCE(x + y, fast_sum_float, float)
REDUCE(x * y, fast_mul_float, float)
REDUCE(max(x, y), fast_max_float, float)
#define SOFTMAX(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
\
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
threadgroup float shared_memory[THREADGROUP_SIZE]; \
shared_memory[tid] = -INFINITY; \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
size_t idx = start_idx + tid; \
\
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
while (idx < stop_idx) { \
shared_memory[tid] = MAX(shared_memory[tid], src[idx]); \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
} \
} \
\
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
float _max = shared_memory[0]; \
\
shared_memory[tid] = 0; \
\
idx = start_idx + tid; \
while (idx < stop_idx) { \
const T val = T(exp(src[idx] - _max)); \
dst[idx] = val; \
shared_memory[tid] += val; \
idx += block_dim; \
} \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
shared_memory[tid] += shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
} \
\
const T inv_acc = T(1/shared_memory[0]); \
idx = start_idx + tid; \
while (idx < stop_idx) { \
dst[idx] *= inv_acc; \
idx += block_dim; \
} \
} \
SOFTMAX(softmax_float, float)
SOFTMAX(softmax_half, half)
#if __METAL_VERSION__ >= 310
SOFTMAX(softmax_bfloat, bfloat)
#endif

View File

@ -32,6 +32,9 @@ kernel void FN_NAME( \
device TYPENAME *out ,\
uint i [[ thread_position_in_grid ]] \
) { \
if (i >= numel){ \
return; \
} \
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \

View File

@ -0,0 +1,211 @@
import Metal
import MetalPerformanceShadersGraph
let type = MTLDataType.float;
let dataType = type;
var B = 2;
var M = 2;
var N = 4;
var K = 3;
var A_trans = false;
var B_trans = false;
var D_trans = false;
var alpha = Float(1.0);
var beta = Float(0.0);
var batched = B > 1;
var fused_activation = false;
var fused_bias = false;
let constants = MTLFunctionConstantValues()
constants.setConstantValue(&M, type: .uint, index: 0)
constants.setConstantValue(&N, type: .uint, index: 1)
constants.setConstantValue(&K, type: .uint, index: 2)
constants.setConstantValue(&A_trans, type: .bool, index: 10)
constants.setConstantValue(&B_trans, type: .bool, index: 11)
constants.setConstantValue(&D_trans, type: .bool, index: 13)
constants.setConstantValue(&alpha, type: .float, index: 20)
constants.setConstantValue(&beta, type: .float, index: 21)
constants.setConstantValue(&batched, type: .bool, index: 100)
constants.setConstantValue(&fused_activation, type: .bool, index: 101)
constants.setConstantValue(&fused_bias, type: .bool, index: 50001)
var M_simd = UInt16(16)
var N_simd = UInt16(16)
var K_simd = UInt16(32)
var M_splits = UInt16(2)
var N_splits = UInt16(2)
constants.setConstantValue(&M_simd, type: .ushort, index: 200)
constants.setConstantValue(&N_simd, type: .ushort, index: 201)
constants.setConstantValue(&K_simd, type: .ushort, index: 202)
constants.setConstantValue(&M_splits, type: .ushort, index: 210)
constants.setConstantValue(&N_splits, type: .ushort, index: 211)
let M_group = M_simd * M_splits
let N_group = N_simd * N_splits
// Satisfy Metal API validation.
#if DEBUG
do {
var garbage: SIMD4<UInt64> = .zero
constants.setConstantValue(&garbage, type: .bool, index: 102)
constants.setConstantValue(&garbage, type: .bool, index: 103)
constants.setConstantValue(&garbage, type: .bool, index: 113)
constants.setConstantValue(&garbage, type: .bool, index: 50000)
}
#endif
print(constants)
let device = MTLCopyAllDevices().first!
device.shouldMaximizeConcurrentCompilation = true
var libraryURL = URL.init(string: "/Users/nicolas/src/candle/candle-metal-kernels/")!;
libraryURL.append(component: "src")
libraryURL.append(component: "libMetalFlashAttention.metallib")
let library = try! device.makeLibrary(URL: libraryURL)
var name: String
switch dataType {
case .half: name = "hgemm"
case .float: name = "sgemm"
default: fatalError()
}
let function = try! library.makeFunction(
name: name, constantValues: constants)
let A_block_length = M_group * K_simd
let B_block_length = K_simd * N_group
var blockElements = A_block_length + B_block_length;
if (M % 8 != 0) && (N % 8 != 0) {
let C_block_length = M_group * N_group;
blockElements = max(C_block_length, blockElements)
}
if fused_bias {
if D_trans {
blockElements = max(blockElements, M_group)
} else {
blockElements = max(blockElements, N_group)
}
}
// let blockBytes = blockElements * UInt16(dataType.size)
let elementSize = 4
let blockBytes = blockElements * UInt16(elementSize)
func ceilDivide(target: Int, granularity: UInt16) -> Int {
(target + Int(granularity) - 1) / Int(granularity)
}
var gridSize = MTLSize(
width: ceilDivide(target: N, granularity: N_group),
height: ceilDivide(target: M, granularity: M_group),
depth: 1)
let groupSize = MTLSize(
width: Int(32 * M_splits * N_splits),
height: 1,
depth: 1)
let commandQueue = device.makeCommandQueue()!
let commandBuffer = commandQueue.makeCommandBuffer()!
let encoder = commandBuffer.makeComputeCommandEncoder(dispatchType: MTLDispatchType.serial)!
let pipeline = try device.makeComputePipelineState(function: function)
let threadgroupMemoryLength = blockBytes;
print(threadgroupMemoryLength)
encoder.setComputePipelineState(pipeline)
encoder.setThreadgroupMemoryLength(Int(threadgroupMemoryLength), index: 0)
let rowsA = M;
let columnsA = K;
let rowsB = K;
let columnsB = N;
let rowsC = M;
let columnsC = N;
var arrayA = [Float](repeating: 0, count: B * rowsA * columnsA)
var arrayB = [Float](repeating: 0, count: B * rowsB * columnsB)
var arrayC = [Float](repeating: 0, count: B * rowsC * columnsC)
for i in 0..<arrayA.count {
arrayA[i] = Float(i)
}
for i in 0..<arrayB.count {
arrayB[i] = Float(i)
}
let bufferA = device.makeBuffer(bytes: arrayA, length: B * rowsA * columnsA * MemoryLayout<Float>.stride, options: [])
let bufferB = device.makeBuffer(bytes: arrayB, length: B * rowsB * columnsB * MemoryLayout<Float>.stride, options: [])
let bufferC = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout<Float>.stride, options: [])
print(arrayA)
print(arrayB)
encoder.setBuffer(bufferA, offset: 0, index: 0)
encoder.setBuffer(bufferB, offset: 0, index: 1)
encoder.setBuffer(bufferC, offset: 0, index: 2)
var gridZ: Int = B
if batched{
func byteStride(shape: [Int]) -> Int {
let rank = shape.count
var output = elementSize * shape[rank - 2] * shape[rank - 1]
if shape.dropLast(2).reduce(1, *) == 1 {
output = 0
}
return output
}
let byteStrideA = M*K*elementSize
let byteStrideB = N*K*elementSize
let byteStrideC = M*N*elementSize
let byteStrideD = 0
// if let shapeD = tensors.d?.shape {
// let rank = shapeD.count
// byteStrideD = elementSize * shapeD[rank - 1]
// if shapeD.dropLast(1).reduce(1, *) == 1 {
// byteStrideD = 0
// }
// }
withUnsafeTemporaryAllocation(
of: SIMD4<UInt64>.self, capacity: gridZ
) { buffer in
for i in 0..<buffer.count {
buffer[i] = SIMD4(
UInt64(truncatingIfNeeded: i * byteStrideA),
UInt64(truncatingIfNeeded: i * byteStrideB),
UInt64(truncatingIfNeeded: i * byteStrideC),
UInt64(truncatingIfNeeded: i * byteStrideD))
}
let bufferLength = buffer.count * MemoryLayout<SIMD4<UInt64>>.stride
assert(MemoryLayout<SIMD4<UInt64>>.stride == 8 * 4)
encoder.setBytes(buffer.baseAddress!, length: bufferLength, index: 10)
print("BATCHED")
print(buffer)
}
}
gridSize.depth = gridZ
print(gridSize, groupSize)
encoder.dispatchThreadgroups(
gridSize, threadsPerThreadgroup: groupSize
)
encoder.endEncoding()
commandBuffer.commit()
commandBuffer.waitUntilCompleted()
var contents = bufferC!.contents();
var count = B * rowsA * columnsB;
var typedPointer = contents.bindMemory(to: Float.self, capacity: count)
var bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count)
print(Array(bufferedPointer))

View File

@ -0,0 +1,800 @@
use super::*;
use half::{bf16, f16};
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
let options = MTLResourceOptions::StorageModeManaged;
let ptr = data.as_ptr() as *const core::ffi::c_void;
let size = (data.len() * std::mem::size_of::<T>()) as u64;
device.new_buffer_with_data(ptr, size, options)
}
fn device() -> Device {
Device::system_default().unwrap()
}
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
let b = 10f32.powi(digits);
v.iter().map(|t| f32::round(t * b) / b).collect()
}
fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> {
let b = 10f32.powi(digits);
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
}
fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> {
let b = 10f32.powi(digits);
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
}
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let output = new_buffer(&device, v);
call_unary_contiguous(
&device,
command_buffer,
&kernels,
name,
v.len(),
&input,
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(v.len())
}
fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let left = new_buffer(&device, x);
let right = new_buffer(&device, y);
let output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
call_binary_contiguous(
&device,
command_buffer,
&kernels,
name,
x.len(),
&left,
&right,
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(x.len())
}
fn run_strided<T: Clone>(
v: &[T],
kernel: unary::strided::Kernel,
shape: &[usize],
strides: &[usize],
offset: usize,
) -> Vec<T> {
let device = device();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let output = new_buffer(&device, v);
let kernels = Kernels::new();
call_unary_strided(
&device,
command_buffer,
&kernels,
kernel,
shape,
&input,
strides,
offset,
&output,
0,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(v.len())
}
#[test]
fn cos_f32() {
let v = vec![1.0f32, 2.0, 3.0];
let results = run(&v, unary::contiguous::cos::FLOAT);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
let v = vec![1.0f32; 10_000];
let results = run(&v, unary::contiguous::cos::FLOAT);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
}
#[test]
fn cos_f32_strided() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let shape = vec![6];
let strides = vec![1];
let offset = 0;
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(
approx(results, 4),
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
);
assert_eq!(
approx(expected, 4),
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
);
// Contiguous
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let shape = vec![3, 2];
let strides = vec![2, 1];
let offset = 0;
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(
approx(results, 4),
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
);
assert_eq!(
approx(expected, 4),
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
);
// Transposed
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let shape = vec![3, 2];
let strides = vec![1, 3];
let offset = 0;
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(
approx(results, 4),
vec![0.5403, -0.6536, -0.4161, 0.2837, -0.99, 0.9602]
);
assert_eq!(
approx(expected, 4),
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
);
// Very large
let v = vec![1.0f32; 10_000];
let shape = vec![2, 5_000];
let strides = vec![2, 1];
let offset = 0;
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
}
#[test]
fn cos_strided_random() {
let v: Vec<_> = (0..10_000).map(|_| rand::random::<f32>()).collect();
let shape = vec![5_000, 2];
let strides = vec![1, 5_000];
let offset = 0;
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(approx(vec![results[0]], 4), approx(vec![expected[0]], 4));
assert_eq!(
approx(vec![results[1]], 4),
approx(vec![expected[5_000]], 4)
);
assert_eq!(approx(vec![results[2]], 4), approx(vec![expected[1]], 4));
assert_eq!(
approx(vec![results[3]], 4),
approx(vec![expected[5_001]], 4)
);
assert_eq!(
approx(vec![results[5_000]], 4),
approx(vec![expected[2_500]], 4)
);
}
#[test]
fn binary_add_f32() {
let left = vec![1.0f32, 2.0, 3.0];
let right = vec![2.0f32, 3.1, 4.2];
let results = run_binary(&left, &right, binary::contiguous::add::FLOAT);
let expected: Vec<_> = left
.iter()
.zip(right.iter())
.map(|(&x, &y)| x + y)
.collect();
assert_eq!(approx(results, 4), vec![3.0f32, 5.1, 7.2]);
assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]);
}
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let options = MTLResourceOptions::StorageModeManaged;
let size = (v.len() * std::mem::size_of::<U>()) as u64;
let output = device.new_buffer(size, options);
call_cast_contiguous(
&device,
command_buffer,
&kernels,
name,
v.len(),
&input,
0,
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<U>(v.len())
}
#[test]
fn cast_u32_f32() {
let v = vec![1u32, 2, 3];
let results = cast(&v, "cast_u32_f32");
let expected: Vec<_> = v.iter().map(|&v| v as f32).collect();
assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
let v = vec![1.0f32, 2.0, 3.0];
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
let results: Vec<f32> = cast(&input, "cast_f16_f32");
assert_eq!(results, vec![1.0f32, 2.0, 3.0]);
let v = vec![1.0f32; 10_000];
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
let results: Vec<f32> = cast(&input, "cast_f16_f32");
assert_eq!(results.len(), 10_000);
assert_eq!(&results[..10], vec![1.0f32; 10]);
assert_eq!(results, vec![1.0f32; 10_000]);
}
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let output = new_buffer(&device, v);
let size = v.len();
call_affine(
&device,
command_buffer,
&kernels,
"affine_float",
size,
&input,
&output,
mul as f32,
add as f32,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(v.len())
}
fn run_affine_strided<T: Clone>(
v: &[T],
shape: &[usize],
strides: &[usize],
mul: f64,
add: f64,
) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let output = new_buffer(&device, v);
call_affine_strided(
&device,
command_buffer,
&kernels,
"affine_float_strided",
shape,
&input,
strides,
0,
&output,
mul as f32,
add as f32,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
let len: usize = shape.iter().product();
output.read_to_vec::<T>(len)
}
#[test]
fn affine() {
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mul = 1.5;
let add = 1.1;
let result = run_affine(&input, mul, add);
assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]);
let input = [1.0f32; 40_000];
let mul = 1.5;
let add = 1.1;
let result = run_affine(&input, mul, add);
assert_eq!(result, vec![2.6; 40_000]);
}
#[test]
fn affine_strided() {
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mul = 1.5;
let add = 1.1;
let shape = [4];
let strides = [2];
let result = run_affine_strided(&input, &shape, &strides, mul, add);
// 1 on 2
assert_eq!(result, vec![2.6, 5.6, 8.6, 11.6]);
}
#[test]
fn index_select() {
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [5, 2];
let ids = [0u32, 4, 2];
let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim);
assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]);
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [2, 5];
let ids = [0u32, 1, 0];
let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim);
assert_eq!(
result,
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
);
}
#[test]
fn index_select_f16() {
let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
.into_iter()
.map(|x| f16::from_f32(x))
.collect();
let shape = [5, 2];
let ids = [0u32, 4, 2];
let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim);
assert_eq!(
approx_f16(result, 4),
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
);
}
#[test]
fn index_select_dim1() {
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [5, 2];
let ids = [0u32, 1, 0];
let dim = 1;
let result = run_index_select(&embedding, &shape, &ids, dim);
assert_eq!(
result,
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
);
}
fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
embeddings: &[T],
shape: &[usize],
ids: &[I],
dim: usize,
) -> Vec<T> {
let device = Device::system_default().expect("no device found");
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let embeddings_buffer = new_buffer(&device, &embeddings);
let ids_buffer = new_buffer(&device, &ids);
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
let dst_el = ids.len() * left_size * right_size;
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
let name = match core::mem::size_of::<T>() {
4 => "is_u32_f32",
2 => "is_u32_f16",
_ => unimplemented!(),
};
let kernels = Kernels::new();
call_index_select(
&device,
&command_buffer,
&kernels,
name,
shape,
ids.len(),
dim,
&embeddings_buffer,
&ids_buffer,
&dst_buffer,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
dst_buffer.read_to_vec::<T>(dst_el)
}
#[test]
fn index_add() {
let device = Device::system_default().expect("no device found");
let options = CompileOptions::new();
let library = device.new_library_with_source(INDEXING, &options).unwrap();
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let right = [1.0f32; 15];
let index = [0u32, 4, 2];
let ids_dim_size = index.len() as u32;
let dst_dim_size: u32 = 15;
let left_size: u32 = 3;
let right_size: u32 = 3;
let function = library.get_function("ia_u32_f32", None).unwrap();
let pipeline = device
.new_compute_pipeline_state_with_function(&function)
.unwrap();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
let index_buffer = new_buffer(&device, &index);
let inputs_buffer = new_buffer(&device, &left);
let outputs_buffer = new_buffer(&device, &right);
set_params!(
encoder,
(
&index_buffer,
&inputs_buffer,
&outputs_buffer,
ids_dim_size,
left_size,
dst_dim_size,
right_size
)
);
let grid_size = MTLSize {
width: right.len() as NSUInteger,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width: pipeline.max_total_threads_per_threadgroup(),
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(grid_size, thread_group_size);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
let expected = vec![
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
];
let result = outputs_buffer.read_to_vec::<f32>(right.len());
assert_eq!(result, expected);
}
#[test]
fn cos_f16() {
let v: Vec<f16> = [1.0f32, 2.0, 3.0]
.iter()
.map(|v| f16::from_f32(*v))
.collect();
let results = run(&v, unary::contiguous::cos::HALF);
let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
assert_eq!(approx_f16(results, 4), vec![0.5405, -0.4163, -0.9902]);
assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]);
}
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let options = MTLResourceOptions::StorageModeManaged;
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
call_reduce_contiguous(
&device,
command_buffer,
&kernels,
name,
v.len(),
out_length,
&input,
0,
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(out_length)
}
fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let output = new_buffer(&device, v);
call_last_softmax(
&device,
command_buffer,
&kernels,
name,
v.len(),
last_dim,
&input,
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(v.len())
}
#[test]
fn reduce_sum() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let out_length = 1;
let results = run_reduce(&v, out_length, "fast_sum_float");
assert_eq!(approx(results, 4), vec![21.0]);
}
#[test]
fn reduce_sum2() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let out_length = 2;
let results = run_reduce(&v, out_length, "fast_sum_float");
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
}
#[test]
fn softmax() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let last_dim = 6;
let results = run_softmax(&v, last_dim, "softmax_float");
assert_eq!(
approx(results, 4),
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
);
let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
let last_dim = 6;
let results = run_softmax(&v, last_dim, "softmax_float");
assert_eq!(
approx(results, 4),
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
);
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let last_dim = 3;
let results = run_softmax(&v, last_dim, "softmax_float");
assert_eq!(
approx(results, 4),
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
);
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
.iter()
.map(|v| f16::from_f32(*v))
.collect::<Vec<_>>();
let last_dim = 6;
let results = run_softmax(&v, last_dim, "softmax_half");
assert_eq!(
approx_f16(results, 4),
vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
);
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
.iter()
.map(|v| bf16::from_f32(*v))
.collect::<Vec<_>>();
let last_dim = 6;
let results = run_softmax(&v, last_dim, "softmax_bfloat");
assert_eq!(
approx_bf16(results, 4),
vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328]
);
}
fn run_where_cond<I: Clone, T: Clone>(
shape: &[usize],
cond: &[I],
(cond_stride, cond_offset): (Vec<usize>, usize),
left_true: &[T],
(left_stride, left_offset): (Vec<usize>, usize),
right_false: &[T],
(_right_stride, _right_offset): (Vec<usize>, usize),
name: &'static str,
) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let length = cond.len();
let cond = device.new_buffer_with_data(
cond.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(cond) as u64,
options,
);
let left = device.new_buffer_with_data(
left_true.as_ptr() as *const core::ffi::c_void,
(length * core::mem::size_of::<T>()) as u64,
options,
);
let right = device.new_buffer_with_data(
right_false.as_ptr() as *const core::ffi::c_void,
(length * core::mem::size_of::<T>()) as u64,
options,
);
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
call_where_cond_strided(
&device,
command_buffer,
&kernels,
name,
shape,
&cond,
(&cond_stride, cond_offset),
&left,
(&left_stride, left_offset),
&right,
(&cond_stride, cond_offset),
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(length)
}
#[test]
fn where_cond() {
let shape = vec![6];
let cond = vec![0u8, 1, 0, 0, 1, 1];
let cond_l = (vec![1], 0);
let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let left_l = (vec![1], 0);
let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0];
let right_l = (vec![1], 0);
let results = run_where_cond(
&shape,
&cond,
cond_l,
&left_true,
left_l,
&right_false,
right_l,
"where_u8_f32",
);
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
}
fn run_gemm<T: Clone>(
(b, m, n, k): (usize, usize, usize, usize),
lhs: &[T],
lhs_stride: Vec<usize>,
rhs: &[T],
rhs_stride: Vec<usize>,
) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let lhs = device.new_buffer_with_data(
lhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(lhs) as u64,
options,
);
let rhs = device.new_buffer_with_data(
rhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(rhs) as u64,
options,
);
let length = b * m * n;
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
call_gemm(
&device,
command_buffer,
&kernels,
"sgemm",
(b, m, n, k),
&lhs_stride,
0,
&lhs,
&rhs_stride,
0,
&rhs,
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(length)
}
#[test]
fn gemm() {
let (b, m, n, k) = (1, 2, 4, 3);
let lhs_stride = vec![m * k, k, 1];
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, &rhs, rhs_stride);
assert_eq!(
approx(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
);
let (b, m, n, k) = (2, 2, 4, 3);
let lhs_stride = vec![m * k, k, 1];
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, &rhs, rhs_stride);
assert_eq!(
approx(results, 4),
vec![
20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0,
518.0, 548.0, 578.0
]
);
}

View File

@ -11,7 +11,7 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
half = { workspace = true }
thiserror = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
@ -19,6 +19,7 @@ num-traits = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
[dev-dependencies]
anyhow = { workspace = true }
@ -29,3 +30,4 @@ default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"]
mkl = ["dep:intel-mkl-src", "candle/mkl"]
metal = ["candle/metal", "dep:candle-metal-kernels"]

View File

@ -6,7 +6,7 @@ extern crate intel_mkl_src;
extern crate accelerate_src;
use candle::quantized::GgmlType;
use candle::{CpuStorage, Device, Layout, Result, Shape, Tensor, D};
use candle::{CpuStorage, Device, Layout, Module, Result, Shape, Tensor, D};
use clap::{Parser, Subcommand};
const CHECK_CONV2D: bool = false;

View File

@ -201,6 +201,46 @@ impl candle::CustomOp1 for SoftmaxLastDim {
};
Ok((dst, layout.shape().clone()))
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
storage: &candle::MetalStorage,
layout: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::{backend::BackendStorage, DType};
let device = storage.device();
let command_buffer = device.command_buffer();
let kernels = device.kernels();
let name = match storage.dtype() {
DType::F32 => "softmax_float",
DType::F16 => "softmax_half",
DType::BF16 => "softmax_bfloat",
dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
};
let n = layout.stride().len();
if !(layout.stride()[n - 1] == 1 && layout.start_offset() == 0) {
candle::bail!("Non contiguous softmax-last-dim is not implemented");
}
let last_dim = layout.dims()[layout.shape().rank() - 1];
let elem_count = layout.shape().elem_count();
let mut output = device.new_buffer(elem_count, storage.dtype());
candle_metal_kernels::call_last_softmax(
device.metal_device(),
&command_buffer,
&kernels,
name,
elem_count,
last_dim,
storage.buffer(),
&mut output,
)
.unwrap();
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
Ok((newstorage, layout.shape().clone()))
}
}
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {

View File

@ -1,6 +1,6 @@
[package]
name = "candle-onnx"
version = "0.3.0"
version = "0.3.1"
edition = "2021"
description = "ONNX support for Candle"
@ -10,8 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.1" }
prost = "0.12.1"
[build-dependencies]

View File

@ -741,6 +741,25 @@ pub fn simple_eval(
let output = input.to_dtype(dtype)?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#CumSum
"CumSum" => {
let exclusive = get_attr_opt::<i64>(node, "exclusive")?
.copied()
.unwrap_or(0);
let reverse = get_attr_opt::<i64>(node, "reverse")?.copied().unwrap_or(0);
if exclusive != 0 {
bail!("only exclusive == 0 is supported in CumSum")
}
if reverse != 0 {
bail!("only reverse == 0 is supported in CumSum")
}
let input = get(&node.input[0])?;
let axis = get(&node.input[1])?
.to_dtype(DType::U32)?
.to_vec0::<u32>()?;
let output = input.cumsum(axis as usize)?;
values.insert(node.output[0].clone(), output);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}

746
candle-onnx/tests/ops.rs Normal file
View File

@ -0,0 +1,746 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{Device, Result, Tensor};
use candle_onnx::onnx::{GraphProto, ModelProto, NodeProto, ValueInfoProto};
use std::collections::HashMap;
const INPUT_X: &str = "x";
const INPUT_Y: &str = "y";
const OUTPUT_Z: &str = "z";
fn create_model_proto_with_graph(graph: Option<GraphProto>) -> ModelProto {
ModelProto {
metadata_props: vec![],
training_info: vec![],
functions: vec![],
ir_version: 0,
opset_import: vec![],
producer_name: "".to_string(),
producer_version: "".to_string(),
domain: "".to_string(),
model_version: 0,
doc_string: "".to_string(),
graph,
}
}
#[test]
fn test_evaluation_fails_without_defined_graph() -> Result<()> {
let manual_graph = create_model_proto_with_graph(None);
let inputs: HashMap<String, Tensor> = HashMap::new();
match candle_onnx::simple_eval(&manual_graph, inputs) {
Err(err) => assert_eq!(err.to_string(), "no graph defined in proto"),
Ok(_) => panic!("Expected an error due to undefined graph"),
}
Ok(())
}
// "Add"
#[test]
fn test_add_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Add".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let first = z
.to_vec1::<f64>()?
.to_vec()
.get(0)
.expect("Failed to get first element")
.clone();
assert_eq!(first, 4.0f64);
Ok(())
}
// "Sub"
#[test]
fn test_sub_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Sub".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let first = z
.to_vec1::<f64>()?
.to_vec()
.get(0)
.expect("Failed to get first element")
.clone();
assert_eq!(first, 0.0f64);
Ok(())
}
// "Mul"
#[test]
fn test_mul_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Mul".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let first = z
.to_vec1::<f64>()?
.to_vec()
.get(0)
.expect("Failed to get first element")
.clone();
assert_eq!(first, 4.0f64);
Ok(())
}
// "Div"
#[test]
fn test_div_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Div".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let first = z
.to_vec1::<f64>()?
.to_vec()
.get(0)
.expect("Failed to get first element")
.clone();
assert_eq!(first, 1.0f64);
Ok(())
}
// "Equal"
#[test]
fn test_equal_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Equal".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let first = z.to_dtype(candle::DType::U8)?.to_vec1::<u8>()?.to_vec()[0];
assert_eq!(first, 1);
Ok(())
}
// "Not"
#[test]
fn test_not_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Not".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(&[0.], &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let first = z.to_dtype(candle::DType::U8)?.to_vec1::<u8>()?.to_vec()[0];
assert_eq!(first, 1);
Ok(())
}
// "MatMul"
#[test]
fn test_matmul_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "MatMul".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(
INPUT_X.to_string(),
Tensor::from_vec(
//
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
&[2, 2],
&Device::Cpu,
)?,
);
inputs.insert(
INPUT_Y.to_string(),
Tensor::from_vec(
//
vec![5.0f32, 6.0f32, 7.0f32, 8.0f32],
&[2, 2],
&Device::Cpu,
)?,
);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(results, vec![vec![19.0, 22.0], vec![43.0, 50.0]]);
Ok(())
}
// "Reshape"
#[test]
fn test_reshape_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Reshape".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
//
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
&[2, 2],
&Device::Cpu,
)?;
let y = Tensor::from_vec(
//
vec![4i64],
&[1],
&Device::Cpu,
)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
inputs.insert(INPUT_Y.to_string(), y);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec1::<f32>()?;
assert_eq!(results, vec![1.0, 2.0, 3.0, 4.0]);
Ok(())
}
// "LogSoftmax"
#[test]
fn test_logsoftmax_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "LogSoftmax".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
//
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
&[2, 2],
&Device::Cpu,
)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(
results,
vec![vec![0.26894143, 0.7310586], vec![0.26894143, 0.7310586]]
);
Ok(())
}
// "Softmax"
#[test]
fn test_softmax_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Softmax".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
//
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
&[2, 2],
&Device::Cpu,
)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(
results,
vec![vec![0.26894143, 0.7310586], vec![0.26894143, 0.7310586]]
);
Ok(())
}
// "Transpose"
#[test]
fn test_transpose_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Transpose".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
//
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
&[2, 2],
&Device::Cpu,
)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(results, vec![vec![1.0, 3.0], vec![2.0, 4.0]]);
Ok(())
}
// "Dropout"
#[test]
fn test_dropout_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Dropout".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
//
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
&[2, 2],
&Device::Cpu,
)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(results, vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
Ok(())
}
// Below are ops that are implemented but not tested yet
// "MaxPool"
// #[test]
// "AveragePool"
// #[test]
// "BatchNormalization"
// #[test]
// "Squeeze"
// #[test]
// "ConstantOfShape"
// #[test]
// "Unsqueeze"
// #[test]
// "Clip"
// #[test]
// "Gather"
// #[test]
// "Shape"
// #[test]
// "Conv"
// #[test]
// "Concat"
// #[test]
// "Abs"
// #[test]
// "Cos"
// #[test]
// "Sin"
// #[test]
// "Neg"
// #[test]
// "Erf"
// #[test]
// "Tanh"
// #[test]
// "Sigmoid"
// #[test]
// "Gelu"
// #[test]
// "Relu"
// #[test]
// "Constant"
// #[test]
// "Cast"
// #[test]

View File

@ -15,9 +15,9 @@ crate-type = ["cdylib"]
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
candle-onnx = {path= "../candle-onnx", version = "0.3.0", optional = true}
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.1" }
candle-onnx = {path= "../candle-onnx", version = "0.3.1", optional = true}
half = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }

View File

@ -17,7 +17,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
use ::candle::{quantized::QTensor, DType, Device, Module, Tensor, WithDType};
mod utils;
use utils::wrap_err;

View File

@ -12,9 +12,9 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true }
candle-nn = { path = "../candle-nn", version = "0.3.1" }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
rand = { workspace = true }

View File

@ -0,0 +1,342 @@
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
use candle::{DType, Device, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
use serde::Deserialize;
pub const DTYPE: DType = DType::F32;
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "lowercase")]
enum HiddenAct {
Gelu,
Relu,
}
struct HiddenActLayer {
act: HiddenAct,
span: tracing::Span,
}
impl HiddenActLayer {
fn new(act: HiddenAct) -> Self {
let span = tracing::span!(tracing::Level::TRACE, "hidden-act");
Self { act, span }
}
}
impl Module for HiddenActLayer {
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
let _enter = self.span.enter();
match self.act {
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
HiddenAct::Gelu => xs.gelu(),
HiddenAct::Relu => xs.relu(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
enum PositionEmbeddingType {
#[default]
Absolute,
}
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
vocab_size: usize,
dim: usize,
n_layers: usize,
n_heads: usize,
hidden_dim: usize,
activation: HiddenAct,
max_position_embeddings: usize,
initializer_range: f64,
pad_token_id: usize,
#[serde(default)]
position_embedding_type: PositionEmbeddingType,
#[serde(default)]
use_cache: bool,
model_type: Option<String>,
}
impl Default for Config {
fn default() -> Self {
Self {
vocab_size: 30522,
dim: 768,
n_layers: 12,
n_heads: 12,
hidden_dim: 3072,
activation: HiddenAct::Gelu,
max_position_embeddings: 512,
initializer_range: 0.02,
pad_token_id: 0,
position_embedding_type: PositionEmbeddingType::Absolute,
use_cache: true,
model_type: Some("distilbert".to_string()),
}
}
}
struct Embeddings {
word_embeddings: Embedding,
position_embeddings: Embedding,
layer_norm: LayerNorm,
span: tracing::Span,
}
impl Embeddings {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let word_embeddings =
candle_nn::embedding(config.vocab_size, config.dim, vb.pp("word_embeddings"))?;
let position_embeddings = candle_nn::embedding(
config.max_position_embeddings,
config.dim,
vb.pp("position_embeddings"),
)?;
let layer_norm = layer_norm(config.dim, 1e-12, vb.pp("LayerNorm"))?;
Ok(Self {
word_embeddings,
position_embeddings,
layer_norm,
span: tracing::span!(tracing::Level::TRACE, "embeddings"),
})
}
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (_bsize, seq_len) = input_ids.dims2()?;
let input_embeddings = self.word_embeddings.forward(input_ids)?;
let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
let embeddings =
input_embeddings.broadcast_add(&self.position_embeddings.forward(&position_ids)?)?;
let embeddings = self.layer_norm.forward(&embeddings)?;
Ok(embeddings)
}
}
struct MultiHeadSelfAttention {
q_lin: Linear,
k_lin: Linear,
v_lin: Linear,
out_lin: Linear,
n_heads: usize,
attention_head_size: usize,
span: tracing::Span,
}
impl MultiHeadSelfAttention {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let attention_head_size = config.dim / config.n_heads;
let all_head_size = config.n_heads * attention_head_size;
let dim = config.dim;
let q_lin = linear(dim, all_head_size, vb.pp("q_lin"))?;
let v_lin = linear(dim, all_head_size, vb.pp("v_lin"))?;
let k_lin = linear(dim, all_head_size, vb.pp("k_lin"))?;
let out_lin = linear(all_head_size, dim, vb.pp("out_lin"))?;
Ok(Self {
q_lin,
k_lin,
v_lin,
out_lin,
n_heads: config.n_heads,
attention_head_size,
span: tracing::span!(tracing::Level::TRACE, "attention"),
})
}
}
impl MultiHeadSelfAttention {
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (bs, q_length, _dim) = hidden_states.dims3()?;
let dim_per_head = self.attention_head_size;
let q = self.q_lin.forward(hidden_states)?;
let k = self.k_lin.forward(hidden_states)?;
let v = self.v_lin.forward(hidden_states)?;
let q = q
.reshape((bs, q_length, self.n_heads, dim_per_head))?
.transpose(1, 2)?;
let k = k
.reshape((bs, q_length, self.n_heads, dim_per_head))?
.transpose(1, 2)?;
let v = v
.reshape((bs, q_length, self.n_heads, dim_per_head))?
.transpose(1, 2)?;
let q: Tensor = (q / (dim_per_head as f64).sqrt())?;
let scores = q.matmul(&k.transpose(2, 3)?.contiguous()?)?;
let mask = attention_mask.broadcast_as(scores.shape())?;
let scores = masked_fill(&scores.to_dtype(DType::F32)?, &mask, f32::NEG_INFINITY)?;
let weights = candle_nn::ops::softmax(&scores, candle::D::Minus1)?;
let context = weights.matmul(&v.contiguous()?)?;
let context = context
.transpose(1, 2)?
.reshape((bs, q_length, self.n_heads * dim_per_head))?
.contiguous()?;
let context = self.out_lin.forward(&context)?;
Ok(context)
}
}
#[allow(clippy::upper_case_acronyms)]
struct FFN {
lin1: Linear,
lin2: Linear,
activation: HiddenActLayer,
span: tracing::Span,
}
impl FFN {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let lin1 = linear(config.dim, config.hidden_dim, vb.pp("lin1"))?;
let lin2 = linear(config.hidden_dim, config.dim, vb.pp("lin2"))?;
Ok(Self {
lin1,
lin2,
activation: HiddenActLayer::new(config.activation),
span: tracing::span!(tracing::Level::TRACE, "ffn"),
})
}
}
impl Module for FFN {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
hidden_states
.apply(&self.lin1)?
.apply(&self.activation)?
.apply(&self.lin2)
}
}
struct TransformerBlock {
attention: MultiHeadSelfAttention,
sa_layer_norm: LayerNorm,
ffn: FFN,
output_layer_norm: LayerNorm,
span: tracing::Span,
}
impl TransformerBlock {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let attention = MultiHeadSelfAttention::load(vb.pp("attention"), config)?;
let sa_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("sa_layer_norm"))?;
let ffn = FFN::load(vb.pp("ffn"), config)?;
let output_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("output_layer_norm"))?;
Ok(Self {
attention,
sa_layer_norm,
ffn,
output_layer_norm,
span: tracing::span!(tracing::Level::TRACE, "layer"),
})
}
}
impl TransformerBlock {
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let sa_output = self.attention.forward(hidden_states, attention_mask)?;
// TODO: Support cross-attention?
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
// TODO: Support something similar to `apply_chunking_to_forward`?
let sa_output = sa_output.broadcast_add(hidden_states)?;
let sa_output = self.sa_layer_norm.forward(&sa_output)?;
let ffn_output = self.ffn.forward(&sa_output)?;
let ffn_output = (&ffn_output + sa_output)?;
let output = self.output_layer_norm.forward(&ffn_output)?;
Ok(output)
}
}
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
struct Transformer {
layers: Vec<TransformerBlock>,
span: tracing::Span,
}
impl Transformer {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let layers = (0..config.n_layers)
.map(|index| TransformerBlock::load(vb.pp(&format!("layer.{index}")), config))
.collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "encoder");
Ok(Transformer { layers, span })
}
}
impl Transformer {
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut hidden_states = hidden_states.clone();
// Use a loop rather than a fold as it's easier to modify when adding debug/...
for layer in self.layers.iter() {
hidden_states = layer.forward(&hidden_states, attention_mask)?;
}
Ok(hidden_states)
}
}
pub struct DistilBertModel {
embeddings: Embeddings,
transformer: Transformer,
pub device: Device,
span: tracing::Span,
}
impl DistilBertModel {
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let (embeddings, transformer) = match (
Embeddings::load(vb.pp("embeddings"), config),
Transformer::load(vb.pp("transformer"), config),
) {
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
(Err(err), _) | (_, Err(err)) => {
if let Some(model_type) = &config.model_type {
if let (Ok(embeddings), Ok(encoder)) = (
Embeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
Transformer::load(vb.pp(&format!("{model_type}.transformer")), config),
) {
(embeddings, encoder)
} else {
return Err(err);
}
} else {
return Err(err);
}
}
};
Ok(Self {
embeddings,
transformer,
device: vb.device().clone(),
span: tracing::span!(tracing::Level::TRACE, "model"),
})
}
pub fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let embedding_output = self.embeddings.forward(input_ids)?;
let sequence_output = self
.transformer
.forward(&embedding_output, attention_mask)?;
Ok(sequence_output)
}
}

View File

@ -4,6 +4,7 @@ pub mod blip;
pub mod blip_text;
pub mod convmixer;
pub mod dinov2;
pub mod distilbert;
pub mod efficientnet;
pub mod falcon;
pub mod jina_bert;

View File

@ -1,12 +1,15 @@
use candle::{Device, Result, Tensor};
pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
if steps < 1 {
candle::bail!("cannot use linspace with steps {steps} <= 1")
if steps == 0 {
Tensor::from_vec(Vec::<f64>::new(), steps, &Device::Cpu)
} else if steps == 1 {
Tensor::from_vec(vec![start], steps, &Device::Cpu)
} else {
let delta = (stop - start) / (steps - 1) as f64;
let vs = (0..steps)
.map(|step| start + step as f64 * delta)
.collect::<Vec<_>>();
Tensor::from_vec(vs, steps, &Device::Cpu)
}
let delta = (stop - start) / (steps - 1) as f64;
let vs = (0..steps)
.map(|step| start + step as f64 * delta)
.collect::<Vec<_>>();
Tensor::from_vec(vs, steps, &Device::Cpu)
}

View File

@ -138,6 +138,10 @@ impl TrOCRAttention {
})
}
fn reset_kv_cache(&mut self) {
self.kv_cache = None
}
fn _shape(&self, tensor: &Tensor, bsz: usize) -> Result<Tensor> {
tensor
.reshape((bsz, (), self.num_heads, self.head_dim))?
@ -239,6 +243,10 @@ impl TrOCRDecoderLayer {
})
}
fn reset_kv_cache(&mut self) {
self.self_attn.reset_kv_cache();
}
fn forward(
&mut self,
xs: &Tensor,
@ -307,6 +315,10 @@ impl TrOCRDecoder {
})
}
fn reset_kv_cache(&mut self) {
self.layers.iter_mut().for_each(|l| l.reset_kv_cache())
}
pub fn forward(
&mut self,
xs: &Tensor,
@ -393,6 +405,10 @@ impl TrOCRForCausalLM {
Ok(xs)
}
fn reset_kv_cache(&mut self) {
self.decoder.reset_kv_cache();
}
}
#[derive(Debug, Clone)]
@ -431,4 +447,8 @@ impl TrOCRModel {
self.decoder
.forward(xs, Some(encoder_xs), past_kv_len, &mask)
}
pub fn reset_kv_cache(&mut self) {
self.decoder.reset_kv_cache();
}
}

View File

@ -58,8 +58,7 @@ fn dft<T: Float>(inp: &[T]) -> Vec<T> {
let n = inp.len();
let two_pi = T::PI() + T::PI();
let mut out = Vec::new();
out.reserve(2 * n);
let mut out = Vec::with_capacity(2 * n);
let n_t = T::from(n).unwrap();
for k in 0..n {
let k_t = T::from(k).unwrap();

View File

@ -43,4 +43,4 @@ pub const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
pub const TRANSLATE_TOKEN: &str = "<|translate|>";
pub const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
pub const EOT_TOKEN: &str = "<|endoftext|>";
pub const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
pub const NO_SPEECH_TOKENS: [&str; 2] = ["<|nocaptions|>", "<|nospeech|>"];

View File

@ -277,8 +277,12 @@ impl DecoderLayer {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("ln1"))?;
let ln2 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("ln2"))?;
let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
let ln2 = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
vb.pp("post_attention_layernorm"),
)?;
Ok(Self {
self_attn,
mlp,

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
num-traits = { workspace = true }

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
num-traits = { workspace = true }

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
num-traits = { workspace = true }
# App crates.

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }

View File

@ -59,8 +59,7 @@ fn dft<T: Float>(inp: &[T]) -> Vec<T> {
let n = inp.len();
let two_pi = T::PI() + T::PI();
let mut out = Vec::new();
out.reserve(2 * n);
let mut out = Vec::with_capacity(2 * n);
let n_t = T::from(n).unwrap();
for k in 0..n {
let k_t = T::from(k).unwrap();

View File

@ -129,7 +129,13 @@ impl Decoder {
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
let no_speech_token = m::NO_SPEECH_TOKENS
.iter()
.find_map(|token| token_id(&tokenizer, token).ok());
let no_speech_token = match no_speech_token {
None => anyhow::bail!("unable to find any non-speech token"),
Some(n) => n,
};
let seed = 299792458;
Ok(Self {
model,

View File

@ -9,8 +9,8 @@ categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
num-traits = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }

View File

@ -7,7 +7,7 @@ keywords.workspace = true
categories.workspace = true
[dependencies]
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
rand = { workspace = true }
getrandom = { version = "0.2", features = ["js"] }

View File

@ -1,7 +1,7 @@
use candle::{
quantized::{self, k_quants, GgmlDType, GgmlType},
test_utils::to_vec2_round,
Device, Result, Tensor,
Device, Module, Result, Tensor,
};
use wasm_bindgen_test::*;