mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Compare commits
57 Commits
metal2-tmp
...
tmp4
Author | SHA1 | Date | |
---|---|---|---|
a0282751d5 | |||
da0af3cb3e | |||
803ac8405b | |||
6e25822d4f | |||
2ca086939f | |||
4349ff1fc2 | |||
7c3cfd1086 | |||
e2eb6590ed | |||
481c45d78d | |||
14a2bdc062 | |||
bfa7c8fc01 | |||
762e996ce6 | |||
ca19a9af62 | |||
ec23427d60 | |||
f83e14f68d | |||
c7e613ab5e | |||
8f63f68289 | |||
1edc3ddf24 | |||
b380657bfe | |||
60f624a902 | |||
8d6c6de8e0 | |||
7ec345c2eb | |||
671fc29b36 | |||
dc64adb8e4 | |||
c66e5d4716 | |||
bd3b243725 | |||
2813fb5dbc | |||
7cfffcac10 | |||
38de52bc4b | |||
d46670f7c0 | |||
f710fab02e | |||
f82bf2d915 | |||
df6814f34e | |||
39406a6721 | |||
976ad9f9c2 | |||
a4c4a56429 | |||
f49bf6a81d | |||
992a788da1 | |||
8d8f48c60c | |||
d31f11035f | |||
9ab3f9729f | |||
a1f41ab37b | |||
92a05b51cf | |||
c6763e3b41 | |||
347e31c9ff | |||
f4fcf60900 | |||
12561b31d3 | |||
a209ce8ceb | |||
f1e678b39c | |||
a007f8fdb4 | |||
2341aa079e | |||
9e666d4229 | |||
1b12142a02 | |||
d2c3f14773 | |||
26c4e5bf1d | |||
18d30005c5 | |||
6958384327 |
@ -19,7 +19,7 @@ exclude = [
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.3.0"
|
version = "0.3.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "Minimalist ML framework."
|
description = "Minimalist ML framework."
|
||||||
repository = "https://github.com/huggingface/candle"
|
repository = "https://github.com/huggingface/candle"
|
||||||
@ -51,6 +51,7 @@ rayon = "1.7.0"
|
|||||||
rusttype = { version = "0.9", default-features = false }
|
rusttype = { version = "0.9", default-features = false }
|
||||||
safetensors = "0.3.1"
|
safetensors = "0.3.1"
|
||||||
serde = { version = "1.0.171", features = ["derive"] }
|
serde = { version = "1.0.171", features = ["derive"] }
|
||||||
|
serde_plain = "1.0.2"
|
||||||
serde_json = "1.0.99"
|
serde_json = "1.0.99"
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
tokenizers = { version = "0.13.4", default-features = false }
|
tokenizers = { version = "0.13.4", default-features = false }
|
||||||
@ -60,8 +61,7 @@ tracing-subscriber = "0.3.7"
|
|||||||
wav = "1.0.0"
|
wav = "1.0.0"
|
||||||
yoke = { version = "0.7.2", features = ["derive"] }
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "0.6.6", default-features = false }
|
zip = { version = "0.6.6", default-features = false }
|
||||||
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
|
||||||
metal = { path = "../metal-rs", features = ["mps"] }
|
|
||||||
|
|
||||||
[profile.release-with-debug]
|
[profile.release-with-debug]
|
||||||
inherits = "release"
|
inherits = "release"
|
||||||
|
22
README.md
22
README.md
@ -69,6 +69,8 @@ We also provide a some command line based examples using state of the art models
|
|||||||
performance larger than all publicly available 13b models as of 2023-09-28.
|
performance larger than all publicly available 13b models as of 2023-09-28.
|
||||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
||||||
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
||||||
|
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
|
||||||
|
(English/Chinese) general LLMs with 6b and 34b parameters.
|
||||||
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
|
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
|
||||||
the LLaMA model using the same quantization techniques as
|
the LLaMA model using the same quantization techniques as
|
||||||
[llama.cpp](https://github.com/ggerganov/llama.cpp).
|
[llama.cpp](https://github.com/ggerganov/llama.cpp).
|
||||||
@ -137,16 +139,16 @@ And then head over to
|
|||||||
<!--- ANCHOR: useful_libraries --->
|
<!--- ANCHOR: useful_libraries --->
|
||||||
|
|
||||||
## Useful External Resources
|
## Useful External Resources
|
||||||
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): a
|
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): A
|
||||||
very detailed tutorial showing how to convert a PyTorch model to Candle.
|
very detailed tutorial showing how to convert a PyTorch model to Candle.
|
||||||
- [`optimisers`](https://github.com/KGrewal1/optimisers): a collection of optimisers
|
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): Efficient and ergonomic LoRA implemenation for Candle. `candle-lora` has
|
||||||
|
out-of-the-box LoRA support for many models from Candle, which can be found [here](https://github.com/EricLBuehler/candle-lora/tree/master/candle-lora-transformers/examples).
|
||||||
|
- [`optimisers`](https://github.com/KGrewal1/optimisers): A collection of optimisers
|
||||||
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
|
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
|
||||||
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): a LoRA implementation
|
|
||||||
that conforms to the official `peft` implementation.
|
|
||||||
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
|
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
|
||||||
serving local LLMs including an OpenAI compatible API server.
|
serving local LLMs including an OpenAI compatible API server.
|
||||||
- [`candle-ext`](https://github.com/mokeyish/candle-ext): an extension library to Candle that provides PyTorch functions not currently available in Candle.
|
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
|
||||||
- [`kalosm`](https://github.com/floneum/floneum/tree/master/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
||||||
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
||||||
|
|
||||||
If you have an addition to this list, please submit a pull request.
|
If you have an addition to this list, please submit a pull request.
|
||||||
@ -174,8 +176,14 @@ If you have an addition to this list, please submit a pull request.
|
|||||||
- StableLM-3B-4E1T.
|
- StableLM-3B-4E1T.
|
||||||
- Replit-code-v1.5-3B.
|
- Replit-code-v1.5-3B.
|
||||||
- Bert.
|
- Bert.
|
||||||
|
- Yi-6B and Yi-34B.
|
||||||
|
- Quantized LLMs.
|
||||||
|
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
||||||
|
- Mistral 7b, and 7b instruct.
|
||||||
|
- Zephyr 7b a and b (Mistral based).
|
||||||
|
- OpenChat 3.5 (Mistral based).
|
||||||
- Text to text.
|
- Text to text.
|
||||||
- T5 and its variants: FlanT5, MADLAD400 (translation), CoEdit (Grammar correction).
|
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
|
||||||
- Marian MT (Machine Translation).
|
- Marian MT (Machine Translation).
|
||||||
- Whisper (multi-lingual support).
|
- Whisper (multi-lingual support).
|
||||||
- Text to image.
|
- Text to image.
|
||||||
|
@ -11,11 +11,11 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||||
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
candle-datasets = { path = "../candle-datasets", version = "0.3.1" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
||||||
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
candle-transformers = { path = "../candle-transformers", version = "0.3.1" }
|
||||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true }
|
||||||
safetensors = { workspace = true }
|
safetensors = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
|
@ -12,8 +12,8 @@ readme = "README.md"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
|
candle-kernels = { path = "../candle-kernels", version = "0.3.1", optional = true }
|
||||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.1", optional = true }
|
||||||
metal = { workspace = true, optional = true}
|
metal = { workspace = true, optional = true}
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
gemm = { workspace = true }
|
gemm = { workspace = true }
|
||||||
|
@ -8,11 +8,10 @@ use anyhow::Result;
|
|||||||
use candle_core::{Device, Tensor};
|
use candle_core::{Device, Tensor};
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
|
let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?;
|
||||||
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
|
let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?;
|
||||||
let start = std::time::Instant::now();
|
let new_a = a.slice_scatter(&b, 1, 2)?;
|
||||||
let res = inp.conv2d(&w, 0, 1, 1, 1)?;
|
assert_eq!(a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||||
println!("{:?}", start.elapsed());
|
assert_eq!(new_a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||||
println!("{res:?}");
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -104,36 +104,30 @@ impl From<&Tensor> for TensorIndexer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! impl_from_range {
|
trait RB: RangeBounds<usize> {}
|
||||||
($range_type:ty) => {
|
impl RB for Range<usize> {}
|
||||||
impl From<$range_type> for TensorIndexer {
|
impl RB for RangeFrom<usize> {}
|
||||||
fn from(range: $range_type) -> Self {
|
impl RB for RangeFull {}
|
||||||
use std::ops::Bound::*;
|
impl RB for RangeInclusive<usize> {}
|
||||||
|
impl RB for RangeTo<usize> {}
|
||||||
|
impl RB for RangeToInclusive<usize> {}
|
||||||
|
|
||||||
|
impl<T: RB> From<T> for TensorIndexer {
|
||||||
|
fn from(range: T) -> Self {
|
||||||
|
use std::ops::Bound::*;
|
||||||
let start = match range.start_bound() {
|
let start = match range.start_bound() {
|
||||||
Included(idx) => Included(*idx),
|
Included(idx) => Included(*idx),
|
||||||
Excluded(idx) => Excluded(*idx),
|
Excluded(idx) => Excluded(*idx),
|
||||||
Unbounded => Unbounded,
|
Unbounded => Unbounded,
|
||||||
};
|
};
|
||||||
|
|
||||||
let end = match range.end_bound() {
|
let end = match range.end_bound() {
|
||||||
Included(idx) => Included(*idx),
|
Included(idx) => Included(*idx),
|
||||||
Excluded(idx) => Excluded(*idx),
|
Excluded(idx) => Excluded(*idx),
|
||||||
Unbounded => Unbounded,
|
Unbounded => Unbounded,
|
||||||
};
|
};
|
||||||
|
|
||||||
TensorIndexer::Narrow(start, end)
|
TensorIndexer::Narrow(start, end)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
impl_from_range!(Range<usize>);
|
|
||||||
impl_from_range!(RangeFrom<usize>);
|
|
||||||
impl_from_range!(RangeFull);
|
|
||||||
impl_from_range!(RangeInclusive<usize>);
|
|
||||||
impl_from_range!(RangeTo<usize>);
|
|
||||||
impl_from_range!(RangeToInclusive<usize>);
|
|
||||||
|
|
||||||
/// Trait used to implement multiple signatures for ease of use of the slicing
|
/// Trait used to implement multiple signatures for ease of use of the slicing
|
||||||
/// of a tensor
|
/// of a tensor
|
||||||
|
@ -123,12 +123,6 @@ pub trait Module {
|
|||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for quantized::QMatMul {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
self.forward(xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
self(xs)
|
self(xs)
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -593,7 +593,8 @@ unary_op!(Recip, "recip", v, v.recip());
|
|||||||
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
||||||
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
||||||
|
|
||||||
/// `gelu` operation
|
/// Tanh based approximation of the `gelu` operation
|
||||||
|
/// GeluErf is the more precise one.
|
||||||
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||||
impl UnaryOpT for Gelu {
|
impl UnaryOpT for Gelu {
|
||||||
const NAME: &'static str = "gelu";
|
const NAME: &'static str = "gelu";
|
||||||
|
@ -307,8 +307,8 @@ impl crate::CustomOp1 for QTensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl QMatMul {
|
impl crate::Module for QMatMul {
|
||||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
match self {
|
match self {
|
||||||
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
||||||
Self::Tensor(w) => {
|
Self::Tensor(w) => {
|
||||||
|
@ -157,8 +157,6 @@ pub(crate) fn from_storage<S: Into<Shape>>(
|
|||||||
) -> Tensor {
|
) -> Tensor {
|
||||||
let dtype = storage.dtype();
|
let dtype = storage.dtype();
|
||||||
let device = storage.device();
|
let device = storage.device();
|
||||||
let shape = shape.into();
|
|
||||||
// println!("{:?} {storage:?}", shape);
|
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage: Arc::new(RwLock::new(storage)),
|
storage: Arc::new(RwLock::new(storage)),
|
||||||
@ -168,11 +166,7 @@ pub(crate) fn from_storage<S: Into<Shape>>(
|
|||||||
dtype,
|
dtype,
|
||||||
device,
|
device,
|
||||||
};
|
};
|
||||||
let result = Tensor(Arc::new(tensor_));
|
Tensor(Arc::new(tensor_))
|
||||||
// todo!(" from_storage");
|
|
||||||
// let result = result.to_device(&Device::Cpu).unwrap();
|
|
||||||
// todo!(" {result}");
|
|
||||||
result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
@ -862,6 +856,20 @@ impl Tensor {
|
|||||||
self.sum_impl(mean_dims, false)? * scale
|
self.sum_impl(mean_dims, false)? * scale
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the unbiased variance over the selected dimension.
|
||||||
|
pub fn var_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||||
|
let dim = dim.to_index(self.shape(), "var")?;
|
||||||
|
let mean = self.mean_keepdim(dim)?;
|
||||||
|
let squares = self.broadcast_sub(&mean)?.sqr()?;
|
||||||
|
squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the unbiased variance over the selected dimension.
|
||||||
|
pub fn var<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||||
|
let dim = dim.to_index(self.shape(), "var")?;
|
||||||
|
self.var_keepdim(dim)?.squeeze(dim)
|
||||||
|
}
|
||||||
|
|
||||||
/// Gathers the maximum value across the selected dimension. The resulting shape has the same
|
/// Gathers the maximum value across the selected dimension. The resulting shape has the same
|
||||||
/// number of dimensions as the original tensor and the select dimension has a single element.
|
/// number of dimensions as the original tensor and the select dimension has a single element.
|
||||||
pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||||
@ -1855,7 +1863,10 @@ impl Tensor {
|
|||||||
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
|
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
|
||||||
}
|
}
|
||||||
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||||
(Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
(Storage::Metal(storage), Device::Cpu) => {
|
||||||
|
println!("{storage:?} - {:?}", storage.to_cpu_storage()?);
|
||||||
|
Storage::Cpu(storage.to_cpu_storage()?)
|
||||||
|
}
|
||||||
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
|
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
|
||||||
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
|
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
|
||||||
// are the same.
|
// are the same.
|
||||||
@ -2446,6 +2457,110 @@ impl Tensor {
|
|||||||
Ok(naxis as usize)
|
Ok(naxis as usize)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a lower triangular matrix of ones of size n by n.
|
||||||
|
pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||||
|
let t = Tensor::arange(0u32, n as u32, device)?;
|
||||||
|
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
||||||
|
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
||||||
|
t1.le(&t2)?.to_dtype(dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns an upper triangular matrix of ones of size n by n.
|
||||||
|
pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||||
|
let t = Tensor::arange(0u32, n as u32, device)?;
|
||||||
|
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
||||||
|
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
||||||
|
t1.ge(&t2)?.to_dtype(dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a matrix with a diagonal of ones of size n by n.
|
||||||
|
pub fn eye(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||||
|
let t = Tensor::arange(0u32, n as u32, device)?;
|
||||||
|
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
||||||
|
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
||||||
|
t1.eq(&t2)?.to_dtype(dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the cumulative sum of elements of the input tensor summed over the specified
|
||||||
|
/// dimension.
|
||||||
|
///
|
||||||
|
/// This operation is most efficient when dim is the last dimension of the tensor.
|
||||||
|
pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||||
|
let dim = dim.to_index(self.shape(), "cumsum")?;
|
||||||
|
let rank = self.rank();
|
||||||
|
if rank == 0 {
|
||||||
|
return Ok(self.clone());
|
||||||
|
}
|
||||||
|
let n_axis = self.dim(dim)?;
|
||||||
|
let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?;
|
||||||
|
if rank == 1 {
|
||||||
|
self.unsqueeze(0)?.matmul(&triu)?.squeeze(0)
|
||||||
|
} else {
|
||||||
|
let last = rank - 1;
|
||||||
|
let t = self.transpose(dim, last)?;
|
||||||
|
let t = t.broadcast_matmul(&triu)?;
|
||||||
|
t.transpose(dim, last)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a copy of `self` where the values within `ranges` have been replaced with the
|
||||||
|
/// content of `src`.
|
||||||
|
pub fn slice_assign<D: std::ops::RangeBounds<usize>>(
|
||||||
|
&self,
|
||||||
|
ranges: &[D],
|
||||||
|
src: &Tensor,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let src_dims = src.dims();
|
||||||
|
let self_dims = self.dims();
|
||||||
|
if self_dims.len() != src_dims.len() {
|
||||||
|
crate::bail!(
|
||||||
|
"slice-assign requires input with the same rank {} <> {}",
|
||||||
|
self_dims.len(),
|
||||||
|
src_dims.len()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if self_dims.len() != ranges.len() {
|
||||||
|
crate::bail!(
|
||||||
|
"slice-assign requires input with the same rank as there are ranges {} <> {}",
|
||||||
|
self_dims.len(),
|
||||||
|
ranges.len()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
let mut src = src.clone();
|
||||||
|
let mut mask = Self::ones(src.shape(), DType::U8, src.device())?;
|
||||||
|
for (i, range) in ranges.iter().enumerate() {
|
||||||
|
let start_included = match range.start_bound() {
|
||||||
|
std::ops::Bound::Unbounded => 0,
|
||||||
|
std::ops::Bound::Included(v) => *v,
|
||||||
|
std::ops::Bound::Excluded(v) => *v + 1,
|
||||||
|
};
|
||||||
|
let end_excluded = match range.end_bound() {
|
||||||
|
std::ops::Bound::Unbounded => self_dims[i],
|
||||||
|
std::ops::Bound::Included(v) => *v + 1,
|
||||||
|
std::ops::Bound::Excluded(v) => *v,
|
||||||
|
};
|
||||||
|
if end_excluded <= start_included {
|
||||||
|
crate::bail!(
|
||||||
|
"slice-assign: empty range for dim {i}, {start_included} {end_excluded}"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if self_dims[i] < end_excluded {
|
||||||
|
crate::bail!(
|
||||||
|
"slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
|
||||||
|
self_dims[i]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if end_excluded - start_included != src_dims[i] {
|
||||||
|
crate::bail!(
|
||||||
|
"slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?;
|
||||||
|
mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?
|
||||||
|
}
|
||||||
|
mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! bin_trait {
|
macro_rules! bin_trait {
|
||||||
|
@ -4,7 +4,7 @@ use crate::{Result, Tensor};
|
|||||||
macro_rules! test_device {
|
macro_rules! test_device {
|
||||||
// TODO: Switch to generating the two last arguments automatically once concat_idents is
|
// TODO: Switch to generating the two last arguments automatically once concat_idents is
|
||||||
// stable. https://github.com/rust-lang/rust/issues/29599
|
// stable. https://github.com/rust-lang/rust/issues/29599
|
||||||
($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => {
|
($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => {
|
||||||
#[test]
|
#[test]
|
||||||
fn $test_cpu() -> Result<()> {
|
fn $test_cpu() -> Result<()> {
|
||||||
$fn_name(&Device::Cpu)
|
$fn_name(&Device::Cpu)
|
||||||
@ -15,6 +15,12 @@ macro_rules! test_device {
|
|||||||
fn $test_cuda() -> Result<()> {
|
fn $test_cuda() -> Result<()> {
|
||||||
$fn_name(&Device::new_cuda(0)?)
|
$fn_name(&Device::new_cuda(0)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
#[test]
|
||||||
|
fn $test_metal() -> Result<()> {
|
||||||
|
$fn_name(&Device::new_metal(0)?)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -563,14 +563,35 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
|
test_device!(conv1d, conv1d_cpu, conv1d_gpu, conv1d_metal);
|
||||||
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
|
test_device!(
|
||||||
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
|
conv1d_small,
|
||||||
|
conv1d_small_cpu,
|
||||||
|
conv1d_small_gpu,
|
||||||
|
conv1d_small_metal
|
||||||
|
);
|
||||||
|
test_device!(conv2d, conv2d_cpu, conv2d_gpu, conv2d_metal);
|
||||||
test_device!(
|
test_device!(
|
||||||
conv2d_non_square,
|
conv2d_non_square,
|
||||||
conv2d_non_square_cpu,
|
conv2d_non_square_cpu,
|
||||||
conv2d_non_square_gpu
|
conv2d_non_square_gpu,
|
||||||
|
conv2d_non_square_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
conv2d_small,
|
||||||
|
conv2d_small_cpu,
|
||||||
|
conv2d_small_gpu,
|
||||||
|
conv2d_small_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
conv2d_smaller,
|
||||||
|
conv2d_smaller_cpu,
|
||||||
|
conv2d_smaller_gpu,
|
||||||
|
conv2d_smaller_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
conv2d_grad,
|
||||||
|
conv2d_grad_cpu,
|
||||||
|
conv2d_grad_gpu,
|
||||||
|
conv2_grad_metal
|
||||||
);
|
);
|
||||||
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
|
|
||||||
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
|
|
||||||
test_device!(conv2d_grad, conv2d_grad_cpu, conv2d_grad_gpu);
|
|
||||||
|
@ -315,9 +315,29 @@ fn binary_grad(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
|
test_device!(
|
||||||
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
|
simple_grad,
|
||||||
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
|
simple_grad_cpu,
|
||||||
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
|
simple_grad_gpu,
|
||||||
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu);
|
simple_grad_metal
|
||||||
test_device!(binary_grad, binary_grad_cpu, binary_grad_gpu);
|
);
|
||||||
|
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu, sum_grad_metal);
|
||||||
|
test_device!(
|
||||||
|
matmul_grad,
|
||||||
|
matmul_grad_cpu,
|
||||||
|
matmul_grad_gpu,
|
||||||
|
matmul_grad_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
grad_descent,
|
||||||
|
grad_descent_cpu,
|
||||||
|
grad_descent_gpu,
|
||||||
|
grad_descent_metal
|
||||||
|
);
|
||||||
|
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu, unary_grad_metal);
|
||||||
|
test_device!(
|
||||||
|
binary_grad,
|
||||||
|
binary_grad_cpu,
|
||||||
|
binary_grad_gpu,
|
||||||
|
binary_grad_metal
|
||||||
|
);
|
||||||
|
@ -91,3 +91,32 @@ fn index_3d() -> Result<()> {
|
|||||||
assert_eq!(tensor.i((1, .., 3))?.to_vec1::<u32>()?, &[15, 19, 23]);
|
assert_eq!(tensor.i((1, .., 3))?.to_vec1::<u32>()?, &[15, 19, 23]);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn slice_assign() -> Result<()> {
|
||||||
|
let dev = Device::Cpu;
|
||||||
|
|
||||||
|
let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?;
|
||||||
|
let src = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((3, 2))?;
|
||||||
|
let out = tensor.slice_assign(&[1..4, 3..5], &src)?;
|
||||||
|
assert_eq!(
|
||||||
|
out.to_vec2::<u32>()?,
|
||||||
|
&[
|
||||||
|
[0, 1, 2, 3, 4],
|
||||||
|
[5, 6, 7, 0, 1],
|
||||||
|
[10, 11, 12, 2, 3],
|
||||||
|
[15, 16, 17, 4, 5]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
let out = tensor.slice_assign(&[0..3, 0..2], &src)?;
|
||||||
|
assert_eq!(
|
||||||
|
out.to_vec2::<u32>()?,
|
||||||
|
&[
|
||||||
|
[0, 1, 2, 3, 4],
|
||||||
|
[2, 3, 7, 8, 9],
|
||||||
|
[4, 5, 12, 13, 14],
|
||||||
|
[15, 16, 17, 18, 19]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
@ -49,7 +49,7 @@ fn contiguous(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(contiguous, contiguous_cpu, contiguous_gpu);
|
test_device!(contiguous, contiguous_cpu, contiguous_gpu, contiguous_metal);
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn strided_blocks() -> Result<()> {
|
fn strided_blocks() -> Result<()> {
|
||||||
|
@ -98,15 +98,17 @@ fn upsample_nearest2d(dev: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu);
|
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu, avg_pool2d_metal);
|
||||||
test_device!(
|
test_device!(
|
||||||
avg_pool2d_pytorch,
|
avg_pool2d_pytorch,
|
||||||
avg_pool2d_pytorch_cpu,
|
avg_pool2d_pytorch_cpu,
|
||||||
avg_pool2d_pytorch_gpu
|
avg_pool2d_pytorch_gpu,
|
||||||
|
avg_pool2d_pytorch_metal
|
||||||
);
|
);
|
||||||
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);
|
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu, max_pool2d_metal);
|
||||||
test_device!(
|
test_device!(
|
||||||
upsample_nearest2d,
|
upsample_nearest2d,
|
||||||
upsample_nearest2d_cpu,
|
upsample_nearest2d_cpu,
|
||||||
upsample_nearest2d_gpu
|
upsample_nearest2d_gpu,
|
||||||
|
upsample_nearest2d_metal
|
||||||
);
|
);
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use candle_core::{
|
use candle_core::{
|
||||||
quantized::{self, GgmlDType},
|
quantized::{self, GgmlDType},
|
||||||
test_utils::to_vec2_round,
|
test_utils::to_vec2_round,
|
||||||
Device, Result, Tensor,
|
Device, Module, Result, Tensor,
|
||||||
};
|
};
|
||||||
use quantized::{k_quants, GgmlType};
|
use quantized::{k_quants, GgmlType};
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
|
@ -180,6 +180,22 @@ fn transpose(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn var(device: &Device) -> Result<()> {
|
||||||
|
// Values taken from https://pytorch.org/docs/stable/generated/torch.var.html
|
||||||
|
let data = &[
|
||||||
|
[0.2035f32, 1.2959, 1.8101, -0.4644],
|
||||||
|
[1.5027, -0.3270, 0.5905, 0.6538],
|
||||||
|
[-1.5745, 1.3330, -0.5596, -0.6548],
|
||||||
|
[0.1264, -0.5080, 1.6420, 0.1992],
|
||||||
|
];
|
||||||
|
let tensor = Tensor::new(data, device)?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec2_round(&tensor.var_keepdim(1)?, 4)?,
|
||||||
|
&[[1.0631], [0.559], [1.4893], [0.8258]]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn sum(device: &Device) -> Result<()> {
|
fn sum(device: &Device) -> Result<()> {
|
||||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||||
let tensor = Tensor::new(data, device)?;
|
let tensor = Tensor::new(data, device)?;
|
||||||
@ -1054,34 +1070,60 @@ fn randn(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(zeros, zeros_cpu, zeros_gpu);
|
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
||||||
test_device!(ones, ones_cpu, ones_gpu);
|
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
||||||
test_device!(arange, arange_cpu, arange_gpu);
|
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
|
||||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
||||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
||||||
test_device!(narrow, narrow_cpu, narrow_gpu);
|
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
|
||||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu);
|
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
|
||||||
test_device!(cat, cat_cpu, cat_gpu);
|
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
|
||||||
test_device!(sum, sum_cpu, sum_gpu);
|
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
|
||||||
test_device!(min, min_cpu, min_gpu);
|
test_device!(min, min_cpu, min_gpu, min_metal);
|
||||||
test_device!(max, max_cpu, max_gpu);
|
test_device!(max, max_cpu, max_gpu, max_metal);
|
||||||
test_device!(argmax, argmax_cpu, argmax_gpu);
|
test_device!(argmax, argmax_cpu, argmax_gpu, argmax_metal);
|
||||||
test_device!(argmin, argmin_cpu, argmin_gpu);
|
test_device!(argmin, argmin_cpu, argmin_gpu, argmin_metal);
|
||||||
test_device!(transpose, transpose_cpu, transpose_gpu);
|
test_device!(transpose, transpose_cpu, transpose_gpu, transpose_metal);
|
||||||
test_device!(unary_op, unary_op_cpu, unary_op_gpu);
|
test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
|
||||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
|
||||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
|
||||||
test_device!(cmp, cmp_cpu, cmp_gpu);
|
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
|
||||||
test_device!(matmul, matmul_cpu, matmul_gpu);
|
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
||||||
test_device!(broadcast_matmul, broadcast_matmul_cpu, broadcast_matmul_gpu);
|
test_device!(
|
||||||
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
|
broadcast_matmul,
|
||||||
test_device!(index_select, index_select_cpu, index_select_gpu);
|
broadcast_matmul_cpu,
|
||||||
test_device!(index_add, index_add_cpu, index_add_gpu);
|
broadcast_matmul_gpu,
|
||||||
test_device!(gather, gather_cpu, gather_gpu);
|
broadcast_matmul_metal
|
||||||
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
);
|
||||||
test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
|
test_device!(
|
||||||
test_device!(randn, randn_cpu, randn_gpu);
|
broadcasting,
|
||||||
test_device!(clamp, clamp_cpu, clamp_gpu);
|
broadcasting_cpu,
|
||||||
|
broadcasting_gpu,
|
||||||
|
broadcasting_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
index_select,
|
||||||
|
index_select_cpu,
|
||||||
|
index_select_gpu,
|
||||||
|
index_select_metal
|
||||||
|
);
|
||||||
|
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
|
||||||
|
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
|
||||||
|
test_device!(
|
||||||
|
scatter_add,
|
||||||
|
scatter_add_cpu,
|
||||||
|
scatter_add_gpu,
|
||||||
|
scatter_add_metal
|
||||||
|
);
|
||||||
|
test_device!(
|
||||||
|
slice_scatter,
|
||||||
|
slice_scatter_cpu,
|
||||||
|
slice_scatter_gpu,
|
||||||
|
slice_scatter_metal
|
||||||
|
);
|
||||||
|
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
||||||
|
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
||||||
|
test_device!(var, var_cpu, var_gpu, var_metal);
|
||||||
|
|
||||||
// There was originally a bug on the CPU implementation for randn
|
// There was originally a bug on the CPU implementation for randn
|
||||||
// https://github.com/huggingface/candle/issues/381
|
// https://github.com/huggingface/candle/issues/381
|
||||||
@ -1117,3 +1159,65 @@ fn i64_abs() -> Result<()> {
|
|||||||
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
|
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tril_triu_eye() -> Result<()> {
|
||||||
|
let t = Tensor::tril2(4, DType::F32, &Device::Cpu)?;
|
||||||
|
assert_eq!(
|
||||||
|
t.to_vec2::<f32>()?,
|
||||||
|
[
|
||||||
|
[1.0, 0.0, 0.0, 0.0],
|
||||||
|
[1.0, 1.0, 0.0, 0.0],
|
||||||
|
[1.0, 1.0, 1.0, 0.0],
|
||||||
|
[1.0, 1.0, 1.0, 1.0]
|
||||||
|
],
|
||||||
|
);
|
||||||
|
let t = Tensor::triu2(4, DType::F32, &Device::Cpu)?;
|
||||||
|
assert_eq!(
|
||||||
|
t.to_vec2::<f32>()?,
|
||||||
|
[
|
||||||
|
[1.0, 1.0, 1.0, 1.0],
|
||||||
|
[0.0, 1.0, 1.0, 1.0],
|
||||||
|
[0.0, 0.0, 1.0, 1.0],
|
||||||
|
[0.0, 0.0, 0.0, 1.0]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
let t = Tensor::eye(4, DType::F32, &Device::Cpu)?;
|
||||||
|
assert_eq!(
|
||||||
|
t.to_vec2::<f32>()?,
|
||||||
|
[
|
||||||
|
[1.0, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 1.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 1.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 1.0]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cumsum() -> Result<()> {
|
||||||
|
let t = &[3f32, 1., 4., 1., 5.];
|
||||||
|
let t = Tensor::new(t, &Device::Cpu)?;
|
||||||
|
assert_eq!(t.cumsum(0)?.to_vec1::<f32>()?, [3., 4., 8., 9., 14.]);
|
||||||
|
let t = t.unsqueeze(1)?;
|
||||||
|
assert_eq!(
|
||||||
|
t.cumsum(0)?.to_vec2::<f32>()?,
|
||||||
|
[[3.0], [4.0], [8.0], [9.0], [14.0]]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
t.cumsum(1)?.to_vec2::<f32>()?,
|
||||||
|
[[3.0], [1.0], [4.0], [1.0], [5.0]]
|
||||||
|
);
|
||||||
|
let t = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||||
|
let t = Tensor::new(t, &Device::Cpu)?;
|
||||||
|
assert_eq!(
|
||||||
|
t.cumsum(1)?.to_vec2::<f32>()?,
|
||||||
|
[[3.0, 4.0, 8.0, 9.0, 14.0], [2.0, 3.0, 10.0, 18.0, 20.0]],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
t.cumsum(0)?.to_vec2::<f32>()?,
|
||||||
|
[[3.0, 1.0, 4.0, 1.0, 5.0], [5.0, 2.0, 11.0, 9.0, 7.0]]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
@ -11,8 +11,8 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
||||||
hf-hub = { workspace = true}
|
hf-hub = { workspace = true}
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
memmap2 = { workspace = true }
|
memmap2 = { workspace = true }
|
||||||
|
@ -11,12 +11,12 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||||
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
candle-datasets = { path = "../candle-datasets", version = "0.3.1" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
||||||
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
candle-transformers = { path = "../candle-transformers", version = "0.3.1" }
|
||||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true }
|
||||||
candle-onnx = { path = "../candle-onnx", version = "0.3.0", optional = true }
|
candle-onnx = { path = "../candle-onnx", version = "0.3.1", optional = true }
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
half = { workspace = true, optional = true }
|
half = { workspace = true, optional = true }
|
||||||
image = { workspace = true }
|
image = { workspace = true }
|
||||||
@ -57,6 +57,7 @@ flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
|||||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||||
onnx = ["candle-onnx"]
|
onnx = ["candle-onnx"]
|
||||||
|
metal = ["candle/metal", "candle-nn/metal"]
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
name = "llama_multiprocess"
|
name = "llama_multiprocess"
|
||||||
|
22
candle-examples/examples/distilbert/README.md
Normal file
22
candle-examples/examples/distilbert/README.md
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
# candle-distilbert
|
||||||
|
|
||||||
|
DistilBert is a distiled version of the Bert model.
|
||||||
|
|
||||||
|
## Sentence embeddings
|
||||||
|
|
||||||
|
DistilBert is used to compute the sentence embeddings for a prompt. The model weights
|
||||||
|
are downloaded from the hub on the first run.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example distilbert --release -- --prompt "Here is a test sentence"
|
||||||
|
|
||||||
|
> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441],
|
||||||
|
> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244],
|
||||||
|
> [ 0.0702, -0.1311, -0.4914, ..., 0.3483, -0.6194, 0.1829],
|
||||||
|
> ...
|
||||||
|
> [ 0.2993, -0.0106, -0.4640, ..., 0.2844, -0.6732, 0.0042],
|
||||||
|
> [ 0.1066, -0.0081, -0.4299, ..., 0.3435, -0.7729, 0.0190],
|
||||||
|
> [ 0.8903, 0.2055, -0.2541, ..., 0.3208, -0.6585, 0.0586]]]
|
||||||
|
> Tensor[[1, 7, 768], f32]
|
||||||
|
|
||||||
|
```
|
135
candle-examples/examples/distilbert/main.rs
Normal file
135
candle-examples/examples/distilbert/main.rs
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE};
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use candle::{Device, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use clap::Parser;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
revision: Option<String>,
|
||||||
|
|
||||||
|
/// When set, compute embeddings for this prompt.
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// Use the pytorch weights rather than the safetensors ones
|
||||||
|
#[arg(long)]
|
||||||
|
use_pth: bool,
|
||||||
|
|
||||||
|
/// The number of times to run the prompt.
|
||||||
|
#[arg(long, default_value = "1")]
|
||||||
|
n: usize,
|
||||||
|
|
||||||
|
/// L2 normalization for embeddings.
|
||||||
|
#[arg(long, default_value = "true")]
|
||||||
|
normalize_embeddings: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Args {
|
||||||
|
fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> {
|
||||||
|
let device = candle_examples::device(self.cpu)?;
|
||||||
|
let default_model = "distilbert-base-uncased".to_string();
|
||||||
|
let default_revision = "main".to_string();
|
||||||
|
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
|
||||||
|
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||||
|
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||||
|
(None, Some(revision)) => (default_model, revision),
|
||||||
|
(None, None) => (default_model, default_revision),
|
||||||
|
};
|
||||||
|
|
||||||
|
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||||
|
let (config_filename, tokenizer_filename, weights_filename) = {
|
||||||
|
let api = Api::new()?;
|
||||||
|
let api = api.repo(repo);
|
||||||
|
let config = api.get("config.json")?;
|
||||||
|
let tokenizer = api.get("tokenizer.json")?;
|
||||||
|
let weights = if self.use_pth {
|
||||||
|
api.get("pytorch_model.bin")?
|
||||||
|
} else {
|
||||||
|
api.get("model.safetensors")?
|
||||||
|
};
|
||||||
|
(config, tokenizer, weights)
|
||||||
|
};
|
||||||
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
|
let config: Config = serde_json::from_str(&config)?;
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let vb = if self.use_pth {
|
||||||
|
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
|
||||||
|
} else {
|
||||||
|
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
||||||
|
};
|
||||||
|
let model = DistilBertModel::load(vb, &config)?;
|
||||||
|
Ok((model, tokenizer))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_mask(size: usize, device: &Device) -> Tensor {
|
||||||
|
let mask: Vec<_> = (0..size)
|
||||||
|
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||||
|
.collect();
|
||||||
|
Tensor::from_slice(&mask, (size, size), device).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
println!("tracing...");
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
||||||
|
let device = &model.device;
|
||||||
|
|
||||||
|
let tokenizer = tokenizer
|
||||||
|
.with_padding(None)
|
||||||
|
.with_truncation(None)
|
||||||
|
.map_err(E::msg)?;
|
||||||
|
let tokens = tokenizer
|
||||||
|
.encode(args.prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||||
|
let mask = get_mask(tokens.len(), device);
|
||||||
|
|
||||||
|
println!("token_ids: {:?}", token_ids.to_vec2::<u32>());
|
||||||
|
println!("mask: {:?}", mask.to_vec2::<u8>());
|
||||||
|
|
||||||
|
let ys = model.forward(&token_ids, &mask)?;
|
||||||
|
println!("{ys}");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
|
||||||
|
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
|
||||||
|
}
|
@ -329,18 +329,14 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
|
|
||||||
println!("{tokens:?}");
|
|
||||||
|
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
for index in 0..1 {
|
for index in 0.. {
|
||||||
if tokens.len() >= config.seq_len {
|
if tokens.len() >= config.seq_len {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||||
// println!("Input {}", input);
|
|
||||||
// println!("Input {}", input.to_device(&candle::Device::Cpu)?);
|
|
||||||
let logits = model.forward(&input, index_pos)?;
|
let logits = model.forward(&input, index_pos)?;
|
||||||
let logits = logits.i((0, logits.dim(1)? - 1))?;
|
let logits = logits.i((0, logits.dim(1)? - 1))?;
|
||||||
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
|
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
|
||||||
|
@ -53,6 +53,8 @@ enum Which {
|
|||||||
Zephyr7bAlpha,
|
Zephyr7bAlpha,
|
||||||
#[value(name = "7b-zephyr-b")]
|
#[value(name = "7b-zephyr-b")]
|
||||||
Zephyr7bBeta,
|
Zephyr7bBeta,
|
||||||
|
#[value(name = "7b-open-chat-3.5")]
|
||||||
|
OpenChat35,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Which {
|
impl Which {
|
||||||
@ -67,8 +69,10 @@ impl Which {
|
|||||||
| Self::L7bCode
|
| Self::L7bCode
|
||||||
| Self::L13bCode
|
| Self::L13bCode
|
||||||
| Self::L34bCode => false,
|
| Self::L34bCode => false,
|
||||||
// Zephyr is a fine tuned version of mistral and should be treated in the same way.
|
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
|
||||||
Self::Zephyr7bAlpha
|
// same way.
|
||||||
|
Self::OpenChat35
|
||||||
|
| Self::Zephyr7bAlpha
|
||||||
| Self::Zephyr7bBeta
|
| Self::Zephyr7bBeta
|
||||||
| Self::Mistral7b
|
| Self::Mistral7b
|
||||||
| Self::Mistral7bInstruct => true,
|
| Self::Mistral7bInstruct => true,
|
||||||
@ -87,10 +91,30 @@ impl Which {
|
|||||||
| Self::L13bCode
|
| Self::L13bCode
|
||||||
| Self::L34bCode
|
| Self::L34bCode
|
||||||
| Self::Mistral7b
|
| Self::Mistral7b
|
||||||
| Self::Mistral7bInstruct => false,
|
| Self::Mistral7bInstruct
|
||||||
|
| Self::OpenChat35 => false,
|
||||||
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn is_open_chat(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Which::L7b
|
||||||
|
| Which::L13b
|
||||||
|
| Which::L70b
|
||||||
|
| Which::L7bChat
|
||||||
|
| Which::L13bChat
|
||||||
|
| Which::L70bChat
|
||||||
|
| Which::L7bCode
|
||||||
|
| Which::L13bCode
|
||||||
|
| Which::L34bCode
|
||||||
|
| Which::Mistral7b
|
||||||
|
| Which::Mistral7bInstruct
|
||||||
|
| Which::Zephyr7bAlpha
|
||||||
|
| Which::Zephyr7bBeta => false,
|
||||||
|
Which::OpenChat35 => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -157,7 +181,9 @@ impl Args {
|
|||||||
Some(config) => std::path::PathBuf::from(config),
|
Some(config) => std::path::PathBuf::from(config),
|
||||||
None => {
|
None => {
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let repo = if self.which.is_mistral() {
|
let repo = if self.which.is_open_chat() {
|
||||||
|
"openchat/openchat_3.5"
|
||||||
|
} else if self.which.is_mistral() {
|
||||||
"mistralai/Mistral-7B-v0.1"
|
"mistralai/Mistral-7B-v0.1"
|
||||||
} else {
|
} else {
|
||||||
"hf-internal-testing/llama-tokenizer"
|
"hf-internal-testing/llama-tokenizer"
|
||||||
@ -207,6 +233,7 @@ impl Args {
|
|||||||
Which::Zephyr7bBeta => {
|
Which::Zephyr7bBeta => {
|
||||||
("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf")
|
("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf")
|
||||||
}
|
}
|
||||||
|
Which::OpenChat35 => ("TheBloke/openchat_3.5-GGUF", "openchat_3.5.Q4_K_M.gguf"),
|
||||||
};
|
};
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let api = api.model(repo.to_string());
|
let api = api.model(repo.to_string());
|
||||||
@ -308,7 +335,8 @@ fn main() -> anyhow::Result<()> {
|
|||||||
| Which::Zephyr7bAlpha
|
| Which::Zephyr7bAlpha
|
||||||
| Which::Zephyr7bBeta
|
| Which::Zephyr7bBeta
|
||||||
| Which::L70b
|
| Which::L70b
|
||||||
| Which::L70bChat => 8,
|
| Which::L70bChat
|
||||||
|
| Which::OpenChat35 => 8,
|
||||||
};
|
};
|
||||||
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
|
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
|
||||||
}
|
}
|
||||||
@ -325,10 +353,11 @@ fn main() -> anyhow::Result<()> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let mut pre_prompt_tokens = vec![];
|
let mut pre_prompt_tokens = vec![];
|
||||||
loop {
|
for prompt_index in 0.. {
|
||||||
let prompt_str = match &prompt {
|
let prompt_str = match &prompt {
|
||||||
Prompt::One(prompt) => prompt.clone(),
|
Prompt::One(prompt) => prompt.clone(),
|
||||||
Prompt::Interactive | Prompt::Chat => {
|
Prompt::Interactive | Prompt::Chat => {
|
||||||
|
let is_interactive = matches!(prompt, Prompt::Interactive);
|
||||||
print!("> ");
|
print!("> ");
|
||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
let mut prompt = String::new();
|
let mut prompt = String::new();
|
||||||
@ -339,8 +368,14 @@ fn main() -> anyhow::Result<()> {
|
|||||||
prompt.pop();
|
prompt.pop();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if args.which.is_zephyr() {
|
if args.which.is_open_chat() {
|
||||||
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>")
|
format!("User: {prompt}<|end_of_turn|>Assistant: ")
|
||||||
|
} else if args.which.is_zephyr() {
|
||||||
|
if prompt_index == 0 || is_interactive {
|
||||||
|
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>",)
|
||||||
|
} else {
|
||||||
|
format!("<|user|>\n{prompt}</s>\n<|assistant|>")
|
||||||
|
}
|
||||||
} else if args.which.is_mistral() {
|
} else if args.which.is_mistral() {
|
||||||
format!("[INST] {prompt} [/INST]")
|
format!("[INST] {prompt} [/INST]")
|
||||||
} else {
|
} else {
|
||||||
@ -385,8 +420,12 @@ fn main() -> anyhow::Result<()> {
|
|||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let eos_token = *tos.tokenizer().get_vocab(true).get("</s>").unwrap();
|
let eos_token = if args.which.is_open_chat() {
|
||||||
|
"<|end_of_turn|>"
|
||||||
|
} else {
|
||||||
|
"</s>"
|
||||||
|
};
|
||||||
|
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
|
||||||
let start_post_prompt = std::time::Instant::now();
|
let start_post_prompt = std::time::Instant::now();
|
||||||
let mut sampled = 0;
|
let mut sampled = 0;
|
||||||
for index in 0..to_sample {
|
for index in 0..to_sample {
|
||||||
|
@ -416,7 +416,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
|
|
||||||
println!("Building the autoencoder.");
|
println!("Building the autoencoder.");
|
||||||
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
|
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
|
||||||
let vae = sd_config.build_vae(&vae_weights, &device, dtype)?;
|
let vae = sd_config.build_vae(vae_weights, &device, dtype)?;
|
||||||
let init_latent_dist = match &img2img {
|
let init_latent_dist = match &img2img {
|
||||||
None => None,
|
None => None,
|
||||||
Some(image) => {
|
Some(image) => {
|
||||||
@ -426,7 +426,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
};
|
};
|
||||||
println!("Building the unet.");
|
println!("Building the unet.");
|
||||||
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?;
|
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?;
|
||||||
let unet = sd_config.build_unet(&unet_weights, &device, 4, use_flash_attn, dtype)?;
|
let unet = sd_config.build_unet(unet_weights, &device, 4, use_flash_attn, dtype)?;
|
||||||
|
|
||||||
let t_start = if img2img.is_some() {
|
let t_start = if img2img.is_some() {
|
||||||
n_steps - (n_steps as f64 * img2img_strength) as usize
|
n_steps - (n_steps as f64 * img2img_strength) as usize
|
||||||
|
@ -9,6 +9,8 @@ $ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate
|
|||||||
9 tokens generated (2.42 token/s)
|
9 tokens generated (2.42 token/s)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Variants such as [flan-t5](https://huggingface.co/google/flan-t5-small), [flan-ul2](https://huggingface.co/google/flan-ul2) (with `--revision "refs/pr/25"`), and [Co-EdIT](https://huggingface.co/grammarly/coedit-large) are also supported.
|
||||||
|
|
||||||
## Translation with [MADLAD-400](https://arxiv.org/abs/2309.04662)
|
## Translation with [MADLAD-400](https://arxiv.org/abs/2309.04662)
|
||||||
|
|
||||||
MADLAD-400 is a series of multilingual machine translation T5 models trained on 250 billion tokens covering over 450 languages using publicly available data. These models are competitive with significantly larger models.
|
MADLAD-400 is a series of multilingual machine translation T5 models trained on 250 billion tokens covering over 450 languages using publicly available data. These models are competitive with significantly larger models.
|
||||||
@ -22,7 +24,7 @@ cargo run --example t5 --release -- \
|
|||||||
Wie geht es dir, mein Freund?
|
Wie geht es dir, mein Freund?
|
||||||
```
|
```
|
||||||
|
|
||||||
## Sentence embedding example:
|
## Sentence embedding example
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle."
|
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle."
|
||||||
|
@ -104,6 +104,17 @@ impl T5ModelBuilder {
|
|||||||
api.get("model-00004-of-00005.safetensors")?,
|
api.get("model-00004-of-00005.safetensors")?,
|
||||||
api.get("model-00005-of-00005.safetensors")?,
|
api.get("model-00005-of-00005.safetensors")?,
|
||||||
]
|
]
|
||||||
|
} else if model_id == "google/flan-ul2" {
|
||||||
|
vec![
|
||||||
|
api.get("model-00001-of-00008.safetensors")?,
|
||||||
|
api.get("model-00002-of-00008.safetensors")?,
|
||||||
|
api.get("model-00003-of-00008.safetensors")?,
|
||||||
|
api.get("model-00004-of-00008.safetensors")?,
|
||||||
|
api.get("model-00005-of-00008.safetensors")?,
|
||||||
|
api.get("model-00006-of-00008.safetensors")?,
|
||||||
|
api.get("model-00007-of-00008.safetensors")?,
|
||||||
|
api.get("model-00008-of-00008.safetensors")?,
|
||||||
|
]
|
||||||
} else {
|
} else {
|
||||||
vec![api.get("model.safetensors")?]
|
vec![api.get("model.safetensors")?]
|
||||||
};
|
};
|
||||||
|
BIN
candle-examples/examples/trocr/assets/trocr.png
Normal file
BIN
candle-examples/examples/trocr/assets/trocr.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 36 KiB |
154
candle-examples/examples/trocr/image_processor.rs
Normal file
154
candle-examples/examples/trocr/image_processor.rs
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
use image::{DynamicImage, ImageBuffer};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use candle::{DType, Device, Result, Tensor};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||||
|
pub struct ProcessorConfig {
|
||||||
|
do_resize: bool,
|
||||||
|
height: u32,
|
||||||
|
width: u32,
|
||||||
|
do_rescale: bool,
|
||||||
|
do_normalize: bool,
|
||||||
|
image_mean: Vec<f32>,
|
||||||
|
image_std: Vec<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ProcessorConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
do_resize: true,
|
||||||
|
height: 384,
|
||||||
|
width: 384,
|
||||||
|
do_rescale: true,
|
||||||
|
do_normalize: true,
|
||||||
|
image_mean: vec![0.5, 0.5, 0.5],
|
||||||
|
image_std: vec![0.5, 0.5, 0.5],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ViTImageProcessor {
|
||||||
|
do_resize: bool,
|
||||||
|
height: u32,
|
||||||
|
width: u32,
|
||||||
|
do_normalize: bool,
|
||||||
|
image_mean: Vec<f32>,
|
||||||
|
image_std: Vec<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ViTImageProcessor {
|
||||||
|
pub fn new(config: &ProcessorConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
do_resize: config.do_resize,
|
||||||
|
height: config.height,
|
||||||
|
width: config.width,
|
||||||
|
do_normalize: config.do_normalize,
|
||||||
|
image_mean: config.image_mean.clone(),
|
||||||
|
image_std: config.image_std.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn preprocess(&self, images: Vec<&str>) -> Result<Tensor> {
|
||||||
|
let height = self.height as usize;
|
||||||
|
let width = self.width as usize;
|
||||||
|
let channels = 3;
|
||||||
|
|
||||||
|
let images = self.load_images(images)?;
|
||||||
|
|
||||||
|
let resized_images: Vec<DynamicImage> = if self.do_resize {
|
||||||
|
images
|
||||||
|
.iter()
|
||||||
|
.map(|image| self.resize(image.clone(), None).unwrap())
|
||||||
|
.collect()
|
||||||
|
} else {
|
||||||
|
images
|
||||||
|
};
|
||||||
|
|
||||||
|
let normalized_images: Vec<Tensor> = if self.do_normalize {
|
||||||
|
resized_images
|
||||||
|
.iter()
|
||||||
|
.map(|image| self.normalize(image.clone(), None, None).unwrap())
|
||||||
|
.collect()
|
||||||
|
} else {
|
||||||
|
let resized_images: Vec<ImageBuffer<image::Rgb<u8>, Vec<u8>>> =
|
||||||
|
resized_images.iter().map(|image| image.to_rgb8()).collect();
|
||||||
|
let data = resized_images
|
||||||
|
.into_iter()
|
||||||
|
.map(|image| image.into_raw())
|
||||||
|
.collect::<Vec<Vec<u8>>>();
|
||||||
|
|
||||||
|
data.iter()
|
||||||
|
.map(|image| {
|
||||||
|
Tensor::from_vec(image.clone(), (height, width, channels), &Device::Cpu)
|
||||||
|
.unwrap()
|
||||||
|
.permute((2, 0, 1))
|
||||||
|
.unwrap()
|
||||||
|
})
|
||||||
|
.collect::<Vec<Tensor>>()
|
||||||
|
};
|
||||||
|
|
||||||
|
Tensor::stack(&normalized_images, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn resize(
|
||||||
|
&self,
|
||||||
|
image: image::DynamicImage,
|
||||||
|
size: Option<HashMap<String, u32>>,
|
||||||
|
) -> Result<image::DynamicImage> {
|
||||||
|
let (height, width) = match &size {
|
||||||
|
Some(size) => (size.get("height").unwrap(), size.get("width").unwrap()),
|
||||||
|
None => (&self.height, &self.width),
|
||||||
|
};
|
||||||
|
|
||||||
|
let resized_image =
|
||||||
|
image.resize_exact(*width, *height, image::imageops::FilterType::Triangle);
|
||||||
|
|
||||||
|
Ok(resized_image)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize(
|
||||||
|
&self,
|
||||||
|
image: image::DynamicImage,
|
||||||
|
mean: Option<Vec<f32>>,
|
||||||
|
std: Option<Vec<f32>>,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let mean = match mean {
|
||||||
|
Some(mean) => mean,
|
||||||
|
None => self.image_mean.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let std = match std {
|
||||||
|
Some(std) => std,
|
||||||
|
None => self.image_std.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mean = Tensor::from_vec(mean, (3, 1, 1), &Device::Cpu)?;
|
||||||
|
let std = Tensor::from_vec(std, (3, 1, 1), &Device::Cpu)?;
|
||||||
|
|
||||||
|
let image = image.to_rgb8();
|
||||||
|
let data = image.into_raw();
|
||||||
|
|
||||||
|
let height = self.height as usize;
|
||||||
|
let width = self.width as usize;
|
||||||
|
let channels = 3;
|
||||||
|
|
||||||
|
let data =
|
||||||
|
Tensor::from_vec(data, &[height, width, channels], &Device::Cpu)?.permute((2, 0, 1))?;
|
||||||
|
|
||||||
|
(data.to_dtype(DType::F32)? / 255.)?
|
||||||
|
.broadcast_sub(&mean)?
|
||||||
|
.broadcast_div(&std)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_images(&self, image_path: Vec<&str>) -> Result<Vec<image::DynamicImage>> {
|
||||||
|
let mut images: Vec<image::DynamicImage> = Vec::new();
|
||||||
|
for path in image_path {
|
||||||
|
let img = image::io::Reader::open(path)?.decode().unwrap();
|
||||||
|
images.push(img);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(images)
|
||||||
|
}
|
||||||
|
}
|
132
candle-examples/examples/trocr/main.rs
Normal file
132
candle-examples/examples/trocr/main.rs
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::Error as E;
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
use candle::{DType, Tensor};
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::models::trocr;
|
||||||
|
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
mod image_processor;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
Base,
|
||||||
|
Large,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
/// Choose the variant of the model to run.
|
||||||
|
#[arg(long, default_value = "base")]
|
||||||
|
which: Which,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Text to be translated
|
||||||
|
#[arg(long)]
|
||||||
|
image: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
use hf_hub::api::sync::Api;
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let tokenizer_dec = {
|
||||||
|
let tokenizer = Api::new()?
|
||||||
|
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
|
||||||
|
.get("tokenizer.json")?;
|
||||||
|
|
||||||
|
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let vb = {
|
||||||
|
let model = match args.model {
|
||||||
|
Some(model) => std::path::PathBuf::from(model),
|
||||||
|
None => match args.which {
|
||||||
|
Which::Base => Api::new()?
|
||||||
|
.repo(hf_hub::Repo::with_revision(
|
||||||
|
"microsoft/trocr-base-handwritten".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/3".to_string(),
|
||||||
|
))
|
||||||
|
.get("model.safetensors")?,
|
||||||
|
Which::Large => Api::new()?
|
||||||
|
.repo(hf_hub::Repo::with_revision(
|
||||||
|
"microsoft/trocr-large-handwritten".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/6".to_string(),
|
||||||
|
))
|
||||||
|
.get("model.safetensors")?,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
println!("model: {:?}", model);
|
||||||
|
unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }
|
||||||
|
};
|
||||||
|
|
||||||
|
let encoder_config = match args.which {
|
||||||
|
Which::Base => candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten(),
|
||||||
|
Which::Large => {
|
||||||
|
candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let decoder_config = trocr::TrOCRConfig::default();
|
||||||
|
let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?;
|
||||||
|
|
||||||
|
let config = image_processor::ProcessorConfig::default();
|
||||||
|
let processor = image_processor::ViTImageProcessor::new(&config);
|
||||||
|
|
||||||
|
let image = vec![args.image.as_str()];
|
||||||
|
let image = processor.preprocess(image)?;
|
||||||
|
|
||||||
|
let encoder_xs = model.encoder().forward(&image)?;
|
||||||
|
|
||||||
|
let mut logits_processor =
|
||||||
|
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
|
||||||
|
|
||||||
|
let mut token_ids: Vec<u32> = vec![decoder_config.decoder_start_token_id];
|
||||||
|
for index in 0..1000 {
|
||||||
|
let context_size = if index >= 1 { 1 } else { token_ids.len() };
|
||||||
|
let start_pos = token_ids.len().saturating_sub(context_size);
|
||||||
|
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
|
||||||
|
|
||||||
|
let logits = model.decode(&input_ids, &encoder_xs, start_pos)?;
|
||||||
|
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||||
|
let token = logits_processor.sample(&logits)?;
|
||||||
|
token_ids.push(token);
|
||||||
|
|
||||||
|
if let Some(t) = tokenizer_dec.next_token(token)? {
|
||||||
|
use std::io::Write;
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
if token == decoder_config.eos_token_id {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
println!();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
16
candle-examples/examples/trocr/readme.md
Normal file
16
candle-examples/examples/trocr/readme.md
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
# candle-trocr
|
||||||
|
|
||||||
|
`TrOCR` is a transformer OCR Model. In this example it is used to
|
||||||
|
transcribe image text. See the associated [model
|
||||||
|
card](https://huggingface.co/microsoft/trocr-base-printed) for details on
|
||||||
|
the model itself.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example trocr --release -- --which base --cpu --image candle-examples/examples/trocr/assets/trocr.png
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
<s> industry , Mr. Brown commented icily . " Let us have a</s>
|
||||||
|
```
|
@ -128,7 +128,13 @@ impl Decoder {
|
|||||||
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
|
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
|
||||||
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
|
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
|
||||||
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
|
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
|
||||||
let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
|
let no_speech_token = m::NO_SPEECH_TOKENS
|
||||||
|
.iter()
|
||||||
|
.find_map(|token| token_id(&tokenizer, token).ok());
|
||||||
|
let no_speech_token = match no_speech_token {
|
||||||
|
None => anyhow::bail!("unable to find any non-speech token"),
|
||||||
|
Some(n) => n,
|
||||||
|
};
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
model,
|
model,
|
||||||
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
||||||
@ -512,11 +518,7 @@ fn main() -> Result<()> {
|
|||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
let config = repo.get("config.json")?;
|
let config = repo.get("config.json")?;
|
||||||
let tokenizer = if args.model == WhichModel::LargeV3 {
|
let tokenizer = repo.get("tokenizer.json")?;
|
||||||
panic!("openai/whisper-large-v3 does not provide a compatible tokenizer.json config at the moment")
|
|
||||||
} else {
|
|
||||||
repo.get("tokenizer.json")?
|
|
||||||
};
|
|
||||||
let model = repo.get("model.safetensors")?;
|
let model = repo.get("model.safetensors")?;
|
||||||
(config, tokenizer, model)
|
(config, tokenizer, model)
|
||||||
};
|
};
|
||||||
|
268
candle-examples/examples/yi/main.rs
Normal file
268
candle-examples/examples/yi/main.rs
Normal file
@ -0,0 +1,268 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
use candle_transformers::models::yi::{Config, Model};
|
||||||
|
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "6b")]
|
||||||
|
L6b,
|
||||||
|
#[value(name = "34b")]
|
||||||
|
L34b,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TextGeneration {
|
||||||
|
model: Model,
|
||||||
|
device: Device,
|
||||||
|
tokenizer: TokenOutputStream,
|
||||||
|
logits_processor: LogitsProcessor,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextGeneration {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn new(
|
||||||
|
model: Model,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
seed: u64,
|
||||||
|
temp: Option<f64>,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
device: &Device,
|
||||||
|
) -> Self {
|
||||||
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
tokenizer: TokenOutputStream::new(tokenizer),
|
||||||
|
logits_processor,
|
||||||
|
repeat_penalty,
|
||||||
|
repeat_last_n,
|
||||||
|
device: device.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||||
|
use std::io::Write;
|
||||||
|
self.tokenizer.clear();
|
||||||
|
let mut tokens = self
|
||||||
|
.tokenizer
|
||||||
|
.tokenizer()
|
||||||
|
.encode(prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
for &t in tokens.iter() {
|
||||||
|
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||||
|
print!("{t}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
|
||||||
|
let mut generated_tokens = 0usize;
|
||||||
|
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||||
|
Some(token) => token,
|
||||||
|
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
||||||
|
};
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
for index in 0..sample_len {
|
||||||
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
|
let start_pos = tokens.len().saturating_sub(context_size);
|
||||||
|
let ctxt = &tokens[start_pos..];
|
||||||
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
|
let logits = self.model.forward(&input, start_pos)?;
|
||||||
|
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
let logits = if self.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
self.repeat_penalty,
|
||||||
|
&tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
generated_tokens += 1;
|
||||||
|
if next_token == eos_token {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let dt = start_gen.elapsed();
|
||||||
|
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
println!(
|
||||||
|
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||||
|
generated_tokens as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long)]
|
||||||
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(long, short = 'n', default_value_t = 100)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "01-ai/Yi-6B")]
|
||||||
|
model_id: String,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "main")]
|
||||||
|
revision: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
weight_files: Option<String>,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
|
||||||
|
/// The model size to use.
|
||||||
|
#[arg(long, default_value = "6b")]
|
||||||
|
which: Which,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature.unwrap_or(0.),
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let api = Api::new()?;
|
||||||
|
let repo = api.repo(Repo::with_revision(
|
||||||
|
args.model_id,
|
||||||
|
RepoType::Model,
|
||||||
|
args.revision,
|
||||||
|
));
|
||||||
|
let tokenizer_filename = match args.tokenizer_file {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => repo.get("tokenizer.json")?,
|
||||||
|
};
|
||||||
|
let filenames = match args.weight_files {
|
||||||
|
Some(files) => files
|
||||||
|
.split(',')
|
||||||
|
.map(std::path::PathBuf::from)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
None => match args.which {
|
||||||
|
Which::L6b => vec![
|
||||||
|
repo.get("model-00001-of-00002.safetensors")?,
|
||||||
|
repo.get("model-00002-of-00002.safetensors")?,
|
||||||
|
],
|
||||||
|
Which::L34b => vec![
|
||||||
|
repo.get("model-00001-of-00007.safetensors")?,
|
||||||
|
repo.get("model-00002-of-00007.safetensors")?,
|
||||||
|
repo.get("model-00003-of-00007.safetensors")?,
|
||||||
|
repo.get("model-00004-of-00007.safetensors")?,
|
||||||
|
repo.get("model-00005-of-00007.safetensors")?,
|
||||||
|
repo.get("model-00006-of-00007.safetensors")?,
|
||||||
|
repo.get("model-00007-of-00007.safetensors")?,
|
||||||
|
],
|
||||||
|
},
|
||||||
|
};
|
||||||
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let config = match args.which {
|
||||||
|
Which::L6b => Config::config_6b(),
|
||||||
|
Which::L34b => Config::config_34b(),
|
||||||
|
};
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let dtype = if device.is_cuda() {
|
||||||
|
DType::BF16
|
||||||
|
} else {
|
||||||
|
DType::F32
|
||||||
|
};
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
|
let model = Model::new(&config, vb)?;
|
||||||
|
|
||||||
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let mut pipeline = TextGeneration::new(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
args.seed,
|
||||||
|
args.temperature,
|
||||||
|
args.top_p,
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
pipeline.run(&args.prompt, args.sample_len)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -43,6 +43,7 @@ pub fn report(
|
|||||||
confidence_threshold: f32,
|
confidence_threshold: f32,
|
||||||
nms_threshold: f32,
|
nms_threshold: f32,
|
||||||
) -> Result<DynamicImage> {
|
) -> Result<DynamicImage> {
|
||||||
|
let pred = pred.to_device(&Device::Cpu)?;
|
||||||
let (npreds, pred_size) = pred.dims2()?;
|
let (npreds, pred_size) = pred.dims2()?;
|
||||||
let nclasses = pred_size - 5;
|
let nclasses = pred_size - 5;
|
||||||
// The bounding boxes grouped by (maximum) class index.
|
// The bounding boxes grouped by (maximum) class index.
|
||||||
|
@ -32,7 +32,7 @@ Image source:
|
|||||||
### Pose Estimation
|
### Pose Estimation
|
||||||
```bash
|
```bash
|
||||||
cargo run --example yolo-v8 --release -- \
|
cargo run --example yolo-v8 --release -- \
|
||||||
candle-examples/examples/yolo-v8/assets/peoples.jpeg --task pose
|
candle-examples/examples/yolo-v8/assets/bike.jpg --task pose
|
||||||
```
|
```
|
||||||
|
|
||||||

|

|
||||||
|
@ -7,7 +7,7 @@ extern crate accelerate_src;
|
|||||||
mod model;
|
mod model;
|
||||||
use model::{Multiples, YoloV8, YoloV8Pose};
|
use model::{Multiples, YoloV8, YoloV8Pose};
|
||||||
|
|
||||||
use candle::{DType, IndexOp, Result, Tensor};
|
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||||
use candle_nn::{Module, VarBuilder};
|
use candle_nn::{Module, VarBuilder};
|
||||||
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
|
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
@ -61,6 +61,7 @@ pub fn report_detect(
|
|||||||
nms_threshold: f32,
|
nms_threshold: f32,
|
||||||
legend_size: u32,
|
legend_size: u32,
|
||||||
) -> Result<DynamicImage> {
|
) -> Result<DynamicImage> {
|
||||||
|
let pred = pred.to_device(&Device::Cpu)?;
|
||||||
let (pred_size, npreds) = pred.dims2()?;
|
let (pred_size, npreds) = pred.dims2()?;
|
||||||
let nclasses = pred_size - 4;
|
let nclasses = pred_size - 4;
|
||||||
// The bounding boxes grouped by (maximum) class index.
|
// The bounding boxes grouped by (maximum) class index.
|
||||||
@ -153,6 +154,7 @@ pub fn report_pose(
|
|||||||
confidence_threshold: f32,
|
confidence_threshold: f32,
|
||||||
nms_threshold: f32,
|
nms_threshold: f32,
|
||||||
) -> Result<DynamicImage> {
|
) -> Result<DynamicImage> {
|
||||||
|
let pred = pred.to_device(&Device::Cpu)?;
|
||||||
let (pred_size, npreds) = pred.dims2()?;
|
let (pred_size, npreds) = pred.dims2()?;
|
||||||
if pred_size != 17 * 3 + 4 + 1 {
|
if pred_size != 17 * 3 + 4 + 1 {
|
||||||
candle::bail!("unexpected pred-size {pred_size}");
|
candle::bail!("unexpected pred-size {pred_size}");
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-flash-attn"
|
name = "candle-flash-attn"
|
||||||
version = "0.3.0"
|
version = "0.3.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Flash attention layer for the candle ML framework."
|
description = "Flash attention layer for the candle ML framework."
|
||||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", features = ["cuda"], version = "0.3.0", package = "candle-core" }
|
candle = { path = "../candle-core", features = ["cuda"], version = "0.3.1", package = "candle-core" }
|
||||||
half = { version = "2.3.1", features = ["num-traits"] }
|
half = { version = "2.3.1", features = ["num-traits"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
@ -21,4 +21,4 @@ rayon = "1.7.0"
|
|||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.3.0", features = ["cuda"] }
|
candle-nn = { path = "../candle-nn", version = "0.3.1", features = ["cuda"] }
|
||||||
|
@ -233,8 +233,8 @@ impl FlashAttnVarLen {
|
|||||||
|
|
||||||
let (seqlens_q, seqlens_q_layout) = self.seqlens_q.storage_and_layout();
|
let (seqlens_q, seqlens_q_layout) = self.seqlens_q.storage_and_layout();
|
||||||
let seqlens_q = match &*seqlens_q {
|
let seqlens_q = match &*seqlens_q {
|
||||||
candle::Storage::Cpu(_) => candle::bail!("seqlens_q must be a cuda tensor"),
|
|
||||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
|
||||||
|
_ => candle::bail!("seqlens_q must be a cuda tensor"),
|
||||||
};
|
};
|
||||||
let seqlens_q = match seqlens_q_layout.contiguous_offsets() {
|
let seqlens_q = match seqlens_q_layout.contiguous_offsets() {
|
||||||
Some((o1, o2)) => seqlens_q.slice(o1..o2),
|
Some((o1, o2)) => seqlens_q.slice(o1..o2),
|
||||||
@ -243,8 +243,8 @@ impl FlashAttnVarLen {
|
|||||||
|
|
||||||
let (seqlens_k, seqlens_k_layout) = self.seqlens_k.storage_and_layout();
|
let (seqlens_k, seqlens_k_layout) = self.seqlens_k.storage_and_layout();
|
||||||
let seqlens_k = match &*seqlens_k {
|
let seqlens_k = match &*seqlens_k {
|
||||||
candle::Storage::Cpu(_) => candle::bail!("seqlens_k must be a cuda tensor"),
|
|
||||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
|
||||||
|
_ => candle::bail!("seqlens_k must be a cuda tensor"),
|
||||||
};
|
};
|
||||||
let seqlens_k = match seqlens_k_layout.contiguous_offsets() {
|
let seqlens_k = match seqlens_k_layout.contiguous_offsets() {
|
||||||
Some((o1, o2)) => seqlens_k.slice(o1..o2),
|
Some((o1, o2)) => seqlens_k.slice(o1..o2),
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-kernels"
|
name = "candle-kernels"
|
||||||
version = "0.3.0"
|
version = "0.3.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "CUDA kernels for Candle"
|
description = "CUDA kernels for Candle"
|
||||||
|
@ -1,17 +1,16 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-metal-kernels"
|
name = "candle-metal-kernels"
|
||||||
version = "0.3.0"
|
version = "0.3.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "CUDA kernels for Candle"
|
description = "Metal kernels for Candle"
|
||||||
repository = "https://github.com/huggingface/candle"
|
repository = "https://github.com/huggingface/candle"
|
||||||
keywords = ["blas", "tensor", "machine-learning"]
|
keywords = ["blas", "tensor", "machine-learning"]
|
||||||
categories = ["science"]
|
categories = ["science"]
|
||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
|
||||||
metal = { path = "../../metal-rs", features = ["mps"] }
|
|
||||||
once_cell = "1.18.0"
|
once_cell = "1.18.0"
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
|
@ -33,6 +33,24 @@ kernel void FN_NAME( \
|
|||||||
const TYPENAME a = TYPENAME(add); \
|
const TYPENAME a = TYPENAME(add); \
|
||||||
output[id] = input[id] * m + a; \
|
output[id] = input[id] * m + a; \
|
||||||
} \
|
} \
|
||||||
|
kernel void FN_NAME##_strided( \
|
||||||
|
constant size_t &dim, \
|
||||||
|
constant size_t &num_dims, \
|
||||||
|
constant size_t *dims, \
|
||||||
|
constant size_t *strides, \
|
||||||
|
constant float &mul, \
|
||||||
|
constant float &add, \
|
||||||
|
device const TYPENAME *input, \
|
||||||
|
device TYPENAME *output, \
|
||||||
|
uint id [[ thread_position_in_grid ]] \
|
||||||
|
) { \
|
||||||
|
if (id >= dim) { \
|
||||||
|
return; \
|
||||||
|
} \
|
||||||
|
const TYPENAME m = TYPENAME(mul); \
|
||||||
|
const TYPENAME a = TYPENAME(add); \
|
||||||
|
output[id] = input[get_strided_index(id, num_dims, dims, strides)] * m + a; \
|
||||||
|
} \
|
||||||
|
|
||||||
AFFINE(affine_float, float)
|
AFFINE(affine_float, float)
|
||||||
AFFINE(affine_half, half)
|
AFFINE(affine_half, half)
|
||||||
|
@ -23,12 +23,12 @@ kernel void FN_NAME( \
|
|||||||
constant size_t &dim, \
|
constant size_t &dim, \
|
||||||
device const LEFT_TYPENAME *input, \
|
device const LEFT_TYPENAME *input, \
|
||||||
device RIGHT_TYPENAME *output, \
|
device RIGHT_TYPENAME *output, \
|
||||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
uint tid [[ thread_position_in_grid ]] \
|
||||||
) { \
|
) { \
|
||||||
if (thread_position_in_grid >= dim) { \
|
if (tid >= dim) { \
|
||||||
return; \
|
return; \
|
||||||
} \
|
} \
|
||||||
output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \
|
output[tid] = RIGHT_TYPENAME(input[tid]); \
|
||||||
} \
|
} \
|
||||||
kernel void FN_NAME_STRIDED( \
|
kernel void FN_NAME_STRIDED( \
|
||||||
constant size_t &dim, \
|
constant size_t &dim, \
|
||||||
@ -37,15 +37,19 @@ kernel void FN_NAME_STRIDED( \
|
|||||||
constant size_t *strides, \
|
constant size_t *strides, \
|
||||||
device const LEFT_TYPENAME *input, \
|
device const LEFT_TYPENAME *input, \
|
||||||
device RIGHT_TYPENAME *output, \
|
device RIGHT_TYPENAME *output, \
|
||||||
uint i [[ thread_position_in_grid ]] \
|
uint tid [[ thread_position_in_grid ]] \
|
||||||
) { \
|
) { \
|
||||||
if (i >= dim) { \
|
if (tid >= dim) { \
|
||||||
return; \
|
return; \
|
||||||
} \
|
} \
|
||||||
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
|
output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
|
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
|
||||||
|
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
|
||||||
|
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
|
||||||
|
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
|
||||||
|
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
#endif
|
#endif
|
||||||
|
@ -16,16 +16,16 @@ kernel void NAME( \
|
|||||||
if (gid >= dst_size) { \
|
if (gid >= dst_size) { \
|
||||||
return; \
|
return; \
|
||||||
} \
|
} \
|
||||||
const size_t id_i = gid / right_size / left_size; \
|
const size_t id_i = (gid / right_size) % ids_size; \
|
||||||
|
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
|
||||||
const size_t right_rank_i = gid % right_size; \
|
const size_t right_rank_i = gid % right_size; \
|
||||||
const size_t left_rank_i = gid % left_size; \
|
const size_t left_rank_i = gid / right_size / ids_size; \
|
||||||
/* \
|
/* \
|
||||||
// Force prevent out of bounds indexing \
|
// Force prevent out of bounds indexing \
|
||||||
// since there doesn't seem to be a good way to force crash \
|
// since there doesn't seem to be a good way to force crash \
|
||||||
// No need to check for zero we're only allowing unsized. \
|
// No need to check for zero we're only allowing unsized. \
|
||||||
*/ \
|
*/ \
|
||||||
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
|
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \
|
||||||
const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; \
|
|
||||||
output[gid] = input[src_i]; \
|
output[gid] = input[src_i]; \
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -75,6 +75,7 @@ kernel void FN_NAME( \
|
|||||||
|
|
||||||
|
|
||||||
INDEX_OP(is_u32_f32, uint, float)
|
INDEX_OP(is_u32_f32, uint, float)
|
||||||
|
INDEX_OP(is_u32_f16, uint, half)
|
||||||
|
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,8 @@
|
|||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
|
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||||
|
|
||||||
METAL_FUNC uint get_strided_index(
|
METAL_FUNC uint get_strided_index(
|
||||||
uint idx,
|
uint idx,
|
||||||
constant size_t &num_dims,
|
constant size_t &num_dims,
|
||||||
@ -16,18 +18,18 @@ METAL_FUNC uint get_strided_index(
|
|||||||
return strided_i;
|
return strided_i;
|
||||||
}
|
}
|
||||||
|
|
||||||
constant int THREADGROUP_SIZE = 256;
|
constant int THREADGROUP_SIZE = 1024;
|
||||||
|
|
||||||
# define REDUCE(FN, NAME, TYPENAME) \
|
# define REDUCE(FN, NAME, T) \
|
||||||
kernel void NAME( \
|
kernel void NAME( \
|
||||||
constant size_t &src_numel, \
|
constant size_t &src_numel, \
|
||||||
constant size_t &el_to_sum_per_block, \
|
constant size_t &el_to_sum_per_block, \
|
||||||
device const TYPENAME *src, \
|
device const T *src, \
|
||||||
device TYPENAME *dst, \
|
device T *dst, \
|
||||||
uint id [[ thread_position_in_grid ]], \
|
uint id [[ thread_position_in_grid ]], \
|
||||||
uint tid [[ thread_index_in_threadgroup ]], \
|
uint tid [[ thread_index_in_threadgroup ]], \
|
||||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||||
uint blockDim [[ threads_per_threadgroup ]] \
|
uint block_dim [[ threads_per_threadgroup ]] \
|
||||||
) { \
|
) { \
|
||||||
\
|
\
|
||||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||||
@ -45,10 +47,10 @@ kernel void NAME( \
|
|||||||
// TODO: Fast version for the contiguous case. \
|
// TODO: Fast version for the contiguous case. \
|
||||||
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||||
*/ \
|
*/ \
|
||||||
TYPENAME x = shared_memory[tid]; \
|
T x = shared_memory[tid]; \
|
||||||
TYPENAME y = src[idx]; \
|
T y = src[idx]; \
|
||||||
shared_memory[tid] = FN; \
|
shared_memory[tid] = FN; \
|
||||||
idx += blockDim; \
|
idx += block_dim; \
|
||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
threadgroup_barrier(mem_flags::mem_none); \
|
threadgroup_barrier(mem_flags::mem_none); \
|
||||||
@ -56,10 +58,10 @@ kernel void NAME( \
|
|||||||
/* \
|
/* \
|
||||||
// reduction in shared memory \
|
// reduction in shared memory \
|
||||||
*/ \
|
*/ \
|
||||||
for (uint s = blockDim / 2; s > 0; s >>= 1) { \
|
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||||
if (tid < s) { \
|
if (tid < s) { \
|
||||||
TYPENAME x = shared_memory[tid]; \
|
T x = shared_memory[tid]; \
|
||||||
TYPENAME y = shared_memory[tid + s]; \
|
T y = shared_memory[tid + s]; \
|
||||||
shared_memory[tid] = FN; \
|
shared_memory[tid] = FN; \
|
||||||
} \
|
} \
|
||||||
threadgroup_barrier(mem_flags::mem_none); \
|
threadgroup_barrier(mem_flags::mem_none); \
|
||||||
@ -68,72 +70,74 @@ kernel void NAME( \
|
|||||||
dst[dst_id] = shared_memory[0]; \
|
dst[dst_id] = shared_memory[0]; \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
kernel void softmax_float(
|
|
||||||
constant size_t &src_numel,
|
|
||||||
constant size_t &el_to_sum_per_block,
|
|
||||||
device const float *src,
|
|
||||||
device float *dst,
|
|
||||||
uint id [[ thread_position_in_grid ]],
|
|
||||||
uint tid [[ thread_index_in_threadgroup ]],
|
|
||||||
uint dst_id [[ threadgroup_position_in_grid ]],
|
|
||||||
uint blockDim [[ threads_per_threadgroup ]]
|
|
||||||
) {
|
|
||||||
|
|
||||||
threadgroup float shared_memory[THREADGROUP_SIZE];
|
|
||||||
|
|
||||||
shared_memory[tid] = -INFINITY;
|
|
||||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
|
||||||
// to (dst_id + 1) * el_to_sum_per_block.
|
|
||||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
|
||||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
|
||||||
size_t idx = start_idx + tid;
|
|
||||||
|
|
||||||
while (idx < stop_idx) {
|
|
||||||
// TODO: Fast version for the contiguous case.
|
|
||||||
shared_memory[tid] = max(shared_memory[tid], src[idx]);
|
|
||||||
idx += blockDim;
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
// reduction in shared memory
|
|
||||||
for (uint s = blockDim / 2; s > 0; s >>= 1) {
|
|
||||||
if (tid < s) {
|
|
||||||
shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]);
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
}
|
|
||||||
|
|
||||||
float max = shared_memory[0];
|
|
||||||
|
|
||||||
shared_memory[tid] = 0;
|
|
||||||
|
|
||||||
// Restart
|
|
||||||
idx = start_idx + tid;
|
|
||||||
while (idx < stop_idx) {
|
|
||||||
// TODO: Fast version for the contiguous case.
|
|
||||||
const float val = exp(src[idx] - max);
|
|
||||||
dst[idx] = val;
|
|
||||||
shared_memory[tid] += val;
|
|
||||||
idx += blockDim;
|
|
||||||
}
|
|
||||||
// reduction in shared memory
|
|
||||||
for (uint s = blockDim / 2; s > 0; s >>= 1) {
|
|
||||||
if (tid < s) {
|
|
||||||
shared_memory[tid] += shared_memory[tid + s];
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
}
|
|
||||||
|
|
||||||
const float inv_acc = 1/shared_memory[0];
|
|
||||||
idx = start_idx + tid;
|
|
||||||
while (idx < stop_idx) {
|
|
||||||
dst[idx] *= inv_acc;
|
|
||||||
idx += blockDim;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
REDUCE(x + y, fast_sum_float, float)
|
REDUCE(x + y, fast_sum_float, float)
|
||||||
REDUCE(x * y, fast_mul_float, float)
|
REDUCE(x * y, fast_mul_float, float)
|
||||||
REDUCE(max(x, y), fast_max_float, float)
|
REDUCE(max(x, y), fast_max_float, float)
|
||||||
|
|
||||||
|
#define SOFTMAX(NAME, T) \
|
||||||
|
kernel void NAME( \
|
||||||
|
constant size_t &src_numel, \
|
||||||
|
constant size_t &el_to_sum_per_block, \
|
||||||
|
device const T *src, \
|
||||||
|
device T *dst, \
|
||||||
|
\
|
||||||
|
uint id [[ thread_position_in_grid ]], \
|
||||||
|
uint tid [[ thread_index_in_threadgroup ]], \
|
||||||
|
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||||
|
uint block_dim [[ threads_per_threadgroup ]] \
|
||||||
|
) { \
|
||||||
|
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||||
|
shared_memory[tid] = -INFINITY; \
|
||||||
|
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||||
|
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||||
|
size_t idx = start_idx + tid; \
|
||||||
|
\
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||||
|
\
|
||||||
|
while (idx < stop_idx) { \
|
||||||
|
shared_memory[tid] = MAX(shared_memory[tid], src[idx]); \
|
||||||
|
idx += block_dim; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||||
|
\
|
||||||
|
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||||
|
if (tid < s) { \
|
||||||
|
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||||
|
\
|
||||||
|
float _max = shared_memory[0]; \
|
||||||
|
\
|
||||||
|
shared_memory[tid] = 0; \
|
||||||
|
\
|
||||||
|
idx = start_idx + tid; \
|
||||||
|
while (idx < stop_idx) { \
|
||||||
|
const T val = T(exp(src[idx] - _max)); \
|
||||||
|
dst[idx] = val; \
|
||||||
|
shared_memory[tid] += val; \
|
||||||
|
idx += block_dim; \
|
||||||
|
} \
|
||||||
|
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||||
|
if (tid < s) { \
|
||||||
|
shared_memory[tid] += shared_memory[tid + s]; \
|
||||||
|
} \
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
const T inv_acc = T(1/shared_memory[0]); \
|
||||||
|
idx = start_idx + tid; \
|
||||||
|
while (idx < stop_idx) { \
|
||||||
|
dst[idx] *= inv_acc; \
|
||||||
|
idx += block_dim; \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
|
||||||
|
SOFTMAX(softmax_float, float)
|
||||||
|
SOFTMAX(softmax_half, half)
|
||||||
|
#if __METAL_VERSION__ >= 310
|
||||||
|
SOFTMAX(softmax_bfloat, bfloat)
|
||||||
|
#endif
|
||||||
|
@ -32,6 +32,9 @@ kernel void FN_NAME( \
|
|||||||
device TYPENAME *out ,\
|
device TYPENAME *out ,\
|
||||||
uint i [[ thread_position_in_grid ]] \
|
uint i [[ thread_position_in_grid ]] \
|
||||||
) { \
|
) { \
|
||||||
|
if (i >= numel){ \
|
||||||
|
return; \
|
||||||
|
} \
|
||||||
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
|
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||||
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
|
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
|
||||||
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
|
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
|
||||||
|
746
candle-metal-kernels/src/tests.rs
Normal file
746
candle-metal-kernels/src/tests.rs
Normal file
@ -0,0 +1,746 @@
|
|||||||
|
use super::*;
|
||||||
|
use half::{bf16, f16};
|
||||||
|
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
|
||||||
|
|
||||||
|
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
|
||||||
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
|
let ptr = data.as_ptr() as *const core::ffi::c_void;
|
||||||
|
let size = (data.len() * std::mem::size_of::<T>()) as u64;
|
||||||
|
device.new_buffer_with_data(ptr, size, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn device() -> Device {
|
||||||
|
Device::system_default().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
|
||||||
|
let b = 10f32.powi(digits);
|
||||||
|
v.iter().map(|t| f32::round(t * b) / b).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> {
|
||||||
|
let b = 10f32.powi(digits);
|
||||||
|
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> {
|
||||||
|
let b = 10f32.powi(digits);
|
||||||
|
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
||||||
|
let device = device();
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let input = new_buffer(&device, v);
|
||||||
|
let output = new_buffer(&device, v);
|
||||||
|
call_unary_contiguous(
|
||||||
|
&device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
name,
|
||||||
|
v.len(),
|
||||||
|
&input,
|
||||||
|
&output,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
output.read_to_vec::<T>(v.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
|
||||||
|
let device = device();
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
|
let left = new_buffer(&device, x);
|
||||||
|
let right = new_buffer(&device, y);
|
||||||
|
let output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
|
||||||
|
call_binary_contiguous(
|
||||||
|
&device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
name,
|
||||||
|
x.len(),
|
||||||
|
&left,
|
||||||
|
&right,
|
||||||
|
&output,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
output.read_to_vec::<T>(x.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_strided<T: Clone>(
|
||||||
|
v: &[T],
|
||||||
|
kernel: unary::strided::Kernel,
|
||||||
|
shape: &[usize],
|
||||||
|
strides: &[usize],
|
||||||
|
offset: usize,
|
||||||
|
) -> Vec<T> {
|
||||||
|
let device = device();
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let input = new_buffer(&device, v);
|
||||||
|
let output = new_buffer(&device, v);
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
call_unary_strided(
|
||||||
|
&device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
kernel,
|
||||||
|
shape,
|
||||||
|
&input,
|
||||||
|
strides,
|
||||||
|
offset,
|
||||||
|
&output,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
output.read_to_vec::<T>(v.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cos_f32() {
|
||||||
|
let v = vec![1.0f32, 2.0, 3.0];
|
||||||
|
let results = run(&v, unary::contiguous::cos::FLOAT);
|
||||||
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||||
|
assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
|
||||||
|
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
|
||||||
|
|
||||||
|
let v = vec![1.0f32; 10_000];
|
||||||
|
let results = run(&v, unary::contiguous::cos::FLOAT);
|
||||||
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||||
|
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
|
||||||
|
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cos_f32_strided() {
|
||||||
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
|
let shape = vec![6];
|
||||||
|
let strides = vec![1];
|
||||||
|
let offset = 0;
|
||||||
|
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||||
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||||
|
assert_eq!(
|
||||||
|
approx(results, 4),
|
||||||
|
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
approx(expected, 4),
|
||||||
|
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||||
|
);
|
||||||
|
|
||||||
|
// Contiguous
|
||||||
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
|
let shape = vec![3, 2];
|
||||||
|
let strides = vec![2, 1];
|
||||||
|
let offset = 0;
|
||||||
|
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||||
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||||
|
assert_eq!(
|
||||||
|
approx(results, 4),
|
||||||
|
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
approx(expected, 4),
|
||||||
|
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||||
|
);
|
||||||
|
|
||||||
|
// Transposed
|
||||||
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
|
let shape = vec![3, 2];
|
||||||
|
let strides = vec![1, 3];
|
||||||
|
let offset = 0;
|
||||||
|
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||||
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||||
|
assert_eq!(
|
||||||
|
approx(results, 4),
|
||||||
|
vec![0.5403, -0.6536, -0.4161, 0.2837, -0.99, 0.9602]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
approx(expected, 4),
|
||||||
|
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||||
|
);
|
||||||
|
|
||||||
|
// Very large
|
||||||
|
let v = vec![1.0f32; 10_000];
|
||||||
|
let shape = vec![2, 5_000];
|
||||||
|
let strides = vec![2, 1];
|
||||||
|
let offset = 0;
|
||||||
|
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||||
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||||
|
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
|
||||||
|
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cos_strided_random() {
|
||||||
|
let v: Vec<_> = (0..10_000).map(|_| rand::random::<f32>()).collect();
|
||||||
|
let shape = vec![5_000, 2];
|
||||||
|
let strides = vec![1, 5_000];
|
||||||
|
let offset = 0;
|
||||||
|
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||||
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||||
|
assert_eq!(approx(vec![results[0]], 4), approx(vec![expected[0]], 4));
|
||||||
|
assert_eq!(
|
||||||
|
approx(vec![results[1]], 4),
|
||||||
|
approx(vec![expected[5_000]], 4)
|
||||||
|
);
|
||||||
|
assert_eq!(approx(vec![results[2]], 4), approx(vec![expected[1]], 4));
|
||||||
|
assert_eq!(
|
||||||
|
approx(vec![results[3]], 4),
|
||||||
|
approx(vec![expected[5_001]], 4)
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
approx(vec![results[5_000]], 4),
|
||||||
|
approx(vec![expected[2_500]], 4)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn gelu_f16() {
|
||||||
|
let v: Vec<f16> = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]
|
||||||
|
.iter()
|
||||||
|
.map(|v| f16::from_f32(*v))
|
||||||
|
.collect();
|
||||||
|
let expected: Vec<f32> = vec![-0.0, -0.16, 0.0, 0.84, 1.96, 3.0, 10.0, 20.0];
|
||||||
|
let results = run(&v, unary::contiguous::gelu::HALF);
|
||||||
|
assert_eq!(approx_f16(results, 2), expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn gelu_f32() {
|
||||||
|
let v: Vec<f32> = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0];
|
||||||
|
let expected: Vec<f32> = vec![-0.0, -0.159, 0.0, 0.841, 1.955, 2.996, 10.0, 20.0];
|
||||||
|
let results = run(&v, unary::contiguous::gelu::FLOAT);
|
||||||
|
assert_eq!(approx(results, 3), expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn binary_add_f32() {
|
||||||
|
let left = vec![1.0f32, 2.0, 3.0];
|
||||||
|
let right = vec![2.0f32, 3.1, 4.2];
|
||||||
|
let results = run_binary(&left, &right, binary::contiguous::add::FLOAT);
|
||||||
|
let expected: Vec<_> = left
|
||||||
|
.iter()
|
||||||
|
.zip(right.iter())
|
||||||
|
.map(|(&x, &y)| x + y)
|
||||||
|
.collect();
|
||||||
|
assert_eq!(approx(results, 4), vec![3.0f32, 5.1, 7.2]);
|
||||||
|
assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||||
|
let device = device();
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let input = new_buffer(&device, v);
|
||||||
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
|
let size = (v.len() * std::mem::size_of::<U>()) as u64;
|
||||||
|
let output = device.new_buffer(size, options);
|
||||||
|
|
||||||
|
call_cast_contiguous(
|
||||||
|
&device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
name,
|
||||||
|
v.len(),
|
||||||
|
&input,
|
||||||
|
0,
|
||||||
|
&output,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
output.read_to_vec::<U>(v.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cast_u32_f32() {
|
||||||
|
let v = vec![1u32, 2, 3];
|
||||||
|
let results = cast(&v, "cast_u32_f32");
|
||||||
|
let expected: Vec<_> = v.iter().map(|&v| v as f32).collect();
|
||||||
|
assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
|
||||||
|
assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
|
||||||
|
|
||||||
|
let v = vec![1.0f32, 2.0, 3.0];
|
||||||
|
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
|
||||||
|
let results: Vec<f32> = cast(&input, "cast_f16_f32");
|
||||||
|
assert_eq!(results, vec![1.0f32, 2.0, 3.0]);
|
||||||
|
|
||||||
|
let v = vec![1.0f32; 10_000];
|
||||||
|
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
|
||||||
|
let results: Vec<f32> = cast(&input, "cast_f16_f32");
|
||||||
|
assert_eq!(results.len(), 10_000);
|
||||||
|
assert_eq!(&results[..10], vec![1.0f32; 10]);
|
||||||
|
assert_eq!(results, vec![1.0f32; 10_000]);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||||
|
let device = device();
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
|
||||||
|
let input = new_buffer(&device, v);
|
||||||
|
let output = new_buffer(&device, v);
|
||||||
|
|
||||||
|
let size = v.len();
|
||||||
|
|
||||||
|
call_affine(
|
||||||
|
&device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
"affine_float",
|
||||||
|
size,
|
||||||
|
&input,
|
||||||
|
&output,
|
||||||
|
mul as f32,
|
||||||
|
add as f32,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
output.read_to_vec::<T>(v.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_affine_strided<T: Clone>(
|
||||||
|
v: &[T],
|
||||||
|
shape: &[usize],
|
||||||
|
strides: &[usize],
|
||||||
|
mul: f64,
|
||||||
|
add: f64,
|
||||||
|
) -> Vec<T> {
|
||||||
|
let device = device();
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
|
||||||
|
let input = new_buffer(&device, v);
|
||||||
|
let output = new_buffer(&device, v);
|
||||||
|
|
||||||
|
call_affine_strided(
|
||||||
|
&device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
"affine_float_strided",
|
||||||
|
shape,
|
||||||
|
&input,
|
||||||
|
strides,
|
||||||
|
0,
|
||||||
|
&output,
|
||||||
|
mul as f32,
|
||||||
|
add as f32,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
let len: usize = shape.iter().product();
|
||||||
|
output.read_to_vec::<T>(len)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn affine() {
|
||||||
|
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||||
|
let mul = 1.5;
|
||||||
|
let add = 1.1;
|
||||||
|
let result = run_affine(&input, mul, add);
|
||||||
|
assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]);
|
||||||
|
|
||||||
|
let input = [1.0f32; 40_000];
|
||||||
|
let mul = 1.5;
|
||||||
|
let add = 1.1;
|
||||||
|
let result = run_affine(&input, mul, add);
|
||||||
|
assert_eq!(result, vec![2.6; 40_000]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn affine_strided() {
|
||||||
|
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||||
|
let mul = 1.5;
|
||||||
|
let add = 1.1;
|
||||||
|
let shape = [4];
|
||||||
|
let strides = [2];
|
||||||
|
let result = run_affine_strided(&input, &shape, &strides, mul, add);
|
||||||
|
// 1 on 2
|
||||||
|
assert_eq!(result, vec![2.6, 5.6, 8.6, 11.6]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn index_select() {
|
||||||
|
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||||
|
let shape = [5, 2];
|
||||||
|
let ids = [0u32, 4, 2];
|
||||||
|
let dim = 0;
|
||||||
|
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||||
|
assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]);
|
||||||
|
|
||||||
|
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||||
|
let shape = [2, 5];
|
||||||
|
let ids = [0u32, 1, 0];
|
||||||
|
let dim = 0;
|
||||||
|
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||||
|
assert_eq!(
|
||||||
|
result,
|
||||||
|
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn index_select_f16() {
|
||||||
|
let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
||||||
|
.into_iter()
|
||||||
|
.map(|x| f16::from_f32(x))
|
||||||
|
.collect();
|
||||||
|
let shape = [5, 2];
|
||||||
|
let ids = [0u32, 4, 2];
|
||||||
|
let dim = 0;
|
||||||
|
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||||
|
assert_eq!(
|
||||||
|
approx_f16(result, 4),
|
||||||
|
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn index_select_dim1() {
|
||||||
|
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||||
|
let shape = [5, 2];
|
||||||
|
let ids = [0u32, 1, 0];
|
||||||
|
let dim = 1;
|
||||||
|
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||||
|
assert_eq!(
|
||||||
|
result,
|
||||||
|
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
||||||
|
embeddings: &[T],
|
||||||
|
shape: &[usize],
|
||||||
|
ids: &[I],
|
||||||
|
dim: usize,
|
||||||
|
) -> Vec<T> {
|
||||||
|
let device = Device::system_default().expect("no device found");
|
||||||
|
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let embeddings_buffer = new_buffer(&device, &embeddings);
|
||||||
|
let ids_buffer = new_buffer(&device, &ids);
|
||||||
|
|
||||||
|
let left_size: usize = shape[..dim].iter().product();
|
||||||
|
let right_size: usize = shape[dim + 1..].iter().product();
|
||||||
|
let dst_el = ids.len() * left_size * right_size;
|
||||||
|
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
|
||||||
|
|
||||||
|
let name = match core::mem::size_of::<T>() {
|
||||||
|
4 => "is_u32_f32",
|
||||||
|
2 => "is_u32_f16",
|
||||||
|
_ => unimplemented!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
call_index_select(
|
||||||
|
&device,
|
||||||
|
&command_buffer,
|
||||||
|
&kernels,
|
||||||
|
name,
|
||||||
|
shape,
|
||||||
|
ids.len(),
|
||||||
|
dim,
|
||||||
|
&embeddings_buffer,
|
||||||
|
&ids_buffer,
|
||||||
|
&dst_buffer,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
dst_buffer.read_to_vec::<T>(dst_el)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn index_add() {
|
||||||
|
let device = Device::system_default().expect("no device found");
|
||||||
|
|
||||||
|
let options = CompileOptions::new();
|
||||||
|
let library = device.new_library_with_source(INDEXING, &options).unwrap();
|
||||||
|
|
||||||
|
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
|
||||||
|
let right = [1.0f32; 15];
|
||||||
|
let index = [0u32, 4, 2];
|
||||||
|
let ids_dim_size = index.len() as u32;
|
||||||
|
let dst_dim_size: u32 = 15;
|
||||||
|
let left_size: u32 = 3;
|
||||||
|
let right_size: u32 = 3;
|
||||||
|
|
||||||
|
let function = library.get_function("ia_u32_f32", None).unwrap();
|
||||||
|
let pipeline = device
|
||||||
|
.new_compute_pipeline_state_with_function(&function)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
|
let index_buffer = new_buffer(&device, &index);
|
||||||
|
let inputs_buffer = new_buffer(&device, &left);
|
||||||
|
let outputs_buffer = new_buffer(&device, &right);
|
||||||
|
|
||||||
|
set_params!(
|
||||||
|
encoder,
|
||||||
|
(
|
||||||
|
&index_buffer,
|
||||||
|
&inputs_buffer,
|
||||||
|
&outputs_buffer,
|
||||||
|
ids_dim_size,
|
||||||
|
left_size,
|
||||||
|
dst_dim_size,
|
||||||
|
right_size
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
let grid_size = MTLSize {
|
||||||
|
width: right.len() as NSUInteger,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
let thread_group_size = MTLSize {
|
||||||
|
width: pipeline.max_total_threads_per_threadgroup(),
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
encoder.dispatch_thread_groups(grid_size, thread_group_size);
|
||||||
|
encoder.end_encoding();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
let expected = vec![
|
||||||
|
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
|
||||||
|
];
|
||||||
|
let result = outputs_buffer.read_to_vec::<f32>(right.len());
|
||||||
|
assert_eq!(result, expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cos_f16() {
|
||||||
|
let v: Vec<f16> = [1.0f32, 2.0, 3.0]
|
||||||
|
.iter()
|
||||||
|
.map(|v| f16::from_f32(*v))
|
||||||
|
.collect();
|
||||||
|
let results = run(&v, unary::contiguous::cos::HALF);
|
||||||
|
let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
|
||||||
|
assert_eq!(approx_f16(results, 2), vec![0.54, -0.42, -0.99]);
|
||||||
|
assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
|
||||||
|
let device = device();
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let input = new_buffer(&device, v);
|
||||||
|
|
||||||
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
|
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
||||||
|
call_reduce_contiguous(
|
||||||
|
&device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
name,
|
||||||
|
v.len(),
|
||||||
|
out_length,
|
||||||
|
&input,
|
||||||
|
0,
|
||||||
|
&output,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
output.read_to_vec::<T>(out_length)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> {
|
||||||
|
let device = device();
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let input = new_buffer(&device, v);
|
||||||
|
let output = new_buffer(&device, v);
|
||||||
|
call_last_softmax(
|
||||||
|
&device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
name,
|
||||||
|
v.len(),
|
||||||
|
last_dim,
|
||||||
|
&input,
|
||||||
|
&output,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
output.read_to_vec::<T>(v.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn reduce_sum() {
|
||||||
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
|
let out_length = 1;
|
||||||
|
|
||||||
|
let results = run_reduce(&v, out_length, "fast_sum_float");
|
||||||
|
assert_eq!(approx(results, 4), vec![21.0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn reduce_sum2() {
|
||||||
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
|
let out_length = 2;
|
||||||
|
|
||||||
|
let results = run_reduce(&v, out_length, "fast_sum_float");
|
||||||
|
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn softmax() {
|
||||||
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
|
let last_dim = 6;
|
||||||
|
let results = run_softmax(&v, last_dim, "softmax_float");
|
||||||
|
assert_eq!(
|
||||||
|
approx(results, 4),
|
||||||
|
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
|
||||||
|
);
|
||||||
|
|
||||||
|
let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
|
||||||
|
let last_dim = 6;
|
||||||
|
let results = run_softmax(&v, last_dim, "softmax_float");
|
||||||
|
assert_eq!(
|
||||||
|
approx(results, 4),
|
||||||
|
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
|
||||||
|
);
|
||||||
|
|
||||||
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
|
let last_dim = 3;
|
||||||
|
let results = run_softmax(&v, last_dim, "softmax_float");
|
||||||
|
assert_eq!(
|
||||||
|
approx(results, 4),
|
||||||
|
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
|
||||||
|
);
|
||||||
|
|
||||||
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||||
|
.iter()
|
||||||
|
.map(|v| f16::from_f32(*v))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let last_dim = 6;
|
||||||
|
let results = run_softmax(&v, last_dim, "softmax_half");
|
||||||
|
assert_eq!(
|
||||||
|
approx_f16(results, 4),
|
||||||
|
vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
|
||||||
|
);
|
||||||
|
|
||||||
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||||
|
.iter()
|
||||||
|
.map(|v| bf16::from_f32(*v))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let last_dim = 6;
|
||||||
|
let results = run_softmax(&v, last_dim, "softmax_bfloat");
|
||||||
|
assert_eq!(
|
||||||
|
approx_bf16(results, 4),
|
||||||
|
vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_where_cond<I: Clone, T: Clone>(
|
||||||
|
shape: &[usize],
|
||||||
|
cond: &[I],
|
||||||
|
(cond_stride, cond_offset): (Vec<usize>, usize),
|
||||||
|
left_true: &[T],
|
||||||
|
(left_stride, left_offset): (Vec<usize>, usize),
|
||||||
|
right_false: &[T],
|
||||||
|
(_right_stride, _right_offset): (Vec<usize>, usize),
|
||||||
|
name: &'static str,
|
||||||
|
) -> Vec<T> {
|
||||||
|
let device = device();
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
|
|
||||||
|
let length = cond.len();
|
||||||
|
let cond = device.new_buffer_with_data(
|
||||||
|
cond.as_ptr() as *const core::ffi::c_void,
|
||||||
|
std::mem::size_of_val(cond) as u64,
|
||||||
|
options,
|
||||||
|
);
|
||||||
|
let left = device.new_buffer_with_data(
|
||||||
|
left_true.as_ptr() as *const core::ffi::c_void,
|
||||||
|
(length * core::mem::size_of::<T>()) as u64,
|
||||||
|
options,
|
||||||
|
);
|
||||||
|
let right = device.new_buffer_with_data(
|
||||||
|
right_false.as_ptr() as *const core::ffi::c_void,
|
||||||
|
(length * core::mem::size_of::<T>()) as u64,
|
||||||
|
options,
|
||||||
|
);
|
||||||
|
|
||||||
|
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
||||||
|
call_where_cond_strided(
|
||||||
|
&device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
name,
|
||||||
|
shape,
|
||||||
|
&cond,
|
||||||
|
(&cond_stride, cond_offset),
|
||||||
|
&left,
|
||||||
|
(&left_stride, left_offset),
|
||||||
|
&right,
|
||||||
|
(&cond_stride, cond_offset),
|
||||||
|
&output,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
output.read_to_vec::<T>(length)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn where_cond() {
|
||||||
|
let shape = vec![6];
|
||||||
|
let cond = vec![0u8, 1, 0, 0, 1, 1];
|
||||||
|
let cond_l = (vec![1], 0);
|
||||||
|
let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
|
let left_l = (vec![1], 0);
|
||||||
|
let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0];
|
||||||
|
let right_l = (vec![1], 0);
|
||||||
|
let results = run_where_cond(
|
||||||
|
&shape,
|
||||||
|
&cond,
|
||||||
|
cond_l,
|
||||||
|
&left_true,
|
||||||
|
left_l,
|
||||||
|
&right_false,
|
||||||
|
right_l,
|
||||||
|
"where_u8_f32",
|
||||||
|
);
|
||||||
|
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
|
||||||
|
}
|
@ -1,4 +1,7 @@
|
|||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
#include <metal_math>
|
||||||
|
#
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
METAL_FUNC uint get_strided_index(
|
METAL_FUNC uint get_strided_index(
|
||||||
uint idx,
|
uint idx,
|
||||||
@ -17,10 +20,44 @@ METAL_FUNC uint get_strided_index(
|
|||||||
|
|
||||||
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
|
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
|
||||||
template <typename T> METAL_FUNC T neg(T in){ return -in; }
|
template <typename T> METAL_FUNC T neg(T in){ return -in; }
|
||||||
|
template <typename T> METAL_FUNC T erf(T in){
|
||||||
|
float x = (float) in;
|
||||||
|
// constants
|
||||||
|
float a1 = 0.254829592;
|
||||||
|
float a2 = -0.284496736;
|
||||||
|
float a3 = 1.421413741;
|
||||||
|
float a4 = -1.453152027;
|
||||||
|
float a5 = 1.061405429;
|
||||||
|
float p = 0.3275911;
|
||||||
|
|
||||||
|
// Save the sign of x
|
||||||
|
int sign = 1;
|
||||||
|
if (x < 0)
|
||||||
|
sign = -1;
|
||||||
|
x = fabs(x);
|
||||||
|
|
||||||
|
// A&S formula 7.1.26
|
||||||
|
float t = 1.0/(1.0 + p*x);
|
||||||
|
float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x);
|
||||||
|
|
||||||
|
return T(sign*y);
|
||||||
|
}
|
||||||
template <typename T> METAL_FUNC T id(T in) { return in; }
|
template <typename T> METAL_FUNC T id(T in) { return in; }
|
||||||
|
template <typename T> METAL_FUNC T gelu_erf(T x) {
|
||||||
|
return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2);
|
||||||
|
}
|
||||||
|
template <typename T> METAL_FUNC T gelu(T x) {
|
||||||
|
if (x > 5) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
T x_sq = x * x;
|
||||||
|
T x_cube = x_sq * x;
|
||||||
|
T alpha = x + static_cast<T>(0.044715) * x_cube;
|
||||||
|
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
|
||||||
|
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
using namespace metal;
|
|
||||||
|
|
||||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||||
kernel void FN_NAME( \
|
kernel void FN_NAME( \
|
||||||
@ -63,8 +100,17 @@ UNARY_OP(sqr)
|
|||||||
UNARY_OP(sqrt)
|
UNARY_OP(sqrt)
|
||||||
UNARY_OP(neg)
|
UNARY_OP(neg)
|
||||||
UNARY_OP(exp)
|
UNARY_OP(exp)
|
||||||
|
UNARY_OP(log)
|
||||||
|
UNARY_OP(gelu)
|
||||||
|
UNARY_OP(ceil)
|
||||||
|
UNARY_OP(floor)
|
||||||
|
UNARY_OP(round)
|
||||||
|
UNARY_OP(gelu_erf)
|
||||||
|
UNARY_OP(erf)
|
||||||
UNARY(id, float, copy_float, copy_float_strided)
|
UNARY(id, float, copy_float, copy_float_strided)
|
||||||
UNARY(id, half, copy_half, copy_half_strided)
|
UNARY(id, half, copy_half, copy_half_strided)
|
||||||
|
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
||||||
|
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
BFLOAT_UNARY_OP(cos)
|
BFLOAT_UNARY_OP(cos)
|
||||||
@ -73,6 +119,13 @@ BFLOAT_UNARY_OP(sqr)
|
|||||||
BFLOAT_UNARY_OP(sqrt)
|
BFLOAT_UNARY_OP(sqrt)
|
||||||
BFLOAT_UNARY_OP(neg)
|
BFLOAT_UNARY_OP(neg)
|
||||||
BFLOAT_UNARY_OP(exp)
|
BFLOAT_UNARY_OP(exp)
|
||||||
|
BFLOAT_UNARY_OP(log)
|
||||||
|
BFLOAT_UNARY_OP(gelu)
|
||||||
|
BFLOAT_UNARY_OP(ceil)
|
||||||
|
BFLOAT_UNARY_OP(floor)
|
||||||
|
BFLOAT_UNARY_OP(round)
|
||||||
|
BFLOAT_UNARY_OP(gelu_erf)
|
||||||
|
BFLOAT_UNARY_OP(erf)
|
||||||
|
|
||||||
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
|
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
|
||||||
#endif
|
#endif
|
||||||
|
@ -50,6 +50,7 @@ fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) {
|
|||||||
&device,
|
&device,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
&kernels,
|
&kernels,
|
||||||
|
"affine_float",
|
||||||
v.len(),
|
v.len(),
|
||||||
&input,
|
&input,
|
||||||
&mut output,
|
&mut output,
|
@ -147,7 +147,7 @@ fn run_unary_bench<T: Clone>(
|
|||||||
println!(
|
println!(
|
||||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||||
type_name::<T>().split("::").last().unwrap(),
|
type_name::<T>().split("::").last().unwrap(),
|
||||||
kernel_name.to_string(),
|
kernel_name.0,
|
||||||
v.len(),
|
v.len(),
|
||||||
iterations,
|
iterations,
|
||||||
total_time,
|
total_time,
|
||||||
@ -159,7 +159,7 @@ fn run_unary_bench<T: Clone>(
|
|||||||
let shape = vec![2, 5_000];
|
let shape = vec![2, 5_000];
|
||||||
let strides = vec![2, 1];
|
let strides = vec![2, 1];
|
||||||
let offset = 0;
|
let offset = 0;
|
||||||
for kernel_name in strided {
|
for kernel_name in &strided {
|
||||||
let total_time = autoreleasepool(|| {
|
let total_time = autoreleasepool(|| {
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
@ -187,7 +187,7 @@ fn run_unary_bench<T: Clone>(
|
|||||||
println!(
|
println!(
|
||||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||||
type_name::<T>().split("::").last().unwrap(),
|
type_name::<T>().split("::").last().unwrap(),
|
||||||
kernel_name.to_string(),
|
kernel_name.0,
|
||||||
v.len(),
|
v.len(),
|
||||||
iterations,
|
iterations,
|
||||||
total_time,
|
total_time,
|
@ -11,7 +11,7 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||||
half = { workspace = true }
|
half = { workspace = true }
|
||||||
thiserror = { workspace = true }
|
thiserror = { workspace = true }
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
@ -19,6 +19,7 @@ num-traits = { workspace = true }
|
|||||||
rayon = { workspace = true }
|
rayon = { workspace = true }
|
||||||
safetensors = { workspace = true }
|
safetensors = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
|
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
@ -29,3 +30,4 @@ default = []
|
|||||||
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
||||||
cuda = ["candle/cuda"]
|
cuda = ["candle/cuda"]
|
||||||
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
||||||
|
metal = ["candle/metal", "dep:candle-metal-kernels"]
|
||||||
|
@ -6,7 +6,7 @@ extern crate intel_mkl_src;
|
|||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use candle::quantized::GgmlType;
|
use candle::quantized::GgmlType;
|
||||||
use candle::{CpuStorage, Device, Layout, Result, Shape, Tensor, D};
|
use candle::{CpuStorage, Device, Layout, Module, Result, Shape, Tensor, D};
|
||||||
use clap::{Parser, Subcommand};
|
use clap::{Parser, Subcommand};
|
||||||
|
|
||||||
const CHECK_CONV2D: bool = false;
|
const CHECK_CONV2D: bool = false;
|
||||||
|
@ -6,7 +6,6 @@ use serde::Deserialize;
|
|||||||
pub enum Activation {
|
pub enum Activation {
|
||||||
#[default]
|
#[default]
|
||||||
Gelu,
|
Gelu,
|
||||||
#[serde(rename = "gated-gelu")]
|
|
||||||
NewGelu,
|
NewGelu,
|
||||||
Relu,
|
Relu,
|
||||||
Relu2,
|
Relu2,
|
||||||
|
@ -9,7 +9,6 @@ pub struct Embedding {
|
|||||||
|
|
||||||
impl Embedding {
|
impl Embedding {
|
||||||
pub fn new(embeddings: Tensor, hidden_size: usize) -> Self {
|
pub fn new(embeddings: Tensor, hidden_size: usize) -> Self {
|
||||||
// todo!("Embedding {embeddings}");
|
|
||||||
Self {
|
Self {
|
||||||
embeddings,
|
embeddings,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
|
@ -201,6 +201,46 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||||||
};
|
};
|
||||||
Ok((dst, layout.shape().clone()))
|
Ok((dst, layout.shape().clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
storage: &candle::MetalStorage,
|
||||||
|
layout: &Layout,
|
||||||
|
) -> Result<(candle::MetalStorage, Shape)> {
|
||||||
|
use candle::{backend::BackendStorage, DType};
|
||||||
|
let device = storage.device();
|
||||||
|
let command_buffer = device.command_buffer();
|
||||||
|
let kernels = device.kernels();
|
||||||
|
let name = match storage.dtype() {
|
||||||
|
DType::F32 => "softmax_float",
|
||||||
|
DType::F16 => "softmax_half",
|
||||||
|
DType::BF16 => "softmax_bfloat",
|
||||||
|
dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let n = layout.stride().len();
|
||||||
|
if !(layout.stride()[n - 1] == 1 && layout.start_offset() == 0) {
|
||||||
|
candle::bail!("Non contiguous softmax-last-dim is not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
|
let last_dim = layout.dims()[layout.shape().rank() - 1];
|
||||||
|
let elem_count = layout.shape().elem_count();
|
||||||
|
let mut output = device.new_buffer(elem_count, storage.dtype());
|
||||||
|
candle_metal_kernels::call_last_softmax(
|
||||||
|
device.metal_device(),
|
||||||
|
&command_buffer,
|
||||||
|
&kernels,
|
||||||
|
name,
|
||||||
|
elem_count,
|
||||||
|
last_dim,
|
||||||
|
storage.buffer(),
|
||||||
|
&mut output,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
|
||||||
|
Ok((newstorage, layout.shape().clone()))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-onnx"
|
name = "candle-onnx"
|
||||||
version = "0.3.0"
|
version = "0.3.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "ONNX support for Candle"
|
description = "ONNX support for Candle"
|
||||||
@ -10,8 +10,8 @@ categories = ["science"]
|
|||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
||||||
prost = "0.12.1"
|
prost = "0.12.1"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
@ -741,6 +741,25 @@ pub fn simple_eval(
|
|||||||
let output = input.to_dtype(dtype)?;
|
let output = input.to_dtype(dtype)?;
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#CumSum
|
||||||
|
"CumSum" => {
|
||||||
|
let exclusive = get_attr_opt::<i64>(node, "exclusive")?
|
||||||
|
.copied()
|
||||||
|
.unwrap_or(0);
|
||||||
|
let reverse = get_attr_opt::<i64>(node, "reverse")?.copied().unwrap_or(0);
|
||||||
|
if exclusive != 0 {
|
||||||
|
bail!("only exclusive == 0 is supported in CumSum")
|
||||||
|
}
|
||||||
|
if reverse != 0 {
|
||||||
|
bail!("only reverse == 0 is supported in CumSum")
|
||||||
|
}
|
||||||
|
let input = get(&node.input[0])?;
|
||||||
|
let axis = get(&node.input[1])?
|
||||||
|
.to_dtype(DType::U32)?
|
||||||
|
.to_vec0::<u32>()?;
|
||||||
|
let output = input.cumsum(axis as usize)?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
746
candle-onnx/tests/ops.rs
Normal file
746
candle-onnx/tests/ops.rs
Normal file
@ -0,0 +1,746 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use candle::{Device, Result, Tensor};
|
||||||
|
use candle_onnx::onnx::{GraphProto, ModelProto, NodeProto, ValueInfoProto};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
const INPUT_X: &str = "x";
|
||||||
|
const INPUT_Y: &str = "y";
|
||||||
|
const OUTPUT_Z: &str = "z";
|
||||||
|
|
||||||
|
fn create_model_proto_with_graph(graph: Option<GraphProto>) -> ModelProto {
|
||||||
|
ModelProto {
|
||||||
|
metadata_props: vec![],
|
||||||
|
training_info: vec![],
|
||||||
|
functions: vec![],
|
||||||
|
ir_version: 0,
|
||||||
|
opset_import: vec![],
|
||||||
|
producer_name: "".to_string(),
|
||||||
|
producer_version: "".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
model_version: 0,
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
graph,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_evaluation_fails_without_defined_graph() -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(None);
|
||||||
|
|
||||||
|
let inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
|
||||||
|
match candle_onnx::simple_eval(&manual_graph, inputs) {
|
||||||
|
Err(err) => assert_eq!(err.to_string(), "no graph defined in proto"),
|
||||||
|
Ok(_) => panic!("Expected an error due to undefined graph"),
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// "Add"
|
||||||
|
#[test]
|
||||||
|
fn test_add_operation() -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Add".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||||
|
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
let first = z
|
||||||
|
.to_vec1::<f64>()?
|
||||||
|
.to_vec()
|
||||||
|
.get(0)
|
||||||
|
.expect("Failed to get first element")
|
||||||
|
.clone();
|
||||||
|
assert_eq!(first, 4.0f64);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// "Sub"
|
||||||
|
#[test]
|
||||||
|
fn test_sub_operation() -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Sub".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||||
|
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
let first = z
|
||||||
|
.to_vec1::<f64>()?
|
||||||
|
.to_vec()
|
||||||
|
.get(0)
|
||||||
|
.expect("Failed to get first element")
|
||||||
|
.clone();
|
||||||
|
assert_eq!(first, 0.0f64);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// "Mul"
|
||||||
|
#[test]
|
||||||
|
fn test_mul_operation() -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Mul".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||||
|
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
let first = z
|
||||||
|
.to_vec1::<f64>()?
|
||||||
|
.to_vec()
|
||||||
|
.get(0)
|
||||||
|
.expect("Failed to get first element")
|
||||||
|
.clone();
|
||||||
|
assert_eq!(first, 4.0f64);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// "Div"
|
||||||
|
#[test]
|
||||||
|
fn test_div_operation() -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Div".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||||
|
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
let first = z
|
||||||
|
.to_vec1::<f64>()?
|
||||||
|
.to_vec()
|
||||||
|
.get(0)
|
||||||
|
.expect("Failed to get first element")
|
||||||
|
.clone();
|
||||||
|
|
||||||
|
assert_eq!(first, 1.0f64);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// "Equal"
|
||||||
|
#[test]
|
||||||
|
fn test_equal_operation() -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Equal".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||||
|
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
let first = z.to_dtype(candle::DType::U8)?.to_vec1::<u8>()?.to_vec()[0];
|
||||||
|
assert_eq!(first, 1);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// "Not"
|
||||||
|
#[test]
|
||||||
|
fn test_not_operation() -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Not".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![INPUT_X.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), Tensor::new(&[0.], &Device::Cpu)?);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
let first = z.to_dtype(candle::DType::U8)?.to_vec1::<u8>()?.to_vec()[0];
|
||||||
|
assert_eq!(first, 1);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// "MatMul"
|
||||||
|
#[test]
|
||||||
|
fn test_matmul_operation() -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "MatMul".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(
|
||||||
|
INPUT_X.to_string(),
|
||||||
|
Tensor::from_vec(
|
||||||
|
//
|
||||||
|
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
|
||||||
|
&[2, 2],
|
||||||
|
&Device::Cpu,
|
||||||
|
)?,
|
||||||
|
);
|
||||||
|
inputs.insert(
|
||||||
|
INPUT_Y.to_string(),
|
||||||
|
Tensor::from_vec(
|
||||||
|
//
|
||||||
|
vec![5.0f32, 6.0f32, 7.0f32, 8.0f32],
|
||||||
|
&[2, 2],
|
||||||
|
&Device::Cpu,
|
||||||
|
)?,
|
||||||
|
);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
let results = z.to_vec2::<f32>()?;
|
||||||
|
assert_eq!(results, vec![vec![19.0, 22.0], vec![43.0, 50.0]]);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// "Reshape"
|
||||||
|
#[test]
|
||||||
|
fn test_reshape_operation() -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Reshape".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![
|
||||||
|
ValueInfoProto {
|
||||||
|
name: INPUT_X.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
},
|
||||||
|
ValueInfoProto {
|
||||||
|
name: INPUT_Y.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let x = Tensor::from_vec(
|
||||||
|
//
|
||||||
|
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
|
||||||
|
&[2, 2],
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
let y = Tensor::from_vec(
|
||||||
|
//
|
||||||
|
vec![4i64],
|
||||||
|
&[1],
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), x);
|
||||||
|
inputs.insert(INPUT_Y.to_string(), y);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
|
||||||
|
let results = z.to_vec1::<f32>()?;
|
||||||
|
|
||||||
|
assert_eq!(results, vec![1.0, 2.0, 3.0, 4.0]);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// "LogSoftmax"
|
||||||
|
#[test]
|
||||||
|
fn test_logsoftmax_operation() -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "LogSoftmax".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![INPUT_X.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![
|
||||||
|
ValueInfoProto {
|
||||||
|
name: INPUT_X.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
},
|
||||||
|
ValueInfoProto {
|
||||||
|
name: INPUT_Y.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let x = Tensor::from_vec(
|
||||||
|
//
|
||||||
|
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
|
||||||
|
&[2, 2],
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), x);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
|
||||||
|
let results = z.to_vec2::<f32>()?;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
results,
|
||||||
|
vec![vec![0.26894143, 0.7310586], vec![0.26894143, 0.7310586]]
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// "Softmax"
|
||||||
|
#[test]
|
||||||
|
fn test_softmax_operation() -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Softmax".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![INPUT_X.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![
|
||||||
|
ValueInfoProto {
|
||||||
|
name: INPUT_X.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
},
|
||||||
|
ValueInfoProto {
|
||||||
|
name: INPUT_Y.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let x = Tensor::from_vec(
|
||||||
|
//
|
||||||
|
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
|
||||||
|
&[2, 2],
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), x);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
|
||||||
|
let results = z.to_vec2::<f32>()?;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
results,
|
||||||
|
vec![vec![0.26894143, 0.7310586], vec![0.26894143, 0.7310586]]
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// "Transpose"
|
||||||
|
#[test]
|
||||||
|
fn test_transpose_operation() -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Transpose".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![INPUT_X.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![
|
||||||
|
ValueInfoProto {
|
||||||
|
name: INPUT_X.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
},
|
||||||
|
ValueInfoProto {
|
||||||
|
name: INPUT_Y.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let x = Tensor::from_vec(
|
||||||
|
//
|
||||||
|
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
|
||||||
|
&[2, 2],
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), x);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
|
||||||
|
let results = z.to_vec2::<f32>()?;
|
||||||
|
|
||||||
|
assert_eq!(results, vec![vec![1.0, 3.0], vec![2.0, 4.0]]);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// "Dropout"
|
||||||
|
#[test]
|
||||||
|
fn test_dropout_operation() -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Dropout".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![INPUT_X.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![
|
||||||
|
ValueInfoProto {
|
||||||
|
name: INPUT_X.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
},
|
||||||
|
ValueInfoProto {
|
||||||
|
name: INPUT_Y.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
let x = Tensor::from_vec(
|
||||||
|
//
|
||||||
|
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
|
||||||
|
&[2, 2],
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), x);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
|
||||||
|
let results = z.to_vec2::<f32>()?;
|
||||||
|
|
||||||
|
assert_eq!(results, vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Below are ops that are implemented but not tested yet
|
||||||
|
|
||||||
|
// "MaxPool"
|
||||||
|
// #[test]
|
||||||
|
|
||||||
|
// "AveragePool"
|
||||||
|
// #[test]
|
||||||
|
|
||||||
|
// "BatchNormalization"
|
||||||
|
// #[test]
|
||||||
|
|
||||||
|
// "Squeeze"
|
||||||
|
// #[test]
|
||||||
|
|
||||||
|
// "ConstantOfShape"
|
||||||
|
// #[test]
|
||||||
|
|
||||||
|
// "Unsqueeze"
|
||||||
|
// #[test]
|
||||||
|
|
||||||
|
// "Clip"
|
||||||
|
// #[test]
|
||||||
|
|
||||||
|
// "Gather"
|
||||||
|
// #[test]
|
||||||
|
|
||||||
|
// "Shape"
|
||||||
|
// #[test]
|
||||||
|
|
||||||
|
// "Conv"
|
||||||
|
// #[test]
|
||||||
|
|
||||||
|
// "Concat"
|
||||||
|
// #[test]
|
||||||
|
|
||||||
|
// "Abs"
|
||||||
|
// #[test]
|
||||||
|
|
||||||
|
// "Cos"
|
||||||
|
// #[test]
|
||||||
|
|
||||||
|
// "Sin"
|
||||||
|
// #[test]
|
||||||
|
|
||||||
|
// "Neg"
|
||||||
|
// #[test]
|
||||||
|
|
||||||
|
// "Erf"
|
||||||
|
// #[test]
|
||||||
|
|
||||||
|
// "Tanh"
|
||||||
|
// #[test]
|
||||||
|
|
||||||
|
// "Sigmoid"
|
||||||
|
// #[test]
|
||||||
|
|
||||||
|
// "Gelu"
|
||||||
|
// #[test]
|
||||||
|
|
||||||
|
// "Relu"
|
||||||
|
// #[test]
|
||||||
|
|
||||||
|
// "Constant"
|
||||||
|
// #[test]
|
||||||
|
|
||||||
|
// "Cast"
|
||||||
|
// #[test]
|
@ -15,9 +15,9 @@ crate-type = ["cdylib"]
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
||||||
candle-onnx = {path= "../candle-onnx", version = "0.3.0", optional = true}
|
candle-onnx = {path= "../candle-onnx", version = "0.3.1", optional = true}
|
||||||
half = { workspace = true }
|
half = { workspace = true }
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
|
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
|
||||||
|
@ -17,7 +17,7 @@ extern crate intel_mkl_src;
|
|||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
|
use ::candle::{quantized::QTensor, DType, Device, Module, Tensor, WithDType};
|
||||||
|
|
||||||
mod utils;
|
mod utils;
|
||||||
use utils::wrap_err;
|
use utils::wrap_err;
|
||||||
|
@ -12,15 +12,16 @@ readme = "README.md"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
rand = { workspace = true }
|
rand = { workspace = true }
|
||||||
rayon = { workspace = true }
|
rayon = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
|
serde_plain = { workspace = true }
|
||||||
tracing = { workspace = true }
|
tracing = { workspace = true }
|
||||||
wav = { workspace = true }
|
wav = { workspace = true }
|
||||||
|
|
||||||
|
342
candle-transformers/src/models/distilbert.rs
Normal file
342
candle-transformers/src/models/distilbert.rs
Normal file
@ -0,0 +1,342 @@
|
|||||||
|
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
|
||||||
|
use candle::{DType, Device, Result, Tensor};
|
||||||
|
use candle_nn::{Embedding, Module, VarBuilder};
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
pub const DTYPE: DType = DType::F32;
|
||||||
|
|
||||||
|
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||||
|
let shape = mask.shape();
|
||||||
|
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||||
|
let m = mask.where_cond(&on_true, on_false)?;
|
||||||
|
Ok(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
enum HiddenAct {
|
||||||
|
Gelu,
|
||||||
|
Relu,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct HiddenActLayer {
|
||||||
|
act: HiddenAct,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HiddenActLayer {
|
||||||
|
fn new(act: HiddenAct) -> Self {
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "hidden-act");
|
||||||
|
Self { act, span }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for HiddenActLayer {
|
||||||
|
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
match self.act {
|
||||||
|
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
|
||||||
|
HiddenAct::Gelu => xs.gelu(),
|
||||||
|
HiddenAct::Relu => xs.relu(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
enum PositionEmbeddingType {
|
||||||
|
#[default]
|
||||||
|
Absolute,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
vocab_size: usize,
|
||||||
|
dim: usize,
|
||||||
|
n_layers: usize,
|
||||||
|
n_heads: usize,
|
||||||
|
hidden_dim: usize,
|
||||||
|
activation: HiddenAct,
|
||||||
|
max_position_embeddings: usize,
|
||||||
|
initializer_range: f64,
|
||||||
|
pad_token_id: usize,
|
||||||
|
#[serde(default)]
|
||||||
|
position_embedding_type: PositionEmbeddingType,
|
||||||
|
#[serde(default)]
|
||||||
|
use_cache: bool,
|
||||||
|
model_type: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Config {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
vocab_size: 30522,
|
||||||
|
dim: 768,
|
||||||
|
n_layers: 12,
|
||||||
|
n_heads: 12,
|
||||||
|
hidden_dim: 3072,
|
||||||
|
activation: HiddenAct::Gelu,
|
||||||
|
max_position_embeddings: 512,
|
||||||
|
initializer_range: 0.02,
|
||||||
|
pad_token_id: 0,
|
||||||
|
position_embedding_type: PositionEmbeddingType::Absolute,
|
||||||
|
use_cache: true,
|
||||||
|
model_type: Some("distilbert".to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Embeddings {
|
||||||
|
word_embeddings: Embedding,
|
||||||
|
position_embeddings: Embedding,
|
||||||
|
layer_norm: LayerNorm,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Embeddings {
|
||||||
|
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let word_embeddings =
|
||||||
|
candle_nn::embedding(config.vocab_size, config.dim, vb.pp("word_embeddings"))?;
|
||||||
|
let position_embeddings = candle_nn::embedding(
|
||||||
|
config.max_position_embeddings,
|
||||||
|
config.dim,
|
||||||
|
vb.pp("position_embeddings"),
|
||||||
|
)?;
|
||||||
|
let layer_norm = layer_norm(config.dim, 1e-12, vb.pp("LayerNorm"))?;
|
||||||
|
Ok(Self {
|
||||||
|
word_embeddings,
|
||||||
|
position_embeddings,
|
||||||
|
layer_norm,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "embeddings"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let (_bsize, seq_len) = input_ids.dims2()?;
|
||||||
|
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
||||||
|
let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
|
||||||
|
let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
|
||||||
|
let embeddings =
|
||||||
|
input_embeddings.broadcast_add(&self.position_embeddings.forward(&position_ids)?)?;
|
||||||
|
|
||||||
|
let embeddings = self.layer_norm.forward(&embeddings)?;
|
||||||
|
Ok(embeddings)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MultiHeadSelfAttention {
|
||||||
|
q_lin: Linear,
|
||||||
|
k_lin: Linear,
|
||||||
|
v_lin: Linear,
|
||||||
|
out_lin: Linear,
|
||||||
|
n_heads: usize,
|
||||||
|
attention_head_size: usize,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MultiHeadSelfAttention {
|
||||||
|
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let attention_head_size = config.dim / config.n_heads;
|
||||||
|
let all_head_size = config.n_heads * attention_head_size;
|
||||||
|
let dim = config.dim;
|
||||||
|
let q_lin = linear(dim, all_head_size, vb.pp("q_lin"))?;
|
||||||
|
let v_lin = linear(dim, all_head_size, vb.pp("v_lin"))?;
|
||||||
|
let k_lin = linear(dim, all_head_size, vb.pp("k_lin"))?;
|
||||||
|
let out_lin = linear(all_head_size, dim, vb.pp("out_lin"))?;
|
||||||
|
Ok(Self {
|
||||||
|
q_lin,
|
||||||
|
k_lin,
|
||||||
|
v_lin,
|
||||||
|
out_lin,
|
||||||
|
n_heads: config.n_heads,
|
||||||
|
attention_head_size,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "attention"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MultiHeadSelfAttention {
|
||||||
|
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let (bs, q_length, _dim) = hidden_states.dims3()?;
|
||||||
|
|
||||||
|
let dim_per_head = self.attention_head_size;
|
||||||
|
let q = self.q_lin.forward(hidden_states)?;
|
||||||
|
let k = self.k_lin.forward(hidden_states)?;
|
||||||
|
let v = self.v_lin.forward(hidden_states)?;
|
||||||
|
|
||||||
|
let q = q
|
||||||
|
.reshape((bs, q_length, self.n_heads, dim_per_head))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let k = k
|
||||||
|
.reshape((bs, q_length, self.n_heads, dim_per_head))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let v = v
|
||||||
|
.reshape((bs, q_length, self.n_heads, dim_per_head))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
|
||||||
|
let q: Tensor = (q / (dim_per_head as f64).sqrt())?;
|
||||||
|
let scores = q.matmul(&k.transpose(2, 3)?.contiguous()?)?;
|
||||||
|
let mask = attention_mask.broadcast_as(scores.shape())?;
|
||||||
|
|
||||||
|
let scores = masked_fill(&scores.to_dtype(DType::F32)?, &mask, f32::NEG_INFINITY)?;
|
||||||
|
let weights = candle_nn::ops::softmax(&scores, candle::D::Minus1)?;
|
||||||
|
|
||||||
|
let context = weights.matmul(&v.contiguous()?)?;
|
||||||
|
let context = context
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.reshape((bs, q_length, self.n_heads * dim_per_head))?
|
||||||
|
.contiguous()?;
|
||||||
|
let context = self.out_lin.forward(&context)?;
|
||||||
|
|
||||||
|
Ok(context)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
|
struct FFN {
|
||||||
|
lin1: Linear,
|
||||||
|
lin2: Linear,
|
||||||
|
activation: HiddenActLayer,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FFN {
|
||||||
|
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let lin1 = linear(config.dim, config.hidden_dim, vb.pp("lin1"))?;
|
||||||
|
let lin2 = linear(config.hidden_dim, config.dim, vb.pp("lin2"))?;
|
||||||
|
Ok(Self {
|
||||||
|
lin1,
|
||||||
|
lin2,
|
||||||
|
activation: HiddenActLayer::new(config.activation),
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "ffn"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for FFN {
|
||||||
|
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
hidden_states
|
||||||
|
.apply(&self.lin1)?
|
||||||
|
.apply(&self.activation)?
|
||||||
|
.apply(&self.lin2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TransformerBlock {
|
||||||
|
attention: MultiHeadSelfAttention,
|
||||||
|
sa_layer_norm: LayerNorm,
|
||||||
|
ffn: FFN,
|
||||||
|
output_layer_norm: LayerNorm,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TransformerBlock {
|
||||||
|
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let attention = MultiHeadSelfAttention::load(vb.pp("attention"), config)?;
|
||||||
|
let sa_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("sa_layer_norm"))?;
|
||||||
|
let ffn = FFN::load(vb.pp("ffn"), config)?;
|
||||||
|
let output_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("output_layer_norm"))?;
|
||||||
|
Ok(Self {
|
||||||
|
attention,
|
||||||
|
sa_layer_norm,
|
||||||
|
ffn,
|
||||||
|
output_layer_norm,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "layer"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TransformerBlock {
|
||||||
|
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let sa_output = self.attention.forward(hidden_states, attention_mask)?;
|
||||||
|
// TODO: Support cross-attention?
|
||||||
|
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
|
||||||
|
// TODO: Support something similar to `apply_chunking_to_forward`?
|
||||||
|
let sa_output = sa_output.broadcast_add(hidden_states)?;
|
||||||
|
let sa_output = self.sa_layer_norm.forward(&sa_output)?;
|
||||||
|
|
||||||
|
let ffn_output = self.ffn.forward(&sa_output)?;
|
||||||
|
let ffn_output = (&ffn_output + sa_output)?;
|
||||||
|
let output = self.output_layer_norm.forward(&ffn_output)?;
|
||||||
|
Ok(output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
|
||||||
|
struct Transformer {
|
||||||
|
layers: Vec<TransformerBlock>,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Transformer {
|
||||||
|
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let layers = (0..config.n_layers)
|
||||||
|
.map(|index| TransformerBlock::load(vb.pp(&format!("layer.{index}")), config))
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "encoder");
|
||||||
|
Ok(Transformer { layers, span })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Transformer {
|
||||||
|
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let mut hidden_states = hidden_states.clone();
|
||||||
|
// Use a loop rather than a fold as it's easier to modify when adding debug/...
|
||||||
|
for layer in self.layers.iter() {
|
||||||
|
hidden_states = layer.forward(&hidden_states, attention_mask)?;
|
||||||
|
}
|
||||||
|
Ok(hidden_states)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct DistilBertModel {
|
||||||
|
embeddings: Embeddings,
|
||||||
|
transformer: Transformer,
|
||||||
|
pub device: Device,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DistilBertModel {
|
||||||
|
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let (embeddings, transformer) = match (
|
||||||
|
Embeddings::load(vb.pp("embeddings"), config),
|
||||||
|
Transformer::load(vb.pp("transformer"), config),
|
||||||
|
) {
|
||||||
|
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
|
||||||
|
(Err(err), _) | (_, Err(err)) => {
|
||||||
|
if let Some(model_type) = &config.model_type {
|
||||||
|
if let (Ok(embeddings), Ok(encoder)) = (
|
||||||
|
Embeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
|
||||||
|
Transformer::load(vb.pp(&format!("{model_type}.transformer")), config),
|
||||||
|
) {
|
||||||
|
(embeddings, encoder)
|
||||||
|
} else {
|
||||||
|
return Err(err);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Err(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(Self {
|
||||||
|
embeddings,
|
||||||
|
transformer,
|
||||||
|
device: vb.device().clone(),
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "model"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let embedding_output = self.embeddings.forward(input_ids)?;
|
||||||
|
let sequence_output = self
|
||||||
|
.transformer
|
||||||
|
.forward(&embedding_output, attention_mask)?;
|
||||||
|
Ok(sequence_output)
|
||||||
|
}
|
||||||
|
}
|
@ -156,7 +156,6 @@ impl CausalSelfAttention {
|
|||||||
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
|
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
|
||||||
let x0 = x.narrow(D::Minus1, 0, 1)?;
|
let x0 = x.narrow(D::Minus1, 0, 1)?;
|
||||||
let x1 = x.narrow(D::Minus1, 1, 1)?;
|
let x1 = x.narrow(D::Minus1, 1, 1)?;
|
||||||
todo!("X {x1}");
|
|
||||||
let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
|
let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
|
||||||
let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
|
let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
|
||||||
let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?;
|
let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?;
|
||||||
@ -174,7 +173,6 @@ impl CausalSelfAttention {
|
|||||||
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
||||||
|
|
||||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||||
todo!("X {q}");
|
|
||||||
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
||||||
|
|
||||||
if self.cache.use_kv_cache {
|
if self.cache.use_kv_cache {
|
||||||
@ -297,7 +295,6 @@ impl Block {
|
|||||||
let residual = x;
|
let residual = x;
|
||||||
let x = self.rms_1.forward(x)?;
|
let x = self.rms_1.forward(x)?;
|
||||||
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
||||||
todo!("---X {}", x);
|
|
||||||
let residual = &x;
|
let residual = &x;
|
||||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
||||||
Ok(x)
|
Ok(x)
|
||||||
@ -330,7 +327,6 @@ impl Llama {
|
|||||||
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||||
let (_b_sz, _seq_len) = x.dims2()?;
|
let (_b_sz, _seq_len) = x.dims2()?;
|
||||||
let mut x = self.wte.forward(x)?;
|
let mut x = self.wte.forward(x)?;
|
||||||
//println!("Embeddings {}", self.wte.embeddings());
|
|
||||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||||
x = block.forward(&x, index_pos, block_idx)?;
|
x = block.forward(&x, index_pos, block_idx)?;
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ pub mod blip;
|
|||||||
pub mod blip_text;
|
pub mod blip_text;
|
||||||
pub mod convmixer;
|
pub mod convmixer;
|
||||||
pub mod dinov2;
|
pub mod dinov2;
|
||||||
|
pub mod distilbert;
|
||||||
pub mod efficientnet;
|
pub mod efficientnet;
|
||||||
pub mod falcon;
|
pub mod falcon;
|
||||||
pub mod jina_bert;
|
pub mod jina_bert;
|
||||||
@ -29,8 +30,10 @@ pub mod segment_anything;
|
|||||||
pub mod stable_diffusion;
|
pub mod stable_diffusion;
|
||||||
pub mod stable_lm;
|
pub mod stable_lm;
|
||||||
pub mod t5;
|
pub mod t5;
|
||||||
|
pub mod trocr;
|
||||||
pub mod vgg;
|
pub mod vgg;
|
||||||
pub mod vit;
|
pub mod vit;
|
||||||
pub mod whisper;
|
pub mod whisper;
|
||||||
pub mod with_tracing;
|
pub mod with_tracing;
|
||||||
pub mod wuerstchen;
|
pub mod wuerstchen;
|
||||||
|
pub mod yi;
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
// T5 Text Model, quantized version
|
// T5 Text Model, quantized version
|
||||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||||
|
|
||||||
|
use crate::models::t5::{deserialize_feed_forward_proj_activation, ActivationWithOptionalGating};
|
||||||
use crate::models::with_tracing::QMatMul;
|
use crate::models::with_tracing::QMatMul;
|
||||||
use crate::quantized_nn::Embedding;
|
use crate::quantized_nn::Embedding;
|
||||||
pub use crate::quantized_var_builder::VarBuilder;
|
pub use crate::quantized_var_builder::VarBuilder;
|
||||||
@ -54,8 +55,8 @@ pub struct Config {
|
|||||||
dropout_rate: f64,
|
dropout_rate: f64,
|
||||||
layer_norm_epsilon: f64,
|
layer_norm_epsilon: f64,
|
||||||
initializer_factor: f64,
|
initializer_factor: f64,
|
||||||
#[serde(default)]
|
#[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")]
|
||||||
feed_forward_proj: Activation,
|
pub feed_forward_proj: ActivationWithOptionalGating,
|
||||||
#[serde(default = "default_tie_word_embeddings")]
|
#[serde(default = "default_tie_word_embeddings")]
|
||||||
tie_word_embeddings: bool,
|
tie_word_embeddings: bool,
|
||||||
#[serde(default = "default_is_decoder")]
|
#[serde(default = "default_is_decoder")]
|
||||||
@ -83,7 +84,10 @@ impl Default for Config {
|
|||||||
dropout_rate: 0.1,
|
dropout_rate: 0.1,
|
||||||
layer_norm_epsilon: 1e-6,
|
layer_norm_epsilon: 1e-6,
|
||||||
initializer_factor: 1.0,
|
initializer_factor: 1.0,
|
||||||
feed_forward_proj: Activation::Relu,
|
feed_forward_proj: ActivationWithOptionalGating {
|
||||||
|
gated: false,
|
||||||
|
activation: Activation::Relu,
|
||||||
|
},
|
||||||
tie_word_embeddings: true,
|
tie_word_embeddings: true,
|
||||||
is_decoder: false,
|
is_decoder: false,
|
||||||
is_encoder_decoder: true,
|
is_encoder_decoder: true,
|
||||||
@ -176,7 +180,7 @@ impl T5DenseGatedActDense {
|
|||||||
wi_0,
|
wi_0,
|
||||||
wi_1,
|
wi_1,
|
||||||
wo,
|
wo,
|
||||||
act: Activation::NewGelu,
|
act: cfg.feed_forward_proj.activation,
|
||||||
span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
|
span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -205,7 +209,7 @@ impl T5LayerFF {
|
|||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let layer_norm =
|
let layer_norm =
|
||||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||||
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu {
|
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj.gated {
|
||||||
(
|
(
|
||||||
None,
|
None,
|
||||||
Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
|
Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
|
||||||
|
@ -1,12 +1,15 @@
|
|||||||
use candle::{Device, Result, Tensor};
|
use candle::{Device, Result, Tensor};
|
||||||
|
|
||||||
pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
|
pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
|
||||||
if steps < 1 {
|
if steps == 0 {
|
||||||
candle::bail!("cannot use linspace with steps {steps} <= 1")
|
Tensor::from_vec(Vec::<f64>::new(), steps, &Device::Cpu)
|
||||||
}
|
} else if steps == 1 {
|
||||||
|
Tensor::from_vec(vec![start], steps, &Device::Cpu)
|
||||||
|
} else {
|
||||||
let delta = (stop - start) / (steps - 1) as f64;
|
let delta = (stop - start) / (steps - 1) as f64;
|
||||||
let vs = (0..steps)
|
let vs = (0..steps)
|
||||||
.map(|step| start + step as f64 * delta)
|
.map(|step| start + step as f64 * delta)
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
Tensor::from_vec(vs, steps, &Device::Cpu)
|
Tensor::from_vec(vs, steps, &Device::Cpu)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
@ -37,6 +37,37 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
|||||||
Ok(m)
|
Ok(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Default, Clone, PartialEq)]
|
||||||
|
pub struct ActivationWithOptionalGating {
|
||||||
|
pub gated: bool,
|
||||||
|
pub activation: candle_nn::Activation,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deserialize_feed_forward_proj_activation<'de, D>(
|
||||||
|
deserializer: D,
|
||||||
|
) -> std::result::Result<ActivationWithOptionalGating, D::Error>
|
||||||
|
where
|
||||||
|
D: serde::de::Deserializer<'de>,
|
||||||
|
{
|
||||||
|
match String::deserialize(deserializer)?.as_str() {
|
||||||
|
"gated-gelu" => Ok(ActivationWithOptionalGating {
|
||||||
|
gated: true,
|
||||||
|
activation: candle_nn::Activation::NewGelu,
|
||||||
|
}),
|
||||||
|
"gated-silu" => Ok(ActivationWithOptionalGating {
|
||||||
|
gated: true,
|
||||||
|
activation: candle_nn::Activation::Silu,
|
||||||
|
}),
|
||||||
|
buf => {
|
||||||
|
let activation = serde_plain::from_str(buf).map_err(serde::de::Error::custom)?;
|
||||||
|
Ok(ActivationWithOptionalGating {
|
||||||
|
gated: false,
|
||||||
|
activation,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
vocab_size: usize,
|
vocab_size: usize,
|
||||||
@ -52,8 +83,8 @@ pub struct Config {
|
|||||||
dropout_rate: f64,
|
dropout_rate: f64,
|
||||||
layer_norm_epsilon: f64,
|
layer_norm_epsilon: f64,
|
||||||
initializer_factor: f64,
|
initializer_factor: f64,
|
||||||
#[serde(default)]
|
#[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")]
|
||||||
feed_forward_proj: Activation,
|
feed_forward_proj: ActivationWithOptionalGating,
|
||||||
#[serde(default = "default_tie_word_embeddings")]
|
#[serde(default = "default_tie_word_embeddings")]
|
||||||
tie_word_embeddings: bool,
|
tie_word_embeddings: bool,
|
||||||
#[serde(default = "default_is_decoder")]
|
#[serde(default = "default_is_decoder")]
|
||||||
@ -81,7 +112,10 @@ impl Default for Config {
|
|||||||
dropout_rate: 0.1,
|
dropout_rate: 0.1,
|
||||||
layer_norm_epsilon: 1e-6,
|
layer_norm_epsilon: 1e-6,
|
||||||
initializer_factor: 1.0,
|
initializer_factor: 1.0,
|
||||||
feed_forward_proj: Activation::Relu,
|
feed_forward_proj: ActivationWithOptionalGating {
|
||||||
|
gated: false,
|
||||||
|
activation: Activation::Relu,
|
||||||
|
},
|
||||||
tie_word_embeddings: true,
|
tie_word_embeddings: true,
|
||||||
is_decoder: false,
|
is_decoder: false,
|
||||||
is_encoder_decoder: true,
|
is_encoder_decoder: true,
|
||||||
@ -102,7 +136,10 @@ impl Config {
|
|||||||
d_model: 768,
|
d_model: 768,
|
||||||
dropout_rate: 0.1,
|
dropout_rate: 0.1,
|
||||||
eos_token_id: 1,
|
eos_token_id: 1,
|
||||||
feed_forward_proj: Activation::Relu,
|
feed_forward_proj: ActivationWithOptionalGating {
|
||||||
|
gated: false,
|
||||||
|
activation: Activation::Relu,
|
||||||
|
},
|
||||||
tie_word_embeddings: true,
|
tie_word_embeddings: true,
|
||||||
initializer_factor: 1.0,
|
initializer_factor: 1.0,
|
||||||
is_decoder: false,
|
is_decoder: false,
|
||||||
@ -202,7 +239,7 @@ impl T5DenseGatedActDense {
|
|||||||
wi_0,
|
wi_0,
|
||||||
wi_1,
|
wi_1,
|
||||||
wo,
|
wo,
|
||||||
act: Activation::NewGelu,
|
act: cfg.feed_forward_proj.activation,
|
||||||
span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
|
span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -231,7 +268,7 @@ impl T5LayerFF {
|
|||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let layer_norm =
|
let layer_norm =
|
||||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||||
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu {
|
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj.gated {
|
||||||
(
|
(
|
||||||
None,
|
None,
|
||||||
Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
|
Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
|
||||||
@ -425,7 +462,7 @@ impl T5Attention {
|
|||||||
self.relative_attention_max_distance as f32
|
self.relative_attention_max_distance as f32
|
||||||
/ max_exact as f32,
|
/ max_exact as f32,
|
||||||
) * (num_buckets - max_exact) as f32;
|
) * (num_buckets - max_exact) as f32;
|
||||||
max_exact + b as u32
|
u32::min(max_exact + b as u32, num_buckets - 1)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.collect::<Vec<u32>>()
|
.collect::<Vec<u32>>()
|
||||||
|
454
candle-transformers/src/models/trocr.rs
Normal file
454
candle-transformers/src/models/trocr.rs
Normal file
@ -0,0 +1,454 @@
|
|||||||
|
use crate::models::vit::{Config, Embeddings, Encoder};
|
||||||
|
use candle::{Result, Tensor};
|
||||||
|
use candle_nn::{
|
||||||
|
embedding, layer_norm, linear_no_bias, Embedding, LayerNorm, Linear, Module, VarBuilder,
|
||||||
|
};
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||||
|
pub struct TrOCRConfig {
|
||||||
|
pub vocab_size: usize,
|
||||||
|
pub d_model: usize,
|
||||||
|
pub hidden_size: usize,
|
||||||
|
pub decoder_layers: usize,
|
||||||
|
pub decoder_attention_heads: usize,
|
||||||
|
pub decoder_ffn_dim: usize,
|
||||||
|
pub activation_function: candle_nn::Activation,
|
||||||
|
pub max_position_embeddings: usize,
|
||||||
|
pub dropout: f64,
|
||||||
|
pub attention_dropout: f64,
|
||||||
|
pub activation_dropout: f64,
|
||||||
|
pub decoder_start_token_id: u32,
|
||||||
|
pub init_std: f64,
|
||||||
|
pub decoder_layerdrop: f64,
|
||||||
|
pub use_cache: bool,
|
||||||
|
pub scale_embedding: bool,
|
||||||
|
pub use_learned_position_embeddings: bool,
|
||||||
|
pub layernorm_embedding: bool,
|
||||||
|
pub pad_token_id: usize,
|
||||||
|
pub bos_token_id: usize,
|
||||||
|
pub eos_token_id: u32,
|
||||||
|
pub num_attention_heads: usize,
|
||||||
|
pub decoder_vocab_size: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for TrOCRConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
vocab_size: 50265,
|
||||||
|
d_model: 1024,
|
||||||
|
hidden_size: 768,
|
||||||
|
decoder_layers: 12,
|
||||||
|
decoder_attention_heads: 16,
|
||||||
|
decoder_ffn_dim: 4096,
|
||||||
|
activation_function: candle_nn::Activation::Gelu,
|
||||||
|
max_position_embeddings: 512,
|
||||||
|
dropout: 0.1,
|
||||||
|
attention_dropout: 0.0,
|
||||||
|
activation_dropout: 0.0,
|
||||||
|
decoder_start_token_id: 2,
|
||||||
|
init_std: 0.02,
|
||||||
|
decoder_layerdrop: 0.0,
|
||||||
|
use_cache: true,
|
||||||
|
scale_embedding: false,
|
||||||
|
use_learned_position_embeddings: true,
|
||||||
|
layernorm_embedding: true,
|
||||||
|
pad_token_id: 1,
|
||||||
|
bos_token_id: 0,
|
||||||
|
eos_token_id: 2,
|
||||||
|
num_attention_heads: 12,
|
||||||
|
decoder_vocab_size: Some(50265),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct TrOCRLearnedPositionalEmbedding {
|
||||||
|
offset: usize,
|
||||||
|
weights: Embedding,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TrOCRLearnedPositionalEmbedding {
|
||||||
|
fn load(vb: VarBuilder, cfg: &TrOCRConfig) -> Result<Self> {
|
||||||
|
let offset: usize = 2;
|
||||||
|
let num_embeddings = cfg.max_position_embeddings;
|
||||||
|
let embedding_dim = cfg.d_model;
|
||||||
|
let weights = embedding(num_embeddings + offset, embedding_dim, vb)?;
|
||||||
|
|
||||||
|
Ok(Self { offset, weights })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&mut self, input_ids: &Tensor, past_key_values_length: u32) -> Result<Tensor> {
|
||||||
|
let (b_sz, seq_len) = input_ids.dims2()?;
|
||||||
|
|
||||||
|
let mut positions = Tensor::arange(
|
||||||
|
past_key_values_length,
|
||||||
|
seq_len as u32 + past_key_values_length,
|
||||||
|
input_ids.device(),
|
||||||
|
)?
|
||||||
|
.expand((b_sz, seq_len))?;
|
||||||
|
|
||||||
|
positions =
|
||||||
|
positions.broadcast_add(&Tensor::new(self.offset as u32, input_ids.device())?)?;
|
||||||
|
self.weights.forward(&positions)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct TrOCRAttention {
|
||||||
|
head_dim: usize,
|
||||||
|
num_heads: usize,
|
||||||
|
is_decoder: bool,
|
||||||
|
scaling: f64,
|
||||||
|
k_proj: Linear,
|
||||||
|
v_proj: Linear,
|
||||||
|
q_proj: Linear,
|
||||||
|
out_proj: Linear,
|
||||||
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TrOCRAttention {
|
||||||
|
fn load(
|
||||||
|
vb: VarBuilder,
|
||||||
|
cfg: &TrOCRConfig,
|
||||||
|
kdim: Option<usize>,
|
||||||
|
vdim: Option<usize>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let embed_dim = cfg.d_model;
|
||||||
|
let num_heads = cfg.decoder_attention_heads;
|
||||||
|
let head_dim = embed_dim / num_heads;
|
||||||
|
let kdim = kdim.unwrap_or(embed_dim);
|
||||||
|
let vdim = vdim.unwrap_or(embed_dim);
|
||||||
|
|
||||||
|
let k_proj = linear_no_bias(kdim, embed_dim, vb.pp("k_proj"))?;
|
||||||
|
let v_proj = linear_no_bias(vdim, embed_dim, vb.pp("v_proj"))?;
|
||||||
|
let q_proj = linear_no_bias(embed_dim, embed_dim, vb.pp("q_proj"))?;
|
||||||
|
|
||||||
|
let out_proj = linear_no_bias(embed_dim, embed_dim, vb.pp("out_proj"))?;
|
||||||
|
Ok(Self {
|
||||||
|
head_dim,
|
||||||
|
num_heads,
|
||||||
|
is_decoder: true,
|
||||||
|
scaling: 1. / (head_dim as f64).sqrt(),
|
||||||
|
k_proj,
|
||||||
|
v_proj,
|
||||||
|
q_proj,
|
||||||
|
out_proj,
|
||||||
|
kv_cache: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reset_kv_cache(&mut self) {
|
||||||
|
self.kv_cache = None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn _shape(&self, tensor: &Tensor, bsz: usize) -> Result<Tensor> {
|
||||||
|
tensor
|
||||||
|
.reshape((bsz, (), self.num_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.contiguous()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(
|
||||||
|
&mut self,
|
||||||
|
xs: &Tensor,
|
||||||
|
kv_states: Option<&Tensor>,
|
||||||
|
attn_mask: Option<&Tensor>,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let (b_sz, tgt_len, _) = xs.dims3()?;
|
||||||
|
let query_states = (xs.apply(&self.q_proj)? * self.scaling)?;
|
||||||
|
let (key_states, value_states) = match kv_states {
|
||||||
|
None => {
|
||||||
|
let key_states = self._shape(&xs.apply(&self.k_proj)?, b_sz)?;
|
||||||
|
let value_states = self._shape(&xs.apply(&self.v_proj)?, b_sz)?;
|
||||||
|
if self.is_decoder {
|
||||||
|
let kv_states = match &self.kv_cache {
|
||||||
|
None => (key_states, value_states),
|
||||||
|
Some((p_key_states, p_value_states)) => {
|
||||||
|
let key_states = Tensor::cat(&[p_key_states, &key_states], 2)?;
|
||||||
|
let value_states = Tensor::cat(&[p_value_states, &value_states], 2)?;
|
||||||
|
(key_states, value_states)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
self.kv_cache = Some(kv_states.clone());
|
||||||
|
kv_states
|
||||||
|
} else {
|
||||||
|
(key_states, value_states)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(kv_states) => {
|
||||||
|
let key_states = self._shape(&kv_states.apply(&self.k_proj)?, b_sz)?;
|
||||||
|
let value_states = self._shape(&kv_states.apply(&self.v_proj)?, b_sz)?;
|
||||||
|
(key_states, value_states)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let proj_shape = (b_sz * self.num_heads, (), self.head_dim);
|
||||||
|
let query_states = self._shape(&query_states, b_sz)?.reshape(proj_shape)?;
|
||||||
|
let key_states = key_states.reshape(proj_shape)?;
|
||||||
|
let value_states = value_states.reshape(proj_shape)?;
|
||||||
|
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
|
||||||
|
let attn_weights = match attn_mask {
|
||||||
|
None => attn_weights,
|
||||||
|
Some(attn_mask) => attn_weights.broadcast_add(attn_mask)?,
|
||||||
|
};
|
||||||
|
let attn_probs = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||||
|
let attn_output = attn_probs.matmul(&value_states)?;
|
||||||
|
attn_output
|
||||||
|
.reshape((b_sz, self.num_heads, tgt_len, self.head_dim))?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.reshape((b_sz, tgt_len, self.head_dim * self.num_heads))?
|
||||||
|
.apply(&self.out_proj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct TrOCRDecoderLayer {
|
||||||
|
self_attn: TrOCRAttention,
|
||||||
|
activation_fn: candle_nn::Activation,
|
||||||
|
self_attn_layer_norm: LayerNorm,
|
||||||
|
encoder_attn: TrOCRAttention,
|
||||||
|
encoder_attn_layer_norm: LayerNorm,
|
||||||
|
fc1: Linear,
|
||||||
|
fc2: Linear,
|
||||||
|
final_layer_norm: LayerNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TrOCRDecoderLayer {
|
||||||
|
fn load(vb: VarBuilder, cfg: &TrOCRConfig) -> Result<Self> {
|
||||||
|
let embed_dim = cfg.d_model;
|
||||||
|
let self_attn = TrOCRAttention::load(vb.pp("self_attn"), cfg, None, None)?;
|
||||||
|
let self_attn_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("self_attn_layer_norm"))?;
|
||||||
|
let encoder_attn = TrOCRAttention::load(
|
||||||
|
vb.pp("encoder_attn"),
|
||||||
|
cfg,
|
||||||
|
Some(cfg.hidden_size),
|
||||||
|
Some(cfg.hidden_size),
|
||||||
|
)?;
|
||||||
|
let encoder_attn_layer_norm =
|
||||||
|
layer_norm(embed_dim, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
|
||||||
|
let fc1 = linear_no_bias(embed_dim, cfg.decoder_ffn_dim, vb.pp("fc1"))?;
|
||||||
|
let fc2 = linear_no_bias(cfg.decoder_ffn_dim, embed_dim, vb.pp("fc2"))?;
|
||||||
|
let final_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("final_layer_norm"))?;
|
||||||
|
let activation_fn = candle_nn::Activation::Gelu;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
self_attn,
|
||||||
|
activation_fn,
|
||||||
|
self_attn_layer_norm,
|
||||||
|
encoder_attn,
|
||||||
|
encoder_attn_layer_norm,
|
||||||
|
fc1,
|
||||||
|
fc2,
|
||||||
|
final_layer_norm,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reset_kv_cache(&mut self) {
|
||||||
|
self.self_attn.reset_kv_cache();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(
|
||||||
|
&mut self,
|
||||||
|
xs: &Tensor,
|
||||||
|
attention_mask: &Tensor,
|
||||||
|
encoder_hidden_states: Option<&Tensor>,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let residual = xs.clone();
|
||||||
|
let xs = self.self_attn.forward(xs, None, Some(attention_mask))?;
|
||||||
|
let xs = (xs + residual)?;
|
||||||
|
let mut xs = self.self_attn_layer_norm.forward(&xs)?;
|
||||||
|
|
||||||
|
if let Some(encoder_hidden_states) = &encoder_hidden_states {
|
||||||
|
let residual = xs.clone();
|
||||||
|
let encoder_attention_mask = attention_mask.clone(); // TODO
|
||||||
|
xs = self.encoder_attn.forward(
|
||||||
|
&xs,
|
||||||
|
Some(encoder_hidden_states),
|
||||||
|
Some(&encoder_attention_mask),
|
||||||
|
)?;
|
||||||
|
xs = (xs + residual)?;
|
||||||
|
xs = self.encoder_attn_layer_norm.forward(&xs)?
|
||||||
|
}
|
||||||
|
|
||||||
|
let residual = xs.clone();
|
||||||
|
let xs = self.fc1.forward(&xs)?;
|
||||||
|
let xs = self.activation_fn.forward(&xs)?;
|
||||||
|
let xs = self.fc2.forward(&xs)?;
|
||||||
|
let xs = (xs + residual)?;
|
||||||
|
let xs = self.final_layer_norm.forward(&xs)?;
|
||||||
|
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct TrOCRDecoder {
|
||||||
|
layers: Vec<TrOCRDecoderLayer>,
|
||||||
|
embed_scale: Option<f64>,
|
||||||
|
embed_tokens: Embedding,
|
||||||
|
embed_positions: TrOCRLearnedPositionalEmbedding,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TrOCRDecoder {
|
||||||
|
fn new(cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let vb = vb.pp("decoder.model.decoder");
|
||||||
|
|
||||||
|
let embed_tokens = embedding(cfg.vocab_size, cfg.d_model, vb.pp("embed_tokens"))?;
|
||||||
|
let embed_positions = TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?;
|
||||||
|
let mut layers = Vec::with_capacity(cfg.decoder_layers);
|
||||||
|
let vb_l = vb.pp("layers");
|
||||||
|
for idx in 0..cfg.decoder_layers {
|
||||||
|
let layer = TrOCRDecoderLayer::load(vb_l.pp(idx), cfg)?;
|
||||||
|
layers.push(layer)
|
||||||
|
}
|
||||||
|
let embed_scale = if cfg.scale_embedding {
|
||||||
|
Some((cfg.d_model as f64).sqrt())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
layers,
|
||||||
|
embed_scale,
|
||||||
|
embed_tokens,
|
||||||
|
embed_positions,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reset_kv_cache(&mut self) {
|
||||||
|
self.layers.iter_mut().for_each(|l| l.reset_kv_cache())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(
|
||||||
|
&mut self,
|
||||||
|
xs: &Tensor,
|
||||||
|
encoder_xs: Option<&Tensor>,
|
||||||
|
past_kv_len: usize,
|
||||||
|
attn_mask: &Tensor,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let embed_pos = self.embed_positions.forward(xs, past_kv_len as u32)?;
|
||||||
|
let xs = xs.apply(&self.embed_tokens)?;
|
||||||
|
|
||||||
|
let xs = match self.embed_scale {
|
||||||
|
None => xs,
|
||||||
|
Some(scale) => (xs * scale)?,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut xs = xs.broadcast_add(&embed_pos)?;
|
||||||
|
|
||||||
|
for layer in self.layers.iter_mut() {
|
||||||
|
xs = layer.forward(&xs, attn_mask, encoder_xs)?;
|
||||||
|
}
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct TrOCREncoder {
|
||||||
|
embeddings: Embeddings,
|
||||||
|
encoder: Encoder,
|
||||||
|
layernorm: LayerNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TrOCREncoder {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let vb_v = vb.pp("encoder");
|
||||||
|
|
||||||
|
let embeddings = Embeddings::new(cfg, false, vb_v.pp("embeddings"))?;
|
||||||
|
|
||||||
|
let encoder = Encoder::new(cfg, vb_v.pp("encoder"))?;
|
||||||
|
let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_v.pp("layernorm"))?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
embeddings,
|
||||||
|
encoder,
|
||||||
|
layernorm,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let embedding_output = self.embeddings.forward(xs, None, false)?;
|
||||||
|
let encoder_outputs = self.encoder.forward(&embedding_output)?;
|
||||||
|
|
||||||
|
self.layernorm.forward(&encoder_outputs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct TrOCRForCausalLM {
|
||||||
|
decoder: TrOCRDecoder,
|
||||||
|
output_projection: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TrOCRForCausalLM {
|
||||||
|
pub fn new(decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let decoder = TrOCRDecoder::new(decoder_cfg, vb.clone())?;
|
||||||
|
let output_projection =
|
||||||
|
candle_nn::Linear::new(decoder.embed_tokens.embeddings().clone(), None);
|
||||||
|
Ok(Self {
|
||||||
|
decoder,
|
||||||
|
output_projection,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(
|
||||||
|
&mut self,
|
||||||
|
xs: &Tensor,
|
||||||
|
encoder_xs: Option<&Tensor>,
|
||||||
|
past_kv_len: usize,
|
||||||
|
attn_mask: &Tensor,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let xs = self
|
||||||
|
.decoder
|
||||||
|
.forward(xs, encoder_xs, past_kv_len, attn_mask)?;
|
||||||
|
let xs = xs.apply(&self.output_projection)?;
|
||||||
|
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reset_kv_cache(&mut self) {
|
||||||
|
self.decoder.reset_kv_cache();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct TrOCRModel {
|
||||||
|
encoder: TrOCREncoder,
|
||||||
|
decoder: TrOCRForCausalLM,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TrOCRModel {
|
||||||
|
pub fn new(encoder_cfg: &Config, decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let encoder = TrOCREncoder::new(encoder_cfg, vb.clone())?;
|
||||||
|
let decoder = TrOCRForCausalLM::new(decoder_cfg, vb)?;
|
||||||
|
Ok(Self { encoder, decoder })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encoder(&mut self) -> &mut TrOCREncoder {
|
||||||
|
&mut self.encoder
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decoder(&mut self) -> &mut TrOCRForCausalLM {
|
||||||
|
&mut self.decoder
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decode(
|
||||||
|
&mut self,
|
||||||
|
xs: &Tensor,
|
||||||
|
encoder_xs: &Tensor,
|
||||||
|
past_kv_len: usize,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let seq_len = xs.dim(1)?;
|
||||||
|
let mask: Vec<_> = (0..seq_len)
|
||||||
|
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
||||||
|
.collect();
|
||||||
|
let mask = Tensor::from_vec(mask, (seq_len, seq_len), xs.device())?;
|
||||||
|
|
||||||
|
self.decoder
|
||||||
|
.forward(xs, Some(encoder_xs), past_kv_len, &mask)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset_kv_cache(&mut self) {
|
||||||
|
self.decoder.reset_kv_cache();
|
||||||
|
}
|
||||||
|
}
|
@ -6,16 +6,16 @@ use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
|||||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/configuration_vit.py
|
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/configuration_vit.py
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
hidden_size: usize,
|
pub hidden_size: usize,
|
||||||
num_hidden_layers: usize,
|
pub num_hidden_layers: usize,
|
||||||
num_attention_heads: usize,
|
pub num_attention_heads: usize,
|
||||||
intermediate_size: usize,
|
pub intermediate_size: usize,
|
||||||
hidden_act: candle_nn::Activation,
|
pub hidden_act: candle_nn::Activation,
|
||||||
layer_norm_eps: f64,
|
pub layer_norm_eps: f64,
|
||||||
image_size: usize,
|
pub image_size: usize,
|
||||||
patch_size: usize,
|
pub patch_size: usize,
|
||||||
num_channels: usize,
|
pub num_channels: usize,
|
||||||
qkv_bias: bool,
|
pub qkv_bias: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
@ -34,6 +34,21 @@ impl Config {
|
|||||||
qkv_bias: true,
|
qkv_bias: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn microsoft_trocr_base_handwritten() -> Self {
|
||||||
|
Self {
|
||||||
|
hidden_size: 768,
|
||||||
|
num_hidden_layers: 12,
|
||||||
|
num_attention_heads: 12,
|
||||||
|
intermediate_size: 3072,
|
||||||
|
hidden_act: candle_nn::Activation::Gelu,
|
||||||
|
layer_norm_eps: 1e-12,
|
||||||
|
image_size: 384,
|
||||||
|
patch_size: 16,
|
||||||
|
num_channels: 3,
|
||||||
|
qkv_bias: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -76,7 +91,7 @@ impl Module for PatchEmbeddings {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct Embeddings {
|
pub struct Embeddings {
|
||||||
cls_token: Tensor,
|
cls_token: Tensor,
|
||||||
mask_token: Option<Tensor>,
|
mask_token: Option<Tensor>,
|
||||||
patch_embeddings: PatchEmbeddings,
|
patch_embeddings: PatchEmbeddings,
|
||||||
@ -85,7 +100,7 @@ struct Embeddings {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Embeddings {
|
impl Embeddings {
|
||||||
fn new(cfg: &Config, use_mask_token: bool, vb: VarBuilder) -> Result<Self> {
|
pub fn new(cfg: &Config, use_mask_token: bool, vb: VarBuilder) -> Result<Self> {
|
||||||
let hidden_size = cfg.hidden_size;
|
let hidden_size = cfg.hidden_size;
|
||||||
let cls_token = vb.get((1, 1, hidden_size), "cls_token")?;
|
let cls_token = vb.get((1, 1, hidden_size), "cls_token")?;
|
||||||
let mask_token = if use_mask_token {
|
let mask_token = if use_mask_token {
|
||||||
@ -115,7 +130,7 @@ impl Embeddings {
|
|||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(
|
pub fn forward(
|
||||||
&self,
|
&self,
|
||||||
pixel_values: &Tensor,
|
pixel_values: &Tensor,
|
||||||
bool_masked_pos: Option<&Tensor>,
|
bool_masked_pos: Option<&Tensor>,
|
||||||
@ -324,12 +339,12 @@ impl Module for Layer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct Encoder {
|
pub struct Encoder {
|
||||||
layers: Vec<Layer>,
|
layers: Vec<Layer>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Encoder {
|
impl Encoder {
|
||||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
let vb = vb.pp("layer");
|
let vb = vb.pp("layer");
|
||||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
for i in 0..cfg.num_hidden_layers {
|
for i in 0..cfg.num_hidden_layers {
|
||||||
|
@ -58,8 +58,7 @@ fn dft<T: Float>(inp: &[T]) -> Vec<T> {
|
|||||||
let n = inp.len();
|
let n = inp.len();
|
||||||
let two_pi = T::PI() + T::PI();
|
let two_pi = T::PI() + T::PI();
|
||||||
|
|
||||||
let mut out = Vec::new();
|
let mut out = Vec::with_capacity(2 * n);
|
||||||
out.reserve(2 * n);
|
|
||||||
let n_t = T::from(n).unwrap();
|
let n_t = T::from(n).unwrap();
|
||||||
for k in 0..n {
|
for k in 0..n {
|
||||||
let k_t = T::from(k).unwrap();
|
let k_t = T::from(k).unwrap();
|
||||||
|
@ -43,4 +43,4 @@ pub const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
|
|||||||
pub const TRANSLATE_TOKEN: &str = "<|translate|>";
|
pub const TRANSLATE_TOKEN: &str = "<|translate|>";
|
||||||
pub const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
|
pub const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
|
||||||
pub const EOT_TOKEN: &str = "<|endoftext|>";
|
pub const EOT_TOKEN: &str = "<|endoftext|>";
|
||||||
pub const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
|
pub const NO_SPEECH_TOKENS: [&str; 2] = ["<|nocaptions|>", "<|nospeech|>"];
|
||||||
|
381
candle-transformers/src/models/yi.rs
Normal file
381
candle-transformers/src/models/yi.rs
Normal file
@ -0,0 +1,381 @@
|
|||||||
|
/// https://huggingface.co/01-ai/Yi-6B/blob/main/modeling_yi.py
|
||||||
|
use crate::models::with_tracing::{linear_no_bias, Linear};
|
||||||
|
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||||
|
use candle_nn::{Activation, VarBuilder};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub struct Config {
|
||||||
|
pub(crate) vocab_size: usize,
|
||||||
|
pub(crate) hidden_size: usize,
|
||||||
|
pub(crate) intermediate_size: usize,
|
||||||
|
pub(crate) num_hidden_layers: usize,
|
||||||
|
pub(crate) num_attention_heads: usize,
|
||||||
|
pub(crate) num_key_value_heads: usize,
|
||||||
|
pub(crate) hidden_act: Activation,
|
||||||
|
pub(crate) max_position_embeddings: usize,
|
||||||
|
pub(crate) rms_norm_eps: f64,
|
||||||
|
pub(crate) rope_theta: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
pub fn config_6b() -> Self {
|
||||||
|
Self {
|
||||||
|
vocab_size: 64000,
|
||||||
|
hidden_size: 4096,
|
||||||
|
intermediate_size: 11008,
|
||||||
|
num_hidden_layers: 32,
|
||||||
|
num_attention_heads: 32,
|
||||||
|
num_key_value_heads: 4,
|
||||||
|
hidden_act: Activation::Silu,
|
||||||
|
max_position_embeddings: 4096,
|
||||||
|
rms_norm_eps: 1e-5,
|
||||||
|
rope_theta: 5_000_000.,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn config_34b() -> Self {
|
||||||
|
Self {
|
||||||
|
vocab_size: 64000,
|
||||||
|
hidden_size: 7168,
|
||||||
|
intermediate_size: 20480,
|
||||||
|
num_hidden_layers: 60,
|
||||||
|
num_attention_heads: 56,
|
||||||
|
num_key_value_heads: 8,
|
||||||
|
hidden_act: Activation::Silu,
|
||||||
|
max_position_embeddings: 4096,
|
||||||
|
rms_norm_eps: 1e-5,
|
||||||
|
rope_theta: 5_000_000.,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct RmsNorm {
|
||||||
|
inner: candle_nn::RmsNorm,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RmsNorm {
|
||||||
|
fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
||||||
|
let inner = candle_nn::rms_norm(size, eps, vb)?;
|
||||||
|
Ok(Self { inner, span })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for RmsNorm {
|
||||||
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
self.inner.forward(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct RotaryEmbedding {
|
||||||
|
sin: Tensor,
|
||||||
|
cos: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let last_dim = xs.dim(D::Minus1)?;
|
||||||
|
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
|
||||||
|
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
|
||||||
|
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RotaryEmbedding {
|
||||||
|
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||||
|
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||||
|
let max_seq_len = cfg.max_position_embeddings;
|
||||||
|
let inv_freq: Vec<_> = (0..dim)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
|
||||||
|
.collect();
|
||||||
|
let inv_freq_len = inv_freq.len();
|
||||||
|
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||||
|
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||||
|
.to_dtype(dtype)?
|
||||||
|
.reshape((max_seq_len, 1))?;
|
||||||
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
|
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||||
|
Ok(Self {
|
||||||
|
sin: freqs.sin()?,
|
||||||
|
cos: freqs.cos()?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_rotary_emb_qkv(
|
||||||
|
&self,
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||||
|
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||||
|
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||||
|
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||||
|
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||||
|
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
|
||||||
|
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
|
||||||
|
Ok((q_embed, k_embed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
|
struct MLP {
|
||||||
|
gate_proj: Linear,
|
||||||
|
up_proj: Linear,
|
||||||
|
down_proj: Linear,
|
||||||
|
act_fn: Activation,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MLP {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let hidden_sz = cfg.hidden_size;
|
||||||
|
let intermediate_sz = cfg.intermediate_size;
|
||||||
|
let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
|
||||||
|
let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
|
||||||
|
let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
|
||||||
|
Ok(Self {
|
||||||
|
gate_proj,
|
||||||
|
up_proj,
|
||||||
|
down_proj,
|
||||||
|
act_fn: cfg.hidden_act,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for MLP {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||||
|
let rhs = xs.apply(&self.up_proj)?;
|
||||||
|
(lhs * rhs)?.apply(&self.down_proj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Attention {
|
||||||
|
q_proj: Linear,
|
||||||
|
k_proj: Linear,
|
||||||
|
v_proj: Linear,
|
||||||
|
o_proj: Linear,
|
||||||
|
num_heads: usize,
|
||||||
|
num_kv_heads: usize,
|
||||||
|
num_kv_groups: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
hidden_size: usize,
|
||||||
|
rotary_emb: Arc<RotaryEmbedding>,
|
||||||
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Attention {
|
||||||
|
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let hidden_sz = cfg.hidden_size;
|
||||||
|
let num_heads = cfg.num_attention_heads;
|
||||||
|
let num_kv_heads = cfg.num_key_value_heads;
|
||||||
|
let num_kv_groups = num_heads / num_kv_heads;
|
||||||
|
let head_dim = hidden_sz / num_heads;
|
||||||
|
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
||||||
|
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
||||||
|
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
||||||
|
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
|
||||||
|
Ok(Self {
|
||||||
|
q_proj,
|
||||||
|
k_proj,
|
||||||
|
v_proj,
|
||||||
|
o_proj,
|
||||||
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
num_kv_groups,
|
||||||
|
head_dim,
|
||||||
|
hidden_size: hidden_sz,
|
||||||
|
rotary_emb,
|
||||||
|
kv_cache: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
||||||
|
let n_rep = self.num_kv_groups;
|
||||||
|
if n_rep == 1 {
|
||||||
|
Ok(xs)
|
||||||
|
} else {
|
||||||
|
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||||
|
xs.unsqueeze(2)?
|
||||||
|
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||||
|
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(
|
||||||
|
&mut self,
|
||||||
|
xs: &Tensor,
|
||||||
|
attention_mask: Option<&Tensor>,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let (b_sz, q_len, _) = xs.dims3()?;
|
||||||
|
|
||||||
|
let query_states = self.q_proj.forward(xs)?;
|
||||||
|
let key_states = self.k_proj.forward(xs)?;
|
||||||
|
let value_states = self.v_proj.forward(xs)?;
|
||||||
|
|
||||||
|
let query_states = query_states
|
||||||
|
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let key_states = key_states
|
||||||
|
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let value_states = value_states
|
||||||
|
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
|
||||||
|
let (query_states, key_states) =
|
||||||
|
self.rotary_emb
|
||||||
|
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
|
||||||
|
|
||||||
|
let (key_states, value_states) = match &self.kv_cache {
|
||||||
|
None => (key_states, value_states),
|
||||||
|
Some((prev_k, prev_v)) => {
|
||||||
|
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
|
||||||
|
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
|
||||||
|
(key_states, value_states)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||||
|
|
||||||
|
let key_states = self.repeat_kv(key_states)?;
|
||||||
|
let value_states = self.repeat_kv(value_states)?;
|
||||||
|
|
||||||
|
let attn_output = {
|
||||||
|
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||||
|
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||||
|
|
||||||
|
let attn_weights = match attention_mask {
|
||||||
|
None => attn_weights,
|
||||||
|
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||||
|
};
|
||||||
|
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||||
|
attn_weights.matmul(&value_states)?
|
||||||
|
};
|
||||||
|
attn_output
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.reshape((b_sz, q_len, self.hidden_size))?
|
||||||
|
.apply(&self.o_proj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct DecoderLayer {
|
||||||
|
self_attn: Attention,
|
||||||
|
mlp: MLP,
|
||||||
|
ln1: RmsNorm,
|
||||||
|
ln2: RmsNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DecoderLayer {
|
||||||
|
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
||||||
|
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||||
|
let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||||
|
let ln2 = RmsNorm::new(
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.rms_norm_eps,
|
||||||
|
vb.pp("post_attention_layernorm"),
|
||||||
|
)?;
|
||||||
|
Ok(Self {
|
||||||
|
self_attn,
|
||||||
|
mlp,
|
||||||
|
ln1,
|
||||||
|
ln2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(
|
||||||
|
&mut self,
|
||||||
|
xs: &Tensor,
|
||||||
|
attention_mask: Option<&Tensor>,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let residual = xs;
|
||||||
|
let xs = self.ln1.forward(xs)?;
|
||||||
|
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
|
||||||
|
let xs = (xs + residual)?;
|
||||||
|
let residual = &xs;
|
||||||
|
let xs = xs.apply(&self.ln2)?.apply(&self.mlp)?;
|
||||||
|
residual + xs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Model {
|
||||||
|
embed_tokens: candle_nn::Embedding,
|
||||||
|
layers: Vec<DecoderLayer>,
|
||||||
|
norm: RmsNorm,
|
||||||
|
lm_head: Linear,
|
||||||
|
device: Device,
|
||||||
|
dtype: DType,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let vb_m = vb.pp("model");
|
||||||
|
let embed_tokens =
|
||||||
|
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||||
|
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
||||||
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
|
let vb_l = vb_m.pp("layers");
|
||||||
|
for layer_idx in 0..cfg.num_hidden_layers {
|
||||||
|
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
|
||||||
|
layers.push(layer)
|
||||||
|
}
|
||||||
|
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
||||||
|
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||||
|
Ok(Self {
|
||||||
|
embed_tokens,
|
||||||
|
layers,
|
||||||
|
norm,
|
||||||
|
lm_head,
|
||||||
|
device: vb.device().clone(),
|
||||||
|
dtype: vb.dtype(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prepare_decoder_attention_mask(
|
||||||
|
&self,
|
||||||
|
b_size: usize,
|
||||||
|
tgt_len: usize,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
// Sliding window mask?
|
||||||
|
let mask: Vec<_> = (0..tgt_len)
|
||||||
|
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||||
|
.collect();
|
||||||
|
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||||
|
let mask = if seqlen_offset > 0 {
|
||||||
|
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
||||||
|
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||||
|
} else {
|
||||||
|
mask
|
||||||
|
};
|
||||||
|
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||||
|
.to_dtype(self.dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||||
|
let (b_size, seq_len) = input_ids.dims2()?;
|
||||||
|
let attention_mask = if seq_len <= 1 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
||||||
|
Some(mask)
|
||||||
|
};
|
||||||
|
let mut xs = self.embed_tokens.forward(input_ids)?;
|
||||||
|
for layer in self.layers.iter_mut() {
|
||||||
|
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||||
|
}
|
||||||
|
xs.narrow(1, seq_len - 1, 1)?
|
||||||
|
.apply(&self.norm)?
|
||||||
|
.apply(&self.lm_head)
|
||||||
|
}
|
||||||
|
}
|
@ -9,9 +9,9 @@ categories.workspace = true
|
|||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
||||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
||||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||||
|
|
||||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
|||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
||||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
||||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
||||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
|
|
||||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
|||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
||||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
||||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||||
|
|
||||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
|||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
||||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
||||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
||||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
|
|
||||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
|||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
||||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
||||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
|
|
||||||
# App crates.
|
# App crates.
|
||||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
|||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
||||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
||||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||||
|
|
||||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
|||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
||||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
||||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||||
|
|
||||||
|
@ -59,8 +59,7 @@ fn dft<T: Float>(inp: &[T]) -> Vec<T> {
|
|||||||
let n = inp.len();
|
let n = inp.len();
|
||||||
let two_pi = T::PI() + T::PI();
|
let two_pi = T::PI() + T::PI();
|
||||||
|
|
||||||
let mut out = Vec::new();
|
let mut out = Vec::with_capacity(2 * n);
|
||||||
out.reserve(2 * n);
|
|
||||||
let n_t = T::from(n).unwrap();
|
let n_t = T::from(n).unwrap();
|
||||||
for k in 0..n {
|
for k in 0..n {
|
||||||
let k_t = T::from(k).unwrap();
|
let k_t = T::from(k).unwrap();
|
||||||
|
@ -129,7 +129,13 @@ impl Decoder {
|
|||||||
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
|
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
|
||||||
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
|
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
|
||||||
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
|
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
|
||||||
let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
|
let no_speech_token = m::NO_SPEECH_TOKENS
|
||||||
|
.iter()
|
||||||
|
.find_map(|token| token_id(&tokenizer, token).ok());
|
||||||
|
let no_speech_token = match no_speech_token {
|
||||||
|
None => anyhow::bail!("unable to find any non-speech token"),
|
||||||
|
Some(n) => n,
|
||||||
|
};
|
||||||
let seed = 299792458;
|
let seed = 299792458;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
model,
|
model,
|
||||||
|
@ -9,8 +9,8 @@ categories.workspace = true
|
|||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
||||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
|
@ -7,7 +7,7 @@ keywords.workspace = true
|
|||||||
categories.workspace = true
|
categories.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||||
rand = { workspace = true }
|
rand = { workspace = true }
|
||||||
getrandom = { version = "0.2", features = ["js"] }
|
getrandom = { version = "0.2", features = ["js"] }
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use candle::{
|
use candle::{
|
||||||
quantized::{self, k_quants, GgmlDType, GgmlType},
|
quantized::{self, k_quants, GgmlDType, GgmlType},
|
||||||
test_utils::to_vec2_round,
|
test_utils::to_vec2_round,
|
||||||
Device, Result, Tensor,
|
Device, Module, Result, Tensor,
|
||||||
};
|
};
|
||||||
|
|
||||||
use wasm_bindgen_test::*;
|
use wasm_bindgen_test::*;
|
||||||
|
Reference in New Issue
Block a user