Compare commits

..

28 Commits

Author SHA1 Message Date
eb24875856 Reworked affine and it works ? No idea how it's different. 2023-11-08 02:37:20 +01:00
3f662e54cd Reworked affine and it works ? No idea how it's different. 2023-11-08 02:34:08 +01:00
480a3e22e6 Adding cast + binary kernels. 2023-11-07 23:45:53 +01:00
0c24a885a6 Updated everything and output a trace. 2023-11-07 21:12:42 +01:00
76d3116f5d Broken metal ? 2023-11-07 14:20:13 +01:00
1367e0278b pesky bfloat type 2023-11-07 10:26:59 +01:00
7ff17d92b3 Finished the unary
- Added proper kernel type check (through modules + macro)
- split contiguous and strided into 2 different kernels
- Verified on long range + strided values.
2023-11-06 23:12:12 +01:00
cd68c96803 Going overbounds will break other kernels running from other threads. 2023-11-06 17:29:58 +01:00
4d87305c48 Float -> half / bfloat conversion in unary 2023-11-06 17:09:39 +01:00
677495f9b8 Working but failing tests because of threadgroup. 2023-11-06 17:04:47 +01:00
dedc8c3656 Writing unary as macro instead, protecting bfloat type with proper metal version. 2023-11-06 15:36:48 +01:00
63cce76b84 Improve metal kernel loading and associated errors 2023-11-06 09:48:18 +01:00
634a4e7168 BlitEncoder added to affine for copying buffer contents quickly. 2023-11-06 08:23:36 +01:00
8124d1003f Affine metal kernel works. Need to extract buffer contents based on layout offset (like CudaSlice.slice) for candle intergration 2023-11-06 04:46:56 +01:00
6d4c8c0707 Use metal encode_gemm 2023-11-06 03:27:22 +01:00
e6d33a8efb Remove unused utils.metal 2023-11-06 03:26:21 +01:00
c921cc3784 Add Arc to metalstorage buffer for quick cloning 2023-11-04 09:03:23 +01:00
d4d6850c78 Impl index_add via template for all types 2023-11-04 08:46:08 +01:00
e708d35e7f index_add works 2023-11-03 21:12:52 +01:00
0794e70a19 Debugging index_add. 2023-11-03 12:09:05 +01:00
f57e3164ae Implemented cos for now. 2023-11-03 01:24:51 +01:00
7161002a34 Finished scaffolding, lots of TODOs
- Most kernels just copy themselfs to get the shapes correct
- Matmul works only in 1 case and simply empty allocates otherwise
- Logits and randomized to make the demo finish itself.

Performance is quite bad (30ms/token), but lot's of prints and allocs and some actual sending to metal.

Couln't get it super high by removing the obvious blockers (println + the actual running matmuls).

Allocations takes between 1us and 100us and seems very stable, Maybe metal doesn't really have a smart allocator and we'll need to own it.
2023-11-02 15:32:28 +01:00
82cce52e73 Rename candle-metal -> candle-metal-kernels 2023-11-02 09:53:29 +01:00
71fcb31873 Owned command buffer now. 2023-11-01 18:03:53 +01:00
198009453a Matmul (no batch, no strided, f32, f32 only) sort of done. 2023-11-01 17:36:51 +01:00
492d164235 More scaffolding, now need to implement matmul (for precompute_cos_sin to work). 2023-11-01 16:54:09 +01:00
2d84c16fed First pass (Quantized scaffolding work done + quantized example scaffolding). 2023-11-01 15:10:11 +01:00
4525b7b52a Initial setup 2023-10-31 18:09:10 +01:00
134 changed files with 1813 additions and 9852 deletions

View File

@ -59,7 +59,7 @@ jobs:
- name: Install Rust Stable
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
- uses: Swatinem/rust-cache@v2
- run: apt-get update -y && apt-get install libssl-dev protobuf-compiler -y
- run: apt-get update -y && apt-get install libssl-dev -y
- name: Test (cuda)
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
stop-runner:

Binary file not shown.

View File

@ -39,12 +39,6 @@ jobs:
path: ~/.cargo/registry
key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}
- name: Install Protoc
uses: arduino/setup-protoc@v2
with:
version: "25.0"
repo-token: ${{ secrets.GITHUB_TOKEN }}
- name: Install
working-directory: ./candle-pyo3
run: |
@ -52,7 +46,7 @@ jobs:
source .env/bin/activate
pip install -U pip
pip install pytest maturin black
python -m maturin develop -r --features onnx
python -m maturin develop -r
- name: Check style
working-directory: ./candle-pyo3

View File

@ -10,16 +10,11 @@ members = [
"candle-wasm-examples/*",
"candle-wasm-tests",
]
exclude = [
"candle-flash-attn",
"candle-kernels",
"candle-metal-kernels",
"candle-onnx",
]
exclude = ["candle-flash-attn", "candle-kernels"]
resolver = "2"
[workspace.package]
version = "0.3.1"
version = "0.3.0"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -51,7 +46,6 @@ rayon = "1.7.0"
rusttype = { version = "0.9", default-features = false }
safetensors = "0.3.1"
serde = { version = "1.0.171", features = ["derive"] }
serde_plain = "1.0.2"
serde_json = "1.0.99"
thiserror = "1"
tokenizers = { version = "0.13.4", default-features = false }
@ -61,7 +55,7 @@ tracing-subscriber = "0.3.7"
wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false }
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
[profile.release-with-debug]
inherits = "release"

View File

@ -51,7 +51,7 @@ For more advanced examples, please have a look at the following section.
These online demos run entirely in your browser:
- [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and
object recognition.
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): speech recognition.
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): text to speech.
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
@ -69,8 +69,6 @@ We also provide a some command line based examples using state of the art models
performance larger than all publicly available 13b models as of 2023-09-28.
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
(English/Chinese) general LLMs with 6b and 34b parameters.
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
the LLaMA model using the same quantization techniques as
[llama.cpp](https://github.com/ggerganov/llama.cpp).
@ -139,17 +137,12 @@ And then head over to
<!--- ANCHOR: useful_libraries --->
## Useful External Resources
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): A
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): a
very detailed tutorial showing how to convert a PyTorch model to Candle.
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): Efficient and ergonomic LoRA implemenation for Candle. `candle-lora` has
out-of-the-box LoRA support for many models from Candle, which can be found [here](https://github.com/EricLBuehler/candle-lora/tree/master/candle-lora-transformers/examples).
- [`optimisers`](https://github.com/KGrewal1/optimisers): A collection of optimisers
- [`optimisers`](https://github.com/KGrewal1/optimisers): a collection of optimisers
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
serving local LLMs including an OpenAI compatible API server.
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): a LoRA implementation
that conforms to the official `peft` implementation.
If you have an addition to this list, please submit a pull request.
@ -175,22 +168,16 @@ If you have an addition to this list, please submit a pull request.
- Mistral 7b v0.1.
- StableLM-3B-4E1T.
- Replit-code-v1.5-3B.
- T5.
- Bert.
- Yi-6B and Yi-34B.
- Quantized LLMs.
- Llama 7b, 13b, 70b, as well as the chat and code variants.
- Mistral 7b, and 7b instruct.
- Zephyr 7b a and b (Mistral based).
- OpenChat 3.5 (Mistral based).
- Text to text.
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
- Marian MT (Machine Translation).
- Whisper (multi-lingual support).
- Text to image.
- Stable Diffusion v1.5, v2.1, XL v1.0.
- Wurstchen v2.
- Image to text.
- BLIP.
- Text to text.
- Marian MT (Machine Translation).
- Computer Vision Models.
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
- yolo-v3, yolo-v8.
@ -231,7 +218,6 @@ Cheatsheet:
- [candle-datasets](./candle-datasets/): Datasets and data loaders.
- [candle-transformers](./candle-transformers): transformers-related utilities.
- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer.
- [candle-onnx](./candle-onnx/): ONNX model evaluation.
## FAQ

View File

@ -11,11 +11,11 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.3.1" }
candle-nn = { path = "../candle-nn", version = "0.3.1" }
candle-transformers = { path = "../candle-transformers", version = "0.3.1" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true }
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }

View File

@ -12,8 +12,9 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.3.1", optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.1", optional = true }
tracing = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
metal = { workspace = true, optional = true}
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
@ -41,4 +42,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels"]
metal = ["dep:candle-metal-kernels", "dep:metal"]

View File

@ -8,10 +8,11 @@ use anyhow::Result;
use candle_core::{Device, Tensor};
fn main() -> Result<()> {
let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?;
let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?;
let new_a = a.slice_scatter(&b, 1, 2)?;
assert_eq!(a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
assert_eq!(new_a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
let start = std::time::Instant::now();
let res = inp.conv2d(&w, 0, 1, 1, 1)?;
println!("{:?}", start.elapsed());
println!("{res:?}");
Ok(())
}

View File

@ -39,14 +39,6 @@ pub trait BackendStorage: Sized {
_params: &crate::conv::ParamsConv1D,
) -> Result<Self>;
fn conv_transpose1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self>;
fn conv2d(
&self,
_l: &Layout,

View File

@ -15,17 +15,6 @@ fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result
}
}
thread_local! {
static CANDLE_GRAD_DO_NOT_DETACH: bool = {
match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") {
Ok(s) => {
!s.is_empty() && s != "0"
},
Err(_) => false,
}
}
}
impl Tensor {
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
/// elements having dependencies on the latter ones, e.g. the first element if any is the
@ -68,11 +57,6 @@ impl Tensor {
kernel: rhs,
..
}
| Op::ConvTranspose1D {
arg: lhs,
kernel: rhs,
..
}
| Op::Conv2D {
arg: lhs,
kernel: rhs,
@ -166,16 +150,10 @@ impl Tensor {
if node.is_variable() {
continue;
}
let grad = grads
.remove(node)
.expect("candle internal error - grad not populated");
// https://github.com/huggingface/candle/issues/1241
// Ideally, we would make these operations in place where possible to ensure that we
// do not have to allocate too often. Here we just call `.detach` to avoid computing
// the backprop graph of the backprop itself. This would be an issue for second order
// derivatives but these are out of scope at the moment.
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
let grad = if do_not_detach { grad } else { grad.detach()? };
let grad = grads.remove(node).unwrap();
// TODO: We should perform all these operations in place (or at least not track the
// whole graph). The only drawback would be if we wanted to support grad of grad but
// this is out of scope.
if let Some(op) = node.op() {
match op {
Op::Binary(lhs, rhs, BinaryOp::Add) => {
@ -230,44 +208,7 @@ impl Tensor {
let f_grad = pred.where_cond(&zeros, &grad)?;
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D {
arg,
kernel,
padding,
stride,
dilation,
} => {
// The output height for conv_transpose1d is:
// (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1
let grad_l_in = grad.dim(2)?;
let k_size = kernel.dim(2)?;
let out_size =
(grad_l_in - 1) * stride + dilation * (k_size - 1) + 1 - 2 * padding;
let out_padding = arg.dim(2)? - out_size;
let grad_arg = grad.conv_transpose1d(
kernel,
*padding,
out_padding,
*stride,
*dilation,
)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
let grad_kernel = arg
.transpose(0, 1)?
.conv1d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
let (_, _, k0) = kernel.dims3()?;
let (_, _, g_k0) = grad_kernel.dims3()?;
let grad_kernel = if g_k0 != k0 {
grad_kernel.narrow(2, 0, k0)?
} else {
grad_kernel
};
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
Op::Conv2D {
arg,
kernel,
@ -306,9 +247,6 @@ impl Tensor {
};
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
op: "conv-transpose1d",
})?,
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
op: "conv-transpose2d",
})?,
@ -549,38 +487,16 @@ impl Tensor {
+ 0.5)?;
*sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
}
Op::Unary(arg, UnaryOp::Erf) => {
let sum_grad = grads.or_insert(arg)?;
// d/dx erf(x) = 2/sqrt(pi) * e^(-x^2)
let erf_grad =
(2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?;
*sum_grad = sum_grad.add(&(&grad * erf_grad)?)?
}
Op::Unary(arg, UnaryOp::GeluErf) => {
let sum_grad = grads.or_insert(arg)?;
// d/dx gelu_erf(x) = 0.5 + 0.398942 e^(-x^2/2) x + 0.5 erf(x/sqrt(2))
let neg_half_square = (arg.sqr()?.neg()? / 2.)?;
let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?;
let arg_scaled_sqrt = (arg / 2f64.sqrt())?;
let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?;
let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?;
*sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?;
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
Op::Unary(_, UnaryOp::GeluErf) => {
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
}
Op::Unary(arg, UnaryOp::Relu) => {
let sum_grad = grads.or_insert(arg)?;
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
}
Op::Elu(arg, alpha) => {
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
let sum_grad = grads.or_insert(arg)?;
let zeros = arg.zeros_like()?;
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?;
let combined_mask = (positive_mask + negative_exp_mask)?;
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
}
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
Op::Powf(arg, e) => {
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
let sum_grad = grads.or_insert(arg)?;

View File

@ -25,33 +25,6 @@ impl ParamsConv1D {
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConvTranspose1D {
pub(crate) b_size: usize,
pub(crate) l_in: usize,
pub(crate) c_out: usize,
pub(crate) c_in: usize,
pub(crate) k_size: usize,
pub(crate) padding: usize,
pub(crate) output_padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
}
impl ParamsConvTranspose1D {
pub(crate) fn l_out(&self) -> usize {
(self.l_in - 1) * self.stride - 2 * self.padding
+ self.dilation * (self.k_size - 1)
+ self.output_padding
+ 1
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
let l_out = self.l_out();
vec![self.b_size, self.c_out, l_out]
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum CudnnFwdAlgo {
ImplicitGemm,
@ -187,49 +160,6 @@ impl Tensor {
}
}
/// Applies a 1D transposed convolution over the input tensor.
pub fn conv_transpose1d(
&self,
kernel: &Self,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
) -> Result<Self> {
let (b_size, c_in, l_in) = self.dims3()?;
let (c_in_k, c_out, k_size) = kernel.dims3()?;
if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
}
let params = ParamsConvTranspose1D {
b_size,
l_in,
k_size,
c_out,
c_in,
padding,
output_padding,
stride,
dilation,
};
let storage = self.storage().conv_transpose1d(
self.layout(),
&kernel.storage(),
kernel.layout(),
&params,
)?;
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
arg,
kernel,
padding: params.padding,
output_padding: params.output_padding,
stride: params.stride,
dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
}
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
let storage =
self.storage()

View File

@ -1256,74 +1256,6 @@ impl Map1 for Im2Col {
}
}
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
impl<'a> Map2 for ConvTranspose1D<'a> {
const OP: &'static str = "conv_transpose1d";
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
let inp = &inp[inp_l.start_offset()..];
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
let l_out = p.l_out();
// Output shape: [b_size, c_out, l_out].
let dst_elems = p.c_out * l_out * p.b_size;
let dst = vec![T::zero(); dst_elems];
let dst_s0 = p.c_out * l_out;
let dst_s1 = l_out;
let dst_s2 = 1;
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
let cont_s0 = p.l_in * p.c_in;
let cont_s1 = p.c_in;
for b_idx in 0..p.b_size {
for l_idx in 0..p.l_in {
for c_idx in 0..p.c_in {
let src_idx = b_idx * inp_s0 + c_idx * inp_s1 + l_idx * inp_s2;
let dst_idx = b_idx * cont_s0 + l_idx * cont_s1 + c_idx;
inp_cont[dst_idx] = inp[src_idx]
}
}
}
for k_idx in 0..p.k_size {
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
let k_cont = (0..p.c_in)
.map(|c_in_idx| k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_idx * k_s2])
.collect::<Vec<_>>();
for b_idx in 0..p.b_size {
for l_idx in 0..p.l_in {
let out_idx = l_idx * p.stride + k_idx * p.dilation;
if out_idx < p.padding {
continue;
}
let out_idx = out_idx - p.padding;
if out_idx < l_out {
let inp_cont = &inp_cont[b_idx * cont_s0 + l_idx * cont_s1..];
let dst_idx = b_idx * dst_s0 + out_idx * dst_s2 + dst_c_idx * dst_s1;
let mut d = T::zero();
unsafe {
T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
}
let dst_p = dst.as_ptr();
// Safety: dst_idx are uniques per dst_c_idx which is used to
// parallelise the different tasks so no two threads can try to
// write at the same location.
unsafe {
let ptr = dst_p.add(dst_idx) as *mut T;
*ptr += d
}
}
}
}
})
}
Ok(dst)
}
}
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
impl<'a> Map2 for Conv2D<'a> {
@ -2503,16 +2435,6 @@ impl BackendStorage for CpuStorage {
Ok(res_t)
}
fn conv_transpose1d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
}
fn conv2d(
&self,
l: &Layout,

View File

@ -1808,16 +1808,6 @@ impl BackendStorage for CudaStorage {
Ok(res_t)
}
fn conv_transpose1d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
todo!()
}
#[cfg(not(feature = "cudnn"))]
fn conv2d(
&self,

View File

@ -1,6 +1,6 @@
use crate::backend::BackendDevice;
use crate::cpu_backend::CpuDevice;
use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
use crate::{bail, CpuStorage, DType, Result, Shape, Storage, WithDType};
/// A `DeviceLocation` represents a physical device whereas multiple `Device`
/// can live on the same location (typically for cuda devices).
@ -8,7 +8,7 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
pub enum DeviceLocation {
Cpu,
Cuda { gpu_id: usize },
Metal { gpu_id: usize },
Metal,
}
#[derive(Debug, Clone)]
@ -105,14 +105,14 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4:
impl<S: NdArray> NdArray for Vec<S> {
fn shape(&self) -> Result<Shape> {
if self.is_empty() {
crate::bail!("empty array")
bail!("empty array")
}
let shape0 = self[0].shape()?;
let n = self.len();
for v in self.iter() {
let shape = v.shape()?;
if shape != shape0 {
crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
bail!("two elements have different shapes {shape:?} {shape0:?}")
}
}
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
@ -146,7 +146,6 @@ impl Device {
match (self, rhs) {
(Self::Cpu, Self::Cpu) => true,
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
(Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
_ => false,
}
}
@ -167,10 +166,6 @@ impl Device {
matches!(self, Self::Cuda(_))
}
pub fn is_metal(&self) -> bool {
matches!(self, Self::Metal(_))
}
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
if crate::utils::cuda_is_available() {
Self::new_cuda(ordinal)
@ -192,19 +187,13 @@ impl Device {
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
if dtype == DType::F16 || dtype == DType::BF16 {
let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
} else {
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cuda(storage))
}
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(_device) => {
// let storage = device.rand_uniform(shape, dtype, lo, up)?;
// Ok(Storage::Metal(storage))
crate::bail!("Metal rand_uniform not implemented")
bail!("Metal rand_uniform not implemented")
}
}
}
@ -231,14 +220,8 @@ impl Device {
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
if dtype == DType::F16 || dtype == DType::BF16 {
let storage = device.rand_normal(shape, DType::F32, mean, std)?;
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
} else {
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage))
}
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.rand_normal(shape, dtype, mean, std)?;

View File

@ -14,9 +14,7 @@ impl Tensor {
crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id)
}
crate::DeviceLocation::Metal { gpu_id } => {
format!(", metal:{}", gpu_id)
}
_ => todo!(),
};
write!(f, "Tensor[")?;
@ -479,9 +477,7 @@ impl std::fmt::Display for Tensor {
crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id)
}
crate::DeviceLocation::Metal { gpu_id } => {
format!(", metal:{}", gpu_id)
}
crate::DeviceLocation::Metal => todo!(),
};
write!(

View File

@ -79,16 +79,6 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
fn conv_transpose1d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn conv2d(
&self,
_: &Layout,

View File

@ -8,18 +8,6 @@ pub struct MetalDevice;
#[derive(Debug)]
pub struct MetalStorage;
#[derive(thiserror::Error, Debug)]
pub enum MetalError {
#[error("{0}")]
Message(String),
}
impl From<String> for MetalError {
fn from(e: String) -> Self {
MetalError::Message(e)
}
}
macro_rules! fail {
() => {
unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
@ -91,16 +79,6 @@ impl crate::backend::BackendStorage for MetalStorage {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv_transpose1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv2d(
&self,
_: &Layout,

View File

@ -1,4 +1,4 @@
use crate::{DType, DeviceLocation, Layout, MetalError, Shape};
use crate::{metal_backend, DType, DeviceLocation, Layout, Shape};
#[derive(Debug, Clone)]
pub struct MatMulUnexpectedStriding {
@ -163,7 +163,7 @@ pub enum Error {
Cuda(Box<dyn std::error::Error + Send + Sync>),
#[error("Metal error {0}")]
Metal(#[from] MetalError),
Metal(#[from] metal_backend::MetalError),
#[error(transparent)]
TryFromIntError(#[from] core::num::TryFromIntError),

View File

@ -104,31 +104,37 @@ impl From<&Tensor> for TensorIndexer {
}
}
trait RB: RangeBounds<usize> {}
impl RB for Range<usize> {}
impl RB for RangeFrom<usize> {}
impl RB for RangeFull {}
impl RB for RangeInclusive<usize> {}
impl RB for RangeTo<usize> {}
impl RB for RangeToInclusive<usize> {}
macro_rules! impl_from_range {
($range_type:ty) => {
impl From<$range_type> for TensorIndexer {
fn from(range: $range_type) -> Self {
use std::ops::Bound::*;
impl<T: RB> From<T> for TensorIndexer {
fn from(range: T) -> Self {
use std::ops::Bound::*;
let start = match range.start_bound() {
Included(idx) => Included(*idx),
Excluded(idx) => Excluded(*idx),
Unbounded => Unbounded,
};
let end = match range.end_bound() {
Included(idx) => Included(*idx),
Excluded(idx) => Excluded(*idx),
Unbounded => Unbounded,
};
TensorIndexer::Narrow(start, end)
}
let start = match range.start_bound() {
Included(idx) => Included(*idx),
Excluded(idx) => Excluded(*idx),
Unbounded => Unbounded,
};
let end = match range.end_bound() {
Included(idx) => Included(*idx),
Excluded(idx) => Excluded(*idx),
Unbounded => Unbounded,
};
TensorIndexer::Narrow(start, end)
}
}
};
}
impl_from_range!(Range<usize>);
impl_from_range!(RangeFrom<usize>);
impl_from_range!(RangeFull);
impl_from_range!(RangeInclusive<usize>);
impl_from_range!(RangeTo<usize>);
impl_from_range!(RangeToInclusive<usize>);
/// Trait used to implement multiple signatures for ease of use of the slicing
/// of a tensor
pub trait IndexOp<T> {

View File

@ -49,12 +49,13 @@ mod device;
pub mod display;
mod dtype;
mod dummy_cuda_backend;
mod dummy_metal_backend;
pub mod error;
mod indexer;
pub mod layout;
#[cfg(feature = "metal")]
pub mod metal_backend;
#[cfg(feature = "accelerate")]
mod metal_backend;
#[cfg(feature = "mkl")]
mod mkl;
pub mod npy;
@ -91,10 +92,10 @@ pub use cuda_backend::{CudaDevice, CudaStorage};
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
#[cfg(feature = "metal")]
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
pub use metal_backend::{MetalDevice, MetalStorage};
#[cfg(not(feature = "metal"))]
pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage};
pub use dummy_metal_backend::{MetalDevice, MetalStorage};
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
@ -123,6 +124,12 @@ pub trait Module {
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
}
impl Module for quantized::QMatMul {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self.forward(xs)
}
}
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self(xs)

File diff suppressed because it is too large Load Diff

View File

@ -90,16 +90,6 @@ pub enum Op {
dilation: usize,
},
#[allow(dead_code)]
ConvTranspose1D {
arg: Tensor,
kernel: Tensor,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
},
#[allow(dead_code)]
Conv2D {
arg: Tensor,
@ -593,8 +583,7 @@ unary_op!(Recip, "recip", v, v.recip());
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
/// Tanh based approximation of the `gelu` operation
/// GeluErf is the more precise one.
/// `gelu` operation
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
impl UnaryOpT for Gelu {
const NAME: &'static str = "gelu";
@ -684,8 +673,6 @@ impl UnaryOpT for Gelu {
}
}
/// `erf` operation
/// <https://en.wikipedia.org/wiki/Error_function>
impl UnaryOpT for Erf {
const NAME: &'static str = "erf";
const KERNEL: &'static str = "uerf";
@ -975,10 +962,6 @@ impl BackpropOp {
};
Self(op)
}
pub(crate) fn is_none(&self) -> bool {
self.0.is_none()
}
}
impl std::ops::Deref for BackpropOp {

View File

@ -1,7 +1,7 @@
//! Support for the GGML file format.
use super::{k_quants, GgmlDType};
use crate::Result;
use crate::{Device, Result};
use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap;
@ -121,11 +121,12 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
raw_data: &[u8],
size_in_bytes: usize,
dims: Vec<usize>,
device: &Device,
) -> Result<super::QTensor> {
let raw_data_ptr = raw_data.as_ptr();
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
super::QTensor::new(data.to_vec(), dims)
super::QTensor::new(data.to_vec(), dims, device)
}
/// Creates a [Tensor] from a raw GGML tensor.
@ -133,6 +134,7 @@ pub fn qtensor_from_ggml(
ggml_dtype: GgmlDType,
raw_data: &[u8],
dims: Vec<usize>,
device: &Device,
) -> Result<super::QTensor> {
let tensor_elems = dims.iter().product::<usize>();
let blck_size = ggml_dtype.blck_size();
@ -144,18 +146,38 @@ pub fn qtensor_from_ggml(
let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
match ggml_dtype {
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q4_0 => {
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q4_1 => {
from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5_0 => {
from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5_1 => {
from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q8_0 => {
from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q2K => {
from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q3K => {
from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q4K => {
from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5K => {
from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q6K => {
from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
}
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
}
}
@ -163,6 +185,7 @@ pub fn qtensor_from_ggml(
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
reader: &mut R,
magic: VersionedMagic,
device: &Device,
) -> Result<(String, super::QTensor)> {
let n_dims = reader.read_u32::<LittleEndian>()?;
let name_len = reader.read_u32::<LittleEndian>()?;
@ -187,7 +210,7 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
// TODO: Mmap version to avoid copying the data around?
let mut raw_data = vec![0u8; size_in_bytes];
reader.read_exact(&mut raw_data)?;
match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) {
Ok(tensor) => Ok((name, tensor)),
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
}
@ -201,7 +224,10 @@ pub struct Content {
}
impl Content {
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
pub fn read<R: std::io::Seek + std::io::Read>(
reader: &mut R,
device: &Device,
) -> Result<Content> {
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
reader.seek(std::io::SeekFrom::Start(0))?;
@ -211,7 +237,7 @@ impl Content {
let mut tensors = HashMap::new();
while reader.stream_position()? != last_position {
let (name, tensor) = read_one_tensor(reader, magic)?;
let (name, tensor) = read_one_tensor(reader, magic, device)?;
tensors.insert(name, tensor);
}
Ok(Self {

View File

@ -3,7 +3,7 @@
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
use super::{GgmlDType, QTensor};
use crate::Result;
use crate::{Device, Result};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::collections::HashMap;
@ -29,7 +29,6 @@ impl TryFrom<u32> for Magic {
pub enum VersionedMagic {
GgufV1,
GgufV2,
GgufV3,
}
impl VersionedMagic {
@ -40,7 +39,6 @@ impl VersionedMagic {
let versioned_magic = match (magic, version) {
(Magic::Gguf, 1) => Self::GgufV1,
(Magic::Gguf, 2) => Self::GgufV2,
(Magic::Gguf, 3) => Self::GgufV3,
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
};
Ok(versioned_magic)
@ -59,6 +57,7 @@ impl TensorInfo {
&self,
reader: &mut R,
tensor_data_offset: u64,
device: &Device,
) -> Result<QTensor> {
let tensor_elems = self.shape.elem_count();
let blck_size = self.ggml_dtype.blck_size();
@ -71,7 +70,12 @@ impl TensorInfo {
let mut raw_data = vec![0u8; size_in_bytes];
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
reader.read_exact(&mut raw_data)?;
super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec())
super::ggml_file::qtensor_from_ggml(
self.ggml_dtype,
&raw_data,
self.shape.dims().to_vec(),
device,
)
}
}
@ -86,9 +90,7 @@ pub struct Content {
fn read_string<R: std::io::Read>(reader: &mut R, magic: &VersionedMagic) -> Result<String> {
let len = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
reader.read_u64::<LittleEndian>()? as usize
}
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
};
let mut v = vec![0u8; len];
reader.read_exact(&mut v)?;
@ -288,9 +290,7 @@ impl Value {
let value_type = ValueType::from_u32(value_type)?;
let len = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
reader.read_u64::<LittleEndian>()? as usize
}
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
};
let mut vs = Vec::with_capacity(len);
for _ in 0..len {
@ -387,15 +387,11 @@ impl Content {
let tensor_count = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
reader.read_u64::<LittleEndian>()? as usize
}
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
};
let metadata_kv_count = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
reader.read_u64::<LittleEndian>()? as usize
}
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
};
let mut metadata = HashMap::new();
@ -417,7 +413,7 @@ impl Content {
reader.read_u32_into::<LittleEndian>(&mut dimensions)?;
dimensions.into_iter().map(|c| c as usize).collect()
}
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
VersionedMagic::GgufV2 => {
let mut dimensions = vec![0; n_dimensions as usize];
reader.read_u64_into::<LittleEndian>(&mut dimensions)?;
dimensions.into_iter().map(|c| c as usize).collect()
@ -460,12 +456,13 @@ impl Content {
&self,
reader: &mut R,
name: &str,
device: &Device,
) -> Result<QTensor> {
let tensor_info = match self.tensor_infos.get(name) {
Some(tensor_info) => tensor_info,
None => crate::bail!("cannot find tensor-infor for {name}"),
};
tensor_info.read(reader, self.tensor_data_offset)
tensor_info.read(reader, self.tensor_data_offset, device)
}
}

View File

@ -1,4 +1,5 @@
use crate::{Device, Result, Shape, Tensor};
use tracing::debug;
#[cfg(target_feature = "avx")]
pub mod avx;
@ -14,6 +15,7 @@ pub mod utils;
pub use k_quants::GgmlType;
pub struct QTensor {
device: Device,
data: Box<dyn QuantizedType>,
shape: Shape,
}
@ -170,17 +172,20 @@ impl QTensor {
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
data: Vec<T>,
shape: S,
device: &Device,
) -> Result<Self> {
let shape = shape.into();
check_shape::<T>(&shape)?;
Ok(Self {
data: Box::new(data),
shape,
device: device.clone(),
})
}
pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
let shape = src.shape();
let device = src.device();
check_shape::<T>(shape)?;
let src = src
.to_dtype(crate::DType::F32)?
@ -197,6 +202,7 @@ impl QTensor {
Ok(Self {
data: Box::new(data),
shape: shape.clone(),
device: device.clone(),
})
}
@ -212,7 +218,12 @@ impl QTensor {
&self.shape
}
pub fn device(&self) -> &Device {
&self.device
}
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
// TODO Skip the CPU part on metal
let mut f32_data = vec![0f32; self.shape.elem_count()];
self.data.to_float(&mut f32_data)?;
Tensor::from_vec(f32_data, &self.shape, device)
@ -305,10 +316,50 @@ impl crate::CustomOp1 for QTensor {
)?;
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
}
fn metal_fwd(
&self,
storage: &crate::MetalStorage,
layout: &crate::Layout,
) -> Result<(crate::MetalStorage, Shape)> {
debug!("TODO qmatmul");
if !layout.is_contiguous() {
crate::bail!("input tensor is not contiguous {layout:?}")
}
let src_shape = layout.shape();
// self is transposed so n is first then k.
let (n, k) = self.shape.dims2()?;
if src_shape.rank() < 2 {
crate::bail!("input tensor has only one dimension {layout:?}")
}
let mut dst_shape = src_shape.dims().to_vec();
let last_k = dst_shape.pop().unwrap();
if last_k != k {
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
}
dst_shape.push(n);
let dst_shape = Shape::from(dst_shape);
// let storage = storage.as_slice::<f32>()?;
// let storage =
// &storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
let dst_storage = vec![0f32; dst_shape.elem_count()];
// self.matmul_t(
// (dst_shape.elem_count() / n, k, n),
// storage,
// &mut dst_storage,
// )?;
let cpu_storage = crate::CpuStorage::F32(dst_storage);
use crate::backend::{BackendDevice, BackendStorage};
if let Device::Metal(device) = &self.device {
Ok((device.storage_from_cpu_storage(&cpu_storage)?, dst_shape))
} else {
crate::bail!("qtensor not on metal device")
}
}
}
impl crate::Module for QMatMul {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
impl QMatMul {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
Self::Tensor(w) => {

View File

@ -334,33 +334,6 @@ impl Storage {
}
}
pub(crate) fn conv_transpose1d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
self.same_device(kernel, "conv-transpose1d")?;
self.same_dtype(kernel, "conv-transpose1d")?;
match (self, &kernel) {
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
Ok(Self::Cpu(s))
}
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "conv-transpose1d",
}
.bt()),
}
}
pub(crate) fn conv2d(
&self,
l: &Layout,

View File

@ -477,12 +477,6 @@ impl Tensor {
broadcast_binary_op!(broadcast_div, div);
broadcast_binary_op!(broadcast_maximum, maximum);
broadcast_binary_op!(broadcast_minimum, minimum);
broadcast_binary_op!(broadcast_eq, eq);
broadcast_binary_op!(broadcast_ne, ne);
broadcast_binary_op!(broadcast_lt, lt);
broadcast_binary_op!(broadcast_le, le);
broadcast_binary_op!(broadcast_gt, gt);
broadcast_binary_op!(broadcast_ge, ge);
unary_op!(recip, Recip);
unary_op!(neg, Neg);
@ -856,20 +850,6 @@ impl Tensor {
self.sum_impl(mean_dims, false)? * scale
}
/// Returns the unbiased variance over the selected dimension.
pub fn var_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "var")?;
let mean = self.mean_keepdim(dim)?;
let squares = self.broadcast_sub(&mean)?.sqr()?;
squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64
}
/// Returns the unbiased variance over the selected dimension.
pub fn var<D: Dim>(&self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "var")?;
self.var_keepdim(dim)?.squeeze(dim)
}
/// Gathers the maximum value across the selected dimension. The resulting shape has the same
/// number of dimensions as the original tensor and the select dimension has a single element.
pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
@ -1831,23 +1811,17 @@ impl Tensor {
/// Returns a new tensor detached from the current graph, gradient are not propagated through
/// this new node. The storage of this tensor is shared with the initial tensor.
///
/// If the tensor is already detached from the computation graph, the same tensor is returned.
pub fn detach(&self) -> Result<Tensor> {
if self.op.is_none() && !self.is_variable {
Ok(self.clone())
} else {
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
layout: self.layout.clone(),
op: BackpropOp::none(),
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
}
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
layout: self.layout.clone(),
op: BackpropOp::none(),
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
}
/// If the target device is the same as the tensor device, only a shallow copy is performed.
@ -1859,14 +1833,7 @@ impl Tensor {
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
}
(Storage::Cpu(storage), Device::Metal(metal)) => {
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
}
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
(Storage::Metal(storage), Device::Cpu) => {
println!("{storage:?} - {:?}", storage.to_cpu_storage()?);
Storage::Cpu(storage.to_cpu_storage()?)
}
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
// are the same.
@ -2440,127 +2407,6 @@ impl Tensor {
) -> Result<Self> {
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
}
/// Normalize a 'relative' axis value: positive values are kept, negative
/// values means counting the dimensions from the back.
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
let rank = self.rank() as i64;
if rank <= axis {
crate::bail!("axis {axis} is too large, tensor rank {rank}")
} else if 0 <= axis {
Ok(axis as usize)
} else {
let naxis = rank + axis;
if naxis < 0 {
crate::bail!("axis {axis} is too small, tensor rank {rank}")
}
Ok(naxis as usize)
}
}
/// Returns a lower triangular matrix of ones of size n by n.
pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
let t = Tensor::arange(0u32, n as u32, device)?;
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
t1.le(&t2)?.to_dtype(dtype)
}
/// Returns an upper triangular matrix of ones of size n by n.
pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
let t = Tensor::arange(0u32, n as u32, device)?;
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
t1.ge(&t2)?.to_dtype(dtype)
}
/// Returns a matrix with a diagonal of ones of size n by n.
pub fn eye(n: usize, dtype: DType, device: &Device) -> Result<Self> {
let t = Tensor::arange(0u32, n as u32, device)?;
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
t1.eq(&t2)?.to_dtype(dtype)
}
/// Returns the cumulative sum of elements of the input tensor summed over the specified
/// dimension.
///
/// This operation is most efficient when dim is the last dimension of the tensor.
pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "cumsum")?;
let rank = self.rank();
if rank == 0 {
return Ok(self.clone());
}
let n_axis = self.dim(dim)?;
let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?;
if rank == 1 {
self.unsqueeze(0)?.matmul(&triu)?.squeeze(0)
} else {
let last = rank - 1;
let t = self.transpose(dim, last)?;
let t = t.broadcast_matmul(&triu)?;
t.transpose(dim, last)
}
}
/// Returns a copy of `self` where the values within `ranges` have been replaced with the
/// content of `src`.
pub fn slice_assign<D: std::ops::RangeBounds<usize>>(
&self,
ranges: &[D],
src: &Tensor,
) -> Result<Self> {
let src_dims = src.dims();
let self_dims = self.dims();
if self_dims.len() != src_dims.len() {
crate::bail!(
"slice-assign requires input with the same rank {} <> {}",
self_dims.len(),
src_dims.len()
)
}
if self_dims.len() != ranges.len() {
crate::bail!(
"slice-assign requires input with the same rank as there are ranges {} <> {}",
self_dims.len(),
ranges.len()
)
}
let mut src = src.clone();
let mut mask = Self::ones(src.shape(), DType::U8, src.device())?;
for (i, range) in ranges.iter().enumerate() {
let start_included = match range.start_bound() {
std::ops::Bound::Unbounded => 0,
std::ops::Bound::Included(v) => *v,
std::ops::Bound::Excluded(v) => *v + 1,
};
let end_excluded = match range.end_bound() {
std::ops::Bound::Unbounded => self_dims[i],
std::ops::Bound::Included(v) => *v + 1,
std::ops::Bound::Excluded(v) => *v,
};
if end_excluded <= start_included {
crate::bail!(
"slice-assign: empty range for dim {i}, {start_included} {end_excluded}"
)
}
if self_dims[i] < end_excluded {
crate::bail!(
"slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
self_dims[i]
)
}
if end_excluded - start_included != src_dims[i] {
crate::bail!(
"slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
)
}
src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?;
mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?
}
mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
}
}
macro_rules! bin_trait {

View File

@ -4,7 +4,7 @@ use crate::{Result, Tensor};
macro_rules! test_device {
// TODO: Switch to generating the two last arguments automatically once concat_idents is
// stable. https://github.com/rust-lang/rust/issues/29599
($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => {
($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => {
#[test]
fn $test_cpu() -> Result<()> {
$fn_name(&Device::Cpu)
@ -15,12 +15,6 @@ macro_rules! test_device {
fn $test_cuda() -> Result<()> {
$fn_name(&Device::new_cuda(0)?)
}
#[cfg(feature = "metal")]
#[test]
fn $test_metal() -> Result<()> {
$fn_name(&Device::new_metal(0)?)
}
};
}

View File

@ -13,11 +13,6 @@ res = torch.nn.functional.conv1d(t, w)
print(res.flatten())
res = torch.nn.functional.conv1d(t, w, padding=1)
print(res.flatten())
w_t = w.transpose(0, 1)
res = torch.nn.functional.conv_transpose1d(t, w_t)
print(res.shape)
print(res)
*/
fn conv1d(dev: &Device) -> Result<()> {
let t = Tensor::new(
@ -50,17 +45,6 @@ fn conv1d(dev: &Device) -> Result<()> {
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
if dev.is_cpu() {
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 7]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
4.7076, -5.9745, -0.8276, 1.621
],
);
}
Ok(())
}
@ -563,35 +547,14 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
Ok(())
}
test_device!(conv1d, conv1d_cpu, conv1d_gpu, conv1d_metal);
test_device!(
conv1d_small,
conv1d_small_cpu,
conv1d_small_gpu,
conv1d_small_metal
);
test_device!(conv2d, conv2d_cpu, conv2d_gpu, conv2d_metal);
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
test_device!(
conv2d_non_square,
conv2d_non_square_cpu,
conv2d_non_square_gpu,
conv2d_non_square_metal
);
test_device!(
conv2d_small,
conv2d_small_cpu,
conv2d_small_gpu,
conv2d_small_metal
);
test_device!(
conv2d_smaller,
conv2d_smaller_cpu,
conv2d_smaller_gpu,
conv2d_smaller_metal
);
test_device!(
conv2d_grad,
conv2d_grad_cpu,
conv2d_grad_gpu,
conv2_grad_metal
conv2d_non_square_gpu
);
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
test_device!(conv2d_grad, conv2d_grad_cpu, conv2d_grad_gpu);

View File

@ -205,71 +205,6 @@ fn unary_grad(device: &Device) -> Result<()> {
test_utils::to_vec1_round(grad_x, 4)?,
[1.0116, 1.0830, 1.0003, 0.6188],
);
// Testing compared to pytorch torch.erf
//
// import torch
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
// y = x.erf()
// print(y)
// loss = y.sum()
// loss.backward()
// print(x.grad)
let y = x.erf()?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(test_utils::to_vec1_round(&y, 4)?, [1.0, 0.8427, 1.0, 0.168]);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[0.0001, 0.4151, 0.0, 1.1033],
);
// Testing compared to pytorch nn.GELU(approximate = 'none')
//
// import torch
// import torch.nn.functional as F
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
// y = F.gelu(x, approximate='none')
// print(y)
// loss = y.sum()
// loss.backward()
// print(x.grad)
let y = x.gelu_erf()?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[2.9960, 0.8413, 3.9999, 0.0839]
);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[1.0119, 1.0833, 1.0005, 0.6188],
);
// Testing compared to pytorch elu
//
// import torch
// import torch.nn.functional as F
// x = torch.tensor([-1.0, 0.0, -2.0, 3.0], requires_grad=True)
// y = F.elu(x, alpha=2.0)
// print(y)
// loss = y.min
// loss = y.sum()
// loss.backward()
// print(x.grad)
let elu_x = Var::new(&[-1.0f32, 0., -2., 3.], device)?;
let y = elu_x.elu(2.)?;
let grads = y.backward()?;
let grad_x = grads.get(&elu_x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[-1.2642, 0.0000, -1.7293, 3.0000]
);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[0.7358, 2.0000, 0.2707, 1.0000]
);
Ok(())
}
@ -315,29 +250,9 @@ fn binary_grad(device: &Device) -> Result<()> {
Ok(())
}
test_device!(
simple_grad,
simple_grad_cpu,
simple_grad_gpu,
simple_grad_metal
);
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu, sum_grad_metal);
test_device!(
matmul_grad,
matmul_grad_cpu,
matmul_grad_gpu,
matmul_grad_metal
);
test_device!(
grad_descent,
grad_descent_cpu,
grad_descent_gpu,
grad_descent_metal
);
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu, unary_grad_metal);
test_device!(
binary_grad,
binary_grad_cpu,
binary_grad_gpu,
binary_grad_metal
);
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu);
test_device!(binary_grad, binary_grad_cpu, binary_grad_gpu);

View File

@ -91,32 +91,3 @@ fn index_3d() -> Result<()> {
assert_eq!(tensor.i((1, .., 3))?.to_vec1::<u32>()?, &[15, 19, 23]);
Ok(())
}
#[test]
fn slice_assign() -> Result<()> {
let dev = Device::Cpu;
let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?;
let src = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((3, 2))?;
let out = tensor.slice_assign(&[1..4, 3..5], &src)?;
assert_eq!(
out.to_vec2::<u32>()?,
&[
[0, 1, 2, 3, 4],
[5, 6, 7, 0, 1],
[10, 11, 12, 2, 3],
[15, 16, 17, 4, 5]
]
);
let out = tensor.slice_assign(&[0..3, 0..2], &src)?;
assert_eq!(
out.to_vec2::<u32>()?,
&[
[0, 1, 2, 3, 4],
[2, 3, 7, 8, 9],
[4, 5, 12, 13, 14],
[15, 16, 17, 18, 19]
]
);
Ok(())
}

View File

@ -49,7 +49,7 @@ fn contiguous(device: &Device) -> Result<()> {
Ok(())
}
test_device!(contiguous, contiguous_cpu, contiguous_gpu, contiguous_metal);
test_device!(contiguous, contiguous_cpu, contiguous_gpu);
#[test]
fn strided_blocks() -> Result<()> {

View File

@ -98,17 +98,15 @@ fn upsample_nearest2d(dev: &Device) -> Result<()> {
Ok(())
}
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu, avg_pool2d_metal);
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu);
test_device!(
avg_pool2d_pytorch,
avg_pool2d_pytorch_cpu,
avg_pool2d_pytorch_gpu,
avg_pool2d_pytorch_metal
avg_pool2d_pytorch_gpu
);
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu, max_pool2d_metal);
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);
test_device!(
upsample_nearest2d,
upsample_nearest2d_cpu,
upsample_nearest2d_gpu,
upsample_nearest2d_metal
upsample_nearest2d_gpu
);

View File

@ -1,7 +1,7 @@
use candle_core::{
quantized::{self, GgmlDType},
test_utils::to_vec2_round,
Device, Module, Result, Tensor,
Device, Result, Tensor,
};
use quantized::{k_quants, GgmlType};
use rand::prelude::*;

View File

@ -180,22 +180,6 @@ fn transpose(device: &Device) -> Result<()> {
Ok(())
}
fn var(device: &Device) -> Result<()> {
// Values taken from https://pytorch.org/docs/stable/generated/torch.var.html
let data = &[
[0.2035f32, 1.2959, 1.8101, -0.4644],
[1.5027, -0.3270, 0.5905, 0.6538],
[-1.5745, 1.3330, -0.5596, -0.6548],
[0.1264, -0.5080, 1.6420, 0.1992],
];
let tensor = Tensor::new(data, device)?;
assert_eq!(
test_utils::to_vec2_round(&tensor.var_keepdim(1)?, 4)?,
&[[1.0631], [0.559], [1.4893], [0.8258]]
);
Ok(())
}
fn sum(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
@ -1070,60 +1054,34 @@ fn randn(device: &Device) -> Result<()> {
Ok(())
}
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
test_device!(min, min_cpu, min_gpu, min_metal);
test_device!(max, max_cpu, max_gpu, max_metal);
test_device!(argmax, argmax_cpu, argmax_gpu, argmax_metal);
test_device!(argmin, argmin_cpu, argmin_gpu, argmin_metal);
test_device!(transpose, transpose_cpu, transpose_gpu, transpose_metal);
test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
test_device!(
broadcast_matmul,
broadcast_matmul_cpu,
broadcast_matmul_gpu,
broadcast_matmul_metal
);
test_device!(
broadcasting,
broadcasting_cpu,
broadcasting_gpu,
broadcasting_metal
);
test_device!(
index_select,
index_select_cpu,
index_select_gpu,
index_select_metal
);
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
test_device!(
scatter_add,
scatter_add_cpu,
scatter_add_gpu,
scatter_add_metal
);
test_device!(
slice_scatter,
slice_scatter_cpu,
slice_scatter_gpu,
slice_scatter_metal
);
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
test_device!(var, var_cpu, var_gpu, var_metal);
test_device!(zeros, zeros_cpu, zeros_gpu);
test_device!(ones, ones_cpu, ones_gpu);
test_device!(arange, arange_cpu, arange_gpu);
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
test_device!(narrow, narrow_cpu, narrow_gpu);
test_device!(broadcast, broadcast_cpu, broadcast_gpu);
test_device!(cat, cat_cpu, cat_gpu);
test_device!(sum, sum_cpu, sum_gpu);
test_device!(min, min_cpu, min_gpu);
test_device!(max, max_cpu, max_gpu);
test_device!(argmax, argmax_cpu, argmax_gpu);
test_device!(argmin, argmin_cpu, argmin_gpu);
test_device!(transpose, transpose_cpu, transpose_gpu);
test_device!(unary_op, unary_op_cpu, unary_op_gpu);
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
test_device!(cmp, cmp_cpu, cmp_gpu);
test_device!(matmul, matmul_cpu, matmul_gpu);
test_device!(broadcast_matmul, broadcast_matmul_cpu, broadcast_matmul_gpu);
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
test_device!(index_select, index_select_cpu, index_select_gpu);
test_device!(index_add, index_add_cpu, index_add_gpu);
test_device!(gather, gather_cpu, gather_gpu);
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
test_device!(randn, randn_cpu, randn_gpu);
test_device!(clamp, clamp_cpu, clamp_gpu);
// There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381
@ -1159,65 +1117,3 @@ fn i64_abs() -> Result<()> {
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
Ok(())
}
#[test]
fn tril_triu_eye() -> Result<()> {
let t = Tensor::tril2(4, DType::F32, &Device::Cpu)?;
assert_eq!(
t.to_vec2::<f32>()?,
[
[1.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 0.0],
[1.0, 1.0, 1.0, 1.0]
],
);
let t = Tensor::triu2(4, DType::F32, &Device::Cpu)?;
assert_eq!(
t.to_vec2::<f32>()?,
[
[1.0, 1.0, 1.0, 1.0],
[0.0, 1.0, 1.0, 1.0],
[0.0, 0.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 1.0]
]
);
let t = Tensor::eye(4, DType::F32, &Device::Cpu)?;
assert_eq!(
t.to_vec2::<f32>()?,
[
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0]
]
);
Ok(())
}
#[test]
fn cumsum() -> Result<()> {
let t = &[3f32, 1., 4., 1., 5.];
let t = Tensor::new(t, &Device::Cpu)?;
assert_eq!(t.cumsum(0)?.to_vec1::<f32>()?, [3., 4., 8., 9., 14.]);
let t = t.unsqueeze(1)?;
assert_eq!(
t.cumsum(0)?.to_vec2::<f32>()?,
[[3.0], [4.0], [8.0], [9.0], [14.0]]
);
assert_eq!(
t.cumsum(1)?.to_vec2::<f32>()?,
[[3.0], [1.0], [4.0], [1.0], [5.0]]
);
let t = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let t = Tensor::new(t, &Device::Cpu)?;
assert_eq!(
t.cumsum(1)?.to_vec2::<f32>()?,
[[3.0, 4.0, 8.0, 9.0, 14.0], [2.0, 3.0, 10.0, 18.0, 20.0]],
);
assert_eq!(
t.cumsum(0)?.to_vec2::<f32>()?,
[[3.0, 1.0, 4.0, 1.0, 5.0], [5.0, 2.0, 11.0, 9.0, 7.0]]
);
Ok(())
}

View File

@ -11,8 +11,8 @@ readme = "README.md"
[dependencies]
byteorder = { workspace = true }
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.1" }
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
hf-hub = { workspace = true}
intel-mkl-src = { workspace = true, optional = true }
memmap2 = { workspace = true }

View File

@ -4,9 +4,7 @@
//! <https://www.cs.toronto.edu/~kriz/cifar.html>
//! The binary version of the dataset is used.
use crate::vision::Dataset;
use candle::{DType, Device, Error, Result, Tensor};
use hf_hub::{api::sync::Api, Repo, RepoType};
use parquet::file::reader::{FileReader, SerializedFileReader};
use candle::{DType, Device, Result, Tensor};
use std::fs::File;
use std::io::{BufReader, Read};
@ -62,58 +60,3 @@ pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<Dataset> {
labels: 10,
})
}
fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor, Tensor)> {
let samples = parquet.metadata().file_metadata().num_rows() as usize;
let mut buffer_images: Vec<u8> = Vec::with_capacity(samples * 1_024);
let mut buffer_labels: Vec<u8> = Vec::with_capacity(samples);
for row in parquet.into_iter().flatten() {
for (_name, field) in row.get_column_iter() {
if let parquet::record::Field::Group(subrow) = field {
for (_name, field) in subrow.get_column_iter() {
if let parquet::record::Field::Bytes(value) = field {
let image = image::load_from_memory(value.data()).unwrap();
buffer_images.extend(image.to_rgb8().as_raw());
}
}
} else if let parquet::record::Field::Long(label) = field {
buffer_labels.push(*label as u8);
}
}
}
let images = (Tensor::from_vec(buffer_images, (samples, 3, 32, 32), &Device::Cpu)?
.to_dtype(DType::U8)?
/ 255.)?;
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
Ok((images, labels))
}
pub fn load() -> Result<Dataset> {
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let dataset_id = "cifar10".to_string();
let repo = Repo::with_revision(
dataset_id,
RepoType::Dataset,
"refs/convert/parquet".to_string(),
);
let repo = api.repo(repo);
let test_parquet_filename = repo
.get("plain_text/test/0000.parquet")
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let train_parquet_filename = repo
.get("plain_text/train/0000.parquet")
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
let (test_images, test_labels) = load_parquet(test_parquet)?;
let (train_images, train_labels) = load_parquet(train_parquet)?;
Ok(crate::vision::Dataset {
train_images,
train_labels,
test_images,
test_labels,
labels: 10,
})
}

View File

@ -11,12 +11,11 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.3.1" }
candle-nn = { path = "../candle-nn", version = "0.3.1" }
candle-transformers = { path = "../candle-transformers", version = "0.3.1" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true }
candle-onnx = { path = "../candle-onnx", version = "0.3.1", optional = true }
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
image = { workspace = true }
@ -52,12 +51,11 @@ anyhow = { workspace = true }
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
metal = ["candle/metal", "candle-nn/metal", "candle-transformers/metal"]
cudnn = ["candle/cudnn"]
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
[[example]]
name = "llama_multiprocess"
@ -66,11 +64,3 @@ required-features = ["cuda", "nccl", "flash-attn"]
[[example]]
name = "reinforcement-learning"
required-features = ["pyo3"]
[[example]]
name = "onnx"
required-features = ["onnx"]
[[example]]
name = "onnx_basics"
required-features = ["onnx"]

View File

@ -1,22 +0,0 @@
# candle-distilbert
DistilBert is a distiled version of the Bert model.
## Sentence embeddings
DistilBert is used to compute the sentence embeddings for a prompt. The model weights
are downloaded from the hub on the first run.
```bash
cargo run --example distilbert --release -- --prompt "Here is a test sentence"
> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441],
> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244],
> [ 0.0702, -0.1311, -0.4914, ..., 0.3483, -0.6194, 0.1829],
> ...
> [ 0.2993, -0.0106, -0.4640, ..., 0.2844, -0.6732, 0.0042],
> [ 0.1066, -0.0081, -0.4299, ..., 0.3435, -0.7729, 0.0190],
> [ 0.8903, 0.2055, -0.2541, ..., 0.3208, -0.6585, 0.0586]]]
> Tensor[[1, 7, 768], f32]
```

View File

@ -1,135 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE};
use anyhow::{Error as E, Result};
use candle::{Device, Tensor};
use candle_nn::VarBuilder;
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
/// When set, compute embeddings for this prompt.
#[arg(long)]
prompt: String,
/// Use the pytorch weights rather than the safetensors ones
#[arg(long)]
use_pth: bool,
/// The number of times to run the prompt.
#[arg(long, default_value = "1")]
n: usize,
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
}
impl Args {
fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> {
let device = candle_examples::device(self.cpu)?;
let default_model = "distilbert-base-uncased".to_string();
let default_revision = "main".to_string();
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(None, Some(revision)) => (default_model, revision),
(None, None) => (default_model, default_revision),
};
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let (config_filename, tokenizer_filename, weights_filename) = {
let api = Api::new()?;
let api = api.repo(repo);
let config = api.get("config.json")?;
let tokenizer = api.get("tokenizer.json")?;
let weights = if self.use_pth {
api.get("pytorch_model.bin")?
} else {
api.get("model.safetensors")?
};
(config, tokenizer, weights)
};
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let vb = if self.use_pth {
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
} else {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
};
let model = DistilBertModel::load(vb, &config)?;
Ok((model, tokenizer))
}
}
fn get_mask(size: usize, device: &Device) -> Tensor {
let mask: Vec<_> = (0..size)
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
.collect();
Tensor::from_slice(&mask, (size, size), device).unwrap()
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
let device = &model.device;
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let tokens = tokenizer
.encode(args.prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let mask = get_mask(tokens.len(), device);
println!("token_ids: {:?}", token_ids.to_vec2::<u32>());
println!("mask: {:?}", mask.to_vec2::<u8>());
let ys = model.forward(&token_ids, &mask)?;
println!("{ys}");
Ok(())
}
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
}

View File

@ -8,7 +8,6 @@ use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
#[derive(Debug, Clone, PartialEq)]
enum NormType {
WeightNorm,
TimeGroupNorm,
None,
}
@ -269,7 +268,6 @@ impl Module for EncodecConvTranspose1d {
struct EncodecConv1d {
causal: bool,
conv: Conv1d,
norm: Option<candle_nn::GroupNorm>,
}
impl EncodecConv1d {
@ -294,7 +292,7 @@ impl EncodecConv1d {
},
vb.pp("conv"),
)?,
NormType::None | NormType::TimeGroupNorm => conv1d(
NormType::None => conv1d(
in_c,
out_c,
kernel_size,
@ -307,17 +305,9 @@ impl EncodecConv1d {
vb.pp("conv"),
)?,
};
let norm = match cfg.norm_type {
NormType::None | NormType::WeightNorm => None,
NormType::TimeGroupNorm => {
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
Some(gn)
}
};
Ok(Self {
causal: cfg.use_causal_conv,
conv,
norm,
})
}
}
@ -326,10 +316,8 @@ impl Module for EncodecConv1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
// TODO: padding, depending on causal.
let xs = self.conv.forward(xs)?;
match &self.norm {
None => Ok(xs),
Some(norm) => xs.apply(norm),
}
// If we add support for NormType "time_group_norm", we should add some normalization here.
Ok(xs)
}
}

View File

@ -1,10 +0,0 @@
## Using ONNX models in Candle
This example demonstrates how to run ONNX based models in Candle, the model
being used here is a small sequeezenet variant.
You can run the example with the following command:
```bash
cargo run --example squeezenet-onnx --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
```

View File

@ -1,78 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{IndexOp, D};
use clap::{Parser, ValueEnum};
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
SqueezeNet,
EfficientNet,
}
#[derive(Parser)]
struct Args {
#[arg(long)]
image: String,
#[arg(long)]
model: Option<String>,
/// The model to be used.
#[arg(value_enum, long, default_value_t = Which::SqueezeNet)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let image = candle_examples::imagenet::load_image224(args.image)?;
let image = match args.which {
Which::SqueezeNet => image,
Which::EfficientNet => image.permute((1, 2, 0))?,
};
println!("loaded image {image:?}");
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => match args.which {
Which::SqueezeNet => hf_hub::api::sync::Api::new()?
.model("lmz/candle-onnx".into())
.get("squeezenet1.1-7.onnx")?,
Which::EfficientNet => hf_hub::api::sync::Api::new()?
.model("onnx/EfficientNet-Lite4".into())
.get("efficientnet-lite4-11.onnx")?,
},
};
let model = candle_onnx::read_file(model)?;
let graph = model.graph.as_ref().unwrap();
let mut inputs = std::collections::HashMap::new();
inputs.insert(graph.input[0].name.to_string(), image.unsqueeze(0)?);
let mut outputs = candle_onnx::simple_eval(&model, inputs)?;
let output = outputs.remove(&graph.output[0].name).unwrap();
let prs = match args.which {
Which::SqueezeNet => candle_nn::ops::softmax(&output, D::Minus1)?,
Which::EfficientNet => output,
};
let prs = prs.i(0)?.to_vec1::<f32>()?;
// Sort the predictions and take the top 5
let mut top: Vec<_> = prs.iter().enumerate().collect();
top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
let top = top.into_iter().take(5).collect::<Vec<_>>();
// Print the top predictions
for &(i, p) in &top {
println!(
"{:50}: {:.2}%",
candle_examples::imagenet::CLASSES[i],
p * 100.0
);
}
Ok(())
}

View File

@ -1,87 +0,0 @@
use anyhow::Result;
use candle::{Device, Tensor};
use clap::{Parser, Subcommand};
#[derive(Subcommand, Debug, Clone)]
enum Command {
Print {
#[arg(long)]
file: String,
},
SimpleEval {
#[arg(long)]
file: String,
},
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
#[command(subcommand)]
command: Command,
}
pub fn main() -> Result<()> {
let args = Args::parse();
match args.command {
Command::Print { file } => {
let model = candle_onnx::read_file(file)?;
println!("{model:?}");
let graph = model.graph.unwrap();
for node in graph.node.iter() {
println!("{node:?}");
}
}
Command::SimpleEval { file } => {
let model = candle_onnx::read_file(file)?;
let graph = model.graph.as_ref().unwrap();
let constants: std::collections::HashSet<_> =
graph.initializer.iter().map(|i| i.name.as_str()).collect();
let mut inputs = std::collections::HashMap::new();
for input in graph.input.iter() {
use candle_onnx::onnx::tensor_proto::DataType;
if constants.contains(input.name.as_str()) {
continue;
}
let type_ = input.r#type.as_ref().expect("no type for input");
let type_ = type_.value.as_ref().expect("no type.value for input");
let value = match type_ {
candle_onnx::onnx::type_proto::Value::TensorType(tt) => {
let dt = match DataType::try_from(tt.elem_type) {
Ok(dt) => match candle_onnx::dtype(dt) {
Some(dt) => dt,
None => {
anyhow::bail!(
"unsupported 'value' data-type {dt:?} for {}",
input.name
)
}
},
type_ => anyhow::bail!("unsupported input type {type_:?}"),
};
let shape = tt.shape.as_ref().expect("no tensortype.shape for input");
let dims = shape
.dim
.iter()
.map(|dim| match dim.value.as_ref().expect("no dim value") {
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => Ok(42),
})
.collect::<Result<Vec<usize>>>()?;
Tensor::zeros(dims, dt, &Device::Cpu)?
}
type_ => anyhow::bail!("unsupported input type {type_:?}"),
};
println!("input {}: {value:?}", input.name);
inputs.insert(input.name.clone(), value);
}
let outputs = candle_onnx::simple_eval(&model, inputs)?;
for (name, value) in outputs.iter() {
println!("output {name}: {value:?}")
}
}
}
Ok(())
}

View File

@ -1,7 +1,5 @@
# candle-quantized-t5
## Seq2Seq example
This example uses a quantized version of the t5 model.
```bash
@ -10,8 +8,6 @@ $ cargo run --example quantized-t5 --release -- --prompt "translate to German: A
Eine schöne Kerze.
```
## Generating Quantized weight files
The weight file is automatically retrieved from the hub. It is also possible to
generate quantized weight files from the original safetensors file by using the
`tensor-tools` command line utility via:
@ -20,11 +16,8 @@ generate quantized weight files from the original safetensors file by using the
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
```
## Using custom models
To use a different model, specify the `model-id`.
For example, for text editing, you can use quantized [CoEdit models](https://huggingface.co/jbochi/candle-coedit-quantized).
To use a different model, specify the `model-id`. For example, you can use
quantized [CoEdit models](https://huggingface.co/jbochi/candle-coedit-quantized).
```bash
$ cargo run --example quantized-t5 --release -- \
@ -33,7 +26,6 @@ $ cargo run --example quantized-t5 --release -- \
--temperature 0
...
Although their flight is weak, they run quickly through the tree canopy.
```
By default, it will look for `model.gguf` and `config.json`, but you can specify
custom local or remote `weight-file` and `config-file`s:
@ -48,16 +40,3 @@ cargo run --example quantized-t5 --release -- \
...
Note that a storm surge is what forecasters consider a hurricane's most dangerous part.
```
### [MADLAD-400](https://arxiv.org/abs/2309.04662)
MADLAD-400 is a series of multilingual machine translation T5 models trained on 250 billion tokens covering over 450 languages using publicly available data. These models are competitive with significantly larger models.
```bash
cargo run --example quantized-t5 --release -- \
--model-id "jbochi/madlad400-3b-mt" --weight-file "model-q4k.gguf" \
--prompt "<2de> How are you, my friend?" \
--temperature 0
...
Wie geht es dir, mein Freund?
```

View File

@ -173,11 +173,7 @@ fn main() -> Result<()> {
.to_vec();
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let mut model = builder.build_model()?;
let mut output_token_ids = [builder
.config
.decoder_start_token_id
.unwrap_or(builder.config.pad_token_id) as u32]
.to_vec();
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
let temperature = if args.temperature <= 0. {
None
} else {

View File

@ -9,10 +9,9 @@ use std::io::Write;
use tokenizers::Tokenizer;
use candle::quantized::{ggml_file, gguf_file};
use candle::{Device, Tensor};
use candle::Tensor;
use candle_transformers::generation::LogitsProcessor;
use candle_examples::token_output_stream::TokenOutputStream;
use candle_transformers::models::quantized_llama as model;
use model::ModelWeights;
@ -25,7 +24,7 @@ enum Prompt {
One(String),
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
#[derive(Clone, Debug, Copy, ValueEnum)]
enum Which {
#[value(name = "7b")]
L7b,
@ -49,12 +48,8 @@ enum Which {
Mistral7b,
#[value(name = "7b-mistral-instruct")]
Mistral7bInstruct,
#[value(name = "7b-zephyr-a")]
Zephyr7bAlpha,
#[value(name = "7b-zephyr-b")]
Zephyr7bBeta,
#[value(name = "7b-open-chat-3.5")]
OpenChat35,
#[value(name = "7b-zephyr")]
Zephyr7b,
}
impl Which {
@ -69,50 +64,7 @@ impl Which {
| Self::L7bCode
| Self::L13bCode
| Self::L34bCode => false,
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
// same way.
Self::OpenChat35
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta
| Self::Mistral7b
| Self::Mistral7bInstruct => true,
}
}
fn is_zephyr(&self) -> bool {
match self {
Self::L7b
| Self::L13b
| Self::L70b
| Self::L7bChat
| Self::L13bChat
| Self::L70bChat
| Self::L7bCode
| Self::L13bCode
| Self::L34bCode
| Self::Mistral7b
| Self::Mistral7bInstruct
| Self::OpenChat35 => false,
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
}
}
fn is_open_chat(&self) -> bool {
match self {
Which::L7b
| Which::L13b
| Which::L70b
| Which::L7bChat
| Which::L13bChat
| Which::L70bChat
| Which::L7bCode
| Which::L13bCode
| Which::L34bCode
| Which::Mistral7b
| Which::Mistral7bInstruct
| Which::Zephyr7bAlpha
| Which::Zephyr7bBeta => false,
Which::OpenChat35 => true,
Self::Mistral7b | Self::Mistral7bInstruct | Self::Zephyr7b => true,
}
}
}
@ -131,7 +83,7 @@ struct Args {
prompt: Option<String>,
/// The length of the sample to generate (in tokens).
#[arg(short = 'n', long, default_value_t = 1000)]
#[arg(short = 'n', long, default_value_t = 100)]
sample_len: usize,
/// The tokenizer config in json format.
@ -181,9 +133,7 @@ impl Args {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
let repo = if self.which.is_open_chat() {
"openchat/openchat_3.5"
} else if self.which.is_mistral() {
let repo = if self.which.is_mistral() {
"mistralai/Mistral-7B-v0.1"
} else {
"hf-internal-testing/llama-tokenizer"
@ -226,14 +176,10 @@ impl Args {
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
"mistral-7b-instruct-v0.1.Q4_K_S.gguf",
),
Which::Zephyr7bAlpha => (
Which::Zephyr7b => (
"TheBloke/zephyr-7B-alpha-GGUF",
"zephyr-7b-alpha.Q4_K_M.gguf",
),
Which::Zephyr7bBeta => {
("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf")
}
Which::OpenChat35 => ("TheBloke/openchat_3.5-GGUF", "openchat_3.5.Q4_K_M.gguf"),
};
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(repo.to_string());
@ -244,6 +190,31 @@ impl Args {
}
}
fn print_token(next_token: u32, tokenizer: &Tokenizer) {
// Extracting the last token as a string is complicated, here we just apply some simple
// heuristics as it seems to work well enough for this example. See the following for more
// details:
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
if let Some(text) = tokenizer.id_to_token(next_token) {
let text = text.replace('▁', " ");
let ascii = text
.strip_prefix("<0x")
.and_then(|t| t.strip_suffix('>'))
.and_then(|t| u8::from_str_radix(t, 16).ok());
match ascii {
None => print!("{text}"),
Some(ascii) => {
if let Some(chr) = char::from_u32(ascii as u32) {
if chr.is_ascii() {
print!("{chr}")
}
}
}
}
let _ = std::io::stdout().flush();
}
}
fn format_size(size_in_bytes: usize) -> String {
if size_in_bytes < 1_000 {
format!("{}B", size_in_bytes)
@ -261,11 +232,13 @@ fn main() -> anyhow::Result<()> {
use tracing_subscriber::prelude::*;
let args = Args::parse();
let device = candle_examples::device(false)?;
let temperature = if args.temperature == 0. {
None
} else {
Some(args.temperature)
};
tracing_subscriber::fmt::init();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
@ -305,10 +278,10 @@ fn main() -> anyhow::Result<()> {
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
ModelWeights::from_gguf(model, &mut file)?
ModelWeights::from_gguf(model, &mut file, &device)?
}
Some("ggml" | "bin") | Some(_) | None => {
let model = ggml_file::Content::read(&mut file)?;
let model = ggml_file::Content::read(&mut file, &device)?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensors.iter() {
let elem_count = tensor.shape().elem_count();
@ -332,19 +305,16 @@ fn main() -> anyhow::Result<()> {
| Which::L34bCode => 1,
Which::Mistral7b
| Which::Mistral7bInstruct
| Which::Zephyr7bAlpha
| Which::Zephyr7bBeta
| Which::Zephyr7b
| Which::L70b
| Which::L70bChat
| Which::OpenChat35 => 8,
| Which::L70bChat => 8,
};
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa), &device)?
}
};
println!("model built");
let tokenizer = args.tokenizer()?;
let mut tos = TokenOutputStream::new(tokenizer);
let prompt = match args.prompt.as_deref() {
Some("chat") => Prompt::Chat,
Some("interactive") => Prompt::Interactive,
@ -353,11 +323,10 @@ fn main() -> anyhow::Result<()> {
};
let mut pre_prompt_tokens = vec![];
for prompt_index in 0.. {
loop {
let prompt_str = match &prompt {
Prompt::One(prompt) => prompt.clone(),
Prompt::Interactive | Prompt::Chat => {
let is_interactive = matches!(prompt, Prompt::Interactive);
print!("> ");
std::io::stdout().flush()?;
let mut prompt = String::new();
@ -368,15 +337,7 @@ fn main() -> anyhow::Result<()> {
prompt.pop();
}
}
if args.which.is_open_chat() {
format!("User: {prompt}<|end_of_turn|>Assistant: ")
} else if args.which.is_zephyr() {
if prompt_index == 0 || is_interactive {
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>",)
} else {
format!("<|user|>\n{prompt}</s>\n<|assistant|>")
}
} else if args.which.is_mistral() {
if args.which.is_mistral() {
format!("[INST] {prompt} [/INST]")
} else {
prompt
@ -384,8 +345,7 @@ fn main() -> anyhow::Result<()> {
}
};
print!("{}", &prompt_str);
let tokens = tos
.tokenizer()
let tokens = tokenizer
.encode(prompt_str, true)
.map_err(anyhow::Error::msg)?;
if args.verbose_prompt {
@ -408,28 +368,23 @@ fn main() -> anyhow::Result<()> {
let start_prompt_processing = std::time::Instant::now();
let mut next_token = {
let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)?
// TODO Remove this once implementation is finished.
let logits = logits.ones_like()?;
// logits_processor.sample(&logits)?
15043
};
let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
print_token(next_token, &tokenizer);
let eos_token = *tokenizer.get_vocab(true).get("</s>").unwrap();
let eos_token = if args.which.is_open_chat() {
"<|end_of_turn|>"
} else {
"</s>"
};
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
let start_post_prompt = std::time::Instant::now();
let mut sampled = 0;
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
@ -442,21 +397,16 @@ fn main() -> anyhow::Result<()> {
&all_tokens[start_at..],
)?
};
next_token = logits_processor.sample(&logits)?;
// TODO Remove this once implementation is finished.
// let logits = logits.ones_like()?;
// next_token = logits_processor.sample(&logits)?;
let next_token = 15043;
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
sampled += 1;
print_token(next_token, &tokenizer);
if next_token == eos_token {
break;
};
}
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
let dt = start_post_prompt.elapsed();
println!(
"\n\n{:4} prompt tokens processed: {:.2} token/s",
@ -464,8 +414,9 @@ fn main() -> anyhow::Result<()> {
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
);
println!(
"{sampled:4} tokens generated: {:.2} token/s",
sampled as f64 / dt.as_secs_f64(),
"{:4} tokens generated: {:.2} token/s",
to_sample,
to_sample as f64 / dt.as_secs_f64(),
);
match prompt {

View File

@ -416,7 +416,7 @@ fn run(args: Args) -> Result<()> {
println!("Building the autoencoder.");
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
let vae = sd_config.build_vae(vae_weights, &device, dtype)?;
let vae = sd_config.build_vae(&vae_weights, &device, dtype)?;
let init_latent_dist = match &img2img {
None => None,
Some(image) => {
@ -426,7 +426,7 @@ fn run(args: Args) -> Result<()> {
};
println!("Building the unet.");
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?;
let unet = sd_config.build_unet(unet_weights, &device, 4, use_flash_attn, dtype)?;
let unet = sd_config.build_unet(&unet_weights, &device, 4, use_flash_attn, dtype)?;
let t_start = if img2img.is_some() {
n_steps - (n_steps as f64 * img2img_strength) as usize

View File

@ -5,26 +5,12 @@
```bash
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode
...
Running on CPU, to run on GPU, build this example with `--features cuda`
Eine schöne Kerze.
9 tokens generated (2.42 token/s)
```
Variants such as [flan-t5](https://huggingface.co/google/flan-t5-small), [flan-ul2](https://huggingface.co/google/flan-ul2) (with `--revision "refs/pr/25"`), and [Co-EdIT](https://huggingface.co/grammarly/coedit-large) are also supported.
## Translation with [MADLAD-400](https://arxiv.org/abs/2309.04662)
MADLAD-400 is a series of multilingual machine translation T5 models trained on 250 billion tokens covering over 450 languages using publicly available data. These models are competitive with significantly larger models.
```bash
cargo run --example t5 --release -- \
--model-id "jbochi/madlad400-3b-mt" \
--prompt "<2de> How are you, my friend?" \
--decode --temperature 0
...
Wie geht es dir, mein Freund?
```
## Sentence embedding example
## Sentence embedding example:
```bash
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle."

View File

@ -104,17 +104,6 @@ impl T5ModelBuilder {
api.get("model-00004-of-00005.safetensors")?,
api.get("model-00005-of-00005.safetensors")?,
]
} else if model_id == "google/flan-ul2" {
vec![
api.get("model-00001-of-00008.safetensors")?,
api.get("model-00002-of-00008.safetensors")?,
api.get("model-00003-of-00008.safetensors")?,
api.get("model-00004-of-00008.safetensors")?,
api.get("model-00005-of-00008.safetensors")?,
api.get("model-00006-of-00008.safetensors")?,
api.get("model-00007-of-00008.safetensors")?,
api.get("model-00008-of-00008.safetensors")?,
]
} else {
vec![api.get("model.safetensors")?]
};
@ -183,12 +172,7 @@ fn main() -> Result<()> {
println!("Took {:?}", start.elapsed());
} else {
let mut model = builder.build_conditional_generation()?;
let mut output_token_ids = [builder
.config
.decoder_start_token_id
.unwrap_or(builder.config.pad_token_id)
as u32]
.to_vec();
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
if let Some(decoder_prompt) = &args.decoder_prompt {
print!("{decoder_prompt}");
output_token_ids.extend(

Binary file not shown.

Before

Width:  |  Height:  |  Size: 36 KiB

View File

@ -1,154 +0,0 @@
use image::{DynamicImage, ImageBuffer};
use serde::Deserialize;
use std::collections::HashMap;
use candle::{DType, Device, Result, Tensor};
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct ProcessorConfig {
do_resize: bool,
height: u32,
width: u32,
do_rescale: bool,
do_normalize: bool,
image_mean: Vec<f32>,
image_std: Vec<f32>,
}
impl Default for ProcessorConfig {
fn default() -> Self {
Self {
do_resize: true,
height: 384,
width: 384,
do_rescale: true,
do_normalize: true,
image_mean: vec![0.5, 0.5, 0.5],
image_std: vec![0.5, 0.5, 0.5],
}
}
}
pub struct ViTImageProcessor {
do_resize: bool,
height: u32,
width: u32,
do_normalize: bool,
image_mean: Vec<f32>,
image_std: Vec<f32>,
}
impl ViTImageProcessor {
pub fn new(config: &ProcessorConfig) -> Self {
Self {
do_resize: config.do_resize,
height: config.height,
width: config.width,
do_normalize: config.do_normalize,
image_mean: config.image_mean.clone(),
image_std: config.image_std.clone(),
}
}
pub fn preprocess(&self, images: Vec<&str>) -> Result<Tensor> {
let height = self.height as usize;
let width = self.width as usize;
let channels = 3;
let images = self.load_images(images)?;
let resized_images: Vec<DynamicImage> = if self.do_resize {
images
.iter()
.map(|image| self.resize(image.clone(), None).unwrap())
.collect()
} else {
images
};
let normalized_images: Vec<Tensor> = if self.do_normalize {
resized_images
.iter()
.map(|image| self.normalize(image.clone(), None, None).unwrap())
.collect()
} else {
let resized_images: Vec<ImageBuffer<image::Rgb<u8>, Vec<u8>>> =
resized_images.iter().map(|image| image.to_rgb8()).collect();
let data = resized_images
.into_iter()
.map(|image| image.into_raw())
.collect::<Vec<Vec<u8>>>();
data.iter()
.map(|image| {
Tensor::from_vec(image.clone(), (height, width, channels), &Device::Cpu)
.unwrap()
.permute((2, 0, 1))
.unwrap()
})
.collect::<Vec<Tensor>>()
};
Tensor::stack(&normalized_images, 0)
}
fn resize(
&self,
image: image::DynamicImage,
size: Option<HashMap<String, u32>>,
) -> Result<image::DynamicImage> {
let (height, width) = match &size {
Some(size) => (size.get("height").unwrap(), size.get("width").unwrap()),
None => (&self.height, &self.width),
};
let resized_image =
image.resize_exact(*width, *height, image::imageops::FilterType::Triangle);
Ok(resized_image)
}
fn normalize(
&self,
image: image::DynamicImage,
mean: Option<Vec<f32>>,
std: Option<Vec<f32>>,
) -> Result<Tensor> {
let mean = match mean {
Some(mean) => mean,
None => self.image_mean.clone(),
};
let std = match std {
Some(std) => std,
None => self.image_std.clone(),
};
let mean = Tensor::from_vec(mean, (3, 1, 1), &Device::Cpu)?;
let std = Tensor::from_vec(std, (3, 1, 1), &Device::Cpu)?;
let image = image.to_rgb8();
let data = image.into_raw();
let height = self.height as usize;
let width = self.width as usize;
let channels = 3;
let data =
Tensor::from_vec(data, &[height, width, channels], &Device::Cpu)?.permute((2, 0, 1))?;
(data.to_dtype(DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)
}
pub fn load_images(&self, image_path: Vec<&str>) -> Result<Vec<image::DynamicImage>> {
let mut images: Vec<image::DynamicImage> = Vec::new();
for path in image_path {
let img = image::io::Reader::open(path)?.decode().unwrap();
images.push(img);
}
Ok(images)
}
}

View File

@ -1,132 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::{Parser, ValueEnum};
use candle::{DType, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::models::trocr;
use tokenizers::Tokenizer;
mod image_processor;
#[derive(Clone, Debug, Copy, ValueEnum)]
enum Which {
Base,
Large,
}
#[derive(Parser, Debug)]
struct Args {
#[arg(long)]
model: Option<String>,
/// Choose the variant of the model to run.
#[arg(long, default_value = "base")]
which: Which,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Text to be translated
#[arg(long)]
image: String,
}
pub fn main() -> anyhow::Result<()> {
use hf_hub::api::sync::Api;
let args = Args::parse();
let tokenizer_dec = {
let tokenizer = Api::new()?
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
.get("tokenizer.json")?;
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
};
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
let device = candle_examples::device(args.cpu)?;
let vb = {
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => match args.which {
Which::Base => Api::new()?
.repo(hf_hub::Repo::with_revision(
"microsoft/trocr-base-handwritten".to_string(),
hf_hub::RepoType::Model,
"refs/pr/3".to_string(),
))
.get("model.safetensors")?,
Which::Large => Api::new()?
.repo(hf_hub::Repo::with_revision(
"microsoft/trocr-large-handwritten".to_string(),
hf_hub::RepoType::Model,
"refs/pr/6".to_string(),
))
.get("model.safetensors")?,
},
};
println!("model: {:?}", model);
unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }
};
let encoder_config = match args.which {
Which::Base => candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten(),
Which::Large => {
candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten()
}
};
let decoder_config = trocr::TrOCRConfig::default();
let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?;
let config = image_processor::ProcessorConfig::default();
let processor = image_processor::ViTImageProcessor::new(&config);
let image = vec![args.image.as_str()];
let image = processor.preprocess(image)?;
let encoder_xs = model.encoder().forward(&image)?;
let mut logits_processor =
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
let mut token_ids: Vec<u32> = vec![decoder_config.decoder_start_token_id];
for index in 0..1000 {
let context_size = if index >= 1 { 1 } else { token_ids.len() };
let start_pos = token_ids.len().saturating_sub(context_size);
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
let logits = model.decode(&input_ids, &encoder_xs, start_pos)?;
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
token_ids.push(token);
if let Some(t) = tokenizer_dec.next_token(token)? {
use std::io::Write;
print!("{t}");
std::io::stdout().flush()?;
}
if token == decoder_config.eos_token_id {
break;
}
}
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
println!();
Ok(())
}

View File

@ -1,16 +0,0 @@
# candle-trocr
`TrOCR` is a transformer OCR Model. In this example it is used to
transcribe image text. See the associated [model
card](https://huggingface.co/microsoft/trocr-base-printed) for details on
the model itself.
## Running an example
```bash
cargo run --example trocr --release -- --which base --cpu --image candle-examples/examples/trocr/assets/trocr.png
```
```
<s> industry , Mr. Brown commented icily . " Let us have a</s>
```

View File

@ -128,13 +128,7 @@ impl Decoder {
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
let no_speech_token = m::NO_SPEECH_TOKENS
.iter()
.find_map(|token| token_id(&tokenizer, token).ok());
let no_speech_token = match no_speech_token {
None => anyhow::bail!("unable to find any non-speech token"),
Some(n) => n,
};
let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
Ok(Self {
model,
rng: rand::rngs::StdRng::seed_from_u64(seed),
@ -351,7 +345,7 @@ enum Task {
Translate,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
#[derive(Clone, Copy, Debug, ValueEnum)]
enum WhichModel {
Tiny,
#[value(name = "tiny.en")]
@ -367,27 +361,15 @@ enum WhichModel {
MediumEn,
Large,
LargeV2,
LargeV3,
#[value(name = "distil-medium.en")]
DistilMediumEn,
#[value(name = "distil-large-v2")]
DistilLargeV2,
}
impl WhichModel {
fn is_multilingual(&self) -> bool {
match self {
Self::Tiny
| Self::Base
| Self::Small
| Self::Medium
| Self::Large
| Self::LargeV2
| Self::LargeV3
| Self::DistilLargeV2 => true,
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => {
false
Self::Tiny | Self::Base | Self::Small | Self::Medium | Self::Large | Self::LargeV2 => {
true
}
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn => false,
}
}
@ -403,9 +385,6 @@ impl WhichModel {
Self::MediumEn => ("openai/whisper-medium.en", "main"),
Self::Large => ("openai/whisper-large", "refs/pr/36"),
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
Self::LargeV3 => ("openai/whisper-large-v3", "main"),
Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"),
Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"),
}
}
}
@ -517,21 +496,17 @@ fn main() -> Result<()> {
repo.get(&format!("model-{ext}-q80.gguf"))?,
)
} else {
let config = repo.get("config.json")?;
let tokenizer = repo.get("tokenizer.json")?;
let model = repo.get("model.safetensors")?;
(config, tokenizer, model)
(
repo.get("config.json")?,
repo.get("tokenizer.json")?,
repo.get("model.safetensors")?,
)
};
(config, tokenizer, model, sample)
};
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let mel_bytes = match config.num_mel_bins {
80 => include_bytes!("melfilters.bytes").as_slice(),
128 => include_bytes!("melfilters128.bytes").as_slice(),
nmel => anyhow::bail!("unexpected num_mel_bins {nmel}"),
};
let mel_bytes = include_bytes!("melfilters.bytes");
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
@ -547,15 +522,12 @@ fn main() -> Result<()> {
.map(|v| *v as f32 / 32768.)
.collect();
println!("pcm data loaded {}", pcm_data.len());
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters);
let mel_len = mel.len();
let mel = Tensor::from_vec(
mel,
(1, config.num_mel_bins, mel_len / config.num_mel_bins),
&device,
)?;
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
println!("loaded mel: {:?}", mel.dims());
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let mut model = if args.quantized {
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;

View File

@ -1,268 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum};
use candle_transformers::models::yi::{Config, Model};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
#[value(name = "6b")]
L6b,
#[value(name = "34b")]
L34b,
}
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <|endoftext|> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
#[arg(long, default_value = "01-ai/Yi-6B")]
model_id: String,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model size to use.
#[arg(long, default_value = "6b")]
which: Which,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
args.model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => match args.which {
Which::L6b => vec![
repo.get("model-00001-of-00002.safetensors")?,
repo.get("model-00002-of-00002.safetensors")?,
],
Which::L34b => vec![
repo.get("model-00001-of-00007.safetensors")?,
repo.get("model-00002-of-00007.safetensors")?,
repo.get("model-00003-of-00007.safetensors")?,
repo.get("model-00004-of-00007.safetensors")?,
repo.get("model-00005-of-00007.safetensors")?,
repo.get("model-00006-of-00007.safetensors")?,
repo.get("model-00007-of-00007.safetensors")?,
],
},
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = match args.which {
Which::L6b => Config::config_6b(),
Which::L34b => Config::config_34b(),
};
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -43,7 +43,6 @@ pub fn report(
confidence_threshold: f32,
nms_threshold: f32,
) -> Result<DynamicImage> {
let pred = pred.to_device(&Device::Cpu)?;
let (npreds, pred_size) = pred.dims2()?;
let nclasses = pred_size - 5;
// The bounding boxes grouped by (maximum) class index.

View File

@ -32,7 +32,7 @@ Image source:
### Pose Estimation
```bash
cargo run --example yolo-v8 --release -- \
candle-examples/examples/yolo-v8/assets/bike.jpg --task pose
candle-examples/examples/yolo-v8/assets/peoples.jpeg --task pose
```
![Leading group, Giro d'Italia 2021](./assets/bike.pose.jpg)

View File

@ -7,7 +7,7 @@ extern crate accelerate_src;
mod model;
use model::{Multiples, YoloV8, YoloV8Pose};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle::{DType, IndexOp, Result, Tensor};
use candle_nn::{Module, VarBuilder};
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
use clap::{Parser, ValueEnum};
@ -61,7 +61,6 @@ pub fn report_detect(
nms_threshold: f32,
legend_size: u32,
) -> Result<DynamicImage> {
let pred = pred.to_device(&Device::Cpu)?;
let (pred_size, npreds) = pred.dims2()?;
let nclasses = pred_size - 4;
// The bounding boxes grouped by (maximum) class index.
@ -154,7 +153,6 @@ pub fn report_pose(
confidence_threshold: f32,
nms_threshold: f32,
) -> Result<DynamicImage> {
let pred = pred.to_device(&Device::Cpu)?;
let (pred_size, npreds) = pred.dims2()?;
if pred_size != 17 * 3 + 4 + 1 {
candle::bail!("unexpected pred-size {pred_size}");

View File

@ -8,22 +8,24 @@ use candle::{Device, Result, Tensor};
pub fn device(cpu: bool) -> Result<Device> {
if cpu {
Ok(Device::Cpu)
} else if cuda_is_available() {
Ok(Device::new_cuda(0)?)
} else if metal_is_available() {
Ok(Device::new_metal(0)?)
} else {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
println!(
"Running on CPU, to run on GPU(metal), build this example with `--features metal`"
);
if cuda_is_available() {
Ok(Device::new_cuda(0)?)
} else if metal_is_available() {
Ok(Device::new_metal(0)?)
} else {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
println!("Running on CPU, to run on GPU(metal), build this example with `--features metal`");
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
println!(
"Running on CPU, to run on GPU, build this example with `--features cuda`"
);
}
Ok(Device::Cpu)
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
}
Ok(Device::Cpu)
}
}

View File

@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
version = "0.3.1"
version = "0.3.0"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], version = "0.3.1", package = "candle-core" }
candle = { path = "../candle-core", features = ["cuda"], version = "0.3.0", package = "candle-core" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]
@ -21,4 +21,4 @@ rayon = "1.7.0"
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
candle-nn = { path = "../candle-nn", version = "0.3.1", features = ["cuda"] }
candle-nn = { path = "../candle-nn", version = "0.3.0", features = ["cuda"] }

View File

@ -233,8 +233,8 @@ impl FlashAttnVarLen {
let (seqlens_q, seqlens_q_layout) = self.seqlens_q.storage_and_layout();
let seqlens_q = match &*seqlens_q {
candle::Storage::Cpu(_) => candle::bail!("seqlens_q must be a cuda tensor"),
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
_ => candle::bail!("seqlens_q must be a cuda tensor"),
};
let seqlens_q = match seqlens_q_layout.contiguous_offsets() {
Some((o1, o2)) => seqlens_q.slice(o1..o2),
@ -243,8 +243,8 @@ impl FlashAttnVarLen {
let (seqlens_k, seqlens_k_layout) = self.seqlens_k.storage_and_layout();
let seqlens_k = match &*seqlens_k {
candle::Storage::Cpu(_) => candle::bail!("seqlens_k must be a cuda tensor"),
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
_ => candle::bail!("seqlens_k must be a cuda tensor"),
};
let seqlens_k = match seqlens_k_layout.contiguous_offsets() {
Some((o1, o2)) => seqlens_k.slice(o1..o2),

View File

@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
version = "0.3.1"
version = "0.3.0"
edition = "2021"
description = "CUDA kernels for Candle"
@ -14,4 +14,4 @@ license = "MIT OR Apache-2.0"
[build-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
glob = "0.3.1"
rayon = "1.7.0"
rayon = "1.7.0"

View File

@ -1,21 +1,17 @@
[package]
name = "candle-metal-kernels"
version = "0.3.1"
edition = "2021"
description = "Metal kernels for Candle"
repository = "https://github.com/huggingface/candle"
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT OR Apache-2.0"
version.workspace = true
edition.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true
[dependencies]
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
metal-flash-attention = { path = "../../../metal-flash-attention" }
metal = { workspace = true }
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"
thiserror = { workspace = true }
[dev-dependencies]
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
rand = "0.8.5"
half = { workspace = true }

View File

@ -24,32 +24,15 @@ kernel void FN_NAME( \
constant float &add, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
if (id >= dim) { \
return; \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = input[i] * mul + add; \
} \
const TYPENAME m = TYPENAME(mul); \
const TYPENAME a = TYPENAME(add); \
output[id] = input[id] * m + a; \
} \
kernel void FN_NAME##_strided( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant float &mul, \
constant float &add, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
const TYPENAME m = TYPENAME(mul); \
const TYPENAME a = TYPENAME(add); \
output[id] = input[get_strided_index(id, num_dims, dims, strides)] * m + a; \
} \
AFFINE(affine_float, float)

View File

@ -23,14 +23,17 @@ kernel void FN_NAME( \
device const TYPENAME *left, \
device const TYPENAME *right, \
device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
if (thread_position_in_grid >= dim) { \
return; \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
TYPENAME x = left[i]; \
TYPENAME y = right[i]; \
output[i] = OUT_TYPENAME(FN); \
} \
TYPENAME x = left[thread_position_in_grid]; \
TYPENAME y = right[thread_position_in_grid]; \
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@ -41,14 +44,17 @@ kernel void FN_NAME_STRIDED( \
device const TYPENAME *left, \
device const TYPENAME *right, \
device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
if (thread_position_in_grid >= dim) { \
return; \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
TYPENAME x = left[get_strided_index(i, num_dims, dims, left_strides)]; \
TYPENAME y = left[get_strided_index(i, num_dims, dims, right_strides)]; \
output[i] = OUT_TYPENAME(FN); \
} \
TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \
TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, right_strides)]; \
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
}
#define BINARY_OP(FN, NAME) \

View File

@ -23,12 +23,15 @@ kernel void FN_NAME( \
constant size_t &dim, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
if (tid >= dim) { \
return; \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = RIGHT_TYPENAME(input[i]); \
} \
output[tid] = RIGHT_TYPENAME(input[tid]); \
} \
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@ -37,19 +40,19 @@ kernel void FN_NAME_STRIDED( \
constant size_t *strides, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
if (tid >= dim) { \
return; \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
} \
output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \
} \
}
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
#if __METAL_VERSION__ >= 310
#endif

View File

@ -1,36 +1,6 @@
#include <metal_stdlib>
using namespace metal;
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \
constant size_t &dst_size, \
constant size_t &left_size, \
constant size_t &src_dim_size, \
constant size_t &right_size, \
constant size_t &ids_size, \
const device TYPENAME *input, \
const device INDEX_TYPENAME *input_ids, \
device TYPENAME *output, \
uint gid [[ thread_position_in_grid ]] \
) { \
if (gid >= dst_size) { \
return; \
} \
const size_t id_i = (gid / right_size) % ids_size; \
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
const size_t right_rank_i = gid % right_size; \
const size_t left_rank_i = gid / right_size / ids_size; \
/* \
// Force prevent out of bounds indexing \
// since there doesn't seem to be a good way to force crash \
// No need to check for zero we're only allowing unsized. \
*/ \
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \
output[gid] = input[src_i]; \
}
template <typename T, typename I>
void index_add(
device I *ids [[buffer(0)]],
@ -42,9 +12,12 @@ void index_add(
constant uint &dst_dim_size,
constant uint &right_size,
uint gid [[ thread_position_in_grid ]] \
uint threadgroup_size [[threads_per_threadgroup]],
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]],
uint thread_index [[thread_index_in_threadgroup]]
) {
const uint gid = thread_index + (threadgroup_position_in_grid * threadgroup_size);
if (gid >= left_size * right_size) {
return;
}
@ -70,13 +43,12 @@ kernel void FN_NAME( \
constant uint &left_size, \
constant uint &dst_dim_size, \
constant uint &right_size, \
uint gid [[ thread_position_in_grid ]] \
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \
uint threadgroup_size [[threads_per_threadgroup]], \
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \
INDEX_OP(is_u32_f32, uint, float)
INDEX_OP(is_u32_f16, uint, half)
#if __METAL_VERSION__ >= 310
IA_OP(bfloat, int64_t, ia_i64_bf16)

File diff suppressed because it is too large Load Diff

View File

@ -1,143 +0,0 @@
#include <metal_stdlib>
using namespace metal;
#define MAX(x, y) ((x) > (y) ? (x) : (y))
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
constant int THREADGROUP_SIZE = 1024;
# define REDUCE(FN, NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup float shared_memory[THREADGROUP_SIZE]; \
\
shared_memory[tid] = 0; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
*/ \
T x = shared_memory[tid]; \
T y = src[idx]; \
shared_memory[tid] = FN; \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
T x = shared_memory[tid]; \
T y = shared_memory[tid + s]; \
shared_memory[tid] = FN; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
dst[dst_id] = shared_memory[0]; \
} \
REDUCE(x + y, fast_sum_float, float)
REDUCE(x * y, fast_mul_float, float)
REDUCE(max(x, y), fast_max_float, float)
#define SOFTMAX(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
\
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
threadgroup float shared_memory[THREADGROUP_SIZE]; \
shared_memory[tid] = -INFINITY; \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
size_t idx = start_idx + tid; \
\
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
while (idx < stop_idx) { \
shared_memory[tid] = MAX(shared_memory[tid], src[idx]); \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
} \
} \
\
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
float _max = shared_memory[0]; \
\
shared_memory[tid] = 0; \
\
idx = start_idx + tid; \
while (idx < stop_idx) { \
const T val = T(exp(src[idx] - _max)); \
dst[idx] = val; \
shared_memory[tid] += val; \
idx += block_dim; \
} \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
shared_memory[tid] += shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
} \
\
const T inv_acc = T(1/shared_memory[0]); \
idx = start_idx + tid; \
while (idx < stop_idx) { \
dst[idx] *= inv_acc; \
idx += block_dim; \
} \
} \
SOFTMAX(softmax_float, float)
SOFTMAX(softmax_half, half)
#if __METAL_VERSION__ >= 310
SOFTMAX(softmax_bfloat, bfloat)
#endif

View File

@ -1,60 +0,0 @@
#include <metal_stdlib>
#
using namespace metal;
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \
kernel void FN_NAME( \
constant size_t &numel, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t *strides_t, \
constant size_t *strides_f, \
device const ID_TYPENAME *ids, \
device const TYPENAME *t, \
device const TYPENAME *f, \
device TYPENAME *out ,\
uint i [[ thread_position_in_grid ]] \
) { \
if (i >= numel){ \
return; \
} \
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \
} \
// WHERE_OP(float, int64_t, where_i64_f32)
// WHERE_OP(double, int64_t, where_i64_f64)
// WHERE_OP(uint8_t, int64_t, where_i64_u8)
// WHERE_OP(uint32_t, int64_t, where_i64_u32)
// WHERE_OP(int64_t, int64_t, where_i64_i64)
//
// WHERE_OP(float, uint32_t, where_u32_f32)
// WHERE_OP(double, uint32_t, where_u32_f64)
// WHERE_OP(uint8_t, uint32_t, where_u32_u8)
// WHERE_OP(uint32_t, uint32_t, where_u32_u32)
// WHERE_OP(int64_t, uint32_t, where_u32_i64)
WHERE_OP(float, uint8_t, where_u8_f32)
// WHERE_OP(double, uint8_t, where_u8_f64)
// WHERE_OP(uint8_t, uint8_t, where_u8_u8)
// WHERE_OP(uint32_t, uint8_t, where_u8_u32)
// WHERE_OP(int64_t, uint8_t, where_u8_i64)

View File

@ -1,727 +0,0 @@
use super::*;
use half::{bf16, f16};
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
let options = MTLResourceOptions::StorageModeManaged;
let ptr = data.as_ptr() as *const core::ffi::c_void;
let size = (data.len() * std::mem::size_of::<T>()) as u64;
device.new_buffer_with_data(ptr, size, options)
}
fn device() -> Device {
Device::system_default().unwrap()
}
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
let b = 10f32.powi(digits);
v.iter().map(|t| f32::round(t * b) / b).collect()
}
fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> {
let b = 10f32.powi(digits);
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
}
fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> {
let b = 10f32.powi(digits);
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
}
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let output = new_buffer(&device, v);
call_unary_contiguous(
&device,
command_buffer,
&kernels,
name,
v.len(),
&input,
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(v.len())
}
fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let left = new_buffer(&device, x);
let right = new_buffer(&device, y);
let output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
call_binary_contiguous(
&device,
command_buffer,
&kernels,
name,
x.len(),
&left,
&right,
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(x.len())
}
fn run_strided<T: Clone>(
v: &[T],
kernel: unary::strided::Kernel,
shape: &[usize],
strides: &[usize],
offset: usize,
) -> Vec<T> {
let device = device();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let output = new_buffer(&device, v);
let kernels = Kernels::new();
call_unary_strided(
&device,
command_buffer,
&kernels,
kernel,
shape,
&input,
strides,
offset,
&output,
0,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(v.len())
}
#[test]
fn cos_f32() {
let v = vec![1.0f32, 2.0, 3.0];
let results = run(&v, unary::contiguous::cos::FLOAT);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
let v = vec![1.0f32; 10_000];
let results = run(&v, unary::contiguous::cos::FLOAT);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
}
#[test]
fn cos_f32_strided() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let shape = vec![6];
let strides = vec![1];
let offset = 0;
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(
approx(results, 4),
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
);
assert_eq!(
approx(expected, 4),
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
);
// Contiguous
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let shape = vec![3, 2];
let strides = vec![2, 1];
let offset = 0;
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(
approx(results, 4),
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
);
assert_eq!(
approx(expected, 4),
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
);
// Transposed
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let shape = vec![3, 2];
let strides = vec![1, 3];
let offset = 0;
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(
approx(results, 4),
vec![0.5403, -0.6536, -0.4161, 0.2837, -0.99, 0.9602]
);
assert_eq!(
approx(expected, 4),
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
);
// Very large
let v = vec![1.0f32; 10_000];
let shape = vec![2, 5_000];
let strides = vec![2, 1];
let offset = 0;
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
}
#[test]
fn cos_strided_random() {
let v: Vec<_> = (0..10_000).map(|_| rand::random::<f32>()).collect();
let shape = vec![5_000, 2];
let strides = vec![1, 5_000];
let offset = 0;
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(approx(vec![results[0]], 4), approx(vec![expected[0]], 4));
assert_eq!(
approx(vec![results[1]], 4),
approx(vec![expected[5_000]], 4)
);
assert_eq!(approx(vec![results[2]], 4), approx(vec![expected[1]], 4));
assert_eq!(
approx(vec![results[3]], 4),
approx(vec![expected[5_001]], 4)
);
assert_eq!(
approx(vec![results[5_000]], 4),
approx(vec![expected[2_500]], 4)
);
}
#[test]
fn binary_add_f32() {
let left = vec![1.0f32, 2.0, 3.0];
let right = vec![2.0f32, 3.1, 4.2];
let results = run_binary(&left, &right, binary::contiguous::add::FLOAT);
let expected: Vec<_> = left
.iter()
.zip(right.iter())
.map(|(&x, &y)| x + y)
.collect();
assert_eq!(approx(results, 4), vec![3.0f32, 5.1, 7.2]);
assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]);
}
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let options = MTLResourceOptions::StorageModeManaged;
let size = (v.len() * std::mem::size_of::<U>()) as u64;
let output = device.new_buffer(size, options);
call_cast_contiguous(
&device,
command_buffer,
&kernels,
name,
v.len(),
&input,
0,
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<U>(v.len())
}
#[test]
fn cast_u32_f32() {
let v = vec![1u32, 2, 3];
let results = cast(&v, "cast_u32_f32");
let expected: Vec<_> = v.iter().map(|&v| v as f32).collect();
assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
let v = vec![1.0f32, 2.0, 3.0];
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
let results: Vec<f32> = cast(&input, "cast_f16_f32");
assert_eq!(results, vec![1.0f32, 2.0, 3.0]);
let v = vec![1.0f32; 10_000];
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
let results: Vec<f32> = cast(&input, "cast_f16_f32");
assert_eq!(results.len(), 10_000);
assert_eq!(&results[..10], vec![1.0f32; 10]);
assert_eq!(results, vec![1.0f32; 10_000]);
}
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let output = new_buffer(&device, v);
let size = v.len();
call_affine(
&device,
command_buffer,
&kernels,
"affine_float",
size,
&input,
&output,
mul as f32,
add as f32,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(v.len())
}
fn run_affine_strided<T: Clone>(
v: &[T],
shape: &[usize],
strides: &[usize],
mul: f64,
add: f64,
) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let output = new_buffer(&device, v);
call_affine_strided(
&device,
command_buffer,
&kernels,
"affine_float_strided",
shape,
&input,
strides,
0,
&output,
mul as f32,
add as f32,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
let len: usize = shape.iter().product();
output.read_to_vec::<T>(len)
}
#[test]
fn affine() {
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mul = 1.5;
let add = 1.1;
let result = run_affine(&input, mul, add);
assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]);
let input = [1.0f32; 40_000];
let mul = 1.5;
let add = 1.1;
let result = run_affine(&input, mul, add);
assert_eq!(result, vec![2.6; 40_000]);
}
#[test]
fn affine_strided() {
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mul = 1.5;
let add = 1.1;
let shape = [4];
let strides = [2];
let result = run_affine_strided(&input, &shape, &strides, mul, add);
// 1 on 2
assert_eq!(result, vec![2.6, 5.6, 8.6, 11.6]);
}
#[test]
fn index_select() {
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [5, 2];
let ids = [0u32, 4, 2];
let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim);
assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]);
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [2, 5];
let ids = [0u32, 1, 0];
let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim);
assert_eq!(
result,
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
);
}
#[test]
fn index_select_f16() {
let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
.into_iter()
.map(|x| f16::from_f32(x))
.collect();
let shape = [5, 2];
let ids = [0u32, 4, 2];
let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim);
assert_eq!(
approx_f16(result, 4),
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
);
}
#[test]
fn index_select_dim1() {
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [5, 2];
let ids = [0u32, 1, 0];
let dim = 1;
let result = run_index_select(&embedding, &shape, &ids, dim);
assert_eq!(
result,
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
);
}
fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
embeddings: &[T],
shape: &[usize],
ids: &[I],
dim: usize,
) -> Vec<T> {
let device = Device::system_default().expect("no device found");
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let embeddings_buffer = new_buffer(&device, &embeddings);
let ids_buffer = new_buffer(&device, &ids);
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
let dst_el = ids.len() * left_size * right_size;
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
let name = match core::mem::size_of::<T>() {
4 => "is_u32_f32",
2 => "is_u32_f16",
_ => unimplemented!(),
};
let kernels = Kernels::new();
call_index_select(
&device,
&command_buffer,
&kernels,
name,
shape,
ids.len(),
dim,
&embeddings_buffer,
&ids_buffer,
&dst_buffer,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
dst_buffer.read_to_vec::<T>(dst_el)
}
#[test]
fn index_add() {
let device = Device::system_default().expect("no device found");
let options = CompileOptions::new();
let library = device.new_library_with_source(INDEXING, &options).unwrap();
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let right = [1.0f32; 15];
let index = [0u32, 4, 2];
let ids_dim_size = index.len() as u32;
let dst_dim_size: u32 = 15;
let left_size: u32 = 3;
let right_size: u32 = 3;
let function = library.get_function("ia_u32_f32", None).unwrap();
let pipeline = device
.new_compute_pipeline_state_with_function(&function)
.unwrap();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
let index_buffer = new_buffer(&device, &index);
let inputs_buffer = new_buffer(&device, &left);
let outputs_buffer = new_buffer(&device, &right);
set_params!(
encoder,
(
&index_buffer,
&inputs_buffer,
&outputs_buffer,
ids_dim_size,
left_size,
dst_dim_size,
right_size
)
);
let grid_size = MTLSize {
width: right.len() as NSUInteger,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width: pipeline.max_total_threads_per_threadgroup(),
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(grid_size, thread_group_size);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
let expected = vec![
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
];
let result = outputs_buffer.read_to_vec::<f32>(right.len());
assert_eq!(result, expected);
}
#[test]
fn cos_f16() {
let v: Vec<f16> = [1.0f32, 2.0, 3.0]
.iter()
.map(|v| f16::from_f32(*v))
.collect();
let results = run(&v, unary::contiguous::cos::HALF);
let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
assert_eq!(approx_f16(results, 4), vec![0.5405, -0.4163, -0.9902]);
assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]);
}
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let options = MTLResourceOptions::StorageModeManaged;
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
call_reduce_contiguous(
&device,
command_buffer,
&kernels,
name,
v.len(),
out_length,
&input,
0,
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(out_length)
}
fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let output = new_buffer(&device, v);
call_last_softmax(
&device,
command_buffer,
&kernels,
name,
v.len(),
last_dim,
&input,
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(v.len())
}
#[test]
fn reduce_sum() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let out_length = 1;
let results = run_reduce(&v, out_length, "fast_sum_float");
assert_eq!(approx(results, 4), vec![21.0]);
}
#[test]
fn reduce_sum2() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let out_length = 2;
let results = run_reduce(&v, out_length, "fast_sum_float");
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
}
#[test]
fn softmax() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let last_dim = 6;
let results = run_softmax(&v, last_dim, "softmax_float");
assert_eq!(
approx(results, 4),
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
);
let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
let last_dim = 6;
let results = run_softmax(&v, last_dim, "softmax_float");
assert_eq!(
approx(results, 4),
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
);
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let last_dim = 3;
let results = run_softmax(&v, last_dim, "softmax_float");
assert_eq!(
approx(results, 4),
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
);
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
.iter()
.map(|v| f16::from_f32(*v))
.collect::<Vec<_>>();
let last_dim = 6;
let results = run_softmax(&v, last_dim, "softmax_half");
assert_eq!(
approx_f16(results, 4),
vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
);
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
.iter()
.map(|v| bf16::from_f32(*v))
.collect::<Vec<_>>();
let last_dim = 6;
let results = run_softmax(&v, last_dim, "softmax_bfloat");
assert_eq!(
approx_bf16(results, 4),
vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328]
);
}
fn run_where_cond<I: Clone, T: Clone>(
shape: &[usize],
cond: &[I],
(cond_stride, cond_offset): (Vec<usize>, usize),
left_true: &[T],
(left_stride, left_offset): (Vec<usize>, usize),
right_false: &[T],
(_right_stride, _right_offset): (Vec<usize>, usize),
name: &'static str,
) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let length = cond.len();
let cond = device.new_buffer_with_data(
cond.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(cond) as u64,
options,
);
let left = device.new_buffer_with_data(
left_true.as_ptr() as *const core::ffi::c_void,
(length * core::mem::size_of::<T>()) as u64,
options,
);
let right = device.new_buffer_with_data(
right_false.as_ptr() as *const core::ffi::c_void,
(length * core::mem::size_of::<T>()) as u64,
options,
);
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
call_where_cond_strided(
&device,
command_buffer,
&kernels,
name,
shape,
&cond,
(&cond_stride, cond_offset),
&left,
(&left_stride, left_offset),
&right,
(&cond_stride, cond_offset),
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(length)
}
#[test]
fn where_cond() {
let shape = vec![6];
let cond = vec![0u8, 1, 0, 0, 1, 1];
let cond_l = (vec![1], 0);
let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let left_l = (vec![1], 0);
let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0];
let right_l = (vec![1], 0);
let results = run_where_cond(
&shape,
&cond,
cond_l,
&left_true,
left_l,
&right_false,
right_l,
"where_u8_f32",
);
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
}

View File

@ -1,7 +1,4 @@
#include <metal_stdlib>
#include <metal_math>
#
using namespace metal;
METAL_FUNC uint get_strided_index(
uint idx,
@ -20,51 +17,25 @@ METAL_FUNC uint get_strided_index(
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
template <typename T> METAL_FUNC T neg(T in){ return -in; }
template <typename T> METAL_FUNC T erf(T in){
float x = (float) in;
// constants
float a1 = 0.254829592;
float a2 = -0.284496736;
float a3 = 1.421413741;
float a4 = -1.453152027;
float a5 = 1.061405429;
float p = 0.3275911;
// Save the sign of x
int sign = 1;
if (x < 0)
sign = -1;
x = fabs(x);
// A&S formula 7.1.26
float t = 1.0/(1.0 + p*x);
float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x);
return T(sign*y);
}
template <typename T> METAL_FUNC T id(T in){ return in; }
template <typename T> METAL_FUNC T gelu_erf(T x){ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); }
template <typename T> METAL_FUNC T gelu(T x){
T x_sq = x * x;
T x_cube = x_sq * x;
T alpha = x + static_cast<T>(0.044715) * x_cube;
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
}
using namespace metal;
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
constant size_t &dim, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
if (thread_position_in_grid >= dim) { \
return; \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = TYPENAME(FN(input[i])); \
} \
output[thread_position_in_grid] = TYPENAME(FN(input[thread_position_in_grid])); \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@ -73,12 +44,15 @@ kernel void FN_NAME_STRIDED( \
constant size_t *strides, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
if (thread_position_in_grid >= dim) { \
return; \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = TYPENAME(FN(input[get_strided_index(i, num_dims, dims, strides)])); \
} \
output[thread_position_in_grid] = TYPENAME(FN(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)])); \
}
#define UNARY_OP(NAME) \
@ -95,17 +69,8 @@ UNARY_OP(sqr)
UNARY_OP(sqrt)
UNARY_OP(neg)
UNARY_OP(exp)
UNARY_OP(log)
UNARY_OP(gelu)
UNARY_OP(ceil)
UNARY_OP(floor)
UNARY_OP(round)
UNARY_OP(gelu_erf)
UNARY_OP(erf)
UNARY(id, float, copy_float, copy_float_strided)
UNARY(id, half, copy_half, copy_half_strided)
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
#if __METAL_VERSION__ >= 310
BFLOAT_UNARY_OP(cos)
@ -114,13 +79,4 @@ BFLOAT_UNARY_OP(sqr)
BFLOAT_UNARY_OP(sqrt)
BFLOAT_UNARY_OP(neg)
BFLOAT_UNARY_OP(exp)
BFLOAT_UNARY_OP(log)
BFLOAT_UNARY_OP(gelu)
BFLOAT_UNARY_OP(ceil)
BFLOAT_UNARY_OP(floor)
BFLOAT_UNARY_OP(round)
BFLOAT_UNARY_OP(gelu_erf)
BFLOAT_UNARY_OP(erf)
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
#endif

View File

@ -1,76 +0,0 @@
use candle_metal_kernels::{call_affine, Kernels};
use metal::objc::rc::autoreleasepool;
use metal::{Device, MTLResourceOptions};
use rand;
use std::any::type_name;
use std::time::Instant;
fn main() {
let device = Device::system_default().unwrap();
let kernels = Kernels::new();
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
let f32_10k = (0..10000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f32_100k = (0..100000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
"dtype", "kernel", "size", "runs", "total time", "avg time"
);
// f32
run_affine_bench(&device, &kernels, &f32_1k);
run_affine_bench(&device, &kernels, &f32_10k);
run_affine_bench(&device, &kernels, &f32_100k);
}
fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) {
let command_queue = device.new_command_queue();
let options = MTLResourceOptions::StorageModeManaged;
let iterations = 10000;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(v) as u64,
options,
);
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
let mul: f32 = 1.2345;
let add: f32 = 2.3456;
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_affine(
&device,
command_buffer,
&kernels,
"affine_float",
v.len(),
&input,
&mut output,
mul,
add,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
"affine",
v.len(),
iterations,
total_time,
total_time / iterations
);
}

View File

@ -1,182 +0,0 @@
use candle_metal_kernels::{binary, call_binary_contiguous, call_binary_strided, Kernels};
use half::{bf16, f16};
use metal::objc::rc::autoreleasepool;
use metal::{Device, MTLResourceOptions};
use rand;
use std::any::type_name;
use std::time::Instant;
fn main() {
let device = Device::system_default().unwrap();
let kernels = Kernels::new();
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
let f32_10k = (0..10000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f32_100k = (0..100000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
let f16_1k = f16_map(&f32_1k);
let f16_10k = f16_map(&f32_10k);
let f16_100k = f16_map(&f32_100k);
let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
let bf16_1k = bf16_map(&f32_1k);
let bf16_10k = bf16_map(&f32_10k);
let bf16_100k = bf16_map(&f32_100k);
let f32_ckernels = [
binary::contiguous::add::FLOAT,
binary::contiguous::sub::FLOAT,
binary::contiguous::mul::FLOAT,
binary::contiguous::div::FLOAT,
];
let f32_skernels = [
binary::strided::add::FLOAT,
binary::strided::sub::FLOAT,
binary::strided::mul::FLOAT,
binary::strided::div::FLOAT,
];
let f16_ckernels = [
binary::contiguous::add::HALF,
binary::contiguous::sub::HALF,
binary::contiguous::mul::HALF,
binary::contiguous::div::HALF,
];
let f16_skernels = [
binary::strided::add::HALF,
binary::strided::sub::HALF,
binary::strided::mul::HALF,
binary::strided::div::HALF,
];
let bf16_ckernels = [
binary::contiguous::add::BFLOAT,
binary::contiguous::sub::BFLOAT,
binary::contiguous::mul::BFLOAT,
binary::contiguous::div::BFLOAT,
];
let bf16_skernels = [
binary::strided::add::BFLOAT,
binary::strided::sub::BFLOAT,
binary::strided::mul::BFLOAT,
binary::strided::div::BFLOAT,
];
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
"dtype", "kernel", "size", "runs", "total time", "avg time"
);
// f32
run_binary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels);
run_binary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels);
run_binary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels);
// f16
run_binary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels);
run_binary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels);
run_binary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels);
// bf16
run_binary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels);
run_binary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels);
run_binary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels);
}
fn run_binary_bench<T: Clone>(
device: &Device,
kernels: &Kernels,
v: &[T],
contiguous: [binary::contiguous::Kernel; 4],
strided: [binary::strided::Kernel; 4],
) {
let command_queue = device.new_command_queue();
let options = MTLResourceOptions::StorageModeManaged;
let iterations = 1000;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(v) as u64,
options,
);
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
// Contiguous
for kernel_name in contiguous {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_binary_contiguous(
device,
&command_buffer,
kernels,
kernel_name,
v.len(),
&input,
&input,
&mut output,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.to_string(),
v.len(),
iterations,
total_time,
total_time / iterations
);
}
// Strided
let shape = vec![2, 5_000];
let strides = vec![2, 1];
let offset = 0;
for kernel_name in strided {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_binary_strided(
device,
command_buffer,
&kernels,
kernel_name,
&shape,
&input,
&strides,
offset,
&input,
&strides,
offset,
&mut output,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.to_string(),
v.len(),
iterations,
total_time,
total_time / iterations
);
}
}

View File

@ -1,84 +0,0 @@
use candle_metal_kernels::{call_cast_contiguous, Kernels};
use metal::objc::rc::autoreleasepool;
use metal::{Device, MTLResourceOptions};
use rand;
use std::any::type_name;
use std::time::Instant;
fn main() {
let device = Device::system_default().unwrap();
let kernels = Kernels::new();
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
let f32_10k = (0..10000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f32_100k = (0..100000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let contiguous_kernels = ["cast_u32_f32"];
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
"dtype", "kernel", "size", "runs", "total time", "avg time"
);
// f32
run_cast_bench(&device, &kernels, &f32_1k, &contiguous_kernels);
run_cast_bench(&device, &kernels, &f32_10k, &contiguous_kernels);
run_cast_bench(&device, &kernels, &f32_100k, &contiguous_kernels);
}
fn run_cast_bench<T: Clone>(
device: &Device,
kernels: &Kernels,
v: &[T],
contiguous: &[&'static str],
) {
let command_queue = device.new_command_queue();
let options = MTLResourceOptions::StorageModeManaged;
let iterations = 1000;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(v) as u64,
options,
);
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
// Contiguous
for kernel_name in contiguous {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_cast_contiguous(
device,
&command_buffer,
kernels,
kernel_name,
v.len(),
&input,
&mut output,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.to_string(),
v.len(),
iterations,
total_time,
total_time / iterations
);
}
// Strided?
}

View File

@ -1,197 +0,0 @@
use candle_metal_kernels::{call_unary_contiguous, call_unary_strided, unary, Kernels};
use half::{bf16, f16};
use metal::objc::rc::autoreleasepool;
use metal::{Device, MTLResourceOptions};
use rand;
use std::any::type_name;
use std::time::Instant;
fn main() {
let device = Device::system_default().unwrap();
let kernels = Kernels::new();
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
let f32_10k = (0..10000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f32_100k = (0..100000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
let f16_1k = f16_map(&f32_1k);
let f16_10k = f16_map(&f32_10k);
let f16_100k = f16_map(&f32_100k);
let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
let bf16_1k = bf16_map(&f32_1k);
let bf16_10k = bf16_map(&f32_10k);
let bf16_100k = bf16_map(&f32_100k);
let f32_ckernels = [
unary::contiguous::sin::FLOAT,
unary::contiguous::cos::FLOAT,
unary::contiguous::exp::FLOAT,
unary::contiguous::sqr::FLOAT,
unary::contiguous::sqrt::FLOAT,
unary::contiguous::neg::FLOAT,
unary::contiguous::copy::FLOAT,
];
let f32_skernels = [
unary::strided::sin::FLOAT,
unary::strided::cos::FLOAT,
unary::strided::exp::FLOAT,
unary::strided::sqr::FLOAT,
unary::strided::sqrt::FLOAT,
unary::strided::neg::FLOAT,
unary::strided::copy::FLOAT,
];
let f16_ckernels = [
unary::contiguous::sin::HALF,
unary::contiguous::cos::HALF,
unary::contiguous::exp::HALF,
unary::contiguous::sqr::HALF,
unary::contiguous::sqrt::HALF,
unary::contiguous::neg::HALF,
unary::contiguous::copy::HALF,
];
let f16_skernels = [
unary::strided::sin::HALF,
unary::strided::cos::HALF,
unary::strided::exp::HALF,
unary::strided::sqr::HALF,
unary::strided::sqrt::HALF,
unary::strided::neg::HALF,
unary::strided::copy::HALF,
];
let bf16_ckernels = [
unary::contiguous::sin::BFLOAT,
unary::contiguous::cos::BFLOAT,
unary::contiguous::exp::BFLOAT,
unary::contiguous::sqr::BFLOAT,
unary::contiguous::sqrt::BFLOAT,
unary::contiguous::neg::BFLOAT,
unary::contiguous::copy::BFLOAT,
];
let bf16_skernels = [
unary::strided::sin::BFLOAT,
unary::strided::cos::BFLOAT,
unary::strided::exp::BFLOAT,
unary::strided::sqr::BFLOAT,
unary::strided::sqrt::BFLOAT,
unary::strided::neg::BFLOAT,
unary::strided::copy::BFLOAT,
];
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
"dtype", "kernel", "size", "runs", "total time", "avg time"
);
// f32
run_unary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels);
run_unary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels);
run_unary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels);
// f16
run_unary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels);
run_unary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels);
run_unary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels);
// bf16
run_unary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels);
run_unary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels);
run_unary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels);
}
fn run_unary_bench<T: Clone>(
device: &Device,
kernels: &Kernels,
v: &[T],
contiguous: [unary::contiguous::Kernel; 7],
strided: [unary::strided::Kernel; 7],
) {
let command_queue = device.new_command_queue();
let options = MTLResourceOptions::StorageModeManaged;
let iterations = 10000;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(v) as u64,
options,
);
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
// Contiguous
for kernel_name in contiguous {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_unary_contiguous(
device,
&command_buffer,
kernels,
kernel_name,
v.len(),
&input,
&mut output,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.0,
v.len(),
iterations,
total_time,
total_time / iterations
);
}
// Strided
let shape = vec![2, 5_000];
let strides = vec![2, 1];
let offset = 0;
for kernel_name in &strided {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_unary_strided(
device,
command_buffer,
&kernels,
kernel_name,
&shape,
&input,
&strides,
offset,
&mut output,
0,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.0,
v.len(),
iterations,
total_time,
total_time / iterations
);
}
}

View File

@ -11,15 +11,15 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
half = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
[dev-dependencies]
anyhow = { workspace = true }
@ -29,5 +29,5 @@ clap = { workspace = true }
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"]
metal = ["candle/metal"]
mkl = ["dep:intel-mkl-src", "candle/mkl"]
metal = ["candle/metal", "dep:candle-metal-kernels"]

View File

@ -6,7 +6,7 @@ extern crate intel_mkl_src;
extern crate accelerate_src;
use candle::quantized::GgmlType;
use candle::{CpuStorage, Device, Layout, Module, Result, Shape, Tensor, D};
use candle::{CpuStorage, Device, Layout, Result, Shape, Tensor, D};
use clap::{Parser, Subcommand};
const CHECK_CONV2D: bool = false;

View File

@ -6,16 +6,14 @@ use serde::Deserialize;
pub enum Activation {
#[default]
Gelu,
#[serde(rename = "gated-gelu")]
NewGelu,
Relu,
Relu2,
Relu6,
Silu,
Sigmoid,
HardSigmoid,
Swiglu,
Swish,
HardSwish,
Elu(f64),
LeakyRelu(f64),
}
@ -31,10 +29,7 @@ impl super::Module for Activation {
Self::Relu6 => xs.clamp(0f32, 6f32),
Self::Silu => crate::ops::silu(xs),
Self::Sigmoid => crate::ops::sigmoid(xs),
Self::HardSigmoid => crate::ops::hard_sigmoid(xs),
Self::Swiglu => crate::ops::swiglu(xs),
Self::Swish => xs * crate::ops::sigmoid(xs)?,
Self::HardSwish => xs * crate::ops::hard_sigmoid(xs)?,
&Self::Elu(alpha) => xs.elu(alpha),
&Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope),
}

View File

@ -70,67 +70,6 @@ impl crate::Module for Conv1d {
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ConvTranspose1dConfig {
pub padding: usize,
pub output_padding: usize,
pub stride: usize,
pub dilation: usize,
// TODO: support groups.
}
impl Default for ConvTranspose1dConfig {
fn default() -> Self {
Self {
padding: 0,
output_padding: 0,
stride: 1,
dilation: 1,
}
}
}
#[derive(Clone, Debug)]
pub struct ConvTranspose1d {
weight: Tensor,
bias: Option<Tensor>,
config: ConvTranspose1dConfig,
}
impl ConvTranspose1d {
pub fn new(weight: Tensor, bias: Option<Tensor>, config: ConvTranspose1dConfig) -> Self {
Self {
weight,
bias,
config,
}
}
pub fn config(&self) -> &ConvTranspose1dConfig {
&self.config
}
}
impl crate::Module for ConvTranspose1d {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv_transpose1d(
&self.weight,
self.config.padding,
self.config.output_padding,
self.config.stride,
self.config.dilation,
)?;
match &self.bias {
None => Ok(x),
Some(bias) => {
let b = bias.dims1()?;
let bias = bias.reshape((1, b, 1, 1))?;
Ok(x.broadcast_add(&bias)?)
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Conv2dConfig {
pub padding: usize,
@ -302,39 +241,6 @@ pub fn conv1d(
Ok(Conv1d::new(ws, Some(bs), cfg))
}
pub fn conv_transpose1d(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: ConvTranspose1dConfig,
vb: crate::VarBuilder,
) -> Result<ConvTranspose1d> {
let bound = 1. / (out_channels as f64 * kernel_size as f64).sqrt();
let init = crate::Init::Uniform {
lo: -bound,
up: bound,
};
let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?;
let bs = vb.get_with_hints(out_channels, "bias", init)?;
Ok(ConvTranspose1d::new(ws, Some(bs), cfg))
}
pub fn conv_transpose1d_no_bias(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: ConvTranspose1dConfig,
vb: crate::VarBuilder,
) -> Result<ConvTranspose1d> {
let bound = 1. / (out_channels as f64 * kernel_size as f64).sqrt();
let init = crate::Init::Uniform {
lo: -bound,
up: bound,
};
let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?;
Ok(ConvTranspose1d::new(ws, None, cfg))
}
pub fn conv2d(
in_channels: usize,
out_channels: usize,

View File

@ -95,14 +95,6 @@ impl LayerNorm {
eps,
}
}
pub fn weight(&self) -> &Tensor {
&self.weight
}
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
}
impl crate::Module for LayerNorm {

View File

@ -1,5 +1,6 @@
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
use rayon::prelude::*;
use tracing::debug;
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
@ -39,21 +40,11 @@ pub fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)?
}
pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
let xs = xs.chunk(2, candle::D::Minus1)?;
crate::ops::silu(&xs[0])? * &xs[1]
}
pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
// TODO: Should we have a specialized op for this?
(xs.neg()?.exp()? + 1.0)?.recip()
}
pub fn hard_sigmoid(xs: &Tensor) -> Result<Tensor> {
// TODO: Should we have a specialized op for this?
((xs + 3.0)? / 6.0)?.clamp(0f32, 1f32)
}
pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result<Tensor> {
let zeros = xs.zeros_like()?;
xs.maximum(&zeros)? + xs.minimum(&zeros)? * negative_slope
@ -208,38 +199,8 @@ impl candle::CustomOp1 for SoftmaxLastDim {
storage: &candle::MetalStorage,
layout: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::{backend::BackendStorage, DType};
let device = storage.device();
let command_buffer = device.command_buffer();
let kernels = device.kernels();
let name = match storage.dtype() {
DType::F32 => "softmax_float",
DType::F16 => "softmax_half",
DType::BF16 => "softmax_bfloat",
dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
};
let n = layout.stride().len();
if !(layout.stride()[n - 1] == 1 && layout.start_offset() == 0) {
candle::bail!("Non contiguous softmax-last-dim is not implemented");
}
let last_dim = layout.dims()[layout.shape().rank() - 1];
let elem_count = layout.shape().elem_count();
let mut output = device.new_buffer(elem_count, storage.dtype());
candle_metal_kernels::call_last_softmax(
device.metal_device(),
&command_buffer,
&kernels,
name,
elem_count,
last_dim,
storage.buffer(),
&mut output,
)
.unwrap();
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
Ok((newstorage, layout.shape().clone()))
debug!("TODO softmax-last-dim");
Ok((storage.clone(), layout.shape().clone()))
}
}

View File

@ -1,23 +0,0 @@
[package]
name = "candle-onnx"
version = "0.3.1"
edition = "2021"
description = "ONNX support for Candle"
repository = "https://github.com/huggingface/candle"
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.1" }
prost = "0.12.1"
[build-dependencies]
prost-build = "0.12.1"
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
clap = { version = "4.2.4", features = ["derive"] }

View File

@ -1,21 +0,0 @@
# candle-onnx
This crate adds ONNX support to candle
## FAQ
#### Missing protoc installation when compiling candle-onnx
The candle-onnx dependency prost-build no longer comes bundled with prost
binaries. This could cause the following error when attempting to compile
candle-onnx:
```
error: failed to run custom build command for `candle-onnx`
Caused by: // (...)
Could not find `protoc` installation and this build crate cannot proceed without this knowledge.
```
To fix this issue install protoc on your system and make it available in your
system `PATH`. See the [protoc
documentation](https://grpc.io/docs/protoc-installation/) for more information.

View File

@ -1,6 +0,0 @@
use std::io::Result;
fn main() -> Result<()> {
prost_build::compile_protos(&["src/onnx.proto3"], &["src/"])?;
Ok(())
}

View File

@ -1,774 +0,0 @@
use crate::onnx;
use crate::onnx::attribute_proto::AttributeType;
use crate::onnx::tensor_proto::DataType;
use candle::{bail, DType, Device, Result, Tensor};
use std::collections::HashMap;
pub type Value = Tensor;
pub fn dtype(dt: DataType) -> Option<DType> {
match dt {
DataType::Uint8 => Some(DType::U8),
DataType::Uint32 => Some(DType::U32),
DataType::Int64 => Some(DType::I64),
DataType::Float16 => Some(DType::F16),
DataType::Float => Some(DType::F32),
DataType::Double => Some(DType::F64),
_ => None,
}
}
trait Attr {
const TYPE: AttributeType;
fn get(attr: &onnx::AttributeProto) -> Result<&Self>;
}
impl Attr for i64 {
const TYPE: AttributeType = AttributeType::Int;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
Ok(&attr.i)
}
}
impl Attr for f32 {
const TYPE: AttributeType = AttributeType::Float;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
Ok(&attr.f)
}
}
impl Attr for [i64] {
const TYPE: AttributeType = AttributeType::Ints;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
Ok(attr.ints.as_slice())
}
}
impl Attr for str {
const TYPE: AttributeType = AttributeType::String;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
std::str::from_utf8(&attr.s).map_err(candle::Error::wrap)
}
}
fn get_attr_<'a>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a onnx::AttributeProto> {
match node.attribute.iter().find(|attr| attr.name == name) {
None => {
bail!(
"cannot find the '{name}' attribute in '{}' for {}",
node.op_type,
node.name
)
}
Some(dt) => Ok(dt),
}
}
fn get_attr<'a, T: Attr + ?Sized>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a T> {
let attr = get_attr_(node, name)?;
if attr.r#type() != T::TYPE {
bail!(
"unsupported type {:?} for '{name}' attribute in '{}' for {}",
attr.r#type,
node.op_type,
node.name
)
}
T::get(attr)
}
fn get_attr_opt<'a, T: Attr + ?Sized>(
node: &'a onnx::NodeProto,
name: &str,
) -> Result<Option<&'a T>> {
match node.attribute.iter().find(|attr| attr.name == name) {
None => Ok(None),
Some(attr) => {
if attr.r#type() != T::TYPE {
bail!(
"unsupported type {:?} for '{name}' attribute in '{}' for {}",
attr.r#type,
node.op_type,
node.name
)
}
let val = T::get(attr)?;
Ok(Some(val))
}
}
}
pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
match DataType::try_from(t.data_type) {
Ok(DataType::Int32) => {
if t.int32_data.is_empty() {
let len = t.raw_data.len() / 4;
let data: &[i32] =
unsafe { std::slice::from_raw_parts(t.raw_data.as_ptr() as *const i32, len) };
let data = data.iter().map(|v| *v as i64).collect::<Vec<_>>();
Tensor::from_vec(data, len, &Device::Cpu)
} else {
let data = t.int32_data.iter().map(|v| *v as i64).collect::<Vec<_>>();
Tensor::from_vec(data, t.int32_data.len(), &Device::Cpu)
}
}
Ok(dt) => match dtype(dt) {
Some(dt) => {
if dt == DType::F32 && !t.float_data.is_empty() {
Tensor::from_slice(&t.float_data, dims.as_slice(), &Device::Cpu)
} else if dt == DType::F64 && !t.double_data.is_empty() {
Tensor::from_slice(&t.double_data, dims.as_slice(), &Device::Cpu)
} else if dt == DType::I64 && !t.int64_data.is_empty() {
Tensor::from_slice(&t.int64_data, dims.as_slice(), &Device::Cpu)
} else {
Tensor::from_raw_buffer(
t.raw_data.as_slice(),
dt,
dims.as_slice(),
&Device::Cpu,
)
}
}
None => {
bail!("unsupported 'value' data-type {dt:?} for {name}")
}
},
Err(_) => {
bail!("unsupported 'value' data-type {} for {name}", t.data_type,)
}
}
}
// This function provides a direct evaluation of the proto.
// Longer-term, we should first convert the proto to an intermediate representation of the compute
// graph so as to make multiple evaluations more efficient.
// An example upside of this would be to remove intermediary values when they are not needed
// anymore.
pub fn simple_eval(
model: &onnx::ModelProto,
inputs: HashMap<String, Value>,
) -> Result<HashMap<String, Value>> {
let graph = match &model.graph {
None => bail!("no graph defined in proto"),
Some(graph) => graph,
};
let mut values = inputs;
for t in graph.initializer.iter() {
let tensor = get_tensor(t, t.name.as_str())?;
values.insert(t.name.to_string(), tensor);
}
for input in graph.input.iter() {
let input_type = match &input.r#type {
Some(input_type) => input_type,
None => continue,
};
let input_type = match &input_type.value {
Some(input_type) => input_type,
None => continue,
};
let tensor_type = match input_type {
onnx::type_proto::Value::TensorType(tt) => tt,
_ => continue,
};
let tensor = match values.get(&input.name) {
None => bail!("missing input {}", input.name),
Some(tensor) => tensor,
};
let dt = match DataType::try_from(tensor_type.elem_type) {
Ok(dt) => match dtype(dt) {
Some(dt) => dt,
None => {
bail!("unsupported 'value' data-type {dt:?} for {}", input.name)
}
},
type_ => bail!("unsupported input type {type_:?}"),
};
match &tensor_type.shape {
None => continue,
Some(shape) => {
if shape.dim.len() != tensor.rank() {
bail!(
"unexpected rank for {}, got {:?}, expected {:?}",
input.name,
shape.dim,
tensor.shape()
)
}
for (idx, (d, &dim)) in shape.dim.iter().zip(tensor.dims().iter()).enumerate() {
match &d.value {
Some(onnx::tensor_shape_proto::dimension::Value::DimValue(v)) => {
if *v as usize != dim {
bail!(
"unexpected dim {idx} for {}, got {:?}, expected {:?}",
input.name,
shape.dim,
tensor.shape()
)
}
}
// We do not check equality constraints for the DimParam dimensions for now.
Some(onnx::tensor_shape_proto::dimension::Value::DimParam(_)) | None => (),
}
}
}
};
if dt != tensor.dtype() {
bail!(
"unexpected dtype for {}, got {:?}, expected {dt:?}",
input.name,
tensor.dtype()
)
}
}
// The nodes are topologically sorted so we can just process them in order.
for node in graph.node.iter() {
let get = |input_name: &str| match values.get(input_name) {
Some(value) => Ok(value),
None => bail!("cannot find {input_name} for op {}", node.name),
};
// TODO: Validate node.input for each operator.
match node.op_type.as_str() {
"Add" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_add(input1)?;
values.insert(node.output[0].clone(), output);
}
"Sub" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_sub(input1)?;
values.insert(node.output[0].clone(), output);
}
"Mul" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_mul(input1)?;
values.insert(node.output[0].clone(), output);
}
"Div" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_div(input1)?;
values.insert(node.output[0].clone(), output);
}
"Equal" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_eq(input1)?;
values.insert(node.output[0].clone(), output);
}
"Not" => {
let xs = get(&node.input[0])?;
let xs = xs.eq(&xs.zeros_like()?)?;
values.insert(node.output[0].clone(), xs);
}
"MatMul" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_matmul(input1)?;
values.insert(node.output[0].clone(), output);
}
"Reshape" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?.to_vec1::<i64>()?;
// TODO: Check that there is at most a single -1 or 0, handle other neg values.
let mut other_than_minus1 = 1usize;
for &v in input1.iter() {
if v != -1 && v != 0 {
other_than_minus1 *= v as usize
}
}
let input1 = input1
.iter()
.enumerate()
.map(|(idx, &v)| match v {
-1 => Ok(input0.elem_count() / other_than_minus1),
0 => input0.dim(idx),
_ => Ok(v as usize),
})
.collect::<Result<Vec<usize>>>()?;
let output = input0.reshape(input1)?;
values.insert(node.output[0].clone(), output);
}
"LogSoftmax" => {
let input = get(&node.input[0])?;
let output = match get_attr_opt::<i64>(node, "axis")? {
None => candle_nn::ops::softmax_last_dim(input)?,
Some(&axis) => {
let axis = input.normalize_axis(axis)?;
candle_nn::ops::log_softmax(input, axis)?
}
};
values.insert(node.output[0].clone(), output);
}
"Softmax" => {
let input = get(&node.input[0])?;
let output = match get_attr_opt::<i64>(node, "axis")? {
None => candle_nn::ops::softmax_last_dim(input)?,
Some(&axis) => {
let axis = input.normalize_axis(axis)?;
candle_nn::ops::softmax(input, axis)?
}
};
values.insert(node.output[0].clone(), output);
}
"Transpose" => {
let input = get(&node.input[0])?;
let output = match get_attr_opt::<[i64]>(node, "perm")? {
None => input.t()?,
Some(perm) => {
let perm = perm.iter().map(|&v| v as usize).collect::<Vec<_>>();
input.permute(perm)?
}
};
values.insert(node.output[0].clone(), output);
}
"Dropout" => {
let input = get(&node.input[0])?;
// Do not apply dropout at the moment, consider that we're only doing inference.
values.insert(node.output[0].clone(), input.clone());
}
"MaxPool" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#MaxPool
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
let kernel_shape = get_attr::<[i64]>(node, "kernel_shape")?;
let pads = get_attr_opt::<[i64]>(node, "pads")?;
let strides = get_attr_opt::<[i64]>(node, "strides")?;
let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
match auto_pad {
None | Some("NOTSET") => (),
Some(s) => bail!("unsupported auto_pad {s}"),
};
if let Some(d) = dilations {
if d.iter().any(|&v| v != 1) {
bail!("MaxPool with dilation != 1, {dilations:?}")
}
}
if let Some(d) = pads {
if d.iter().any(|&v| v != 0) {
bail!("MaxPool with pads != 0, {pads:?}")
}
}
let xs = get(&node.input[0])?;
let (k1, k2) = match kernel_shape {
[k1, k2] => (*k1 as usize, *k2 as usize),
_ => bail!("only 2d MaxPool is supported, kernel shape {kernel_shape:?}"),
};
let ys = match strides {
None => xs.max_pool2d((k1, k2))?,
Some([s1, s2]) => {
xs.max_pool2d_with_stride((k1, k2), (*s1 as usize, *s2 as usize))?
}
Some(strides) => bail!("only 2d MaxPool is supported, strides {strides:?}"),
};
values.insert(node.output[0].clone(), ys);
}
"AveragePool" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
let kernel_shape = get_attr::<[i64]>(node, "kernel_shape")?;
let pads = get_attr_opt::<[i64]>(node, "pads")?;
let strides = get_attr_opt::<[i64]>(node, "strides")?;
let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
match auto_pad {
None | Some("NOTSET") => (),
Some(s) => bail!("unsupported auto_pad {s}"),
};
if let Some(d) = dilations {
if d.iter().any(|&v| v != 1) {
bail!("AvgPool with dilation != 1, {dilations:?}")
}
}
if let Some(d) = pads {
if d.iter().any(|&v| v != 0) {
bail!("AvgPool with pads != 0, {pads:?}")
}
}
let xs = get(&node.input[0])?;
let (k1, k2) = match kernel_shape {
[k1, k2] => (*k1 as usize, *k2 as usize),
_ => bail!("only 2d AvgPool is supported, kernel shape {kernel_shape:?}"),
};
let ys = match strides {
None => xs.avg_pool2d((k1, k2))?,
Some([s1, s2]) => {
xs.avg_pool2d_with_stride((k1, k2), (*s1 as usize, *s2 as usize))?
}
Some(strides) => bail!("only 2d AvgPool is supported, strides {strides:?}"),
};
values.insert(node.output[0].clone(), ys);
}
"BatchNormalization" => {
let training_mode = get_attr_opt::<i64>(node, "training_mode")?;
if training_mode.copied().unwrap_or(0) != 0 {
bail!("training mode is not supported for BatchNorm")
}
let eps = get_attr_opt::<f32>(node, "epsilon")?
.copied()
.unwrap_or(1e-5);
let xs = get(&node.input[0])?;
let weight = get(&node.input[1])?;
let bias = get(&node.input[2])?;
let running_mean = get(&node.input[3])?;
let running_var = get(&node.input[4])?;
let target_shape: Vec<usize> = xs
.dims()
.iter()
.enumerate()
.map(|(idx, v)| if idx == 1 { *v } else { 1 })
.collect();
let target_shape = target_shape.as_slice();
let xs = xs
.broadcast_sub(&running_mean.reshape(target_shape)?)?
.broadcast_div(&(running_var.reshape(target_shape)? + eps as f64)?.sqrt()?)?;
let weight = weight.reshape(target_shape)?;
let bias = bias.reshape(target_shape)?;
let xs = xs.broadcast_mul(&weight)?.broadcast_add(&bias)?;
values.insert(node.output[0].clone(), xs);
}
"Squeeze" => {
let xs = get(&node.input[0])?;
let mut axes = if node.input.len() <= 1 {
// contract all the dimensions with size 1 except the batch dim.
xs.dims()
.iter()
.enumerate()
.flat_map(|(idx, &s)| if s == 1 && idx > 0 { Some(idx) } else { None })
.collect()
} else {
get(&node.input[1])?
.to_vec1::<i64>()?
.iter()
.map(|&i| xs.normalize_axis(i))
.collect::<Result<Vec<_>>>()?
};
axes.sort();
let mut xs = xs.clone();
for &axis in axes.iter().rev() {
xs = xs.squeeze(axis)?
}
values.insert(node.output[0].clone(), xs);
}
"ConstantOfShape" => {
let dims = get(&node.input[0])?;
let shape = dims
.to_vec1::<i64>()?
.into_iter()
.map(|v| v as usize)
.collect::<Vec<_>>();
let xs = Tensor::zeros(shape, DType::F32, dims.device())?;
values.insert(node.output[0].clone(), xs);
}
"Unsqueeze" => {
let xs = get(&node.input[0])?;
let axes = match get_attr_opt::<[i64]>(node, "axes")? {
Some(axis) => axis.to_vec(),
None => get(&node.input[1])?.to_vec1::<i64>()?,
};
let mut axes = axes
.iter()
.map(|&i| {
if i == xs.rank() as i64 {
Ok(xs.rank())
} else {
xs.normalize_axis(i)
}
})
.collect::<Result<Vec<_>>>()?;
axes.sort();
let mut xs = xs.clone();
for &axis in axes.iter().rev() {
xs = xs.unsqueeze(axis)?
}
values.insert(node.output[0].clone(), xs);
}
"Clip" => {
let xs = get(&node.input[0])?;
let xs = if node.input.len() >= 2 {
let mins = get(&node.input[1])?;
xs.broadcast_maximum(mins)?
} else {
xs.clone()
};
let xs = if node.input.len() >= 3 {
let maxs = get(&node.input[2])?;
xs.broadcast_minimum(maxs)?
} else {
xs.clone()
};
values.insert(node.output[0].clone(), xs);
}
"Gather" => {
let xs = get(&node.input[0])?;
let indices = get(&node.input[1])?;
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
let axis = xs.normalize_axis(axis)?;
// TODO: Provide an op to handle the ONNX generalized gather op ideally in a
// differentiable way.
let xs = if indices.rank() == 0 {
let index = indices.to_vec0::<i64>()? as usize;
xs.narrow(axis, index, 1)?.squeeze(axis)?
} else {
todo!("implement gather for {xs:?} {indices:?} axis {axis}")
};
values.insert(node.output[0].clone(), xs);
}
"Shape" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape
let xs = get(&node.input[0])?;
let start = get_attr_opt::<i64>(node, "start")?.copied().unwrap_or(0);
let end = get_attr_opt::<i64>(node, "end")?.copied().unwrap_or(-1);
let start = xs.normalize_axis(start)?;
let end = xs.normalize_axis(end)?;
let mut dims = vec![];
for idx in start..=end {
dims.push(xs.dim(idx)? as i64)
}
let dims = Tensor::from_vec(dims, xs.rank(), xs.device())?;
values.insert(node.output[0].clone(), dims);
}
"Conv" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
let groups = get_attr_opt::<i64>(node, "group")?.copied().unwrap_or(1);
let _kernel_shape = get_attr_opt::<[i64]>(node, "kernel_shape")?;
let pads = get_attr_opt::<[i64]>(node, "pads")?;
let strides = get_attr_opt::<[i64]>(node, "strides")?;
let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
match auto_pad {
None | Some("NOTSET") => (),
Some(s) => bail!("unsupported auto_pad {s}"),
};
let xs = get(&node.input[0])?;
let ws = get(&node.input[1])?;
let ys = match ws.rank() {
3 => {
let (pads, xs) = match pads {
None => (0, xs.clone()),
Some([p]) => (*p as usize, xs.clone()),
Some([p1, p2]) => {
if p1 != p2 {
(0usize, xs.pad_with_zeros(2, *p1 as usize, *p2 as usize)?)
} else {
(*p1 as usize, xs.clone())
}
}
Some(pads) => {
bail!("more pads than expected in conv1d {pads:?} {}", node.name)
}
};
let strides = match strides {
None => 1,
Some([p]) => *p as usize,
Some(s) => {
bail!("more strides than expected in conv1d {s:?} {}", node.name)
}
};
let dilations = match dilations {
None => 1,
Some([p]) => *p as usize,
Some(s) => {
bail!("more dilations than expected in conv1d {s:?} {}", node.name)
}
};
xs.conv1d(ws, pads, strides, dilations, groups as usize)?
}
4 => {
let (pads, xs) = match pads {
None => (0, xs.clone()),
Some([p]) => (*p as usize, xs.clone()),
Some(&[p1, p2, p3, p4]) => {
let p1 = p1 as usize;
let p2 = p2 as usize;
let p3 = p3 as usize;
let p4 = p4 as usize;
if p1 != p2 || p1 != p3 || p1 != p4 {
(0, xs.pad_with_zeros(2, p1, p3)?.pad_with_zeros(3, p2, p4)?)
} else {
(p1, xs.clone())
}
}
Some(pads) => {
bail!("more pads than expected in conv2d {pads:?} {}", node.name)
}
};
let strides = match strides {
None => 1,
Some([p]) => *p as usize,
Some([p1, p2]) => {
if p1 != p2 {
bail!(
"strides have to be the same on both axis {pads:?} {}",
node.name
)
}
*p1 as usize
}
Some(s) => {
bail!("more strides than expected in conv2d {s:?} {}", node.name)
}
};
let dilations = match dilations {
None => 1,
Some([p]) => *p as usize,
Some([p1, p2]) => {
if p1 != p2 {
bail!(
"dilations have to be the same on both axis {pads:?} {}",
node.name
)
}
*p1 as usize
}
Some(s) => {
bail!("more dilations than expected in conv2d {s:?} {}", node.name)
}
};
xs.conv2d(ws, pads, strides, dilations, groups as usize)?
}
rank => bail!(
"unsupported rank for weight matrix {rank} in conv {}",
node.name
),
};
let ys = if node.input.len() > 2 {
let bs = get(&node.input[2])?;
let mut bs_shape = vec![1; ys.rank()];
bs_shape[1] = bs.elem_count();
ys.broadcast_add(&bs.reshape(bs_shape)?)?
} else {
ys
};
values.insert(node.output[0].clone(), ys);
}
"Concat" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Concat
let inputs = node
.input
.iter()
.map(|n| Ok(get(n.as_str())?.clone()))
.collect::<Result<Vec<Value>>>()?;
let axis: i64 = *get_attr(node, "axis")?;
if inputs.is_empty() {
bail!("empty concat")
};
let axis = inputs[0].normalize_axis(axis)?;
let output = Tensor::cat(&inputs, axis)?;
values.insert(node.output[0].clone(), output);
}
"Abs" => {
let input = get(&node.input[0])?;
let output = input.abs()?;
values.insert(node.output[0].clone(), output);
}
"Cos" => {
let input = get(&node.input[0])?;
let output = input.cos()?;
values.insert(node.output[0].clone(), output);
}
"Sin" => {
let input = get(&node.input[0])?;
let output = input.sin()?;
values.insert(node.output[0].clone(), output);
}
"Neg" => {
let input = get(&node.input[0])?;
let output = input.neg()?;
values.insert(node.output[0].clone(), output);
}
"Erf" => {
let input = get(&node.input[0])?;
let output = input.erf()?;
values.insert(node.output[0].clone(), output);
}
"Tanh" => {
let input = get(&node.input[0])?;
let output = input.tanh()?;
values.insert(node.output[0].clone(), output);
}
"Sigmoid" => {
let input = get(&node.input[0])?;
let output = candle_nn::ops::sigmoid(input)?;
values.insert(node.output[0].clone(), output);
}
"Gelu" => {
let input = get(&node.input[0])?;
let output = input.gelu_erf()?;
values.insert(node.output[0].clone(), output);
}
"Relu" => {
let input = get(&node.input[0])?;
let output = input.relu()?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Constant
"Constant" => {
let value = match node.attribute.iter().find(|attr| attr.name == "value") {
None => {
// TODO: support sparse_value etc.
bail!("cannot find 'value' attr in 'Constant' for {}", node.name)
}
Some(value) => value,
};
let output = match value.r#type() {
AttributeType::Tensor => {
let t = value.t.as_ref().unwrap();
get_tensor(t, &node.name)?
}
rtype => bail!("unsupported 'value' type {rtype:?} for {}", node.name),
};
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
"Cast" => {
let input = get(&node.input[0])?;
let dt: i64 = *get_attr(node, "to")?;
let dtype = match DataType::try_from(dt as i32) {
Ok(DataType::Int32) => DType::I64,
Ok(dt) => match dtype(dt) {
Some(dt) => dt,
None => {
bail!("unsupported 'to' value {dt:?} for cast {}", node.name)
}
},
Err(_) => {
bail!("unsupported 'to' value {dt:?} for cast {}", node.name)
}
};
let output = input.to_dtype(dtype)?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#CumSum
"CumSum" => {
let exclusive = get_attr_opt::<i64>(node, "exclusive")?
.copied()
.unwrap_or(0);
let reverse = get_attr_opt::<i64>(node, "reverse")?.copied().unwrap_or(0);
if exclusive != 0 {
bail!("only exclusive == 0 is supported in CumSum")
}
if reverse != 0 {
bail!("only reverse == 0 is supported in CumSum")
}
let input = get(&node.input[0])?;
let axis = get(&node.input[1])?
.to_dtype(DType::U32)?
.to_vec0::<u32>()?;
let output = input.cumsum(axis as usize)?;
values.insert(node.output[0].clone(), output);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}
graph
.output
.iter()
.map(|output| match values.remove(&output.name) {
None => bail!("cannot find output {}", output.name),
Some(value) => Ok((output.name.clone(), value)),
})
.collect()
}

View File

@ -1,14 +0,0 @@
use candle::Result;
use prost::Message;
pub mod onnx {
include!(concat!(env!("OUT_DIR"), "/onnx.rs"));
}
pub mod eval;
pub use eval::{dtype, simple_eval};
pub fn read_file<P: AsRef<std::path::Path>>(p: P) -> Result<onnx::ModelProto> {
let buf = std::fs::read(p)?;
onnx::ModelProto::decode(buf.as_slice()).map_err(candle::Error::wrap)
}

View File

@ -1,836 +0,0 @@
//
// WARNING: This file is automatically generated! Please edit onnx.in.proto.
//
// SPDX-License-Identifier: Apache-2.0
syntax = "proto3";
package onnx;
// Overview
//
// ONNX is an open specification that is comprised of the following components:
//
// 1) A definition of an extensible computation graph model.
// 2) Definitions of standard data types.
// 3) Definitions of built-in operators.
//
// This document describes the syntax of models and their computation graphs,
// as well as the standard data types. Together, they are referred to as the ONNX
// Intermediate Representation, or 'IR' for short.
//
// The normative semantic specification of the ONNX IR is found in docs/IR.md.
// Definitions of the built-in neural network operators may be found in docs/Operators.md.
// Notes
//
// Protobuf compatibility
//
// To simplify framework compatibility, ONNX is defined using the subset of protobuf
// that is compatible with both protobuf v2 and v3. This means that we do not use any
// protobuf features that are only available in one of the two versions.
//
// Here are the most notable contortions we have to carry out to work around
// these limitations:
//
// - No 'map' (added protobuf 3.0). We instead represent mappings as lists
// of key-value pairs, where order does not matter and duplicates
// are not allowed.
// Versioning
//
// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md
//
// To be compatible with both proto2 and proto3, we will use a version number
// that is not defined by the default value but an explicit enum number.
enum Version {
// proto3 requires the first enum value to be zero.
// We add this just to appease the compiler.
_START_VERSION = 0;
// The version field is always serialized and we will use it to store the
// version that the graph is generated from. This helps us set up version
// control.
// For the IR, we are using simple numbers starting with 0x00000001,
// which was the version we published on Oct 10, 2017.
IR_VERSION_2017_10_10 = 0x0000000000000001;
// IR_VERSION 2 published on Oct 30, 2017
// - Added type discriminator to AttributeProto to support proto3 users
IR_VERSION_2017_10_30 = 0x0000000000000002;
// IR VERSION 3 published on Nov 3, 2017
// - For operator versioning:
// - Added new message OperatorSetIdProto
// - Added opset_import in ModelProto
// - For vendor extensions, added domain in NodeProto
IR_VERSION_2017_11_3 = 0x0000000000000003;
// IR VERSION 4 published on Jan 22, 2019
// - Relax constraint that initializers should be a subset of graph inputs
// - Add type BFLOAT16
IR_VERSION_2019_1_22 = 0x0000000000000004;
// IR VERSION 5 published on March 18, 2019
// - Add message TensorAnnotation.
// - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters.
IR_VERSION_2019_3_18 = 0x0000000000000005;
// IR VERSION 6 published on Sep 19, 2019
// - Add support for sparse tensor constants stored in model.
// - Add message SparseTensorProto
// - Add sparse initializers
IR_VERSION_2019_9_19 = 0x0000000000000006;
// IR VERSION 7 published on May 8, 2020
// - Add support to allow function body graph to rely on multiple external opreator sets.
// - Add a list to promote inference graph's initializers to global and
// mutable variables. Global variables are visible in all graphs of the
// stored models.
// - Add message TrainingInfoProto to store initialization
// method and training algorithm. The execution of TrainingInfoProto
// can modify the values of mutable variables.
// - Implicitly add inference graph into each TrainingInfoProto's algorithm.
IR_VERSION_2020_5_8 = 0x0000000000000007;
// IR VERSION 8 published on July 30, 2021
// Introduce TypeProto.SparseTensor
// Introduce TypeProto.Optional
// Added a list of FunctionProtos local to the model
// Deprecated since_version and operator status from FunctionProto
IR_VERSION_2021_7_30 = 0x0000000000000008;
// IR VERSION 9 published on May 5, 2023
// Added AttributeProto to FunctionProto so that default attribute values can be set.
// Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ.
IR_VERSION = 0x0000000000000009;
}
// Attributes
//
// A named attribute containing either singular float, integer, string, graph,
// and tensor values, or repeated float, integer, string, graph, and tensor values.
// An AttributeProto MUST contain the name field, and *only one* of the
// following content fields, effectively enforcing a C/C++ union equivalent.
message AttributeProto {
reserved 12, 16 to 19;
reserved "v";
// Note: this enum is structurally identical to the OpSchema::AttrType
// enum defined in schema.h. If you rev one, you likely need to rev the other.
enum AttributeType {
UNDEFINED = 0;
FLOAT = 1;
INT = 2;
STRING = 3;
TENSOR = 4;
GRAPH = 5;
SPARSE_TENSOR = 11;
TYPE_PROTO = 13;
FLOATS = 6;
INTS = 7;
STRINGS = 8;
TENSORS = 9;
GRAPHS = 10;
SPARSE_TENSORS = 12;
TYPE_PROTOS = 14;
}
// The name field MUST be present for this version of the IR.
string name = 1; // namespace Attribute
// if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
// In this case, this AttributeProto does not contain data, and it's a reference of attribute
// in parent scope.
// NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph.
string ref_attr_name = 21;
// A human-readable documentation for this attribute. Markdown is allowed.
string doc_string = 13;
// The type field MUST be present for this version of the IR.
// For 0.0.1 versions of the IR, this field was not defined, and
// implementations needed to use has_field heuristics to determine
// which value field was in use. For IR_VERSION 0.0.2 or later, this
// field MUST be set and match the f|i|s|t|... field in use. This
// change was made to accommodate proto3 implementations.
AttributeType type = 20; // discriminator that indicates which field below is in use
// Exactly ONE of the following fields must be present for this version of the IR
float f = 2; // float
int64 i = 3; // int
bytes s = 4; // UTF-8 string
TensorProto t = 5; // tensor value
GraphProto g = 6; // graph
SparseTensorProto sparse_tensor = 22; // sparse tensor value
// Do not use field below, it's deprecated.
// optional ValueProto v = 12; // value - subsumes everything but graph
TypeProto tp = 14; // type proto
repeated float floats = 7; // list of floats
repeated int64 ints = 8; // list of ints
repeated bytes strings = 9; // list of UTF-8 strings
repeated TensorProto tensors = 10; // list of tensors
repeated GraphProto graphs = 11; // list of graph
repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors
repeated TypeProto type_protos = 15;// list of type protos
}
// Defines information on value, including the name, the type, and
// the shape of the value.
message ValueInfoProto {
// This field MUST be present in this version of the IR.
string name = 1; // namespace Value
// This field MUST be present in this version of the IR for
// inputs and outputs of the top-level graph.
TypeProto type = 2;
// A human-readable documentation for this value. Markdown is allowed.
string doc_string = 3;
}
// Nodes
//
// Computation graphs are made up of a DAG of nodes, which represent what is
// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
//
// For example, it can be a node of type "Conv" that takes in an image, a filter
// tensor and a bias tensor, and produces the convolved output.
message NodeProto {
repeated string input = 1; // namespace Value
repeated string output = 2; // namespace Value
// An optional identifier for this node in a graph.
// This field MAY be absent in ths version of the IR.
string name = 3; // namespace Node
// The symbolic identifier of the Operator to execute.
string op_type = 4; // namespace Operator
// The domain of the OperatorSet that specifies the operator named by op_type.
string domain = 7; // namespace Domain
// Additional named attributes.
repeated AttributeProto attribute = 5;
// A human-readable documentation for this node. Markdown is allowed.
string doc_string = 6;
}
// Training information
// TrainingInfoProto stores information for training a model.
// In particular, this defines two functionalities: an initialization-step
// and a training-algorithm-step. Initialization resets the model
// back to its original state as if no training has been performed.
// Training algorithm improves the model based on input data.
//
// The semantics of the initialization-step is that the initializers
// in ModelProto.graph and in TrainingInfoProto.algorithm are first
// initialized as specified by the initializers in the graph, and then
// updated by the "initialization_binding" in every instance in
// ModelProto.training_info.
//
// The field "algorithm" defines a computation graph which represents a
// training algorithm's step. After the execution of a
// TrainingInfoProto.algorithm, the initializers specified by "update_binding"
// may be immediately updated. If the targeted training algorithm contains
// consecutive update steps (such as block coordinate descent methods),
// the user needs to create a TrainingInfoProto for each step.
message TrainingInfoProto {
// This field describes a graph to compute the initial tensors
// upon starting the training process. Initialization graph has no input
// and can have multiple outputs. Usually, trainable tensors in neural
// networks are randomly initialized. To achieve that, for each tensor,
// the user can put a random number operator such as RandomNormal or
// RandomUniform in TrainingInfoProto.initialization.node and assign its
// random output to the specific tensor using "initialization_binding".
// This graph can also set the initializers in "algorithm" in the same
// TrainingInfoProto; a use case is resetting the number of training
// iteration to zero.
//
// By default, this field is an empty graph and its evaluation does not
// produce any output. Thus, no initializer would be changed by default.
GraphProto initialization = 1;
// This field represents a training algorithm step. Given required inputs,
// it computes outputs to update initializers in its own or inference graph's
// initializer lists. In general, this field contains loss node, gradient node,
// optimizer node, increment of iteration count.
//
// An execution of the training algorithm step is performed by executing the
// graph obtained by combining the inference graph (namely "ModelProto.graph")
// and the "algorithm" graph. That is, the actual
// input/initializer/output/node/value_info/sparse_initializer list of
// the training graph is the concatenation of
// "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer"
// and "algorithm.input/initializer/output/node/value_info/sparse_initializer"
// in that order. This combined graph must satisfy the normal ONNX conditions.
// Now, let's provide a visualization of graph combination for clarity.
// Let the inference graph (i.e., "ModelProto.graph") be
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d
// and the "algorithm" graph be
// tensor_d -> Add -> tensor_e
// The combination process results
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e
//
// Notice that an input of a node in the "algorithm" graph may reference the
// output of a node in the inference graph (but not the other way round). Also, inference
// node cannot reference inputs of "algorithm". With these restrictions, inference graph
// can always be run independently without training information.
//
// By default, this field is an empty graph and its evaluation does not
// produce any output. Evaluating the default training step never
// update any initializers.
GraphProto algorithm = 2;
// This field specifies the bindings from the outputs of "initialization" to
// some initializers in "ModelProto.graph.initializer" and
// the "algorithm.initializer" in the same TrainingInfoProto.
// See "update_binding" below for details.
//
// By default, this field is empty and no initializer would be changed
// by the execution of "initialization".
repeated StringStringEntryProto initialization_binding = 3;
// Gradient-based training is usually an iterative procedure. In one gradient
// descent iteration, we apply
//
// x = x - r * g
//
// where "x" is the optimized tensor, "r" stands for learning rate, and "g" is
// gradient of "x" with respect to a chosen loss. To avoid adding assignments
// into the training graph, we split the update equation into
//
// y = x - r * g
// x = y
//
// The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To
// tell that "y" should be assigned to "x", the field "update_binding" may
// contain a key-value pair of strings, "x" (key of StringStringEntryProto)
// and "y" (value of StringStringEntryProto).
// For a neural network with multiple trainable (mutable) tensors, there can
// be multiple key-value pairs in "update_binding".
//
// The initializers appears as keys in "update_binding" are considered
// mutable variables. This implies some behaviors
// as described below.
//
// 1. We have only unique keys in all "update_binding"s so that two
// variables may not have the same name. This ensures that one
// variable is assigned up to once.
// 2. The keys must appear in names of "ModelProto.graph.initializer" or
// "TrainingInfoProto.algorithm.initializer".
// 3. The values must be output names of "algorithm" or "ModelProto.graph.output".
// 4. Mutable variables are initialized to the value specified by the
// corresponding initializer, and then potentially updated by
// "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s.
//
// This field usually contains names of trainable tensors
// (in ModelProto.graph), optimizer states such as momentums in advanced
// stochastic gradient methods (in TrainingInfoProto.graph),
// and number of training iterations (in TrainingInfoProto.graph).
//
// By default, this field is empty and no initializer would be changed
// by the execution of "algorithm".
repeated StringStringEntryProto update_binding = 4;
}
// Models
//
// ModelProto is a top-level file/container format for bundling a ML model and
// associating its computation graph with metadata.
//
// The semantics of the model are described by the associated GraphProto's.
message ModelProto {
// The version of the IR this model targets. See Version enum above.
// This field MUST be present.
int64 ir_version = 1;
// The OperatorSets this model relies on.
// All ModelProtos MUST have at least one entry that
// specifies which version of the ONNX OperatorSet is
// being imported.
//
// All nodes in the ModelProto's graph will bind against the operator
// with the same-domain/same-op_type operator with the HIGHEST version
// in the referenced operator sets.
repeated OperatorSetIdProto opset_import = 8;
// The name of the framework or tool used to generate this model.
// This field SHOULD be present to indicate which implementation/tool/framework
// emitted the model.
string producer_name = 2;
// The version of the framework or tool used to generate this model.
// This field SHOULD be present to indicate which implementation/tool/framework
// emitted the model.
string producer_version = 3;
// Domain name of the model.
// We use reverse domain names as name space indicators. For example:
// `com.facebook.fair` or `com.microsoft.cognitiveservices`
//
// Together with `model_version` and GraphProto.name, this forms the unique identity of
// the graph.
string domain = 4;
// The version of the graph encoded. See Version enum below.
int64 model_version = 5;
// A human-readable documentation for this model. Markdown is allowed.
string doc_string = 6;
// The parameterized graph that is evaluated to execute the model.
GraphProto graph = 7;
// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 14;
// Training-specific information. Sequentially executing all stored
// `TrainingInfoProto.algorithm`s and assigning their outputs following
// the corresponding `TrainingInfoProto.update_binding`s is one training
// iteration. Similarly, to initialize the model
// (as if training hasn't happened), the user should sequentially execute
// all stored `TrainingInfoProto.initialization`s and assigns their outputs
// using `TrainingInfoProto.initialization_binding`s.
//
// If this field is empty, the training behavior of the model is undefined.
repeated TrainingInfoProto training_info = 20;
// A list of function protos local to the model.
//
// Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain".
// In case of any conflicts the behavior (whether the model local functions are given higher priority,
// or standard operator sets are given higher priotity or this is treated as error) is defined by
// the runtimes.
//
// The operator sets imported by FunctionProto should be compatible with the ones
// imported by ModelProto and other model local FunctionProtos.
// Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto
// or by 2 FunctionProtos then versions for the operator set may be different but,
// the operator schema returned for op_type, domain, version combination
// for both the versions should be same for every node in the function body.
//
// One FunctionProto can reference other FunctionProto in the model, however, recursive reference
// is not allowed.
repeated FunctionProto functions = 25;
};
// StringStringEntryProto follows the pattern for cross-proto-version maps.
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
message StringStringEntryProto {
string key = 1;
string value = 2;
};
message TensorAnnotation {
string tensor_name = 1;
// <key, value> pairs to annotate tensor specified by <tensor_name> above.
// The keys used in the mapping below must be pre-defined in ONNX spec.
// For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as
// quantization parameter keys.
repeated StringStringEntryProto quant_parameter_tensor_names = 2;
}
// Graphs
//
// A graph defines the computational logic of a model and is comprised of a parameterized
// list of nodes that form a directed acyclic graph based on their inputs and outputs.
// This is the equivalent of the "network" or "graph" in many deep learning
// frameworks.
message GraphProto {
// The nodes in the graph, sorted topologically.
repeated NodeProto node = 1;
// The name of the graph.
string name = 2; // namespace Graph
// A list of named tensor values, used to specify constant inputs of the graph.
// Each initializer (both TensorProto as well SparseTensorProto) MUST have a name.
// The name MUST be unique across both initializer and sparse_initializer,
// but the name MAY also appear in the input list.
repeated TensorProto initializer = 5;
// Initializers (see above) stored in sparse format.
repeated SparseTensorProto sparse_initializer = 15;
// A human-readable documentation for this graph. Markdown is allowed.
string doc_string = 10;
// The inputs and outputs of the graph.
repeated ValueInfoProto input = 11;
repeated ValueInfoProto output = 12;
// Information for the values in the graph. The ValueInfoProto.name's
// must be distinct. It is optional for a value to appear in value_info list.
repeated ValueInfoProto value_info = 13;
// This field carries information to indicate the mapping among a tensor and its
// quantization parameter tensors. For example:
// For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated,
// which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
repeated TensorAnnotation quantization_annotation = 14;
reserved 3, 4, 6 to 9;
reserved "ir_version", "producer_version", "producer_tag", "domain";
}
// Tensors
//
// A serialized tensor value.
message TensorProto {
enum DataType {
UNDEFINED = 0;
// Basic types.
FLOAT = 1; // float
UINT8 = 2; // uint8_t
INT8 = 3; // int8_t
UINT16 = 4; // uint16_t
INT16 = 5; // int16_t
INT32 = 6; // int32_t
INT64 = 7; // int64_t
STRING = 8; // string
BOOL = 9; // bool
// IEEE754 half-precision floating-point format (16 bits wide).
// This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
FLOAT16 = 10;
DOUBLE = 11;
UINT32 = 12;
UINT64 = 13;
COMPLEX64 = 14; // complex with float32 real and imaginary components
COMPLEX128 = 15; // complex with float64 real and imaginary components
// Non-IEEE floating-point format based on IEEE754 single-precision
// floating-point number truncated to 16 bits.
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
BFLOAT16 = 16;
// Non-IEEE floating-point format based on papers
// FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433,
// 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf.
// Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear.
// The computation usually happens inside a block quantize / dequantize
// fused by the runtime.
FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf
FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero
FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients
FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero
// Future extensions go here.
}
// The shape of the tensor.
repeated int64 dims = 1;
// The data type of the tensor.
// This field MUST have a valid TensorProto.DataType value
int32 data_type = 2;
// For very large tensors, we may want to store them in chunks, in which
// case the following fields will specify the segment that is stored in
// the current TensorProto.
message Segment {
int64 begin = 1;
int64 end = 2;
}
Segment segment = 3;
// Tensor content must be organized in row-major order.
//
// Depending on the data_type field, exactly one of the fields below with
// name ending in _data is used to store the elements of the tensor.
// For float and complex64 values
// Complex64 tensors are encoded as a single array of floats,
// with the real components appearing in odd numbered positions,
// and the corresponding imaginary component appearing in the
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
repeated float float_data = 4 [packed = true];
// For int32, uint8, int8, uint16, int16, bool, float8, and float16 values
// float16 and float8 values must be bit-wise converted to an uint16_t prior
// to writing to the buffer.
// When this field is present, the data_type field MUST be
// INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ
repeated int32 int32_data = 5 [packed = true];
// For strings.
// Each element of string_data is a UTF-8 encoded Unicode
// string. No trailing null, no leading BOM. The protobuf "string"
// scalar type is not used to match ML community conventions.
// When this field is present, the data_type field MUST be STRING
repeated bytes string_data = 6;
// For int64.
// When this field is present, the data_type field MUST be INT64
repeated int64 int64_data = 7 [packed = true];
// Optionally, a name for the tensor.
string name = 8; // namespace Value
// A human-readable documentation for this tensor. Markdown is allowed.
string doc_string = 12;
// Serializations can either use one of the fields above, or use this
// raw bytes field. The only exception is the string case, where one is
// required to store the content in the repeated bytes string_data field.
//
// When this raw_data field is used to store tensor value, elements MUST
// be stored in as fixed-width, little-endian order.
// Floating-point data types MUST be stored in IEEE 754 format.
// Complex64 elements must be written as two consecutive FLOAT values, real component first.
// Complex128 elements must be written as two consecutive DOUBLE values, real component first.
// Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
//
// Note: the advantage of specific field rather than the raw_data field is
// that in some cases (e.g. int data), protobuf does a better packing via
// variable length storage, and may lead to smaller binary footprint.
// When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
bytes raw_data = 9;
// Data can be stored inside the protobuf file using type-specific fields or raw_data.
// Alternatively, raw bytes data can be stored in an external file, using the external_data field.
// external_data stores key-value pairs describing data location. Recognized keys are:
// - "location" (required) - POSIX filesystem path relative to the directory where the ONNX
// protobuf model was stored
// - "offset" (optional) - position of byte at which stored data begins. Integer stored as string.
// Offset values SHOULD be multiples 4096 (page size) to enable mmap support.
// - "length" (optional) - number of bytes containing data. Integer stored as string.
// - "checksum" (optional) - SHA1 digest of file specified in under 'location' key.
repeated StringStringEntryProto external_data = 13;
// Location of the data for this tensor. MUST be one of:
// - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field.
// - EXTERNAL - data stored in an external location as described by external_data field.
enum DataLocation {
DEFAULT = 0;
EXTERNAL = 1;
}
// If value not set, data is stored in raw_data (if set) otherwise in type-specified field.
DataLocation data_location = 14;
// For double
// Complex128 tensors are encoded as a single array of doubles,
// with the real components appearing in odd numbered positions,
// and the corresponding imaginary component appearing in the
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
// When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
repeated double double_data = 10 [packed = true];
// For uint64 and uint32 values
// When this field is present, the data_type field MUST be
// UINT32 or UINT64
repeated uint64 uint64_data = 11 [packed = true];
}
// A serialized sparse-tensor value
message SparseTensorProto {
// The sequence of non-default values are encoded as a tensor of shape [NNZ].
// The default-value is zero for numeric tensors, and empty-string for string tensors.
// values must have a non-empty name present which serves as a name for SparseTensorProto
// when used in sparse_initializer list.
TensorProto values = 1;
// The indices of the non-default values, which may be stored in one of two formats.
// (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value
// corresponding to the j-th index of the i-th value (in the values tensor).
// (b) Indices can be a tensor of shape [NNZ], in which case the i-th value
// must be the linearized-index of the i-th value (in the values tensor).
// The linearized-index can be converted into an index tuple (k_1,...,k_rank)
// using the shape provided below.
// The indices must appear in ascending order without duplication.
// In the first format, the ordering is lexicographic-ordering:
// e.g., index-value [1,4] must appear before [2,1]
TensorProto indices = 2;
// The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank]
repeated int64 dims = 3;
}
// Defines a tensor shape. A dimension can be either an integer value
// or a symbolic variable. A symbolic variable represents an unknown
// dimension.
message TensorShapeProto {
message Dimension {
oneof value {
int64 dim_value = 1;
string dim_param = 2; // namespace Shape
};
// Standard denotation can optionally be used to denote tensor
// dimensions with standard semantic descriptions to ensure
// that operations are applied to the correct axis of a tensor.
// Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
// for pre-defined dimension denotations.
string denotation = 3;
};
repeated Dimension dim = 1;
}
// Types
//
// The standard ONNX data types.
message TypeProto {
message Tensor {
// This field MUST NOT have the value of UNDEFINED
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
int32 elem_type = 1;
TensorShapeProto shape = 2;
}
// repeated T
message Sequence {
// The type and optional shape of each element of the sequence.
// This field MUST be present for this version of the IR.
TypeProto elem_type = 1;
};
// map<K,V>
message Map {
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
// This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING
int32 key_type = 1;
// This field MUST be present for this version of the IR.
TypeProto value_type = 2;
};
// wrapper for Tensor, Sequence, or Map
message Optional {
// The type and optional shape of the element wrapped.
// This field MUST be present for this version of the IR.
// Possible values correspond to OptionalProto.DataType enum
TypeProto elem_type = 1;
};
message SparseTensor {
// This field MUST NOT have the value of UNDEFINED
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
int32 elem_type = 1;
TensorShapeProto shape = 2;
}
oneof value {
// The type of a tensor.
Tensor tensor_type = 1;
// NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values
// as input and output to graphs and nodes. These types are needed to naturally
// support classical ML operators. DNN operators SHOULD restrict their input
// and output types to tensors.
// The type of a sequence.
Sequence sequence_type = 4;
// The type of a map.
Map map_type = 5;
// The type of an optional.
Optional optional_type = 9;
// Type of the sparse tensor
SparseTensor sparse_tensor_type = 8;
}
// An optional denotation can be used to denote the whole
// type with a standard semantic description as to what is
// stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition
// for pre-defined type denotations.
string denotation = 6;
}
// Operator Sets
//
// OperatorSets are uniquely identified by a (domain, opset_version) pair.
message OperatorSetIdProto {
// The domain of the operator set being identified.
// The empty string ("") or absence of this field implies the operator
// set that is defined as part of the ONNX specification.
// This field MUST be present in this version of the IR when referring to any other operator set.
string domain = 1;
// The version of the operator set being identified.
// This field MUST be present in this version of the IR.
int64 version = 2;
}
// Operator/function status.
enum OperatorStatus {
EXPERIMENTAL = 0;
STABLE = 1;
}
message FunctionProto {
// The name of the function, similar usage of op_type in OperatorProto.
// Combined with FunctionProto.domain, this forms the unique identity of
// the FunctionProto.
string name = 1;
// Deprecated since IR Version 8
// optional int64 since_version = 2;
reserved 2;
reserved "since_version";
// Deprecated since IR Version 8
// optional OperatorStatus status = 3;
reserved 3;
reserved "status";
// The inputs and outputs of the function.
repeated string input = 4;
repeated string output = 5;
// The attribute parameters of the function.
// It is for function parameters without default values.
repeated string attribute = 6;
// The attribute protos of the function.
// It is for function attributes with default values.
// A function attribute shall be represented either as
// a string attribute or an AttributeProto, not both.
repeated AttributeProto attribute_proto = 11;
// The nodes in the function.
repeated NodeProto node = 7;
// A human-readable documentation for this function. Markdown is allowed.
string doc_string = 8;
// The OperatorSets this function body (graph) relies on.
//
// All nodes in the function body (graph) will bind against the operator
// with the same-domain/same-op_type operator with the HIGHEST version
// in the referenced operator sets. This means at most one version can be relied
// for one domain.
//
// The operator sets imported by FunctionProto should be compatible with the ones
// imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto
// and ModelProto then versions for the operator set may be different but,
// the operator schema returned for op_type, domain, version combination
// for both the versions should be same.
repeated OperatorSetIdProto opset_import = 9;
// The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of
// the FunctionProto.
string domain = 10;
}
// For using protobuf-lite
option optimize_for = LITE_RUNTIME;

View File

@ -1,746 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{Device, Result, Tensor};
use candle_onnx::onnx::{GraphProto, ModelProto, NodeProto, ValueInfoProto};
use std::collections::HashMap;
const INPUT_X: &str = "x";
const INPUT_Y: &str = "y";
const OUTPUT_Z: &str = "z";
fn create_model_proto_with_graph(graph: Option<GraphProto>) -> ModelProto {
ModelProto {
metadata_props: vec![],
training_info: vec![],
functions: vec![],
ir_version: 0,
opset_import: vec![],
producer_name: "".to_string(),
producer_version: "".to_string(),
domain: "".to_string(),
model_version: 0,
doc_string: "".to_string(),
graph,
}
}
#[test]
fn test_evaluation_fails_without_defined_graph() -> Result<()> {
let manual_graph = create_model_proto_with_graph(None);
let inputs: HashMap<String, Tensor> = HashMap::new();
match candle_onnx::simple_eval(&manual_graph, inputs) {
Err(err) => assert_eq!(err.to_string(), "no graph defined in proto"),
Ok(_) => panic!("Expected an error due to undefined graph"),
}
Ok(())
}
// "Add"
#[test]
fn test_add_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Add".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let first = z
.to_vec1::<f64>()?
.to_vec()
.get(0)
.expect("Failed to get first element")
.clone();
assert_eq!(first, 4.0f64);
Ok(())
}
// "Sub"
#[test]
fn test_sub_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Sub".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let first = z
.to_vec1::<f64>()?
.to_vec()
.get(0)
.expect("Failed to get first element")
.clone();
assert_eq!(first, 0.0f64);
Ok(())
}
// "Mul"
#[test]
fn test_mul_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Mul".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let first = z
.to_vec1::<f64>()?
.to_vec()
.get(0)
.expect("Failed to get first element")
.clone();
assert_eq!(first, 4.0f64);
Ok(())
}
// "Div"
#[test]
fn test_div_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Div".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let first = z
.to_vec1::<f64>()?
.to_vec()
.get(0)
.expect("Failed to get first element")
.clone();
assert_eq!(first, 1.0f64);
Ok(())
}
// "Equal"
#[test]
fn test_equal_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Equal".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let first = z.to_dtype(candle::DType::U8)?.to_vec1::<u8>()?.to_vec()[0];
assert_eq!(first, 1);
Ok(())
}
// "Not"
#[test]
fn test_not_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Not".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(&[0.], &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let first = z.to_dtype(candle::DType::U8)?.to_vec1::<u8>()?.to_vec()[0];
assert_eq!(first, 1);
Ok(())
}
// "MatMul"
#[test]
fn test_matmul_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "MatMul".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(
INPUT_X.to_string(),
Tensor::from_vec(
//
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
&[2, 2],
&Device::Cpu,
)?,
);
inputs.insert(
INPUT_Y.to_string(),
Tensor::from_vec(
//
vec![5.0f32, 6.0f32, 7.0f32, 8.0f32],
&[2, 2],
&Device::Cpu,
)?,
);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(results, vec![vec![19.0, 22.0], vec![43.0, 50.0]]);
Ok(())
}
// "Reshape"
#[test]
fn test_reshape_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Reshape".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
//
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
&[2, 2],
&Device::Cpu,
)?;
let y = Tensor::from_vec(
//
vec![4i64],
&[1],
&Device::Cpu,
)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
inputs.insert(INPUT_Y.to_string(), y);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec1::<f32>()?;
assert_eq!(results, vec![1.0, 2.0, 3.0, 4.0]);
Ok(())
}
// "LogSoftmax"
#[test]
fn test_logsoftmax_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "LogSoftmax".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
//
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
&[2, 2],
&Device::Cpu,
)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(
results,
vec![vec![0.26894143, 0.7310586], vec![0.26894143, 0.7310586]]
);
Ok(())
}
// "Softmax"
#[test]
fn test_softmax_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Softmax".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
//
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
&[2, 2],
&Device::Cpu,
)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(
results,
vec![vec![0.26894143, 0.7310586], vec![0.26894143, 0.7310586]]
);
Ok(())
}
// "Transpose"
#[test]
fn test_transpose_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Transpose".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
//
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
&[2, 2],
&Device::Cpu,
)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(results, vec![vec![1.0, 3.0], vec![2.0, 4.0]]);
Ok(())
}
// "Dropout"
#[test]
fn test_dropout_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Dropout".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
//
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
&[2, 2],
&Device::Cpu,
)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(results, vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
Ok(())
}
// Below are ops that are implemented but not tested yet
// "MaxPool"
// #[test]
// "AveragePool"
// #[test]
// "BatchNormalization"
// #[test]
// "Squeeze"
// #[test]
// "ConstantOfShape"
// #[test]
// "Unsqueeze"
// #[test]
// "Clip"
// #[test]
// "Gather"
// #[test]
// "Shape"
// #[test]
// "Conv"
// #[test]
// "Concat"
// #[test]
// "Abs"
// #[test]
// "Cos"
// #[test]
// "Sin"
// #[test]
// "Neg"
// #[test]
// "Erf"
// #[test]
// "Tanh"
// #[test]
// "Sigmoid"
// #[test]
// "Gelu"
// #[test]
// "Relu"
// #[test]
// "Constant"
// #[test]
// "Cast"
// #[test]

View File

@ -15,9 +15,8 @@ crate-type = ["cdylib"]
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.1" }
candle-onnx = {path= "../candle-onnx", version = "0.3.1", optional = true}
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
half = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
@ -30,5 +29,3 @@ default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"]
mkl = ["dep:intel-mkl-src","candle/mkl"]
onnx = ["dep:candle-onnx"]

View File

@ -1,5 +0,0 @@
# Generated content DO NOT EDIT
from .. import onnx
ONNXModel = onnx.ONNXModel
ONNXTensorDescription = onnx.ONNXTensorDescription

View File

@ -1,89 +0,0 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
from candle import Tensor, DType, QTensor
class ONNXModel:
"""
A wrapper around an ONNX model.
"""
def __init__(self, path: str):
pass
@property
def doc_string(self) -> str:
"""
The doc string of the model.
"""
pass
@property
def domain(self) -> str:
"""
The domain of the operator set of the model.
"""
pass
def initializers(self) -> Dict[str, Tensor]:
"""
Get the weights of the model.
"""
pass
@property
def inputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
"""
The inputs of the model.
"""
pass
@property
def ir_version(self) -> int:
"""
The version of the IR this model targets.
"""
pass
@property
def model_version(self) -> int:
"""
The version of the model.
"""
pass
@property
def outputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
"""
The outputs of the model.
"""
pass
@property
def producer_name(self) -> str:
"""
The producer of the model.
"""
pass
@property
def producer_version(self) -> str:
"""
The version of the producer of the model.
"""
pass
def run(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
Run the model on the given inputs.
"""
pass
class ONNXTensorDescription:
"""
A wrapper around an ONNX tensor description.
"""
@property
def dtype(self) -> DType:
"""
The data type of the tensor.
"""
pass
@property
def shape(self) -> Tuple[Union[int, str, Any]]:
"""
The shape of the tensor.
"""
pass

View File

@ -17,16 +17,14 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use ::candle::{quantized::QTensor, DType, Device, Module, Tensor, WithDType};
mod utils;
use utils::wrap_err;
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
mod shape;
use shape::{PyShape, PyShapeWithHole};
#[cfg(feature = "onnx")]
mod onnx;
pub fn wrap_err(err: ::candle::Error) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}"))
}
#[derive(Clone, Debug)]
#[pyclass(name = "Tensor")]
@ -71,13 +69,11 @@ impl PyDType {
}
static CUDA_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
static METAL_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum PyDevice {
Cpu,
Cuda,
Metal,
}
impl PyDevice {
@ -85,7 +81,7 @@ impl PyDevice {
match device {
Device::Cpu => Self::Cpu,
Device::Cuda(_) => Self::Cuda,
Device::Metal(_) => Self::Metal,
Device::Metal(_) => unimplemented!(),
}
}
@ -101,15 +97,6 @@ impl PyDevice {
*device = Some(d.clone());
Ok(d)
}
Self::Metal => {
let mut device = METAL_DEVICE.lock().unwrap();
if let Some(device) = device.as_ref() {
return Ok(device.clone());
};
let d = Device::new_metal(0).map_err(wrap_err)?;
*device = Some(d.clone());
Ok(d)
}
}
}
}
@ -131,7 +118,6 @@ impl ToPyObject for PyDevice {
let str = match self {
PyDevice::Cpu => "cpu",
PyDevice::Cuda => "cuda",
PyDevice::Metal => "metal",
};
str.to_object(py)
}
@ -1574,14 +1560,6 @@ fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
Ok(())
}
#[cfg(feature = "onnx")]
fn candle_onnx_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
use onnx::{PyONNXModel, PyONNXTensorDescriptor};
m.add_class::<PyONNXModel>()?;
m.add_class::<PyONNXTensorDescriptor>()?;
Ok(())
}
#[pymodule]
fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let utils = PyModule::new(py, "utils")?;
@ -1590,12 +1568,6 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let nn = PyModule::new(py, "functional")?;
candle_functional_m(py, nn)?;
m.add_submodule(nn)?;
#[cfg(feature = "onnx")]
{
let onnx = PyModule::new(py, "onnx")?;
candle_onnx_m(py, onnx)?;
m.add_submodule(onnx)?;
}
m.add_class::<PyTensor>()?;
m.add_class::<PyQTensor>()?;
m.add_class::<PyDType>()?;

View File

@ -1,212 +0,0 @@
use std::collections::HashMap;
use crate::utils::wrap_err;
use crate::{PyDType, PyTensor};
use candle_onnx::eval::{dtype, get_tensor, simple_eval};
use candle_onnx::onnx::tensor_proto::DataType;
use candle_onnx::onnx::tensor_shape_proto::dimension::Value;
use candle_onnx::onnx::type_proto::{Tensor as ONNXTensor, Value as ONNXValue};
use candle_onnx::onnx::{ModelProto, ValueInfoProto};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::{PyList, PyTuple};
#[derive(Clone, Debug)]
#[pyclass(name = "ONNXTensorDescription")]
/// A wrapper around an ONNX tensor description.
pub struct PyONNXTensorDescriptor(ONNXTensor);
#[pymethods]
impl PyONNXTensorDescriptor {
#[getter]
/// The data type of the tensor.
/// &RETURNS&: DType
fn dtype(&self) -> PyResult<PyDType> {
match DataType::try_from(self.0.elem_type) {
Ok(dt) => match dtype(dt) {
Some(dt) => Ok(PyDType(dt)),
None => Err(PyValueError::new_err(format!(
"unsupported 'value' data-type {dt:?}"
))),
},
type_ => Err(PyValueError::new_err(format!(
"unsupported input type {type_:?}"
))),
}
}
#[getter]
/// The shape of the tensor.
/// &RETURNS&: Tuple[Union[int,str,Any]]
fn shape(&self, py: Python) -> PyResult<Py<PyTuple>> {
let shape = PyList::empty(py);
if let Some(d) = &self.0.shape {
for dim in d.dim.iter() {
if let Some(value) = &dim.value {
match value {
Value::DimValue(v) => shape.append(*v)?,
Value::DimParam(s) => shape.append(s.clone())?,
};
} else {
return Err(PyValueError::new_err("None value in shape"));
}
}
}
Ok(shape.to_tuple().into())
}
fn __repr__(&self, py: Python) -> String {
match (self.shape(py), self.dtype()) {
(Ok(shape), Ok(dtype)) => format!(
"TensorDescriptor[shape: {:?}, dtype: {:?}]",
shape.to_string(),
dtype.__str__()
),
(Err(_), Err(_)) => "TensorDescriptor[shape: unknown, dtype: unknown]".to_string(),
(Err(_), Ok(dtype)) => format!(
"TensorDescriptor[shape: unknown, dtype: {:?}]",
dtype.__str__()
),
(Ok(shape), Err(_)) => format!(
"TensorDescriptor[shape: {:?}, dtype: unknown]",
shape.to_string()
),
}
}
fn __str__(&self, py: Python) -> String {
self.__repr__(py)
}
}
#[derive(Clone, Debug)]
#[pyclass(name = "ONNXModel")]
/// A wrapper around an ONNX model.
pub struct PyONNXModel(ModelProto);
fn extract_tensor_descriptions(
value_infos: &[ValueInfoProto],
) -> HashMap<String, PyONNXTensorDescriptor> {
let mut map = HashMap::new();
for value_info in value_infos.iter() {
let input_type = match &value_info.r#type {
Some(input_type) => input_type,
None => continue,
};
let input_type = match &input_type.value {
Some(input_type) => input_type,
None => continue,
};
let tensor_type: &ONNXTensor = match input_type {
ONNXValue::TensorType(tt) => tt,
_ => continue,
};
map.insert(
value_info.name.to_string(),
PyONNXTensorDescriptor(tensor_type.clone()),
);
}
map
}
#[pymethods]
impl PyONNXModel {
#[new]
#[pyo3(text_signature = "(self, path:str)")]
/// Load an ONNX model from the given path.
fn new(path: String) -> PyResult<Self> {
let model: ModelProto = candle_onnx::read_file(path).map_err(wrap_err)?;
Ok(PyONNXModel(model))
}
#[getter]
/// The version of the IR this model targets.
/// &RETURNS&: int
fn ir_version(&self) -> i64 {
self.0.ir_version
}
#[getter]
/// The producer of the model.
/// &RETURNS&: str
fn producer_name(&self) -> String {
self.0.producer_name.clone()
}
#[getter]
/// The version of the producer of the model.
/// &RETURNS&: str
fn producer_version(&self) -> String {
self.0.producer_version.clone()
}
#[getter]
/// The domain of the operator set of the model.
/// &RETURNS&: str
fn domain(&self) -> String {
self.0.domain.clone()
}
#[getter]
/// The version of the model.
/// &RETURNS&: int
fn model_version(&self) -> i64 {
self.0.model_version
}
#[getter]
/// The doc string of the model.
/// &RETURNS&: str
fn doc_string(&self) -> String {
self.0.doc_string.clone()
}
/// Get the weights of the model.
/// &RETURNS&: Dict[str, Tensor]
fn initializers(&self) -> PyResult<HashMap<String, PyTensor>> {
let mut map = HashMap::new();
if let Some(graph) = self.0.graph.as_ref() {
for tensor_description in graph.initializer.iter() {
let tensor = get_tensor(tensor_description, tensor_description.name.as_str())
.map_err(wrap_err)?;
map.insert(tensor_description.name.to_string(), PyTensor(tensor));
}
}
Ok(map)
}
#[getter]
/// The inputs of the model.
/// &RETURNS&: Optional[Dict[str, ONNXTensorDescription]]
fn inputs(&self) -> Option<HashMap<String, PyONNXTensorDescriptor>> {
if let Some(graph) = self.0.graph.as_ref() {
return Some(extract_tensor_descriptions(&graph.input));
}
None
}
#[getter]
/// The outputs of the model.
/// &RETURNS&: Optional[Dict[str, ONNXTensorDescription]]
fn outputs(&self) -> Option<HashMap<String, PyONNXTensorDescriptor>> {
if let Some(graph) = self.0.graph.as_ref() {
return Some(extract_tensor_descriptions(&graph.output));
}
None
}
#[pyo3(text_signature = "(self, inputs:Dict[str,Tensor])")]
/// Run the model on the given inputs.
/// &RETURNS&: Dict[str,Tensor]
fn run(&self, inputs: HashMap<String, PyTensor>) -> PyResult<HashMap<String, PyTensor>> {
let unwrapped_tensors = inputs.into_iter().map(|(k, v)| (k.clone(), v.0)).collect();
let result = simple_eval(&self.0, unwrapped_tensors).map_err(wrap_err)?;
Ok(result
.into_iter()
.map(|(k, v)| (k.clone(), PyTensor(v)))
.collect())
}
}

View File

@ -1,6 +0,0 @@
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
pub fn wrap_err(err: ::candle::Error) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}"))
}

View File

@ -12,16 +12,15 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true }
candle-nn = { path = "../candle-nn", version = "0.3.1" }
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
rand = { workspace = true }
rayon = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
serde_plain = { workspace = true }
tracing = { workspace = true }
wav = { workspace = true }
@ -29,5 +28,6 @@ wav = { workspace = true }
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda"]
metal = ["candle/metal", "candle-nn/metal"]
flash-attn = ["cuda", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]

View File

@ -1,4 +1,3 @@
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
use candle::{DType, Device, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
use serde::Deserialize;
@ -33,6 +32,76 @@ impl HiddenActLayer {
}
}
#[derive(Debug)]
pub struct Linear {
weight: Tensor,
bias: Option<Tensor>,
span: tracing::Span,
}
impl Linear {
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
let span = tracing::span!(tracing::Level::TRACE, "linear");
Self { weight, bias, span }
}
}
impl Module for Linear {
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let _enter = self.span.enter();
let w = match x.dims() {
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
_ => self.weight.t()?,
};
let x = x.matmul(&w)?;
match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),
}
}
}
#[derive(Debug)]
pub struct LayerNorm {
weight: Tensor,
bias: Tensor,
eps: f64,
span: tracing::Span,
}
impl LayerNorm {
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
Self {
weight,
bias,
eps,
span,
}
}
}
impl Module for LayerNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let (_bsize, _seq_len, hidden_size) = x.dims3()?;
let x = x.to_dtype(internal_dtype)?;
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
let x = x.broadcast_sub(&mean_x)?;
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
let x = x_normed
.to_dtype(x_dtype)?
.broadcast_mul(&self.weight)?
.broadcast_add(&self.bias)?;
Ok(x)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
enum PositionEmbeddingType {
@ -115,6 +184,12 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
Ok(Embedding::new(embeddings, hidden_size))
}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
let bias = vb.get(size2, "bias")?;
Ok(Linear::new(weight, Some(bias)))
}
struct Dropout {
#[allow(dead_code)]
pr: f64,
@ -133,6 +208,20 @@ impl Module for Dropout {
}
}
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
(Ok(weight), Ok(bias)) => (weight, bias),
(Err(err), _) | (_, Err(err)) => {
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
(weight, bias)
} else {
return Err(err);
}
}
};
Ok(LayerNorm::new(weight, bias, eps))
}
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
struct BertEmbeddings {
word_embeddings: Embedding,

Some files were not shown because too many files have changed in this diff Show More