mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Compare commits
56 Commits
tmp-metal-
...
metal2-tmp
Author | SHA1 | Date | |
---|---|---|---|
d9c1f7e201 | |||
315ba4cf0c | |||
915f0e5b69 | |||
9975f2b239 | |||
d7cc660c68 | |||
c54ed0ab48 | |||
af5e77f409 | |||
8cf39d27ce | |||
e6697471bb | |||
73d02f4f57 | |||
f772213e84 | |||
2feb0b054f | |||
2d28497197 | |||
f3a4f3db76 | |||
7920b45c8a | |||
d4a45c936a | |||
c912d24570 | |||
d5c2a7b64b | |||
508f811b93 | |||
a773a4b22b | |||
5a363dbc26 | |||
abc4f698c5 | |||
a923e8b53a | |||
2a45bcf943 | |||
47f4ddb011 | |||
f365a075e5 | |||
60fdab4e17 | |||
928a9d906e | |||
d1d89bac1f | |||
39ad840a90 | |||
b5e4f84bed | |||
7051fb8098 | |||
dc68c130e4 | |||
bc9a1bf239 | |||
f7c957d64f | |||
8cbb9d0e6c | |||
bfe95115c6 | |||
6fa3151820 | |||
0a58886ccb | |||
3173b1ce3b | |||
ad63f20781 | |||
1cfc5d6d0c | |||
b07b2350b6 | |||
1b5063f3ca | |||
3b0d1e7d03 | |||
be4555c5a5 | |||
6975c65112 | |||
a2a20aeecc | |||
e08fbb6543 | |||
d39d0c40fd | |||
b97463098c | |||
fbd69f952c | |||
6c990a33ea | |||
1704f1b3ae | |||
693fad511c | |||
36fb84f038 |
2
.github/workflows/ci_cuda.yaml
vendored
2
.github/workflows/ci_cuda.yaml
vendored
@ -59,7 +59,7 @@ jobs:
|
||||
- name: Install Rust Stable
|
||||
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- run: apt-get update -y && apt-get install libssl-dev -y
|
||||
- run: apt-get update -y && apt-get install libssl-dev protobuf-compiler -y
|
||||
- name: Test (cuda)
|
||||
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
|
||||
stop-runner:
|
||||
|
BIN
.github/workflows/maturin.yml
vendored
BIN
.github/workflows/maturin.yml
vendored
Binary file not shown.
8
.github/workflows/python.yml
vendored
8
.github/workflows/python.yml
vendored
@ -39,6 +39,12 @@ jobs:
|
||||
path: ~/.cargo/registry
|
||||
key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}
|
||||
|
||||
- name: Install Protoc
|
||||
uses: arduino/setup-protoc@v2
|
||||
with:
|
||||
version: "25.0"
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Install
|
||||
working-directory: ./candle-pyo3
|
||||
run: |
|
||||
@ -46,7 +52,7 @@ jobs:
|
||||
source .env/bin/activate
|
||||
pip install -U pip
|
||||
pip install pytest maturin black
|
||||
python -m maturin develop -r
|
||||
python -m maturin develop -r --features onnx
|
||||
|
||||
- name: Check style
|
||||
working-directory: ./candle-pyo3
|
||||
|
@ -10,7 +10,12 @@ members = [
|
||||
"candle-wasm-examples/*",
|
||||
"candle-wasm-tests",
|
||||
]
|
||||
exclude = ["candle-flash-attn", "candle-kernels"]
|
||||
exclude = [
|
||||
"candle-flash-attn",
|
||||
"candle-kernels",
|
||||
"candle-metal-kernels",
|
||||
"candle-onnx",
|
||||
]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
|
14
README.md
14
README.md
@ -51,7 +51,7 @@ For more advanced examples, please have a look at the following section.
|
||||
These online demos run entirely in your browser:
|
||||
- [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and
|
||||
object recognition.
|
||||
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): text to speech.
|
||||
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): speech recognition.
|
||||
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
|
||||
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
|
||||
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
|
||||
@ -143,6 +143,11 @@ And then head over to
|
||||
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
|
||||
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): a LoRA implementation
|
||||
that conforms to the official `peft` implementation.
|
||||
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
|
||||
serving local LLMs including an OpenAI compatible API server.
|
||||
- [`candle-ext`](https://github.com/mokeyish/candle-ext): an extension library to Candle that provides PyTorch functions not currently available in Candle.
|
||||
- [`kalosm`](https://github.com/floneum/floneum/tree/master/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
||||
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
||||
|
||||
If you have an addition to this list, please submit a pull request.
|
||||
|
||||
@ -168,16 +173,16 @@ If you have an addition to this list, please submit a pull request.
|
||||
- Mistral 7b v0.1.
|
||||
- StableLM-3B-4E1T.
|
||||
- Replit-code-v1.5-3B.
|
||||
- T5.
|
||||
- Bert.
|
||||
- Text to text.
|
||||
- T5 and its variants: FlanT5, MADLAD400 (translation), CoEdit (Grammar correction).
|
||||
- Marian MT (Machine Translation).
|
||||
- Whisper (multi-lingual support).
|
||||
- Text to image.
|
||||
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
||||
- Wurstchen v2.
|
||||
- Image to text.
|
||||
- BLIP.
|
||||
- Text to text.
|
||||
- Marian MT (Machine Translation).
|
||||
- Computer Vision Models.
|
||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
|
||||
- yolo-v3, yolo-v8.
|
||||
@ -218,6 +223,7 @@ Cheatsheet:
|
||||
- [candle-datasets](./candle-datasets/): Datasets and data loaders.
|
||||
- [candle-transformers](./candle-transformers): transformers-related utilities.
|
||||
- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer.
|
||||
- [candle-onnx](./candle-onnx/): ONNX model evaluation.
|
||||
|
||||
## FAQ
|
||||
|
||||
|
@ -30,7 +30,6 @@ safetensors = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
yoke = { workspace = true }
|
||||
zip = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
@ -42,4 +41,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
|
||||
cudnn = ["cuda", "cudarc/cudnn"]
|
||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||
metal = ["dep:candle-metal-kernels", "dep:metal"]
|
||||
metal = ["dep:metal", "dep:candle-metal-kernels"]
|
||||
|
@ -39,6 +39,14 @@ pub trait BackendStorage: Sized {
|
||||
_params: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self>;
|
||||
|
||||
fn conv_transpose1d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self>;
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
|
@ -15,6 +15,17 @@ fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result
|
||||
}
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static CANDLE_GRAD_DO_NOT_DETACH: bool = {
|
||||
match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") {
|
||||
Ok(s) => {
|
||||
!s.is_empty() && s != "0"
|
||||
},
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
||||
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
||||
@ -57,6 +68,11 @@ impl Tensor {
|
||||
kernel: rhs,
|
||||
..
|
||||
}
|
||||
| Op::ConvTranspose1D {
|
||||
arg: lhs,
|
||||
kernel: rhs,
|
||||
..
|
||||
}
|
||||
| Op::Conv2D {
|
||||
arg: lhs,
|
||||
kernel: rhs,
|
||||
@ -150,10 +166,16 @@ impl Tensor {
|
||||
if node.is_variable() {
|
||||
continue;
|
||||
}
|
||||
let grad = grads.remove(node).unwrap();
|
||||
// TODO: We should perform all these operations in place (or at least not track the
|
||||
// whole graph). The only drawback would be if we wanted to support grad of grad but
|
||||
// this is out of scope.
|
||||
let grad = grads
|
||||
.remove(node)
|
||||
.expect("candle internal error - grad not populated");
|
||||
// https://github.com/huggingface/candle/issues/1241
|
||||
// Ideally, we would make these operations in place where possible to ensure that we
|
||||
// do not have to allocate too often. Here we just call `.detach` to avoid computing
|
||||
// the backprop graph of the backprop itself. This would be an issue for second order
|
||||
// derivatives but these are out of scope at the moment.
|
||||
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
|
||||
let grad = if do_not_detach { grad } else { grad.detach()? };
|
||||
if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
||||
@ -208,7 +230,44 @@ impl Tensor {
|
||||
let f_grad = pred.where_cond(&zeros, &grad)?;
|
||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||
}
|
||||
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
||||
Op::Conv1D {
|
||||
arg,
|
||||
kernel,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
} => {
|
||||
// The output height for conv_transpose1d is:
|
||||
// (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1
|
||||
let grad_l_in = grad.dim(2)?;
|
||||
let k_size = kernel.dim(2)?;
|
||||
let out_size =
|
||||
(grad_l_in - 1) * stride + dilation * (k_size - 1) + 1 - 2 * padding;
|
||||
let out_padding = arg.dim(2)? - out_size;
|
||||
let grad_arg = grad.conv_transpose1d(
|
||||
kernel,
|
||||
*padding,
|
||||
out_padding,
|
||||
*stride,
|
||||
*dilation,
|
||||
)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
|
||||
let grad_kernel = arg
|
||||
.transpose(0, 1)?
|
||||
.conv1d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
||||
.transpose(0, 1)?;
|
||||
let sum_grad = grads.or_insert(kernel)?;
|
||||
let (_, _, k0) = kernel.dims3()?;
|
||||
let (_, _, g_k0) = grad_kernel.dims3()?;
|
||||
let grad_kernel = if g_k0 != k0 {
|
||||
grad_kernel.narrow(2, 0, k0)?
|
||||
} else {
|
||||
grad_kernel
|
||||
};
|
||||
*sum_grad = sum_grad.add(&grad_kernel)?;
|
||||
}
|
||||
Op::Conv2D {
|
||||
arg,
|
||||
kernel,
|
||||
@ -247,6 +306,9 @@ impl Tensor {
|
||||
};
|
||||
*sum_grad = sum_grad.add(&grad_kernel)?;
|
||||
}
|
||||
Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "conv-transpose1d",
|
||||
})?,
|
||||
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "conv-transpose2d",
|
||||
})?,
|
||||
@ -487,16 +549,38 @@ impl Tensor {
|
||||
+ 0.5)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
|
||||
Op::Unary(_, UnaryOp::GeluErf) => {
|
||||
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
|
||||
Op::Unary(arg, UnaryOp::Erf) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
// d/dx erf(x) = 2/sqrt(pi) * e^(-x^2)
|
||||
let erf_grad =
|
||||
(2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?;
|
||||
*sum_grad = sum_grad.add(&(&grad * erf_grad)?)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::GeluErf) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
// d/dx gelu_erf(x) = 0.5 + 0.398942 e^(-x^2/2) x + 0.5 erf(x/sqrt(2))
|
||||
let neg_half_square = (arg.sqr()?.neg()? / 2.)?;
|
||||
let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?;
|
||||
let arg_scaled_sqrt = (arg / 2f64.sqrt())?;
|
||||
let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?;
|
||||
let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?;
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Relu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
||||
}
|
||||
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
|
||||
Op::Elu(arg, alpha) => {
|
||||
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let zeros = arg.zeros_like()?;
|
||||
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
|
||||
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
|
||||
let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?;
|
||||
let combined_mask = (positive_mask + negative_exp_mask)?;
|
||||
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
|
||||
}
|
||||
Op::Powf(arg, e) => {
|
||||
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
@ -25,6 +25,33 @@ impl ParamsConv1D {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ParamsConvTranspose1D {
|
||||
pub(crate) b_size: usize,
|
||||
pub(crate) l_in: usize,
|
||||
pub(crate) c_out: usize,
|
||||
pub(crate) c_in: usize,
|
||||
pub(crate) k_size: usize,
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) output_padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
pub(crate) dilation: usize,
|
||||
}
|
||||
|
||||
impl ParamsConvTranspose1D {
|
||||
pub(crate) fn l_out(&self) -> usize {
|
||||
(self.l_in - 1) * self.stride - 2 * self.padding
|
||||
+ self.dilation * (self.k_size - 1)
|
||||
+ self.output_padding
|
||||
+ 1
|
||||
}
|
||||
|
||||
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
||||
let l_out = self.l_out();
|
||||
vec![self.b_size, self.c_out, l_out]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum CudnnFwdAlgo {
|
||||
ImplicitGemm,
|
||||
@ -160,6 +187,49 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies a 1D transposed convolution over the input tensor.
|
||||
pub fn conv_transpose1d(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
) -> Result<Self> {
|
||||
let (b_size, c_in, l_in) = self.dims3()?;
|
||||
let (c_in_k, c_out, k_size) = kernel.dims3()?;
|
||||
if c_in != c_in_k {
|
||||
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||
}
|
||||
let params = ParamsConvTranspose1D {
|
||||
b_size,
|
||||
l_in,
|
||||
k_size,
|
||||
c_out,
|
||||
c_in,
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
dilation,
|
||||
};
|
||||
let storage = self.storage().conv_transpose1d(
|
||||
self.layout(),
|
||||
&kernel.storage(),
|
||||
kernel.layout(),
|
||||
¶ms,
|
||||
)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
|
||||
arg,
|
||||
kernel,
|
||||
padding: params.padding,
|
||||
output_padding: params.output_padding,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
|
||||
let storage =
|
||||
self.storage()
|
||||
|
@ -1256,6 +1256,74 @@ impl Map1 for Im2Col {
|
||||
}
|
||||
}
|
||||
|
||||
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
||||
|
||||
impl<'a> Map2 for ConvTranspose1D<'a> {
|
||||
const OP: &'static str = "conv_transpose1d";
|
||||
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
||||
let p = self.0;
|
||||
let inp = &inp[inp_l.start_offset()..];
|
||||
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
|
||||
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
|
||||
let l_out = p.l_out();
|
||||
|
||||
// Output shape: [b_size, c_out, l_out].
|
||||
let dst_elems = p.c_out * l_out * p.b_size;
|
||||
let dst = vec![T::zero(); dst_elems];
|
||||
let dst_s0 = p.c_out * l_out;
|
||||
let dst_s1 = l_out;
|
||||
let dst_s2 = 1;
|
||||
|
||||
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
|
||||
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
|
||||
let cont_s0 = p.l_in * p.c_in;
|
||||
let cont_s1 = p.c_in;
|
||||
for b_idx in 0..p.b_size {
|
||||
for l_idx in 0..p.l_in {
|
||||
for c_idx in 0..p.c_in {
|
||||
let src_idx = b_idx * inp_s0 + c_idx * inp_s1 + l_idx * inp_s2;
|
||||
let dst_idx = b_idx * cont_s0 + l_idx * cont_s1 + c_idx;
|
||||
inp_cont[dst_idx] = inp[src_idx]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for k_idx in 0..p.k_size {
|
||||
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
|
||||
let k_cont = (0..p.c_in)
|
||||
.map(|c_in_idx| k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_idx * k_s2])
|
||||
.collect::<Vec<_>>();
|
||||
for b_idx in 0..p.b_size {
|
||||
for l_idx in 0..p.l_in {
|
||||
let out_idx = l_idx * p.stride + k_idx * p.dilation;
|
||||
if out_idx < p.padding {
|
||||
continue;
|
||||
}
|
||||
let out_idx = out_idx - p.padding;
|
||||
if out_idx < l_out {
|
||||
let inp_cont = &inp_cont[b_idx * cont_s0 + l_idx * cont_s1..];
|
||||
let dst_idx = b_idx * dst_s0 + out_idx * dst_s2 + dst_c_idx * dst_s1;
|
||||
let mut d = T::zero();
|
||||
unsafe {
|
||||
T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
|
||||
}
|
||||
let dst_p = dst.as_ptr();
|
||||
// Safety: dst_idx are uniques per dst_c_idx which is used to
|
||||
// parallelise the different tasks so no two threads can try to
|
||||
// write at the same location.
|
||||
unsafe {
|
||||
let ptr = dst_p.add(dst_idx) as *mut T;
|
||||
*ptr += d
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
||||
|
||||
impl<'a> Map2 for Conv2D<'a> {
|
||||
@ -2435,6 +2503,16 @@ impl BackendStorage for CpuStorage {
|
||||
Ok(res_t)
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
|
@ -1808,6 +1808,16 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(res_t)
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cudnn"))]
|
||||
fn conv2d(
|
||||
&self,
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::backend::BackendDevice;
|
||||
use crate::cpu_backend::CpuDevice;
|
||||
use crate::{bail, CpuStorage, DType, Result, Shape, Storage, WithDType};
|
||||
use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
|
||||
|
||||
/// A `DeviceLocation` represents a physical device whereas multiple `Device`
|
||||
/// can live on the same location (typically for cuda devices).
|
||||
@ -8,7 +8,7 @@ use crate::{bail, CpuStorage, DType, Result, Shape, Storage, WithDType};
|
||||
pub enum DeviceLocation {
|
||||
Cpu,
|
||||
Cuda { gpu_id: usize },
|
||||
Metal,
|
||||
Metal { gpu_id: usize },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -105,14 +105,14 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4:
|
||||
impl<S: NdArray> NdArray for Vec<S> {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
if self.is_empty() {
|
||||
bail!("empty array")
|
||||
crate::bail!("empty array")
|
||||
}
|
||||
let shape0 = self[0].shape()?;
|
||||
let n = self.len();
|
||||
for v in self.iter() {
|
||||
let shape = v.shape()?;
|
||||
if shape != shape0 {
|
||||
bail!("two elements have different shapes {shape:?} {shape0:?}")
|
||||
crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
|
||||
}
|
||||
}
|
||||
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
|
||||
@ -146,6 +146,7 @@ impl Device {
|
||||
match (self, rhs) {
|
||||
(Self::Cpu, Self::Cpu) => true,
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
|
||||
(Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
@ -166,6 +167,10 @@ impl Device {
|
||||
matches!(self, Self::Cuda(_))
|
||||
}
|
||||
|
||||
pub fn is_metal(&self) -> bool {
|
||||
matches!(self, Self::Metal(_))
|
||||
}
|
||||
|
||||
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
||||
if crate::utils::cuda_is_available() {
|
||||
Self::new_cuda(ordinal)
|
||||
@ -187,13 +192,19 @@ impl Device {
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
|
||||
if dtype == DType::F16 || dtype == DType::BF16 {
|
||||
let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
|
||||
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
|
||||
} else {
|
||||
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
Device::Metal(_device) => {
|
||||
// let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||
// Ok(Storage::Metal(storage))
|
||||
bail!("Metal rand_uniform not implemented")
|
||||
crate::bail!("Metal rand_uniform not implemented")
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -220,9 +231,15 @@ impl Device {
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
|
||||
if dtype == DType::F16 || dtype == DType::BF16 {
|
||||
let storage = device.rand_normal(shape, DType::F32, mean, std)?;
|
||||
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
|
||||
} else {
|
||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
|
@ -14,7 +14,9 @@ impl Tensor {
|
||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||
format!(", cuda:{}", gpu_id)
|
||||
}
|
||||
_ => todo!(),
|
||||
crate::DeviceLocation::Metal { gpu_id } => {
|
||||
format!(", metal:{}", gpu_id)
|
||||
}
|
||||
};
|
||||
|
||||
write!(f, "Tensor[")?;
|
||||
@ -477,7 +479,9 @@ impl std::fmt::Display for Tensor {
|
||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||
format!(", cuda:{}", gpu_id)
|
||||
}
|
||||
crate::DeviceLocation::Metal => todo!(),
|
||||
crate::DeviceLocation::Metal { gpu_id } => {
|
||||
format!(", metal:{}", gpu_id)
|
||||
}
|
||||
};
|
||||
|
||||
write!(
|
||||
|
@ -79,6 +79,16 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
_: &Layout,
|
||||
|
@ -8,6 +8,18 @@ pub struct MetalDevice;
|
||||
#[derive(Debug)]
|
||||
pub struct MetalStorage;
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum MetalError {
|
||||
#[error("{0}")]
|
||||
Message(String),
|
||||
}
|
||||
|
||||
impl From<String> for MetalError {
|
||||
fn from(e: String) -> Self {
|
||||
MetalError::Message(e)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! fail {
|
||||
() => {
|
||||
unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
|
||||
@ -79,6 +91,16 @@ impl crate::backend::BackendStorage for MetalStorage {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
_: &Layout,
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::{DType, DeviceLocation, Layout, Shape};
|
||||
use crate::{DType, DeviceLocation, Layout, MetalError, Shape};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MatMulUnexpectedStriding {
|
||||
@ -163,7 +163,7 @@ pub enum Error {
|
||||
Cuda(Box<dyn std::error::Error + Send + Sync>),
|
||||
|
||||
#[error("Metal error {0}")]
|
||||
Metal(String),
|
||||
Metal(#[from] MetalError),
|
||||
|
||||
#[error(transparent)]
|
||||
TryFromIntError(#[from] core::num::TryFromIntError),
|
||||
|
@ -49,13 +49,12 @@ mod device;
|
||||
pub mod display;
|
||||
mod dtype;
|
||||
mod dummy_cuda_backend;
|
||||
mod dummy_metal_backend;
|
||||
pub mod error;
|
||||
mod indexer;
|
||||
pub mod layout;
|
||||
#[cfg(feature = "metal")]
|
||||
pub mod metal_backend;
|
||||
#[cfg(feature = "accelerate")]
|
||||
mod metal_backend;
|
||||
#[cfg(feature = "mkl")]
|
||||
mod mkl;
|
||||
pub mod npy;
|
||||
@ -92,10 +91,10 @@ pub use cuda_backend::{CudaDevice, CudaStorage};
|
||||
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
pub use metal_backend::{MetalDevice, MetalStorage};
|
||||
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
|
||||
|
||||
#[cfg(not(feature = "metal"))]
|
||||
pub use dummy_metal_backend::{MetalDevice, MetalStorage};
|
||||
pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage};
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
@ -1,28 +1,44 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::bail;
|
||||
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D};
|
||||
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
use candle_metal_kernels;
|
||||
use candle_metal_kernels::Kernels;
|
||||
use core::mem;
|
||||
use half::{bf16, f16};
|
||||
use metal;
|
||||
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
|
||||
use metal::mps::{Float32, MPSDataType};
|
||||
use metal::MTLResourceOptions;
|
||||
use metal::mps::matrix::encode_gemm;
|
||||
use metal::mps::Float32;
|
||||
use metal::{Buffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Metal related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum MetalError {
|
||||
#[error("metal error")]
|
||||
Metal,
|
||||
#[error("{0}")]
|
||||
Message(String),
|
||||
#[error(transparent)]
|
||||
KernelError(#[from] candle_metal_kernels::MetalKernelError),
|
||||
|
||||
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
||||
MatMulNonContiguous {
|
||||
lhs_stride: Vec<usize>,
|
||||
rhs_stride: Vec<usize>,
|
||||
mnk: (usize, usize, usize),
|
||||
},
|
||||
}
|
||||
|
||||
impl From<String> for MetalError {
|
||||
fn from(e: String) -> Self {
|
||||
MetalError::Message(e)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MetalDevice {
|
||||
device: metal::Device,
|
||||
_command_queue: metal::CommandQueue,
|
||||
command_buffer: metal::CommandBuffer,
|
||||
command_queue: metal::CommandQueue,
|
||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MetalDevice {
|
||||
@ -40,12 +56,31 @@ impl std::ops::Deref for MetalDevice {
|
||||
}
|
||||
|
||||
impl MetalDevice {
|
||||
pub fn metal_device(&self) -> &metal::DeviceRef {
|
||||
self.device.as_ref()
|
||||
// pub fn metal_device(&self) -> &metal::DeviceRef {
|
||||
// self.device.as_ref()
|
||||
// }
|
||||
|
||||
pub fn id(&self) -> NSUInteger {
|
||||
self.registry_id()
|
||||
}
|
||||
|
||||
pub fn id(&self) -> u64 {
|
||||
self.registry_id()
|
||||
pub fn command_queue(&self) -> &CommandQueue {
|
||||
&self.command_queue
|
||||
}
|
||||
|
||||
pub fn kernels(&self) -> &Kernels {
|
||||
&self.kernels
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &metal::Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
||||
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
||||
// debug!("Allocate 1 - buffer size {size}");
|
||||
self.device
|
||||
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
||||
}
|
||||
}
|
||||
|
||||
@ -72,20 +107,63 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
match self.dtype{
|
||||
DType::F32 => {
|
||||
// self.buffer.read_to_vec(self.buffer.length() as usize / 4);
|
||||
let mut buffer = vec![0.0; 32000];
|
||||
buffer[0] = 1.0;
|
||||
Ok(CpuStorage::F32(buffer))},
|
||||
dtype => todo!("Unsupported dtype {dtype:?}")
|
||||
// TODO Is this necessary
|
||||
// self.buffer.synchronize();
|
||||
match self.dtype {
|
||||
DType::U8 => Ok(CpuStorage::U8(
|
||||
self.buffer.read_to_vec(self.buffer.length() as usize / 1),
|
||||
)),
|
||||
DType::U32 => Ok(CpuStorage::U32(
|
||||
self.buffer.read_to_vec(self.buffer.length() as usize / 4),
|
||||
)),
|
||||
DType::I64 => Ok(CpuStorage::I64(
|
||||
self.buffer.read_to_vec(self.buffer.length() as usize / 8),
|
||||
)),
|
||||
DType::F16 => Ok(CpuStorage::F16(
|
||||
self.buffer.read_to_vec(self.buffer.length() as usize / 2),
|
||||
)),
|
||||
DType::BF16 => Ok(CpuStorage::BF16(
|
||||
self.buffer.read_to_vec(self.buffer.length() as usize / 2),
|
||||
)),
|
||||
DType::F32 => Ok(CpuStorage::F32(
|
||||
self.buffer.read_to_vec(self.buffer.length() as usize / 4),
|
||||
)),
|
||||
DType::F64 => Ok(CpuStorage::F64(
|
||||
self.buffer.read_to_vec(self.buffer.length() as usize / 8),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
|
||||
println!("TODO Affine");
|
||||
Ok(self.clone())
|
||||
// todo!()
|
||||
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
|
||||
let shape = layout.shape();
|
||||
let el = shape.elem_count();
|
||||
let dtype = self.dtype;
|
||||
|
||||
assert!(layout.is_contiguous());
|
||||
assert_eq!(dtype, DType::F32);
|
||||
|
||||
let mut buffer = device.new_buffer(el, self.dtype);
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
candle_metal_kernels::call_affine(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
el,
|
||||
&self.buffer,
|
||||
&mut buffer,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
return Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
});
|
||||
}
|
||||
|
||||
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
@ -96,10 +174,66 @@ buffer[0] = 1.0;
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
|
||||
println!("TODO reduce_op");
|
||||
Ok(self.clone())
|
||||
// todo!()
|
||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
// debug!("TODO reduce_op {op:?} {sum_dims:?}");
|
||||
assert!(sum_dims.len() == 1);
|
||||
assert!(sum_dims[0] == layout.shape().rank() - 1);
|
||||
assert!(layout.is_contiguous());
|
||||
let device = self.device.clone();
|
||||
let src_stride = layout.stride();
|
||||
let src_dims = layout.shape().dims();
|
||||
let src_el: usize = src_dims.iter().product();
|
||||
// Source dims and strides with the sum dims at the end.
|
||||
let mut dims = vec![];
|
||||
let mut stride = vec![];
|
||||
let mut dst_el: usize = 1;
|
||||
for (dim_idx, &d) in src_dims.iter().enumerate() {
|
||||
if !sum_dims.contains(&dim_idx) {
|
||||
dst_el *= d;
|
||||
dims.push(d);
|
||||
stride.push(src_stride[dim_idx]);
|
||||
}
|
||||
}
|
||||
for &dim_idx in sum_dims.iter() {
|
||||
dims.push(src_dims[dim_idx]);
|
||||
stride.push(src_stride[dim_idx]);
|
||||
}
|
||||
|
||||
// The reduction loop requires the shared array to be properly initialized and for
|
||||
// this we want the number of threads to be a power of two.
|
||||
let (name, check_empty, return_index) = match (op, self.dtype) {
|
||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_float", false, false),
|
||||
(ReduceOp::Min, DType::F32) => ("fast_min_float", true, false),
|
||||
(ReduceOp::Max, DType::F32) => ("fast_max_float", true, false),
|
||||
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true),
|
||||
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true),
|
||||
_ => todo!("Reduce op for non float"),
|
||||
};
|
||||
if check_empty && layout.shape().elem_count() == 0 {
|
||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
}
|
||||
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||
let mut buffer = device.new_buffer(dst_el, dtype);
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
candle_metal_kernels::call_reduce_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
name,
|
||||
src_el,
|
||||
dst_el,
|
||||
&self.buffer,
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device,
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
@ -107,26 +241,202 @@ buffer[0] = 1.0;
|
||||
}
|
||||
|
||||
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
todo!("Implement {:?} {layout:?} - {dtype:?}", self.dtype)
|
||||
let device = self.device();
|
||||
let shape = layout.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
let command_buffer = device.command_queue.new_command_buffer();
|
||||
if layout.is_contiguous() {
|
||||
let kernel_name = match (self.dtype, dtype) {
|
||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||
(left, right) => todo!("to dtype {left:?} - {right:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_cast_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
todo!(
|
||||
"TODO Implement the kernel calling cast {:?}-{:?}",
|
||||
self.dtype,
|
||||
dtype
|
||||
);
|
||||
}
|
||||
|
||||
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
|
||||
// todo!()
|
||||
// TODO
|
||||
println!("TODO {:?}", B::NAME);
|
||||
Ok(self.clone())
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
// debug!(
|
||||
// "cast {:?} - {:?} - {:?}",
|
||||
// dtype,
|
||||
// self.buffer.length(),
|
||||
// buffer.length()
|
||||
// );
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
println!("TODO Binary {:?}", B::NAME);
|
||||
Ok(self.clone())
|
||||
// todo!()
|
||||
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||
let device = self.device();
|
||||
let dtype = self.dtype;
|
||||
let shape = layout.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
let command_buffer = device.command_queue.new_command_buffer();
|
||||
if layout.is_contiguous() {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("ucos", DType::F32) => contiguous::cos::FLOAT,
|
||||
("usin", DType::F32) => contiguous::sin::FLOAT,
|
||||
("usqr", DType::F32) => contiguous::sqr::FLOAT,
|
||||
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
|
||||
("uneg", DType::F32) => contiguous::neg::FLOAT,
|
||||
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
||||
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_unary_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
todo!("TODO Implement the kernel calling {}", B::KERNEL);
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn where_cond(&self, _: &Layout, rhs: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||
println!("TODO where_cond");
|
||||
Ok(rhs.clone())
|
||||
// todo!()
|
||||
fn binary_impl<B: BinaryOpT>(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
let device = self.device();
|
||||
let dtype = self.dtype;
|
||||
let shape = lhs_l.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
let command_buffer = device.command_queue.new_command_buffer();
|
||||
if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
|
||||
use candle_metal_kernels::binary::contiguous;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("add", DType::F32) => contiguous::add::FLOAT,
|
||||
("badd", DType::F32) => contiguous::add::FLOAT,
|
||||
("sub", DType::F32) => contiguous::sub::FLOAT,
|
||||
("bsub", DType::F32) => contiguous::sub::FLOAT,
|
||||
("mul", DType::F32) => contiguous::mul::FLOAT,
|
||||
("bmul", DType::F32) => contiguous::mul::FLOAT,
|
||||
("div", DType::F32) => contiguous::div::FLOAT,
|
||||
("bdiv", DType::F32) => contiguous::div::FLOAT,
|
||||
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_binary_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&rhs.buffer,
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
use candle_metal_kernels::binary::strided;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("badd", DType::F32) => strided::add::FLOAT,
|
||||
("bsub", DType::F32) => strided::sub::FLOAT,
|
||||
("bmul", DType::F32) => strided::mul::FLOAT,
|
||||
("bdiv", DType::F32) => strided::div::FLOAT,
|
||||
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_binary_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
lhs_l.dims(),
|
||||
&self.buffer,
|
||||
&lhs_l.stride(),
|
||||
lhs_l.start_offset(),
|
||||
&rhs.buffer,
|
||||
&rhs_l.stride(),
|
||||
rhs_l.start_offset(),
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn where_cond(
|
||||
&self,
|
||||
layout: &Layout,
|
||||
t: &Self,
|
||||
t_l: &Layout,
|
||||
f: &Self,
|
||||
f_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
let device = self.device.clone();
|
||||
let shape = t_l.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let dtype = t.dtype;
|
||||
let mut buffer = self.device.new_buffer(el, dtype);
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
candle_metal_kernels::call_where_cond_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
"where_u8_f32",
|
||||
&dims,
|
||||
&self.buffer,
|
||||
(layout.stride(), layout.start_offset()),
|
||||
&t.buffer,
|
||||
(&t_l.stride(), t_l.start_offset()),
|
||||
&f.buffer,
|
||||
(&f_l.stride(), f_l.start_offset()),
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device,
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
@ -139,6 +449,16 @@ buffer[0] = 1.0;
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
@ -191,10 +511,43 @@ buffer[0] = 1.0;
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||
println!("TODO Index select");
|
||||
Ok(self.clone())
|
||||
// todo!()
|
||||
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
assert!(src_l.is_contiguous());
|
||||
assert!(ids_l.is_contiguous());
|
||||
let left_size: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let ids_el = ids_l.shape().elem_count();
|
||||
let dst_el = ids_el * left_size * right_size;
|
||||
let dtype = self.dtype;
|
||||
let device = self.device();
|
||||
let mut buffer = device.new_buffer(dst_el, dtype);
|
||||
let out = self.to_cpu_storage().unwrap();
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U32, DType::F32) => "is_u32_f32",
|
||||
(left, right) => todo!("index select metal {left:?} {right:?}"),
|
||||
};
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
// println!("INDEX SELECT");
|
||||
candle_metal_kernels::call_index_select(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
src_l.dims(),
|
||||
ids_el,
|
||||
dim,
|
||||
&self.buffer,
|
||||
&ids.buffer,
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn index_add(
|
||||
@ -216,122 +569,79 @@ buffer[0] = 1.0;
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
let transpose_left = false;
|
||||
let transpose_right = false;
|
||||
let alpha = 1.0;
|
||||
let beta = 0.0;
|
||||
self.matmul_generic(
|
||||
rhs,
|
||||
(b, m, n, k),
|
||||
lhs_l,
|
||||
rhs_l,
|
||||
transpose_left,
|
||||
transpose_right,
|
||||
alpha,
|
||||
beta,
|
||||
)
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
||||
println!("TODO Copy strided");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl MetalStorage {
|
||||
pub(crate) fn matmul_t(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
let transpose_left = false;
|
||||
let transpose_right = true;
|
||||
let alpha = 1.0;
|
||||
let beta = 0.0;
|
||||
self.matmul_generic(
|
||||
rhs,
|
||||
(b, m, n, k),
|
||||
lhs_l,
|
||||
rhs_l,
|
||||
transpose_left,
|
||||
transpose_right,
|
||||
alpha,
|
||||
beta,
|
||||
)
|
||||
}
|
||||
pub(crate) fn matmul_generic(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
transpose_left: bool,
|
||||
transpose_right: bool,
|
||||
alpha: f64,
|
||||
beta: f64,
|
||||
) -> Result<Self> {
|
||||
let elem_count = b * m * n;
|
||||
match (self.dtype, rhs.dtype) {
|
||||
(DType::F32, DType::F32) => {
|
||||
let span= tracing::span!(tracing::Level::TRACE, "metal alloc matmul");
|
||||
let _enter = span.enter();
|
||||
|
||||
let out_buffer = self.device.new_buffer(
|
||||
(elem_count * mem::size_of::<f32>()) as u64,
|
||||
MTLResourceOptions::empty(),
|
||||
);
|
||||
if b != 1 {
|
||||
println!("TODO implement batched matmul for B={b}");
|
||||
// bail!("Didn't implemented strided matmul yet");
|
||||
return Ok(Self {
|
||||
buffer: out_buffer,
|
||||
device: self.device.clone(),
|
||||
dtype: self.dtype(),
|
||||
});
|
||||
}
|
||||
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
|
||||
println!("Didn't implemented non contiguous matmul yet {:?} {:?}", lhs_l.is_contiguous(), rhs_l.is_contiguous());
|
||||
return Ok(Self {
|
||||
buffer: out_buffer,
|
||||
device: self.device.clone(),
|
||||
dtype: self.dtype(),
|
||||
});
|
||||
}
|
||||
return Ok(Self {
|
||||
buffer: out_buffer,
|
||||
device: self.device.clone(),
|
||||
dtype: self.dtype(),
|
||||
});
|
||||
let m: u64 = m.try_into().expect("usize should fit u64");
|
||||
let n: u64 = n.try_into().expect("usize should fit u64");
|
||||
let k: u64 = k.try_into().expect("usize should fit u64");
|
||||
// Create descriptors
|
||||
let left_descriptor =
|
||||
MatrixDescriptor::init_single(m, k, k * Float32::SIZE, Float32::TYPE_ID);
|
||||
let right_descriptor =
|
||||
MatrixDescriptor::init_single(k, n, n * Float32::SIZE, Float32::TYPE_ID);
|
||||
let result_descriptor =
|
||||
MatrixDescriptor::init_single(m, n, n * Float32::SIZE, Float32::TYPE_ID);
|
||||
use metal::mps::matrix::*;
|
||||
let type_id = metal::mps::MPS_FLOATBIT_ENCODING | 32;
|
||||
let size = core::mem::size_of::<f32>() as NSUInteger;
|
||||
|
||||
let elem_count = b * m * n;
|
||||
|
||||
let lhs_stride = lhs_l.stride();
|
||||
let rhs_stride = rhs_l.stride();
|
||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||
// The a tensor has dims batching, k, n (rhs)
|
||||
let transpose_left = if lhs_m1 == 1 && lhs_m2 == k {
|
||||
false
|
||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||
true
|
||||
} else {
|
||||
Err(MetalError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?
|
||||
};
|
||||
let transpose_right = if rhs_m1 == 1 && rhs_m2 == n {
|
||||
false
|
||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||
true
|
||||
} else {
|
||||
Err(MetalError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?
|
||||
};
|
||||
// println!("{transpose_left} {transpose_right}");
|
||||
|
||||
let b = b as NSUInteger;
|
||||
let m = m as NSUInteger;
|
||||
let n = n as NSUInteger;
|
||||
let k = k as NSUInteger;
|
||||
|
||||
let left_descriptor = if transpose_left {
|
||||
MatrixDescriptor::init_single(k, m, m * size, type_id)
|
||||
} else {
|
||||
MatrixDescriptor::init_single(m, k, k * size, type_id)
|
||||
};
|
||||
let right_descriptor = if transpose_right {
|
||||
MatrixDescriptor::init_single(n, k, k * size, type_id)
|
||||
} else {
|
||||
MatrixDescriptor::init_single(k, n, n * size, type_id)
|
||||
};
|
||||
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
|
||||
|
||||
println!("lhs {:?} {m} {k}", self.buffer.length());
|
||||
println!("rhs {:?} {k} {n}", rhs.buffer.length());
|
||||
println!("out {:?} {m} {n}", out_buffer.length());
|
||||
// Create matrix objects
|
||||
let left_matrix =
|
||||
Matrix::init_with_buffer_descriptor(&self.buffer, &left_descriptor)
|
||||
.expect("Failed to create left matrix");
|
||||
let right_matrix =
|
||||
Matrix::init_with_buffer_descriptor(&rhs.buffer, &right_descriptor)
|
||||
.expect("Failed to create left matrix");
|
||||
let left_matrix = Matrix::init_with_buffer_descriptor(&self.buffer, &left_descriptor)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
let right_matrix = Matrix::init_with_buffer_descriptor(&rhs.buffer, &right_descriptor)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
|
||||
let result_matrix =
|
||||
Matrix::init_with_buffer_descriptor(&out_buffer, &result_descriptor)
|
||||
.expect("Failed to create left matrix");
|
||||
|
||||
println!("lhs {:?}", lhs_l.shape());
|
||||
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
||||
let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, &result_descriptor)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
|
||||
let alpha = 1.0f64;
|
||||
let beta = 0.0f64;
|
||||
// Create kernel
|
||||
let matrix_multiplication = MatrixMultiplication::init(
|
||||
&self.device,
|
||||
@ -343,33 +653,81 @@ impl MetalStorage {
|
||||
alpha,
|
||||
beta,
|
||||
)
|
||||
.expect("Failed to create matrix multiplication kernel");
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
|
||||
matrix_multiplication.set_batch_size(b);
|
||||
|
||||
// Encode kernel to command buffer
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
matrix_multiplication.encode_to_command_buffer(
|
||||
&self.device.command_buffer,
|
||||
command_buffer,
|
||||
&left_matrix,
|
||||
&right_matrix,
|
||||
&result_matrix,
|
||||
);
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
// let left = self.buffer.read_to_vec::<f32>(10);
|
||||
// let right = rhs.buffer.read_to_vec::<f32>(10);
|
||||
// let out = out_buffer.read_to_vec::<f32>(40);
|
||||
// todo!("Out {left:?} {right:?} {out:?}");
|
||||
|
||||
Ok(Self {
|
||||
buffer: out_buffer,
|
||||
device: self.device.clone(),
|
||||
dtype: self.dtype(),
|
||||
})
|
||||
}
|
||||
_ => todo!("Unimplemented matmul for this pair"),
|
||||
|
||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||
let src_shape = src_l.shape();
|
||||
let el_count = src_shape.elem_count();
|
||||
if el_count == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
let kernel_name = match self.dtype {
|
||||
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
||||
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
|
||||
DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
|
||||
dtype => todo!("copy_strided not implemented for {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_unary_strided(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
kernel_name,
|
||||
src_l.dims(),
|
||||
&self.buffer,
|
||||
&src_l.stride(),
|
||||
src_l.start_offset(),
|
||||
&mut dst.buffer,
|
||||
dst_offset,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
// todo!("Output {:?}", dst.buffer.read_to_vec::<f32>(10));
|
||||
// }
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl MetalDevice{
|
||||
pub fn flush(&mut self){
|
||||
self.command_buffer.commit();
|
||||
self.command_buffer.wait_until_completed();
|
||||
self.command_buffer = self._command_queue.new_owned_command_buffer();
|
||||
impl MetalStorage {
|
||||
pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self {
|
||||
Self {
|
||||
buffer,
|
||||
device,
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn buffer(&self) -> &Buffer {
|
||||
&self.buffer
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendDevice for MetalDevice {
|
||||
@ -377,12 +735,26 @@ impl BackendDevice for MetalDevice {
|
||||
|
||||
fn new(ordinal: usize) -> Result<Self> {
|
||||
let device = metal::Device::all().swap_remove(ordinal);
|
||||
let _command_queue = device.new_command_queue();
|
||||
let command_buffer = _command_queue.new_owned_command_buffer();
|
||||
|
||||
// let capture = metal::CaptureManager::shared();
|
||||
// let descriptor = metal::CaptureDescriptor::new();
|
||||
// descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
||||
// descriptor.set_capture_device(&device);
|
||||
// let mut dir = std::env::current_dir()?;
|
||||
// dir.push("out.gputrace");
|
||||
// descriptor.set_output_url(dir);
|
||||
|
||||
// capture
|
||||
// .start_capture(&descriptor)
|
||||
// .map_err(MetalError::from)?;
|
||||
let command_queue = device.new_command_queue();
|
||||
// let command_buffer = _command_queue.new_owned_command_buffer();
|
||||
let kernels = Arc::new(Kernels::new());
|
||||
Ok(Self {
|
||||
device,
|
||||
_command_queue,
|
||||
command_buffer,
|
||||
command_queue,
|
||||
// command_buffer,
|
||||
kernels,
|
||||
})
|
||||
}
|
||||
|
||||
@ -391,7 +763,9 @@ impl BackendDevice for MetalDevice {
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
crate::DeviceLocation::Metal
|
||||
crate::DeviceLocation::Metal {
|
||||
gpu_id: self.registry_id() as usize,
|
||||
}
|
||||
}
|
||||
|
||||
fn same_device(&self, rhs: &Self) -> bool {
|
||||
@ -411,48 +785,47 @@ impl BackendDevice for MetalDevice {
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
||||
let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache;
|
||||
let span= tracing::span!(tracing::Level::TRACE, "metal alloc");
|
||||
let _enter = span.enter();
|
||||
|
||||
let buffer = self.device.new_buffer(4, option);
|
||||
// let buffer = match storage {
|
||||
// CpuStorage::U8(storage) => self.device.new_buffer_with_data(
|
||||
// storage.as_ptr() as *const core::ffi::c_void,
|
||||
// (storage.len() * mem::size_of::<u8>()) as u64,
|
||||
// option,
|
||||
// ),
|
||||
// CpuStorage::U32(storage) => self.device.new_buffer_with_data(
|
||||
// storage.as_ptr() as *const core::ffi::c_void,
|
||||
// (storage.len() * mem::size_of::<u32>()) as u64,
|
||||
// option,
|
||||
// ),
|
||||
// CpuStorage::I64(storage) => self.device.new_buffer_with_data(
|
||||
// storage.as_ptr() as *const core::ffi::c_void,
|
||||
// (storage.len() * mem::size_of::<i64>()) as u64,
|
||||
// option,
|
||||
// ),
|
||||
// CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
|
||||
// storage.as_ptr() as *const core::ffi::c_void,
|
||||
// (storage.len() * mem::size_of::<bf16>()) as u64,
|
||||
// option,
|
||||
// ),
|
||||
// CpuStorage::F16(storage) => self.device.new_buffer_with_data(
|
||||
// storage.as_ptr() as *const core::ffi::c_void,
|
||||
// (storage.len() * mem::size_of::<f16>()) as u64,
|
||||
// option,
|
||||
// ),
|
||||
// CpuStorage::F32(storage) => self.device.new_buffer_with_data(
|
||||
// storage.as_ptr() as *const core::ffi::c_void,
|
||||
// (storage.len() * mem::size_of::<f32>()) as u64,
|
||||
// option,
|
||||
// ),
|
||||
// CpuStorage::F64(storage) => self.device.new_buffer_with_data(
|
||||
// storage.as_ptr() as *const core::ffi::c_void,
|
||||
// (storage.len() * mem::size_of::<f64>()) as u64,
|
||||
// option,
|
||||
// ),
|
||||
// };
|
||||
let option = metal::MTLResourceOptions::StorageModeManaged;
|
||||
let buffer = match storage {
|
||||
CpuStorage::U8(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<u8>()) as NSUInteger,
|
||||
option,
|
||||
),
|
||||
CpuStorage::U32(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<u32>()) as NSUInteger,
|
||||
option,
|
||||
),
|
||||
CpuStorage::I64(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<i64>()) as NSUInteger,
|
||||
option,
|
||||
),
|
||||
CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<bf16>()) as NSUInteger,
|
||||
option,
|
||||
),
|
||||
CpuStorage::F16(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<f16>()) as NSUInteger,
|
||||
option,
|
||||
),
|
||||
CpuStorage::F32(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<f32>()) as NSUInteger,
|
||||
option,
|
||||
),
|
||||
CpuStorage::F64(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<f64>()) as NSUInteger,
|
||||
option,
|
||||
),
|
||||
};
|
||||
// TODO is that necessary ?
|
||||
// buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
// debug!("Allocate 2 - buffer size {}", buffer.length());
|
||||
Ok(Self::Storage {
|
||||
buffer,
|
||||
device: self.clone(),
|
||||
@ -460,13 +833,25 @@ impl BackendDevice for MetalDevice {
|
||||
})
|
||||
}
|
||||
|
||||
fn rand_uniform(&self, shape: &Shape, dtype: DType, mean: f64, stddev: f64) -> Result<Self::Storage> {
|
||||
fn rand_uniform(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
dtype: DType,
|
||||
mean: f64,
|
||||
stddev: f64,
|
||||
) -> Result<Self::Storage> {
|
||||
// TODO is there a better way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
}
|
||||
|
||||
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, stddev: f64) -> Result<Self::Storage> {
|
||||
fn rand_normal(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
dtype: DType,
|
||||
mean: f64,
|
||||
stddev: f64,
|
||||
) -> Result<Self::Storage> {
|
||||
// TODO is there a better way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
|
@ -90,6 +90,16 @@ pub enum Op {
|
||||
dilation: usize,
|
||||
},
|
||||
|
||||
#[allow(dead_code)]
|
||||
ConvTranspose1D {
|
||||
arg: Tensor,
|
||||
kernel: Tensor,
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
},
|
||||
|
||||
#[allow(dead_code)]
|
||||
Conv2D {
|
||||
arg: Tensor,
|
||||
@ -673,6 +683,8 @@ impl UnaryOpT for Gelu {
|
||||
}
|
||||
}
|
||||
|
||||
/// `erf` operation
|
||||
/// <https://en.wikipedia.org/wiki/Error_function>
|
||||
impl UnaryOpT for Erf {
|
||||
const NAME: &'static str = "erf";
|
||||
const KERNEL: &'static str = "uerf";
|
||||
@ -962,6 +974,10 @@ impl BackpropOp {
|
||||
};
|
||||
Self(op)
|
||||
}
|
||||
|
||||
pub(crate) fn is_none(&self) -> bool {
|
||||
self.0.is_none()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for BackpropOp {
|
||||
|
@ -1,7 +1,7 @@
|
||||
//! Support for the GGML file format.
|
||||
|
||||
use super::{k_quants, GgmlDType};
|
||||
use crate::{Device, Result};
|
||||
use crate::Result;
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -121,12 +121,11 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
|
||||
raw_data: &[u8],
|
||||
size_in_bytes: usize,
|
||||
dims: Vec<usize>,
|
||||
device: &Device,
|
||||
) -> Result<super::QTensor> {
|
||||
let raw_data_ptr = raw_data.as_ptr();
|
||||
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
|
||||
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
||||
super::QTensor::new(data.to_vec(), dims, device)
|
||||
super::QTensor::new(data.to_vec(), dims)
|
||||
}
|
||||
|
||||
/// Creates a [Tensor] from a raw GGML tensor.
|
||||
@ -134,7 +133,6 @@ pub fn qtensor_from_ggml(
|
||||
ggml_dtype: GgmlDType,
|
||||
raw_data: &[u8],
|
||||
dims: Vec<usize>,
|
||||
device: &Device,
|
||||
) -> Result<super::QTensor> {
|
||||
let tensor_elems = dims.iter().product::<usize>();
|
||||
let blck_size = ggml_dtype.blck_size();
|
||||
@ -146,38 +144,18 @@ pub fn qtensor_from_ggml(
|
||||
let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
|
||||
|
||||
match ggml_dtype {
|
||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
|
||||
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
|
||||
GgmlDType::Q4_0 => {
|
||||
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q4_1 => {
|
||||
from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q5_0 => {
|
||||
from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q8_0 => {
|
||||
from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q2K => {
|
||||
from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q3K => {
|
||||
from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q4K => {
|
||||
from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q5K => {
|
||||
from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q6K => {
|
||||
from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
|
||||
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
|
||||
}
|
||||
}
|
||||
@ -185,7 +163,6 @@ pub fn qtensor_from_ggml(
|
||||
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||
reader: &mut R,
|
||||
magic: VersionedMagic,
|
||||
device: &Device,
|
||||
) -> Result<(String, super::QTensor)> {
|
||||
let n_dims = reader.read_u32::<LittleEndian>()?;
|
||||
let name_len = reader.read_u32::<LittleEndian>()?;
|
||||
@ -210,7 +187,7 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||
// TODO: Mmap version to avoid copying the data around?
|
||||
let mut raw_data = vec![0u8; size_in_bytes];
|
||||
reader.read_exact(&mut raw_data)?;
|
||||
match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) {
|
||||
match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
|
||||
Ok(tensor) => Ok((name, tensor)),
|
||||
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
|
||||
}
|
||||
@ -224,10 +201,7 @@ pub struct Content {
|
||||
}
|
||||
|
||||
impl Content {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
) -> Result<Content> {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
|
||||
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
|
||||
reader.seek(std::io::SeekFrom::Start(0))?;
|
||||
@ -237,7 +211,7 @@ impl Content {
|
||||
let mut tensors = HashMap::new();
|
||||
|
||||
while reader.stream_position()? != last_position {
|
||||
let (name, tensor) = read_one_tensor(reader, magic, device)?;
|
||||
let (name, tensor) = read_one_tensor(reader, magic)?;
|
||||
tensors.insert(name, tensor);
|
||||
}
|
||||
Ok(Self {
|
||||
|
@ -3,7 +3,7 @@
|
||||
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
|
||||
|
||||
use super::{GgmlDType, QTensor};
|
||||
use crate::{Device, Result};
|
||||
use crate::Result;
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -29,6 +29,7 @@ impl TryFrom<u32> for Magic {
|
||||
pub enum VersionedMagic {
|
||||
GgufV1,
|
||||
GgufV2,
|
||||
GgufV3,
|
||||
}
|
||||
|
||||
impl VersionedMagic {
|
||||
@ -39,6 +40,7 @@ impl VersionedMagic {
|
||||
let versioned_magic = match (magic, version) {
|
||||
(Magic::Gguf, 1) => Self::GgufV1,
|
||||
(Magic::Gguf, 2) => Self::GgufV2,
|
||||
(Magic::Gguf, 3) => Self::GgufV3,
|
||||
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
|
||||
};
|
||||
Ok(versioned_magic)
|
||||
@ -57,7 +59,6 @@ impl TensorInfo {
|
||||
&self,
|
||||
reader: &mut R,
|
||||
tensor_data_offset: u64,
|
||||
device: &Device,
|
||||
) -> Result<QTensor> {
|
||||
let tensor_elems = self.shape.elem_count();
|
||||
let blck_size = self.ggml_dtype.blck_size();
|
||||
@ -70,12 +71,7 @@ impl TensorInfo {
|
||||
let mut raw_data = vec![0u8; size_in_bytes];
|
||||
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
|
||||
reader.read_exact(&mut raw_data)?;
|
||||
super::ggml_file::qtensor_from_ggml(
|
||||
self.ggml_dtype,
|
||||
&raw_data,
|
||||
self.shape.dims().to_vec(),
|
||||
device,
|
||||
)
|
||||
super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
@ -90,7 +86,9 @@ pub struct Content {
|
||||
fn read_string<R: std::io::Read>(reader: &mut R, magic: &VersionedMagic) -> Result<String> {
|
||||
let len = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||
reader.read_u64::<LittleEndian>()? as usize
|
||||
}
|
||||
};
|
||||
let mut v = vec![0u8; len];
|
||||
reader.read_exact(&mut v)?;
|
||||
@ -290,7 +288,9 @@ impl Value {
|
||||
let value_type = ValueType::from_u32(value_type)?;
|
||||
let len = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||
reader.read_u64::<LittleEndian>()? as usize
|
||||
}
|
||||
};
|
||||
let mut vs = Vec::with_capacity(len);
|
||||
for _ in 0..len {
|
||||
@ -387,11 +387,15 @@ impl Content {
|
||||
|
||||
let tensor_count = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||
reader.read_u64::<LittleEndian>()? as usize
|
||||
}
|
||||
};
|
||||
let metadata_kv_count = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||
reader.read_u64::<LittleEndian>()? as usize
|
||||
}
|
||||
};
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
@ -413,7 +417,7 @@ impl Content {
|
||||
reader.read_u32_into::<LittleEndian>(&mut dimensions)?;
|
||||
dimensions.into_iter().map(|c| c as usize).collect()
|
||||
}
|
||||
VersionedMagic::GgufV2 => {
|
||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||
let mut dimensions = vec![0; n_dimensions as usize];
|
||||
reader.read_u64_into::<LittleEndian>(&mut dimensions)?;
|
||||
dimensions.into_iter().map(|c| c as usize).collect()
|
||||
@ -456,13 +460,12 @@ impl Content {
|
||||
&self,
|
||||
reader: &mut R,
|
||||
name: &str,
|
||||
device: &Device,
|
||||
) -> Result<QTensor> {
|
||||
let tensor_info = match self.tensor_infos.get(name) {
|
||||
Some(tensor_info) => tensor_info,
|
||||
None => crate::bail!("cannot find tensor-infor for {name}"),
|
||||
};
|
||||
tensor_info.read(reader, self.tensor_data_offset, device)
|
||||
tensor_info.read(reader, self.tensor_data_offset)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -14,7 +14,6 @@ pub mod utils;
|
||||
pub use k_quants::GgmlType;
|
||||
|
||||
pub struct QTensor {
|
||||
device: Device,
|
||||
data: Box<dyn QuantizedType>,
|
||||
shape: Shape,
|
||||
}
|
||||
@ -171,20 +170,17 @@ impl QTensor {
|
||||
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
|
||||
data: Vec<T>,
|
||||
shape: S,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
check_shape::<T>(&shape)?;
|
||||
Ok(Self {
|
||||
data: Box::new(data),
|
||||
shape,
|
||||
device: device.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
|
||||
let shape = src.shape();
|
||||
let device = src.device();
|
||||
check_shape::<T>(shape)?;
|
||||
let src = src
|
||||
.to_dtype(crate::DType::F32)?
|
||||
@ -201,7 +197,6 @@ impl QTensor {
|
||||
Ok(Self {
|
||||
data: Box::new(data),
|
||||
shape: shape.clone(),
|
||||
device: device.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
@ -217,12 +212,7 @@ impl QTensor {
|
||||
&self.shape
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
||||
// TODO Skip the CPU part on metal
|
||||
let mut f32_data = vec![0f32; self.shape.elem_count()];
|
||||
self.data.to_float(&mut f32_data)?;
|
||||
Tensor::from_vec(f32_data, &self.shape, device)
|
||||
@ -315,49 +305,6 @@ impl crate::CustomOp1 for QTensor {
|
||||
)?;
|
||||
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
||||
}
|
||||
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &crate::MetalStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::MetalStorage, Shape)> {
|
||||
println!("TODO qmatmul");
|
||||
if !layout.is_contiguous() {
|
||||
crate::bail!("input tensor is not contiguous {layout:?}")
|
||||
}
|
||||
let src_shape = layout.shape();
|
||||
// self is transposed so n is first then k.
|
||||
let (n, k) = self.shape.dims2()?;
|
||||
if src_shape.rank() < 2 {
|
||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||
}
|
||||
let mut dst_shape = src_shape.dims().to_vec();
|
||||
let last_k = dst_shape.pop().unwrap();
|
||||
if last_k != k {
|
||||
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
|
||||
}
|
||||
dst_shape.push(n);
|
||||
let dst_shape = Shape::from(dst_shape);
|
||||
// let storage = storage.as_slice::<f32>()?;
|
||||
// let storage =
|
||||
// &storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||
let dst_storage = vec![0f32; dst_shape.elem_count()];
|
||||
// self.matmul_t(
|
||||
// (dst_shape.elem_count() / n, k, n),
|
||||
// storage,
|
||||
// &mut dst_storage,
|
||||
// )?;
|
||||
let cpu_storage = crate::CpuStorage::F32(dst_storage);
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
if let Device::Metal(device) = &self.device{
|
||||
Ok((
|
||||
device.storage_from_cpu_storage(&cpu_storage)?,
|
||||
dst_shape,
|
||||
))
|
||||
}else{
|
||||
crate::bail!("qtensor not on metal device")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl QMatMul {
|
||||
|
@ -334,6 +334,33 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn conv_transpose1d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
self.same_device(kernel, "conv-transpose1d")?;
|
||||
self.same_dtype(kernel, "conv-transpose1d")?;
|
||||
match (self, &kernel) {
|
||||
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
|
||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cpu(s))
|
||||
}
|
||||
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
|
||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "conv-transpose1d",
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn conv2d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
|
@ -157,6 +157,8 @@ pub(crate) fn from_storage<S: Into<Shape>>(
|
||||
) -> Tensor {
|
||||
let dtype = storage.dtype();
|
||||
let device = storage.device();
|
||||
let shape = shape.into();
|
||||
// println!("{:?} {storage:?}", shape);
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: Arc::new(RwLock::new(storage)),
|
||||
@ -166,7 +168,11 @@ pub(crate) fn from_storage<S: Into<Shape>>(
|
||||
dtype,
|
||||
device,
|
||||
};
|
||||
Tensor(Arc::new(tensor_))
|
||||
let result = Tensor(Arc::new(tensor_));
|
||||
// todo!(" from_storage");
|
||||
// let result = result.to_device(&Device::Cpu).unwrap();
|
||||
// todo!(" {result}");
|
||||
result
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
@ -477,6 +483,12 @@ impl Tensor {
|
||||
broadcast_binary_op!(broadcast_div, div);
|
||||
broadcast_binary_op!(broadcast_maximum, maximum);
|
||||
broadcast_binary_op!(broadcast_minimum, minimum);
|
||||
broadcast_binary_op!(broadcast_eq, eq);
|
||||
broadcast_binary_op!(broadcast_ne, ne);
|
||||
broadcast_binary_op!(broadcast_lt, lt);
|
||||
broadcast_binary_op!(broadcast_le, le);
|
||||
broadcast_binary_op!(broadcast_gt, gt);
|
||||
broadcast_binary_op!(broadcast_ge, ge);
|
||||
|
||||
unary_op!(recip, Recip);
|
||||
unary_op!(neg, Neg);
|
||||
@ -1811,7 +1823,12 @@ impl Tensor {
|
||||
|
||||
/// Returns a new tensor detached from the current graph, gradient are not propagated through
|
||||
/// this new node. The storage of this tensor is shared with the initial tensor.
|
||||
///
|
||||
/// If the tensor is already detached from the computation graph, the same tensor is returned.
|
||||
pub fn detach(&self) -> Result<Tensor> {
|
||||
if self.op.is_none() && !self.is_variable {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
@ -1823,6 +1840,7 @@ impl Tensor {
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
}
|
||||
|
||||
/// If the target device is the same as the tensor device, only a shallow copy is performed.
|
||||
pub fn to_device(&self, device: &Device) -> Result<Tensor> {
|
||||
@ -1833,7 +1851,11 @@ impl Tensor {
|
||||
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
|
||||
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
|
||||
}
|
||||
(Storage::Cpu(storage), Device::Metal(metal)) => {
|
||||
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
|
||||
}
|
||||
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||
(Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
|
||||
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
|
||||
// are the same.
|
||||
@ -2407,6 +2429,23 @@ impl Tensor {
|
||||
) -> Result<Self> {
|
||||
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
||||
}
|
||||
|
||||
/// Normalize a 'relative' axis value: positive values are kept, negative
|
||||
/// values means counting the dimensions from the back.
|
||||
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
|
||||
let rank = self.rank() as i64;
|
||||
if rank <= axis {
|
||||
crate::bail!("axis {axis} is too large, tensor rank {rank}")
|
||||
} else if 0 <= axis {
|
||||
Ok(axis as usize)
|
||||
} else {
|
||||
let naxis = rank + axis;
|
||||
if naxis < 0 {
|
||||
crate::bail!("axis {axis} is too small, tensor rank {rank}")
|
||||
}
|
||||
Ok(naxis as usize)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! bin_trait {
|
||||
|
@ -13,6 +13,11 @@ res = torch.nn.functional.conv1d(t, w)
|
||||
print(res.flatten())
|
||||
res = torch.nn.functional.conv1d(t, w, padding=1)
|
||||
print(res.flatten())
|
||||
|
||||
w_t = w.transpose(0, 1)
|
||||
res = torch.nn.functional.conv_transpose1d(t, w_t)
|
||||
print(res.shape)
|
||||
print(res)
|
||||
*/
|
||||
fn conv1d(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
@ -45,6 +50,17 @@ fn conv1d(dev: &Device) -> Result<()> {
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||
);
|
||||
if dev.is_cpu() {
|
||||
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[
|
||||
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
|
||||
4.7076, -5.9745, -0.8276, 1.621
|
||||
],
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -205,6 +205,71 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[1.0116, 1.0830, 1.0003, 0.6188],
|
||||
);
|
||||
|
||||
// Testing compared to pytorch torch.erf
|
||||
//
|
||||
// import torch
|
||||
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
|
||||
// y = x.erf()
|
||||
// print(y)
|
||||
// loss = y.sum()
|
||||
// loss.backward()
|
||||
// print(x.grad)
|
||||
let y = x.erf()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(test_utils::to_vec1_round(&y, 4)?, [1.0, 0.8427, 1.0, 0.168]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[0.0001, 0.4151, 0.0, 1.1033],
|
||||
);
|
||||
|
||||
// Testing compared to pytorch nn.GELU(approximate = 'none')
|
||||
//
|
||||
// import torch
|
||||
// import torch.nn.functional as F
|
||||
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
|
||||
// y = F.gelu(x, approximate='none')
|
||||
// print(y)
|
||||
// loss = y.sum()
|
||||
// loss.backward()
|
||||
// print(x.grad)
|
||||
let y = x.gelu_erf()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[2.9960, 0.8413, 3.9999, 0.0839]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[1.0119, 1.0833, 1.0005, 0.6188],
|
||||
);
|
||||
|
||||
// Testing compared to pytorch elu
|
||||
//
|
||||
// import torch
|
||||
// import torch.nn.functional as F
|
||||
// x = torch.tensor([-1.0, 0.0, -2.0, 3.0], requires_grad=True)
|
||||
// y = F.elu(x, alpha=2.0)
|
||||
// print(y)
|
||||
// loss = y.min
|
||||
// loss = y.sum()
|
||||
// loss.backward()
|
||||
// print(x.grad)
|
||||
let elu_x = Var::new(&[-1.0f32, 0., -2., 3.], device)?;
|
||||
let y = elu_x.elu(2.)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&elu_x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[-1.2642, 0.0000, -1.7293, 3.0000]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[0.7358, 2.0000, 0.2707, 1.0000]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -4,7 +4,9 @@
|
||||
//! <https://www.cs.toronto.edu/~kriz/cifar.html>
|
||||
//! The binary version of the dataset is used.
|
||||
use crate::vision::Dataset;
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle::{DType, Device, Error, Result, Tensor};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use parquet::file::reader::{FileReader, SerializedFileReader};
|
||||
use std::fs::File;
|
||||
use std::io::{BufReader, Read};
|
||||
|
||||
@ -60,3 +62,58 @@ pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<Dataset> {
|
||||
labels: 10,
|
||||
})
|
||||
}
|
||||
|
||||
fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor, Tensor)> {
|
||||
let samples = parquet.metadata().file_metadata().num_rows() as usize;
|
||||
let mut buffer_images: Vec<u8> = Vec::with_capacity(samples * 1_024);
|
||||
let mut buffer_labels: Vec<u8> = Vec::with_capacity(samples);
|
||||
for row in parquet.into_iter().flatten() {
|
||||
for (_name, field) in row.get_column_iter() {
|
||||
if let parquet::record::Field::Group(subrow) = field {
|
||||
for (_name, field) in subrow.get_column_iter() {
|
||||
if let parquet::record::Field::Bytes(value) = field {
|
||||
let image = image::load_from_memory(value.data()).unwrap();
|
||||
buffer_images.extend(image.to_rgb8().as_raw());
|
||||
}
|
||||
}
|
||||
} else if let parquet::record::Field::Long(label) = field {
|
||||
buffer_labels.push(*label as u8);
|
||||
}
|
||||
}
|
||||
}
|
||||
let images = (Tensor::from_vec(buffer_images, (samples, 3, 32, 32), &Device::Cpu)?
|
||||
.to_dtype(DType::U8)?
|
||||
/ 255.)?;
|
||||
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
|
||||
Ok((images, labels))
|
||||
}
|
||||
|
||||
pub fn load() -> Result<Dataset> {
|
||||
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||
let dataset_id = "cifar10".to_string();
|
||||
let repo = Repo::with_revision(
|
||||
dataset_id,
|
||||
RepoType::Dataset,
|
||||
"refs/convert/parquet".to_string(),
|
||||
);
|
||||
let repo = api.repo(repo);
|
||||
let test_parquet_filename = repo
|
||||
.get("plain_text/test/0000.parquet")
|
||||
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||
let train_parquet_filename = repo
|
||||
.get("plain_text/train/0000.parquet")
|
||||
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)
|
||||
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
|
||||
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)
|
||||
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
|
||||
let (test_images, test_labels) = load_parquet(test_parquet)?;
|
||||
let (train_images, train_labels) = load_parquet(train_parquet)?;
|
||||
Ok(crate::vision::Dataset {
|
||||
train_images,
|
||||
train_labels,
|
||||
test_images,
|
||||
test_labels,
|
||||
labels: 10,
|
||||
})
|
||||
}
|
||||
|
@ -16,6 +16,7 @@ candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||
candle-onnx = { path = "../candle-onnx", version = "0.3.0", optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
image = { workspace = true }
|
||||
@ -51,11 +52,11 @@ anyhow = { workspace = true }
|
||||
default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
metal = ["candle/metal", "candle-nn/metal", "candle-transformers/metal"]
|
||||
cudnn = ["candle/cudnn"]
|
||||
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||
onnx = ["candle-onnx"]
|
||||
|
||||
[[example]]
|
||||
name = "llama_multiprocess"
|
||||
@ -64,3 +65,11 @@ required-features = ["cuda", "nccl", "flash-attn"]
|
||||
[[example]]
|
||||
name = "reinforcement-learning"
|
||||
required-features = ["pyo3"]
|
||||
|
||||
[[example]]
|
||||
name = "onnx"
|
||||
required-features = ["onnx"]
|
||||
|
||||
[[example]]
|
||||
name = "onnx_basics"
|
||||
required-features = ["onnx"]
|
||||
|
@ -329,14 +329,18 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
println!("{tokens:?}");
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0.. {
|
||||
for index in 0..1 {
|
||||
if tokens.len() >= config.seq_len {
|
||||
break;
|
||||
}
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
// println!("Input {}", input);
|
||||
// println!("Input {}", input.to_device(&candle::Device::Cpu)?);
|
||||
let logits = model.forward(&input, index_pos)?;
|
||||
let logits = logits.i((0, logits.dim(1)? - 1))?;
|
||||
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
|
||||
|
@ -8,6 +8,7 @@ use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum NormType {
|
||||
WeightNorm,
|
||||
TimeGroupNorm,
|
||||
None,
|
||||
}
|
||||
|
||||
@ -268,6 +269,7 @@ impl Module for EncodecConvTranspose1d {
|
||||
struct EncodecConv1d {
|
||||
causal: bool,
|
||||
conv: Conv1d,
|
||||
norm: Option<candle_nn::GroupNorm>,
|
||||
}
|
||||
|
||||
impl EncodecConv1d {
|
||||
@ -292,7 +294,7 @@ impl EncodecConv1d {
|
||||
},
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
NormType::None => conv1d(
|
||||
NormType::None | NormType::TimeGroupNorm => conv1d(
|
||||
in_c,
|
||||
out_c,
|
||||
kernel_size,
|
||||
@ -305,9 +307,17 @@ impl EncodecConv1d {
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
};
|
||||
let norm = match cfg.norm_type {
|
||||
NormType::None | NormType::WeightNorm => None,
|
||||
NormType::TimeGroupNorm => {
|
||||
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
|
||||
Some(gn)
|
||||
}
|
||||
};
|
||||
Ok(Self {
|
||||
causal: cfg.use_causal_conv,
|
||||
conv,
|
||||
norm,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -316,8 +326,10 @@ impl Module for EncodecConv1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
// TODO: padding, depending on causal.
|
||||
let xs = self.conv.forward(xs)?;
|
||||
// If we add support for NormType "time_group_norm", we should add some normalization here.
|
||||
Ok(xs)
|
||||
match &self.norm {
|
||||
None => Ok(xs),
|
||||
Some(norm) => xs.apply(norm),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
10
candle-examples/examples/onnx/README.md
Normal file
10
candle-examples/examples/onnx/README.md
Normal file
@ -0,0 +1,10 @@
|
||||
## Using ONNX models in Candle
|
||||
|
||||
This example demonstrates how to run ONNX based models in Candle, the model
|
||||
being used here is a small sequeezenet variant.
|
||||
|
||||
You can run the example with the following command:
|
||||
|
||||
```bash
|
||||
cargo run --example squeezenet-onnx --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
```
|
78
candle-examples/examples/onnx/main.rs
Normal file
78
candle-examples/examples/onnx/main.rs
Normal file
@ -0,0 +1,78 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::{IndexOp, D};
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
SqueezeNet,
|
||||
EfficientNet,
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
/// The model to be used.
|
||||
#[arg(value_enum, long, default_value_t = Which::SqueezeNet)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
let image = match args.which {
|
||||
Which::SqueezeNet => image,
|
||||
Which::EfficientNet => image.permute((1, 2, 0))?,
|
||||
};
|
||||
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => match args.which {
|
||||
Which::SqueezeNet => hf_hub::api::sync::Api::new()?
|
||||
.model("lmz/candle-onnx".into())
|
||||
.get("squeezenet1.1-7.onnx")?,
|
||||
Which::EfficientNet => hf_hub::api::sync::Api::new()?
|
||||
.model("onnx/EfficientNet-Lite4".into())
|
||||
.get("efficientnet-lite4-11.onnx")?,
|
||||
},
|
||||
};
|
||||
|
||||
let model = candle_onnx::read_file(model)?;
|
||||
let graph = model.graph.as_ref().unwrap();
|
||||
let mut inputs = std::collections::HashMap::new();
|
||||
inputs.insert(graph.input[0].name.to_string(), image.unsqueeze(0)?);
|
||||
let mut outputs = candle_onnx::simple_eval(&model, inputs)?;
|
||||
let output = outputs.remove(&graph.output[0].name).unwrap();
|
||||
let prs = match args.which {
|
||||
Which::SqueezeNet => candle_nn::ops::softmax(&output, D::Minus1)?,
|
||||
Which::EfficientNet => output,
|
||||
};
|
||||
let prs = prs.i(0)?.to_vec1::<f32>()?;
|
||||
|
||||
// Sort the predictions and take the top 5
|
||||
let mut top: Vec<_> = prs.iter().enumerate().collect();
|
||||
top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
|
||||
let top = top.into_iter().take(5).collect::<Vec<_>>();
|
||||
|
||||
// Print the top predictions
|
||||
for &(i, p) in &top {
|
||||
println!(
|
||||
"{:50}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[i],
|
||||
p * 100.0
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
87
candle-examples/examples/onnx_basics.rs
Normal file
87
candle-examples/examples/onnx_basics.rs
Normal file
@ -0,0 +1,87 @@
|
||||
use anyhow::Result;
|
||||
use candle::{Device, Tensor};
|
||||
|
||||
use clap::{Parser, Subcommand};
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
enum Command {
|
||||
Print {
|
||||
#[arg(long)]
|
||||
file: String,
|
||||
},
|
||||
SimpleEval {
|
||||
#[arg(long)]
|
||||
file: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
pub struct Args {
|
||||
#[command(subcommand)]
|
||||
command: Command,
|
||||
}
|
||||
|
||||
pub fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
match args.command {
|
||||
Command::Print { file } => {
|
||||
let model = candle_onnx::read_file(file)?;
|
||||
println!("{model:?}");
|
||||
let graph = model.graph.unwrap();
|
||||
for node in graph.node.iter() {
|
||||
println!("{node:?}");
|
||||
}
|
||||
}
|
||||
Command::SimpleEval { file } => {
|
||||
let model = candle_onnx::read_file(file)?;
|
||||
let graph = model.graph.as_ref().unwrap();
|
||||
let constants: std::collections::HashSet<_> =
|
||||
graph.initializer.iter().map(|i| i.name.as_str()).collect();
|
||||
let mut inputs = std::collections::HashMap::new();
|
||||
for input in graph.input.iter() {
|
||||
use candle_onnx::onnx::tensor_proto::DataType;
|
||||
if constants.contains(input.name.as_str()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let type_ = input.r#type.as_ref().expect("no type for input");
|
||||
let type_ = type_.value.as_ref().expect("no type.value for input");
|
||||
let value = match type_ {
|
||||
candle_onnx::onnx::type_proto::Value::TensorType(tt) => {
|
||||
let dt = match DataType::try_from(tt.elem_type) {
|
||||
Ok(dt) => match candle_onnx::dtype(dt) {
|
||||
Some(dt) => dt,
|
||||
None => {
|
||||
anyhow::bail!(
|
||||
"unsupported 'value' data-type {dt:?} for {}",
|
||||
input.name
|
||||
)
|
||||
}
|
||||
},
|
||||
type_ => anyhow::bail!("unsupported input type {type_:?}"),
|
||||
};
|
||||
let shape = tt.shape.as_ref().expect("no tensortype.shape for input");
|
||||
let dims = shape
|
||||
.dim
|
||||
.iter()
|
||||
.map(|dim| match dim.value.as_ref().expect("no dim value") {
|
||||
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
|
||||
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => Ok(42),
|
||||
})
|
||||
.collect::<Result<Vec<usize>>>()?;
|
||||
Tensor::zeros(dims, dt, &Device::Cpu)?
|
||||
}
|
||||
type_ => anyhow::bail!("unsupported input type {type_:?}"),
|
||||
};
|
||||
println!("input {}: {value:?}", input.name);
|
||||
inputs.insert(input.name.clone(), value);
|
||||
}
|
||||
let outputs = candle_onnx::simple_eval(&model, inputs)?;
|
||||
for (name, value) in outputs.iter() {
|
||||
println!("output {name}: {value:?}")
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -1,5 +1,7 @@
|
||||
# candle-quantized-t5
|
||||
|
||||
## Seq2Seq example
|
||||
|
||||
This example uses a quantized version of the t5 model.
|
||||
|
||||
```bash
|
||||
@ -8,6 +10,8 @@ $ cargo run --example quantized-t5 --release -- --prompt "translate to German: A
|
||||
Eine schöne Kerze.
|
||||
```
|
||||
|
||||
## Generating Quantized weight files
|
||||
|
||||
The weight file is automatically retrieved from the hub. It is also possible to
|
||||
generate quantized weight files from the original safetensors file by using the
|
||||
`tensor-tools` command line utility via:
|
||||
@ -16,8 +20,11 @@ generate quantized weight files from the original safetensors file by using the
|
||||
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
|
||||
```
|
||||
|
||||
To use a different model, specify the `model-id`. For example, you can use
|
||||
quantized [CoEdit models](https://huggingface.co/jbochi/candle-coedit-quantized).
|
||||
## Using custom models
|
||||
|
||||
To use a different model, specify the `model-id`.
|
||||
|
||||
For example, for text editing, you can use quantized [CoEdit models](https://huggingface.co/jbochi/candle-coedit-quantized).
|
||||
|
||||
```bash
|
||||
$ cargo run --example quantized-t5 --release -- \
|
||||
@ -26,6 +33,7 @@ $ cargo run --example quantized-t5 --release -- \
|
||||
--temperature 0
|
||||
...
|
||||
Although their flight is weak, they run quickly through the tree canopy.
|
||||
```
|
||||
|
||||
By default, it will look for `model.gguf` and `config.json`, but you can specify
|
||||
custom local or remote `weight-file` and `config-file`s:
|
||||
@ -40,3 +48,16 @@ cargo run --example quantized-t5 --release -- \
|
||||
...
|
||||
Note that a storm surge is what forecasters consider a hurricane's most dangerous part.
|
||||
```
|
||||
|
||||
### [MADLAD-400](https://arxiv.org/abs/2309.04662)
|
||||
|
||||
MADLAD-400 is a series of multilingual machine translation T5 models trained on 250 billion tokens covering over 450 languages using publicly available data. These models are competitive with significantly larger models.
|
||||
|
||||
```bash
|
||||
cargo run --example quantized-t5 --release -- \
|
||||
--model-id "jbochi/madlad400-3b-mt" --weight-file "model-q4k.gguf" \
|
||||
--prompt "<2de> How are you, my friend?" \
|
||||
--temperature 0
|
||||
...
|
||||
Wie geht es dir, mein Freund?
|
||||
```
|
||||
|
@ -173,7 +173,11 @@ fn main() -> Result<()> {
|
||||
.to_vec();
|
||||
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
let mut model = builder.build_model()?;
|
||||
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
||||
let mut output_token_ids = [builder
|
||||
.config
|
||||
.decoder_start_token_id
|
||||
.unwrap_or(builder.config.pad_token_id) as u32]
|
||||
.to_vec();
|
||||
let temperature = if args.temperature <= 0. {
|
||||
None
|
||||
} else {
|
||||
|
@ -9,9 +9,10 @@ use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use candle::quantized::{ggml_file, gguf_file};
|
||||
use candle::{Tensor};
|
||||
use candle::{Device, Tensor};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_transformers::models::quantized_llama as model;
|
||||
use model::ModelWeights;
|
||||
|
||||
@ -24,7 +25,7 @@ enum Prompt {
|
||||
One(String),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "7b")]
|
||||
L7b,
|
||||
@ -48,8 +49,10 @@ enum Which {
|
||||
Mistral7b,
|
||||
#[value(name = "7b-mistral-instruct")]
|
||||
Mistral7bInstruct,
|
||||
#[value(name = "7b-zephyr")]
|
||||
Zephyr7b,
|
||||
#[value(name = "7b-zephyr-a")]
|
||||
Zephyr7bAlpha,
|
||||
#[value(name = "7b-zephyr-b")]
|
||||
Zephyr7bBeta,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
@ -64,7 +67,28 @@ impl Which {
|
||||
| Self::L7bCode
|
||||
| Self::L13bCode
|
||||
| Self::L34bCode => false,
|
||||
Self::Mistral7b | Self::Mistral7bInstruct | Self::Zephyr7b => true,
|
||||
// Zephyr is a fine tuned version of mistral and should be treated in the same way.
|
||||
Self::Zephyr7bAlpha
|
||||
| Self::Zephyr7bBeta
|
||||
| Self::Mistral7b
|
||||
| Self::Mistral7bInstruct => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_zephyr(&self) -> bool {
|
||||
match self {
|
||||
Self::L7b
|
||||
| Self::L13b
|
||||
| Self::L70b
|
||||
| Self::L7bChat
|
||||
| Self::L13bChat
|
||||
| Self::L70bChat
|
||||
| Self::L7bCode
|
||||
| Self::L13bCode
|
||||
| Self::L34bCode
|
||||
| Self::Mistral7b
|
||||
| Self::Mistral7bInstruct => false,
|
||||
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -83,7 +107,7 @@ struct Args {
|
||||
prompt: Option<String>,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(short = 'n', long, default_value_t = 100)]
|
||||
#[arg(short = 'n', long, default_value_t = 1000)]
|
||||
sample_len: usize,
|
||||
|
||||
/// The tokenizer config in json format.
|
||||
@ -176,10 +200,13 @@ impl Args {
|
||||
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
|
||||
"mistral-7b-instruct-v0.1.Q4_K_S.gguf",
|
||||
),
|
||||
Which::Zephyr7b => (
|
||||
Which::Zephyr7bAlpha => (
|
||||
"TheBloke/zephyr-7B-alpha-GGUF",
|
||||
"zephyr-7b-alpha.Q4_K_M.gguf",
|
||||
),
|
||||
Which::Zephyr7bBeta => {
|
||||
("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf")
|
||||
}
|
||||
};
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(repo.to_string());
|
||||
@ -190,31 +217,6 @@ impl Args {
|
||||
}
|
||||
}
|
||||
|
||||
fn print_token(next_token: u32, tokenizer: &Tokenizer) {
|
||||
// Extracting the last token as a string is complicated, here we just apply some simple
|
||||
// heuristics as it seems to work well enough for this example. See the following for more
|
||||
// details:
|
||||
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
||||
if let Some(text) = tokenizer.id_to_token(next_token) {
|
||||
let text = text.replace('▁', " ");
|
||||
let ascii = text
|
||||
.strip_prefix("<0x")
|
||||
.and_then(|t| t.strip_suffix('>'))
|
||||
.and_then(|t| u8::from_str_radix(t, 16).ok());
|
||||
match ascii {
|
||||
None => print!("{text}"),
|
||||
Some(ascii) => {
|
||||
if let Some(chr) = char::from_u32(ascii as u32) {
|
||||
if chr.is_ascii() {
|
||||
print!("{chr}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let _ = std::io::stdout().flush();
|
||||
}
|
||||
}
|
||||
|
||||
fn format_size(size_in_bytes: usize) -> String {
|
||||
if size_in_bytes < 1_000 {
|
||||
format!("{}B", size_in_bytes)
|
||||
@ -232,7 +234,6 @@ fn main() -> anyhow::Result<()> {
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let mut device = candle_examples::device(false)?;
|
||||
let temperature = if args.temperature == 0. {
|
||||
None
|
||||
} else {
|
||||
@ -277,10 +278,10 @@ fn main() -> anyhow::Result<()> {
|
||||
&format_size(total_size_in_bytes),
|
||||
start.elapsed().as_secs_f32(),
|
||||
);
|
||||
ModelWeights::from_gguf(model, &mut file, &device)?
|
||||
ModelWeights::from_gguf(model, &mut file)?
|
||||
}
|
||||
Some("ggml" | "bin") | Some(_) | None => {
|
||||
let model = ggml_file::Content::read(&mut file, &device)?;
|
||||
let model = ggml_file::Content::read(&mut file)?;
|
||||
let mut total_size_in_bytes = 0;
|
||||
for (_, tensor) in model.tensors.iter() {
|
||||
let elem_count = tensor.shape().elem_count();
|
||||
@ -304,16 +305,18 @@ fn main() -> anyhow::Result<()> {
|
||||
| Which::L34bCode => 1,
|
||||
Which::Mistral7b
|
||||
| Which::Mistral7bInstruct
|
||||
| Which::Zephyr7b
|
||||
| Which::Zephyr7bAlpha
|
||||
| Which::Zephyr7bBeta
|
||||
| Which::L70b
|
||||
| Which::L70bChat => 8,
|
||||
};
|
||||
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa), &device)?
|
||||
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
|
||||
}
|
||||
};
|
||||
println!("model built");
|
||||
|
||||
let tokenizer = args.tokenizer()?;
|
||||
let mut tos = TokenOutputStream::new(tokenizer);
|
||||
let prompt = match args.prompt.as_deref() {
|
||||
Some("chat") => Prompt::Chat,
|
||||
Some("interactive") => Prompt::Interactive,
|
||||
@ -336,7 +339,9 @@ fn main() -> anyhow::Result<()> {
|
||||
prompt.pop();
|
||||
}
|
||||
}
|
||||
if args.which.is_mistral() {
|
||||
if args.which.is_zephyr() {
|
||||
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>")
|
||||
} else if args.which.is_mistral() {
|
||||
format!("[INST] {prompt} [/INST]")
|
||||
} else {
|
||||
prompt
|
||||
@ -344,7 +349,8 @@ fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
};
|
||||
print!("{}", &prompt_str);
|
||||
let tokens = tokenizer
|
||||
let tokens = tos
|
||||
.tokenizer()
|
||||
.encode(prompt_str, true)
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
if args.verbose_prompt {
|
||||
@ -367,46 +373,51 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let start_prompt_processing = std::time::Instant::now();
|
||||
let mut next_token = {
|
||||
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
// TODO Remove this once implementation is finished.
|
||||
let logits = logits.ones_like()?;
|
||||
logits_processor.sample(&logits)?
|
||||
};
|
||||
let prompt_dt = start_prompt_processing.elapsed();
|
||||
all_tokens.push(next_token);
|
||||
print_token(next_token, &tokenizer);
|
||||
if let Some(t) = tos.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
||||
let eos_token = *tokenizer.get_vocab(true).get("</s>").unwrap();
|
||||
let eos_token = *tos.tokenizer().get_vocab(true).get("</s>").unwrap();
|
||||
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
let mut sampled = 0;
|
||||
for index in 0..to_sample {
|
||||
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
|
||||
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||
if let candle::Device::Metal(device) = &mut device{
|
||||
device.flush()
|
||||
}
|
||||
let logits = logits.squeeze(0)?;
|
||||
// let logits = if args.repeat_penalty == 1. {
|
||||
// logits
|
||||
// } else {
|
||||
// let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
||||
// candle_transformers::utils::apply_repeat_penalty(
|
||||
// &logits,
|
||||
// args.repeat_penalty,
|
||||
// &all_tokens[start_at..],
|
||||
// )?
|
||||
// };
|
||||
// TODO Remove this once implementation is finished.
|
||||
let logits = logits.ones_like()?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&all_tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
next_token = logits_processor.sample(&logits)?;
|
||||
all_tokens.push(next_token);
|
||||
print_token(next_token, &tokenizer);
|
||||
if let Some(t) = tos.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
sampled += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
};
|
||||
}
|
||||
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
let dt = start_post_prompt.elapsed();
|
||||
println!(
|
||||
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
||||
@ -414,9 +425,8 @@ fn main() -> anyhow::Result<()> {
|
||||
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
||||
);
|
||||
println!(
|
||||
"{:4} tokens generated: {:.2} token/s",
|
||||
to_sample,
|
||||
to_sample as f64 / dt.as_secs_f64(),
|
||||
"{sampled:4} tokens generated: {:.2} token/s",
|
||||
sampled as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
|
||||
match prompt {
|
||||
|
@ -5,11 +5,23 @@
|
||||
```bash
|
||||
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode
|
||||
...
|
||||
Running on CPU, to run on GPU, build this example with `--features cuda`
|
||||
Eine schöne Kerze.
|
||||
9 tokens generated (2.42 token/s)
|
||||
```
|
||||
|
||||
## Translation with [MADLAD-400](https://arxiv.org/abs/2309.04662)
|
||||
|
||||
MADLAD-400 is a series of multilingual machine translation T5 models trained on 250 billion tokens covering over 450 languages using publicly available data. These models are competitive with significantly larger models.
|
||||
|
||||
```bash
|
||||
cargo run --example t5 --release -- \
|
||||
--model-id "jbochi/madlad400-3b-mt" \
|
||||
--prompt "<2de> How are you, my friend?" \
|
||||
--decode --temperature 0
|
||||
...
|
||||
Wie geht es dir, mein Freund?
|
||||
```
|
||||
|
||||
## Sentence embedding example:
|
||||
|
||||
```bash
|
||||
|
@ -172,7 +172,12 @@ fn main() -> Result<()> {
|
||||
println!("Took {:?}", start.elapsed());
|
||||
} else {
|
||||
let mut model = builder.build_conditional_generation()?;
|
||||
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
||||
let mut output_token_ids = [builder
|
||||
.config
|
||||
.decoder_start_token_id
|
||||
.unwrap_or(builder.config.pad_token_id)
|
||||
as u32]
|
||||
.to_vec();
|
||||
if let Some(decoder_prompt) = &args.decoder_prompt {
|
||||
print!("{decoder_prompt}");
|
||||
output_token_ids.extend(
|
||||
|
@ -345,7 +345,7 @@ enum Task {
|
||||
Translate,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
|
||||
enum WhichModel {
|
||||
Tiny,
|
||||
#[value(name = "tiny.en")]
|
||||
@ -361,15 +361,27 @@ enum WhichModel {
|
||||
MediumEn,
|
||||
Large,
|
||||
LargeV2,
|
||||
LargeV3,
|
||||
#[value(name = "distil-medium.en")]
|
||||
DistilMediumEn,
|
||||
#[value(name = "distil-large-v2")]
|
||||
DistilLargeV2,
|
||||
}
|
||||
|
||||
impl WhichModel {
|
||||
fn is_multilingual(&self) -> bool {
|
||||
match self {
|
||||
Self::Tiny | Self::Base | Self::Small | Self::Medium | Self::Large | Self::LargeV2 => {
|
||||
true
|
||||
Self::Tiny
|
||||
| Self::Base
|
||||
| Self::Small
|
||||
| Self::Medium
|
||||
| Self::Large
|
||||
| Self::LargeV2
|
||||
| Self::LargeV3
|
||||
| Self::DistilLargeV2 => true,
|
||||
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => {
|
||||
false
|
||||
}
|
||||
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn => false,
|
||||
}
|
||||
}
|
||||
|
||||
@ -385,6 +397,9 @@ impl WhichModel {
|
||||
Self::MediumEn => ("openai/whisper-medium.en", "main"),
|
||||
Self::Large => ("openai/whisper-large", "refs/pr/36"),
|
||||
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
|
||||
Self::LargeV3 => ("openai/whisper-large-v3", "main"),
|
||||
Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"),
|
||||
Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -496,17 +511,25 @@ fn main() -> Result<()> {
|
||||
repo.get(&format!("model-{ext}-q80.gguf"))?,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
repo.get("config.json")?,
|
||||
repo.get("tokenizer.json")?,
|
||||
repo.get("model.safetensors")?,
|
||||
)
|
||||
let config = repo.get("config.json")?;
|
||||
let tokenizer = if args.model == WhichModel::LargeV3 {
|
||||
panic!("openai/whisper-large-v3 does not provide a compatible tokenizer.json config at the moment")
|
||||
} else {
|
||||
repo.get("tokenizer.json")?
|
||||
};
|
||||
let model = repo.get("model.safetensors")?;
|
||||
(config, tokenizer, model)
|
||||
};
|
||||
(config, tokenizer, model, sample)
|
||||
};
|
||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let mel_bytes = include_bytes!("melfilters.bytes");
|
||||
let mel_bytes = match config.num_mel_bins {
|
||||
80 => include_bytes!("melfilters.bytes").as_slice(),
|
||||
128 => include_bytes!("melfilters128.bytes").as_slice(),
|
||||
nmel => anyhow::bail!("unexpected num_mel_bins {nmel}"),
|
||||
};
|
||||
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
|
||||
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
|
||||
|
||||
@ -522,12 +545,15 @@ fn main() -> Result<()> {
|
||||
.map(|v| *v as f32 / 32768.)
|
||||
.collect();
|
||||
println!("pcm data loaded {}", pcm_data.len());
|
||||
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters);
|
||||
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
|
||||
let mel_len = mel.len();
|
||||
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
|
||||
let mel = Tensor::from_vec(
|
||||
mel,
|
||||
(1, config.num_mel_bins, mel_len / config.num_mel_bins),
|
||||
&device,
|
||||
)?;
|
||||
println!("loaded mel: {:?}", mel.dims());
|
||||
|
||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||
let mut model = if args.quantized {
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;
|
||||
|
BIN
candle-examples/examples/whisper/melfilters128.bytes
Normal file
BIN
candle-examples/examples/whisper/melfilters128.bytes
Normal file
Binary file not shown.
@ -8,25 +8,23 @@ use candle::{Device, Result, Tensor};
|
||||
pub fn device(cpu: bool) -> Result<Device> {
|
||||
if cpu {
|
||||
Ok(Device::Cpu)
|
||||
} else {
|
||||
if cuda_is_available() {
|
||||
} else if cuda_is_available() {
|
||||
Ok(Device::new_cuda(0)?)
|
||||
} else if metal_is_available() {
|
||||
Ok(Device::new_metal(0)?)
|
||||
} else {
|
||||
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||
{
|
||||
println!("Running on CPU, to run on GPU(metal), build this example with `--features metal`");
|
||||
println!(
|
||||
"Running on CPU, to run on GPU(metal), build this example with `--features metal`"
|
||||
);
|
||||
}
|
||||
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
|
||||
{
|
||||
println!(
|
||||
"Running on CPU, to run on GPU, build this example with `--features cuda`"
|
||||
);
|
||||
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
|
||||
}
|
||||
Ok(Device::Cpu)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_image<P: AsRef<std::path::Path>>(
|
||||
|
@ -1,12 +1,21 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
license.workspace = true
|
||||
version = "0.3.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
keywords = ["blas", "tensor", "machine-learning"]
|
||||
categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
metal = { workspace = true }
|
||||
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||
metal = { path = "../../metal-rs", features = ["mps"] }
|
||||
once_cell = "1.18.0"
|
||||
thiserror = "1"
|
||||
tracing = "0.1.37"
|
||||
|
||||
[dev-dependencies]
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
rand = "0.8.5"
|
||||
|
75
candle-metal-kernels/examples/affine.rs
Normal file
75
candle-metal-kernels/examples/affine.rs
Normal file
@ -0,0 +1,75 @@
|
||||
use candle_metal_kernels::{call_affine, Kernels};
|
||||
use metal::objc::rc::autoreleasepool;
|
||||
use metal::{Device, MTLResourceOptions};
|
||||
use rand;
|
||||
use std::any::type_name;
|
||||
use std::time::Instant;
|
||||
|
||||
fn main() {
|
||||
let device = Device::system_default().unwrap();
|
||||
let kernels = Kernels::new();
|
||||
|
||||
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
|
||||
let f32_10k = (0..10000)
|
||||
.map(|_| rand::random::<f32>())
|
||||
.collect::<Vec<_>>();
|
||||
let f32_100k = (0..100000)
|
||||
.map(|_| rand::random::<f32>())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
|
||||
"dtype", "kernel", "size", "runs", "total time", "avg time"
|
||||
);
|
||||
|
||||
// f32
|
||||
run_affine_bench(&device, &kernels, &f32_1k);
|
||||
run_affine_bench(&device, &kernels, &f32_10k);
|
||||
run_affine_bench(&device, &kernels, &f32_100k);
|
||||
}
|
||||
|
||||
fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) {
|
||||
let command_queue = device.new_command_queue();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let iterations = 10000;
|
||||
let input = device.new_buffer_with_data(
|
||||
v.as_ptr() as *const core::ffi::c_void,
|
||||
core::mem::size_of_val(v) as u64,
|
||||
options,
|
||||
);
|
||||
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
|
||||
|
||||
let mul: f32 = 1.2345;
|
||||
let add: f32 = 2.3456;
|
||||
let total_time = autoreleasepool(|| {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
call_affine(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
v.len(),
|
||||
&input,
|
||||
&mut output,
|
||||
mul,
|
||||
add,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
start.elapsed()
|
||||
});
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||
type_name::<T>().split("::").last().unwrap(),
|
||||
"affine",
|
||||
v.len(),
|
||||
iterations,
|
||||
total_time,
|
||||
total_time / iterations
|
||||
);
|
||||
}
|
182
candle-metal-kernels/examples/binary.rs
Normal file
182
candle-metal-kernels/examples/binary.rs
Normal file
@ -0,0 +1,182 @@
|
||||
use candle_metal_kernels::{binary, call_binary_contiguous, call_binary_strided, Kernels};
|
||||
use half::{bf16, f16};
|
||||
use metal::objc::rc::autoreleasepool;
|
||||
use metal::{Device, MTLResourceOptions};
|
||||
use rand;
|
||||
use std::any::type_name;
|
||||
use std::time::Instant;
|
||||
|
||||
fn main() {
|
||||
let device = Device::system_default().unwrap();
|
||||
let kernels = Kernels::new();
|
||||
|
||||
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
|
||||
let f32_10k = (0..10000)
|
||||
.map(|_| rand::random::<f32>())
|
||||
.collect::<Vec<_>>();
|
||||
let f32_100k = (0..100000)
|
||||
.map(|_| rand::random::<f32>())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
||||
let f16_1k = f16_map(&f32_1k);
|
||||
let f16_10k = f16_map(&f32_10k);
|
||||
let f16_100k = f16_map(&f32_100k);
|
||||
|
||||
let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
||||
let bf16_1k = bf16_map(&f32_1k);
|
||||
let bf16_10k = bf16_map(&f32_10k);
|
||||
let bf16_100k = bf16_map(&f32_100k);
|
||||
|
||||
let f32_ckernels = [
|
||||
binary::contiguous::add::FLOAT,
|
||||
binary::contiguous::sub::FLOAT,
|
||||
binary::contiguous::mul::FLOAT,
|
||||
binary::contiguous::div::FLOAT,
|
||||
];
|
||||
let f32_skernels = [
|
||||
binary::strided::add::FLOAT,
|
||||
binary::strided::sub::FLOAT,
|
||||
binary::strided::mul::FLOAT,
|
||||
binary::strided::div::FLOAT,
|
||||
];
|
||||
let f16_ckernels = [
|
||||
binary::contiguous::add::HALF,
|
||||
binary::contiguous::sub::HALF,
|
||||
binary::contiguous::mul::HALF,
|
||||
binary::contiguous::div::HALF,
|
||||
];
|
||||
let f16_skernels = [
|
||||
binary::strided::add::HALF,
|
||||
binary::strided::sub::HALF,
|
||||
binary::strided::mul::HALF,
|
||||
binary::strided::div::HALF,
|
||||
];
|
||||
let bf16_ckernels = [
|
||||
binary::contiguous::add::BFLOAT,
|
||||
binary::contiguous::sub::BFLOAT,
|
||||
binary::contiguous::mul::BFLOAT,
|
||||
binary::contiguous::div::BFLOAT,
|
||||
];
|
||||
let bf16_skernels = [
|
||||
binary::strided::add::BFLOAT,
|
||||
binary::strided::sub::BFLOAT,
|
||||
binary::strided::mul::BFLOAT,
|
||||
binary::strided::div::BFLOAT,
|
||||
];
|
||||
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
|
||||
"dtype", "kernel", "size", "runs", "total time", "avg time"
|
||||
);
|
||||
|
||||
// f32
|
||||
run_binary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels);
|
||||
run_binary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels);
|
||||
run_binary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels);
|
||||
|
||||
// f16
|
||||
run_binary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels);
|
||||
run_binary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels);
|
||||
run_binary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels);
|
||||
|
||||
// bf16
|
||||
run_binary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels);
|
||||
run_binary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels);
|
||||
run_binary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels);
|
||||
}
|
||||
|
||||
fn run_binary_bench<T: Clone>(
|
||||
device: &Device,
|
||||
kernels: &Kernels,
|
||||
v: &[T],
|
||||
contiguous: [binary::contiguous::Kernel; 4],
|
||||
strided: [binary::strided::Kernel; 4],
|
||||
) {
|
||||
let command_queue = device.new_command_queue();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let iterations = 1000;
|
||||
let input = device.new_buffer_with_data(
|
||||
v.as_ptr() as *const core::ffi::c_void,
|
||||
core::mem::size_of_val(v) as u64,
|
||||
options,
|
||||
);
|
||||
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
|
||||
|
||||
// Contiguous
|
||||
for kernel_name in contiguous {
|
||||
let total_time = autoreleasepool(|| {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
call_binary_contiguous(
|
||||
device,
|
||||
&command_buffer,
|
||||
kernels,
|
||||
kernel_name,
|
||||
v.len(),
|
||||
&input,
|
||||
&input,
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
start.elapsed()
|
||||
});
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||
type_name::<T>().split("::").last().unwrap(),
|
||||
kernel_name.to_string(),
|
||||
v.len(),
|
||||
iterations,
|
||||
total_time,
|
||||
total_time / iterations
|
||||
);
|
||||
}
|
||||
|
||||
// Strided
|
||||
let shape = vec![2, 5_000];
|
||||
let strides = vec![2, 1];
|
||||
let offset = 0;
|
||||
for kernel_name in strided {
|
||||
let total_time = autoreleasepool(|| {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
call_binary_strided(
|
||||
device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
kernel_name,
|
||||
&shape,
|
||||
&input,
|
||||
&strides,
|
||||
offset,
|
||||
&input,
|
||||
&strides,
|
||||
offset,
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
start.elapsed()
|
||||
});
|
||||
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||
type_name::<T>().split("::").last().unwrap(),
|
||||
kernel_name.to_string(),
|
||||
v.len(),
|
||||
iterations,
|
||||
total_time,
|
||||
total_time / iterations
|
||||
);
|
||||
}
|
||||
}
|
84
candle-metal-kernels/examples/cast.rs
Normal file
84
candle-metal-kernels/examples/cast.rs
Normal file
@ -0,0 +1,84 @@
|
||||
use candle_metal_kernels::{call_cast_contiguous, Kernels};
|
||||
use metal::objc::rc::autoreleasepool;
|
||||
use metal::{Device, MTLResourceOptions};
|
||||
use rand;
|
||||
use std::any::type_name;
|
||||
use std::time::Instant;
|
||||
|
||||
fn main() {
|
||||
let device = Device::system_default().unwrap();
|
||||
let kernels = Kernels::new();
|
||||
|
||||
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
|
||||
let f32_10k = (0..10000)
|
||||
.map(|_| rand::random::<f32>())
|
||||
.collect::<Vec<_>>();
|
||||
let f32_100k = (0..100000)
|
||||
.map(|_| rand::random::<f32>())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let contiguous_kernels = ["cast_u32_f32"];
|
||||
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
|
||||
"dtype", "kernel", "size", "runs", "total time", "avg time"
|
||||
);
|
||||
|
||||
// f32
|
||||
run_cast_bench(&device, &kernels, &f32_1k, &contiguous_kernels);
|
||||
run_cast_bench(&device, &kernels, &f32_10k, &contiguous_kernels);
|
||||
run_cast_bench(&device, &kernels, &f32_100k, &contiguous_kernels);
|
||||
}
|
||||
|
||||
fn run_cast_bench<T: Clone>(
|
||||
device: &Device,
|
||||
kernels: &Kernels,
|
||||
v: &[T],
|
||||
contiguous: &[&'static str],
|
||||
) {
|
||||
let command_queue = device.new_command_queue();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let iterations = 1000;
|
||||
let input = device.new_buffer_with_data(
|
||||
v.as_ptr() as *const core::ffi::c_void,
|
||||
core::mem::size_of_val(v) as u64,
|
||||
options,
|
||||
);
|
||||
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
|
||||
|
||||
// Contiguous
|
||||
for kernel_name in contiguous {
|
||||
let total_time = autoreleasepool(|| {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
call_cast_contiguous(
|
||||
device,
|
||||
&command_buffer,
|
||||
kernels,
|
||||
kernel_name,
|
||||
v.len(),
|
||||
&input,
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
start.elapsed()
|
||||
});
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||
type_name::<T>().split("::").last().unwrap(),
|
||||
kernel_name.to_string(),
|
||||
v.len(),
|
||||
iterations,
|
||||
total_time,
|
||||
total_time / iterations
|
||||
);
|
||||
}
|
||||
|
||||
// Strided?
|
||||
}
|
197
candle-metal-kernels/examples/unary.rs
Normal file
197
candle-metal-kernels/examples/unary.rs
Normal file
@ -0,0 +1,197 @@
|
||||
use candle_metal_kernels::{call_unary_contiguous, call_unary_strided, unary, Kernels};
|
||||
use half::{bf16, f16};
|
||||
use metal::objc::rc::autoreleasepool;
|
||||
use metal::{Device, MTLResourceOptions};
|
||||
use rand;
|
||||
use std::any::type_name;
|
||||
use std::time::Instant;
|
||||
|
||||
fn main() {
|
||||
let device = Device::system_default().unwrap();
|
||||
let kernels = Kernels::new();
|
||||
|
||||
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
|
||||
let f32_10k = (0..10000)
|
||||
.map(|_| rand::random::<f32>())
|
||||
.collect::<Vec<_>>();
|
||||
let f32_100k = (0..100000)
|
||||
.map(|_| rand::random::<f32>())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
||||
let f16_1k = f16_map(&f32_1k);
|
||||
let f16_10k = f16_map(&f32_10k);
|
||||
let f16_100k = f16_map(&f32_100k);
|
||||
|
||||
let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
||||
let bf16_1k = bf16_map(&f32_1k);
|
||||
let bf16_10k = bf16_map(&f32_10k);
|
||||
let bf16_100k = bf16_map(&f32_100k);
|
||||
|
||||
let f32_ckernels = [
|
||||
unary::contiguous::sin::FLOAT,
|
||||
unary::contiguous::cos::FLOAT,
|
||||
unary::contiguous::exp::FLOAT,
|
||||
unary::contiguous::sqr::FLOAT,
|
||||
unary::contiguous::sqrt::FLOAT,
|
||||
unary::contiguous::neg::FLOAT,
|
||||
unary::contiguous::copy::FLOAT,
|
||||
];
|
||||
let f32_skernels = [
|
||||
unary::strided::sin::FLOAT,
|
||||
unary::strided::cos::FLOAT,
|
||||
unary::strided::exp::FLOAT,
|
||||
unary::strided::sqr::FLOAT,
|
||||
unary::strided::sqrt::FLOAT,
|
||||
unary::strided::neg::FLOAT,
|
||||
unary::strided::copy::FLOAT,
|
||||
];
|
||||
let f16_ckernels = [
|
||||
unary::contiguous::sin::HALF,
|
||||
unary::contiguous::cos::HALF,
|
||||
unary::contiguous::exp::HALF,
|
||||
unary::contiguous::sqr::HALF,
|
||||
unary::contiguous::sqrt::HALF,
|
||||
unary::contiguous::neg::HALF,
|
||||
unary::contiguous::copy::HALF,
|
||||
];
|
||||
let f16_skernels = [
|
||||
unary::strided::sin::HALF,
|
||||
unary::strided::cos::HALF,
|
||||
unary::strided::exp::HALF,
|
||||
unary::strided::sqr::HALF,
|
||||
unary::strided::sqrt::HALF,
|
||||
unary::strided::neg::HALF,
|
||||
unary::strided::copy::HALF,
|
||||
];
|
||||
let bf16_ckernels = [
|
||||
unary::contiguous::sin::BFLOAT,
|
||||
unary::contiguous::cos::BFLOAT,
|
||||
unary::contiguous::exp::BFLOAT,
|
||||
unary::contiguous::sqr::BFLOAT,
|
||||
unary::contiguous::sqrt::BFLOAT,
|
||||
unary::contiguous::neg::BFLOAT,
|
||||
unary::contiguous::copy::BFLOAT,
|
||||
];
|
||||
let bf16_skernels = [
|
||||
unary::strided::sin::BFLOAT,
|
||||
unary::strided::cos::BFLOAT,
|
||||
unary::strided::exp::BFLOAT,
|
||||
unary::strided::sqr::BFLOAT,
|
||||
unary::strided::sqrt::BFLOAT,
|
||||
unary::strided::neg::BFLOAT,
|
||||
unary::strided::copy::BFLOAT,
|
||||
];
|
||||
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
|
||||
"dtype", "kernel", "size", "runs", "total time", "avg time"
|
||||
);
|
||||
|
||||
// f32
|
||||
run_unary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels);
|
||||
run_unary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels);
|
||||
run_unary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels);
|
||||
|
||||
// f16
|
||||
run_unary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels);
|
||||
run_unary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels);
|
||||
run_unary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels);
|
||||
|
||||
// bf16
|
||||
run_unary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels);
|
||||
run_unary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels);
|
||||
run_unary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels);
|
||||
}
|
||||
|
||||
fn run_unary_bench<T: Clone>(
|
||||
device: &Device,
|
||||
kernels: &Kernels,
|
||||
v: &[T],
|
||||
contiguous: [unary::contiguous::Kernel; 7],
|
||||
strided: [unary::strided::Kernel; 7],
|
||||
) {
|
||||
let command_queue = device.new_command_queue();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let iterations = 10000;
|
||||
let input = device.new_buffer_with_data(
|
||||
v.as_ptr() as *const core::ffi::c_void,
|
||||
core::mem::size_of_val(v) as u64,
|
||||
options,
|
||||
);
|
||||
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
|
||||
|
||||
// Contiguous
|
||||
for kernel_name in contiguous {
|
||||
let total_time = autoreleasepool(|| {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
call_unary_contiguous(
|
||||
device,
|
||||
&command_buffer,
|
||||
kernels,
|
||||
kernel_name,
|
||||
v.len(),
|
||||
&input,
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
start.elapsed()
|
||||
});
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||
type_name::<T>().split("::").last().unwrap(),
|
||||
kernel_name.to_string(),
|
||||
v.len(),
|
||||
iterations,
|
||||
total_time,
|
||||
total_time / iterations
|
||||
);
|
||||
}
|
||||
|
||||
// Strided
|
||||
let shape = vec![2, 5_000];
|
||||
let strides = vec![2, 1];
|
||||
let offset = 0;
|
||||
for kernel_name in strided {
|
||||
let total_time = autoreleasepool(|| {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
call_unary_strided(
|
||||
device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
kernel_name,
|
||||
&shape,
|
||||
&input,
|
||||
&strides,
|
||||
offset,
|
||||
&mut output,
|
||||
0,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
start.elapsed()
|
||||
});
|
||||
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||
type_name::<T>().split("::").last().unwrap(),
|
||||
kernel_name.to_string(),
|
||||
v.len(),
|
||||
iterations,
|
||||
total_time,
|
||||
total_time / iterations
|
||||
);
|
||||
}
|
||||
}
|
43
candle-metal-kernels/src/affine.metal
Normal file
43
candle-metal-kernels/src/affine.metal
Normal file
@ -0,0 +1,43 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define AFFINE(FN_NAME, TYPENAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
constant float &mul, \
|
||||
constant float &add, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint id [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (id >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
const TYPENAME m = TYPENAME(mul); \
|
||||
const TYPENAME a = TYPENAME(add); \
|
||||
output[id] = input[id] * m + a; \
|
||||
} \
|
||||
|
||||
AFFINE(affine_float, float)
|
||||
AFFINE(affine_half, half)
|
||||
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
AFFINE(affine_bfloat, bfloat);
|
||||
#endif
|
72
candle-metal-kernels/src/binary.metal
Normal file
72
candle-metal-kernels/src/binary.metal
Normal file
@ -0,0 +1,72 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define BINARY(FN, TYPENAME, OUT_TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
device const TYPENAME *left, \
|
||||
device const TYPENAME *right, \
|
||||
device TYPENAME *output, \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
TYPENAME x = left[thread_position_in_grid]; \
|
||||
TYPENAME y = right[thread_position_in_grid]; \
|
||||
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
|
||||
}\
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *left_strides, \
|
||||
constant size_t *right_strides, \
|
||||
device const TYPENAME *left, \
|
||||
device const TYPENAME *right, \
|
||||
device TYPENAME *output, \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \
|
||||
TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, right_strides)]; \
|
||||
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
|
||||
}
|
||||
|
||||
#define BINARY_OP(FN, NAME) \
|
||||
BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \
|
||||
BINARY(FN, half, half, NAME##_half, NAME##_half_strided);
|
||||
|
||||
#define BFLOAT_BINARY_OP(FN, NAME) \
|
||||
BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
|
||||
|
||||
|
||||
BINARY_OP(x + y, add)
|
||||
BINARY_OP(x - y, sub)
|
||||
BINARY_OP(x * y, mul)
|
||||
BINARY_OP(x / y, div)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
BFLOAT_BINARY_OP(x + y, add)
|
||||
BFLOAT_BINARY_OP(x - y, sub)
|
||||
BFLOAT_BINARY_OP(x * y, mul)
|
||||
BFLOAT_BINARY_OP(x / y, div)
|
||||
#endif
|
51
candle-metal-kernels/src/cast.metal
Normal file
51
candle-metal-kernels/src/cast.metal
Normal file
@ -0,0 +1,51 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define CAST(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
device const LEFT_TYPENAME *input, \
|
||||
device RIGHT_TYPENAME *output, \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \
|
||||
} \
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
device const LEFT_TYPENAME *input, \
|
||||
device RIGHT_TYPENAME *output, \
|
||||
uint i [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (i >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
|
||||
} \
|
||||
|
||||
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#endif
|
102
candle-metal-kernels/src/indexing.metal
Normal file
102
candle-metal-kernels/src/indexing.metal
Normal file
@ -0,0 +1,102 @@
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||
kernel void NAME( \
|
||||
constant size_t &dst_size, \
|
||||
constant size_t &left_size, \
|
||||
constant size_t &src_dim_size, \
|
||||
constant size_t &right_size, \
|
||||
constant size_t &ids_size, \
|
||||
const device TYPENAME *input, \
|
||||
const device INDEX_TYPENAME *input_ids, \
|
||||
device TYPENAME *output, \
|
||||
uint gid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (gid >= dst_size) { \
|
||||
return; \
|
||||
} \
|
||||
const size_t id_i = gid / right_size / left_size; \
|
||||
const size_t right_rank_i = gid % right_size; \
|
||||
const size_t left_rank_i = gid % left_size; \
|
||||
/* \
|
||||
// Force prevent out of bounds indexing \
|
||||
// since there doesn't seem to be a good way to force crash \
|
||||
// No need to check for zero we're only allowing unsized. \
|
||||
*/ \
|
||||
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
|
||||
const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; \
|
||||
output[gid] = input[src_i]; \
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename I>
|
||||
void index_add(
|
||||
device I *ids [[buffer(0)]],
|
||||
device T *inp [[buffer(1)]],
|
||||
device T *out [[buffer(2)]],
|
||||
|
||||
constant uint &ids_dim_size,
|
||||
constant uint &left_size,
|
||||
constant uint &dst_dim_size,
|
||||
constant uint &right_size,
|
||||
|
||||
uint gid [[ thread_position_in_grid ]] \
|
||||
) {
|
||||
|
||||
if (gid >= left_size * right_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i = gid;
|
||||
const uint pre = i / right_size;
|
||||
const uint post = i % right_size;
|
||||
|
||||
for (uint j = 0; j < ids_dim_size; j++) {
|
||||
const uint idx = ids[j];
|
||||
const uint src_i = (pre * ids_dim_size + j) * right_size + post;
|
||||
const uint dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||
out[dst_i] += inp[src_i];
|
||||
}
|
||||
}
|
||||
|
||||
#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
device INDEX_TYPENAME *ids [[buffer(0)]], \
|
||||
device TYPENAME *inp [[buffer(1)]], \
|
||||
device TYPENAME *out [[buffer(2)]], \
|
||||
constant uint &ids_dim_size, \
|
||||
constant uint &left_size, \
|
||||
constant uint &dst_dim_size, \
|
||||
constant uint &right_size, \
|
||||
uint gid [[ thread_position_in_grid ]] \
|
||||
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \
|
||||
|
||||
|
||||
INDEX_OP(is_u32_f32, uint, float)
|
||||
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
IA_OP(bfloat, int64_t, ia_i64_bf16)
|
||||
IA_OP(bfloat, uint32_t, ia_u32_bf16)
|
||||
IA_OP(bfloat, uint8_t, ia_u8_bf16)
|
||||
#endif
|
||||
|
||||
IA_OP(half, uint32_t, ia_u32_f16)
|
||||
IA_OP(half, uint8_t, ia_u8_f16)
|
||||
|
||||
IA_OP(float, int64_t, ia_i64_f32)
|
||||
IA_OP(uint8_t, int64_t, ia_i64_u8)
|
||||
IA_OP(int64_t, int64_t, ia_i64_i64)
|
||||
IA_OP(uint32_t, int64_t, ia_i64_u32)
|
||||
|
||||
IA_OP(float, uint32_t, ia_u32_f32)
|
||||
IA_OP(uint8_t, uint32_t, ia_u32_u8)
|
||||
IA_OP(int64_t, uint32_t, ia_u32_i64)
|
||||
IA_OP(uint32_t, uint32_t, ia_u32_u32)
|
||||
|
||||
IA_OP(float, uint8_t, ia_u8_f32)
|
||||
IA_OP(uint8_t, uint8_t, ia_u8_u8)
|
||||
IA_OP(uint32_t, uint8_t, ia_u8_u32)
|
||||
IA_OP(int64_t, uint8_t, ia_u8_i64)
|
File diff suppressed because it is too large
Load Diff
139
candle-metal-kernels/src/reduce.metal
Normal file
139
candle-metal-kernels/src/reduce.metal
Normal file
@ -0,0 +1,139 @@
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
constant int THREADGROUP_SIZE = 256;
|
||||
|
||||
# define REDUCE(FN, NAME, TYPENAME) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const TYPENAME *src, \
|
||||
device TYPENAME *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint blockDim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = 0; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||
size_t idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
*/ \
|
||||
TYPENAME x = shared_memory[tid]; \
|
||||
TYPENAME y = src[idx]; \
|
||||
shared_memory[tid] = FN; \
|
||||
idx += blockDim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = blockDim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
TYPENAME x = shared_memory[tid]; \
|
||||
TYPENAME y = shared_memory[tid + s]; \
|
||||
shared_memory[tid] = FN; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
dst[dst_id] = shared_memory[0]; \
|
||||
} \
|
||||
|
||||
kernel void softmax_float(
|
||||
constant size_t &src_numel,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
device const float *src,
|
||||
device float *dst,
|
||||
uint id [[ thread_position_in_grid ]],
|
||||
uint tid [[ thread_index_in_threadgroup ]],
|
||||
uint dst_id [[ threadgroup_position_in_grid ]],
|
||||
uint blockDim [[ threads_per_threadgroup ]]
|
||||
) {
|
||||
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE];
|
||||
|
||||
shared_memory[tid] = -INFINITY;
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// to (dst_id + 1) * el_to_sum_per_block.
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
||||
size_t idx = start_idx + tid;
|
||||
|
||||
while (idx < stop_idx) {
|
||||
// TODO: Fast version for the contiguous case.
|
||||
shared_memory[tid] = max(shared_memory[tid], src[idx]);
|
||||
idx += blockDim;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// reduction in shared memory
|
||||
for (uint s = blockDim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
float max = shared_memory[0];
|
||||
|
||||
shared_memory[tid] = 0;
|
||||
|
||||
// Restart
|
||||
idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
// TODO: Fast version for the contiguous case.
|
||||
const float val = exp(src[idx] - max);
|
||||
dst[idx] = val;
|
||||
shared_memory[tid] += val;
|
||||
idx += blockDim;
|
||||
}
|
||||
// reduction in shared memory
|
||||
for (uint s = blockDim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
shared_memory[tid] += shared_memory[tid + s];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
const float inv_acc = 1/shared_memory[0];
|
||||
idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
dst[idx] *= inv_acc;
|
||||
idx += blockDim;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
REDUCE(x + y, fast_sum_float, float)
|
||||
REDUCE(x * y, fast_mul_float, float)
|
||||
REDUCE(max(x, y), fast_max_float, float)
|
57
candle-metal-kernels/src/ternary.metal
Normal file
57
candle-metal-kernels/src/ternary.metal
Normal file
@ -0,0 +1,57 @@
|
||||
#include <metal_stdlib>
|
||||
#
|
||||
using namespace metal;
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
|
||||
#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &numel, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t *strides_t, \
|
||||
constant size_t *strides_f, \
|
||||
device const ID_TYPENAME *ids, \
|
||||
device const TYPENAME *t, \
|
||||
device const TYPENAME *f, \
|
||||
device TYPENAME *out ,\
|
||||
uint i [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
|
||||
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
|
||||
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \
|
||||
} \
|
||||
|
||||
// WHERE_OP(float, int64_t, where_i64_f32)
|
||||
// WHERE_OP(double, int64_t, where_i64_f64)
|
||||
// WHERE_OP(uint8_t, int64_t, where_i64_u8)
|
||||
// WHERE_OP(uint32_t, int64_t, where_i64_u32)
|
||||
// WHERE_OP(int64_t, int64_t, where_i64_i64)
|
||||
//
|
||||
// WHERE_OP(float, uint32_t, where_u32_f32)
|
||||
// WHERE_OP(double, uint32_t, where_u32_f64)
|
||||
// WHERE_OP(uint8_t, uint32_t, where_u32_u8)
|
||||
// WHERE_OP(uint32_t, uint32_t, where_u32_u32)
|
||||
// WHERE_OP(int64_t, uint32_t, where_u32_i64)
|
||||
|
||||
WHERE_OP(float, uint8_t, where_u8_f32)
|
||||
// WHERE_OP(double, uint8_t, where_u8_f64)
|
||||
// WHERE_OP(uint8_t, uint8_t, where_u8_u8)
|
||||
// WHERE_OP(uint32_t, uint8_t, where_u8_u32)
|
||||
// WHERE_OP(int64_t, uint8_t, where_u8_i64)
|
78
candle-metal-kernels/src/unary.metal
Normal file
78
candle-metal-kernels/src/unary.metal
Normal file
@ -0,0 +1,78 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
|
||||
template <typename T> METAL_FUNC T neg(T in){ return -in; }
|
||||
template <typename T> METAL_FUNC T id(T in){ return in; }
|
||||
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[thread_position_in_grid] = TYPENAME(FN(input[thread_position_in_grid])); \
|
||||
}\
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[thread_position_in_grid] = TYPENAME(FN(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)])); \
|
||||
}
|
||||
|
||||
#define UNARY_OP(NAME) \
|
||||
UNARY(NAME, float, NAME##_float, NAME##_float_strided); \
|
||||
UNARY(NAME, half, NAME##_half, NAME##_half_strided);
|
||||
|
||||
#define BFLOAT_UNARY_OP(NAME) \
|
||||
UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
|
||||
|
||||
|
||||
UNARY_OP(cos)
|
||||
UNARY_OP(sin)
|
||||
UNARY_OP(sqr)
|
||||
UNARY_OP(sqrt)
|
||||
UNARY_OP(neg)
|
||||
UNARY_OP(exp)
|
||||
UNARY(id, float, copy_float, copy_float_strided)
|
||||
UNARY(id, half, copy_half, copy_half_strided)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
BFLOAT_UNARY_OP(cos)
|
||||
BFLOAT_UNARY_OP(sin)
|
||||
BFLOAT_UNARY_OP(sqr)
|
||||
BFLOAT_UNARY_OP(sqrt)
|
||||
BFLOAT_UNARY_OP(neg)
|
||||
BFLOAT_UNARY_OP(exp)
|
||||
|
||||
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
|
||||
#endif
|
@ -28,5 +28,4 @@ clap = { workspace = true }
|
||||
default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
||||
cuda = ["candle/cuda"]
|
||||
metal = ["candle/metal"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
||||
|
@ -13,7 +13,10 @@ pub enum Activation {
|
||||
Relu6,
|
||||
Silu,
|
||||
Sigmoid,
|
||||
HardSigmoid,
|
||||
Swiglu,
|
||||
Swish,
|
||||
HardSwish,
|
||||
Elu(f64),
|
||||
LeakyRelu(f64),
|
||||
}
|
||||
@ -29,7 +32,10 @@ impl super::Module for Activation {
|
||||
Self::Relu6 => xs.clamp(0f32, 6f32),
|
||||
Self::Silu => crate::ops::silu(xs),
|
||||
Self::Sigmoid => crate::ops::sigmoid(xs),
|
||||
Self::HardSigmoid => crate::ops::hard_sigmoid(xs),
|
||||
Self::Swiglu => crate::ops::swiglu(xs),
|
||||
Self::Swish => xs * crate::ops::sigmoid(xs)?,
|
||||
Self::HardSwish => xs * crate::ops::hard_sigmoid(xs)?,
|
||||
&Self::Elu(alpha) => xs.elu(alpha),
|
||||
&Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope),
|
||||
}
|
||||
|
@ -70,6 +70,67 @@ impl crate::Module for Conv1d {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct ConvTranspose1dConfig {
|
||||
pub padding: usize,
|
||||
pub output_padding: usize,
|
||||
pub stride: usize,
|
||||
pub dilation: usize,
|
||||
// TODO: support groups.
|
||||
}
|
||||
|
||||
impl Default for ConvTranspose1dConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
padding: 0,
|
||||
output_padding: 0,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ConvTranspose1d {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
config: ConvTranspose1dConfig,
|
||||
}
|
||||
|
||||
impl ConvTranspose1d {
|
||||
pub fn new(weight: Tensor, bias: Option<Tensor>, config: ConvTranspose1dConfig) -> Self {
|
||||
Self {
|
||||
weight,
|
||||
bias,
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn config(&self) -> &ConvTranspose1dConfig {
|
||||
&self.config
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for ConvTranspose1d {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.conv_transpose1d(
|
||||
&self.weight,
|
||||
self.config.padding,
|
||||
self.config.output_padding,
|
||||
self.config.stride,
|
||||
self.config.dilation,
|
||||
)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => {
|
||||
let b = bias.dims1()?;
|
||||
let bias = bias.reshape((1, b, 1, 1))?;
|
||||
Ok(x.broadcast_add(&bias)?)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct Conv2dConfig {
|
||||
pub padding: usize,
|
||||
@ -241,6 +302,39 @@ pub fn conv1d(
|
||||
Ok(Conv1d::new(ws, Some(bs), cfg))
|
||||
}
|
||||
|
||||
pub fn conv_transpose1d(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
cfg: ConvTranspose1dConfig,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<ConvTranspose1d> {
|
||||
let bound = 1. / (out_channels as f64 * kernel_size as f64).sqrt();
|
||||
let init = crate::Init::Uniform {
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?;
|
||||
let bs = vb.get_with_hints(out_channels, "bias", init)?;
|
||||
Ok(ConvTranspose1d::new(ws, Some(bs), cfg))
|
||||
}
|
||||
|
||||
pub fn conv_transpose1d_no_bias(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
cfg: ConvTranspose1dConfig,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<ConvTranspose1d> {
|
||||
let bound = 1. / (out_channels as f64 * kernel_size as f64).sqrt();
|
||||
let init = crate::Init::Uniform {
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?;
|
||||
Ok(ConvTranspose1d::new(ws, None, cfg))
|
||||
}
|
||||
|
||||
pub fn conv2d(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
|
@ -9,6 +9,7 @@ pub struct Embedding {
|
||||
|
||||
impl Embedding {
|
||||
pub fn new(embeddings: Tensor, hidden_size: usize) -> Self {
|
||||
// todo!("Embedding {embeddings}");
|
||||
Self {
|
||||
embeddings,
|
||||
hidden_size,
|
||||
|
@ -95,6 +95,14 @@ impl LayerNorm {
|
||||
eps,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn weight(&self) -> &Tensor {
|
||||
&self.weight
|
||||
}
|
||||
|
||||
pub fn bias(&self) -> Option<&Tensor> {
|
||||
self.bias.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for LayerNorm {
|
||||
|
@ -39,11 +39,21 @@ pub fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||
xs / (xs.neg()?.exp()? + 1.0)?
|
||||
}
|
||||
|
||||
pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.chunk(2, candle::D::Minus1)?;
|
||||
crate::ops::silu(&xs[0])? * &xs[1]
|
||||
}
|
||||
|
||||
pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
|
||||
// TODO: Should we have a specialized op for this?
|
||||
(xs.neg()?.exp()? + 1.0)?.recip()
|
||||
}
|
||||
|
||||
pub fn hard_sigmoid(xs: &Tensor) -> Result<Tensor> {
|
||||
// TODO: Should we have a specialized op for this?
|
||||
((xs + 3.0)? / 6.0)?.clamp(0f32, 1f32)
|
||||
}
|
||||
|
||||
pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result<Tensor> {
|
||||
let zeros = xs.zeros_like()?;
|
||||
xs.maximum(&zeros)? + xs.minimum(&zeros)? * negative_slope
|
||||
@ -191,16 +201,6 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
};
|
||||
Ok((dst, layout.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &candle::MetalStorage,
|
||||
layout: &Layout,
|
||||
) -> Result<(candle::MetalStorage, Shape)> {
|
||||
println!("TODO softmax-last-dim");
|
||||
Ok((storage.clone(), layout.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
||||
|
23
candle-onnx/Cargo.toml
Normal file
23
candle-onnx/Cargo.toml
Normal file
@ -0,0 +1,23 @@
|
||||
[package]
|
||||
name = "candle-onnx"
|
||||
version = "0.3.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "ONNX support for Candle"
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
keywords = ["blas", "tensor", "machine-learning"]
|
||||
categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
prost = "0.12.1"
|
||||
|
||||
[build-dependencies]
|
||||
prost-build = "0.12.1"
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
|
21
candle-onnx/README.md
Normal file
21
candle-onnx/README.md
Normal file
@ -0,0 +1,21 @@
|
||||
# candle-onnx
|
||||
|
||||
This crate adds ONNX support to candle
|
||||
|
||||
## FAQ
|
||||
|
||||
#### Missing protoc installation when compiling candle-onnx
|
||||
|
||||
The candle-onnx dependency prost-build no longer comes bundled with prost
|
||||
binaries. This could cause the following error when attempting to compile
|
||||
candle-onnx:
|
||||
|
||||
```
|
||||
error: failed to run custom build command for `candle-onnx`
|
||||
Caused by: // (...)
|
||||
Could not find `protoc` installation and this build crate cannot proceed without this knowledge.
|
||||
```
|
||||
|
||||
To fix this issue install protoc on your system and make it available in your
|
||||
system `PATH`. See the [protoc
|
||||
documentation](https://grpc.io/docs/protoc-installation/) for more information.
|
6
candle-onnx/build.rs
Normal file
6
candle-onnx/build.rs
Normal file
@ -0,0 +1,6 @@
|
||||
use std::io::Result;
|
||||
|
||||
fn main() -> Result<()> {
|
||||
prost_build::compile_protos(&["src/onnx.proto3"], &["src/"])?;
|
||||
Ok(())
|
||||
}
|
755
candle-onnx/src/eval.rs
Normal file
755
candle-onnx/src/eval.rs
Normal file
@ -0,0 +1,755 @@
|
||||
use crate::onnx;
|
||||
use crate::onnx::attribute_proto::AttributeType;
|
||||
use crate::onnx::tensor_proto::DataType;
|
||||
use candle::{bail, DType, Device, Result, Tensor};
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub type Value = Tensor;
|
||||
|
||||
pub fn dtype(dt: DataType) -> Option<DType> {
|
||||
match dt {
|
||||
DataType::Uint8 => Some(DType::U8),
|
||||
DataType::Uint32 => Some(DType::U32),
|
||||
DataType::Int64 => Some(DType::I64),
|
||||
DataType::Float16 => Some(DType::F16),
|
||||
DataType::Float => Some(DType::F32),
|
||||
DataType::Double => Some(DType::F64),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
trait Attr {
|
||||
const TYPE: AttributeType;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<&Self>;
|
||||
}
|
||||
|
||||
impl Attr for i64 {
|
||||
const TYPE: AttributeType = AttributeType::Int;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
|
||||
Ok(&attr.i)
|
||||
}
|
||||
}
|
||||
|
||||
impl Attr for f32 {
|
||||
const TYPE: AttributeType = AttributeType::Float;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
|
||||
Ok(&attr.f)
|
||||
}
|
||||
}
|
||||
|
||||
impl Attr for [i64] {
|
||||
const TYPE: AttributeType = AttributeType::Ints;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
|
||||
Ok(attr.ints.as_slice())
|
||||
}
|
||||
}
|
||||
|
||||
impl Attr for str {
|
||||
const TYPE: AttributeType = AttributeType::String;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
|
||||
std::str::from_utf8(&attr.s).map_err(candle::Error::wrap)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_attr_<'a>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a onnx::AttributeProto> {
|
||||
match node.attribute.iter().find(|attr| attr.name == name) {
|
||||
None => {
|
||||
bail!(
|
||||
"cannot find the '{name}' attribute in '{}' for {}",
|
||||
node.op_type,
|
||||
node.name
|
||||
)
|
||||
}
|
||||
Some(dt) => Ok(dt),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_attr<'a, T: Attr + ?Sized>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a T> {
|
||||
let attr = get_attr_(node, name)?;
|
||||
if attr.r#type() != T::TYPE {
|
||||
bail!(
|
||||
"unsupported type {:?} for '{name}' attribute in '{}' for {}",
|
||||
attr.r#type,
|
||||
node.op_type,
|
||||
node.name
|
||||
)
|
||||
}
|
||||
T::get(attr)
|
||||
}
|
||||
|
||||
fn get_attr_opt<'a, T: Attr + ?Sized>(
|
||||
node: &'a onnx::NodeProto,
|
||||
name: &str,
|
||||
) -> Result<Option<&'a T>> {
|
||||
match node.attribute.iter().find(|attr| attr.name == name) {
|
||||
None => Ok(None),
|
||||
Some(attr) => {
|
||||
if attr.r#type() != T::TYPE {
|
||||
bail!(
|
||||
"unsupported type {:?} for '{name}' attribute in '{}' for {}",
|
||||
attr.r#type,
|
||||
node.op_type,
|
||||
node.name
|
||||
)
|
||||
}
|
||||
let val = T::get(attr)?;
|
||||
Ok(Some(val))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
|
||||
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
|
||||
match DataType::try_from(t.data_type) {
|
||||
Ok(DataType::Int32) => {
|
||||
if t.int32_data.is_empty() {
|
||||
let len = t.raw_data.len() / 4;
|
||||
let data: &[i32] =
|
||||
unsafe { std::slice::from_raw_parts(t.raw_data.as_ptr() as *const i32, len) };
|
||||
let data = data.iter().map(|v| *v as i64).collect::<Vec<_>>();
|
||||
Tensor::from_vec(data, len, &Device::Cpu)
|
||||
} else {
|
||||
let data = t.int32_data.iter().map(|v| *v as i64).collect::<Vec<_>>();
|
||||
Tensor::from_vec(data, t.int32_data.len(), &Device::Cpu)
|
||||
}
|
||||
}
|
||||
Ok(dt) => match dtype(dt) {
|
||||
Some(dt) => {
|
||||
if dt == DType::F32 && !t.float_data.is_empty() {
|
||||
Tensor::from_slice(&t.float_data, dims.as_slice(), &Device::Cpu)
|
||||
} else if dt == DType::F64 && !t.double_data.is_empty() {
|
||||
Tensor::from_slice(&t.double_data, dims.as_slice(), &Device::Cpu)
|
||||
} else if dt == DType::I64 && !t.int64_data.is_empty() {
|
||||
Tensor::from_slice(&t.int64_data, dims.as_slice(), &Device::Cpu)
|
||||
} else {
|
||||
Tensor::from_raw_buffer(
|
||||
t.raw_data.as_slice(),
|
||||
dt,
|
||||
dims.as_slice(),
|
||||
&Device::Cpu,
|
||||
)
|
||||
}
|
||||
}
|
||||
None => {
|
||||
bail!("unsupported 'value' data-type {dt:?} for {name}")
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
bail!("unsupported 'value' data-type {} for {name}", t.data_type,)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This function provides a direct evaluation of the proto.
|
||||
// Longer-term, we should first convert the proto to an intermediate representation of the compute
|
||||
// graph so as to make multiple evaluations more efficient.
|
||||
// An example upside of this would be to remove intermediary values when they are not needed
|
||||
// anymore.
|
||||
pub fn simple_eval(
|
||||
model: &onnx::ModelProto,
|
||||
inputs: HashMap<String, Value>,
|
||||
) -> Result<HashMap<String, Value>> {
|
||||
let graph = match &model.graph {
|
||||
None => bail!("no graph defined in proto"),
|
||||
Some(graph) => graph,
|
||||
};
|
||||
let mut values = inputs;
|
||||
for t in graph.initializer.iter() {
|
||||
let tensor = get_tensor(t, t.name.as_str())?;
|
||||
values.insert(t.name.to_string(), tensor);
|
||||
}
|
||||
for input in graph.input.iter() {
|
||||
let input_type = match &input.r#type {
|
||||
Some(input_type) => input_type,
|
||||
None => continue,
|
||||
};
|
||||
let input_type = match &input_type.value {
|
||||
Some(input_type) => input_type,
|
||||
None => continue,
|
||||
};
|
||||
let tensor_type = match input_type {
|
||||
onnx::type_proto::Value::TensorType(tt) => tt,
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let tensor = match values.get(&input.name) {
|
||||
None => bail!("missing input {}", input.name),
|
||||
Some(tensor) => tensor,
|
||||
};
|
||||
let dt = match DataType::try_from(tensor_type.elem_type) {
|
||||
Ok(dt) => match dtype(dt) {
|
||||
Some(dt) => dt,
|
||||
None => {
|
||||
bail!("unsupported 'value' data-type {dt:?} for {}", input.name)
|
||||
}
|
||||
},
|
||||
type_ => bail!("unsupported input type {type_:?}"),
|
||||
};
|
||||
match &tensor_type.shape {
|
||||
None => continue,
|
||||
Some(shape) => {
|
||||
if shape.dim.len() != tensor.rank() {
|
||||
bail!(
|
||||
"unexpected rank for {}, got {:?}, expected {:?}",
|
||||
input.name,
|
||||
shape.dim,
|
||||
tensor.shape()
|
||||
)
|
||||
}
|
||||
for (idx, (d, &dim)) in shape.dim.iter().zip(tensor.dims().iter()).enumerate() {
|
||||
match &d.value {
|
||||
Some(onnx::tensor_shape_proto::dimension::Value::DimValue(v)) => {
|
||||
if *v as usize != dim {
|
||||
bail!(
|
||||
"unexpected dim {idx} for {}, got {:?}, expected {:?}",
|
||||
input.name,
|
||||
shape.dim,
|
||||
tensor.shape()
|
||||
)
|
||||
}
|
||||
}
|
||||
// We do not check equality constraints for the DimParam dimensions for now.
|
||||
Some(onnx::tensor_shape_proto::dimension::Value::DimParam(_)) | None => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
if dt != tensor.dtype() {
|
||||
bail!(
|
||||
"unexpected dtype for {}, got {:?}, expected {dt:?}",
|
||||
input.name,
|
||||
tensor.dtype()
|
||||
)
|
||||
}
|
||||
}
|
||||
// The nodes are topologically sorted so we can just process them in order.
|
||||
for node in graph.node.iter() {
|
||||
let get = |input_name: &str| match values.get(input_name) {
|
||||
Some(value) => Ok(value),
|
||||
None => bail!("cannot find {input_name} for op {}", node.name),
|
||||
};
|
||||
// TODO: Validate node.input for each operator.
|
||||
match node.op_type.as_str() {
|
||||
"Add" => {
|
||||
let input0 = get(&node.input[0])?;
|
||||
let input1 = get(&node.input[1])?;
|
||||
let output = input0.broadcast_add(input1)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Sub" => {
|
||||
let input0 = get(&node.input[0])?;
|
||||
let input1 = get(&node.input[1])?;
|
||||
let output = input0.broadcast_sub(input1)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Mul" => {
|
||||
let input0 = get(&node.input[0])?;
|
||||
let input1 = get(&node.input[1])?;
|
||||
let output = input0.broadcast_mul(input1)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Div" => {
|
||||
let input0 = get(&node.input[0])?;
|
||||
let input1 = get(&node.input[1])?;
|
||||
let output = input0.broadcast_div(input1)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Equal" => {
|
||||
let input0 = get(&node.input[0])?;
|
||||
let input1 = get(&node.input[1])?;
|
||||
let output = input0.broadcast_eq(input1)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Not" => {
|
||||
let xs = get(&node.input[0])?;
|
||||
let xs = xs.eq(&xs.zeros_like()?)?;
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
"MatMul" => {
|
||||
let input0 = get(&node.input[0])?;
|
||||
let input1 = get(&node.input[1])?;
|
||||
let output = input0.broadcast_matmul(input1)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Reshape" => {
|
||||
let input0 = get(&node.input[0])?;
|
||||
let input1 = get(&node.input[1])?.to_vec1::<i64>()?;
|
||||
// TODO: Check that there is at most a single -1 or 0, handle other neg values.
|
||||
let mut other_than_minus1 = 1usize;
|
||||
for &v in input1.iter() {
|
||||
if v != -1 && v != 0 {
|
||||
other_than_minus1 *= v as usize
|
||||
}
|
||||
}
|
||||
let input1 = input1
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, &v)| match v {
|
||||
-1 => Ok(input0.elem_count() / other_than_minus1),
|
||||
0 => input0.dim(idx),
|
||||
_ => Ok(v as usize),
|
||||
})
|
||||
.collect::<Result<Vec<usize>>>()?;
|
||||
let output = input0.reshape(input1)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"LogSoftmax" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = match get_attr_opt::<i64>(node, "axis")? {
|
||||
None => candle_nn::ops::softmax_last_dim(input)?,
|
||||
Some(&axis) => {
|
||||
let axis = input.normalize_axis(axis)?;
|
||||
candle_nn::ops::log_softmax(input, axis)?
|
||||
}
|
||||
};
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Softmax" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = match get_attr_opt::<i64>(node, "axis")? {
|
||||
None => candle_nn::ops::softmax_last_dim(input)?,
|
||||
Some(&axis) => {
|
||||
let axis = input.normalize_axis(axis)?;
|
||||
candle_nn::ops::softmax(input, axis)?
|
||||
}
|
||||
};
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Transpose" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = match get_attr_opt::<[i64]>(node, "perm")? {
|
||||
None => input.t()?,
|
||||
Some(perm) => {
|
||||
let perm = perm.iter().map(|&v| v as usize).collect::<Vec<_>>();
|
||||
input.permute(perm)?
|
||||
}
|
||||
};
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Dropout" => {
|
||||
let input = get(&node.input[0])?;
|
||||
// Do not apply dropout at the moment, consider that we're only doing inference.
|
||||
values.insert(node.output[0].clone(), input.clone());
|
||||
}
|
||||
"MaxPool" => {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#MaxPool
|
||||
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
|
||||
let kernel_shape = get_attr::<[i64]>(node, "kernel_shape")?;
|
||||
let pads = get_attr_opt::<[i64]>(node, "pads")?;
|
||||
let strides = get_attr_opt::<[i64]>(node, "strides")?;
|
||||
let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
|
||||
match auto_pad {
|
||||
None | Some("NOTSET") => (),
|
||||
Some(s) => bail!("unsupported auto_pad {s}"),
|
||||
};
|
||||
if let Some(d) = dilations {
|
||||
if d.iter().any(|&v| v != 1) {
|
||||
bail!("MaxPool with dilation != 1, {dilations:?}")
|
||||
}
|
||||
}
|
||||
if let Some(d) = pads {
|
||||
if d.iter().any(|&v| v != 0) {
|
||||
bail!("MaxPool with pads != 0, {pads:?}")
|
||||
}
|
||||
}
|
||||
let xs = get(&node.input[0])?;
|
||||
let (k1, k2) = match kernel_shape {
|
||||
[k1, k2] => (*k1 as usize, *k2 as usize),
|
||||
_ => bail!("only 2d MaxPool is supported, kernel shape {kernel_shape:?}"),
|
||||
};
|
||||
let ys = match strides {
|
||||
None => xs.max_pool2d((k1, k2))?,
|
||||
Some([s1, s2]) => {
|
||||
xs.max_pool2d_with_stride((k1, k2), (*s1 as usize, *s2 as usize))?
|
||||
}
|
||||
Some(strides) => bail!("only 2d MaxPool is supported, strides {strides:?}"),
|
||||
};
|
||||
values.insert(node.output[0].clone(), ys);
|
||||
}
|
||||
"AveragePool" => {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool
|
||||
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
|
||||
let kernel_shape = get_attr::<[i64]>(node, "kernel_shape")?;
|
||||
let pads = get_attr_opt::<[i64]>(node, "pads")?;
|
||||
let strides = get_attr_opt::<[i64]>(node, "strides")?;
|
||||
let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
|
||||
match auto_pad {
|
||||
None | Some("NOTSET") => (),
|
||||
Some(s) => bail!("unsupported auto_pad {s}"),
|
||||
};
|
||||
if let Some(d) = dilations {
|
||||
if d.iter().any(|&v| v != 1) {
|
||||
bail!("AvgPool with dilation != 1, {dilations:?}")
|
||||
}
|
||||
}
|
||||
if let Some(d) = pads {
|
||||
if d.iter().any(|&v| v != 0) {
|
||||
bail!("AvgPool with pads != 0, {pads:?}")
|
||||
}
|
||||
}
|
||||
let xs = get(&node.input[0])?;
|
||||
let (k1, k2) = match kernel_shape {
|
||||
[k1, k2] => (*k1 as usize, *k2 as usize),
|
||||
_ => bail!("only 2d AvgPool is supported, kernel shape {kernel_shape:?}"),
|
||||
};
|
||||
let ys = match strides {
|
||||
None => xs.avg_pool2d((k1, k2))?,
|
||||
Some([s1, s2]) => {
|
||||
xs.avg_pool2d_with_stride((k1, k2), (*s1 as usize, *s2 as usize))?
|
||||
}
|
||||
Some(strides) => bail!("only 2d AvgPool is supported, strides {strides:?}"),
|
||||
};
|
||||
values.insert(node.output[0].clone(), ys);
|
||||
}
|
||||
"BatchNormalization" => {
|
||||
let training_mode = get_attr_opt::<i64>(node, "training_mode")?;
|
||||
if training_mode.copied().unwrap_or(0) != 0 {
|
||||
bail!("training mode is not supported for BatchNorm")
|
||||
}
|
||||
let eps = get_attr_opt::<f32>(node, "epsilon")?
|
||||
.copied()
|
||||
.unwrap_or(1e-5);
|
||||
let xs = get(&node.input[0])?;
|
||||
let weight = get(&node.input[1])?;
|
||||
let bias = get(&node.input[2])?;
|
||||
let running_mean = get(&node.input[3])?;
|
||||
let running_var = get(&node.input[4])?;
|
||||
let target_shape: Vec<usize> = xs
|
||||
.dims()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, v)| if idx == 1 { *v } else { 1 })
|
||||
.collect();
|
||||
let target_shape = target_shape.as_slice();
|
||||
let xs = xs
|
||||
.broadcast_sub(&running_mean.reshape(target_shape)?)?
|
||||
.broadcast_div(&(running_var.reshape(target_shape)? + eps as f64)?.sqrt()?)?;
|
||||
let weight = weight.reshape(target_shape)?;
|
||||
let bias = bias.reshape(target_shape)?;
|
||||
let xs = xs.broadcast_mul(&weight)?.broadcast_add(&bias)?;
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
"Squeeze" => {
|
||||
let xs = get(&node.input[0])?;
|
||||
let mut axes = if node.input.len() <= 1 {
|
||||
// contract all the dimensions with size 1 except the batch dim.
|
||||
xs.dims()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.flat_map(|(idx, &s)| if s == 1 && idx > 0 { Some(idx) } else { None })
|
||||
.collect()
|
||||
} else {
|
||||
get(&node.input[1])?
|
||||
.to_vec1::<i64>()?
|
||||
.iter()
|
||||
.map(|&i| xs.normalize_axis(i))
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
};
|
||||
axes.sort();
|
||||
let mut xs = xs.clone();
|
||||
for &axis in axes.iter().rev() {
|
||||
xs = xs.squeeze(axis)?
|
||||
}
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
"ConstantOfShape" => {
|
||||
let dims = get(&node.input[0])?;
|
||||
let shape = dims
|
||||
.to_vec1::<i64>()?
|
||||
.into_iter()
|
||||
.map(|v| v as usize)
|
||||
.collect::<Vec<_>>();
|
||||
let xs = Tensor::zeros(shape, DType::F32, dims.device())?;
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
"Unsqueeze" => {
|
||||
let xs = get(&node.input[0])?;
|
||||
let axes = match get_attr_opt::<[i64]>(node, "axes")? {
|
||||
Some(axis) => axis.to_vec(),
|
||||
None => get(&node.input[1])?.to_vec1::<i64>()?,
|
||||
};
|
||||
let mut axes = axes
|
||||
.iter()
|
||||
.map(|&i| {
|
||||
if i == xs.rank() as i64 {
|
||||
Ok(xs.rank())
|
||||
} else {
|
||||
xs.normalize_axis(i)
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
axes.sort();
|
||||
let mut xs = xs.clone();
|
||||
for &axis in axes.iter().rev() {
|
||||
xs = xs.unsqueeze(axis)?
|
||||
}
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
"Clip" => {
|
||||
let xs = get(&node.input[0])?;
|
||||
let xs = if node.input.len() >= 2 {
|
||||
let mins = get(&node.input[1])?;
|
||||
xs.broadcast_maximum(mins)?
|
||||
} else {
|
||||
xs.clone()
|
||||
};
|
||||
let xs = if node.input.len() >= 3 {
|
||||
let maxs = get(&node.input[2])?;
|
||||
xs.broadcast_minimum(maxs)?
|
||||
} else {
|
||||
xs.clone()
|
||||
};
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
"Gather" => {
|
||||
let xs = get(&node.input[0])?;
|
||||
let indices = get(&node.input[1])?;
|
||||
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
|
||||
let axis = xs.normalize_axis(axis)?;
|
||||
// TODO: Provide an op to handle the ONNX generalized gather op ideally in a
|
||||
// differentiable way.
|
||||
let xs = if indices.rank() == 0 {
|
||||
let index = indices.to_vec0::<i64>()? as usize;
|
||||
xs.narrow(axis, index, 1)?.squeeze(axis)?
|
||||
} else {
|
||||
todo!("implement gather for {xs:?} {indices:?} axis {axis}")
|
||||
};
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
"Shape" => {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape
|
||||
let xs = get(&node.input[0])?;
|
||||
let start = get_attr_opt::<i64>(node, "start")?.copied().unwrap_or(0);
|
||||
let end = get_attr_opt::<i64>(node, "end")?.copied().unwrap_or(-1);
|
||||
let start = xs.normalize_axis(start)?;
|
||||
let end = xs.normalize_axis(end)?;
|
||||
let mut dims = vec![];
|
||||
for idx in start..=end {
|
||||
dims.push(xs.dim(idx)? as i64)
|
||||
}
|
||||
let dims = Tensor::from_vec(dims, xs.rank(), xs.device())?;
|
||||
values.insert(node.output[0].clone(), dims);
|
||||
}
|
||||
"Conv" => {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
|
||||
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
|
||||
let groups = get_attr_opt::<i64>(node, "group")?.copied().unwrap_or(1);
|
||||
let _kernel_shape = get_attr_opt::<[i64]>(node, "kernel_shape")?;
|
||||
let pads = get_attr_opt::<[i64]>(node, "pads")?;
|
||||
let strides = get_attr_opt::<[i64]>(node, "strides")?;
|
||||
let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
|
||||
match auto_pad {
|
||||
None | Some("NOTSET") => (),
|
||||
Some(s) => bail!("unsupported auto_pad {s}"),
|
||||
};
|
||||
let xs = get(&node.input[0])?;
|
||||
let ws = get(&node.input[1])?;
|
||||
let ys = match ws.rank() {
|
||||
3 => {
|
||||
let (pads, xs) = match pads {
|
||||
None => (0, xs.clone()),
|
||||
Some([p]) => (*p as usize, xs.clone()),
|
||||
Some([p1, p2]) => {
|
||||
if p1 != p2 {
|
||||
(0usize, xs.pad_with_zeros(2, *p1 as usize, *p2 as usize)?)
|
||||
} else {
|
||||
(*p1 as usize, xs.clone())
|
||||
}
|
||||
}
|
||||
Some(pads) => {
|
||||
bail!("more pads than expected in conv1d {pads:?} {}", node.name)
|
||||
}
|
||||
};
|
||||
let strides = match strides {
|
||||
None => 1,
|
||||
Some([p]) => *p as usize,
|
||||
Some(s) => {
|
||||
bail!("more strides than expected in conv1d {s:?} {}", node.name)
|
||||
}
|
||||
};
|
||||
let dilations = match dilations {
|
||||
None => 1,
|
||||
Some([p]) => *p as usize,
|
||||
Some(s) => {
|
||||
bail!("more dilations than expected in conv1d {s:?} {}", node.name)
|
||||
}
|
||||
};
|
||||
xs.conv1d(ws, pads, strides, dilations, groups as usize)?
|
||||
}
|
||||
4 => {
|
||||
let (pads, xs) = match pads {
|
||||
None => (0, xs.clone()),
|
||||
Some([p]) => (*p as usize, xs.clone()),
|
||||
Some(&[p1, p2, p3, p4]) => {
|
||||
let p1 = p1 as usize;
|
||||
let p2 = p2 as usize;
|
||||
let p3 = p3 as usize;
|
||||
let p4 = p4 as usize;
|
||||
if p1 != p2 || p1 != p3 || p1 != p4 {
|
||||
(0, xs.pad_with_zeros(2, p1, p3)?.pad_with_zeros(3, p2, p4)?)
|
||||
} else {
|
||||
(p1, xs.clone())
|
||||
}
|
||||
}
|
||||
Some(pads) => {
|
||||
bail!("more pads than expected in conv2d {pads:?} {}", node.name)
|
||||
}
|
||||
};
|
||||
let strides = match strides {
|
||||
None => 1,
|
||||
Some([p]) => *p as usize,
|
||||
Some([p1, p2]) => {
|
||||
if p1 != p2 {
|
||||
bail!(
|
||||
"strides have to be the same on both axis {pads:?} {}",
|
||||
node.name
|
||||
)
|
||||
}
|
||||
*p1 as usize
|
||||
}
|
||||
Some(s) => {
|
||||
bail!("more strides than expected in conv2d {s:?} {}", node.name)
|
||||
}
|
||||
};
|
||||
let dilations = match dilations {
|
||||
None => 1,
|
||||
Some([p]) => *p as usize,
|
||||
Some([p1, p2]) => {
|
||||
if p1 != p2 {
|
||||
bail!(
|
||||
"dilations have to be the same on both axis {pads:?} {}",
|
||||
node.name
|
||||
)
|
||||
}
|
||||
*p1 as usize
|
||||
}
|
||||
Some(s) => {
|
||||
bail!("more dilations than expected in conv2d {s:?} {}", node.name)
|
||||
}
|
||||
};
|
||||
xs.conv2d(ws, pads, strides, dilations, groups as usize)?
|
||||
}
|
||||
rank => bail!(
|
||||
"unsupported rank for weight matrix {rank} in conv {}",
|
||||
node.name
|
||||
),
|
||||
};
|
||||
let ys = if node.input.len() > 2 {
|
||||
let bs = get(&node.input[2])?;
|
||||
let mut bs_shape = vec![1; ys.rank()];
|
||||
bs_shape[1] = bs.elem_count();
|
||||
ys.broadcast_add(&bs.reshape(bs_shape)?)?
|
||||
} else {
|
||||
ys
|
||||
};
|
||||
values.insert(node.output[0].clone(), ys);
|
||||
}
|
||||
"Concat" => {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Concat
|
||||
let inputs = node
|
||||
.input
|
||||
.iter()
|
||||
.map(|n| Ok(get(n.as_str())?.clone()))
|
||||
.collect::<Result<Vec<Value>>>()?;
|
||||
let axis: i64 = *get_attr(node, "axis")?;
|
||||
if inputs.is_empty() {
|
||||
bail!("empty concat")
|
||||
};
|
||||
let axis = inputs[0].normalize_axis(axis)?;
|
||||
let output = Tensor::cat(&inputs, axis)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Abs" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = input.abs()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Cos" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = input.cos()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Sin" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = input.sin()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Neg" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = input.neg()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Erf" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = input.erf()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Tanh" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = input.tanh()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Sigmoid" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = candle_nn::ops::sigmoid(input)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Gelu" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = input.gelu_erf()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Relu" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = input.relu()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Constant
|
||||
"Constant" => {
|
||||
let value = match node.attribute.iter().find(|attr| attr.name == "value") {
|
||||
None => {
|
||||
// TODO: support sparse_value etc.
|
||||
bail!("cannot find 'value' attr in 'Constant' for {}", node.name)
|
||||
}
|
||||
Some(value) => value,
|
||||
};
|
||||
let output = match value.r#type() {
|
||||
AttributeType::Tensor => {
|
||||
let t = value.t.as_ref().unwrap();
|
||||
get_tensor(t, &node.name)?
|
||||
}
|
||||
rtype => bail!("unsupported 'value' type {rtype:?} for {}", node.name),
|
||||
};
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
|
||||
"Cast" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let dt: i64 = *get_attr(node, "to")?;
|
||||
let dtype = match DataType::try_from(dt as i32) {
|
||||
Ok(DataType::Int32) => DType::I64,
|
||||
Ok(dt) => match dtype(dt) {
|
||||
Some(dt) => dt,
|
||||
None => {
|
||||
bail!("unsupported 'to' value {dt:?} for cast {}", node.name)
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
bail!("unsupported 'to' value {dt:?} for cast {}", node.name)
|
||||
}
|
||||
};
|
||||
let output = input.to_dtype(dtype)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||
}
|
||||
}
|
||||
graph
|
||||
.output
|
||||
.iter()
|
||||
.map(|output| match values.remove(&output.name) {
|
||||
None => bail!("cannot find output {}", output.name),
|
||||
Some(value) => Ok((output.name.clone(), value)),
|
||||
})
|
||||
.collect()
|
||||
}
|
14
candle-onnx/src/lib.rs
Normal file
14
candle-onnx/src/lib.rs
Normal file
@ -0,0 +1,14 @@
|
||||
use candle::Result;
|
||||
use prost::Message;
|
||||
|
||||
pub mod onnx {
|
||||
include!(concat!(env!("OUT_DIR"), "/onnx.rs"));
|
||||
}
|
||||
|
||||
pub mod eval;
|
||||
pub use eval::{dtype, simple_eval};
|
||||
|
||||
pub fn read_file<P: AsRef<std::path::Path>>(p: P) -> Result<onnx::ModelProto> {
|
||||
let buf = std::fs::read(p)?;
|
||||
onnx::ModelProto::decode(buf.as_slice()).map_err(candle::Error::wrap)
|
||||
}
|
836
candle-onnx/src/onnx.proto3
Normal file
836
candle-onnx/src/onnx.proto3
Normal file
@ -0,0 +1,836 @@
|
||||
//
|
||||
// WARNING: This file is automatically generated! Please edit onnx.in.proto.
|
||||
//
|
||||
|
||||
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package onnx;
|
||||
|
||||
// Overview
|
||||
//
|
||||
// ONNX is an open specification that is comprised of the following components:
|
||||
//
|
||||
// 1) A definition of an extensible computation graph model.
|
||||
// 2) Definitions of standard data types.
|
||||
// 3) Definitions of built-in operators.
|
||||
//
|
||||
// This document describes the syntax of models and their computation graphs,
|
||||
// as well as the standard data types. Together, they are referred to as the ONNX
|
||||
// Intermediate Representation, or 'IR' for short.
|
||||
//
|
||||
// The normative semantic specification of the ONNX IR is found in docs/IR.md.
|
||||
// Definitions of the built-in neural network operators may be found in docs/Operators.md.
|
||||
|
||||
// Notes
|
||||
//
|
||||
// Protobuf compatibility
|
||||
//
|
||||
// To simplify framework compatibility, ONNX is defined using the subset of protobuf
|
||||
// that is compatible with both protobuf v2 and v3. This means that we do not use any
|
||||
// protobuf features that are only available in one of the two versions.
|
||||
//
|
||||
// Here are the most notable contortions we have to carry out to work around
|
||||
// these limitations:
|
||||
//
|
||||
// - No 'map' (added protobuf 3.0). We instead represent mappings as lists
|
||||
// of key-value pairs, where order does not matter and duplicates
|
||||
// are not allowed.
|
||||
|
||||
|
||||
// Versioning
|
||||
//
|
||||
// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md
|
||||
//
|
||||
// To be compatible with both proto2 and proto3, we will use a version number
|
||||
// that is not defined by the default value but an explicit enum number.
|
||||
enum Version {
|
||||
// proto3 requires the first enum value to be zero.
|
||||
// We add this just to appease the compiler.
|
||||
_START_VERSION = 0;
|
||||
// The version field is always serialized and we will use it to store the
|
||||
// version that the graph is generated from. This helps us set up version
|
||||
// control.
|
||||
// For the IR, we are using simple numbers starting with 0x00000001,
|
||||
// which was the version we published on Oct 10, 2017.
|
||||
IR_VERSION_2017_10_10 = 0x0000000000000001;
|
||||
|
||||
// IR_VERSION 2 published on Oct 30, 2017
|
||||
// - Added type discriminator to AttributeProto to support proto3 users
|
||||
IR_VERSION_2017_10_30 = 0x0000000000000002;
|
||||
|
||||
// IR VERSION 3 published on Nov 3, 2017
|
||||
// - For operator versioning:
|
||||
// - Added new message OperatorSetIdProto
|
||||
// - Added opset_import in ModelProto
|
||||
// - For vendor extensions, added domain in NodeProto
|
||||
IR_VERSION_2017_11_3 = 0x0000000000000003;
|
||||
|
||||
// IR VERSION 4 published on Jan 22, 2019
|
||||
// - Relax constraint that initializers should be a subset of graph inputs
|
||||
// - Add type BFLOAT16
|
||||
IR_VERSION_2019_1_22 = 0x0000000000000004;
|
||||
|
||||
// IR VERSION 5 published on March 18, 2019
|
||||
// - Add message TensorAnnotation.
|
||||
// - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters.
|
||||
IR_VERSION_2019_3_18 = 0x0000000000000005;
|
||||
|
||||
// IR VERSION 6 published on Sep 19, 2019
|
||||
// - Add support for sparse tensor constants stored in model.
|
||||
// - Add message SparseTensorProto
|
||||
// - Add sparse initializers
|
||||
IR_VERSION_2019_9_19 = 0x0000000000000006;
|
||||
|
||||
// IR VERSION 7 published on May 8, 2020
|
||||
// - Add support to allow function body graph to rely on multiple external opreator sets.
|
||||
// - Add a list to promote inference graph's initializers to global and
|
||||
// mutable variables. Global variables are visible in all graphs of the
|
||||
// stored models.
|
||||
// - Add message TrainingInfoProto to store initialization
|
||||
// method and training algorithm. The execution of TrainingInfoProto
|
||||
// can modify the values of mutable variables.
|
||||
// - Implicitly add inference graph into each TrainingInfoProto's algorithm.
|
||||
IR_VERSION_2020_5_8 = 0x0000000000000007;
|
||||
|
||||
// IR VERSION 8 published on July 30, 2021
|
||||
// Introduce TypeProto.SparseTensor
|
||||
// Introduce TypeProto.Optional
|
||||
// Added a list of FunctionProtos local to the model
|
||||
// Deprecated since_version and operator status from FunctionProto
|
||||
IR_VERSION_2021_7_30 = 0x0000000000000008;
|
||||
|
||||
// IR VERSION 9 published on May 5, 2023
|
||||
// Added AttributeProto to FunctionProto so that default attribute values can be set.
|
||||
// Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ.
|
||||
IR_VERSION = 0x0000000000000009;
|
||||
}
|
||||
|
||||
// Attributes
|
||||
//
|
||||
// A named attribute containing either singular float, integer, string, graph,
|
||||
// and tensor values, or repeated float, integer, string, graph, and tensor values.
|
||||
// An AttributeProto MUST contain the name field, and *only one* of the
|
||||
// following content fields, effectively enforcing a C/C++ union equivalent.
|
||||
message AttributeProto {
|
||||
reserved 12, 16 to 19;
|
||||
reserved "v";
|
||||
|
||||
// Note: this enum is structurally identical to the OpSchema::AttrType
|
||||
// enum defined in schema.h. If you rev one, you likely need to rev the other.
|
||||
enum AttributeType {
|
||||
UNDEFINED = 0;
|
||||
FLOAT = 1;
|
||||
INT = 2;
|
||||
STRING = 3;
|
||||
TENSOR = 4;
|
||||
GRAPH = 5;
|
||||
SPARSE_TENSOR = 11;
|
||||
TYPE_PROTO = 13;
|
||||
|
||||
FLOATS = 6;
|
||||
INTS = 7;
|
||||
STRINGS = 8;
|
||||
TENSORS = 9;
|
||||
GRAPHS = 10;
|
||||
SPARSE_TENSORS = 12;
|
||||
TYPE_PROTOS = 14;
|
||||
}
|
||||
|
||||
// The name field MUST be present for this version of the IR.
|
||||
string name = 1; // namespace Attribute
|
||||
|
||||
// if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
|
||||
// In this case, this AttributeProto does not contain data, and it's a reference of attribute
|
||||
// in parent scope.
|
||||
// NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph.
|
||||
string ref_attr_name = 21;
|
||||
|
||||
// A human-readable documentation for this attribute. Markdown is allowed.
|
||||
string doc_string = 13;
|
||||
|
||||
// The type field MUST be present for this version of the IR.
|
||||
// For 0.0.1 versions of the IR, this field was not defined, and
|
||||
// implementations needed to use has_field heuristics to determine
|
||||
// which value field was in use. For IR_VERSION 0.0.2 or later, this
|
||||
// field MUST be set and match the f|i|s|t|... field in use. This
|
||||
// change was made to accommodate proto3 implementations.
|
||||
AttributeType type = 20; // discriminator that indicates which field below is in use
|
||||
|
||||
// Exactly ONE of the following fields must be present for this version of the IR
|
||||
float f = 2; // float
|
||||
int64 i = 3; // int
|
||||
bytes s = 4; // UTF-8 string
|
||||
TensorProto t = 5; // tensor value
|
||||
GraphProto g = 6; // graph
|
||||
SparseTensorProto sparse_tensor = 22; // sparse tensor value
|
||||
// Do not use field below, it's deprecated.
|
||||
// optional ValueProto v = 12; // value - subsumes everything but graph
|
||||
TypeProto tp = 14; // type proto
|
||||
|
||||
repeated float floats = 7; // list of floats
|
||||
repeated int64 ints = 8; // list of ints
|
||||
repeated bytes strings = 9; // list of UTF-8 strings
|
||||
repeated TensorProto tensors = 10; // list of tensors
|
||||
repeated GraphProto graphs = 11; // list of graph
|
||||
repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors
|
||||
repeated TypeProto type_protos = 15;// list of type protos
|
||||
}
|
||||
|
||||
// Defines information on value, including the name, the type, and
|
||||
// the shape of the value.
|
||||
message ValueInfoProto {
|
||||
// This field MUST be present in this version of the IR.
|
||||
string name = 1; // namespace Value
|
||||
// This field MUST be present in this version of the IR for
|
||||
// inputs and outputs of the top-level graph.
|
||||
TypeProto type = 2;
|
||||
// A human-readable documentation for this value. Markdown is allowed.
|
||||
string doc_string = 3;
|
||||
}
|
||||
|
||||
// Nodes
|
||||
//
|
||||
// Computation graphs are made up of a DAG of nodes, which represent what is
|
||||
// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
|
||||
//
|
||||
// For example, it can be a node of type "Conv" that takes in an image, a filter
|
||||
// tensor and a bias tensor, and produces the convolved output.
|
||||
message NodeProto {
|
||||
repeated string input = 1; // namespace Value
|
||||
repeated string output = 2; // namespace Value
|
||||
|
||||
// An optional identifier for this node in a graph.
|
||||
// This field MAY be absent in ths version of the IR.
|
||||
string name = 3; // namespace Node
|
||||
|
||||
// The symbolic identifier of the Operator to execute.
|
||||
string op_type = 4; // namespace Operator
|
||||
// The domain of the OperatorSet that specifies the operator named by op_type.
|
||||
string domain = 7; // namespace Domain
|
||||
|
||||
// Additional named attributes.
|
||||
repeated AttributeProto attribute = 5;
|
||||
|
||||
// A human-readable documentation for this node. Markdown is allowed.
|
||||
string doc_string = 6;
|
||||
}
|
||||
|
||||
// Training information
|
||||
// TrainingInfoProto stores information for training a model.
|
||||
// In particular, this defines two functionalities: an initialization-step
|
||||
// and a training-algorithm-step. Initialization resets the model
|
||||
// back to its original state as if no training has been performed.
|
||||
// Training algorithm improves the model based on input data.
|
||||
//
|
||||
// The semantics of the initialization-step is that the initializers
|
||||
// in ModelProto.graph and in TrainingInfoProto.algorithm are first
|
||||
// initialized as specified by the initializers in the graph, and then
|
||||
// updated by the "initialization_binding" in every instance in
|
||||
// ModelProto.training_info.
|
||||
//
|
||||
// The field "algorithm" defines a computation graph which represents a
|
||||
// training algorithm's step. After the execution of a
|
||||
// TrainingInfoProto.algorithm, the initializers specified by "update_binding"
|
||||
// may be immediately updated. If the targeted training algorithm contains
|
||||
// consecutive update steps (such as block coordinate descent methods),
|
||||
// the user needs to create a TrainingInfoProto for each step.
|
||||
message TrainingInfoProto {
|
||||
// This field describes a graph to compute the initial tensors
|
||||
// upon starting the training process. Initialization graph has no input
|
||||
// and can have multiple outputs. Usually, trainable tensors in neural
|
||||
// networks are randomly initialized. To achieve that, for each tensor,
|
||||
// the user can put a random number operator such as RandomNormal or
|
||||
// RandomUniform in TrainingInfoProto.initialization.node and assign its
|
||||
// random output to the specific tensor using "initialization_binding".
|
||||
// This graph can also set the initializers in "algorithm" in the same
|
||||
// TrainingInfoProto; a use case is resetting the number of training
|
||||
// iteration to zero.
|
||||
//
|
||||
// By default, this field is an empty graph and its evaluation does not
|
||||
// produce any output. Thus, no initializer would be changed by default.
|
||||
GraphProto initialization = 1;
|
||||
|
||||
// This field represents a training algorithm step. Given required inputs,
|
||||
// it computes outputs to update initializers in its own or inference graph's
|
||||
// initializer lists. In general, this field contains loss node, gradient node,
|
||||
// optimizer node, increment of iteration count.
|
||||
//
|
||||
// An execution of the training algorithm step is performed by executing the
|
||||
// graph obtained by combining the inference graph (namely "ModelProto.graph")
|
||||
// and the "algorithm" graph. That is, the actual
|
||||
// input/initializer/output/node/value_info/sparse_initializer list of
|
||||
// the training graph is the concatenation of
|
||||
// "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer"
|
||||
// and "algorithm.input/initializer/output/node/value_info/sparse_initializer"
|
||||
// in that order. This combined graph must satisfy the normal ONNX conditions.
|
||||
// Now, let's provide a visualization of graph combination for clarity.
|
||||
// Let the inference graph (i.e., "ModelProto.graph") be
|
||||
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d
|
||||
// and the "algorithm" graph be
|
||||
// tensor_d -> Add -> tensor_e
|
||||
// The combination process results
|
||||
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e
|
||||
//
|
||||
// Notice that an input of a node in the "algorithm" graph may reference the
|
||||
// output of a node in the inference graph (but not the other way round). Also, inference
|
||||
// node cannot reference inputs of "algorithm". With these restrictions, inference graph
|
||||
// can always be run independently without training information.
|
||||
//
|
||||
// By default, this field is an empty graph and its evaluation does not
|
||||
// produce any output. Evaluating the default training step never
|
||||
// update any initializers.
|
||||
GraphProto algorithm = 2;
|
||||
|
||||
// This field specifies the bindings from the outputs of "initialization" to
|
||||
// some initializers in "ModelProto.graph.initializer" and
|
||||
// the "algorithm.initializer" in the same TrainingInfoProto.
|
||||
// See "update_binding" below for details.
|
||||
//
|
||||
// By default, this field is empty and no initializer would be changed
|
||||
// by the execution of "initialization".
|
||||
repeated StringStringEntryProto initialization_binding = 3;
|
||||
|
||||
// Gradient-based training is usually an iterative procedure. In one gradient
|
||||
// descent iteration, we apply
|
||||
//
|
||||
// x = x - r * g
|
||||
//
|
||||
// where "x" is the optimized tensor, "r" stands for learning rate, and "g" is
|
||||
// gradient of "x" with respect to a chosen loss. To avoid adding assignments
|
||||
// into the training graph, we split the update equation into
|
||||
//
|
||||
// y = x - r * g
|
||||
// x = y
|
||||
//
|
||||
// The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To
|
||||
// tell that "y" should be assigned to "x", the field "update_binding" may
|
||||
// contain a key-value pair of strings, "x" (key of StringStringEntryProto)
|
||||
// and "y" (value of StringStringEntryProto).
|
||||
// For a neural network with multiple trainable (mutable) tensors, there can
|
||||
// be multiple key-value pairs in "update_binding".
|
||||
//
|
||||
// The initializers appears as keys in "update_binding" are considered
|
||||
// mutable variables. This implies some behaviors
|
||||
// as described below.
|
||||
//
|
||||
// 1. We have only unique keys in all "update_binding"s so that two
|
||||
// variables may not have the same name. This ensures that one
|
||||
// variable is assigned up to once.
|
||||
// 2. The keys must appear in names of "ModelProto.graph.initializer" or
|
||||
// "TrainingInfoProto.algorithm.initializer".
|
||||
// 3. The values must be output names of "algorithm" or "ModelProto.graph.output".
|
||||
// 4. Mutable variables are initialized to the value specified by the
|
||||
// corresponding initializer, and then potentially updated by
|
||||
// "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s.
|
||||
//
|
||||
// This field usually contains names of trainable tensors
|
||||
// (in ModelProto.graph), optimizer states such as momentums in advanced
|
||||
// stochastic gradient methods (in TrainingInfoProto.graph),
|
||||
// and number of training iterations (in TrainingInfoProto.graph).
|
||||
//
|
||||
// By default, this field is empty and no initializer would be changed
|
||||
// by the execution of "algorithm".
|
||||
repeated StringStringEntryProto update_binding = 4;
|
||||
}
|
||||
|
||||
// Models
|
||||
//
|
||||
// ModelProto is a top-level file/container format for bundling a ML model and
|
||||
// associating its computation graph with metadata.
|
||||
//
|
||||
// The semantics of the model are described by the associated GraphProto's.
|
||||
message ModelProto {
|
||||
// The version of the IR this model targets. See Version enum above.
|
||||
// This field MUST be present.
|
||||
int64 ir_version = 1;
|
||||
|
||||
// The OperatorSets this model relies on.
|
||||
// All ModelProtos MUST have at least one entry that
|
||||
// specifies which version of the ONNX OperatorSet is
|
||||
// being imported.
|
||||
//
|
||||
// All nodes in the ModelProto's graph will bind against the operator
|
||||
// with the same-domain/same-op_type operator with the HIGHEST version
|
||||
// in the referenced operator sets.
|
||||
repeated OperatorSetIdProto opset_import = 8;
|
||||
|
||||
// The name of the framework or tool used to generate this model.
|
||||
// This field SHOULD be present to indicate which implementation/tool/framework
|
||||
// emitted the model.
|
||||
string producer_name = 2;
|
||||
|
||||
// The version of the framework or tool used to generate this model.
|
||||
// This field SHOULD be present to indicate which implementation/tool/framework
|
||||
// emitted the model.
|
||||
string producer_version = 3;
|
||||
|
||||
// Domain name of the model.
|
||||
// We use reverse domain names as name space indicators. For example:
|
||||
// `com.facebook.fair` or `com.microsoft.cognitiveservices`
|
||||
//
|
||||
// Together with `model_version` and GraphProto.name, this forms the unique identity of
|
||||
// the graph.
|
||||
string domain = 4;
|
||||
|
||||
// The version of the graph encoded. See Version enum below.
|
||||
int64 model_version = 5;
|
||||
|
||||
// A human-readable documentation for this model. Markdown is allowed.
|
||||
string doc_string = 6;
|
||||
|
||||
// The parameterized graph that is evaluated to execute the model.
|
||||
GraphProto graph = 7;
|
||||
|
||||
// Named metadata values; keys should be distinct.
|
||||
repeated StringStringEntryProto metadata_props = 14;
|
||||
|
||||
// Training-specific information. Sequentially executing all stored
|
||||
// `TrainingInfoProto.algorithm`s and assigning their outputs following
|
||||
// the corresponding `TrainingInfoProto.update_binding`s is one training
|
||||
// iteration. Similarly, to initialize the model
|
||||
// (as if training hasn't happened), the user should sequentially execute
|
||||
// all stored `TrainingInfoProto.initialization`s and assigns their outputs
|
||||
// using `TrainingInfoProto.initialization_binding`s.
|
||||
//
|
||||
// If this field is empty, the training behavior of the model is undefined.
|
||||
repeated TrainingInfoProto training_info = 20;
|
||||
|
||||
// A list of function protos local to the model.
|
||||
//
|
||||
// Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain".
|
||||
// In case of any conflicts the behavior (whether the model local functions are given higher priority,
|
||||
// or standard operator sets are given higher priotity or this is treated as error) is defined by
|
||||
// the runtimes.
|
||||
//
|
||||
// The operator sets imported by FunctionProto should be compatible with the ones
|
||||
// imported by ModelProto and other model local FunctionProtos.
|
||||
// Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto
|
||||
// or by 2 FunctionProtos then versions for the operator set may be different but,
|
||||
// the operator schema returned for op_type, domain, version combination
|
||||
// for both the versions should be same for every node in the function body.
|
||||
//
|
||||
// One FunctionProto can reference other FunctionProto in the model, however, recursive reference
|
||||
// is not allowed.
|
||||
repeated FunctionProto functions = 25;
|
||||
};
|
||||
|
||||
// StringStringEntryProto follows the pattern for cross-proto-version maps.
|
||||
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
|
||||
message StringStringEntryProto {
|
||||
string key = 1;
|
||||
string value = 2;
|
||||
};
|
||||
|
||||
message TensorAnnotation {
|
||||
string tensor_name = 1;
|
||||
// <key, value> pairs to annotate tensor specified by <tensor_name> above.
|
||||
// The keys used in the mapping below must be pre-defined in ONNX spec.
|
||||
// For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as
|
||||
// quantization parameter keys.
|
||||
repeated StringStringEntryProto quant_parameter_tensor_names = 2;
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Graphs
|
||||
//
|
||||
// A graph defines the computational logic of a model and is comprised of a parameterized
|
||||
// list of nodes that form a directed acyclic graph based on their inputs and outputs.
|
||||
// This is the equivalent of the "network" or "graph" in many deep learning
|
||||
// frameworks.
|
||||
message GraphProto {
|
||||
// The nodes in the graph, sorted topologically.
|
||||
repeated NodeProto node = 1;
|
||||
|
||||
// The name of the graph.
|
||||
string name = 2; // namespace Graph
|
||||
|
||||
// A list of named tensor values, used to specify constant inputs of the graph.
|
||||
// Each initializer (both TensorProto as well SparseTensorProto) MUST have a name.
|
||||
// The name MUST be unique across both initializer and sparse_initializer,
|
||||
// but the name MAY also appear in the input list.
|
||||
repeated TensorProto initializer = 5;
|
||||
|
||||
// Initializers (see above) stored in sparse format.
|
||||
repeated SparseTensorProto sparse_initializer = 15;
|
||||
|
||||
// A human-readable documentation for this graph. Markdown is allowed.
|
||||
string doc_string = 10;
|
||||
|
||||
// The inputs and outputs of the graph.
|
||||
repeated ValueInfoProto input = 11;
|
||||
repeated ValueInfoProto output = 12;
|
||||
|
||||
// Information for the values in the graph. The ValueInfoProto.name's
|
||||
// must be distinct. It is optional for a value to appear in value_info list.
|
||||
repeated ValueInfoProto value_info = 13;
|
||||
|
||||
// This field carries information to indicate the mapping among a tensor and its
|
||||
// quantization parameter tensors. For example:
|
||||
// For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated,
|
||||
// which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
|
||||
repeated TensorAnnotation quantization_annotation = 14;
|
||||
|
||||
reserved 3, 4, 6 to 9;
|
||||
reserved "ir_version", "producer_version", "producer_tag", "domain";
|
||||
}
|
||||
|
||||
// Tensors
|
||||
//
|
||||
// A serialized tensor value.
|
||||
message TensorProto {
|
||||
enum DataType {
|
||||
UNDEFINED = 0;
|
||||
// Basic types.
|
||||
FLOAT = 1; // float
|
||||
UINT8 = 2; // uint8_t
|
||||
INT8 = 3; // int8_t
|
||||
UINT16 = 4; // uint16_t
|
||||
INT16 = 5; // int16_t
|
||||
INT32 = 6; // int32_t
|
||||
INT64 = 7; // int64_t
|
||||
STRING = 8; // string
|
||||
BOOL = 9; // bool
|
||||
|
||||
// IEEE754 half-precision floating-point format (16 bits wide).
|
||||
// This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
|
||||
FLOAT16 = 10;
|
||||
|
||||
DOUBLE = 11;
|
||||
UINT32 = 12;
|
||||
UINT64 = 13;
|
||||
COMPLEX64 = 14; // complex with float32 real and imaginary components
|
||||
COMPLEX128 = 15; // complex with float64 real and imaginary components
|
||||
|
||||
// Non-IEEE floating-point format based on IEEE754 single-precision
|
||||
// floating-point number truncated to 16 bits.
|
||||
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
|
||||
BFLOAT16 = 16;
|
||||
|
||||
// Non-IEEE floating-point format based on papers
|
||||
// FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433,
|
||||
// 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf.
|
||||
// Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear.
|
||||
// The computation usually happens inside a block quantize / dequantize
|
||||
// fused by the runtime.
|
||||
FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf
|
||||
FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero
|
||||
FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients
|
||||
FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero
|
||||
|
||||
// Future extensions go here.
|
||||
}
|
||||
|
||||
// The shape of the tensor.
|
||||
repeated int64 dims = 1;
|
||||
|
||||
// The data type of the tensor.
|
||||
// This field MUST have a valid TensorProto.DataType value
|
||||
int32 data_type = 2;
|
||||
|
||||
// For very large tensors, we may want to store them in chunks, in which
|
||||
// case the following fields will specify the segment that is stored in
|
||||
// the current TensorProto.
|
||||
message Segment {
|
||||
int64 begin = 1;
|
||||
int64 end = 2;
|
||||
}
|
||||
Segment segment = 3;
|
||||
|
||||
// Tensor content must be organized in row-major order.
|
||||
//
|
||||
// Depending on the data_type field, exactly one of the fields below with
|
||||
// name ending in _data is used to store the elements of the tensor.
|
||||
|
||||
// For float and complex64 values
|
||||
// Complex64 tensors are encoded as a single array of floats,
|
||||
// with the real components appearing in odd numbered positions,
|
||||
// and the corresponding imaginary component appearing in the
|
||||
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
|
||||
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
|
||||
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
|
||||
repeated float float_data = 4 [packed = true];
|
||||
|
||||
// For int32, uint8, int8, uint16, int16, bool, float8, and float16 values
|
||||
// float16 and float8 values must be bit-wise converted to an uint16_t prior
|
||||
// to writing to the buffer.
|
||||
// When this field is present, the data_type field MUST be
|
||||
// INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ
|
||||
repeated int32 int32_data = 5 [packed = true];
|
||||
|
||||
// For strings.
|
||||
// Each element of string_data is a UTF-8 encoded Unicode
|
||||
// string. No trailing null, no leading BOM. The protobuf "string"
|
||||
// scalar type is not used to match ML community conventions.
|
||||
// When this field is present, the data_type field MUST be STRING
|
||||
repeated bytes string_data = 6;
|
||||
|
||||
// For int64.
|
||||
// When this field is present, the data_type field MUST be INT64
|
||||
repeated int64 int64_data = 7 [packed = true];
|
||||
|
||||
// Optionally, a name for the tensor.
|
||||
string name = 8; // namespace Value
|
||||
|
||||
// A human-readable documentation for this tensor. Markdown is allowed.
|
||||
string doc_string = 12;
|
||||
|
||||
// Serializations can either use one of the fields above, or use this
|
||||
// raw bytes field. The only exception is the string case, where one is
|
||||
// required to store the content in the repeated bytes string_data field.
|
||||
//
|
||||
// When this raw_data field is used to store tensor value, elements MUST
|
||||
// be stored in as fixed-width, little-endian order.
|
||||
// Floating-point data types MUST be stored in IEEE 754 format.
|
||||
// Complex64 elements must be written as two consecutive FLOAT values, real component first.
|
||||
// Complex128 elements must be written as two consecutive DOUBLE values, real component first.
|
||||
// Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
|
||||
//
|
||||
// Note: the advantage of specific field rather than the raw_data field is
|
||||
// that in some cases (e.g. int data), protobuf does a better packing via
|
||||
// variable length storage, and may lead to smaller binary footprint.
|
||||
// When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
|
||||
bytes raw_data = 9;
|
||||
|
||||
// Data can be stored inside the protobuf file using type-specific fields or raw_data.
|
||||
// Alternatively, raw bytes data can be stored in an external file, using the external_data field.
|
||||
// external_data stores key-value pairs describing data location. Recognized keys are:
|
||||
// - "location" (required) - POSIX filesystem path relative to the directory where the ONNX
|
||||
// protobuf model was stored
|
||||
// - "offset" (optional) - position of byte at which stored data begins. Integer stored as string.
|
||||
// Offset values SHOULD be multiples 4096 (page size) to enable mmap support.
|
||||
// - "length" (optional) - number of bytes containing data. Integer stored as string.
|
||||
// - "checksum" (optional) - SHA1 digest of file specified in under 'location' key.
|
||||
repeated StringStringEntryProto external_data = 13;
|
||||
|
||||
// Location of the data for this tensor. MUST be one of:
|
||||
// - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field.
|
||||
// - EXTERNAL - data stored in an external location as described by external_data field.
|
||||
enum DataLocation {
|
||||
DEFAULT = 0;
|
||||
EXTERNAL = 1;
|
||||
}
|
||||
|
||||
// If value not set, data is stored in raw_data (if set) otherwise in type-specified field.
|
||||
DataLocation data_location = 14;
|
||||
|
||||
// For double
|
||||
// Complex128 tensors are encoded as a single array of doubles,
|
||||
// with the real components appearing in odd numbered positions,
|
||||
// and the corresponding imaginary component appearing in the
|
||||
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
|
||||
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
|
||||
// When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
|
||||
repeated double double_data = 10 [packed = true];
|
||||
|
||||
// For uint64 and uint32 values
|
||||
// When this field is present, the data_type field MUST be
|
||||
// UINT32 or UINT64
|
||||
repeated uint64 uint64_data = 11 [packed = true];
|
||||
}
|
||||
|
||||
// A serialized sparse-tensor value
|
||||
message SparseTensorProto {
|
||||
// The sequence of non-default values are encoded as a tensor of shape [NNZ].
|
||||
// The default-value is zero for numeric tensors, and empty-string for string tensors.
|
||||
// values must have a non-empty name present which serves as a name for SparseTensorProto
|
||||
// when used in sparse_initializer list.
|
||||
TensorProto values = 1;
|
||||
|
||||
// The indices of the non-default values, which may be stored in one of two formats.
|
||||
// (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value
|
||||
// corresponding to the j-th index of the i-th value (in the values tensor).
|
||||
// (b) Indices can be a tensor of shape [NNZ], in which case the i-th value
|
||||
// must be the linearized-index of the i-th value (in the values tensor).
|
||||
// The linearized-index can be converted into an index tuple (k_1,...,k_rank)
|
||||
// using the shape provided below.
|
||||
// The indices must appear in ascending order without duplication.
|
||||
// In the first format, the ordering is lexicographic-ordering:
|
||||
// e.g., index-value [1,4] must appear before [2,1]
|
||||
TensorProto indices = 2;
|
||||
|
||||
// The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank]
|
||||
repeated int64 dims = 3;
|
||||
}
|
||||
|
||||
// Defines a tensor shape. A dimension can be either an integer value
|
||||
// or a symbolic variable. A symbolic variable represents an unknown
|
||||
// dimension.
|
||||
message TensorShapeProto {
|
||||
message Dimension {
|
||||
oneof value {
|
||||
int64 dim_value = 1;
|
||||
string dim_param = 2; // namespace Shape
|
||||
};
|
||||
// Standard denotation can optionally be used to denote tensor
|
||||
// dimensions with standard semantic descriptions to ensure
|
||||
// that operations are applied to the correct axis of a tensor.
|
||||
// Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
|
||||
// for pre-defined dimension denotations.
|
||||
string denotation = 3;
|
||||
};
|
||||
repeated Dimension dim = 1;
|
||||
}
|
||||
|
||||
// Types
|
||||
//
|
||||
// The standard ONNX data types.
|
||||
message TypeProto {
|
||||
|
||||
message Tensor {
|
||||
// This field MUST NOT have the value of UNDEFINED
|
||||
// This field MUST have a valid TensorProto.DataType value
|
||||
// This field MUST be present for this version of the IR.
|
||||
int32 elem_type = 1;
|
||||
TensorShapeProto shape = 2;
|
||||
}
|
||||
|
||||
// repeated T
|
||||
message Sequence {
|
||||
// The type and optional shape of each element of the sequence.
|
||||
// This field MUST be present for this version of the IR.
|
||||
TypeProto elem_type = 1;
|
||||
};
|
||||
|
||||
// map<K,V>
|
||||
message Map {
|
||||
// This field MUST have a valid TensorProto.DataType value
|
||||
// This field MUST be present for this version of the IR.
|
||||
// This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING
|
||||
int32 key_type = 1;
|
||||
// This field MUST be present for this version of the IR.
|
||||
TypeProto value_type = 2;
|
||||
};
|
||||
|
||||
// wrapper for Tensor, Sequence, or Map
|
||||
message Optional {
|
||||
// The type and optional shape of the element wrapped.
|
||||
// This field MUST be present for this version of the IR.
|
||||
// Possible values correspond to OptionalProto.DataType enum
|
||||
TypeProto elem_type = 1;
|
||||
};
|
||||
|
||||
|
||||
message SparseTensor {
|
||||
// This field MUST NOT have the value of UNDEFINED
|
||||
// This field MUST have a valid TensorProto.DataType value
|
||||
// This field MUST be present for this version of the IR.
|
||||
int32 elem_type = 1;
|
||||
TensorShapeProto shape = 2;
|
||||
}
|
||||
|
||||
|
||||
oneof value {
|
||||
// The type of a tensor.
|
||||
Tensor tensor_type = 1;
|
||||
|
||||
// NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values
|
||||
// as input and output to graphs and nodes. These types are needed to naturally
|
||||
// support classical ML operators. DNN operators SHOULD restrict their input
|
||||
// and output types to tensors.
|
||||
|
||||
// The type of a sequence.
|
||||
Sequence sequence_type = 4;
|
||||
|
||||
// The type of a map.
|
||||
Map map_type = 5;
|
||||
|
||||
// The type of an optional.
|
||||
Optional optional_type = 9;
|
||||
|
||||
|
||||
// Type of the sparse tensor
|
||||
SparseTensor sparse_tensor_type = 8;
|
||||
|
||||
}
|
||||
|
||||
// An optional denotation can be used to denote the whole
|
||||
// type with a standard semantic description as to what is
|
||||
// stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition
|
||||
// for pre-defined type denotations.
|
||||
string denotation = 6;
|
||||
}
|
||||
|
||||
// Operator Sets
|
||||
//
|
||||
// OperatorSets are uniquely identified by a (domain, opset_version) pair.
|
||||
message OperatorSetIdProto {
|
||||
// The domain of the operator set being identified.
|
||||
// The empty string ("") or absence of this field implies the operator
|
||||
// set that is defined as part of the ONNX specification.
|
||||
// This field MUST be present in this version of the IR when referring to any other operator set.
|
||||
string domain = 1;
|
||||
|
||||
// The version of the operator set being identified.
|
||||
// This field MUST be present in this version of the IR.
|
||||
int64 version = 2;
|
||||
}
|
||||
|
||||
// Operator/function status.
|
||||
enum OperatorStatus {
|
||||
EXPERIMENTAL = 0;
|
||||
STABLE = 1;
|
||||
}
|
||||
|
||||
message FunctionProto {
|
||||
// The name of the function, similar usage of op_type in OperatorProto.
|
||||
// Combined with FunctionProto.domain, this forms the unique identity of
|
||||
// the FunctionProto.
|
||||
string name = 1;
|
||||
|
||||
// Deprecated since IR Version 8
|
||||
// optional int64 since_version = 2;
|
||||
reserved 2;
|
||||
reserved "since_version";
|
||||
|
||||
// Deprecated since IR Version 8
|
||||
// optional OperatorStatus status = 3;
|
||||
reserved 3;
|
||||
reserved "status";
|
||||
|
||||
// The inputs and outputs of the function.
|
||||
repeated string input = 4;
|
||||
repeated string output = 5;
|
||||
|
||||
// The attribute parameters of the function.
|
||||
// It is for function parameters without default values.
|
||||
repeated string attribute = 6;
|
||||
|
||||
// The attribute protos of the function.
|
||||
// It is for function attributes with default values.
|
||||
// A function attribute shall be represented either as
|
||||
// a string attribute or an AttributeProto, not both.
|
||||
repeated AttributeProto attribute_proto = 11;
|
||||
|
||||
// The nodes in the function.
|
||||
repeated NodeProto node = 7;
|
||||
// A human-readable documentation for this function. Markdown is allowed.
|
||||
string doc_string = 8;
|
||||
|
||||
// The OperatorSets this function body (graph) relies on.
|
||||
//
|
||||
// All nodes in the function body (graph) will bind against the operator
|
||||
// with the same-domain/same-op_type operator with the HIGHEST version
|
||||
// in the referenced operator sets. This means at most one version can be relied
|
||||
// for one domain.
|
||||
//
|
||||
// The operator sets imported by FunctionProto should be compatible with the ones
|
||||
// imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto
|
||||
// and ModelProto then versions for the operator set may be different but,
|
||||
// the operator schema returned for op_type, domain, version combination
|
||||
// for both the versions should be same.
|
||||
|
||||
repeated OperatorSetIdProto opset_import = 9;
|
||||
|
||||
// The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of
|
||||
// the FunctionProto.
|
||||
string domain = 10;
|
||||
}
|
||||
|
||||
// For using protobuf-lite
|
||||
option optimize_for = LITE_RUNTIME;
|
||||
|
@ -17,6 +17,7 @@ crate-type = ["cdylib"]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
candle-onnx = {path= "../candle-onnx", version = "0.3.0", optional = true}
|
||||
half = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
|
||||
@ -29,3 +30,5 @@ default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
||||
cuda = ["candle/cuda"]
|
||||
mkl = ["dep:intel-mkl-src","candle/mkl"]
|
||||
onnx = ["dep:candle-onnx"]
|
||||
|
||||
|
5
candle-pyo3/py_src/candle/onnx/__init__.py
Normal file
5
candle-pyo3/py_src/candle/onnx/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# Generated content DO NOT EDIT
|
||||
from .. import onnx
|
||||
|
||||
ONNXModel = onnx.ONNXModel
|
||||
ONNXTensorDescription = onnx.ONNXTensorDescription
|
89
candle-pyo3/py_src/candle/onnx/__init__.pyi
Normal file
89
candle-pyo3/py_src/candle/onnx/__init__.pyi
Normal file
@ -0,0 +1,89 @@
|
||||
# Generated content DO NOT EDIT
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||
from os import PathLike
|
||||
from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
|
||||
from candle import Tensor, DType, QTensor
|
||||
|
||||
class ONNXModel:
|
||||
"""
|
||||
A wrapper around an ONNX model.
|
||||
"""
|
||||
|
||||
def __init__(self, path: str):
|
||||
pass
|
||||
@property
|
||||
def doc_string(self) -> str:
|
||||
"""
|
||||
The doc string of the model.
|
||||
"""
|
||||
pass
|
||||
@property
|
||||
def domain(self) -> str:
|
||||
"""
|
||||
The domain of the operator set of the model.
|
||||
"""
|
||||
pass
|
||||
def initializers(self) -> Dict[str, Tensor]:
|
||||
"""
|
||||
Get the weights of the model.
|
||||
"""
|
||||
pass
|
||||
@property
|
||||
def inputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
|
||||
"""
|
||||
The inputs of the model.
|
||||
"""
|
||||
pass
|
||||
@property
|
||||
def ir_version(self) -> int:
|
||||
"""
|
||||
The version of the IR this model targets.
|
||||
"""
|
||||
pass
|
||||
@property
|
||||
def model_version(self) -> int:
|
||||
"""
|
||||
The version of the model.
|
||||
"""
|
||||
pass
|
||||
@property
|
||||
def outputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
|
||||
"""
|
||||
The outputs of the model.
|
||||
"""
|
||||
pass
|
||||
@property
|
||||
def producer_name(self) -> str:
|
||||
"""
|
||||
The producer of the model.
|
||||
"""
|
||||
pass
|
||||
@property
|
||||
def producer_version(self) -> str:
|
||||
"""
|
||||
The version of the producer of the model.
|
||||
"""
|
||||
pass
|
||||
def run(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
"""
|
||||
Run the model on the given inputs.
|
||||
"""
|
||||
pass
|
||||
|
||||
class ONNXTensorDescription:
|
||||
"""
|
||||
A wrapper around an ONNX tensor description.
|
||||
"""
|
||||
|
||||
@property
|
||||
def dtype(self) -> DType:
|
||||
"""
|
||||
The data type of the tensor.
|
||||
"""
|
||||
pass
|
||||
@property
|
||||
def shape(self) -> Tuple[Union[int, str, Any]]:
|
||||
"""
|
||||
The shape of the tensor.
|
||||
"""
|
||||
pass
|
@ -19,12 +19,14 @@ extern crate accelerate_src;
|
||||
|
||||
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
|
||||
|
||||
mod utils;
|
||||
use utils::wrap_err;
|
||||
|
||||
mod shape;
|
||||
use shape::{PyShape, PyShapeWithHole};
|
||||
|
||||
pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
||||
PyErr::new::<PyValueError, _>(format!("{err:?}"))
|
||||
}
|
||||
#[cfg(feature = "onnx")]
|
||||
mod onnx;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[pyclass(name = "Tensor")]
|
||||
@ -69,11 +71,13 @@ impl PyDType {
|
||||
}
|
||||
|
||||
static CUDA_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
|
||||
static METAL_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
enum PyDevice {
|
||||
Cpu,
|
||||
Cuda,
|
||||
Metal,
|
||||
}
|
||||
|
||||
impl PyDevice {
|
||||
@ -81,7 +85,7 @@ impl PyDevice {
|
||||
match device {
|
||||
Device::Cpu => Self::Cpu,
|
||||
Device::Cuda(_) => Self::Cuda,
|
||||
Device::Metal(_) => unimplemented!(),
|
||||
Device::Metal(_) => Self::Metal,
|
||||
}
|
||||
}
|
||||
|
||||
@ -97,6 +101,15 @@ impl PyDevice {
|
||||
*device = Some(d.clone());
|
||||
Ok(d)
|
||||
}
|
||||
Self::Metal => {
|
||||
let mut device = METAL_DEVICE.lock().unwrap();
|
||||
if let Some(device) = device.as_ref() {
|
||||
return Ok(device.clone());
|
||||
};
|
||||
let d = Device::new_metal(0).map_err(wrap_err)?;
|
||||
*device = Some(d.clone());
|
||||
Ok(d)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -118,6 +131,7 @@ impl ToPyObject for PyDevice {
|
||||
let str = match self {
|
||||
PyDevice::Cpu => "cpu",
|
||||
PyDevice::Cuda => "cuda",
|
||||
PyDevice::Metal => "metal",
|
||||
};
|
||||
str.to_object(py)
|
||||
}
|
||||
@ -1560,6 +1574,14 @@ fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
fn candle_onnx_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
use onnx::{PyONNXModel, PyONNXTensorDescriptor};
|
||||
m.add_class::<PyONNXModel>()?;
|
||||
m.add_class::<PyONNXTensorDescriptor>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
let utils = PyModule::new(py, "utils")?;
|
||||
@ -1568,6 +1590,12 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
let nn = PyModule::new(py, "functional")?;
|
||||
candle_functional_m(py, nn)?;
|
||||
m.add_submodule(nn)?;
|
||||
#[cfg(feature = "onnx")]
|
||||
{
|
||||
let onnx = PyModule::new(py, "onnx")?;
|
||||
candle_onnx_m(py, onnx)?;
|
||||
m.add_submodule(onnx)?;
|
||||
}
|
||||
m.add_class::<PyTensor>()?;
|
||||
m.add_class::<PyQTensor>()?;
|
||||
m.add_class::<PyDType>()?;
|
||||
|
212
candle-pyo3/src/onnx.rs
Normal file
212
candle-pyo3/src/onnx.rs
Normal file
@ -0,0 +1,212 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::utils::wrap_err;
|
||||
use crate::{PyDType, PyTensor};
|
||||
use candle_onnx::eval::{dtype, get_tensor, simple_eval};
|
||||
use candle_onnx::onnx::tensor_proto::DataType;
|
||||
use candle_onnx::onnx::tensor_shape_proto::dimension::Value;
|
||||
use candle_onnx::onnx::type_proto::{Tensor as ONNXTensor, Value as ONNXValue};
|
||||
use candle_onnx::onnx::{ModelProto, ValueInfoProto};
|
||||
use pyo3::exceptions::PyValueError;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::{PyList, PyTuple};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[pyclass(name = "ONNXTensorDescription")]
|
||||
/// A wrapper around an ONNX tensor description.
|
||||
pub struct PyONNXTensorDescriptor(ONNXTensor);
|
||||
|
||||
#[pymethods]
|
||||
impl PyONNXTensorDescriptor {
|
||||
#[getter]
|
||||
/// The data type of the tensor.
|
||||
/// &RETURNS&: DType
|
||||
fn dtype(&self) -> PyResult<PyDType> {
|
||||
match DataType::try_from(self.0.elem_type) {
|
||||
Ok(dt) => match dtype(dt) {
|
||||
Some(dt) => Ok(PyDType(dt)),
|
||||
None => Err(PyValueError::new_err(format!(
|
||||
"unsupported 'value' data-type {dt:?}"
|
||||
))),
|
||||
},
|
||||
type_ => Err(PyValueError::new_err(format!(
|
||||
"unsupported input type {type_:?}"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// The shape of the tensor.
|
||||
/// &RETURNS&: Tuple[Union[int,str,Any]]
|
||||
fn shape(&self, py: Python) -> PyResult<Py<PyTuple>> {
|
||||
let shape = PyList::empty(py);
|
||||
if let Some(d) = &self.0.shape {
|
||||
for dim in d.dim.iter() {
|
||||
if let Some(value) = &dim.value {
|
||||
match value {
|
||||
Value::DimValue(v) => shape.append(*v)?,
|
||||
Value::DimParam(s) => shape.append(s.clone())?,
|
||||
};
|
||||
} else {
|
||||
return Err(PyValueError::new_err("None value in shape"));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(shape.to_tuple().into())
|
||||
}
|
||||
|
||||
fn __repr__(&self, py: Python) -> String {
|
||||
match (self.shape(py), self.dtype()) {
|
||||
(Ok(shape), Ok(dtype)) => format!(
|
||||
"TensorDescriptor[shape: {:?}, dtype: {:?}]",
|
||||
shape.to_string(),
|
||||
dtype.__str__()
|
||||
),
|
||||
(Err(_), Err(_)) => "TensorDescriptor[shape: unknown, dtype: unknown]".to_string(),
|
||||
(Err(_), Ok(dtype)) => format!(
|
||||
"TensorDescriptor[shape: unknown, dtype: {:?}]",
|
||||
dtype.__str__()
|
||||
),
|
||||
(Ok(shape), Err(_)) => format!(
|
||||
"TensorDescriptor[shape: {:?}, dtype: unknown]",
|
||||
shape.to_string()
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn __str__(&self, py: Python) -> String {
|
||||
self.__repr__(py)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[pyclass(name = "ONNXModel")]
|
||||
/// A wrapper around an ONNX model.
|
||||
pub struct PyONNXModel(ModelProto);
|
||||
|
||||
fn extract_tensor_descriptions(
|
||||
value_infos: &[ValueInfoProto],
|
||||
) -> HashMap<String, PyONNXTensorDescriptor> {
|
||||
let mut map = HashMap::new();
|
||||
for value_info in value_infos.iter() {
|
||||
let input_type = match &value_info.r#type {
|
||||
Some(input_type) => input_type,
|
||||
None => continue,
|
||||
};
|
||||
let input_type = match &input_type.value {
|
||||
Some(input_type) => input_type,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let tensor_type: &ONNXTensor = match input_type {
|
||||
ONNXValue::TensorType(tt) => tt,
|
||||
_ => continue,
|
||||
};
|
||||
map.insert(
|
||||
value_info.name.to_string(),
|
||||
PyONNXTensorDescriptor(tensor_type.clone()),
|
||||
);
|
||||
}
|
||||
map
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyONNXModel {
|
||||
#[new]
|
||||
#[pyo3(text_signature = "(self, path:str)")]
|
||||
/// Load an ONNX model from the given path.
|
||||
fn new(path: String) -> PyResult<Self> {
|
||||
let model: ModelProto = candle_onnx::read_file(path).map_err(wrap_err)?;
|
||||
Ok(PyONNXModel(model))
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// The version of the IR this model targets.
|
||||
/// &RETURNS&: int
|
||||
fn ir_version(&self) -> i64 {
|
||||
self.0.ir_version
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// The producer of the model.
|
||||
/// &RETURNS&: str
|
||||
fn producer_name(&self) -> String {
|
||||
self.0.producer_name.clone()
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// The version of the producer of the model.
|
||||
/// &RETURNS&: str
|
||||
fn producer_version(&self) -> String {
|
||||
self.0.producer_version.clone()
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// The domain of the operator set of the model.
|
||||
/// &RETURNS&: str
|
||||
fn domain(&self) -> String {
|
||||
self.0.domain.clone()
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// The version of the model.
|
||||
/// &RETURNS&: int
|
||||
fn model_version(&self) -> i64 {
|
||||
self.0.model_version
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// The doc string of the model.
|
||||
/// &RETURNS&: str
|
||||
fn doc_string(&self) -> String {
|
||||
self.0.doc_string.clone()
|
||||
}
|
||||
|
||||
/// Get the weights of the model.
|
||||
/// &RETURNS&: Dict[str, Tensor]
|
||||
fn initializers(&self) -> PyResult<HashMap<String, PyTensor>> {
|
||||
let mut map = HashMap::new();
|
||||
if let Some(graph) = self.0.graph.as_ref() {
|
||||
for tensor_description in graph.initializer.iter() {
|
||||
let tensor = get_tensor(tensor_description, tensor_description.name.as_str())
|
||||
.map_err(wrap_err)?;
|
||||
map.insert(tensor_description.name.to_string(), PyTensor(tensor));
|
||||
}
|
||||
}
|
||||
Ok(map)
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// The inputs of the model.
|
||||
/// &RETURNS&: Optional[Dict[str, ONNXTensorDescription]]
|
||||
fn inputs(&self) -> Option<HashMap<String, PyONNXTensorDescriptor>> {
|
||||
if let Some(graph) = self.0.graph.as_ref() {
|
||||
return Some(extract_tensor_descriptions(&graph.input));
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// The outputs of the model.
|
||||
/// &RETURNS&: Optional[Dict[str, ONNXTensorDescription]]
|
||||
fn outputs(&self) -> Option<HashMap<String, PyONNXTensorDescriptor>> {
|
||||
if let Some(graph) = self.0.graph.as_ref() {
|
||||
return Some(extract_tensor_descriptions(&graph.output));
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, inputs:Dict[str,Tensor])")]
|
||||
/// Run the model on the given inputs.
|
||||
/// &RETURNS&: Dict[str,Tensor]
|
||||
fn run(&self, inputs: HashMap<String, PyTensor>) -> PyResult<HashMap<String, PyTensor>> {
|
||||
let unwrapped_tensors = inputs.into_iter().map(|(k, v)| (k.clone(), v.0)).collect();
|
||||
|
||||
let result = simple_eval(&self.0, unwrapped_tensors).map_err(wrap_err)?;
|
||||
|
||||
Ok(result
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k.clone(), PyTensor(v)))
|
||||
.collect())
|
||||
}
|
||||
}
|
6
candle-pyo3/src/utils.rs
Normal file
6
candle-pyo3/src/utils.rs
Normal file
@ -0,0 +1,6 @@
|
||||
use pyo3::exceptions::PyValueError;
|
||||
use pyo3::prelude::*;
|
||||
|
||||
pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
||||
PyErr::new::<PyValueError, _>(format!("{err:?}"))
|
||||
}
|
@ -28,6 +28,5 @@ wav = { workspace = true }
|
||||
default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda"]
|
||||
metal = ["candle/metal", "candle-nn/metal"]
|
||||
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]
|
||||
|
@ -1,3 +1,4 @@
|
||||
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle_nn::{Embedding, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
@ -32,76 +33,6 @@ impl HiddenActLayer {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Linear {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
Self { weight, bias, span }
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Linear {
|
||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let w = match x.dims() {
|
||||
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
||||
_ => self.weight.t()?,
|
||||
};
|
||||
let x = x.matmul(&w)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => x.broadcast_add(bias),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LayerNorm {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
eps: f64,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl LayerNorm {
|
||||
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
|
||||
Self {
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
span,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for LayerNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
let (_bsize, _seq_len, hidden_size) = x.dims3()?;
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let x = x.broadcast_sub(&mean_x)?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||
let x = x_normed
|
||||
.to_dtype(x_dtype)?
|
||||
.broadcast_mul(&self.weight)?
|
||||
.broadcast_add(&self.bias)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
enum PositionEmbeddingType {
|
||||
@ -184,12 +115,6 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
let bias = vb.get(size2, "bias")?;
|
||||
Ok(Linear::new(weight, Some(bias)))
|
||||
}
|
||||
|
||||
struct Dropout {
|
||||
#[allow(dead_code)]
|
||||
pr: f64,
|
||||
@ -208,20 +133,6 @@ impl Module for Dropout {
|
||||
}
|
||||
}
|
||||
|
||||
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
|
||||
(Ok(weight), Ok(bias)) => (weight, bias),
|
||||
(Err(err), _) | (_, Err(err)) => {
|
||||
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
|
||||
(weight, bias)
|
||||
} else {
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(LayerNorm::new(weight, bias, eps))
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
|
||||
struct BertEmbeddings {
|
||||
word_embeddings: Embedding,
|
||||
|
@ -1,3 +1,4 @@
|
||||
use super::with_tracing::{linear_no_bias as linear, Linear};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
@ -81,21 +82,6 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
|
||||
// model.
|
||||
#[derive(Debug)]
|
||||
pub struct Linear {
|
||||
inner: candle_nn::Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Cache {
|
||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||
@ -150,12 +136,6 @@ impl Cache {
|
||||
}
|
||||
}
|
||||
|
||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
|
||||
Ok(Linear { inner, span })
|
||||
}
|
||||
|
||||
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, cfg.hidden_size))
|
||||
|
@ -156,6 +156,7 @@ impl CausalSelfAttention {
|
||||
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
|
||||
let x0 = x.narrow(D::Minus1, 0, 1)?;
|
||||
let x1 = x.narrow(D::Minus1, 1, 1)?;
|
||||
todo!("X {x1}");
|
||||
let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
|
||||
let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
|
||||
let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?;
|
||||
@ -173,6 +174,7 @@ impl CausalSelfAttention {
|
||||
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
||||
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||
todo!("X {q}");
|
||||
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
||||
|
||||
if self.cache.use_kv_cache {
|
||||
@ -295,6 +297,7 @@ impl Block {
|
||||
let residual = x;
|
||||
let x = self.rms_1.forward(x)?;
|
||||
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
||||
todo!("---X {}", x);
|
||||
let residual = &x;
|
||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
||||
Ok(x)
|
||||
@ -327,6 +330,7 @@ impl Llama {
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, _seq_len) = x.dims2()?;
|
||||
let mut x = self.wte.forward(x)?;
|
||||
//println!("Embeddings {}", self.wte.embeddings());
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
x = block.forward(&x, index_pos, block_idx)?;
|
||||
}
|
||||
|
@ -1,6 +1,5 @@
|
||||
#![allow(unused)]
|
||||
use super::with_tracing::{linear, linear_no_bias, Embedding, Linear};
|
||||
use candle::{Module, Result, Tensor};
|
||||
use super::with_tracing::{linear, Embedding, Linear};
|
||||
use candle::{Result, Tensor};
|
||||
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -170,7 +169,6 @@ impl Attention {
|
||||
kv_states: Option<&Tensor>,
|
||||
attn_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let is_cross_attn = kv_states.is_some();
|
||||
let (b_sz, tgt_len, _) = xs.dims3()?;
|
||||
let query_states = (xs.apply(&self.q_proj)? * self.scaling)?;
|
||||
let (key_states, value_states) = match kv_states {
|
||||
@ -259,6 +257,10 @@ impl EncoderLayer {
|
||||
.apply(&self.fc2)?;
|
||||
(xs + residual)?.apply(&self.final_layer_norm)
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.self_attn.reset_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -320,6 +322,11 @@ impl DecoderLayer {
|
||||
let xs = (xs + residual)?.apply(&self.final_layer_norm)?;
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.self_attn.reset_kv_cache();
|
||||
self.encoder_attn.reset_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -368,6 +375,12 @@ impl Encoder {
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
for layer in self.layers.iter_mut() {
|
||||
layer.reset_kv_cache()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -422,6 +435,12 @@ impl Decoder {
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
for layer in self.layers.iter_mut() {
|
||||
layer.reset_kv_cache()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -442,6 +461,11 @@ impl Model {
|
||||
decoder,
|
||||
})
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.encoder.reset_kv_cache();
|
||||
self.decoder.reset_kv_cache();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -489,4 +513,8 @@ impl MTModel {
|
||||
.apply(&self.lm_head)?
|
||||
.broadcast_add(&self.final_logits_bias)
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
self.model.reset_kv_cache();
|
||||
}
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ use std::collections::HashMap;
|
||||
|
||||
use candle::quantized::QTensor;
|
||||
use candle::quantized::{ggml_file, gguf_file};
|
||||
use candle::{Device, IndexOp, Result, Tensor, D};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, Module};
|
||||
|
||||
pub const MAX_SEQ_LEN: usize = 4096;
|
||||
@ -16,7 +16,7 @@ struct RmsNorm {
|
||||
impl RmsNorm {
|
||||
fn new(scale: QTensor, eps: f32) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
||||
let scale = scale.dequantize(scale.device())?;
|
||||
let scale = scale.dequantize(&Device::Cpu)?;
|
||||
let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64);
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
@ -79,8 +79,6 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
||||
impl LayerWeights {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let _enter = self.span_rot.enter();
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-cos");
|
||||
let _enter = span.enter();
|
||||
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
|
||||
let cos = self
|
||||
.cos
|
||||
@ -90,37 +88,21 @@ impl LayerWeights {
|
||||
.sin
|
||||
.narrow(0, index_pos, seq_len)?
|
||||
.reshape((seq_len, n_embd / 2, 1))?;
|
||||
drop(_enter);
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-broad");
|
||||
let _enter = span.enter();
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
|
||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
|
||||
drop(_enter);
|
||||
// This mimics the llama.cpp behavior.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105
|
||||
// The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension.
|
||||
// The resulting y0 and y1 are also interleaved with:
|
||||
// y0 = x0*cos - x1*sin
|
||||
// y1 = x0*sin + x1*cos
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-reshape");
|
||||
let _enter = span.enter();
|
||||
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
|
||||
let x0 = x.narrow(D::Minus1, 0, 1)?;
|
||||
let x1 = x.narrow(D::Minus1, 1, 1)?;
|
||||
drop(_enter);
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-broad-mul");
|
||||
let _enter = span.enter();
|
||||
let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
|
||||
let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
|
||||
drop(_enter);
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-cat");
|
||||
let _enter = span.enter();
|
||||
let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
|
||||
drop(_enter);
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-flatten");
|
||||
let _enter = span.enter();
|
||||
let rope = rope.flatten_from(D::Minus2)?;
|
||||
drop(_enter);
|
||||
Ok(rope)
|
||||
}
|
||||
|
||||
@ -130,7 +112,6 @@ impl LayerWeights {
|
||||
let q = self.attention_wq.forward(x)?;
|
||||
let k = self.attention_wk.forward(x)?;
|
||||
let v = self.attention_wv.forward(x)?;
|
||||
// println!("Q {:?} K {:?} V {:?}", q.dtype(), k.dtype(), v.dtype());
|
||||
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
||||
@ -164,12 +145,9 @@ impl LayerWeights {
|
||||
let v = self.repeat_kv(v)?;
|
||||
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
// println!("att {:?}", att.dtype());
|
||||
let mask = mask.broadcast_as(att.shape())?;
|
||||
// println!("mask {:?}", mask.dtype());
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = candle_nn::ops::softmax_last_dim(&att)?;
|
||||
// println!("att {:?} v {:?}", att.dtype(), v.dtype());
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||
@ -203,37 +181,28 @@ pub struct ModelWeights {
|
||||
span_output: tracing::Span,
|
||||
}
|
||||
|
||||
fn precomput_freqs_cis(
|
||||
head_dim: usize,
|
||||
freq_base: f32,
|
||||
device: &Device,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> {
|
||||
let theta: Vec<_> = (0..head_dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let range: Vec<f32> = (0..MAX_SEQ_LEN).map(|r| r as f32).collect();
|
||||
let idx_theta = Tensor::new(range.as_slice(), device)?
|
||||
let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((MAX_SEQ_LEN, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
// TODO This change avoids allocating on Metal and then casting since allocating directly on
|
||||
// CPU as f32 seems just as fast
|
||||
// let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||
// .to_dtype(DType::F32)?
|
||||
// .reshape((MAX_SEQ_LEN, 1))?
|
||||
// .matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
let cos = idx_theta.cos()?;
|
||||
let sin = idx_theta.sin()?;
|
||||
Ok((cos, sin))
|
||||
}
|
||||
|
||||
impl ModelWeights {
|
||||
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize, device: &Device) -> Result<Self> {
|
||||
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
|
||||
let cpu = &Device::Cpu;
|
||||
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
|
||||
let (cos, sin) = precomput_freqs_cis(head_dim, 10000., device)?;
|
||||
let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?;
|
||||
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
|
||||
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
|
||||
let output = ct.remove("output.weight")?;
|
||||
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
|
||||
@ -288,8 +257,8 @@ impl ModelWeights {
|
||||
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
||||
ct: gguf_file::Content,
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let cpu = &Device::Cpu;
|
||||
let md_get = |s: &str| match ct.metadata.get(s) {
|
||||
None => candle::bail!("cannot find {s} in metadata"),
|
||||
Some(v) => Ok(v),
|
||||
@ -307,31 +276,24 @@ impl ModelWeights {
|
||||
let rope_freq_base = md_get("llama.rope.freq_base")
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(10000f32);
|
||||
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
|
||||
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?;
|
||||
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||
let norm = RmsNorm::new(
|
||||
ct.tensor(reader, "output_norm.weight", device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
let output = ct.tensor(reader, "output.weight", device)?;
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight")?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
|
||||
let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?;
|
||||
let output = ct.tensor(reader, "output.weight")?;
|
||||
let mut layers = Vec::with_capacity(block_count);
|
||||
for layer_idx in 0..block_count {
|
||||
let prefix = format!("blk.{layer_idx}");
|
||||
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
|
||||
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
|
||||
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
|
||||
let attention_wo =
|
||||
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
|
||||
let feed_forward_w1 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
|
||||
let feed_forward_w2 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
|
||||
let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
|
||||
let attention_norm =
|
||||
ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?;
|
||||
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?;
|
||||
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?;
|
||||
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?;
|
||||
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?;
|
||||
let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?;
|
||||
let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?;
|
||||
let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?;
|
||||
let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?;
|
||||
let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?;
|
||||
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?;
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
||||
@ -369,14 +331,14 @@ impl ModelWeights {
|
||||
})
|
||||
}
|
||||
|
||||
fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> {
|
||||
fn mask(&mut self, t: usize) -> Result<Tensor> {
|
||||
if let Some(mask) = self.masks.get(&t) {
|
||||
Ok(mask.clone())
|
||||
} else {
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (t, t), device)?;
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
|
||||
self.masks.insert(t, mask.clone());
|
||||
Ok(mask)
|
||||
}
|
||||
@ -384,7 +346,7 @@ impl ModelWeights {
|
||||
|
||||
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len) = x.dims2()?;
|
||||
let mask = self.mask(seq_len, x.device())?;
|
||||
let mask = self.mask(seq_len)?;
|
||||
let _enter = self.span.enter();
|
||||
let mut layer_in = self.tok_embeddings.forward(x)?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
|
@ -65,6 +65,7 @@ pub struct Config {
|
||||
pub use_cache: bool,
|
||||
pub pad_token_id: usize,
|
||||
pub eos_token_id: usize,
|
||||
pub decoder_start_token_id: Option<usize>,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
@ -89,6 +90,7 @@ impl Default for Config {
|
||||
use_cache: true,
|
||||
pad_token_id: 0,
|
||||
eos_token_id: 1,
|
||||
decoder_start_token_id: Some(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -642,7 +644,12 @@ pub struct T5EncoderModel {
|
||||
|
||||
impl T5EncoderModel {
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||
let shared_vb = if vb.contains_key("shared.weight") {
|
||||
vb.pp("shared")
|
||||
} else {
|
||||
vb.pp("decoder").pp("embed_tokens")
|
||||
};
|
||||
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
|
||||
let shared = Arc::new(shared);
|
||||
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
|
||||
Ok(Self {
|
||||
@ -683,7 +690,12 @@ impl T5ForConditionalGeneration {
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
assert!(cfg.is_encoder_decoder);
|
||||
let d_model = cfg.d_model;
|
||||
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||
let shared_vb = if vb.contains_key("shared.weight") {
|
||||
vb.pp("shared")
|
||||
} else {
|
||||
vb.pp("decoder").pp("embed_tokens")
|
||||
};
|
||||
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
|
||||
let shared = Arc::new(shared);
|
||||
|
||||
let mut encoder_cfg = cfg.clone();
|
||||
|
@ -1,3 +1,4 @@
|
||||
pub use crate::models::with_tracing::Linear;
|
||||
use candle::{Result, Tensor};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
|
||||
@ -9,13 +10,11 @@ pub mod tiny_vit;
|
||||
pub mod transformer;
|
||||
|
||||
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
|
||||
let inner = if bias {
|
||||
candle_nn::linear(in_dim, out_dim, vb)?
|
||||
if bias {
|
||||
crate::models::with_tracing::linear(in_dim, out_dim, vb)
|
||||
} else {
|
||||
candle_nn::linear_no_bias(in_dim, out_dim, vb)?
|
||||
};
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
Ok(Linear { inner, span })
|
||||
crate::models::with_tracing::linear_no_bias(in_dim, out_dim, vb)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -85,16 +84,3 @@ impl Module for MlpBlock {
|
||||
.apply(&self.lin2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Linear {
|
||||
inner: candle_nn::Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Module for Linear {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
@ -102,6 +102,14 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ssd1b() -> Self {
|
||||
Self::sdxl()
|
||||
}
|
||||
|
||||
pub fn ssd1b2() -> Self {
|
||||
Self::sdxl2()
|
||||
}
|
||||
|
||||
// https://huggingface.co/warp-ai/wuerstchen/blob/main/text_encoder/config.json
|
||||
pub fn wuerstchen() -> Self {
|
||||
Self {
|
||||
|
@ -249,6 +249,71 @@ impl StableDiffusionConfig {
|
||||
)
|
||||
}
|
||||
|
||||
pub fn ssd1b(
|
||||
sliced_attention_size: Option<usize>,
|
||||
height: Option<usize>,
|
||||
width: Option<usize>,
|
||||
) -> Self {
|
||||
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
|
||||
out_channels,
|
||||
use_cross_attn,
|
||||
attention_head_dim,
|
||||
};
|
||||
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json
|
||||
let unet = unet_2d::UNet2DConditionModelConfig {
|
||||
blocks: vec![
|
||||
bc(320, None, 5),
|
||||
bc(640, Some(2), 10),
|
||||
bc(1280, Some(10), 20),
|
||||
],
|
||||
center_input_sample: false,
|
||||
cross_attention_dim: 2048,
|
||||
downsample_padding: 1,
|
||||
flip_sin_to_cos: true,
|
||||
freq_shift: 0.,
|
||||
layers_per_block: 2,
|
||||
mid_block_scale_factor: 1.,
|
||||
norm_eps: 1e-5,
|
||||
norm_num_groups: 32,
|
||||
sliced_attention_size,
|
||||
use_linear_projection: true,
|
||||
};
|
||||
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/vae/config.json
|
||||
let autoencoder = vae::AutoEncoderKLConfig {
|
||||
block_out_channels: vec![128, 256, 512, 512],
|
||||
layers_per_block: 2,
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
};
|
||||
let scheduler = ddim::DDIMSchedulerConfig {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let height = if let Some(height) = height {
|
||||
assert_eq!(height % 8, 0, "height has to be divisible by 8");
|
||||
height
|
||||
} else {
|
||||
1024
|
||||
};
|
||||
|
||||
let width = if let Some(width) = width {
|
||||
assert_eq!(width % 8, 0, "width has to be divisible by 8");
|
||||
width
|
||||
} else {
|
||||
1024
|
||||
};
|
||||
|
||||
Self {
|
||||
width,
|
||||
height,
|
||||
clip: clip::Config::ssd1b(),
|
||||
clip2: Some(clip::Config::ssd1b2()),
|
||||
autoencoder,
|
||||
scheduler,
|
||||
unet,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build_vae<P: AsRef<std::path::Path>>(
|
||||
&self,
|
||||
vae_weights: P,
|
||||
|
@ -63,6 +63,7 @@ pub struct Config {
|
||||
pub use_cache: bool,
|
||||
pub pad_token_id: usize,
|
||||
pub eos_token_id: usize,
|
||||
pub decoder_start_token_id: Option<usize>,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
@ -87,6 +88,7 @@ impl Default for Config {
|
||||
use_cache: true,
|
||||
pad_token_id: 0,
|
||||
eos_token_id: 1,
|
||||
decoder_start_token_id: Some(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -110,6 +112,7 @@ impl Config {
|
||||
num_heads: 12,
|
||||
num_layers: 12,
|
||||
pad_token_id: 0,
|
||||
decoder_start_token_id: Some(0),
|
||||
relative_attention_max_distance: 128,
|
||||
relative_attention_num_buckets: 32,
|
||||
use_cache: true,
|
||||
@ -667,7 +670,12 @@ pub struct T5EncoderModel {
|
||||
|
||||
impl T5EncoderModel {
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||
let shared_vb = if vb.contains_tensor("shared.weight") {
|
||||
vb.pp("shared")
|
||||
} else {
|
||||
vb.pp("decoder").pp("embed_tokens")
|
||||
};
|
||||
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
|
||||
let shared = Arc::new(shared);
|
||||
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
|
||||
Ok(Self {
|
||||
@ -708,7 +716,12 @@ impl T5ForConditionalGeneration {
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
assert!(cfg.is_encoder_decoder);
|
||||
let d_model = cfg.d_model;
|
||||
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||
let shared_vb = if vb.contains_tensor("shared.weight") {
|
||||
vb.pp("shared")
|
||||
} else {
|
||||
vb.pp("decoder").pp("embed_tokens")
|
||||
};
|
||||
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
|
||||
let shared = Arc::new(shared);
|
||||
|
||||
let mut encoder_cfg = cfg.clone();
|
||||
|
@ -198,13 +198,17 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
|
||||
mel
|
||||
}
|
||||
|
||||
pub fn pcm_to_mel<T: Float + std::fmt::Display>(samples: &[T], filters: &[T]) -> Vec<T> {
|
||||
pub fn pcm_to_mel<T: Float + std::fmt::Display>(
|
||||
cfg: &super::Config,
|
||||
samples: &[T],
|
||||
filters: &[T],
|
||||
) -> Vec<T> {
|
||||
log_mel_spectrogram_(
|
||||
samples,
|
||||
filters,
|
||||
super::N_FFT,
|
||||
super::HOP_LENGTH,
|
||||
super::N_MELS,
|
||||
cfg.num_mel_bins,
|
||||
false,
|
||||
)
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ pub struct Config {
|
||||
// pub n_text_state: usize,
|
||||
pub decoder_attention_heads: usize, // n_text_head
|
||||
pub decoder_layers: usize, // n_text_layer
|
||||
#[serde(default)]
|
||||
pub suppress_tokens: Vec<u32>,
|
||||
}
|
||||
|
||||
@ -26,7 +27,6 @@ pub const DTYPE: candle::DType = candle::DType::F32;
|
||||
// Audio parameters.
|
||||
pub const SAMPLE_RATE: usize = 16000;
|
||||
pub const N_FFT: usize = 400;
|
||||
pub const N_MELS: usize = 80;
|
||||
pub const HOP_LENGTH: usize = 160;
|
||||
pub const CHUNK_LENGTH: usize = 30;
|
||||
pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
|
||||
|
@ -1,4 +1,5 @@
|
||||
use super::Config;
|
||||
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
|
||||
use candle::{Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
||||
|
||||
@ -6,33 +7,6 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
//
|
||||
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
|
||||
// model.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Linear {
|
||||
inner: candle_nn::Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
let inner = candle_nn::linear(size1, size2, vb)?;
|
||||
Ok(Linear { inner, span })
|
||||
}
|
||||
|
||||
fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
|
||||
Ok(Linear { inner, span })
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
in_channels: usize,
|
||||
|
@ -124,3 +124,34 @@ impl std::fmt::Debug for QMatMul {
|
||||
write!(f, "QMatMul")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct LayerNorm {
|
||||
inner: candle_nn::LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl LayerNorm {
|
||||
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
|
||||
let inner = candle_nn::LayerNorm::new(weight, bias, eps);
|
||||
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
|
||||
Self { inner, span }
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for LayerNorm {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(xs)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn layer_norm<C: Into<candle_nn::LayerNormConfig>>(
|
||||
size: usize,
|
||||
c: C,
|
||||
vb: VarBuilder,
|
||||
) -> Result<LayerNorm> {
|
||||
let inner = candle_nn::layer_norm(size, c, vb)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
|
||||
Ok(LayerNorm { inner, span })
|
||||
}
|
||||
|
@ -10,12 +10,12 @@ pub struct VarBuilder {
|
||||
}
|
||||
|
||||
impl VarBuilder {
|
||||
pub fn from_gguf<P: AsRef<std::path::Path>>(p: P, device: &Device) -> Result<Self> {
|
||||
pub fn from_gguf<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
|
||||
let mut file = std::fs::File::open(p)?;
|
||||
let content = candle::quantized::gguf_file::Content::read(&mut file)?;
|
||||
let mut data = std::collections::HashMap::new();
|
||||
for tensor_name in content.tensor_infos.keys() {
|
||||
let tensor = content.tensor(&mut file, tensor_name, device)?;
|
||||
let tensor = content.tensor(&mut file, tensor_name)?;
|
||||
data.insert(tensor_name.to_string(), Arc::new(tensor));
|
||||
}
|
||||
Ok(Self {
|
||||
@ -25,12 +25,12 @@ impl VarBuilder {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_gguf_buffer(buffer: &[u8], device: &Device) -> Result<Self> {
|
||||
pub fn from_gguf_buffer(buffer: &[u8]) -> Result<Self> {
|
||||
let mut cursor = std::io::Cursor::new(buffer);
|
||||
let content = candle::quantized::gguf_file::Content::read(&mut cursor)?;
|
||||
let mut data = std::collections::HashMap::new();
|
||||
for tensor_name in content.tensor_infos.keys() {
|
||||
let tensor = content.tensor(&mut cursor, tensor_name, device)?;
|
||||
let tensor = content.tensor(&mut cursor, tensor_name)?;
|
||||
data.insert(tensor_name.to_string(), Arc::new(tensor));
|
||||
}
|
||||
Ok(Self {
|
||||
@ -90,4 +90,8 @@ impl VarBuilder {
|
||||
pub fn device(&self) -> &Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn contains_key(&self, key: &str) -> bool {
|
||||
self.data.contains_key(key)
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,7 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{embedding, linear, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
|
||||
use candle_nn::{
|
||||
embedding, linear_no_bias as linear, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
|
@ -31,6 +31,7 @@
|
||||
model: "model.safetensors",
|
||||
tokenizer: "tokenizer.json",
|
||||
config: "config.json",
|
||||
size: "151 MB",
|
||||
},
|
||||
tiny_en: {
|
||||
base_url:
|
||||
@ -38,20 +39,41 @@
|
||||
model: "model.safetensors",
|
||||
tokenizer: "tokenizer.json",
|
||||
config: "config.json",
|
||||
size: "151 MB",
|
||||
},
|
||||
tiny_quantized_multilingual_q80: {
|
||||
base_url: "https://huggingface.co/lmz/candle-whisper/resolve/main/",
|
||||
model: "model-tiny-q80.gguf",
|
||||
tokenizer: "tokenizer-tiny.json",
|
||||
config: "config-tiny.json",
|
||||
size: "41.5 MB",
|
||||
},
|
||||
tiny_en_quantized_q80: {
|
||||
base_url: "https://huggingface.co/lmz/candle-whisper/resolve/main/",
|
||||
model: "model-tiny-q80.gguf",
|
||||
tokenizer: "tokenizer-tiny-en.json",
|
||||
config: "config-tiny-en.json",
|
||||
size: "41.8 MB",
|
||||
},
|
||||
distil_medium_en: {
|
||||
base_url:
|
||||
"https://huggingface.co/distil-whisper/distil-medium.en/resolve/main/",
|
||||
model: "model.safetensors",
|
||||
tokenizer: "tokenizer.json",
|
||||
config: "config.json",
|
||||
size: "789 MB",
|
||||
},
|
||||
};
|
||||
|
||||
const modelEl = document.querySelector("#model");
|
||||
|
||||
Object.keys(MODELS).forEach((modelID) => {
|
||||
const model = MODELS[modelID];
|
||||
const option = document.createElement("option");
|
||||
option.value = modelID;
|
||||
option.textContent = `${modelID} (${model.size})`;
|
||||
modelEl.appendChild(option);
|
||||
});
|
||||
const whisperWorker = new Worker("./whisperWorker.js", {
|
||||
type: "module",
|
||||
});
|
||||
@ -150,7 +172,7 @@
|
||||
if (audioURL === null) {
|
||||
return;
|
||||
}
|
||||
const modelID = document.querySelector("#model").value;
|
||||
const modelID = modelEl.value;
|
||||
const model = MODELS[modelID];
|
||||
const modelURL = model.base_url + model.model;
|
||||
const tokenizerURL = model.base_url + model.tokenizer;
|
||||
@ -222,14 +244,6 @@
|
||||
<select
|
||||
id="model"
|
||||
class="border-2 border-gray-500 rounded-md font-light">
|
||||
<option value="tiny_multilingual" selected>tiny (151 MB)</option>
|
||||
<option value="tiny_en" selected>tiny.en (151 MB)</option>
|
||||
<option value="tiny_quantized_multilingual_q80">
|
||||
tiny quantized q80 (41.5 MB)
|
||||
</option>
|
||||
<option value="tiny_en_quantized_q80">
|
||||
tiny.en quantized q80 (41.8 MB)
|
||||
</option>
|
||||
</select>
|
||||
</div>
|
||||
<!-- drag and drop area -->
|
||||
|
@ -200,6 +200,7 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
|
||||
}
|
||||
|
||||
pub fn pcm_to_mel<T: Float + std::fmt::Display>(
|
||||
cfg: &worker::m::Config,
|
||||
samples: &[T],
|
||||
filters: &[T],
|
||||
) -> anyhow::Result<Vec<T>> {
|
||||
@ -208,7 +209,7 @@ pub fn pcm_to_mel<T: Float + std::fmt::Display>(
|
||||
filters,
|
||||
worker::m::N_FFT,
|
||||
worker::m::HOP_LENGTH,
|
||||
worker::m::N_MELS,
|
||||
cfg.num_mel_bins,
|
||||
false,
|
||||
);
|
||||
Ok(mel)
|
||||
|
@ -349,9 +349,10 @@ impl Decoder {
|
||||
.map(|v| *v as f32 / 32768.)
|
||||
.collect();
|
||||
console_log!("pcm data loaded {}", pcm_data.len());
|
||||
let mel = crate::audio::pcm_to_mel(&pcm_data, &self.mel_filters)?;
|
||||
let mel = crate::audio::pcm_to_mel(self.model.config(), &pcm_data, &self.mel_filters)?;
|
||||
let mel_len = mel.len();
|
||||
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
|
||||
let n_mels = self.model.config().num_mel_bins;
|
||||
let mel = Tensor::from_vec(mel, (1, n_mels, mel_len / n_mels), &device)?;
|
||||
console_log!("loaded mel: {:?}", mel.dims());
|
||||
let segments = self.run(&mel)?;
|
||||
Ok(segments)
|
||||
|
Reference in New Issue
Block a user