Compare commits

..

8 Commits

Author SHA1 Message Date
9a27f11c3f Adding tons of profiling and removing the metal allocation (still slow). 2023-11-02 17:48:07 +01:00
7161002a34 Finished scaffolding, lots of TODOs
- Most kernels just copy themselfs to get the shapes correct
- Matmul works only in 1 case and simply empty allocates otherwise
- Logits and randomized to make the demo finish itself.

Performance is quite bad (30ms/token), but lot's of prints and allocs and some actual sending to metal.

Couln't get it super high by removing the obvious blockers (println + the actual running matmuls).

Allocations takes between 1us and 100us and seems very stable, Maybe metal doesn't really have a smart allocator and we'll need to own it.
2023-11-02 15:32:28 +01:00
82cce52e73 Rename candle-metal -> candle-metal-kernels 2023-11-02 09:53:29 +01:00
71fcb31873 Owned command buffer now. 2023-11-01 18:03:53 +01:00
198009453a Matmul (no batch, no strided, f32, f32 only) sort of done. 2023-11-01 17:36:51 +01:00
492d164235 More scaffolding, now need to implement matmul (for precompute_cos_sin to work). 2023-11-01 16:54:09 +01:00
2d84c16fed First pass (Quantized scaffolding work done + quantized example scaffolding). 2023-11-01 15:10:11 +01:00
4525b7b52a Initial setup 2023-10-31 18:09:10 +01:00
104 changed files with 879 additions and 8446 deletions

View File

@ -59,7 +59,7 @@ jobs:
- name: Install Rust Stable - name: Install Rust Stable
run: curl https://sh.rustup.rs -sSf | sh -s -- -y run: curl https://sh.rustup.rs -sSf | sh -s -- -y
- uses: Swatinem/rust-cache@v2 - uses: Swatinem/rust-cache@v2
- run: apt-get update -y && apt-get install libssl-dev protobuf-compiler -y - run: apt-get update -y && apt-get install libssl-dev -y
- name: Test (cuda) - name: Test (cuda)
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
stop-runner: stop-runner:

Binary file not shown.

View File

@ -39,12 +39,6 @@ jobs:
path: ~/.cargo/registry path: ~/.cargo/registry
key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }} key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}
- name: Install Protoc
uses: arduino/setup-protoc@v2
with:
version: "25.0"
repo-token: ${{ secrets.GITHUB_TOKEN }}
- name: Install - name: Install
working-directory: ./candle-pyo3 working-directory: ./candle-pyo3
run: | run: |
@ -52,7 +46,7 @@ jobs:
source .env/bin/activate source .env/bin/activate
pip install -U pip pip install -U pip
pip install pytest maturin black pip install pytest maturin black
python -m maturin develop -r --features onnx python -m maturin develop -r
- name: Check style - name: Check style
working-directory: ./candle-pyo3 working-directory: ./candle-pyo3

View File

@ -10,12 +10,7 @@ members = [
"candle-wasm-examples/*", "candle-wasm-examples/*",
"candle-wasm-tests", "candle-wasm-tests",
] ]
exclude = [ exclude = ["candle-flash-attn", "candle-kernels"]
"candle-flash-attn",
"candle-kernels",
"candle-metal-kernels",
"candle-onnx",
]
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]
@ -51,7 +46,6 @@ rayon = "1.7.0"
rusttype = { version = "0.9", default-features = false } rusttype = { version = "0.9", default-features = false }
safetensors = "0.3.1" safetensors = "0.3.1"
serde = { version = "1.0.171", features = ["derive"] } serde = { version = "1.0.171", features = ["derive"] }
serde_plain = "1.0.2"
serde_json = "1.0.99" serde_json = "1.0.99"
thiserror = "1" thiserror = "1"
tokenizers = { version = "0.13.4", default-features = false } tokenizers = { version = "0.13.4", default-features = false }
@ -61,7 +55,8 @@ tracing-subscriber = "0.3.7"
wav = "1.0.0" wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] } yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false } zip = { version = "0.6.6", default-features = false }
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } # metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
metal = { path = "../metal-rs", features = ["mps"] }
[profile.release-with-debug] [profile.release-with-debug]
inherits = "release" inherits = "release"

View File

@ -51,7 +51,7 @@ For more advanced examples, please have a look at the following section.
These online demos run entirely in your browser: These online demos run entirely in your browser:
- [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and - [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and
object recognition. object recognition.
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): speech recognition. - [whisper](https://huggingface.co/spaces/lmz/candle-whisper): text to speech.
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation. - [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation. - [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation. - [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
@ -69,8 +69,6 @@ We also provide a some command line based examples using state of the art models
performance larger than all publicly available 13b models as of 2023-09-28. performance larger than all publicly available 13b models as of 2023-09-28.
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation. - [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion. - [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
(English/Chinese) general LLMs with 6b and 34b parameters.
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of - [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
the LLaMA model using the same quantization techniques as the LLaMA model using the same quantization techniques as
[llama.cpp](https://github.com/ggerganov/llama.cpp). [llama.cpp](https://github.com/ggerganov/llama.cpp).
@ -145,11 +143,6 @@ And then head over to
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop. including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): a LoRA implementation - [`candle-lora`](https://github.com/EricLBuehler/candle-lora): a LoRA implementation
that conforms to the official `peft` implementation. that conforms to the official `peft` implementation.
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
serving local LLMs including an OpenAI compatible API server.
- [`candle-ext`](https://github.com/mokeyish/candle-ext): an extension library to Candle that provides PyTorch functions not currently available in Candle.
- [`kalosm`](https://github.com/floneum/floneum/tree/master/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
If you have an addition to this list, please submit a pull request. If you have an addition to this list, please submit a pull request.
@ -175,17 +168,16 @@ If you have an addition to this list, please submit a pull request.
- Mistral 7b v0.1. - Mistral 7b v0.1.
- StableLM-3B-4E1T. - StableLM-3B-4E1T.
- Replit-code-v1.5-3B. - Replit-code-v1.5-3B.
- T5.
- Bert. - Bert.
- Yi-6B and Yi-34B.
- Text to text.
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
- Marian MT (Machine Translation).
- Whisper (multi-lingual support). - Whisper (multi-lingual support).
- Text to image. - Text to image.
- Stable Diffusion v1.5, v2.1, XL v1.0. - Stable Diffusion v1.5, v2.1, XL v1.0.
- Wurstchen v2. - Wurstchen v2.
- Image to text. - Image to text.
- BLIP. - BLIP.
- Text to text.
- Marian MT (Machine Translation).
- Computer Vision Models. - Computer Vision Models.
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT. - DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
- yolo-v3, yolo-v8. - yolo-v3, yolo-v8.
@ -226,7 +218,6 @@ Cheatsheet:
- [candle-datasets](./candle-datasets/): Datasets and data loaders. - [candle-datasets](./candle-datasets/): Datasets and data loaders.
- [candle-transformers](./candle-transformers): transformers-related utilities. - [candle-transformers](./candle-transformers): transformers-related utilities.
- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer. - [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer.
- [candle-onnx](./candle-onnx/): ONNX model evaluation.
## FAQ ## FAQ

View File

@ -30,6 +30,7 @@ safetensors = { workspace = true }
thiserror = { workspace = true } thiserror = { workspace = true }
yoke = { workspace = true } yoke = { workspace = true }
zip = { workspace = true } zip = { workspace = true }
tracing = { workspace = true }
[dev-dependencies] [dev-dependencies]
anyhow = { workspace = true } anyhow = { workspace = true }
@ -41,4 +42,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
cudnn = ["cuda", "cudarc/cudnn"] cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"] mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"] accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels"] metal = ["dep:candle-metal-kernels", "dep:metal"]

View File

@ -39,14 +39,6 @@ pub trait BackendStorage: Sized {
_params: &crate::conv::ParamsConv1D, _params: &crate::conv::ParamsConv1D,
) -> Result<Self>; ) -> Result<Self>;
fn conv_transpose1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self>;
fn conv2d( fn conv2d(
&self, &self,
_l: &Layout, _l: &Layout,

View File

@ -15,17 +15,6 @@ fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result
} }
} }
thread_local! {
static CANDLE_GRAD_DO_NOT_DETACH: bool = {
match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") {
Ok(s) => {
!s.is_empty() && s != "0"
},
Err(_) => false,
}
}
}
impl Tensor { impl Tensor {
/// Return all the nodes that lead to this value in a topologically sorted vec, the first /// Return all the nodes that lead to this value in a topologically sorted vec, the first
/// elements having dependencies on the latter ones, e.g. the first element if any is the /// elements having dependencies on the latter ones, e.g. the first element if any is the
@ -68,11 +57,6 @@ impl Tensor {
kernel: rhs, kernel: rhs,
.. ..
} }
| Op::ConvTranspose1D {
arg: lhs,
kernel: rhs,
..
}
| Op::Conv2D { | Op::Conv2D {
arg: lhs, arg: lhs,
kernel: rhs, kernel: rhs,
@ -166,16 +150,10 @@ impl Tensor {
if node.is_variable() { if node.is_variable() {
continue; continue;
} }
let grad = grads let grad = grads.remove(node).unwrap();
.remove(node) // TODO: We should perform all these operations in place (or at least not track the
.expect("candle internal error - grad not populated"); // whole graph). The only drawback would be if we wanted to support grad of grad but
// https://github.com/huggingface/candle/issues/1241 // this is out of scope.
// Ideally, we would make these operations in place where possible to ensure that we
// do not have to allocate too often. Here we just call `.detach` to avoid computing
// the backprop graph of the backprop itself. This would be an issue for second order
// derivatives but these are out of scope at the moment.
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
let grad = if do_not_detach { grad } else { grad.detach()? };
if let Some(op) = node.op() { if let Some(op) = node.op() {
match op { match op {
Op::Binary(lhs, rhs, BinaryOp::Add) => { Op::Binary(lhs, rhs, BinaryOp::Add) => {
@ -230,44 +208,7 @@ impl Tensor {
let f_grad = pred.where_cond(&zeros, &grad)?; let f_grad = pred.where_cond(&zeros, &grad)?;
*f_sum_grad = f_sum_grad.add(&f_grad)?; *f_sum_grad = f_sum_grad.add(&f_grad)?;
} }
Op::Conv1D { Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
arg,
kernel,
padding,
stride,
dilation,
} => {
// The output height for conv_transpose1d is:
// (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1
let grad_l_in = grad.dim(2)?;
let k_size = kernel.dim(2)?;
let out_size =
(grad_l_in - 1) * stride + dilation * (k_size - 1) + 1 - 2 * padding;
let out_padding = arg.dim(2)? - out_size;
let grad_arg = grad.conv_transpose1d(
kernel,
*padding,
out_padding,
*stride,
*dilation,
)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
let grad_kernel = arg
.transpose(0, 1)?
.conv1d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
let (_, _, k0) = kernel.dims3()?;
let (_, _, g_k0) = grad_kernel.dims3()?;
let grad_kernel = if g_k0 != k0 {
grad_kernel.narrow(2, 0, k0)?
} else {
grad_kernel
};
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::Conv2D { Op::Conv2D {
arg, arg,
kernel, kernel,
@ -306,9 +247,6 @@ impl Tensor {
}; };
*sum_grad = sum_grad.add(&grad_kernel)?; *sum_grad = sum_grad.add(&grad_kernel)?;
} }
Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
op: "conv-transpose1d",
})?,
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported { Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
op: "conv-transpose2d", op: "conv-transpose2d",
})?, })?,
@ -549,38 +487,16 @@ impl Tensor {
+ 0.5)?; + 0.5)?;
*sum_grad = sum_grad.add(&(&grad * gelu_grad)?)? *sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
} }
Op::Unary(arg, UnaryOp::Erf) => { Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
let sum_grad = grads.or_insert(arg)?; Op::Unary(_, UnaryOp::GeluErf) => {
// d/dx erf(x) = 2/sqrt(pi) * e^(-x^2) Err(Error::BackwardNotSupported { op: "gelu-erf" })?
let erf_grad =
(2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?;
*sum_grad = sum_grad.add(&(&grad * erf_grad)?)?
}
Op::Unary(arg, UnaryOp::GeluErf) => {
let sum_grad = grads.or_insert(arg)?;
// d/dx gelu_erf(x) = 0.5 + 0.398942 e^(-x^2/2) x + 0.5 erf(x/sqrt(2))
let neg_half_square = (arg.sqr()?.neg()? / 2.)?;
let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?;
let arg_scaled_sqrt = (arg / 2f64.sqrt())?;
let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?;
let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?;
*sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?;
} }
Op::Unary(arg, UnaryOp::Relu) => { Op::Unary(arg, UnaryOp::Relu) => {
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?; let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)? *sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
} }
Op::Elu(arg, alpha) => { Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
let sum_grad = grads.or_insert(arg)?;
let zeros = arg.zeros_like()?;
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?;
let combined_mask = (positive_mask + negative_exp_mask)?;
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
}
Op::Powf(arg, e) => { Op::Powf(arg, e) => {
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?; let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;

View File

@ -25,33 +25,6 @@ impl ParamsConv1D {
} }
} }
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConvTranspose1D {
pub(crate) b_size: usize,
pub(crate) l_in: usize,
pub(crate) c_out: usize,
pub(crate) c_in: usize,
pub(crate) k_size: usize,
pub(crate) padding: usize,
pub(crate) output_padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
}
impl ParamsConvTranspose1D {
pub(crate) fn l_out(&self) -> usize {
(self.l_in - 1) * self.stride - 2 * self.padding
+ self.dilation * (self.k_size - 1)
+ self.output_padding
+ 1
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
let l_out = self.l_out();
vec![self.b_size, self.c_out, l_out]
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum CudnnFwdAlgo { pub enum CudnnFwdAlgo {
ImplicitGemm, ImplicitGemm,
@ -187,49 +160,6 @@ impl Tensor {
} }
} }
/// Applies a 1D transposed convolution over the input tensor.
pub fn conv_transpose1d(
&self,
kernel: &Self,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
) -> Result<Self> {
let (b_size, c_in, l_in) = self.dims3()?;
let (c_in_k, c_out, k_size) = kernel.dims3()?;
if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
}
let params = ParamsConvTranspose1D {
b_size,
l_in,
k_size,
c_out,
c_in,
padding,
output_padding,
stride,
dilation,
};
let storage = self.storage().conv_transpose1d(
self.layout(),
&kernel.storage(),
kernel.layout(),
&params,
)?;
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
arg,
kernel,
padding: params.padding,
output_padding: params.output_padding,
stride: params.stride,
dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
}
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> { fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
let storage = let storage =
self.storage() self.storage()

View File

@ -1256,74 +1256,6 @@ impl Map1 for Im2Col {
} }
} }
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
impl<'a> Map2 for ConvTranspose1D<'a> {
const OP: &'static str = "conv_transpose1d";
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
let inp = &inp[inp_l.start_offset()..];
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
let l_out = p.l_out();
// Output shape: [b_size, c_out, l_out].
let dst_elems = p.c_out * l_out * p.b_size;
let dst = vec![T::zero(); dst_elems];
let dst_s0 = p.c_out * l_out;
let dst_s1 = l_out;
let dst_s2 = 1;
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
let cont_s0 = p.l_in * p.c_in;
let cont_s1 = p.c_in;
for b_idx in 0..p.b_size {
for l_idx in 0..p.l_in {
for c_idx in 0..p.c_in {
let src_idx = b_idx * inp_s0 + c_idx * inp_s1 + l_idx * inp_s2;
let dst_idx = b_idx * cont_s0 + l_idx * cont_s1 + c_idx;
inp_cont[dst_idx] = inp[src_idx]
}
}
}
for k_idx in 0..p.k_size {
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
let k_cont = (0..p.c_in)
.map(|c_in_idx| k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_idx * k_s2])
.collect::<Vec<_>>();
for b_idx in 0..p.b_size {
for l_idx in 0..p.l_in {
let out_idx = l_idx * p.stride + k_idx * p.dilation;
if out_idx < p.padding {
continue;
}
let out_idx = out_idx - p.padding;
if out_idx < l_out {
let inp_cont = &inp_cont[b_idx * cont_s0 + l_idx * cont_s1..];
let dst_idx = b_idx * dst_s0 + out_idx * dst_s2 + dst_c_idx * dst_s1;
let mut d = T::zero();
unsafe {
T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
}
let dst_p = dst.as_ptr();
// Safety: dst_idx are uniques per dst_c_idx which is used to
// parallelise the different tasks so no two threads can try to
// write at the same location.
unsafe {
let ptr = dst_p.add(dst_idx) as *mut T;
*ptr += d
}
}
}
}
})
}
Ok(dst)
}
}
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D); struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
impl<'a> Map2 for Conv2D<'a> { impl<'a> Map2 for Conv2D<'a> {
@ -2503,16 +2435,6 @@ impl BackendStorage for CpuStorage {
Ok(res_t) Ok(res_t)
} }
fn conv_transpose1d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
}
fn conv2d( fn conv2d(
&self, &self,
l: &Layout, l: &Layout,

View File

@ -1808,16 +1808,6 @@ impl BackendStorage for CudaStorage {
Ok(res_t) Ok(res_t)
} }
fn conv_transpose1d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
todo!()
}
#[cfg(not(feature = "cudnn"))] #[cfg(not(feature = "cudnn"))]
fn conv2d( fn conv2d(
&self, &self,

View File

@ -1,6 +1,6 @@
use crate::backend::BackendDevice; use crate::backend::BackendDevice;
use crate::cpu_backend::CpuDevice; use crate::cpu_backend::CpuDevice;
use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType}; use crate::{bail, CpuStorage, DType, Result, Shape, Storage, WithDType};
/// A `DeviceLocation` represents a physical device whereas multiple `Device` /// A `DeviceLocation` represents a physical device whereas multiple `Device`
/// can live on the same location (typically for cuda devices). /// can live on the same location (typically for cuda devices).
@ -8,7 +8,7 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
pub enum DeviceLocation { pub enum DeviceLocation {
Cpu, Cpu,
Cuda { gpu_id: usize }, Cuda { gpu_id: usize },
Metal { gpu_id: usize }, Metal,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -105,14 +105,14 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4:
impl<S: NdArray> NdArray for Vec<S> { impl<S: NdArray> NdArray for Vec<S> {
fn shape(&self) -> Result<Shape> { fn shape(&self) -> Result<Shape> {
if self.is_empty() { if self.is_empty() {
crate::bail!("empty array") bail!("empty array")
} }
let shape0 = self[0].shape()?; let shape0 = self[0].shape()?;
let n = self.len(); let n = self.len();
for v in self.iter() { for v in self.iter() {
let shape = v.shape()?; let shape = v.shape()?;
if shape != shape0 { if shape != shape0 {
crate::bail!("two elements have different shapes {shape:?} {shape0:?}") bail!("two elements have different shapes {shape:?} {shape0:?}")
} }
} }
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat())) Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
@ -146,7 +146,6 @@ impl Device {
match (self, rhs) { match (self, rhs) {
(Self::Cpu, Self::Cpu) => true, (Self::Cpu, Self::Cpu) => true,
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs), (Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
(Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
_ => false, _ => false,
} }
} }
@ -167,10 +166,6 @@ impl Device {
matches!(self, Self::Cuda(_)) matches!(self, Self::Cuda(_))
} }
pub fn is_metal(&self) -> bool {
matches!(self, Self::Metal(_))
}
pub fn cuda_if_available(ordinal: usize) -> Result<Self> { pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
if crate::utils::cuda_is_available() { if crate::utils::cuda_is_available() {
Self::new_cuda(ordinal) Self::new_cuda(ordinal)
@ -192,19 +187,13 @@ impl Device {
Ok(Storage::Cpu(storage)) Ok(Storage::Cpu(storage))
} }
Device::Cuda(device) => { Device::Cuda(device) => {
// TODO: Remove the special case if we start supporting generating f16/bf16 directly. let storage = device.rand_uniform(shape, dtype, lo, up)?;
if dtype == DType::F16 || dtype == DType::BF16 { Ok(Storage::Cuda(storage))
let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
} else {
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cuda(storage))
}
} }
Device::Metal(_device) => { Device::Metal(_device) => {
// let storage = device.rand_uniform(shape, dtype, lo, up)?; // let storage = device.rand_uniform(shape, dtype, lo, up)?;
// Ok(Storage::Metal(storage)) // Ok(Storage::Metal(storage))
crate::bail!("Metal rand_uniform not implemented") bail!("Metal rand_uniform not implemented")
} }
} }
} }
@ -231,14 +220,8 @@ impl Device {
Ok(Storage::Cpu(storage)) Ok(Storage::Cpu(storage))
} }
Device::Cuda(device) => { Device::Cuda(device) => {
// TODO: Remove the special case if we start supporting generating f16/bf16 directly. let storage = device.rand_normal(shape, dtype, mean, std)?;
if dtype == DType::F16 || dtype == DType::BF16 { Ok(Storage::Cuda(storage))
let storage = device.rand_normal(shape, DType::F32, mean, std)?;
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
} else {
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage))
}
} }
Device::Metal(device) => { Device::Metal(device) => {
let storage = device.rand_normal(shape, dtype, mean, std)?; let storage = device.rand_normal(shape, dtype, mean, std)?;

View File

@ -14,9 +14,7 @@ impl Tensor {
crate::DeviceLocation::Cuda { gpu_id } => { crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id) format!(", cuda:{}", gpu_id)
} }
crate::DeviceLocation::Metal { gpu_id } => { _ => todo!(),
format!(", metal:{}", gpu_id)
}
}; };
write!(f, "Tensor[")?; write!(f, "Tensor[")?;
@ -479,9 +477,7 @@ impl std::fmt::Display for Tensor {
crate::DeviceLocation::Cuda { gpu_id } => { crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id) format!(", cuda:{}", gpu_id)
} }
crate::DeviceLocation::Metal { gpu_id } => { crate::DeviceLocation::Metal => todo!(),
format!(", metal:{}", gpu_id)
}
}; };
write!( write!(

View File

@ -79,16 +79,6 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }
fn conv_transpose1d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn conv2d( fn conv2d(
&self, &self,
_: &Layout, _: &Layout,

View File

@ -8,18 +8,6 @@ pub struct MetalDevice;
#[derive(Debug)] #[derive(Debug)]
pub struct MetalStorage; pub struct MetalStorage;
#[derive(thiserror::Error, Debug)]
pub enum MetalError {
#[error("{0}")]
Message(String),
}
impl From<String> for MetalError {
fn from(e: String) -> Self {
MetalError::Message(e)
}
}
macro_rules! fail { macro_rules! fail {
() => { () => {
unimplemented!("metal support has not been enabled, add `metal` feature to enable.") unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
@ -91,16 +79,6 @@ impl crate::backend::BackendStorage for MetalStorage {
Err(Error::NotCompiledWithMetalSupport) Err(Error::NotCompiledWithMetalSupport)
} }
fn conv_transpose1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv2d( fn conv2d(
&self, &self,
_: &Layout, _: &Layout,

View File

@ -1,4 +1,4 @@
use crate::{DType, DeviceLocation, Layout, MetalError, Shape}; use crate::{DType, DeviceLocation, Layout, Shape};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct MatMulUnexpectedStriding { pub struct MatMulUnexpectedStriding {
@ -163,7 +163,7 @@ pub enum Error {
Cuda(Box<dyn std::error::Error + Send + Sync>), Cuda(Box<dyn std::error::Error + Send + Sync>),
#[error("Metal error {0}")] #[error("Metal error {0}")]
Metal(#[from] MetalError), Metal(String),
#[error(transparent)] #[error(transparent)]
TryFromIntError(#[from] core::num::TryFromIntError), TryFromIntError(#[from] core::num::TryFromIntError),

View File

@ -49,12 +49,13 @@ mod device;
pub mod display; pub mod display;
mod dtype; mod dtype;
mod dummy_cuda_backend; mod dummy_cuda_backend;
mod dummy_metal_backend;
pub mod error; pub mod error;
mod indexer; mod indexer;
pub mod layout; pub mod layout;
#[cfg(feature = "metal")] #[cfg(feature = "metal")]
pub mod metal_backend; pub mod metal_backend;
#[cfg(feature = "accelerate")]
mod metal_backend;
#[cfg(feature = "mkl")] #[cfg(feature = "mkl")]
mod mkl; mod mkl;
pub mod npy; pub mod npy;
@ -91,10 +92,10 @@ pub use cuda_backend::{CudaDevice, CudaStorage};
pub use dummy_cuda_backend::{CudaDevice, CudaStorage}; pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
#[cfg(feature = "metal")] #[cfg(feature = "metal")]
pub use metal_backend::{MetalDevice, MetalError, MetalStorage}; pub use metal_backend::{MetalDevice, MetalStorage};
#[cfg(not(feature = "metal"))] #[cfg(not(feature = "metal"))]
pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage}; pub use dummy_metal_backend::{MetalDevice, MetalStorage};
#[cfg(feature = "mkl")] #[cfg(feature = "mkl")]
extern crate intel_mkl_src; extern crate intel_mkl_src;

File diff suppressed because it is too large Load Diff

View File

@ -90,16 +90,6 @@ pub enum Op {
dilation: usize, dilation: usize,
}, },
#[allow(dead_code)]
ConvTranspose1D {
arg: Tensor,
kernel: Tensor,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
},
#[allow(dead_code)] #[allow(dead_code)]
Conv2D { Conv2D {
arg: Tensor, arg: Tensor,
@ -593,8 +583,7 @@ unary_op!(Recip, "recip", v, v.recip());
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr); unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt); unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
/// Tanh based approximation of the `gelu` operation /// `gelu` operation
/// GeluErf is the more precise one.
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions> /// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
impl UnaryOpT for Gelu { impl UnaryOpT for Gelu {
const NAME: &'static str = "gelu"; const NAME: &'static str = "gelu";
@ -684,8 +673,6 @@ impl UnaryOpT for Gelu {
} }
} }
/// `erf` operation
/// <https://en.wikipedia.org/wiki/Error_function>
impl UnaryOpT for Erf { impl UnaryOpT for Erf {
const NAME: &'static str = "erf"; const NAME: &'static str = "erf";
const KERNEL: &'static str = "uerf"; const KERNEL: &'static str = "uerf";
@ -975,10 +962,6 @@ impl BackpropOp {
}; };
Self(op) Self(op)
} }
pub(crate) fn is_none(&self) -> bool {
self.0.is_none()
}
} }
impl std::ops::Deref for BackpropOp { impl std::ops::Deref for BackpropOp {

View File

@ -1,7 +1,7 @@
//! Support for the GGML file format. //! Support for the GGML file format.
use super::{k_quants, GgmlDType}; use super::{k_quants, GgmlDType};
use crate::Result; use crate::{Device, Result};
use byteorder::{LittleEndian, ReadBytesExt}; use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap; use std::collections::HashMap;
@ -121,11 +121,12 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
raw_data: &[u8], raw_data: &[u8],
size_in_bytes: usize, size_in_bytes: usize,
dims: Vec<usize>, dims: Vec<usize>,
device: &Device,
) -> Result<super::QTensor> { ) -> Result<super::QTensor> {
let raw_data_ptr = raw_data.as_ptr(); let raw_data_ptr = raw_data.as_ptr();
let n_blocks = size_in_bytes / std::mem::size_of::<T>(); let n_blocks = size_in_bytes / std::mem::size_of::<T>();
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) }; let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
super::QTensor::new(data.to_vec(), dims) super::QTensor::new(data.to_vec(), dims, device)
} }
/// Creates a [Tensor] from a raw GGML tensor. /// Creates a [Tensor] from a raw GGML tensor.
@ -133,6 +134,7 @@ pub fn qtensor_from_ggml(
ggml_dtype: GgmlDType, ggml_dtype: GgmlDType,
raw_data: &[u8], raw_data: &[u8],
dims: Vec<usize>, dims: Vec<usize>,
device: &Device,
) -> Result<super::QTensor> { ) -> Result<super::QTensor> {
let tensor_elems = dims.iter().product::<usize>(); let tensor_elems = dims.iter().product::<usize>();
let blck_size = ggml_dtype.blck_size(); let blck_size = ggml_dtype.blck_size();
@ -144,18 +146,38 @@ pub fn qtensor_from_ggml(
let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size(); let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
match ggml_dtype { match ggml_dtype {
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims), GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims), GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims), GgmlDType::Q4_0 => {
GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims), from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims), }
GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims), GgmlDType::Q4_1 => {
GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims), from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device)
GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims), }
GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims), GgmlDType::Q5_0 => {
GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims), from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device)
GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims), }
GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims), GgmlDType::Q5_1 => {
from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q8_0 => {
from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q2K => {
from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q3K => {
from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q4K => {
from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5K => {
from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q6K => {
from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
}
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"), _ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
} }
} }
@ -163,6 +185,7 @@ pub fn qtensor_from_ggml(
fn read_one_tensor<R: std::io::Seek + std::io::Read>( fn read_one_tensor<R: std::io::Seek + std::io::Read>(
reader: &mut R, reader: &mut R,
magic: VersionedMagic, magic: VersionedMagic,
device: &Device,
) -> Result<(String, super::QTensor)> { ) -> Result<(String, super::QTensor)> {
let n_dims = reader.read_u32::<LittleEndian>()?; let n_dims = reader.read_u32::<LittleEndian>()?;
let name_len = reader.read_u32::<LittleEndian>()?; let name_len = reader.read_u32::<LittleEndian>()?;
@ -187,7 +210,7 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
// TODO: Mmap version to avoid copying the data around? // TODO: Mmap version to avoid copying the data around?
let mut raw_data = vec![0u8; size_in_bytes]; let mut raw_data = vec![0u8; size_in_bytes];
reader.read_exact(&mut raw_data)?; reader.read_exact(&mut raw_data)?;
match qtensor_from_ggml(ggml_dtype, &raw_data, dims) { match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) {
Ok(tensor) => Ok((name, tensor)), Ok(tensor) => Ok((name, tensor)),
Err(e) => crate::bail!("Error creating tensor {name}: {e}"), Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
} }
@ -201,7 +224,10 @@ pub struct Content {
} }
impl Content { impl Content {
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> { pub fn read<R: std::io::Seek + std::io::Read>(
reader: &mut R,
device: &Device,
) -> Result<Content> {
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505 // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
let last_position = reader.seek(std::io::SeekFrom::End(0))?; let last_position = reader.seek(std::io::SeekFrom::End(0))?;
reader.seek(std::io::SeekFrom::Start(0))?; reader.seek(std::io::SeekFrom::Start(0))?;
@ -211,7 +237,7 @@ impl Content {
let mut tensors = HashMap::new(); let mut tensors = HashMap::new();
while reader.stream_position()? != last_position { while reader.stream_position()? != last_position {
let (name, tensor) = read_one_tensor(reader, magic)?; let (name, tensor) = read_one_tensor(reader, magic, device)?;
tensors.insert(name, tensor); tensors.insert(name, tensor);
} }
Ok(Self { Ok(Self {

View File

@ -3,7 +3,7 @@
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md //! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
use super::{GgmlDType, QTensor}; use super::{GgmlDType, QTensor};
use crate::Result; use crate::{Device, Result};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::collections::HashMap; use std::collections::HashMap;
@ -29,7 +29,6 @@ impl TryFrom<u32> for Magic {
pub enum VersionedMagic { pub enum VersionedMagic {
GgufV1, GgufV1,
GgufV2, GgufV2,
GgufV3,
} }
impl VersionedMagic { impl VersionedMagic {
@ -40,7 +39,6 @@ impl VersionedMagic {
let versioned_magic = match (magic, version) { let versioned_magic = match (magic, version) {
(Magic::Gguf, 1) => Self::GgufV1, (Magic::Gguf, 1) => Self::GgufV1,
(Magic::Gguf, 2) => Self::GgufV2, (Magic::Gguf, 2) => Self::GgufV2,
(Magic::Gguf, 3) => Self::GgufV3,
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"), _ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
}; };
Ok(versioned_magic) Ok(versioned_magic)
@ -59,6 +57,7 @@ impl TensorInfo {
&self, &self,
reader: &mut R, reader: &mut R,
tensor_data_offset: u64, tensor_data_offset: u64,
device: &Device,
) -> Result<QTensor> { ) -> Result<QTensor> {
let tensor_elems = self.shape.elem_count(); let tensor_elems = self.shape.elem_count();
let blck_size = self.ggml_dtype.blck_size(); let blck_size = self.ggml_dtype.blck_size();
@ -71,7 +70,12 @@ impl TensorInfo {
let mut raw_data = vec![0u8; size_in_bytes]; let mut raw_data = vec![0u8; size_in_bytes];
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?; reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
reader.read_exact(&mut raw_data)?; reader.read_exact(&mut raw_data)?;
super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec()) super::ggml_file::qtensor_from_ggml(
self.ggml_dtype,
&raw_data,
self.shape.dims().to_vec(),
device,
)
} }
} }
@ -86,9 +90,7 @@ pub struct Content {
fn read_string<R: std::io::Read>(reader: &mut R, magic: &VersionedMagic) -> Result<String> { fn read_string<R: std::io::Read>(reader: &mut R, magic: &VersionedMagic) -> Result<String> {
let len = match magic { let len = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize, VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => { VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
reader.read_u64::<LittleEndian>()? as usize
}
}; };
let mut v = vec![0u8; len]; let mut v = vec![0u8; len];
reader.read_exact(&mut v)?; reader.read_exact(&mut v)?;
@ -288,9 +290,7 @@ impl Value {
let value_type = ValueType::from_u32(value_type)?; let value_type = ValueType::from_u32(value_type)?;
let len = match magic { let len = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize, VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => { VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
reader.read_u64::<LittleEndian>()? as usize
}
}; };
let mut vs = Vec::with_capacity(len); let mut vs = Vec::with_capacity(len);
for _ in 0..len { for _ in 0..len {
@ -387,15 +387,11 @@ impl Content {
let tensor_count = match magic { let tensor_count = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize, VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => { VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
reader.read_u64::<LittleEndian>()? as usize
}
}; };
let metadata_kv_count = match magic { let metadata_kv_count = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize, VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => { VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
reader.read_u64::<LittleEndian>()? as usize
}
}; };
let mut metadata = HashMap::new(); let mut metadata = HashMap::new();
@ -417,7 +413,7 @@ impl Content {
reader.read_u32_into::<LittleEndian>(&mut dimensions)?; reader.read_u32_into::<LittleEndian>(&mut dimensions)?;
dimensions.into_iter().map(|c| c as usize).collect() dimensions.into_iter().map(|c| c as usize).collect()
} }
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => { VersionedMagic::GgufV2 => {
let mut dimensions = vec![0; n_dimensions as usize]; let mut dimensions = vec![0; n_dimensions as usize];
reader.read_u64_into::<LittleEndian>(&mut dimensions)?; reader.read_u64_into::<LittleEndian>(&mut dimensions)?;
dimensions.into_iter().map(|c| c as usize).collect() dimensions.into_iter().map(|c| c as usize).collect()
@ -460,12 +456,13 @@ impl Content {
&self, &self,
reader: &mut R, reader: &mut R,
name: &str, name: &str,
device: &Device,
) -> Result<QTensor> { ) -> Result<QTensor> {
let tensor_info = match self.tensor_infos.get(name) { let tensor_info = match self.tensor_infos.get(name) {
Some(tensor_info) => tensor_info, Some(tensor_info) => tensor_info,
None => crate::bail!("cannot find tensor-infor for {name}"), None => crate::bail!("cannot find tensor-infor for {name}"),
}; };
tensor_info.read(reader, self.tensor_data_offset) tensor_info.read(reader, self.tensor_data_offset, device)
} }
} }

View File

@ -14,6 +14,7 @@ pub mod utils;
pub use k_quants::GgmlType; pub use k_quants::GgmlType;
pub struct QTensor { pub struct QTensor {
device: Device,
data: Box<dyn QuantizedType>, data: Box<dyn QuantizedType>,
shape: Shape, shape: Shape,
} }
@ -170,17 +171,20 @@ impl QTensor {
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>( pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
data: Vec<T>, data: Vec<T>,
shape: S, shape: S,
device: &Device,
) -> Result<Self> { ) -> Result<Self> {
let shape = shape.into(); let shape = shape.into();
check_shape::<T>(&shape)?; check_shape::<T>(&shape)?;
Ok(Self { Ok(Self {
data: Box::new(data), data: Box::new(data),
shape, shape,
device: device.clone(),
}) })
} }
pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> { pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
let shape = src.shape(); let shape = src.shape();
let device = src.device();
check_shape::<T>(shape)?; check_shape::<T>(shape)?;
let src = src let src = src
.to_dtype(crate::DType::F32)? .to_dtype(crate::DType::F32)?
@ -197,6 +201,7 @@ impl QTensor {
Ok(Self { Ok(Self {
data: Box::new(data), data: Box::new(data),
shape: shape.clone(), shape: shape.clone(),
device: device.clone(),
}) })
} }
@ -212,7 +217,12 @@ impl QTensor {
&self.shape &self.shape
} }
pub fn device(&self) -> &Device {
&self.device
}
pub fn dequantize(&self, device: &Device) -> Result<Tensor> { pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
// TODO Skip the CPU part on metal
let mut f32_data = vec![0f32; self.shape.elem_count()]; let mut f32_data = vec![0f32; self.shape.elem_count()];
self.data.to_float(&mut f32_data)?; self.data.to_float(&mut f32_data)?;
Tensor::from_vec(f32_data, &self.shape, device) Tensor::from_vec(f32_data, &self.shape, device)
@ -305,6 +315,49 @@ impl crate::CustomOp1 for QTensor {
)?; )?;
Ok((crate::CpuStorage::F32(dst_storage), dst_shape)) Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
} }
fn metal_fwd(
&self,
storage: &crate::MetalStorage,
layout: &crate::Layout,
) -> Result<(crate::MetalStorage, Shape)> {
println!("TODO qmatmul");
if !layout.is_contiguous() {
crate::bail!("input tensor is not contiguous {layout:?}")
}
let src_shape = layout.shape();
// self is transposed so n is first then k.
let (n, k) = self.shape.dims2()?;
if src_shape.rank() < 2 {
crate::bail!("input tensor has only one dimension {layout:?}")
}
let mut dst_shape = src_shape.dims().to_vec();
let last_k = dst_shape.pop().unwrap();
if last_k != k {
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
}
dst_shape.push(n);
let dst_shape = Shape::from(dst_shape);
// let storage = storage.as_slice::<f32>()?;
// let storage =
// &storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
let dst_storage = vec![0f32; dst_shape.elem_count()];
// self.matmul_t(
// (dst_shape.elem_count() / n, k, n),
// storage,
// &mut dst_storage,
// )?;
let cpu_storage = crate::CpuStorage::F32(dst_storage);
use crate::backend::{BackendDevice, BackendStorage};
if let Device::Metal(device) = &self.device{
Ok((
device.storage_from_cpu_storage(&cpu_storage)?,
dst_shape,
))
}else{
crate::bail!("qtensor not on metal device")
}
}
} }
impl QMatMul { impl QMatMul {

View File

@ -334,33 +334,6 @@ impl Storage {
} }
} }
pub(crate) fn conv_transpose1d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
self.same_device(kernel, "conv-transpose1d")?;
self.same_dtype(kernel, "conv-transpose1d")?;
match (self, &kernel) {
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
Ok(Self::Cpu(s))
}
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "conv-transpose1d",
}
.bt()),
}
}
pub(crate) fn conv2d( pub(crate) fn conv2d(
&self, &self,
l: &Layout, l: &Layout,

View File

@ -477,12 +477,6 @@ impl Tensor {
broadcast_binary_op!(broadcast_div, div); broadcast_binary_op!(broadcast_div, div);
broadcast_binary_op!(broadcast_maximum, maximum); broadcast_binary_op!(broadcast_maximum, maximum);
broadcast_binary_op!(broadcast_minimum, minimum); broadcast_binary_op!(broadcast_minimum, minimum);
broadcast_binary_op!(broadcast_eq, eq);
broadcast_binary_op!(broadcast_ne, ne);
broadcast_binary_op!(broadcast_lt, lt);
broadcast_binary_op!(broadcast_le, le);
broadcast_binary_op!(broadcast_gt, gt);
broadcast_binary_op!(broadcast_ge, ge);
unary_op!(recip, Recip); unary_op!(recip, Recip);
unary_op!(neg, Neg); unary_op!(neg, Neg);
@ -856,20 +850,6 @@ impl Tensor {
self.sum_impl(mean_dims, false)? * scale self.sum_impl(mean_dims, false)? * scale
} }
/// Returns the unbiased variance over the selected dimension.
pub fn var_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "var")?;
let mean = self.mean_keepdim(dim)?;
let squares = self.broadcast_sub(&mean)?.sqr()?;
squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64
}
/// Returns the unbiased variance over the selected dimension.
pub fn var<D: Dim>(&self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "var")?;
self.var_keepdim(dim)?.squeeze(dim)
}
/// Gathers the maximum value across the selected dimension. The resulting shape has the same /// Gathers the maximum value across the selected dimension. The resulting shape has the same
/// number of dimensions as the original tensor and the select dimension has a single element. /// number of dimensions as the original tensor and the select dimension has a single element.
pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> { pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
@ -1831,23 +1811,17 @@ impl Tensor {
/// Returns a new tensor detached from the current graph, gradient are not propagated through /// Returns a new tensor detached from the current graph, gradient are not propagated through
/// this new node. The storage of this tensor is shared with the initial tensor. /// this new node. The storage of this tensor is shared with the initial tensor.
///
/// If the tensor is already detached from the computation graph, the same tensor is returned.
pub fn detach(&self) -> Result<Tensor> { pub fn detach(&self) -> Result<Tensor> {
if self.op.is_none() && !self.is_variable { let tensor_ = Tensor_ {
Ok(self.clone()) id: TensorId::new(),
} else { storage: self.storage.clone(),
let tensor_ = Tensor_ { layout: self.layout.clone(),
id: TensorId::new(), op: BackpropOp::none(),
storage: self.storage.clone(), is_variable: false,
layout: self.layout.clone(), dtype: self.dtype,
op: BackpropOp::none(), device: self.device.clone(),
is_variable: false, };
dtype: self.dtype, Ok(Tensor(Arc::new(tensor_)))
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
}
} }
/// If the target device is the same as the tensor device, only a shallow copy is performed. /// If the target device is the same as the tensor device, only a shallow copy is performed.
@ -1859,14 +1833,7 @@ impl Tensor {
(Storage::Cpu(storage), Device::Cuda(cuda)) => { (Storage::Cpu(storage), Device::Cuda(cuda)) => {
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?) Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
} }
(Storage::Cpu(storage), Device::Metal(metal)) => {
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
}
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
(Storage::Metal(storage), Device::Cpu) => {
println!("{storage:?} - {:?}", storage.to_cpu_storage()?);
Storage::Cpu(storage.to_cpu_storage()?)
}
(Storage::Cuda(storage), Device::Cuda(cuda)) => { (Storage::Cuda(storage), Device::Cuda(cuda)) => {
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids // TODO: Avoid passing through the cpu storage here, especially if the gpu ids
// are the same. // are the same.
@ -2440,23 +2407,6 @@ impl Tensor {
) -> Result<Self> { ) -> Result<Self> {
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c))) self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
} }
/// Normalize a 'relative' axis value: positive values are kept, negative
/// values means counting the dimensions from the back.
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
let rank = self.rank() as i64;
if rank <= axis {
crate::bail!("axis {axis} is too large, tensor rank {rank}")
} else if 0 <= axis {
Ok(axis as usize)
} else {
let naxis = rank + axis;
if naxis < 0 {
crate::bail!("axis {axis} is too small, tensor rank {rank}")
}
Ok(naxis as usize)
}
}
} }
macro_rules! bin_trait { macro_rules! bin_trait {

View File

@ -4,7 +4,7 @@ use crate::{Result, Tensor};
macro_rules! test_device { macro_rules! test_device {
// TODO: Switch to generating the two last arguments automatically once concat_idents is // TODO: Switch to generating the two last arguments automatically once concat_idents is
// stable. https://github.com/rust-lang/rust/issues/29599 // stable. https://github.com/rust-lang/rust/issues/29599
($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => { ($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => {
#[test] #[test]
fn $test_cpu() -> Result<()> { fn $test_cpu() -> Result<()> {
$fn_name(&Device::Cpu) $fn_name(&Device::Cpu)
@ -15,12 +15,6 @@ macro_rules! test_device {
fn $test_cuda() -> Result<()> { fn $test_cuda() -> Result<()> {
$fn_name(&Device::new_cuda(0)?) $fn_name(&Device::new_cuda(0)?)
} }
#[cfg(feature = "metal")]
#[test]
fn $test_metal() -> Result<()> {
$fn_name(&Device::new_metal(0)?)
}
}; };
} }

View File

@ -13,11 +13,6 @@ res = torch.nn.functional.conv1d(t, w)
print(res.flatten()) print(res.flatten())
res = torch.nn.functional.conv1d(t, w, padding=1) res = torch.nn.functional.conv1d(t, w, padding=1)
print(res.flatten()) print(res.flatten())
w_t = w.transpose(0, 1)
res = torch.nn.functional.conv_transpose1d(t, w_t)
print(res.shape)
print(res)
*/ */
fn conv1d(dev: &Device) -> Result<()> { fn conv1d(dev: &Device) -> Result<()> {
let t = Tensor::new( let t = Tensor::new(
@ -50,17 +45,6 @@ fn conv1d(dev: &Device) -> Result<()> {
test_utils::to_vec1_round(&res.flatten_all()?, 4)?, test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352] [2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
); );
if dev.is_cpu() {
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 7]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
4.7076, -5.9745, -0.8276, 1.621
],
);
}
Ok(()) Ok(())
} }
@ -563,35 +547,14 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
Ok(()) Ok(())
} }
test_device!(conv1d, conv1d_cpu, conv1d_gpu, conv1d_metal); test_device!(conv1d, conv1d_cpu, conv1d_gpu);
test_device!( test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
conv1d_small, test_device!(conv2d, conv2d_cpu, conv2d_gpu);
conv1d_small_cpu,
conv1d_small_gpu,
conv1d_small_metal
);
test_device!(conv2d, conv2d_cpu, conv2d_gpu, conv2d_metal);
test_device!( test_device!(
conv2d_non_square, conv2d_non_square,
conv2d_non_square_cpu, conv2d_non_square_cpu,
conv2d_non_square_gpu, conv2d_non_square_gpu
conv2d_non_square_metal
);
test_device!(
conv2d_small,
conv2d_small_cpu,
conv2d_small_gpu,
conv2d_small_metal
);
test_device!(
conv2d_smaller,
conv2d_smaller_cpu,
conv2d_smaller_gpu,
conv2d_smaller_metal
);
test_device!(
conv2d_grad,
conv2d_grad_cpu,
conv2d_grad_gpu,
conv2_grad_metal
); );
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
test_device!(conv2d_grad, conv2d_grad_cpu, conv2d_grad_gpu);

View File

@ -205,71 +205,6 @@ fn unary_grad(device: &Device) -> Result<()> {
test_utils::to_vec1_round(grad_x, 4)?, test_utils::to_vec1_round(grad_x, 4)?,
[1.0116, 1.0830, 1.0003, 0.6188], [1.0116, 1.0830, 1.0003, 0.6188],
); );
// Testing compared to pytorch torch.erf
//
// import torch
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
// y = x.erf()
// print(y)
// loss = y.sum()
// loss.backward()
// print(x.grad)
let y = x.erf()?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(test_utils::to_vec1_round(&y, 4)?, [1.0, 0.8427, 1.0, 0.168]);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[0.0001, 0.4151, 0.0, 1.1033],
);
// Testing compared to pytorch nn.GELU(approximate = 'none')
//
// import torch
// import torch.nn.functional as F
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
// y = F.gelu(x, approximate='none')
// print(y)
// loss = y.sum()
// loss.backward()
// print(x.grad)
let y = x.gelu_erf()?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[2.9960, 0.8413, 3.9999, 0.0839]
);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[1.0119, 1.0833, 1.0005, 0.6188],
);
// Testing compared to pytorch elu
//
// import torch
// import torch.nn.functional as F
// x = torch.tensor([-1.0, 0.0, -2.0, 3.0], requires_grad=True)
// y = F.elu(x, alpha=2.0)
// print(y)
// loss = y.min
// loss = y.sum()
// loss.backward()
// print(x.grad)
let elu_x = Var::new(&[-1.0f32, 0., -2., 3.], device)?;
let y = elu_x.elu(2.)?;
let grads = y.backward()?;
let grad_x = grads.get(&elu_x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[-1.2642, 0.0000, -1.7293, 3.0000]
);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[0.7358, 2.0000, 0.2707, 1.0000]
);
Ok(()) Ok(())
} }
@ -315,29 +250,9 @@ fn binary_grad(device: &Device) -> Result<()> {
Ok(()) Ok(())
} }
test_device!( test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
simple_grad, test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
simple_grad_cpu, test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
simple_grad_gpu, test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
simple_grad_metal test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu);
); test_device!(binary_grad, binary_grad_cpu, binary_grad_gpu);
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu, sum_grad_metal);
test_device!(
matmul_grad,
matmul_grad_cpu,
matmul_grad_gpu,
matmul_grad_metal
);
test_device!(
grad_descent,
grad_descent_cpu,
grad_descent_gpu,
grad_descent_metal
);
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu, unary_grad_metal);
test_device!(
binary_grad,
binary_grad_cpu,
binary_grad_gpu,
binary_grad_metal
);

View File

@ -49,7 +49,7 @@ fn contiguous(device: &Device) -> Result<()> {
Ok(()) Ok(())
} }
test_device!(contiguous, contiguous_cpu, contiguous_gpu, contiguous_metal); test_device!(contiguous, contiguous_cpu, contiguous_gpu);
#[test] #[test]
fn strided_blocks() -> Result<()> { fn strided_blocks() -> Result<()> {

View File

@ -98,17 +98,15 @@ fn upsample_nearest2d(dev: &Device) -> Result<()> {
Ok(()) Ok(())
} }
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu, avg_pool2d_metal); test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu);
test_device!( test_device!(
avg_pool2d_pytorch, avg_pool2d_pytorch,
avg_pool2d_pytorch_cpu, avg_pool2d_pytorch_cpu,
avg_pool2d_pytorch_gpu, avg_pool2d_pytorch_gpu
avg_pool2d_pytorch_metal
); );
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu, max_pool2d_metal); test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);
test_device!( test_device!(
upsample_nearest2d, upsample_nearest2d,
upsample_nearest2d_cpu, upsample_nearest2d_cpu,
upsample_nearest2d_gpu, upsample_nearest2d_gpu
upsample_nearest2d_metal
); );

View File

@ -180,22 +180,6 @@ fn transpose(device: &Device) -> Result<()> {
Ok(()) Ok(())
} }
fn var(device: &Device) -> Result<()> {
// Values taken from https://pytorch.org/docs/stable/generated/torch.var.html
let data = &[
[0.2035f32, 1.2959, 1.8101, -0.4644],
[1.5027, -0.3270, 0.5905, 0.6538],
[-1.5745, 1.3330, -0.5596, -0.6548],
[0.1264, -0.5080, 1.6420, 0.1992],
];
let tensor = Tensor::new(data, device)?;
assert_eq!(
test_utils::to_vec2_round(&tensor.var_keepdim(1)?, 4)?,
&[[1.0631], [0.559], [1.4893], [0.8258]]
);
Ok(())
}
fn sum(device: &Device) -> Result<()> { fn sum(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]]; let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?; let tensor = Tensor::new(data, device)?;
@ -1070,60 +1054,34 @@ fn randn(device: &Device) -> Result<()> {
Ok(()) Ok(())
} }
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal); test_device!(zeros, zeros_cpu, zeros_gpu);
test_device!(ones, ones_cpu, ones_gpu, ones_metal); test_device!(ones, ones_cpu, ones_gpu);
test_device!(arange, arange_cpu, arange_gpu, arange_metal); test_device!(arange, arange_cpu, arange_gpu);
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal); test_device!(add_mul, add_mul_cpu, add_mul_gpu);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal); test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal); test_device!(narrow, narrow_cpu, narrow_gpu);
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal); test_device!(broadcast, broadcast_cpu, broadcast_gpu);
test_device!(cat, cat_cpu, cat_gpu, cat_metal); test_device!(cat, cat_cpu, cat_gpu);
test_device!(sum, sum_cpu, sum_gpu, sum_metal); test_device!(sum, sum_cpu, sum_gpu);
test_device!(min, min_cpu, min_gpu, min_metal); test_device!(min, min_cpu, min_gpu);
test_device!(max, max_cpu, max_gpu, max_metal); test_device!(max, max_cpu, max_gpu);
test_device!(argmax, argmax_cpu, argmax_gpu, argmax_metal); test_device!(argmax, argmax_cpu, argmax_gpu);
test_device!(argmin, argmin_cpu, argmin_gpu, argmin_metal); test_device!(argmin, argmin_cpu, argmin_gpu);
test_device!(transpose, transpose_cpu, transpose_gpu, transpose_metal); test_device!(transpose, transpose_cpu, transpose_gpu);
test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal); test_device!(unary_op, unary_op_cpu, unary_op_gpu);
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal); test_device!(binary_op, binary_op_cpu, binary_op_gpu);
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal); test_device!(embeddings, embeddings_cpu, embeddings_gpu);
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal); test_device!(cmp, cmp_cpu, cmp_gpu);
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal); test_device!(matmul, matmul_cpu, matmul_gpu);
test_device!( test_device!(broadcast_matmul, broadcast_matmul_cpu, broadcast_matmul_gpu);
broadcast_matmul, test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
broadcast_matmul_cpu, test_device!(index_select, index_select_cpu, index_select_gpu);
broadcast_matmul_gpu, test_device!(index_add, index_add_cpu, index_add_gpu);
broadcast_matmul_metal test_device!(gather, gather_cpu, gather_gpu);
); test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
test_device!( test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
broadcasting, test_device!(randn, randn_cpu, randn_gpu);
broadcasting_cpu, test_device!(clamp, clamp_cpu, clamp_gpu);
broadcasting_gpu,
broadcasting_metal
);
test_device!(
index_select,
index_select_cpu,
index_select_gpu,
index_select_metal
);
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
test_device!(
scatter_add,
scatter_add_cpu,
scatter_add_gpu,
scatter_add_metal
);
test_device!(
slice_scatter,
slice_scatter_cpu,
slice_scatter_gpu,
slice_scatter_metal
);
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
test_device!(var, var_cpu, var_gpu, var_metal);
// There was originally a bug on the CPU implementation for randn // There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381 // https://github.com/huggingface/candle/issues/381

View File

@ -4,9 +4,7 @@
//! <https://www.cs.toronto.edu/~kriz/cifar.html> //! <https://www.cs.toronto.edu/~kriz/cifar.html>
//! The binary version of the dataset is used. //! The binary version of the dataset is used.
use crate::vision::Dataset; use crate::vision::Dataset;
use candle::{DType, Device, Error, Result, Tensor}; use candle::{DType, Device, Result, Tensor};
use hf_hub::{api::sync::Api, Repo, RepoType};
use parquet::file::reader::{FileReader, SerializedFileReader};
use std::fs::File; use std::fs::File;
use std::io::{BufReader, Read}; use std::io::{BufReader, Read};
@ -62,58 +60,3 @@ pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<Dataset> {
labels: 10, labels: 10,
}) })
} }
fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor, Tensor)> {
let samples = parquet.metadata().file_metadata().num_rows() as usize;
let mut buffer_images: Vec<u8> = Vec::with_capacity(samples * 1_024);
let mut buffer_labels: Vec<u8> = Vec::with_capacity(samples);
for row in parquet.into_iter().flatten() {
for (_name, field) in row.get_column_iter() {
if let parquet::record::Field::Group(subrow) = field {
for (_name, field) in subrow.get_column_iter() {
if let parquet::record::Field::Bytes(value) = field {
let image = image::load_from_memory(value.data()).unwrap();
buffer_images.extend(image.to_rgb8().as_raw());
}
}
} else if let parquet::record::Field::Long(label) = field {
buffer_labels.push(*label as u8);
}
}
}
let images = (Tensor::from_vec(buffer_images, (samples, 3, 32, 32), &Device::Cpu)?
.to_dtype(DType::U8)?
/ 255.)?;
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
Ok((images, labels))
}
pub fn load() -> Result<Dataset> {
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let dataset_id = "cifar10".to_string();
let repo = Repo::with_revision(
dataset_id,
RepoType::Dataset,
"refs/convert/parquet".to_string(),
);
let repo = api.repo(repo);
let test_parquet_filename = repo
.get("plain_text/test/0000.parquet")
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let train_parquet_filename = repo
.get("plain_text/train/0000.parquet")
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
let (test_images, test_labels) = load_parquet(test_parquet)?;
let (train_images, train_labels) = load_parquet(train_parquet)?;
Ok(crate::vision::Dataset {
train_images,
train_labels,
test_images,
test_labels,
labels: 10,
})
}

View File

@ -16,7 +16,6 @@ candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
candle-nn = { path = "../candle-nn", version = "0.3.0" } candle-nn = { path = "../candle-nn", version = "0.3.0" }
candle-transformers = { path = "../candle-transformers", version = "0.3.0" } candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true } candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
candle-onnx = { path = "../candle-onnx", version = "0.3.0", optional = true }
cudarc = { workspace = true, optional = true } cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true } half = { workspace = true, optional = true }
image = { workspace = true } image = { workspace = true }
@ -52,12 +51,11 @@ anyhow = { workspace = true }
default = [] default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"] accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
metal = ["candle/metal", "candle-nn/metal", "candle-transformers/metal"]
cudnn = ["candle/cudnn"] cudnn = ["candle/cudnn"]
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"] flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["cuda", "cudarc/nccl", "dep:half"] nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
[[example]] [[example]]
name = "llama_multiprocess" name = "llama_multiprocess"
@ -66,11 +64,3 @@ required-features = ["cuda", "nccl", "flash-attn"]
[[example]] [[example]]
name = "reinforcement-learning" name = "reinforcement-learning"
required-features = ["pyo3"] required-features = ["pyo3"]
[[example]]
name = "onnx"
required-features = ["onnx"]
[[example]]
name = "onnx_basics"
required-features = ["onnx"]

View File

@ -8,7 +8,6 @@ use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
enum NormType { enum NormType {
WeightNorm, WeightNorm,
TimeGroupNorm,
None, None,
} }
@ -269,7 +268,6 @@ impl Module for EncodecConvTranspose1d {
struct EncodecConv1d { struct EncodecConv1d {
causal: bool, causal: bool,
conv: Conv1d, conv: Conv1d,
norm: Option<candle_nn::GroupNorm>,
} }
impl EncodecConv1d { impl EncodecConv1d {
@ -294,7 +292,7 @@ impl EncodecConv1d {
}, },
vb.pp("conv"), vb.pp("conv"),
)?, )?,
NormType::None | NormType::TimeGroupNorm => conv1d( NormType::None => conv1d(
in_c, in_c,
out_c, out_c,
kernel_size, kernel_size,
@ -307,17 +305,9 @@ impl EncodecConv1d {
vb.pp("conv"), vb.pp("conv"),
)?, )?,
}; };
let norm = match cfg.norm_type {
NormType::None | NormType::WeightNorm => None,
NormType::TimeGroupNorm => {
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
Some(gn)
}
};
Ok(Self { Ok(Self {
causal: cfg.use_causal_conv, causal: cfg.use_causal_conv,
conv, conv,
norm,
}) })
} }
} }
@ -326,10 +316,8 @@ impl Module for EncodecConv1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
// TODO: padding, depending on causal. // TODO: padding, depending on causal.
let xs = self.conv.forward(xs)?; let xs = self.conv.forward(xs)?;
match &self.norm { // If we add support for NormType "time_group_norm", we should add some normalization here.
None => Ok(xs), Ok(xs)
Some(norm) => xs.apply(norm),
}
} }
} }

View File

@ -1,10 +0,0 @@
## Using ONNX models in Candle
This example demonstrates how to run ONNX based models in Candle, the model
being used here is a small sequeezenet variant.
You can run the example with the following command:
```bash
cargo run --example squeezenet-onnx --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
```

View File

@ -1,78 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{IndexOp, D};
use clap::{Parser, ValueEnum};
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
SqueezeNet,
EfficientNet,
}
#[derive(Parser)]
struct Args {
#[arg(long)]
image: String,
#[arg(long)]
model: Option<String>,
/// The model to be used.
#[arg(value_enum, long, default_value_t = Which::SqueezeNet)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let image = candle_examples::imagenet::load_image224(args.image)?;
let image = match args.which {
Which::SqueezeNet => image,
Which::EfficientNet => image.permute((1, 2, 0))?,
};
println!("loaded image {image:?}");
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => match args.which {
Which::SqueezeNet => hf_hub::api::sync::Api::new()?
.model("lmz/candle-onnx".into())
.get("squeezenet1.1-7.onnx")?,
Which::EfficientNet => hf_hub::api::sync::Api::new()?
.model("onnx/EfficientNet-Lite4".into())
.get("efficientnet-lite4-11.onnx")?,
},
};
let model = candle_onnx::read_file(model)?;
let graph = model.graph.as_ref().unwrap();
let mut inputs = std::collections::HashMap::new();
inputs.insert(graph.input[0].name.to_string(), image.unsqueeze(0)?);
let mut outputs = candle_onnx::simple_eval(&model, inputs)?;
let output = outputs.remove(&graph.output[0].name).unwrap();
let prs = match args.which {
Which::SqueezeNet => candle_nn::ops::softmax(&output, D::Minus1)?,
Which::EfficientNet => output,
};
let prs = prs.i(0)?.to_vec1::<f32>()?;
// Sort the predictions and take the top 5
let mut top: Vec<_> = prs.iter().enumerate().collect();
top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
let top = top.into_iter().take(5).collect::<Vec<_>>();
// Print the top predictions
for &(i, p) in &top {
println!(
"{:50}: {:.2}%",
candle_examples::imagenet::CLASSES[i],
p * 100.0
);
}
Ok(())
}

View File

@ -1,87 +0,0 @@
use anyhow::Result;
use candle::{Device, Tensor};
use clap::{Parser, Subcommand};
#[derive(Subcommand, Debug, Clone)]
enum Command {
Print {
#[arg(long)]
file: String,
},
SimpleEval {
#[arg(long)]
file: String,
},
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
#[command(subcommand)]
command: Command,
}
pub fn main() -> Result<()> {
let args = Args::parse();
match args.command {
Command::Print { file } => {
let model = candle_onnx::read_file(file)?;
println!("{model:?}");
let graph = model.graph.unwrap();
for node in graph.node.iter() {
println!("{node:?}");
}
}
Command::SimpleEval { file } => {
let model = candle_onnx::read_file(file)?;
let graph = model.graph.as_ref().unwrap();
let constants: std::collections::HashSet<_> =
graph.initializer.iter().map(|i| i.name.as_str()).collect();
let mut inputs = std::collections::HashMap::new();
for input in graph.input.iter() {
use candle_onnx::onnx::tensor_proto::DataType;
if constants.contains(input.name.as_str()) {
continue;
}
let type_ = input.r#type.as_ref().expect("no type for input");
let type_ = type_.value.as_ref().expect("no type.value for input");
let value = match type_ {
candle_onnx::onnx::type_proto::Value::TensorType(tt) => {
let dt = match DataType::try_from(tt.elem_type) {
Ok(dt) => match candle_onnx::dtype(dt) {
Some(dt) => dt,
None => {
anyhow::bail!(
"unsupported 'value' data-type {dt:?} for {}",
input.name
)
}
},
type_ => anyhow::bail!("unsupported input type {type_:?}"),
};
let shape = tt.shape.as_ref().expect("no tensortype.shape for input");
let dims = shape
.dim
.iter()
.map(|dim| match dim.value.as_ref().expect("no dim value") {
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => Ok(42),
})
.collect::<Result<Vec<usize>>>()?;
Tensor::zeros(dims, dt, &Device::Cpu)?
}
type_ => anyhow::bail!("unsupported input type {type_:?}"),
};
println!("input {}: {value:?}", input.name);
inputs.insert(input.name.clone(), value);
}
let outputs = candle_onnx::simple_eval(&model, inputs)?;
for (name, value) in outputs.iter() {
println!("output {name}: {value:?}")
}
}
}
Ok(())
}

View File

@ -1,7 +1,5 @@
# candle-quantized-t5 # candle-quantized-t5
## Seq2Seq example
This example uses a quantized version of the t5 model. This example uses a quantized version of the t5 model.
```bash ```bash
@ -10,8 +8,6 @@ $ cargo run --example quantized-t5 --release -- --prompt "translate to German: A
Eine schöne Kerze. Eine schöne Kerze.
``` ```
## Generating Quantized weight files
The weight file is automatically retrieved from the hub. It is also possible to The weight file is automatically retrieved from the hub. It is also possible to
generate quantized weight files from the original safetensors file by using the generate quantized weight files from the original safetensors file by using the
`tensor-tools` command line utility via: `tensor-tools` command line utility via:
@ -20,11 +16,8 @@ generate quantized weight files from the original safetensors file by using the
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf $ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
``` ```
## Using custom models To use a different model, specify the `model-id`. For example, you can use
quantized [CoEdit models](https://huggingface.co/jbochi/candle-coedit-quantized).
To use a different model, specify the `model-id`.
For example, for text editing, you can use quantized [CoEdit models](https://huggingface.co/jbochi/candle-coedit-quantized).
```bash ```bash
$ cargo run --example quantized-t5 --release -- \ $ cargo run --example quantized-t5 --release -- \
@ -33,7 +26,6 @@ $ cargo run --example quantized-t5 --release -- \
--temperature 0 --temperature 0
... ...
Although their flight is weak, they run quickly through the tree canopy. Although their flight is weak, they run quickly through the tree canopy.
```
By default, it will look for `model.gguf` and `config.json`, but you can specify By default, it will look for `model.gguf` and `config.json`, but you can specify
custom local or remote `weight-file` and `config-file`s: custom local or remote `weight-file` and `config-file`s:
@ -48,16 +40,3 @@ cargo run --example quantized-t5 --release -- \
... ...
Note that a storm surge is what forecasters consider a hurricane's most dangerous part. Note that a storm surge is what forecasters consider a hurricane's most dangerous part.
``` ```
### [MADLAD-400](https://arxiv.org/abs/2309.04662)
MADLAD-400 is a series of multilingual machine translation T5 models trained on 250 billion tokens covering over 450 languages using publicly available data. These models are competitive with significantly larger models.
```bash
cargo run --example quantized-t5 --release -- \
--model-id "jbochi/madlad400-3b-mt" --weight-file "model-q4k.gguf" \
--prompt "<2de> How are you, my friend?" \
--temperature 0
...
Wie geht es dir, mein Freund?
```

View File

@ -173,11 +173,7 @@ fn main() -> Result<()> {
.to_vec(); .to_vec();
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let mut model = builder.build_model()?; let mut model = builder.build_model()?;
let mut output_token_ids = [builder let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
.config
.decoder_start_token_id
.unwrap_or(builder.config.pad_token_id) as u32]
.to_vec();
let temperature = if args.temperature <= 0. { let temperature = if args.temperature <= 0. {
None None
} else { } else {

View File

@ -9,10 +9,9 @@ use std::io::Write;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use candle::quantized::{ggml_file, gguf_file}; use candle::quantized::{ggml_file, gguf_file};
use candle::{Device, Tensor}; use candle::{Tensor};
use candle_transformers::generation::LogitsProcessor; use candle_transformers::generation::LogitsProcessor;
use candle_examples::token_output_stream::TokenOutputStream;
use candle_transformers::models::quantized_llama as model; use candle_transformers::models::quantized_llama as model;
use model::ModelWeights; use model::ModelWeights;
@ -25,7 +24,7 @@ enum Prompt {
One(String), One(String),
} }
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] #[derive(Clone, Debug, Copy, ValueEnum)]
enum Which { enum Which {
#[value(name = "7b")] #[value(name = "7b")]
L7b, L7b,
@ -49,10 +48,8 @@ enum Which {
Mistral7b, Mistral7b,
#[value(name = "7b-mistral-instruct")] #[value(name = "7b-mistral-instruct")]
Mistral7bInstruct, Mistral7bInstruct,
#[value(name = "7b-zephyr-a")] #[value(name = "7b-zephyr")]
Zephyr7bAlpha, Zephyr7b,
#[value(name = "7b-zephyr-b")]
Zephyr7bBeta,
} }
impl Which { impl Which {
@ -67,28 +64,7 @@ impl Which {
| Self::L7bCode | Self::L7bCode
| Self::L13bCode | Self::L13bCode
| Self::L34bCode => false, | Self::L34bCode => false,
// Zephyr is a fine tuned version of mistral and should be treated in the same way. Self::Mistral7b | Self::Mistral7bInstruct | Self::Zephyr7b => true,
Self::Zephyr7bAlpha
| Self::Zephyr7bBeta
| Self::Mistral7b
| Self::Mistral7bInstruct => true,
}
}
fn is_zephyr(&self) -> bool {
match self {
Self::L7b
| Self::L13b
| Self::L70b
| Self::L7bChat
| Self::L13bChat
| Self::L70bChat
| Self::L7bCode
| Self::L13bCode
| Self::L34bCode
| Self::Mistral7b
| Self::Mistral7bInstruct => false,
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
} }
} }
} }
@ -107,7 +83,7 @@ struct Args {
prompt: Option<String>, prompt: Option<String>,
/// The length of the sample to generate (in tokens). /// The length of the sample to generate (in tokens).
#[arg(short = 'n', long, default_value_t = 1000)] #[arg(short = 'n', long, default_value_t = 100)]
sample_len: usize, sample_len: usize,
/// The tokenizer config in json format. /// The tokenizer config in json format.
@ -200,13 +176,10 @@ impl Args {
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF", "TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
"mistral-7b-instruct-v0.1.Q4_K_S.gguf", "mistral-7b-instruct-v0.1.Q4_K_S.gguf",
), ),
Which::Zephyr7bAlpha => ( Which::Zephyr7b => (
"TheBloke/zephyr-7B-alpha-GGUF", "TheBloke/zephyr-7B-alpha-GGUF",
"zephyr-7b-alpha.Q4_K_M.gguf", "zephyr-7b-alpha.Q4_K_M.gguf",
), ),
Which::Zephyr7bBeta => {
("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf")
}
}; };
let api = hf_hub::api::sync::Api::new()?; let api = hf_hub::api::sync::Api::new()?;
let api = api.model(repo.to_string()); let api = api.model(repo.to_string());
@ -217,6 +190,31 @@ impl Args {
} }
} }
fn print_token(next_token: u32, tokenizer: &Tokenizer) {
// Extracting the last token as a string is complicated, here we just apply some simple
// heuristics as it seems to work well enough for this example. See the following for more
// details:
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
if let Some(text) = tokenizer.id_to_token(next_token) {
let text = text.replace('▁', " ");
let ascii = text
.strip_prefix("<0x")
.and_then(|t| t.strip_suffix('>'))
.and_then(|t| u8::from_str_radix(t, 16).ok());
match ascii {
None => print!("{text}"),
Some(ascii) => {
if let Some(chr) = char::from_u32(ascii as u32) {
if chr.is_ascii() {
print!("{chr}")
}
}
}
}
let _ = std::io::stdout().flush();
}
}
fn format_size(size_in_bytes: usize) -> String { fn format_size(size_in_bytes: usize) -> String {
if size_in_bytes < 1_000 { if size_in_bytes < 1_000 {
format!("{}B", size_in_bytes) format!("{}B", size_in_bytes)
@ -234,6 +232,7 @@ fn main() -> anyhow::Result<()> {
use tracing_subscriber::prelude::*; use tracing_subscriber::prelude::*;
let args = Args::parse(); let args = Args::parse();
let mut device = candle_examples::device(false)?;
let temperature = if args.temperature == 0. { let temperature = if args.temperature == 0. {
None None
} else { } else {
@ -278,10 +277,10 @@ fn main() -> anyhow::Result<()> {
&format_size(total_size_in_bytes), &format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(), start.elapsed().as_secs_f32(),
); );
ModelWeights::from_gguf(model, &mut file)? ModelWeights::from_gguf(model, &mut file, &device)?
} }
Some("ggml" | "bin") | Some(_) | None => { Some("ggml" | "bin") | Some(_) | None => {
let model = ggml_file::Content::read(&mut file)?; let model = ggml_file::Content::read(&mut file, &device)?;
let mut total_size_in_bytes = 0; let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensors.iter() { for (_, tensor) in model.tensors.iter() {
let elem_count = tensor.shape().elem_count(); let elem_count = tensor.shape().elem_count();
@ -305,18 +304,16 @@ fn main() -> anyhow::Result<()> {
| Which::L34bCode => 1, | Which::L34bCode => 1,
Which::Mistral7b Which::Mistral7b
| Which::Mistral7bInstruct | Which::Mistral7bInstruct
| Which::Zephyr7bAlpha | Which::Zephyr7b
| Which::Zephyr7bBeta
| Which::L70b | Which::L70b
| Which::L70bChat => 8, | Which::L70bChat => 8,
}; };
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))? ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa), &device)?
} }
}; };
println!("model built"); println!("model built");
let tokenizer = args.tokenizer()?; let tokenizer = args.tokenizer()?;
let mut tos = TokenOutputStream::new(tokenizer);
let prompt = match args.prompt.as_deref() { let prompt = match args.prompt.as_deref() {
Some("chat") => Prompt::Chat, Some("chat") => Prompt::Chat,
Some("interactive") => Prompt::Interactive, Some("interactive") => Prompt::Interactive,
@ -325,11 +322,10 @@ fn main() -> anyhow::Result<()> {
}; };
let mut pre_prompt_tokens = vec![]; let mut pre_prompt_tokens = vec![];
for prompt_index in 0.. { loop {
let prompt_str = match &prompt { let prompt_str = match &prompt {
Prompt::One(prompt) => prompt.clone(), Prompt::One(prompt) => prompt.clone(),
Prompt::Interactive | Prompt::Chat => { Prompt::Interactive | Prompt::Chat => {
let is_interactive = matches!(prompt, Prompt::Interactive);
print!("> "); print!("> ");
std::io::stdout().flush()?; std::io::stdout().flush()?;
let mut prompt = String::new(); let mut prompt = String::new();
@ -340,13 +336,7 @@ fn main() -> anyhow::Result<()> {
prompt.pop(); prompt.pop();
} }
} }
if args.which.is_zephyr() { if args.which.is_mistral() {
if prompt_index == 0 || is_interactive {
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>",)
} else {
format!("<|user|>\n{prompt}</s>\n<|assistant|>")
}
} else if args.which.is_mistral() {
format!("[INST] {prompt} [/INST]") format!("[INST] {prompt} [/INST]")
} else { } else {
prompt prompt
@ -354,8 +344,7 @@ fn main() -> anyhow::Result<()> {
} }
}; };
print!("{}", &prompt_str); print!("{}", &prompt_str);
let tokens = tos let tokens = tokenizer
.tokenizer()
.encode(prompt_str, true) .encode(prompt_str, true)
.map_err(anyhow::Error::msg)?; .map_err(anyhow::Error::msg)?;
if args.verbose_prompt { if args.verbose_prompt {
@ -378,51 +367,46 @@ fn main() -> anyhow::Result<()> {
let start_prompt_processing = std::time::Instant::now(); let start_prompt_processing = std::time::Instant::now();
let mut next_token = { let mut next_token = {
let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?; let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?; let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?; let logits = logits.squeeze(0)?;
// TODO Remove this once implementation is finished.
let logits = logits.ones_like()?;
logits_processor.sample(&logits)? logits_processor.sample(&logits)?
}; };
let prompt_dt = start_prompt_processing.elapsed(); let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token); all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? { print_token(next_token, &tokenizer);
print!("{t}");
std::io::stdout().flush()?;
}
let eos_token = *tos.tokenizer().get_vocab(true).get("</s>").unwrap(); let eos_token = *tokenizer.get_vocab(true).get("</s>").unwrap();
let start_post_prompt = std::time::Instant::now(); let start_post_prompt = std::time::Instant::now();
let mut sampled = 0;
for index in 0..to_sample { for index in 0..to_sample {
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?; let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?; let logits = model.forward(&input, prompt_tokens.len() + index)?;
if let candle::Device::Metal(device) = &mut device{
device.flush()
}
let logits = logits.squeeze(0)?; let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. { // let logits = if args.repeat_penalty == 1. {
logits // logits
} else { // } else {
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n); // let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty( // candle_transformers::utils::apply_repeat_penalty(
&logits, // &logits,
args.repeat_penalty, // args.repeat_penalty,
&all_tokens[start_at..], // &all_tokens[start_at..],
)? // )?
}; // };
// TODO Remove this once implementation is finished.
let logits = logits.ones_like()?;
next_token = logits_processor.sample(&logits)?; next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token); all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? { print_token(next_token, &tokenizer);
print!("{t}");
std::io::stdout().flush()?;
}
sampled += 1;
if next_token == eos_token { if next_token == eos_token {
break; break;
}; };
} }
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
let dt = start_post_prompt.elapsed(); let dt = start_post_prompt.elapsed();
println!( println!(
"\n\n{:4} prompt tokens processed: {:.2} token/s", "\n\n{:4} prompt tokens processed: {:.2} token/s",
@ -430,8 +414,9 @@ fn main() -> anyhow::Result<()> {
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(), prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
); );
println!( println!(
"{sampled:4} tokens generated: {:.2} token/s", "{:4} tokens generated: {:.2} token/s",
sampled as f64 / dt.as_secs_f64(), to_sample,
to_sample as f64 / dt.as_secs_f64(),
); );
match prompt { match prompt {

View File

@ -5,26 +5,12 @@
```bash ```bash
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode $ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode
... ...
Running on CPU, to run on GPU, build this example with `--features cuda`
Eine schöne Kerze. Eine schöne Kerze.
9 tokens generated (2.42 token/s) 9 tokens generated (2.42 token/s)
``` ```
Variants such as [flan-t5](https://huggingface.co/google/flan-t5-small), [flan-ul2](https://huggingface.co/google/flan-ul2) (with `--revision "refs/pr/25"`), and [Co-EdIT](https://huggingface.co/grammarly/coedit-large) are also supported. ## Sentence embedding example:
## Translation with [MADLAD-400](https://arxiv.org/abs/2309.04662)
MADLAD-400 is a series of multilingual machine translation T5 models trained on 250 billion tokens covering over 450 languages using publicly available data. These models are competitive with significantly larger models.
```bash
cargo run --example t5 --release -- \
--model-id "jbochi/madlad400-3b-mt" \
--prompt "<2de> How are you, my friend?" \
--decode --temperature 0
...
Wie geht es dir, mein Freund?
```
## Sentence embedding example
```bash ```bash
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle." $ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle."

View File

@ -104,17 +104,6 @@ impl T5ModelBuilder {
api.get("model-00004-of-00005.safetensors")?, api.get("model-00004-of-00005.safetensors")?,
api.get("model-00005-of-00005.safetensors")?, api.get("model-00005-of-00005.safetensors")?,
] ]
} else if model_id == "google/flan-ul2" {
vec![
api.get("model-00001-of-00008.safetensors")?,
api.get("model-00002-of-00008.safetensors")?,
api.get("model-00003-of-00008.safetensors")?,
api.get("model-00004-of-00008.safetensors")?,
api.get("model-00005-of-00008.safetensors")?,
api.get("model-00006-of-00008.safetensors")?,
api.get("model-00007-of-00008.safetensors")?,
api.get("model-00008-of-00008.safetensors")?,
]
} else { } else {
vec![api.get("model.safetensors")?] vec![api.get("model.safetensors")?]
}; };
@ -183,12 +172,7 @@ fn main() -> Result<()> {
println!("Took {:?}", start.elapsed()); println!("Took {:?}", start.elapsed());
} else { } else {
let mut model = builder.build_conditional_generation()?; let mut model = builder.build_conditional_generation()?;
let mut output_token_ids = [builder let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
.config
.decoder_start_token_id
.unwrap_or(builder.config.pad_token_id)
as u32]
.to_vec();
if let Some(decoder_prompt) = &args.decoder_prompt { if let Some(decoder_prompt) = &args.decoder_prompt {
print!("{decoder_prompt}"); print!("{decoder_prompt}");
output_token_ids.extend( output_token_ids.extend(

Binary file not shown.

Before

Width:  |  Height:  |  Size: 36 KiB

View File

@ -1,154 +0,0 @@
use image::{DynamicImage, ImageBuffer};
use serde::Deserialize;
use std::collections::HashMap;
use candle::{DType, Device, Result, Tensor};
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct ProcessorConfig {
do_resize: bool,
height: u32,
width: u32,
do_rescale: bool,
do_normalize: bool,
image_mean: Vec<f32>,
image_std: Vec<f32>,
}
impl Default for ProcessorConfig {
fn default() -> Self {
Self {
do_resize: true,
height: 384,
width: 384,
do_rescale: true,
do_normalize: true,
image_mean: vec![0.5, 0.5, 0.5],
image_std: vec![0.5, 0.5, 0.5],
}
}
}
pub struct ViTImageProcessor {
do_resize: bool,
height: u32,
width: u32,
do_normalize: bool,
image_mean: Vec<f32>,
image_std: Vec<f32>,
}
impl ViTImageProcessor {
pub fn new(config: &ProcessorConfig) -> Self {
Self {
do_resize: config.do_resize,
height: config.height,
width: config.width,
do_normalize: config.do_normalize,
image_mean: config.image_mean.clone(),
image_std: config.image_std.clone(),
}
}
pub fn preprocess(&self, images: Vec<&str>) -> Result<Tensor> {
let height = self.height as usize;
let width = self.width as usize;
let channels = 3;
let images = self.load_images(images)?;
let resized_images: Vec<DynamicImage> = if self.do_resize {
images
.iter()
.map(|image| self.resize(image.clone(), None).unwrap())
.collect()
} else {
images
};
let normalized_images: Vec<Tensor> = if self.do_normalize {
resized_images
.iter()
.map(|image| self.normalize(image.clone(), None, None).unwrap())
.collect()
} else {
let resized_images: Vec<ImageBuffer<image::Rgb<u8>, Vec<u8>>> =
resized_images.iter().map(|image| image.to_rgb8()).collect();
let data = resized_images
.into_iter()
.map(|image| image.into_raw())
.collect::<Vec<Vec<u8>>>();
data.iter()
.map(|image| {
Tensor::from_vec(image.clone(), (height, width, channels), &Device::Cpu)
.unwrap()
.permute((2, 0, 1))
.unwrap()
})
.collect::<Vec<Tensor>>()
};
Tensor::stack(&normalized_images, 0)
}
fn resize(
&self,
image: image::DynamicImage,
size: Option<HashMap<String, u32>>,
) -> Result<image::DynamicImage> {
let (height, width) = match &size {
Some(size) => (size.get("height").unwrap(), size.get("width").unwrap()),
None => (&self.height, &self.width),
};
let resized_image =
image.resize_exact(*width, *height, image::imageops::FilterType::Triangle);
Ok(resized_image)
}
fn normalize(
&self,
image: image::DynamicImage,
mean: Option<Vec<f32>>,
std: Option<Vec<f32>>,
) -> Result<Tensor> {
let mean = match mean {
Some(mean) => mean,
None => self.image_mean.clone(),
};
let std = match std {
Some(std) => std,
None => self.image_std.clone(),
};
let mean = Tensor::from_vec(mean, (3, 1, 1), &Device::Cpu)?;
let std = Tensor::from_vec(std, (3, 1, 1), &Device::Cpu)?;
let image = image.to_rgb8();
let data = image.into_raw();
let height = self.height as usize;
let width = self.width as usize;
let channels = 3;
let data =
Tensor::from_vec(data, &[height, width, channels], &Device::Cpu)?.permute((2, 0, 1))?;
(data.to_dtype(DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)
}
pub fn load_images(&self, image_path: Vec<&str>) -> Result<Vec<image::DynamicImage>> {
let mut images: Vec<image::DynamicImage> = Vec::new();
for path in image_path {
let img = image::io::Reader::open(path)?.decode().unwrap();
images.push(img);
}
Ok(images)
}
}

View File

@ -1,132 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::{Parser, ValueEnum};
use candle::{DType, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::models::trocr;
use tokenizers::Tokenizer;
mod image_processor;
#[derive(Clone, Debug, Copy, ValueEnum)]
enum Which {
Base,
Large,
}
#[derive(Parser, Debug)]
struct Args {
#[arg(long)]
model: Option<String>,
/// Choose the variant of the model to run.
#[arg(long, default_value = "base")]
which: Which,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Text to be translated
#[arg(long)]
image: String,
}
pub fn main() -> anyhow::Result<()> {
use hf_hub::api::sync::Api;
let args = Args::parse();
let tokenizer_dec = {
let tokenizer = Api::new()?
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
.get("tokenizer.json")?;
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
};
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
let device = candle_examples::device(args.cpu)?;
let vb = {
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => match args.which {
Which::Base => Api::new()?
.repo(hf_hub::Repo::with_revision(
"microsoft/trocr-base-handwritten".to_string(),
hf_hub::RepoType::Model,
"refs/pr/3".to_string(),
))
.get("model.safetensors")?,
Which::Large => Api::new()?
.repo(hf_hub::Repo::with_revision(
"microsoft/trocr-large-handwritten".to_string(),
hf_hub::RepoType::Model,
"refs/pr/6".to_string(),
))
.get("model.safetensors")?,
},
};
println!("model: {:?}", model);
unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }
};
let encoder_config = match args.which {
Which::Base => candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten(),
Which::Large => {
candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten()
}
};
let decoder_config = trocr::TrOCRConfig::default();
let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?;
let config = image_processor::ProcessorConfig::default();
let processor = image_processor::ViTImageProcessor::new(&config);
let image = vec![args.image.as_str()];
let image = processor.preprocess(image)?;
let encoder_xs = model.encoder().forward(&image)?;
let mut logits_processor =
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
let mut token_ids: Vec<u32> = vec![decoder_config.decoder_start_token_id];
for index in 0..1000 {
let context_size = if index >= 1 { 1 } else { token_ids.len() };
let start_pos = token_ids.len().saturating_sub(context_size);
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
let logits = model.decode(&input_ids, &encoder_xs, start_pos)?;
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
token_ids.push(token);
if let Some(t) = tokenizer_dec.next_token(token)? {
use std::io::Write;
print!("{t}");
std::io::stdout().flush()?;
}
if token == decoder_config.eos_token_id {
break;
}
}
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
println!();
Ok(())
}

View File

@ -1,16 +0,0 @@
# candle-trocr
`TrOCR` is a transformer OCR Model. In this example it is used to
transcribe image text. See the associated [model
card](https://huggingface.co/microsoft/trocr-base-printed) for details on
the model itself.
## Running an example
```bash
cargo run --example trocr --release -- --which base --cpu --image assets/trocr.png
```
```
<s> industry , Mr. Brown commented icily . " Let us have a</s>
```

View File

@ -345,7 +345,7 @@ enum Task {
Translate, Translate,
} }
#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)] #[derive(Clone, Copy, Debug, ValueEnum)]
enum WhichModel { enum WhichModel {
Tiny, Tiny,
#[value(name = "tiny.en")] #[value(name = "tiny.en")]
@ -361,27 +361,15 @@ enum WhichModel {
MediumEn, MediumEn,
Large, Large,
LargeV2, LargeV2,
LargeV3,
#[value(name = "distil-medium.en")]
DistilMediumEn,
#[value(name = "distil-large-v2")]
DistilLargeV2,
} }
impl WhichModel { impl WhichModel {
fn is_multilingual(&self) -> bool { fn is_multilingual(&self) -> bool {
match self { match self {
Self::Tiny Self::Tiny | Self::Base | Self::Small | Self::Medium | Self::Large | Self::LargeV2 => {
| Self::Base true
| Self::Small
| Self::Medium
| Self::Large
| Self::LargeV2
| Self::LargeV3
| Self::DistilLargeV2 => true,
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => {
false
} }
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn => false,
} }
} }
@ -397,9 +385,6 @@ impl WhichModel {
Self::MediumEn => ("openai/whisper-medium.en", "main"), Self::MediumEn => ("openai/whisper-medium.en", "main"),
Self::Large => ("openai/whisper-large", "refs/pr/36"), Self::Large => ("openai/whisper-large", "refs/pr/36"),
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"), Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
Self::LargeV3 => ("openai/whisper-large-v3", "main"),
Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"),
Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"),
} }
} }
} }
@ -511,25 +496,17 @@ fn main() -> Result<()> {
repo.get(&format!("model-{ext}-q80.gguf"))?, repo.get(&format!("model-{ext}-q80.gguf"))?,
) )
} else { } else {
let config = repo.get("config.json")?; (
let tokenizer = if args.model == WhichModel::LargeV3 { repo.get("config.json")?,
panic!("openai/whisper-large-v3 does not provide a compatible tokenizer.json config at the moment") repo.get("tokenizer.json")?,
} else { repo.get("model.safetensors")?,
repo.get("tokenizer.json")? )
};
let model = repo.get("model.safetensors")?;
(config, tokenizer, model)
}; };
(config, tokenizer, model, sample) (config, tokenizer, model, sample)
}; };
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let mel_bytes = match config.num_mel_bins { let mel_bytes = include_bytes!("melfilters.bytes");
80 => include_bytes!("melfilters.bytes").as_slice(),
128 => include_bytes!("melfilters128.bytes").as_slice(),
nmel => anyhow::bail!("unexpected num_mel_bins {nmel}"),
};
let mut mel_filters = vec![0f32; mel_bytes.len() / 4]; let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters); <byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
@ -545,15 +522,12 @@ fn main() -> Result<()> {
.map(|v| *v as f32 / 32768.) .map(|v| *v as f32 / 32768.)
.collect(); .collect();
println!("pcm data loaded {}", pcm_data.len()); println!("pcm data loaded {}", pcm_data.len());
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters); let mel = audio::pcm_to_mel(&pcm_data, &mel_filters);
let mel_len = mel.len(); let mel_len = mel.len();
let mel = Tensor::from_vec( let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
mel,
(1, config.num_mel_bins, mel_len / config.num_mel_bins),
&device,
)?;
println!("loaded mel: {:?}", mel.dims()); println!("loaded mel: {:?}", mel.dims());
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let mut model = if args.quantized { let mut model = if args.quantized {
let vb = let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?; candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;

View File

@ -1,268 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum};
use candle_transformers::models::yi::{Config, Model};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
#[value(name = "6b")]
L6b,
#[value(name = "34b")]
L34b,
}
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("</s>") {
Some(token) => token,
None => anyhow::bail!("cannot find the </s> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
#[arg(long, default_value = "01-ai/Yi-6B")]
model_id: String,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model size to use.
#[arg(long, default_value = "6b")]
which: Which,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
args.model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => match args.which {
Which::L6b => vec![
repo.get("model-00001-of-00002.safetensors")?,
repo.get("model-00002-of-00002.safetensors")?,
],
Which::L34b => vec![
repo.get("model-00001-of-00007.safetensors")?,
repo.get("model-00002-of-00007.safetensors")?,
repo.get("model-00003-of-00007.safetensors")?,
repo.get("model-00004-of-00007.safetensors")?,
repo.get("model-00005-of-00007.safetensors")?,
repo.get("model-00006-of-00007.safetensors")?,
repo.get("model-00007-of-00007.safetensors")?,
],
},
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = match args.which {
Which::L6b => Config::config_6b(),
Which::L34b => Config::config_34b(),
};
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -8,22 +8,24 @@ use candle::{Device, Result, Tensor};
pub fn device(cpu: bool) -> Result<Device> { pub fn device(cpu: bool) -> Result<Device> {
if cpu { if cpu {
Ok(Device::Cpu) Ok(Device::Cpu)
} else if cuda_is_available() {
Ok(Device::new_cuda(0)?)
} else if metal_is_available() {
Ok(Device::new_metal(0)?)
} else { } else {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))] if cuda_is_available() {
{ Ok(Device::new_cuda(0)?)
println!( } else if metal_is_available() {
"Running on CPU, to run on GPU(metal), build this example with `--features metal`" Ok(Device::new_metal(0)?)
); } else {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
println!("Running on CPU, to run on GPU(metal), build this example with `--features metal`");
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
println!(
"Running on CPU, to run on GPU, build this example with `--features cuda`"
);
}
Ok(Device::Cpu)
} }
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
}
Ok(Device::Cpu)
} }
} }

View File

@ -233,8 +233,8 @@ impl FlashAttnVarLen {
let (seqlens_q, seqlens_q_layout) = self.seqlens_q.storage_and_layout(); let (seqlens_q, seqlens_q_layout) = self.seqlens_q.storage_and_layout();
let seqlens_q = match &*seqlens_q { let seqlens_q = match &*seqlens_q {
candle::Storage::Cpu(_) => candle::bail!("seqlens_q must be a cuda tensor"),
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32! candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
_ => candle::bail!("seqlens_q must be a cuda tensor"),
}; };
let seqlens_q = match seqlens_q_layout.contiguous_offsets() { let seqlens_q = match seqlens_q_layout.contiguous_offsets() {
Some((o1, o2)) => seqlens_q.slice(o1..o2), Some((o1, o2)) => seqlens_q.slice(o1..o2),
@ -243,8 +243,8 @@ impl FlashAttnVarLen {
let (seqlens_k, seqlens_k_layout) = self.seqlens_k.storage_and_layout(); let (seqlens_k, seqlens_k_layout) = self.seqlens_k.storage_and_layout();
let seqlens_k = match &*seqlens_k { let seqlens_k = match &*seqlens_k {
candle::Storage::Cpu(_) => candle::bail!("seqlens_k must be a cuda tensor"),
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32! candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
_ => candle::bail!("seqlens_k must be a cuda tensor"),
}; };
let seqlens_k = match seqlens_k_layout.contiguous_offsets() { let seqlens_k = match seqlens_k_layout.contiguous_offsets() {
Some((o1, o2)) => seqlens_k.slice(o1..o2), Some((o1, o2)) => seqlens_k.slice(o1..o2),

View File

@ -1,20 +1,12 @@
[package] [package]
name = "candle-metal-kernels" name = "candle-metal-kernels"
version = "0.3.0" version.workspace = true
edition = "2021" edition.workspace = true
description.workspace = true
description = "CUDA kernels for Candle" repository.workspace = true
repository = "https://github.com/huggingface/candle" keywords.workspace = true
keywords = ["blas", "tensor", "machine-learning"] categories.workspace = true
categories = ["science"] license.workspace = true
license = "MIT OR Apache-2.0"
[dependencies] [dependencies]
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } metal = { workspace = true }
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"
[dev-dependencies]
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
rand = "0.8.5"

View File

@ -1,61 +0,0 @@
#include <metal_stdlib>
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
using namespace metal;
#define AFFINE(FN_NAME, TYPENAME) \
kernel void FN_NAME( \
constant size_t &dim, \
constant float &mul, \
constant float &add, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
const TYPENAME m = TYPENAME(mul); \
const TYPENAME a = TYPENAME(add); \
output[id] = input[id] * m + a; \
} \
kernel void FN_NAME##_strided( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant float &mul, \
constant float &add, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
const TYPENAME m = TYPENAME(mul); \
const TYPENAME a = TYPENAME(add); \
output[id] = input[get_strided_index(id, num_dims, dims, strides)] * m + a; \
} \
AFFINE(affine_float, float)
AFFINE(affine_half, half)
#if __METAL_VERSION__ >= 310
AFFINE(affine_bfloat, bfloat);
#endif

View File

@ -1,72 +0,0 @@
#include <metal_stdlib>
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
using namespace metal;
#define BINARY(FN, TYPENAME, OUT_TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
constant size_t &dim, \
device const TYPENAME *left, \
device const TYPENAME *right, \
device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
) { \
if (thread_position_in_grid >= dim) { \
return; \
} \
TYPENAME x = left[thread_position_in_grid]; \
TYPENAME y = right[thread_position_in_grid]; \
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *left_strides, \
constant size_t *right_strides, \
device const TYPENAME *left, \
device const TYPENAME *right, \
device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
) { \
if (thread_position_in_grid >= dim) { \
return; \
} \
TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \
TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, right_strides)]; \
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
}
#define BINARY_OP(FN, NAME) \
BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \
BINARY(FN, half, half, NAME##_half, NAME##_half_strided);
#define BFLOAT_BINARY_OP(FN, NAME) \
BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
BINARY_OP(x + y, add)
BINARY_OP(x - y, sub)
BINARY_OP(x * y, mul)
BINARY_OP(x / y, div)
#if __METAL_VERSION__ >= 310
BFLOAT_BINARY_OP(x + y, add)
BFLOAT_BINARY_OP(x - y, sub)
BFLOAT_BINARY_OP(x * y, mul)
BFLOAT_BINARY_OP(x / y, div)
#endif

View File

@ -1,53 +0,0 @@
#include <metal_stdlib>
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
using namespace metal;
#define CAST(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME) \
kernel void FN_NAME( \
constant size_t &dim, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
) { \
if (thread_position_in_grid >= dim) { \
return; \
} \
output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \
} \
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint i [[ thread_position_in_grid ]] \
) { \
if (i >= dim) { \
return; \
} \
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
} \
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
#if __METAL_VERSION__ >= 310
#endif

View File

@ -1,103 +0,0 @@
#include <metal_stdlib>
using namespace metal;
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \
constant size_t &dst_size, \
constant size_t &left_size, \
constant size_t &src_dim_size, \
constant size_t &right_size, \
constant size_t &ids_size, \
const device TYPENAME *input, \
const device INDEX_TYPENAME *input_ids, \
device TYPENAME *output, \
uint gid [[ thread_position_in_grid ]] \
) { \
if (gid >= dst_size) { \
return; \
} \
const size_t id_i = (gid / right_size) % ids_size; \
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
const size_t right_rank_i = gid % right_size; \
const size_t left_rank_i = gid / right_size / ids_size; \
/* \
// Force prevent out of bounds indexing \
// since there doesn't seem to be a good way to force crash \
// No need to check for zero we're only allowing unsized. \
*/ \
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \
output[gid] = input[src_i]; \
}
template <typename T, typename I>
void index_add(
device I *ids [[buffer(0)]],
device T *inp [[buffer(1)]],
device T *out [[buffer(2)]],
constant uint &ids_dim_size,
constant uint &left_size,
constant uint &dst_dim_size,
constant uint &right_size,
uint gid [[ thread_position_in_grid ]] \
) {
if (gid >= left_size * right_size) {
return;
}
const uint i = gid;
const uint pre = i / right_size;
const uint post = i % right_size;
for (uint j = 0; j < ids_dim_size; j++) {
const uint idx = ids[j];
const uint src_i = (pre * ids_dim_size + j) * right_size + post;
const uint dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] += inp[src_i];
}
}
#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
kernel void FN_NAME( \
device INDEX_TYPENAME *ids [[buffer(0)]], \
device TYPENAME *inp [[buffer(1)]], \
device TYPENAME *out [[buffer(2)]], \
constant uint &ids_dim_size, \
constant uint &left_size, \
constant uint &dst_dim_size, \
constant uint &right_size, \
uint gid [[ thread_position_in_grid ]] \
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \
INDEX_OP(is_u32_f32, uint, float)
INDEX_OP(is_u32_f16, uint, half)
#if __METAL_VERSION__ >= 310
IA_OP(bfloat, int64_t, ia_i64_bf16)
IA_OP(bfloat, uint32_t, ia_u32_bf16)
IA_OP(bfloat, uint8_t, ia_u8_bf16)
#endif
IA_OP(half, uint32_t, ia_u32_f16)
IA_OP(half, uint8_t, ia_u8_f16)
IA_OP(float, int64_t, ia_i64_f32)
IA_OP(uint8_t, int64_t, ia_i64_u8)
IA_OP(int64_t, int64_t, ia_i64_i64)
IA_OP(uint32_t, int64_t, ia_i64_u32)
IA_OP(float, uint32_t, ia_u32_f32)
IA_OP(uint8_t, uint32_t, ia_u32_u8)
IA_OP(int64_t, uint32_t, ia_u32_i64)
IA_OP(uint32_t, uint32_t, ia_u32_u32)
IA_OP(float, uint8_t, ia_u8_f32)
IA_OP(uint8_t, uint8_t, ia_u8_u8)
IA_OP(uint32_t, uint8_t, ia_u8_u32)
IA_OP(int64_t, uint8_t, ia_u8_i64)

File diff suppressed because it is too large Load Diff

View File

@ -1,139 +0,0 @@
#include <metal_stdlib>
using namespace metal;
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
constant int THREADGROUP_SIZE = 1024;
# define REDUCE(FN, NAME, TYPENAME) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const TYPENAME *src, \
device TYPENAME *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint blockDim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup float shared_memory[THREADGROUP_SIZE]; \
\
shared_memory[tid] = 0; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
*/ \
TYPENAME x = shared_memory[tid]; \
TYPENAME y = src[idx]; \
shared_memory[tid] = FN; \
idx += blockDim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = blockDim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
TYPENAME x = shared_memory[tid]; \
TYPENAME y = shared_memory[tid + s]; \
shared_memory[tid] = FN; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
dst[dst_id] = shared_memory[0]; \
} \
kernel void softmax_float(
constant size_t &src_numel,
constant size_t &el_to_sum_per_block,
device const float *src,
device float *dst,
uint id [[ thread_position_in_grid ]],
uint tid [[ thread_index_in_threadgroup ]],
uint dst_id [[ threadgroup_position_in_grid ]],
uint blockDim [[ threads_per_threadgroup ]]
) {
threadgroup float shared_memory[THREADGROUP_SIZE];
shared_memory[tid] = -INFINITY;
// Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block.
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
size_t idx = start_idx + tid;
while (idx < stop_idx) {
// TODO: Fast version for the contiguous case.
shared_memory[tid] = max(shared_memory[tid], src[idx]);
idx += blockDim;
}
threadgroup_barrier(mem_flags::mem_none);
// reduction in shared memory
for (uint s = blockDim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]);
}
threadgroup_barrier(mem_flags::mem_none);
}
float max = shared_memory[0];
shared_memory[tid] = 0;
// Restart
idx = start_idx + tid;
while (idx < stop_idx) {
// TODO: Fast version for the contiguous case.
const float val = exp(src[idx] - max);
dst[idx] = val;
shared_memory[tid] += val;
idx += blockDim;
}
// reduction in shared memory
for (uint s = blockDim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] += shared_memory[tid + s];
}
threadgroup_barrier(mem_flags::mem_none);
}
const float inv_acc = 1/shared_memory[0];
idx = start_idx + tid;
while (idx < stop_idx) {
dst[idx] *= inv_acc;
idx += blockDim;
}
}
REDUCE(x + y, fast_sum_float, float)
REDUCE(x * y, fast_mul_float, float)
REDUCE(max(x, y), fast_max_float, float)

View File

@ -1,57 +0,0 @@
#include <metal_stdlib>
#
using namespace metal;
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \
kernel void FN_NAME( \
constant size_t &numel, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t *strides_t, \
constant size_t *strides_f, \
device const ID_TYPENAME *ids, \
device const TYPENAME *t, \
device const TYPENAME *f, \
device TYPENAME *out ,\
uint i [[ thread_position_in_grid ]] \
) { \
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \
} \
// WHERE_OP(float, int64_t, where_i64_f32)
// WHERE_OP(double, int64_t, where_i64_f64)
// WHERE_OP(uint8_t, int64_t, where_i64_u8)
// WHERE_OP(uint32_t, int64_t, where_i64_u32)
// WHERE_OP(int64_t, int64_t, where_i64_i64)
//
// WHERE_OP(float, uint32_t, where_u32_f32)
// WHERE_OP(double, uint32_t, where_u32_f64)
// WHERE_OP(uint8_t, uint32_t, where_u32_u8)
// WHERE_OP(uint32_t, uint32_t, where_u32_u32)
// WHERE_OP(int64_t, uint32_t, where_u32_i64)
WHERE_OP(float, uint8_t, where_u8_f32)
// WHERE_OP(double, uint8_t, where_u8_f64)
// WHERE_OP(uint8_t, uint8_t, where_u8_u8)
// WHERE_OP(uint32_t, uint8_t, where_u8_u32)
// WHERE_OP(int64_t, uint8_t, where_u8_i64)

View File

@ -1,126 +0,0 @@
#include <metal_stdlib>
#include <metal_math>
#
using namespace metal;
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
template <typename T> METAL_FUNC T neg(T in){ return -in; }
template <typename T> METAL_FUNC T erf(T in){
float x = (float) in;
// constants
float a1 = 0.254829592;
float a2 = -0.284496736;
float a3 = 1.421413741;
float a4 = -1.453152027;
float a5 = 1.061405429;
float p = 0.3275911;
// Save the sign of x
int sign = 1;
if (x < 0)
sign = -1;
x = fabs(x);
// A&S formula 7.1.26
float t = 1.0/(1.0 + p*x);
float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x);
return T(sign*y);
}
template <typename T> METAL_FUNC T id(T in){ return in; }
template <typename T> METAL_FUNC T gelu_erf(T x){ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); }
template <typename T> METAL_FUNC T gelu(T x){
T x_sq = x * x;
T x_cube = x_sq * x;
T alpha = x + static_cast<T>(0.044715) * x_cube;
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
}
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
constant size_t &dim, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
) { \
if (thread_position_in_grid >= dim) { \
return; \
} \
output[thread_position_in_grid] = TYPENAME(FN(input[thread_position_in_grid])); \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
) { \
if (thread_position_in_grid >= dim) { \
return; \
} \
output[thread_position_in_grid] = TYPENAME(FN(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)])); \
}
#define UNARY_OP(NAME) \
UNARY(NAME, float, NAME##_float, NAME##_float_strided); \
UNARY(NAME, half, NAME##_half, NAME##_half_strided);
#define BFLOAT_UNARY_OP(NAME) \
UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
UNARY_OP(cos)
UNARY_OP(sin)
UNARY_OP(sqr)
UNARY_OP(sqrt)
UNARY_OP(neg)
UNARY_OP(exp)
UNARY_OP(log)
UNARY_OP(gelu)
UNARY_OP(ceil)
UNARY_OP(floor)
UNARY_OP(round)
UNARY_OP(gelu_erf)
UNARY_OP(erf)
UNARY(id, float, copy_float, copy_float_strided)
UNARY(id, half, copy_half, copy_half_strided)
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
#if __METAL_VERSION__ >= 310
BFLOAT_UNARY_OP(cos)
BFLOAT_UNARY_OP(sin)
BFLOAT_UNARY_OP(sqr)
BFLOAT_UNARY_OP(sqrt)
BFLOAT_UNARY_OP(neg)
BFLOAT_UNARY_OP(exp)
BFLOAT_UNARY_OP(log)
BFLOAT_UNARY_OP(gelu)
BFLOAT_UNARY_OP(ceil)
BFLOAT_UNARY_OP(floor)
BFLOAT_UNARY_OP(round)
BFLOAT_UNARY_OP(gelu_erf)
BFLOAT_UNARY_OP(erf)
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
#endif

View File

@ -1,76 +0,0 @@
use candle_metal_kernels::{call_affine, Kernels};
use metal::objc::rc::autoreleasepool;
use metal::{Device, MTLResourceOptions};
use rand;
use std::any::type_name;
use std::time::Instant;
fn main() {
let device = Device::system_default().unwrap();
let kernels = Kernels::new();
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
let f32_10k = (0..10000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f32_100k = (0..100000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
"dtype", "kernel", "size", "runs", "total time", "avg time"
);
// f32
run_affine_bench(&device, &kernels, &f32_1k);
run_affine_bench(&device, &kernels, &f32_10k);
run_affine_bench(&device, &kernels, &f32_100k);
}
fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) {
let command_queue = device.new_command_queue();
let options = MTLResourceOptions::StorageModeManaged;
let iterations = 10000;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(v) as u64,
options,
);
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
let mul: f32 = 1.2345;
let add: f32 = 2.3456;
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_affine(
&device,
command_buffer,
&kernels,
"affine_float",
v.len(),
&input,
&mut output,
mul,
add,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
"affine",
v.len(),
iterations,
total_time,
total_time / iterations
);
}

View File

@ -1,182 +0,0 @@
use candle_metal_kernels::{binary, call_binary_contiguous, call_binary_strided, Kernels};
use half::{bf16, f16};
use metal::objc::rc::autoreleasepool;
use metal::{Device, MTLResourceOptions};
use rand;
use std::any::type_name;
use std::time::Instant;
fn main() {
let device = Device::system_default().unwrap();
let kernels = Kernels::new();
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
let f32_10k = (0..10000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f32_100k = (0..100000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
let f16_1k = f16_map(&f32_1k);
let f16_10k = f16_map(&f32_10k);
let f16_100k = f16_map(&f32_100k);
let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
let bf16_1k = bf16_map(&f32_1k);
let bf16_10k = bf16_map(&f32_10k);
let bf16_100k = bf16_map(&f32_100k);
let f32_ckernels = [
binary::contiguous::add::FLOAT,
binary::contiguous::sub::FLOAT,
binary::contiguous::mul::FLOAT,
binary::contiguous::div::FLOAT,
];
let f32_skernels = [
binary::strided::add::FLOAT,
binary::strided::sub::FLOAT,
binary::strided::mul::FLOAT,
binary::strided::div::FLOAT,
];
let f16_ckernels = [
binary::contiguous::add::HALF,
binary::contiguous::sub::HALF,
binary::contiguous::mul::HALF,
binary::contiguous::div::HALF,
];
let f16_skernels = [
binary::strided::add::HALF,
binary::strided::sub::HALF,
binary::strided::mul::HALF,
binary::strided::div::HALF,
];
let bf16_ckernels = [
binary::contiguous::add::BFLOAT,
binary::contiguous::sub::BFLOAT,
binary::contiguous::mul::BFLOAT,
binary::contiguous::div::BFLOAT,
];
let bf16_skernels = [
binary::strided::add::BFLOAT,
binary::strided::sub::BFLOAT,
binary::strided::mul::BFLOAT,
binary::strided::div::BFLOAT,
];
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
"dtype", "kernel", "size", "runs", "total time", "avg time"
);
// f32
run_binary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels);
run_binary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels);
run_binary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels);
// f16
run_binary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels);
run_binary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels);
run_binary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels);
// bf16
run_binary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels);
run_binary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels);
run_binary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels);
}
fn run_binary_bench<T: Clone>(
device: &Device,
kernels: &Kernels,
v: &[T],
contiguous: [binary::contiguous::Kernel; 4],
strided: [binary::strided::Kernel; 4],
) {
let command_queue = device.new_command_queue();
let options = MTLResourceOptions::StorageModeManaged;
let iterations = 1000;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(v) as u64,
options,
);
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
// Contiguous
for kernel_name in contiguous {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_binary_contiguous(
device,
&command_buffer,
kernels,
kernel_name,
v.len(),
&input,
&input,
&mut output,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.to_string(),
v.len(),
iterations,
total_time,
total_time / iterations
);
}
// Strided
let shape = vec![2, 5_000];
let strides = vec![2, 1];
let offset = 0;
for kernel_name in strided {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_binary_strided(
device,
command_buffer,
&kernels,
kernel_name,
&shape,
&input,
&strides,
offset,
&input,
&strides,
offset,
&mut output,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.to_string(),
v.len(),
iterations,
total_time,
total_time / iterations
);
}
}

View File

@ -1,84 +0,0 @@
use candle_metal_kernels::{call_cast_contiguous, Kernels};
use metal::objc::rc::autoreleasepool;
use metal::{Device, MTLResourceOptions};
use rand;
use std::any::type_name;
use std::time::Instant;
fn main() {
let device = Device::system_default().unwrap();
let kernels = Kernels::new();
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
let f32_10k = (0..10000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f32_100k = (0..100000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let contiguous_kernels = ["cast_u32_f32"];
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
"dtype", "kernel", "size", "runs", "total time", "avg time"
);
// f32
run_cast_bench(&device, &kernels, &f32_1k, &contiguous_kernels);
run_cast_bench(&device, &kernels, &f32_10k, &contiguous_kernels);
run_cast_bench(&device, &kernels, &f32_100k, &contiguous_kernels);
}
fn run_cast_bench<T: Clone>(
device: &Device,
kernels: &Kernels,
v: &[T],
contiguous: &[&'static str],
) {
let command_queue = device.new_command_queue();
let options = MTLResourceOptions::StorageModeManaged;
let iterations = 1000;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(v) as u64,
options,
);
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
// Contiguous
for kernel_name in contiguous {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_cast_contiguous(
device,
&command_buffer,
kernels,
kernel_name,
v.len(),
&input,
&mut output,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.to_string(),
v.len(),
iterations,
total_time,
total_time / iterations
);
}
// Strided?
}

View File

@ -1,197 +0,0 @@
use candle_metal_kernels::{call_unary_contiguous, call_unary_strided, unary, Kernels};
use half::{bf16, f16};
use metal::objc::rc::autoreleasepool;
use metal::{Device, MTLResourceOptions};
use rand;
use std::any::type_name;
use std::time::Instant;
fn main() {
let device = Device::system_default().unwrap();
let kernels = Kernels::new();
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
let f32_10k = (0..10000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f32_100k = (0..100000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
let f16_1k = f16_map(&f32_1k);
let f16_10k = f16_map(&f32_10k);
let f16_100k = f16_map(&f32_100k);
let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
let bf16_1k = bf16_map(&f32_1k);
let bf16_10k = bf16_map(&f32_10k);
let bf16_100k = bf16_map(&f32_100k);
let f32_ckernels = [
unary::contiguous::sin::FLOAT,
unary::contiguous::cos::FLOAT,
unary::contiguous::exp::FLOAT,
unary::contiguous::sqr::FLOAT,
unary::contiguous::sqrt::FLOAT,
unary::contiguous::neg::FLOAT,
unary::contiguous::copy::FLOAT,
];
let f32_skernels = [
unary::strided::sin::FLOAT,
unary::strided::cos::FLOAT,
unary::strided::exp::FLOAT,
unary::strided::sqr::FLOAT,
unary::strided::sqrt::FLOAT,
unary::strided::neg::FLOAT,
unary::strided::copy::FLOAT,
];
let f16_ckernels = [
unary::contiguous::sin::HALF,
unary::contiguous::cos::HALF,
unary::contiguous::exp::HALF,
unary::contiguous::sqr::HALF,
unary::contiguous::sqrt::HALF,
unary::contiguous::neg::HALF,
unary::contiguous::copy::HALF,
];
let f16_skernels = [
unary::strided::sin::HALF,
unary::strided::cos::HALF,
unary::strided::exp::HALF,
unary::strided::sqr::HALF,
unary::strided::sqrt::HALF,
unary::strided::neg::HALF,
unary::strided::copy::HALF,
];
let bf16_ckernels = [
unary::contiguous::sin::BFLOAT,
unary::contiguous::cos::BFLOAT,
unary::contiguous::exp::BFLOAT,
unary::contiguous::sqr::BFLOAT,
unary::contiguous::sqrt::BFLOAT,
unary::contiguous::neg::BFLOAT,
unary::contiguous::copy::BFLOAT,
];
let bf16_skernels = [
unary::strided::sin::BFLOAT,
unary::strided::cos::BFLOAT,
unary::strided::exp::BFLOAT,
unary::strided::sqr::BFLOAT,
unary::strided::sqrt::BFLOAT,
unary::strided::neg::BFLOAT,
unary::strided::copy::BFLOAT,
];
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
"dtype", "kernel", "size", "runs", "total time", "avg time"
);
// f32
run_unary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels);
run_unary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels);
run_unary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels);
// f16
run_unary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels);
run_unary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels);
run_unary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels);
// bf16
run_unary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels);
run_unary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels);
run_unary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels);
}
fn run_unary_bench<T: Clone>(
device: &Device,
kernels: &Kernels,
v: &[T],
contiguous: [unary::contiguous::Kernel; 7],
strided: [unary::strided::Kernel; 7],
) {
let command_queue = device.new_command_queue();
let options = MTLResourceOptions::StorageModeManaged;
let iterations = 10000;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(v) as u64,
options,
);
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
// Contiguous
for kernel_name in contiguous {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_unary_contiguous(
device,
&command_buffer,
kernels,
kernel_name,
v.len(),
&input,
&mut output,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.0,
v.len(),
iterations,
total_time,
total_time / iterations
);
}
// Strided
let shape = vec![2, 5_000];
let strides = vec![2, 1];
let offset = 0;
for kernel_name in &strided {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_unary_strided(
device,
command_buffer,
&kernels,
kernel_name,
&shape,
&input,
&strides,
offset,
&mut output,
0,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.0,
v.len(),
iterations,
total_time,
total_time / iterations
);
}
}

View File

@ -19,7 +19,6 @@ num-traits = { workspace = true }
rayon = { workspace = true } rayon = { workspace = true }
safetensors = { workspace = true } safetensors = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
[dev-dependencies] [dev-dependencies]
anyhow = { workspace = true } anyhow = { workspace = true }
@ -29,5 +28,5 @@ clap = { workspace = true }
default = [] default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"] accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"] cuda = ["candle/cuda"]
metal = ["candle/metal"]
mkl = ["dep:intel-mkl-src", "candle/mkl"] mkl = ["dep:intel-mkl-src", "candle/mkl"]
metal = ["candle/metal", "dep:candle-metal-kernels"]

View File

@ -6,16 +6,14 @@ use serde::Deserialize;
pub enum Activation { pub enum Activation {
#[default] #[default]
Gelu, Gelu,
#[serde(rename = "gated-gelu")]
NewGelu, NewGelu,
Relu, Relu,
Relu2, Relu2,
Relu6, Relu6,
Silu, Silu,
Sigmoid, Sigmoid,
HardSigmoid,
Swiglu,
Swish, Swish,
HardSwish,
Elu(f64), Elu(f64),
LeakyRelu(f64), LeakyRelu(f64),
} }
@ -31,10 +29,7 @@ impl super::Module for Activation {
Self::Relu6 => xs.clamp(0f32, 6f32), Self::Relu6 => xs.clamp(0f32, 6f32),
Self::Silu => crate::ops::silu(xs), Self::Silu => crate::ops::silu(xs),
Self::Sigmoid => crate::ops::sigmoid(xs), Self::Sigmoid => crate::ops::sigmoid(xs),
Self::HardSigmoid => crate::ops::hard_sigmoid(xs),
Self::Swiglu => crate::ops::swiglu(xs),
Self::Swish => xs * crate::ops::sigmoid(xs)?, Self::Swish => xs * crate::ops::sigmoid(xs)?,
Self::HardSwish => xs * crate::ops::hard_sigmoid(xs)?,
&Self::Elu(alpha) => xs.elu(alpha), &Self::Elu(alpha) => xs.elu(alpha),
&Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope), &Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope),
} }

View File

@ -70,67 +70,6 @@ impl crate::Module for Conv1d {
} }
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ConvTranspose1dConfig {
pub padding: usize,
pub output_padding: usize,
pub stride: usize,
pub dilation: usize,
// TODO: support groups.
}
impl Default for ConvTranspose1dConfig {
fn default() -> Self {
Self {
padding: 0,
output_padding: 0,
stride: 1,
dilation: 1,
}
}
}
#[derive(Clone, Debug)]
pub struct ConvTranspose1d {
weight: Tensor,
bias: Option<Tensor>,
config: ConvTranspose1dConfig,
}
impl ConvTranspose1d {
pub fn new(weight: Tensor, bias: Option<Tensor>, config: ConvTranspose1dConfig) -> Self {
Self {
weight,
bias,
config,
}
}
pub fn config(&self) -> &ConvTranspose1dConfig {
&self.config
}
}
impl crate::Module for ConvTranspose1d {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv_transpose1d(
&self.weight,
self.config.padding,
self.config.output_padding,
self.config.stride,
self.config.dilation,
)?;
match &self.bias {
None => Ok(x),
Some(bias) => {
let b = bias.dims1()?;
let bias = bias.reshape((1, b, 1, 1))?;
Ok(x.broadcast_add(&bias)?)
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Conv2dConfig { pub struct Conv2dConfig {
pub padding: usize, pub padding: usize,
@ -302,39 +241,6 @@ pub fn conv1d(
Ok(Conv1d::new(ws, Some(bs), cfg)) Ok(Conv1d::new(ws, Some(bs), cfg))
} }
pub fn conv_transpose1d(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: ConvTranspose1dConfig,
vb: crate::VarBuilder,
) -> Result<ConvTranspose1d> {
let bound = 1. / (out_channels as f64 * kernel_size as f64).sqrt();
let init = crate::Init::Uniform {
lo: -bound,
up: bound,
};
let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?;
let bs = vb.get_with_hints(out_channels, "bias", init)?;
Ok(ConvTranspose1d::new(ws, Some(bs), cfg))
}
pub fn conv_transpose1d_no_bias(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: ConvTranspose1dConfig,
vb: crate::VarBuilder,
) -> Result<ConvTranspose1d> {
let bound = 1. / (out_channels as f64 * kernel_size as f64).sqrt();
let init = crate::Init::Uniform {
lo: -bound,
up: bound,
};
let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?;
Ok(ConvTranspose1d::new(ws, None, cfg))
}
pub fn conv2d( pub fn conv2d(
in_channels: usize, in_channels: usize,
out_channels: usize, out_channels: usize,

View File

@ -95,14 +95,6 @@ impl LayerNorm {
eps, eps,
} }
} }
pub fn weight(&self) -> &Tensor {
&self.weight
}
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
} }
impl crate::Module for LayerNorm { impl crate::Module for LayerNorm {

View File

@ -39,21 +39,11 @@ pub fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)? xs / (xs.neg()?.exp()? + 1.0)?
} }
pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
let xs = xs.chunk(2, candle::D::Minus1)?;
crate::ops::silu(&xs[0])? * &xs[1]
}
pub fn sigmoid(xs: &Tensor) -> Result<Tensor> { pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
// TODO: Should we have a specialized op for this? // TODO: Should we have a specialized op for this?
(xs.neg()?.exp()? + 1.0)?.recip() (xs.neg()?.exp()? + 1.0)?.recip()
} }
pub fn hard_sigmoid(xs: &Tensor) -> Result<Tensor> {
// TODO: Should we have a specialized op for this?
((xs + 3.0)? / 6.0)?.clamp(0f32, 1f32)
}
pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result<Tensor> { pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result<Tensor> {
let zeros = xs.zeros_like()?; let zeros = xs.zeros_like()?;
xs.maximum(&zeros)? + xs.minimum(&zeros)? * negative_slope xs.maximum(&zeros)? + xs.minimum(&zeros)? * negative_slope
@ -200,7 +190,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
device: dev.clone(), device: dev.clone(),
}; };
Ok((dst, layout.shape().clone())) Ok((dst, layout.shape().clone()))
} }
#[cfg(feature = "metal")] #[cfg(feature = "metal")]
fn metal_fwd( fn metal_fwd(
@ -208,29 +198,8 @@ impl candle::CustomOp1 for SoftmaxLastDim {
storage: &candle::MetalStorage, storage: &candle::MetalStorage,
layout: &Layout, layout: &Layout,
) -> Result<(candle::MetalStorage, Shape)> { ) -> Result<(candle::MetalStorage, Shape)> {
use candle::backend::{BackendStorage}; println!("TODO softmax-last-dim");
let device = storage.device(); Ok((storage.clone(), layout.shape().clone()))
let command_buffer = device.command_buffer();
let kernels = device.kernels();
let name = "softmax_float";
assert!(layout.is_contiguous());
assert!(layout.start_offset() == 0);
let last_dim = layout.dims()[layout.shape().rank() - 1];
let elem_count = layout.shape().elem_count();
let mut output = device.new_buffer(elem_count, storage.dtype());
candle_metal_kernels::call_last_softmax(
device.metal_device(),
&command_buffer,
&kernels,
name,
elem_count,
last_dim,
storage.buffer(),
&mut output,
)
.unwrap();
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
Ok((newstorage, layout.shape().clone()))
} }
} }

View File

@ -1,23 +0,0 @@
[package]
name = "candle-onnx"
version = "0.3.0"
edition = "2021"
description = "ONNX support for Candle"
repository = "https://github.com/huggingface/candle"
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
prost = "0.12.1"
[build-dependencies]
prost-build = "0.12.1"
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
clap = { version = "4.2.4", features = ["derive"] }

View File

@ -1,21 +0,0 @@
# candle-onnx
This crate adds ONNX support to candle
## FAQ
#### Missing protoc installation when compiling candle-onnx
The candle-onnx dependency prost-build no longer comes bundled with prost
binaries. This could cause the following error when attempting to compile
candle-onnx:
```
error: failed to run custom build command for `candle-onnx`
Caused by: // (...)
Could not find `protoc` installation and this build crate cannot proceed without this knowledge.
```
To fix this issue install protoc on your system and make it available in your
system `PATH`. See the [protoc
documentation](https://grpc.io/docs/protoc-installation/) for more information.

View File

@ -1,6 +0,0 @@
use std::io::Result;
fn main() -> Result<()> {
prost_build::compile_protos(&["src/onnx.proto3"], &["src/"])?;
Ok(())
}

View File

@ -1,755 +0,0 @@
use crate::onnx;
use crate::onnx::attribute_proto::AttributeType;
use crate::onnx::tensor_proto::DataType;
use candle::{bail, DType, Device, Result, Tensor};
use std::collections::HashMap;
pub type Value = Tensor;
pub fn dtype(dt: DataType) -> Option<DType> {
match dt {
DataType::Uint8 => Some(DType::U8),
DataType::Uint32 => Some(DType::U32),
DataType::Int64 => Some(DType::I64),
DataType::Float16 => Some(DType::F16),
DataType::Float => Some(DType::F32),
DataType::Double => Some(DType::F64),
_ => None,
}
}
trait Attr {
const TYPE: AttributeType;
fn get(attr: &onnx::AttributeProto) -> Result<&Self>;
}
impl Attr for i64 {
const TYPE: AttributeType = AttributeType::Int;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
Ok(&attr.i)
}
}
impl Attr for f32 {
const TYPE: AttributeType = AttributeType::Float;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
Ok(&attr.f)
}
}
impl Attr for [i64] {
const TYPE: AttributeType = AttributeType::Ints;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
Ok(attr.ints.as_slice())
}
}
impl Attr for str {
const TYPE: AttributeType = AttributeType::String;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
std::str::from_utf8(&attr.s).map_err(candle::Error::wrap)
}
}
fn get_attr_<'a>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a onnx::AttributeProto> {
match node.attribute.iter().find(|attr| attr.name == name) {
None => {
bail!(
"cannot find the '{name}' attribute in '{}' for {}",
node.op_type,
node.name
)
}
Some(dt) => Ok(dt),
}
}
fn get_attr<'a, T: Attr + ?Sized>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a T> {
let attr = get_attr_(node, name)?;
if attr.r#type() != T::TYPE {
bail!(
"unsupported type {:?} for '{name}' attribute in '{}' for {}",
attr.r#type,
node.op_type,
node.name
)
}
T::get(attr)
}
fn get_attr_opt<'a, T: Attr + ?Sized>(
node: &'a onnx::NodeProto,
name: &str,
) -> Result<Option<&'a T>> {
match node.attribute.iter().find(|attr| attr.name == name) {
None => Ok(None),
Some(attr) => {
if attr.r#type() != T::TYPE {
bail!(
"unsupported type {:?} for '{name}' attribute in '{}' for {}",
attr.r#type,
node.op_type,
node.name
)
}
let val = T::get(attr)?;
Ok(Some(val))
}
}
}
pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
match DataType::try_from(t.data_type) {
Ok(DataType::Int32) => {
if t.int32_data.is_empty() {
let len = t.raw_data.len() / 4;
let data: &[i32] =
unsafe { std::slice::from_raw_parts(t.raw_data.as_ptr() as *const i32, len) };
let data = data.iter().map(|v| *v as i64).collect::<Vec<_>>();
Tensor::from_vec(data, len, &Device::Cpu)
} else {
let data = t.int32_data.iter().map(|v| *v as i64).collect::<Vec<_>>();
Tensor::from_vec(data, t.int32_data.len(), &Device::Cpu)
}
}
Ok(dt) => match dtype(dt) {
Some(dt) => {
if dt == DType::F32 && !t.float_data.is_empty() {
Tensor::from_slice(&t.float_data, dims.as_slice(), &Device::Cpu)
} else if dt == DType::F64 && !t.double_data.is_empty() {
Tensor::from_slice(&t.double_data, dims.as_slice(), &Device::Cpu)
} else if dt == DType::I64 && !t.int64_data.is_empty() {
Tensor::from_slice(&t.int64_data, dims.as_slice(), &Device::Cpu)
} else {
Tensor::from_raw_buffer(
t.raw_data.as_slice(),
dt,
dims.as_slice(),
&Device::Cpu,
)
}
}
None => {
bail!("unsupported 'value' data-type {dt:?} for {name}")
}
},
Err(_) => {
bail!("unsupported 'value' data-type {} for {name}", t.data_type,)
}
}
}
// This function provides a direct evaluation of the proto.
// Longer-term, we should first convert the proto to an intermediate representation of the compute
// graph so as to make multiple evaluations more efficient.
// An example upside of this would be to remove intermediary values when they are not needed
// anymore.
pub fn simple_eval(
model: &onnx::ModelProto,
inputs: HashMap<String, Value>,
) -> Result<HashMap<String, Value>> {
let graph = match &model.graph {
None => bail!("no graph defined in proto"),
Some(graph) => graph,
};
let mut values = inputs;
for t in graph.initializer.iter() {
let tensor = get_tensor(t, t.name.as_str())?;
values.insert(t.name.to_string(), tensor);
}
for input in graph.input.iter() {
let input_type = match &input.r#type {
Some(input_type) => input_type,
None => continue,
};
let input_type = match &input_type.value {
Some(input_type) => input_type,
None => continue,
};
let tensor_type = match input_type {
onnx::type_proto::Value::TensorType(tt) => tt,
_ => continue,
};
let tensor = match values.get(&input.name) {
None => bail!("missing input {}", input.name),
Some(tensor) => tensor,
};
let dt = match DataType::try_from(tensor_type.elem_type) {
Ok(dt) => match dtype(dt) {
Some(dt) => dt,
None => {
bail!("unsupported 'value' data-type {dt:?} for {}", input.name)
}
},
type_ => bail!("unsupported input type {type_:?}"),
};
match &tensor_type.shape {
None => continue,
Some(shape) => {
if shape.dim.len() != tensor.rank() {
bail!(
"unexpected rank for {}, got {:?}, expected {:?}",
input.name,
shape.dim,
tensor.shape()
)
}
for (idx, (d, &dim)) in shape.dim.iter().zip(tensor.dims().iter()).enumerate() {
match &d.value {
Some(onnx::tensor_shape_proto::dimension::Value::DimValue(v)) => {
if *v as usize != dim {
bail!(
"unexpected dim {idx} for {}, got {:?}, expected {:?}",
input.name,
shape.dim,
tensor.shape()
)
}
}
// We do not check equality constraints for the DimParam dimensions for now.
Some(onnx::tensor_shape_proto::dimension::Value::DimParam(_)) | None => (),
}
}
}
};
if dt != tensor.dtype() {
bail!(
"unexpected dtype for {}, got {:?}, expected {dt:?}",
input.name,
tensor.dtype()
)
}
}
// The nodes are topologically sorted so we can just process them in order.
for node in graph.node.iter() {
let get = |input_name: &str| match values.get(input_name) {
Some(value) => Ok(value),
None => bail!("cannot find {input_name} for op {}", node.name),
};
// TODO: Validate node.input for each operator.
match node.op_type.as_str() {
"Add" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_add(input1)?;
values.insert(node.output[0].clone(), output);
}
"Sub" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_sub(input1)?;
values.insert(node.output[0].clone(), output);
}
"Mul" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_mul(input1)?;
values.insert(node.output[0].clone(), output);
}
"Div" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_div(input1)?;
values.insert(node.output[0].clone(), output);
}
"Equal" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_eq(input1)?;
values.insert(node.output[0].clone(), output);
}
"Not" => {
let xs = get(&node.input[0])?;
let xs = xs.eq(&xs.zeros_like()?)?;
values.insert(node.output[0].clone(), xs);
}
"MatMul" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_matmul(input1)?;
values.insert(node.output[0].clone(), output);
}
"Reshape" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?.to_vec1::<i64>()?;
// TODO: Check that there is at most a single -1 or 0, handle other neg values.
let mut other_than_minus1 = 1usize;
for &v in input1.iter() {
if v != -1 && v != 0 {
other_than_minus1 *= v as usize
}
}
let input1 = input1
.iter()
.enumerate()
.map(|(idx, &v)| match v {
-1 => Ok(input0.elem_count() / other_than_minus1),
0 => input0.dim(idx),
_ => Ok(v as usize),
})
.collect::<Result<Vec<usize>>>()?;
let output = input0.reshape(input1)?;
values.insert(node.output[0].clone(), output);
}
"LogSoftmax" => {
let input = get(&node.input[0])?;
let output = match get_attr_opt::<i64>(node, "axis")? {
None => candle_nn::ops::softmax_last_dim(input)?,
Some(&axis) => {
let axis = input.normalize_axis(axis)?;
candle_nn::ops::log_softmax(input, axis)?
}
};
values.insert(node.output[0].clone(), output);
}
"Softmax" => {
let input = get(&node.input[0])?;
let output = match get_attr_opt::<i64>(node, "axis")? {
None => candle_nn::ops::softmax_last_dim(input)?,
Some(&axis) => {
let axis = input.normalize_axis(axis)?;
candle_nn::ops::softmax(input, axis)?
}
};
values.insert(node.output[0].clone(), output);
}
"Transpose" => {
let input = get(&node.input[0])?;
let output = match get_attr_opt::<[i64]>(node, "perm")? {
None => input.t()?,
Some(perm) => {
let perm = perm.iter().map(|&v| v as usize).collect::<Vec<_>>();
input.permute(perm)?
}
};
values.insert(node.output[0].clone(), output);
}
"Dropout" => {
let input = get(&node.input[0])?;
// Do not apply dropout at the moment, consider that we're only doing inference.
values.insert(node.output[0].clone(), input.clone());
}
"MaxPool" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#MaxPool
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
let kernel_shape = get_attr::<[i64]>(node, "kernel_shape")?;
let pads = get_attr_opt::<[i64]>(node, "pads")?;
let strides = get_attr_opt::<[i64]>(node, "strides")?;
let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
match auto_pad {
None | Some("NOTSET") => (),
Some(s) => bail!("unsupported auto_pad {s}"),
};
if let Some(d) = dilations {
if d.iter().any(|&v| v != 1) {
bail!("MaxPool with dilation != 1, {dilations:?}")
}
}
if let Some(d) = pads {
if d.iter().any(|&v| v != 0) {
bail!("MaxPool with pads != 0, {pads:?}")
}
}
let xs = get(&node.input[0])?;
let (k1, k2) = match kernel_shape {
[k1, k2] => (*k1 as usize, *k2 as usize),
_ => bail!("only 2d MaxPool is supported, kernel shape {kernel_shape:?}"),
};
let ys = match strides {
None => xs.max_pool2d((k1, k2))?,
Some([s1, s2]) => {
xs.max_pool2d_with_stride((k1, k2), (*s1 as usize, *s2 as usize))?
}
Some(strides) => bail!("only 2d MaxPool is supported, strides {strides:?}"),
};
values.insert(node.output[0].clone(), ys);
}
"AveragePool" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
let kernel_shape = get_attr::<[i64]>(node, "kernel_shape")?;
let pads = get_attr_opt::<[i64]>(node, "pads")?;
let strides = get_attr_opt::<[i64]>(node, "strides")?;
let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
match auto_pad {
None | Some("NOTSET") => (),
Some(s) => bail!("unsupported auto_pad {s}"),
};
if let Some(d) = dilations {
if d.iter().any(|&v| v != 1) {
bail!("AvgPool with dilation != 1, {dilations:?}")
}
}
if let Some(d) = pads {
if d.iter().any(|&v| v != 0) {
bail!("AvgPool with pads != 0, {pads:?}")
}
}
let xs = get(&node.input[0])?;
let (k1, k2) = match kernel_shape {
[k1, k2] => (*k1 as usize, *k2 as usize),
_ => bail!("only 2d AvgPool is supported, kernel shape {kernel_shape:?}"),
};
let ys = match strides {
None => xs.avg_pool2d((k1, k2))?,
Some([s1, s2]) => {
xs.avg_pool2d_with_stride((k1, k2), (*s1 as usize, *s2 as usize))?
}
Some(strides) => bail!("only 2d AvgPool is supported, strides {strides:?}"),
};
values.insert(node.output[0].clone(), ys);
}
"BatchNormalization" => {
let training_mode = get_attr_opt::<i64>(node, "training_mode")?;
if training_mode.copied().unwrap_or(0) != 0 {
bail!("training mode is not supported for BatchNorm")
}
let eps = get_attr_opt::<f32>(node, "epsilon")?
.copied()
.unwrap_or(1e-5);
let xs = get(&node.input[0])?;
let weight = get(&node.input[1])?;
let bias = get(&node.input[2])?;
let running_mean = get(&node.input[3])?;
let running_var = get(&node.input[4])?;
let target_shape: Vec<usize> = xs
.dims()
.iter()
.enumerate()
.map(|(idx, v)| if idx == 1 { *v } else { 1 })
.collect();
let target_shape = target_shape.as_slice();
let xs = xs
.broadcast_sub(&running_mean.reshape(target_shape)?)?
.broadcast_div(&(running_var.reshape(target_shape)? + eps as f64)?.sqrt()?)?;
let weight = weight.reshape(target_shape)?;
let bias = bias.reshape(target_shape)?;
let xs = xs.broadcast_mul(&weight)?.broadcast_add(&bias)?;
values.insert(node.output[0].clone(), xs);
}
"Squeeze" => {
let xs = get(&node.input[0])?;
let mut axes = if node.input.len() <= 1 {
// contract all the dimensions with size 1 except the batch dim.
xs.dims()
.iter()
.enumerate()
.flat_map(|(idx, &s)| if s == 1 && idx > 0 { Some(idx) } else { None })
.collect()
} else {
get(&node.input[1])?
.to_vec1::<i64>()?
.iter()
.map(|&i| xs.normalize_axis(i))
.collect::<Result<Vec<_>>>()?
};
axes.sort();
let mut xs = xs.clone();
for &axis in axes.iter().rev() {
xs = xs.squeeze(axis)?
}
values.insert(node.output[0].clone(), xs);
}
"ConstantOfShape" => {
let dims = get(&node.input[0])?;
let shape = dims
.to_vec1::<i64>()?
.into_iter()
.map(|v| v as usize)
.collect::<Vec<_>>();
let xs = Tensor::zeros(shape, DType::F32, dims.device())?;
values.insert(node.output[0].clone(), xs);
}
"Unsqueeze" => {
let xs = get(&node.input[0])?;
let axes = match get_attr_opt::<[i64]>(node, "axes")? {
Some(axis) => axis.to_vec(),
None => get(&node.input[1])?.to_vec1::<i64>()?,
};
let mut axes = axes
.iter()
.map(|&i| {
if i == xs.rank() as i64 {
Ok(xs.rank())
} else {
xs.normalize_axis(i)
}
})
.collect::<Result<Vec<_>>>()?;
axes.sort();
let mut xs = xs.clone();
for &axis in axes.iter().rev() {
xs = xs.unsqueeze(axis)?
}
values.insert(node.output[0].clone(), xs);
}
"Clip" => {
let xs = get(&node.input[0])?;
let xs = if node.input.len() >= 2 {
let mins = get(&node.input[1])?;
xs.broadcast_maximum(mins)?
} else {
xs.clone()
};
let xs = if node.input.len() >= 3 {
let maxs = get(&node.input[2])?;
xs.broadcast_minimum(maxs)?
} else {
xs.clone()
};
values.insert(node.output[0].clone(), xs);
}
"Gather" => {
let xs = get(&node.input[0])?;
let indices = get(&node.input[1])?;
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
let axis = xs.normalize_axis(axis)?;
// TODO: Provide an op to handle the ONNX generalized gather op ideally in a
// differentiable way.
let xs = if indices.rank() == 0 {
let index = indices.to_vec0::<i64>()? as usize;
xs.narrow(axis, index, 1)?.squeeze(axis)?
} else {
todo!("implement gather for {xs:?} {indices:?} axis {axis}")
};
values.insert(node.output[0].clone(), xs);
}
"Shape" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape
let xs = get(&node.input[0])?;
let start = get_attr_opt::<i64>(node, "start")?.copied().unwrap_or(0);
let end = get_attr_opt::<i64>(node, "end")?.copied().unwrap_or(-1);
let start = xs.normalize_axis(start)?;
let end = xs.normalize_axis(end)?;
let mut dims = vec![];
for idx in start..=end {
dims.push(xs.dim(idx)? as i64)
}
let dims = Tensor::from_vec(dims, xs.rank(), xs.device())?;
values.insert(node.output[0].clone(), dims);
}
"Conv" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
let groups = get_attr_opt::<i64>(node, "group")?.copied().unwrap_or(1);
let _kernel_shape = get_attr_opt::<[i64]>(node, "kernel_shape")?;
let pads = get_attr_opt::<[i64]>(node, "pads")?;
let strides = get_attr_opt::<[i64]>(node, "strides")?;
let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
match auto_pad {
None | Some("NOTSET") => (),
Some(s) => bail!("unsupported auto_pad {s}"),
};
let xs = get(&node.input[0])?;
let ws = get(&node.input[1])?;
let ys = match ws.rank() {
3 => {
let (pads, xs) = match pads {
None => (0, xs.clone()),
Some([p]) => (*p as usize, xs.clone()),
Some([p1, p2]) => {
if p1 != p2 {
(0usize, xs.pad_with_zeros(2, *p1 as usize, *p2 as usize)?)
} else {
(*p1 as usize, xs.clone())
}
}
Some(pads) => {
bail!("more pads than expected in conv1d {pads:?} {}", node.name)
}
};
let strides = match strides {
None => 1,
Some([p]) => *p as usize,
Some(s) => {
bail!("more strides than expected in conv1d {s:?} {}", node.name)
}
};
let dilations = match dilations {
None => 1,
Some([p]) => *p as usize,
Some(s) => {
bail!("more dilations than expected in conv1d {s:?} {}", node.name)
}
};
xs.conv1d(ws, pads, strides, dilations, groups as usize)?
}
4 => {
let (pads, xs) = match pads {
None => (0, xs.clone()),
Some([p]) => (*p as usize, xs.clone()),
Some(&[p1, p2, p3, p4]) => {
let p1 = p1 as usize;
let p2 = p2 as usize;
let p3 = p3 as usize;
let p4 = p4 as usize;
if p1 != p2 || p1 != p3 || p1 != p4 {
(0, xs.pad_with_zeros(2, p1, p3)?.pad_with_zeros(3, p2, p4)?)
} else {
(p1, xs.clone())
}
}
Some(pads) => {
bail!("more pads than expected in conv2d {pads:?} {}", node.name)
}
};
let strides = match strides {
None => 1,
Some([p]) => *p as usize,
Some([p1, p2]) => {
if p1 != p2 {
bail!(
"strides have to be the same on both axis {pads:?} {}",
node.name
)
}
*p1 as usize
}
Some(s) => {
bail!("more strides than expected in conv2d {s:?} {}", node.name)
}
};
let dilations = match dilations {
None => 1,
Some([p]) => *p as usize,
Some([p1, p2]) => {
if p1 != p2 {
bail!(
"dilations have to be the same on both axis {pads:?} {}",
node.name
)
}
*p1 as usize
}
Some(s) => {
bail!("more dilations than expected in conv2d {s:?} {}", node.name)
}
};
xs.conv2d(ws, pads, strides, dilations, groups as usize)?
}
rank => bail!(
"unsupported rank for weight matrix {rank} in conv {}",
node.name
),
};
let ys = if node.input.len() > 2 {
let bs = get(&node.input[2])?;
let mut bs_shape = vec![1; ys.rank()];
bs_shape[1] = bs.elem_count();
ys.broadcast_add(&bs.reshape(bs_shape)?)?
} else {
ys
};
values.insert(node.output[0].clone(), ys);
}
"Concat" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Concat
let inputs = node
.input
.iter()
.map(|n| Ok(get(n.as_str())?.clone()))
.collect::<Result<Vec<Value>>>()?;
let axis: i64 = *get_attr(node, "axis")?;
if inputs.is_empty() {
bail!("empty concat")
};
let axis = inputs[0].normalize_axis(axis)?;
let output = Tensor::cat(&inputs, axis)?;
values.insert(node.output[0].clone(), output);
}
"Abs" => {
let input = get(&node.input[0])?;
let output = input.abs()?;
values.insert(node.output[0].clone(), output);
}
"Cos" => {
let input = get(&node.input[0])?;
let output = input.cos()?;
values.insert(node.output[0].clone(), output);
}
"Sin" => {
let input = get(&node.input[0])?;
let output = input.sin()?;
values.insert(node.output[0].clone(), output);
}
"Neg" => {
let input = get(&node.input[0])?;
let output = input.neg()?;
values.insert(node.output[0].clone(), output);
}
"Erf" => {
let input = get(&node.input[0])?;
let output = input.erf()?;
values.insert(node.output[0].clone(), output);
}
"Tanh" => {
let input = get(&node.input[0])?;
let output = input.tanh()?;
values.insert(node.output[0].clone(), output);
}
"Sigmoid" => {
let input = get(&node.input[0])?;
let output = candle_nn::ops::sigmoid(input)?;
values.insert(node.output[0].clone(), output);
}
"Gelu" => {
let input = get(&node.input[0])?;
let output = input.gelu_erf()?;
values.insert(node.output[0].clone(), output);
}
"Relu" => {
let input = get(&node.input[0])?;
let output = input.relu()?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Constant
"Constant" => {
let value = match node.attribute.iter().find(|attr| attr.name == "value") {
None => {
// TODO: support sparse_value etc.
bail!("cannot find 'value' attr in 'Constant' for {}", node.name)
}
Some(value) => value,
};
let output = match value.r#type() {
AttributeType::Tensor => {
let t = value.t.as_ref().unwrap();
get_tensor(t, &node.name)?
}
rtype => bail!("unsupported 'value' type {rtype:?} for {}", node.name),
};
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
"Cast" => {
let input = get(&node.input[0])?;
let dt: i64 = *get_attr(node, "to")?;
let dtype = match DataType::try_from(dt as i32) {
Ok(DataType::Int32) => DType::I64,
Ok(dt) => match dtype(dt) {
Some(dt) => dt,
None => {
bail!("unsupported 'to' value {dt:?} for cast {}", node.name)
}
},
Err(_) => {
bail!("unsupported 'to' value {dt:?} for cast {}", node.name)
}
};
let output = input.to_dtype(dtype)?;
values.insert(node.output[0].clone(), output);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}
graph
.output
.iter()
.map(|output| match values.remove(&output.name) {
None => bail!("cannot find output {}", output.name),
Some(value) => Ok((output.name.clone(), value)),
})
.collect()
}

View File

@ -1,14 +0,0 @@
use candle::Result;
use prost::Message;
pub mod onnx {
include!(concat!(env!("OUT_DIR"), "/onnx.rs"));
}
pub mod eval;
pub use eval::{dtype, simple_eval};
pub fn read_file<P: AsRef<std::path::Path>>(p: P) -> Result<onnx::ModelProto> {
let buf = std::fs::read(p)?;
onnx::ModelProto::decode(buf.as_slice()).map_err(candle::Error::wrap)
}

View File

@ -1,836 +0,0 @@
//
// WARNING: This file is automatically generated! Please edit onnx.in.proto.
//
// SPDX-License-Identifier: Apache-2.0
syntax = "proto3";
package onnx;
// Overview
//
// ONNX is an open specification that is comprised of the following components:
//
// 1) A definition of an extensible computation graph model.
// 2) Definitions of standard data types.
// 3) Definitions of built-in operators.
//
// This document describes the syntax of models and their computation graphs,
// as well as the standard data types. Together, they are referred to as the ONNX
// Intermediate Representation, or 'IR' for short.
//
// The normative semantic specification of the ONNX IR is found in docs/IR.md.
// Definitions of the built-in neural network operators may be found in docs/Operators.md.
// Notes
//
// Protobuf compatibility
//
// To simplify framework compatibility, ONNX is defined using the subset of protobuf
// that is compatible with both protobuf v2 and v3. This means that we do not use any
// protobuf features that are only available in one of the two versions.
//
// Here are the most notable contortions we have to carry out to work around
// these limitations:
//
// - No 'map' (added protobuf 3.0). We instead represent mappings as lists
// of key-value pairs, where order does not matter and duplicates
// are not allowed.
// Versioning
//
// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md
//
// To be compatible with both proto2 and proto3, we will use a version number
// that is not defined by the default value but an explicit enum number.
enum Version {
// proto3 requires the first enum value to be zero.
// We add this just to appease the compiler.
_START_VERSION = 0;
// The version field is always serialized and we will use it to store the
// version that the graph is generated from. This helps us set up version
// control.
// For the IR, we are using simple numbers starting with 0x00000001,
// which was the version we published on Oct 10, 2017.
IR_VERSION_2017_10_10 = 0x0000000000000001;
// IR_VERSION 2 published on Oct 30, 2017
// - Added type discriminator to AttributeProto to support proto3 users
IR_VERSION_2017_10_30 = 0x0000000000000002;
// IR VERSION 3 published on Nov 3, 2017
// - For operator versioning:
// - Added new message OperatorSetIdProto
// - Added opset_import in ModelProto
// - For vendor extensions, added domain in NodeProto
IR_VERSION_2017_11_3 = 0x0000000000000003;
// IR VERSION 4 published on Jan 22, 2019
// - Relax constraint that initializers should be a subset of graph inputs
// - Add type BFLOAT16
IR_VERSION_2019_1_22 = 0x0000000000000004;
// IR VERSION 5 published on March 18, 2019
// - Add message TensorAnnotation.
// - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters.
IR_VERSION_2019_3_18 = 0x0000000000000005;
// IR VERSION 6 published on Sep 19, 2019
// - Add support for sparse tensor constants stored in model.
// - Add message SparseTensorProto
// - Add sparse initializers
IR_VERSION_2019_9_19 = 0x0000000000000006;
// IR VERSION 7 published on May 8, 2020
// - Add support to allow function body graph to rely on multiple external opreator sets.
// - Add a list to promote inference graph's initializers to global and
// mutable variables. Global variables are visible in all graphs of the
// stored models.
// - Add message TrainingInfoProto to store initialization
// method and training algorithm. The execution of TrainingInfoProto
// can modify the values of mutable variables.
// - Implicitly add inference graph into each TrainingInfoProto's algorithm.
IR_VERSION_2020_5_8 = 0x0000000000000007;
// IR VERSION 8 published on July 30, 2021
// Introduce TypeProto.SparseTensor
// Introduce TypeProto.Optional
// Added a list of FunctionProtos local to the model
// Deprecated since_version and operator status from FunctionProto
IR_VERSION_2021_7_30 = 0x0000000000000008;
// IR VERSION 9 published on May 5, 2023
// Added AttributeProto to FunctionProto so that default attribute values can be set.
// Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ.
IR_VERSION = 0x0000000000000009;
}
// Attributes
//
// A named attribute containing either singular float, integer, string, graph,
// and tensor values, or repeated float, integer, string, graph, and tensor values.
// An AttributeProto MUST contain the name field, and *only one* of the
// following content fields, effectively enforcing a C/C++ union equivalent.
message AttributeProto {
reserved 12, 16 to 19;
reserved "v";
// Note: this enum is structurally identical to the OpSchema::AttrType
// enum defined in schema.h. If you rev one, you likely need to rev the other.
enum AttributeType {
UNDEFINED = 0;
FLOAT = 1;
INT = 2;
STRING = 3;
TENSOR = 4;
GRAPH = 5;
SPARSE_TENSOR = 11;
TYPE_PROTO = 13;
FLOATS = 6;
INTS = 7;
STRINGS = 8;
TENSORS = 9;
GRAPHS = 10;
SPARSE_TENSORS = 12;
TYPE_PROTOS = 14;
}
// The name field MUST be present for this version of the IR.
string name = 1; // namespace Attribute
// if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
// In this case, this AttributeProto does not contain data, and it's a reference of attribute
// in parent scope.
// NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph.
string ref_attr_name = 21;
// A human-readable documentation for this attribute. Markdown is allowed.
string doc_string = 13;
// The type field MUST be present for this version of the IR.
// For 0.0.1 versions of the IR, this field was not defined, and
// implementations needed to use has_field heuristics to determine
// which value field was in use. For IR_VERSION 0.0.2 or later, this
// field MUST be set and match the f|i|s|t|... field in use. This
// change was made to accommodate proto3 implementations.
AttributeType type = 20; // discriminator that indicates which field below is in use
// Exactly ONE of the following fields must be present for this version of the IR
float f = 2; // float
int64 i = 3; // int
bytes s = 4; // UTF-8 string
TensorProto t = 5; // tensor value
GraphProto g = 6; // graph
SparseTensorProto sparse_tensor = 22; // sparse tensor value
// Do not use field below, it's deprecated.
// optional ValueProto v = 12; // value - subsumes everything but graph
TypeProto tp = 14; // type proto
repeated float floats = 7; // list of floats
repeated int64 ints = 8; // list of ints
repeated bytes strings = 9; // list of UTF-8 strings
repeated TensorProto tensors = 10; // list of tensors
repeated GraphProto graphs = 11; // list of graph
repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors
repeated TypeProto type_protos = 15;// list of type protos
}
// Defines information on value, including the name, the type, and
// the shape of the value.
message ValueInfoProto {
// This field MUST be present in this version of the IR.
string name = 1; // namespace Value
// This field MUST be present in this version of the IR for
// inputs and outputs of the top-level graph.
TypeProto type = 2;
// A human-readable documentation for this value. Markdown is allowed.
string doc_string = 3;
}
// Nodes
//
// Computation graphs are made up of a DAG of nodes, which represent what is
// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
//
// For example, it can be a node of type "Conv" that takes in an image, a filter
// tensor and a bias tensor, and produces the convolved output.
message NodeProto {
repeated string input = 1; // namespace Value
repeated string output = 2; // namespace Value
// An optional identifier for this node in a graph.
// This field MAY be absent in ths version of the IR.
string name = 3; // namespace Node
// The symbolic identifier of the Operator to execute.
string op_type = 4; // namespace Operator
// The domain of the OperatorSet that specifies the operator named by op_type.
string domain = 7; // namespace Domain
// Additional named attributes.
repeated AttributeProto attribute = 5;
// A human-readable documentation for this node. Markdown is allowed.
string doc_string = 6;
}
// Training information
// TrainingInfoProto stores information for training a model.
// In particular, this defines two functionalities: an initialization-step
// and a training-algorithm-step. Initialization resets the model
// back to its original state as if no training has been performed.
// Training algorithm improves the model based on input data.
//
// The semantics of the initialization-step is that the initializers
// in ModelProto.graph and in TrainingInfoProto.algorithm are first
// initialized as specified by the initializers in the graph, and then
// updated by the "initialization_binding" in every instance in
// ModelProto.training_info.
//
// The field "algorithm" defines a computation graph which represents a
// training algorithm's step. After the execution of a
// TrainingInfoProto.algorithm, the initializers specified by "update_binding"
// may be immediately updated. If the targeted training algorithm contains
// consecutive update steps (such as block coordinate descent methods),
// the user needs to create a TrainingInfoProto for each step.
message TrainingInfoProto {
// This field describes a graph to compute the initial tensors
// upon starting the training process. Initialization graph has no input
// and can have multiple outputs. Usually, trainable tensors in neural
// networks are randomly initialized. To achieve that, for each tensor,
// the user can put a random number operator such as RandomNormal or
// RandomUniform in TrainingInfoProto.initialization.node and assign its
// random output to the specific tensor using "initialization_binding".
// This graph can also set the initializers in "algorithm" in the same
// TrainingInfoProto; a use case is resetting the number of training
// iteration to zero.
//
// By default, this field is an empty graph and its evaluation does not
// produce any output. Thus, no initializer would be changed by default.
GraphProto initialization = 1;
// This field represents a training algorithm step. Given required inputs,
// it computes outputs to update initializers in its own or inference graph's
// initializer lists. In general, this field contains loss node, gradient node,
// optimizer node, increment of iteration count.
//
// An execution of the training algorithm step is performed by executing the
// graph obtained by combining the inference graph (namely "ModelProto.graph")
// and the "algorithm" graph. That is, the actual
// input/initializer/output/node/value_info/sparse_initializer list of
// the training graph is the concatenation of
// "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer"
// and "algorithm.input/initializer/output/node/value_info/sparse_initializer"
// in that order. This combined graph must satisfy the normal ONNX conditions.
// Now, let's provide a visualization of graph combination for clarity.
// Let the inference graph (i.e., "ModelProto.graph") be
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d
// and the "algorithm" graph be
// tensor_d -> Add -> tensor_e
// The combination process results
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e
//
// Notice that an input of a node in the "algorithm" graph may reference the
// output of a node in the inference graph (but not the other way round). Also, inference
// node cannot reference inputs of "algorithm". With these restrictions, inference graph
// can always be run independently without training information.
//
// By default, this field is an empty graph and its evaluation does not
// produce any output. Evaluating the default training step never
// update any initializers.
GraphProto algorithm = 2;
// This field specifies the bindings from the outputs of "initialization" to
// some initializers in "ModelProto.graph.initializer" and
// the "algorithm.initializer" in the same TrainingInfoProto.
// See "update_binding" below for details.
//
// By default, this field is empty and no initializer would be changed
// by the execution of "initialization".
repeated StringStringEntryProto initialization_binding = 3;
// Gradient-based training is usually an iterative procedure. In one gradient
// descent iteration, we apply
//
// x = x - r * g
//
// where "x" is the optimized tensor, "r" stands for learning rate, and "g" is
// gradient of "x" with respect to a chosen loss. To avoid adding assignments
// into the training graph, we split the update equation into
//
// y = x - r * g
// x = y
//
// The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To
// tell that "y" should be assigned to "x", the field "update_binding" may
// contain a key-value pair of strings, "x" (key of StringStringEntryProto)
// and "y" (value of StringStringEntryProto).
// For a neural network with multiple trainable (mutable) tensors, there can
// be multiple key-value pairs in "update_binding".
//
// The initializers appears as keys in "update_binding" are considered
// mutable variables. This implies some behaviors
// as described below.
//
// 1. We have only unique keys in all "update_binding"s so that two
// variables may not have the same name. This ensures that one
// variable is assigned up to once.
// 2. The keys must appear in names of "ModelProto.graph.initializer" or
// "TrainingInfoProto.algorithm.initializer".
// 3. The values must be output names of "algorithm" or "ModelProto.graph.output".
// 4. Mutable variables are initialized to the value specified by the
// corresponding initializer, and then potentially updated by
// "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s.
//
// This field usually contains names of trainable tensors
// (in ModelProto.graph), optimizer states such as momentums in advanced
// stochastic gradient methods (in TrainingInfoProto.graph),
// and number of training iterations (in TrainingInfoProto.graph).
//
// By default, this field is empty and no initializer would be changed
// by the execution of "algorithm".
repeated StringStringEntryProto update_binding = 4;
}
// Models
//
// ModelProto is a top-level file/container format for bundling a ML model and
// associating its computation graph with metadata.
//
// The semantics of the model are described by the associated GraphProto's.
message ModelProto {
// The version of the IR this model targets. See Version enum above.
// This field MUST be present.
int64 ir_version = 1;
// The OperatorSets this model relies on.
// All ModelProtos MUST have at least one entry that
// specifies which version of the ONNX OperatorSet is
// being imported.
//
// All nodes in the ModelProto's graph will bind against the operator
// with the same-domain/same-op_type operator with the HIGHEST version
// in the referenced operator sets.
repeated OperatorSetIdProto opset_import = 8;
// The name of the framework or tool used to generate this model.
// This field SHOULD be present to indicate which implementation/tool/framework
// emitted the model.
string producer_name = 2;
// The version of the framework or tool used to generate this model.
// This field SHOULD be present to indicate which implementation/tool/framework
// emitted the model.
string producer_version = 3;
// Domain name of the model.
// We use reverse domain names as name space indicators. For example:
// `com.facebook.fair` or `com.microsoft.cognitiveservices`
//
// Together with `model_version` and GraphProto.name, this forms the unique identity of
// the graph.
string domain = 4;
// The version of the graph encoded. See Version enum below.
int64 model_version = 5;
// A human-readable documentation for this model. Markdown is allowed.
string doc_string = 6;
// The parameterized graph that is evaluated to execute the model.
GraphProto graph = 7;
// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 14;
// Training-specific information. Sequentially executing all stored
// `TrainingInfoProto.algorithm`s and assigning their outputs following
// the corresponding `TrainingInfoProto.update_binding`s is one training
// iteration. Similarly, to initialize the model
// (as if training hasn't happened), the user should sequentially execute
// all stored `TrainingInfoProto.initialization`s and assigns their outputs
// using `TrainingInfoProto.initialization_binding`s.
//
// If this field is empty, the training behavior of the model is undefined.
repeated TrainingInfoProto training_info = 20;
// A list of function protos local to the model.
//
// Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain".
// In case of any conflicts the behavior (whether the model local functions are given higher priority,
// or standard operator sets are given higher priotity or this is treated as error) is defined by
// the runtimes.
//
// The operator sets imported by FunctionProto should be compatible with the ones
// imported by ModelProto and other model local FunctionProtos.
// Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto
// or by 2 FunctionProtos then versions for the operator set may be different but,
// the operator schema returned for op_type, domain, version combination
// for both the versions should be same for every node in the function body.
//
// One FunctionProto can reference other FunctionProto in the model, however, recursive reference
// is not allowed.
repeated FunctionProto functions = 25;
};
// StringStringEntryProto follows the pattern for cross-proto-version maps.
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
message StringStringEntryProto {
string key = 1;
string value = 2;
};
message TensorAnnotation {
string tensor_name = 1;
// <key, value> pairs to annotate tensor specified by <tensor_name> above.
// The keys used in the mapping below must be pre-defined in ONNX spec.
// For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as
// quantization parameter keys.
repeated StringStringEntryProto quant_parameter_tensor_names = 2;
}
// Graphs
//
// A graph defines the computational logic of a model and is comprised of a parameterized
// list of nodes that form a directed acyclic graph based on their inputs and outputs.
// This is the equivalent of the "network" or "graph" in many deep learning
// frameworks.
message GraphProto {
// The nodes in the graph, sorted topologically.
repeated NodeProto node = 1;
// The name of the graph.
string name = 2; // namespace Graph
// A list of named tensor values, used to specify constant inputs of the graph.
// Each initializer (both TensorProto as well SparseTensorProto) MUST have a name.
// The name MUST be unique across both initializer and sparse_initializer,
// but the name MAY also appear in the input list.
repeated TensorProto initializer = 5;
// Initializers (see above) stored in sparse format.
repeated SparseTensorProto sparse_initializer = 15;
// A human-readable documentation for this graph. Markdown is allowed.
string doc_string = 10;
// The inputs and outputs of the graph.
repeated ValueInfoProto input = 11;
repeated ValueInfoProto output = 12;
// Information for the values in the graph. The ValueInfoProto.name's
// must be distinct. It is optional for a value to appear in value_info list.
repeated ValueInfoProto value_info = 13;
// This field carries information to indicate the mapping among a tensor and its
// quantization parameter tensors. For example:
// For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated,
// which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
repeated TensorAnnotation quantization_annotation = 14;
reserved 3, 4, 6 to 9;
reserved "ir_version", "producer_version", "producer_tag", "domain";
}
// Tensors
//
// A serialized tensor value.
message TensorProto {
enum DataType {
UNDEFINED = 0;
// Basic types.
FLOAT = 1; // float
UINT8 = 2; // uint8_t
INT8 = 3; // int8_t
UINT16 = 4; // uint16_t
INT16 = 5; // int16_t
INT32 = 6; // int32_t
INT64 = 7; // int64_t
STRING = 8; // string
BOOL = 9; // bool
// IEEE754 half-precision floating-point format (16 bits wide).
// This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
FLOAT16 = 10;
DOUBLE = 11;
UINT32 = 12;
UINT64 = 13;
COMPLEX64 = 14; // complex with float32 real and imaginary components
COMPLEX128 = 15; // complex with float64 real and imaginary components
// Non-IEEE floating-point format based on IEEE754 single-precision
// floating-point number truncated to 16 bits.
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
BFLOAT16 = 16;
// Non-IEEE floating-point format based on papers
// FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433,
// 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf.
// Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear.
// The computation usually happens inside a block quantize / dequantize
// fused by the runtime.
FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf
FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero
FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients
FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero
// Future extensions go here.
}
// The shape of the tensor.
repeated int64 dims = 1;
// The data type of the tensor.
// This field MUST have a valid TensorProto.DataType value
int32 data_type = 2;
// For very large tensors, we may want to store them in chunks, in which
// case the following fields will specify the segment that is stored in
// the current TensorProto.
message Segment {
int64 begin = 1;
int64 end = 2;
}
Segment segment = 3;
// Tensor content must be organized in row-major order.
//
// Depending on the data_type field, exactly one of the fields below with
// name ending in _data is used to store the elements of the tensor.
// For float and complex64 values
// Complex64 tensors are encoded as a single array of floats,
// with the real components appearing in odd numbered positions,
// and the corresponding imaginary component appearing in the
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
repeated float float_data = 4 [packed = true];
// For int32, uint8, int8, uint16, int16, bool, float8, and float16 values
// float16 and float8 values must be bit-wise converted to an uint16_t prior
// to writing to the buffer.
// When this field is present, the data_type field MUST be
// INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ
repeated int32 int32_data = 5 [packed = true];
// For strings.
// Each element of string_data is a UTF-8 encoded Unicode
// string. No trailing null, no leading BOM. The protobuf "string"
// scalar type is not used to match ML community conventions.
// When this field is present, the data_type field MUST be STRING
repeated bytes string_data = 6;
// For int64.
// When this field is present, the data_type field MUST be INT64
repeated int64 int64_data = 7 [packed = true];
// Optionally, a name for the tensor.
string name = 8; // namespace Value
// A human-readable documentation for this tensor. Markdown is allowed.
string doc_string = 12;
// Serializations can either use one of the fields above, or use this
// raw bytes field. The only exception is the string case, where one is
// required to store the content in the repeated bytes string_data field.
//
// When this raw_data field is used to store tensor value, elements MUST
// be stored in as fixed-width, little-endian order.
// Floating-point data types MUST be stored in IEEE 754 format.
// Complex64 elements must be written as two consecutive FLOAT values, real component first.
// Complex128 elements must be written as two consecutive DOUBLE values, real component first.
// Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
//
// Note: the advantage of specific field rather than the raw_data field is
// that in some cases (e.g. int data), protobuf does a better packing via
// variable length storage, and may lead to smaller binary footprint.
// When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
bytes raw_data = 9;
// Data can be stored inside the protobuf file using type-specific fields or raw_data.
// Alternatively, raw bytes data can be stored in an external file, using the external_data field.
// external_data stores key-value pairs describing data location. Recognized keys are:
// - "location" (required) - POSIX filesystem path relative to the directory where the ONNX
// protobuf model was stored
// - "offset" (optional) - position of byte at which stored data begins. Integer stored as string.
// Offset values SHOULD be multiples 4096 (page size) to enable mmap support.
// - "length" (optional) - number of bytes containing data. Integer stored as string.
// - "checksum" (optional) - SHA1 digest of file specified in under 'location' key.
repeated StringStringEntryProto external_data = 13;
// Location of the data for this tensor. MUST be one of:
// - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field.
// - EXTERNAL - data stored in an external location as described by external_data field.
enum DataLocation {
DEFAULT = 0;
EXTERNAL = 1;
}
// If value not set, data is stored in raw_data (if set) otherwise in type-specified field.
DataLocation data_location = 14;
// For double
// Complex128 tensors are encoded as a single array of doubles,
// with the real components appearing in odd numbered positions,
// and the corresponding imaginary component appearing in the
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
// When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
repeated double double_data = 10 [packed = true];
// For uint64 and uint32 values
// When this field is present, the data_type field MUST be
// UINT32 or UINT64
repeated uint64 uint64_data = 11 [packed = true];
}
// A serialized sparse-tensor value
message SparseTensorProto {
// The sequence of non-default values are encoded as a tensor of shape [NNZ].
// The default-value is zero for numeric tensors, and empty-string for string tensors.
// values must have a non-empty name present which serves as a name for SparseTensorProto
// when used in sparse_initializer list.
TensorProto values = 1;
// The indices of the non-default values, which may be stored in one of two formats.
// (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value
// corresponding to the j-th index of the i-th value (in the values tensor).
// (b) Indices can be a tensor of shape [NNZ], in which case the i-th value
// must be the linearized-index of the i-th value (in the values tensor).
// The linearized-index can be converted into an index tuple (k_1,...,k_rank)
// using the shape provided below.
// The indices must appear in ascending order without duplication.
// In the first format, the ordering is lexicographic-ordering:
// e.g., index-value [1,4] must appear before [2,1]
TensorProto indices = 2;
// The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank]
repeated int64 dims = 3;
}
// Defines a tensor shape. A dimension can be either an integer value
// or a symbolic variable. A symbolic variable represents an unknown
// dimension.
message TensorShapeProto {
message Dimension {
oneof value {
int64 dim_value = 1;
string dim_param = 2; // namespace Shape
};
// Standard denotation can optionally be used to denote tensor
// dimensions with standard semantic descriptions to ensure
// that operations are applied to the correct axis of a tensor.
// Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
// for pre-defined dimension denotations.
string denotation = 3;
};
repeated Dimension dim = 1;
}
// Types
//
// The standard ONNX data types.
message TypeProto {
message Tensor {
// This field MUST NOT have the value of UNDEFINED
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
int32 elem_type = 1;
TensorShapeProto shape = 2;
}
// repeated T
message Sequence {
// The type and optional shape of each element of the sequence.
// This field MUST be present for this version of the IR.
TypeProto elem_type = 1;
};
// map<K,V>
message Map {
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
// This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING
int32 key_type = 1;
// This field MUST be present for this version of the IR.
TypeProto value_type = 2;
};
// wrapper for Tensor, Sequence, or Map
message Optional {
// The type and optional shape of the element wrapped.
// This field MUST be present for this version of the IR.
// Possible values correspond to OptionalProto.DataType enum
TypeProto elem_type = 1;
};
message SparseTensor {
// This field MUST NOT have the value of UNDEFINED
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
int32 elem_type = 1;
TensorShapeProto shape = 2;
}
oneof value {
// The type of a tensor.
Tensor tensor_type = 1;
// NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values
// as input and output to graphs and nodes. These types are needed to naturally
// support classical ML operators. DNN operators SHOULD restrict their input
// and output types to tensors.
// The type of a sequence.
Sequence sequence_type = 4;
// The type of a map.
Map map_type = 5;
// The type of an optional.
Optional optional_type = 9;
// Type of the sparse tensor
SparseTensor sparse_tensor_type = 8;
}
// An optional denotation can be used to denote the whole
// type with a standard semantic description as to what is
// stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition
// for pre-defined type denotations.
string denotation = 6;
}
// Operator Sets
//
// OperatorSets are uniquely identified by a (domain, opset_version) pair.
message OperatorSetIdProto {
// The domain of the operator set being identified.
// The empty string ("") or absence of this field implies the operator
// set that is defined as part of the ONNX specification.
// This field MUST be present in this version of the IR when referring to any other operator set.
string domain = 1;
// The version of the operator set being identified.
// This field MUST be present in this version of the IR.
int64 version = 2;
}
// Operator/function status.
enum OperatorStatus {
EXPERIMENTAL = 0;
STABLE = 1;
}
message FunctionProto {
// The name of the function, similar usage of op_type in OperatorProto.
// Combined with FunctionProto.domain, this forms the unique identity of
// the FunctionProto.
string name = 1;
// Deprecated since IR Version 8
// optional int64 since_version = 2;
reserved 2;
reserved "since_version";
// Deprecated since IR Version 8
// optional OperatorStatus status = 3;
reserved 3;
reserved "status";
// The inputs and outputs of the function.
repeated string input = 4;
repeated string output = 5;
// The attribute parameters of the function.
// It is for function parameters without default values.
repeated string attribute = 6;
// The attribute protos of the function.
// It is for function attributes with default values.
// A function attribute shall be represented either as
// a string attribute or an AttributeProto, not both.
repeated AttributeProto attribute_proto = 11;
// The nodes in the function.
repeated NodeProto node = 7;
// A human-readable documentation for this function. Markdown is allowed.
string doc_string = 8;
// The OperatorSets this function body (graph) relies on.
//
// All nodes in the function body (graph) will bind against the operator
// with the same-domain/same-op_type operator with the HIGHEST version
// in the referenced operator sets. This means at most one version can be relied
// for one domain.
//
// The operator sets imported by FunctionProto should be compatible with the ones
// imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto
// and ModelProto then versions for the operator set may be different but,
// the operator schema returned for op_type, domain, version combination
// for both the versions should be same.
repeated OperatorSetIdProto opset_import = 9;
// The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of
// the FunctionProto.
string domain = 10;
}
// For using protobuf-lite
option optimize_for = LITE_RUNTIME;

View File

@ -17,7 +17,6 @@ crate-type = ["cdylib"]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" } candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.0" } candle-nn = { path = "../candle-nn", version = "0.3.0" }
candle-onnx = {path= "../candle-onnx", version = "0.3.0", optional = true}
half = { workspace = true } half = { workspace = true }
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] } pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
@ -30,5 +29,3 @@ default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"] accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"] cuda = ["candle/cuda"]
mkl = ["dep:intel-mkl-src","candle/mkl"] mkl = ["dep:intel-mkl-src","candle/mkl"]
onnx = ["dep:candle-onnx"]

View File

@ -1,5 +0,0 @@
# Generated content DO NOT EDIT
from .. import onnx
ONNXModel = onnx.ONNXModel
ONNXTensorDescription = onnx.ONNXTensorDescription

View File

@ -1,89 +0,0 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
from candle import Tensor, DType, QTensor
class ONNXModel:
"""
A wrapper around an ONNX model.
"""
def __init__(self, path: str):
pass
@property
def doc_string(self) -> str:
"""
The doc string of the model.
"""
pass
@property
def domain(self) -> str:
"""
The domain of the operator set of the model.
"""
pass
def initializers(self) -> Dict[str, Tensor]:
"""
Get the weights of the model.
"""
pass
@property
def inputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
"""
The inputs of the model.
"""
pass
@property
def ir_version(self) -> int:
"""
The version of the IR this model targets.
"""
pass
@property
def model_version(self) -> int:
"""
The version of the model.
"""
pass
@property
def outputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
"""
The outputs of the model.
"""
pass
@property
def producer_name(self) -> str:
"""
The producer of the model.
"""
pass
@property
def producer_version(self) -> str:
"""
The version of the producer of the model.
"""
pass
def run(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
Run the model on the given inputs.
"""
pass
class ONNXTensorDescription:
"""
A wrapper around an ONNX tensor description.
"""
@property
def dtype(self) -> DType:
"""
The data type of the tensor.
"""
pass
@property
def shape(self) -> Tuple[Union[int, str, Any]]:
"""
The shape of the tensor.
"""
pass

View File

@ -19,14 +19,12 @@ extern crate accelerate_src;
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType}; use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
mod utils;
use utils::wrap_err;
mod shape; mod shape;
use shape::{PyShape, PyShapeWithHole}; use shape::{PyShape, PyShapeWithHole};
#[cfg(feature = "onnx")] pub fn wrap_err(err: ::candle::Error) -> PyErr {
mod onnx; PyErr::new::<PyValueError, _>(format!("{err:?}"))
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
#[pyclass(name = "Tensor")] #[pyclass(name = "Tensor")]
@ -71,13 +69,11 @@ impl PyDType {
} }
static CUDA_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None); static CUDA_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
static METAL_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
#[derive(Clone, Copy, Debug, PartialEq, Eq)] #[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum PyDevice { enum PyDevice {
Cpu, Cpu,
Cuda, Cuda,
Metal,
} }
impl PyDevice { impl PyDevice {
@ -85,7 +81,7 @@ impl PyDevice {
match device { match device {
Device::Cpu => Self::Cpu, Device::Cpu => Self::Cpu,
Device::Cuda(_) => Self::Cuda, Device::Cuda(_) => Self::Cuda,
Device::Metal(_) => Self::Metal, Device::Metal(_) => unimplemented!(),
} }
} }
@ -101,15 +97,6 @@ impl PyDevice {
*device = Some(d.clone()); *device = Some(d.clone());
Ok(d) Ok(d)
} }
Self::Metal => {
let mut device = METAL_DEVICE.lock().unwrap();
if let Some(device) = device.as_ref() {
return Ok(device.clone());
};
let d = Device::new_metal(0).map_err(wrap_err)?;
*device = Some(d.clone());
Ok(d)
}
} }
} }
} }
@ -131,7 +118,6 @@ impl ToPyObject for PyDevice {
let str = match self { let str = match self {
PyDevice::Cpu => "cpu", PyDevice::Cpu => "cpu",
PyDevice::Cuda => "cuda", PyDevice::Cuda => "cuda",
PyDevice::Metal => "metal",
}; };
str.to_object(py) str.to_object(py)
} }
@ -1574,14 +1560,6 @@ fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
Ok(()) Ok(())
} }
#[cfg(feature = "onnx")]
fn candle_onnx_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
use onnx::{PyONNXModel, PyONNXTensorDescriptor};
m.add_class::<PyONNXModel>()?;
m.add_class::<PyONNXTensorDescriptor>()?;
Ok(())
}
#[pymodule] #[pymodule]
fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> { fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let utils = PyModule::new(py, "utils")?; let utils = PyModule::new(py, "utils")?;
@ -1590,12 +1568,6 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let nn = PyModule::new(py, "functional")?; let nn = PyModule::new(py, "functional")?;
candle_functional_m(py, nn)?; candle_functional_m(py, nn)?;
m.add_submodule(nn)?; m.add_submodule(nn)?;
#[cfg(feature = "onnx")]
{
let onnx = PyModule::new(py, "onnx")?;
candle_onnx_m(py, onnx)?;
m.add_submodule(onnx)?;
}
m.add_class::<PyTensor>()?; m.add_class::<PyTensor>()?;
m.add_class::<PyQTensor>()?; m.add_class::<PyQTensor>()?;
m.add_class::<PyDType>()?; m.add_class::<PyDType>()?;

View File

@ -1,212 +0,0 @@
use std::collections::HashMap;
use crate::utils::wrap_err;
use crate::{PyDType, PyTensor};
use candle_onnx::eval::{dtype, get_tensor, simple_eval};
use candle_onnx::onnx::tensor_proto::DataType;
use candle_onnx::onnx::tensor_shape_proto::dimension::Value;
use candle_onnx::onnx::type_proto::{Tensor as ONNXTensor, Value as ONNXValue};
use candle_onnx::onnx::{ModelProto, ValueInfoProto};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::{PyList, PyTuple};
#[derive(Clone, Debug)]
#[pyclass(name = "ONNXTensorDescription")]
/// A wrapper around an ONNX tensor description.
pub struct PyONNXTensorDescriptor(ONNXTensor);
#[pymethods]
impl PyONNXTensorDescriptor {
#[getter]
/// The data type of the tensor.
/// &RETURNS&: DType
fn dtype(&self) -> PyResult<PyDType> {
match DataType::try_from(self.0.elem_type) {
Ok(dt) => match dtype(dt) {
Some(dt) => Ok(PyDType(dt)),
None => Err(PyValueError::new_err(format!(
"unsupported 'value' data-type {dt:?}"
))),
},
type_ => Err(PyValueError::new_err(format!(
"unsupported input type {type_:?}"
))),
}
}
#[getter]
/// The shape of the tensor.
/// &RETURNS&: Tuple[Union[int,str,Any]]
fn shape(&self, py: Python) -> PyResult<Py<PyTuple>> {
let shape = PyList::empty(py);
if let Some(d) = &self.0.shape {
for dim in d.dim.iter() {
if let Some(value) = &dim.value {
match value {
Value::DimValue(v) => shape.append(*v)?,
Value::DimParam(s) => shape.append(s.clone())?,
};
} else {
return Err(PyValueError::new_err("None value in shape"));
}
}
}
Ok(shape.to_tuple().into())
}
fn __repr__(&self, py: Python) -> String {
match (self.shape(py), self.dtype()) {
(Ok(shape), Ok(dtype)) => format!(
"TensorDescriptor[shape: {:?}, dtype: {:?}]",
shape.to_string(),
dtype.__str__()
),
(Err(_), Err(_)) => "TensorDescriptor[shape: unknown, dtype: unknown]".to_string(),
(Err(_), Ok(dtype)) => format!(
"TensorDescriptor[shape: unknown, dtype: {:?}]",
dtype.__str__()
),
(Ok(shape), Err(_)) => format!(
"TensorDescriptor[shape: {:?}, dtype: unknown]",
shape.to_string()
),
}
}
fn __str__(&self, py: Python) -> String {
self.__repr__(py)
}
}
#[derive(Clone, Debug)]
#[pyclass(name = "ONNXModel")]
/// A wrapper around an ONNX model.
pub struct PyONNXModel(ModelProto);
fn extract_tensor_descriptions(
value_infos: &[ValueInfoProto],
) -> HashMap<String, PyONNXTensorDescriptor> {
let mut map = HashMap::new();
for value_info in value_infos.iter() {
let input_type = match &value_info.r#type {
Some(input_type) => input_type,
None => continue,
};
let input_type = match &input_type.value {
Some(input_type) => input_type,
None => continue,
};
let tensor_type: &ONNXTensor = match input_type {
ONNXValue::TensorType(tt) => tt,
_ => continue,
};
map.insert(
value_info.name.to_string(),
PyONNXTensorDescriptor(tensor_type.clone()),
);
}
map
}
#[pymethods]
impl PyONNXModel {
#[new]
#[pyo3(text_signature = "(self, path:str)")]
/// Load an ONNX model from the given path.
fn new(path: String) -> PyResult<Self> {
let model: ModelProto = candle_onnx::read_file(path).map_err(wrap_err)?;
Ok(PyONNXModel(model))
}
#[getter]
/// The version of the IR this model targets.
/// &RETURNS&: int
fn ir_version(&self) -> i64 {
self.0.ir_version
}
#[getter]
/// The producer of the model.
/// &RETURNS&: str
fn producer_name(&self) -> String {
self.0.producer_name.clone()
}
#[getter]
/// The version of the producer of the model.
/// &RETURNS&: str
fn producer_version(&self) -> String {
self.0.producer_version.clone()
}
#[getter]
/// The domain of the operator set of the model.
/// &RETURNS&: str
fn domain(&self) -> String {
self.0.domain.clone()
}
#[getter]
/// The version of the model.
/// &RETURNS&: int
fn model_version(&self) -> i64 {
self.0.model_version
}
#[getter]
/// The doc string of the model.
/// &RETURNS&: str
fn doc_string(&self) -> String {
self.0.doc_string.clone()
}
/// Get the weights of the model.
/// &RETURNS&: Dict[str, Tensor]
fn initializers(&self) -> PyResult<HashMap<String, PyTensor>> {
let mut map = HashMap::new();
if let Some(graph) = self.0.graph.as_ref() {
for tensor_description in graph.initializer.iter() {
let tensor = get_tensor(tensor_description, tensor_description.name.as_str())
.map_err(wrap_err)?;
map.insert(tensor_description.name.to_string(), PyTensor(tensor));
}
}
Ok(map)
}
#[getter]
/// The inputs of the model.
/// &RETURNS&: Optional[Dict[str, ONNXTensorDescription]]
fn inputs(&self) -> Option<HashMap<String, PyONNXTensorDescriptor>> {
if let Some(graph) = self.0.graph.as_ref() {
return Some(extract_tensor_descriptions(&graph.input));
}
None
}
#[getter]
/// The outputs of the model.
/// &RETURNS&: Optional[Dict[str, ONNXTensorDescription]]
fn outputs(&self) -> Option<HashMap<String, PyONNXTensorDescriptor>> {
if let Some(graph) = self.0.graph.as_ref() {
return Some(extract_tensor_descriptions(&graph.output));
}
None
}
#[pyo3(text_signature = "(self, inputs:Dict[str,Tensor])")]
/// Run the model on the given inputs.
/// &RETURNS&: Dict[str,Tensor]
fn run(&self, inputs: HashMap<String, PyTensor>) -> PyResult<HashMap<String, PyTensor>> {
let unwrapped_tensors = inputs.into_iter().map(|(k, v)| (k.clone(), v.0)).collect();
let result = simple_eval(&self.0, unwrapped_tensors).map_err(wrap_err)?;
Ok(result
.into_iter()
.map(|(k, v)| (k.clone(), PyTensor(v)))
.collect())
}
}

View File

@ -1,6 +0,0 @@
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
pub fn wrap_err(err: ::candle::Error) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}"))
}

View File

@ -21,7 +21,6 @@ rand = { workspace = true }
rayon = { workspace = true } rayon = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }
serde_plain = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }
wav = { workspace = true } wav = { workspace = true }
@ -29,5 +28,6 @@ wav = { workspace = true }
default = [] default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"] accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda"] cuda = ["candle/cuda", "candle-nn/cuda"]
metal = ["candle/metal", "candle-nn/metal"]
flash-attn = ["cuda", "dep:candle-flash-attn"] flash-attn = ["cuda", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]

View File

@ -1,4 +1,3 @@
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
use candle::{DType, Device, Result, Tensor}; use candle::{DType, Device, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder}; use candle_nn::{Embedding, Module, VarBuilder};
use serde::Deserialize; use serde::Deserialize;
@ -33,6 +32,76 @@ impl HiddenActLayer {
} }
} }
#[derive(Debug)]
pub struct Linear {
weight: Tensor,
bias: Option<Tensor>,
span: tracing::Span,
}
impl Linear {
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
let span = tracing::span!(tracing::Level::TRACE, "linear");
Self { weight, bias, span }
}
}
impl Module for Linear {
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let _enter = self.span.enter();
let w = match x.dims() {
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
_ => self.weight.t()?,
};
let x = x.matmul(&w)?;
match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),
}
}
}
#[derive(Debug)]
pub struct LayerNorm {
weight: Tensor,
bias: Tensor,
eps: f64,
span: tracing::Span,
}
impl LayerNorm {
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
Self {
weight,
bias,
eps,
span,
}
}
}
impl Module for LayerNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let (_bsize, _seq_len, hidden_size) = x.dims3()?;
let x = x.to_dtype(internal_dtype)?;
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
let x = x.broadcast_sub(&mean_x)?;
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
let x = x_normed
.to_dtype(x_dtype)?
.broadcast_mul(&self.weight)?
.broadcast_add(&self.bias)?;
Ok(x)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
enum PositionEmbeddingType { enum PositionEmbeddingType {
@ -115,6 +184,12 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
Ok(Embedding::new(embeddings, hidden_size)) Ok(Embedding::new(embeddings, hidden_size))
} }
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
let bias = vb.get(size2, "bias")?;
Ok(Linear::new(weight, Some(bias)))
}
struct Dropout { struct Dropout {
#[allow(dead_code)] #[allow(dead_code)]
pr: f64, pr: f64,
@ -133,6 +208,20 @@ impl Module for Dropout {
} }
} }
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
(Ok(weight), Ok(bias)) => (weight, bias),
(Err(err), _) | (_, Err(err)) => {
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
(weight, bias)
} else {
return Err(err);
}
}
};
Ok(LayerNorm::new(weight, bias, eps))
}
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180 // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
struct BertEmbeddings { struct BertEmbeddings {
word_embeddings: Embedding, word_embeddings: Embedding,

View File

@ -1,4 +1,3 @@
use super::with_tracing::{linear_no_bias as linear, Linear};
use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Module, VarBuilder}; use candle_nn::{Embedding, Module, VarBuilder};
use serde::Deserialize; use serde::Deserialize;
@ -82,6 +81,21 @@ impl Config {
} }
} }
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
// model.
#[derive(Debug)]
pub struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
}
impl Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
#[derive(Clone)] #[derive(Clone)]
pub struct Cache { pub struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>, masks: Arc<Mutex<HashMap<usize, Tensor>>>,
@ -136,6 +150,12 @@ impl Cache {
} }
} }
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let span = tracing::span!(tracing::Level::TRACE, "linear");
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
Ok(Linear { inner, span })
}
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> { fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?; let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
Ok(Embedding::new(embeddings, cfg.hidden_size)) Ok(Embedding::new(embeddings, cfg.hidden_size))

View File

@ -1,5 +1,6 @@
use super::with_tracing::{linear, Embedding, Linear}; #![allow(unused)]
use candle::{Result, Tensor}; use super::with_tracing::{linear, linear_no_bias, Embedding, Linear};
use candle::{Module, Result, Tensor};
use candle_nn::{layer_norm, LayerNorm, VarBuilder}; use candle_nn::{layer_norm, LayerNorm, VarBuilder};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -169,6 +170,7 @@ impl Attention {
kv_states: Option<&Tensor>, kv_states: Option<&Tensor>,
attn_mask: Option<&Tensor>, attn_mask: Option<&Tensor>,
) -> Result<Tensor> { ) -> Result<Tensor> {
let is_cross_attn = kv_states.is_some();
let (b_sz, tgt_len, _) = xs.dims3()?; let (b_sz, tgt_len, _) = xs.dims3()?;
let query_states = (xs.apply(&self.q_proj)? * self.scaling)?; let query_states = (xs.apply(&self.q_proj)? * self.scaling)?;
let (key_states, value_states) = match kv_states { let (key_states, value_states) = match kv_states {
@ -257,10 +259,6 @@ impl EncoderLayer {
.apply(&self.fc2)?; .apply(&self.fc2)?;
(xs + residual)?.apply(&self.final_layer_norm) (xs + residual)?.apply(&self.final_layer_norm)
} }
fn reset_kv_cache(&mut self) {
self.self_attn.reset_kv_cache()
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -322,11 +320,6 @@ impl DecoderLayer {
let xs = (xs + residual)?.apply(&self.final_layer_norm)?; let xs = (xs + residual)?.apply(&self.final_layer_norm)?;
Ok(xs) Ok(xs)
} }
fn reset_kv_cache(&mut self) {
self.self_attn.reset_kv_cache();
self.encoder_attn.reset_kv_cache()
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -375,12 +368,6 @@ impl Encoder {
} }
Ok(xs) Ok(xs)
} }
pub fn reset_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.reset_kv_cache()
}
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -435,12 +422,6 @@ impl Decoder {
} }
Ok(xs) Ok(xs)
} }
pub fn reset_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.reset_kv_cache()
}
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -461,11 +442,6 @@ impl Model {
decoder, decoder,
}) })
} }
fn reset_kv_cache(&mut self) {
self.encoder.reset_kv_cache();
self.decoder.reset_kv_cache();
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -513,8 +489,4 @@ impl MTModel {
.apply(&self.lm_head)? .apply(&self.lm_head)?
.broadcast_add(&self.final_logits_bias) .broadcast_add(&self.final_logits_bias)
} }
pub fn reset_kv_cache(&mut self) {
self.model.reset_kv_cache();
}
} }

View File

@ -29,10 +29,8 @@ pub mod segment_anything;
pub mod stable_diffusion; pub mod stable_diffusion;
pub mod stable_lm; pub mod stable_lm;
pub mod t5; pub mod t5;
pub mod trocr;
pub mod vgg; pub mod vgg;
pub mod vit; pub mod vit;
pub mod whisper; pub mod whisper;
pub mod with_tracing; pub mod with_tracing;
pub mod wuerstchen; pub mod wuerstchen;
pub mod yi;

View File

@ -2,7 +2,7 @@ use std::collections::HashMap;
use candle::quantized::QTensor; use candle::quantized::QTensor;
use candle::quantized::{ggml_file, gguf_file}; use candle::quantized::{ggml_file, gguf_file};
use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle::{Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Module}; use candle_nn::{Embedding, Module};
pub const MAX_SEQ_LEN: usize = 4096; pub const MAX_SEQ_LEN: usize = 4096;
@ -16,7 +16,7 @@ struct RmsNorm {
impl RmsNorm { impl RmsNorm {
fn new(scale: QTensor, eps: f32) -> Result<Self> { fn new(scale: QTensor, eps: f32) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let scale = scale.dequantize(&Device::Cpu)?; let scale = scale.dequantize(scale.device())?;
let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64); let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64);
Ok(Self { inner, span }) Ok(Self { inner, span })
} }
@ -79,6 +79,8 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
impl LayerWeights { impl LayerWeights {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let _enter = self.span_rot.enter(); let _enter = self.span_rot.enter();
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-cos");
let _enter = span.enter();
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?; let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
let cos = self let cos = self
.cos .cos
@ -88,21 +90,37 @@ impl LayerWeights {
.sin .sin
.narrow(0, index_pos, seq_len)? .narrow(0, index_pos, seq_len)?
.reshape((seq_len, n_embd / 2, 1))?; .reshape((seq_len, n_embd / 2, 1))?;
drop(_enter);
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-broad");
let _enter = span.enter();
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?; let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?; let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
drop(_enter);
// This mimics the llama.cpp behavior. // This mimics the llama.cpp behavior.
// https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105 // https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105
// The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension. // The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension.
// The resulting y0 and y1 are also interleaved with: // The resulting y0 and y1 are also interleaved with:
// y0 = x0*cos - x1*sin // y0 = x0*cos - x1*sin
// y1 = x0*sin + x1*cos // y1 = x0*sin + x1*cos
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-reshape");
let _enter = span.enter();
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?; let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
let x0 = x.narrow(D::Minus1, 0, 1)?; let x0 = x.narrow(D::Minus1, 0, 1)?;
let x1 = x.narrow(D::Minus1, 1, 1)?; let x1 = x.narrow(D::Minus1, 1, 1)?;
drop(_enter);
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-broad-mul");
let _enter = span.enter();
let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?; let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?; let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
drop(_enter);
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-cat");
let _enter = span.enter();
let rope = Tensor::cat(&[y0, y1], D::Minus1)?; let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
drop(_enter);
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-flatten");
let _enter = span.enter();
let rope = rope.flatten_from(D::Minus2)?; let rope = rope.flatten_from(D::Minus2)?;
drop(_enter);
Ok(rope) Ok(rope)
} }
@ -112,6 +130,7 @@ impl LayerWeights {
let q = self.attention_wq.forward(x)?; let q = self.attention_wq.forward(x)?;
let k = self.attention_wk.forward(x)?; let k = self.attention_wk.forward(x)?;
let v = self.attention_wv.forward(x)?; let v = self.attention_wv.forward(x)?;
// println!("Q {:?} K {:?} V {:?}", q.dtype(), k.dtype(), v.dtype());
let q = q let q = q
.reshape((b_sz, seq_len, self.n_head, self.head_dim))? .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
@ -145,9 +164,12 @@ impl LayerWeights {
let v = self.repeat_kv(v)?; let v = self.repeat_kv(v)?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
// println!("att {:?}", att.dtype());
let mask = mask.broadcast_as(att.shape())?; let mask = mask.broadcast_as(att.shape())?;
// println!("mask {:?}", mask.dtype());
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = candle_nn::ops::softmax_last_dim(&att)?; let att = candle_nn::ops::softmax_last_dim(&att)?;
// println!("att {:?} v {:?}", att.dtype(), v.dtype());
// Convert to contiguous as matmul doesn't support strided vs for now. // Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?; let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
@ -181,28 +203,37 @@ pub struct ModelWeights {
span_output: tracing::Span, span_output: tracing::Span,
} }
fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> { fn precomput_freqs_cis(
head_dim: usize,
freq_base: f32,
device: &Device,
) -> Result<(Tensor, Tensor)> {
let theta: Vec<_> = (0..head_dim) let theta: Vec<_> = (0..head_dim)
.step_by(2) .step_by(2)
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32)) .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
.collect(); .collect();
let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?; let theta = Tensor::new(theta.as_slice(), device)?;
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)? let range: Vec<f32> = (0..MAX_SEQ_LEN).map(|r| r as f32).collect();
.to_dtype(DType::F32)? let idx_theta = Tensor::new(range.as_slice(), device)?
.reshape((MAX_SEQ_LEN, 1))? .reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?; .matmul(&theta.reshape((1, theta.elem_count()))?)?;
// TODO This change avoids allocating on Metal and then casting since allocating directly on
// CPU as f32 seems just as fast
// let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
// .to_dtype(DType::F32)?
// .reshape((MAX_SEQ_LEN, 1))?
// .matmul(&theta.reshape((1, theta.elem_count()))?)?;
let cos = idx_theta.cos()?; let cos = idx_theta.cos()?;
let sin = idx_theta.sin()?; let sin = idx_theta.sin()?;
Ok((cos, sin)) Ok((cos, sin))
} }
impl ModelWeights { impl ModelWeights {
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> { pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize, device: &Device) -> Result<Self> {
let cpu = &Device::Cpu;
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?; let (cos, sin) = precomput_freqs_cis(head_dim, 10000., device)?;
let tok_embeddings = ct.remove("tok_embeddings.weight")?; let tok_embeddings = ct.remove("tok_embeddings.weight")?;
let tok_embeddings = tok_embeddings.dequantize(cpu)?; let tok_embeddings = tok_embeddings.dequantize(device)?;
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?; let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
let output = ct.remove("output.weight")?; let output = ct.remove("output.weight")?;
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize); let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
@ -257,8 +288,8 @@ impl ModelWeights {
pub fn from_gguf<R: std::io::Seek + std::io::Read>( pub fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content, ct: gguf_file::Content,
reader: &mut R, reader: &mut R,
device: &Device,
) -> Result<Self> { ) -> Result<Self> {
let cpu = &Device::Cpu;
let md_get = |s: &str| match ct.metadata.get(s) { let md_get = |s: &str| match ct.metadata.get(s) {
None => candle::bail!("cannot find {s} in metadata"), None => candle::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v), Some(v) => Ok(v),
@ -276,24 +307,31 @@ impl ModelWeights {
let rope_freq_base = md_get("llama.rope.freq_base") let rope_freq_base = md_get("llama.rope.freq_base")
.and_then(|m| m.to_f32()) .and_then(|m| m.to_f32())
.unwrap_or(10000f32); .unwrap_or(10000f32);
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?; let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
let tok_embeddings = ct.tensor(reader, "token_embd.weight")?; let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(cpu)?; let tok_embeddings = tok_embeddings.dequantize(device)?;
let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?; let norm = RmsNorm::new(
let output = ct.tensor(reader, "output.weight")?; ct.tensor(reader, "output_norm.weight", device)?,
rms_norm_eps,
)?;
let output = ct.tensor(reader, "output.weight", device)?;
let mut layers = Vec::with_capacity(block_count); let mut layers = Vec::with_capacity(block_count);
for layer_idx in 0..block_count { for layer_idx in 0..block_count {
let prefix = format!("blk.{layer_idx}"); let prefix = format!("blk.{layer_idx}");
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?; let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?; let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?; let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?; let attention_wo =
let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?; ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?; let feed_forward_w1 =
let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?; ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?; let feed_forward_w2 =
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?; ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
let attention_norm =
ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?;
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
@ -331,14 +369,14 @@ impl ModelWeights {
}) })
} }
fn mask(&mut self, t: usize) -> Result<Tensor> { fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> {
if let Some(mask) = self.masks.get(&t) { if let Some(mask) = self.masks.get(&t) {
Ok(mask.clone()) Ok(mask.clone())
} else { } else {
let mask: Vec<_> = (0..t) let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u8::from(j > i))) .flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect(); .collect();
let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?; let mask = Tensor::from_slice(&mask, (t, t), device)?;
self.masks.insert(t, mask.clone()); self.masks.insert(t, mask.clone());
Ok(mask) Ok(mask)
} }
@ -346,7 +384,7 @@ impl ModelWeights {
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> { pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, seq_len) = x.dims2()?; let (_b_sz, seq_len) = x.dims2()?;
let mask = self.mask(seq_len)?; let mask = self.mask(seq_len, x.device())?;
let _enter = self.span.enter(); let _enter = self.span.enter();
let mut layer_in = self.tok_embeddings.forward(x)?; let mut layer_in = self.tok_embeddings.forward(x)?;
for layer in self.layers.iter_mut() { for layer in self.layers.iter_mut() {

View File

@ -1,7 +1,6 @@
// T5 Text Model, quantized version // T5 Text Model, quantized version
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py // https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
use crate::models::t5::{deserialize_feed_forward_proj_activation, ActivationWithOptionalGating};
use crate::models::with_tracing::QMatMul; use crate::models::with_tracing::QMatMul;
use crate::quantized_nn::Embedding; use crate::quantized_nn::Embedding;
pub use crate::quantized_var_builder::VarBuilder; pub use crate::quantized_var_builder::VarBuilder;
@ -55,8 +54,8 @@ pub struct Config {
dropout_rate: f64, dropout_rate: f64,
layer_norm_epsilon: f64, layer_norm_epsilon: f64,
initializer_factor: f64, initializer_factor: f64,
#[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")] #[serde(default)]
pub feed_forward_proj: ActivationWithOptionalGating, feed_forward_proj: Activation,
#[serde(default = "default_tie_word_embeddings")] #[serde(default = "default_tie_word_embeddings")]
tie_word_embeddings: bool, tie_word_embeddings: bool,
#[serde(default = "default_is_decoder")] #[serde(default = "default_is_decoder")]
@ -66,7 +65,6 @@ pub struct Config {
pub use_cache: bool, pub use_cache: bool,
pub pad_token_id: usize, pub pad_token_id: usize,
pub eos_token_id: usize, pub eos_token_id: usize,
pub decoder_start_token_id: Option<usize>,
} }
impl Default for Config { impl Default for Config {
@ -84,17 +82,13 @@ impl Default for Config {
dropout_rate: 0.1, dropout_rate: 0.1,
layer_norm_epsilon: 1e-6, layer_norm_epsilon: 1e-6,
initializer_factor: 1.0, initializer_factor: 1.0,
feed_forward_proj: ActivationWithOptionalGating { feed_forward_proj: Activation::Relu,
gated: false,
activation: Activation::Relu,
},
tie_word_embeddings: true, tie_word_embeddings: true,
is_decoder: false, is_decoder: false,
is_encoder_decoder: true, is_encoder_decoder: true,
use_cache: true, use_cache: true,
pad_token_id: 0, pad_token_id: 0,
eos_token_id: 1, eos_token_id: 1,
decoder_start_token_id: Some(0),
} }
} }
} }
@ -180,7 +174,7 @@ impl T5DenseGatedActDense {
wi_0, wi_0,
wi_1, wi_1,
wo, wo,
act: cfg.feed_forward_proj.activation, act: Activation::NewGelu,
span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"), span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
}) })
} }
@ -209,7 +203,7 @@ impl T5LayerFF {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let layer_norm = let layer_norm =
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj.gated { let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu {
( (
None, None,
Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?), Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
@ -648,12 +642,7 @@ pub struct T5EncoderModel {
impl T5EncoderModel { impl T5EncoderModel {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let shared_vb = if vb.contains_key("shared.weight") { let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
vb.pp("shared")
} else {
vb.pp("decoder").pp("embed_tokens")
};
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
let shared = Arc::new(shared); let shared = Arc::new(shared);
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?; let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
Ok(Self { Ok(Self {
@ -694,12 +683,7 @@ impl T5ForConditionalGeneration {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
assert!(cfg.is_encoder_decoder); assert!(cfg.is_encoder_decoder);
let d_model = cfg.d_model; let d_model = cfg.d_model;
let shared_vb = if vb.contains_key("shared.weight") { let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
vb.pp("shared")
} else {
vb.pp("decoder").pp("embed_tokens")
};
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
let shared = Arc::new(shared); let shared = Arc::new(shared);
let mut encoder_cfg = cfg.clone(); let mut encoder_cfg = cfg.clone();

View File

@ -1,4 +1,3 @@
pub use crate::models::with_tracing::Linear;
use candle::{Result, Tensor}; use candle::{Result, Tensor};
use candle_nn::{Module, VarBuilder}; use candle_nn::{Module, VarBuilder};
@ -10,11 +9,13 @@ pub mod tiny_vit;
pub mod transformer; pub mod transformer;
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> { pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
if bias { let inner = if bias {
crate::models::with_tracing::linear(in_dim, out_dim, vb) candle_nn::linear(in_dim, out_dim, vb)?
} else { } else {
crate::models::with_tracing::linear_no_bias(in_dim, out_dim, vb) candle_nn::linear_no_bias(in_dim, out_dim, vb)?
} };
let span = tracing::span!(tracing::Level::TRACE, "linear");
Ok(Linear { inner, span })
} }
#[derive(Debug)] #[derive(Debug)]
@ -84,3 +85,16 @@ impl Module for MlpBlock {
.apply(&self.lin2) .apply(&self.lin2)
} }
} }
#[derive(Debug)]
pub struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
}
impl Module for Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}

View File

@ -102,14 +102,6 @@ impl Config {
} }
} }
pub fn ssd1b() -> Self {
Self::sdxl()
}
pub fn ssd1b2() -> Self {
Self::sdxl2()
}
// https://huggingface.co/warp-ai/wuerstchen/blob/main/text_encoder/config.json // https://huggingface.co/warp-ai/wuerstchen/blob/main/text_encoder/config.json
pub fn wuerstchen() -> Self { pub fn wuerstchen() -> Self {
Self { Self {

View File

@ -249,71 +249,6 @@ impl StableDiffusionConfig {
) )
} }
pub fn ssd1b(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
) -> Self {
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
};
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![
bc(320, None, 5),
bc(640, Some(2), 10),
bc(1280, Some(10), 20),
],
center_input_sample: false,
cross_attention_dim: 2048,
downsample_padding: 1,
flip_sin_to_cos: true,
freq_shift: 0.,
layers_per_block: 2,
mid_block_scale_factor: 1.,
norm_eps: 1e-5,
norm_num_groups: 32,
sliced_attention_size,
use_linear_projection: true,
};
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/vae/config.json
let autoencoder = vae::AutoEncoderKLConfig {
block_out_channels: vec![128, 256, 512, 512],
layers_per_block: 2,
latent_channels: 4,
norm_num_groups: 32,
};
let scheduler = ddim::DDIMSchedulerConfig {
..Default::default()
};
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
1024
};
let width = if let Some(width) = width {
assert_eq!(width % 8, 0, "width has to be divisible by 8");
width
} else {
1024
};
Self {
width,
height,
clip: clip::Config::ssd1b(),
clip2: Some(clip::Config::ssd1b2()),
autoencoder,
scheduler,
unet,
}
}
pub fn build_vae<P: AsRef<std::path::Path>>( pub fn build_vae<P: AsRef<std::path::Path>>(
&self, &self,
vae_weights: P, vae_weights: P,

View File

@ -37,37 +37,6 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
Ok(m) Ok(m)
} }
#[derive(Debug, Deserialize, Default, Clone, PartialEq)]
pub struct ActivationWithOptionalGating {
pub gated: bool,
pub activation: candle_nn::Activation,
}
pub fn deserialize_feed_forward_proj_activation<'de, D>(
deserializer: D,
) -> std::result::Result<ActivationWithOptionalGating, D::Error>
where
D: serde::de::Deserializer<'de>,
{
match String::deserialize(deserializer)?.as_str() {
"gated-gelu" => Ok(ActivationWithOptionalGating {
gated: true,
activation: candle_nn::Activation::NewGelu,
}),
"gated-silu" => Ok(ActivationWithOptionalGating {
gated: true,
activation: candle_nn::Activation::Silu,
}),
buf => {
let activation = serde_plain::from_str(buf).map_err(serde::de::Error::custom)?;
Ok(ActivationWithOptionalGating {
gated: false,
activation,
})
}
}
}
#[derive(Debug, Clone, PartialEq, Deserialize)] #[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config { pub struct Config {
vocab_size: usize, vocab_size: usize,
@ -83,8 +52,8 @@ pub struct Config {
dropout_rate: f64, dropout_rate: f64,
layer_norm_epsilon: f64, layer_norm_epsilon: f64,
initializer_factor: f64, initializer_factor: f64,
#[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")] #[serde(default)]
feed_forward_proj: ActivationWithOptionalGating, feed_forward_proj: Activation,
#[serde(default = "default_tie_word_embeddings")] #[serde(default = "default_tie_word_embeddings")]
tie_word_embeddings: bool, tie_word_embeddings: bool,
#[serde(default = "default_is_decoder")] #[serde(default = "default_is_decoder")]
@ -94,7 +63,6 @@ pub struct Config {
pub use_cache: bool, pub use_cache: bool,
pub pad_token_id: usize, pub pad_token_id: usize,
pub eos_token_id: usize, pub eos_token_id: usize,
pub decoder_start_token_id: Option<usize>,
} }
impl Default for Config { impl Default for Config {
@ -112,17 +80,13 @@ impl Default for Config {
dropout_rate: 0.1, dropout_rate: 0.1,
layer_norm_epsilon: 1e-6, layer_norm_epsilon: 1e-6,
initializer_factor: 1.0, initializer_factor: 1.0,
feed_forward_proj: ActivationWithOptionalGating { feed_forward_proj: Activation::Relu,
gated: false,
activation: Activation::Relu,
},
tie_word_embeddings: true, tie_word_embeddings: true,
is_decoder: false, is_decoder: false,
is_encoder_decoder: true, is_encoder_decoder: true,
use_cache: true, use_cache: true,
pad_token_id: 0, pad_token_id: 0,
eos_token_id: 1, eos_token_id: 1,
decoder_start_token_id: Some(0),
} }
} }
} }
@ -136,10 +100,7 @@ impl Config {
d_model: 768, d_model: 768,
dropout_rate: 0.1, dropout_rate: 0.1,
eos_token_id: 1, eos_token_id: 1,
feed_forward_proj: ActivationWithOptionalGating { feed_forward_proj: Activation::Relu,
gated: false,
activation: Activation::Relu,
},
tie_word_embeddings: true, tie_word_embeddings: true,
initializer_factor: 1.0, initializer_factor: 1.0,
is_decoder: false, is_decoder: false,
@ -149,7 +110,6 @@ impl Config {
num_heads: 12, num_heads: 12,
num_layers: 12, num_layers: 12,
pad_token_id: 0, pad_token_id: 0,
decoder_start_token_id: Some(0),
relative_attention_max_distance: 128, relative_attention_max_distance: 128,
relative_attention_num_buckets: 32, relative_attention_num_buckets: 32,
use_cache: true, use_cache: true,
@ -239,7 +199,7 @@ impl T5DenseGatedActDense {
wi_0, wi_0,
wi_1, wi_1,
wo, wo,
act: cfg.feed_forward_proj.activation, act: Activation::NewGelu,
span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"), span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
}) })
} }
@ -268,7 +228,7 @@ impl T5LayerFF {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let layer_norm = let layer_norm =
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj.gated { let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu {
( (
None, None,
Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?), Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
@ -462,7 +422,7 @@ impl T5Attention {
self.relative_attention_max_distance as f32 self.relative_attention_max_distance as f32
/ max_exact as f32, / max_exact as f32,
) * (num_buckets - max_exact) as f32; ) * (num_buckets - max_exact) as f32;
u32::min(max_exact + b as u32, num_buckets - 1) max_exact + b as u32
} }
}) })
.collect::<Vec<u32>>() .collect::<Vec<u32>>()
@ -707,12 +667,7 @@ pub struct T5EncoderModel {
impl T5EncoderModel { impl T5EncoderModel {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let shared_vb = if vb.contains_tensor("shared.weight") { let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
vb.pp("shared")
} else {
vb.pp("decoder").pp("embed_tokens")
};
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
let shared = Arc::new(shared); let shared = Arc::new(shared);
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?; let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
Ok(Self { Ok(Self {
@ -753,12 +708,7 @@ impl T5ForConditionalGeneration {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
assert!(cfg.is_encoder_decoder); assert!(cfg.is_encoder_decoder);
let d_model = cfg.d_model; let d_model = cfg.d_model;
let shared_vb = if vb.contains_tensor("shared.weight") { let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
vb.pp("shared")
} else {
vb.pp("decoder").pp("embed_tokens")
};
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
let shared = Arc::new(shared); let shared = Arc::new(shared);
let mut encoder_cfg = cfg.clone(); let mut encoder_cfg = cfg.clone();

View File

@ -1,434 +0,0 @@
use crate::models::vit::{Config, Embeddings, Encoder};
use candle::{Result, Tensor};
use candle_nn::{
embedding, layer_norm, linear_no_bias, Embedding, LayerNorm, Linear, Module, VarBuilder,
};
use serde::Deserialize;
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct TrOCRConfig {
pub vocab_size: usize,
pub d_model: usize,
pub hidden_size: usize,
pub decoder_layers: usize,
pub decoder_attention_heads: usize,
pub decoder_ffn_dim: usize,
pub activation_function: candle_nn::Activation,
pub max_position_embeddings: usize,
pub dropout: f64,
pub attention_dropout: f64,
pub activation_dropout: f64,
pub decoder_start_token_id: u32,
pub init_std: f64,
pub decoder_layerdrop: f64,
pub use_cache: bool,
pub scale_embedding: bool,
pub use_learned_position_embeddings: bool,
pub layernorm_embedding: bool,
pub pad_token_id: usize,
pub bos_token_id: usize,
pub eos_token_id: u32,
pub num_attention_heads: usize,
pub decoder_vocab_size: Option<usize>,
}
impl Default for TrOCRConfig {
fn default() -> Self {
Self {
vocab_size: 50265,
d_model: 1024,
hidden_size: 768,
decoder_layers: 12,
decoder_attention_heads: 16,
decoder_ffn_dim: 4096,
activation_function: candle_nn::Activation::Gelu,
max_position_embeddings: 512,
dropout: 0.1,
attention_dropout: 0.0,
activation_dropout: 0.0,
decoder_start_token_id: 2,
init_std: 0.02,
decoder_layerdrop: 0.0,
use_cache: true,
scale_embedding: false,
use_learned_position_embeddings: true,
layernorm_embedding: true,
pad_token_id: 1,
bos_token_id: 0,
eos_token_id: 2,
num_attention_heads: 12,
decoder_vocab_size: Some(50265),
}
}
}
#[derive(Debug, Clone)]
struct TrOCRLearnedPositionalEmbedding {
offset: usize,
weights: Embedding,
}
impl TrOCRLearnedPositionalEmbedding {
fn load(vb: VarBuilder, cfg: &TrOCRConfig) -> Result<Self> {
let offset: usize = 2;
let num_embeddings = cfg.max_position_embeddings;
let embedding_dim = cfg.d_model;
let weights = embedding(num_embeddings + offset, embedding_dim, vb)?;
Ok(Self { offset, weights })
}
fn forward(&mut self, input_ids: &Tensor, past_key_values_length: u32) -> Result<Tensor> {
let (b_sz, seq_len) = input_ids.dims2()?;
let mut positions = Tensor::arange(
past_key_values_length,
seq_len as u32 + past_key_values_length,
input_ids.device(),
)?
.expand((b_sz, seq_len))?;
positions =
positions.broadcast_add(&Tensor::new(self.offset as u32, input_ids.device())?)?;
self.weights.forward(&positions)
}
}
#[derive(Debug, Clone)]
struct TrOCRAttention {
head_dim: usize,
num_heads: usize,
is_decoder: bool,
scaling: f64,
k_proj: Linear,
v_proj: Linear,
q_proj: Linear,
out_proj: Linear,
kv_cache: Option<(Tensor, Tensor)>,
}
impl TrOCRAttention {
fn load(
vb: VarBuilder,
cfg: &TrOCRConfig,
kdim: Option<usize>,
vdim: Option<usize>,
) -> Result<Self> {
let embed_dim = cfg.d_model;
let num_heads = cfg.decoder_attention_heads;
let head_dim = embed_dim / num_heads;
let kdim = kdim.unwrap_or(embed_dim);
let vdim = vdim.unwrap_or(embed_dim);
let k_proj = linear_no_bias(kdim, embed_dim, vb.pp("k_proj"))?;
let v_proj = linear_no_bias(vdim, embed_dim, vb.pp("v_proj"))?;
let q_proj = linear_no_bias(embed_dim, embed_dim, vb.pp("q_proj"))?;
let out_proj = linear_no_bias(embed_dim, embed_dim, vb.pp("out_proj"))?;
Ok(Self {
head_dim,
num_heads,
is_decoder: true,
scaling: 1. / (head_dim as f64).sqrt(),
k_proj,
v_proj,
q_proj,
out_proj,
kv_cache: None,
})
}
fn _shape(&self, tensor: &Tensor, bsz: usize) -> Result<Tensor> {
tensor
.reshape((bsz, (), self.num_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()
}
fn forward(
&mut self,
xs: &Tensor,
kv_states: Option<&Tensor>,
attn_mask: Option<&Tensor>,
) -> Result<Tensor> {
let (b_sz, tgt_len, _) = xs.dims3()?;
let query_states = (xs.apply(&self.q_proj)? * self.scaling)?;
let (key_states, value_states) = match kv_states {
None => {
let key_states = self._shape(&xs.apply(&self.k_proj)?, b_sz)?;
let value_states = self._shape(&xs.apply(&self.v_proj)?, b_sz)?;
if self.is_decoder {
let kv_states = match &self.kv_cache {
None => (key_states, value_states),
Some((p_key_states, p_value_states)) => {
let key_states = Tensor::cat(&[p_key_states, &key_states], 2)?;
let value_states = Tensor::cat(&[p_value_states, &value_states], 2)?;
(key_states, value_states)
}
};
self.kv_cache = Some(kv_states.clone());
kv_states
} else {
(key_states, value_states)
}
}
Some(kv_states) => {
let key_states = self._shape(&kv_states.apply(&self.k_proj)?, b_sz)?;
let value_states = self._shape(&kv_states.apply(&self.v_proj)?, b_sz)?;
(key_states, value_states)
}
};
let proj_shape = (b_sz * self.num_heads, (), self.head_dim);
let query_states = self._shape(&query_states, b_sz)?.reshape(proj_shape)?;
let key_states = key_states.reshape(proj_shape)?;
let value_states = value_states.reshape(proj_shape)?;
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
let attn_weights = match attn_mask {
None => attn_weights,
Some(attn_mask) => attn_weights.broadcast_add(attn_mask)?,
};
let attn_probs = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let attn_output = attn_probs.matmul(&value_states)?;
attn_output
.reshape((b_sz, self.num_heads, tgt_len, self.head_dim))?
.transpose(1, 2)?
.reshape((b_sz, tgt_len, self.head_dim * self.num_heads))?
.apply(&self.out_proj)
}
}
#[derive(Debug, Clone)]
struct TrOCRDecoderLayer {
self_attn: TrOCRAttention,
activation_fn: candle_nn::Activation,
self_attn_layer_norm: LayerNorm,
encoder_attn: TrOCRAttention,
encoder_attn_layer_norm: LayerNorm,
fc1: Linear,
fc2: Linear,
final_layer_norm: LayerNorm,
}
impl TrOCRDecoderLayer {
fn load(vb: VarBuilder, cfg: &TrOCRConfig) -> Result<Self> {
let embed_dim = cfg.d_model;
let self_attn = TrOCRAttention::load(vb.pp("self_attn"), cfg, None, None)?;
let self_attn_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("self_attn_layer_norm"))?;
let encoder_attn = TrOCRAttention::load(
vb.pp("encoder_attn"),
cfg,
Some(cfg.hidden_size),
Some(cfg.hidden_size),
)?;
let encoder_attn_layer_norm =
layer_norm(embed_dim, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
let fc1 = linear_no_bias(embed_dim, cfg.decoder_ffn_dim, vb.pp("fc1"))?;
let fc2 = linear_no_bias(cfg.decoder_ffn_dim, embed_dim, vb.pp("fc2"))?;
let final_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("final_layer_norm"))?;
let activation_fn = candle_nn::Activation::Gelu;
Ok(Self {
self_attn,
activation_fn,
self_attn_layer_norm,
encoder_attn,
encoder_attn_layer_norm,
fc1,
fc2,
final_layer_norm,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: &Tensor,
encoder_hidden_states: Option<&Tensor>,
) -> Result<Tensor> {
let residual = xs.clone();
let xs = self.self_attn.forward(xs, None, Some(attention_mask))?;
let xs = (xs + residual)?;
let mut xs = self.self_attn_layer_norm.forward(&xs)?;
if let Some(encoder_hidden_states) = &encoder_hidden_states {
let residual = xs.clone();
let encoder_attention_mask = attention_mask.clone(); // TODO
xs = self.encoder_attn.forward(
&xs,
Some(encoder_hidden_states),
Some(&encoder_attention_mask),
)?;
xs = (xs + residual)?;
xs = self.encoder_attn_layer_norm.forward(&xs)?
}
let residual = xs.clone();
let xs = self.fc1.forward(&xs)?;
let xs = self.activation_fn.forward(&xs)?;
let xs = self.fc2.forward(&xs)?;
let xs = (xs + residual)?;
let xs = self.final_layer_norm.forward(&xs)?;
Ok(xs)
}
}
#[derive(Debug, Clone)]
pub struct TrOCRDecoder {
layers: Vec<TrOCRDecoderLayer>,
embed_scale: Option<f64>,
embed_tokens: Embedding,
embed_positions: TrOCRLearnedPositionalEmbedding,
}
impl TrOCRDecoder {
fn new(cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
let vb = vb.pp("decoder.model.decoder");
let embed_tokens = embedding(cfg.vocab_size, cfg.d_model, vb.pp("embed_tokens"))?;
let embed_positions = TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?;
let mut layers = Vec::with_capacity(cfg.decoder_layers);
let vb_l = vb.pp("layers");
for idx in 0..cfg.decoder_layers {
let layer = TrOCRDecoderLayer::load(vb_l.pp(idx), cfg)?;
layers.push(layer)
}
let embed_scale = if cfg.scale_embedding {
Some((cfg.d_model as f64).sqrt())
} else {
None
};
Ok(Self {
layers,
embed_scale,
embed_tokens,
embed_positions,
})
}
pub fn forward(
&mut self,
xs: &Tensor,
encoder_xs: Option<&Tensor>,
past_kv_len: usize,
attn_mask: &Tensor,
) -> Result<Tensor> {
let embed_pos = self.embed_positions.forward(xs, past_kv_len as u32)?;
let xs = xs.apply(&self.embed_tokens)?;
let xs = match self.embed_scale {
None => xs,
Some(scale) => (xs * scale)?,
};
let mut xs = xs.broadcast_add(&embed_pos)?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attn_mask, encoder_xs)?;
}
Ok(xs)
}
}
#[derive(Debug, Clone)]
pub struct TrOCREncoder {
embeddings: Embeddings,
encoder: Encoder,
layernorm: LayerNorm,
}
impl TrOCREncoder {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_v = vb.pp("encoder");
let embeddings = Embeddings::new(cfg, false, vb_v.pp("embeddings"))?;
let encoder = Encoder::new(cfg, vb_v.pp("encoder"))?;
let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_v.pp("layernorm"))?;
Ok(Self {
embeddings,
encoder,
layernorm,
})
}
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let embedding_output = self.embeddings.forward(xs, None, false)?;
let encoder_outputs = self.encoder.forward(&embedding_output)?;
self.layernorm.forward(&encoder_outputs)
}
}
#[derive(Debug, Clone)]
pub struct TrOCRForCausalLM {
decoder: TrOCRDecoder,
output_projection: Linear,
}
impl TrOCRForCausalLM {
pub fn new(decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
let decoder = TrOCRDecoder::new(decoder_cfg, vb.clone())?;
let output_projection =
candle_nn::Linear::new(decoder.embed_tokens.embeddings().clone(), None);
Ok(Self {
decoder,
output_projection,
})
}
pub fn forward(
&mut self,
xs: &Tensor,
encoder_xs: Option<&Tensor>,
past_kv_len: usize,
attn_mask: &Tensor,
) -> Result<Tensor> {
let xs = self
.decoder
.forward(xs, encoder_xs, past_kv_len, attn_mask)?;
let xs = xs.apply(&self.output_projection)?;
Ok(xs)
}
}
#[derive(Debug, Clone)]
pub struct TrOCRModel {
encoder: TrOCREncoder,
decoder: TrOCRForCausalLM,
}
impl TrOCRModel {
pub fn new(encoder_cfg: &Config, decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
let encoder = TrOCREncoder::new(encoder_cfg, vb.clone())?;
let decoder = TrOCRForCausalLM::new(decoder_cfg, vb)?;
Ok(Self { encoder, decoder })
}
pub fn encoder(&mut self) -> &mut TrOCREncoder {
&mut self.encoder
}
pub fn decoder(&mut self) -> &mut TrOCRForCausalLM {
&mut self.decoder
}
pub fn decode(
&mut self,
xs: &Tensor,
encoder_xs: &Tensor,
past_kv_len: usize,
) -> Result<Tensor> {
let seq_len = xs.dim(1)?;
let mask: Vec<_> = (0..seq_len)
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
.collect();
let mask = Tensor::from_vec(mask, (seq_len, seq_len), xs.device())?;
self.decoder
.forward(xs, Some(encoder_xs), past_kv_len, &mask)
}
}

View File

@ -6,16 +6,16 @@ use candle_nn::{layer_norm, LayerNorm, VarBuilder};
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/configuration_vit.py // https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/configuration_vit.py
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Config { pub struct Config {
pub hidden_size: usize, hidden_size: usize,
pub num_hidden_layers: usize, num_hidden_layers: usize,
pub num_attention_heads: usize, num_attention_heads: usize,
pub intermediate_size: usize, intermediate_size: usize,
pub hidden_act: candle_nn::Activation, hidden_act: candle_nn::Activation,
pub layer_norm_eps: f64, layer_norm_eps: f64,
pub image_size: usize, image_size: usize,
pub patch_size: usize, patch_size: usize,
pub num_channels: usize, num_channels: usize,
pub qkv_bias: bool, qkv_bias: bool,
} }
impl Config { impl Config {
@ -34,21 +34,6 @@ impl Config {
qkv_bias: true, qkv_bias: true,
} }
} }
pub fn microsoft_trocr_base_handwritten() -> Self {
Self {
hidden_size: 768,
num_hidden_layers: 12,
num_attention_heads: 12,
intermediate_size: 3072,
hidden_act: candle_nn::Activation::Gelu,
layer_norm_eps: 1e-12,
image_size: 384,
patch_size: 16,
num_channels: 3,
qkv_bias: false,
}
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -91,7 +76,7 @@ impl Module for PatchEmbeddings {
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Embeddings { struct Embeddings {
cls_token: Tensor, cls_token: Tensor,
mask_token: Option<Tensor>, mask_token: Option<Tensor>,
patch_embeddings: PatchEmbeddings, patch_embeddings: PatchEmbeddings,
@ -100,7 +85,7 @@ pub struct Embeddings {
} }
impl Embeddings { impl Embeddings {
pub fn new(cfg: &Config, use_mask_token: bool, vb: VarBuilder) -> Result<Self> { fn new(cfg: &Config, use_mask_token: bool, vb: VarBuilder) -> Result<Self> {
let hidden_size = cfg.hidden_size; let hidden_size = cfg.hidden_size;
let cls_token = vb.get((1, 1, hidden_size), "cls_token")?; let cls_token = vb.get((1, 1, hidden_size), "cls_token")?;
let mask_token = if use_mask_token { let mask_token = if use_mask_token {
@ -130,7 +115,7 @@ impl Embeddings {
todo!() todo!()
} }
pub fn forward( fn forward(
&self, &self,
pixel_values: &Tensor, pixel_values: &Tensor,
bool_masked_pos: Option<&Tensor>, bool_masked_pos: Option<&Tensor>,
@ -339,12 +324,12 @@ impl Module for Layer {
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Encoder { struct Encoder {
layers: Vec<Layer>, layers: Vec<Layer>,
} }
impl Encoder { impl Encoder {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb = vb.pp("layer"); let vb = vb.pp("layer");
let mut layers = Vec::with_capacity(cfg.num_hidden_layers); let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
for i in 0..cfg.num_hidden_layers { for i in 0..cfg.num_hidden_layers {

View File

@ -198,17 +198,13 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
mel mel
} }
pub fn pcm_to_mel<T: Float + std::fmt::Display>( pub fn pcm_to_mel<T: Float + std::fmt::Display>(samples: &[T], filters: &[T]) -> Vec<T> {
cfg: &super::Config,
samples: &[T],
filters: &[T],
) -> Vec<T> {
log_mel_spectrogram_( log_mel_spectrogram_(
samples, samples,
filters, filters,
super::N_FFT, super::N_FFT,
super::HOP_LENGTH, super::HOP_LENGTH,
cfg.num_mel_bins, super::N_MELS,
false, false,
) )
} }

View File

@ -18,7 +18,6 @@ pub struct Config {
// pub n_text_state: usize, // pub n_text_state: usize,
pub decoder_attention_heads: usize, // n_text_head pub decoder_attention_heads: usize, // n_text_head
pub decoder_layers: usize, // n_text_layer pub decoder_layers: usize, // n_text_layer
#[serde(default)]
pub suppress_tokens: Vec<u32>, pub suppress_tokens: Vec<u32>,
} }
@ -27,6 +26,7 @@ pub const DTYPE: candle::DType = candle::DType::F32;
// Audio parameters. // Audio parameters.
pub const SAMPLE_RATE: usize = 16000; pub const SAMPLE_RATE: usize = 16000;
pub const N_FFT: usize = 400; pub const N_FFT: usize = 400;
pub const N_MELS: usize = 80;
pub const HOP_LENGTH: usize = 160; pub const HOP_LENGTH: usize = 160;
pub const CHUNK_LENGTH: usize = 30; pub const CHUNK_LENGTH: usize = 30;
pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk

View File

@ -1,5 +1,4 @@
use super::Config; use super::Config;
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
use candle::{Device, IndexOp, Result, Tensor, D}; use candle::{Device, IndexOp, Result, Tensor, D};
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
@ -7,6 +6,33 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
let embeddings = vb.get((vocab_size, hidden_size), "weight")?; let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size)) Ok(Embedding::new(embeddings, hidden_size))
} }
//
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
// model.
#[derive(Debug, Clone)]
pub struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
}
impl Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let span = tracing::span!(tracing::Level::TRACE, "linear");
let inner = candle_nn::linear(size1, size2, vb)?;
Ok(Linear { inner, span })
}
fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let span = tracing::span!(tracing::Level::TRACE, "linear");
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
Ok(Linear { inner, span })
}
fn conv1d( fn conv1d(
in_channels: usize, in_channels: usize,

View File

@ -124,34 +124,3 @@ impl std::fmt::Debug for QMatMul {
write!(f, "QMatMul") write!(f, "QMatMul")
} }
} }
#[derive(Clone, Debug)]
pub struct LayerNorm {
inner: candle_nn::LayerNorm,
span: tracing::Span,
}
impl LayerNorm {
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
let inner = candle_nn::LayerNorm::new(weight, bias, eps);
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
Self { inner, span }
}
}
impl Module for LayerNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(xs)
}
}
pub fn layer_norm<C: Into<candle_nn::LayerNormConfig>>(
size: usize,
c: C,
vb: VarBuilder,
) -> Result<LayerNorm> {
let inner = candle_nn::layer_norm(size, c, vb)?;
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
Ok(LayerNorm { inner, span })
}

View File

@ -1,377 +0,0 @@
/// https://huggingface.co/01-ai/Yi-6B/blob/main/modeling_yi.py
use crate::models::with_tracing::{linear_no_bias, Linear};
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)]
pub struct Config {
pub(crate) vocab_size: usize,
pub(crate) hidden_size: usize,
pub(crate) intermediate_size: usize,
pub(crate) num_hidden_layers: usize,
pub(crate) num_attention_heads: usize,
pub(crate) num_key_value_heads: usize,
pub(crate) hidden_act: Activation,
pub(crate) max_position_embeddings: usize,
pub(crate) rms_norm_eps: f64,
pub(crate) rope_theta: f64,
}
impl Config {
pub fn config_6b() -> Self {
Self {
vocab_size: 64000,
hidden_size: 4096,
intermediate_size: 11008,
num_hidden_layers: 32,
num_attention_heads: 32,
num_key_value_heads: 4,
hidden_act: Activation::Silu,
max_position_embeddings: 4096,
rms_norm_eps: 1e-5,
rope_theta: 5_000_000.,
}
}
pub fn config_34b() -> Self {
Self {
vocab_size: 64000,
hidden_size: 7168,
intermediate_size: 20480,
num_hidden_layers: 60,
num_attention_heads: 56,
num_key_value_heads: 8,
hidden_act: Activation::Silu,
max_position_embeddings: 4096,
rms_norm_eps: 1e-5,
rope_theta: 5_000_000.,
}
}
}
#[derive(Debug, Clone)]
struct RmsNorm {
inner: candle_nn::RmsNorm,
span: tracing::Span,
}
impl RmsNorm {
fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let inner = candle_nn::rms_norm(size, eps, vb)?;
Ok(Self { inner, span })
}
}
impl Module for RmsNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
let last_dim = xs.dim(D::Minus1)?;
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.hidden_size / cfg.num_attention_heads;
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,
seqlen_offset: usize,
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
Ok((q_embed, k_embed))
}
}
#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
act_fn: Activation,
}
impl MLP {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let intermediate_sz = cfg.intermediate_size;
let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
Ok(Self {
gate_proj,
up_proj,
down_proj,
act_fn: cfg.hidden_act,
})
}
}
impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
let rhs = xs.apply(&self.up_proj)?;
(lhs * rhs)?.apply(&self.down_proj)
}
}
#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
num_heads: usize,
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
hidden_size: usize,
rotary_emb: Arc<RotaryEmbedding>,
kv_cache: Option<(Tensor, Tensor)>,
}
impl Attention {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let num_kv_groups = num_heads / num_kv_heads;
let head_dim = hidden_sz / num_heads;
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
num_heads,
num_kv_heads,
num_kv_groups,
head_dim,
hidden_size: hidden_sz,
rotary_emb,
kv_cache: None,
})
}
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
let n_rep = self.num_kv_groups;
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
xs.unsqueeze(2)?
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
}
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;
let query_states = self.q_proj.forward(xs)?;
let key_states = self.k_proj.forward(xs)?;
let value_states = self.v_proj.forward(xs)?;
let query_states = query_states
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let key_states = key_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let value_states = value_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let (query_states, key_states) =
self.rotary_emb
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
let (key_states, value_states) = match &self.kv_cache {
None => (key_states, value_states),
Some((prev_k, prev_v)) => {
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
(key_states, value_states)
}
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
let key_states = self.repeat_kv(key_states)?;
let value_states = self.repeat_kv(value_states)?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
attn_weights.matmul(&value_states)?
};
attn_output
.transpose(1, 2)?
.reshape((b_sz, q_len, self.hidden_size))?
.apply(&self.o_proj)
}
}
#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Attention,
mlp: MLP,
ln1: RmsNorm,
ln2: RmsNorm,
}
impl DecoderLayer {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("ln1"))?;
let ln2 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("ln2"))?;
Ok(Self {
self_attn,
mlp,
ln1,
ln2,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let residual = xs;
let xs = self.ln1.forward(xs)?;
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = xs.apply(&self.ln2)?.apply(&self.mlp)?;
residual + xs
}
}
#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,
norm: RmsNorm,
lm_head: Linear,
device: Device,
dtype: DType,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_m = vb.pp("model");
let embed_tokens =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_m.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
Ok(Self {
embed_tokens,
layers,
norm,
lm_head,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
fn prepare_decoder_attention_mask(
&self,
b_size: usize,
tgt_len: usize,
seqlen_offset: usize,
) -> Result<Tensor> {
// Sliding window mask?
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(self.dtype)
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
Some(mask)
};
let mut xs = self.embed_tokens.forward(input_ids)?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
}
xs.narrow(1, seq_len - 1, 1)?
.apply(&self.norm)?
.apply(&self.lm_head)
}
}

View File

@ -10,12 +10,12 @@ pub struct VarBuilder {
} }
impl VarBuilder { impl VarBuilder {
pub fn from_gguf<P: AsRef<std::path::Path>>(p: P) -> Result<Self> { pub fn from_gguf<P: AsRef<std::path::Path>>(p: P, device: &Device) -> Result<Self> {
let mut file = std::fs::File::open(p)?; let mut file = std::fs::File::open(p)?;
let content = candle::quantized::gguf_file::Content::read(&mut file)?; let content = candle::quantized::gguf_file::Content::read(&mut file)?;
let mut data = std::collections::HashMap::new(); let mut data = std::collections::HashMap::new();
for tensor_name in content.tensor_infos.keys() { for tensor_name in content.tensor_infos.keys() {
let tensor = content.tensor(&mut file, tensor_name)?; let tensor = content.tensor(&mut file, tensor_name, device)?;
data.insert(tensor_name.to_string(), Arc::new(tensor)); data.insert(tensor_name.to_string(), Arc::new(tensor));
} }
Ok(Self { Ok(Self {
@ -25,12 +25,12 @@ impl VarBuilder {
}) })
} }
pub fn from_gguf_buffer(buffer: &[u8]) -> Result<Self> { pub fn from_gguf_buffer(buffer: &[u8], device: &Device) -> Result<Self> {
let mut cursor = std::io::Cursor::new(buffer); let mut cursor = std::io::Cursor::new(buffer);
let content = candle::quantized::gguf_file::Content::read(&mut cursor)?; let content = candle::quantized::gguf_file::Content::read(&mut cursor)?;
let mut data = std::collections::HashMap::new(); let mut data = std::collections::HashMap::new();
for tensor_name in content.tensor_infos.keys() { for tensor_name in content.tensor_infos.keys() {
let tensor = content.tensor(&mut cursor, tensor_name)?; let tensor = content.tensor(&mut cursor, tensor_name, device)?;
data.insert(tensor_name.to_string(), Arc::new(tensor)); data.insert(tensor_name.to_string(), Arc::new(tensor));
} }
Ok(Self { Ok(Self {
@ -90,8 +90,4 @@ impl VarBuilder {
pub fn device(&self) -> &Device { pub fn device(&self) -> &Device {
&self.device &self.device
} }
pub fn contains_key(&self, key: &str) -> bool {
self.data.contains_key(key)
}
} }

View File

@ -1,7 +1,5 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{ use candle_nn::{embedding, linear, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
embedding, linear_no_bias as linear, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder,
};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};

Some files were not shown because too many files have changed in this diff Show More