mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
27 Commits
Author | SHA1 | Date | |
---|---|---|---|
7e49e0af96 | |||
181d2299b2 | |||
2801541e5f | |||
4289984d32 | |||
1471f98f0b | |||
dd4a40f1c0 | |||
79845bd93b | |||
6071797450 | |||
b58b247323 | |||
3900091e75 | |||
54355ff997 | |||
e02f1912bb | |||
a52b71686b | |||
7adfb70dff | |||
3ad02147e4 | |||
4f39695465 | |||
4cf4844c9d | |||
d840838e95 | |||
61a070fdd1 | |||
e35669647d | |||
53e8b7ee3e | |||
cc26cce23c | |||
02c2ec2c71 | |||
9a2784b8ab | |||
0f652f0e3d | |||
ddee9dc1dd | |||
fc9bb7784a |
@ -19,7 +19,7 @@ exclude = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.3.1"
|
||||
version = "0.3.0"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -61,7 +61,10 @@ tracing-subscriber = "0.3.7"
|
||||
wav = "1.0.0"
|
||||
yoke = { version = "0.7.2", features = ["derive"] }
|
||||
zip = { version = "0.6.6", default-features = false }
|
||||
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
|
||||
#metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||
metal = { path = "../metal-rs", features = ["mps"] }
|
||||
dispatch = "0.2.0"
|
||||
rustc-hash = "1.1"
|
||||
|
||||
[profile.release-with-debug]
|
||||
inherits = "release"
|
||||
|
17
README.md
17
README.md
@ -139,16 +139,16 @@ And then head over to
|
||||
<!--- ANCHOR: useful_libraries --->
|
||||
|
||||
## Useful External Resources
|
||||
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): A
|
||||
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): a
|
||||
very detailed tutorial showing how to convert a PyTorch model to Candle.
|
||||
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): Efficient and ergonomic LoRA implemenation for Candle. `candle-lora` has
|
||||
out-of-the-box LoRA support for many models from Candle, which can be found [here](https://github.com/EricLBuehler/candle-lora/tree/master/candle-lora-transformers/examples).
|
||||
- [`optimisers`](https://github.com/KGrewal1/optimisers): A collection of optimisers
|
||||
- [`optimisers`](https://github.com/KGrewal1/optimisers): a collection of optimisers
|
||||
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
|
||||
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): a LoRA implementation
|
||||
that conforms to the official `peft` implementation.
|
||||
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
|
||||
serving local LLMs including an OpenAI compatible API server.
|
||||
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
|
||||
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
||||
- [`candle-ext`](https://github.com/mokeyish/candle-ext): an extension library to Candle that provides PyTorch functions not currently available in Candle.
|
||||
- [`kalosm`](https://github.com/floneum/floneum/tree/master/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
||||
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
||||
|
||||
If you have an addition to this list, please submit a pull request.
|
||||
@ -177,11 +177,6 @@ If you have an addition to this list, please submit a pull request.
|
||||
- Replit-code-v1.5-3B.
|
||||
- Bert.
|
||||
- Yi-6B and Yi-34B.
|
||||
- Quantized LLMs.
|
||||
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
||||
- Mistral 7b, and 7b instruct.
|
||||
- Zephyr 7b a and b (Mistral based).
|
||||
- OpenChat 3.5 (Mistral based).
|
||||
- Text to text.
|
||||
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
|
||||
- Marian MT (Machine Translation).
|
||||
|
@ -11,11 +11,11 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.3.1" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.1" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
@ -12,8 +12,8 @@ readme = "README.md"
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.3.1", optional = true }
|
||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.1", optional = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
|
||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
||||
metal = { workspace = true, optional = true}
|
||||
cudarc = { workspace = true, optional = true }
|
||||
gemm = { workspace = true }
|
||||
@ -30,6 +30,8 @@ safetensors = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
yoke = { workspace = true }
|
||||
zip = { workspace = true }
|
||||
dispatch = { workspace = true, optional = true }
|
||||
rustc-hash = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
@ -41,4 +43,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
|
||||
cudnn = ["cuda", "cudarc/cudnn"]
|
||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||
metal = ["dep:metal", "dep:candle-metal-kernels"]
|
||||
metal = ["dep:metal", "dep:candle-metal-kernels", "dep:dispatch"]
|
||||
|
@ -8,10 +8,11 @@ use anyhow::Result;
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?;
|
||||
let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?;
|
||||
let new_a = a.slice_scatter(&b, 1, 2)?;
|
||||
assert_eq!(a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
assert_eq!(new_a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
|
||||
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
|
||||
let start = std::time::Instant::now();
|
||||
let res = inp.conv2d(&w, 0, 1, 1, 1)?;
|
||||
println!("{:?}", start.elapsed());
|
||||
println!("{res:?}");
|
||||
Ok(())
|
||||
}
|
||||
|
@ -104,31 +104,37 @@ impl From<&Tensor> for TensorIndexer {
|
||||
}
|
||||
}
|
||||
|
||||
trait RB: RangeBounds<usize> {}
|
||||
impl RB for Range<usize> {}
|
||||
impl RB for RangeFrom<usize> {}
|
||||
impl RB for RangeFull {}
|
||||
impl RB for RangeInclusive<usize> {}
|
||||
impl RB for RangeTo<usize> {}
|
||||
impl RB for RangeToInclusive<usize> {}
|
||||
macro_rules! impl_from_range {
|
||||
($range_type:ty) => {
|
||||
impl From<$range_type> for TensorIndexer {
|
||||
fn from(range: $range_type) -> Self {
|
||||
use std::ops::Bound::*;
|
||||
|
||||
impl<T: RB> From<T> for TensorIndexer {
|
||||
fn from(range: T) -> Self {
|
||||
use std::ops::Bound::*;
|
||||
let start = match range.start_bound() {
|
||||
Included(idx) => Included(*idx),
|
||||
Excluded(idx) => Excluded(*idx),
|
||||
Unbounded => Unbounded,
|
||||
};
|
||||
let end = match range.end_bound() {
|
||||
Included(idx) => Included(*idx),
|
||||
Excluded(idx) => Excluded(*idx),
|
||||
Unbounded => Unbounded,
|
||||
};
|
||||
TensorIndexer::Narrow(start, end)
|
||||
}
|
||||
let start = match range.start_bound() {
|
||||
Included(idx) => Included(*idx),
|
||||
Excluded(idx) => Excluded(*idx),
|
||||
Unbounded => Unbounded,
|
||||
};
|
||||
|
||||
let end = match range.end_bound() {
|
||||
Included(idx) => Included(*idx),
|
||||
Excluded(idx) => Excluded(*idx),
|
||||
Unbounded => Unbounded,
|
||||
};
|
||||
|
||||
TensorIndexer::Narrow(start, end)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_from_range!(Range<usize>);
|
||||
impl_from_range!(RangeFrom<usize>);
|
||||
impl_from_range!(RangeFull);
|
||||
impl_from_range!(RangeInclusive<usize>);
|
||||
impl_from_range!(RangeTo<usize>);
|
||||
impl_from_range!(RangeToInclusive<usize>);
|
||||
|
||||
/// Trait used to implement multiple signatures for ease of use of the slicing
|
||||
/// of a tensor
|
||||
pub trait IndexOp<T> {
|
||||
|
@ -123,6 +123,12 @@ pub trait Module {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
||||
}
|
||||
|
||||
impl Module for quantized::QMatMul {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.forward(xs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self(xs)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -307,8 +307,8 @@ impl crate::CustomOp1 for QTensor {
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for QMatMul {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
impl QMatMul {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
||||
Self::Tensor(w) => {
|
||||
|
@ -2457,110 +2457,6 @@ impl Tensor {
|
||||
Ok(naxis as usize)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a lower triangular matrix of ones of size n by n.
|
||||
pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||
let t = Tensor::arange(0u32, n as u32, device)?;
|
||||
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
||||
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
||||
t1.le(&t2)?.to_dtype(dtype)
|
||||
}
|
||||
|
||||
/// Returns an upper triangular matrix of ones of size n by n.
|
||||
pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||
let t = Tensor::arange(0u32, n as u32, device)?;
|
||||
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
||||
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
||||
t1.ge(&t2)?.to_dtype(dtype)
|
||||
}
|
||||
|
||||
/// Returns a matrix with a diagonal of ones of size n by n.
|
||||
pub fn eye(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||
let t = Tensor::arange(0u32, n as u32, device)?;
|
||||
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
||||
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
||||
t1.eq(&t2)?.to_dtype(dtype)
|
||||
}
|
||||
|
||||
/// Returns the cumulative sum of elements of the input tensor summed over the specified
|
||||
/// dimension.
|
||||
///
|
||||
/// This operation is most efficient when dim is the last dimension of the tensor.
|
||||
pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "cumsum")?;
|
||||
let rank = self.rank();
|
||||
if rank == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let n_axis = self.dim(dim)?;
|
||||
let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?;
|
||||
if rank == 1 {
|
||||
self.unsqueeze(0)?.matmul(&triu)?.squeeze(0)
|
||||
} else {
|
||||
let last = rank - 1;
|
||||
let t = self.transpose(dim, last)?;
|
||||
let t = t.broadcast_matmul(&triu)?;
|
||||
t.transpose(dim, last)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a copy of `self` where the values within `ranges` have been replaced with the
|
||||
/// content of `src`.
|
||||
pub fn slice_assign<D: std::ops::RangeBounds<usize>>(
|
||||
&self,
|
||||
ranges: &[D],
|
||||
src: &Tensor,
|
||||
) -> Result<Self> {
|
||||
let src_dims = src.dims();
|
||||
let self_dims = self.dims();
|
||||
if self_dims.len() != src_dims.len() {
|
||||
crate::bail!(
|
||||
"slice-assign requires input with the same rank {} <> {}",
|
||||
self_dims.len(),
|
||||
src_dims.len()
|
||||
)
|
||||
}
|
||||
if self_dims.len() != ranges.len() {
|
||||
crate::bail!(
|
||||
"slice-assign requires input with the same rank as there are ranges {} <> {}",
|
||||
self_dims.len(),
|
||||
ranges.len()
|
||||
)
|
||||
}
|
||||
let mut src = src.clone();
|
||||
let mut mask = Self::ones(src.shape(), DType::U8, src.device())?;
|
||||
for (i, range) in ranges.iter().enumerate() {
|
||||
let start_included = match range.start_bound() {
|
||||
std::ops::Bound::Unbounded => 0,
|
||||
std::ops::Bound::Included(v) => *v,
|
||||
std::ops::Bound::Excluded(v) => *v + 1,
|
||||
};
|
||||
let end_excluded = match range.end_bound() {
|
||||
std::ops::Bound::Unbounded => self_dims[i],
|
||||
std::ops::Bound::Included(v) => *v + 1,
|
||||
std::ops::Bound::Excluded(v) => *v,
|
||||
};
|
||||
if end_excluded <= start_included {
|
||||
crate::bail!(
|
||||
"slice-assign: empty range for dim {i}, {start_included} {end_excluded}"
|
||||
)
|
||||
}
|
||||
if self_dims[i] < end_excluded {
|
||||
crate::bail!(
|
||||
"slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
|
||||
self_dims[i]
|
||||
)
|
||||
}
|
||||
if end_excluded - start_included != src_dims[i] {
|
||||
crate::bail!(
|
||||
"slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
|
||||
)
|
||||
}
|
||||
src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?;
|
||||
mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?
|
||||
}
|
||||
mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! bin_trait {
|
||||
|
@ -91,32 +91,3 @@ fn index_3d() -> Result<()> {
|
||||
assert_eq!(tensor.i((1, .., 3))?.to_vec1::<u32>()?, &[15, 19, 23]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slice_assign() -> Result<()> {
|
||||
let dev = Device::Cpu;
|
||||
|
||||
let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?;
|
||||
let src = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((3, 2))?;
|
||||
let out = tensor.slice_assign(&[1..4, 3..5], &src)?;
|
||||
assert_eq!(
|
||||
out.to_vec2::<u32>()?,
|
||||
&[
|
||||
[0, 1, 2, 3, 4],
|
||||
[5, 6, 7, 0, 1],
|
||||
[10, 11, 12, 2, 3],
|
||||
[15, 16, 17, 4, 5]
|
||||
]
|
||||
);
|
||||
let out = tensor.slice_assign(&[0..3, 0..2], &src)?;
|
||||
assert_eq!(
|
||||
out.to_vec2::<u32>()?,
|
||||
&[
|
||||
[0, 1, 2, 3, 4],
|
||||
[2, 3, 7, 8, 9],
|
||||
[4, 5, 12, 13, 14],
|
||||
[15, 16, 17, 18, 19]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
use candle_core::{
|
||||
quantized::{self, GgmlDType},
|
||||
test_utils::to_vec2_round,
|
||||
Device, Module, Result, Tensor,
|
||||
Device, Result, Tensor,
|
||||
};
|
||||
use quantized::{k_quants, GgmlType};
|
||||
use rand::prelude::*;
|
||||
|
@ -1159,65 +1159,3 @@ fn i64_abs() -> Result<()> {
|
||||
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tril_triu_eye() -> Result<()> {
|
||||
let t = Tensor::tril2(4, DType::F32, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
[
|
||||
[1.0, 0.0, 0.0, 0.0],
|
||||
[1.0, 1.0, 0.0, 0.0],
|
||||
[1.0, 1.0, 1.0, 0.0],
|
||||
[1.0, 1.0, 1.0, 1.0]
|
||||
],
|
||||
);
|
||||
let t = Tensor::triu2(4, DType::F32, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
[
|
||||
[1.0, 1.0, 1.0, 1.0],
|
||||
[0.0, 1.0, 1.0, 1.0],
|
||||
[0.0, 0.0, 1.0, 1.0],
|
||||
[0.0, 0.0, 0.0, 1.0]
|
||||
]
|
||||
);
|
||||
let t = Tensor::eye(4, DType::F32, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
[
|
||||
[1.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 1.0]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cumsum() -> Result<()> {
|
||||
let t = &[3f32, 1., 4., 1., 5.];
|
||||
let t = Tensor::new(t, &Device::Cpu)?;
|
||||
assert_eq!(t.cumsum(0)?.to_vec1::<f32>()?, [3., 4., 8., 9., 14.]);
|
||||
let t = t.unsqueeze(1)?;
|
||||
assert_eq!(
|
||||
t.cumsum(0)?.to_vec2::<f32>()?,
|
||||
[[3.0], [4.0], [8.0], [9.0], [14.0]]
|
||||
);
|
||||
assert_eq!(
|
||||
t.cumsum(1)?.to_vec2::<f32>()?,
|
||||
[[3.0], [1.0], [4.0], [1.0], [5.0]]
|
||||
);
|
||||
let t = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||
let t = Tensor::new(t, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
t.cumsum(1)?.to_vec2::<f32>()?,
|
||||
[[3.0, 4.0, 8.0, 9.0, 14.0], [2.0, 3.0, 10.0, 18.0, 20.0]],
|
||||
);
|
||||
assert_eq!(
|
||||
t.cumsum(0)?.to_vec2::<f32>()?,
|
||||
[[3.0, 1.0, 4.0, 1.0, 5.0], [5.0, 2.0, 11.0, 9.0, 7.0]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -11,8 +11,8 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
byteorder = { workspace = true }
|
||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
hf-hub = { workspace = true}
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
memmap2 = { workspace = true }
|
||||
|
@ -11,12 +11,12 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.3.1" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.1" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true }
|
||||
candle-onnx = { path = "../candle-onnx", version = "0.3.1", optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||
candle-onnx = { path = "../candle-onnx", version = "0.3.0", optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
image = { workspace = true }
|
||||
@ -57,7 +57,6 @@ flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||
onnx = ["candle-onnx"]
|
||||
metal = ["candle/metal", "candle-nn/metal"]
|
||||
|
||||
[[example]]
|
||||
name = "llama_multiprocess"
|
||||
|
@ -1,22 +0,0 @@
|
||||
# candle-distilbert
|
||||
|
||||
DistilBert is a distiled version of the Bert model.
|
||||
|
||||
## Sentence embeddings
|
||||
|
||||
DistilBert is used to compute the sentence embeddings for a prompt. The model weights
|
||||
are downloaded from the hub on the first run.
|
||||
|
||||
```bash
|
||||
cargo run --example distilbert --release -- --prompt "Here is a test sentence"
|
||||
|
||||
> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441],
|
||||
> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244],
|
||||
> [ 0.0702, -0.1311, -0.4914, ..., 0.3483, -0.6194, 0.1829],
|
||||
> ...
|
||||
> [ 0.2993, -0.0106, -0.4640, ..., 0.2844, -0.6732, 0.0042],
|
||||
> [ 0.1066, -0.0081, -0.4299, ..., 0.3435, -0.7729, 0.0190],
|
||||
> [ 0.8903, 0.2055, -0.2541, ..., 0.3208, -0.6585, 0.0586]]]
|
||||
> Tensor[[1, 7, 768], f32]
|
||||
|
||||
```
|
@ -1,135 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE};
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
/// When set, compute embeddings for this prompt.
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// Use the pytorch weights rather than the safetensors ones
|
||||
#[arg(long)]
|
||||
use_pth: bool,
|
||||
|
||||
/// The number of times to run the prompt.
|
||||
#[arg(long, default_value = "1")]
|
||||
n: usize,
|
||||
|
||||
/// L2 normalization for embeddings.
|
||||
#[arg(long, default_value = "true")]
|
||||
normalize_embeddings: bool,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> {
|
||||
let device = candle_examples::device(self.cpu)?;
|
||||
let default_model = "distilbert-base-uncased".to_string();
|
||||
let default_revision = "main".to_string();
|
||||
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||
(None, Some(revision)) => (default_model, revision),
|
||||
(None, None) => (default_model, default_revision),
|
||||
};
|
||||
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let (config_filename, tokenizer_filename, weights_filename) = {
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
let config = api.get("config.json")?;
|
||||
let tokenizer = api.get("tokenizer.json")?;
|
||||
let weights = if self.use_pth {
|
||||
api.get("pytorch_model.bin")?
|
||||
} else {
|
||||
api.get("model.safetensors")?
|
||||
};
|
||||
(config, tokenizer, weights)
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let config: Config = serde_json::from_str(&config)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let vb = if self.use_pth {
|
||||
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
|
||||
} else {
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
||||
};
|
||||
let model = DistilBertModel::load(vb, &config)?;
|
||||
Ok((model, tokenizer))
|
||||
}
|
||||
}
|
||||
|
||||
fn get_mask(size: usize, device: &Device) -> Tensor {
|
||||
let mask: Vec<_> = (0..size)
|
||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (size, size), device).unwrap()
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
println!("tracing...");
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
||||
let device = &model.device;
|
||||
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
let tokens = tokenizer
|
||||
.encode(args.prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
let mask = get_mask(tokens.len(), device);
|
||||
|
||||
println!("token_ids: {:?}", token_ids.to_vec2::<u32>());
|
||||
println!("mask: {:?}", mask.to_vec2::<u8>());
|
||||
|
||||
let ys = model.forward(&token_ids, &mask)?;
|
||||
println!("{ys}");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
|
||||
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
|
||||
}
|
@ -53,8 +53,6 @@ enum Which {
|
||||
Zephyr7bAlpha,
|
||||
#[value(name = "7b-zephyr-b")]
|
||||
Zephyr7bBeta,
|
||||
#[value(name = "7b-open-chat-3.5")]
|
||||
OpenChat35,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
@ -69,10 +67,8 @@ impl Which {
|
||||
| Self::L7bCode
|
||||
| Self::L13bCode
|
||||
| Self::L34bCode => false,
|
||||
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
|
||||
// same way.
|
||||
Self::OpenChat35
|
||||
| Self::Zephyr7bAlpha
|
||||
// Zephyr is a fine tuned version of mistral and should be treated in the same way.
|
||||
Self::Zephyr7bAlpha
|
||||
| Self::Zephyr7bBeta
|
||||
| Self::Mistral7b
|
||||
| Self::Mistral7bInstruct => true,
|
||||
@ -91,30 +87,10 @@ impl Which {
|
||||
| Self::L13bCode
|
||||
| Self::L34bCode
|
||||
| Self::Mistral7b
|
||||
| Self::Mistral7bInstruct
|
||||
| Self::OpenChat35 => false,
|
||||
| Self::Mistral7bInstruct => false,
|
||||
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_open_chat(&self) -> bool {
|
||||
match self {
|
||||
Which::L7b
|
||||
| Which::L13b
|
||||
| Which::L70b
|
||||
| Which::L7bChat
|
||||
| Which::L13bChat
|
||||
| Which::L70bChat
|
||||
| Which::L7bCode
|
||||
| Which::L13bCode
|
||||
| Which::L34bCode
|
||||
| Which::Mistral7b
|
||||
| Which::Mistral7bInstruct
|
||||
| Which::Zephyr7bAlpha
|
||||
| Which::Zephyr7bBeta => false,
|
||||
Which::OpenChat35 => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -181,9 +157,7 @@ impl Args {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let repo = if self.which.is_open_chat() {
|
||||
"openchat/openchat_3.5"
|
||||
} else if self.which.is_mistral() {
|
||||
let repo = if self.which.is_mistral() {
|
||||
"mistralai/Mistral-7B-v0.1"
|
||||
} else {
|
||||
"hf-internal-testing/llama-tokenizer"
|
||||
@ -233,7 +207,6 @@ impl Args {
|
||||
Which::Zephyr7bBeta => {
|
||||
("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf")
|
||||
}
|
||||
Which::OpenChat35 => ("TheBloke/openchat_3.5-GGUF", "openchat_3.5.Q4_K_M.gguf"),
|
||||
};
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(repo.to_string());
|
||||
@ -335,8 +308,7 @@ fn main() -> anyhow::Result<()> {
|
||||
| Which::Zephyr7bAlpha
|
||||
| Which::Zephyr7bBeta
|
||||
| Which::L70b
|
||||
| Which::L70bChat
|
||||
| Which::OpenChat35 => 8,
|
||||
| Which::L70bChat => 8,
|
||||
};
|
||||
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
|
||||
}
|
||||
@ -368,9 +340,7 @@ fn main() -> anyhow::Result<()> {
|
||||
prompt.pop();
|
||||
}
|
||||
}
|
||||
if args.which.is_open_chat() {
|
||||
format!("User: {prompt}<|end_of_turn|>Assistant: ")
|
||||
} else if args.which.is_zephyr() {
|
||||
if args.which.is_zephyr() {
|
||||
if prompt_index == 0 || is_interactive {
|
||||
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>",)
|
||||
} else {
|
||||
@ -420,12 +390,8 @@ fn main() -> anyhow::Result<()> {
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
||||
let eos_token = if args.which.is_open_chat() {
|
||||
"<|end_of_turn|>"
|
||||
} else {
|
||||
"</s>"
|
||||
};
|
||||
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
|
||||
let eos_token = *tos.tokenizer().get_vocab(true).get("</s>").unwrap();
|
||||
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
let mut sampled = 0;
|
||||
for index in 0..to_sample {
|
||||
|
@ -416,7 +416,7 @@ fn run(args: Args) -> Result<()> {
|
||||
|
||||
println!("Building the autoencoder.");
|
||||
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
|
||||
let vae = sd_config.build_vae(vae_weights, &device, dtype)?;
|
||||
let vae = sd_config.build_vae(&vae_weights, &device, dtype)?;
|
||||
let init_latent_dist = match &img2img {
|
||||
None => None,
|
||||
Some(image) => {
|
||||
@ -426,7 +426,7 @@ fn run(args: Args) -> Result<()> {
|
||||
};
|
||||
println!("Building the unet.");
|
||||
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?;
|
||||
let unet = sd_config.build_unet(unet_weights, &device, 4, use_flash_attn, dtype)?;
|
||||
let unet = sd_config.build_unet(&unet_weights, &device, 4, use_flash_attn, dtype)?;
|
||||
|
||||
let t_start = if img2img.is_some() {
|
||||
n_steps - (n_steps as f64 * img2img_strength) as usize
|
||||
|
@ -8,7 +8,7 @@ the model itself.
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
cargo run --example trocr --release -- --which base --cpu --image candle-examples/examples/trocr/assets/trocr.png
|
||||
cargo run --example trocr --release -- --which base --cpu --image assets/trocr.png
|
||||
```
|
||||
|
||||
```
|
||||
|
@ -128,13 +128,7 @@ impl Decoder {
|
||||
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
|
||||
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
|
||||
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
|
||||
let no_speech_token = m::NO_SPEECH_TOKENS
|
||||
.iter()
|
||||
.find_map(|token| token_id(&tokenizer, token).ok());
|
||||
let no_speech_token = match no_speech_token {
|
||||
None => anyhow::bail!("unable to find any non-speech token"),
|
||||
Some(n) => n,
|
||||
};
|
||||
let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
|
||||
Ok(Self {
|
||||
model,
|
||||
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
||||
@ -518,7 +512,11 @@ fn main() -> Result<()> {
|
||||
)
|
||||
} else {
|
||||
let config = repo.get("config.json")?;
|
||||
let tokenizer = repo.get("tokenizer.json")?;
|
||||
let tokenizer = if args.model == WhichModel::LargeV3 {
|
||||
panic!("openai/whisper-large-v3 does not provide a compatible tokenizer.json config at the moment")
|
||||
} else {
|
||||
repo.get("tokenizer.json")?
|
||||
};
|
||||
let model = repo.get("model.safetensors")?;
|
||||
(config, tokenizer, model)
|
||||
};
|
||||
|
@ -74,9 +74,9 @@ impl TextGeneration {
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||
let eos_token = match self.tokenizer.get_token("</s>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
||||
None => anyhow::bail!("cannot find the </s> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
|
@ -43,7 +43,6 @@ pub fn report(
|
||||
confidence_threshold: f32,
|
||||
nms_threshold: f32,
|
||||
) -> Result<DynamicImage> {
|
||||
let pred = pred.to_device(&Device::Cpu)?;
|
||||
let (npreds, pred_size) = pred.dims2()?;
|
||||
let nclasses = pred_size - 5;
|
||||
// The bounding boxes grouped by (maximum) class index.
|
||||
|
@ -32,7 +32,7 @@ Image source:
|
||||
### Pose Estimation
|
||||
```bash
|
||||
cargo run --example yolo-v8 --release -- \
|
||||
candle-examples/examples/yolo-v8/assets/bike.jpg --task pose
|
||||
candle-examples/examples/yolo-v8/assets/peoples.jpeg --task pose
|
||||
```
|
||||
|
||||

|
||||
|
@ -7,7 +7,7 @@ extern crate accelerate_src;
|
||||
mod model;
|
||||
use model::{Multiples, YoloV8, YoloV8Pose};
|
||||
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||
use candle::{DType, IndexOp, Result, Tensor};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
|
||||
use clap::{Parser, ValueEnum};
|
||||
@ -61,7 +61,6 @@ pub fn report_detect(
|
||||
nms_threshold: f32,
|
||||
legend_size: u32,
|
||||
) -> Result<DynamicImage> {
|
||||
let pred = pred.to_device(&Device::Cpu)?;
|
||||
let (pred_size, npreds) = pred.dims2()?;
|
||||
let nclasses = pred_size - 4;
|
||||
// The bounding boxes grouped by (maximum) class index.
|
||||
@ -154,7 +153,6 @@ pub fn report_pose(
|
||||
confidence_threshold: f32,
|
||||
nms_threshold: f32,
|
||||
) -> Result<DynamicImage> {
|
||||
let pred = pred.to_device(&Device::Cpu)?;
|
||||
let (pred_size, npreds) = pred.dims2()?;
|
||||
if pred_size != 17 * 3 + 4 + 1 {
|
||||
candle::bail!("unexpected pred-size {pred_size}");
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.3.1"
|
||||
version = "0.3.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], version = "0.3.1", package = "candle-core" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], version = "0.3.0", package = "candle-core" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
@ -21,4 +21,4 @@ rayon = "1.7.0"
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.1", features = ["cuda"] }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0", features = ["cuda"] }
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.3.1"
|
||||
version = "0.3.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
@ -14,4 +14,4 @@ license = "MIT OR Apache-2.0"
|
||||
[build-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
glob = "0.3.1"
|
||||
rayon = "1.7.0"
|
||||
rayon = "1.7.0"
|
@ -1,16 +1,17 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.3.1"
|
||||
version = "0.3.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Metal kernels for Candle"
|
||||
description = "CUDA kernels for Candle"
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
keywords = ["blas", "tensor", "machine-learning"]
|
||||
categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
|
||||
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||
metal = { path = "../../metal-rs", features = ["mps"] }
|
||||
once_cell = "1.18.0"
|
||||
thiserror = "1"
|
||||
tracing = "0.1.37"
|
||||
|
@ -23,12 +23,12 @@ kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
device const LEFT_TYPENAME *input, \
|
||||
device RIGHT_TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (tid >= dim) { \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = RIGHT_TYPENAME(input[tid]); \
|
||||
output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \
|
||||
} \
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
@ -37,17 +37,15 @@ kernel void FN_NAME_STRIDED( \
|
||||
constant size_t *strides, \
|
||||
device const LEFT_TYPENAME *input, \
|
||||
device RIGHT_TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
uint i [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (tid >= dim) { \
|
||||
if (i >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \
|
||||
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
|
||||
} \
|
||||
|
||||
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
|
||||
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
|
||||
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
|
||||
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
|
||||
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
|
||||
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
Binary file not shown.
@ -1,8 +1,6 @@
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
@ -18,18 +16,18 @@ METAL_FUNC uint get_strided_index(
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
constant int THREADGROUP_SIZE = 1024;
|
||||
constant int THREADGROUP_SIZE = 256;
|
||||
|
||||
# define REDUCE(FN, NAME, T) \
|
||||
# define REDUCE(FN, NAME, TYPENAME) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
device const TYPENAME *src, \
|
||||
device TYPENAME *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
uint blockDim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
@ -47,10 +45,10 @@ kernel void NAME( \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
*/ \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = src[idx]; \
|
||||
TYPENAME x = shared_memory[tid]; \
|
||||
TYPENAME y = src[idx]; \
|
||||
shared_memory[tid] = FN; \
|
||||
idx += block_dim; \
|
||||
idx += blockDim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
@ -58,10 +56,10 @@ kernel void NAME( \
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
for (uint s = blockDim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = shared_memory[tid + s]; \
|
||||
TYPENAME x = shared_memory[tid]; \
|
||||
TYPENAME y = shared_memory[tid + s]; \
|
||||
shared_memory[tid] = FN; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
@ -70,74 +68,72 @@ kernel void NAME( \
|
||||
dst[dst_id] = shared_memory[0]; \
|
||||
} \
|
||||
|
||||
kernel void softmax_float(
|
||||
constant size_t &src_numel,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
device const float *src,
|
||||
device float *dst,
|
||||
uint id [[ thread_position_in_grid ]],
|
||||
uint tid [[ thread_index_in_threadgroup ]],
|
||||
uint dst_id [[ threadgroup_position_in_grid ]],
|
||||
uint blockDim [[ threads_per_threadgroup ]]
|
||||
) {
|
||||
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE];
|
||||
|
||||
shared_memory[tid] = -INFINITY;
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// to (dst_id + 1) * el_to_sum_per_block.
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
||||
size_t idx = start_idx + tid;
|
||||
|
||||
while (idx < stop_idx) {
|
||||
// TODO: Fast version for the contiguous case.
|
||||
shared_memory[tid] = max(shared_memory[tid], src[idx]);
|
||||
idx += blockDim;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// reduction in shared memory
|
||||
for (uint s = blockDim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
float max = shared_memory[0];
|
||||
|
||||
shared_memory[tid] = 0;
|
||||
|
||||
// Restart
|
||||
idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
// TODO: Fast version for the contiguous case.
|
||||
const float val = exp(src[idx] - max);
|
||||
dst[idx] = val;
|
||||
shared_memory[tid] += val;
|
||||
idx += blockDim;
|
||||
}
|
||||
// reduction in shared memory
|
||||
for (uint s = blockDim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
shared_memory[tid] += shared_memory[tid + s];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
const float inv_acc = 1/shared_memory[0];
|
||||
idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
dst[idx] *= inv_acc;
|
||||
idx += blockDim;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
REDUCE(x + y, fast_sum_float, float)
|
||||
REDUCE(x * y, fast_mul_float, float)
|
||||
REDUCE(max(x, y), fast_max_float, float)
|
||||
|
||||
#define SOFTMAX(NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
\
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
shared_memory[tid] = -INFINITY; \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||
size_t idx = start_idx + tid; \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
while (idx < stop_idx) { \
|
||||
shared_memory[tid] = MAX(shared_memory[tid], src[idx]); \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
float _max = shared_memory[0]; \
|
||||
\
|
||||
shared_memory[tid] = 0; \
|
||||
\
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
const T val = T(exp(src[idx] - _max)); \
|
||||
dst[idx] = val; \
|
||||
shared_memory[tid] += val; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] += shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
const T inv_acc = T(1/shared_memory[0]); \
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
dst[idx] *= inv_acc; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
} \
|
||||
|
||||
SOFTMAX(softmax_float, float)
|
||||
SOFTMAX(softmax_half, half)
|
||||
#if __METAL_VERSION__ >= 310
|
||||
SOFTMAX(softmax_bfloat, bfloat)
|
||||
#endif
|
||||
|
@ -32,9 +32,6 @@ kernel void FN_NAME( \
|
||||
device TYPENAME *out ,\
|
||||
uint i [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (i >= numel){ \
|
||||
return; \
|
||||
} \
|
||||
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
|
||||
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
|
||||
|
@ -1,211 +0,0 @@
|
||||
|
||||
import Metal
|
||||
import MetalPerformanceShadersGraph
|
||||
|
||||
|
||||
|
||||
let type = MTLDataType.float;
|
||||
let dataType = type;
|
||||
var B = 2;
|
||||
var M = 2;
|
||||
var N = 4;
|
||||
var K = 3;
|
||||
var A_trans = false;
|
||||
var B_trans = false;
|
||||
var D_trans = false;
|
||||
var alpha = Float(1.0);
|
||||
var beta = Float(0.0);
|
||||
var batched = B > 1;
|
||||
var fused_activation = false;
|
||||
var fused_bias = false;
|
||||
let constants = MTLFunctionConstantValues()
|
||||
constants.setConstantValue(&M, type: .uint, index: 0)
|
||||
constants.setConstantValue(&N, type: .uint, index: 1)
|
||||
constants.setConstantValue(&K, type: .uint, index: 2)
|
||||
constants.setConstantValue(&A_trans, type: .bool, index: 10)
|
||||
constants.setConstantValue(&B_trans, type: .bool, index: 11)
|
||||
constants.setConstantValue(&D_trans, type: .bool, index: 13)
|
||||
constants.setConstantValue(&alpha, type: .float, index: 20)
|
||||
constants.setConstantValue(&beta, type: .float, index: 21)
|
||||
constants.setConstantValue(&batched, type: .bool, index: 100)
|
||||
constants.setConstantValue(&fused_activation, type: .bool, index: 101)
|
||||
constants.setConstantValue(&fused_bias, type: .bool, index: 50001)
|
||||
|
||||
|
||||
var M_simd = UInt16(16)
|
||||
var N_simd = UInt16(16)
|
||||
var K_simd = UInt16(32)
|
||||
var M_splits = UInt16(2)
|
||||
var N_splits = UInt16(2)
|
||||
constants.setConstantValue(&M_simd, type: .ushort, index: 200)
|
||||
constants.setConstantValue(&N_simd, type: .ushort, index: 201)
|
||||
constants.setConstantValue(&K_simd, type: .ushort, index: 202)
|
||||
constants.setConstantValue(&M_splits, type: .ushort, index: 210)
|
||||
constants.setConstantValue(&N_splits, type: .ushort, index: 211)
|
||||
|
||||
let M_group = M_simd * M_splits
|
||||
let N_group = N_simd * N_splits
|
||||
|
||||
// Satisfy Metal API validation.
|
||||
#if DEBUG
|
||||
do {
|
||||
var garbage: SIMD4<UInt64> = .zero
|
||||
constants.setConstantValue(&garbage, type: .bool, index: 102)
|
||||
constants.setConstantValue(&garbage, type: .bool, index: 103)
|
||||
constants.setConstantValue(&garbage, type: .bool, index: 113)
|
||||
constants.setConstantValue(&garbage, type: .bool, index: 50000)
|
||||
}
|
||||
#endif
|
||||
print(constants)
|
||||
|
||||
let device = MTLCopyAllDevices().first!
|
||||
device.shouldMaximizeConcurrentCompilation = true
|
||||
|
||||
var libraryURL = URL.init(string: "/Users/nicolas/src/candle/candle-metal-kernels/")!;
|
||||
libraryURL.append(component: "src")
|
||||
libraryURL.append(component: "libMetalFlashAttention.metallib")
|
||||
let library = try! device.makeLibrary(URL: libraryURL)
|
||||
|
||||
var name: String
|
||||
switch dataType {
|
||||
case .half: name = "hgemm"
|
||||
case .float: name = "sgemm"
|
||||
default: fatalError()
|
||||
}
|
||||
let function = try! library.makeFunction(
|
||||
name: name, constantValues: constants)
|
||||
|
||||
let A_block_length = M_group * K_simd
|
||||
let B_block_length = K_simd * N_group
|
||||
|
||||
var blockElements = A_block_length + B_block_length;
|
||||
if (M % 8 != 0) && (N % 8 != 0) {
|
||||
let C_block_length = M_group * N_group;
|
||||
blockElements = max(C_block_length, blockElements)
|
||||
}
|
||||
if fused_bias {
|
||||
if D_trans {
|
||||
blockElements = max(blockElements, M_group)
|
||||
} else {
|
||||
blockElements = max(blockElements, N_group)
|
||||
}
|
||||
}
|
||||
// let blockBytes = blockElements * UInt16(dataType.size)
|
||||
let elementSize = 4
|
||||
let blockBytes = blockElements * UInt16(elementSize)
|
||||
|
||||
func ceilDivide(target: Int, granularity: UInt16) -> Int {
|
||||
(target + Int(granularity) - 1) / Int(granularity)
|
||||
}
|
||||
var gridSize = MTLSize(
|
||||
width: ceilDivide(target: N, granularity: N_group),
|
||||
height: ceilDivide(target: M, granularity: M_group),
|
||||
depth: 1)
|
||||
let groupSize = MTLSize(
|
||||
width: Int(32 * M_splits * N_splits),
|
||||
height: 1,
|
||||
depth: 1)
|
||||
|
||||
let commandQueue = device.makeCommandQueue()!
|
||||
let commandBuffer = commandQueue.makeCommandBuffer()!
|
||||
let encoder = commandBuffer.makeComputeCommandEncoder(dispatchType: MTLDispatchType.serial)!
|
||||
let pipeline = try device.makeComputePipelineState(function: function)
|
||||
|
||||
let threadgroupMemoryLength = blockBytes;
|
||||
print(threadgroupMemoryLength)
|
||||
encoder.setComputePipelineState(pipeline)
|
||||
encoder.setThreadgroupMemoryLength(Int(threadgroupMemoryLength), index: 0)
|
||||
|
||||
|
||||
let rowsA = M;
|
||||
let columnsA = K;
|
||||
let rowsB = K;
|
||||
let columnsB = N;
|
||||
let rowsC = M;
|
||||
let columnsC = N;
|
||||
var arrayA = [Float](repeating: 0, count: B * rowsA * columnsA)
|
||||
|
||||
var arrayB = [Float](repeating: 0, count: B * rowsB * columnsB)
|
||||
|
||||
var arrayC = [Float](repeating: 0, count: B * rowsC * columnsC)
|
||||
for i in 0..<arrayA.count {
|
||||
arrayA[i] = Float(i)
|
||||
}
|
||||
|
||||
for i in 0..<arrayB.count {
|
||||
arrayB[i] = Float(i)
|
||||
}
|
||||
|
||||
let bufferA = device.makeBuffer(bytes: arrayA, length: B * rowsA * columnsA * MemoryLayout<Float>.stride, options: [])
|
||||
|
||||
let bufferB = device.makeBuffer(bytes: arrayB, length: B * rowsB * columnsB * MemoryLayout<Float>.stride, options: [])
|
||||
|
||||
let bufferC = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout<Float>.stride, options: [])
|
||||
|
||||
print(arrayA)
|
||||
print(arrayB)
|
||||
|
||||
|
||||
encoder.setBuffer(bufferA, offset: 0, index: 0)
|
||||
encoder.setBuffer(bufferB, offset: 0, index: 1)
|
||||
encoder.setBuffer(bufferC, offset: 0, index: 2)
|
||||
var gridZ: Int = B
|
||||
if batched{
|
||||
func byteStride(shape: [Int]) -> Int {
|
||||
let rank = shape.count
|
||||
var output = elementSize * shape[rank - 2] * shape[rank - 1]
|
||||
if shape.dropLast(2).reduce(1, *) == 1 {
|
||||
output = 0
|
||||
}
|
||||
return output
|
||||
}
|
||||
let byteStrideA = M*K*elementSize
|
||||
let byteStrideB = N*K*elementSize
|
||||
let byteStrideC = M*N*elementSize
|
||||
|
||||
let byteStrideD = 0
|
||||
// if let shapeD = tensors.d?.shape {
|
||||
// let rank = shapeD.count
|
||||
// byteStrideD = elementSize * shapeD[rank - 1]
|
||||
// if shapeD.dropLast(1).reduce(1, *) == 1 {
|
||||
// byteStrideD = 0
|
||||
// }
|
||||
// }
|
||||
withUnsafeTemporaryAllocation(
|
||||
of: SIMD4<UInt64>.self, capacity: gridZ
|
||||
) { buffer in
|
||||
for i in 0..<buffer.count {
|
||||
buffer[i] = SIMD4(
|
||||
UInt64(truncatingIfNeeded: i * byteStrideA),
|
||||
UInt64(truncatingIfNeeded: i * byteStrideB),
|
||||
UInt64(truncatingIfNeeded: i * byteStrideC),
|
||||
UInt64(truncatingIfNeeded: i * byteStrideD))
|
||||
}
|
||||
|
||||
let bufferLength = buffer.count * MemoryLayout<SIMD4<UInt64>>.stride
|
||||
assert(MemoryLayout<SIMD4<UInt64>>.stride == 8 * 4)
|
||||
encoder.setBytes(buffer.baseAddress!, length: bufferLength, index: 10)
|
||||
print("BATCHED")
|
||||
print(buffer)
|
||||
}
|
||||
}
|
||||
gridSize.depth = gridZ
|
||||
|
||||
|
||||
print(gridSize, groupSize)
|
||||
encoder.dispatchThreadgroups(
|
||||
gridSize, threadsPerThreadgroup: groupSize
|
||||
)
|
||||
encoder.endEncoding()
|
||||
commandBuffer.commit()
|
||||
|
||||
commandBuffer.waitUntilCompleted()
|
||||
var contents = bufferC!.contents();
|
||||
|
||||
var count = B * rowsA * columnsB;
|
||||
|
||||
var typedPointer = contents.bindMemory(to: Float.self, capacity: count)
|
||||
|
||||
var bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count)
|
||||
|
||||
print(Array(bufferedPointer))
|
@ -1,800 +0,0 @@
|
||||
use super::*;
|
||||
use half::{bf16, f16};
|
||||
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
|
||||
|
||||
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let ptr = data.as_ptr() as *const core::ffi::c_void;
|
||||
let size = (data.len() * std::mem::size_of::<T>()) as u64;
|
||||
device.new_buffer_with_data(ptr, size, options)
|
||||
}
|
||||
|
||||
fn device() -> Device {
|
||||
Device::system_default().unwrap()
|
||||
}
|
||||
|
||||
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
|
||||
let b = 10f32.powi(digits);
|
||||
v.iter().map(|t| f32::round(t * b) / b).collect()
|
||||
}
|
||||
|
||||
fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> {
|
||||
let b = 10f32.powi(digits);
|
||||
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
|
||||
}
|
||||
|
||||
fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> {
|
||||
let b = 10f32.powi(digits);
|
||||
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
|
||||
}
|
||||
|
||||
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
call_unary_contiguous(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
v.len(),
|
||||
&input,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
output.read_to_vec::<T>(v.len())
|
||||
}
|
||||
|
||||
fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let left = new_buffer(&device, x);
|
||||
let right = new_buffer(&device, y);
|
||||
let output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
|
||||
call_binary_contiguous(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
x.len(),
|
||||
&left,
|
||||
&right,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
output.read_to_vec::<T>(x.len())
|
||||
}
|
||||
|
||||
fn run_strided<T: Clone>(
|
||||
v: &[T],
|
||||
kernel: unary::strided::Kernel,
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
offset: usize,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
let kernels = Kernels::new();
|
||||
call_unary_strided(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
kernel,
|
||||
shape,
|
||||
&input,
|
||||
strides,
|
||||
offset,
|
||||
&output,
|
||||
0,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
output.read_to_vec::<T>(v.len())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cos_f32() {
|
||||
let v = vec![1.0f32, 2.0, 3.0];
|
||||
let results = run(&v, unary::contiguous::cos::FLOAT);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
|
||||
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
|
||||
|
||||
let v = vec![1.0f32; 10_000];
|
||||
let results = run(&v, unary::contiguous::cos::FLOAT);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
|
||||
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cos_f32_strided() {
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let shape = vec![6];
|
||||
let strides = vec![1];
|
||||
let offset = 0;
|
||||
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||
);
|
||||
assert_eq!(
|
||||
approx(expected, 4),
|
||||
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||
);
|
||||
|
||||
// Contiguous
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let shape = vec![3, 2];
|
||||
let strides = vec![2, 1];
|
||||
let offset = 0;
|
||||
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||
);
|
||||
assert_eq!(
|
||||
approx(expected, 4),
|
||||
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||
);
|
||||
|
||||
// Transposed
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let shape = vec![3, 2];
|
||||
let strides = vec![1, 3];
|
||||
let offset = 0;
|
||||
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![0.5403, -0.6536, -0.4161, 0.2837, -0.99, 0.9602]
|
||||
);
|
||||
assert_eq!(
|
||||
approx(expected, 4),
|
||||
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||
);
|
||||
|
||||
// Very large
|
||||
let v = vec![1.0f32; 10_000];
|
||||
let shape = vec![2, 5_000];
|
||||
let strides = vec![2, 1];
|
||||
let offset = 0;
|
||||
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
|
||||
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cos_strided_random() {
|
||||
let v: Vec<_> = (0..10_000).map(|_| rand::random::<f32>()).collect();
|
||||
let shape = vec![5_000, 2];
|
||||
let strides = vec![1, 5_000];
|
||||
let offset = 0;
|
||||
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(approx(vec![results[0]], 4), approx(vec![expected[0]], 4));
|
||||
assert_eq!(
|
||||
approx(vec![results[1]], 4),
|
||||
approx(vec![expected[5_000]], 4)
|
||||
);
|
||||
assert_eq!(approx(vec![results[2]], 4), approx(vec![expected[1]], 4));
|
||||
assert_eq!(
|
||||
approx(vec![results[3]], 4),
|
||||
approx(vec![expected[5_001]], 4)
|
||||
);
|
||||
assert_eq!(
|
||||
approx(vec![results[5_000]], 4),
|
||||
approx(vec![expected[2_500]], 4)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn binary_add_f32() {
|
||||
let left = vec![1.0f32, 2.0, 3.0];
|
||||
let right = vec![2.0f32, 3.1, 4.2];
|
||||
let results = run_binary(&left, &right, binary::contiguous::add::FLOAT);
|
||||
let expected: Vec<_> = left
|
||||
.iter()
|
||||
.zip(right.iter())
|
||||
.map(|(&x, &y)| x + y)
|
||||
.collect();
|
||||
assert_eq!(approx(results, 4), vec![3.0f32, 5.1, 7.2]);
|
||||
assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]);
|
||||
}
|
||||
|
||||
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let size = (v.len() * std::mem::size_of::<U>()) as u64;
|
||||
let output = device.new_buffer(size, options);
|
||||
|
||||
call_cast_contiguous(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
v.len(),
|
||||
&input,
|
||||
0,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
output.read_to_vec::<U>(v.len())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_u32_f32() {
|
||||
let v = vec![1u32, 2, 3];
|
||||
let results = cast(&v, "cast_u32_f32");
|
||||
let expected: Vec<_> = v.iter().map(|&v| v as f32).collect();
|
||||
assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
|
||||
assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
|
||||
|
||||
let v = vec![1.0f32, 2.0, 3.0];
|
||||
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
|
||||
let results: Vec<f32> = cast(&input, "cast_f16_f32");
|
||||
assert_eq!(results, vec![1.0f32, 2.0, 3.0]);
|
||||
|
||||
let v = vec![1.0f32; 10_000];
|
||||
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
|
||||
let results: Vec<f32> = cast(&input, "cast_f16_f32");
|
||||
assert_eq!(results.len(), 10_000);
|
||||
assert_eq!(&results[..10], vec![1.0f32; 10]);
|
||||
assert_eq!(results, vec![1.0f32; 10_000]);
|
||||
}
|
||||
|
||||
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
|
||||
let input = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
|
||||
let size = v.len();
|
||||
|
||||
call_affine(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
"affine_float",
|
||||
size,
|
||||
&input,
|
||||
&output,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
output.read_to_vec::<T>(v.len())
|
||||
}
|
||||
|
||||
fn run_affine_strided<T: Clone>(
|
||||
v: &[T],
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
mul: f64,
|
||||
add: f64,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
|
||||
let input = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
|
||||
call_affine_strided(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
"affine_float_strided",
|
||||
shape,
|
||||
&input,
|
||||
strides,
|
||||
0,
|
||||
&output,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
let len: usize = shape.iter().product();
|
||||
output.read_to_vec::<T>(len)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn affine() {
|
||||
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
let mul = 1.5;
|
||||
let add = 1.1;
|
||||
let result = run_affine(&input, mul, add);
|
||||
assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]);
|
||||
|
||||
let input = [1.0f32; 40_000];
|
||||
let mul = 1.5;
|
||||
let add = 1.1;
|
||||
let result = run_affine(&input, mul, add);
|
||||
assert_eq!(result, vec![2.6; 40_000]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn affine_strided() {
|
||||
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
let mul = 1.5;
|
||||
let add = 1.1;
|
||||
let shape = [4];
|
||||
let strides = [2];
|
||||
let result = run_affine_strided(&input, &shape, &strides, mul, add);
|
||||
// 1 on 2
|
||||
assert_eq!(result, vec![2.6, 5.6, 8.6, 11.6]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select() {
|
||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||
let shape = [5, 2];
|
||||
let ids = [0u32, 4, 2];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]);
|
||||
|
||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||
let shape = [2, 5];
|
||||
let ids = [0u32, 1, 0];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
assert_eq!(
|
||||
result,
|
||||
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select_f16() {
|
||||
let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
||||
.into_iter()
|
||||
.map(|x| f16::from_f32(x))
|
||||
.collect();
|
||||
let shape = [5, 2];
|
||||
let ids = [0u32, 4, 2];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
assert_eq!(
|
||||
approx_f16(result, 4),
|
||||
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select_dim1() {
|
||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||
let shape = [5, 2];
|
||||
let ids = [0u32, 1, 0];
|
||||
let dim = 1;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
assert_eq!(
|
||||
result,
|
||||
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
|
||||
);
|
||||
}
|
||||
|
||||
fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
embeddings: &[T],
|
||||
shape: &[usize],
|
||||
ids: &[I],
|
||||
dim: usize,
|
||||
) -> Vec<T> {
|
||||
let device = Device::system_default().expect("no device found");
|
||||
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let embeddings_buffer = new_buffer(&device, &embeddings);
|
||||
let ids_buffer = new_buffer(&device, &ids);
|
||||
|
||||
let left_size: usize = shape[..dim].iter().product();
|
||||
let right_size: usize = shape[dim + 1..].iter().product();
|
||||
let dst_el = ids.len() * left_size * right_size;
|
||||
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
|
||||
|
||||
let name = match core::mem::size_of::<T>() {
|
||||
4 => "is_u32_f32",
|
||||
2 => "is_u32_f16",
|
||||
_ => unimplemented!(),
|
||||
};
|
||||
|
||||
let kernels = Kernels::new();
|
||||
call_index_select(
|
||||
&device,
|
||||
&command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
shape,
|
||||
ids.len(),
|
||||
dim,
|
||||
&embeddings_buffer,
|
||||
&ids_buffer,
|
||||
&dst_buffer,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
dst_buffer.read_to_vec::<T>(dst_el)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_add() {
|
||||
let device = Device::system_default().expect("no device found");
|
||||
|
||||
let options = CompileOptions::new();
|
||||
let library = device.new_library_with_source(INDEXING, &options).unwrap();
|
||||
|
||||
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
|
||||
let right = [1.0f32; 15];
|
||||
let index = [0u32, 4, 2];
|
||||
let ids_dim_size = index.len() as u32;
|
||||
let dst_dim_size: u32 = 15;
|
||||
let left_size: u32 = 3;
|
||||
let right_size: u32 = 3;
|
||||
|
||||
let function = library.get_function("ia_u32_f32", None).unwrap();
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(&function)
|
||||
.unwrap();
|
||||
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let index_buffer = new_buffer(&device, &index);
|
||||
let inputs_buffer = new_buffer(&device, &left);
|
||||
let outputs_buffer = new_buffer(&device, &right);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
&index_buffer,
|
||||
&inputs_buffer,
|
||||
&outputs_buffer,
|
||||
ids_dim_size,
|
||||
left_size,
|
||||
dst_dim_size,
|
||||
right_size
|
||||
)
|
||||
);
|
||||
|
||||
let grid_size = MTLSize {
|
||||
width: right.len() as NSUInteger,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width: pipeline.max_total_threads_per_threadgroup(),
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.dispatch_thread_groups(grid_size, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
let expected = vec![
|
||||
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
|
||||
];
|
||||
let result = outputs_buffer.read_to_vec::<f32>(right.len());
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cos_f16() {
|
||||
let v: Vec<f16> = [1.0f32, 2.0, 3.0]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect();
|
||||
let results = run(&v, unary::contiguous::cos::HALF);
|
||||
let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
|
||||
assert_eq!(approx_f16(results, 4), vec![0.5405, -0.4163, -0.9902]);
|
||||
assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]);
|
||||
}
|
||||
|
||||
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
||||
call_reduce_contiguous(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
v.len(),
|
||||
out_length,
|
||||
&input,
|
||||
0,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
output.read_to_vec::<T>(out_length)
|
||||
}
|
||||
|
||||
fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
call_last_softmax(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
v.len(),
|
||||
last_dim,
|
||||
&input,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
output.read_to_vec::<T>(v.len())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reduce_sum() {
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let out_length = 1;
|
||||
|
||||
let results = run_reduce(&v, out_length, "fast_sum_float");
|
||||
assert_eq!(approx(results, 4), vec![21.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reduce_sum2() {
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let out_length = 2;
|
||||
|
||||
let results = run_reduce(&v, out_length, "fast_sum_float");
|
||||
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn softmax() {
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let last_dim = 6;
|
||||
let results = run_softmax(&v, last_dim, "softmax_float");
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
|
||||
);
|
||||
|
||||
let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let last_dim = 6;
|
||||
let results = run_softmax(&v, last_dim, "softmax_float");
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
|
||||
);
|
||||
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let last_dim = 3;
|
||||
let results = run_softmax(&v, last_dim, "softmax_float");
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
|
||||
);
|
||||
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
let last_dim = 6;
|
||||
let results = run_softmax(&v, last_dim, "softmax_half");
|
||||
assert_eq!(
|
||||
approx_f16(results, 4),
|
||||
vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
|
||||
);
|
||||
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
let last_dim = 6;
|
||||
let results = run_softmax(&v, last_dim, "softmax_bfloat");
|
||||
assert_eq!(
|
||||
approx_bf16(results, 4),
|
||||
vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328]
|
||||
);
|
||||
}
|
||||
|
||||
fn run_where_cond<I: Clone, T: Clone>(
|
||||
shape: &[usize],
|
||||
cond: &[I],
|
||||
(cond_stride, cond_offset): (Vec<usize>, usize),
|
||||
left_true: &[T],
|
||||
(left_stride, left_offset): (Vec<usize>, usize),
|
||||
right_false: &[T],
|
||||
(_right_stride, _right_offset): (Vec<usize>, usize),
|
||||
name: &'static str,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let length = cond.len();
|
||||
let cond = device.new_buffer_with_data(
|
||||
cond.as_ptr() as *const core::ffi::c_void,
|
||||
std::mem::size_of_val(cond) as u64,
|
||||
options,
|
||||
);
|
||||
let left = device.new_buffer_with_data(
|
||||
left_true.as_ptr() as *const core::ffi::c_void,
|
||||
(length * core::mem::size_of::<T>()) as u64,
|
||||
options,
|
||||
);
|
||||
let right = device.new_buffer_with_data(
|
||||
right_false.as_ptr() as *const core::ffi::c_void,
|
||||
(length * core::mem::size_of::<T>()) as u64,
|
||||
options,
|
||||
);
|
||||
|
||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
||||
call_where_cond_strided(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
shape,
|
||||
&cond,
|
||||
(&cond_stride, cond_offset),
|
||||
&left,
|
||||
(&left_stride, left_offset),
|
||||
&right,
|
||||
(&cond_stride, cond_offset),
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
output.read_to_vec::<T>(length)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn where_cond() {
|
||||
let shape = vec![6];
|
||||
let cond = vec![0u8, 1, 0, 0, 1, 1];
|
||||
let cond_l = (vec![1], 0);
|
||||
let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let left_l = (vec![1], 0);
|
||||
let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0];
|
||||
let right_l = (vec![1], 0);
|
||||
let results = run_where_cond(
|
||||
&shape,
|
||||
&cond,
|
||||
cond_l,
|
||||
&left_true,
|
||||
left_l,
|
||||
&right_false,
|
||||
right_l,
|
||||
"where_u8_f32",
|
||||
);
|
||||
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
|
||||
}
|
||||
|
||||
fn run_gemm<T: Clone>(
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs: &[T],
|
||||
lhs_stride: Vec<usize>,
|
||||
rhs: &[T],
|
||||
rhs_stride: Vec<usize>,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let lhs = device.new_buffer_with_data(
|
||||
lhs.as_ptr() as *const core::ffi::c_void,
|
||||
std::mem::size_of_val(lhs) as u64,
|
||||
options,
|
||||
);
|
||||
let rhs = device.new_buffer_with_data(
|
||||
rhs.as_ptr() as *const core::ffi::c_void,
|
||||
std::mem::size_of_val(rhs) as u64,
|
||||
options,
|
||||
);
|
||||
let length = b * m * n;
|
||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
||||
call_gemm(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
"sgemm",
|
||||
(b, m, n, k),
|
||||
&lhs_stride,
|
||||
0,
|
||||
&lhs,
|
||||
&rhs_stride,
|
||||
0,
|
||||
&rhs,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
output.read_to_vec::<T>(length)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gemm() {
|
||||
let (b, m, n, k) = (1, 2, 4, 3);
|
||||
let lhs_stride = vec![m * k, k, 1];
|
||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||
let rhs_stride = vec![n * k, n, 1];
|
||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, &rhs, rhs_stride);
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
||||
);
|
||||
let (b, m, n, k) = (2, 2, 4, 3);
|
||||
let lhs_stride = vec![m * k, k, 1];
|
||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||
let rhs_stride = vec![n * k, n, 1];
|
||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, &rhs, rhs_stride);
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![
|
||||
20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0,
|
||||
518.0, 548.0, 578.0
|
||||
]
|
||||
);
|
||||
}
|
@ -11,7 +11,7 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
half = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
@ -19,7 +19,6 @@ num-traits = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
@ -30,4 +29,3 @@ default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
||||
cuda = ["candle/cuda"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
||||
metal = ["candle/metal", "dep:candle-metal-kernels"]
|
||||
|
@ -6,7 +6,7 @@ extern crate intel_mkl_src;
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::quantized::GgmlType;
|
||||
use candle::{CpuStorage, Device, Layout, Module, Result, Shape, Tensor, D};
|
||||
use candle::{CpuStorage, Device, Layout, Result, Shape, Tensor, D};
|
||||
use clap::{Parser, Subcommand};
|
||||
|
||||
const CHECK_CONV2D: bool = false;
|
||||
|
@ -201,46 +201,6 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
};
|
||||
Ok((dst, layout.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &candle::MetalStorage,
|
||||
layout: &Layout,
|
||||
) -> Result<(candle::MetalStorage, Shape)> {
|
||||
use candle::{backend::BackendStorage, DType};
|
||||
let device = storage.device();
|
||||
let command_buffer = device.command_buffer();
|
||||
let kernels = device.kernels();
|
||||
let name = match storage.dtype() {
|
||||
DType::F32 => "softmax_float",
|
||||
DType::F16 => "softmax_half",
|
||||
DType::BF16 => "softmax_bfloat",
|
||||
dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
|
||||
};
|
||||
|
||||
let n = layout.stride().len();
|
||||
if !(layout.stride()[n - 1] == 1 && layout.start_offset() == 0) {
|
||||
candle::bail!("Non contiguous softmax-last-dim is not implemented");
|
||||
}
|
||||
|
||||
let last_dim = layout.dims()[layout.shape().rank() - 1];
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let mut output = device.new_buffer(elem_count, storage.dtype());
|
||||
candle_metal_kernels::call_last_softmax(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
elem_count,
|
||||
last_dim,
|
||||
storage.buffer(),
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
|
||||
Ok((newstorage, layout.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-onnx"
|
||||
version = "0.3.1"
|
||||
version = "0.3.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "ONNX support for Candle"
|
||||
@ -10,8 +10,8 @@ categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
prost = "0.12.1"
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -741,25 +741,6 @@ pub fn simple_eval(
|
||||
let output = input.to_dtype(dtype)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#CumSum
|
||||
"CumSum" => {
|
||||
let exclusive = get_attr_opt::<i64>(node, "exclusive")?
|
||||
.copied()
|
||||
.unwrap_or(0);
|
||||
let reverse = get_attr_opt::<i64>(node, "reverse")?.copied().unwrap_or(0);
|
||||
if exclusive != 0 {
|
||||
bail!("only exclusive == 0 is supported in CumSum")
|
||||
}
|
||||
if reverse != 0 {
|
||||
bail!("only reverse == 0 is supported in CumSum")
|
||||
}
|
||||
let input = get(&node.input[0])?;
|
||||
let axis = get(&node.input[1])?
|
||||
.to_dtype(DType::U32)?
|
||||
.to_vec0::<u32>()?;
|
||||
let output = input.cumsum(axis as usize)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||
}
|
||||
}
|
||||
|
@ -1,746 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::{Device, Result, Tensor};
|
||||
use candle_onnx::onnx::{GraphProto, ModelProto, NodeProto, ValueInfoProto};
|
||||
use std::collections::HashMap;
|
||||
|
||||
const INPUT_X: &str = "x";
|
||||
const INPUT_Y: &str = "y";
|
||||
const OUTPUT_Z: &str = "z";
|
||||
|
||||
fn create_model_proto_with_graph(graph: Option<GraphProto>) -> ModelProto {
|
||||
ModelProto {
|
||||
metadata_props: vec![],
|
||||
training_info: vec![],
|
||||
functions: vec![],
|
||||
ir_version: 0,
|
||||
opset_import: vec![],
|
||||
producer_name: "".to_string(),
|
||||
producer_version: "".to_string(),
|
||||
domain: "".to_string(),
|
||||
model_version: 0,
|
||||
doc_string: "".to_string(),
|
||||
graph,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluation_fails_without_defined_graph() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(None);
|
||||
|
||||
let inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
|
||||
match candle_onnx::simple_eval(&manual_graph, inputs) {
|
||||
Err(err) => assert_eq!(err.to_string(), "no graph defined in proto"),
|
||||
Ok(_) => panic!("Expected an error due to undefined graph"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Add"
|
||||
#[test]
|
||||
fn test_add_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Add".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let first = z
|
||||
.to_vec1::<f64>()?
|
||||
.to_vec()
|
||||
.get(0)
|
||||
.expect("Failed to get first element")
|
||||
.clone();
|
||||
assert_eq!(first, 4.0f64);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Sub"
|
||||
#[test]
|
||||
fn test_sub_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Sub".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let first = z
|
||||
.to_vec1::<f64>()?
|
||||
.to_vec()
|
||||
.get(0)
|
||||
.expect("Failed to get first element")
|
||||
.clone();
|
||||
assert_eq!(first, 0.0f64);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Mul"
|
||||
#[test]
|
||||
fn test_mul_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Mul".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let first = z
|
||||
.to_vec1::<f64>()?
|
||||
.to_vec()
|
||||
.get(0)
|
||||
.expect("Failed to get first element")
|
||||
.clone();
|
||||
assert_eq!(first, 4.0f64);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Div"
|
||||
#[test]
|
||||
fn test_div_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Div".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let first = z
|
||||
.to_vec1::<f64>()?
|
||||
.to_vec()
|
||||
.get(0)
|
||||
.expect("Failed to get first element")
|
||||
.clone();
|
||||
|
||||
assert_eq!(first, 1.0f64);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Equal"
|
||||
#[test]
|
||||
fn test_equal_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Equal".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let first = z.to_dtype(candle::DType::U8)?.to_vec1::<u8>()?.to_vec()[0];
|
||||
assert_eq!(first, 1);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Not"
|
||||
#[test]
|
||||
fn test_not_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Not".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(&[0.], &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let first = z.to_dtype(candle::DType::U8)?.to_vec1::<u8>()?.to_vec()[0];
|
||||
assert_eq!(first, 1);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "MatMul"
|
||||
#[test]
|
||||
fn test_matmul_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "MatMul".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(
|
||||
INPUT_X.to_string(),
|
||||
Tensor::from_vec(
|
||||
//
|
||||
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
|
||||
&[2, 2],
|
||||
&Device::Cpu,
|
||||
)?,
|
||||
);
|
||||
inputs.insert(
|
||||
INPUT_Y.to_string(),
|
||||
Tensor::from_vec(
|
||||
//
|
||||
vec![5.0f32, 6.0f32, 7.0f32, 8.0f32],
|
||||
&[2, 2],
|
||||
&Device::Cpu,
|
||||
)?,
|
||||
);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
assert_eq!(results, vec![vec![19.0, 22.0], vec![43.0, 50.0]]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Reshape"
|
||||
#[test]
|
||||
fn test_reshape_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Reshape".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: INPUT_Y.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
//
|
||||
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
|
||||
&[2, 2],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let y = Tensor::from_vec(
|
||||
//
|
||||
vec![4i64],
|
||||
&[1],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
inputs.insert(INPUT_Y.to_string(), y);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let results = z.to_vec1::<f32>()?;
|
||||
|
||||
assert_eq!(results, vec![1.0, 2.0, 3.0, 4.0]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "LogSoftmax"
|
||||
#[test]
|
||||
fn test_logsoftmax_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "LogSoftmax".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: INPUT_Y.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
//
|
||||
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
|
||||
&[2, 2],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![vec![0.26894143, 0.7310586], vec![0.26894143, 0.7310586]]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Softmax"
|
||||
#[test]
|
||||
fn test_softmax_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Softmax".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: INPUT_Y.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
//
|
||||
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
|
||||
&[2, 2],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![vec![0.26894143, 0.7310586], vec![0.26894143, 0.7310586]]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Transpose"
|
||||
#[test]
|
||||
fn test_transpose_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Transpose".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: INPUT_Y.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
//
|
||||
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
|
||||
&[2, 2],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
|
||||
assert_eq!(results, vec![vec![1.0, 3.0], vec![2.0, 4.0]]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Dropout"
|
||||
#[test]
|
||||
fn test_dropout_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Dropout".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: INPUT_Y.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
let x = Tensor::from_vec(
|
||||
//
|
||||
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
|
||||
&[2, 2],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
|
||||
assert_eq!(results, vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Below are ops that are implemented but not tested yet
|
||||
|
||||
// "MaxPool"
|
||||
// #[test]
|
||||
|
||||
// "AveragePool"
|
||||
// #[test]
|
||||
|
||||
// "BatchNormalization"
|
||||
// #[test]
|
||||
|
||||
// "Squeeze"
|
||||
// #[test]
|
||||
|
||||
// "ConstantOfShape"
|
||||
// #[test]
|
||||
|
||||
// "Unsqueeze"
|
||||
// #[test]
|
||||
|
||||
// "Clip"
|
||||
// #[test]
|
||||
|
||||
// "Gather"
|
||||
// #[test]
|
||||
|
||||
// "Shape"
|
||||
// #[test]
|
||||
|
||||
// "Conv"
|
||||
// #[test]
|
||||
|
||||
// "Concat"
|
||||
// #[test]
|
||||
|
||||
// "Abs"
|
||||
// #[test]
|
||||
|
||||
// "Cos"
|
||||
// #[test]
|
||||
|
||||
// "Sin"
|
||||
// #[test]
|
||||
|
||||
// "Neg"
|
||||
// #[test]
|
||||
|
||||
// "Erf"
|
||||
// #[test]
|
||||
|
||||
// "Tanh"
|
||||
// #[test]
|
||||
|
||||
// "Sigmoid"
|
||||
// #[test]
|
||||
|
||||
// "Gelu"
|
||||
// #[test]
|
||||
|
||||
// "Relu"
|
||||
// #[test]
|
||||
|
||||
// "Constant"
|
||||
// #[test]
|
||||
|
||||
// "Cast"
|
||||
// #[test]
|
@ -15,9 +15,9 @@ crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
||||
candle-onnx = {path= "../candle-onnx", version = "0.3.1", optional = true}
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
candle-onnx = {path= "../candle-onnx", version = "0.3.0", optional = true}
|
||||
half = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
|
||||
|
@ -17,7 +17,7 @@ extern crate intel_mkl_src;
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use ::candle::{quantized::QTensor, DType, Device, Module, Tensor, WithDType};
|
||||
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
|
||||
|
||||
mod utils;
|
||||
use utils::wrap_err;
|
||||
|
@ -12,9 +12,9 @@ readme = "README.md"
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
|
@ -1,342 +0,0 @@
|
||||
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle_nn::{Embedding, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
pub const DTYPE: DType = DType::F32;
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
enum HiddenAct {
|
||||
Gelu,
|
||||
Relu,
|
||||
}
|
||||
|
||||
struct HiddenActLayer {
|
||||
act: HiddenAct,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl HiddenActLayer {
|
||||
fn new(act: HiddenAct) -> Self {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "hidden-act");
|
||||
Self { act, span }
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for HiddenActLayer {
|
||||
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
match self.act {
|
||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
|
||||
HiddenAct::Gelu => xs.gelu(),
|
||||
HiddenAct::Relu => xs.relu(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
enum PositionEmbeddingType {
|
||||
#[default]
|
||||
Absolute,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
vocab_size: usize,
|
||||
dim: usize,
|
||||
n_layers: usize,
|
||||
n_heads: usize,
|
||||
hidden_dim: usize,
|
||||
activation: HiddenAct,
|
||||
max_position_embeddings: usize,
|
||||
initializer_range: f64,
|
||||
pad_token_id: usize,
|
||||
#[serde(default)]
|
||||
position_embedding_type: PositionEmbeddingType,
|
||||
#[serde(default)]
|
||||
use_cache: bool,
|
||||
model_type: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
vocab_size: 30522,
|
||||
dim: 768,
|
||||
n_layers: 12,
|
||||
n_heads: 12,
|
||||
hidden_dim: 3072,
|
||||
activation: HiddenAct::Gelu,
|
||||
max_position_embeddings: 512,
|
||||
initializer_range: 0.02,
|
||||
pad_token_id: 0,
|
||||
position_embedding_type: PositionEmbeddingType::Absolute,
|
||||
use_cache: true,
|
||||
model_type: Some("distilbert".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Embeddings {
|
||||
word_embeddings: Embedding,
|
||||
position_embeddings: Embedding,
|
||||
layer_norm: LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Embeddings {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let word_embeddings =
|
||||
candle_nn::embedding(config.vocab_size, config.dim, vb.pp("word_embeddings"))?;
|
||||
let position_embeddings = candle_nn::embedding(
|
||||
config.max_position_embeddings,
|
||||
config.dim,
|
||||
vb.pp("position_embeddings"),
|
||||
)?;
|
||||
let layer_norm = layer_norm(config.dim, 1e-12, vb.pp("LayerNorm"))?;
|
||||
Ok(Self {
|
||||
word_embeddings,
|
||||
position_embeddings,
|
||||
layer_norm,
|
||||
span: tracing::span!(tracing::Level::TRACE, "embeddings"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (_bsize, seq_len) = input_ids.dims2()?;
|
||||
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
||||
let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
|
||||
let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
|
||||
let embeddings =
|
||||
input_embeddings.broadcast_add(&self.position_embeddings.forward(&position_ids)?)?;
|
||||
|
||||
let embeddings = self.layer_norm.forward(&embeddings)?;
|
||||
Ok(embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
struct MultiHeadSelfAttention {
|
||||
q_lin: Linear,
|
||||
k_lin: Linear,
|
||||
v_lin: Linear,
|
||||
out_lin: Linear,
|
||||
n_heads: usize,
|
||||
attention_head_size: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl MultiHeadSelfAttention {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let attention_head_size = config.dim / config.n_heads;
|
||||
let all_head_size = config.n_heads * attention_head_size;
|
||||
let dim = config.dim;
|
||||
let q_lin = linear(dim, all_head_size, vb.pp("q_lin"))?;
|
||||
let v_lin = linear(dim, all_head_size, vb.pp("v_lin"))?;
|
||||
let k_lin = linear(dim, all_head_size, vb.pp("k_lin"))?;
|
||||
let out_lin = linear(all_head_size, dim, vb.pp("out_lin"))?;
|
||||
Ok(Self {
|
||||
q_lin,
|
||||
k_lin,
|
||||
v_lin,
|
||||
out_lin,
|
||||
n_heads: config.n_heads,
|
||||
attention_head_size,
|
||||
span: tracing::span!(tracing::Level::TRACE, "attention"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl MultiHeadSelfAttention {
|
||||
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (bs, q_length, _dim) = hidden_states.dims3()?;
|
||||
|
||||
let dim_per_head = self.attention_head_size;
|
||||
let q = self.q_lin.forward(hidden_states)?;
|
||||
let k = self.k_lin.forward(hidden_states)?;
|
||||
let v = self.v_lin.forward(hidden_states)?;
|
||||
|
||||
let q = q
|
||||
.reshape((bs, q_length, self.n_heads, dim_per_head))?
|
||||
.transpose(1, 2)?;
|
||||
let k = k
|
||||
.reshape((bs, q_length, self.n_heads, dim_per_head))?
|
||||
.transpose(1, 2)?;
|
||||
let v = v
|
||||
.reshape((bs, q_length, self.n_heads, dim_per_head))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let q: Tensor = (q / (dim_per_head as f64).sqrt())?;
|
||||
let scores = q.matmul(&k.transpose(2, 3)?.contiguous()?)?;
|
||||
let mask = attention_mask.broadcast_as(scores.shape())?;
|
||||
|
||||
let scores = masked_fill(&scores.to_dtype(DType::F32)?, &mask, f32::NEG_INFINITY)?;
|
||||
let weights = candle_nn::ops::softmax(&scores, candle::D::Minus1)?;
|
||||
|
||||
let context = weights.matmul(&v.contiguous()?)?;
|
||||
let context = context
|
||||
.transpose(1, 2)?
|
||||
.reshape((bs, q_length, self.n_heads * dim_per_head))?
|
||||
.contiguous()?;
|
||||
let context = self.out_lin.forward(&context)?;
|
||||
|
||||
Ok(context)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
struct FFN {
|
||||
lin1: Linear,
|
||||
lin2: Linear,
|
||||
activation: HiddenActLayer,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl FFN {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let lin1 = linear(config.dim, config.hidden_dim, vb.pp("lin1"))?;
|
||||
let lin2 = linear(config.hidden_dim, config.dim, vb.pp("lin2"))?;
|
||||
Ok(Self {
|
||||
lin1,
|
||||
lin2,
|
||||
activation: HiddenActLayer::new(config.activation),
|
||||
span: tracing::span!(tracing::Level::TRACE, "ffn"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for FFN {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
hidden_states
|
||||
.apply(&self.lin1)?
|
||||
.apply(&self.activation)?
|
||||
.apply(&self.lin2)
|
||||
}
|
||||
}
|
||||
|
||||
struct TransformerBlock {
|
||||
attention: MultiHeadSelfAttention,
|
||||
sa_layer_norm: LayerNorm,
|
||||
ffn: FFN,
|
||||
output_layer_norm: LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl TransformerBlock {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let attention = MultiHeadSelfAttention::load(vb.pp("attention"), config)?;
|
||||
let sa_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("sa_layer_norm"))?;
|
||||
let ffn = FFN::load(vb.pp("ffn"), config)?;
|
||||
let output_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("output_layer_norm"))?;
|
||||
Ok(Self {
|
||||
attention,
|
||||
sa_layer_norm,
|
||||
ffn,
|
||||
output_layer_norm,
|
||||
span: tracing::span!(tracing::Level::TRACE, "layer"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TransformerBlock {
|
||||
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let sa_output = self.attention.forward(hidden_states, attention_mask)?;
|
||||
// TODO: Support cross-attention?
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
|
||||
// TODO: Support something similar to `apply_chunking_to_forward`?
|
||||
let sa_output = sa_output.broadcast_add(hidden_states)?;
|
||||
let sa_output = self.sa_layer_norm.forward(&sa_output)?;
|
||||
|
||||
let ffn_output = self.ffn.forward(&sa_output)?;
|
||||
let ffn_output = (&ffn_output + sa_output)?;
|
||||
let output = self.output_layer_norm.forward(&ffn_output)?;
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
|
||||
struct Transformer {
|
||||
layers: Vec<TransformerBlock>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Transformer {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let layers = (0..config.n_layers)
|
||||
.map(|index| TransformerBlock::load(vb.pp(&format!("layer.{index}")), config))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "encoder");
|
||||
Ok(Transformer { layers, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl Transformer {
|
||||
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut hidden_states = hidden_states.clone();
|
||||
// Use a loop rather than a fold as it's easier to modify when adding debug/...
|
||||
for layer in self.layers.iter() {
|
||||
hidden_states = layer.forward(&hidden_states, attention_mask)?;
|
||||
}
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DistilBertModel {
|
||||
embeddings: Embeddings,
|
||||
transformer: Transformer,
|
||||
pub device: Device,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl DistilBertModel {
|
||||
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let (embeddings, transformer) = match (
|
||||
Embeddings::load(vb.pp("embeddings"), config),
|
||||
Transformer::load(vb.pp("transformer"), config),
|
||||
) {
|
||||
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
|
||||
(Err(err), _) | (_, Err(err)) => {
|
||||
if let Some(model_type) = &config.model_type {
|
||||
if let (Ok(embeddings), Ok(encoder)) = (
|
||||
Embeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
|
||||
Transformer::load(vb.pp(&format!("{model_type}.transformer")), config),
|
||||
) {
|
||||
(embeddings, encoder)
|
||||
} else {
|
||||
return Err(err);
|
||||
}
|
||||
} else {
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
transformer,
|
||||
device: vb.device().clone(),
|
||||
span: tracing::span!(tracing::Level::TRACE, "model"),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let embedding_output = self.embeddings.forward(input_ids)?;
|
||||
let sequence_output = self
|
||||
.transformer
|
||||
.forward(&embedding_output, attention_mask)?;
|
||||
Ok(sequence_output)
|
||||
}
|
||||
}
|
@ -4,7 +4,6 @@ pub mod blip;
|
||||
pub mod blip_text;
|
||||
pub mod convmixer;
|
||||
pub mod dinov2;
|
||||
pub mod distilbert;
|
||||
pub mod efficientnet;
|
||||
pub mod falcon;
|
||||
pub mod jina_bert;
|
||||
|
@ -1,15 +1,12 @@
|
||||
use candle::{Device, Result, Tensor};
|
||||
|
||||
pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
|
||||
if steps == 0 {
|
||||
Tensor::from_vec(Vec::<f64>::new(), steps, &Device::Cpu)
|
||||
} else if steps == 1 {
|
||||
Tensor::from_vec(vec![start], steps, &Device::Cpu)
|
||||
} else {
|
||||
let delta = (stop - start) / (steps - 1) as f64;
|
||||
let vs = (0..steps)
|
||||
.map(|step| start + step as f64 * delta)
|
||||
.collect::<Vec<_>>();
|
||||
Tensor::from_vec(vs, steps, &Device::Cpu)
|
||||
if steps < 1 {
|
||||
candle::bail!("cannot use linspace with steps {steps} <= 1")
|
||||
}
|
||||
let delta = (stop - start) / (steps - 1) as f64;
|
||||
let vs = (0..steps)
|
||||
.map(|step| start + step as f64 * delta)
|
||||
.collect::<Vec<_>>();
|
||||
Tensor::from_vec(vs, steps, &Device::Cpu)
|
||||
}
|
||||
|
@ -138,10 +138,6 @@ impl TrOCRAttention {
|
||||
})
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
|
||||
fn _shape(&self, tensor: &Tensor, bsz: usize) -> Result<Tensor> {
|
||||
tensor
|
||||
.reshape((bsz, (), self.num_heads, self.head_dim))?
|
||||
@ -243,10 +239,6 @@ impl TrOCRDecoderLayer {
|
||||
})
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.self_attn.reset_kv_cache();
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
@ -315,10 +307,6 @@ impl TrOCRDecoder {
|
||||
})
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.layers.iter_mut().for_each(|l| l.reset_kv_cache())
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
@ -405,10 +393,6 @@ impl TrOCRForCausalLM {
|
||||
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.decoder.reset_kv_cache();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -447,8 +431,4 @@ impl TrOCRModel {
|
||||
self.decoder
|
||||
.forward(xs, Some(encoder_xs), past_kv_len, &mask)
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
self.decoder.reset_kv_cache();
|
||||
}
|
||||
}
|
||||
|
@ -58,7 +58,8 @@ fn dft<T: Float>(inp: &[T]) -> Vec<T> {
|
||||
let n = inp.len();
|
||||
let two_pi = T::PI() + T::PI();
|
||||
|
||||
let mut out = Vec::with_capacity(2 * n);
|
||||
let mut out = Vec::new();
|
||||
out.reserve(2 * n);
|
||||
let n_t = T::from(n).unwrap();
|
||||
for k in 0..n {
|
||||
let k_t = T::from(k).unwrap();
|
||||
|
@ -43,4 +43,4 @@ pub const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
|
||||
pub const TRANSLATE_TOKEN: &str = "<|translate|>";
|
||||
pub const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
|
||||
pub const EOT_TOKEN: &str = "<|endoftext|>";
|
||||
pub const NO_SPEECH_TOKENS: [&str; 2] = ["<|nocaptions|>", "<|nospeech|>"];
|
||||
pub const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
|
||||
|
@ -277,12 +277,8 @@ impl DecoderLayer {
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||
let ln2 = RmsNorm::new(
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?;
|
||||
let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("ln1"))?;
|
||||
let ln2 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("ln2"))?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
mlp,
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
||||
num-traits = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
num-traits = { workspace = true }
|
||||
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
||||
num-traits = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
num-traits = { workspace = true }
|
||||
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
||||
num-traits = { workspace = true }
|
||||
|
||||
# App crates.
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
||||
num-traits = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
||||
num-traits = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
|
||||
|
@ -59,7 +59,8 @@ fn dft<T: Float>(inp: &[T]) -> Vec<T> {
|
||||
let n = inp.len();
|
||||
let two_pi = T::PI() + T::PI();
|
||||
|
||||
let mut out = Vec::with_capacity(2 * n);
|
||||
let mut out = Vec::new();
|
||||
out.reserve(2 * n);
|
||||
let n_t = T::from(n).unwrap();
|
||||
for k in 0..n {
|
||||
let k_t = T::from(k).unwrap();
|
||||
|
@ -129,13 +129,7 @@ impl Decoder {
|
||||
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
|
||||
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
|
||||
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
|
||||
let no_speech_token = m::NO_SPEECH_TOKENS
|
||||
.iter()
|
||||
.find_map(|token| token_id(&tokenizer, token).ok());
|
||||
let no_speech_token = match no_speech_token {
|
||||
None => anyhow::bail!("unable to find any non-speech token"),
|
||||
Some(n) => n,
|
||||
};
|
||||
let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
|
||||
let seed = 299792458;
|
||||
Ok(Self {
|
||||
model,
|
||||
|
@ -9,8 +9,8 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||
num-traits = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
@ -7,7 +7,7 @@ keywords.workspace = true
|
||||
categories.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
rand = { workspace = true }
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
use candle::{
|
||||
quantized::{self, k_quants, GgmlDType, GgmlType},
|
||||
test_utils::to_vec2_round,
|
||||
Device, Module, Result, Tensor,
|
||||
Device, Result, Tensor,
|
||||
};
|
||||
|
||||
use wasm_bindgen_test::*;
|
||||
|
Reference in New Issue
Block a user