Compare commits

..

28 Commits

Author SHA1 Message Date
eb24875856 Reworked affine and it works ? No idea how it's different. 2023-11-08 02:37:20 +01:00
3f662e54cd Reworked affine and it works ? No idea how it's different. 2023-11-08 02:34:08 +01:00
480a3e22e6 Adding cast + binary kernels. 2023-11-07 23:45:53 +01:00
0c24a885a6 Updated everything and output a trace. 2023-11-07 21:12:42 +01:00
76d3116f5d Broken metal ? 2023-11-07 14:20:13 +01:00
1367e0278b pesky bfloat type 2023-11-07 10:26:59 +01:00
7ff17d92b3 Finished the unary
- Added proper kernel type check (through modules + macro)
- split contiguous and strided into 2 different kernels
- Verified on long range + strided values.
2023-11-06 23:12:12 +01:00
cd68c96803 Going overbounds will break other kernels running from other threads. 2023-11-06 17:29:58 +01:00
4d87305c48 Float -> half / bfloat conversion in unary 2023-11-06 17:09:39 +01:00
677495f9b8 Working but failing tests because of threadgroup. 2023-11-06 17:04:47 +01:00
dedc8c3656 Writing unary as macro instead, protecting bfloat type with proper metal version. 2023-11-06 15:36:48 +01:00
63cce76b84 Improve metal kernel loading and associated errors 2023-11-06 09:48:18 +01:00
634a4e7168 BlitEncoder added to affine for copying buffer contents quickly. 2023-11-06 08:23:36 +01:00
8124d1003f Affine metal kernel works. Need to extract buffer contents based on layout offset (like CudaSlice.slice) for candle intergration 2023-11-06 04:46:56 +01:00
6d4c8c0707 Use metal encode_gemm 2023-11-06 03:27:22 +01:00
e6d33a8efb Remove unused utils.metal 2023-11-06 03:26:21 +01:00
c921cc3784 Add Arc to metalstorage buffer for quick cloning 2023-11-04 09:03:23 +01:00
d4d6850c78 Impl index_add via template for all types 2023-11-04 08:46:08 +01:00
e708d35e7f index_add works 2023-11-03 21:12:52 +01:00
0794e70a19 Debugging index_add. 2023-11-03 12:09:05 +01:00
f57e3164ae Implemented cos for now. 2023-11-03 01:24:51 +01:00
7161002a34 Finished scaffolding, lots of TODOs
- Most kernels just copy themselfs to get the shapes correct
- Matmul works only in 1 case and simply empty allocates otherwise
- Logits and randomized to make the demo finish itself.

Performance is quite bad (30ms/token), but lot's of prints and allocs and some actual sending to metal.

Couln't get it super high by removing the obvious blockers (println + the actual running matmuls).

Allocations takes between 1us and 100us and seems very stable, Maybe metal doesn't really have a smart allocator and we'll need to own it.
2023-11-02 15:32:28 +01:00
82cce52e73 Rename candle-metal -> candle-metal-kernels 2023-11-02 09:53:29 +01:00
71fcb31873 Owned command buffer now. 2023-11-01 18:03:53 +01:00
198009453a Matmul (no batch, no strided, f32, f32 only) sort of done. 2023-11-01 17:36:51 +01:00
492d164235 More scaffolding, now need to implement matmul (for precompute_cos_sin to work). 2023-11-01 16:54:09 +01:00
2d84c16fed First pass (Quantized scaffolding work done + quantized example scaffolding). 2023-11-01 15:10:11 +01:00
4525b7b52a Initial setup 2023-10-31 18:09:10 +01:00
104 changed files with 1315 additions and 7366 deletions

View File

@ -59,7 +59,7 @@ jobs:
- name: Install Rust Stable
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
- uses: Swatinem/rust-cache@v2
- run: apt-get update -y && apt-get install libssl-dev protobuf-compiler -y
- run: apt-get update -y && apt-get install libssl-dev -y
- name: Test (cuda)
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
stop-runner:

Binary file not shown.

View File

@ -39,12 +39,6 @@ jobs:
path: ~/.cargo/registry
key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}
- name: Install Protoc
uses: arduino/setup-protoc@v2
with:
version: "25.0"
repo-token: ${{ secrets.GITHUB_TOKEN }}
- name: Install
working-directory: ./candle-pyo3
run: |
@ -52,7 +46,7 @@ jobs:
source .env/bin/activate
pip install -U pip
pip install pytest maturin black
python -m maturin develop -r --features onnx
python -m maturin develop -r
- name: Check style
working-directory: ./candle-pyo3

View File

@ -10,12 +10,7 @@ members = [
"candle-wasm-examples/*",
"candle-wasm-tests",
]
exclude = [
"candle-flash-attn",
"candle-kernels",
"candle-metal-kernels",
"candle-onnx",
]
exclude = ["candle-flash-attn", "candle-kernels"]
resolver = "2"
[workspace.package]
@ -51,7 +46,6 @@ rayon = "1.7.0"
rusttype = { version = "0.9", default-features = false }
safetensors = "0.3.1"
serde = { version = "1.0.171", features = ["derive"] }
serde_plain = "1.0.2"
serde_json = "1.0.99"
thiserror = "1"
tokenizers = { version = "0.13.4", default-features = false }

View File

@ -51,7 +51,7 @@ For more advanced examples, please have a look at the following section.
These online demos run entirely in your browser:
- [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and
object recognition.
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): speech recognition.
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): text to speech.
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
@ -69,8 +69,6 @@ We also provide a some command line based examples using state of the art models
performance larger than all publicly available 13b models as of 2023-09-28.
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
(English/Chinese) general LLMs with 6b and 34b parameters.
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
the LLaMA model using the same quantization techniques as
[llama.cpp](https://github.com/ggerganov/llama.cpp).
@ -145,11 +143,6 @@ And then head over to
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): a LoRA implementation
that conforms to the official `peft` implementation.
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
serving local LLMs including an OpenAI compatible API server.
- [`candle-ext`](https://github.com/mokeyish/candle-ext): an extension library to Candle that provides PyTorch functions not currently available in Candle.
- [`kalosm`](https://github.com/floneum/floneum/tree/master/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
If you have an addition to this list, please submit a pull request.
@ -175,17 +168,16 @@ If you have an addition to this list, please submit a pull request.
- Mistral 7b v0.1.
- StableLM-3B-4E1T.
- Replit-code-v1.5-3B.
- T5.
- Bert.
- Yi-6B and Yi-34B.
- Text to text.
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
- Marian MT (Machine Translation).
- Whisper (multi-lingual support).
- Text to image.
- Stable Diffusion v1.5, v2.1, XL v1.0.
- Wurstchen v2.
- Image to text.
- BLIP.
- Text to text.
- Marian MT (Machine Translation).
- Computer Vision Models.
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
- yolo-v3, yolo-v8.
@ -226,7 +218,6 @@ Cheatsheet:
- [candle-datasets](./candle-datasets/): Datasets and data loaders.
- [candle-transformers](./candle-transformers): transformers-related utilities.
- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer.
- [candle-onnx](./candle-onnx/): ONNX model evaluation.
## FAQ

View File

@ -12,6 +12,7 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
tracing = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
metal = { workspace = true, optional = true}
@ -41,4 +42,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels"]
metal = ["dep:candle-metal-kernels", "dep:metal"]

View File

@ -39,14 +39,6 @@ pub trait BackendStorage: Sized {
_params: &crate::conv::ParamsConv1D,
) -> Result<Self>;
fn conv_transpose1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self>;
fn conv2d(
&self,
_l: &Layout,

View File

@ -15,17 +15,6 @@ fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result
}
}
thread_local! {
static CANDLE_GRAD_DO_NOT_DETACH: bool = {
match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") {
Ok(s) => {
!s.is_empty() && s != "0"
},
Err(_) => false,
}
}
}
impl Tensor {
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
/// elements having dependencies on the latter ones, e.g. the first element if any is the
@ -68,11 +57,6 @@ impl Tensor {
kernel: rhs,
..
}
| Op::ConvTranspose1D {
arg: lhs,
kernel: rhs,
..
}
| Op::Conv2D {
arg: lhs,
kernel: rhs,
@ -166,16 +150,10 @@ impl Tensor {
if node.is_variable() {
continue;
}
let grad = grads
.remove(node)
.expect("candle internal error - grad not populated");
// https://github.com/huggingface/candle/issues/1241
// Ideally, we would make these operations in place where possible to ensure that we
// do not have to allocate too often. Here we just call `.detach` to avoid computing
// the backprop graph of the backprop itself. This would be an issue for second order
// derivatives but these are out of scope at the moment.
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
let grad = if do_not_detach { grad } else { grad.detach()? };
let grad = grads.remove(node).unwrap();
// TODO: We should perform all these operations in place (or at least not track the
// whole graph). The only drawback would be if we wanted to support grad of grad but
// this is out of scope.
if let Some(op) = node.op() {
match op {
Op::Binary(lhs, rhs, BinaryOp::Add) => {
@ -230,44 +208,7 @@ impl Tensor {
let f_grad = pred.where_cond(&zeros, &grad)?;
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D {
arg,
kernel,
padding,
stride,
dilation,
} => {
// The output height for conv_transpose1d is:
// (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1
let grad_l_in = grad.dim(2)?;
let k_size = kernel.dim(2)?;
let out_size =
(grad_l_in - 1) * stride + dilation * (k_size - 1) + 1 - 2 * padding;
let out_padding = arg.dim(2)? - out_size;
let grad_arg = grad.conv_transpose1d(
kernel,
*padding,
out_padding,
*stride,
*dilation,
)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
let grad_kernel = arg
.transpose(0, 1)?
.conv1d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
let (_, _, k0) = kernel.dims3()?;
let (_, _, g_k0) = grad_kernel.dims3()?;
let grad_kernel = if g_k0 != k0 {
grad_kernel.narrow(2, 0, k0)?
} else {
grad_kernel
};
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
Op::Conv2D {
arg,
kernel,
@ -306,9 +247,6 @@ impl Tensor {
};
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
op: "conv-transpose1d",
})?,
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
op: "conv-transpose2d",
})?,
@ -549,38 +487,16 @@ impl Tensor {
+ 0.5)?;
*sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
}
Op::Unary(arg, UnaryOp::Erf) => {
let sum_grad = grads.or_insert(arg)?;
// d/dx erf(x) = 2/sqrt(pi) * e^(-x^2)
let erf_grad =
(2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?;
*sum_grad = sum_grad.add(&(&grad * erf_grad)?)?
}
Op::Unary(arg, UnaryOp::GeluErf) => {
let sum_grad = grads.or_insert(arg)?;
// d/dx gelu_erf(x) = 0.5 + 0.398942 e^(-x^2/2) x + 0.5 erf(x/sqrt(2))
let neg_half_square = (arg.sqr()?.neg()? / 2.)?;
let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?;
let arg_scaled_sqrt = (arg / 2f64.sqrt())?;
let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?;
let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?;
*sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?;
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
Op::Unary(_, UnaryOp::GeluErf) => {
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
}
Op::Unary(arg, UnaryOp::Relu) => {
let sum_grad = grads.or_insert(arg)?;
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
}
Op::Elu(arg, alpha) => {
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
let sum_grad = grads.or_insert(arg)?;
let zeros = arg.zeros_like()?;
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?;
let combined_mask = (positive_mask + negative_exp_mask)?;
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
}
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
Op::Powf(arg, e) => {
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
let sum_grad = grads.or_insert(arg)?;

View File

@ -25,33 +25,6 @@ impl ParamsConv1D {
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConvTranspose1D {
pub(crate) b_size: usize,
pub(crate) l_in: usize,
pub(crate) c_out: usize,
pub(crate) c_in: usize,
pub(crate) k_size: usize,
pub(crate) padding: usize,
pub(crate) output_padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
}
impl ParamsConvTranspose1D {
pub(crate) fn l_out(&self) -> usize {
(self.l_in - 1) * self.stride - 2 * self.padding
+ self.dilation * (self.k_size - 1)
+ self.output_padding
+ 1
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
let l_out = self.l_out();
vec![self.b_size, self.c_out, l_out]
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum CudnnFwdAlgo {
ImplicitGemm,
@ -187,49 +160,6 @@ impl Tensor {
}
}
/// Applies a 1D transposed convolution over the input tensor.
pub fn conv_transpose1d(
&self,
kernel: &Self,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
) -> Result<Self> {
let (b_size, c_in, l_in) = self.dims3()?;
let (c_in_k, c_out, k_size) = kernel.dims3()?;
if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
}
let params = ParamsConvTranspose1D {
b_size,
l_in,
k_size,
c_out,
c_in,
padding,
output_padding,
stride,
dilation,
};
let storage = self.storage().conv_transpose1d(
self.layout(),
&kernel.storage(),
kernel.layout(),
&params,
)?;
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
arg,
kernel,
padding: params.padding,
output_padding: params.output_padding,
stride: params.stride,
dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
}
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
let storage =
self.storage()

View File

@ -1256,74 +1256,6 @@ impl Map1 for Im2Col {
}
}
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
impl<'a> Map2 for ConvTranspose1D<'a> {
const OP: &'static str = "conv_transpose1d";
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
let inp = &inp[inp_l.start_offset()..];
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
let l_out = p.l_out();
// Output shape: [b_size, c_out, l_out].
let dst_elems = p.c_out * l_out * p.b_size;
let dst = vec![T::zero(); dst_elems];
let dst_s0 = p.c_out * l_out;
let dst_s1 = l_out;
let dst_s2 = 1;
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
let cont_s0 = p.l_in * p.c_in;
let cont_s1 = p.c_in;
for b_idx in 0..p.b_size {
for l_idx in 0..p.l_in {
for c_idx in 0..p.c_in {
let src_idx = b_idx * inp_s0 + c_idx * inp_s1 + l_idx * inp_s2;
let dst_idx = b_idx * cont_s0 + l_idx * cont_s1 + c_idx;
inp_cont[dst_idx] = inp[src_idx]
}
}
}
for k_idx in 0..p.k_size {
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
let k_cont = (0..p.c_in)
.map(|c_in_idx| k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_idx * k_s2])
.collect::<Vec<_>>();
for b_idx in 0..p.b_size {
for l_idx in 0..p.l_in {
let out_idx = l_idx * p.stride + k_idx * p.dilation;
if out_idx < p.padding {
continue;
}
let out_idx = out_idx - p.padding;
if out_idx < l_out {
let inp_cont = &inp_cont[b_idx * cont_s0 + l_idx * cont_s1..];
let dst_idx = b_idx * dst_s0 + out_idx * dst_s2 + dst_c_idx * dst_s1;
let mut d = T::zero();
unsafe {
T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
}
let dst_p = dst.as_ptr();
// Safety: dst_idx are uniques per dst_c_idx which is used to
// parallelise the different tasks so no two threads can try to
// write at the same location.
unsafe {
let ptr = dst_p.add(dst_idx) as *mut T;
*ptr += d
}
}
}
}
})
}
Ok(dst)
}
}
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
impl<'a> Map2 for Conv2D<'a> {
@ -2503,16 +2435,6 @@ impl BackendStorage for CpuStorage {
Ok(res_t)
}
fn conv_transpose1d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
}
fn conv2d(
&self,
l: &Layout,

View File

@ -1808,16 +1808,6 @@ impl BackendStorage for CudaStorage {
Ok(res_t)
}
fn conv_transpose1d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
todo!()
}
#[cfg(not(feature = "cudnn"))]
fn conv2d(
&self,

View File

@ -1,6 +1,6 @@
use crate::backend::BackendDevice;
use crate::cpu_backend::CpuDevice;
use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
use crate::{bail, CpuStorage, DType, Result, Shape, Storage, WithDType};
/// A `DeviceLocation` represents a physical device whereas multiple `Device`
/// can live on the same location (typically for cuda devices).
@ -8,7 +8,7 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
pub enum DeviceLocation {
Cpu,
Cuda { gpu_id: usize },
Metal { gpu_id: usize },
Metal,
}
#[derive(Debug, Clone)]
@ -105,14 +105,14 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4:
impl<S: NdArray> NdArray for Vec<S> {
fn shape(&self) -> Result<Shape> {
if self.is_empty() {
crate::bail!("empty array")
bail!("empty array")
}
let shape0 = self[0].shape()?;
let n = self.len();
for v in self.iter() {
let shape = v.shape()?;
if shape != shape0 {
crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
bail!("two elements have different shapes {shape:?} {shape0:?}")
}
}
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
@ -146,7 +146,6 @@ impl Device {
match (self, rhs) {
(Self::Cpu, Self::Cpu) => true,
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
(Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
_ => false,
}
}
@ -167,10 +166,6 @@ impl Device {
matches!(self, Self::Cuda(_))
}
pub fn is_metal(&self) -> bool {
matches!(self, Self::Metal(_))
}
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
if crate::utils::cuda_is_available() {
Self::new_cuda(ordinal)
@ -192,19 +187,13 @@ impl Device {
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
if dtype == DType::F16 || dtype == DType::BF16 {
let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
} else {
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cuda(storage))
}
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(_device) => {
// let storage = device.rand_uniform(shape, dtype, lo, up)?;
// Ok(Storage::Metal(storage))
crate::bail!("Metal rand_uniform not implemented")
bail!("Metal rand_uniform not implemented")
}
}
}
@ -231,14 +220,8 @@ impl Device {
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
if dtype == DType::F16 || dtype == DType::BF16 {
let storage = device.rand_normal(shape, DType::F32, mean, std)?;
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
} else {
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage))
}
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.rand_normal(shape, dtype, mean, std)?;

View File

@ -14,9 +14,7 @@ impl Tensor {
crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id)
}
crate::DeviceLocation::Metal { gpu_id } => {
format!(", metal:{}", gpu_id)
}
_ => todo!(),
};
write!(f, "Tensor[")?;
@ -479,9 +477,7 @@ impl std::fmt::Display for Tensor {
crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id)
}
crate::DeviceLocation::Metal { gpu_id } => {
format!(", metal:{}", gpu_id)
}
crate::DeviceLocation::Metal => todo!(),
};
write!(

View File

@ -79,16 +79,6 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
fn conv_transpose1d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn conv2d(
&self,
_: &Layout,

View File

@ -8,18 +8,6 @@ pub struct MetalDevice;
#[derive(Debug)]
pub struct MetalStorage;
#[derive(thiserror::Error, Debug)]
pub enum MetalError {
#[error("{0}")]
Message(String),
}
impl From<String> for MetalError {
fn from(e: String) -> Self {
MetalError::Message(e)
}
}
macro_rules! fail {
() => {
unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
@ -91,16 +79,6 @@ impl crate::backend::BackendStorage for MetalStorage {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv_transpose1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv2d(
&self,
_: &Layout,

View File

@ -1,4 +1,4 @@
use crate::{DType, DeviceLocation, Layout, MetalError, Shape};
use crate::{metal_backend, DType, DeviceLocation, Layout, Shape};
#[derive(Debug, Clone)]
pub struct MatMulUnexpectedStriding {
@ -163,7 +163,7 @@ pub enum Error {
Cuda(Box<dyn std::error::Error + Send + Sync>),
#[error("Metal error {0}")]
Metal(#[from] MetalError),
Metal(#[from] metal_backend::MetalError),
#[error(transparent)]
TryFromIntError(#[from] core::num::TryFromIntError),

View File

@ -49,12 +49,13 @@ mod device;
pub mod display;
mod dtype;
mod dummy_cuda_backend;
mod dummy_metal_backend;
pub mod error;
mod indexer;
pub mod layout;
#[cfg(feature = "metal")]
pub mod metal_backend;
#[cfg(feature = "accelerate")]
mod metal_backend;
#[cfg(feature = "mkl")]
mod mkl;
pub mod npy;
@ -91,10 +92,10 @@ pub use cuda_backend::{CudaDevice, CudaStorage};
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
#[cfg(feature = "metal")]
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
pub use metal_backend::{MetalDevice, MetalStorage};
#[cfg(not(feature = "metal"))]
pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage};
pub use dummy_metal_backend::{MetalDevice, MetalStorage};
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

File diff suppressed because it is too large Load Diff

View File

@ -90,16 +90,6 @@ pub enum Op {
dilation: usize,
},
#[allow(dead_code)]
ConvTranspose1D {
arg: Tensor,
kernel: Tensor,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
},
#[allow(dead_code)]
Conv2D {
arg: Tensor,
@ -593,8 +583,7 @@ unary_op!(Recip, "recip", v, v.recip());
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
/// Tanh based approximation of the `gelu` operation
/// GeluErf is the more precise one.
/// `gelu` operation
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
impl UnaryOpT for Gelu {
const NAME: &'static str = "gelu";
@ -684,8 +673,6 @@ impl UnaryOpT for Gelu {
}
}
/// `erf` operation
/// <https://en.wikipedia.org/wiki/Error_function>
impl UnaryOpT for Erf {
const NAME: &'static str = "erf";
const KERNEL: &'static str = "uerf";
@ -975,10 +962,6 @@ impl BackpropOp {
};
Self(op)
}
pub(crate) fn is_none(&self) -> bool {
self.0.is_none()
}
}
impl std::ops::Deref for BackpropOp {

View File

@ -1,7 +1,7 @@
//! Support for the GGML file format.
use super::{k_quants, GgmlDType};
use crate::Result;
use crate::{Device, Result};
use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap;
@ -121,11 +121,12 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
raw_data: &[u8],
size_in_bytes: usize,
dims: Vec<usize>,
device: &Device,
) -> Result<super::QTensor> {
let raw_data_ptr = raw_data.as_ptr();
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
super::QTensor::new(data.to_vec(), dims)
super::QTensor::new(data.to_vec(), dims, device)
}
/// Creates a [Tensor] from a raw GGML tensor.
@ -133,6 +134,7 @@ pub fn qtensor_from_ggml(
ggml_dtype: GgmlDType,
raw_data: &[u8],
dims: Vec<usize>,
device: &Device,
) -> Result<super::QTensor> {
let tensor_elems = dims.iter().product::<usize>();
let blck_size = ggml_dtype.blck_size();
@ -144,18 +146,38 @@ pub fn qtensor_from_ggml(
let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
match ggml_dtype {
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q4_0 => {
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q4_1 => {
from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5_0 => {
from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5_1 => {
from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q8_0 => {
from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q2K => {
from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q3K => {
from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q4K => {
from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5K => {
from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q6K => {
from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
}
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
}
}
@ -163,6 +185,7 @@ pub fn qtensor_from_ggml(
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
reader: &mut R,
magic: VersionedMagic,
device: &Device,
) -> Result<(String, super::QTensor)> {
let n_dims = reader.read_u32::<LittleEndian>()?;
let name_len = reader.read_u32::<LittleEndian>()?;
@ -187,7 +210,7 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
// TODO: Mmap version to avoid copying the data around?
let mut raw_data = vec![0u8; size_in_bytes];
reader.read_exact(&mut raw_data)?;
match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) {
Ok(tensor) => Ok((name, tensor)),
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
}
@ -201,7 +224,10 @@ pub struct Content {
}
impl Content {
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
pub fn read<R: std::io::Seek + std::io::Read>(
reader: &mut R,
device: &Device,
) -> Result<Content> {
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
reader.seek(std::io::SeekFrom::Start(0))?;
@ -211,7 +237,7 @@ impl Content {
let mut tensors = HashMap::new();
while reader.stream_position()? != last_position {
let (name, tensor) = read_one_tensor(reader, magic)?;
let (name, tensor) = read_one_tensor(reader, magic, device)?;
tensors.insert(name, tensor);
}
Ok(Self {

View File

@ -3,7 +3,7 @@
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
use super::{GgmlDType, QTensor};
use crate::Result;
use crate::{Device, Result};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::collections::HashMap;
@ -29,7 +29,6 @@ impl TryFrom<u32> for Magic {
pub enum VersionedMagic {
GgufV1,
GgufV2,
GgufV3,
}
impl VersionedMagic {
@ -40,7 +39,6 @@ impl VersionedMagic {
let versioned_magic = match (magic, version) {
(Magic::Gguf, 1) => Self::GgufV1,
(Magic::Gguf, 2) => Self::GgufV2,
(Magic::Gguf, 3) => Self::GgufV3,
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
};
Ok(versioned_magic)
@ -59,6 +57,7 @@ impl TensorInfo {
&self,
reader: &mut R,
tensor_data_offset: u64,
device: &Device,
) -> Result<QTensor> {
let tensor_elems = self.shape.elem_count();
let blck_size = self.ggml_dtype.blck_size();
@ -71,7 +70,12 @@ impl TensorInfo {
let mut raw_data = vec![0u8; size_in_bytes];
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
reader.read_exact(&mut raw_data)?;
super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec())
super::ggml_file::qtensor_from_ggml(
self.ggml_dtype,
&raw_data,
self.shape.dims().to_vec(),
device,
)
}
}
@ -86,9 +90,7 @@ pub struct Content {
fn read_string<R: std::io::Read>(reader: &mut R, magic: &VersionedMagic) -> Result<String> {
let len = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
reader.read_u64::<LittleEndian>()? as usize
}
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
};
let mut v = vec![0u8; len];
reader.read_exact(&mut v)?;
@ -288,9 +290,7 @@ impl Value {
let value_type = ValueType::from_u32(value_type)?;
let len = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
reader.read_u64::<LittleEndian>()? as usize
}
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
};
let mut vs = Vec::with_capacity(len);
for _ in 0..len {
@ -387,15 +387,11 @@ impl Content {
let tensor_count = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
reader.read_u64::<LittleEndian>()? as usize
}
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
};
let metadata_kv_count = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
reader.read_u64::<LittleEndian>()? as usize
}
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
};
let mut metadata = HashMap::new();
@ -417,7 +413,7 @@ impl Content {
reader.read_u32_into::<LittleEndian>(&mut dimensions)?;
dimensions.into_iter().map(|c| c as usize).collect()
}
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
VersionedMagic::GgufV2 => {
let mut dimensions = vec![0; n_dimensions as usize];
reader.read_u64_into::<LittleEndian>(&mut dimensions)?;
dimensions.into_iter().map(|c| c as usize).collect()
@ -460,12 +456,13 @@ impl Content {
&self,
reader: &mut R,
name: &str,
device: &Device,
) -> Result<QTensor> {
let tensor_info = match self.tensor_infos.get(name) {
Some(tensor_info) => tensor_info,
None => crate::bail!("cannot find tensor-infor for {name}"),
};
tensor_info.read(reader, self.tensor_data_offset)
tensor_info.read(reader, self.tensor_data_offset, device)
}
}

View File

@ -1,4 +1,5 @@
use crate::{Device, Result, Shape, Tensor};
use tracing::debug;
#[cfg(target_feature = "avx")]
pub mod avx;
@ -14,6 +15,7 @@ pub mod utils;
pub use k_quants::GgmlType;
pub struct QTensor {
device: Device,
data: Box<dyn QuantizedType>,
shape: Shape,
}
@ -170,17 +172,20 @@ impl QTensor {
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
data: Vec<T>,
shape: S,
device: &Device,
) -> Result<Self> {
let shape = shape.into();
check_shape::<T>(&shape)?;
Ok(Self {
data: Box::new(data),
shape,
device: device.clone(),
})
}
pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
let shape = src.shape();
let device = src.device();
check_shape::<T>(shape)?;
let src = src
.to_dtype(crate::DType::F32)?
@ -197,6 +202,7 @@ impl QTensor {
Ok(Self {
data: Box::new(data),
shape: shape.clone(),
device: device.clone(),
})
}
@ -212,7 +218,12 @@ impl QTensor {
&self.shape
}
pub fn device(&self) -> &Device {
&self.device
}
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
// TODO Skip the CPU part on metal
let mut f32_data = vec![0f32; self.shape.elem_count()];
self.data.to_float(&mut f32_data)?;
Tensor::from_vec(f32_data, &self.shape, device)
@ -305,6 +316,46 @@ impl crate::CustomOp1 for QTensor {
)?;
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
}
fn metal_fwd(
&self,
storage: &crate::MetalStorage,
layout: &crate::Layout,
) -> Result<(crate::MetalStorage, Shape)> {
debug!("TODO qmatmul");
if !layout.is_contiguous() {
crate::bail!("input tensor is not contiguous {layout:?}")
}
let src_shape = layout.shape();
// self is transposed so n is first then k.
let (n, k) = self.shape.dims2()?;
if src_shape.rank() < 2 {
crate::bail!("input tensor has only one dimension {layout:?}")
}
let mut dst_shape = src_shape.dims().to_vec();
let last_k = dst_shape.pop().unwrap();
if last_k != k {
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
}
dst_shape.push(n);
let dst_shape = Shape::from(dst_shape);
// let storage = storage.as_slice::<f32>()?;
// let storage =
// &storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
let dst_storage = vec![0f32; dst_shape.elem_count()];
// self.matmul_t(
// (dst_shape.elem_count() / n, k, n),
// storage,
// &mut dst_storage,
// )?;
let cpu_storage = crate::CpuStorage::F32(dst_storage);
use crate::backend::{BackendDevice, BackendStorage};
if let Device::Metal(device) = &self.device {
Ok((device.storage_from_cpu_storage(&cpu_storage)?, dst_shape))
} else {
crate::bail!("qtensor not on metal device")
}
}
}
impl QMatMul {

View File

@ -334,33 +334,6 @@ impl Storage {
}
}
pub(crate) fn conv_transpose1d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
self.same_device(kernel, "conv-transpose1d")?;
self.same_dtype(kernel, "conv-transpose1d")?;
match (self, &kernel) {
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
Ok(Self::Cpu(s))
}
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "conv-transpose1d",
}
.bt()),
}
}
pub(crate) fn conv2d(
&self,
l: &Layout,

View File

@ -477,12 +477,6 @@ impl Tensor {
broadcast_binary_op!(broadcast_div, div);
broadcast_binary_op!(broadcast_maximum, maximum);
broadcast_binary_op!(broadcast_minimum, minimum);
broadcast_binary_op!(broadcast_eq, eq);
broadcast_binary_op!(broadcast_ne, ne);
broadcast_binary_op!(broadcast_lt, lt);
broadcast_binary_op!(broadcast_le, le);
broadcast_binary_op!(broadcast_gt, gt);
broadcast_binary_op!(broadcast_ge, ge);
unary_op!(recip, Recip);
unary_op!(neg, Neg);
@ -856,20 +850,6 @@ impl Tensor {
self.sum_impl(mean_dims, false)? * scale
}
/// Returns the unbiased variance over the selected dimension.
pub fn var_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "var")?;
let mean = self.mean_keepdim(dim)?;
let squares = self.broadcast_sub(&mean)?.sqr()?;
squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64
}
/// Returns the unbiased variance over the selected dimension.
pub fn var<D: Dim>(&self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "var")?;
self.var_keepdim(dim)?.squeeze(dim)
}
/// Gathers the maximum value across the selected dimension. The resulting shape has the same
/// number of dimensions as the original tensor and the select dimension has a single element.
pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
@ -1831,23 +1811,17 @@ impl Tensor {
/// Returns a new tensor detached from the current graph, gradient are not propagated through
/// this new node. The storage of this tensor is shared with the initial tensor.
///
/// If the tensor is already detached from the computation graph, the same tensor is returned.
pub fn detach(&self) -> Result<Tensor> {
if self.op.is_none() && !self.is_variable {
Ok(self.clone())
} else {
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
layout: self.layout.clone(),
op: BackpropOp::none(),
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
}
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
layout: self.layout.clone(),
op: BackpropOp::none(),
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
}
/// If the target device is the same as the tensor device, only a shallow copy is performed.
@ -1859,14 +1833,7 @@ impl Tensor {
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
}
(Storage::Cpu(storage), Device::Metal(metal)) => {
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
}
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
(Storage::Metal(storage), Device::Cpu) => {
println!("{storage:?} - {:?}", storage.to_cpu_storage()?);
Storage::Cpu(storage.to_cpu_storage()?)
}
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
// are the same.
@ -2440,23 +2407,6 @@ impl Tensor {
) -> Result<Self> {
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
}
/// Normalize a 'relative' axis value: positive values are kept, negative
/// values means counting the dimensions from the back.
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
let rank = self.rank() as i64;
if rank <= axis {
crate::bail!("axis {axis} is too large, tensor rank {rank}")
} else if 0 <= axis {
Ok(axis as usize)
} else {
let naxis = rank + axis;
if naxis < 0 {
crate::bail!("axis {axis} is too small, tensor rank {rank}")
}
Ok(naxis as usize)
}
}
}
macro_rules! bin_trait {

View File

@ -4,7 +4,7 @@ use crate::{Result, Tensor};
macro_rules! test_device {
// TODO: Switch to generating the two last arguments automatically once concat_idents is
// stable. https://github.com/rust-lang/rust/issues/29599
($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => {
($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => {
#[test]
fn $test_cpu() -> Result<()> {
$fn_name(&Device::Cpu)
@ -15,12 +15,6 @@ macro_rules! test_device {
fn $test_cuda() -> Result<()> {
$fn_name(&Device::new_cuda(0)?)
}
#[cfg(feature = "metal")]
#[test]
fn $test_metal() -> Result<()> {
$fn_name(&Device::new_metal(0)?)
}
};
}

View File

@ -13,11 +13,6 @@ res = torch.nn.functional.conv1d(t, w)
print(res.flatten())
res = torch.nn.functional.conv1d(t, w, padding=1)
print(res.flatten())
w_t = w.transpose(0, 1)
res = torch.nn.functional.conv_transpose1d(t, w_t)
print(res.shape)
print(res)
*/
fn conv1d(dev: &Device) -> Result<()> {
let t = Tensor::new(
@ -50,17 +45,6 @@ fn conv1d(dev: &Device) -> Result<()> {
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
if dev.is_cpu() {
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 7]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
4.7076, -5.9745, -0.8276, 1.621
],
);
}
Ok(())
}
@ -563,35 +547,14 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
Ok(())
}
test_device!(conv1d, conv1d_cpu, conv1d_gpu, conv1d_metal);
test_device!(
conv1d_small,
conv1d_small_cpu,
conv1d_small_gpu,
conv1d_small_metal
);
test_device!(conv2d, conv2d_cpu, conv2d_gpu, conv2d_metal);
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
test_device!(
conv2d_non_square,
conv2d_non_square_cpu,
conv2d_non_square_gpu,
conv2d_non_square_metal
);
test_device!(
conv2d_small,
conv2d_small_cpu,
conv2d_small_gpu,
conv2d_small_metal
);
test_device!(
conv2d_smaller,
conv2d_smaller_cpu,
conv2d_smaller_gpu,
conv2d_smaller_metal
);
test_device!(
conv2d_grad,
conv2d_grad_cpu,
conv2d_grad_gpu,
conv2_grad_metal
conv2d_non_square_gpu
);
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
test_device!(conv2d_grad, conv2d_grad_cpu, conv2d_grad_gpu);

View File

@ -205,71 +205,6 @@ fn unary_grad(device: &Device) -> Result<()> {
test_utils::to_vec1_round(grad_x, 4)?,
[1.0116, 1.0830, 1.0003, 0.6188],
);
// Testing compared to pytorch torch.erf
//
// import torch
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
// y = x.erf()
// print(y)
// loss = y.sum()
// loss.backward()
// print(x.grad)
let y = x.erf()?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(test_utils::to_vec1_round(&y, 4)?, [1.0, 0.8427, 1.0, 0.168]);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[0.0001, 0.4151, 0.0, 1.1033],
);
// Testing compared to pytorch nn.GELU(approximate = 'none')
//
// import torch
// import torch.nn.functional as F
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
// y = F.gelu(x, approximate='none')
// print(y)
// loss = y.sum()
// loss.backward()
// print(x.grad)
let y = x.gelu_erf()?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[2.9960, 0.8413, 3.9999, 0.0839]
);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[1.0119, 1.0833, 1.0005, 0.6188],
);
// Testing compared to pytorch elu
//
// import torch
// import torch.nn.functional as F
// x = torch.tensor([-1.0, 0.0, -2.0, 3.0], requires_grad=True)
// y = F.elu(x, alpha=2.0)
// print(y)
// loss = y.min
// loss = y.sum()
// loss.backward()
// print(x.grad)
let elu_x = Var::new(&[-1.0f32, 0., -2., 3.], device)?;
let y = elu_x.elu(2.)?;
let grads = y.backward()?;
let grad_x = grads.get(&elu_x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[-1.2642, 0.0000, -1.7293, 3.0000]
);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[0.7358, 2.0000, 0.2707, 1.0000]
);
Ok(())
}
@ -315,29 +250,9 @@ fn binary_grad(device: &Device) -> Result<()> {
Ok(())
}
test_device!(
simple_grad,
simple_grad_cpu,
simple_grad_gpu,
simple_grad_metal
);
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu, sum_grad_metal);
test_device!(
matmul_grad,
matmul_grad_cpu,
matmul_grad_gpu,
matmul_grad_metal
);
test_device!(
grad_descent,
grad_descent_cpu,
grad_descent_gpu,
grad_descent_metal
);
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu, unary_grad_metal);
test_device!(
binary_grad,
binary_grad_cpu,
binary_grad_gpu,
binary_grad_metal
);
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu);
test_device!(binary_grad, binary_grad_cpu, binary_grad_gpu);

View File

@ -49,7 +49,7 @@ fn contiguous(device: &Device) -> Result<()> {
Ok(())
}
test_device!(contiguous, contiguous_cpu, contiguous_gpu, contiguous_metal);
test_device!(contiguous, contiguous_cpu, contiguous_gpu);
#[test]
fn strided_blocks() -> Result<()> {

View File

@ -98,17 +98,15 @@ fn upsample_nearest2d(dev: &Device) -> Result<()> {
Ok(())
}
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu, avg_pool2d_metal);
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu);
test_device!(
avg_pool2d_pytorch,
avg_pool2d_pytorch_cpu,
avg_pool2d_pytorch_gpu,
avg_pool2d_pytorch_metal
avg_pool2d_pytorch_gpu
);
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu, max_pool2d_metal);
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);
test_device!(
upsample_nearest2d,
upsample_nearest2d_cpu,
upsample_nearest2d_gpu,
upsample_nearest2d_metal
upsample_nearest2d_gpu
);

View File

@ -180,22 +180,6 @@ fn transpose(device: &Device) -> Result<()> {
Ok(())
}
fn var(device: &Device) -> Result<()> {
// Values taken from https://pytorch.org/docs/stable/generated/torch.var.html
let data = &[
[0.2035f32, 1.2959, 1.8101, -0.4644],
[1.5027, -0.3270, 0.5905, 0.6538],
[-1.5745, 1.3330, -0.5596, -0.6548],
[0.1264, -0.5080, 1.6420, 0.1992],
];
let tensor = Tensor::new(data, device)?;
assert_eq!(
test_utils::to_vec2_round(&tensor.var_keepdim(1)?, 4)?,
&[[1.0631], [0.559], [1.4893], [0.8258]]
);
Ok(())
}
fn sum(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
@ -1070,60 +1054,34 @@ fn randn(device: &Device) -> Result<()> {
Ok(())
}
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
test_device!(min, min_cpu, min_gpu, min_metal);
test_device!(max, max_cpu, max_gpu, max_metal);
test_device!(argmax, argmax_cpu, argmax_gpu, argmax_metal);
test_device!(argmin, argmin_cpu, argmin_gpu, argmin_metal);
test_device!(transpose, transpose_cpu, transpose_gpu, transpose_metal);
test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
test_device!(
broadcast_matmul,
broadcast_matmul_cpu,
broadcast_matmul_gpu,
broadcast_matmul_metal
);
test_device!(
broadcasting,
broadcasting_cpu,
broadcasting_gpu,
broadcasting_metal
);
test_device!(
index_select,
index_select_cpu,
index_select_gpu,
index_select_metal
);
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
test_device!(
scatter_add,
scatter_add_cpu,
scatter_add_gpu,
scatter_add_metal
);
test_device!(
slice_scatter,
slice_scatter_cpu,
slice_scatter_gpu,
slice_scatter_metal
);
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
test_device!(var, var_cpu, var_gpu, var_metal);
test_device!(zeros, zeros_cpu, zeros_gpu);
test_device!(ones, ones_cpu, ones_gpu);
test_device!(arange, arange_cpu, arange_gpu);
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
test_device!(narrow, narrow_cpu, narrow_gpu);
test_device!(broadcast, broadcast_cpu, broadcast_gpu);
test_device!(cat, cat_cpu, cat_gpu);
test_device!(sum, sum_cpu, sum_gpu);
test_device!(min, min_cpu, min_gpu);
test_device!(max, max_cpu, max_gpu);
test_device!(argmax, argmax_cpu, argmax_gpu);
test_device!(argmin, argmin_cpu, argmin_gpu);
test_device!(transpose, transpose_cpu, transpose_gpu);
test_device!(unary_op, unary_op_cpu, unary_op_gpu);
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
test_device!(cmp, cmp_cpu, cmp_gpu);
test_device!(matmul, matmul_cpu, matmul_gpu);
test_device!(broadcast_matmul, broadcast_matmul_cpu, broadcast_matmul_gpu);
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
test_device!(index_select, index_select_cpu, index_select_gpu);
test_device!(index_add, index_add_cpu, index_add_gpu);
test_device!(gather, gather_cpu, gather_gpu);
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
test_device!(randn, randn_cpu, randn_gpu);
test_device!(clamp, clamp_cpu, clamp_gpu);
// There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381

View File

@ -4,9 +4,7 @@
//! <https://www.cs.toronto.edu/~kriz/cifar.html>
//! The binary version of the dataset is used.
use crate::vision::Dataset;
use candle::{DType, Device, Error, Result, Tensor};
use hf_hub::{api::sync::Api, Repo, RepoType};
use parquet::file::reader::{FileReader, SerializedFileReader};
use candle::{DType, Device, Result, Tensor};
use std::fs::File;
use std::io::{BufReader, Read};
@ -62,58 +60,3 @@ pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<Dataset> {
labels: 10,
})
}
fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor, Tensor)> {
let samples = parquet.metadata().file_metadata().num_rows() as usize;
let mut buffer_images: Vec<u8> = Vec::with_capacity(samples * 1_024);
let mut buffer_labels: Vec<u8> = Vec::with_capacity(samples);
for row in parquet.into_iter().flatten() {
for (_name, field) in row.get_column_iter() {
if let parquet::record::Field::Group(subrow) = field {
for (_name, field) in subrow.get_column_iter() {
if let parquet::record::Field::Bytes(value) = field {
let image = image::load_from_memory(value.data()).unwrap();
buffer_images.extend(image.to_rgb8().as_raw());
}
}
} else if let parquet::record::Field::Long(label) = field {
buffer_labels.push(*label as u8);
}
}
}
let images = (Tensor::from_vec(buffer_images, (samples, 3, 32, 32), &Device::Cpu)?
.to_dtype(DType::U8)?
/ 255.)?;
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
Ok((images, labels))
}
pub fn load() -> Result<Dataset> {
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let dataset_id = "cifar10".to_string();
let repo = Repo::with_revision(
dataset_id,
RepoType::Dataset,
"refs/convert/parquet".to_string(),
);
let repo = api.repo(repo);
let test_parquet_filename = repo
.get("plain_text/test/0000.parquet")
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let train_parquet_filename = repo
.get("plain_text/train/0000.parquet")
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
let (test_images, test_labels) = load_parquet(test_parquet)?;
let (train_images, train_labels) = load_parquet(train_parquet)?;
Ok(crate::vision::Dataset {
train_images,
train_labels,
test_images,
test_labels,
labels: 10,
})
}

View File

@ -16,7 +16,6 @@ candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
candle-onnx = { path = "../candle-onnx", version = "0.3.0", optional = true }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
image = { workspace = true }
@ -52,12 +51,11 @@ anyhow = { workspace = true }
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
metal = ["candle/metal", "candle-nn/metal", "candle-transformers/metal"]
cudnn = ["candle/cudnn"]
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
[[example]]
name = "llama_multiprocess"
@ -66,11 +64,3 @@ required-features = ["cuda", "nccl", "flash-attn"]
[[example]]
name = "reinforcement-learning"
required-features = ["pyo3"]
[[example]]
name = "onnx"
required-features = ["onnx"]
[[example]]
name = "onnx_basics"
required-features = ["onnx"]

View File

@ -8,7 +8,6 @@ use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
#[derive(Debug, Clone, PartialEq)]
enum NormType {
WeightNorm,
TimeGroupNorm,
None,
}
@ -269,7 +268,6 @@ impl Module for EncodecConvTranspose1d {
struct EncodecConv1d {
causal: bool,
conv: Conv1d,
norm: Option<candle_nn::GroupNorm>,
}
impl EncodecConv1d {
@ -294,7 +292,7 @@ impl EncodecConv1d {
},
vb.pp("conv"),
)?,
NormType::None | NormType::TimeGroupNorm => conv1d(
NormType::None => conv1d(
in_c,
out_c,
kernel_size,
@ -307,17 +305,9 @@ impl EncodecConv1d {
vb.pp("conv"),
)?,
};
let norm = match cfg.norm_type {
NormType::None | NormType::WeightNorm => None,
NormType::TimeGroupNorm => {
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
Some(gn)
}
};
Ok(Self {
causal: cfg.use_causal_conv,
conv,
norm,
})
}
}
@ -326,10 +316,8 @@ impl Module for EncodecConv1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
// TODO: padding, depending on causal.
let xs = self.conv.forward(xs)?;
match &self.norm {
None => Ok(xs),
Some(norm) => xs.apply(norm),
}
// If we add support for NormType "time_group_norm", we should add some normalization here.
Ok(xs)
}
}

View File

@ -1,10 +0,0 @@
## Using ONNX models in Candle
This example demonstrates how to run ONNX based models in Candle, the model
being used here is a small sequeezenet variant.
You can run the example with the following command:
```bash
cargo run --example squeezenet-onnx --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
```

View File

@ -1,78 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{IndexOp, D};
use clap::{Parser, ValueEnum};
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
SqueezeNet,
EfficientNet,
}
#[derive(Parser)]
struct Args {
#[arg(long)]
image: String,
#[arg(long)]
model: Option<String>,
/// The model to be used.
#[arg(value_enum, long, default_value_t = Which::SqueezeNet)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let image = candle_examples::imagenet::load_image224(args.image)?;
let image = match args.which {
Which::SqueezeNet => image,
Which::EfficientNet => image.permute((1, 2, 0))?,
};
println!("loaded image {image:?}");
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => match args.which {
Which::SqueezeNet => hf_hub::api::sync::Api::new()?
.model("lmz/candle-onnx".into())
.get("squeezenet1.1-7.onnx")?,
Which::EfficientNet => hf_hub::api::sync::Api::new()?
.model("onnx/EfficientNet-Lite4".into())
.get("efficientnet-lite4-11.onnx")?,
},
};
let model = candle_onnx::read_file(model)?;
let graph = model.graph.as_ref().unwrap();
let mut inputs = std::collections::HashMap::new();
inputs.insert(graph.input[0].name.to_string(), image.unsqueeze(0)?);
let mut outputs = candle_onnx::simple_eval(&model, inputs)?;
let output = outputs.remove(&graph.output[0].name).unwrap();
let prs = match args.which {
Which::SqueezeNet => candle_nn::ops::softmax(&output, D::Minus1)?,
Which::EfficientNet => output,
};
let prs = prs.i(0)?.to_vec1::<f32>()?;
// Sort the predictions and take the top 5
let mut top: Vec<_> = prs.iter().enumerate().collect();
top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
let top = top.into_iter().take(5).collect::<Vec<_>>();
// Print the top predictions
for &(i, p) in &top {
println!(
"{:50}: {:.2}%",
candle_examples::imagenet::CLASSES[i],
p * 100.0
);
}
Ok(())
}

View File

@ -1,87 +0,0 @@
use anyhow::Result;
use candle::{Device, Tensor};
use clap::{Parser, Subcommand};
#[derive(Subcommand, Debug, Clone)]
enum Command {
Print {
#[arg(long)]
file: String,
},
SimpleEval {
#[arg(long)]
file: String,
},
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
#[command(subcommand)]
command: Command,
}
pub fn main() -> Result<()> {
let args = Args::parse();
match args.command {
Command::Print { file } => {
let model = candle_onnx::read_file(file)?;
println!("{model:?}");
let graph = model.graph.unwrap();
for node in graph.node.iter() {
println!("{node:?}");
}
}
Command::SimpleEval { file } => {
let model = candle_onnx::read_file(file)?;
let graph = model.graph.as_ref().unwrap();
let constants: std::collections::HashSet<_> =
graph.initializer.iter().map(|i| i.name.as_str()).collect();
let mut inputs = std::collections::HashMap::new();
for input in graph.input.iter() {
use candle_onnx::onnx::tensor_proto::DataType;
if constants.contains(input.name.as_str()) {
continue;
}
let type_ = input.r#type.as_ref().expect("no type for input");
let type_ = type_.value.as_ref().expect("no type.value for input");
let value = match type_ {
candle_onnx::onnx::type_proto::Value::TensorType(tt) => {
let dt = match DataType::try_from(tt.elem_type) {
Ok(dt) => match candle_onnx::dtype(dt) {
Some(dt) => dt,
None => {
anyhow::bail!(
"unsupported 'value' data-type {dt:?} for {}",
input.name
)
}
},
type_ => anyhow::bail!("unsupported input type {type_:?}"),
};
let shape = tt.shape.as_ref().expect("no tensortype.shape for input");
let dims = shape
.dim
.iter()
.map(|dim| match dim.value.as_ref().expect("no dim value") {
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => Ok(42),
})
.collect::<Result<Vec<usize>>>()?;
Tensor::zeros(dims, dt, &Device::Cpu)?
}
type_ => anyhow::bail!("unsupported input type {type_:?}"),
};
println!("input {}: {value:?}", input.name);
inputs.insert(input.name.clone(), value);
}
let outputs = candle_onnx::simple_eval(&model, inputs)?;
for (name, value) in outputs.iter() {
println!("output {name}: {value:?}")
}
}
}
Ok(())
}

View File

@ -1,7 +1,5 @@
# candle-quantized-t5
## Seq2Seq example
This example uses a quantized version of the t5 model.
```bash
@ -10,8 +8,6 @@ $ cargo run --example quantized-t5 --release -- --prompt "translate to German: A
Eine schöne Kerze.
```
## Generating Quantized weight files
The weight file is automatically retrieved from the hub. It is also possible to
generate quantized weight files from the original safetensors file by using the
`tensor-tools` command line utility via:
@ -20,11 +16,8 @@ generate quantized weight files from the original safetensors file by using the
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
```
## Using custom models
To use a different model, specify the `model-id`.
For example, for text editing, you can use quantized [CoEdit models](https://huggingface.co/jbochi/candle-coedit-quantized).
To use a different model, specify the `model-id`. For example, you can use
quantized [CoEdit models](https://huggingface.co/jbochi/candle-coedit-quantized).
```bash
$ cargo run --example quantized-t5 --release -- \
@ -33,7 +26,6 @@ $ cargo run --example quantized-t5 --release -- \
--temperature 0
...
Although their flight is weak, they run quickly through the tree canopy.
```
By default, it will look for `model.gguf` and `config.json`, but you can specify
custom local or remote `weight-file` and `config-file`s:
@ -48,16 +40,3 @@ cargo run --example quantized-t5 --release -- \
...
Note that a storm surge is what forecasters consider a hurricane's most dangerous part.
```
### [MADLAD-400](https://arxiv.org/abs/2309.04662)
MADLAD-400 is a series of multilingual machine translation T5 models trained on 250 billion tokens covering over 450 languages using publicly available data. These models are competitive with significantly larger models.
```bash
cargo run --example quantized-t5 --release -- \
--model-id "jbochi/madlad400-3b-mt" --weight-file "model-q4k.gguf" \
--prompt "<2de> How are you, my friend?" \
--temperature 0
...
Wie geht es dir, mein Freund?
```

View File

@ -173,11 +173,7 @@ fn main() -> Result<()> {
.to_vec();
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let mut model = builder.build_model()?;
let mut output_token_ids = [builder
.config
.decoder_start_token_id
.unwrap_or(builder.config.pad_token_id) as u32]
.to_vec();
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
let temperature = if args.temperature <= 0. {
None
} else {

View File

@ -9,10 +9,9 @@ use std::io::Write;
use tokenizers::Tokenizer;
use candle::quantized::{ggml_file, gguf_file};
use candle::{Device, Tensor};
use candle::Tensor;
use candle_transformers::generation::LogitsProcessor;
use candle_examples::token_output_stream::TokenOutputStream;
use candle_transformers::models::quantized_llama as model;
use model::ModelWeights;
@ -25,7 +24,7 @@ enum Prompt {
One(String),
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
#[derive(Clone, Debug, Copy, ValueEnum)]
enum Which {
#[value(name = "7b")]
L7b,
@ -49,10 +48,8 @@ enum Which {
Mistral7b,
#[value(name = "7b-mistral-instruct")]
Mistral7bInstruct,
#[value(name = "7b-zephyr-a")]
Zephyr7bAlpha,
#[value(name = "7b-zephyr-b")]
Zephyr7bBeta,
#[value(name = "7b-zephyr")]
Zephyr7b,
}
impl Which {
@ -67,28 +64,7 @@ impl Which {
| Self::L7bCode
| Self::L13bCode
| Self::L34bCode => false,
// Zephyr is a fine tuned version of mistral and should be treated in the same way.
Self::Zephyr7bAlpha
| Self::Zephyr7bBeta
| Self::Mistral7b
| Self::Mistral7bInstruct => true,
}
}
fn is_zephyr(&self) -> bool {
match self {
Self::L7b
| Self::L13b
| Self::L70b
| Self::L7bChat
| Self::L13bChat
| Self::L70bChat
| Self::L7bCode
| Self::L13bCode
| Self::L34bCode
| Self::Mistral7b
| Self::Mistral7bInstruct => false,
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
Self::Mistral7b | Self::Mistral7bInstruct | Self::Zephyr7b => true,
}
}
}
@ -107,7 +83,7 @@ struct Args {
prompt: Option<String>,
/// The length of the sample to generate (in tokens).
#[arg(short = 'n', long, default_value_t = 1000)]
#[arg(short = 'n', long, default_value_t = 100)]
sample_len: usize,
/// The tokenizer config in json format.
@ -200,13 +176,10 @@ impl Args {
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
"mistral-7b-instruct-v0.1.Q4_K_S.gguf",
),
Which::Zephyr7bAlpha => (
Which::Zephyr7b => (
"TheBloke/zephyr-7B-alpha-GGUF",
"zephyr-7b-alpha.Q4_K_M.gguf",
),
Which::Zephyr7bBeta => {
("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf")
}
};
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(repo.to_string());
@ -217,6 +190,31 @@ impl Args {
}
}
fn print_token(next_token: u32, tokenizer: &Tokenizer) {
// Extracting the last token as a string is complicated, here we just apply some simple
// heuristics as it seems to work well enough for this example. See the following for more
// details:
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
if let Some(text) = tokenizer.id_to_token(next_token) {
let text = text.replace('▁', " ");
let ascii = text
.strip_prefix("<0x")
.and_then(|t| t.strip_suffix('>'))
.and_then(|t| u8::from_str_radix(t, 16).ok());
match ascii {
None => print!("{text}"),
Some(ascii) => {
if let Some(chr) = char::from_u32(ascii as u32) {
if chr.is_ascii() {
print!("{chr}")
}
}
}
}
let _ = std::io::stdout().flush();
}
}
fn format_size(size_in_bytes: usize) -> String {
if size_in_bytes < 1_000 {
format!("{}B", size_in_bytes)
@ -234,11 +232,13 @@ fn main() -> anyhow::Result<()> {
use tracing_subscriber::prelude::*;
let args = Args::parse();
let device = candle_examples::device(false)?;
let temperature = if args.temperature == 0. {
None
} else {
Some(args.temperature)
};
tracing_subscriber::fmt::init();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
@ -278,10 +278,10 @@ fn main() -> anyhow::Result<()> {
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
ModelWeights::from_gguf(model, &mut file)?
ModelWeights::from_gguf(model, &mut file, &device)?
}
Some("ggml" | "bin") | Some(_) | None => {
let model = ggml_file::Content::read(&mut file)?;
let model = ggml_file::Content::read(&mut file, &device)?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensors.iter() {
let elem_count = tensor.shape().elem_count();
@ -305,18 +305,16 @@ fn main() -> anyhow::Result<()> {
| Which::L34bCode => 1,
Which::Mistral7b
| Which::Mistral7bInstruct
| Which::Zephyr7bAlpha
| Which::Zephyr7bBeta
| Which::Zephyr7b
| Which::L70b
| Which::L70bChat => 8,
};
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa), &device)?
}
};
println!("model built");
let tokenizer = args.tokenizer()?;
let mut tos = TokenOutputStream::new(tokenizer);
let prompt = match args.prompt.as_deref() {
Some("chat") => Prompt::Chat,
Some("interactive") => Prompt::Interactive,
@ -325,11 +323,10 @@ fn main() -> anyhow::Result<()> {
};
let mut pre_prompt_tokens = vec![];
for prompt_index in 0.. {
loop {
let prompt_str = match &prompt {
Prompt::One(prompt) => prompt.clone(),
Prompt::Interactive | Prompt::Chat => {
let is_interactive = matches!(prompt, Prompt::Interactive);
print!("> ");
std::io::stdout().flush()?;
let mut prompt = String::new();
@ -340,13 +337,7 @@ fn main() -> anyhow::Result<()> {
prompt.pop();
}
}
if args.which.is_zephyr() {
if prompt_index == 0 || is_interactive {
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>",)
} else {
format!("<|user|>\n{prompt}</s>\n<|assistant|>")
}
} else if args.which.is_mistral() {
if args.which.is_mistral() {
format!("[INST] {prompt} [/INST]")
} else {
prompt
@ -354,8 +345,7 @@ fn main() -> anyhow::Result<()> {
}
};
print!("{}", &prompt_str);
let tokens = tos
.tokenizer()
let tokens = tokenizer
.encode(prompt_str, true)
.map_err(anyhow::Error::msg)?;
if args.verbose_prompt {
@ -378,24 +368,23 @@ fn main() -> anyhow::Result<()> {
let start_prompt_processing = std::time::Instant::now();
let mut next_token = {
let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)?
// TODO Remove this once implementation is finished.
let logits = logits.ones_like()?;
// logits_processor.sample(&logits)?
15043
};
let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
print_token(next_token, &tokenizer);
let eos_token = *tos.tokenizer().get_vocab(true).get("</s>").unwrap();
let eos_token = *tokenizer.get_vocab(true).get("</s>").unwrap();
let start_post_prompt = std::time::Instant::now();
let mut sampled = 0;
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
@ -408,21 +397,16 @@ fn main() -> anyhow::Result<()> {
&all_tokens[start_at..],
)?
};
next_token = logits_processor.sample(&logits)?;
// TODO Remove this once implementation is finished.
// let logits = logits.ones_like()?;
// next_token = logits_processor.sample(&logits)?;
let next_token = 15043;
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
sampled += 1;
print_token(next_token, &tokenizer);
if next_token == eos_token {
break;
};
}
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
let dt = start_post_prompt.elapsed();
println!(
"\n\n{:4} prompt tokens processed: {:.2} token/s",
@ -430,8 +414,9 @@ fn main() -> anyhow::Result<()> {
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
);
println!(
"{sampled:4} tokens generated: {:.2} token/s",
sampled as f64 / dt.as_secs_f64(),
"{:4} tokens generated: {:.2} token/s",
to_sample,
to_sample as f64 / dt.as_secs_f64(),
);
match prompt {

View File

@ -5,26 +5,12 @@
```bash
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode
...
Running on CPU, to run on GPU, build this example with `--features cuda`
Eine schöne Kerze.
9 tokens generated (2.42 token/s)
```
Variants such as [flan-t5](https://huggingface.co/google/flan-t5-small), [flan-ul2](https://huggingface.co/google/flan-ul2) (with `--revision "refs/pr/25"`), and [Co-EdIT](https://huggingface.co/grammarly/coedit-large) are also supported.
## Translation with [MADLAD-400](https://arxiv.org/abs/2309.04662)
MADLAD-400 is a series of multilingual machine translation T5 models trained on 250 billion tokens covering over 450 languages using publicly available data. These models are competitive with significantly larger models.
```bash
cargo run --example t5 --release -- \
--model-id "jbochi/madlad400-3b-mt" \
--prompt "<2de> How are you, my friend?" \
--decode --temperature 0
...
Wie geht es dir, mein Freund?
```
## Sentence embedding example
## Sentence embedding example:
```bash
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle."

View File

@ -104,17 +104,6 @@ impl T5ModelBuilder {
api.get("model-00004-of-00005.safetensors")?,
api.get("model-00005-of-00005.safetensors")?,
]
} else if model_id == "google/flan-ul2" {
vec![
api.get("model-00001-of-00008.safetensors")?,
api.get("model-00002-of-00008.safetensors")?,
api.get("model-00003-of-00008.safetensors")?,
api.get("model-00004-of-00008.safetensors")?,
api.get("model-00005-of-00008.safetensors")?,
api.get("model-00006-of-00008.safetensors")?,
api.get("model-00007-of-00008.safetensors")?,
api.get("model-00008-of-00008.safetensors")?,
]
} else {
vec![api.get("model.safetensors")?]
};
@ -183,12 +172,7 @@ fn main() -> Result<()> {
println!("Took {:?}", start.elapsed());
} else {
let mut model = builder.build_conditional_generation()?;
let mut output_token_ids = [builder
.config
.decoder_start_token_id
.unwrap_or(builder.config.pad_token_id)
as u32]
.to_vec();
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
if let Some(decoder_prompt) = &args.decoder_prompt {
print!("{decoder_prompt}");
output_token_ids.extend(

Binary file not shown.

Before

Width:  |  Height:  |  Size: 36 KiB

View File

@ -1,154 +0,0 @@
use image::{DynamicImage, ImageBuffer};
use serde::Deserialize;
use std::collections::HashMap;
use candle::{DType, Device, Result, Tensor};
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct ProcessorConfig {
do_resize: bool,
height: u32,
width: u32,
do_rescale: bool,
do_normalize: bool,
image_mean: Vec<f32>,
image_std: Vec<f32>,
}
impl Default for ProcessorConfig {
fn default() -> Self {
Self {
do_resize: true,
height: 384,
width: 384,
do_rescale: true,
do_normalize: true,
image_mean: vec![0.5, 0.5, 0.5],
image_std: vec![0.5, 0.5, 0.5],
}
}
}
pub struct ViTImageProcessor {
do_resize: bool,
height: u32,
width: u32,
do_normalize: bool,
image_mean: Vec<f32>,
image_std: Vec<f32>,
}
impl ViTImageProcessor {
pub fn new(config: &ProcessorConfig) -> Self {
Self {
do_resize: config.do_resize,
height: config.height,
width: config.width,
do_normalize: config.do_normalize,
image_mean: config.image_mean.clone(),
image_std: config.image_std.clone(),
}
}
pub fn preprocess(&self, images: Vec<&str>) -> Result<Tensor> {
let height = self.height as usize;
let width = self.width as usize;
let channels = 3;
let images = self.load_images(images)?;
let resized_images: Vec<DynamicImage> = if self.do_resize {
images
.iter()
.map(|image| self.resize(image.clone(), None).unwrap())
.collect()
} else {
images
};
let normalized_images: Vec<Tensor> = if self.do_normalize {
resized_images
.iter()
.map(|image| self.normalize(image.clone(), None, None).unwrap())
.collect()
} else {
let resized_images: Vec<ImageBuffer<image::Rgb<u8>, Vec<u8>>> =
resized_images.iter().map(|image| image.to_rgb8()).collect();
let data = resized_images
.into_iter()
.map(|image| image.into_raw())
.collect::<Vec<Vec<u8>>>();
data.iter()
.map(|image| {
Tensor::from_vec(image.clone(), (height, width, channels), &Device::Cpu)
.unwrap()
.permute((2, 0, 1))
.unwrap()
})
.collect::<Vec<Tensor>>()
};
Tensor::stack(&normalized_images, 0)
}
fn resize(
&self,
image: image::DynamicImage,
size: Option<HashMap<String, u32>>,
) -> Result<image::DynamicImage> {
let (height, width) = match &size {
Some(size) => (size.get("height").unwrap(), size.get("width").unwrap()),
None => (&self.height, &self.width),
};
let resized_image =
image.resize_exact(*width, *height, image::imageops::FilterType::Triangle);
Ok(resized_image)
}
fn normalize(
&self,
image: image::DynamicImage,
mean: Option<Vec<f32>>,
std: Option<Vec<f32>>,
) -> Result<Tensor> {
let mean = match mean {
Some(mean) => mean,
None => self.image_mean.clone(),
};
let std = match std {
Some(std) => std,
None => self.image_std.clone(),
};
let mean = Tensor::from_vec(mean, (3, 1, 1), &Device::Cpu)?;
let std = Tensor::from_vec(std, (3, 1, 1), &Device::Cpu)?;
let image = image.to_rgb8();
let data = image.into_raw();
let height = self.height as usize;
let width = self.width as usize;
let channels = 3;
let data =
Tensor::from_vec(data, &[height, width, channels], &Device::Cpu)?.permute((2, 0, 1))?;
(data.to_dtype(DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)
}
pub fn load_images(&self, image_path: Vec<&str>) -> Result<Vec<image::DynamicImage>> {
let mut images: Vec<image::DynamicImage> = Vec::new();
for path in image_path {
let img = image::io::Reader::open(path)?.decode().unwrap();
images.push(img);
}
Ok(images)
}
}

View File

@ -1,132 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::{Parser, ValueEnum};
use candle::{DType, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::models::trocr;
use tokenizers::Tokenizer;
mod image_processor;
#[derive(Clone, Debug, Copy, ValueEnum)]
enum Which {
Base,
Large,
}
#[derive(Parser, Debug)]
struct Args {
#[arg(long)]
model: Option<String>,
/// Choose the variant of the model to run.
#[arg(long, default_value = "base")]
which: Which,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Text to be translated
#[arg(long)]
image: String,
}
pub fn main() -> anyhow::Result<()> {
use hf_hub::api::sync::Api;
let args = Args::parse();
let tokenizer_dec = {
let tokenizer = Api::new()?
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
.get("tokenizer.json")?;
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
};
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
let device = candle_examples::device(args.cpu)?;
let vb = {
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => match args.which {
Which::Base => Api::new()?
.repo(hf_hub::Repo::with_revision(
"microsoft/trocr-base-handwritten".to_string(),
hf_hub::RepoType::Model,
"refs/pr/3".to_string(),
))
.get("model.safetensors")?,
Which::Large => Api::new()?
.repo(hf_hub::Repo::with_revision(
"microsoft/trocr-large-handwritten".to_string(),
hf_hub::RepoType::Model,
"refs/pr/6".to_string(),
))
.get("model.safetensors")?,
},
};
println!("model: {:?}", model);
unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }
};
let encoder_config = match args.which {
Which::Base => candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten(),
Which::Large => {
candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten()
}
};
let decoder_config = trocr::TrOCRConfig::default();
let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?;
let config = image_processor::ProcessorConfig::default();
let processor = image_processor::ViTImageProcessor::new(&config);
let image = vec![args.image.as_str()];
let image = processor.preprocess(image)?;
let encoder_xs = model.encoder().forward(&image)?;
let mut logits_processor =
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
let mut token_ids: Vec<u32> = vec![decoder_config.decoder_start_token_id];
for index in 0..1000 {
let context_size = if index >= 1 { 1 } else { token_ids.len() };
let start_pos = token_ids.len().saturating_sub(context_size);
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
let logits = model.decode(&input_ids, &encoder_xs, start_pos)?;
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
token_ids.push(token);
if let Some(t) = tokenizer_dec.next_token(token)? {
use std::io::Write;
print!("{t}");
std::io::stdout().flush()?;
}
if token == decoder_config.eos_token_id {
break;
}
}
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
println!();
Ok(())
}

View File

@ -1,16 +0,0 @@
# candle-trocr
`TrOCR` is a transformer OCR Model. In this example it is used to
transcribe image text. See the associated [model
card](https://huggingface.co/microsoft/trocr-base-printed) for details on
the model itself.
## Running an example
```bash
cargo run --example trocr --release -- --which base --cpu --image assets/trocr.png
```
```
<s> industry , Mr. Brown commented icily . " Let us have a</s>
```

View File

@ -345,7 +345,7 @@ enum Task {
Translate,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
#[derive(Clone, Copy, Debug, ValueEnum)]
enum WhichModel {
Tiny,
#[value(name = "tiny.en")]
@ -361,27 +361,15 @@ enum WhichModel {
MediumEn,
Large,
LargeV2,
LargeV3,
#[value(name = "distil-medium.en")]
DistilMediumEn,
#[value(name = "distil-large-v2")]
DistilLargeV2,
}
impl WhichModel {
fn is_multilingual(&self) -> bool {
match self {
Self::Tiny
| Self::Base
| Self::Small
| Self::Medium
| Self::Large
| Self::LargeV2
| Self::LargeV3
| Self::DistilLargeV2 => true,
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => {
false
Self::Tiny | Self::Base | Self::Small | Self::Medium | Self::Large | Self::LargeV2 => {
true
}
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn => false,
}
}
@ -397,9 +385,6 @@ impl WhichModel {
Self::MediumEn => ("openai/whisper-medium.en", "main"),
Self::Large => ("openai/whisper-large", "refs/pr/36"),
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
Self::LargeV3 => ("openai/whisper-large-v3", "main"),
Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"),
Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"),
}
}
}
@ -511,25 +496,17 @@ fn main() -> Result<()> {
repo.get(&format!("model-{ext}-q80.gguf"))?,
)
} else {
let config = repo.get("config.json")?;
let tokenizer = if args.model == WhichModel::LargeV3 {
panic!("openai/whisper-large-v3 does not provide a compatible tokenizer.json config at the moment")
} else {
repo.get("tokenizer.json")?
};
let model = repo.get("model.safetensors")?;
(config, tokenizer, model)
(
repo.get("config.json")?,
repo.get("tokenizer.json")?,
repo.get("model.safetensors")?,
)
};
(config, tokenizer, model, sample)
};
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let mel_bytes = match config.num_mel_bins {
80 => include_bytes!("melfilters.bytes").as_slice(),
128 => include_bytes!("melfilters128.bytes").as_slice(),
nmel => anyhow::bail!("unexpected num_mel_bins {nmel}"),
};
let mel_bytes = include_bytes!("melfilters.bytes");
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
@ -545,15 +522,12 @@ fn main() -> Result<()> {
.map(|v| *v as f32 / 32768.)
.collect();
println!("pcm data loaded {}", pcm_data.len());
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters);
let mel_len = mel.len();
let mel = Tensor::from_vec(
mel,
(1, config.num_mel_bins, mel_len / config.num_mel_bins),
&device,
)?;
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
println!("loaded mel: {:?}", mel.dims());
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let mut model = if args.quantized {
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;

View File

@ -1,268 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum};
use candle_transformers::models::yi::{Config, Model};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
#[value(name = "6b")]
L6b,
#[value(name = "34b")]
L34b,
}
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("</s>") {
Some(token) => token,
None => anyhow::bail!("cannot find the </s> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
#[arg(long, default_value = "01-ai/Yi-6B")]
model_id: String,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model size to use.
#[arg(long, default_value = "6b")]
which: Which,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
args.model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => match args.which {
Which::L6b => vec![
repo.get("model-00001-of-00002.safetensors")?,
repo.get("model-00002-of-00002.safetensors")?,
],
Which::L34b => vec![
repo.get("model-00001-of-00007.safetensors")?,
repo.get("model-00002-of-00007.safetensors")?,
repo.get("model-00003-of-00007.safetensors")?,
repo.get("model-00004-of-00007.safetensors")?,
repo.get("model-00005-of-00007.safetensors")?,
repo.get("model-00006-of-00007.safetensors")?,
repo.get("model-00007-of-00007.safetensors")?,
],
},
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = match args.which {
Which::L6b => Config::config_6b(),
Which::L34b => Config::config_34b(),
};
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -8,22 +8,24 @@ use candle::{Device, Result, Tensor};
pub fn device(cpu: bool) -> Result<Device> {
if cpu {
Ok(Device::Cpu)
} else if cuda_is_available() {
Ok(Device::new_cuda(0)?)
} else if metal_is_available() {
Ok(Device::new_metal(0)?)
} else {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
println!(
"Running on CPU, to run on GPU(metal), build this example with `--features metal`"
);
if cuda_is_available() {
Ok(Device::new_cuda(0)?)
} else if metal_is_available() {
Ok(Device::new_metal(0)?)
} else {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
println!("Running on CPU, to run on GPU(metal), build this example with `--features metal`");
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
println!(
"Running on CPU, to run on GPU, build this example with `--features cuda`"
);
}
Ok(Device::Cpu)
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
}
Ok(Device::Cpu)
}
}

View File

@ -233,8 +233,8 @@ impl FlashAttnVarLen {
let (seqlens_q, seqlens_q_layout) = self.seqlens_q.storage_and_layout();
let seqlens_q = match &*seqlens_q {
candle::Storage::Cpu(_) => candle::bail!("seqlens_q must be a cuda tensor"),
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
_ => candle::bail!("seqlens_q must be a cuda tensor"),
};
let seqlens_q = match seqlens_q_layout.contiguous_offsets() {
Some((o1, o2)) => seqlens_q.slice(o1..o2),
@ -243,8 +243,8 @@ impl FlashAttnVarLen {
let (seqlens_k, seqlens_k_layout) = self.seqlens_k.storage_and_layout();
let seqlens_k = match &*seqlens_k {
candle::Storage::Cpu(_) => candle::bail!("seqlens_k must be a cuda tensor"),
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
_ => candle::bail!("seqlens_k must be a cuda tensor"),
};
let seqlens_k = match seqlens_k_layout.contiguous_offsets() {
Some((o1, o2)) => seqlens_k.slice(o1..o2),

View File

@ -1,20 +1,17 @@
[package]
name = "candle-metal-kernels"
version = "0.3.0"
edition = "2021"
description = "CUDA kernels for Candle"
repository = "https://github.com/huggingface/candle"
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT OR Apache-2.0"
version.workspace = true
edition.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true
[dependencies]
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
metal = { workspace = true }
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"
thiserror = { workspace = true }
[dev-dependencies]
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
rand = "0.8.5"
half = { workspace = true }

View File

@ -24,32 +24,15 @@ kernel void FN_NAME( \
constant float &add, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
if (id >= dim) { \
return; \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = input[i] * mul + add; \
} \
const TYPENAME m = TYPENAME(mul); \
const TYPENAME a = TYPENAME(add); \
output[id] = input[id] * m + a; \
} \
kernel void FN_NAME##_strided( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant float &mul, \
constant float &add, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
const TYPENAME m = TYPENAME(mul); \
const TYPENAME a = TYPENAME(add); \
output[id] = input[get_strided_index(id, num_dims, dims, strides)] * m + a; \
} \
AFFINE(affine_float, float)

View File

@ -23,14 +23,17 @@ kernel void FN_NAME( \
device const TYPENAME *left, \
device const TYPENAME *right, \
device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
if (thread_position_in_grid >= dim) { \
return; \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
TYPENAME x = left[i]; \
TYPENAME y = right[i]; \
output[i] = OUT_TYPENAME(FN); \
} \
TYPENAME x = left[thread_position_in_grid]; \
TYPENAME y = right[thread_position_in_grid]; \
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@ -41,14 +44,17 @@ kernel void FN_NAME_STRIDED( \
device const TYPENAME *left, \
device const TYPENAME *right, \
device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
if (thread_position_in_grid >= dim) { \
return; \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
TYPENAME x = left[get_strided_index(i, num_dims, dims, left_strides)]; \
TYPENAME y = left[get_strided_index(i, num_dims, dims, right_strides)]; \
output[i] = OUT_TYPENAME(FN); \
} \
TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \
TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, right_strides)]; \
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
}
#define BINARY_OP(FN, NAME) \

View File

@ -23,12 +23,15 @@ kernel void FN_NAME( \
constant size_t &dim, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
if (thread_position_in_grid >= dim) { \
return; \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = RIGHT_TYPENAME(input[i]); \
} \
output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \
} \
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@ -37,17 +40,19 @@ kernel void FN_NAME_STRIDED( \
constant size_t *strides, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint i [[ thread_position_in_grid ]] \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
if (i >= dim) { \
return; \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
} \
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
} \
}
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
#if __METAL_VERSION__ >= 310
#endif

View File

@ -1,36 +1,6 @@
#include <metal_stdlib>
using namespace metal;
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \
constant size_t &dst_size, \
constant size_t &left_size, \
constant size_t &src_dim_size, \
constant size_t &right_size, \
constant size_t &ids_size, \
const device TYPENAME *input, \
const device INDEX_TYPENAME *input_ids, \
device TYPENAME *output, \
uint gid [[ thread_position_in_grid ]] \
) { \
if (gid >= dst_size) { \
return; \
} \
const size_t id_i = (gid / right_size) % ids_size; \
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
const size_t right_rank_i = gid % right_size; \
const size_t left_rank_i = gid / right_size / ids_size; \
/* \
// Force prevent out of bounds indexing \
// since there doesn't seem to be a good way to force crash \
// No need to check for zero we're only allowing unsized. \
*/ \
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \
output[gid] = input[src_i]; \
}
template <typename T, typename I>
void index_add(
device I *ids [[buffer(0)]],
@ -42,9 +12,12 @@ void index_add(
constant uint &dst_dim_size,
constant uint &right_size,
uint gid [[ thread_position_in_grid ]] \
uint threadgroup_size [[threads_per_threadgroup]],
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]],
uint thread_index [[thread_index_in_threadgroup]]
) {
const uint gid = thread_index + (threadgroup_position_in_grid * threadgroup_size);
if (gid >= left_size * right_size) {
return;
}
@ -70,13 +43,12 @@ kernel void FN_NAME( \
constant uint &left_size, \
constant uint &dst_dim_size, \
constant uint &right_size, \
uint gid [[ thread_position_in_grid ]] \
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \
uint threadgroup_size [[threads_per_threadgroup]], \
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \
INDEX_OP(is_u32_f32, uint, float)
INDEX_OP(is_u32_f16, uint, half)
#if __METAL_VERSION__ >= 310
IA_OP(bfloat, int64_t, ia_i64_bf16)

File diff suppressed because it is too large Load Diff

View File

@ -1,139 +0,0 @@
#include <metal_stdlib>
using namespace metal;
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
constant int THREADGROUP_SIZE = 1024;
# define REDUCE(FN, NAME, TYPENAME) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const TYPENAME *src, \
device TYPENAME *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint blockDim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup float shared_memory[THREADGROUP_SIZE]; \
\
shared_memory[tid] = 0; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
*/ \
TYPENAME x = shared_memory[tid]; \
TYPENAME y = src[idx]; \
shared_memory[tid] = FN; \
idx += blockDim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = blockDim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
TYPENAME x = shared_memory[tid]; \
TYPENAME y = shared_memory[tid + s]; \
shared_memory[tid] = FN; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
dst[dst_id] = shared_memory[0]; \
} \
kernel void softmax_float(
constant size_t &src_numel,
constant size_t &el_to_sum_per_block,
device const float *src,
device float *dst,
uint id [[ thread_position_in_grid ]],
uint tid [[ thread_index_in_threadgroup ]],
uint dst_id [[ threadgroup_position_in_grid ]],
uint blockDim [[ threads_per_threadgroup ]]
) {
threadgroup float shared_memory[THREADGROUP_SIZE];
shared_memory[tid] = -INFINITY;
// Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block.
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
size_t idx = start_idx + tid;
while (idx < stop_idx) {
// TODO: Fast version for the contiguous case.
shared_memory[tid] = max(shared_memory[tid], src[idx]);
idx += blockDim;
}
threadgroup_barrier(mem_flags::mem_none);
// reduction in shared memory
for (uint s = blockDim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]);
}
threadgroup_barrier(mem_flags::mem_none);
}
float max = shared_memory[0];
shared_memory[tid] = 0;
// Restart
idx = start_idx + tid;
while (idx < stop_idx) {
// TODO: Fast version for the contiguous case.
const float val = exp(src[idx] - max);
dst[idx] = val;
shared_memory[tid] += val;
idx += blockDim;
}
// reduction in shared memory
for (uint s = blockDim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] += shared_memory[tid + s];
}
threadgroup_barrier(mem_flags::mem_none);
}
const float inv_acc = 1/shared_memory[0];
idx = start_idx + tid;
while (idx < stop_idx) {
dst[idx] *= inv_acc;
idx += blockDim;
}
}
REDUCE(x + y, fast_sum_float, float)
REDUCE(x * y, fast_mul_float, float)
REDUCE(max(x, y), fast_max_float, float)

View File

@ -1,57 +0,0 @@
#include <metal_stdlib>
#
using namespace metal;
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \
kernel void FN_NAME( \
constant size_t &numel, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t *strides_t, \
constant size_t *strides_f, \
device const ID_TYPENAME *ids, \
device const TYPENAME *t, \
device const TYPENAME *f, \
device TYPENAME *out ,\
uint i [[ thread_position_in_grid ]] \
) { \
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \
} \
// WHERE_OP(float, int64_t, where_i64_f32)
// WHERE_OP(double, int64_t, where_i64_f64)
// WHERE_OP(uint8_t, int64_t, where_i64_u8)
// WHERE_OP(uint32_t, int64_t, where_i64_u32)
// WHERE_OP(int64_t, int64_t, where_i64_i64)
//
// WHERE_OP(float, uint32_t, where_u32_f32)
// WHERE_OP(double, uint32_t, where_u32_f64)
// WHERE_OP(uint8_t, uint32_t, where_u32_u8)
// WHERE_OP(uint32_t, uint32_t, where_u32_u32)
// WHERE_OP(int64_t, uint32_t, where_u32_i64)
WHERE_OP(float, uint8_t, where_u8_f32)
// WHERE_OP(double, uint8_t, where_u8_f64)
// WHERE_OP(uint8_t, uint8_t, where_u8_u8)
// WHERE_OP(uint32_t, uint8_t, where_u8_u32)
// WHERE_OP(int64_t, uint8_t, where_u8_i64)

View File

@ -1,7 +1,4 @@
#include <metal_stdlib>
#include <metal_math>
#
using namespace metal;
METAL_FUNC uint get_strided_index(
uint idx,
@ -20,51 +17,25 @@ METAL_FUNC uint get_strided_index(
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
template <typename T> METAL_FUNC T neg(T in){ return -in; }
template <typename T> METAL_FUNC T erf(T in){
float x = (float) in;
// constants
float a1 = 0.254829592;
float a2 = -0.284496736;
float a3 = 1.421413741;
float a4 = -1.453152027;
float a5 = 1.061405429;
float p = 0.3275911;
// Save the sign of x
int sign = 1;
if (x < 0)
sign = -1;
x = fabs(x);
// A&S formula 7.1.26
float t = 1.0/(1.0 + p*x);
float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x);
return T(sign*y);
}
template <typename T> METAL_FUNC T id(T in){ return in; }
template <typename T> METAL_FUNC T gelu_erf(T x){ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); }
template <typename T> METAL_FUNC T gelu(T x){
T x_sq = x * x;
T x_cube = x_sq * x;
T alpha = x + static_cast<T>(0.044715) * x_cube;
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
}
using namespace metal;
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
constant size_t &dim, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
if (thread_position_in_grid >= dim) { \
return; \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = TYPENAME(FN(input[i])); \
} \
output[thread_position_in_grid] = TYPENAME(FN(input[thread_position_in_grid])); \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@ -73,12 +44,15 @@ kernel void FN_NAME_STRIDED( \
constant size_t *strides, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
if (thread_position_in_grid >= dim) { \
return; \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = TYPENAME(FN(input[get_strided_index(i, num_dims, dims, strides)])); \
} \
output[thread_position_in_grid] = TYPENAME(FN(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)])); \
}
#define UNARY_OP(NAME) \
@ -95,17 +69,8 @@ UNARY_OP(sqr)
UNARY_OP(sqrt)
UNARY_OP(neg)
UNARY_OP(exp)
UNARY_OP(log)
UNARY_OP(gelu)
UNARY_OP(ceil)
UNARY_OP(floor)
UNARY_OP(round)
UNARY_OP(gelu_erf)
UNARY_OP(erf)
UNARY(id, float, copy_float, copy_float_strided)
UNARY(id, half, copy_half, copy_half_strided)
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
#if __METAL_VERSION__ >= 310
BFLOAT_UNARY_OP(cos)
@ -114,13 +79,4 @@ BFLOAT_UNARY_OP(sqr)
BFLOAT_UNARY_OP(sqrt)
BFLOAT_UNARY_OP(neg)
BFLOAT_UNARY_OP(exp)
BFLOAT_UNARY_OP(log)
BFLOAT_UNARY_OP(gelu)
BFLOAT_UNARY_OP(ceil)
BFLOAT_UNARY_OP(floor)
BFLOAT_UNARY_OP(round)
BFLOAT_UNARY_OP(gelu_erf)
BFLOAT_UNARY_OP(erf)
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
#endif

View File

@ -1,76 +0,0 @@
use candle_metal_kernels::{call_affine, Kernels};
use metal::objc::rc::autoreleasepool;
use metal::{Device, MTLResourceOptions};
use rand;
use std::any::type_name;
use std::time::Instant;
fn main() {
let device = Device::system_default().unwrap();
let kernels = Kernels::new();
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
let f32_10k = (0..10000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f32_100k = (0..100000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
"dtype", "kernel", "size", "runs", "total time", "avg time"
);
// f32
run_affine_bench(&device, &kernels, &f32_1k);
run_affine_bench(&device, &kernels, &f32_10k);
run_affine_bench(&device, &kernels, &f32_100k);
}
fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) {
let command_queue = device.new_command_queue();
let options = MTLResourceOptions::StorageModeManaged;
let iterations = 10000;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(v) as u64,
options,
);
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
let mul: f32 = 1.2345;
let add: f32 = 2.3456;
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_affine(
&device,
command_buffer,
&kernels,
"affine_float",
v.len(),
&input,
&mut output,
mul,
add,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
"affine",
v.len(),
iterations,
total_time,
total_time / iterations
);
}

View File

@ -1,182 +0,0 @@
use candle_metal_kernels::{binary, call_binary_contiguous, call_binary_strided, Kernels};
use half::{bf16, f16};
use metal::objc::rc::autoreleasepool;
use metal::{Device, MTLResourceOptions};
use rand;
use std::any::type_name;
use std::time::Instant;
fn main() {
let device = Device::system_default().unwrap();
let kernels = Kernels::new();
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
let f32_10k = (0..10000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f32_100k = (0..100000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
let f16_1k = f16_map(&f32_1k);
let f16_10k = f16_map(&f32_10k);
let f16_100k = f16_map(&f32_100k);
let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
let bf16_1k = bf16_map(&f32_1k);
let bf16_10k = bf16_map(&f32_10k);
let bf16_100k = bf16_map(&f32_100k);
let f32_ckernels = [
binary::contiguous::add::FLOAT,
binary::contiguous::sub::FLOAT,
binary::contiguous::mul::FLOAT,
binary::contiguous::div::FLOAT,
];
let f32_skernels = [
binary::strided::add::FLOAT,
binary::strided::sub::FLOAT,
binary::strided::mul::FLOAT,
binary::strided::div::FLOAT,
];
let f16_ckernels = [
binary::contiguous::add::HALF,
binary::contiguous::sub::HALF,
binary::contiguous::mul::HALF,
binary::contiguous::div::HALF,
];
let f16_skernels = [
binary::strided::add::HALF,
binary::strided::sub::HALF,
binary::strided::mul::HALF,
binary::strided::div::HALF,
];
let bf16_ckernels = [
binary::contiguous::add::BFLOAT,
binary::contiguous::sub::BFLOAT,
binary::contiguous::mul::BFLOAT,
binary::contiguous::div::BFLOAT,
];
let bf16_skernels = [
binary::strided::add::BFLOAT,
binary::strided::sub::BFLOAT,
binary::strided::mul::BFLOAT,
binary::strided::div::BFLOAT,
];
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
"dtype", "kernel", "size", "runs", "total time", "avg time"
);
// f32
run_binary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels);
run_binary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels);
run_binary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels);
// f16
run_binary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels);
run_binary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels);
run_binary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels);
// bf16
run_binary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels);
run_binary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels);
run_binary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels);
}
fn run_binary_bench<T: Clone>(
device: &Device,
kernels: &Kernels,
v: &[T],
contiguous: [binary::contiguous::Kernel; 4],
strided: [binary::strided::Kernel; 4],
) {
let command_queue = device.new_command_queue();
let options = MTLResourceOptions::StorageModeManaged;
let iterations = 1000;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(v) as u64,
options,
);
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
// Contiguous
for kernel_name in contiguous {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_binary_contiguous(
device,
&command_buffer,
kernels,
kernel_name,
v.len(),
&input,
&input,
&mut output,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.to_string(),
v.len(),
iterations,
total_time,
total_time / iterations
);
}
// Strided
let shape = vec![2, 5_000];
let strides = vec![2, 1];
let offset = 0;
for kernel_name in strided {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_binary_strided(
device,
command_buffer,
&kernels,
kernel_name,
&shape,
&input,
&strides,
offset,
&input,
&strides,
offset,
&mut output,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.to_string(),
v.len(),
iterations,
total_time,
total_time / iterations
);
}
}

View File

@ -1,84 +0,0 @@
use candle_metal_kernels::{call_cast_contiguous, Kernels};
use metal::objc::rc::autoreleasepool;
use metal::{Device, MTLResourceOptions};
use rand;
use std::any::type_name;
use std::time::Instant;
fn main() {
let device = Device::system_default().unwrap();
let kernels = Kernels::new();
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
let f32_10k = (0..10000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f32_100k = (0..100000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let contiguous_kernels = ["cast_u32_f32"];
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
"dtype", "kernel", "size", "runs", "total time", "avg time"
);
// f32
run_cast_bench(&device, &kernels, &f32_1k, &contiguous_kernels);
run_cast_bench(&device, &kernels, &f32_10k, &contiguous_kernels);
run_cast_bench(&device, &kernels, &f32_100k, &contiguous_kernels);
}
fn run_cast_bench<T: Clone>(
device: &Device,
kernels: &Kernels,
v: &[T],
contiguous: &[&'static str],
) {
let command_queue = device.new_command_queue();
let options = MTLResourceOptions::StorageModeManaged;
let iterations = 1000;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(v) as u64,
options,
);
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
// Contiguous
for kernel_name in contiguous {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_cast_contiguous(
device,
&command_buffer,
kernels,
kernel_name,
v.len(),
&input,
&mut output,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.to_string(),
v.len(),
iterations,
total_time,
total_time / iterations
);
}
// Strided?
}

View File

@ -1,197 +0,0 @@
use candle_metal_kernels::{call_unary_contiguous, call_unary_strided, unary, Kernels};
use half::{bf16, f16};
use metal::objc::rc::autoreleasepool;
use metal::{Device, MTLResourceOptions};
use rand;
use std::any::type_name;
use std::time::Instant;
fn main() {
let device = Device::system_default().unwrap();
let kernels = Kernels::new();
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
let f32_10k = (0..10000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f32_100k = (0..100000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
let f16_1k = f16_map(&f32_1k);
let f16_10k = f16_map(&f32_10k);
let f16_100k = f16_map(&f32_100k);
let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
let bf16_1k = bf16_map(&f32_1k);
let bf16_10k = bf16_map(&f32_10k);
let bf16_100k = bf16_map(&f32_100k);
let f32_ckernels = [
unary::contiguous::sin::FLOAT,
unary::contiguous::cos::FLOAT,
unary::contiguous::exp::FLOAT,
unary::contiguous::sqr::FLOAT,
unary::contiguous::sqrt::FLOAT,
unary::contiguous::neg::FLOAT,
unary::contiguous::copy::FLOAT,
];
let f32_skernels = [
unary::strided::sin::FLOAT,
unary::strided::cos::FLOAT,
unary::strided::exp::FLOAT,
unary::strided::sqr::FLOAT,
unary::strided::sqrt::FLOAT,
unary::strided::neg::FLOAT,
unary::strided::copy::FLOAT,
];
let f16_ckernels = [
unary::contiguous::sin::HALF,
unary::contiguous::cos::HALF,
unary::contiguous::exp::HALF,
unary::contiguous::sqr::HALF,
unary::contiguous::sqrt::HALF,
unary::contiguous::neg::HALF,
unary::contiguous::copy::HALF,
];
let f16_skernels = [
unary::strided::sin::HALF,
unary::strided::cos::HALF,
unary::strided::exp::HALF,
unary::strided::sqr::HALF,
unary::strided::sqrt::HALF,
unary::strided::neg::HALF,
unary::strided::copy::HALF,
];
let bf16_ckernels = [
unary::contiguous::sin::BFLOAT,
unary::contiguous::cos::BFLOAT,
unary::contiguous::exp::BFLOAT,
unary::contiguous::sqr::BFLOAT,
unary::contiguous::sqrt::BFLOAT,
unary::contiguous::neg::BFLOAT,
unary::contiguous::copy::BFLOAT,
];
let bf16_skernels = [
unary::strided::sin::BFLOAT,
unary::strided::cos::BFLOAT,
unary::strided::exp::BFLOAT,
unary::strided::sqr::BFLOAT,
unary::strided::sqrt::BFLOAT,
unary::strided::neg::BFLOAT,
unary::strided::copy::BFLOAT,
];
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
"dtype", "kernel", "size", "runs", "total time", "avg time"
);
// f32
run_unary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels);
run_unary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels);
run_unary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels);
// f16
run_unary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels);
run_unary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels);
run_unary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels);
// bf16
run_unary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels);
run_unary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels);
run_unary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels);
}
fn run_unary_bench<T: Clone>(
device: &Device,
kernels: &Kernels,
v: &[T],
contiguous: [unary::contiguous::Kernel; 7],
strided: [unary::strided::Kernel; 7],
) {
let command_queue = device.new_command_queue();
let options = MTLResourceOptions::StorageModeManaged;
let iterations = 10000;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(v) as u64,
options,
);
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
// Contiguous
for kernel_name in contiguous {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_unary_contiguous(
device,
&command_buffer,
kernels,
kernel_name,
v.len(),
&input,
&mut output,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.0,
v.len(),
iterations,
total_time,
total_time / iterations
);
}
// Strided
let shape = vec![2, 5_000];
let strides = vec![2, 1];
let offset = 0;
for kernel_name in &strided {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_unary_strided(
device,
command_buffer,
&kernels,
kernel_name,
&shape,
&input,
&strides,
offset,
&mut output,
0,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.0,
v.len(),
iterations,
total_time,
total_time / iterations
);
}
}

View File

@ -14,12 +14,12 @@ accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
half = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
[dev-dependencies]
anyhow = { workspace = true }
@ -29,5 +29,5 @@ clap = { workspace = true }
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"]
metal = ["candle/metal"]
mkl = ["dep:intel-mkl-src", "candle/mkl"]
metal = ["candle/metal", "dep:candle-metal-kernels"]

View File

@ -6,16 +6,14 @@ use serde::Deserialize;
pub enum Activation {
#[default]
Gelu,
#[serde(rename = "gated-gelu")]
NewGelu,
Relu,
Relu2,
Relu6,
Silu,
Sigmoid,
HardSigmoid,
Swiglu,
Swish,
HardSwish,
Elu(f64),
LeakyRelu(f64),
}
@ -31,10 +29,7 @@ impl super::Module for Activation {
Self::Relu6 => xs.clamp(0f32, 6f32),
Self::Silu => crate::ops::silu(xs),
Self::Sigmoid => crate::ops::sigmoid(xs),
Self::HardSigmoid => crate::ops::hard_sigmoid(xs),
Self::Swiglu => crate::ops::swiglu(xs),
Self::Swish => xs * crate::ops::sigmoid(xs)?,
Self::HardSwish => xs * crate::ops::hard_sigmoid(xs)?,
&Self::Elu(alpha) => xs.elu(alpha),
&Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope),
}

View File

@ -70,67 +70,6 @@ impl crate::Module for Conv1d {
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ConvTranspose1dConfig {
pub padding: usize,
pub output_padding: usize,
pub stride: usize,
pub dilation: usize,
// TODO: support groups.
}
impl Default for ConvTranspose1dConfig {
fn default() -> Self {
Self {
padding: 0,
output_padding: 0,
stride: 1,
dilation: 1,
}
}
}
#[derive(Clone, Debug)]
pub struct ConvTranspose1d {
weight: Tensor,
bias: Option<Tensor>,
config: ConvTranspose1dConfig,
}
impl ConvTranspose1d {
pub fn new(weight: Tensor, bias: Option<Tensor>, config: ConvTranspose1dConfig) -> Self {
Self {
weight,
bias,
config,
}
}
pub fn config(&self) -> &ConvTranspose1dConfig {
&self.config
}
}
impl crate::Module for ConvTranspose1d {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv_transpose1d(
&self.weight,
self.config.padding,
self.config.output_padding,
self.config.stride,
self.config.dilation,
)?;
match &self.bias {
None => Ok(x),
Some(bias) => {
let b = bias.dims1()?;
let bias = bias.reshape((1, b, 1, 1))?;
Ok(x.broadcast_add(&bias)?)
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Conv2dConfig {
pub padding: usize,
@ -302,39 +241,6 @@ pub fn conv1d(
Ok(Conv1d::new(ws, Some(bs), cfg))
}
pub fn conv_transpose1d(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: ConvTranspose1dConfig,
vb: crate::VarBuilder,
) -> Result<ConvTranspose1d> {
let bound = 1. / (out_channels as f64 * kernel_size as f64).sqrt();
let init = crate::Init::Uniform {
lo: -bound,
up: bound,
};
let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?;
let bs = vb.get_with_hints(out_channels, "bias", init)?;
Ok(ConvTranspose1d::new(ws, Some(bs), cfg))
}
pub fn conv_transpose1d_no_bias(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: ConvTranspose1dConfig,
vb: crate::VarBuilder,
) -> Result<ConvTranspose1d> {
let bound = 1. / (out_channels as f64 * kernel_size as f64).sqrt();
let init = crate::Init::Uniform {
lo: -bound,
up: bound,
};
let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?;
Ok(ConvTranspose1d::new(ws, None, cfg))
}
pub fn conv2d(
in_channels: usize,
out_channels: usize,

View File

@ -95,14 +95,6 @@ impl LayerNorm {
eps,
}
}
pub fn weight(&self) -> &Tensor {
&self.weight
}
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
}
impl crate::Module for LayerNorm {

View File

@ -1,5 +1,6 @@
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
use rayon::prelude::*;
use tracing::debug;
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
@ -39,21 +40,11 @@ pub fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)?
}
pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
let xs = xs.chunk(2, candle::D::Minus1)?;
crate::ops::silu(&xs[0])? * &xs[1]
}
pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
// TODO: Should we have a specialized op for this?
(xs.neg()?.exp()? + 1.0)?.recip()
}
pub fn hard_sigmoid(xs: &Tensor) -> Result<Tensor> {
// TODO: Should we have a specialized op for this?
((xs + 3.0)? / 6.0)?.clamp(0f32, 1f32)
}
pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result<Tensor> {
let zeros = xs.zeros_like()?;
xs.maximum(&zeros)? + xs.minimum(&zeros)? * negative_slope
@ -208,29 +199,8 @@ impl candle::CustomOp1 for SoftmaxLastDim {
storage: &candle::MetalStorage,
layout: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::backend::{BackendStorage};
let device = storage.device();
let command_buffer = device.command_buffer();
let kernels = device.kernels();
let name = "softmax_float";
assert!(layout.is_contiguous());
assert!(layout.start_offset() == 0);
let last_dim = layout.dims()[layout.shape().rank() - 1];
let elem_count = layout.shape().elem_count();
let mut output = device.new_buffer(elem_count, storage.dtype());
candle_metal_kernels::call_last_softmax(
device.metal_device(),
&command_buffer,
&kernels,
name,
elem_count,
last_dim,
storage.buffer(),
&mut output,
)
.unwrap();
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
Ok((newstorage, layout.shape().clone()))
debug!("TODO softmax-last-dim");
Ok((storage.clone(), layout.shape().clone()))
}
}

View File

@ -1,23 +0,0 @@
[package]
name = "candle-onnx"
version = "0.3.0"
edition = "2021"
description = "ONNX support for Candle"
repository = "https://github.com/huggingface/candle"
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
prost = "0.12.1"
[build-dependencies]
prost-build = "0.12.1"
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
clap = { version = "4.2.4", features = ["derive"] }

View File

@ -1,21 +0,0 @@
# candle-onnx
This crate adds ONNX support to candle
## FAQ
#### Missing protoc installation when compiling candle-onnx
The candle-onnx dependency prost-build no longer comes bundled with prost
binaries. This could cause the following error when attempting to compile
candle-onnx:
```
error: failed to run custom build command for `candle-onnx`
Caused by: // (...)
Could not find `protoc` installation and this build crate cannot proceed without this knowledge.
```
To fix this issue install protoc on your system and make it available in your
system `PATH`. See the [protoc
documentation](https://grpc.io/docs/protoc-installation/) for more information.

View File

@ -1,6 +0,0 @@
use std::io::Result;
fn main() -> Result<()> {
prost_build::compile_protos(&["src/onnx.proto3"], &["src/"])?;
Ok(())
}

View File

@ -1,755 +0,0 @@
use crate::onnx;
use crate::onnx::attribute_proto::AttributeType;
use crate::onnx::tensor_proto::DataType;
use candle::{bail, DType, Device, Result, Tensor};
use std::collections::HashMap;
pub type Value = Tensor;
pub fn dtype(dt: DataType) -> Option<DType> {
match dt {
DataType::Uint8 => Some(DType::U8),
DataType::Uint32 => Some(DType::U32),
DataType::Int64 => Some(DType::I64),
DataType::Float16 => Some(DType::F16),
DataType::Float => Some(DType::F32),
DataType::Double => Some(DType::F64),
_ => None,
}
}
trait Attr {
const TYPE: AttributeType;
fn get(attr: &onnx::AttributeProto) -> Result<&Self>;
}
impl Attr for i64 {
const TYPE: AttributeType = AttributeType::Int;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
Ok(&attr.i)
}
}
impl Attr for f32 {
const TYPE: AttributeType = AttributeType::Float;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
Ok(&attr.f)
}
}
impl Attr for [i64] {
const TYPE: AttributeType = AttributeType::Ints;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
Ok(attr.ints.as_slice())
}
}
impl Attr for str {
const TYPE: AttributeType = AttributeType::String;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
std::str::from_utf8(&attr.s).map_err(candle::Error::wrap)
}
}
fn get_attr_<'a>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a onnx::AttributeProto> {
match node.attribute.iter().find(|attr| attr.name == name) {
None => {
bail!(
"cannot find the '{name}' attribute in '{}' for {}",
node.op_type,
node.name
)
}
Some(dt) => Ok(dt),
}
}
fn get_attr<'a, T: Attr + ?Sized>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a T> {
let attr = get_attr_(node, name)?;
if attr.r#type() != T::TYPE {
bail!(
"unsupported type {:?} for '{name}' attribute in '{}' for {}",
attr.r#type,
node.op_type,
node.name
)
}
T::get(attr)
}
fn get_attr_opt<'a, T: Attr + ?Sized>(
node: &'a onnx::NodeProto,
name: &str,
) -> Result<Option<&'a T>> {
match node.attribute.iter().find(|attr| attr.name == name) {
None => Ok(None),
Some(attr) => {
if attr.r#type() != T::TYPE {
bail!(
"unsupported type {:?} for '{name}' attribute in '{}' for {}",
attr.r#type,
node.op_type,
node.name
)
}
let val = T::get(attr)?;
Ok(Some(val))
}
}
}
pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
match DataType::try_from(t.data_type) {
Ok(DataType::Int32) => {
if t.int32_data.is_empty() {
let len = t.raw_data.len() / 4;
let data: &[i32] =
unsafe { std::slice::from_raw_parts(t.raw_data.as_ptr() as *const i32, len) };
let data = data.iter().map(|v| *v as i64).collect::<Vec<_>>();
Tensor::from_vec(data, len, &Device::Cpu)
} else {
let data = t.int32_data.iter().map(|v| *v as i64).collect::<Vec<_>>();
Tensor::from_vec(data, t.int32_data.len(), &Device::Cpu)
}
}
Ok(dt) => match dtype(dt) {
Some(dt) => {
if dt == DType::F32 && !t.float_data.is_empty() {
Tensor::from_slice(&t.float_data, dims.as_slice(), &Device::Cpu)
} else if dt == DType::F64 && !t.double_data.is_empty() {
Tensor::from_slice(&t.double_data, dims.as_slice(), &Device::Cpu)
} else if dt == DType::I64 && !t.int64_data.is_empty() {
Tensor::from_slice(&t.int64_data, dims.as_slice(), &Device::Cpu)
} else {
Tensor::from_raw_buffer(
t.raw_data.as_slice(),
dt,
dims.as_slice(),
&Device::Cpu,
)
}
}
None => {
bail!("unsupported 'value' data-type {dt:?} for {name}")
}
},
Err(_) => {
bail!("unsupported 'value' data-type {} for {name}", t.data_type,)
}
}
}
// This function provides a direct evaluation of the proto.
// Longer-term, we should first convert the proto to an intermediate representation of the compute
// graph so as to make multiple evaluations more efficient.
// An example upside of this would be to remove intermediary values when they are not needed
// anymore.
pub fn simple_eval(
model: &onnx::ModelProto,
inputs: HashMap<String, Value>,
) -> Result<HashMap<String, Value>> {
let graph = match &model.graph {
None => bail!("no graph defined in proto"),
Some(graph) => graph,
};
let mut values = inputs;
for t in graph.initializer.iter() {
let tensor = get_tensor(t, t.name.as_str())?;
values.insert(t.name.to_string(), tensor);
}
for input in graph.input.iter() {
let input_type = match &input.r#type {
Some(input_type) => input_type,
None => continue,
};
let input_type = match &input_type.value {
Some(input_type) => input_type,
None => continue,
};
let tensor_type = match input_type {
onnx::type_proto::Value::TensorType(tt) => tt,
_ => continue,
};
let tensor = match values.get(&input.name) {
None => bail!("missing input {}", input.name),
Some(tensor) => tensor,
};
let dt = match DataType::try_from(tensor_type.elem_type) {
Ok(dt) => match dtype(dt) {
Some(dt) => dt,
None => {
bail!("unsupported 'value' data-type {dt:?} for {}", input.name)
}
},
type_ => bail!("unsupported input type {type_:?}"),
};
match &tensor_type.shape {
None => continue,
Some(shape) => {
if shape.dim.len() != tensor.rank() {
bail!(
"unexpected rank for {}, got {:?}, expected {:?}",
input.name,
shape.dim,
tensor.shape()
)
}
for (idx, (d, &dim)) in shape.dim.iter().zip(tensor.dims().iter()).enumerate() {
match &d.value {
Some(onnx::tensor_shape_proto::dimension::Value::DimValue(v)) => {
if *v as usize != dim {
bail!(
"unexpected dim {idx} for {}, got {:?}, expected {:?}",
input.name,
shape.dim,
tensor.shape()
)
}
}
// We do not check equality constraints for the DimParam dimensions for now.
Some(onnx::tensor_shape_proto::dimension::Value::DimParam(_)) | None => (),
}
}
}
};
if dt != tensor.dtype() {
bail!(
"unexpected dtype for {}, got {:?}, expected {dt:?}",
input.name,
tensor.dtype()
)
}
}
// The nodes are topologically sorted so we can just process them in order.
for node in graph.node.iter() {
let get = |input_name: &str| match values.get(input_name) {
Some(value) => Ok(value),
None => bail!("cannot find {input_name} for op {}", node.name),
};
// TODO: Validate node.input for each operator.
match node.op_type.as_str() {
"Add" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_add(input1)?;
values.insert(node.output[0].clone(), output);
}
"Sub" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_sub(input1)?;
values.insert(node.output[0].clone(), output);
}
"Mul" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_mul(input1)?;
values.insert(node.output[0].clone(), output);
}
"Div" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_div(input1)?;
values.insert(node.output[0].clone(), output);
}
"Equal" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_eq(input1)?;
values.insert(node.output[0].clone(), output);
}
"Not" => {
let xs = get(&node.input[0])?;
let xs = xs.eq(&xs.zeros_like()?)?;
values.insert(node.output[0].clone(), xs);
}
"MatMul" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_matmul(input1)?;
values.insert(node.output[0].clone(), output);
}
"Reshape" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?.to_vec1::<i64>()?;
// TODO: Check that there is at most a single -1 or 0, handle other neg values.
let mut other_than_minus1 = 1usize;
for &v in input1.iter() {
if v != -1 && v != 0 {
other_than_minus1 *= v as usize
}
}
let input1 = input1
.iter()
.enumerate()
.map(|(idx, &v)| match v {
-1 => Ok(input0.elem_count() / other_than_minus1),
0 => input0.dim(idx),
_ => Ok(v as usize),
})
.collect::<Result<Vec<usize>>>()?;
let output = input0.reshape(input1)?;
values.insert(node.output[0].clone(), output);
}
"LogSoftmax" => {
let input = get(&node.input[0])?;
let output = match get_attr_opt::<i64>(node, "axis")? {
None => candle_nn::ops::softmax_last_dim(input)?,
Some(&axis) => {
let axis = input.normalize_axis(axis)?;
candle_nn::ops::log_softmax(input, axis)?
}
};
values.insert(node.output[0].clone(), output);
}
"Softmax" => {
let input = get(&node.input[0])?;
let output = match get_attr_opt::<i64>(node, "axis")? {
None => candle_nn::ops::softmax_last_dim(input)?,
Some(&axis) => {
let axis = input.normalize_axis(axis)?;
candle_nn::ops::softmax(input, axis)?
}
};
values.insert(node.output[0].clone(), output);
}
"Transpose" => {
let input = get(&node.input[0])?;
let output = match get_attr_opt::<[i64]>(node, "perm")? {
None => input.t()?,
Some(perm) => {
let perm = perm.iter().map(|&v| v as usize).collect::<Vec<_>>();
input.permute(perm)?
}
};
values.insert(node.output[0].clone(), output);
}
"Dropout" => {
let input = get(&node.input[0])?;
// Do not apply dropout at the moment, consider that we're only doing inference.
values.insert(node.output[0].clone(), input.clone());
}
"MaxPool" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#MaxPool
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
let kernel_shape = get_attr::<[i64]>(node, "kernel_shape")?;
let pads = get_attr_opt::<[i64]>(node, "pads")?;
let strides = get_attr_opt::<[i64]>(node, "strides")?;
let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
match auto_pad {
None | Some("NOTSET") => (),
Some(s) => bail!("unsupported auto_pad {s}"),
};
if let Some(d) = dilations {
if d.iter().any(|&v| v != 1) {
bail!("MaxPool with dilation != 1, {dilations:?}")
}
}
if let Some(d) = pads {
if d.iter().any(|&v| v != 0) {
bail!("MaxPool with pads != 0, {pads:?}")
}
}
let xs = get(&node.input[0])?;
let (k1, k2) = match kernel_shape {
[k1, k2] => (*k1 as usize, *k2 as usize),
_ => bail!("only 2d MaxPool is supported, kernel shape {kernel_shape:?}"),
};
let ys = match strides {
None => xs.max_pool2d((k1, k2))?,
Some([s1, s2]) => {
xs.max_pool2d_with_stride((k1, k2), (*s1 as usize, *s2 as usize))?
}
Some(strides) => bail!("only 2d MaxPool is supported, strides {strides:?}"),
};
values.insert(node.output[0].clone(), ys);
}
"AveragePool" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
let kernel_shape = get_attr::<[i64]>(node, "kernel_shape")?;
let pads = get_attr_opt::<[i64]>(node, "pads")?;
let strides = get_attr_opt::<[i64]>(node, "strides")?;
let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
match auto_pad {
None | Some("NOTSET") => (),
Some(s) => bail!("unsupported auto_pad {s}"),
};
if let Some(d) = dilations {
if d.iter().any(|&v| v != 1) {
bail!("AvgPool with dilation != 1, {dilations:?}")
}
}
if let Some(d) = pads {
if d.iter().any(|&v| v != 0) {
bail!("AvgPool with pads != 0, {pads:?}")
}
}
let xs = get(&node.input[0])?;
let (k1, k2) = match kernel_shape {
[k1, k2] => (*k1 as usize, *k2 as usize),
_ => bail!("only 2d AvgPool is supported, kernel shape {kernel_shape:?}"),
};
let ys = match strides {
None => xs.avg_pool2d((k1, k2))?,
Some([s1, s2]) => {
xs.avg_pool2d_with_stride((k1, k2), (*s1 as usize, *s2 as usize))?
}
Some(strides) => bail!("only 2d AvgPool is supported, strides {strides:?}"),
};
values.insert(node.output[0].clone(), ys);
}
"BatchNormalization" => {
let training_mode = get_attr_opt::<i64>(node, "training_mode")?;
if training_mode.copied().unwrap_or(0) != 0 {
bail!("training mode is not supported for BatchNorm")
}
let eps = get_attr_opt::<f32>(node, "epsilon")?
.copied()
.unwrap_or(1e-5);
let xs = get(&node.input[0])?;
let weight = get(&node.input[1])?;
let bias = get(&node.input[2])?;
let running_mean = get(&node.input[3])?;
let running_var = get(&node.input[4])?;
let target_shape: Vec<usize> = xs
.dims()
.iter()
.enumerate()
.map(|(idx, v)| if idx == 1 { *v } else { 1 })
.collect();
let target_shape = target_shape.as_slice();
let xs = xs
.broadcast_sub(&running_mean.reshape(target_shape)?)?
.broadcast_div(&(running_var.reshape(target_shape)? + eps as f64)?.sqrt()?)?;
let weight = weight.reshape(target_shape)?;
let bias = bias.reshape(target_shape)?;
let xs = xs.broadcast_mul(&weight)?.broadcast_add(&bias)?;
values.insert(node.output[0].clone(), xs);
}
"Squeeze" => {
let xs = get(&node.input[0])?;
let mut axes = if node.input.len() <= 1 {
// contract all the dimensions with size 1 except the batch dim.
xs.dims()
.iter()
.enumerate()
.flat_map(|(idx, &s)| if s == 1 && idx > 0 { Some(idx) } else { None })
.collect()
} else {
get(&node.input[1])?
.to_vec1::<i64>()?
.iter()
.map(|&i| xs.normalize_axis(i))
.collect::<Result<Vec<_>>>()?
};
axes.sort();
let mut xs = xs.clone();
for &axis in axes.iter().rev() {
xs = xs.squeeze(axis)?
}
values.insert(node.output[0].clone(), xs);
}
"ConstantOfShape" => {
let dims = get(&node.input[0])?;
let shape = dims
.to_vec1::<i64>()?
.into_iter()
.map(|v| v as usize)
.collect::<Vec<_>>();
let xs = Tensor::zeros(shape, DType::F32, dims.device())?;
values.insert(node.output[0].clone(), xs);
}
"Unsqueeze" => {
let xs = get(&node.input[0])?;
let axes = match get_attr_opt::<[i64]>(node, "axes")? {
Some(axis) => axis.to_vec(),
None => get(&node.input[1])?.to_vec1::<i64>()?,
};
let mut axes = axes
.iter()
.map(|&i| {
if i == xs.rank() as i64 {
Ok(xs.rank())
} else {
xs.normalize_axis(i)
}
})
.collect::<Result<Vec<_>>>()?;
axes.sort();
let mut xs = xs.clone();
for &axis in axes.iter().rev() {
xs = xs.unsqueeze(axis)?
}
values.insert(node.output[0].clone(), xs);
}
"Clip" => {
let xs = get(&node.input[0])?;
let xs = if node.input.len() >= 2 {
let mins = get(&node.input[1])?;
xs.broadcast_maximum(mins)?
} else {
xs.clone()
};
let xs = if node.input.len() >= 3 {
let maxs = get(&node.input[2])?;
xs.broadcast_minimum(maxs)?
} else {
xs.clone()
};
values.insert(node.output[0].clone(), xs);
}
"Gather" => {
let xs = get(&node.input[0])?;
let indices = get(&node.input[1])?;
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
let axis = xs.normalize_axis(axis)?;
// TODO: Provide an op to handle the ONNX generalized gather op ideally in a
// differentiable way.
let xs = if indices.rank() == 0 {
let index = indices.to_vec0::<i64>()? as usize;
xs.narrow(axis, index, 1)?.squeeze(axis)?
} else {
todo!("implement gather for {xs:?} {indices:?} axis {axis}")
};
values.insert(node.output[0].clone(), xs);
}
"Shape" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape
let xs = get(&node.input[0])?;
let start = get_attr_opt::<i64>(node, "start")?.copied().unwrap_or(0);
let end = get_attr_opt::<i64>(node, "end")?.copied().unwrap_or(-1);
let start = xs.normalize_axis(start)?;
let end = xs.normalize_axis(end)?;
let mut dims = vec![];
for idx in start..=end {
dims.push(xs.dim(idx)? as i64)
}
let dims = Tensor::from_vec(dims, xs.rank(), xs.device())?;
values.insert(node.output[0].clone(), dims);
}
"Conv" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
let groups = get_attr_opt::<i64>(node, "group")?.copied().unwrap_or(1);
let _kernel_shape = get_attr_opt::<[i64]>(node, "kernel_shape")?;
let pads = get_attr_opt::<[i64]>(node, "pads")?;
let strides = get_attr_opt::<[i64]>(node, "strides")?;
let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
match auto_pad {
None | Some("NOTSET") => (),
Some(s) => bail!("unsupported auto_pad {s}"),
};
let xs = get(&node.input[0])?;
let ws = get(&node.input[1])?;
let ys = match ws.rank() {
3 => {
let (pads, xs) = match pads {
None => (0, xs.clone()),
Some([p]) => (*p as usize, xs.clone()),
Some([p1, p2]) => {
if p1 != p2 {
(0usize, xs.pad_with_zeros(2, *p1 as usize, *p2 as usize)?)
} else {
(*p1 as usize, xs.clone())
}
}
Some(pads) => {
bail!("more pads than expected in conv1d {pads:?} {}", node.name)
}
};
let strides = match strides {
None => 1,
Some([p]) => *p as usize,
Some(s) => {
bail!("more strides than expected in conv1d {s:?} {}", node.name)
}
};
let dilations = match dilations {
None => 1,
Some([p]) => *p as usize,
Some(s) => {
bail!("more dilations than expected in conv1d {s:?} {}", node.name)
}
};
xs.conv1d(ws, pads, strides, dilations, groups as usize)?
}
4 => {
let (pads, xs) = match pads {
None => (0, xs.clone()),
Some([p]) => (*p as usize, xs.clone()),
Some(&[p1, p2, p3, p4]) => {
let p1 = p1 as usize;
let p2 = p2 as usize;
let p3 = p3 as usize;
let p4 = p4 as usize;
if p1 != p2 || p1 != p3 || p1 != p4 {
(0, xs.pad_with_zeros(2, p1, p3)?.pad_with_zeros(3, p2, p4)?)
} else {
(p1, xs.clone())
}
}
Some(pads) => {
bail!("more pads than expected in conv2d {pads:?} {}", node.name)
}
};
let strides = match strides {
None => 1,
Some([p]) => *p as usize,
Some([p1, p2]) => {
if p1 != p2 {
bail!(
"strides have to be the same on both axis {pads:?} {}",
node.name
)
}
*p1 as usize
}
Some(s) => {
bail!("more strides than expected in conv2d {s:?} {}", node.name)
}
};
let dilations = match dilations {
None => 1,
Some([p]) => *p as usize,
Some([p1, p2]) => {
if p1 != p2 {
bail!(
"dilations have to be the same on both axis {pads:?} {}",
node.name
)
}
*p1 as usize
}
Some(s) => {
bail!("more dilations than expected in conv2d {s:?} {}", node.name)
}
};
xs.conv2d(ws, pads, strides, dilations, groups as usize)?
}
rank => bail!(
"unsupported rank for weight matrix {rank} in conv {}",
node.name
),
};
let ys = if node.input.len() > 2 {
let bs = get(&node.input[2])?;
let mut bs_shape = vec![1; ys.rank()];
bs_shape[1] = bs.elem_count();
ys.broadcast_add(&bs.reshape(bs_shape)?)?
} else {
ys
};
values.insert(node.output[0].clone(), ys);
}
"Concat" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Concat
let inputs = node
.input
.iter()
.map(|n| Ok(get(n.as_str())?.clone()))
.collect::<Result<Vec<Value>>>()?;
let axis: i64 = *get_attr(node, "axis")?;
if inputs.is_empty() {
bail!("empty concat")
};
let axis = inputs[0].normalize_axis(axis)?;
let output = Tensor::cat(&inputs, axis)?;
values.insert(node.output[0].clone(), output);
}
"Abs" => {
let input = get(&node.input[0])?;
let output = input.abs()?;
values.insert(node.output[0].clone(), output);
}
"Cos" => {
let input = get(&node.input[0])?;
let output = input.cos()?;
values.insert(node.output[0].clone(), output);
}
"Sin" => {
let input = get(&node.input[0])?;
let output = input.sin()?;
values.insert(node.output[0].clone(), output);
}
"Neg" => {
let input = get(&node.input[0])?;
let output = input.neg()?;
values.insert(node.output[0].clone(), output);
}
"Erf" => {
let input = get(&node.input[0])?;
let output = input.erf()?;
values.insert(node.output[0].clone(), output);
}
"Tanh" => {
let input = get(&node.input[0])?;
let output = input.tanh()?;
values.insert(node.output[0].clone(), output);
}
"Sigmoid" => {
let input = get(&node.input[0])?;
let output = candle_nn::ops::sigmoid(input)?;
values.insert(node.output[0].clone(), output);
}
"Gelu" => {
let input = get(&node.input[0])?;
let output = input.gelu_erf()?;
values.insert(node.output[0].clone(), output);
}
"Relu" => {
let input = get(&node.input[0])?;
let output = input.relu()?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Constant
"Constant" => {
let value = match node.attribute.iter().find(|attr| attr.name == "value") {
None => {
// TODO: support sparse_value etc.
bail!("cannot find 'value' attr in 'Constant' for {}", node.name)
}
Some(value) => value,
};
let output = match value.r#type() {
AttributeType::Tensor => {
let t = value.t.as_ref().unwrap();
get_tensor(t, &node.name)?
}
rtype => bail!("unsupported 'value' type {rtype:?} for {}", node.name),
};
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
"Cast" => {
let input = get(&node.input[0])?;
let dt: i64 = *get_attr(node, "to")?;
let dtype = match DataType::try_from(dt as i32) {
Ok(DataType::Int32) => DType::I64,
Ok(dt) => match dtype(dt) {
Some(dt) => dt,
None => {
bail!("unsupported 'to' value {dt:?} for cast {}", node.name)
}
},
Err(_) => {
bail!("unsupported 'to' value {dt:?} for cast {}", node.name)
}
};
let output = input.to_dtype(dtype)?;
values.insert(node.output[0].clone(), output);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}
graph
.output
.iter()
.map(|output| match values.remove(&output.name) {
None => bail!("cannot find output {}", output.name),
Some(value) => Ok((output.name.clone(), value)),
})
.collect()
}

View File

@ -1,14 +0,0 @@
use candle::Result;
use prost::Message;
pub mod onnx {
include!(concat!(env!("OUT_DIR"), "/onnx.rs"));
}
pub mod eval;
pub use eval::{dtype, simple_eval};
pub fn read_file<P: AsRef<std::path::Path>>(p: P) -> Result<onnx::ModelProto> {
let buf = std::fs::read(p)?;
onnx::ModelProto::decode(buf.as_slice()).map_err(candle::Error::wrap)
}

View File

@ -1,836 +0,0 @@
//
// WARNING: This file is automatically generated! Please edit onnx.in.proto.
//
// SPDX-License-Identifier: Apache-2.0
syntax = "proto3";
package onnx;
// Overview
//
// ONNX is an open specification that is comprised of the following components:
//
// 1) A definition of an extensible computation graph model.
// 2) Definitions of standard data types.
// 3) Definitions of built-in operators.
//
// This document describes the syntax of models and their computation graphs,
// as well as the standard data types. Together, they are referred to as the ONNX
// Intermediate Representation, or 'IR' for short.
//
// The normative semantic specification of the ONNX IR is found in docs/IR.md.
// Definitions of the built-in neural network operators may be found in docs/Operators.md.
// Notes
//
// Protobuf compatibility
//
// To simplify framework compatibility, ONNX is defined using the subset of protobuf
// that is compatible with both protobuf v2 and v3. This means that we do not use any
// protobuf features that are only available in one of the two versions.
//
// Here are the most notable contortions we have to carry out to work around
// these limitations:
//
// - No 'map' (added protobuf 3.0). We instead represent mappings as lists
// of key-value pairs, where order does not matter and duplicates
// are not allowed.
// Versioning
//
// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md
//
// To be compatible with both proto2 and proto3, we will use a version number
// that is not defined by the default value but an explicit enum number.
enum Version {
// proto3 requires the first enum value to be zero.
// We add this just to appease the compiler.
_START_VERSION = 0;
// The version field is always serialized and we will use it to store the
// version that the graph is generated from. This helps us set up version
// control.
// For the IR, we are using simple numbers starting with 0x00000001,
// which was the version we published on Oct 10, 2017.
IR_VERSION_2017_10_10 = 0x0000000000000001;
// IR_VERSION 2 published on Oct 30, 2017
// - Added type discriminator to AttributeProto to support proto3 users
IR_VERSION_2017_10_30 = 0x0000000000000002;
// IR VERSION 3 published on Nov 3, 2017
// - For operator versioning:
// - Added new message OperatorSetIdProto
// - Added opset_import in ModelProto
// - For vendor extensions, added domain in NodeProto
IR_VERSION_2017_11_3 = 0x0000000000000003;
// IR VERSION 4 published on Jan 22, 2019
// - Relax constraint that initializers should be a subset of graph inputs
// - Add type BFLOAT16
IR_VERSION_2019_1_22 = 0x0000000000000004;
// IR VERSION 5 published on March 18, 2019
// - Add message TensorAnnotation.
// - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters.
IR_VERSION_2019_3_18 = 0x0000000000000005;
// IR VERSION 6 published on Sep 19, 2019
// - Add support for sparse tensor constants stored in model.
// - Add message SparseTensorProto
// - Add sparse initializers
IR_VERSION_2019_9_19 = 0x0000000000000006;
// IR VERSION 7 published on May 8, 2020
// - Add support to allow function body graph to rely on multiple external opreator sets.
// - Add a list to promote inference graph's initializers to global and
// mutable variables. Global variables are visible in all graphs of the
// stored models.
// - Add message TrainingInfoProto to store initialization
// method and training algorithm. The execution of TrainingInfoProto
// can modify the values of mutable variables.
// - Implicitly add inference graph into each TrainingInfoProto's algorithm.
IR_VERSION_2020_5_8 = 0x0000000000000007;
// IR VERSION 8 published on July 30, 2021
// Introduce TypeProto.SparseTensor
// Introduce TypeProto.Optional
// Added a list of FunctionProtos local to the model
// Deprecated since_version and operator status from FunctionProto
IR_VERSION_2021_7_30 = 0x0000000000000008;
// IR VERSION 9 published on May 5, 2023
// Added AttributeProto to FunctionProto so that default attribute values can be set.
// Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ.
IR_VERSION = 0x0000000000000009;
}
// Attributes
//
// A named attribute containing either singular float, integer, string, graph,
// and tensor values, or repeated float, integer, string, graph, and tensor values.
// An AttributeProto MUST contain the name field, and *only one* of the
// following content fields, effectively enforcing a C/C++ union equivalent.
message AttributeProto {
reserved 12, 16 to 19;
reserved "v";
// Note: this enum is structurally identical to the OpSchema::AttrType
// enum defined in schema.h. If you rev one, you likely need to rev the other.
enum AttributeType {
UNDEFINED = 0;
FLOAT = 1;
INT = 2;
STRING = 3;
TENSOR = 4;
GRAPH = 5;
SPARSE_TENSOR = 11;
TYPE_PROTO = 13;
FLOATS = 6;
INTS = 7;
STRINGS = 8;
TENSORS = 9;
GRAPHS = 10;
SPARSE_TENSORS = 12;
TYPE_PROTOS = 14;
}
// The name field MUST be present for this version of the IR.
string name = 1; // namespace Attribute
// if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
// In this case, this AttributeProto does not contain data, and it's a reference of attribute
// in parent scope.
// NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph.
string ref_attr_name = 21;
// A human-readable documentation for this attribute. Markdown is allowed.
string doc_string = 13;
// The type field MUST be present for this version of the IR.
// For 0.0.1 versions of the IR, this field was not defined, and
// implementations needed to use has_field heuristics to determine
// which value field was in use. For IR_VERSION 0.0.2 or later, this
// field MUST be set and match the f|i|s|t|... field in use. This
// change was made to accommodate proto3 implementations.
AttributeType type = 20; // discriminator that indicates which field below is in use
// Exactly ONE of the following fields must be present for this version of the IR
float f = 2; // float
int64 i = 3; // int
bytes s = 4; // UTF-8 string
TensorProto t = 5; // tensor value
GraphProto g = 6; // graph
SparseTensorProto sparse_tensor = 22; // sparse tensor value
// Do not use field below, it's deprecated.
// optional ValueProto v = 12; // value - subsumes everything but graph
TypeProto tp = 14; // type proto
repeated float floats = 7; // list of floats
repeated int64 ints = 8; // list of ints
repeated bytes strings = 9; // list of UTF-8 strings
repeated TensorProto tensors = 10; // list of tensors
repeated GraphProto graphs = 11; // list of graph
repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors
repeated TypeProto type_protos = 15;// list of type protos
}
// Defines information on value, including the name, the type, and
// the shape of the value.
message ValueInfoProto {
// This field MUST be present in this version of the IR.
string name = 1; // namespace Value
// This field MUST be present in this version of the IR for
// inputs and outputs of the top-level graph.
TypeProto type = 2;
// A human-readable documentation for this value. Markdown is allowed.
string doc_string = 3;
}
// Nodes
//
// Computation graphs are made up of a DAG of nodes, which represent what is
// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
//
// For example, it can be a node of type "Conv" that takes in an image, a filter
// tensor and a bias tensor, and produces the convolved output.
message NodeProto {
repeated string input = 1; // namespace Value
repeated string output = 2; // namespace Value
// An optional identifier for this node in a graph.
// This field MAY be absent in ths version of the IR.
string name = 3; // namespace Node
// The symbolic identifier of the Operator to execute.
string op_type = 4; // namespace Operator
// The domain of the OperatorSet that specifies the operator named by op_type.
string domain = 7; // namespace Domain
// Additional named attributes.
repeated AttributeProto attribute = 5;
// A human-readable documentation for this node. Markdown is allowed.
string doc_string = 6;
}
// Training information
// TrainingInfoProto stores information for training a model.
// In particular, this defines two functionalities: an initialization-step
// and a training-algorithm-step. Initialization resets the model
// back to its original state as if no training has been performed.
// Training algorithm improves the model based on input data.
//
// The semantics of the initialization-step is that the initializers
// in ModelProto.graph and in TrainingInfoProto.algorithm are first
// initialized as specified by the initializers in the graph, and then
// updated by the "initialization_binding" in every instance in
// ModelProto.training_info.
//
// The field "algorithm" defines a computation graph which represents a
// training algorithm's step. After the execution of a
// TrainingInfoProto.algorithm, the initializers specified by "update_binding"
// may be immediately updated. If the targeted training algorithm contains
// consecutive update steps (such as block coordinate descent methods),
// the user needs to create a TrainingInfoProto for each step.
message TrainingInfoProto {
// This field describes a graph to compute the initial tensors
// upon starting the training process. Initialization graph has no input
// and can have multiple outputs. Usually, trainable tensors in neural
// networks are randomly initialized. To achieve that, for each tensor,
// the user can put a random number operator such as RandomNormal or
// RandomUniform in TrainingInfoProto.initialization.node and assign its
// random output to the specific tensor using "initialization_binding".
// This graph can also set the initializers in "algorithm" in the same
// TrainingInfoProto; a use case is resetting the number of training
// iteration to zero.
//
// By default, this field is an empty graph and its evaluation does not
// produce any output. Thus, no initializer would be changed by default.
GraphProto initialization = 1;
// This field represents a training algorithm step. Given required inputs,
// it computes outputs to update initializers in its own or inference graph's
// initializer lists. In general, this field contains loss node, gradient node,
// optimizer node, increment of iteration count.
//
// An execution of the training algorithm step is performed by executing the
// graph obtained by combining the inference graph (namely "ModelProto.graph")
// and the "algorithm" graph. That is, the actual
// input/initializer/output/node/value_info/sparse_initializer list of
// the training graph is the concatenation of
// "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer"
// and "algorithm.input/initializer/output/node/value_info/sparse_initializer"
// in that order. This combined graph must satisfy the normal ONNX conditions.
// Now, let's provide a visualization of graph combination for clarity.
// Let the inference graph (i.e., "ModelProto.graph") be
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d
// and the "algorithm" graph be
// tensor_d -> Add -> tensor_e
// The combination process results
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e
//
// Notice that an input of a node in the "algorithm" graph may reference the
// output of a node in the inference graph (but not the other way round). Also, inference
// node cannot reference inputs of "algorithm". With these restrictions, inference graph
// can always be run independently without training information.
//
// By default, this field is an empty graph and its evaluation does not
// produce any output. Evaluating the default training step never
// update any initializers.
GraphProto algorithm = 2;
// This field specifies the bindings from the outputs of "initialization" to
// some initializers in "ModelProto.graph.initializer" and
// the "algorithm.initializer" in the same TrainingInfoProto.
// See "update_binding" below for details.
//
// By default, this field is empty and no initializer would be changed
// by the execution of "initialization".
repeated StringStringEntryProto initialization_binding = 3;
// Gradient-based training is usually an iterative procedure. In one gradient
// descent iteration, we apply
//
// x = x - r * g
//
// where "x" is the optimized tensor, "r" stands for learning rate, and "g" is
// gradient of "x" with respect to a chosen loss. To avoid adding assignments
// into the training graph, we split the update equation into
//
// y = x - r * g
// x = y
//
// The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To
// tell that "y" should be assigned to "x", the field "update_binding" may
// contain a key-value pair of strings, "x" (key of StringStringEntryProto)
// and "y" (value of StringStringEntryProto).
// For a neural network with multiple trainable (mutable) tensors, there can
// be multiple key-value pairs in "update_binding".
//
// The initializers appears as keys in "update_binding" are considered
// mutable variables. This implies some behaviors
// as described below.
//
// 1. We have only unique keys in all "update_binding"s so that two
// variables may not have the same name. This ensures that one
// variable is assigned up to once.
// 2. The keys must appear in names of "ModelProto.graph.initializer" or
// "TrainingInfoProto.algorithm.initializer".
// 3. The values must be output names of "algorithm" or "ModelProto.graph.output".
// 4. Mutable variables are initialized to the value specified by the
// corresponding initializer, and then potentially updated by
// "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s.
//
// This field usually contains names of trainable tensors
// (in ModelProto.graph), optimizer states such as momentums in advanced
// stochastic gradient methods (in TrainingInfoProto.graph),
// and number of training iterations (in TrainingInfoProto.graph).
//
// By default, this field is empty and no initializer would be changed
// by the execution of "algorithm".
repeated StringStringEntryProto update_binding = 4;
}
// Models
//
// ModelProto is a top-level file/container format for bundling a ML model and
// associating its computation graph with metadata.
//
// The semantics of the model are described by the associated GraphProto's.
message ModelProto {
// The version of the IR this model targets. See Version enum above.
// This field MUST be present.
int64 ir_version = 1;
// The OperatorSets this model relies on.
// All ModelProtos MUST have at least one entry that
// specifies which version of the ONNX OperatorSet is
// being imported.
//
// All nodes in the ModelProto's graph will bind against the operator
// with the same-domain/same-op_type operator with the HIGHEST version
// in the referenced operator sets.
repeated OperatorSetIdProto opset_import = 8;
// The name of the framework or tool used to generate this model.
// This field SHOULD be present to indicate which implementation/tool/framework
// emitted the model.
string producer_name = 2;
// The version of the framework or tool used to generate this model.
// This field SHOULD be present to indicate which implementation/tool/framework
// emitted the model.
string producer_version = 3;
// Domain name of the model.
// We use reverse domain names as name space indicators. For example:
// `com.facebook.fair` or `com.microsoft.cognitiveservices`
//
// Together with `model_version` and GraphProto.name, this forms the unique identity of
// the graph.
string domain = 4;
// The version of the graph encoded. See Version enum below.
int64 model_version = 5;
// A human-readable documentation for this model. Markdown is allowed.
string doc_string = 6;
// The parameterized graph that is evaluated to execute the model.
GraphProto graph = 7;
// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 14;
// Training-specific information. Sequentially executing all stored
// `TrainingInfoProto.algorithm`s and assigning their outputs following
// the corresponding `TrainingInfoProto.update_binding`s is one training
// iteration. Similarly, to initialize the model
// (as if training hasn't happened), the user should sequentially execute
// all stored `TrainingInfoProto.initialization`s and assigns their outputs
// using `TrainingInfoProto.initialization_binding`s.
//
// If this field is empty, the training behavior of the model is undefined.
repeated TrainingInfoProto training_info = 20;
// A list of function protos local to the model.
//
// Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain".
// In case of any conflicts the behavior (whether the model local functions are given higher priority,
// or standard operator sets are given higher priotity or this is treated as error) is defined by
// the runtimes.
//
// The operator sets imported by FunctionProto should be compatible with the ones
// imported by ModelProto and other model local FunctionProtos.
// Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto
// or by 2 FunctionProtos then versions for the operator set may be different but,
// the operator schema returned for op_type, domain, version combination
// for both the versions should be same for every node in the function body.
//
// One FunctionProto can reference other FunctionProto in the model, however, recursive reference
// is not allowed.
repeated FunctionProto functions = 25;
};
// StringStringEntryProto follows the pattern for cross-proto-version maps.
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
message StringStringEntryProto {
string key = 1;
string value = 2;
};
message TensorAnnotation {
string tensor_name = 1;
// <key, value> pairs to annotate tensor specified by <tensor_name> above.
// The keys used in the mapping below must be pre-defined in ONNX spec.
// For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as
// quantization parameter keys.
repeated StringStringEntryProto quant_parameter_tensor_names = 2;
}
// Graphs
//
// A graph defines the computational logic of a model and is comprised of a parameterized
// list of nodes that form a directed acyclic graph based on their inputs and outputs.
// This is the equivalent of the "network" or "graph" in many deep learning
// frameworks.
message GraphProto {
// The nodes in the graph, sorted topologically.
repeated NodeProto node = 1;
// The name of the graph.
string name = 2; // namespace Graph
// A list of named tensor values, used to specify constant inputs of the graph.
// Each initializer (both TensorProto as well SparseTensorProto) MUST have a name.
// The name MUST be unique across both initializer and sparse_initializer,
// but the name MAY also appear in the input list.
repeated TensorProto initializer = 5;
// Initializers (see above) stored in sparse format.
repeated SparseTensorProto sparse_initializer = 15;
// A human-readable documentation for this graph. Markdown is allowed.
string doc_string = 10;
// The inputs and outputs of the graph.
repeated ValueInfoProto input = 11;
repeated ValueInfoProto output = 12;
// Information for the values in the graph. The ValueInfoProto.name's
// must be distinct. It is optional for a value to appear in value_info list.
repeated ValueInfoProto value_info = 13;
// This field carries information to indicate the mapping among a tensor and its
// quantization parameter tensors. For example:
// For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated,
// which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
repeated TensorAnnotation quantization_annotation = 14;
reserved 3, 4, 6 to 9;
reserved "ir_version", "producer_version", "producer_tag", "domain";
}
// Tensors
//
// A serialized tensor value.
message TensorProto {
enum DataType {
UNDEFINED = 0;
// Basic types.
FLOAT = 1; // float
UINT8 = 2; // uint8_t
INT8 = 3; // int8_t
UINT16 = 4; // uint16_t
INT16 = 5; // int16_t
INT32 = 6; // int32_t
INT64 = 7; // int64_t
STRING = 8; // string
BOOL = 9; // bool
// IEEE754 half-precision floating-point format (16 bits wide).
// This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
FLOAT16 = 10;
DOUBLE = 11;
UINT32 = 12;
UINT64 = 13;
COMPLEX64 = 14; // complex with float32 real and imaginary components
COMPLEX128 = 15; // complex with float64 real and imaginary components
// Non-IEEE floating-point format based on IEEE754 single-precision
// floating-point number truncated to 16 bits.
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
BFLOAT16 = 16;
// Non-IEEE floating-point format based on papers
// FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433,
// 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf.
// Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear.
// The computation usually happens inside a block quantize / dequantize
// fused by the runtime.
FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf
FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero
FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients
FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero
// Future extensions go here.
}
// The shape of the tensor.
repeated int64 dims = 1;
// The data type of the tensor.
// This field MUST have a valid TensorProto.DataType value
int32 data_type = 2;
// For very large tensors, we may want to store them in chunks, in which
// case the following fields will specify the segment that is stored in
// the current TensorProto.
message Segment {
int64 begin = 1;
int64 end = 2;
}
Segment segment = 3;
// Tensor content must be organized in row-major order.
//
// Depending on the data_type field, exactly one of the fields below with
// name ending in _data is used to store the elements of the tensor.
// For float and complex64 values
// Complex64 tensors are encoded as a single array of floats,
// with the real components appearing in odd numbered positions,
// and the corresponding imaginary component appearing in the
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
repeated float float_data = 4 [packed = true];
// For int32, uint8, int8, uint16, int16, bool, float8, and float16 values
// float16 and float8 values must be bit-wise converted to an uint16_t prior
// to writing to the buffer.
// When this field is present, the data_type field MUST be
// INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ
repeated int32 int32_data = 5 [packed = true];
// For strings.
// Each element of string_data is a UTF-8 encoded Unicode
// string. No trailing null, no leading BOM. The protobuf "string"
// scalar type is not used to match ML community conventions.
// When this field is present, the data_type field MUST be STRING
repeated bytes string_data = 6;
// For int64.
// When this field is present, the data_type field MUST be INT64
repeated int64 int64_data = 7 [packed = true];
// Optionally, a name for the tensor.
string name = 8; // namespace Value
// A human-readable documentation for this tensor. Markdown is allowed.
string doc_string = 12;
// Serializations can either use one of the fields above, or use this
// raw bytes field. The only exception is the string case, where one is
// required to store the content in the repeated bytes string_data field.
//
// When this raw_data field is used to store tensor value, elements MUST
// be stored in as fixed-width, little-endian order.
// Floating-point data types MUST be stored in IEEE 754 format.
// Complex64 elements must be written as two consecutive FLOAT values, real component first.
// Complex128 elements must be written as two consecutive DOUBLE values, real component first.
// Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
//
// Note: the advantage of specific field rather than the raw_data field is
// that in some cases (e.g. int data), protobuf does a better packing via
// variable length storage, and may lead to smaller binary footprint.
// When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
bytes raw_data = 9;
// Data can be stored inside the protobuf file using type-specific fields or raw_data.
// Alternatively, raw bytes data can be stored in an external file, using the external_data field.
// external_data stores key-value pairs describing data location. Recognized keys are:
// - "location" (required) - POSIX filesystem path relative to the directory where the ONNX
// protobuf model was stored
// - "offset" (optional) - position of byte at which stored data begins. Integer stored as string.
// Offset values SHOULD be multiples 4096 (page size) to enable mmap support.
// - "length" (optional) - number of bytes containing data. Integer stored as string.
// - "checksum" (optional) - SHA1 digest of file specified in under 'location' key.
repeated StringStringEntryProto external_data = 13;
// Location of the data for this tensor. MUST be one of:
// - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field.
// - EXTERNAL - data stored in an external location as described by external_data field.
enum DataLocation {
DEFAULT = 0;
EXTERNAL = 1;
}
// If value not set, data is stored in raw_data (if set) otherwise in type-specified field.
DataLocation data_location = 14;
// For double
// Complex128 tensors are encoded as a single array of doubles,
// with the real components appearing in odd numbered positions,
// and the corresponding imaginary component appearing in the
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
// When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
repeated double double_data = 10 [packed = true];
// For uint64 and uint32 values
// When this field is present, the data_type field MUST be
// UINT32 or UINT64
repeated uint64 uint64_data = 11 [packed = true];
}
// A serialized sparse-tensor value
message SparseTensorProto {
// The sequence of non-default values are encoded as a tensor of shape [NNZ].
// The default-value is zero for numeric tensors, and empty-string for string tensors.
// values must have a non-empty name present which serves as a name for SparseTensorProto
// when used in sparse_initializer list.
TensorProto values = 1;
// The indices of the non-default values, which may be stored in one of two formats.
// (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value
// corresponding to the j-th index of the i-th value (in the values tensor).
// (b) Indices can be a tensor of shape [NNZ], in which case the i-th value
// must be the linearized-index of the i-th value (in the values tensor).
// The linearized-index can be converted into an index tuple (k_1,...,k_rank)
// using the shape provided below.
// The indices must appear in ascending order without duplication.
// In the first format, the ordering is lexicographic-ordering:
// e.g., index-value [1,4] must appear before [2,1]
TensorProto indices = 2;
// The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank]
repeated int64 dims = 3;
}
// Defines a tensor shape. A dimension can be either an integer value
// or a symbolic variable. A symbolic variable represents an unknown
// dimension.
message TensorShapeProto {
message Dimension {
oneof value {
int64 dim_value = 1;
string dim_param = 2; // namespace Shape
};
// Standard denotation can optionally be used to denote tensor
// dimensions with standard semantic descriptions to ensure
// that operations are applied to the correct axis of a tensor.
// Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
// for pre-defined dimension denotations.
string denotation = 3;
};
repeated Dimension dim = 1;
}
// Types
//
// The standard ONNX data types.
message TypeProto {
message Tensor {
// This field MUST NOT have the value of UNDEFINED
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
int32 elem_type = 1;
TensorShapeProto shape = 2;
}
// repeated T
message Sequence {
// The type and optional shape of each element of the sequence.
// This field MUST be present for this version of the IR.
TypeProto elem_type = 1;
};
// map<K,V>
message Map {
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
// This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING
int32 key_type = 1;
// This field MUST be present for this version of the IR.
TypeProto value_type = 2;
};
// wrapper for Tensor, Sequence, or Map
message Optional {
// The type and optional shape of the element wrapped.
// This field MUST be present for this version of the IR.
// Possible values correspond to OptionalProto.DataType enum
TypeProto elem_type = 1;
};
message SparseTensor {
// This field MUST NOT have the value of UNDEFINED
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
int32 elem_type = 1;
TensorShapeProto shape = 2;
}
oneof value {
// The type of a tensor.
Tensor tensor_type = 1;
// NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values
// as input and output to graphs and nodes. These types are needed to naturally
// support classical ML operators. DNN operators SHOULD restrict their input
// and output types to tensors.
// The type of a sequence.
Sequence sequence_type = 4;
// The type of a map.
Map map_type = 5;
// The type of an optional.
Optional optional_type = 9;
// Type of the sparse tensor
SparseTensor sparse_tensor_type = 8;
}
// An optional denotation can be used to denote the whole
// type with a standard semantic description as to what is
// stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition
// for pre-defined type denotations.
string denotation = 6;
}
// Operator Sets
//
// OperatorSets are uniquely identified by a (domain, opset_version) pair.
message OperatorSetIdProto {
// The domain of the operator set being identified.
// The empty string ("") or absence of this field implies the operator
// set that is defined as part of the ONNX specification.
// This field MUST be present in this version of the IR when referring to any other operator set.
string domain = 1;
// The version of the operator set being identified.
// This field MUST be present in this version of the IR.
int64 version = 2;
}
// Operator/function status.
enum OperatorStatus {
EXPERIMENTAL = 0;
STABLE = 1;
}
message FunctionProto {
// The name of the function, similar usage of op_type in OperatorProto.
// Combined with FunctionProto.domain, this forms the unique identity of
// the FunctionProto.
string name = 1;
// Deprecated since IR Version 8
// optional int64 since_version = 2;
reserved 2;
reserved "since_version";
// Deprecated since IR Version 8
// optional OperatorStatus status = 3;
reserved 3;
reserved "status";
// The inputs and outputs of the function.
repeated string input = 4;
repeated string output = 5;
// The attribute parameters of the function.
// It is for function parameters without default values.
repeated string attribute = 6;
// The attribute protos of the function.
// It is for function attributes with default values.
// A function attribute shall be represented either as
// a string attribute or an AttributeProto, not both.
repeated AttributeProto attribute_proto = 11;
// The nodes in the function.
repeated NodeProto node = 7;
// A human-readable documentation for this function. Markdown is allowed.
string doc_string = 8;
// The OperatorSets this function body (graph) relies on.
//
// All nodes in the function body (graph) will bind against the operator
// with the same-domain/same-op_type operator with the HIGHEST version
// in the referenced operator sets. This means at most one version can be relied
// for one domain.
//
// The operator sets imported by FunctionProto should be compatible with the ones
// imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto
// and ModelProto then versions for the operator set may be different but,
// the operator schema returned for op_type, domain, version combination
// for both the versions should be same.
repeated OperatorSetIdProto opset_import = 9;
// The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of
// the FunctionProto.
string domain = 10;
}
// For using protobuf-lite
option optimize_for = LITE_RUNTIME;

View File

@ -17,7 +17,6 @@ crate-type = ["cdylib"]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
candle-onnx = {path= "../candle-onnx", version = "0.3.0", optional = true}
half = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
@ -30,5 +29,3 @@ default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"]
mkl = ["dep:intel-mkl-src","candle/mkl"]
onnx = ["dep:candle-onnx"]

View File

@ -1,5 +0,0 @@
# Generated content DO NOT EDIT
from .. import onnx
ONNXModel = onnx.ONNXModel
ONNXTensorDescription = onnx.ONNXTensorDescription

View File

@ -1,89 +0,0 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
from candle import Tensor, DType, QTensor
class ONNXModel:
"""
A wrapper around an ONNX model.
"""
def __init__(self, path: str):
pass
@property
def doc_string(self) -> str:
"""
The doc string of the model.
"""
pass
@property
def domain(self) -> str:
"""
The domain of the operator set of the model.
"""
pass
def initializers(self) -> Dict[str, Tensor]:
"""
Get the weights of the model.
"""
pass
@property
def inputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
"""
The inputs of the model.
"""
pass
@property
def ir_version(self) -> int:
"""
The version of the IR this model targets.
"""
pass
@property
def model_version(self) -> int:
"""
The version of the model.
"""
pass
@property
def outputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
"""
The outputs of the model.
"""
pass
@property
def producer_name(self) -> str:
"""
The producer of the model.
"""
pass
@property
def producer_version(self) -> str:
"""
The version of the producer of the model.
"""
pass
def run(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
Run the model on the given inputs.
"""
pass
class ONNXTensorDescription:
"""
A wrapper around an ONNX tensor description.
"""
@property
def dtype(self) -> DType:
"""
The data type of the tensor.
"""
pass
@property
def shape(self) -> Tuple[Union[int, str, Any]]:
"""
The shape of the tensor.
"""
pass

View File

@ -19,14 +19,12 @@ extern crate accelerate_src;
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
mod utils;
use utils::wrap_err;
mod shape;
use shape::{PyShape, PyShapeWithHole};
#[cfg(feature = "onnx")]
mod onnx;
pub fn wrap_err(err: ::candle::Error) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}"))
}
#[derive(Clone, Debug)]
#[pyclass(name = "Tensor")]
@ -71,13 +69,11 @@ impl PyDType {
}
static CUDA_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
static METAL_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum PyDevice {
Cpu,
Cuda,
Metal,
}
impl PyDevice {
@ -85,7 +81,7 @@ impl PyDevice {
match device {
Device::Cpu => Self::Cpu,
Device::Cuda(_) => Self::Cuda,
Device::Metal(_) => Self::Metal,
Device::Metal(_) => unimplemented!(),
}
}
@ -101,15 +97,6 @@ impl PyDevice {
*device = Some(d.clone());
Ok(d)
}
Self::Metal => {
let mut device = METAL_DEVICE.lock().unwrap();
if let Some(device) = device.as_ref() {
return Ok(device.clone());
};
let d = Device::new_metal(0).map_err(wrap_err)?;
*device = Some(d.clone());
Ok(d)
}
}
}
}
@ -131,7 +118,6 @@ impl ToPyObject for PyDevice {
let str = match self {
PyDevice::Cpu => "cpu",
PyDevice::Cuda => "cuda",
PyDevice::Metal => "metal",
};
str.to_object(py)
}
@ -1574,14 +1560,6 @@ fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
Ok(())
}
#[cfg(feature = "onnx")]
fn candle_onnx_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
use onnx::{PyONNXModel, PyONNXTensorDescriptor};
m.add_class::<PyONNXModel>()?;
m.add_class::<PyONNXTensorDescriptor>()?;
Ok(())
}
#[pymodule]
fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let utils = PyModule::new(py, "utils")?;
@ -1590,12 +1568,6 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let nn = PyModule::new(py, "functional")?;
candle_functional_m(py, nn)?;
m.add_submodule(nn)?;
#[cfg(feature = "onnx")]
{
let onnx = PyModule::new(py, "onnx")?;
candle_onnx_m(py, onnx)?;
m.add_submodule(onnx)?;
}
m.add_class::<PyTensor>()?;
m.add_class::<PyQTensor>()?;
m.add_class::<PyDType>()?;

View File

@ -1,212 +0,0 @@
use std::collections::HashMap;
use crate::utils::wrap_err;
use crate::{PyDType, PyTensor};
use candle_onnx::eval::{dtype, get_tensor, simple_eval};
use candle_onnx::onnx::tensor_proto::DataType;
use candle_onnx::onnx::tensor_shape_proto::dimension::Value;
use candle_onnx::onnx::type_proto::{Tensor as ONNXTensor, Value as ONNXValue};
use candle_onnx::onnx::{ModelProto, ValueInfoProto};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::{PyList, PyTuple};
#[derive(Clone, Debug)]
#[pyclass(name = "ONNXTensorDescription")]
/// A wrapper around an ONNX tensor description.
pub struct PyONNXTensorDescriptor(ONNXTensor);
#[pymethods]
impl PyONNXTensorDescriptor {
#[getter]
/// The data type of the tensor.
/// &RETURNS&: DType
fn dtype(&self) -> PyResult<PyDType> {
match DataType::try_from(self.0.elem_type) {
Ok(dt) => match dtype(dt) {
Some(dt) => Ok(PyDType(dt)),
None => Err(PyValueError::new_err(format!(
"unsupported 'value' data-type {dt:?}"
))),
},
type_ => Err(PyValueError::new_err(format!(
"unsupported input type {type_:?}"
))),
}
}
#[getter]
/// The shape of the tensor.
/// &RETURNS&: Tuple[Union[int,str,Any]]
fn shape(&self, py: Python) -> PyResult<Py<PyTuple>> {
let shape = PyList::empty(py);
if let Some(d) = &self.0.shape {
for dim in d.dim.iter() {
if let Some(value) = &dim.value {
match value {
Value::DimValue(v) => shape.append(*v)?,
Value::DimParam(s) => shape.append(s.clone())?,
};
} else {
return Err(PyValueError::new_err("None value in shape"));
}
}
}
Ok(shape.to_tuple().into())
}
fn __repr__(&self, py: Python) -> String {
match (self.shape(py), self.dtype()) {
(Ok(shape), Ok(dtype)) => format!(
"TensorDescriptor[shape: {:?}, dtype: {:?}]",
shape.to_string(),
dtype.__str__()
),
(Err(_), Err(_)) => "TensorDescriptor[shape: unknown, dtype: unknown]".to_string(),
(Err(_), Ok(dtype)) => format!(
"TensorDescriptor[shape: unknown, dtype: {:?}]",
dtype.__str__()
),
(Ok(shape), Err(_)) => format!(
"TensorDescriptor[shape: {:?}, dtype: unknown]",
shape.to_string()
),
}
}
fn __str__(&self, py: Python) -> String {
self.__repr__(py)
}
}
#[derive(Clone, Debug)]
#[pyclass(name = "ONNXModel")]
/// A wrapper around an ONNX model.
pub struct PyONNXModel(ModelProto);
fn extract_tensor_descriptions(
value_infos: &[ValueInfoProto],
) -> HashMap<String, PyONNXTensorDescriptor> {
let mut map = HashMap::new();
for value_info in value_infos.iter() {
let input_type = match &value_info.r#type {
Some(input_type) => input_type,
None => continue,
};
let input_type = match &input_type.value {
Some(input_type) => input_type,
None => continue,
};
let tensor_type: &ONNXTensor = match input_type {
ONNXValue::TensorType(tt) => tt,
_ => continue,
};
map.insert(
value_info.name.to_string(),
PyONNXTensorDescriptor(tensor_type.clone()),
);
}
map
}
#[pymethods]
impl PyONNXModel {
#[new]
#[pyo3(text_signature = "(self, path:str)")]
/// Load an ONNX model from the given path.
fn new(path: String) -> PyResult<Self> {
let model: ModelProto = candle_onnx::read_file(path).map_err(wrap_err)?;
Ok(PyONNXModel(model))
}
#[getter]
/// The version of the IR this model targets.
/// &RETURNS&: int
fn ir_version(&self) -> i64 {
self.0.ir_version
}
#[getter]
/// The producer of the model.
/// &RETURNS&: str
fn producer_name(&self) -> String {
self.0.producer_name.clone()
}
#[getter]
/// The version of the producer of the model.
/// &RETURNS&: str
fn producer_version(&self) -> String {
self.0.producer_version.clone()
}
#[getter]
/// The domain of the operator set of the model.
/// &RETURNS&: str
fn domain(&self) -> String {
self.0.domain.clone()
}
#[getter]
/// The version of the model.
/// &RETURNS&: int
fn model_version(&self) -> i64 {
self.0.model_version
}
#[getter]
/// The doc string of the model.
/// &RETURNS&: str
fn doc_string(&self) -> String {
self.0.doc_string.clone()
}
/// Get the weights of the model.
/// &RETURNS&: Dict[str, Tensor]
fn initializers(&self) -> PyResult<HashMap<String, PyTensor>> {
let mut map = HashMap::new();
if let Some(graph) = self.0.graph.as_ref() {
for tensor_description in graph.initializer.iter() {
let tensor = get_tensor(tensor_description, tensor_description.name.as_str())
.map_err(wrap_err)?;
map.insert(tensor_description.name.to_string(), PyTensor(tensor));
}
}
Ok(map)
}
#[getter]
/// The inputs of the model.
/// &RETURNS&: Optional[Dict[str, ONNXTensorDescription]]
fn inputs(&self) -> Option<HashMap<String, PyONNXTensorDescriptor>> {
if let Some(graph) = self.0.graph.as_ref() {
return Some(extract_tensor_descriptions(&graph.input));
}
None
}
#[getter]
/// The outputs of the model.
/// &RETURNS&: Optional[Dict[str, ONNXTensorDescription]]
fn outputs(&self) -> Option<HashMap<String, PyONNXTensorDescriptor>> {
if let Some(graph) = self.0.graph.as_ref() {
return Some(extract_tensor_descriptions(&graph.output));
}
None
}
#[pyo3(text_signature = "(self, inputs:Dict[str,Tensor])")]
/// Run the model on the given inputs.
/// &RETURNS&: Dict[str,Tensor]
fn run(&self, inputs: HashMap<String, PyTensor>) -> PyResult<HashMap<String, PyTensor>> {
let unwrapped_tensors = inputs.into_iter().map(|(k, v)| (k.clone(), v.0)).collect();
let result = simple_eval(&self.0, unwrapped_tensors).map_err(wrap_err)?;
Ok(result
.into_iter()
.map(|(k, v)| (k.clone(), PyTensor(v)))
.collect())
}
}

View File

@ -1,6 +0,0 @@
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
pub fn wrap_err(err: ::candle::Error) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}"))
}

View File

@ -21,7 +21,6 @@ rand = { workspace = true }
rayon = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
serde_plain = { workspace = true }
tracing = { workspace = true }
wav = { workspace = true }
@ -29,5 +28,6 @@ wav = { workspace = true }
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda"]
metal = ["candle/metal", "candle-nn/metal"]
flash-attn = ["cuda", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]

View File

@ -1,4 +1,3 @@
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
use candle::{DType, Device, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
use serde::Deserialize;
@ -33,6 +32,76 @@ impl HiddenActLayer {
}
}
#[derive(Debug)]
pub struct Linear {
weight: Tensor,
bias: Option<Tensor>,
span: tracing::Span,
}
impl Linear {
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
let span = tracing::span!(tracing::Level::TRACE, "linear");
Self { weight, bias, span }
}
}
impl Module for Linear {
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let _enter = self.span.enter();
let w = match x.dims() {
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
_ => self.weight.t()?,
};
let x = x.matmul(&w)?;
match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),
}
}
}
#[derive(Debug)]
pub struct LayerNorm {
weight: Tensor,
bias: Tensor,
eps: f64,
span: tracing::Span,
}
impl LayerNorm {
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
Self {
weight,
bias,
eps,
span,
}
}
}
impl Module for LayerNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let (_bsize, _seq_len, hidden_size) = x.dims3()?;
let x = x.to_dtype(internal_dtype)?;
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
let x = x.broadcast_sub(&mean_x)?;
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
let x = x_normed
.to_dtype(x_dtype)?
.broadcast_mul(&self.weight)?
.broadcast_add(&self.bias)?;
Ok(x)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
enum PositionEmbeddingType {
@ -115,6 +184,12 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
Ok(Embedding::new(embeddings, hidden_size))
}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
let bias = vb.get(size2, "bias")?;
Ok(Linear::new(weight, Some(bias)))
}
struct Dropout {
#[allow(dead_code)]
pr: f64,
@ -133,6 +208,20 @@ impl Module for Dropout {
}
}
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
(Ok(weight), Ok(bias)) => (weight, bias),
(Err(err), _) | (_, Err(err)) => {
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
(weight, bias)
} else {
return Err(err);
}
}
};
Ok(LayerNorm::new(weight, bias, eps))
}
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
struct BertEmbeddings {
word_embeddings: Embedding,

View File

@ -1,4 +1,3 @@
use super::with_tracing::{linear_no_bias as linear, Linear};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Module, VarBuilder};
use serde::Deserialize;
@ -82,6 +81,21 @@ impl Config {
}
}
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
// model.
#[derive(Debug)]
pub struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
}
impl Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
#[derive(Clone)]
pub struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
@ -136,6 +150,12 @@ impl Cache {
}
}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let span = tracing::span!(tracing::Level::TRACE, "linear");
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
Ok(Linear { inner, span })
}
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
Ok(Embedding::new(embeddings, cfg.hidden_size))

View File

@ -1,5 +1,6 @@
use super::with_tracing::{linear, Embedding, Linear};
use candle::{Result, Tensor};
#![allow(unused)]
use super::with_tracing::{linear, linear_no_bias, Embedding, Linear};
use candle::{Module, Result, Tensor};
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
#[derive(Debug, Clone)]
@ -169,6 +170,7 @@ impl Attention {
kv_states: Option<&Tensor>,
attn_mask: Option<&Tensor>,
) -> Result<Tensor> {
let is_cross_attn = kv_states.is_some();
let (b_sz, tgt_len, _) = xs.dims3()?;
let query_states = (xs.apply(&self.q_proj)? * self.scaling)?;
let (key_states, value_states) = match kv_states {
@ -257,10 +259,6 @@ impl EncoderLayer {
.apply(&self.fc2)?;
(xs + residual)?.apply(&self.final_layer_norm)
}
fn reset_kv_cache(&mut self) {
self.self_attn.reset_kv_cache()
}
}
#[derive(Debug, Clone)]
@ -322,11 +320,6 @@ impl DecoderLayer {
let xs = (xs + residual)?.apply(&self.final_layer_norm)?;
Ok(xs)
}
fn reset_kv_cache(&mut self) {
self.self_attn.reset_kv_cache();
self.encoder_attn.reset_kv_cache()
}
}
#[derive(Debug, Clone)]
@ -375,12 +368,6 @@ impl Encoder {
}
Ok(xs)
}
pub fn reset_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.reset_kv_cache()
}
}
}
#[derive(Debug, Clone)]
@ -435,12 +422,6 @@ impl Decoder {
}
Ok(xs)
}
pub fn reset_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.reset_kv_cache()
}
}
}
#[derive(Debug, Clone)]
@ -461,11 +442,6 @@ impl Model {
decoder,
})
}
fn reset_kv_cache(&mut self) {
self.encoder.reset_kv_cache();
self.decoder.reset_kv_cache();
}
}
#[derive(Debug, Clone)]
@ -513,8 +489,4 @@ impl MTModel {
.apply(&self.lm_head)?
.broadcast_add(&self.final_logits_bias)
}
pub fn reset_kv_cache(&mut self) {
self.model.reset_kv_cache();
}
}

View File

@ -29,10 +29,8 @@ pub mod segment_anything;
pub mod stable_diffusion;
pub mod stable_lm;
pub mod t5;
pub mod trocr;
pub mod vgg;
pub mod vit;
pub mod whisper;
pub mod with_tracing;
pub mod wuerstchen;
pub mod yi;

View File

@ -16,7 +16,7 @@ struct RmsNorm {
impl RmsNorm {
fn new(scale: QTensor, eps: f32) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let scale = scale.dequantize(&Device::Cpu)?;
let scale = scale.dequantize(scale.device())?;
let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64);
Ok(Self { inner, span })
}
@ -112,6 +112,7 @@ impl LayerWeights {
let q = self.attention_wq.forward(x)?;
let k = self.attention_wk.forward(x)?;
let v = self.attention_wv.forward(x)?;
// println!("Q {:?} K {:?} V {:?}", q.dtype(), k.dtype(), v.dtype());
let q = q
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
@ -145,9 +146,12 @@ impl LayerWeights {
let v = self.repeat_kv(v)?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
// println!("att {:?}", att.dtype());
let mask = mask.broadcast_as(att.shape())?;
// println!("mask {:?}", mask.dtype());
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = candle_nn::ops::softmax_last_dim(&att)?;
// println!("att {:?} v {:?}", att.dtype(), v.dtype());
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
@ -181,13 +185,23 @@ pub struct ModelWeights {
span_output: tracing::Span,
}
fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> {
fn precomput_freqs_cis(
head_dim: usize,
freq_base: f32,
device: &Device,
) -> Result<(Tensor, Tensor)> {
let theta: Vec<_> = (0..head_dim)
.step_by(2)
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?;
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)?
let theta = Tensor::new(theta.as_slice(), device)?;
let range: Vec<f32> = (0..MAX_SEQ_LEN).map(|r| r as f32).collect();
// let idx_theta = Tensor::new(range.as_slice(), device)?
// .reshape((MAX_SEQ_LEN, 1))?
// .matmul(&theta.reshape((1, theta.elem_count()))?)?;
// TODO This change avoids allocating on Metal and then casting since allocating directly on
// CPU as f32 seems just as fast
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
.to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
@ -197,12 +211,11 @@ fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tenso
}
impl ModelWeights {
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
let cpu = &Device::Cpu;
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize, device: &Device) -> Result<Self> {
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?;
let (cos, sin) = precomput_freqs_cis(head_dim, 10000., device)?;
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
let output = ct.remove("output.weight")?;
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
@ -257,8 +270,8 @@ impl ModelWeights {
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content,
reader: &mut R,
device: &Device,
) -> Result<Self> {
let cpu = &Device::Cpu;
let md_get = |s: &str| match ct.metadata.get(s) {
None => candle::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v),
@ -276,24 +289,31 @@ impl ModelWeights {
let rope_freq_base = md_get("llama.rope.freq_base")
.and_then(|m| m.to_f32())
.unwrap_or(10000f32);
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?;
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
let tok_embeddings = ct.tensor(reader, "token_embd.weight")?;
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?;
let output = ct.tensor(reader, "output.weight")?;
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
let norm = RmsNorm::new(
ct.tensor(reader, "output_norm.weight", device)?,
rms_norm_eps,
)?;
let output = ct.tensor(reader, "output.weight", device)?;
let mut layers = Vec::with_capacity(block_count);
for layer_idx in 0..block_count {
let prefix = format!("blk.{layer_idx}");
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?;
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?;
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?;
let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?;
let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?;
let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?;
let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?;
let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?;
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?;
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
let attention_wo =
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
let feed_forward_w1 =
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
let feed_forward_w2 =
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
let attention_norm =
ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?;
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
@ -331,14 +351,14 @@ impl ModelWeights {
})
}
fn mask(&mut self, t: usize) -> Result<Tensor> {
fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> {
if let Some(mask) = self.masks.get(&t) {
Ok(mask.clone())
} else {
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
let mask = Tensor::from_slice(&mask, (t, t), device)?;
self.masks.insert(t, mask.clone());
Ok(mask)
}
@ -346,7 +366,7 @@ impl ModelWeights {
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, seq_len) = x.dims2()?;
let mask = self.mask(seq_len)?;
let mask = self.mask(seq_len, x.device())?;
let _enter = self.span.enter();
let mut layer_in = self.tok_embeddings.forward(x)?;
for layer in self.layers.iter_mut() {

View File

@ -1,7 +1,6 @@
// T5 Text Model, quantized version
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
use crate::models::t5::{deserialize_feed_forward_proj_activation, ActivationWithOptionalGating};
use crate::models::with_tracing::QMatMul;
use crate::quantized_nn::Embedding;
pub use crate::quantized_var_builder::VarBuilder;
@ -55,8 +54,8 @@ pub struct Config {
dropout_rate: f64,
layer_norm_epsilon: f64,
initializer_factor: f64,
#[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")]
pub feed_forward_proj: ActivationWithOptionalGating,
#[serde(default)]
feed_forward_proj: Activation,
#[serde(default = "default_tie_word_embeddings")]
tie_word_embeddings: bool,
#[serde(default = "default_is_decoder")]
@ -66,7 +65,6 @@ pub struct Config {
pub use_cache: bool,
pub pad_token_id: usize,
pub eos_token_id: usize,
pub decoder_start_token_id: Option<usize>,
}
impl Default for Config {
@ -84,17 +82,13 @@ impl Default for Config {
dropout_rate: 0.1,
layer_norm_epsilon: 1e-6,
initializer_factor: 1.0,
feed_forward_proj: ActivationWithOptionalGating {
gated: false,
activation: Activation::Relu,
},
feed_forward_proj: Activation::Relu,
tie_word_embeddings: true,
is_decoder: false,
is_encoder_decoder: true,
use_cache: true,
pad_token_id: 0,
eos_token_id: 1,
decoder_start_token_id: Some(0),
}
}
}
@ -180,7 +174,7 @@ impl T5DenseGatedActDense {
wi_0,
wi_1,
wo,
act: cfg.feed_forward_proj.activation,
act: Activation::NewGelu,
span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
})
}
@ -209,7 +203,7 @@ impl T5LayerFF {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let layer_norm =
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj.gated {
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu {
(
None,
Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
@ -648,12 +642,7 @@ pub struct T5EncoderModel {
impl T5EncoderModel {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let shared_vb = if vb.contains_key("shared.weight") {
vb.pp("shared")
} else {
vb.pp("decoder").pp("embed_tokens")
};
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
let shared = Arc::new(shared);
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
Ok(Self {
@ -694,12 +683,7 @@ impl T5ForConditionalGeneration {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
assert!(cfg.is_encoder_decoder);
let d_model = cfg.d_model;
let shared_vb = if vb.contains_key("shared.weight") {
vb.pp("shared")
} else {
vb.pp("decoder").pp("embed_tokens")
};
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
let shared = Arc::new(shared);
let mut encoder_cfg = cfg.clone();

View File

@ -1,4 +1,3 @@
pub use crate::models::with_tracing::Linear;
use candle::{Result, Tensor};
use candle_nn::{Module, VarBuilder};
@ -10,11 +9,13 @@ pub mod tiny_vit;
pub mod transformer;
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
if bias {
crate::models::with_tracing::linear(in_dim, out_dim, vb)
let inner = if bias {
candle_nn::linear(in_dim, out_dim, vb)?
} else {
crate::models::with_tracing::linear_no_bias(in_dim, out_dim, vb)
}
candle_nn::linear_no_bias(in_dim, out_dim, vb)?
};
let span = tracing::span!(tracing::Level::TRACE, "linear");
Ok(Linear { inner, span })
}
#[derive(Debug)]
@ -84,3 +85,16 @@ impl Module for MlpBlock {
.apply(&self.lin2)
}
}
#[derive(Debug)]
pub struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
}
impl Module for Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}

View File

@ -102,14 +102,6 @@ impl Config {
}
}
pub fn ssd1b() -> Self {
Self::sdxl()
}
pub fn ssd1b2() -> Self {
Self::sdxl2()
}
// https://huggingface.co/warp-ai/wuerstchen/blob/main/text_encoder/config.json
pub fn wuerstchen() -> Self {
Self {

View File

@ -249,71 +249,6 @@ impl StableDiffusionConfig {
)
}
pub fn ssd1b(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
) -> Self {
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
};
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![
bc(320, None, 5),
bc(640, Some(2), 10),
bc(1280, Some(10), 20),
],
center_input_sample: false,
cross_attention_dim: 2048,
downsample_padding: 1,
flip_sin_to_cos: true,
freq_shift: 0.,
layers_per_block: 2,
mid_block_scale_factor: 1.,
norm_eps: 1e-5,
norm_num_groups: 32,
sliced_attention_size,
use_linear_projection: true,
};
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/vae/config.json
let autoencoder = vae::AutoEncoderKLConfig {
block_out_channels: vec![128, 256, 512, 512],
layers_per_block: 2,
latent_channels: 4,
norm_num_groups: 32,
};
let scheduler = ddim::DDIMSchedulerConfig {
..Default::default()
};
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
1024
};
let width = if let Some(width) = width {
assert_eq!(width % 8, 0, "width has to be divisible by 8");
width
} else {
1024
};
Self {
width,
height,
clip: clip::Config::ssd1b(),
clip2: Some(clip::Config::ssd1b2()),
autoencoder,
scheduler,
unet,
}
}
pub fn build_vae<P: AsRef<std::path::Path>>(
&self,
vae_weights: P,

View File

@ -37,37 +37,6 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
Ok(m)
}
#[derive(Debug, Deserialize, Default, Clone, PartialEq)]
pub struct ActivationWithOptionalGating {
pub gated: bool,
pub activation: candle_nn::Activation,
}
pub fn deserialize_feed_forward_proj_activation<'de, D>(
deserializer: D,
) -> std::result::Result<ActivationWithOptionalGating, D::Error>
where
D: serde::de::Deserializer<'de>,
{
match String::deserialize(deserializer)?.as_str() {
"gated-gelu" => Ok(ActivationWithOptionalGating {
gated: true,
activation: candle_nn::Activation::NewGelu,
}),
"gated-silu" => Ok(ActivationWithOptionalGating {
gated: true,
activation: candle_nn::Activation::Silu,
}),
buf => {
let activation = serde_plain::from_str(buf).map_err(serde::de::Error::custom)?;
Ok(ActivationWithOptionalGating {
gated: false,
activation,
})
}
}
}
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
vocab_size: usize,
@ -83,8 +52,8 @@ pub struct Config {
dropout_rate: f64,
layer_norm_epsilon: f64,
initializer_factor: f64,
#[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")]
feed_forward_proj: ActivationWithOptionalGating,
#[serde(default)]
feed_forward_proj: Activation,
#[serde(default = "default_tie_word_embeddings")]
tie_word_embeddings: bool,
#[serde(default = "default_is_decoder")]
@ -94,7 +63,6 @@ pub struct Config {
pub use_cache: bool,
pub pad_token_id: usize,
pub eos_token_id: usize,
pub decoder_start_token_id: Option<usize>,
}
impl Default for Config {
@ -112,17 +80,13 @@ impl Default for Config {
dropout_rate: 0.1,
layer_norm_epsilon: 1e-6,
initializer_factor: 1.0,
feed_forward_proj: ActivationWithOptionalGating {
gated: false,
activation: Activation::Relu,
},
feed_forward_proj: Activation::Relu,
tie_word_embeddings: true,
is_decoder: false,
is_encoder_decoder: true,
use_cache: true,
pad_token_id: 0,
eos_token_id: 1,
decoder_start_token_id: Some(0),
}
}
}
@ -136,10 +100,7 @@ impl Config {
d_model: 768,
dropout_rate: 0.1,
eos_token_id: 1,
feed_forward_proj: ActivationWithOptionalGating {
gated: false,
activation: Activation::Relu,
},
feed_forward_proj: Activation::Relu,
tie_word_embeddings: true,
initializer_factor: 1.0,
is_decoder: false,
@ -149,7 +110,6 @@ impl Config {
num_heads: 12,
num_layers: 12,
pad_token_id: 0,
decoder_start_token_id: Some(0),
relative_attention_max_distance: 128,
relative_attention_num_buckets: 32,
use_cache: true,
@ -239,7 +199,7 @@ impl T5DenseGatedActDense {
wi_0,
wi_1,
wo,
act: cfg.feed_forward_proj.activation,
act: Activation::NewGelu,
span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
})
}
@ -268,7 +228,7 @@ impl T5LayerFF {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let layer_norm =
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj.gated {
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu {
(
None,
Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
@ -462,7 +422,7 @@ impl T5Attention {
self.relative_attention_max_distance as f32
/ max_exact as f32,
) * (num_buckets - max_exact) as f32;
u32::min(max_exact + b as u32, num_buckets - 1)
max_exact + b as u32
}
})
.collect::<Vec<u32>>()
@ -707,12 +667,7 @@ pub struct T5EncoderModel {
impl T5EncoderModel {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let shared_vb = if vb.contains_tensor("shared.weight") {
vb.pp("shared")
} else {
vb.pp("decoder").pp("embed_tokens")
};
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
let shared = Arc::new(shared);
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
Ok(Self {
@ -753,12 +708,7 @@ impl T5ForConditionalGeneration {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
assert!(cfg.is_encoder_decoder);
let d_model = cfg.d_model;
let shared_vb = if vb.contains_tensor("shared.weight") {
vb.pp("shared")
} else {
vb.pp("decoder").pp("embed_tokens")
};
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
let shared = Arc::new(shared);
let mut encoder_cfg = cfg.clone();

View File

@ -1,434 +0,0 @@
use crate::models::vit::{Config, Embeddings, Encoder};
use candle::{Result, Tensor};
use candle_nn::{
embedding, layer_norm, linear_no_bias, Embedding, LayerNorm, Linear, Module, VarBuilder,
};
use serde::Deserialize;
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct TrOCRConfig {
pub vocab_size: usize,
pub d_model: usize,
pub hidden_size: usize,
pub decoder_layers: usize,
pub decoder_attention_heads: usize,
pub decoder_ffn_dim: usize,
pub activation_function: candle_nn::Activation,
pub max_position_embeddings: usize,
pub dropout: f64,
pub attention_dropout: f64,
pub activation_dropout: f64,
pub decoder_start_token_id: u32,
pub init_std: f64,
pub decoder_layerdrop: f64,
pub use_cache: bool,
pub scale_embedding: bool,
pub use_learned_position_embeddings: bool,
pub layernorm_embedding: bool,
pub pad_token_id: usize,
pub bos_token_id: usize,
pub eos_token_id: u32,
pub num_attention_heads: usize,
pub decoder_vocab_size: Option<usize>,
}
impl Default for TrOCRConfig {
fn default() -> Self {
Self {
vocab_size: 50265,
d_model: 1024,
hidden_size: 768,
decoder_layers: 12,
decoder_attention_heads: 16,
decoder_ffn_dim: 4096,
activation_function: candle_nn::Activation::Gelu,
max_position_embeddings: 512,
dropout: 0.1,
attention_dropout: 0.0,
activation_dropout: 0.0,
decoder_start_token_id: 2,
init_std: 0.02,
decoder_layerdrop: 0.0,
use_cache: true,
scale_embedding: false,
use_learned_position_embeddings: true,
layernorm_embedding: true,
pad_token_id: 1,
bos_token_id: 0,
eos_token_id: 2,
num_attention_heads: 12,
decoder_vocab_size: Some(50265),
}
}
}
#[derive(Debug, Clone)]
struct TrOCRLearnedPositionalEmbedding {
offset: usize,
weights: Embedding,
}
impl TrOCRLearnedPositionalEmbedding {
fn load(vb: VarBuilder, cfg: &TrOCRConfig) -> Result<Self> {
let offset: usize = 2;
let num_embeddings = cfg.max_position_embeddings;
let embedding_dim = cfg.d_model;
let weights = embedding(num_embeddings + offset, embedding_dim, vb)?;
Ok(Self { offset, weights })
}
fn forward(&mut self, input_ids: &Tensor, past_key_values_length: u32) -> Result<Tensor> {
let (b_sz, seq_len) = input_ids.dims2()?;
let mut positions = Tensor::arange(
past_key_values_length,
seq_len as u32 + past_key_values_length,
input_ids.device(),
)?
.expand((b_sz, seq_len))?;
positions =
positions.broadcast_add(&Tensor::new(self.offset as u32, input_ids.device())?)?;
self.weights.forward(&positions)
}
}
#[derive(Debug, Clone)]
struct TrOCRAttention {
head_dim: usize,
num_heads: usize,
is_decoder: bool,
scaling: f64,
k_proj: Linear,
v_proj: Linear,
q_proj: Linear,
out_proj: Linear,
kv_cache: Option<(Tensor, Tensor)>,
}
impl TrOCRAttention {
fn load(
vb: VarBuilder,
cfg: &TrOCRConfig,
kdim: Option<usize>,
vdim: Option<usize>,
) -> Result<Self> {
let embed_dim = cfg.d_model;
let num_heads = cfg.decoder_attention_heads;
let head_dim = embed_dim / num_heads;
let kdim = kdim.unwrap_or(embed_dim);
let vdim = vdim.unwrap_or(embed_dim);
let k_proj = linear_no_bias(kdim, embed_dim, vb.pp("k_proj"))?;
let v_proj = linear_no_bias(vdim, embed_dim, vb.pp("v_proj"))?;
let q_proj = linear_no_bias(embed_dim, embed_dim, vb.pp("q_proj"))?;
let out_proj = linear_no_bias(embed_dim, embed_dim, vb.pp("out_proj"))?;
Ok(Self {
head_dim,
num_heads,
is_decoder: true,
scaling: 1. / (head_dim as f64).sqrt(),
k_proj,
v_proj,
q_proj,
out_proj,
kv_cache: None,
})
}
fn _shape(&self, tensor: &Tensor, bsz: usize) -> Result<Tensor> {
tensor
.reshape((bsz, (), self.num_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()
}
fn forward(
&mut self,
xs: &Tensor,
kv_states: Option<&Tensor>,
attn_mask: Option<&Tensor>,
) -> Result<Tensor> {
let (b_sz, tgt_len, _) = xs.dims3()?;
let query_states = (xs.apply(&self.q_proj)? * self.scaling)?;
let (key_states, value_states) = match kv_states {
None => {
let key_states = self._shape(&xs.apply(&self.k_proj)?, b_sz)?;
let value_states = self._shape(&xs.apply(&self.v_proj)?, b_sz)?;
if self.is_decoder {
let kv_states = match &self.kv_cache {
None => (key_states, value_states),
Some((p_key_states, p_value_states)) => {
let key_states = Tensor::cat(&[p_key_states, &key_states], 2)?;
let value_states = Tensor::cat(&[p_value_states, &value_states], 2)?;
(key_states, value_states)
}
};
self.kv_cache = Some(kv_states.clone());
kv_states
} else {
(key_states, value_states)
}
}
Some(kv_states) => {
let key_states = self._shape(&kv_states.apply(&self.k_proj)?, b_sz)?;
let value_states = self._shape(&kv_states.apply(&self.v_proj)?, b_sz)?;
(key_states, value_states)
}
};
let proj_shape = (b_sz * self.num_heads, (), self.head_dim);
let query_states = self._shape(&query_states, b_sz)?.reshape(proj_shape)?;
let key_states = key_states.reshape(proj_shape)?;
let value_states = value_states.reshape(proj_shape)?;
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
let attn_weights = match attn_mask {
None => attn_weights,
Some(attn_mask) => attn_weights.broadcast_add(attn_mask)?,
};
let attn_probs = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let attn_output = attn_probs.matmul(&value_states)?;
attn_output
.reshape((b_sz, self.num_heads, tgt_len, self.head_dim))?
.transpose(1, 2)?
.reshape((b_sz, tgt_len, self.head_dim * self.num_heads))?
.apply(&self.out_proj)
}
}
#[derive(Debug, Clone)]
struct TrOCRDecoderLayer {
self_attn: TrOCRAttention,
activation_fn: candle_nn::Activation,
self_attn_layer_norm: LayerNorm,
encoder_attn: TrOCRAttention,
encoder_attn_layer_norm: LayerNorm,
fc1: Linear,
fc2: Linear,
final_layer_norm: LayerNorm,
}
impl TrOCRDecoderLayer {
fn load(vb: VarBuilder, cfg: &TrOCRConfig) -> Result<Self> {
let embed_dim = cfg.d_model;
let self_attn = TrOCRAttention::load(vb.pp("self_attn"), cfg, None, None)?;
let self_attn_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("self_attn_layer_norm"))?;
let encoder_attn = TrOCRAttention::load(
vb.pp("encoder_attn"),
cfg,
Some(cfg.hidden_size),
Some(cfg.hidden_size),
)?;
let encoder_attn_layer_norm =
layer_norm(embed_dim, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
let fc1 = linear_no_bias(embed_dim, cfg.decoder_ffn_dim, vb.pp("fc1"))?;
let fc2 = linear_no_bias(cfg.decoder_ffn_dim, embed_dim, vb.pp("fc2"))?;
let final_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("final_layer_norm"))?;
let activation_fn = candle_nn::Activation::Gelu;
Ok(Self {
self_attn,
activation_fn,
self_attn_layer_norm,
encoder_attn,
encoder_attn_layer_norm,
fc1,
fc2,
final_layer_norm,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: &Tensor,
encoder_hidden_states: Option<&Tensor>,
) -> Result<Tensor> {
let residual = xs.clone();
let xs = self.self_attn.forward(xs, None, Some(attention_mask))?;
let xs = (xs + residual)?;
let mut xs = self.self_attn_layer_norm.forward(&xs)?;
if let Some(encoder_hidden_states) = &encoder_hidden_states {
let residual = xs.clone();
let encoder_attention_mask = attention_mask.clone(); // TODO
xs = self.encoder_attn.forward(
&xs,
Some(encoder_hidden_states),
Some(&encoder_attention_mask),
)?;
xs = (xs + residual)?;
xs = self.encoder_attn_layer_norm.forward(&xs)?
}
let residual = xs.clone();
let xs = self.fc1.forward(&xs)?;
let xs = self.activation_fn.forward(&xs)?;
let xs = self.fc2.forward(&xs)?;
let xs = (xs + residual)?;
let xs = self.final_layer_norm.forward(&xs)?;
Ok(xs)
}
}
#[derive(Debug, Clone)]
pub struct TrOCRDecoder {
layers: Vec<TrOCRDecoderLayer>,
embed_scale: Option<f64>,
embed_tokens: Embedding,
embed_positions: TrOCRLearnedPositionalEmbedding,
}
impl TrOCRDecoder {
fn new(cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
let vb = vb.pp("decoder.model.decoder");
let embed_tokens = embedding(cfg.vocab_size, cfg.d_model, vb.pp("embed_tokens"))?;
let embed_positions = TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?;
let mut layers = Vec::with_capacity(cfg.decoder_layers);
let vb_l = vb.pp("layers");
for idx in 0..cfg.decoder_layers {
let layer = TrOCRDecoderLayer::load(vb_l.pp(idx), cfg)?;
layers.push(layer)
}
let embed_scale = if cfg.scale_embedding {
Some((cfg.d_model as f64).sqrt())
} else {
None
};
Ok(Self {
layers,
embed_scale,
embed_tokens,
embed_positions,
})
}
pub fn forward(
&mut self,
xs: &Tensor,
encoder_xs: Option<&Tensor>,
past_kv_len: usize,
attn_mask: &Tensor,
) -> Result<Tensor> {
let embed_pos = self.embed_positions.forward(xs, past_kv_len as u32)?;
let xs = xs.apply(&self.embed_tokens)?;
let xs = match self.embed_scale {
None => xs,
Some(scale) => (xs * scale)?,
};
let mut xs = xs.broadcast_add(&embed_pos)?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attn_mask, encoder_xs)?;
}
Ok(xs)
}
}
#[derive(Debug, Clone)]
pub struct TrOCREncoder {
embeddings: Embeddings,
encoder: Encoder,
layernorm: LayerNorm,
}
impl TrOCREncoder {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_v = vb.pp("encoder");
let embeddings = Embeddings::new(cfg, false, vb_v.pp("embeddings"))?;
let encoder = Encoder::new(cfg, vb_v.pp("encoder"))?;
let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_v.pp("layernorm"))?;
Ok(Self {
embeddings,
encoder,
layernorm,
})
}
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let embedding_output = self.embeddings.forward(xs, None, false)?;
let encoder_outputs = self.encoder.forward(&embedding_output)?;
self.layernorm.forward(&encoder_outputs)
}
}
#[derive(Debug, Clone)]
pub struct TrOCRForCausalLM {
decoder: TrOCRDecoder,
output_projection: Linear,
}
impl TrOCRForCausalLM {
pub fn new(decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
let decoder = TrOCRDecoder::new(decoder_cfg, vb.clone())?;
let output_projection =
candle_nn::Linear::new(decoder.embed_tokens.embeddings().clone(), None);
Ok(Self {
decoder,
output_projection,
})
}
pub fn forward(
&mut self,
xs: &Tensor,
encoder_xs: Option<&Tensor>,
past_kv_len: usize,
attn_mask: &Tensor,
) -> Result<Tensor> {
let xs = self
.decoder
.forward(xs, encoder_xs, past_kv_len, attn_mask)?;
let xs = xs.apply(&self.output_projection)?;
Ok(xs)
}
}
#[derive(Debug, Clone)]
pub struct TrOCRModel {
encoder: TrOCREncoder,
decoder: TrOCRForCausalLM,
}
impl TrOCRModel {
pub fn new(encoder_cfg: &Config, decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
let encoder = TrOCREncoder::new(encoder_cfg, vb.clone())?;
let decoder = TrOCRForCausalLM::new(decoder_cfg, vb)?;
Ok(Self { encoder, decoder })
}
pub fn encoder(&mut self) -> &mut TrOCREncoder {
&mut self.encoder
}
pub fn decoder(&mut self) -> &mut TrOCRForCausalLM {
&mut self.decoder
}
pub fn decode(
&mut self,
xs: &Tensor,
encoder_xs: &Tensor,
past_kv_len: usize,
) -> Result<Tensor> {
let seq_len = xs.dim(1)?;
let mask: Vec<_> = (0..seq_len)
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
.collect();
let mask = Tensor::from_vec(mask, (seq_len, seq_len), xs.device())?;
self.decoder
.forward(xs, Some(encoder_xs), past_kv_len, &mask)
}
}

View File

@ -6,16 +6,16 @@ use candle_nn::{layer_norm, LayerNorm, VarBuilder};
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/configuration_vit.py
#[derive(Debug, Clone)]
pub struct Config {
pub hidden_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub intermediate_size: usize,
pub hidden_act: candle_nn::Activation,
pub layer_norm_eps: f64,
pub image_size: usize,
pub patch_size: usize,
pub num_channels: usize,
pub qkv_bias: bool,
hidden_size: usize,
num_hidden_layers: usize,
num_attention_heads: usize,
intermediate_size: usize,
hidden_act: candle_nn::Activation,
layer_norm_eps: f64,
image_size: usize,
patch_size: usize,
num_channels: usize,
qkv_bias: bool,
}
impl Config {
@ -34,21 +34,6 @@ impl Config {
qkv_bias: true,
}
}
pub fn microsoft_trocr_base_handwritten() -> Self {
Self {
hidden_size: 768,
num_hidden_layers: 12,
num_attention_heads: 12,
intermediate_size: 3072,
hidden_act: candle_nn::Activation::Gelu,
layer_norm_eps: 1e-12,
image_size: 384,
patch_size: 16,
num_channels: 3,
qkv_bias: false,
}
}
}
#[derive(Debug, Clone)]
@ -91,7 +76,7 @@ impl Module for PatchEmbeddings {
}
#[derive(Debug, Clone)]
pub struct Embeddings {
struct Embeddings {
cls_token: Tensor,
mask_token: Option<Tensor>,
patch_embeddings: PatchEmbeddings,
@ -100,7 +85,7 @@ pub struct Embeddings {
}
impl Embeddings {
pub fn new(cfg: &Config, use_mask_token: bool, vb: VarBuilder) -> Result<Self> {
fn new(cfg: &Config, use_mask_token: bool, vb: VarBuilder) -> Result<Self> {
let hidden_size = cfg.hidden_size;
let cls_token = vb.get((1, 1, hidden_size), "cls_token")?;
let mask_token = if use_mask_token {
@ -130,7 +115,7 @@ impl Embeddings {
todo!()
}
pub fn forward(
fn forward(
&self,
pixel_values: &Tensor,
bool_masked_pos: Option<&Tensor>,
@ -339,12 +324,12 @@ impl Module for Layer {
}
#[derive(Debug, Clone)]
pub struct Encoder {
struct Encoder {
layers: Vec<Layer>,
}
impl Encoder {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb = vb.pp("layer");
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
for i in 0..cfg.num_hidden_layers {

View File

@ -198,17 +198,13 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
mel
}
pub fn pcm_to_mel<T: Float + std::fmt::Display>(
cfg: &super::Config,
samples: &[T],
filters: &[T],
) -> Vec<T> {
pub fn pcm_to_mel<T: Float + std::fmt::Display>(samples: &[T], filters: &[T]) -> Vec<T> {
log_mel_spectrogram_(
samples,
filters,
super::N_FFT,
super::HOP_LENGTH,
cfg.num_mel_bins,
super::N_MELS,
false,
)
}

View File

@ -18,7 +18,6 @@ pub struct Config {
// pub n_text_state: usize,
pub decoder_attention_heads: usize, // n_text_head
pub decoder_layers: usize, // n_text_layer
#[serde(default)]
pub suppress_tokens: Vec<u32>,
}
@ -27,6 +26,7 @@ pub const DTYPE: candle::DType = candle::DType::F32;
// Audio parameters.
pub const SAMPLE_RATE: usize = 16000;
pub const N_FFT: usize = 400;
pub const N_MELS: usize = 80;
pub const HOP_LENGTH: usize = 160;
pub const CHUNK_LENGTH: usize = 30;
pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk

View File

@ -1,5 +1,4 @@
use super::Config;
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
use candle::{Device, IndexOp, Result, Tensor, D};
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
@ -7,6 +6,33 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
//
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
// model.
#[derive(Debug, Clone)]
pub struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
}
impl Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let span = tracing::span!(tracing::Level::TRACE, "linear");
let inner = candle_nn::linear(size1, size2, vb)?;
Ok(Linear { inner, span })
}
fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let span = tracing::span!(tracing::Level::TRACE, "linear");
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
Ok(Linear { inner, span })
}
fn conv1d(
in_channels: usize,

View File

@ -124,34 +124,3 @@ impl std::fmt::Debug for QMatMul {
write!(f, "QMatMul")
}
}
#[derive(Clone, Debug)]
pub struct LayerNorm {
inner: candle_nn::LayerNorm,
span: tracing::Span,
}
impl LayerNorm {
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
let inner = candle_nn::LayerNorm::new(weight, bias, eps);
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
Self { inner, span }
}
}
impl Module for LayerNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(xs)
}
}
pub fn layer_norm<C: Into<candle_nn::LayerNormConfig>>(
size: usize,
c: C,
vb: VarBuilder,
) -> Result<LayerNorm> {
let inner = candle_nn::layer_norm(size, c, vb)?;
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
Ok(LayerNorm { inner, span })
}

View File

@ -1,377 +0,0 @@
/// https://huggingface.co/01-ai/Yi-6B/blob/main/modeling_yi.py
use crate::models::with_tracing::{linear_no_bias, Linear};
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)]
pub struct Config {
pub(crate) vocab_size: usize,
pub(crate) hidden_size: usize,
pub(crate) intermediate_size: usize,
pub(crate) num_hidden_layers: usize,
pub(crate) num_attention_heads: usize,
pub(crate) num_key_value_heads: usize,
pub(crate) hidden_act: Activation,
pub(crate) max_position_embeddings: usize,
pub(crate) rms_norm_eps: f64,
pub(crate) rope_theta: f64,
}
impl Config {
pub fn config_6b() -> Self {
Self {
vocab_size: 64000,
hidden_size: 4096,
intermediate_size: 11008,
num_hidden_layers: 32,
num_attention_heads: 32,
num_key_value_heads: 4,
hidden_act: Activation::Silu,
max_position_embeddings: 4096,
rms_norm_eps: 1e-5,
rope_theta: 5_000_000.,
}
}
pub fn config_34b() -> Self {
Self {
vocab_size: 64000,
hidden_size: 7168,
intermediate_size: 20480,
num_hidden_layers: 60,
num_attention_heads: 56,
num_key_value_heads: 8,
hidden_act: Activation::Silu,
max_position_embeddings: 4096,
rms_norm_eps: 1e-5,
rope_theta: 5_000_000.,
}
}
}
#[derive(Debug, Clone)]
struct RmsNorm {
inner: candle_nn::RmsNorm,
span: tracing::Span,
}
impl RmsNorm {
fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let inner = candle_nn::rms_norm(size, eps, vb)?;
Ok(Self { inner, span })
}
}
impl Module for RmsNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
let last_dim = xs.dim(D::Minus1)?;
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.hidden_size / cfg.num_attention_heads;
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,
seqlen_offset: usize,
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
Ok((q_embed, k_embed))
}
}
#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
act_fn: Activation,
}
impl MLP {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let intermediate_sz = cfg.intermediate_size;
let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
Ok(Self {
gate_proj,
up_proj,
down_proj,
act_fn: cfg.hidden_act,
})
}
}
impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
let rhs = xs.apply(&self.up_proj)?;
(lhs * rhs)?.apply(&self.down_proj)
}
}
#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
num_heads: usize,
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
hidden_size: usize,
rotary_emb: Arc<RotaryEmbedding>,
kv_cache: Option<(Tensor, Tensor)>,
}
impl Attention {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let num_kv_groups = num_heads / num_kv_heads;
let head_dim = hidden_sz / num_heads;
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
num_heads,
num_kv_heads,
num_kv_groups,
head_dim,
hidden_size: hidden_sz,
rotary_emb,
kv_cache: None,
})
}
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
let n_rep = self.num_kv_groups;
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
xs.unsqueeze(2)?
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
}
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;
let query_states = self.q_proj.forward(xs)?;
let key_states = self.k_proj.forward(xs)?;
let value_states = self.v_proj.forward(xs)?;
let query_states = query_states
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let key_states = key_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let value_states = value_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let (query_states, key_states) =
self.rotary_emb
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
let (key_states, value_states) = match &self.kv_cache {
None => (key_states, value_states),
Some((prev_k, prev_v)) => {
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
(key_states, value_states)
}
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
let key_states = self.repeat_kv(key_states)?;
let value_states = self.repeat_kv(value_states)?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
attn_weights.matmul(&value_states)?
};
attn_output
.transpose(1, 2)?
.reshape((b_sz, q_len, self.hidden_size))?
.apply(&self.o_proj)
}
}
#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Attention,
mlp: MLP,
ln1: RmsNorm,
ln2: RmsNorm,
}
impl DecoderLayer {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("ln1"))?;
let ln2 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("ln2"))?;
Ok(Self {
self_attn,
mlp,
ln1,
ln2,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let residual = xs;
let xs = self.ln1.forward(xs)?;
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = xs.apply(&self.ln2)?.apply(&self.mlp)?;
residual + xs
}
}
#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,
norm: RmsNorm,
lm_head: Linear,
device: Device,
dtype: DType,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_m = vb.pp("model");
let embed_tokens =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_m.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
Ok(Self {
embed_tokens,
layers,
norm,
lm_head,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
fn prepare_decoder_attention_mask(
&self,
b_size: usize,
tgt_len: usize,
seqlen_offset: usize,
) -> Result<Tensor> {
// Sliding window mask?
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(self.dtype)
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
Some(mask)
};
let mut xs = self.embed_tokens.forward(input_ids)?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
}
xs.narrow(1, seq_len - 1, 1)?
.apply(&self.norm)?
.apply(&self.lm_head)
}
}

View File

@ -10,12 +10,12 @@ pub struct VarBuilder {
}
impl VarBuilder {
pub fn from_gguf<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
pub fn from_gguf<P: AsRef<std::path::Path>>(p: P, device: &Device) -> Result<Self> {
let mut file = std::fs::File::open(p)?;
let content = candle::quantized::gguf_file::Content::read(&mut file)?;
let mut data = std::collections::HashMap::new();
for tensor_name in content.tensor_infos.keys() {
let tensor = content.tensor(&mut file, tensor_name)?;
let tensor = content.tensor(&mut file, tensor_name, device)?;
data.insert(tensor_name.to_string(), Arc::new(tensor));
}
Ok(Self {
@ -25,12 +25,12 @@ impl VarBuilder {
})
}
pub fn from_gguf_buffer(buffer: &[u8]) -> Result<Self> {
pub fn from_gguf_buffer(buffer: &[u8], device: &Device) -> Result<Self> {
let mut cursor = std::io::Cursor::new(buffer);
let content = candle::quantized::gguf_file::Content::read(&mut cursor)?;
let mut data = std::collections::HashMap::new();
for tensor_name in content.tensor_infos.keys() {
let tensor = content.tensor(&mut cursor, tensor_name)?;
let tensor = content.tensor(&mut cursor, tensor_name, device)?;
data.insert(tensor_name.to_string(), Arc::new(tensor));
}
Ok(Self {
@ -90,8 +90,4 @@ impl VarBuilder {
pub fn device(&self) -> &Device {
&self.device
}
pub fn contains_key(&self, key: &str) -> bool {
self.data.contains_key(key)
}
}

View File

@ -1,7 +1,5 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{
embedding, linear_no_bias as linear, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder,
};
use candle_nn::{embedding, linear, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};

Some files were not shown because too many files have changed in this diff Show More