mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
58 Commits
metal2-tmp
...
metal4.6
Author | SHA1 | Date | |
---|---|---|---|
a9d0657432 | |||
87dc559817 | |||
da0af3cb3e | |||
803ac8405b | |||
6e25822d4f | |||
2ca086939f | |||
4349ff1fc2 | |||
7c3cfd1086 | |||
e2eb6590ed | |||
481c45d78d | |||
14a2bdc062 | |||
bfa7c8fc01 | |||
762e996ce6 | |||
ca19a9af62 | |||
ec23427d60 | |||
f83e14f68d | |||
c7e613ab5e | |||
8f63f68289 | |||
1edc3ddf24 | |||
b380657bfe | |||
60f624a902 | |||
8d6c6de8e0 | |||
7ec345c2eb | |||
671fc29b36 | |||
dc64adb8e4 | |||
c66e5d4716 | |||
bd3b243725 | |||
2813fb5dbc | |||
7cfffcac10 | |||
38de52bc4b | |||
d46670f7c0 | |||
f710fab02e | |||
f82bf2d915 | |||
df6814f34e | |||
39406a6721 | |||
976ad9f9c2 | |||
a4c4a56429 | |||
f49bf6a81d | |||
992a788da1 | |||
8d8f48c60c | |||
d31f11035f | |||
9ab3f9729f | |||
a1f41ab37b | |||
92a05b51cf | |||
c6763e3b41 | |||
347e31c9ff | |||
f4fcf60900 | |||
12561b31d3 | |||
a209ce8ceb | |||
f1e678b39c | |||
a007f8fdb4 | |||
2341aa079e | |||
9e666d4229 | |||
1b12142a02 | |||
d2c3f14773 | |||
26c4e5bf1d | |||
18d30005c5 | |||
6958384327 |
@ -19,7 +19,7 @@ exclude = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.3.0"
|
||||
version = "0.3.1"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -51,6 +51,7 @@ rayon = "1.7.0"
|
||||
rusttype = { version = "0.9", default-features = false }
|
||||
safetensors = "0.3.1"
|
||||
serde = { version = "1.0.171", features = ["derive"] }
|
||||
serde_plain = "1.0.2"
|
||||
serde_json = "1.0.99"
|
||||
thiserror = "1"
|
||||
tokenizers = { version = "0.13.4", default-features = false }
|
||||
@ -60,8 +61,7 @@ tracing-subscriber = "0.3.7"
|
||||
wav = "1.0.0"
|
||||
yoke = { version = "0.7.2", features = ["derive"] }
|
||||
zip = { version = "0.6.6", default-features = false }
|
||||
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||
metal = { path = "../metal-rs", features = ["mps"] }
|
||||
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
|
||||
|
||||
[profile.release-with-debug]
|
||||
inherits = "release"
|
||||
|
22
README.md
22
README.md
@ -69,6 +69,8 @@ We also provide a some command line based examples using state of the art models
|
||||
performance larger than all publicly available 13b models as of 2023-09-28.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
||||
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
||||
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
|
||||
(English/Chinese) general LLMs with 6b and 34b parameters.
|
||||
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
|
||||
the LLaMA model using the same quantization techniques as
|
||||
[llama.cpp](https://github.com/ggerganov/llama.cpp).
|
||||
@ -137,16 +139,16 @@ And then head over to
|
||||
<!--- ANCHOR: useful_libraries --->
|
||||
|
||||
## Useful External Resources
|
||||
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): a
|
||||
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): A
|
||||
very detailed tutorial showing how to convert a PyTorch model to Candle.
|
||||
- [`optimisers`](https://github.com/KGrewal1/optimisers): a collection of optimisers
|
||||
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): Efficient and ergonomic LoRA implemenation for Candle. `candle-lora` has
|
||||
out-of-the-box LoRA support for many models from Candle, which can be found [here](https://github.com/EricLBuehler/candle-lora/tree/master/candle-lora-transformers/examples).
|
||||
- [`optimisers`](https://github.com/KGrewal1/optimisers): A collection of optimisers
|
||||
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
|
||||
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): a LoRA implementation
|
||||
that conforms to the official `peft` implementation.
|
||||
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
|
||||
serving local LLMs including an OpenAI compatible API server.
|
||||
- [`candle-ext`](https://github.com/mokeyish/candle-ext): an extension library to Candle that provides PyTorch functions not currently available in Candle.
|
||||
- [`kalosm`](https://github.com/floneum/floneum/tree/master/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
||||
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
|
||||
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
||||
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
||||
|
||||
If you have an addition to this list, please submit a pull request.
|
||||
@ -174,8 +176,14 @@ If you have an addition to this list, please submit a pull request.
|
||||
- StableLM-3B-4E1T.
|
||||
- Replit-code-v1.5-3B.
|
||||
- Bert.
|
||||
- Yi-6B and Yi-34B.
|
||||
- Quantized LLMs.
|
||||
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
||||
- Mistral 7b, and 7b instruct.
|
||||
- Zephyr 7b a and b (Mistral based).
|
||||
- OpenChat 3.5 (Mistral based).
|
||||
- Text to text.
|
||||
- T5 and its variants: FlanT5, MADLAD400 (translation), CoEdit (Grammar correction).
|
||||
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
|
||||
- Marian MT (Machine Translation).
|
||||
- Whisper (multi-lingual support).
|
||||
- Text to image.
|
||||
|
@ -11,11 +11,11 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.3.1" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.1" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
@ -12,8 +12,8 @@ readme = "README.md"
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
|
||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.3.1", optional = true }
|
||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.1", optional = true }
|
||||
metal = { workspace = true, optional = true}
|
||||
cudarc = { workspace = true, optional = true }
|
||||
gemm = { workspace = true }
|
||||
|
@ -8,11 +8,10 @@ use anyhow::Result;
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
|
||||
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
|
||||
let start = std::time::Instant::now();
|
||||
let res = inp.conv2d(&w, 0, 1, 1, 1)?;
|
||||
println!("{:?}", start.elapsed());
|
||||
println!("{res:?}");
|
||||
let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?;
|
||||
let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?;
|
||||
let new_a = a.slice_scatter(&b, 1, 2)?;
|
||||
assert_eq!(a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
assert_eq!(new_a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -104,37 +104,31 @@ impl From<&Tensor> for TensorIndexer {
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_from_range {
|
||||
($range_type:ty) => {
|
||||
impl From<$range_type> for TensorIndexer {
|
||||
fn from(range: $range_type) -> Self {
|
||||
use std::ops::Bound::*;
|
||||
trait RB: RangeBounds<usize> {}
|
||||
impl RB for Range<usize> {}
|
||||
impl RB for RangeFrom<usize> {}
|
||||
impl RB for RangeFull {}
|
||||
impl RB for RangeInclusive<usize> {}
|
||||
impl RB for RangeTo<usize> {}
|
||||
impl RB for RangeToInclusive<usize> {}
|
||||
|
||||
let start = match range.start_bound() {
|
||||
Included(idx) => Included(*idx),
|
||||
Excluded(idx) => Excluded(*idx),
|
||||
Unbounded => Unbounded,
|
||||
};
|
||||
|
||||
let end = match range.end_bound() {
|
||||
Included(idx) => Included(*idx),
|
||||
Excluded(idx) => Excluded(*idx),
|
||||
Unbounded => Unbounded,
|
||||
};
|
||||
|
||||
TensorIndexer::Narrow(start, end)
|
||||
}
|
||||
}
|
||||
};
|
||||
impl<T: RB> From<T> for TensorIndexer {
|
||||
fn from(range: T) -> Self {
|
||||
use std::ops::Bound::*;
|
||||
let start = match range.start_bound() {
|
||||
Included(idx) => Included(*idx),
|
||||
Excluded(idx) => Excluded(*idx),
|
||||
Unbounded => Unbounded,
|
||||
};
|
||||
let end = match range.end_bound() {
|
||||
Included(idx) => Included(*idx),
|
||||
Excluded(idx) => Excluded(*idx),
|
||||
Unbounded => Unbounded,
|
||||
};
|
||||
TensorIndexer::Narrow(start, end)
|
||||
}
|
||||
}
|
||||
|
||||
impl_from_range!(Range<usize>);
|
||||
impl_from_range!(RangeFrom<usize>);
|
||||
impl_from_range!(RangeFull);
|
||||
impl_from_range!(RangeInclusive<usize>);
|
||||
impl_from_range!(RangeTo<usize>);
|
||||
impl_from_range!(RangeToInclusive<usize>);
|
||||
|
||||
/// Trait used to implement multiple signatures for ease of use of the slicing
|
||||
/// of a tensor
|
||||
pub trait IndexOp<T> {
|
||||
|
@ -123,12 +123,6 @@ pub trait Module {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
||||
}
|
||||
|
||||
impl Module for quantized::QMatMul {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.forward(xs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self(xs)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -593,7 +593,8 @@ unary_op!(Recip, "recip", v, v.recip());
|
||||
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
||||
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
||||
|
||||
/// `gelu` operation
|
||||
/// Tanh based approximation of the `gelu` operation
|
||||
/// GeluErf is the more precise one.
|
||||
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||
impl UnaryOpT for Gelu {
|
||||
const NAME: &'static str = "gelu";
|
||||
|
@ -307,8 +307,8 @@ impl crate::CustomOp1 for QTensor {
|
||||
}
|
||||
}
|
||||
|
||||
impl QMatMul {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
impl crate::Module for QMatMul {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
||||
Self::Tensor(w) => {
|
||||
|
@ -157,8 +157,6 @@ pub(crate) fn from_storage<S: Into<Shape>>(
|
||||
) -> Tensor {
|
||||
let dtype = storage.dtype();
|
||||
let device = storage.device();
|
||||
let shape = shape.into();
|
||||
// println!("{:?} {storage:?}", shape);
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: Arc::new(RwLock::new(storage)),
|
||||
@ -168,11 +166,7 @@ pub(crate) fn from_storage<S: Into<Shape>>(
|
||||
dtype,
|
||||
device,
|
||||
};
|
||||
let result = Tensor(Arc::new(tensor_));
|
||||
// todo!(" from_storage");
|
||||
// let result = result.to_device(&Device::Cpu).unwrap();
|
||||
// todo!(" {result}");
|
||||
result
|
||||
Tensor(Arc::new(tensor_))
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
@ -862,6 +856,20 @@ impl Tensor {
|
||||
self.sum_impl(mean_dims, false)? * scale
|
||||
}
|
||||
|
||||
/// Returns the unbiased variance over the selected dimension.
|
||||
pub fn var_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "var")?;
|
||||
let mean = self.mean_keepdim(dim)?;
|
||||
let squares = self.broadcast_sub(&mean)?.sqr()?;
|
||||
squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64
|
||||
}
|
||||
|
||||
/// Returns the unbiased variance over the selected dimension.
|
||||
pub fn var<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "var")?;
|
||||
self.var_keepdim(dim)?.squeeze(dim)
|
||||
}
|
||||
|
||||
/// Gathers the maximum value across the selected dimension. The resulting shape has the same
|
||||
/// number of dimensions as the original tensor and the select dimension has a single element.
|
||||
pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
@ -1855,7 +1863,10 @@ impl Tensor {
|
||||
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
|
||||
}
|
||||
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||
(Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||
(Storage::Metal(storage), Device::Cpu) => {
|
||||
// println!("{storage:?} - {:?}", storage.to_cpu_storage()?);
|
||||
Storage::Cpu(storage.to_cpu_storage()?)
|
||||
}
|
||||
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
|
||||
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
|
||||
// are the same.
|
||||
@ -2446,6 +2457,110 @@ impl Tensor {
|
||||
Ok(naxis as usize)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a lower triangular matrix of ones of size n by n.
|
||||
pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||
let t = Tensor::arange(0u32, n as u32, device)?;
|
||||
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
||||
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
||||
t1.le(&t2)?.to_dtype(dtype)
|
||||
}
|
||||
|
||||
/// Returns an upper triangular matrix of ones of size n by n.
|
||||
pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||
let t = Tensor::arange(0u32, n as u32, device)?;
|
||||
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
||||
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
||||
t1.ge(&t2)?.to_dtype(dtype)
|
||||
}
|
||||
|
||||
/// Returns a matrix with a diagonal of ones of size n by n.
|
||||
pub fn eye(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||
let t = Tensor::arange(0u32, n as u32, device)?;
|
||||
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
||||
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
||||
t1.eq(&t2)?.to_dtype(dtype)
|
||||
}
|
||||
|
||||
/// Returns the cumulative sum of elements of the input tensor summed over the specified
|
||||
/// dimension.
|
||||
///
|
||||
/// This operation is most efficient when dim is the last dimension of the tensor.
|
||||
pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "cumsum")?;
|
||||
let rank = self.rank();
|
||||
if rank == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let n_axis = self.dim(dim)?;
|
||||
let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?;
|
||||
if rank == 1 {
|
||||
self.unsqueeze(0)?.matmul(&triu)?.squeeze(0)
|
||||
} else {
|
||||
let last = rank - 1;
|
||||
let t = self.transpose(dim, last)?;
|
||||
let t = t.broadcast_matmul(&triu)?;
|
||||
t.transpose(dim, last)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a copy of `self` where the values within `ranges` have been replaced with the
|
||||
/// content of `src`.
|
||||
pub fn slice_assign<D: std::ops::RangeBounds<usize>>(
|
||||
&self,
|
||||
ranges: &[D],
|
||||
src: &Tensor,
|
||||
) -> Result<Self> {
|
||||
let src_dims = src.dims();
|
||||
let self_dims = self.dims();
|
||||
if self_dims.len() != src_dims.len() {
|
||||
crate::bail!(
|
||||
"slice-assign requires input with the same rank {} <> {}",
|
||||
self_dims.len(),
|
||||
src_dims.len()
|
||||
)
|
||||
}
|
||||
if self_dims.len() != ranges.len() {
|
||||
crate::bail!(
|
||||
"slice-assign requires input with the same rank as there are ranges {} <> {}",
|
||||
self_dims.len(),
|
||||
ranges.len()
|
||||
)
|
||||
}
|
||||
let mut src = src.clone();
|
||||
let mut mask = Self::ones(src.shape(), DType::U8, src.device())?;
|
||||
for (i, range) in ranges.iter().enumerate() {
|
||||
let start_included = match range.start_bound() {
|
||||
std::ops::Bound::Unbounded => 0,
|
||||
std::ops::Bound::Included(v) => *v,
|
||||
std::ops::Bound::Excluded(v) => *v + 1,
|
||||
};
|
||||
let end_excluded = match range.end_bound() {
|
||||
std::ops::Bound::Unbounded => self_dims[i],
|
||||
std::ops::Bound::Included(v) => *v + 1,
|
||||
std::ops::Bound::Excluded(v) => *v,
|
||||
};
|
||||
if end_excluded <= start_included {
|
||||
crate::bail!(
|
||||
"slice-assign: empty range for dim {i}, {start_included} {end_excluded}"
|
||||
)
|
||||
}
|
||||
if self_dims[i] < end_excluded {
|
||||
crate::bail!(
|
||||
"slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
|
||||
self_dims[i]
|
||||
)
|
||||
}
|
||||
if end_excluded - start_included != src_dims[i] {
|
||||
crate::bail!(
|
||||
"slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
|
||||
)
|
||||
}
|
||||
src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?;
|
||||
mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?
|
||||
}
|
||||
mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! bin_trait {
|
||||
|
@ -4,7 +4,7 @@ use crate::{Result, Tensor};
|
||||
macro_rules! test_device {
|
||||
// TODO: Switch to generating the two last arguments automatically once concat_idents is
|
||||
// stable. https://github.com/rust-lang/rust/issues/29599
|
||||
($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => {
|
||||
($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => {
|
||||
#[test]
|
||||
fn $test_cpu() -> Result<()> {
|
||||
$fn_name(&Device::Cpu)
|
||||
@ -15,6 +15,12 @@ macro_rules! test_device {
|
||||
fn $test_cuda() -> Result<()> {
|
||||
$fn_name(&Device::new_cuda(0)?)
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
#[test]
|
||||
fn $test_metal() -> Result<()> {
|
||||
$fn_name(&Device::new_metal(0)?)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -563,14 +563,35 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
|
||||
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
|
||||
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
|
||||
test_device!(conv1d, conv1d_cpu, conv1d_gpu, conv1d_metal);
|
||||
test_device!(
|
||||
conv1d_small,
|
||||
conv1d_small_cpu,
|
||||
conv1d_small_gpu,
|
||||
conv1d_small_metal
|
||||
);
|
||||
test_device!(conv2d, conv2d_cpu, conv2d_gpu, conv2d_metal);
|
||||
test_device!(
|
||||
conv2d_non_square,
|
||||
conv2d_non_square_cpu,
|
||||
conv2d_non_square_gpu
|
||||
conv2d_non_square_gpu,
|
||||
conv2d_non_square_metal
|
||||
);
|
||||
test_device!(
|
||||
conv2d_small,
|
||||
conv2d_small_cpu,
|
||||
conv2d_small_gpu,
|
||||
conv2d_small_metal
|
||||
);
|
||||
test_device!(
|
||||
conv2d_smaller,
|
||||
conv2d_smaller_cpu,
|
||||
conv2d_smaller_gpu,
|
||||
conv2d_smaller_metal
|
||||
);
|
||||
test_device!(
|
||||
conv2d_grad,
|
||||
conv2d_grad_cpu,
|
||||
conv2d_grad_gpu,
|
||||
conv2_grad_metal
|
||||
);
|
||||
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
|
||||
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
|
||||
test_device!(conv2d_grad, conv2d_grad_cpu, conv2d_grad_gpu);
|
||||
|
@ -315,9 +315,29 @@ fn binary_grad(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
|
||||
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
|
||||
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
|
||||
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
|
||||
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu);
|
||||
test_device!(binary_grad, binary_grad_cpu, binary_grad_gpu);
|
||||
test_device!(
|
||||
simple_grad,
|
||||
simple_grad_cpu,
|
||||
simple_grad_gpu,
|
||||
simple_grad_metal
|
||||
);
|
||||
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu, sum_grad_metal);
|
||||
test_device!(
|
||||
matmul_grad,
|
||||
matmul_grad_cpu,
|
||||
matmul_grad_gpu,
|
||||
matmul_grad_metal
|
||||
);
|
||||
test_device!(
|
||||
grad_descent,
|
||||
grad_descent_cpu,
|
||||
grad_descent_gpu,
|
||||
grad_descent_metal
|
||||
);
|
||||
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu, unary_grad_metal);
|
||||
test_device!(
|
||||
binary_grad,
|
||||
binary_grad_cpu,
|
||||
binary_grad_gpu,
|
||||
binary_grad_metal
|
||||
);
|
||||
|
@ -91,3 +91,32 @@ fn index_3d() -> Result<()> {
|
||||
assert_eq!(tensor.i((1, .., 3))?.to_vec1::<u32>()?, &[15, 19, 23]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slice_assign() -> Result<()> {
|
||||
let dev = Device::Cpu;
|
||||
|
||||
let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?;
|
||||
let src = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((3, 2))?;
|
||||
let out = tensor.slice_assign(&[1..4, 3..5], &src)?;
|
||||
assert_eq!(
|
||||
out.to_vec2::<u32>()?,
|
||||
&[
|
||||
[0, 1, 2, 3, 4],
|
||||
[5, 6, 7, 0, 1],
|
||||
[10, 11, 12, 2, 3],
|
||||
[15, 16, 17, 4, 5]
|
||||
]
|
||||
);
|
||||
let out = tensor.slice_assign(&[0..3, 0..2], &src)?;
|
||||
assert_eq!(
|
||||
out.to_vec2::<u32>()?,
|
||||
&[
|
||||
[0, 1, 2, 3, 4],
|
||||
[2, 3, 7, 8, 9],
|
||||
[4, 5, 12, 13, 14],
|
||||
[15, 16, 17, 18, 19]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -49,7 +49,7 @@ fn contiguous(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(contiguous, contiguous_cpu, contiguous_gpu);
|
||||
test_device!(contiguous, contiguous_cpu, contiguous_gpu, contiguous_metal);
|
||||
|
||||
#[test]
|
||||
fn strided_blocks() -> Result<()> {
|
||||
|
@ -98,15 +98,17 @@ fn upsample_nearest2d(dev: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu);
|
||||
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu, avg_pool2d_metal);
|
||||
test_device!(
|
||||
avg_pool2d_pytorch,
|
||||
avg_pool2d_pytorch_cpu,
|
||||
avg_pool2d_pytorch_gpu
|
||||
avg_pool2d_pytorch_gpu,
|
||||
avg_pool2d_pytorch_metal
|
||||
);
|
||||
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);
|
||||
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu, max_pool2d_metal);
|
||||
test_device!(
|
||||
upsample_nearest2d,
|
||||
upsample_nearest2d_cpu,
|
||||
upsample_nearest2d_gpu
|
||||
upsample_nearest2d_gpu,
|
||||
upsample_nearest2d_metal
|
||||
);
|
||||
|
@ -1,7 +1,7 @@
|
||||
use candle_core::{
|
||||
quantized::{self, GgmlDType},
|
||||
test_utils::to_vec2_round,
|
||||
Device, Result, Tensor,
|
||||
Device, Module, Result, Tensor,
|
||||
};
|
||||
use quantized::{k_quants, GgmlType};
|
||||
use rand::prelude::*;
|
||||
|
@ -180,6 +180,22 @@ fn transpose(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn var(device: &Device) -> Result<()> {
|
||||
// Values taken from https://pytorch.org/docs/stable/generated/torch.var.html
|
||||
let data = &[
|
||||
[0.2035f32, 1.2959, 1.8101, -0.4644],
|
||||
[1.5027, -0.3270, 0.5905, 0.6538],
|
||||
[-1.5745, 1.3330, -0.5596, -0.6548],
|
||||
[0.1264, -0.5080, 1.6420, 0.1992],
|
||||
];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.var_keepdim(1)?, 4)?,
|
||||
&[[1.0631], [0.559], [1.4893], [0.8258]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn sum(device: &Device) -> Result<()> {
|
||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
@ -1054,34 +1070,60 @@ fn randn(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu);
|
||||
test_device!(ones, ones_cpu, ones_gpu);
|
||||
test_device!(arange, arange_cpu, arange_gpu);
|
||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
||||
test_device!(narrow, narrow_cpu, narrow_gpu);
|
||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu);
|
||||
test_device!(cat, cat_cpu, cat_gpu);
|
||||
test_device!(sum, sum_cpu, sum_gpu);
|
||||
test_device!(min, min_cpu, min_gpu);
|
||||
test_device!(max, max_cpu, max_gpu);
|
||||
test_device!(argmax, argmax_cpu, argmax_gpu);
|
||||
test_device!(argmin, argmin_cpu, argmin_gpu);
|
||||
test_device!(transpose, transpose_cpu, transpose_gpu);
|
||||
test_device!(unary_op, unary_op_cpu, unary_op_gpu);
|
||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
||||
test_device!(cmp, cmp_cpu, cmp_gpu);
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu);
|
||||
test_device!(broadcast_matmul, broadcast_matmul_cpu, broadcast_matmul_gpu);
|
||||
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
|
||||
test_device!(index_select, index_select_cpu, index_select_gpu);
|
||||
test_device!(index_add, index_add_cpu, index_add_gpu);
|
||||
test_device!(gather, gather_cpu, gather_gpu);
|
||||
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
||||
test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
|
||||
test_device!(randn, randn_cpu, randn_gpu);
|
||||
test_device!(clamp, clamp_cpu, clamp_gpu);
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
||||
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
||||
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
|
||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
||||
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
|
||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
|
||||
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
|
||||
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
|
||||
test_device!(min, min_cpu, min_gpu, min_metal);
|
||||
test_device!(max, max_cpu, max_gpu, max_metal);
|
||||
test_device!(argmax, argmax_cpu, argmax_gpu, argmax_metal);
|
||||
test_device!(argmin, argmin_cpu, argmin_gpu, argmin_metal);
|
||||
test_device!(transpose, transpose_cpu, transpose_gpu, transpose_metal);
|
||||
test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
|
||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
|
||||
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
||||
test_device!(
|
||||
broadcast_matmul,
|
||||
broadcast_matmul_cpu,
|
||||
broadcast_matmul_gpu,
|
||||
broadcast_matmul_metal
|
||||
);
|
||||
test_device!(
|
||||
broadcasting,
|
||||
broadcasting_cpu,
|
||||
broadcasting_gpu,
|
||||
broadcasting_metal
|
||||
);
|
||||
test_device!(
|
||||
index_select,
|
||||
index_select_cpu,
|
||||
index_select_gpu,
|
||||
index_select_metal
|
||||
);
|
||||
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
|
||||
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
|
||||
test_device!(
|
||||
scatter_add,
|
||||
scatter_add_cpu,
|
||||
scatter_add_gpu,
|
||||
scatter_add_metal
|
||||
);
|
||||
test_device!(
|
||||
slice_scatter,
|
||||
slice_scatter_cpu,
|
||||
slice_scatter_gpu,
|
||||
slice_scatter_metal
|
||||
);
|
||||
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
||||
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
||||
test_device!(var, var_cpu, var_gpu, var_metal);
|
||||
|
||||
// There was originally a bug on the CPU implementation for randn
|
||||
// https://github.com/huggingface/candle/issues/381
|
||||
@ -1117,3 +1159,65 @@ fn i64_abs() -> Result<()> {
|
||||
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tril_triu_eye() -> Result<()> {
|
||||
let t = Tensor::tril2(4, DType::F32, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
[
|
||||
[1.0, 0.0, 0.0, 0.0],
|
||||
[1.0, 1.0, 0.0, 0.0],
|
||||
[1.0, 1.0, 1.0, 0.0],
|
||||
[1.0, 1.0, 1.0, 1.0]
|
||||
],
|
||||
);
|
||||
let t = Tensor::triu2(4, DType::F32, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
[
|
||||
[1.0, 1.0, 1.0, 1.0],
|
||||
[0.0, 1.0, 1.0, 1.0],
|
||||
[0.0, 0.0, 1.0, 1.0],
|
||||
[0.0, 0.0, 0.0, 1.0]
|
||||
]
|
||||
);
|
||||
let t = Tensor::eye(4, DType::F32, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
[
|
||||
[1.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 1.0]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cumsum() -> Result<()> {
|
||||
let t = &[3f32, 1., 4., 1., 5.];
|
||||
let t = Tensor::new(t, &Device::Cpu)?;
|
||||
assert_eq!(t.cumsum(0)?.to_vec1::<f32>()?, [3., 4., 8., 9., 14.]);
|
||||
let t = t.unsqueeze(1)?;
|
||||
assert_eq!(
|
||||
t.cumsum(0)?.to_vec2::<f32>()?,
|
||||
[[3.0], [4.0], [8.0], [9.0], [14.0]]
|
||||
);
|
||||
assert_eq!(
|
||||
t.cumsum(1)?.to_vec2::<f32>()?,
|
||||
[[3.0], [1.0], [4.0], [1.0], [5.0]]
|
||||
);
|
||||
let t = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||
let t = Tensor::new(t, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
t.cumsum(1)?.to_vec2::<f32>()?,
|
||||
[[3.0, 4.0, 8.0, 9.0, 14.0], [2.0, 3.0, 10.0, 18.0, 20.0]],
|
||||
);
|
||||
assert_eq!(
|
||||
t.cumsum(0)?.to_vec2::<f32>()?,
|
||||
[[3.0, 1.0, 4.0, 1.0, 5.0], [5.0, 2.0, 11.0, 9.0, 7.0]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -11,8 +11,8 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
byteorder = { workspace = true }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
||||
hf-hub = { workspace = true}
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
memmap2 = { workspace = true }
|
||||
|
@ -11,12 +11,12 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||
candle-onnx = { path = "../candle-onnx", version = "0.3.0", optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.3.1" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.1" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true }
|
||||
candle-onnx = { path = "../candle-onnx", version = "0.3.1", optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
image = { workspace = true }
|
||||
@ -57,6 +57,7 @@ flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||
onnx = ["candle-onnx"]
|
||||
metal = ["candle/metal", "candle-nn/metal"]
|
||||
|
||||
[[example]]
|
||||
name = "llama_multiprocess"
|
||||
|
22
candle-examples/examples/distilbert/README.md
Normal file
22
candle-examples/examples/distilbert/README.md
Normal file
@ -0,0 +1,22 @@
|
||||
# candle-distilbert
|
||||
|
||||
DistilBert is a distiled version of the Bert model.
|
||||
|
||||
## Sentence embeddings
|
||||
|
||||
DistilBert is used to compute the sentence embeddings for a prompt. The model weights
|
||||
are downloaded from the hub on the first run.
|
||||
|
||||
```bash
|
||||
cargo run --example distilbert --release -- --prompt "Here is a test sentence"
|
||||
|
||||
> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441],
|
||||
> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244],
|
||||
> [ 0.0702, -0.1311, -0.4914, ..., 0.3483, -0.6194, 0.1829],
|
||||
> ...
|
||||
> [ 0.2993, -0.0106, -0.4640, ..., 0.2844, -0.6732, 0.0042],
|
||||
> [ 0.1066, -0.0081, -0.4299, ..., 0.3435, -0.7729, 0.0190],
|
||||
> [ 0.8903, 0.2055, -0.2541, ..., 0.3208, -0.6585, 0.0586]]]
|
||||
> Tensor[[1, 7, 768], f32]
|
||||
|
||||
```
|
135
candle-examples/examples/distilbert/main.rs
Normal file
135
candle-examples/examples/distilbert/main.rs
Normal file
@ -0,0 +1,135 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE};
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
/// When set, compute embeddings for this prompt.
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// Use the pytorch weights rather than the safetensors ones
|
||||
#[arg(long)]
|
||||
use_pth: bool,
|
||||
|
||||
/// The number of times to run the prompt.
|
||||
#[arg(long, default_value = "1")]
|
||||
n: usize,
|
||||
|
||||
/// L2 normalization for embeddings.
|
||||
#[arg(long, default_value = "true")]
|
||||
normalize_embeddings: bool,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> {
|
||||
let device = candle_examples::device(self.cpu)?;
|
||||
let default_model = "distilbert-base-uncased".to_string();
|
||||
let default_revision = "main".to_string();
|
||||
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||
(None, Some(revision)) => (default_model, revision),
|
||||
(None, None) => (default_model, default_revision),
|
||||
};
|
||||
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let (config_filename, tokenizer_filename, weights_filename) = {
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
let config = api.get("config.json")?;
|
||||
let tokenizer = api.get("tokenizer.json")?;
|
||||
let weights = if self.use_pth {
|
||||
api.get("pytorch_model.bin")?
|
||||
} else {
|
||||
api.get("model.safetensors")?
|
||||
};
|
||||
(config, tokenizer, weights)
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let config: Config = serde_json::from_str(&config)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let vb = if self.use_pth {
|
||||
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
|
||||
} else {
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
||||
};
|
||||
let model = DistilBertModel::load(vb, &config)?;
|
||||
Ok((model, tokenizer))
|
||||
}
|
||||
}
|
||||
|
||||
fn get_mask(size: usize, device: &Device) -> Tensor {
|
||||
let mask: Vec<_> = (0..size)
|
||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (size, size), device).unwrap()
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
println!("tracing...");
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
||||
let device = &model.device;
|
||||
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
let tokens = tokenizer
|
||||
.encode(args.prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
let mask = get_mask(tokens.len(), device);
|
||||
|
||||
println!("token_ids: {:?}", token_ids.to_vec2::<u32>());
|
||||
println!("mask: {:?}", mask.to_vec2::<u8>());
|
||||
|
||||
let ys = model.forward(&token_ids, &mask)?;
|
||||
println!("{ys}");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
|
||||
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
|
||||
}
|
@ -329,18 +329,14 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
println!("{tokens:?}");
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..1 {
|
||||
for index in 0.. {
|
||||
if tokens.len() >= config.seq_len {
|
||||
break;
|
||||
}
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
// println!("Input {}", input);
|
||||
// println!("Input {}", input.to_device(&candle::Device::Cpu)?);
|
||||
let logits = model.forward(&input, index_pos)?;
|
||||
let logits = logits.i((0, logits.dim(1)? - 1))?;
|
||||
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
|
||||
|
@ -53,6 +53,8 @@ enum Which {
|
||||
Zephyr7bAlpha,
|
||||
#[value(name = "7b-zephyr-b")]
|
||||
Zephyr7bBeta,
|
||||
#[value(name = "7b-open-chat-3.5")]
|
||||
OpenChat35,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
@ -67,8 +69,10 @@ impl Which {
|
||||
| Self::L7bCode
|
||||
| Self::L13bCode
|
||||
| Self::L34bCode => false,
|
||||
// Zephyr is a fine tuned version of mistral and should be treated in the same way.
|
||||
Self::Zephyr7bAlpha
|
||||
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
|
||||
// same way.
|
||||
Self::OpenChat35
|
||||
| Self::Zephyr7bAlpha
|
||||
| Self::Zephyr7bBeta
|
||||
| Self::Mistral7b
|
||||
| Self::Mistral7bInstruct => true,
|
||||
@ -87,10 +91,30 @@ impl Which {
|
||||
| Self::L13bCode
|
||||
| Self::L34bCode
|
||||
| Self::Mistral7b
|
||||
| Self::Mistral7bInstruct => false,
|
||||
| Self::Mistral7bInstruct
|
||||
| Self::OpenChat35 => false,
|
||||
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_open_chat(&self) -> bool {
|
||||
match self {
|
||||
Which::L7b
|
||||
| Which::L13b
|
||||
| Which::L70b
|
||||
| Which::L7bChat
|
||||
| Which::L13bChat
|
||||
| Which::L70bChat
|
||||
| Which::L7bCode
|
||||
| Which::L13bCode
|
||||
| Which::L34bCode
|
||||
| Which::Mistral7b
|
||||
| Which::Mistral7bInstruct
|
||||
| Which::Zephyr7bAlpha
|
||||
| Which::Zephyr7bBeta => false,
|
||||
Which::OpenChat35 => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -157,7 +181,9 @@ impl Args {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let repo = if self.which.is_mistral() {
|
||||
let repo = if self.which.is_open_chat() {
|
||||
"openchat/openchat_3.5"
|
||||
} else if self.which.is_mistral() {
|
||||
"mistralai/Mistral-7B-v0.1"
|
||||
} else {
|
||||
"hf-internal-testing/llama-tokenizer"
|
||||
@ -207,6 +233,7 @@ impl Args {
|
||||
Which::Zephyr7bBeta => {
|
||||
("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf")
|
||||
}
|
||||
Which::OpenChat35 => ("TheBloke/openchat_3.5-GGUF", "openchat_3.5.Q4_K_M.gguf"),
|
||||
};
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(repo.to_string());
|
||||
@ -308,7 +335,8 @@ fn main() -> anyhow::Result<()> {
|
||||
| Which::Zephyr7bAlpha
|
||||
| Which::Zephyr7bBeta
|
||||
| Which::L70b
|
||||
| Which::L70bChat => 8,
|
||||
| Which::L70bChat
|
||||
| Which::OpenChat35 => 8,
|
||||
};
|
||||
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
|
||||
}
|
||||
@ -325,10 +353,11 @@ fn main() -> anyhow::Result<()> {
|
||||
};
|
||||
|
||||
let mut pre_prompt_tokens = vec![];
|
||||
loop {
|
||||
for prompt_index in 0.. {
|
||||
let prompt_str = match &prompt {
|
||||
Prompt::One(prompt) => prompt.clone(),
|
||||
Prompt::Interactive | Prompt::Chat => {
|
||||
let is_interactive = matches!(prompt, Prompt::Interactive);
|
||||
print!("> ");
|
||||
std::io::stdout().flush()?;
|
||||
let mut prompt = String::new();
|
||||
@ -339,8 +368,14 @@ fn main() -> anyhow::Result<()> {
|
||||
prompt.pop();
|
||||
}
|
||||
}
|
||||
if args.which.is_zephyr() {
|
||||
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>")
|
||||
if args.which.is_open_chat() {
|
||||
format!("User: {prompt}<|end_of_turn|>Assistant: ")
|
||||
} else if args.which.is_zephyr() {
|
||||
if prompt_index == 0 || is_interactive {
|
||||
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>",)
|
||||
} else {
|
||||
format!("<|user|>\n{prompt}</s>\n<|assistant|>")
|
||||
}
|
||||
} else if args.which.is_mistral() {
|
||||
format!("[INST] {prompt} [/INST]")
|
||||
} else {
|
||||
@ -385,8 +420,12 @@ fn main() -> anyhow::Result<()> {
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
||||
let eos_token = *tos.tokenizer().get_vocab(true).get("</s>").unwrap();
|
||||
|
||||
let eos_token = if args.which.is_open_chat() {
|
||||
"<|end_of_turn|>"
|
||||
} else {
|
||||
"</s>"
|
||||
};
|
||||
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
let mut sampled = 0;
|
||||
for index in 0..to_sample {
|
||||
|
@ -416,7 +416,7 @@ fn run(args: Args) -> Result<()> {
|
||||
|
||||
println!("Building the autoencoder.");
|
||||
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
|
||||
let vae = sd_config.build_vae(&vae_weights, &device, dtype)?;
|
||||
let vae = sd_config.build_vae(vae_weights, &device, dtype)?;
|
||||
let init_latent_dist = match &img2img {
|
||||
None => None,
|
||||
Some(image) => {
|
||||
@ -426,7 +426,7 @@ fn run(args: Args) -> Result<()> {
|
||||
};
|
||||
println!("Building the unet.");
|
||||
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?;
|
||||
let unet = sd_config.build_unet(&unet_weights, &device, 4, use_flash_attn, dtype)?;
|
||||
let unet = sd_config.build_unet(unet_weights, &device, 4, use_flash_attn, dtype)?;
|
||||
|
||||
let t_start = if img2img.is_some() {
|
||||
n_steps - (n_steps as f64 * img2img_strength) as usize
|
||||
|
@ -9,6 +9,8 @@ $ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate
|
||||
9 tokens generated (2.42 token/s)
|
||||
```
|
||||
|
||||
Variants such as [flan-t5](https://huggingface.co/google/flan-t5-small), [flan-ul2](https://huggingface.co/google/flan-ul2) (with `--revision "refs/pr/25"`), and [Co-EdIT](https://huggingface.co/grammarly/coedit-large) are also supported.
|
||||
|
||||
## Translation with [MADLAD-400](https://arxiv.org/abs/2309.04662)
|
||||
|
||||
MADLAD-400 is a series of multilingual machine translation T5 models trained on 250 billion tokens covering over 450 languages using publicly available data. These models are competitive with significantly larger models.
|
||||
@ -22,7 +24,7 @@ cargo run --example t5 --release -- \
|
||||
Wie geht es dir, mein Freund?
|
||||
```
|
||||
|
||||
## Sentence embedding example:
|
||||
## Sentence embedding example
|
||||
|
||||
```bash
|
||||
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle."
|
||||
|
@ -104,6 +104,17 @@ impl T5ModelBuilder {
|
||||
api.get("model-00004-of-00005.safetensors")?,
|
||||
api.get("model-00005-of-00005.safetensors")?,
|
||||
]
|
||||
} else if model_id == "google/flan-ul2" {
|
||||
vec![
|
||||
api.get("model-00001-of-00008.safetensors")?,
|
||||
api.get("model-00002-of-00008.safetensors")?,
|
||||
api.get("model-00003-of-00008.safetensors")?,
|
||||
api.get("model-00004-of-00008.safetensors")?,
|
||||
api.get("model-00005-of-00008.safetensors")?,
|
||||
api.get("model-00006-of-00008.safetensors")?,
|
||||
api.get("model-00007-of-00008.safetensors")?,
|
||||
api.get("model-00008-of-00008.safetensors")?,
|
||||
]
|
||||
} else {
|
||||
vec![api.get("model.safetensors")?]
|
||||
};
|
||||
|
BIN
candle-examples/examples/trocr/assets/trocr.png
Normal file
BIN
candle-examples/examples/trocr/assets/trocr.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 36 KiB |
154
candle-examples/examples/trocr/image_processor.rs
Normal file
154
candle-examples/examples/trocr/image_processor.rs
Normal file
@ -0,0 +1,154 @@
|
||||
use image::{DynamicImage, ImageBuffer};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct ProcessorConfig {
|
||||
do_resize: bool,
|
||||
height: u32,
|
||||
width: u32,
|
||||
do_rescale: bool,
|
||||
do_normalize: bool,
|
||||
image_mean: Vec<f32>,
|
||||
image_std: Vec<f32>,
|
||||
}
|
||||
|
||||
impl Default for ProcessorConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
do_resize: true,
|
||||
height: 384,
|
||||
width: 384,
|
||||
do_rescale: true,
|
||||
do_normalize: true,
|
||||
image_mean: vec![0.5, 0.5, 0.5],
|
||||
image_std: vec![0.5, 0.5, 0.5],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ViTImageProcessor {
|
||||
do_resize: bool,
|
||||
height: u32,
|
||||
width: u32,
|
||||
do_normalize: bool,
|
||||
image_mean: Vec<f32>,
|
||||
image_std: Vec<f32>,
|
||||
}
|
||||
|
||||
impl ViTImageProcessor {
|
||||
pub fn new(config: &ProcessorConfig) -> Self {
|
||||
Self {
|
||||
do_resize: config.do_resize,
|
||||
height: config.height,
|
||||
width: config.width,
|
||||
do_normalize: config.do_normalize,
|
||||
image_mean: config.image_mean.clone(),
|
||||
image_std: config.image_std.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn preprocess(&self, images: Vec<&str>) -> Result<Tensor> {
|
||||
let height = self.height as usize;
|
||||
let width = self.width as usize;
|
||||
let channels = 3;
|
||||
|
||||
let images = self.load_images(images)?;
|
||||
|
||||
let resized_images: Vec<DynamicImage> = if self.do_resize {
|
||||
images
|
||||
.iter()
|
||||
.map(|image| self.resize(image.clone(), None).unwrap())
|
||||
.collect()
|
||||
} else {
|
||||
images
|
||||
};
|
||||
|
||||
let normalized_images: Vec<Tensor> = if self.do_normalize {
|
||||
resized_images
|
||||
.iter()
|
||||
.map(|image| self.normalize(image.clone(), None, None).unwrap())
|
||||
.collect()
|
||||
} else {
|
||||
let resized_images: Vec<ImageBuffer<image::Rgb<u8>, Vec<u8>>> =
|
||||
resized_images.iter().map(|image| image.to_rgb8()).collect();
|
||||
let data = resized_images
|
||||
.into_iter()
|
||||
.map(|image| image.into_raw())
|
||||
.collect::<Vec<Vec<u8>>>();
|
||||
|
||||
data.iter()
|
||||
.map(|image| {
|
||||
Tensor::from_vec(image.clone(), (height, width, channels), &Device::Cpu)
|
||||
.unwrap()
|
||||
.permute((2, 0, 1))
|
||||
.unwrap()
|
||||
})
|
||||
.collect::<Vec<Tensor>>()
|
||||
};
|
||||
|
||||
Tensor::stack(&normalized_images, 0)
|
||||
}
|
||||
|
||||
fn resize(
|
||||
&self,
|
||||
image: image::DynamicImage,
|
||||
size: Option<HashMap<String, u32>>,
|
||||
) -> Result<image::DynamicImage> {
|
||||
let (height, width) = match &size {
|
||||
Some(size) => (size.get("height").unwrap(), size.get("width").unwrap()),
|
||||
None => (&self.height, &self.width),
|
||||
};
|
||||
|
||||
let resized_image =
|
||||
image.resize_exact(*width, *height, image::imageops::FilterType::Triangle);
|
||||
|
||||
Ok(resized_image)
|
||||
}
|
||||
|
||||
fn normalize(
|
||||
&self,
|
||||
image: image::DynamicImage,
|
||||
mean: Option<Vec<f32>>,
|
||||
std: Option<Vec<f32>>,
|
||||
) -> Result<Tensor> {
|
||||
let mean = match mean {
|
||||
Some(mean) => mean,
|
||||
None => self.image_mean.clone(),
|
||||
};
|
||||
|
||||
let std = match std {
|
||||
Some(std) => std,
|
||||
None => self.image_std.clone(),
|
||||
};
|
||||
|
||||
let mean = Tensor::from_vec(mean, (3, 1, 1), &Device::Cpu)?;
|
||||
let std = Tensor::from_vec(std, (3, 1, 1), &Device::Cpu)?;
|
||||
|
||||
let image = image.to_rgb8();
|
||||
let data = image.into_raw();
|
||||
|
||||
let height = self.height as usize;
|
||||
let width = self.width as usize;
|
||||
let channels = 3;
|
||||
|
||||
let data =
|
||||
Tensor::from_vec(data, &[height, width, channels], &Device::Cpu)?.permute((2, 0, 1))?;
|
||||
|
||||
(data.to_dtype(DType::F32)? / 255.)?
|
||||
.broadcast_sub(&mean)?
|
||||
.broadcast_div(&std)
|
||||
}
|
||||
|
||||
pub fn load_images(&self, image_path: Vec<&str>) -> Result<Vec<image::DynamicImage>> {
|
||||
let mut images: Vec<image::DynamicImage> = Vec::new();
|
||||
for path in image_path {
|
||||
let img = image::io::Reader::open(path)?.decode().unwrap();
|
||||
images.push(img);
|
||||
}
|
||||
|
||||
Ok(images)
|
||||
}
|
||||
}
|
132
candle-examples/examples/trocr/main.rs
Normal file
132
candle-examples/examples/trocr/main.rs
Normal file
@ -0,0 +1,132 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Error as E;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::trocr;
|
||||
|
||||
use tokenizers::Tokenizer;
|
||||
mod image_processor;
|
||||
|
||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||
enum Which {
|
||||
Base,
|
||||
Large,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
/// Choose the variant of the model to run.
|
||||
#[arg(long, default_value = "base")]
|
||||
which: Which,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Text to be translated
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
use hf_hub::api::sync::Api;
|
||||
let args = Args::parse();
|
||||
|
||||
let tokenizer_dec = {
|
||||
let tokenizer = Api::new()?
|
||||
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
|
||||
.get("tokenizer.json")?;
|
||||
|
||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||
};
|
||||
|
||||
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let vb = {
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => match args.which {
|
||||
Which::Base => Api::new()?
|
||||
.repo(hf_hub::Repo::with_revision(
|
||||
"microsoft/trocr-base-handwritten".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/3".to_string(),
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
Which::Large => Api::new()?
|
||||
.repo(hf_hub::Repo::with_revision(
|
||||
"microsoft/trocr-large-handwritten".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/6".to_string(),
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
},
|
||||
};
|
||||
println!("model: {:?}", model);
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }
|
||||
};
|
||||
|
||||
let encoder_config = match args.which {
|
||||
Which::Base => candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten(),
|
||||
Which::Large => {
|
||||
candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten()
|
||||
}
|
||||
};
|
||||
|
||||
let decoder_config = trocr::TrOCRConfig::default();
|
||||
let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?;
|
||||
|
||||
let config = image_processor::ProcessorConfig::default();
|
||||
let processor = image_processor::ViTImageProcessor::new(&config);
|
||||
|
||||
let image = vec![args.image.as_str()];
|
||||
let image = processor.preprocess(image)?;
|
||||
|
||||
let encoder_xs = model.encoder().forward(&image)?;
|
||||
|
||||
let mut logits_processor =
|
||||
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
|
||||
|
||||
let mut token_ids: Vec<u32> = vec![decoder_config.decoder_start_token_id];
|
||||
for index in 0..1000 {
|
||||
let context_size = if index >= 1 { 1 } else { token_ids.len() };
|
||||
let start_pos = token_ids.len().saturating_sub(context_size);
|
||||
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
|
||||
|
||||
let logits = model.decode(&input_ids, &encoder_xs, start_pos)?;
|
||||
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||
let token = logits_processor.sample(&logits)?;
|
||||
token_ids.push(token);
|
||||
|
||||
if let Some(t) = tokenizer_dec.next_token(token)? {
|
||||
use std::io::Write;
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
if token == decoder_config.eos_token_id {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
println!();
|
||||
|
||||
Ok(())
|
||||
}
|
16
candle-examples/examples/trocr/readme.md
Normal file
16
candle-examples/examples/trocr/readme.md
Normal file
@ -0,0 +1,16 @@
|
||||
# candle-trocr
|
||||
|
||||
`TrOCR` is a transformer OCR Model. In this example it is used to
|
||||
transcribe image text. See the associated [model
|
||||
card](https://huggingface.co/microsoft/trocr-base-printed) for details on
|
||||
the model itself.
|
||||
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
cargo run --example trocr --release -- --which base --cpu --image candle-examples/examples/trocr/assets/trocr.png
|
||||
```
|
||||
|
||||
```
|
||||
<s> industry , Mr. Brown commented icily . " Let us have a</s>
|
||||
```
|
@ -128,7 +128,13 @@ impl Decoder {
|
||||
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
|
||||
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
|
||||
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
|
||||
let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
|
||||
let no_speech_token = m::NO_SPEECH_TOKENS
|
||||
.iter()
|
||||
.find_map(|token| token_id(&tokenizer, token).ok());
|
||||
let no_speech_token = match no_speech_token {
|
||||
None => anyhow::bail!("unable to find any non-speech token"),
|
||||
Some(n) => n,
|
||||
};
|
||||
Ok(Self {
|
||||
model,
|
||||
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
||||
@ -512,11 +518,7 @@ fn main() -> Result<()> {
|
||||
)
|
||||
} else {
|
||||
let config = repo.get("config.json")?;
|
||||
let tokenizer = if args.model == WhichModel::LargeV3 {
|
||||
panic!("openai/whisper-large-v3 does not provide a compatible tokenizer.json config at the moment")
|
||||
} else {
|
||||
repo.get("tokenizer.json")?
|
||||
};
|
||||
let tokenizer = repo.get("tokenizer.json")?;
|
||||
let model = repo.get("model.safetensors")?;
|
||||
(config, tokenizer, model)
|
||||
};
|
||||
|
268
candle-examples/examples/yi/main.rs
Normal file
268
candle-examples/examples/yi/main.rs
Normal file
@ -0,0 +1,268 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::yi::{Config, Model};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "6b")]
|
||||
L6b,
|
||||
#[value(name = "34b")]
|
||||
L34b,
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long, default_value = "01-ai/Yi-6B")]
|
||||
model_id: String,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "6b")]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => match args.which {
|
||||
Which::L6b => vec![
|
||||
repo.get("model-00001-of-00002.safetensors")?,
|
||||
repo.get("model-00002-of-00002.safetensors")?,
|
||||
],
|
||||
Which::L34b => vec![
|
||||
repo.get("model-00001-of-00007.safetensors")?,
|
||||
repo.get("model-00002-of-00007.safetensors")?,
|
||||
repo.get("model-00003-of-00007.safetensors")?,
|
||||
repo.get("model-00004-of-00007.safetensors")?,
|
||||
repo.get("model-00005-of-00007.safetensors")?,
|
||||
repo.get("model-00006-of-00007.safetensors")?,
|
||||
repo.get("model-00007-of-00007.safetensors")?,
|
||||
],
|
||||
},
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = match args.which {
|
||||
Which::L6b => Config::config_6b(),
|
||||
Which::L34b => Config::config_34b(),
|
||||
};
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -43,6 +43,7 @@ pub fn report(
|
||||
confidence_threshold: f32,
|
||||
nms_threshold: f32,
|
||||
) -> Result<DynamicImage> {
|
||||
let pred = pred.to_device(&Device::Cpu)?;
|
||||
let (npreds, pred_size) = pred.dims2()?;
|
||||
let nclasses = pred_size - 5;
|
||||
// The bounding boxes grouped by (maximum) class index.
|
||||
|
@ -32,7 +32,7 @@ Image source:
|
||||
### Pose Estimation
|
||||
```bash
|
||||
cargo run --example yolo-v8 --release -- \
|
||||
candle-examples/examples/yolo-v8/assets/peoples.jpeg --task pose
|
||||
candle-examples/examples/yolo-v8/assets/bike.jpg --task pose
|
||||
```
|
||||
|
||||

|
||||
|
@ -7,7 +7,7 @@ extern crate accelerate_src;
|
||||
mod model;
|
||||
use model::{Multiples, YoloV8, YoloV8Pose};
|
||||
|
||||
use candle::{DType, IndexOp, Result, Tensor};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
|
||||
use clap::{Parser, ValueEnum};
|
||||
@ -61,6 +61,7 @@ pub fn report_detect(
|
||||
nms_threshold: f32,
|
||||
legend_size: u32,
|
||||
) -> Result<DynamicImage> {
|
||||
let pred = pred.to_device(&Device::Cpu)?;
|
||||
let (pred_size, npreds) = pred.dims2()?;
|
||||
let nclasses = pred_size - 4;
|
||||
// The bounding boxes grouped by (maximum) class index.
|
||||
@ -153,6 +154,7 @@ pub fn report_pose(
|
||||
confidence_threshold: f32,
|
||||
nms_threshold: f32,
|
||||
) -> Result<DynamicImage> {
|
||||
let pred = pred.to_device(&Device::Cpu)?;
|
||||
let (pred_size, npreds) = pred.dims2()?;
|
||||
if pred_size != 17 * 3 + 4 + 1 {
|
||||
candle::bail!("unexpected pred-size {pred_size}");
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.3.0"
|
||||
version = "0.3.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], version = "0.3.0", package = "candle-core" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], version = "0.3.1", package = "candle-core" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
@ -21,4 +21,4 @@ rayon = "1.7.0"
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0", features = ["cuda"] }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.1", features = ["cuda"] }
|
||||
|
@ -233,8 +233,8 @@ impl FlashAttnVarLen {
|
||||
|
||||
let (seqlens_q, seqlens_q_layout) = self.seqlens_q.storage_and_layout();
|
||||
let seqlens_q = match &*seqlens_q {
|
||||
candle::Storage::Cpu(_) => candle::bail!("seqlens_q must be a cuda tensor"),
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
|
||||
_ => candle::bail!("seqlens_q must be a cuda tensor"),
|
||||
};
|
||||
let seqlens_q = match seqlens_q_layout.contiguous_offsets() {
|
||||
Some((o1, o2)) => seqlens_q.slice(o1..o2),
|
||||
@ -243,8 +243,8 @@ impl FlashAttnVarLen {
|
||||
|
||||
let (seqlens_k, seqlens_k_layout) = self.seqlens_k.storage_and_layout();
|
||||
let seqlens_k = match &*seqlens_k {
|
||||
candle::Storage::Cpu(_) => candle::bail!("seqlens_k must be a cuda tensor"),
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
|
||||
_ => candle::bail!("seqlens_k must be a cuda tensor"),
|
||||
};
|
||||
let seqlens_k = match seqlens_k_layout.contiguous_offsets() {
|
||||
Some((o1, o2)) => seqlens_k.slice(o1..o2),
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.3.0"
|
||||
version = "0.3.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
@ -14,4 +14,4 @@ license = "MIT OR Apache-2.0"
|
||||
[build-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
glob = "0.3.1"
|
||||
rayon = "1.7.0"
|
||||
rayon = "1.7.0"
|
||||
|
@ -1,17 +1,16 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.3.0"
|
||||
version = "0.3.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
description = "Metal kernels for Candle"
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
keywords = ["blas", "tensor", "machine-learning"]
|
||||
categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||
metal = { path = "../../metal-rs", features = ["mps"] }
|
||||
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
|
||||
once_cell = "1.18.0"
|
||||
thiserror = "1"
|
||||
tracing = "0.1.37"
|
||||
|
@ -29,15 +29,96 @@ kernel void FN_NAME( \
|
||||
if (id >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
const TYPENAME m = TYPENAME(mul); \
|
||||
const TYPENAME a = TYPENAME(add); \
|
||||
output[id] = input[id] * m + a; \
|
||||
output[id] = TYPENAME(float(input[id]) * mul + add); \
|
||||
} \
|
||||
kernel void FN_NAME##_strided( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant float &mul, \
|
||||
constant float &add, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint id [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (id >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[id] = TYPENAME(float(input[get_strided_index(id, num_dims, dims, strides)]) * mul + add); \
|
||||
}
|
||||
|
||||
#define POWF(FN_NAME, TYPENAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
constant float &mul, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint id [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (id >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[id] = TYPENAME(pow(input[id], TYPENAME(mul))); \
|
||||
} \
|
||||
kernel void FN_NAME##_strided( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant float &mul, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint id [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (id >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[id] = TYPENAME(pow(input[get_strided_index(id, num_dims, dims, strides)], TYPENAME(mul))); \
|
||||
}
|
||||
|
||||
#define ELU(FN_NAME, TYPENAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
constant float &mul, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint id [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (id >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
const TYPENAME x = input[id]; \
|
||||
output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \
|
||||
} \
|
||||
kernel void FN_NAME##_strided( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant float &mul, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint id [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (id >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
const TYPENAME x = input[get_strided_index(id, num_dims, dims, strides)]; \
|
||||
output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \
|
||||
} \
|
||||
|
||||
|
||||
AFFINE(affine_float, float)
|
||||
AFFINE(affine_half, half)
|
||||
POWF(powf_float, float)
|
||||
POWF(powf_half, half)
|
||||
ELU(elu_float, float)
|
||||
ELU(elu_half, half)
|
||||
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
AFFINE(affine_bfloat, bfloat);
|
||||
POWF(powf_bfloat, bfloat);
|
||||
ELU(elu_bfloat, bfloat);
|
||||
#endif
|
||||
|
@ -23,12 +23,12 @@ kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
device const LEFT_TYPENAME *input, \
|
||||
device RIGHT_TYPENAME *output, \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \
|
||||
output[tid] = RIGHT_TYPENAME(input[tid]); \
|
||||
} \
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
@ -37,15 +37,19 @@ kernel void FN_NAME_STRIDED( \
|
||||
constant size_t *strides, \
|
||||
device const LEFT_TYPENAME *input, \
|
||||
device RIGHT_TYPENAME *output, \
|
||||
uint i [[ thread_position_in_grid ]] \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (i >= dim) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
|
||||
output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \
|
||||
} \
|
||||
|
||||
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
|
||||
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
|
||||
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
|
||||
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
|
||||
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
|
||||
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#endif
|
||||
|
@ -16,16 +16,16 @@ kernel void NAME( \
|
||||
if (gid >= dst_size) { \
|
||||
return; \
|
||||
} \
|
||||
const size_t id_i = gid / right_size / left_size; \
|
||||
const size_t id_i = (gid / right_size) % ids_size; \
|
||||
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
|
||||
const size_t right_rank_i = gid % right_size; \
|
||||
const size_t left_rank_i = gid % left_size; \
|
||||
const size_t left_rank_i = gid / right_size / ids_size; \
|
||||
/* \
|
||||
// Force prevent out of bounds indexing \
|
||||
// since there doesn't seem to be a good way to force crash \
|
||||
// No need to check for zero we're only allowing unsized. \
|
||||
*/ \
|
||||
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
|
||||
const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; \
|
||||
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \
|
||||
output[gid] = input[src_i]; \
|
||||
}
|
||||
|
||||
@ -75,6 +75,7 @@ kernel void FN_NAME( \
|
||||
|
||||
|
||||
INDEX_OP(is_u32_f32, uint, float)
|
||||
INDEX_OP(is_u32_f16, uint, half)
|
||||
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,8 @@
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
@ -16,18 +18,18 @@ METAL_FUNC uint get_strided_index(
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
constant int THREADGROUP_SIZE = 256;
|
||||
constant int THREADGROUP_SIZE = 2048;
|
||||
|
||||
# define REDUCE(FN, NAME, TYPENAME) \
|
||||
# define REDUCE(FN, NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const TYPENAME *src, \
|
||||
device TYPENAME *dst, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint blockDim [[ threads_per_threadgroup ]] \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
@ -45,10 +47,10 @@ kernel void NAME( \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
*/ \
|
||||
TYPENAME x = shared_memory[tid]; \
|
||||
TYPENAME y = src[idx]; \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = src[idx]; \
|
||||
shared_memory[tid] = FN; \
|
||||
idx += blockDim; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
@ -56,10 +58,10 @@ kernel void NAME( \
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = blockDim / 2; s > 0; s >>= 1) { \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
TYPENAME x = shared_memory[tid]; \
|
||||
TYPENAME y = shared_memory[tid + s]; \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = shared_memory[tid + s]; \
|
||||
shared_memory[tid] = FN; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
@ -68,72 +70,74 @@ kernel void NAME( \
|
||||
dst[dst_id] = shared_memory[0]; \
|
||||
} \
|
||||
|
||||
kernel void softmax_float(
|
||||
constant size_t &src_numel,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
device const float *src,
|
||||
device float *dst,
|
||||
uint id [[ thread_position_in_grid ]],
|
||||
uint tid [[ thread_index_in_threadgroup ]],
|
||||
uint dst_id [[ threadgroup_position_in_grid ]],
|
||||
uint blockDim [[ threads_per_threadgroup ]]
|
||||
) {
|
||||
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE];
|
||||
|
||||
shared_memory[tid] = -INFINITY;
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// to (dst_id + 1) * el_to_sum_per_block.
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
||||
size_t idx = start_idx + tid;
|
||||
|
||||
while (idx < stop_idx) {
|
||||
// TODO: Fast version for the contiguous case.
|
||||
shared_memory[tid] = max(shared_memory[tid], src[idx]);
|
||||
idx += blockDim;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// reduction in shared memory
|
||||
for (uint s = blockDim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
float max = shared_memory[0];
|
||||
|
||||
shared_memory[tid] = 0;
|
||||
|
||||
// Restart
|
||||
idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
// TODO: Fast version for the contiguous case.
|
||||
const float val = exp(src[idx] - max);
|
||||
dst[idx] = val;
|
||||
shared_memory[tid] += val;
|
||||
idx += blockDim;
|
||||
}
|
||||
// reduction in shared memory
|
||||
for (uint s = blockDim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
shared_memory[tid] += shared_memory[tid + s];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
const float inv_acc = 1/shared_memory[0];
|
||||
idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
dst[idx] *= inv_acc;
|
||||
idx += blockDim;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
REDUCE(x + y, fast_sum_float, float)
|
||||
REDUCE(x * y, fast_mul_float, float)
|
||||
REDUCE(max(x, y), fast_max_float, float)
|
||||
|
||||
#define SOFTMAX(NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
\
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
shared_memory[tid] = -INFINITY; \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||
size_t idx = start_idx + tid; \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
while (idx < stop_idx) { \
|
||||
shared_memory[tid] = MAX(shared_memory[tid], src[idx]); \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
float _max = shared_memory[0]; \
|
||||
\
|
||||
shared_memory[tid] = 0; \
|
||||
\
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
const T val = T(exp(src[idx] - _max)); \
|
||||
dst[idx] = val; \
|
||||
shared_memory[tid] += val; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] += shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
const T inv_acc = T(1/shared_memory[0]); \
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
dst[idx] *= inv_acc; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
} \
|
||||
|
||||
SOFTMAX(softmax_float, float)
|
||||
SOFTMAX(softmax_half, half)
|
||||
#if __METAL_VERSION__ >= 310
|
||||
SOFTMAX(softmax_bfloat, bfloat)
|
||||
#endif
|
||||
|
@ -32,6 +32,9 @@ kernel void FN_NAME( \
|
||||
device TYPENAME *out ,\
|
||||
uint i [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (i >= numel){ \
|
||||
return; \
|
||||
} \
|
||||
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
|
||||
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
|
||||
|
746
candle-metal-kernels/src/tests.rs
Normal file
746
candle-metal-kernels/src/tests.rs
Normal file
@ -0,0 +1,746 @@
|
||||
use super::*;
|
||||
use half::{bf16, f16};
|
||||
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
|
||||
|
||||
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let ptr = data.as_ptr() as *const core::ffi::c_void;
|
||||
let size = (data.len() * std::mem::size_of::<T>()) as u64;
|
||||
device.new_buffer_with_data(ptr, size, options)
|
||||
}
|
||||
|
||||
fn device() -> Device {
|
||||
Device::system_default().unwrap()
|
||||
}
|
||||
|
||||
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
|
||||
let b = 10f32.powi(digits);
|
||||
v.iter().map(|t| f32::round(t * b) / b).collect()
|
||||
}
|
||||
|
||||
fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> {
|
||||
let b = 10f32.powi(digits);
|
||||
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
|
||||
}
|
||||
|
||||
fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> {
|
||||
let b = 10f32.powi(digits);
|
||||
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
|
||||
}
|
||||
|
||||
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
call_unary_contiguous(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
v.len(),
|
||||
&input,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
output.read_to_vec::<T>(v.len())
|
||||
}
|
||||
|
||||
fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let left = new_buffer(&device, x);
|
||||
let right = new_buffer(&device, y);
|
||||
let output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
|
||||
call_binary_contiguous(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
x.len(),
|
||||
&left,
|
||||
&right,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
output.read_to_vec::<T>(x.len())
|
||||
}
|
||||
|
||||
fn run_strided<T: Clone>(
|
||||
v: &[T],
|
||||
kernel: unary::strided::Kernel,
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
offset: usize,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
let kernels = Kernels::new();
|
||||
call_unary_strided(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
kernel,
|
||||
shape,
|
||||
&input,
|
||||
strides,
|
||||
offset,
|
||||
&output,
|
||||
0,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
output.read_to_vec::<T>(v.len())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cos_f32() {
|
||||
let v = vec![1.0f32, 2.0, 3.0];
|
||||
let results = run(&v, unary::contiguous::cos::FLOAT);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
|
||||
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
|
||||
|
||||
let v = vec![1.0f32; 10_000];
|
||||
let results = run(&v, unary::contiguous::cos::FLOAT);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
|
||||
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cos_f32_strided() {
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let shape = vec![6];
|
||||
let strides = vec![1];
|
||||
let offset = 0;
|
||||
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||
);
|
||||
assert_eq!(
|
||||
approx(expected, 4),
|
||||
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||
);
|
||||
|
||||
// Contiguous
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let shape = vec![3, 2];
|
||||
let strides = vec![2, 1];
|
||||
let offset = 0;
|
||||
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||
);
|
||||
assert_eq!(
|
||||
approx(expected, 4),
|
||||
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||
);
|
||||
|
||||
// Transposed
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let shape = vec![3, 2];
|
||||
let strides = vec![1, 3];
|
||||
let offset = 0;
|
||||
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![0.5403, -0.6536, -0.4161, 0.2837, -0.99, 0.9602]
|
||||
);
|
||||
assert_eq!(
|
||||
approx(expected, 4),
|
||||
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
|
||||
);
|
||||
|
||||
// Very large
|
||||
let v = vec![1.0f32; 10_000];
|
||||
let shape = vec![2, 5_000];
|
||||
let strides = vec![2, 1];
|
||||
let offset = 0;
|
||||
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
|
||||
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cos_strided_random() {
|
||||
let v: Vec<_> = (0..10_000).map(|_| rand::random::<f32>()).collect();
|
||||
let shape = vec![5_000, 2];
|
||||
let strides = vec![1, 5_000];
|
||||
let offset = 0;
|
||||
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(approx(vec![results[0]], 4), approx(vec![expected[0]], 4));
|
||||
assert_eq!(
|
||||
approx(vec![results[1]], 4),
|
||||
approx(vec![expected[5_000]], 4)
|
||||
);
|
||||
assert_eq!(approx(vec![results[2]], 4), approx(vec![expected[1]], 4));
|
||||
assert_eq!(
|
||||
approx(vec![results[3]], 4),
|
||||
approx(vec![expected[5_001]], 4)
|
||||
);
|
||||
assert_eq!(
|
||||
approx(vec![results[5_000]], 4),
|
||||
approx(vec![expected[2_500]], 4)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gelu_f16() {
|
||||
let v: Vec<f16> = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect();
|
||||
let expected: Vec<f32> = vec![-0.0, -0.16, 0.0, 0.84, 1.96, 3.0, 10.0, 20.0];
|
||||
let results = run(&v, unary::contiguous::gelu::HALF);
|
||||
assert_eq!(approx_f16(results, 2), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gelu_f32() {
|
||||
let v: Vec<f32> = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0];
|
||||
let expected: Vec<f32> = vec![-0.0, -0.159, 0.0, 0.841, 1.955, 2.996, 10.0, 20.0];
|
||||
let results = run(&v, unary::contiguous::gelu::FLOAT);
|
||||
assert_eq!(approx(results, 3), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn binary_add_f32() {
|
||||
let left = vec![1.0f32, 2.0, 3.0];
|
||||
let right = vec![2.0f32, 3.1, 4.2];
|
||||
let results = run_binary(&left, &right, binary::contiguous::add::FLOAT);
|
||||
let expected: Vec<_> = left
|
||||
.iter()
|
||||
.zip(right.iter())
|
||||
.map(|(&x, &y)| x + y)
|
||||
.collect();
|
||||
assert_eq!(approx(results, 4), vec![3.0f32, 5.1, 7.2]);
|
||||
assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]);
|
||||
}
|
||||
|
||||
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let size = (v.len() * std::mem::size_of::<U>()) as u64;
|
||||
let output = device.new_buffer(size, options);
|
||||
|
||||
call_cast_contiguous(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
v.len(),
|
||||
&input,
|
||||
0,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
output.read_to_vec::<U>(v.len())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_u32_f32() {
|
||||
let v = vec![1u32, 2, 3];
|
||||
let results = cast(&v, "cast_u32_f32");
|
||||
let expected: Vec<_> = v.iter().map(|&v| v as f32).collect();
|
||||
assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
|
||||
assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
|
||||
|
||||
let v = vec![1.0f32, 2.0, 3.0];
|
||||
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
|
||||
let results: Vec<f32> = cast(&input, "cast_f16_f32");
|
||||
assert_eq!(results, vec![1.0f32, 2.0, 3.0]);
|
||||
|
||||
let v = vec![1.0f32; 10_000];
|
||||
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
|
||||
let results: Vec<f32> = cast(&input, "cast_f16_f32");
|
||||
assert_eq!(results.len(), 10_000);
|
||||
assert_eq!(&results[..10], vec![1.0f32; 10]);
|
||||
assert_eq!(results, vec![1.0f32; 10_000]);
|
||||
}
|
||||
|
||||
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
|
||||
let input = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
|
||||
let size = v.len();
|
||||
|
||||
call_affine(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
"affine_float",
|
||||
size,
|
||||
&input,
|
||||
&output,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
output.read_to_vec::<T>(v.len())
|
||||
}
|
||||
|
||||
fn run_affine_strided<T: Clone>(
|
||||
v: &[T],
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
mul: f64,
|
||||
add: f64,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
|
||||
let input = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
|
||||
call_affine_strided(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
"affine_float_strided",
|
||||
shape,
|
||||
&input,
|
||||
strides,
|
||||
0,
|
||||
&output,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
let len: usize = shape.iter().product();
|
||||
output.read_to_vec::<T>(len)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn affine() {
|
||||
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
let mul = 1.5;
|
||||
let add = 1.1;
|
||||
let result = run_affine(&input, mul, add);
|
||||
assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]);
|
||||
|
||||
let input = [1.0f32; 40_000];
|
||||
let mul = 1.5;
|
||||
let add = 1.1;
|
||||
let result = run_affine(&input, mul, add);
|
||||
assert_eq!(result, vec![2.6; 40_000]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn affine_strided() {
|
||||
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
let mul = 1.5;
|
||||
let add = 1.1;
|
||||
let shape = [4];
|
||||
let strides = [2];
|
||||
let result = run_affine_strided(&input, &shape, &strides, mul, add);
|
||||
// 1 on 2
|
||||
assert_eq!(result, vec![2.6, 5.6, 8.6, 11.6]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select() {
|
||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||
let shape = [5, 2];
|
||||
let ids = [0u32, 4, 2];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]);
|
||||
|
||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||
let shape = [2, 5];
|
||||
let ids = [0u32, 1, 0];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
assert_eq!(
|
||||
result,
|
||||
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select_f16() {
|
||||
let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
||||
.into_iter()
|
||||
.map(|x| f16::from_f32(x))
|
||||
.collect();
|
||||
let shape = [5, 2];
|
||||
let ids = [0u32, 4, 2];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
assert_eq!(
|
||||
approx_f16(result, 4),
|
||||
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select_dim1() {
|
||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||
let shape = [5, 2];
|
||||
let ids = [0u32, 1, 0];
|
||||
let dim = 1;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
assert_eq!(
|
||||
result,
|
||||
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
|
||||
);
|
||||
}
|
||||
|
||||
fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
embeddings: &[T],
|
||||
shape: &[usize],
|
||||
ids: &[I],
|
||||
dim: usize,
|
||||
) -> Vec<T> {
|
||||
let device = Device::system_default().expect("no device found");
|
||||
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let embeddings_buffer = new_buffer(&device, &embeddings);
|
||||
let ids_buffer = new_buffer(&device, &ids);
|
||||
|
||||
let left_size: usize = shape[..dim].iter().product();
|
||||
let right_size: usize = shape[dim + 1..].iter().product();
|
||||
let dst_el = ids.len() * left_size * right_size;
|
||||
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
|
||||
|
||||
let name = match core::mem::size_of::<T>() {
|
||||
4 => "is_u32_f32",
|
||||
2 => "is_u32_f16",
|
||||
_ => unimplemented!(),
|
||||
};
|
||||
|
||||
let kernels = Kernels::new();
|
||||
call_index_select(
|
||||
&device,
|
||||
&command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
shape,
|
||||
ids.len(),
|
||||
dim,
|
||||
&embeddings_buffer,
|
||||
&ids_buffer,
|
||||
&dst_buffer,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
dst_buffer.read_to_vec::<T>(dst_el)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_add() {
|
||||
let device = Device::system_default().expect("no device found");
|
||||
|
||||
let options = CompileOptions::new();
|
||||
let library = device.new_library_with_source(INDEXING, &options).unwrap();
|
||||
|
||||
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
|
||||
let right = [1.0f32; 15];
|
||||
let index = [0u32, 4, 2];
|
||||
let ids_dim_size = index.len() as u32;
|
||||
let dst_dim_size: u32 = 15;
|
||||
let left_size: u32 = 3;
|
||||
let right_size: u32 = 3;
|
||||
|
||||
let function = library.get_function("ia_u32_f32", None).unwrap();
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(&function)
|
||||
.unwrap();
|
||||
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let index_buffer = new_buffer(&device, &index);
|
||||
let inputs_buffer = new_buffer(&device, &left);
|
||||
let outputs_buffer = new_buffer(&device, &right);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
&index_buffer,
|
||||
&inputs_buffer,
|
||||
&outputs_buffer,
|
||||
ids_dim_size,
|
||||
left_size,
|
||||
dst_dim_size,
|
||||
right_size
|
||||
)
|
||||
);
|
||||
|
||||
let grid_size = MTLSize {
|
||||
width: right.len() as NSUInteger,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width: pipeline.max_total_threads_per_threadgroup(),
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.dispatch_thread_groups(grid_size, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
let expected = vec![
|
||||
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
|
||||
];
|
||||
let result = outputs_buffer.read_to_vec::<f32>(right.len());
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cos_f16() {
|
||||
let v: Vec<f16> = [1.0f32, 2.0, 3.0]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect();
|
||||
let results = run(&v, unary::contiguous::cos::HALF);
|
||||
let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
|
||||
assert_eq!(approx_f16(results, 2), vec![0.54, -0.42, -0.99]);
|
||||
assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]);
|
||||
}
|
||||
|
||||
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
||||
call_reduce_contiguous(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
v.len(),
|
||||
out_length,
|
||||
&input,
|
||||
0,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
output.read_to_vec::<T>(out_length)
|
||||
}
|
||||
|
||||
fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
call_last_softmax(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
v.len(),
|
||||
last_dim,
|
||||
&input,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
output.read_to_vec::<T>(v.len())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reduce_sum() {
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let out_length = 1;
|
||||
|
||||
let results = run_reduce(&v, out_length, "fast_sum_float");
|
||||
assert_eq!(approx(results, 4), vec![21.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reduce_sum2() {
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let out_length = 2;
|
||||
|
||||
let results = run_reduce(&v, out_length, "fast_sum_float");
|
||||
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn softmax() {
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let last_dim = 6;
|
||||
let results = run_softmax(&v, last_dim, "softmax_float");
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
|
||||
);
|
||||
|
||||
let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let last_dim = 6;
|
||||
let results = run_softmax(&v, last_dim, "softmax_float");
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
|
||||
);
|
||||
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let last_dim = 3;
|
||||
let results = run_softmax(&v, last_dim, "softmax_float");
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
|
||||
);
|
||||
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
let last_dim = 6;
|
||||
let results = run_softmax(&v, last_dim, "softmax_half");
|
||||
assert_eq!(
|
||||
approx_f16(results, 4),
|
||||
vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
|
||||
);
|
||||
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
let last_dim = 6;
|
||||
let results = run_softmax(&v, last_dim, "softmax_bfloat");
|
||||
assert_eq!(
|
||||
approx_bf16(results, 4),
|
||||
vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328]
|
||||
);
|
||||
}
|
||||
|
||||
fn run_where_cond<I: Clone, T: Clone>(
|
||||
shape: &[usize],
|
||||
cond: &[I],
|
||||
(cond_stride, cond_offset): (Vec<usize>, usize),
|
||||
left_true: &[T],
|
||||
(left_stride, left_offset): (Vec<usize>, usize),
|
||||
right_false: &[T],
|
||||
(_right_stride, _right_offset): (Vec<usize>, usize),
|
||||
name: &'static str,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let length = cond.len();
|
||||
let cond = device.new_buffer_with_data(
|
||||
cond.as_ptr() as *const core::ffi::c_void,
|
||||
std::mem::size_of_val(cond) as u64,
|
||||
options,
|
||||
);
|
||||
let left = device.new_buffer_with_data(
|
||||
left_true.as_ptr() as *const core::ffi::c_void,
|
||||
(length * core::mem::size_of::<T>()) as u64,
|
||||
options,
|
||||
);
|
||||
let right = device.new_buffer_with_data(
|
||||
right_false.as_ptr() as *const core::ffi::c_void,
|
||||
(length * core::mem::size_of::<T>()) as u64,
|
||||
options,
|
||||
);
|
||||
|
||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
||||
call_where_cond_strided(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
shape,
|
||||
&cond,
|
||||
(&cond_stride, cond_offset),
|
||||
&left,
|
||||
(&left_stride, left_offset),
|
||||
&right,
|
||||
(&cond_stride, cond_offset),
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
output.read_to_vec::<T>(length)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn where_cond() {
|
||||
let shape = vec![6];
|
||||
let cond = vec![0u8, 1, 0, 0, 1, 1];
|
||||
let cond_l = (vec![1], 0);
|
||||
let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let left_l = (vec![1], 0);
|
||||
let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0];
|
||||
let right_l = (vec![1], 0);
|
||||
let results = run_where_cond(
|
||||
&shape,
|
||||
&cond,
|
||||
cond_l,
|
||||
&left_true,
|
||||
left_l,
|
||||
&right_false,
|
||||
right_l,
|
||||
"where_u8_f32",
|
||||
);
|
||||
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
|
||||
}
|
@ -1,4 +1,7 @@
|
||||
#include <metal_stdlib>
|
||||
#include <metal_math>
|
||||
#
|
||||
using namespace metal;
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
@ -17,10 +20,44 @@ METAL_FUNC uint get_strided_index(
|
||||
|
||||
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
|
||||
template <typename T> METAL_FUNC T neg(T in){ return -in; }
|
||||
template <typename T> METAL_FUNC T id(T in){ return in; }
|
||||
template <typename T> METAL_FUNC T erf(T in){
|
||||
float x = (float) in;
|
||||
// constants
|
||||
float a1 = 0.254829592;
|
||||
float a2 = -0.284496736;
|
||||
float a3 = 1.421413741;
|
||||
float a4 = -1.453152027;
|
||||
float a5 = 1.061405429;
|
||||
float p = 0.3275911;
|
||||
|
||||
// Save the sign of x
|
||||
int sign = 1;
|
||||
if (x < 0)
|
||||
sign = -1;
|
||||
x = fabs(x);
|
||||
|
||||
// A&S formula 7.1.26
|
||||
float t = 1.0/(1.0 + p*x);
|
||||
float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x);
|
||||
|
||||
return T(sign*y);
|
||||
}
|
||||
template <typename T> METAL_FUNC T id(T in) { return in; }
|
||||
template <typename T> METAL_FUNC T gelu_erf(T x) {
|
||||
return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2);
|
||||
}
|
||||
template <typename T> METAL_FUNC T gelu(T x) {
|
||||
if (x > 5) {
|
||||
return x;
|
||||
}
|
||||
T x_sq = x * x;
|
||||
T x_cube = x_sq * x;
|
||||
T alpha = x + static_cast<T>(0.044715) * x_cube;
|
||||
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
|
||||
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
|
||||
}
|
||||
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||
kernel void FN_NAME( \
|
||||
@ -32,7 +69,7 @@ kernel void FN_NAME( \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[thread_position_in_grid] = TYPENAME(FN(input[thread_position_in_grid])); \
|
||||
output[thread_position_in_grid] = TYPENAME(FN(float(input[thread_position_in_grid]))); \
|
||||
}\
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
@ -46,7 +83,7 @@ kernel void FN_NAME_STRIDED( \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[thread_position_in_grid] = TYPENAME(FN(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)])); \
|
||||
output[thread_position_in_grid] = TYPENAME(FN(float(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)]))); \
|
||||
}
|
||||
|
||||
#define UNARY_OP(NAME) \
|
||||
@ -63,8 +100,18 @@ UNARY_OP(sqr)
|
||||
UNARY_OP(sqrt)
|
||||
UNARY_OP(neg)
|
||||
UNARY_OP(exp)
|
||||
UNARY_OP(log)
|
||||
UNARY_OP(gelu)
|
||||
UNARY_OP(ceil)
|
||||
UNARY_OP(floor)
|
||||
UNARY_OP(round)
|
||||
UNARY_OP(gelu_erf)
|
||||
UNARY_OP(erf)
|
||||
UNARY_OP(tanh)
|
||||
UNARY(id, float, copy_float, copy_float_strided)
|
||||
UNARY(id, half, copy_half, copy_half_strided)
|
||||
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
||||
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
BFLOAT_UNARY_OP(cos)
|
||||
@ -73,6 +120,14 @@ BFLOAT_UNARY_OP(sqr)
|
||||
BFLOAT_UNARY_OP(sqrt)
|
||||
BFLOAT_UNARY_OP(neg)
|
||||
BFLOAT_UNARY_OP(exp)
|
||||
BFLOAT_UNARY_OP(log)
|
||||
BFLOAT_UNARY_OP(gelu)
|
||||
BFLOAT_UNARY_OP(ceil)
|
||||
BFLOAT_UNARY_OP(floor)
|
||||
BFLOAT_UNARY_OP(round)
|
||||
BFLOAT_UNARY_OP(gelu_erf)
|
||||
BFLOAT_UNARY_OP(erf)
|
||||
BFLOAT_UNARY_OP(tanh)
|
||||
|
||||
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
|
||||
#endif
|
||||
|
@ -50,6 +50,7 @@ fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) {
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
"affine_float",
|
||||
v.len(),
|
||||
&input,
|
||||
&mut output,
|
@ -147,7 +147,7 @@ fn run_unary_bench<T: Clone>(
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||
type_name::<T>().split("::").last().unwrap(),
|
||||
kernel_name.to_string(),
|
||||
kernel_name.0,
|
||||
v.len(),
|
||||
iterations,
|
||||
total_time,
|
||||
@ -159,7 +159,7 @@ fn run_unary_bench<T: Clone>(
|
||||
let shape = vec![2, 5_000];
|
||||
let strides = vec![2, 1];
|
||||
let offset = 0;
|
||||
for kernel_name in strided {
|
||||
for kernel_name in &strided {
|
||||
let total_time = autoreleasepool(|| {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start = Instant::now();
|
||||
@ -187,7 +187,7 @@ fn run_unary_bench<T: Clone>(
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||
type_name::<T>().split("::").last().unwrap(),
|
||||
kernel_name.to_string(),
|
||||
kernel_name.0,
|
||||
v.len(),
|
||||
iterations,
|
||||
total_time,
|
@ -11,7 +11,7 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
half = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
@ -19,6 +19,8 @@ num-traits = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
metal = { workspace = true, optional = true }
|
||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
@ -29,3 +31,4 @@ default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
||||
cuda = ["candle/cuda"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
||||
metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"]
|
||||
|
@ -6,7 +6,7 @@ extern crate intel_mkl_src;
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::quantized::GgmlType;
|
||||
use candle::{CpuStorage, Device, Layout, Result, Shape, Tensor, D};
|
||||
use candle::{CpuStorage, Device, Layout, Module, Result, Shape, Tensor, D};
|
||||
use clap::{Parser, Subcommand};
|
||||
|
||||
const CHECK_CONV2D: bool = false;
|
||||
|
@ -6,7 +6,6 @@ use serde::Deserialize;
|
||||
pub enum Activation {
|
||||
#[default]
|
||||
Gelu,
|
||||
#[serde(rename = "gated-gelu")]
|
||||
NewGelu,
|
||||
Relu,
|
||||
Relu2,
|
||||
|
@ -9,7 +9,6 @@ pub struct Embedding {
|
||||
|
||||
impl Embedding {
|
||||
pub fn new(embeddings: Tensor, hidden_size: usize) -> Self {
|
||||
// todo!("Embedding {embeddings}");
|
||||
Self {
|
||||
embeddings,
|
||||
hidden_size,
|
||||
|
@ -201,6 +201,48 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
};
|
||||
Ok((dst, layout.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &candle::MetalStorage,
|
||||
layout: &Layout,
|
||||
) -> Result<(candle::MetalStorage, Shape)> {
|
||||
use candle::{backend::BackendStorage, DType};
|
||||
let device = storage.device();
|
||||
let command_buffer = device.command_buffer();
|
||||
let kernels = device.kernels();
|
||||
let name = match storage.dtype() {
|
||||
DType::F32 => "softmax_float",
|
||||
DType::F16 => "softmax_half",
|
||||
DType::BF16 => "softmax_bfloat",
|
||||
dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
|
||||
};
|
||||
|
||||
let n = layout.stride().len();
|
||||
if !(layout.stride()[n - 1] == 1 && layout.start_offset() == 0) {
|
||||
candle::bail!("Non contiguous softmax-last-dim is not implemented");
|
||||
}
|
||||
|
||||
let last_dim = layout.dims()[layout.shape().rank() - 1];
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax");
|
||||
candle_metal_kernels::call_last_softmax(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
elem_count,
|
||||
last_dim,
|
||||
storage.buffer(),
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
output.did_modify_range(metal::NSRange::new(0, output.length()));
|
||||
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
|
||||
Ok((newstorage, layout.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-onnx"
|
||||
version = "0.3.0"
|
||||
version = "0.3.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "ONNX support for Candle"
|
||||
@ -10,8 +10,8 @@ categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
||||
prost = "0.12.1"
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -741,6 +741,25 @@ pub fn simple_eval(
|
||||
let output = input.to_dtype(dtype)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#CumSum
|
||||
"CumSum" => {
|
||||
let exclusive = get_attr_opt::<i64>(node, "exclusive")?
|
||||
.copied()
|
||||
.unwrap_or(0);
|
||||
let reverse = get_attr_opt::<i64>(node, "reverse")?.copied().unwrap_or(0);
|
||||
if exclusive != 0 {
|
||||
bail!("only exclusive == 0 is supported in CumSum")
|
||||
}
|
||||
if reverse != 0 {
|
||||
bail!("only reverse == 0 is supported in CumSum")
|
||||
}
|
||||
let input = get(&node.input[0])?;
|
||||
let axis = get(&node.input[1])?
|
||||
.to_dtype(DType::U32)?
|
||||
.to_vec0::<u32>()?;
|
||||
let output = input.cumsum(axis as usize)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||
}
|
||||
}
|
||||
|
746
candle-onnx/tests/ops.rs
Normal file
746
candle-onnx/tests/ops.rs
Normal file
@ -0,0 +1,746 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::{Device, Result, Tensor};
|
||||
use candle_onnx::onnx::{GraphProto, ModelProto, NodeProto, ValueInfoProto};
|
||||
use std::collections::HashMap;
|
||||
|
||||
const INPUT_X: &str = "x";
|
||||
const INPUT_Y: &str = "y";
|
||||
const OUTPUT_Z: &str = "z";
|
||||
|
||||
fn create_model_proto_with_graph(graph: Option<GraphProto>) -> ModelProto {
|
||||
ModelProto {
|
||||
metadata_props: vec![],
|
||||
training_info: vec![],
|
||||
functions: vec![],
|
||||
ir_version: 0,
|
||||
opset_import: vec![],
|
||||
producer_name: "".to_string(),
|
||||
producer_version: "".to_string(),
|
||||
domain: "".to_string(),
|
||||
model_version: 0,
|
||||
doc_string: "".to_string(),
|
||||
graph,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluation_fails_without_defined_graph() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(None);
|
||||
|
||||
let inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
|
||||
match candle_onnx::simple_eval(&manual_graph, inputs) {
|
||||
Err(err) => assert_eq!(err.to_string(), "no graph defined in proto"),
|
||||
Ok(_) => panic!("Expected an error due to undefined graph"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Add"
|
||||
#[test]
|
||||
fn test_add_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Add".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let first = z
|
||||
.to_vec1::<f64>()?
|
||||
.to_vec()
|
||||
.get(0)
|
||||
.expect("Failed to get first element")
|
||||
.clone();
|
||||
assert_eq!(first, 4.0f64);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Sub"
|
||||
#[test]
|
||||
fn test_sub_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Sub".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let first = z
|
||||
.to_vec1::<f64>()?
|
||||
.to_vec()
|
||||
.get(0)
|
||||
.expect("Failed to get first element")
|
||||
.clone();
|
||||
assert_eq!(first, 0.0f64);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Mul"
|
||||
#[test]
|
||||
fn test_mul_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Mul".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let first = z
|
||||
.to_vec1::<f64>()?
|
||||
.to_vec()
|
||||
.get(0)
|
||||
.expect("Failed to get first element")
|
||||
.clone();
|
||||
assert_eq!(first, 4.0f64);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Div"
|
||||
#[test]
|
||||
fn test_div_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Div".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let first = z
|
||||
.to_vec1::<f64>()?
|
||||
.to_vec()
|
||||
.get(0)
|
||||
.expect("Failed to get first element")
|
||||
.clone();
|
||||
|
||||
assert_eq!(first, 1.0f64);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Equal"
|
||||
#[test]
|
||||
fn test_equal_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Equal".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let first = z.to_dtype(candle::DType::U8)?.to_vec1::<u8>()?.to_vec()[0];
|
||||
assert_eq!(first, 1);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Not"
|
||||
#[test]
|
||||
fn test_not_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Not".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(&[0.], &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let first = z.to_dtype(candle::DType::U8)?.to_vec1::<u8>()?.to_vec()[0];
|
||||
assert_eq!(first, 1);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "MatMul"
|
||||
#[test]
|
||||
fn test_matmul_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "MatMul".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(
|
||||
INPUT_X.to_string(),
|
||||
Tensor::from_vec(
|
||||
//
|
||||
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
|
||||
&[2, 2],
|
||||
&Device::Cpu,
|
||||
)?,
|
||||
);
|
||||
inputs.insert(
|
||||
INPUT_Y.to_string(),
|
||||
Tensor::from_vec(
|
||||
//
|
||||
vec![5.0f32, 6.0f32, 7.0f32, 8.0f32],
|
||||
&[2, 2],
|
||||
&Device::Cpu,
|
||||
)?,
|
||||
);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
assert_eq!(results, vec![vec![19.0, 22.0], vec![43.0, 50.0]]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Reshape"
|
||||
#[test]
|
||||
fn test_reshape_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Reshape".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: INPUT_Y.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
//
|
||||
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
|
||||
&[2, 2],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let y = Tensor::from_vec(
|
||||
//
|
||||
vec![4i64],
|
||||
&[1],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
inputs.insert(INPUT_Y.to_string(), y);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let results = z.to_vec1::<f32>()?;
|
||||
|
||||
assert_eq!(results, vec![1.0, 2.0, 3.0, 4.0]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "LogSoftmax"
|
||||
#[test]
|
||||
fn test_logsoftmax_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "LogSoftmax".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: INPUT_Y.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
//
|
||||
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
|
||||
&[2, 2],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![vec![0.26894143, 0.7310586], vec![0.26894143, 0.7310586]]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Softmax"
|
||||
#[test]
|
||||
fn test_softmax_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Softmax".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: INPUT_Y.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
//
|
||||
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
|
||||
&[2, 2],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![vec![0.26894143, 0.7310586], vec![0.26894143, 0.7310586]]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Transpose"
|
||||
#[test]
|
||||
fn test_transpose_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Transpose".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: INPUT_Y.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
//
|
||||
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
|
||||
&[2, 2],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
|
||||
assert_eq!(results, vec![vec![1.0, 3.0], vec![2.0, 4.0]]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Dropout"
|
||||
#[test]
|
||||
fn test_dropout_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Dropout".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: INPUT_Y.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
let x = Tensor::from_vec(
|
||||
//
|
||||
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
|
||||
&[2, 2],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
|
||||
assert_eq!(results, vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Below are ops that are implemented but not tested yet
|
||||
|
||||
// "MaxPool"
|
||||
// #[test]
|
||||
|
||||
// "AveragePool"
|
||||
// #[test]
|
||||
|
||||
// "BatchNormalization"
|
||||
// #[test]
|
||||
|
||||
// "Squeeze"
|
||||
// #[test]
|
||||
|
||||
// "ConstantOfShape"
|
||||
// #[test]
|
||||
|
||||
// "Unsqueeze"
|
||||
// #[test]
|
||||
|
||||
// "Clip"
|
||||
// #[test]
|
||||
|
||||
// "Gather"
|
||||
// #[test]
|
||||
|
||||
// "Shape"
|
||||
// #[test]
|
||||
|
||||
// "Conv"
|
||||
// #[test]
|
||||
|
||||
// "Concat"
|
||||
// #[test]
|
||||
|
||||
// "Abs"
|
||||
// #[test]
|
||||
|
||||
// "Cos"
|
||||
// #[test]
|
||||
|
||||
// "Sin"
|
||||
// #[test]
|
||||
|
||||
// "Neg"
|
||||
// #[test]
|
||||
|
||||
// "Erf"
|
||||
// #[test]
|
||||
|
||||
// "Tanh"
|
||||
// #[test]
|
||||
|
||||
// "Sigmoid"
|
||||
// #[test]
|
||||
|
||||
// "Gelu"
|
||||
// #[test]
|
||||
|
||||
// "Relu"
|
||||
// #[test]
|
||||
|
||||
// "Constant"
|
||||
// #[test]
|
||||
|
||||
// "Cast"
|
||||
// #[test]
|
@ -15,9 +15,9 @@ crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
candle-onnx = {path= "../candle-onnx", version = "0.3.0", optional = true}
|
||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
||||
candle-onnx = {path= "../candle-onnx", version = "0.3.1", optional = true}
|
||||
half = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
|
||||
|
@ -17,7 +17,7 @@ extern crate intel_mkl_src;
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
|
||||
use ::candle::{quantized::QTensor, DType, Device, Module, Tensor, WithDType};
|
||||
|
||||
mod utils;
|
||||
use utils::wrap_err;
|
||||
|
@ -12,15 +12,16 @@ readme = "README.md"
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.1" }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
serde_plain = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
|
||||
@ -30,3 +31,4 @@ accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda"]
|
||||
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]
|
||||
metal = ["candle/metal", "candle-nn/metal"]
|
||||
|
342
candle-transformers/src/models/distilbert.rs
Normal file
342
candle-transformers/src/models/distilbert.rs
Normal file
@ -0,0 +1,342 @@
|
||||
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle_nn::{Embedding, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
pub const DTYPE: DType = DType::F32;
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
enum HiddenAct {
|
||||
Gelu,
|
||||
Relu,
|
||||
}
|
||||
|
||||
struct HiddenActLayer {
|
||||
act: HiddenAct,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl HiddenActLayer {
|
||||
fn new(act: HiddenAct) -> Self {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "hidden-act");
|
||||
Self { act, span }
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for HiddenActLayer {
|
||||
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
match self.act {
|
||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
|
||||
HiddenAct::Gelu => xs.gelu(),
|
||||
HiddenAct::Relu => xs.relu(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
enum PositionEmbeddingType {
|
||||
#[default]
|
||||
Absolute,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
vocab_size: usize,
|
||||
dim: usize,
|
||||
n_layers: usize,
|
||||
n_heads: usize,
|
||||
hidden_dim: usize,
|
||||
activation: HiddenAct,
|
||||
max_position_embeddings: usize,
|
||||
initializer_range: f64,
|
||||
pad_token_id: usize,
|
||||
#[serde(default)]
|
||||
position_embedding_type: PositionEmbeddingType,
|
||||
#[serde(default)]
|
||||
use_cache: bool,
|
||||
model_type: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
vocab_size: 30522,
|
||||
dim: 768,
|
||||
n_layers: 12,
|
||||
n_heads: 12,
|
||||
hidden_dim: 3072,
|
||||
activation: HiddenAct::Gelu,
|
||||
max_position_embeddings: 512,
|
||||
initializer_range: 0.02,
|
||||
pad_token_id: 0,
|
||||
position_embedding_type: PositionEmbeddingType::Absolute,
|
||||
use_cache: true,
|
||||
model_type: Some("distilbert".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Embeddings {
|
||||
word_embeddings: Embedding,
|
||||
position_embeddings: Embedding,
|
||||
layer_norm: LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Embeddings {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let word_embeddings =
|
||||
candle_nn::embedding(config.vocab_size, config.dim, vb.pp("word_embeddings"))?;
|
||||
let position_embeddings = candle_nn::embedding(
|
||||
config.max_position_embeddings,
|
||||
config.dim,
|
||||
vb.pp("position_embeddings"),
|
||||
)?;
|
||||
let layer_norm = layer_norm(config.dim, 1e-12, vb.pp("LayerNorm"))?;
|
||||
Ok(Self {
|
||||
word_embeddings,
|
||||
position_embeddings,
|
||||
layer_norm,
|
||||
span: tracing::span!(tracing::Level::TRACE, "embeddings"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (_bsize, seq_len) = input_ids.dims2()?;
|
||||
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
||||
let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
|
||||
let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
|
||||
let embeddings =
|
||||
input_embeddings.broadcast_add(&self.position_embeddings.forward(&position_ids)?)?;
|
||||
|
||||
let embeddings = self.layer_norm.forward(&embeddings)?;
|
||||
Ok(embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
struct MultiHeadSelfAttention {
|
||||
q_lin: Linear,
|
||||
k_lin: Linear,
|
||||
v_lin: Linear,
|
||||
out_lin: Linear,
|
||||
n_heads: usize,
|
||||
attention_head_size: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl MultiHeadSelfAttention {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let attention_head_size = config.dim / config.n_heads;
|
||||
let all_head_size = config.n_heads * attention_head_size;
|
||||
let dim = config.dim;
|
||||
let q_lin = linear(dim, all_head_size, vb.pp("q_lin"))?;
|
||||
let v_lin = linear(dim, all_head_size, vb.pp("v_lin"))?;
|
||||
let k_lin = linear(dim, all_head_size, vb.pp("k_lin"))?;
|
||||
let out_lin = linear(all_head_size, dim, vb.pp("out_lin"))?;
|
||||
Ok(Self {
|
||||
q_lin,
|
||||
k_lin,
|
||||
v_lin,
|
||||
out_lin,
|
||||
n_heads: config.n_heads,
|
||||
attention_head_size,
|
||||
span: tracing::span!(tracing::Level::TRACE, "attention"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl MultiHeadSelfAttention {
|
||||
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (bs, q_length, _dim) = hidden_states.dims3()?;
|
||||
|
||||
let dim_per_head = self.attention_head_size;
|
||||
let q = self.q_lin.forward(hidden_states)?;
|
||||
let k = self.k_lin.forward(hidden_states)?;
|
||||
let v = self.v_lin.forward(hidden_states)?;
|
||||
|
||||
let q = q
|
||||
.reshape((bs, q_length, self.n_heads, dim_per_head))?
|
||||
.transpose(1, 2)?;
|
||||
let k = k
|
||||
.reshape((bs, q_length, self.n_heads, dim_per_head))?
|
||||
.transpose(1, 2)?;
|
||||
let v = v
|
||||
.reshape((bs, q_length, self.n_heads, dim_per_head))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let q: Tensor = (q / (dim_per_head as f64).sqrt())?;
|
||||
let scores = q.matmul(&k.transpose(2, 3)?.contiguous()?)?;
|
||||
let mask = attention_mask.broadcast_as(scores.shape())?;
|
||||
|
||||
let scores = masked_fill(&scores.to_dtype(DType::F32)?, &mask, f32::NEG_INFINITY)?;
|
||||
let weights = candle_nn::ops::softmax(&scores, candle::D::Minus1)?;
|
||||
|
||||
let context = weights.matmul(&v.contiguous()?)?;
|
||||
let context = context
|
||||
.transpose(1, 2)?
|
||||
.reshape((bs, q_length, self.n_heads * dim_per_head))?
|
||||
.contiguous()?;
|
||||
let context = self.out_lin.forward(&context)?;
|
||||
|
||||
Ok(context)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
struct FFN {
|
||||
lin1: Linear,
|
||||
lin2: Linear,
|
||||
activation: HiddenActLayer,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl FFN {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let lin1 = linear(config.dim, config.hidden_dim, vb.pp("lin1"))?;
|
||||
let lin2 = linear(config.hidden_dim, config.dim, vb.pp("lin2"))?;
|
||||
Ok(Self {
|
||||
lin1,
|
||||
lin2,
|
||||
activation: HiddenActLayer::new(config.activation),
|
||||
span: tracing::span!(tracing::Level::TRACE, "ffn"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for FFN {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
hidden_states
|
||||
.apply(&self.lin1)?
|
||||
.apply(&self.activation)?
|
||||
.apply(&self.lin2)
|
||||
}
|
||||
}
|
||||
|
||||
struct TransformerBlock {
|
||||
attention: MultiHeadSelfAttention,
|
||||
sa_layer_norm: LayerNorm,
|
||||
ffn: FFN,
|
||||
output_layer_norm: LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl TransformerBlock {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let attention = MultiHeadSelfAttention::load(vb.pp("attention"), config)?;
|
||||
let sa_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("sa_layer_norm"))?;
|
||||
let ffn = FFN::load(vb.pp("ffn"), config)?;
|
||||
let output_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("output_layer_norm"))?;
|
||||
Ok(Self {
|
||||
attention,
|
||||
sa_layer_norm,
|
||||
ffn,
|
||||
output_layer_norm,
|
||||
span: tracing::span!(tracing::Level::TRACE, "layer"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TransformerBlock {
|
||||
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let sa_output = self.attention.forward(hidden_states, attention_mask)?;
|
||||
// TODO: Support cross-attention?
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
|
||||
// TODO: Support something similar to `apply_chunking_to_forward`?
|
||||
let sa_output = sa_output.broadcast_add(hidden_states)?;
|
||||
let sa_output = self.sa_layer_norm.forward(&sa_output)?;
|
||||
|
||||
let ffn_output = self.ffn.forward(&sa_output)?;
|
||||
let ffn_output = (&ffn_output + sa_output)?;
|
||||
let output = self.output_layer_norm.forward(&ffn_output)?;
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
|
||||
struct Transformer {
|
||||
layers: Vec<TransformerBlock>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Transformer {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let layers = (0..config.n_layers)
|
||||
.map(|index| TransformerBlock::load(vb.pp(&format!("layer.{index}")), config))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "encoder");
|
||||
Ok(Transformer { layers, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl Transformer {
|
||||
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut hidden_states = hidden_states.clone();
|
||||
// Use a loop rather than a fold as it's easier to modify when adding debug/...
|
||||
for layer in self.layers.iter() {
|
||||
hidden_states = layer.forward(&hidden_states, attention_mask)?;
|
||||
}
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DistilBertModel {
|
||||
embeddings: Embeddings,
|
||||
transformer: Transformer,
|
||||
pub device: Device,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl DistilBertModel {
|
||||
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let (embeddings, transformer) = match (
|
||||
Embeddings::load(vb.pp("embeddings"), config),
|
||||
Transformer::load(vb.pp("transformer"), config),
|
||||
) {
|
||||
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
|
||||
(Err(err), _) | (_, Err(err)) => {
|
||||
if let Some(model_type) = &config.model_type {
|
||||
if let (Ok(embeddings), Ok(encoder)) = (
|
||||
Embeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
|
||||
Transformer::load(vb.pp(&format!("{model_type}.transformer")), config),
|
||||
) {
|
||||
(embeddings, encoder)
|
||||
} else {
|
||||
return Err(err);
|
||||
}
|
||||
} else {
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
transformer,
|
||||
device: vb.device().clone(),
|
||||
span: tracing::span!(tracing::Level::TRACE, "model"),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let embedding_output = self.embeddings.forward(input_ids)?;
|
||||
let sequence_output = self
|
||||
.transformer
|
||||
.forward(&embedding_output, attention_mask)?;
|
||||
Ok(sequence_output)
|
||||
}
|
||||
}
|
@ -156,7 +156,6 @@ impl CausalSelfAttention {
|
||||
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
|
||||
let x0 = x.narrow(D::Minus1, 0, 1)?;
|
||||
let x1 = x.narrow(D::Minus1, 1, 1)?;
|
||||
todo!("X {x1}");
|
||||
let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
|
||||
let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
|
||||
let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?;
|
||||
@ -174,7 +173,6 @@ impl CausalSelfAttention {
|
||||
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
||||
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||
todo!("X {q}");
|
||||
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
||||
|
||||
if self.cache.use_kv_cache {
|
||||
@ -297,7 +295,6 @@ impl Block {
|
||||
let residual = x;
|
||||
let x = self.rms_1.forward(x)?;
|
||||
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
||||
todo!("---X {}", x);
|
||||
let residual = &x;
|
||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
||||
Ok(x)
|
||||
@ -330,7 +327,6 @@ impl Llama {
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, _seq_len) = x.dims2()?;
|
||||
let mut x = self.wte.forward(x)?;
|
||||
//println!("Embeddings {}", self.wte.embeddings());
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
x = block.forward(&x, index_pos, block_idx)?;
|
||||
}
|
||||
|
@ -142,10 +142,10 @@ impl RotaryEmbedding {
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
let sin = freqs.sin()?;
|
||||
let cos = freqs.cos()?;
|
||||
// todo!("{}", sin);
|
||||
Ok(Self { sin, cos })
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(
|
||||
@ -273,6 +273,10 @@ impl MHA {
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
// let view = xs.to_string();
|
||||
// if view.contains("NaN") {
|
||||
// panic!("NaN");
|
||||
// }
|
||||
let _enter = self.span.enter();
|
||||
let (b_size, seq_len, _n_embd) = xs.dims3()?;
|
||||
let qkv = self
|
||||
@ -408,3 +412,38 @@ impl MixFormerSequentialForCausalLM {
|
||||
self.blocks.iter_mut().for_each(|b| b.clear_kv_cache())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
#[test]
|
||||
fn test_rotary() {
|
||||
let dev = Device::new_metal(0).unwrap();
|
||||
for i in 0..10000 {
|
||||
let dim = 8;
|
||||
let max_seq_len = 12;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), &dev).unwrap();
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, &dev)
|
||||
.unwrap()
|
||||
.to_dtype(DType::F32)
|
||||
.unwrap()
|
||||
.reshape((max_seq_len, 1))
|
||||
.unwrap();
|
||||
let x: f32 = t.i((1, 0)).unwrap().to_scalar().unwrap();
|
||||
assert_eq!(x, 1.0);
|
||||
let x: f32 = inv_freq.i((0, 1)).unwrap().to_scalar().unwrap();
|
||||
assert_eq!(x, 0.1);
|
||||
let freqs = t.matmul(&inv_freq).unwrap();
|
||||
let x: f32 = freqs.i((1, 1)).unwrap().to_scalar().unwrap();
|
||||
assert_eq!(x, 0.1);
|
||||
let sin = freqs.sin().unwrap().contiguous().unwrap();
|
||||
let x: f32 = sin.i((1, 1)).unwrap().to_scalar().unwrap();
|
||||
assert_eq!(x, 0.099833414);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ pub mod blip;
|
||||
pub mod blip_text;
|
||||
pub mod convmixer;
|
||||
pub mod dinov2;
|
||||
pub mod distilbert;
|
||||
pub mod efficientnet;
|
||||
pub mod falcon;
|
||||
pub mod jina_bert;
|
||||
@ -29,8 +30,10 @@ pub mod segment_anything;
|
||||
pub mod stable_diffusion;
|
||||
pub mod stable_lm;
|
||||
pub mod t5;
|
||||
pub mod trocr;
|
||||
pub mod vgg;
|
||||
pub mod vit;
|
||||
pub mod whisper;
|
||||
pub mod with_tracing;
|
||||
pub mod wuerstchen;
|
||||
pub mod yi;
|
||||
|
@ -1,6 +1,7 @@
|
||||
// T5 Text Model, quantized version
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||
|
||||
use crate::models::t5::{deserialize_feed_forward_proj_activation, ActivationWithOptionalGating};
|
||||
use crate::models::with_tracing::QMatMul;
|
||||
use crate::quantized_nn::Embedding;
|
||||
pub use crate::quantized_var_builder::VarBuilder;
|
||||
@ -54,8 +55,8 @@ pub struct Config {
|
||||
dropout_rate: f64,
|
||||
layer_norm_epsilon: f64,
|
||||
initializer_factor: f64,
|
||||
#[serde(default)]
|
||||
feed_forward_proj: Activation,
|
||||
#[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")]
|
||||
pub feed_forward_proj: ActivationWithOptionalGating,
|
||||
#[serde(default = "default_tie_word_embeddings")]
|
||||
tie_word_embeddings: bool,
|
||||
#[serde(default = "default_is_decoder")]
|
||||
@ -83,7 +84,10 @@ impl Default for Config {
|
||||
dropout_rate: 0.1,
|
||||
layer_norm_epsilon: 1e-6,
|
||||
initializer_factor: 1.0,
|
||||
feed_forward_proj: Activation::Relu,
|
||||
feed_forward_proj: ActivationWithOptionalGating {
|
||||
gated: false,
|
||||
activation: Activation::Relu,
|
||||
},
|
||||
tie_word_embeddings: true,
|
||||
is_decoder: false,
|
||||
is_encoder_decoder: true,
|
||||
@ -176,7 +180,7 @@ impl T5DenseGatedActDense {
|
||||
wi_0,
|
||||
wi_1,
|
||||
wo,
|
||||
act: Activation::NewGelu,
|
||||
act: cfg.feed_forward_proj.activation,
|
||||
span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
|
||||
})
|
||||
}
|
||||
@ -205,7 +209,7 @@ impl T5LayerFF {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let layer_norm =
|
||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu {
|
||||
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj.gated {
|
||||
(
|
||||
None,
|
||||
Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
|
||||
|
@ -1,12 +1,15 @@
|
||||
use candle::{Device, Result, Tensor};
|
||||
|
||||
pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
|
||||
if steps < 1 {
|
||||
candle::bail!("cannot use linspace with steps {steps} <= 1")
|
||||
if steps == 0 {
|
||||
Tensor::from_vec(Vec::<f64>::new(), steps, &Device::Cpu)
|
||||
} else if steps == 1 {
|
||||
Tensor::from_vec(vec![start], steps, &Device::Cpu)
|
||||
} else {
|
||||
let delta = (stop - start) / (steps - 1) as f64;
|
||||
let vs = (0..steps)
|
||||
.map(|step| start + step as f64 * delta)
|
||||
.collect::<Vec<_>>();
|
||||
Tensor::from_vec(vs, steps, &Device::Cpu)
|
||||
}
|
||||
let delta = (stop - start) / (steps - 1) as f64;
|
||||
let vs = (0..steps)
|
||||
.map(|step| start + step as f64 * delta)
|
||||
.collect::<Vec<_>>();
|
||||
Tensor::from_vec(vs, steps, &Device::Cpu)
|
||||
}
|
||||
|
@ -37,6 +37,37 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Default, Clone, PartialEq)]
|
||||
pub struct ActivationWithOptionalGating {
|
||||
pub gated: bool,
|
||||
pub activation: candle_nn::Activation,
|
||||
}
|
||||
|
||||
pub fn deserialize_feed_forward_proj_activation<'de, D>(
|
||||
deserializer: D,
|
||||
) -> std::result::Result<ActivationWithOptionalGating, D::Error>
|
||||
where
|
||||
D: serde::de::Deserializer<'de>,
|
||||
{
|
||||
match String::deserialize(deserializer)?.as_str() {
|
||||
"gated-gelu" => Ok(ActivationWithOptionalGating {
|
||||
gated: true,
|
||||
activation: candle_nn::Activation::NewGelu,
|
||||
}),
|
||||
"gated-silu" => Ok(ActivationWithOptionalGating {
|
||||
gated: true,
|
||||
activation: candle_nn::Activation::Silu,
|
||||
}),
|
||||
buf => {
|
||||
let activation = serde_plain::from_str(buf).map_err(serde::de::Error::custom)?;
|
||||
Ok(ActivationWithOptionalGating {
|
||||
gated: false,
|
||||
activation,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
vocab_size: usize,
|
||||
@ -52,8 +83,8 @@ pub struct Config {
|
||||
dropout_rate: f64,
|
||||
layer_norm_epsilon: f64,
|
||||
initializer_factor: f64,
|
||||
#[serde(default)]
|
||||
feed_forward_proj: Activation,
|
||||
#[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")]
|
||||
feed_forward_proj: ActivationWithOptionalGating,
|
||||
#[serde(default = "default_tie_word_embeddings")]
|
||||
tie_word_embeddings: bool,
|
||||
#[serde(default = "default_is_decoder")]
|
||||
@ -81,7 +112,10 @@ impl Default for Config {
|
||||
dropout_rate: 0.1,
|
||||
layer_norm_epsilon: 1e-6,
|
||||
initializer_factor: 1.0,
|
||||
feed_forward_proj: Activation::Relu,
|
||||
feed_forward_proj: ActivationWithOptionalGating {
|
||||
gated: false,
|
||||
activation: Activation::Relu,
|
||||
},
|
||||
tie_word_embeddings: true,
|
||||
is_decoder: false,
|
||||
is_encoder_decoder: true,
|
||||
@ -102,7 +136,10 @@ impl Config {
|
||||
d_model: 768,
|
||||
dropout_rate: 0.1,
|
||||
eos_token_id: 1,
|
||||
feed_forward_proj: Activation::Relu,
|
||||
feed_forward_proj: ActivationWithOptionalGating {
|
||||
gated: false,
|
||||
activation: Activation::Relu,
|
||||
},
|
||||
tie_word_embeddings: true,
|
||||
initializer_factor: 1.0,
|
||||
is_decoder: false,
|
||||
@ -202,7 +239,7 @@ impl T5DenseGatedActDense {
|
||||
wi_0,
|
||||
wi_1,
|
||||
wo,
|
||||
act: Activation::NewGelu,
|
||||
act: cfg.feed_forward_proj.activation,
|
||||
span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
|
||||
})
|
||||
}
|
||||
@ -231,7 +268,7 @@ impl T5LayerFF {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let layer_norm =
|
||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu {
|
||||
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj.gated {
|
||||
(
|
||||
None,
|
||||
Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
|
||||
@ -425,7 +462,7 @@ impl T5Attention {
|
||||
self.relative_attention_max_distance as f32
|
||||
/ max_exact as f32,
|
||||
) * (num_buckets - max_exact) as f32;
|
||||
max_exact + b as u32
|
||||
u32::min(max_exact + b as u32, num_buckets - 1)
|
||||
}
|
||||
})
|
||||
.collect::<Vec<u32>>()
|
||||
|
454
candle-transformers/src/models/trocr.rs
Normal file
454
candle-transformers/src/models/trocr.rs
Normal file
@ -0,0 +1,454 @@
|
||||
use crate::models::vit::{Config, Embeddings, Encoder};
|
||||
use candle::{Result, Tensor};
|
||||
use candle_nn::{
|
||||
embedding, layer_norm, linear_no_bias, Embedding, LayerNorm, Linear, Module, VarBuilder,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct TrOCRConfig {
|
||||
pub vocab_size: usize,
|
||||
pub d_model: usize,
|
||||
pub hidden_size: usize,
|
||||
pub decoder_layers: usize,
|
||||
pub decoder_attention_heads: usize,
|
||||
pub decoder_ffn_dim: usize,
|
||||
pub activation_function: candle_nn::Activation,
|
||||
pub max_position_embeddings: usize,
|
||||
pub dropout: f64,
|
||||
pub attention_dropout: f64,
|
||||
pub activation_dropout: f64,
|
||||
pub decoder_start_token_id: u32,
|
||||
pub init_std: f64,
|
||||
pub decoder_layerdrop: f64,
|
||||
pub use_cache: bool,
|
||||
pub scale_embedding: bool,
|
||||
pub use_learned_position_embeddings: bool,
|
||||
pub layernorm_embedding: bool,
|
||||
pub pad_token_id: usize,
|
||||
pub bos_token_id: usize,
|
||||
pub eos_token_id: u32,
|
||||
pub num_attention_heads: usize,
|
||||
pub decoder_vocab_size: Option<usize>,
|
||||
}
|
||||
|
||||
impl Default for TrOCRConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
vocab_size: 50265,
|
||||
d_model: 1024,
|
||||
hidden_size: 768,
|
||||
decoder_layers: 12,
|
||||
decoder_attention_heads: 16,
|
||||
decoder_ffn_dim: 4096,
|
||||
activation_function: candle_nn::Activation::Gelu,
|
||||
max_position_embeddings: 512,
|
||||
dropout: 0.1,
|
||||
attention_dropout: 0.0,
|
||||
activation_dropout: 0.0,
|
||||
decoder_start_token_id: 2,
|
||||
init_std: 0.02,
|
||||
decoder_layerdrop: 0.0,
|
||||
use_cache: true,
|
||||
scale_embedding: false,
|
||||
use_learned_position_embeddings: true,
|
||||
layernorm_embedding: true,
|
||||
pad_token_id: 1,
|
||||
bos_token_id: 0,
|
||||
eos_token_id: 2,
|
||||
num_attention_heads: 12,
|
||||
decoder_vocab_size: Some(50265),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TrOCRLearnedPositionalEmbedding {
|
||||
offset: usize,
|
||||
weights: Embedding,
|
||||
}
|
||||
|
||||
impl TrOCRLearnedPositionalEmbedding {
|
||||
fn load(vb: VarBuilder, cfg: &TrOCRConfig) -> Result<Self> {
|
||||
let offset: usize = 2;
|
||||
let num_embeddings = cfg.max_position_embeddings;
|
||||
let embedding_dim = cfg.d_model;
|
||||
let weights = embedding(num_embeddings + offset, embedding_dim, vb)?;
|
||||
|
||||
Ok(Self { offset, weights })
|
||||
}
|
||||
|
||||
fn forward(&mut self, input_ids: &Tensor, past_key_values_length: u32) -> Result<Tensor> {
|
||||
let (b_sz, seq_len) = input_ids.dims2()?;
|
||||
|
||||
let mut positions = Tensor::arange(
|
||||
past_key_values_length,
|
||||
seq_len as u32 + past_key_values_length,
|
||||
input_ids.device(),
|
||||
)?
|
||||
.expand((b_sz, seq_len))?;
|
||||
|
||||
positions =
|
||||
positions.broadcast_add(&Tensor::new(self.offset as u32, input_ids.device())?)?;
|
||||
self.weights.forward(&positions)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TrOCRAttention {
|
||||
head_dim: usize,
|
||||
num_heads: usize,
|
||||
is_decoder: bool,
|
||||
scaling: f64,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
q_proj: Linear,
|
||||
out_proj: Linear,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl TrOCRAttention {
|
||||
fn load(
|
||||
vb: VarBuilder,
|
||||
cfg: &TrOCRConfig,
|
||||
kdim: Option<usize>,
|
||||
vdim: Option<usize>,
|
||||
) -> Result<Self> {
|
||||
let embed_dim = cfg.d_model;
|
||||
let num_heads = cfg.decoder_attention_heads;
|
||||
let head_dim = embed_dim / num_heads;
|
||||
let kdim = kdim.unwrap_or(embed_dim);
|
||||
let vdim = vdim.unwrap_or(embed_dim);
|
||||
|
||||
let k_proj = linear_no_bias(kdim, embed_dim, vb.pp("k_proj"))?;
|
||||
let v_proj = linear_no_bias(vdim, embed_dim, vb.pp("v_proj"))?;
|
||||
let q_proj = linear_no_bias(embed_dim, embed_dim, vb.pp("q_proj"))?;
|
||||
|
||||
let out_proj = linear_no_bias(embed_dim, embed_dim, vb.pp("out_proj"))?;
|
||||
Ok(Self {
|
||||
head_dim,
|
||||
num_heads,
|
||||
is_decoder: true,
|
||||
scaling: 1. / (head_dim as f64).sqrt(),
|
||||
k_proj,
|
||||
v_proj,
|
||||
q_proj,
|
||||
out_proj,
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
|
||||
fn _shape(&self, tensor: &Tensor, bsz: usize) -> Result<Tensor> {
|
||||
tensor
|
||||
.reshape((bsz, (), self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
kv_states: Option<&Tensor>,
|
||||
attn_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, tgt_len, _) = xs.dims3()?;
|
||||
let query_states = (xs.apply(&self.q_proj)? * self.scaling)?;
|
||||
let (key_states, value_states) = match kv_states {
|
||||
None => {
|
||||
let key_states = self._shape(&xs.apply(&self.k_proj)?, b_sz)?;
|
||||
let value_states = self._shape(&xs.apply(&self.v_proj)?, b_sz)?;
|
||||
if self.is_decoder {
|
||||
let kv_states = match &self.kv_cache {
|
||||
None => (key_states, value_states),
|
||||
Some((p_key_states, p_value_states)) => {
|
||||
let key_states = Tensor::cat(&[p_key_states, &key_states], 2)?;
|
||||
let value_states = Tensor::cat(&[p_value_states, &value_states], 2)?;
|
||||
(key_states, value_states)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some(kv_states.clone());
|
||||
kv_states
|
||||
} else {
|
||||
(key_states, value_states)
|
||||
}
|
||||
}
|
||||
Some(kv_states) => {
|
||||
let key_states = self._shape(&kv_states.apply(&self.k_proj)?, b_sz)?;
|
||||
let value_states = self._shape(&kv_states.apply(&self.v_proj)?, b_sz)?;
|
||||
(key_states, value_states)
|
||||
}
|
||||
};
|
||||
let proj_shape = (b_sz * self.num_heads, (), self.head_dim);
|
||||
let query_states = self._shape(&query_states, b_sz)?.reshape(proj_shape)?;
|
||||
let key_states = key_states.reshape(proj_shape)?;
|
||||
let value_states = value_states.reshape(proj_shape)?;
|
||||
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
|
||||
let attn_weights = match attn_mask {
|
||||
None => attn_weights,
|
||||
Some(attn_mask) => attn_weights.broadcast_add(attn_mask)?,
|
||||
};
|
||||
let attn_probs = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
let attn_output = attn_probs.matmul(&value_states)?;
|
||||
attn_output
|
||||
.reshape((b_sz, self.num_heads, tgt_len, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, tgt_len, self.head_dim * self.num_heads))?
|
||||
.apply(&self.out_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TrOCRDecoderLayer {
|
||||
self_attn: TrOCRAttention,
|
||||
activation_fn: candle_nn::Activation,
|
||||
self_attn_layer_norm: LayerNorm,
|
||||
encoder_attn: TrOCRAttention,
|
||||
encoder_attn_layer_norm: LayerNorm,
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
final_layer_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl TrOCRDecoderLayer {
|
||||
fn load(vb: VarBuilder, cfg: &TrOCRConfig) -> Result<Self> {
|
||||
let embed_dim = cfg.d_model;
|
||||
let self_attn = TrOCRAttention::load(vb.pp("self_attn"), cfg, None, None)?;
|
||||
let self_attn_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("self_attn_layer_norm"))?;
|
||||
let encoder_attn = TrOCRAttention::load(
|
||||
vb.pp("encoder_attn"),
|
||||
cfg,
|
||||
Some(cfg.hidden_size),
|
||||
Some(cfg.hidden_size),
|
||||
)?;
|
||||
let encoder_attn_layer_norm =
|
||||
layer_norm(embed_dim, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
|
||||
let fc1 = linear_no_bias(embed_dim, cfg.decoder_ffn_dim, vb.pp("fc1"))?;
|
||||
let fc2 = linear_no_bias(cfg.decoder_ffn_dim, embed_dim, vb.pp("fc2"))?;
|
||||
let final_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("final_layer_norm"))?;
|
||||
let activation_fn = candle_nn::Activation::Gelu;
|
||||
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
activation_fn,
|
||||
self_attn_layer_norm,
|
||||
encoder_attn,
|
||||
encoder_attn_layer_norm,
|
||||
fc1,
|
||||
fc2,
|
||||
final_layer_norm,
|
||||
})
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.self_attn.reset_kv_cache();
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: &Tensor,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let residual = xs.clone();
|
||||
let xs = self.self_attn.forward(xs, None, Some(attention_mask))?;
|
||||
let xs = (xs + residual)?;
|
||||
let mut xs = self.self_attn_layer_norm.forward(&xs)?;
|
||||
|
||||
if let Some(encoder_hidden_states) = &encoder_hidden_states {
|
||||
let residual = xs.clone();
|
||||
let encoder_attention_mask = attention_mask.clone(); // TODO
|
||||
xs = self.encoder_attn.forward(
|
||||
&xs,
|
||||
Some(encoder_hidden_states),
|
||||
Some(&encoder_attention_mask),
|
||||
)?;
|
||||
xs = (xs + residual)?;
|
||||
xs = self.encoder_attn_layer_norm.forward(&xs)?
|
||||
}
|
||||
|
||||
let residual = xs.clone();
|
||||
let xs = self.fc1.forward(&xs)?;
|
||||
let xs = self.activation_fn.forward(&xs)?;
|
||||
let xs = self.fc2.forward(&xs)?;
|
||||
let xs = (xs + residual)?;
|
||||
let xs = self.final_layer_norm.forward(&xs)?;
|
||||
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrOCRDecoder {
|
||||
layers: Vec<TrOCRDecoderLayer>,
|
||||
embed_scale: Option<f64>,
|
||||
embed_tokens: Embedding,
|
||||
embed_positions: TrOCRLearnedPositionalEmbedding,
|
||||
}
|
||||
|
||||
impl TrOCRDecoder {
|
||||
fn new(cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let vb = vb.pp("decoder.model.decoder");
|
||||
|
||||
let embed_tokens = embedding(cfg.vocab_size, cfg.d_model, vb.pp("embed_tokens"))?;
|
||||
let embed_positions = TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?;
|
||||
let mut layers = Vec::with_capacity(cfg.decoder_layers);
|
||||
let vb_l = vb.pp("layers");
|
||||
for idx in 0..cfg.decoder_layers {
|
||||
let layer = TrOCRDecoderLayer::load(vb_l.pp(idx), cfg)?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let embed_scale = if cfg.scale_embedding {
|
||||
Some((cfg.d_model as f64).sqrt())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
layers,
|
||||
embed_scale,
|
||||
embed_tokens,
|
||||
embed_positions,
|
||||
})
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.layers.iter_mut().for_each(|l| l.reset_kv_cache())
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
encoder_xs: Option<&Tensor>,
|
||||
past_kv_len: usize,
|
||||
attn_mask: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let embed_pos = self.embed_positions.forward(xs, past_kv_len as u32)?;
|
||||
let xs = xs.apply(&self.embed_tokens)?;
|
||||
|
||||
let xs = match self.embed_scale {
|
||||
None => xs,
|
||||
Some(scale) => (xs * scale)?,
|
||||
};
|
||||
|
||||
let mut xs = xs.broadcast_add(&embed_pos)?;
|
||||
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attn_mask, encoder_xs)?;
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrOCREncoder {
|
||||
embeddings: Embeddings,
|
||||
encoder: Encoder,
|
||||
layernorm: LayerNorm,
|
||||
}
|
||||
|
||||
impl TrOCREncoder {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_v = vb.pp("encoder");
|
||||
|
||||
let embeddings = Embeddings::new(cfg, false, vb_v.pp("embeddings"))?;
|
||||
|
||||
let encoder = Encoder::new(cfg, vb_v.pp("encoder"))?;
|
||||
let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_v.pp("layernorm"))?;
|
||||
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
encoder,
|
||||
layernorm,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let embedding_output = self.embeddings.forward(xs, None, false)?;
|
||||
let encoder_outputs = self.encoder.forward(&embedding_output)?;
|
||||
|
||||
self.layernorm.forward(&encoder_outputs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrOCRForCausalLM {
|
||||
decoder: TrOCRDecoder,
|
||||
output_projection: Linear,
|
||||
}
|
||||
|
||||
impl TrOCRForCausalLM {
|
||||
pub fn new(decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let decoder = TrOCRDecoder::new(decoder_cfg, vb.clone())?;
|
||||
let output_projection =
|
||||
candle_nn::Linear::new(decoder.embed_tokens.embeddings().clone(), None);
|
||||
Ok(Self {
|
||||
decoder,
|
||||
output_projection,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
encoder_xs: Option<&Tensor>,
|
||||
past_kv_len: usize,
|
||||
attn_mask: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let xs = self
|
||||
.decoder
|
||||
.forward(xs, encoder_xs, past_kv_len, attn_mask)?;
|
||||
let xs = xs.apply(&self.output_projection)?;
|
||||
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.decoder.reset_kv_cache();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrOCRModel {
|
||||
encoder: TrOCREncoder,
|
||||
decoder: TrOCRForCausalLM,
|
||||
}
|
||||
|
||||
impl TrOCRModel {
|
||||
pub fn new(encoder_cfg: &Config, decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let encoder = TrOCREncoder::new(encoder_cfg, vb.clone())?;
|
||||
let decoder = TrOCRForCausalLM::new(decoder_cfg, vb)?;
|
||||
Ok(Self { encoder, decoder })
|
||||
}
|
||||
|
||||
pub fn encoder(&mut self) -> &mut TrOCREncoder {
|
||||
&mut self.encoder
|
||||
}
|
||||
|
||||
pub fn decoder(&mut self) -> &mut TrOCRForCausalLM {
|
||||
&mut self.decoder
|
||||
}
|
||||
|
||||
pub fn decode(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
encoder_xs: &Tensor,
|
||||
past_kv_len: usize,
|
||||
) -> Result<Tensor> {
|
||||
let seq_len = xs.dim(1)?;
|
||||
let mask: Vec<_> = (0..seq_len)
|
||||
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
||||
.collect();
|
||||
let mask = Tensor::from_vec(mask, (seq_len, seq_len), xs.device())?;
|
||||
|
||||
self.decoder
|
||||
.forward(xs, Some(encoder_xs), past_kv_len, &mask)
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
self.decoder.reset_kv_cache();
|
||||
}
|
||||
}
|
@ -6,16 +6,16 @@ use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/configuration_vit.py
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
hidden_size: usize,
|
||||
num_hidden_layers: usize,
|
||||
num_attention_heads: usize,
|
||||
intermediate_size: usize,
|
||||
hidden_act: candle_nn::Activation,
|
||||
layer_norm_eps: f64,
|
||||
image_size: usize,
|
||||
patch_size: usize,
|
||||
num_channels: usize,
|
||||
qkv_bias: bool,
|
||||
pub hidden_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub hidden_act: candle_nn::Activation,
|
||||
pub layer_norm_eps: f64,
|
||||
pub image_size: usize,
|
||||
pub patch_size: usize,
|
||||
pub num_channels: usize,
|
||||
pub qkv_bias: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@ -34,6 +34,21 @@ impl Config {
|
||||
qkv_bias: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn microsoft_trocr_base_handwritten() -> Self {
|
||||
Self {
|
||||
hidden_size: 768,
|
||||
num_hidden_layers: 12,
|
||||
num_attention_heads: 12,
|
||||
intermediate_size: 3072,
|
||||
hidden_act: candle_nn::Activation::Gelu,
|
||||
layer_norm_eps: 1e-12,
|
||||
image_size: 384,
|
||||
patch_size: 16,
|
||||
num_channels: 3,
|
||||
qkv_bias: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -76,7 +91,7 @@ impl Module for PatchEmbeddings {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Embeddings {
|
||||
pub struct Embeddings {
|
||||
cls_token: Tensor,
|
||||
mask_token: Option<Tensor>,
|
||||
patch_embeddings: PatchEmbeddings,
|
||||
@ -85,7 +100,7 @@ struct Embeddings {
|
||||
}
|
||||
|
||||
impl Embeddings {
|
||||
fn new(cfg: &Config, use_mask_token: bool, vb: VarBuilder) -> Result<Self> {
|
||||
pub fn new(cfg: &Config, use_mask_token: bool, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_size = cfg.hidden_size;
|
||||
let cls_token = vb.get((1, 1, hidden_size), "cls_token")?;
|
||||
let mask_token = if use_mask_token {
|
||||
@ -115,7 +130,7 @@ impl Embeddings {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn forward(
|
||||
pub fn forward(
|
||||
&self,
|
||||
pixel_values: &Tensor,
|
||||
bool_masked_pos: Option<&Tensor>,
|
||||
@ -324,12 +339,12 @@ impl Module for Layer {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Encoder {
|
||||
pub struct Encoder {
|
||||
layers: Vec<Layer>,
|
||||
}
|
||||
|
||||
impl Encoder {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb = vb.pp("layer");
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
for i in 0..cfg.num_hidden_layers {
|
||||
|
@ -58,8 +58,7 @@ fn dft<T: Float>(inp: &[T]) -> Vec<T> {
|
||||
let n = inp.len();
|
||||
let two_pi = T::PI() + T::PI();
|
||||
|
||||
let mut out = Vec::new();
|
||||
out.reserve(2 * n);
|
||||
let mut out = Vec::with_capacity(2 * n);
|
||||
let n_t = T::from(n).unwrap();
|
||||
for k in 0..n {
|
||||
let k_t = T::from(k).unwrap();
|
||||
|
@ -43,4 +43,4 @@ pub const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
|
||||
pub const TRANSLATE_TOKEN: &str = "<|translate|>";
|
||||
pub const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
|
||||
pub const EOT_TOKEN: &str = "<|endoftext|>";
|
||||
pub const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
|
||||
pub const NO_SPEECH_TOKENS: [&str; 2] = ["<|nocaptions|>", "<|nospeech|>"];
|
||||
|
381
candle-transformers/src/models/yi.rs
Normal file
381
candle-transformers/src/models/yi.rs
Normal file
@ -0,0 +1,381 @@
|
||||
/// https://huggingface.co/01-ai/Yi-6B/blob/main/modeling_yi.py
|
||||
use crate::models::with_tracing::{linear_no_bias, Linear};
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{Activation, VarBuilder};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Config {
|
||||
pub(crate) vocab_size: usize,
|
||||
pub(crate) hidden_size: usize,
|
||||
pub(crate) intermediate_size: usize,
|
||||
pub(crate) num_hidden_layers: usize,
|
||||
pub(crate) num_attention_heads: usize,
|
||||
pub(crate) num_key_value_heads: usize,
|
||||
pub(crate) hidden_act: Activation,
|
||||
pub(crate) max_position_embeddings: usize,
|
||||
pub(crate) rms_norm_eps: f64,
|
||||
pub(crate) rope_theta: f64,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn config_6b() -> Self {
|
||||
Self {
|
||||
vocab_size: 64000,
|
||||
hidden_size: 4096,
|
||||
intermediate_size: 11008,
|
||||
num_hidden_layers: 32,
|
||||
num_attention_heads: 32,
|
||||
num_key_value_heads: 4,
|
||||
hidden_act: Activation::Silu,
|
||||
max_position_embeddings: 4096,
|
||||
rms_norm_eps: 1e-5,
|
||||
rope_theta: 5_000_000.,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn config_34b() -> Self {
|
||||
Self {
|
||||
vocab_size: 64000,
|
||||
hidden_size: 7168,
|
||||
intermediate_size: 20480,
|
||||
num_hidden_layers: 60,
|
||||
num_attention_heads: 56,
|
||||
num_key_value_heads: 8,
|
||||
hidden_act: Activation::Silu,
|
||||
max_position_embeddings: 4096,
|
||||
rms_norm_eps: 1e-5,
|
||||
rope_theta: 5_000_000.,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RmsNorm {
|
||||
inner: candle_nn::RmsNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
||||
let inner = candle_nn::rms_norm(size, eps, vb)?;
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for RmsNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
||||
let last_dim = xs.dim(D::Minus1)?;
|
||||
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
|
||||
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
|
||||
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
|
||||
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
struct MLP {
|
||||
gate_proj: Linear,
|
||||
up_proj: Linear,
|
||||
down_proj: Linear,
|
||||
act_fn: Activation,
|
||||
}
|
||||
|
||||
impl MLP {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
let intermediate_sz = cfg.intermediate_size;
|
||||
let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
|
||||
let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
|
||||
let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
|
||||
Ok(Self {
|
||||
gate_proj,
|
||||
up_proj,
|
||||
down_proj,
|
||||
act_fn: cfg.hidden_act,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||
let rhs = xs.apply(&self.up_proj)?;
|
||||
(lhs * rhs)?.apply(&self.down_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
o_proj: Linear,
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
num_kv_groups: usize,
|
||||
head_dim: usize,
|
||||
hidden_size: usize,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let num_kv_heads = cfg.num_key_value_heads;
|
||||
let num_kv_groups = num_heads / num_kv_heads;
|
||||
let head_dim = hidden_sz / num_heads;
|
||||
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
||||
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
||||
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
||||
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
num_kv_groups,
|
||||
head_dim,
|
||||
hidden_size: hidden_sz,
|
||||
rotary_emb,
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_kv_groups;
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||
xs.unsqueeze(2)?
|
||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, q_len, _) = xs.dims3()?;
|
||||
|
||||
let query_states = self.q_proj.forward(xs)?;
|
||||
let key_states = self.k_proj.forward(xs)?;
|
||||
let value_states = self.v_proj.forward(xs)?;
|
||||
|
||||
let query_states = query_states
|
||||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let key_states = key_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let value_states = value_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let (query_states, key_states) =
|
||||
self.rotary_emb
|
||||
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
|
||||
|
||||
let (key_states, value_states) = match &self.kv_cache {
|
||||
None => (key_states, value_states),
|
||||
Some((prev_k, prev_v)) => {
|
||||
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
|
||||
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
|
||||
(key_states, value_states)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
let key_states = self.repeat_kv(key_states)?;
|
||||
let value_states = self.repeat_kv(value_states)?;
|
||||
|
||||
let attn_output = {
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||
|
||||
let attn_weights = match attention_mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||
};
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
attn_weights.matmul(&value_states)?
|
||||
};
|
||||
attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, q_len, self.hidden_size))?
|
||||
.apply(&self.o_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderLayer {
|
||||
self_attn: Attention,
|
||||
mlp: MLP,
|
||||
ln1: RmsNorm,
|
||||
ln2: RmsNorm,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||
let ln2 = RmsNorm::new(
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
mlp,
|
||||
ln1,
|
||||
ln2,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self.ln1.forward(xs)?;
|
||||
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
|
||||
let xs = (xs + residual)?;
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(&self.ln2)?.apply(&self.mlp)?;
|
||||
residual + xs
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embed_tokens: candle_nn::Embedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
norm: RmsNorm,
|
||||
lm_head: Linear,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_m = vb.pp("model");
|
||||
let embed_tokens =
|
||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_l = vb_m.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers {
|
||||
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
||||
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
lm_head,
|
||||
device: vb.device().clone(),
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
&self,
|
||||
b_size: usize,
|
||||
tgt_len: usize,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
// Sliding window mask?
|
||||
let mask: Vec<_> = (0..tgt_len)
|
||||
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||
.to_dtype(self.dtype)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (b_size, seq_len) = input_ids.dims2()?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
||||
Some(mask)
|
||||
};
|
||||
let mut xs = self.embed_tokens.forward(input_ids)?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||
}
|
||||
xs.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.norm)?
|
||||
.apply(&self.lm_head)
|
||||
}
|
||||
}
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
||||
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
||||
num-traits = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
||||
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
num-traits = { workspace = true }
|
||||
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
||||
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
||||
num-traits = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
||||
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
num-traits = { workspace = true }
|
||||
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
||||
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
||||
num-traits = { workspace = true }
|
||||
|
||||
# App crates.
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
||||
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
||||
num-traits = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
||||
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
|
||||
num-traits = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
|
||||
|
@ -59,8 +59,7 @@ fn dft<T: Float>(inp: &[T]) -> Vec<T> {
|
||||
let n = inp.len();
|
||||
let two_pi = T::PI() + T::PI();
|
||||
|
||||
let mut out = Vec::new();
|
||||
out.reserve(2 * n);
|
||||
let mut out = Vec::with_capacity(2 * n);
|
||||
let n_t = T::from(n).unwrap();
|
||||
for k in 0..n {
|
||||
let k_t = T::from(k).unwrap();
|
||||
|
@ -129,7 +129,13 @@ impl Decoder {
|
||||
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
|
||||
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
|
||||
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
|
||||
let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
|
||||
let no_speech_token = m::NO_SPEECH_TOKENS
|
||||
.iter()
|
||||
.find_map(|token| token_id(&tokenizer, token).ok());
|
||||
let no_speech_token = match no_speech_token {
|
||||
None => anyhow::bail!("unable to find any non-speech token"),
|
||||
Some(n) => n,
|
||||
};
|
||||
let seed = 299792458;
|
||||
Ok(Self {
|
||||
model,
|
||||
|
@ -9,8 +9,8 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
|
||||
num-traits = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
@ -7,7 +7,7 @@ keywords.workspace = true
|
||||
categories.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
|
||||
rand = { workspace = true }
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
use candle::{
|
||||
quantized::{self, k_quants, GgmlDType, GgmlType},
|
||||
test_utils::to_vec2_round,
|
||||
Device, Result, Tensor,
|
||||
Device, Module, Result, Tensor,
|
||||
};
|
||||
|
||||
use wasm_bindgen_test::*;
|
||||
|
Reference in New Issue
Block a user