mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Compare commits
26 Commits
tmp4
...
metal_heap
Author | SHA1 | Date | |
---|---|---|---|
e8c1c31245 | |||
51f05e997d | |||
4289984d32 | |||
1471f98f0b | |||
dd4a40f1c0 | |||
79845bd93b | |||
6071797450 | |||
b58b247323 | |||
3900091e75 | |||
54355ff997 | |||
e02f1912bb | |||
a52b71686b | |||
7adfb70dff | |||
3ad02147e4 | |||
4f39695465 | |||
4cf4844c9d | |||
d840838e95 | |||
61a070fdd1 | |||
e35669647d | |||
53e8b7ee3e | |||
cc26cce23c | |||
02c2ec2c71 | |||
9a2784b8ab | |||
0f652f0e3d | |||
ddee9dc1dd | |||
fc9bb7784a |
@ -19,7 +19,7 @@ exclude = [
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.3.1"
|
version = "0.3.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "Minimalist ML framework."
|
description = "Minimalist ML framework."
|
||||||
repository = "https://github.com/huggingface/candle"
|
repository = "https://github.com/huggingface/candle"
|
||||||
@ -61,7 +61,8 @@ tracing-subscriber = "0.3.7"
|
|||||||
wav = "1.0.0"
|
wav = "1.0.0"
|
||||||
yoke = { version = "0.7.2", features = ["derive"] }
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "0.6.6", default-features = false }
|
zip = { version = "0.6.6", default-features = false }
|
||||||
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
|
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||||
|
metal = { path = "../metal-rs", features = ["mps"] }
|
||||||
|
|
||||||
[profile.release-with-debug]
|
[profile.release-with-debug]
|
||||||
inherits = "release"
|
inherits = "release"
|
||||||
|
17
README.md
17
README.md
@ -139,16 +139,16 @@ And then head over to
|
|||||||
<!--- ANCHOR: useful_libraries --->
|
<!--- ANCHOR: useful_libraries --->
|
||||||
|
|
||||||
## Useful External Resources
|
## Useful External Resources
|
||||||
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): A
|
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): a
|
||||||
very detailed tutorial showing how to convert a PyTorch model to Candle.
|
very detailed tutorial showing how to convert a PyTorch model to Candle.
|
||||||
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): Efficient and ergonomic LoRA implemenation for Candle. `candle-lora` has
|
- [`optimisers`](https://github.com/KGrewal1/optimisers): a collection of optimisers
|
||||||
out-of-the-box LoRA support for many models from Candle, which can be found [here](https://github.com/EricLBuehler/candle-lora/tree/master/candle-lora-transformers/examples).
|
|
||||||
- [`optimisers`](https://github.com/KGrewal1/optimisers): A collection of optimisers
|
|
||||||
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
|
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
|
||||||
|
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): a LoRA implementation
|
||||||
|
that conforms to the official `peft` implementation.
|
||||||
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
|
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
|
||||||
serving local LLMs including an OpenAI compatible API server.
|
serving local LLMs including an OpenAI compatible API server.
|
||||||
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
|
- [`candle-ext`](https://github.com/mokeyish/candle-ext): an extension library to Candle that provides PyTorch functions not currently available in Candle.
|
||||||
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
- [`kalosm`](https://github.com/floneum/floneum/tree/master/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
||||||
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
||||||
|
|
||||||
If you have an addition to this list, please submit a pull request.
|
If you have an addition to this list, please submit a pull request.
|
||||||
@ -177,11 +177,6 @@ If you have an addition to this list, please submit a pull request.
|
|||||||
- Replit-code-v1.5-3B.
|
- Replit-code-v1.5-3B.
|
||||||
- Bert.
|
- Bert.
|
||||||
- Yi-6B and Yi-34B.
|
- Yi-6B and Yi-34B.
|
||||||
- Quantized LLMs.
|
|
||||||
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
|
||||||
- Mistral 7b, and 7b instruct.
|
|
||||||
- Zephyr 7b a and b (Mistral based).
|
|
||||||
- OpenChat 3.5 (Mistral based).
|
|
||||||
- Text to text.
|
- Text to text.
|
||||||
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
|
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
|
||||||
- Marian MT (Machine Translation).
|
- Marian MT (Machine Translation).
|
||||||
|
@ -11,11 +11,11 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-datasets = { path = "../candle-datasets", version = "0.3.1" }
|
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||||
candle-transformers = { path = "../candle-transformers", version = "0.3.1" }
|
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
||||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true }
|
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||||
safetensors = { workspace = true }
|
safetensors = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
|
@ -12,8 +12,8 @@ readme = "README.md"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
candle-kernels = { path = "../candle-kernels", version = "0.3.1", optional = true }
|
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
|
||||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.1", optional = true }
|
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
||||||
metal = { workspace = true, optional = true}
|
metal = { workspace = true, optional = true}
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
gemm = { workspace = true }
|
gemm = { workspace = true }
|
||||||
|
@ -8,10 +8,11 @@ use anyhow::Result;
|
|||||||
use candle_core::{Device, Tensor};
|
use candle_core::{Device, Tensor};
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?;
|
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
|
||||||
let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?;
|
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
|
||||||
let new_a = a.slice_scatter(&b, 1, 2)?;
|
let start = std::time::Instant::now();
|
||||||
assert_eq!(a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
let res = inp.conv2d(&w, 0, 1, 1, 1)?;
|
||||||
assert_eq!(new_a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
println!("{:?}", start.elapsed());
|
||||||
|
println!("{res:?}");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -104,31 +104,37 @@ impl From<&Tensor> for TensorIndexer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
trait RB: RangeBounds<usize> {}
|
macro_rules! impl_from_range {
|
||||||
impl RB for Range<usize> {}
|
($range_type:ty) => {
|
||||||
impl RB for RangeFrom<usize> {}
|
impl From<$range_type> for TensorIndexer {
|
||||||
impl RB for RangeFull {}
|
fn from(range: $range_type) -> Self {
|
||||||
impl RB for RangeInclusive<usize> {}
|
use std::ops::Bound::*;
|
||||||
impl RB for RangeTo<usize> {}
|
|
||||||
impl RB for RangeToInclusive<usize> {}
|
|
||||||
|
|
||||||
impl<T: RB> From<T> for TensorIndexer {
|
let start = match range.start_bound() {
|
||||||
fn from(range: T) -> Self {
|
Included(idx) => Included(*idx),
|
||||||
use std::ops::Bound::*;
|
Excluded(idx) => Excluded(*idx),
|
||||||
let start = match range.start_bound() {
|
Unbounded => Unbounded,
|
||||||
Included(idx) => Included(*idx),
|
};
|
||||||
Excluded(idx) => Excluded(*idx),
|
|
||||||
Unbounded => Unbounded,
|
let end = match range.end_bound() {
|
||||||
};
|
Included(idx) => Included(*idx),
|
||||||
let end = match range.end_bound() {
|
Excluded(idx) => Excluded(*idx),
|
||||||
Included(idx) => Included(*idx),
|
Unbounded => Unbounded,
|
||||||
Excluded(idx) => Excluded(*idx),
|
};
|
||||||
Unbounded => Unbounded,
|
|
||||||
};
|
TensorIndexer::Narrow(start, end)
|
||||||
TensorIndexer::Narrow(start, end)
|
}
|
||||||
}
|
}
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl_from_range!(Range<usize>);
|
||||||
|
impl_from_range!(RangeFrom<usize>);
|
||||||
|
impl_from_range!(RangeFull);
|
||||||
|
impl_from_range!(RangeInclusive<usize>);
|
||||||
|
impl_from_range!(RangeTo<usize>);
|
||||||
|
impl_from_range!(RangeToInclusive<usize>);
|
||||||
|
|
||||||
/// Trait used to implement multiple signatures for ease of use of the slicing
|
/// Trait used to implement multiple signatures for ease of use of the slicing
|
||||||
/// of a tensor
|
/// of a tensor
|
||||||
pub trait IndexOp<T> {
|
pub trait IndexOp<T> {
|
||||||
|
@ -123,6 +123,12 @@ pub trait Module {
|
|||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Module for quantized::QMatMul {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
self.forward(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
self(xs)
|
self(xs)
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -307,8 +307,8 @@ impl crate::CustomOp1 for QTensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl crate::Module for QMatMul {
|
impl QMatMul {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
match self {
|
match self {
|
||||||
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
||||||
Self::Tensor(w) => {
|
Self::Tensor(w) => {
|
||||||
|
@ -2457,110 +2457,6 @@ impl Tensor {
|
|||||||
Ok(naxis as usize)
|
Ok(naxis as usize)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a lower triangular matrix of ones of size n by n.
|
|
||||||
pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
|
||||||
let t = Tensor::arange(0u32, n as u32, device)?;
|
|
||||||
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
|
||||||
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
|
||||||
t1.le(&t2)?.to_dtype(dtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns an upper triangular matrix of ones of size n by n.
|
|
||||||
pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
|
||||||
let t = Tensor::arange(0u32, n as u32, device)?;
|
|
||||||
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
|
||||||
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
|
||||||
t1.ge(&t2)?.to_dtype(dtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a matrix with a diagonal of ones of size n by n.
|
|
||||||
pub fn eye(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
|
||||||
let t = Tensor::arange(0u32, n as u32, device)?;
|
|
||||||
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
|
||||||
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
|
||||||
t1.eq(&t2)?.to_dtype(dtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the cumulative sum of elements of the input tensor summed over the specified
|
|
||||||
/// dimension.
|
|
||||||
///
|
|
||||||
/// This operation is most efficient when dim is the last dimension of the tensor.
|
|
||||||
pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> {
|
|
||||||
let dim = dim.to_index(self.shape(), "cumsum")?;
|
|
||||||
let rank = self.rank();
|
|
||||||
if rank == 0 {
|
|
||||||
return Ok(self.clone());
|
|
||||||
}
|
|
||||||
let n_axis = self.dim(dim)?;
|
|
||||||
let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?;
|
|
||||||
if rank == 1 {
|
|
||||||
self.unsqueeze(0)?.matmul(&triu)?.squeeze(0)
|
|
||||||
} else {
|
|
||||||
let last = rank - 1;
|
|
||||||
let t = self.transpose(dim, last)?;
|
|
||||||
let t = t.broadcast_matmul(&triu)?;
|
|
||||||
t.transpose(dim, last)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a copy of `self` where the values within `ranges` have been replaced with the
|
|
||||||
/// content of `src`.
|
|
||||||
pub fn slice_assign<D: std::ops::RangeBounds<usize>>(
|
|
||||||
&self,
|
|
||||||
ranges: &[D],
|
|
||||||
src: &Tensor,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let src_dims = src.dims();
|
|
||||||
let self_dims = self.dims();
|
|
||||||
if self_dims.len() != src_dims.len() {
|
|
||||||
crate::bail!(
|
|
||||||
"slice-assign requires input with the same rank {} <> {}",
|
|
||||||
self_dims.len(),
|
|
||||||
src_dims.len()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if self_dims.len() != ranges.len() {
|
|
||||||
crate::bail!(
|
|
||||||
"slice-assign requires input with the same rank as there are ranges {} <> {}",
|
|
||||||
self_dims.len(),
|
|
||||||
ranges.len()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
let mut src = src.clone();
|
|
||||||
let mut mask = Self::ones(src.shape(), DType::U8, src.device())?;
|
|
||||||
for (i, range) in ranges.iter().enumerate() {
|
|
||||||
let start_included = match range.start_bound() {
|
|
||||||
std::ops::Bound::Unbounded => 0,
|
|
||||||
std::ops::Bound::Included(v) => *v,
|
|
||||||
std::ops::Bound::Excluded(v) => *v + 1,
|
|
||||||
};
|
|
||||||
let end_excluded = match range.end_bound() {
|
|
||||||
std::ops::Bound::Unbounded => self_dims[i],
|
|
||||||
std::ops::Bound::Included(v) => *v + 1,
|
|
||||||
std::ops::Bound::Excluded(v) => *v,
|
|
||||||
};
|
|
||||||
if end_excluded <= start_included {
|
|
||||||
crate::bail!(
|
|
||||||
"slice-assign: empty range for dim {i}, {start_included} {end_excluded}"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if self_dims[i] < end_excluded {
|
|
||||||
crate::bail!(
|
|
||||||
"slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
|
|
||||||
self_dims[i]
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if end_excluded - start_included != src_dims[i] {
|
|
||||||
crate::bail!(
|
|
||||||
"slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
|
|
||||||
)
|
|
||||||
}
|
|
||||||
src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?;
|
|
||||||
mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?
|
|
||||||
}
|
|
||||||
mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! bin_trait {
|
macro_rules! bin_trait {
|
||||||
|
@ -91,32 +91,3 @@ fn index_3d() -> Result<()> {
|
|||||||
assert_eq!(tensor.i((1, .., 3))?.to_vec1::<u32>()?, &[15, 19, 23]);
|
assert_eq!(tensor.i((1, .., 3))?.to_vec1::<u32>()?, &[15, 19, 23]);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn slice_assign() -> Result<()> {
|
|
||||||
let dev = Device::Cpu;
|
|
||||||
|
|
||||||
let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?;
|
|
||||||
let src = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((3, 2))?;
|
|
||||||
let out = tensor.slice_assign(&[1..4, 3..5], &src)?;
|
|
||||||
assert_eq!(
|
|
||||||
out.to_vec2::<u32>()?,
|
|
||||||
&[
|
|
||||||
[0, 1, 2, 3, 4],
|
|
||||||
[5, 6, 7, 0, 1],
|
|
||||||
[10, 11, 12, 2, 3],
|
|
||||||
[15, 16, 17, 4, 5]
|
|
||||||
]
|
|
||||||
);
|
|
||||||
let out = tensor.slice_assign(&[0..3, 0..2], &src)?;
|
|
||||||
assert_eq!(
|
|
||||||
out.to_vec2::<u32>()?,
|
|
||||||
&[
|
|
||||||
[0, 1, 2, 3, 4],
|
|
||||||
[2, 3, 7, 8, 9],
|
|
||||||
[4, 5, 12, 13, 14],
|
|
||||||
[15, 16, 17, 18, 19]
|
|
||||||
]
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use candle_core::{
|
use candle_core::{
|
||||||
quantized::{self, GgmlDType},
|
quantized::{self, GgmlDType},
|
||||||
test_utils::to_vec2_round,
|
test_utils::to_vec2_round,
|
||||||
Device, Module, Result, Tensor,
|
Device, Result, Tensor,
|
||||||
};
|
};
|
||||||
use quantized::{k_quants, GgmlType};
|
use quantized::{k_quants, GgmlType};
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
|
@ -1159,65 +1159,3 @@ fn i64_abs() -> Result<()> {
|
|||||||
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
|
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn tril_triu_eye() -> Result<()> {
|
|
||||||
let t = Tensor::tril2(4, DType::F32, &Device::Cpu)?;
|
|
||||||
assert_eq!(
|
|
||||||
t.to_vec2::<f32>()?,
|
|
||||||
[
|
|
||||||
[1.0, 0.0, 0.0, 0.0],
|
|
||||||
[1.0, 1.0, 0.0, 0.0],
|
|
||||||
[1.0, 1.0, 1.0, 0.0],
|
|
||||||
[1.0, 1.0, 1.0, 1.0]
|
|
||||||
],
|
|
||||||
);
|
|
||||||
let t = Tensor::triu2(4, DType::F32, &Device::Cpu)?;
|
|
||||||
assert_eq!(
|
|
||||||
t.to_vec2::<f32>()?,
|
|
||||||
[
|
|
||||||
[1.0, 1.0, 1.0, 1.0],
|
|
||||||
[0.0, 1.0, 1.0, 1.0],
|
|
||||||
[0.0, 0.0, 1.0, 1.0],
|
|
||||||
[0.0, 0.0, 0.0, 1.0]
|
|
||||||
]
|
|
||||||
);
|
|
||||||
let t = Tensor::eye(4, DType::F32, &Device::Cpu)?;
|
|
||||||
assert_eq!(
|
|
||||||
t.to_vec2::<f32>()?,
|
|
||||||
[
|
|
||||||
[1.0, 0.0, 0.0, 0.0],
|
|
||||||
[0.0, 1.0, 0.0, 0.0],
|
|
||||||
[0.0, 0.0, 1.0, 0.0],
|
|
||||||
[0.0, 0.0, 0.0, 1.0]
|
|
||||||
]
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn cumsum() -> Result<()> {
|
|
||||||
let t = &[3f32, 1., 4., 1., 5.];
|
|
||||||
let t = Tensor::new(t, &Device::Cpu)?;
|
|
||||||
assert_eq!(t.cumsum(0)?.to_vec1::<f32>()?, [3., 4., 8., 9., 14.]);
|
|
||||||
let t = t.unsqueeze(1)?;
|
|
||||||
assert_eq!(
|
|
||||||
t.cumsum(0)?.to_vec2::<f32>()?,
|
|
||||||
[[3.0], [4.0], [8.0], [9.0], [14.0]]
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
t.cumsum(1)?.to_vec2::<f32>()?,
|
|
||||||
[[3.0], [1.0], [4.0], [1.0], [5.0]]
|
|
||||||
);
|
|
||||||
let t = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
|
||||||
let t = Tensor::new(t, &Device::Cpu)?;
|
|
||||||
assert_eq!(
|
|
||||||
t.cumsum(1)?.to_vec2::<f32>()?,
|
|
||||||
[[3.0, 4.0, 8.0, 9.0, 14.0], [2.0, 3.0, 10.0, 18.0, 20.0]],
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
t.cumsum(0)?.to_vec2::<f32>()?,
|
|
||||||
[[3.0, 1.0, 4.0, 1.0, 5.0], [5.0, 2.0, 11.0, 9.0, 7.0]]
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
@ -11,8 +11,8 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||||
hf-hub = { workspace = true}
|
hf-hub = { workspace = true}
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
memmap2 = { workspace = true }
|
memmap2 = { workspace = true }
|
||||||
|
@ -11,12 +11,12 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-datasets = { path = "../candle-datasets", version = "0.3.1" }
|
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||||
candle-transformers = { path = "../candle-transformers", version = "0.3.1" }
|
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
||||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true }
|
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||||
candle-onnx = { path = "../candle-onnx", version = "0.3.1", optional = true }
|
candle-onnx = { path = "../candle-onnx", version = "0.3.0", optional = true }
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
half = { workspace = true, optional = true }
|
half = { workspace = true, optional = true }
|
||||||
image = { workspace = true }
|
image = { workspace = true }
|
||||||
@ -57,7 +57,6 @@ flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
|||||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||||
onnx = ["candle-onnx"]
|
onnx = ["candle-onnx"]
|
||||||
metal = ["candle/metal", "candle-nn/metal"]
|
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
name = "llama_multiprocess"
|
name = "llama_multiprocess"
|
||||||
|
@ -1,22 +0,0 @@
|
|||||||
# candle-distilbert
|
|
||||||
|
|
||||||
DistilBert is a distiled version of the Bert model.
|
|
||||||
|
|
||||||
## Sentence embeddings
|
|
||||||
|
|
||||||
DistilBert is used to compute the sentence embeddings for a prompt. The model weights
|
|
||||||
are downloaded from the hub on the first run.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cargo run --example distilbert --release -- --prompt "Here is a test sentence"
|
|
||||||
|
|
||||||
> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441],
|
|
||||||
> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244],
|
|
||||||
> [ 0.0702, -0.1311, -0.4914, ..., 0.3483, -0.6194, 0.1829],
|
|
||||||
> ...
|
|
||||||
> [ 0.2993, -0.0106, -0.4640, ..., 0.2844, -0.6732, 0.0042],
|
|
||||||
> [ 0.1066, -0.0081, -0.4299, ..., 0.3435, -0.7729, 0.0190],
|
|
||||||
> [ 0.8903, 0.2055, -0.2541, ..., 0.3208, -0.6585, 0.0586]]]
|
|
||||||
> Tensor[[1, 7, 768], f32]
|
|
||||||
|
|
||||||
```
|
|
@ -1,135 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE};
|
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
|
||||||
use candle::{Device, Tensor};
|
|
||||||
use candle_nn::VarBuilder;
|
|
||||||
use clap::Parser;
|
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
|
||||||
#[arg(long)]
|
|
||||||
tracing: bool,
|
|
||||||
|
|
||||||
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
|
||||||
#[arg(long)]
|
|
||||||
model_id: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
revision: Option<String>,
|
|
||||||
|
|
||||||
/// When set, compute embeddings for this prompt.
|
|
||||||
#[arg(long)]
|
|
||||||
prompt: String,
|
|
||||||
|
|
||||||
/// Use the pytorch weights rather than the safetensors ones
|
|
||||||
#[arg(long)]
|
|
||||||
use_pth: bool,
|
|
||||||
|
|
||||||
/// The number of times to run the prompt.
|
|
||||||
#[arg(long, default_value = "1")]
|
|
||||||
n: usize,
|
|
||||||
|
|
||||||
/// L2 normalization for embeddings.
|
|
||||||
#[arg(long, default_value = "true")]
|
|
||||||
normalize_embeddings: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Args {
|
|
||||||
fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> {
|
|
||||||
let device = candle_examples::device(self.cpu)?;
|
|
||||||
let default_model = "distilbert-base-uncased".to_string();
|
|
||||||
let default_revision = "main".to_string();
|
|
||||||
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
|
|
||||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
|
||||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
|
||||||
(None, Some(revision)) => (default_model, revision),
|
|
||||||
(None, None) => (default_model, default_revision),
|
|
||||||
};
|
|
||||||
|
|
||||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
|
||||||
let (config_filename, tokenizer_filename, weights_filename) = {
|
|
||||||
let api = Api::new()?;
|
|
||||||
let api = api.repo(repo);
|
|
||||||
let config = api.get("config.json")?;
|
|
||||||
let tokenizer = api.get("tokenizer.json")?;
|
|
||||||
let weights = if self.use_pth {
|
|
||||||
api.get("pytorch_model.bin")?
|
|
||||||
} else {
|
|
||||||
api.get("model.safetensors")?
|
|
||||||
};
|
|
||||||
(config, tokenizer, weights)
|
|
||||||
};
|
|
||||||
let config = std::fs::read_to_string(config_filename)?;
|
|
||||||
let config: Config = serde_json::from_str(&config)?;
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
|
||||||
|
|
||||||
let vb = if self.use_pth {
|
|
||||||
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
|
|
||||||
} else {
|
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
|
||||||
};
|
|
||||||
let model = DistilBertModel::load(vb, &config)?;
|
|
||||||
Ok((model, tokenizer))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_mask(size: usize, device: &Device) -> Tensor {
|
|
||||||
let mask: Vec<_> = (0..size)
|
|
||||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
|
||||||
.collect();
|
|
||||||
Tensor::from_slice(&mask, (size, size), device).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
|
||||||
use tracing_subscriber::prelude::*;
|
|
||||||
|
|
||||||
let args = Args::parse();
|
|
||||||
let _guard = if args.tracing {
|
|
||||||
println!("tracing...");
|
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
|
||||||
Some(guard)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
|
||||||
let device = &model.device;
|
|
||||||
|
|
||||||
let tokenizer = tokenizer
|
|
||||||
.with_padding(None)
|
|
||||||
.with_truncation(None)
|
|
||||||
.map_err(E::msg)?;
|
|
||||||
let tokens = tokenizer
|
|
||||||
.encode(args.prompt, true)
|
|
||||||
.map_err(E::msg)?
|
|
||||||
.get_ids()
|
|
||||||
.to_vec();
|
|
||||||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
|
||||||
let mask = get_mask(tokens.len(), device);
|
|
||||||
|
|
||||||
println!("token_ids: {:?}", token_ids.to_vec2::<u32>());
|
|
||||||
println!("mask: {:?}", mask.to_vec2::<u8>());
|
|
||||||
|
|
||||||
let ys = model.forward(&token_ids, &mask)?;
|
|
||||||
println!("{ys}");
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
|
|
||||||
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
|
|
||||||
}
|
|
@ -53,8 +53,6 @@ enum Which {
|
|||||||
Zephyr7bAlpha,
|
Zephyr7bAlpha,
|
||||||
#[value(name = "7b-zephyr-b")]
|
#[value(name = "7b-zephyr-b")]
|
||||||
Zephyr7bBeta,
|
Zephyr7bBeta,
|
||||||
#[value(name = "7b-open-chat-3.5")]
|
|
||||||
OpenChat35,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Which {
|
impl Which {
|
||||||
@ -69,10 +67,8 @@ impl Which {
|
|||||||
| Self::L7bCode
|
| Self::L7bCode
|
||||||
| Self::L13bCode
|
| Self::L13bCode
|
||||||
| Self::L34bCode => false,
|
| Self::L34bCode => false,
|
||||||
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
|
// Zephyr is a fine tuned version of mistral and should be treated in the same way.
|
||||||
// same way.
|
Self::Zephyr7bAlpha
|
||||||
Self::OpenChat35
|
|
||||||
| Self::Zephyr7bAlpha
|
|
||||||
| Self::Zephyr7bBeta
|
| Self::Zephyr7bBeta
|
||||||
| Self::Mistral7b
|
| Self::Mistral7b
|
||||||
| Self::Mistral7bInstruct => true,
|
| Self::Mistral7bInstruct => true,
|
||||||
@ -91,30 +87,10 @@ impl Which {
|
|||||||
| Self::L13bCode
|
| Self::L13bCode
|
||||||
| Self::L34bCode
|
| Self::L34bCode
|
||||||
| Self::Mistral7b
|
| Self::Mistral7b
|
||||||
| Self::Mistral7bInstruct
|
| Self::Mistral7bInstruct => false,
|
||||||
| Self::OpenChat35 => false,
|
|
||||||
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn is_open_chat(&self) -> bool {
|
|
||||||
match self {
|
|
||||||
Which::L7b
|
|
||||||
| Which::L13b
|
|
||||||
| Which::L70b
|
|
||||||
| Which::L7bChat
|
|
||||||
| Which::L13bChat
|
|
||||||
| Which::L70bChat
|
|
||||||
| Which::L7bCode
|
|
||||||
| Which::L13bCode
|
|
||||||
| Which::L34bCode
|
|
||||||
| Which::Mistral7b
|
|
||||||
| Which::Mistral7bInstruct
|
|
||||||
| Which::Zephyr7bAlpha
|
|
||||||
| Which::Zephyr7bBeta => false,
|
|
||||||
Which::OpenChat35 => true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -181,9 +157,7 @@ impl Args {
|
|||||||
Some(config) => std::path::PathBuf::from(config),
|
Some(config) => std::path::PathBuf::from(config),
|
||||||
None => {
|
None => {
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let repo = if self.which.is_open_chat() {
|
let repo = if self.which.is_mistral() {
|
||||||
"openchat/openchat_3.5"
|
|
||||||
} else if self.which.is_mistral() {
|
|
||||||
"mistralai/Mistral-7B-v0.1"
|
"mistralai/Mistral-7B-v0.1"
|
||||||
} else {
|
} else {
|
||||||
"hf-internal-testing/llama-tokenizer"
|
"hf-internal-testing/llama-tokenizer"
|
||||||
@ -233,7 +207,6 @@ impl Args {
|
|||||||
Which::Zephyr7bBeta => {
|
Which::Zephyr7bBeta => {
|
||||||
("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf")
|
("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf")
|
||||||
}
|
}
|
||||||
Which::OpenChat35 => ("TheBloke/openchat_3.5-GGUF", "openchat_3.5.Q4_K_M.gguf"),
|
|
||||||
};
|
};
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let api = api.model(repo.to_string());
|
let api = api.model(repo.to_string());
|
||||||
@ -335,8 +308,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
| Which::Zephyr7bAlpha
|
| Which::Zephyr7bAlpha
|
||||||
| Which::Zephyr7bBeta
|
| Which::Zephyr7bBeta
|
||||||
| Which::L70b
|
| Which::L70b
|
||||||
| Which::L70bChat
|
| Which::L70bChat => 8,
|
||||||
| Which::OpenChat35 => 8,
|
|
||||||
};
|
};
|
||||||
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
|
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
|
||||||
}
|
}
|
||||||
@ -368,9 +340,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
prompt.pop();
|
prompt.pop();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if args.which.is_open_chat() {
|
if args.which.is_zephyr() {
|
||||||
format!("User: {prompt}<|end_of_turn|>Assistant: ")
|
|
||||||
} else if args.which.is_zephyr() {
|
|
||||||
if prompt_index == 0 || is_interactive {
|
if prompt_index == 0 || is_interactive {
|
||||||
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>",)
|
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>",)
|
||||||
} else {
|
} else {
|
||||||
@ -420,12 +390,8 @@ fn main() -> anyhow::Result<()> {
|
|||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let eos_token = if args.which.is_open_chat() {
|
let eos_token = *tos.tokenizer().get_vocab(true).get("</s>").unwrap();
|
||||||
"<|end_of_turn|>"
|
|
||||||
} else {
|
|
||||||
"</s>"
|
|
||||||
};
|
|
||||||
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
|
|
||||||
let start_post_prompt = std::time::Instant::now();
|
let start_post_prompt = std::time::Instant::now();
|
||||||
let mut sampled = 0;
|
let mut sampled = 0;
|
||||||
for index in 0..to_sample {
|
for index in 0..to_sample {
|
||||||
|
@ -416,7 +416,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
|
|
||||||
println!("Building the autoencoder.");
|
println!("Building the autoencoder.");
|
||||||
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
|
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
|
||||||
let vae = sd_config.build_vae(vae_weights, &device, dtype)?;
|
let vae = sd_config.build_vae(&vae_weights, &device, dtype)?;
|
||||||
let init_latent_dist = match &img2img {
|
let init_latent_dist = match &img2img {
|
||||||
None => None,
|
None => None,
|
||||||
Some(image) => {
|
Some(image) => {
|
||||||
@ -426,7 +426,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
};
|
};
|
||||||
println!("Building the unet.");
|
println!("Building the unet.");
|
||||||
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?;
|
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?;
|
||||||
let unet = sd_config.build_unet(unet_weights, &device, 4, use_flash_attn, dtype)?;
|
let unet = sd_config.build_unet(&unet_weights, &device, 4, use_flash_attn, dtype)?;
|
||||||
|
|
||||||
let t_start = if img2img.is_some() {
|
let t_start = if img2img.is_some() {
|
||||||
n_steps - (n_steps as f64 * img2img_strength) as usize
|
n_steps - (n_steps as f64 * img2img_strength) as usize
|
||||||
|
@ -8,7 +8,7 @@ the model itself.
|
|||||||
## Running an example
|
## Running an example
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cargo run --example trocr --release -- --which base --cpu --image candle-examples/examples/trocr/assets/trocr.png
|
cargo run --example trocr --release -- --which base --cpu --image assets/trocr.png
|
||||||
```
|
```
|
||||||
|
|
||||||
```
|
```
|
||||||
|
@ -128,13 +128,7 @@ impl Decoder {
|
|||||||
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
|
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
|
||||||
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
|
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
|
||||||
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
|
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
|
||||||
let no_speech_token = m::NO_SPEECH_TOKENS
|
let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
|
||||||
.iter()
|
|
||||||
.find_map(|token| token_id(&tokenizer, token).ok());
|
|
||||||
let no_speech_token = match no_speech_token {
|
|
||||||
None => anyhow::bail!("unable to find any non-speech token"),
|
|
||||||
Some(n) => n,
|
|
||||||
};
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
model,
|
model,
|
||||||
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
||||||
@ -518,7 +512,11 @@ fn main() -> Result<()> {
|
|||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
let config = repo.get("config.json")?;
|
let config = repo.get("config.json")?;
|
||||||
let tokenizer = repo.get("tokenizer.json")?;
|
let tokenizer = if args.model == WhichModel::LargeV3 {
|
||||||
|
panic!("openai/whisper-large-v3 does not provide a compatible tokenizer.json config at the moment")
|
||||||
|
} else {
|
||||||
|
repo.get("tokenizer.json")?
|
||||||
|
};
|
||||||
let model = repo.get("model.safetensors")?;
|
let model = repo.get("model.safetensors")?;
|
||||||
(config, tokenizer, model)
|
(config, tokenizer, model)
|
||||||
};
|
};
|
||||||
|
@ -74,9 +74,9 @@ impl TextGeneration {
|
|||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
|
|
||||||
let mut generated_tokens = 0usize;
|
let mut generated_tokens = 0usize;
|
||||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
let eos_token = match self.tokenizer.get_token("</s>") {
|
||||||
Some(token) => token,
|
Some(token) => token,
|
||||||
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
None => anyhow::bail!("cannot find the </s> token"),
|
||||||
};
|
};
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
for index in 0..sample_len {
|
for index in 0..sample_len {
|
||||||
|
@ -43,7 +43,6 @@ pub fn report(
|
|||||||
confidence_threshold: f32,
|
confidence_threshold: f32,
|
||||||
nms_threshold: f32,
|
nms_threshold: f32,
|
||||||
) -> Result<DynamicImage> {
|
) -> Result<DynamicImage> {
|
||||||
let pred = pred.to_device(&Device::Cpu)?;
|
|
||||||
let (npreds, pred_size) = pred.dims2()?;
|
let (npreds, pred_size) = pred.dims2()?;
|
||||||
let nclasses = pred_size - 5;
|
let nclasses = pred_size - 5;
|
||||||
// The bounding boxes grouped by (maximum) class index.
|
// The bounding boxes grouped by (maximum) class index.
|
||||||
|
@ -32,7 +32,7 @@ Image source:
|
|||||||
### Pose Estimation
|
### Pose Estimation
|
||||||
```bash
|
```bash
|
||||||
cargo run --example yolo-v8 --release -- \
|
cargo run --example yolo-v8 --release -- \
|
||||||
candle-examples/examples/yolo-v8/assets/bike.jpg --task pose
|
candle-examples/examples/yolo-v8/assets/peoples.jpeg --task pose
|
||||||
```
|
```
|
||||||
|
|
||||||

|

|
||||||
|
@ -7,7 +7,7 @@ extern crate accelerate_src;
|
|||||||
mod model;
|
mod model;
|
||||||
use model::{Multiples, YoloV8, YoloV8Pose};
|
use model::{Multiples, YoloV8, YoloV8Pose};
|
||||||
|
|
||||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
use candle::{DType, IndexOp, Result, Tensor};
|
||||||
use candle_nn::{Module, VarBuilder};
|
use candle_nn::{Module, VarBuilder};
|
||||||
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
|
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
@ -61,7 +61,6 @@ pub fn report_detect(
|
|||||||
nms_threshold: f32,
|
nms_threshold: f32,
|
||||||
legend_size: u32,
|
legend_size: u32,
|
||||||
) -> Result<DynamicImage> {
|
) -> Result<DynamicImage> {
|
||||||
let pred = pred.to_device(&Device::Cpu)?;
|
|
||||||
let (pred_size, npreds) = pred.dims2()?;
|
let (pred_size, npreds) = pred.dims2()?;
|
||||||
let nclasses = pred_size - 4;
|
let nclasses = pred_size - 4;
|
||||||
// The bounding boxes grouped by (maximum) class index.
|
// The bounding boxes grouped by (maximum) class index.
|
||||||
@ -154,7 +153,6 @@ pub fn report_pose(
|
|||||||
confidence_threshold: f32,
|
confidence_threshold: f32,
|
||||||
nms_threshold: f32,
|
nms_threshold: f32,
|
||||||
) -> Result<DynamicImage> {
|
) -> Result<DynamicImage> {
|
||||||
let pred = pred.to_device(&Device::Cpu)?;
|
|
||||||
let (pred_size, npreds) = pred.dims2()?;
|
let (pred_size, npreds) = pred.dims2()?;
|
||||||
if pred_size != 17 * 3 + 4 + 1 {
|
if pred_size != 17 * 3 + 4 + 1 {
|
||||||
candle::bail!("unexpected pred-size {pred_size}");
|
candle::bail!("unexpected pred-size {pred_size}");
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-flash-attn"
|
name = "candle-flash-attn"
|
||||||
version = "0.3.1"
|
version = "0.3.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Flash attention layer for the candle ML framework."
|
description = "Flash attention layer for the candle ML framework."
|
||||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", features = ["cuda"], version = "0.3.1", package = "candle-core" }
|
candle = { path = "../candle-core", features = ["cuda"], version = "0.3.0", package = "candle-core" }
|
||||||
half = { version = "2.3.1", features = ["num-traits"] }
|
half = { version = "2.3.1", features = ["num-traits"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
@ -21,4 +21,4 @@ rayon = "1.7.0"
|
|||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.3.1", features = ["cuda"] }
|
candle-nn = { path = "../candle-nn", version = "0.3.0", features = ["cuda"] }
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-kernels"
|
name = "candle-kernels"
|
||||||
version = "0.3.1"
|
version = "0.3.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "CUDA kernels for Candle"
|
description = "CUDA kernels for Candle"
|
||||||
|
@ -1,16 +1,17 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-metal-kernels"
|
name = "candle-metal-kernels"
|
||||||
version = "0.3.1"
|
version = "0.3.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Metal kernels for Candle"
|
description = "CUDA kernels for Candle"
|
||||||
repository = "https://github.com/huggingface/candle"
|
repository = "https://github.com/huggingface/candle"
|
||||||
keywords = ["blas", "tensor", "machine-learning"]
|
keywords = ["blas", "tensor", "machine-learning"]
|
||||||
categories = ["science"]
|
categories = ["science"]
|
||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
|
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||||
|
metal = { path = "../../metal-rs", features = ["mps"] }
|
||||||
once_cell = "1.18.0"
|
once_cell = "1.18.0"
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
|
@ -23,12 +23,12 @@ kernel void FN_NAME( \
|
|||||||
constant size_t &dim, \
|
constant size_t &dim, \
|
||||||
device const LEFT_TYPENAME *input, \
|
device const LEFT_TYPENAME *input, \
|
||||||
device RIGHT_TYPENAME *output, \
|
device RIGHT_TYPENAME *output, \
|
||||||
uint tid [[ thread_position_in_grid ]] \
|
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||||
) { \
|
) { \
|
||||||
if (tid >= dim) { \
|
if (thread_position_in_grid >= dim) { \
|
||||||
return; \
|
return; \
|
||||||
} \
|
} \
|
||||||
output[tid] = RIGHT_TYPENAME(input[tid]); \
|
output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \
|
||||||
} \
|
} \
|
||||||
kernel void FN_NAME_STRIDED( \
|
kernel void FN_NAME_STRIDED( \
|
||||||
constant size_t &dim, \
|
constant size_t &dim, \
|
||||||
@ -37,17 +37,15 @@ kernel void FN_NAME_STRIDED( \
|
|||||||
constant size_t *strides, \
|
constant size_t *strides, \
|
||||||
device const LEFT_TYPENAME *input, \
|
device const LEFT_TYPENAME *input, \
|
||||||
device RIGHT_TYPENAME *output, \
|
device RIGHT_TYPENAME *output, \
|
||||||
uint tid [[ thread_position_in_grid ]] \
|
uint i [[ thread_position_in_grid ]] \
|
||||||
) { \
|
) { \
|
||||||
if (tid >= dim) { \
|
if (i >= dim) { \
|
||||||
return; \
|
return; \
|
||||||
} \
|
} \
|
||||||
output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \
|
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
|
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
|
||||||
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
|
|
||||||
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
|
|
||||||
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
|
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
|
||||||
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
|
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
#![allow(clippy::too_many_arguments)]
|
||||||
use metal::{
|
use metal::{
|
||||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
||||||
Device, Function, Library, MTLSize,
|
Device, Function, Library, MTLSize,
|
||||||
@ -126,7 +127,6 @@ macro_rules! ops{
|
|||||||
pub const HALF: Kernel = Kernel("copy_half");
|
pub const HALF: Kernel = Kernel("copy_half");
|
||||||
pub const BFLOAT: Kernel = Kernel("copy_bfloat");
|
pub const BFLOAT: Kernel = Kernel("copy_bfloat");
|
||||||
pub const U32: Kernel = Kernel("copy_u32");
|
pub const U32: Kernel = Kernel("copy_u32");
|
||||||
pub const U8: Kernel = Kernel("copy_u8");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -146,7 +146,6 @@ macro_rules! ops{
|
|||||||
pub const HALF: Kernel = Kernel("copy_half_strided");
|
pub const HALF: Kernel = Kernel("copy_half_strided");
|
||||||
pub const BFLOAT: Kernel = Kernel("copy_bfloat_strided");
|
pub const BFLOAT: Kernel = Kernel("copy_bfloat_strided");
|
||||||
pub const U32: Kernel = Kernel("copy_u32_strided");
|
pub const U32: Kernel = Kernel("copy_u32_strided");
|
||||||
pub const U8: Kernel = Kernel("copy_u8_strided");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -159,6 +158,14 @@ pub mod binary {
|
|||||||
ops!(add, sub, mul, div);
|
ops!(add, sub, mul, div);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// static LIBRARY_SOURCES: Lazy<HashMap<&'static str, &'static str>> = Lazy::new(|| {
|
||||||
|
// let mut l = HashMap::new();
|
||||||
|
// l.insert("affine", AFFINE);
|
||||||
|
// l.insert("indexing", INDEXING);
|
||||||
|
// l.insert("unary", UNARY);
|
||||||
|
// l
|
||||||
|
// });
|
||||||
|
//
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
pub enum MetalKernelError {
|
pub enum MetalKernelError {
|
||||||
#[error("Could not lock kernel map: {0}")]
|
#[error("Could not lock kernel map: {0}")]
|
||||||
@ -199,7 +206,21 @@ impl Kernels {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// pub fn init(device: &Device) -> Result<Self, MetalKernelError> {
|
||||||
|
// let kernels = Self::new();
|
||||||
|
// kernels.load_libraries(device)?;
|
||||||
|
// Ok(kernels)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// fn load_libraries(&self, device: &Device) -> Result<(), MetalKernelError> {
|
||||||
|
// for name in LIBRARY_SOURCES.keys() {
|
||||||
|
// self.load_library(device, name)?;
|
||||||
|
// }
|
||||||
|
// Ok(())
|
||||||
|
// }
|
||||||
|
|
||||||
fn get_library_source(&self, source: Source) -> &'static str {
|
fn get_library_source(&self, source: Source) -> &'static str {
|
||||||
|
// LIBRARY_SOURCES.get(name).cloned()
|
||||||
match source {
|
match source {
|
||||||
Source::Affine => AFFINE,
|
Source::Affine => AFFINE,
|
||||||
Source::Unary => UNARY,
|
Source::Unary => UNARY,
|
||||||
@ -270,7 +291,6 @@ impl Kernels {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub fn call_unary_contiguous(
|
pub fn call_unary_contiguous(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -278,7 +298,7 @@ pub fn call_unary_contiguous(
|
|||||||
kernel_name: unary::contiguous::Kernel,
|
kernel_name: unary::contiguous::Kernel,
|
||||||
length: usize,
|
length: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
@ -291,8 +311,6 @@ pub fn call_unary_contiguous(
|
|||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub fn call_unary_strided(
|
pub fn call_unary_strided(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -302,7 +320,7 @@ pub fn call_unary_strided(
|
|||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
strides: &[usize],
|
strides: &[usize],
|
||||||
offset: usize,
|
offset: usize,
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
output_offset: usize,
|
output_offset: usize,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
||||||
@ -332,7 +350,6 @@ pub fn call_unary_strided(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub fn call_binary_contiguous(
|
pub fn call_binary_contiguous(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -341,7 +358,7 @@ pub fn call_binary_contiguous(
|
|||||||
length: usize,
|
length: usize,
|
||||||
left: &Buffer,
|
left: &Buffer,
|
||||||
right: &Buffer,
|
right: &Buffer,
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
||||||
|
|
||||||
@ -357,7 +374,6 @@ pub fn call_binary_contiguous(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub fn call_binary_strided(
|
pub fn call_binary_strided(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -370,7 +386,7 @@ pub fn call_binary_strided(
|
|||||||
right_input: &Buffer,
|
right_input: &Buffer,
|
||||||
right_strides: &[usize],
|
right_strides: &[usize],
|
||||||
right_offset: usize,
|
right_offset: usize,
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
|
||||||
|
|
||||||
@ -402,7 +418,6 @@ pub fn call_binary_strided(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub fn call_cast_contiguous(
|
pub fn call_cast_contiguous(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -410,23 +425,22 @@ pub fn call_cast_contiguous(
|
|||||||
kernel_name: &'static str,
|
kernel_name: &'static str,
|
||||||
length: usize,
|
length: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
input_offset: usize,
|
output: &mut Buffer,
|
||||||
output: &Buffer,
|
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (length, (input, input_offset), output));
|
set_params!(encoder, (length, input, output));
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||||
|
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub fn call_cast_strided(
|
pub fn call_cast_strided(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -436,8 +450,10 @@ pub fn call_cast_strided(
|
|||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
input_strides: &[usize],
|
input_strides: &[usize],
|
||||||
input_offset: usize,
|
input_offset: usize,
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
|
// println!("Kernel {:?}", kernel_name.0);
|
||||||
|
// assert_eq!(input.length(), output.length());
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
@ -447,14 +463,7 @@ pub fn call_cast_strided(
|
|||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
encoder,
|
encoder,
|
||||||
(
|
(length, shape, input_strides, (input, input_offset), output)
|
||||||
length,
|
|
||||||
shape.len(),
|
|
||||||
shape,
|
|
||||||
input_strides,
|
|
||||||
(input, input_offset),
|
|
||||||
output
|
|
||||||
)
|
|
||||||
);
|
);
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||||
@ -472,8 +481,7 @@ pub fn call_reduce_contiguous(
|
|||||||
length: usize,
|
length: usize,
|
||||||
out_length: usize,
|
out_length: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
input_offset: usize,
|
output: &mut Buffer,
|
||||||
output: &Buffer,
|
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||||
let elements_to_sum = length / out_length;
|
let elements_to_sum = length / out_length;
|
||||||
@ -481,10 +489,7 @@ pub fn call_reduce_contiguous(
|
|||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(encoder, (length, elements_to_sum, input, output));
|
||||||
encoder,
|
|
||||||
(length, elements_to_sum, (input, input_offset), output)
|
|
||||||
);
|
|
||||||
|
|
||||||
let thread_group_count = MTLSize {
|
let thread_group_count = MTLSize {
|
||||||
width: out_length as u64,
|
width: out_length as u64,
|
||||||
@ -509,7 +514,6 @@ pub fn call_reduce_contiguous(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub fn call_last_softmax(
|
pub fn call_last_softmax(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -518,7 +522,7 @@ pub fn call_last_softmax(
|
|||||||
length: usize,
|
length: usize,
|
||||||
elements_to_sum: usize,
|
elements_to_sum: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
@ -536,6 +540,7 @@ pub fn call_last_softmax(
|
|||||||
|
|
||||||
let width = std::cmp::min(
|
let width = std::cmp::min(
|
||||||
pipeline.max_total_threads_per_threadgroup(),
|
pipeline.max_total_threads_per_threadgroup(),
|
||||||
|
// (elements_to_sum as u64 + 2 - 1) / 2,
|
||||||
elements_to_sum as u64,
|
elements_to_sum as u64,
|
||||||
)
|
)
|
||||||
.next_power_of_two();
|
.next_power_of_two();
|
||||||
@ -551,7 +556,6 @@ pub fn call_last_softmax(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub fn call_affine(
|
pub fn call_affine(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -559,7 +563,7 @@ pub fn call_affine(
|
|||||||
name: &'static str,
|
name: &'static str,
|
||||||
size: usize,
|
size: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
mul: f32,
|
mul: f32,
|
||||||
add: f32,
|
add: f32,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
@ -576,7 +580,6 @@ pub fn call_affine(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub fn call_affine_strided(
|
pub fn call_affine_strided(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -586,7 +589,7 @@ pub fn call_affine_strided(
|
|||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
input_stride: &[usize],
|
input_stride: &[usize],
|
||||||
input_offset: usize,
|
input_offset: usize,
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
mul: f32,
|
mul: f32,
|
||||||
add: f32,
|
add: f32,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
@ -628,7 +631,7 @@ pub fn call_where_cond_strided(
|
|||||||
(left_stride, left_offset): (&[usize], usize),
|
(left_stride, left_offset): (&[usize], usize),
|
||||||
right: &Buffer,
|
right: &Buffer,
|
||||||
(right_stride, right_offset): (&[usize], usize),
|
(right_stride, right_offset): (&[usize], usize),
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
||||||
|
|
||||||
@ -661,7 +664,6 @@ pub fn call_where_cond_strided(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub fn call_index_select(
|
pub fn call_index_select(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -672,7 +674,7 @@ pub fn call_index_select(
|
|||||||
dim: usize,
|
dim: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
ids: &Buffer,
|
ids: &Buffer,
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let left_size: usize = shape[..dim].iter().product();
|
let left_size: usize = shape[..dim].iter().product();
|
||||||
let right_size: usize = shape[dim + 1..].iter().product();
|
let right_size: usize = shape[dim + 1..].iter().product();
|
||||||
@ -707,4 +709,679 @@ pub fn call_index_select(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests;
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use half::f16;
|
||||||
|
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
|
||||||
|
|
||||||
|
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
|
||||||
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
|
let ptr = data.as_ptr() as *const core::ffi::c_void;
|
||||||
|
let size = (data.len() * std::mem::size_of::<T>()) as u64;
|
||||||
|
device.new_buffer_with_data(ptr, size, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn device() -> Device {
|
||||||
|
Device::system_default().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
|
||||||
|
let b = 10f32.powi(digits);
|
||||||
|
v.iter().map(|t| f32::round(t * b) / b).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> {
|
||||||
|
let b = 10f32.powi(digits);
|
||||||
|
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
||||||
|
let device = device();
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let input = new_buffer(&device, v);
|
||||||
|
let mut output = new_buffer(&device, v);
|
||||||
|
call_unary_contiguous(
|
||||||
|
&device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
name,
|
||||||
|
v.len(),
|
||||||
|
&input,
|
||||||
|
&mut output,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
output.read_to_vec::<T>(v.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
|
||||||
|
let device = device();
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
|
let left = new_buffer(&device, x);
|
||||||
|
let right = new_buffer(&device, y);
|
||||||
|
let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
|
||||||
|
call_binary_contiguous(
|
||||||
|
&device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
name,
|
||||||
|
x.len(),
|
||||||
|
&left,
|
||||||
|
&right,
|
||||||
|
&mut output,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
output.read_to_vec::<T>(x.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_strided<T: Clone>(
|
||||||
|
v: &[T],
|
||||||
|
kernel: unary::strided::Kernel,
|
||||||
|
shape: &[usize],
|
||||||
|
strides: &[usize],
|
||||||
|
offset: usize,
|
||||||
|
) -> Vec<T> {
|
||||||
|
let device = device();
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let input = new_buffer(&device, v);
|
||||||
|
let mut output = new_buffer(&device, v);
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
call_unary_strided(
|
||||||
|
&device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
kernel,
|
||||||
|
shape,
|
||||||
|
&input,
|
||||||
|
strides,
|
||||||
|
offset,
|
||||||
|
&mut output,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
output.read_to_vec::<T>(v.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cos_f32() {
|
||||||
|
let v = vec![1.0f32, 2.0, 3.0];
|
||||||
|
let results = run(&v, unary::contiguous::cos::FLOAT);
|
||||||
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||||
|
assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
|
||||||
|
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
|
||||||
|
|
||||||
|
let v = vec![1.0f32; 10_000];
|
||||||
|
let results = run(&v, unary::contiguous::cos::FLOAT);
|
||||||
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||||
|
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
|
||||||
|
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cos_f32_strided() {
|
||||||
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
|
// Shape = [6], strides = [1];
|
||||||
|
let shape = vec![6];
|
||||||
|
let strides = vec![1];
|
||||||
|
let offset = 0;
|
||||||
|
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||||
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||||
|
assert_eq!(
|
||||||
|
approx(results, 4),
|
||||||
|
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
approx(expected, 4),
|
||||||
|
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||||
|
);
|
||||||
|
|
||||||
|
// Contiguous
|
||||||
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
|
let shape = vec![3, 2];
|
||||||
|
let strides = vec![2, 1];
|
||||||
|
let offset = 0;
|
||||||
|
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||||
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||||
|
assert_eq!(
|
||||||
|
approx(results, 4),
|
||||||
|
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
approx(expected, 4),
|
||||||
|
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||||
|
);
|
||||||
|
|
||||||
|
// Transposed
|
||||||
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
|
let shape = vec![3, 2];
|
||||||
|
let strides = vec![1, 3];
|
||||||
|
let offset = 0;
|
||||||
|
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||||
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||||
|
assert_eq!(
|
||||||
|
approx(results, 4),
|
||||||
|
vec![0.5403, -0.6536, -0.4161, 0.2837, -0.99, 0.9602]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
approx(expected, 4),
|
||||||
|
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||||
|
);
|
||||||
|
|
||||||
|
// Very large
|
||||||
|
let v = vec![1.0f32; 10_000];
|
||||||
|
let shape = vec![2, 5_000];
|
||||||
|
let strides = vec![2, 1];
|
||||||
|
let offset = 0;
|
||||||
|
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||||
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||||
|
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
|
||||||
|
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cos_strided_random() {
|
||||||
|
let v: Vec<_> = (0..10_000).map(|i| rand::random::<f32>()).collect();
|
||||||
|
let shape = vec![5_000, 2];
|
||||||
|
let strides = vec![1, 5_000];
|
||||||
|
let offset = 0;
|
||||||
|
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||||
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||||
|
assert_eq!(approx(vec![results[0]], 4), approx(vec![expected[0]], 4));
|
||||||
|
assert_eq!(
|
||||||
|
approx(vec![results[1]], 4),
|
||||||
|
approx(vec![expected[5_000]], 4)
|
||||||
|
);
|
||||||
|
assert_eq!(approx(vec![results[2]], 4), approx(vec![expected[1]], 4));
|
||||||
|
assert_eq!(
|
||||||
|
approx(vec![results[3]], 4),
|
||||||
|
approx(vec![expected[5_001]], 4)
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
approx(vec![results[5_000]], 4),
|
||||||
|
approx(vec![expected[2_500]], 4)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn binary_add_f32() {
|
||||||
|
let left = vec![1.0f32, 2.0, 3.0];
|
||||||
|
let right = vec![2.0f32, 3.1, 4.2];
|
||||||
|
let results = run_binary(&left, &right, binary::contiguous::add::FLOAT);
|
||||||
|
let expected: Vec<_> = left
|
||||||
|
.iter()
|
||||||
|
.zip(right.iter())
|
||||||
|
.map(|(&x, &y)| x + y)
|
||||||
|
.collect();
|
||||||
|
assert_eq!(approx(results, 4), vec![3.0f32, 5.1, 7.2]);
|
||||||
|
assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||||
|
let device = device();
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let input = new_buffer(&device, v);
|
||||||
|
let mut output = new_buffer(&device, v);
|
||||||
|
|
||||||
|
call_cast_contiguous(
|
||||||
|
&device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
name,
|
||||||
|
v.len(),
|
||||||
|
&input,
|
||||||
|
&mut output,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
output.read_to_vec::<U>(v.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cast_u32_f32() {
|
||||||
|
let v = vec![1u32, 2, 3];
|
||||||
|
let results = cast(&v, "cast_u32_f32");
|
||||||
|
let expected: Vec<_> = v.iter().map(|&v| v as f32).collect();
|
||||||
|
assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
|
||||||
|
assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
|
||||||
|
|
||||||
|
let v = vec![1.0f32; 10_000];
|
||||||
|
let results = run(&v, unary::contiguous::cos::FLOAT);
|
||||||
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||||
|
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
|
||||||
|
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||||
|
let device = device();
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
|
||||||
|
let input = new_buffer(&device, v);
|
||||||
|
let mut output = new_buffer(&device, v);
|
||||||
|
|
||||||
|
let size = v.len();
|
||||||
|
|
||||||
|
call_affine(
|
||||||
|
&device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
"affine_float",
|
||||||
|
size,
|
||||||
|
&input,
|
||||||
|
&mut output,
|
||||||
|
mul as f32,
|
||||||
|
add as f32,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
output.read_to_vec::<T>(v.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_affine_strided<T: Clone>(
|
||||||
|
v: &[T],
|
||||||
|
shape: &[usize],
|
||||||
|
strides: &[usize],
|
||||||
|
mul: f64,
|
||||||
|
add: f64,
|
||||||
|
) -> Vec<T> {
|
||||||
|
let device = device();
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
|
||||||
|
let input = new_buffer(&device, v);
|
||||||
|
let mut output = new_buffer(&device, v);
|
||||||
|
|
||||||
|
let size = v.len();
|
||||||
|
|
||||||
|
call_affine_strided(
|
||||||
|
&device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
"affine_float",
|
||||||
|
shape,
|
||||||
|
&input,
|
||||||
|
strides,
|
||||||
|
0,
|
||||||
|
&mut output,
|
||||||
|
mul as f32,
|
||||||
|
add as f32,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
output.read_to_vec::<T>(v.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn affine() {
|
||||||
|
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||||
|
let mul = 1.5;
|
||||||
|
let add = 1.1;
|
||||||
|
let result = run_affine(&input, mul, add);
|
||||||
|
assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]);
|
||||||
|
|
||||||
|
let input = [1.0f32; 40_000];
|
||||||
|
let mul = 1.5;
|
||||||
|
let add = 1.1;
|
||||||
|
let result = run_affine(&input, mul, add);
|
||||||
|
assert_eq!(result, vec![2.6; 40_000]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// #[test]
|
||||||
|
// fn affine_strided() {
|
||||||
|
// let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||||
|
// let mul = 1.5;
|
||||||
|
// let add = 1.1;
|
||||||
|
// let result = run_affine_(&input, mul, add);
|
||||||
|
// assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]);
|
||||||
|
|
||||||
|
// }
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn index_select() {
|
||||||
|
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||||
|
let shape = [5, 2];
|
||||||
|
let ids = [0u32, 4, 2];
|
||||||
|
let dim = 0;
|
||||||
|
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||||
|
assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]);
|
||||||
|
|
||||||
|
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||||
|
let shape = [2, 5];
|
||||||
|
let ids = [0u32, 1, 0];
|
||||||
|
let dim = 0;
|
||||||
|
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||||
|
assert_eq!(
|
||||||
|
result,
|
||||||
|
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn index_select_dim1() {
|
||||||
|
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||||
|
let shape = [5, 2];
|
||||||
|
let ids = [0u32, 1, 0];
|
||||||
|
let dim = 1;
|
||||||
|
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||||
|
assert_eq!(
|
||||||
|
result,
|
||||||
|
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
||||||
|
embeddings: &[T],
|
||||||
|
shape: &[usize],
|
||||||
|
ids: &[I],
|
||||||
|
dim: usize,
|
||||||
|
) -> Vec<T> {
|
||||||
|
let device = Device::system_default().expect("no device found");
|
||||||
|
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let embeddings_buffer = new_buffer(&device, &embeddings);
|
||||||
|
let ids_buffer = new_buffer(&device, &ids);
|
||||||
|
|
||||||
|
let left_size: usize = shape[..dim].iter().product();
|
||||||
|
let right_size: usize = shape[dim + 1..].iter().product();
|
||||||
|
let dst_el = ids.len() * left_size * right_size;
|
||||||
|
let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
|
||||||
|
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
call_index_select(
|
||||||
|
&device,
|
||||||
|
&command_buffer,
|
||||||
|
&kernels,
|
||||||
|
"is_u32_f32",
|
||||||
|
shape,
|
||||||
|
ids.len(),
|
||||||
|
dim,
|
||||||
|
&embeddings_buffer,
|
||||||
|
&ids_buffer,
|
||||||
|
&mut dst_buffer,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
dst_buffer.read_to_vec::<T>(dst_el)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn index_add() {
|
||||||
|
let device = Device::system_default().expect("no device found");
|
||||||
|
|
||||||
|
let options = CompileOptions::new();
|
||||||
|
options.set_fast_math_enabled(true);
|
||||||
|
let library = device.new_library_with_source(INDEXING, &options).unwrap();
|
||||||
|
|
||||||
|
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
|
||||||
|
let right = [1.0f32; 15];
|
||||||
|
let index = [0u32, 4, 2];
|
||||||
|
let ids_dim_size = index.len() as u32;
|
||||||
|
let dst_dim_size: u32 = 15;
|
||||||
|
let left_size: u32 = 3;
|
||||||
|
let right_size: u32 = 3;
|
||||||
|
|
||||||
|
let function = library.get_function("ia_u32_f32", None).unwrap();
|
||||||
|
let pipeline = device
|
||||||
|
.new_compute_pipeline_state_with_function(&function)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
|
let index_buffer = new_buffer(&device, &index);
|
||||||
|
let inputs_buffer = new_buffer(&device, &left);
|
||||||
|
let outputs_buffer = new_buffer(&device, &right);
|
||||||
|
|
||||||
|
set_params!(
|
||||||
|
encoder,
|
||||||
|
(
|
||||||
|
&index_buffer,
|
||||||
|
&inputs_buffer,
|
||||||
|
&outputs_buffer,
|
||||||
|
ids_dim_size,
|
||||||
|
left_size,
|
||||||
|
dst_dim_size,
|
||||||
|
right_size
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
let grid_size = MTLSize {
|
||||||
|
width: right.len() as NSUInteger,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
let thread_group_size = MTLSize {
|
||||||
|
width: pipeline.max_total_threads_per_threadgroup(),
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
encoder.dispatch_thread_groups(grid_size, thread_group_size);
|
||||||
|
encoder.end_encoding();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
let expected = vec![
|
||||||
|
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
|
||||||
|
];
|
||||||
|
let result = outputs_buffer.read_to_vec::<f32>(right.len());
|
||||||
|
assert_eq!(result, expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cos_f16() {
|
||||||
|
let v: Vec<f16> = [1.0f32, 2.0, 3.0]
|
||||||
|
.iter()
|
||||||
|
.map(|v| f16::from_f32(*v))
|
||||||
|
.collect();
|
||||||
|
let results = run(&v, unary::contiguous::cos::HALF);
|
||||||
|
let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
|
||||||
|
assert_eq!(approx_f16(results, 4), vec![0.54, -0.4165, -0.9902]);
|
||||||
|
assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
|
||||||
|
let device = device();
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let input = new_buffer(&device, v);
|
||||||
|
|
||||||
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
|
let mut output =
|
||||||
|
device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
||||||
|
call_reduce_contiguous(
|
||||||
|
&device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
name,
|
||||||
|
v.len(),
|
||||||
|
out_length,
|
||||||
|
&input,
|
||||||
|
&mut output,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
output.read_to_vec::<T>(out_length)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_softmax<T: Clone + std::fmt::Debug>(
|
||||||
|
v: &[T],
|
||||||
|
last_dim: usize,
|
||||||
|
name: &'static str,
|
||||||
|
) -> Vec<T> {
|
||||||
|
let device = device();
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let input = new_buffer(&device, v);
|
||||||
|
let mut output = new_buffer(&device, v);
|
||||||
|
call_last_softmax(
|
||||||
|
&device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
name,
|
||||||
|
v.len(),
|
||||||
|
last_dim,
|
||||||
|
&input,
|
||||||
|
&mut output,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
output.read_to_vec::<T>(v.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn reduce_sum() {
|
||||||
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
|
let out_length = 1;
|
||||||
|
|
||||||
|
let results = run_reduce(&v, out_length, "fast_sum_float");
|
||||||
|
assert_eq!(approx(results, 4), vec![21.0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn reduce_sum2() {
|
||||||
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
|
let out_length = 2;
|
||||||
|
|
||||||
|
let results = run_reduce(&v, out_length, "fast_sum_float");
|
||||||
|
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn softmax() {
|
||||||
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
|
let last_dim = 6;
|
||||||
|
let results = run_softmax(&v, last_dim, "softmax_float");
|
||||||
|
assert_eq!(
|
||||||
|
approx(results, 4),
|
||||||
|
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
|
||||||
|
);
|
||||||
|
|
||||||
|
let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
|
||||||
|
let last_dim = 6;
|
||||||
|
let results = run_softmax(&v, last_dim, "softmax_float");
|
||||||
|
assert_eq!(
|
||||||
|
approx(results, 4),
|
||||||
|
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
|
||||||
|
);
|
||||||
|
|
||||||
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
|
let last_dim = 3;
|
||||||
|
let results = run_softmax(&v, last_dim, "softmax_float");
|
||||||
|
assert_eq!(
|
||||||
|
approx(results, 4),
|
||||||
|
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_where_cond<I: Clone, T: Clone>(
|
||||||
|
shape: &[usize],
|
||||||
|
cond: &[I],
|
||||||
|
(cond_stride, cond_offset): (Vec<usize>, usize),
|
||||||
|
left_true: &[T],
|
||||||
|
(left_stride, left_offset): (Vec<usize>, usize),
|
||||||
|
right_false: &[T],
|
||||||
|
(_right_stride, _right_offset): (Vec<usize>, usize),
|
||||||
|
name: &'static str,
|
||||||
|
) -> Vec<T> {
|
||||||
|
let device = device();
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
|
|
||||||
|
let length = cond.len();
|
||||||
|
let cond = device.new_buffer_with_data(
|
||||||
|
cond.as_ptr() as *const core::ffi::c_void,
|
||||||
|
std::mem::size_of_val(cond) as u64,
|
||||||
|
options,
|
||||||
|
);
|
||||||
|
let left = device.new_buffer_with_data(
|
||||||
|
left_true.as_ptr() as *const core::ffi::c_void,
|
||||||
|
(length * core::mem::size_of::<T>()) as u64,
|
||||||
|
options,
|
||||||
|
);
|
||||||
|
let right = device.new_buffer_with_data(
|
||||||
|
right_false.as_ptr() as *const core::ffi::c_void,
|
||||||
|
(length * core::mem::size_of::<T>()) as u64,
|
||||||
|
options,
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
||||||
|
call_where_cond_strided(
|
||||||
|
&device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
name,
|
||||||
|
shape,
|
||||||
|
&cond,
|
||||||
|
(&cond_stride, cond_offset),
|
||||||
|
&left,
|
||||||
|
(&left_stride, left_offset),
|
||||||
|
&right,
|
||||||
|
(&cond_stride, cond_offset),
|
||||||
|
&mut output,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
output.read_to_vec::<T>(length)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn where_cond() {
|
||||||
|
let shape = vec![6];
|
||||||
|
let cond = vec![0u8, 1, 0, 0, 1, 1];
|
||||||
|
let cond_l = (vec![1], 0);
|
||||||
|
let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
|
let left_l = (vec![1], 0);
|
||||||
|
let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0];
|
||||||
|
let right_l = (vec![1], 0);
|
||||||
|
let results = run_where_cond(
|
||||||
|
&shape,
|
||||||
|
&cond,
|
||||||
|
cond_l,
|
||||||
|
&left_true,
|
||||||
|
left_l,
|
||||||
|
&right_false,
|
||||||
|
right_l,
|
||||||
|
"where_u8_f32",
|
||||||
|
);
|
||||||
|
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
|
||||||
|
|
||||||
METAL_FUNC uint get_strided_index(
|
METAL_FUNC uint get_strided_index(
|
||||||
uint idx,
|
uint idx,
|
||||||
constant size_t &num_dims,
|
constant size_t &num_dims,
|
||||||
@ -18,18 +16,18 @@ METAL_FUNC uint get_strided_index(
|
|||||||
return strided_i;
|
return strided_i;
|
||||||
}
|
}
|
||||||
|
|
||||||
constant int THREADGROUP_SIZE = 1024;
|
constant int THREADGROUP_SIZE = 256;
|
||||||
|
|
||||||
# define REDUCE(FN, NAME, T) \
|
# define REDUCE(FN, NAME, TYPENAME) \
|
||||||
kernel void NAME( \
|
kernel void NAME( \
|
||||||
constant size_t &src_numel, \
|
constant size_t &src_numel, \
|
||||||
constant size_t &el_to_sum_per_block, \
|
constant size_t &el_to_sum_per_block, \
|
||||||
device const T *src, \
|
device const TYPENAME *src, \
|
||||||
device T *dst, \
|
device TYPENAME *dst, \
|
||||||
uint id [[ thread_position_in_grid ]], \
|
uint id [[ thread_position_in_grid ]], \
|
||||||
uint tid [[ thread_index_in_threadgroup ]], \
|
uint tid [[ thread_index_in_threadgroup ]], \
|
||||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||||
uint block_dim [[ threads_per_threadgroup ]] \
|
uint blockDim [[ threads_per_threadgroup ]] \
|
||||||
) { \
|
) { \
|
||||||
\
|
\
|
||||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||||
@ -47,10 +45,10 @@ kernel void NAME( \
|
|||||||
// TODO: Fast version for the contiguous case. \
|
// TODO: Fast version for the contiguous case. \
|
||||||
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||||
*/ \
|
*/ \
|
||||||
T x = shared_memory[tid]; \
|
TYPENAME x = shared_memory[tid]; \
|
||||||
T y = src[idx]; \
|
TYPENAME y = src[idx]; \
|
||||||
shared_memory[tid] = FN; \
|
shared_memory[tid] = FN; \
|
||||||
idx += block_dim; \
|
idx += blockDim; \
|
||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
threadgroup_barrier(mem_flags::mem_none); \
|
threadgroup_barrier(mem_flags::mem_none); \
|
||||||
@ -58,10 +56,10 @@ kernel void NAME( \
|
|||||||
/* \
|
/* \
|
||||||
// reduction in shared memory \
|
// reduction in shared memory \
|
||||||
*/ \
|
*/ \
|
||||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
for (uint s = blockDim / 2; s > 0; s >>= 1) { \
|
||||||
if (tid < s) { \
|
if (tid < s) { \
|
||||||
T x = shared_memory[tid]; \
|
TYPENAME x = shared_memory[tid]; \
|
||||||
T y = shared_memory[tid + s]; \
|
TYPENAME y = shared_memory[tid + s]; \
|
||||||
shared_memory[tid] = FN; \
|
shared_memory[tid] = FN; \
|
||||||
} \
|
} \
|
||||||
threadgroup_barrier(mem_flags::mem_none); \
|
threadgroup_barrier(mem_flags::mem_none); \
|
||||||
@ -70,74 +68,72 @@ kernel void NAME( \
|
|||||||
dst[dst_id] = shared_memory[0]; \
|
dst[dst_id] = shared_memory[0]; \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
|
kernel void softmax_float(
|
||||||
|
constant size_t &src_numel,
|
||||||
|
constant size_t &el_to_sum_per_block,
|
||||||
|
device const float *src,
|
||||||
|
device float *dst,
|
||||||
|
uint id [[ thread_position_in_grid ]],
|
||||||
|
uint tid [[ thread_index_in_threadgroup ]],
|
||||||
|
uint dst_id [[ threadgroup_position_in_grid ]],
|
||||||
|
uint blockDim [[ threads_per_threadgroup ]]
|
||||||
|
) {
|
||||||
|
|
||||||
|
threadgroup float shared_memory[THREADGROUP_SIZE];
|
||||||
|
|
||||||
|
shared_memory[tid] = -INFINITY;
|
||||||
|
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||||
|
// to (dst_id + 1) * el_to_sum_per_block.
|
||||||
|
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||||
|
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
||||||
|
size_t idx = start_idx + tid;
|
||||||
|
|
||||||
|
while (idx < stop_idx) {
|
||||||
|
// TODO: Fast version for the contiguous case.
|
||||||
|
shared_memory[tid] = max(shared_memory[tid], src[idx]);
|
||||||
|
idx += blockDim;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// reduction in shared memory
|
||||||
|
for (uint s = blockDim / 2; s > 0; s >>= 1) {
|
||||||
|
if (tid < s) {
|
||||||
|
shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]);
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
}
|
||||||
|
|
||||||
|
float max = shared_memory[0];
|
||||||
|
|
||||||
|
shared_memory[tid] = 0;
|
||||||
|
|
||||||
|
// Restart
|
||||||
|
idx = start_idx + tid;
|
||||||
|
while (idx < stop_idx) {
|
||||||
|
// TODO: Fast version for the contiguous case.
|
||||||
|
const float val = exp(src[idx] - max);
|
||||||
|
dst[idx] = val;
|
||||||
|
shared_memory[tid] += val;
|
||||||
|
idx += blockDim;
|
||||||
|
}
|
||||||
|
// reduction in shared memory
|
||||||
|
for (uint s = blockDim / 2; s > 0; s >>= 1) {
|
||||||
|
if (tid < s) {
|
||||||
|
shared_memory[tid] += shared_memory[tid + s];
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
}
|
||||||
|
|
||||||
|
const float inv_acc = 1/shared_memory[0];
|
||||||
|
idx = start_idx + tid;
|
||||||
|
while (idx < stop_idx) {
|
||||||
|
dst[idx] *= inv_acc;
|
||||||
|
idx += blockDim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
REDUCE(x + y, fast_sum_float, float)
|
REDUCE(x + y, fast_sum_float, float)
|
||||||
REDUCE(x * y, fast_mul_float, float)
|
REDUCE(x * y, fast_mul_float, float)
|
||||||
REDUCE(max(x, y), fast_max_float, float)
|
REDUCE(max(x, y), fast_max_float, float)
|
||||||
|
|
||||||
#define SOFTMAX(NAME, T) \
|
|
||||||
kernel void NAME( \
|
|
||||||
constant size_t &src_numel, \
|
|
||||||
constant size_t &el_to_sum_per_block, \
|
|
||||||
device const T *src, \
|
|
||||||
device T *dst, \
|
|
||||||
\
|
|
||||||
uint id [[ thread_position_in_grid ]], \
|
|
||||||
uint tid [[ thread_index_in_threadgroup ]], \
|
|
||||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
|
||||||
uint block_dim [[ threads_per_threadgroup ]] \
|
|
||||||
) { \
|
|
||||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
|
||||||
shared_memory[tid] = -INFINITY; \
|
|
||||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
|
||||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
|
||||||
size_t idx = start_idx + tid; \
|
|
||||||
\
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
|
||||||
\
|
|
||||||
while (idx < stop_idx) { \
|
|
||||||
shared_memory[tid] = MAX(shared_memory[tid], src[idx]); \
|
|
||||||
idx += block_dim; \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
|
||||||
\
|
|
||||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
|
||||||
if (tid < s) { \
|
|
||||||
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
|
|
||||||
} \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
|
||||||
\
|
|
||||||
float _max = shared_memory[0]; \
|
|
||||||
\
|
|
||||||
shared_memory[tid] = 0; \
|
|
||||||
\
|
|
||||||
idx = start_idx + tid; \
|
|
||||||
while (idx < stop_idx) { \
|
|
||||||
const T val = T(exp(src[idx] - _max)); \
|
|
||||||
dst[idx] = val; \
|
|
||||||
shared_memory[tid] += val; \
|
|
||||||
idx += block_dim; \
|
|
||||||
} \
|
|
||||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
|
||||||
if (tid < s) { \
|
|
||||||
shared_memory[tid] += shared_memory[tid + s]; \
|
|
||||||
} \
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
const T inv_acc = T(1/shared_memory[0]); \
|
|
||||||
idx = start_idx + tid; \
|
|
||||||
while (idx < stop_idx) { \
|
|
||||||
dst[idx] *= inv_acc; \
|
|
||||||
idx += block_dim; \
|
|
||||||
} \
|
|
||||||
} \
|
|
||||||
|
|
||||||
SOFTMAX(softmax_float, float)
|
|
||||||
SOFTMAX(softmax_half, half)
|
|
||||||
#if __METAL_VERSION__ >= 310
|
|
||||||
SOFTMAX(softmax_bfloat, bfloat)
|
|
||||||
#endif
|
|
||||||
|
@ -32,9 +32,6 @@ kernel void FN_NAME( \
|
|||||||
device TYPENAME *out ,\
|
device TYPENAME *out ,\
|
||||||
uint i [[ thread_position_in_grid ]] \
|
uint i [[ thread_position_in_grid ]] \
|
||||||
) { \
|
) { \
|
||||||
if (i >= numel){ \
|
|
||||||
return; \
|
|
||||||
} \
|
|
||||||
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
|
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||||
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
|
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
|
||||||
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
|
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
|
||||||
|
@ -1,746 +0,0 @@
|
|||||||
use super::*;
|
|
||||||
use half::{bf16, f16};
|
|
||||||
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
|
|
||||||
|
|
||||||
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
|
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
|
||||||
let ptr = data.as_ptr() as *const core::ffi::c_void;
|
|
||||||
let size = (data.len() * std::mem::size_of::<T>()) as u64;
|
|
||||||
device.new_buffer_with_data(ptr, size, options)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn device() -> Device {
|
|
||||||
Device::system_default().unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
|
|
||||||
let b = 10f32.powi(digits);
|
|
||||||
v.iter().map(|t| f32::round(t * b) / b).collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> {
|
|
||||||
let b = 10f32.powi(digits);
|
|
||||||
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> {
|
|
||||||
let b = 10f32.powi(digits);
|
|
||||||
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
|
||||||
let device = device();
|
|
||||||
let kernels = Kernels::new();
|
|
||||||
let command_queue = device.new_command_queue();
|
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
|
||||||
let input = new_buffer(&device, v);
|
|
||||||
let output = new_buffer(&device, v);
|
|
||||||
call_unary_contiguous(
|
|
||||||
&device,
|
|
||||||
command_buffer,
|
|
||||||
&kernels,
|
|
||||||
name,
|
|
||||||
v.len(),
|
|
||||||
&input,
|
|
||||||
&output,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
output.read_to_vec::<T>(v.len())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
|
|
||||||
let device = device();
|
|
||||||
let kernels = Kernels::new();
|
|
||||||
let command_queue = device.new_command_queue();
|
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
|
||||||
let left = new_buffer(&device, x);
|
|
||||||
let right = new_buffer(&device, y);
|
|
||||||
let output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
|
|
||||||
call_binary_contiguous(
|
|
||||||
&device,
|
|
||||||
command_buffer,
|
|
||||||
&kernels,
|
|
||||||
name,
|
|
||||||
x.len(),
|
|
||||||
&left,
|
|
||||||
&right,
|
|
||||||
&output,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
output.read_to_vec::<T>(x.len())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_strided<T: Clone>(
|
|
||||||
v: &[T],
|
|
||||||
kernel: unary::strided::Kernel,
|
|
||||||
shape: &[usize],
|
|
||||||
strides: &[usize],
|
|
||||||
offset: usize,
|
|
||||||
) -> Vec<T> {
|
|
||||||
let device = device();
|
|
||||||
let command_queue = device.new_command_queue();
|
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
|
||||||
let input = new_buffer(&device, v);
|
|
||||||
let output = new_buffer(&device, v);
|
|
||||||
let kernels = Kernels::new();
|
|
||||||
call_unary_strided(
|
|
||||||
&device,
|
|
||||||
command_buffer,
|
|
||||||
&kernels,
|
|
||||||
kernel,
|
|
||||||
shape,
|
|
||||||
&input,
|
|
||||||
strides,
|
|
||||||
offset,
|
|
||||||
&output,
|
|
||||||
0,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
output.read_to_vec::<T>(v.len())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn cos_f32() {
|
|
||||||
let v = vec![1.0f32, 2.0, 3.0];
|
|
||||||
let results = run(&v, unary::contiguous::cos::FLOAT);
|
|
||||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
|
||||||
assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
|
|
||||||
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
|
|
||||||
|
|
||||||
let v = vec![1.0f32; 10_000];
|
|
||||||
let results = run(&v, unary::contiguous::cos::FLOAT);
|
|
||||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
|
||||||
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
|
|
||||||
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn cos_f32_strided() {
|
|
||||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
|
||||||
let shape = vec![6];
|
|
||||||
let strides = vec![1];
|
|
||||||
let offset = 0;
|
|
||||||
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
|
||||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
|
||||||
assert_eq!(
|
|
||||||
approx(results, 4),
|
|
||||||
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
approx(expected, 4),
|
|
||||||
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
|
||||||
);
|
|
||||||
|
|
||||||
// Contiguous
|
|
||||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
|
||||||
let shape = vec![3, 2];
|
|
||||||
let strides = vec![2, 1];
|
|
||||||
let offset = 0;
|
|
||||||
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
|
||||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
|
||||||
assert_eq!(
|
|
||||||
approx(results, 4),
|
|
||||||
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
approx(expected, 4),
|
|
||||||
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
|
||||||
);
|
|
||||||
|
|
||||||
// Transposed
|
|
||||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
|
||||||
let shape = vec![3, 2];
|
|
||||||
let strides = vec![1, 3];
|
|
||||||
let offset = 0;
|
|
||||||
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
|
||||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
|
||||||
assert_eq!(
|
|
||||||
approx(results, 4),
|
|
||||||
vec![0.5403, -0.6536, -0.4161, 0.2837, -0.99, 0.9602]
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
approx(expected, 4),
|
|
||||||
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
|
||||||
);
|
|
||||||
|
|
||||||
// Very large
|
|
||||||
let v = vec![1.0f32; 10_000];
|
|
||||||
let shape = vec![2, 5_000];
|
|
||||||
let strides = vec![2, 1];
|
|
||||||
let offset = 0;
|
|
||||||
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
|
||||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
|
||||||
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
|
|
||||||
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn cos_strided_random() {
|
|
||||||
let v: Vec<_> = (0..10_000).map(|_| rand::random::<f32>()).collect();
|
|
||||||
let shape = vec![5_000, 2];
|
|
||||||
let strides = vec![1, 5_000];
|
|
||||||
let offset = 0;
|
|
||||||
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
|
||||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
|
||||||
assert_eq!(approx(vec![results[0]], 4), approx(vec![expected[0]], 4));
|
|
||||||
assert_eq!(
|
|
||||||
approx(vec![results[1]], 4),
|
|
||||||
approx(vec![expected[5_000]], 4)
|
|
||||||
);
|
|
||||||
assert_eq!(approx(vec![results[2]], 4), approx(vec![expected[1]], 4));
|
|
||||||
assert_eq!(
|
|
||||||
approx(vec![results[3]], 4),
|
|
||||||
approx(vec![expected[5_001]], 4)
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
approx(vec![results[5_000]], 4),
|
|
||||||
approx(vec![expected[2_500]], 4)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn gelu_f16() {
|
|
||||||
let v: Vec<f16> = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]
|
|
||||||
.iter()
|
|
||||||
.map(|v| f16::from_f32(*v))
|
|
||||||
.collect();
|
|
||||||
let expected: Vec<f32> = vec![-0.0, -0.16, 0.0, 0.84, 1.96, 3.0, 10.0, 20.0];
|
|
||||||
let results = run(&v, unary::contiguous::gelu::HALF);
|
|
||||||
assert_eq!(approx_f16(results, 2), expected);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn gelu_f32() {
|
|
||||||
let v: Vec<f32> = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0];
|
|
||||||
let expected: Vec<f32> = vec![-0.0, -0.159, 0.0, 0.841, 1.955, 2.996, 10.0, 20.0];
|
|
||||||
let results = run(&v, unary::contiguous::gelu::FLOAT);
|
|
||||||
assert_eq!(approx(results, 3), expected);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn binary_add_f32() {
|
|
||||||
let left = vec![1.0f32, 2.0, 3.0];
|
|
||||||
let right = vec![2.0f32, 3.1, 4.2];
|
|
||||||
let results = run_binary(&left, &right, binary::contiguous::add::FLOAT);
|
|
||||||
let expected: Vec<_> = left
|
|
||||||
.iter()
|
|
||||||
.zip(right.iter())
|
|
||||||
.map(|(&x, &y)| x + y)
|
|
||||||
.collect();
|
|
||||||
assert_eq!(approx(results, 4), vec![3.0f32, 5.1, 7.2]);
|
|
||||||
assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
|
||||||
let device = device();
|
|
||||||
let kernels = Kernels::new();
|
|
||||||
let command_queue = device.new_command_queue();
|
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
|
||||||
let input = new_buffer(&device, v);
|
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
|
||||||
let size = (v.len() * std::mem::size_of::<U>()) as u64;
|
|
||||||
let output = device.new_buffer(size, options);
|
|
||||||
|
|
||||||
call_cast_contiguous(
|
|
||||||
&device,
|
|
||||||
command_buffer,
|
|
||||||
&kernels,
|
|
||||||
name,
|
|
||||||
v.len(),
|
|
||||||
&input,
|
|
||||||
0,
|
|
||||||
&output,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
output.read_to_vec::<U>(v.len())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn cast_u32_f32() {
|
|
||||||
let v = vec![1u32, 2, 3];
|
|
||||||
let results = cast(&v, "cast_u32_f32");
|
|
||||||
let expected: Vec<_> = v.iter().map(|&v| v as f32).collect();
|
|
||||||
assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
|
|
||||||
assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
|
|
||||||
|
|
||||||
let v = vec![1.0f32, 2.0, 3.0];
|
|
||||||
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
|
|
||||||
let results: Vec<f32> = cast(&input, "cast_f16_f32");
|
|
||||||
assert_eq!(results, vec![1.0f32, 2.0, 3.0]);
|
|
||||||
|
|
||||||
let v = vec![1.0f32; 10_000];
|
|
||||||
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
|
|
||||||
let results: Vec<f32> = cast(&input, "cast_f16_f32");
|
|
||||||
assert_eq!(results.len(), 10_000);
|
|
||||||
assert_eq!(&results[..10], vec![1.0f32; 10]);
|
|
||||||
assert_eq!(results, vec![1.0f32; 10_000]);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
|
||||||
let device = device();
|
|
||||||
let kernels = Kernels::new();
|
|
||||||
let command_queue = device.new_command_queue();
|
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
|
||||||
|
|
||||||
let input = new_buffer(&device, v);
|
|
||||||
let output = new_buffer(&device, v);
|
|
||||||
|
|
||||||
let size = v.len();
|
|
||||||
|
|
||||||
call_affine(
|
|
||||||
&device,
|
|
||||||
command_buffer,
|
|
||||||
&kernels,
|
|
||||||
"affine_float",
|
|
||||||
size,
|
|
||||||
&input,
|
|
||||||
&output,
|
|
||||||
mul as f32,
|
|
||||||
add as f32,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
|
|
||||||
output.read_to_vec::<T>(v.len())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_affine_strided<T: Clone>(
|
|
||||||
v: &[T],
|
|
||||||
shape: &[usize],
|
|
||||||
strides: &[usize],
|
|
||||||
mul: f64,
|
|
||||||
add: f64,
|
|
||||||
) -> Vec<T> {
|
|
||||||
let device = device();
|
|
||||||
let kernels = Kernels::new();
|
|
||||||
let command_queue = device.new_command_queue();
|
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
|
||||||
|
|
||||||
let input = new_buffer(&device, v);
|
|
||||||
let output = new_buffer(&device, v);
|
|
||||||
|
|
||||||
call_affine_strided(
|
|
||||||
&device,
|
|
||||||
command_buffer,
|
|
||||||
&kernels,
|
|
||||||
"affine_float_strided",
|
|
||||||
shape,
|
|
||||||
&input,
|
|
||||||
strides,
|
|
||||||
0,
|
|
||||||
&output,
|
|
||||||
mul as f32,
|
|
||||||
add as f32,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
|
|
||||||
let len: usize = shape.iter().product();
|
|
||||||
output.read_to_vec::<T>(len)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn affine() {
|
|
||||||
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
|
||||||
let mul = 1.5;
|
|
||||||
let add = 1.1;
|
|
||||||
let result = run_affine(&input, mul, add);
|
|
||||||
assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]);
|
|
||||||
|
|
||||||
let input = [1.0f32; 40_000];
|
|
||||||
let mul = 1.5;
|
|
||||||
let add = 1.1;
|
|
||||||
let result = run_affine(&input, mul, add);
|
|
||||||
assert_eq!(result, vec![2.6; 40_000]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn affine_strided() {
|
|
||||||
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
|
||||||
let mul = 1.5;
|
|
||||||
let add = 1.1;
|
|
||||||
let shape = [4];
|
|
||||||
let strides = [2];
|
|
||||||
let result = run_affine_strided(&input, &shape, &strides, mul, add);
|
|
||||||
// 1 on 2
|
|
||||||
assert_eq!(result, vec![2.6, 5.6, 8.6, 11.6]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn index_select() {
|
|
||||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
|
||||||
let shape = [5, 2];
|
|
||||||
let ids = [0u32, 4, 2];
|
|
||||||
let dim = 0;
|
|
||||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
|
||||||
assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]);
|
|
||||||
|
|
||||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
|
||||||
let shape = [2, 5];
|
|
||||||
let ids = [0u32, 1, 0];
|
|
||||||
let dim = 0;
|
|
||||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
|
||||||
assert_eq!(
|
|
||||||
result,
|
|
||||||
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn index_select_f16() {
|
|
||||||
let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
|
||||||
.into_iter()
|
|
||||||
.map(|x| f16::from_f32(x))
|
|
||||||
.collect();
|
|
||||||
let shape = [5, 2];
|
|
||||||
let ids = [0u32, 4, 2];
|
|
||||||
let dim = 0;
|
|
||||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
|
||||||
assert_eq!(
|
|
||||||
approx_f16(result, 4),
|
|
||||||
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn index_select_dim1() {
|
|
||||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
|
||||||
let shape = [5, 2];
|
|
||||||
let ids = [0u32, 1, 0];
|
|
||||||
let dim = 1;
|
|
||||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
|
||||||
assert_eq!(
|
|
||||||
result,
|
|
||||||
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
|
||||||
embeddings: &[T],
|
|
||||||
shape: &[usize],
|
|
||||||
ids: &[I],
|
|
||||||
dim: usize,
|
|
||||||
) -> Vec<T> {
|
|
||||||
let device = Device::system_default().expect("no device found");
|
|
||||||
|
|
||||||
let command_queue = device.new_command_queue();
|
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
|
||||||
let embeddings_buffer = new_buffer(&device, &embeddings);
|
|
||||||
let ids_buffer = new_buffer(&device, &ids);
|
|
||||||
|
|
||||||
let left_size: usize = shape[..dim].iter().product();
|
|
||||||
let right_size: usize = shape[dim + 1..].iter().product();
|
|
||||||
let dst_el = ids.len() * left_size * right_size;
|
|
||||||
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
|
|
||||||
|
|
||||||
let name = match core::mem::size_of::<T>() {
|
|
||||||
4 => "is_u32_f32",
|
|
||||||
2 => "is_u32_f16",
|
|
||||||
_ => unimplemented!(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let kernels = Kernels::new();
|
|
||||||
call_index_select(
|
|
||||||
&device,
|
|
||||||
&command_buffer,
|
|
||||||
&kernels,
|
|
||||||
name,
|
|
||||||
shape,
|
|
||||||
ids.len(),
|
|
||||||
dim,
|
|
||||||
&embeddings_buffer,
|
|
||||||
&ids_buffer,
|
|
||||||
&dst_buffer,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
|
|
||||||
dst_buffer.read_to_vec::<T>(dst_el)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn index_add() {
|
|
||||||
let device = Device::system_default().expect("no device found");
|
|
||||||
|
|
||||||
let options = CompileOptions::new();
|
|
||||||
let library = device.new_library_with_source(INDEXING, &options).unwrap();
|
|
||||||
|
|
||||||
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
|
|
||||||
let right = [1.0f32; 15];
|
|
||||||
let index = [0u32, 4, 2];
|
|
||||||
let ids_dim_size = index.len() as u32;
|
|
||||||
let dst_dim_size: u32 = 15;
|
|
||||||
let left_size: u32 = 3;
|
|
||||||
let right_size: u32 = 3;
|
|
||||||
|
|
||||||
let function = library.get_function("ia_u32_f32", None).unwrap();
|
|
||||||
let pipeline = device
|
|
||||||
.new_compute_pipeline_state_with_function(&function)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let command_queue = device.new_command_queue();
|
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
|
||||||
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
|
||||||
|
|
||||||
let index_buffer = new_buffer(&device, &index);
|
|
||||||
let inputs_buffer = new_buffer(&device, &left);
|
|
||||||
let outputs_buffer = new_buffer(&device, &right);
|
|
||||||
|
|
||||||
set_params!(
|
|
||||||
encoder,
|
|
||||||
(
|
|
||||||
&index_buffer,
|
|
||||||
&inputs_buffer,
|
|
||||||
&outputs_buffer,
|
|
||||||
ids_dim_size,
|
|
||||||
left_size,
|
|
||||||
dst_dim_size,
|
|
||||||
right_size
|
|
||||||
)
|
|
||||||
);
|
|
||||||
|
|
||||||
let grid_size = MTLSize {
|
|
||||||
width: right.len() as NSUInteger,
|
|
||||||
height: 1,
|
|
||||||
depth: 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
let thread_group_size = MTLSize {
|
|
||||||
width: pipeline.max_total_threads_per_threadgroup(),
|
|
||||||
height: 1,
|
|
||||||
depth: 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
encoder.dispatch_thread_groups(grid_size, thread_group_size);
|
|
||||||
encoder.end_encoding();
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
|
|
||||||
let expected = vec![
|
|
||||||
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
|
|
||||||
];
|
|
||||||
let result = outputs_buffer.read_to_vec::<f32>(right.len());
|
|
||||||
assert_eq!(result, expected);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn cos_f16() {
|
|
||||||
let v: Vec<f16> = [1.0f32, 2.0, 3.0]
|
|
||||||
.iter()
|
|
||||||
.map(|v| f16::from_f32(*v))
|
|
||||||
.collect();
|
|
||||||
let results = run(&v, unary::contiguous::cos::HALF);
|
|
||||||
let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
|
|
||||||
assert_eq!(approx_f16(results, 2), vec![0.54, -0.42, -0.99]);
|
|
||||||
assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
|
|
||||||
let device = device();
|
|
||||||
let kernels = Kernels::new();
|
|
||||||
let command_queue = device.new_command_queue();
|
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
|
||||||
let input = new_buffer(&device, v);
|
|
||||||
|
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
|
||||||
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
|
||||||
call_reduce_contiguous(
|
|
||||||
&device,
|
|
||||||
command_buffer,
|
|
||||||
&kernels,
|
|
||||||
name,
|
|
||||||
v.len(),
|
|
||||||
out_length,
|
|
||||||
&input,
|
|
||||||
0,
|
|
||||||
&output,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
|
|
||||||
output.read_to_vec::<T>(out_length)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> {
|
|
||||||
let device = device();
|
|
||||||
let kernels = Kernels::new();
|
|
||||||
let command_queue = device.new_command_queue();
|
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
|
||||||
let input = new_buffer(&device, v);
|
|
||||||
let output = new_buffer(&device, v);
|
|
||||||
call_last_softmax(
|
|
||||||
&device,
|
|
||||||
command_buffer,
|
|
||||||
&kernels,
|
|
||||||
name,
|
|
||||||
v.len(),
|
|
||||||
last_dim,
|
|
||||||
&input,
|
|
||||||
&output,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
|
|
||||||
output.read_to_vec::<T>(v.len())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn reduce_sum() {
|
|
||||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
|
||||||
let out_length = 1;
|
|
||||||
|
|
||||||
let results = run_reduce(&v, out_length, "fast_sum_float");
|
|
||||||
assert_eq!(approx(results, 4), vec![21.0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn reduce_sum2() {
|
|
||||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
|
||||||
let out_length = 2;
|
|
||||||
|
|
||||||
let results = run_reduce(&v, out_length, "fast_sum_float");
|
|
||||||
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn softmax() {
|
|
||||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
|
||||||
let last_dim = 6;
|
|
||||||
let results = run_softmax(&v, last_dim, "softmax_float");
|
|
||||||
assert_eq!(
|
|
||||||
approx(results, 4),
|
|
||||||
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
|
|
||||||
);
|
|
||||||
|
|
||||||
let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
|
|
||||||
let last_dim = 6;
|
|
||||||
let results = run_softmax(&v, last_dim, "softmax_float");
|
|
||||||
assert_eq!(
|
|
||||||
approx(results, 4),
|
|
||||||
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
|
|
||||||
);
|
|
||||||
|
|
||||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
|
||||||
let last_dim = 3;
|
|
||||||
let results = run_softmax(&v, last_dim, "softmax_float");
|
|
||||||
assert_eq!(
|
|
||||||
approx(results, 4),
|
|
||||||
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
|
|
||||||
);
|
|
||||||
|
|
||||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
|
|
||||||
.iter()
|
|
||||||
.map(|v| f16::from_f32(*v))
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
let last_dim = 6;
|
|
||||||
let results = run_softmax(&v, last_dim, "softmax_half");
|
|
||||||
assert_eq!(
|
|
||||||
approx_f16(results, 4),
|
|
||||||
vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
|
|
||||||
);
|
|
||||||
|
|
||||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
|
|
||||||
.iter()
|
|
||||||
.map(|v| bf16::from_f32(*v))
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
let last_dim = 6;
|
|
||||||
let results = run_softmax(&v, last_dim, "softmax_bfloat");
|
|
||||||
assert_eq!(
|
|
||||||
approx_bf16(results, 4),
|
|
||||||
vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_where_cond<I: Clone, T: Clone>(
|
|
||||||
shape: &[usize],
|
|
||||||
cond: &[I],
|
|
||||||
(cond_stride, cond_offset): (Vec<usize>, usize),
|
|
||||||
left_true: &[T],
|
|
||||||
(left_stride, left_offset): (Vec<usize>, usize),
|
|
||||||
right_false: &[T],
|
|
||||||
(_right_stride, _right_offset): (Vec<usize>, usize),
|
|
||||||
name: &'static str,
|
|
||||||
) -> Vec<T> {
|
|
||||||
let device = device();
|
|
||||||
let kernels = Kernels::new();
|
|
||||||
let command_queue = device.new_command_queue();
|
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
|
||||||
|
|
||||||
let length = cond.len();
|
|
||||||
let cond = device.new_buffer_with_data(
|
|
||||||
cond.as_ptr() as *const core::ffi::c_void,
|
|
||||||
std::mem::size_of_val(cond) as u64,
|
|
||||||
options,
|
|
||||||
);
|
|
||||||
let left = device.new_buffer_with_data(
|
|
||||||
left_true.as_ptr() as *const core::ffi::c_void,
|
|
||||||
(length * core::mem::size_of::<T>()) as u64,
|
|
||||||
options,
|
|
||||||
);
|
|
||||||
let right = device.new_buffer_with_data(
|
|
||||||
right_false.as_ptr() as *const core::ffi::c_void,
|
|
||||||
(length * core::mem::size_of::<T>()) as u64,
|
|
||||||
options,
|
|
||||||
);
|
|
||||||
|
|
||||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
|
||||||
call_where_cond_strided(
|
|
||||||
&device,
|
|
||||||
command_buffer,
|
|
||||||
&kernels,
|
|
||||||
name,
|
|
||||||
shape,
|
|
||||||
&cond,
|
|
||||||
(&cond_stride, cond_offset),
|
|
||||||
&left,
|
|
||||||
(&left_stride, left_offset),
|
|
||||||
&right,
|
|
||||||
(&cond_stride, cond_offset),
|
|
||||||
&output,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
|
|
||||||
output.read_to_vec::<T>(length)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn where_cond() {
|
|
||||||
let shape = vec![6];
|
|
||||||
let cond = vec![0u8, 1, 0, 0, 1, 1];
|
|
||||||
let cond_l = (vec![1], 0);
|
|
||||||
let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
|
||||||
let left_l = (vec![1], 0);
|
|
||||||
let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0];
|
|
||||||
let right_l = (vec![1], 0);
|
|
||||||
let results = run_where_cond(
|
|
||||||
&shape,
|
|
||||||
&cond,
|
|
||||||
cond_l,
|
|
||||||
&left_true,
|
|
||||||
left_l,
|
|
||||||
&right_false,
|
|
||||||
right_l,
|
|
||||||
"where_u8_f32",
|
|
||||||
);
|
|
||||||
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
|
|
||||||
}
|
|
@ -42,14 +42,9 @@ template <typename T> METAL_FUNC T erf(T in){
|
|||||||
|
|
||||||
return T(sign*y);
|
return T(sign*y);
|
||||||
}
|
}
|
||||||
template <typename T> METAL_FUNC T id(T in) { return in; }
|
template <typename T> METAL_FUNC T id(T in){ return in; }
|
||||||
template <typename T> METAL_FUNC T gelu_erf(T x) {
|
template <typename T> METAL_FUNC T gelu_erf(T x){ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); }
|
||||||
return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2);
|
template <typename T> METAL_FUNC T gelu(T x){
|
||||||
}
|
|
||||||
template <typename T> METAL_FUNC T gelu(T x) {
|
|
||||||
if (x > 5) {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
T x_sq = x * x;
|
T x_sq = x * x;
|
||||||
T x_cube = x_sq * x;
|
T x_cube = x_sq * x;
|
||||||
T alpha = x + static_cast<T>(0.044715) * x_cube;
|
T alpha = x + static_cast<T>(0.044715) * x_cube;
|
||||||
|
@ -11,7 +11,7 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
half = { workspace = true }
|
half = { workspace = true }
|
||||||
thiserror = { workspace = true }
|
thiserror = { workspace = true }
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
@ -19,7 +19,6 @@ num-traits = { workspace = true }
|
|||||||
rayon = { workspace = true }
|
rayon = { workspace = true }
|
||||||
safetensors = { workspace = true }
|
safetensors = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
@ -30,4 +29,3 @@ default = []
|
|||||||
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
||||||
cuda = ["candle/cuda"]
|
cuda = ["candle/cuda"]
|
||||||
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
||||||
metal = ["candle/metal", "dep:candle-metal-kernels"]
|
|
||||||
|
@ -6,7 +6,7 @@ extern crate intel_mkl_src;
|
|||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use candle::quantized::GgmlType;
|
use candle::quantized::GgmlType;
|
||||||
use candle::{CpuStorage, Device, Layout, Module, Result, Shape, Tensor, D};
|
use candle::{CpuStorage, Device, Layout, Result, Shape, Tensor, D};
|
||||||
use clap::{Parser, Subcommand};
|
use clap::{Parser, Subcommand};
|
||||||
|
|
||||||
const CHECK_CONV2D: bool = false;
|
const CHECK_CONV2D: bool = false;
|
||||||
|
@ -201,46 +201,6 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||||||
};
|
};
|
||||||
Ok((dst, layout.shape().clone()))
|
Ok((dst, layout.shape().clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
fn metal_fwd(
|
|
||||||
&self,
|
|
||||||
storage: &candle::MetalStorage,
|
|
||||||
layout: &Layout,
|
|
||||||
) -> Result<(candle::MetalStorage, Shape)> {
|
|
||||||
use candle::{backend::BackendStorage, DType};
|
|
||||||
let device = storage.device();
|
|
||||||
let command_buffer = device.command_buffer();
|
|
||||||
let kernels = device.kernels();
|
|
||||||
let name = match storage.dtype() {
|
|
||||||
DType::F32 => "softmax_float",
|
|
||||||
DType::F16 => "softmax_half",
|
|
||||||
DType::BF16 => "softmax_bfloat",
|
|
||||||
dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
|
|
||||||
};
|
|
||||||
|
|
||||||
let n = layout.stride().len();
|
|
||||||
if !(layout.stride()[n - 1] == 1 && layout.start_offset() == 0) {
|
|
||||||
candle::bail!("Non contiguous softmax-last-dim is not implemented");
|
|
||||||
}
|
|
||||||
|
|
||||||
let last_dim = layout.dims()[layout.shape().rank() - 1];
|
|
||||||
let elem_count = layout.shape().elem_count();
|
|
||||||
let mut output = device.new_buffer(elem_count, storage.dtype());
|
|
||||||
candle_metal_kernels::call_last_softmax(
|
|
||||||
device.metal_device(),
|
|
||||||
&command_buffer,
|
|
||||||
&kernels,
|
|
||||||
name,
|
|
||||||
elem_count,
|
|
||||||
last_dim,
|
|
||||||
storage.buffer(),
|
|
||||||
&mut output,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
|
|
||||||
Ok((newstorage, layout.shape().clone()))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-onnx"
|
name = "candle-onnx"
|
||||||
version = "0.3.1"
|
version = "0.3.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "ONNX support for Candle"
|
description = "ONNX support for Candle"
|
||||||
@ -10,8 +10,8 @@ categories = ["science"]
|
|||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||||
prost = "0.12.1"
|
prost = "0.12.1"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
@ -741,25 +741,6 @@ pub fn simple_eval(
|
|||||||
let output = input.to_dtype(dtype)?;
|
let output = input.to_dtype(dtype)?;
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#CumSum
|
|
||||||
"CumSum" => {
|
|
||||||
let exclusive = get_attr_opt::<i64>(node, "exclusive")?
|
|
||||||
.copied()
|
|
||||||
.unwrap_or(0);
|
|
||||||
let reverse = get_attr_opt::<i64>(node, "reverse")?.copied().unwrap_or(0);
|
|
||||||
if exclusive != 0 {
|
|
||||||
bail!("only exclusive == 0 is supported in CumSum")
|
|
||||||
}
|
|
||||||
if reverse != 0 {
|
|
||||||
bail!("only reverse == 0 is supported in CumSum")
|
|
||||||
}
|
|
||||||
let input = get(&node.input[0])?;
|
|
||||||
let axis = get(&node.input[1])?
|
|
||||||
.to_dtype(DType::U32)?
|
|
||||||
.to_vec0::<u32>()?;
|
|
||||||
let output = input.cumsum(axis as usize)?;
|
|
||||||
values.insert(node.output[0].clone(), output);
|
|
||||||
}
|
|
||||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,746 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use candle::{Device, Result, Tensor};
|
|
||||||
use candle_onnx::onnx::{GraphProto, ModelProto, NodeProto, ValueInfoProto};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
const INPUT_X: &str = "x";
|
|
||||||
const INPUT_Y: &str = "y";
|
|
||||||
const OUTPUT_Z: &str = "z";
|
|
||||||
|
|
||||||
fn create_model_proto_with_graph(graph: Option<GraphProto>) -> ModelProto {
|
|
||||||
ModelProto {
|
|
||||||
metadata_props: vec![],
|
|
||||||
training_info: vec![],
|
|
||||||
functions: vec![],
|
|
||||||
ir_version: 0,
|
|
||||||
opset_import: vec![],
|
|
||||||
producer_name: "".to_string(),
|
|
||||||
producer_version: "".to_string(),
|
|
||||||
domain: "".to_string(),
|
|
||||||
model_version: 0,
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
graph,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_evaluation_fails_without_defined_graph() -> Result<()> {
|
|
||||||
let manual_graph = create_model_proto_with_graph(None);
|
|
||||||
|
|
||||||
let inputs: HashMap<String, Tensor> = HashMap::new();
|
|
||||||
|
|
||||||
match candle_onnx::simple_eval(&manual_graph, inputs) {
|
|
||||||
Err(err) => assert_eq!(err.to_string(), "no graph defined in proto"),
|
|
||||||
Ok(_) => panic!("Expected an error due to undefined graph"),
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// "Add"
|
|
||||||
#[test]
|
|
||||||
fn test_add_operation() -> Result<()> {
|
|
||||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
|
||||||
node: vec![NodeProto {
|
|
||||||
op_type: "Add".to_string(),
|
|
||||||
domain: "".to_string(),
|
|
||||||
attribute: vec![],
|
|
||||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
|
||||||
output: vec![OUTPUT_Z.to_string()],
|
|
||||||
name: "".to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
}],
|
|
||||||
name: "".to_string(),
|
|
||||||
initializer: vec![],
|
|
||||||
input: vec![],
|
|
||||||
output: vec![ValueInfoProto {
|
|
||||||
name: OUTPUT_Z.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
}],
|
|
||||||
value_info: vec![],
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
sparse_initializer: vec![],
|
|
||||||
quantization_annotation: vec![],
|
|
||||||
}));
|
|
||||||
|
|
||||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
|
||||||
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
|
||||||
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
|
||||||
|
|
||||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
|
||||||
assert_eq!(eval.len(), 1);
|
|
||||||
|
|
||||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
|
||||||
let first = z
|
|
||||||
.to_vec1::<f64>()?
|
|
||||||
.to_vec()
|
|
||||||
.get(0)
|
|
||||||
.expect("Failed to get first element")
|
|
||||||
.clone();
|
|
||||||
assert_eq!(first, 4.0f64);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// "Sub"
|
|
||||||
#[test]
|
|
||||||
fn test_sub_operation() -> Result<()> {
|
|
||||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
|
||||||
node: vec![NodeProto {
|
|
||||||
op_type: "Sub".to_string(),
|
|
||||||
domain: "".to_string(),
|
|
||||||
attribute: vec![],
|
|
||||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
|
||||||
output: vec![OUTPUT_Z.to_string()],
|
|
||||||
name: "".to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
}],
|
|
||||||
name: "".to_string(),
|
|
||||||
initializer: vec![],
|
|
||||||
input: vec![],
|
|
||||||
output: vec![ValueInfoProto {
|
|
||||||
name: OUTPUT_Z.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
}],
|
|
||||||
value_info: vec![],
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
sparse_initializer: vec![],
|
|
||||||
quantization_annotation: vec![],
|
|
||||||
}));
|
|
||||||
|
|
||||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
|
||||||
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
|
||||||
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
|
||||||
|
|
||||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
|
||||||
assert_eq!(eval.len(), 1);
|
|
||||||
|
|
||||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
|
||||||
let first = z
|
|
||||||
.to_vec1::<f64>()?
|
|
||||||
.to_vec()
|
|
||||||
.get(0)
|
|
||||||
.expect("Failed to get first element")
|
|
||||||
.clone();
|
|
||||||
assert_eq!(first, 0.0f64);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// "Mul"
|
|
||||||
#[test]
|
|
||||||
fn test_mul_operation() -> Result<()> {
|
|
||||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
|
||||||
node: vec![NodeProto {
|
|
||||||
op_type: "Mul".to_string(),
|
|
||||||
domain: "".to_string(),
|
|
||||||
attribute: vec![],
|
|
||||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
|
||||||
output: vec![OUTPUT_Z.to_string()],
|
|
||||||
name: "".to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
}],
|
|
||||||
name: "".to_string(),
|
|
||||||
initializer: vec![],
|
|
||||||
input: vec![],
|
|
||||||
output: vec![ValueInfoProto {
|
|
||||||
name: OUTPUT_Z.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
}],
|
|
||||||
value_info: vec![],
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
sparse_initializer: vec![],
|
|
||||||
quantization_annotation: vec![],
|
|
||||||
}));
|
|
||||||
|
|
||||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
|
||||||
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
|
||||||
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
|
||||||
|
|
||||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
|
||||||
assert_eq!(eval.len(), 1);
|
|
||||||
|
|
||||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
|
||||||
let first = z
|
|
||||||
.to_vec1::<f64>()?
|
|
||||||
.to_vec()
|
|
||||||
.get(0)
|
|
||||||
.expect("Failed to get first element")
|
|
||||||
.clone();
|
|
||||||
assert_eq!(first, 4.0f64);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// "Div"
|
|
||||||
#[test]
|
|
||||||
fn test_div_operation() -> Result<()> {
|
|
||||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
|
||||||
node: vec![NodeProto {
|
|
||||||
op_type: "Div".to_string(),
|
|
||||||
domain: "".to_string(),
|
|
||||||
attribute: vec![],
|
|
||||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
|
||||||
output: vec![OUTPUT_Z.to_string()],
|
|
||||||
name: "".to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
}],
|
|
||||||
name: "".to_string(),
|
|
||||||
initializer: vec![],
|
|
||||||
input: vec![],
|
|
||||||
output: vec![ValueInfoProto {
|
|
||||||
name: OUTPUT_Z.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
}],
|
|
||||||
value_info: vec![],
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
sparse_initializer: vec![],
|
|
||||||
quantization_annotation: vec![],
|
|
||||||
}));
|
|
||||||
|
|
||||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
|
||||||
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
|
||||||
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
|
||||||
|
|
||||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
|
||||||
assert_eq!(eval.len(), 1);
|
|
||||||
|
|
||||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
|
||||||
let first = z
|
|
||||||
.to_vec1::<f64>()?
|
|
||||||
.to_vec()
|
|
||||||
.get(0)
|
|
||||||
.expect("Failed to get first element")
|
|
||||||
.clone();
|
|
||||||
|
|
||||||
assert_eq!(first, 1.0f64);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// "Equal"
|
|
||||||
#[test]
|
|
||||||
fn test_equal_operation() -> Result<()> {
|
|
||||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
|
||||||
node: vec![NodeProto {
|
|
||||||
op_type: "Equal".to_string(),
|
|
||||||
domain: "".to_string(),
|
|
||||||
attribute: vec![],
|
|
||||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
|
||||||
output: vec![OUTPUT_Z.to_string()],
|
|
||||||
name: "".to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
}],
|
|
||||||
name: "".to_string(),
|
|
||||||
initializer: vec![],
|
|
||||||
input: vec![],
|
|
||||||
output: vec![ValueInfoProto {
|
|
||||||
name: OUTPUT_Z.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
}],
|
|
||||||
value_info: vec![],
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
sparse_initializer: vec![],
|
|
||||||
quantization_annotation: vec![],
|
|
||||||
}));
|
|
||||||
|
|
||||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
|
||||||
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
|
||||||
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
|
||||||
|
|
||||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
|
||||||
assert_eq!(eval.len(), 1);
|
|
||||||
|
|
||||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
|
||||||
let first = z.to_dtype(candle::DType::U8)?.to_vec1::<u8>()?.to_vec()[0];
|
|
||||||
assert_eq!(first, 1);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// "Not"
|
|
||||||
#[test]
|
|
||||||
fn test_not_operation() -> Result<()> {
|
|
||||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
|
||||||
node: vec![NodeProto {
|
|
||||||
op_type: "Not".to_string(),
|
|
||||||
domain: "".to_string(),
|
|
||||||
attribute: vec![],
|
|
||||||
input: vec![INPUT_X.to_string()],
|
|
||||||
output: vec![OUTPUT_Z.to_string()],
|
|
||||||
name: "".to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
}],
|
|
||||||
name: "".to_string(),
|
|
||||||
initializer: vec![],
|
|
||||||
input: vec![],
|
|
||||||
output: vec![ValueInfoProto {
|
|
||||||
name: OUTPUT_Z.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
}],
|
|
||||||
value_info: vec![],
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
sparse_initializer: vec![],
|
|
||||||
quantization_annotation: vec![],
|
|
||||||
}));
|
|
||||||
|
|
||||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
|
||||||
inputs.insert(INPUT_X.to_string(), Tensor::new(&[0.], &Device::Cpu)?);
|
|
||||||
|
|
||||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
|
||||||
assert_eq!(eval.len(), 1);
|
|
||||||
|
|
||||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
|
||||||
let first = z.to_dtype(candle::DType::U8)?.to_vec1::<u8>()?.to_vec()[0];
|
|
||||||
assert_eq!(first, 1);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// "MatMul"
|
|
||||||
#[test]
|
|
||||||
fn test_matmul_operation() -> Result<()> {
|
|
||||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
|
||||||
node: vec![NodeProto {
|
|
||||||
op_type: "MatMul".to_string(),
|
|
||||||
domain: "".to_string(),
|
|
||||||
attribute: vec![],
|
|
||||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
|
||||||
output: vec![OUTPUT_Z.to_string()],
|
|
||||||
name: "".to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
}],
|
|
||||||
name: "".to_string(),
|
|
||||||
initializer: vec![],
|
|
||||||
input: vec![],
|
|
||||||
output: vec![ValueInfoProto {
|
|
||||||
name: OUTPUT_Z.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
}],
|
|
||||||
value_info: vec![],
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
sparse_initializer: vec![],
|
|
||||||
quantization_annotation: vec![],
|
|
||||||
}));
|
|
||||||
|
|
||||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
|
||||||
inputs.insert(
|
|
||||||
INPUT_X.to_string(),
|
|
||||||
Tensor::from_vec(
|
|
||||||
//
|
|
||||||
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
|
|
||||||
&[2, 2],
|
|
||||||
&Device::Cpu,
|
|
||||||
)?,
|
|
||||||
);
|
|
||||||
inputs.insert(
|
|
||||||
INPUT_Y.to_string(),
|
|
||||||
Tensor::from_vec(
|
|
||||||
//
|
|
||||||
vec![5.0f32, 6.0f32, 7.0f32, 8.0f32],
|
|
||||||
&[2, 2],
|
|
||||||
&Device::Cpu,
|
|
||||||
)?,
|
|
||||||
);
|
|
||||||
|
|
||||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
|
||||||
assert_eq!(eval.len(), 1);
|
|
||||||
|
|
||||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
|
||||||
let results = z.to_vec2::<f32>()?;
|
|
||||||
assert_eq!(results, vec![vec![19.0, 22.0], vec![43.0, 50.0]]);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// "Reshape"
|
|
||||||
#[test]
|
|
||||||
fn test_reshape_operation() -> Result<()> {
|
|
||||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
|
||||||
node: vec![NodeProto {
|
|
||||||
op_type: "Reshape".to_string(),
|
|
||||||
domain: "".to_string(),
|
|
||||||
attribute: vec![],
|
|
||||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
|
||||||
output: vec![OUTPUT_Z.to_string()],
|
|
||||||
name: "".to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
}],
|
|
||||||
name: "".to_string(),
|
|
||||||
initializer: vec![],
|
|
||||||
input: vec![
|
|
||||||
ValueInfoProto {
|
|
||||||
name: INPUT_X.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
},
|
|
||||||
ValueInfoProto {
|
|
||||||
name: INPUT_Y.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
output: vec![ValueInfoProto {
|
|
||||||
name: OUTPUT_Z.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
}],
|
|
||||||
value_info: vec![],
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
sparse_initializer: vec![],
|
|
||||||
quantization_annotation: vec![],
|
|
||||||
}));
|
|
||||||
|
|
||||||
let x = Tensor::from_vec(
|
|
||||||
//
|
|
||||||
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
|
|
||||||
&[2, 2],
|
|
||||||
&Device::Cpu,
|
|
||||||
)?;
|
|
||||||
let y = Tensor::from_vec(
|
|
||||||
//
|
|
||||||
vec![4i64],
|
|
||||||
&[1],
|
|
||||||
&Device::Cpu,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
|
||||||
inputs.insert(INPUT_X.to_string(), x);
|
|
||||||
inputs.insert(INPUT_Y.to_string(), y);
|
|
||||||
|
|
||||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
|
||||||
assert_eq!(eval.len(), 1);
|
|
||||||
|
|
||||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
|
||||||
|
|
||||||
let results = z.to_vec1::<f32>()?;
|
|
||||||
|
|
||||||
assert_eq!(results, vec![1.0, 2.0, 3.0, 4.0]);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// "LogSoftmax"
|
|
||||||
#[test]
|
|
||||||
fn test_logsoftmax_operation() -> Result<()> {
|
|
||||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
|
||||||
node: vec![NodeProto {
|
|
||||||
op_type: "LogSoftmax".to_string(),
|
|
||||||
domain: "".to_string(),
|
|
||||||
attribute: vec![],
|
|
||||||
input: vec![INPUT_X.to_string()],
|
|
||||||
output: vec![OUTPUT_Z.to_string()],
|
|
||||||
name: "".to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
}],
|
|
||||||
name: "".to_string(),
|
|
||||||
initializer: vec![],
|
|
||||||
input: vec![
|
|
||||||
ValueInfoProto {
|
|
||||||
name: INPUT_X.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
},
|
|
||||||
ValueInfoProto {
|
|
||||||
name: INPUT_Y.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
output: vec![ValueInfoProto {
|
|
||||||
name: OUTPUT_Z.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
}],
|
|
||||||
value_info: vec![],
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
sparse_initializer: vec![],
|
|
||||||
quantization_annotation: vec![],
|
|
||||||
}));
|
|
||||||
|
|
||||||
let x = Tensor::from_vec(
|
|
||||||
//
|
|
||||||
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
|
|
||||||
&[2, 2],
|
|
||||||
&Device::Cpu,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
|
||||||
inputs.insert(INPUT_X.to_string(), x);
|
|
||||||
|
|
||||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
|
||||||
assert_eq!(eval.len(), 1);
|
|
||||||
|
|
||||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
|
||||||
|
|
||||||
let results = z.to_vec2::<f32>()?;
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
results,
|
|
||||||
vec![vec![0.26894143, 0.7310586], vec![0.26894143, 0.7310586]]
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// "Softmax"
|
|
||||||
#[test]
|
|
||||||
fn test_softmax_operation() -> Result<()> {
|
|
||||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
|
||||||
node: vec![NodeProto {
|
|
||||||
op_type: "Softmax".to_string(),
|
|
||||||
domain: "".to_string(),
|
|
||||||
attribute: vec![],
|
|
||||||
input: vec![INPUT_X.to_string()],
|
|
||||||
output: vec![OUTPUT_Z.to_string()],
|
|
||||||
name: "".to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
}],
|
|
||||||
name: "".to_string(),
|
|
||||||
initializer: vec![],
|
|
||||||
input: vec![
|
|
||||||
ValueInfoProto {
|
|
||||||
name: INPUT_X.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
},
|
|
||||||
ValueInfoProto {
|
|
||||||
name: INPUT_Y.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
output: vec![ValueInfoProto {
|
|
||||||
name: OUTPUT_Z.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
}],
|
|
||||||
value_info: vec![],
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
sparse_initializer: vec![],
|
|
||||||
quantization_annotation: vec![],
|
|
||||||
}));
|
|
||||||
|
|
||||||
let x = Tensor::from_vec(
|
|
||||||
//
|
|
||||||
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
|
|
||||||
&[2, 2],
|
|
||||||
&Device::Cpu,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
|
||||||
inputs.insert(INPUT_X.to_string(), x);
|
|
||||||
|
|
||||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
|
||||||
assert_eq!(eval.len(), 1);
|
|
||||||
|
|
||||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
|
||||||
|
|
||||||
let results = z.to_vec2::<f32>()?;
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
results,
|
|
||||||
vec![vec![0.26894143, 0.7310586], vec![0.26894143, 0.7310586]]
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// "Transpose"
|
|
||||||
#[test]
|
|
||||||
fn test_transpose_operation() -> Result<()> {
|
|
||||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
|
||||||
node: vec![NodeProto {
|
|
||||||
op_type: "Transpose".to_string(),
|
|
||||||
domain: "".to_string(),
|
|
||||||
attribute: vec![],
|
|
||||||
input: vec![INPUT_X.to_string()],
|
|
||||||
output: vec![OUTPUT_Z.to_string()],
|
|
||||||
name: "".to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
}],
|
|
||||||
name: "".to_string(),
|
|
||||||
initializer: vec![],
|
|
||||||
input: vec![
|
|
||||||
ValueInfoProto {
|
|
||||||
name: INPUT_X.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
},
|
|
||||||
ValueInfoProto {
|
|
||||||
name: INPUT_Y.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
output: vec![ValueInfoProto {
|
|
||||||
name: OUTPUT_Z.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
}],
|
|
||||||
value_info: vec![],
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
sparse_initializer: vec![],
|
|
||||||
quantization_annotation: vec![],
|
|
||||||
}));
|
|
||||||
|
|
||||||
let x = Tensor::from_vec(
|
|
||||||
//
|
|
||||||
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
|
|
||||||
&[2, 2],
|
|
||||||
&Device::Cpu,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
|
||||||
inputs.insert(INPUT_X.to_string(), x);
|
|
||||||
|
|
||||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
|
||||||
assert_eq!(eval.len(), 1);
|
|
||||||
|
|
||||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
|
||||||
|
|
||||||
let results = z.to_vec2::<f32>()?;
|
|
||||||
|
|
||||||
assert_eq!(results, vec![vec![1.0, 3.0], vec![2.0, 4.0]]);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// "Dropout"
|
|
||||||
#[test]
|
|
||||||
fn test_dropout_operation() -> Result<()> {
|
|
||||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
|
||||||
node: vec![NodeProto {
|
|
||||||
op_type: "Dropout".to_string(),
|
|
||||||
domain: "".to_string(),
|
|
||||||
attribute: vec![],
|
|
||||||
input: vec![INPUT_X.to_string()],
|
|
||||||
output: vec![OUTPUT_Z.to_string()],
|
|
||||||
name: "".to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
}],
|
|
||||||
name: "".to_string(),
|
|
||||||
initializer: vec![],
|
|
||||||
input: vec![
|
|
||||||
ValueInfoProto {
|
|
||||||
name: INPUT_X.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
},
|
|
||||||
ValueInfoProto {
|
|
||||||
name: INPUT_Y.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
output: vec![ValueInfoProto {
|
|
||||||
name: OUTPUT_Z.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
}],
|
|
||||||
value_info: vec![],
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
sparse_initializer: vec![],
|
|
||||||
quantization_annotation: vec![],
|
|
||||||
}));
|
|
||||||
let x = Tensor::from_vec(
|
|
||||||
//
|
|
||||||
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
|
|
||||||
&[2, 2],
|
|
||||||
&Device::Cpu,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
|
||||||
inputs.insert(INPUT_X.to_string(), x);
|
|
||||||
|
|
||||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
|
||||||
assert_eq!(eval.len(), 1);
|
|
||||||
|
|
||||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
|
||||||
|
|
||||||
let results = z.to_vec2::<f32>()?;
|
|
||||||
|
|
||||||
assert_eq!(results, vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Below are ops that are implemented but not tested yet
|
|
||||||
|
|
||||||
// "MaxPool"
|
|
||||||
// #[test]
|
|
||||||
|
|
||||||
// "AveragePool"
|
|
||||||
// #[test]
|
|
||||||
|
|
||||||
// "BatchNormalization"
|
|
||||||
// #[test]
|
|
||||||
|
|
||||||
// "Squeeze"
|
|
||||||
// #[test]
|
|
||||||
|
|
||||||
// "ConstantOfShape"
|
|
||||||
// #[test]
|
|
||||||
|
|
||||||
// "Unsqueeze"
|
|
||||||
// #[test]
|
|
||||||
|
|
||||||
// "Clip"
|
|
||||||
// #[test]
|
|
||||||
|
|
||||||
// "Gather"
|
|
||||||
// #[test]
|
|
||||||
|
|
||||||
// "Shape"
|
|
||||||
// #[test]
|
|
||||||
|
|
||||||
// "Conv"
|
|
||||||
// #[test]
|
|
||||||
|
|
||||||
// "Concat"
|
|
||||||
// #[test]
|
|
||||||
|
|
||||||
// "Abs"
|
|
||||||
// #[test]
|
|
||||||
|
|
||||||
// "Cos"
|
|
||||||
// #[test]
|
|
||||||
|
|
||||||
// "Sin"
|
|
||||||
// #[test]
|
|
||||||
|
|
||||||
// "Neg"
|
|
||||||
// #[test]
|
|
||||||
|
|
||||||
// "Erf"
|
|
||||||
// #[test]
|
|
||||||
|
|
||||||
// "Tanh"
|
|
||||||
// #[test]
|
|
||||||
|
|
||||||
// "Sigmoid"
|
|
||||||
// #[test]
|
|
||||||
|
|
||||||
// "Gelu"
|
|
||||||
// #[test]
|
|
||||||
|
|
||||||
// "Relu"
|
|
||||||
// #[test]
|
|
||||||
|
|
||||||
// "Constant"
|
|
||||||
// #[test]
|
|
||||||
|
|
||||||
// "Cast"
|
|
||||||
// #[test]
|
|
@ -15,9 +15,9 @@ crate-type = ["cdylib"]
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||||
candle-onnx = {path= "../candle-onnx", version = "0.3.1", optional = true}
|
candle-onnx = {path= "../candle-onnx", version = "0.3.0", optional = true}
|
||||||
half = { workspace = true }
|
half = { workspace = true }
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
|
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
|
||||||
|
@ -17,7 +17,7 @@ extern crate intel_mkl_src;
|
|||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use ::candle::{quantized::QTensor, DType, Device, Module, Tensor, WithDType};
|
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
|
||||||
|
|
||||||
mod utils;
|
mod utils;
|
||||||
use utils::wrap_err;
|
use utils::wrap_err;
|
||||||
|
@ -12,9 +12,9 @@ readme = "README.md"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true }
|
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
rand = { workspace = true }
|
rand = { workspace = true }
|
||||||
|
@ -1,342 +0,0 @@
|
|||||||
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
|
|
||||||
use candle::{DType, Device, Result, Tensor};
|
|
||||||
use candle_nn::{Embedding, Module, VarBuilder};
|
|
||||||
use serde::Deserialize;
|
|
||||||
|
|
||||||
pub const DTYPE: DType = DType::F32;
|
|
||||||
|
|
||||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
|
||||||
let shape = mask.shape();
|
|
||||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
|
||||||
let m = mask.where_cond(&on_true, on_false)?;
|
|
||||||
Ok(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
|
||||||
#[serde(rename_all = "lowercase")]
|
|
||||||
enum HiddenAct {
|
|
||||||
Gelu,
|
|
||||||
Relu,
|
|
||||||
}
|
|
||||||
|
|
||||||
struct HiddenActLayer {
|
|
||||||
act: HiddenAct,
|
|
||||||
span: tracing::Span,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl HiddenActLayer {
|
|
||||||
fn new(act: HiddenAct) -> Self {
|
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "hidden-act");
|
|
||||||
Self { act, span }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for HiddenActLayer {
|
|
||||||
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
|
||||||
match self.act {
|
|
||||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
|
|
||||||
HiddenAct::Gelu => xs.gelu(),
|
|
||||||
HiddenAct::Relu => xs.relu(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
|
||||||
#[serde(rename_all = "lowercase")]
|
|
||||||
enum PositionEmbeddingType {
|
|
||||||
#[default]
|
|
||||||
Absolute,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
|
||||||
pub struct Config {
|
|
||||||
vocab_size: usize,
|
|
||||||
dim: usize,
|
|
||||||
n_layers: usize,
|
|
||||||
n_heads: usize,
|
|
||||||
hidden_dim: usize,
|
|
||||||
activation: HiddenAct,
|
|
||||||
max_position_embeddings: usize,
|
|
||||||
initializer_range: f64,
|
|
||||||
pad_token_id: usize,
|
|
||||||
#[serde(default)]
|
|
||||||
position_embedding_type: PositionEmbeddingType,
|
|
||||||
#[serde(default)]
|
|
||||||
use_cache: bool,
|
|
||||||
model_type: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for Config {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
vocab_size: 30522,
|
|
||||||
dim: 768,
|
|
||||||
n_layers: 12,
|
|
||||||
n_heads: 12,
|
|
||||||
hidden_dim: 3072,
|
|
||||||
activation: HiddenAct::Gelu,
|
|
||||||
max_position_embeddings: 512,
|
|
||||||
initializer_range: 0.02,
|
|
||||||
pad_token_id: 0,
|
|
||||||
position_embedding_type: PositionEmbeddingType::Absolute,
|
|
||||||
use_cache: true,
|
|
||||||
model_type: Some("distilbert".to_string()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Embeddings {
|
|
||||||
word_embeddings: Embedding,
|
|
||||||
position_embeddings: Embedding,
|
|
||||||
layer_norm: LayerNorm,
|
|
||||||
span: tracing::Span,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Embeddings {
|
|
||||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
|
||||||
let word_embeddings =
|
|
||||||
candle_nn::embedding(config.vocab_size, config.dim, vb.pp("word_embeddings"))?;
|
|
||||||
let position_embeddings = candle_nn::embedding(
|
|
||||||
config.max_position_embeddings,
|
|
||||||
config.dim,
|
|
||||||
vb.pp("position_embeddings"),
|
|
||||||
)?;
|
|
||||||
let layer_norm = layer_norm(config.dim, 1e-12, vb.pp("LayerNorm"))?;
|
|
||||||
Ok(Self {
|
|
||||||
word_embeddings,
|
|
||||||
position_embeddings,
|
|
||||||
layer_norm,
|
|
||||||
span: tracing::span!(tracing::Level::TRACE, "embeddings"),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
|
||||||
let (_bsize, seq_len) = input_ids.dims2()?;
|
|
||||||
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
|
||||||
let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
|
|
||||||
let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
|
|
||||||
let embeddings =
|
|
||||||
input_embeddings.broadcast_add(&self.position_embeddings.forward(&position_ids)?)?;
|
|
||||||
|
|
||||||
let embeddings = self.layer_norm.forward(&embeddings)?;
|
|
||||||
Ok(embeddings)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct MultiHeadSelfAttention {
|
|
||||||
q_lin: Linear,
|
|
||||||
k_lin: Linear,
|
|
||||||
v_lin: Linear,
|
|
||||||
out_lin: Linear,
|
|
||||||
n_heads: usize,
|
|
||||||
attention_head_size: usize,
|
|
||||||
span: tracing::Span,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MultiHeadSelfAttention {
|
|
||||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
|
||||||
let attention_head_size = config.dim / config.n_heads;
|
|
||||||
let all_head_size = config.n_heads * attention_head_size;
|
|
||||||
let dim = config.dim;
|
|
||||||
let q_lin = linear(dim, all_head_size, vb.pp("q_lin"))?;
|
|
||||||
let v_lin = linear(dim, all_head_size, vb.pp("v_lin"))?;
|
|
||||||
let k_lin = linear(dim, all_head_size, vb.pp("k_lin"))?;
|
|
||||||
let out_lin = linear(all_head_size, dim, vb.pp("out_lin"))?;
|
|
||||||
Ok(Self {
|
|
||||||
q_lin,
|
|
||||||
k_lin,
|
|
||||||
v_lin,
|
|
||||||
out_lin,
|
|
||||||
n_heads: config.n_heads,
|
|
||||||
attention_head_size,
|
|
||||||
span: tracing::span!(tracing::Level::TRACE, "attention"),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MultiHeadSelfAttention {
|
|
||||||
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
|
||||||
let (bs, q_length, _dim) = hidden_states.dims3()?;
|
|
||||||
|
|
||||||
let dim_per_head = self.attention_head_size;
|
|
||||||
let q = self.q_lin.forward(hidden_states)?;
|
|
||||||
let k = self.k_lin.forward(hidden_states)?;
|
|
||||||
let v = self.v_lin.forward(hidden_states)?;
|
|
||||||
|
|
||||||
let q = q
|
|
||||||
.reshape((bs, q_length, self.n_heads, dim_per_head))?
|
|
||||||
.transpose(1, 2)?;
|
|
||||||
let k = k
|
|
||||||
.reshape((bs, q_length, self.n_heads, dim_per_head))?
|
|
||||||
.transpose(1, 2)?;
|
|
||||||
let v = v
|
|
||||||
.reshape((bs, q_length, self.n_heads, dim_per_head))?
|
|
||||||
.transpose(1, 2)?;
|
|
||||||
|
|
||||||
let q: Tensor = (q / (dim_per_head as f64).sqrt())?;
|
|
||||||
let scores = q.matmul(&k.transpose(2, 3)?.contiguous()?)?;
|
|
||||||
let mask = attention_mask.broadcast_as(scores.shape())?;
|
|
||||||
|
|
||||||
let scores = masked_fill(&scores.to_dtype(DType::F32)?, &mask, f32::NEG_INFINITY)?;
|
|
||||||
let weights = candle_nn::ops::softmax(&scores, candle::D::Minus1)?;
|
|
||||||
|
|
||||||
let context = weights.matmul(&v.contiguous()?)?;
|
|
||||||
let context = context
|
|
||||||
.transpose(1, 2)?
|
|
||||||
.reshape((bs, q_length, self.n_heads * dim_per_head))?
|
|
||||||
.contiguous()?;
|
|
||||||
let context = self.out_lin.forward(&context)?;
|
|
||||||
|
|
||||||
Ok(context)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::upper_case_acronyms)]
|
|
||||||
struct FFN {
|
|
||||||
lin1: Linear,
|
|
||||||
lin2: Linear,
|
|
||||||
activation: HiddenActLayer,
|
|
||||||
span: tracing::Span,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl FFN {
|
|
||||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
|
||||||
let lin1 = linear(config.dim, config.hidden_dim, vb.pp("lin1"))?;
|
|
||||||
let lin2 = linear(config.hidden_dim, config.dim, vb.pp("lin2"))?;
|
|
||||||
Ok(Self {
|
|
||||||
lin1,
|
|
||||||
lin2,
|
|
||||||
activation: HiddenActLayer::new(config.activation),
|
|
||||||
span: tracing::span!(tracing::Level::TRACE, "ffn"),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for FFN {
|
|
||||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
|
||||||
hidden_states
|
|
||||||
.apply(&self.lin1)?
|
|
||||||
.apply(&self.activation)?
|
|
||||||
.apply(&self.lin2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct TransformerBlock {
|
|
||||||
attention: MultiHeadSelfAttention,
|
|
||||||
sa_layer_norm: LayerNorm,
|
|
||||||
ffn: FFN,
|
|
||||||
output_layer_norm: LayerNorm,
|
|
||||||
span: tracing::Span,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TransformerBlock {
|
|
||||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
|
||||||
let attention = MultiHeadSelfAttention::load(vb.pp("attention"), config)?;
|
|
||||||
let sa_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("sa_layer_norm"))?;
|
|
||||||
let ffn = FFN::load(vb.pp("ffn"), config)?;
|
|
||||||
let output_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("output_layer_norm"))?;
|
|
||||||
Ok(Self {
|
|
||||||
attention,
|
|
||||||
sa_layer_norm,
|
|
||||||
ffn,
|
|
||||||
output_layer_norm,
|
|
||||||
span: tracing::span!(tracing::Level::TRACE, "layer"),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TransformerBlock {
|
|
||||||
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
|
||||||
let sa_output = self.attention.forward(hidden_states, attention_mask)?;
|
|
||||||
// TODO: Support cross-attention?
|
|
||||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
|
|
||||||
// TODO: Support something similar to `apply_chunking_to_forward`?
|
|
||||||
let sa_output = sa_output.broadcast_add(hidden_states)?;
|
|
||||||
let sa_output = self.sa_layer_norm.forward(&sa_output)?;
|
|
||||||
|
|
||||||
let ffn_output = self.ffn.forward(&sa_output)?;
|
|
||||||
let ffn_output = (&ffn_output + sa_output)?;
|
|
||||||
let output = self.output_layer_norm.forward(&ffn_output)?;
|
|
||||||
Ok(output)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
|
|
||||||
struct Transformer {
|
|
||||||
layers: Vec<TransformerBlock>,
|
|
||||||
span: tracing::Span,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Transformer {
|
|
||||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
|
||||||
let layers = (0..config.n_layers)
|
|
||||||
.map(|index| TransformerBlock::load(vb.pp(&format!("layer.{index}")), config))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "encoder");
|
|
||||||
Ok(Transformer { layers, span })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Transformer {
|
|
||||||
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
|
||||||
let mut hidden_states = hidden_states.clone();
|
|
||||||
// Use a loop rather than a fold as it's easier to modify when adding debug/...
|
|
||||||
for layer in self.layers.iter() {
|
|
||||||
hidden_states = layer.forward(&hidden_states, attention_mask)?;
|
|
||||||
}
|
|
||||||
Ok(hidden_states)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct DistilBertModel {
|
|
||||||
embeddings: Embeddings,
|
|
||||||
transformer: Transformer,
|
|
||||||
pub device: Device,
|
|
||||||
span: tracing::Span,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DistilBertModel {
|
|
||||||
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
|
||||||
let (embeddings, transformer) = match (
|
|
||||||
Embeddings::load(vb.pp("embeddings"), config),
|
|
||||||
Transformer::load(vb.pp("transformer"), config),
|
|
||||||
) {
|
|
||||||
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
|
|
||||||
(Err(err), _) | (_, Err(err)) => {
|
|
||||||
if let Some(model_type) = &config.model_type {
|
|
||||||
if let (Ok(embeddings), Ok(encoder)) = (
|
|
||||||
Embeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
|
|
||||||
Transformer::load(vb.pp(&format!("{model_type}.transformer")), config),
|
|
||||||
) {
|
|
||||||
(embeddings, encoder)
|
|
||||||
} else {
|
|
||||||
return Err(err);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return Err(err);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(Self {
|
|
||||||
embeddings,
|
|
||||||
transformer,
|
|
||||||
device: vb.device().clone(),
|
|
||||||
span: tracing::span!(tracing::Level::TRACE, "model"),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
|
||||||
let embedding_output = self.embeddings.forward(input_ids)?;
|
|
||||||
let sequence_output = self
|
|
||||||
.transformer
|
|
||||||
.forward(&embedding_output, attention_mask)?;
|
|
||||||
Ok(sequence_output)
|
|
||||||
}
|
|
||||||
}
|
|
@ -4,7 +4,6 @@ pub mod blip;
|
|||||||
pub mod blip_text;
|
pub mod blip_text;
|
||||||
pub mod convmixer;
|
pub mod convmixer;
|
||||||
pub mod dinov2;
|
pub mod dinov2;
|
||||||
pub mod distilbert;
|
|
||||||
pub mod efficientnet;
|
pub mod efficientnet;
|
||||||
pub mod falcon;
|
pub mod falcon;
|
||||||
pub mod jina_bert;
|
pub mod jina_bert;
|
||||||
|
@ -1,15 +1,12 @@
|
|||||||
use candle::{Device, Result, Tensor};
|
use candle::{Device, Result, Tensor};
|
||||||
|
|
||||||
pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
|
pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
|
||||||
if steps == 0 {
|
if steps < 1 {
|
||||||
Tensor::from_vec(Vec::<f64>::new(), steps, &Device::Cpu)
|
candle::bail!("cannot use linspace with steps {steps} <= 1")
|
||||||
} else if steps == 1 {
|
|
||||||
Tensor::from_vec(vec![start], steps, &Device::Cpu)
|
|
||||||
} else {
|
|
||||||
let delta = (stop - start) / (steps - 1) as f64;
|
|
||||||
let vs = (0..steps)
|
|
||||||
.map(|step| start + step as f64 * delta)
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
Tensor::from_vec(vs, steps, &Device::Cpu)
|
|
||||||
}
|
}
|
||||||
|
let delta = (stop - start) / (steps - 1) as f64;
|
||||||
|
let vs = (0..steps)
|
||||||
|
.map(|step| start + step as f64 * delta)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
Tensor::from_vec(vs, steps, &Device::Cpu)
|
||||||
}
|
}
|
||||||
|
@ -138,10 +138,6 @@ impl TrOCRAttention {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reset_kv_cache(&mut self) {
|
|
||||||
self.kv_cache = None
|
|
||||||
}
|
|
||||||
|
|
||||||
fn _shape(&self, tensor: &Tensor, bsz: usize) -> Result<Tensor> {
|
fn _shape(&self, tensor: &Tensor, bsz: usize) -> Result<Tensor> {
|
||||||
tensor
|
tensor
|
||||||
.reshape((bsz, (), self.num_heads, self.head_dim))?
|
.reshape((bsz, (), self.num_heads, self.head_dim))?
|
||||||
@ -243,10 +239,6 @@ impl TrOCRDecoderLayer {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reset_kv_cache(&mut self) {
|
|
||||||
self.self_attn.reset_kv_cache();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(
|
fn forward(
|
||||||
&mut self,
|
&mut self,
|
||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
@ -315,10 +307,6 @@ impl TrOCRDecoder {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reset_kv_cache(&mut self) {
|
|
||||||
self.layers.iter_mut().for_each(|l| l.reset_kv_cache())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn forward(
|
pub fn forward(
|
||||||
&mut self,
|
&mut self,
|
||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
@ -405,10 +393,6 @@ impl TrOCRForCausalLM {
|
|||||||
|
|
||||||
Ok(xs)
|
Ok(xs)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reset_kv_cache(&mut self) {
|
|
||||||
self.decoder.reset_kv_cache();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -447,8 +431,4 @@ impl TrOCRModel {
|
|||||||
self.decoder
|
self.decoder
|
||||||
.forward(xs, Some(encoder_xs), past_kv_len, &mask)
|
.forward(xs, Some(encoder_xs), past_kv_len, &mask)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn reset_kv_cache(&mut self) {
|
|
||||||
self.decoder.reset_kv_cache();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -58,7 +58,8 @@ fn dft<T: Float>(inp: &[T]) -> Vec<T> {
|
|||||||
let n = inp.len();
|
let n = inp.len();
|
||||||
let two_pi = T::PI() + T::PI();
|
let two_pi = T::PI() + T::PI();
|
||||||
|
|
||||||
let mut out = Vec::with_capacity(2 * n);
|
let mut out = Vec::new();
|
||||||
|
out.reserve(2 * n);
|
||||||
let n_t = T::from(n).unwrap();
|
let n_t = T::from(n).unwrap();
|
||||||
for k in 0..n {
|
for k in 0..n {
|
||||||
let k_t = T::from(k).unwrap();
|
let k_t = T::from(k).unwrap();
|
||||||
|
@ -43,4 +43,4 @@ pub const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
|
|||||||
pub const TRANSLATE_TOKEN: &str = "<|translate|>";
|
pub const TRANSLATE_TOKEN: &str = "<|translate|>";
|
||||||
pub const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
|
pub const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
|
||||||
pub const EOT_TOKEN: &str = "<|endoftext|>";
|
pub const EOT_TOKEN: &str = "<|endoftext|>";
|
||||||
pub const NO_SPEECH_TOKENS: [&str; 2] = ["<|nocaptions|>", "<|nospeech|>"];
|
pub const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
|
||||||
|
@ -277,12 +277,8 @@ impl DecoderLayer {
|
|||||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
||||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||||
let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("ln1"))?;
|
||||||
let ln2 = RmsNorm::new(
|
let ln2 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("ln2"))?;
|
||||||
cfg.hidden_size,
|
|
||||||
cfg.rms_norm_eps,
|
|
||||||
vb.pp("post_attention_layernorm"),
|
|
||||||
)?;
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
self_attn,
|
self_attn,
|
||||||
mlp,
|
mlp,
|
||||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
|||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||||
|
|
||||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
|||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
||||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
|
|
||||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
|||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||||
|
|
||||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
|||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
||||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
|
|
||||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
|||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
|
|
||||||
# App crates.
|
# App crates.
|
||||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
|||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||||
|
|
||||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
|||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||||
|
|
||||||
|
@ -59,7 +59,8 @@ fn dft<T: Float>(inp: &[T]) -> Vec<T> {
|
|||||||
let n = inp.len();
|
let n = inp.len();
|
||||||
let two_pi = T::PI() + T::PI();
|
let two_pi = T::PI() + T::PI();
|
||||||
|
|
||||||
let mut out = Vec::with_capacity(2 * n);
|
let mut out = Vec::new();
|
||||||
|
out.reserve(2 * n);
|
||||||
let n_t = T::from(n).unwrap();
|
let n_t = T::from(n).unwrap();
|
||||||
for k in 0..n {
|
for k in 0..n {
|
||||||
let k_t = T::from(k).unwrap();
|
let k_t = T::from(k).unwrap();
|
||||||
|
@ -129,13 +129,7 @@ impl Decoder {
|
|||||||
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
|
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
|
||||||
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
|
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
|
||||||
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
|
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
|
||||||
let no_speech_token = m::NO_SPEECH_TOKENS
|
let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
|
||||||
.iter()
|
|
||||||
.find_map(|token| token_id(&tokenizer, token).ok());
|
|
||||||
let no_speech_token = match no_speech_token {
|
|
||||||
None => anyhow::bail!("unable to find any non-speech token"),
|
|
||||||
Some(n) => n,
|
|
||||||
};
|
|
||||||
let seed = 299792458;
|
let seed = 299792458;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
model,
|
model,
|
||||||
|
@ -9,8 +9,8 @@ categories.workspace = true
|
|||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
|
@ -7,7 +7,7 @@ keywords.workspace = true
|
|||||||
categories.workspace = true
|
categories.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
rand = { workspace = true }
|
rand = { workspace = true }
|
||||||
getrandom = { version = "0.2", features = ["js"] }
|
getrandom = { version = "0.2", features = ["js"] }
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use candle::{
|
use candle::{
|
||||||
quantized::{self, k_quants, GgmlDType, GgmlType},
|
quantized::{self, k_quants, GgmlDType, GgmlType},
|
||||||
test_utils::to_vec2_round,
|
test_utils::to_vec2_round,
|
||||||
Device, Module, Result, Tensor,
|
Device, Result, Tensor,
|
||||||
};
|
};
|
||||||
|
|
||||||
use wasm_bindgen_test::*;
|
use wasm_bindgen_test::*;
|
||||||
|
Reference in New Issue
Block a user