mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Compare commits
8 Commits
metal4_arc
...
tmp-metal-
Author | SHA1 | Date | |
---|---|---|---|
9a27f11c3f | |||
7161002a34 | |||
82cce52e73 | |||
71fcb31873 | |||
198009453a | |||
492d164235 | |||
2d84c16fed | |||
4525b7b52a |
2
.github/workflows/ci_cuda.yaml
vendored
2
.github/workflows/ci_cuda.yaml
vendored
@ -59,7 +59,7 @@ jobs:
|
|||||||
- name: Install Rust Stable
|
- name: Install Rust Stable
|
||||||
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
|
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
|
||||||
- uses: Swatinem/rust-cache@v2
|
- uses: Swatinem/rust-cache@v2
|
||||||
- run: apt-get update -y && apt-get install libssl-dev protobuf-compiler -y
|
- run: apt-get update -y && apt-get install libssl-dev -y
|
||||||
- name: Test (cuda)
|
- name: Test (cuda)
|
||||||
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
|
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
|
||||||
stop-runner:
|
stop-runner:
|
||||||
|
BIN
.github/workflows/maturin.yml
vendored
BIN
.github/workflows/maturin.yml
vendored
Binary file not shown.
8
.github/workflows/python.yml
vendored
8
.github/workflows/python.yml
vendored
@ -39,12 +39,6 @@ jobs:
|
|||||||
path: ~/.cargo/registry
|
path: ~/.cargo/registry
|
||||||
key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}
|
key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}
|
||||||
|
|
||||||
- name: Install Protoc
|
|
||||||
uses: arduino/setup-protoc@v2
|
|
||||||
with:
|
|
||||||
version: "25.0"
|
|
||||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
|
|
||||||
- name: Install
|
- name: Install
|
||||||
working-directory: ./candle-pyo3
|
working-directory: ./candle-pyo3
|
||||||
run: |
|
run: |
|
||||||
@ -52,7 +46,7 @@ jobs:
|
|||||||
source .env/bin/activate
|
source .env/bin/activate
|
||||||
pip install -U pip
|
pip install -U pip
|
||||||
pip install pytest maturin black
|
pip install pytest maturin black
|
||||||
python -m maturin develop -r --features onnx
|
python -m maturin develop -r
|
||||||
|
|
||||||
- name: Check style
|
- name: Check style
|
||||||
working-directory: ./candle-pyo3
|
working-directory: ./candle-pyo3
|
||||||
|
11
Cargo.toml
11
Cargo.toml
@ -10,12 +10,7 @@ members = [
|
|||||||
"candle-wasm-examples/*",
|
"candle-wasm-examples/*",
|
||||||
"candle-wasm-tests",
|
"candle-wasm-tests",
|
||||||
]
|
]
|
||||||
exclude = [
|
exclude = ["candle-flash-attn", "candle-kernels"]
|
||||||
"candle-flash-attn",
|
|
||||||
"candle-kernels",
|
|
||||||
"candle-metal-kernels",
|
|
||||||
"candle-onnx",
|
|
||||||
]
|
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
@ -51,7 +46,6 @@ rayon = "1.7.0"
|
|||||||
rusttype = { version = "0.9", default-features = false }
|
rusttype = { version = "0.9", default-features = false }
|
||||||
safetensors = "0.3.1"
|
safetensors = "0.3.1"
|
||||||
serde = { version = "1.0.171", features = ["derive"] }
|
serde = { version = "1.0.171", features = ["derive"] }
|
||||||
serde_plain = "1.0.2"
|
|
||||||
serde_json = "1.0.99"
|
serde_json = "1.0.99"
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
tokenizers = { version = "0.13.4", default-features = false }
|
tokenizers = { version = "0.13.4", default-features = false }
|
||||||
@ -61,7 +55,8 @@ tracing-subscriber = "0.3.7"
|
|||||||
wav = "1.0.0"
|
wav = "1.0.0"
|
||||||
yoke = { version = "0.7.2", features = ["derive"] }
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "0.6.6", default-features = false }
|
zip = { version = "0.6.6", default-features = false }
|
||||||
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||||
|
metal = { path = "../metal-rs", features = ["mps"] }
|
||||||
|
|
||||||
[profile.release-with-debug]
|
[profile.release-with-debug]
|
||||||
inherits = "release"
|
inherits = "release"
|
||||||
|
17
README.md
17
README.md
@ -51,7 +51,7 @@ For more advanced examples, please have a look at the following section.
|
|||||||
These online demos run entirely in your browser:
|
These online demos run entirely in your browser:
|
||||||
- [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and
|
- [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and
|
||||||
object recognition.
|
object recognition.
|
||||||
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): speech recognition.
|
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): text to speech.
|
||||||
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
|
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
|
||||||
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
|
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
|
||||||
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
|
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
|
||||||
@ -69,8 +69,6 @@ We also provide a some command line based examples using state of the art models
|
|||||||
performance larger than all publicly available 13b models as of 2023-09-28.
|
performance larger than all publicly available 13b models as of 2023-09-28.
|
||||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
||||||
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
||||||
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
|
|
||||||
(English/Chinese) general LLMs with 6b and 34b parameters.
|
|
||||||
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
|
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
|
||||||
the LLaMA model using the same quantization techniques as
|
the LLaMA model using the same quantization techniques as
|
||||||
[llama.cpp](https://github.com/ggerganov/llama.cpp).
|
[llama.cpp](https://github.com/ggerganov/llama.cpp).
|
||||||
@ -145,11 +143,6 @@ And then head over to
|
|||||||
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
|
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
|
||||||
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): a LoRA implementation
|
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): a LoRA implementation
|
||||||
that conforms to the official `peft` implementation.
|
that conforms to the official `peft` implementation.
|
||||||
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
|
|
||||||
serving local LLMs including an OpenAI compatible API server.
|
|
||||||
- [`candle-ext`](https://github.com/mokeyish/candle-ext): an extension library to Candle that provides PyTorch functions not currently available in Candle.
|
|
||||||
- [`kalosm`](https://github.com/floneum/floneum/tree/master/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
|
||||||
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
|
||||||
|
|
||||||
If you have an addition to this list, please submit a pull request.
|
If you have an addition to this list, please submit a pull request.
|
||||||
|
|
||||||
@ -175,17 +168,16 @@ If you have an addition to this list, please submit a pull request.
|
|||||||
- Mistral 7b v0.1.
|
- Mistral 7b v0.1.
|
||||||
- StableLM-3B-4E1T.
|
- StableLM-3B-4E1T.
|
||||||
- Replit-code-v1.5-3B.
|
- Replit-code-v1.5-3B.
|
||||||
|
- T5.
|
||||||
- Bert.
|
- Bert.
|
||||||
- Yi-6B and Yi-34B.
|
|
||||||
- Text to text.
|
|
||||||
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
|
|
||||||
- Marian MT (Machine Translation).
|
|
||||||
- Whisper (multi-lingual support).
|
- Whisper (multi-lingual support).
|
||||||
- Text to image.
|
- Text to image.
|
||||||
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
||||||
- Wurstchen v2.
|
- Wurstchen v2.
|
||||||
- Image to text.
|
- Image to text.
|
||||||
- BLIP.
|
- BLIP.
|
||||||
|
- Text to text.
|
||||||
|
- Marian MT (Machine Translation).
|
||||||
- Computer Vision Models.
|
- Computer Vision Models.
|
||||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
|
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
|
||||||
- yolo-v3, yolo-v8.
|
- yolo-v3, yolo-v8.
|
||||||
@ -226,7 +218,6 @@ Cheatsheet:
|
|||||||
- [candle-datasets](./candle-datasets/): Datasets and data loaders.
|
- [candle-datasets](./candle-datasets/): Datasets and data loaders.
|
||||||
- [candle-transformers](./candle-transformers): transformers-related utilities.
|
- [candle-transformers](./candle-transformers): transformers-related utilities.
|
||||||
- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer.
|
- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer.
|
||||||
- [candle-onnx](./candle-onnx/): ONNX model evaluation.
|
|
||||||
|
|
||||||
## FAQ
|
## FAQ
|
||||||
|
|
||||||
|
@ -30,6 +30,7 @@ safetensors = { workspace = true }
|
|||||||
thiserror = { workspace = true }
|
thiserror = { workspace = true }
|
||||||
yoke = { workspace = true }
|
yoke = { workspace = true }
|
||||||
zip = { workspace = true }
|
zip = { workspace = true }
|
||||||
|
tracing = { workspace = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
@ -41,4 +42,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
|
|||||||
cudnn = ["cuda", "cudarc/cudnn"]
|
cudnn = ["cuda", "cudarc/cudnn"]
|
||||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||||
metal = ["dep:metal", "dep:candle-metal-kernels"]
|
metal = ["dep:candle-metal-kernels", "dep:metal"]
|
||||||
|
@ -39,14 +39,6 @@ pub trait BackendStorage: Sized {
|
|||||||
_params: &crate::conv::ParamsConv1D,
|
_params: &crate::conv::ParamsConv1D,
|
||||||
) -> Result<Self>;
|
) -> Result<Self>;
|
||||||
|
|
||||||
fn conv_transpose1d(
|
|
||||||
&self,
|
|
||||||
_l: &Layout,
|
|
||||||
_kernel: &Self,
|
|
||||||
_kernel_l: &Layout,
|
|
||||||
_params: &crate::conv::ParamsConvTranspose1D,
|
|
||||||
) -> Result<Self>;
|
|
||||||
|
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
&self,
|
&self,
|
||||||
_l: &Layout,
|
_l: &Layout,
|
||||||
|
@ -15,17 +15,6 @@ fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
thread_local! {
|
|
||||||
static CANDLE_GRAD_DO_NOT_DETACH: bool = {
|
|
||||||
match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") {
|
|
||||||
Ok(s) => {
|
|
||||||
!s.is_empty() && s != "0"
|
|
||||||
},
|
|
||||||
Err(_) => false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
||||||
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
||||||
@ -68,11 +57,6 @@ impl Tensor {
|
|||||||
kernel: rhs,
|
kernel: rhs,
|
||||||
..
|
..
|
||||||
}
|
}
|
||||||
| Op::ConvTranspose1D {
|
|
||||||
arg: lhs,
|
|
||||||
kernel: rhs,
|
|
||||||
..
|
|
||||||
}
|
|
||||||
| Op::Conv2D {
|
| Op::Conv2D {
|
||||||
arg: lhs,
|
arg: lhs,
|
||||||
kernel: rhs,
|
kernel: rhs,
|
||||||
@ -166,16 +150,10 @@ impl Tensor {
|
|||||||
if node.is_variable() {
|
if node.is_variable() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
let grad = grads
|
let grad = grads.remove(node).unwrap();
|
||||||
.remove(node)
|
// TODO: We should perform all these operations in place (or at least not track the
|
||||||
.expect("candle internal error - grad not populated");
|
// whole graph). The only drawback would be if we wanted to support grad of grad but
|
||||||
// https://github.com/huggingface/candle/issues/1241
|
// this is out of scope.
|
||||||
// Ideally, we would make these operations in place where possible to ensure that we
|
|
||||||
// do not have to allocate too often. Here we just call `.detach` to avoid computing
|
|
||||||
// the backprop graph of the backprop itself. This would be an issue for second order
|
|
||||||
// derivatives but these are out of scope at the moment.
|
|
||||||
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
|
|
||||||
let grad = if do_not_detach { grad } else { grad.detach()? };
|
|
||||||
if let Some(op) = node.op() {
|
if let Some(op) = node.op() {
|
||||||
match op {
|
match op {
|
||||||
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
||||||
@ -230,44 +208,7 @@ impl Tensor {
|
|||||||
let f_grad = pred.where_cond(&zeros, &grad)?;
|
let f_grad = pred.where_cond(&zeros, &grad)?;
|
||||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||||
}
|
}
|
||||||
Op::Conv1D {
|
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
||||||
arg,
|
|
||||||
kernel,
|
|
||||||
padding,
|
|
||||||
stride,
|
|
||||||
dilation,
|
|
||||||
} => {
|
|
||||||
// The output height for conv_transpose1d is:
|
|
||||||
// (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1
|
|
||||||
let grad_l_in = grad.dim(2)?;
|
|
||||||
let k_size = kernel.dim(2)?;
|
|
||||||
let out_size =
|
|
||||||
(grad_l_in - 1) * stride + dilation * (k_size - 1) + 1 - 2 * padding;
|
|
||||||
let out_padding = arg.dim(2)? - out_size;
|
|
||||||
let grad_arg = grad.conv_transpose1d(
|
|
||||||
kernel,
|
|
||||||
*padding,
|
|
||||||
out_padding,
|
|
||||||
*stride,
|
|
||||||
*dilation,
|
|
||||||
)?;
|
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
|
||||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
|
||||||
|
|
||||||
let grad_kernel = arg
|
|
||||||
.transpose(0, 1)?
|
|
||||||
.conv1d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
|
||||||
.transpose(0, 1)?;
|
|
||||||
let sum_grad = grads.or_insert(kernel)?;
|
|
||||||
let (_, _, k0) = kernel.dims3()?;
|
|
||||||
let (_, _, g_k0) = grad_kernel.dims3()?;
|
|
||||||
let grad_kernel = if g_k0 != k0 {
|
|
||||||
grad_kernel.narrow(2, 0, k0)?
|
|
||||||
} else {
|
|
||||||
grad_kernel
|
|
||||||
};
|
|
||||||
*sum_grad = sum_grad.add(&grad_kernel)?;
|
|
||||||
}
|
|
||||||
Op::Conv2D {
|
Op::Conv2D {
|
||||||
arg,
|
arg,
|
||||||
kernel,
|
kernel,
|
||||||
@ -306,9 +247,6 @@ impl Tensor {
|
|||||||
};
|
};
|
||||||
*sum_grad = sum_grad.add(&grad_kernel)?;
|
*sum_grad = sum_grad.add(&grad_kernel)?;
|
||||||
}
|
}
|
||||||
Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
|
|
||||||
op: "conv-transpose1d",
|
|
||||||
})?,
|
|
||||||
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
||||||
op: "conv-transpose2d",
|
op: "conv-transpose2d",
|
||||||
})?,
|
})?,
|
||||||
@ -549,38 +487,16 @@ impl Tensor {
|
|||||||
+ 0.5)?;
|
+ 0.5)?;
|
||||||
*sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
|
*sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
|
||||||
}
|
}
|
||||||
Op::Unary(arg, UnaryOp::Erf) => {
|
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
Op::Unary(_, UnaryOp::GeluErf) => {
|
||||||
// d/dx erf(x) = 2/sqrt(pi) * e^(-x^2)
|
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
|
||||||
let erf_grad =
|
|
||||||
(2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?;
|
|
||||||
*sum_grad = sum_grad.add(&(&grad * erf_grad)?)?
|
|
||||||
}
|
|
||||||
Op::Unary(arg, UnaryOp::GeluErf) => {
|
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
|
||||||
// d/dx gelu_erf(x) = 0.5 + 0.398942 e^(-x^2/2) x + 0.5 erf(x/sqrt(2))
|
|
||||||
let neg_half_square = (arg.sqr()?.neg()? / 2.)?;
|
|
||||||
let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?;
|
|
||||||
let arg_scaled_sqrt = (arg / 2f64.sqrt())?;
|
|
||||||
let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?;
|
|
||||||
let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?;
|
|
||||||
*sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?;
|
|
||||||
}
|
}
|
||||||
Op::Unary(arg, UnaryOp::Relu) => {
|
Op::Unary(arg, UnaryOp::Relu) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||||
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
||||||
}
|
}
|
||||||
Op::Elu(arg, alpha) => {
|
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
|
||||||
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
|
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
|
||||||
let zeros = arg.zeros_like()?;
|
|
||||||
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
|
|
||||||
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
|
|
||||||
let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?;
|
|
||||||
let combined_mask = (positive_mask + negative_exp_mask)?;
|
|
||||||
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
|
|
||||||
}
|
|
||||||
Op::Powf(arg, e) => {
|
Op::Powf(arg, e) => {
|
||||||
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
|
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
@ -25,33 +25,6 @@ impl ParamsConv1D {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
pub struct ParamsConvTranspose1D {
|
|
||||||
pub(crate) b_size: usize,
|
|
||||||
pub(crate) l_in: usize,
|
|
||||||
pub(crate) c_out: usize,
|
|
||||||
pub(crate) c_in: usize,
|
|
||||||
pub(crate) k_size: usize,
|
|
||||||
pub(crate) padding: usize,
|
|
||||||
pub(crate) output_padding: usize,
|
|
||||||
pub(crate) stride: usize,
|
|
||||||
pub(crate) dilation: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ParamsConvTranspose1D {
|
|
||||||
pub(crate) fn l_out(&self) -> usize {
|
|
||||||
(self.l_in - 1) * self.stride - 2 * self.padding
|
|
||||||
+ self.dilation * (self.k_size - 1)
|
|
||||||
+ self.output_padding
|
|
||||||
+ 1
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
|
||||||
let l_out = self.l_out();
|
|
||||||
vec![self.b_size, self.c_out, l_out]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
pub enum CudnnFwdAlgo {
|
pub enum CudnnFwdAlgo {
|
||||||
ImplicitGemm,
|
ImplicitGemm,
|
||||||
@ -187,49 +160,6 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Applies a 1D transposed convolution over the input tensor.
|
|
||||||
pub fn conv_transpose1d(
|
|
||||||
&self,
|
|
||||||
kernel: &Self,
|
|
||||||
padding: usize,
|
|
||||||
output_padding: usize,
|
|
||||||
stride: usize,
|
|
||||||
dilation: usize,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let (b_size, c_in, l_in) = self.dims3()?;
|
|
||||||
let (c_in_k, c_out, k_size) = kernel.dims3()?;
|
|
||||||
if c_in != c_in_k {
|
|
||||||
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
|
||||||
}
|
|
||||||
let params = ParamsConvTranspose1D {
|
|
||||||
b_size,
|
|
||||||
l_in,
|
|
||||||
k_size,
|
|
||||||
c_out,
|
|
||||||
c_in,
|
|
||||||
padding,
|
|
||||||
output_padding,
|
|
||||||
stride,
|
|
||||||
dilation,
|
|
||||||
};
|
|
||||||
let storage = self.storage().conv_transpose1d(
|
|
||||||
self.layout(),
|
|
||||||
&kernel.storage(),
|
|
||||||
kernel.layout(),
|
|
||||||
¶ms,
|
|
||||||
)?;
|
|
||||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
|
|
||||||
arg,
|
|
||||||
kernel,
|
|
||||||
padding: params.padding,
|
|
||||||
output_padding: params.output_padding,
|
|
||||||
stride: params.stride,
|
|
||||||
dilation: params.dilation,
|
|
||||||
});
|
|
||||||
let out_dims = params.out_dims();
|
|
||||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
|
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
|
||||||
let storage =
|
let storage =
|
||||||
self.storage()
|
self.storage()
|
||||||
|
@ -1256,74 +1256,6 @@ impl Map1 for Im2Col {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
|
||||||
|
|
||||||
impl<'a> Map2 for ConvTranspose1D<'a> {
|
|
||||||
const OP: &'static str = "conv_transpose1d";
|
|
||||||
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
|
||||||
let p = self.0;
|
|
||||||
let inp = &inp[inp_l.start_offset()..];
|
|
||||||
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
|
|
||||||
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
|
|
||||||
let l_out = p.l_out();
|
|
||||||
|
|
||||||
// Output shape: [b_size, c_out, l_out].
|
|
||||||
let dst_elems = p.c_out * l_out * p.b_size;
|
|
||||||
let dst = vec![T::zero(); dst_elems];
|
|
||||||
let dst_s0 = p.c_out * l_out;
|
|
||||||
let dst_s1 = l_out;
|
|
||||||
let dst_s2 = 1;
|
|
||||||
|
|
||||||
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
|
|
||||||
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
|
|
||||||
let cont_s0 = p.l_in * p.c_in;
|
|
||||||
let cont_s1 = p.c_in;
|
|
||||||
for b_idx in 0..p.b_size {
|
|
||||||
for l_idx in 0..p.l_in {
|
|
||||||
for c_idx in 0..p.c_in {
|
|
||||||
let src_idx = b_idx * inp_s0 + c_idx * inp_s1 + l_idx * inp_s2;
|
|
||||||
let dst_idx = b_idx * cont_s0 + l_idx * cont_s1 + c_idx;
|
|
||||||
inp_cont[dst_idx] = inp[src_idx]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for k_idx in 0..p.k_size {
|
|
||||||
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
|
|
||||||
let k_cont = (0..p.c_in)
|
|
||||||
.map(|c_in_idx| k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_idx * k_s2])
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
for b_idx in 0..p.b_size {
|
|
||||||
for l_idx in 0..p.l_in {
|
|
||||||
let out_idx = l_idx * p.stride + k_idx * p.dilation;
|
|
||||||
if out_idx < p.padding {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
let out_idx = out_idx - p.padding;
|
|
||||||
if out_idx < l_out {
|
|
||||||
let inp_cont = &inp_cont[b_idx * cont_s0 + l_idx * cont_s1..];
|
|
||||||
let dst_idx = b_idx * dst_s0 + out_idx * dst_s2 + dst_c_idx * dst_s1;
|
|
||||||
let mut d = T::zero();
|
|
||||||
unsafe {
|
|
||||||
T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
|
|
||||||
}
|
|
||||||
let dst_p = dst.as_ptr();
|
|
||||||
// Safety: dst_idx are uniques per dst_c_idx which is used to
|
|
||||||
// parallelise the different tasks so no two threads can try to
|
|
||||||
// write at the same location.
|
|
||||||
unsafe {
|
|
||||||
let ptr = dst_p.add(dst_idx) as *mut T;
|
|
||||||
*ptr += d
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
Ok(dst)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
||||||
|
|
||||||
impl<'a> Map2 for Conv2D<'a> {
|
impl<'a> Map2 for Conv2D<'a> {
|
||||||
@ -2503,16 +2435,6 @@ impl BackendStorage for CpuStorage {
|
|||||||
Ok(res_t)
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv_transpose1d(
|
|
||||||
&self,
|
|
||||||
l: &Layout,
|
|
||||||
kernel: &Self,
|
|
||||||
kernel_l: &Layout,
|
|
||||||
params: &crate::conv::ParamsConvTranspose1D,
|
|
||||||
) -> Result<Self> {
|
|
||||||
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
&self,
|
&self,
|
||||||
l: &Layout,
|
l: &Layout,
|
||||||
|
@ -1808,16 +1808,6 @@ impl BackendStorage for CudaStorage {
|
|||||||
Ok(res_t)
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv_transpose1d(
|
|
||||||
&self,
|
|
||||||
_: &Layout,
|
|
||||||
_: &Self,
|
|
||||||
_: &Layout,
|
|
||||||
_: &crate::conv::ParamsConvTranspose1D,
|
|
||||||
) -> Result<Self> {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(not(feature = "cudnn"))]
|
#[cfg(not(feature = "cudnn"))]
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
&self,
|
&self,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use crate::backend::BackendDevice;
|
use crate::backend::BackendDevice;
|
||||||
use crate::cpu_backend::CpuDevice;
|
use crate::cpu_backend::CpuDevice;
|
||||||
use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
|
use crate::{bail, CpuStorage, DType, Result, Shape, Storage, WithDType};
|
||||||
|
|
||||||
/// A `DeviceLocation` represents a physical device whereas multiple `Device`
|
/// A `DeviceLocation` represents a physical device whereas multiple `Device`
|
||||||
/// can live on the same location (typically for cuda devices).
|
/// can live on the same location (typically for cuda devices).
|
||||||
@ -8,7 +8,7 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
|
|||||||
pub enum DeviceLocation {
|
pub enum DeviceLocation {
|
||||||
Cpu,
|
Cpu,
|
||||||
Cuda { gpu_id: usize },
|
Cuda { gpu_id: usize },
|
||||||
Metal { gpu_id: usize },
|
Metal,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -105,14 +105,14 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4:
|
|||||||
impl<S: NdArray> NdArray for Vec<S> {
|
impl<S: NdArray> NdArray for Vec<S> {
|
||||||
fn shape(&self) -> Result<Shape> {
|
fn shape(&self) -> Result<Shape> {
|
||||||
if self.is_empty() {
|
if self.is_empty() {
|
||||||
crate::bail!("empty array")
|
bail!("empty array")
|
||||||
}
|
}
|
||||||
let shape0 = self[0].shape()?;
|
let shape0 = self[0].shape()?;
|
||||||
let n = self.len();
|
let n = self.len();
|
||||||
for v in self.iter() {
|
for v in self.iter() {
|
||||||
let shape = v.shape()?;
|
let shape = v.shape()?;
|
||||||
if shape != shape0 {
|
if shape != shape0 {
|
||||||
crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
|
bail!("two elements have different shapes {shape:?} {shape0:?}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
|
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
|
||||||
@ -146,7 +146,6 @@ impl Device {
|
|||||||
match (self, rhs) {
|
match (self, rhs) {
|
||||||
(Self::Cpu, Self::Cpu) => true,
|
(Self::Cpu, Self::Cpu) => true,
|
||||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
|
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
|
||||||
(Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
|
|
||||||
_ => false,
|
_ => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -167,10 +166,6 @@ impl Device {
|
|||||||
matches!(self, Self::Cuda(_))
|
matches!(self, Self::Cuda(_))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_metal(&self) -> bool {
|
|
||||||
matches!(self, Self::Metal(_))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
||||||
if crate::utils::cuda_is_available() {
|
if crate::utils::cuda_is_available() {
|
||||||
Self::new_cuda(ordinal)
|
Self::new_cuda(ordinal)
|
||||||
@ -192,19 +187,13 @@ impl Device {
|
|||||||
Ok(Storage::Cpu(storage))
|
Ok(Storage::Cpu(storage))
|
||||||
}
|
}
|
||||||
Device::Cuda(device) => {
|
Device::Cuda(device) => {
|
||||||
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
|
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||||
if dtype == DType::F16 || dtype == DType::BF16 {
|
Ok(Storage::Cuda(storage))
|
||||||
let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
|
|
||||||
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
|
|
||||||
} else {
|
|
||||||
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
|
||||||
Ok(Storage::Cuda(storage))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Device::Metal(_device) => {
|
Device::Metal(_device) => {
|
||||||
// let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
// let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||||
// Ok(Storage::Metal(storage))
|
// Ok(Storage::Metal(storage))
|
||||||
crate::bail!("Metal rand_uniform not implemented")
|
bail!("Metal rand_uniform not implemented")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -231,14 +220,8 @@ impl Device {
|
|||||||
Ok(Storage::Cpu(storage))
|
Ok(Storage::Cpu(storage))
|
||||||
}
|
}
|
||||||
Device::Cuda(device) => {
|
Device::Cuda(device) => {
|
||||||
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
|
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||||
if dtype == DType::F16 || dtype == DType::BF16 {
|
Ok(Storage::Cuda(storage))
|
||||||
let storage = device.rand_normal(shape, DType::F32, mean, std)?;
|
|
||||||
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
|
|
||||||
} else {
|
|
||||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
|
||||||
Ok(Storage::Cuda(storage))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Device::Metal(device) => {
|
Device::Metal(device) => {
|
||||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||||
|
@ -14,9 +14,7 @@ impl Tensor {
|
|||||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||||
format!(", cuda:{}", gpu_id)
|
format!(", cuda:{}", gpu_id)
|
||||||
}
|
}
|
||||||
crate::DeviceLocation::Metal { gpu_id } => {
|
_ => todo!(),
|
||||||
format!(", metal:{}", gpu_id)
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
write!(f, "Tensor[")?;
|
write!(f, "Tensor[")?;
|
||||||
@ -479,9 +477,7 @@ impl std::fmt::Display for Tensor {
|
|||||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||||
format!(", cuda:{}", gpu_id)
|
format!(", cuda:{}", gpu_id)
|
||||||
}
|
}
|
||||||
crate::DeviceLocation::Metal { gpu_id } => {
|
crate::DeviceLocation::Metal => todo!(),
|
||||||
format!(", metal:{}", gpu_id)
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
write!(
|
write!(
|
||||||
|
@ -79,16 +79,6 @@ impl crate::backend::BackendStorage for CudaStorage {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv_transpose1d(
|
|
||||||
&self,
|
|
||||||
_: &Layout,
|
|
||||||
_: &Self,
|
|
||||||
_: &Layout,
|
|
||||||
_: &crate::conv::ParamsConvTranspose1D,
|
|
||||||
) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
&self,
|
&self,
|
||||||
_: &Layout,
|
_: &Layout,
|
||||||
|
@ -8,18 +8,6 @@ pub struct MetalDevice;
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct MetalStorage;
|
pub struct MetalStorage;
|
||||||
|
|
||||||
#[derive(thiserror::Error, Debug)]
|
|
||||||
pub enum MetalError {
|
|
||||||
#[error("{0}")]
|
|
||||||
Message(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<String> for MetalError {
|
|
||||||
fn from(e: String) -> Self {
|
|
||||||
MetalError::Message(e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
macro_rules! fail {
|
macro_rules! fail {
|
||||||
() => {
|
() => {
|
||||||
unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
|
unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
|
||||||
@ -91,16 +79,6 @@ impl crate::backend::BackendStorage for MetalStorage {
|
|||||||
Err(Error::NotCompiledWithMetalSupport)
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv_transpose1d(
|
|
||||||
&self,
|
|
||||||
_l: &Layout,
|
|
||||||
_kernel: &Self,
|
|
||||||
_kernel_l: &Layout,
|
|
||||||
_params: &crate::conv::ParamsConvTranspose1D,
|
|
||||||
) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
&self,
|
&self,
|
||||||
_: &Layout,
|
_: &Layout,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::{DType, DeviceLocation, Layout, MetalError, Shape};
|
use crate::{DType, DeviceLocation, Layout, Shape};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct MatMulUnexpectedStriding {
|
pub struct MatMulUnexpectedStriding {
|
||||||
@ -163,7 +163,7 @@ pub enum Error {
|
|||||||
Cuda(Box<dyn std::error::Error + Send + Sync>),
|
Cuda(Box<dyn std::error::Error + Send + Sync>),
|
||||||
|
|
||||||
#[error("Metal error {0}")]
|
#[error("Metal error {0}")]
|
||||||
Metal(#[from] MetalError),
|
Metal(String),
|
||||||
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
TryFromIntError(#[from] core::num::TryFromIntError),
|
TryFromIntError(#[from] core::num::TryFromIntError),
|
||||||
|
@ -49,12 +49,13 @@ mod device;
|
|||||||
pub mod display;
|
pub mod display;
|
||||||
mod dtype;
|
mod dtype;
|
||||||
mod dummy_cuda_backend;
|
mod dummy_cuda_backend;
|
||||||
mod dummy_metal_backend;
|
|
||||||
pub mod error;
|
pub mod error;
|
||||||
mod indexer;
|
mod indexer;
|
||||||
pub mod layout;
|
pub mod layout;
|
||||||
#[cfg(feature = "metal")]
|
#[cfg(feature = "metal")]
|
||||||
pub mod metal_backend;
|
pub mod metal_backend;
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
mod metal_backend;
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
mod mkl;
|
mod mkl;
|
||||||
pub mod npy;
|
pub mod npy;
|
||||||
@ -91,10 +92,10 @@ pub use cuda_backend::{CudaDevice, CudaStorage};
|
|||||||
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
|
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
|
||||||
|
|
||||||
#[cfg(feature = "metal")]
|
#[cfg(feature = "metal")]
|
||||||
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
|
pub use metal_backend::{MetalDevice, MetalStorage};
|
||||||
|
|
||||||
#[cfg(not(feature = "metal"))]
|
#[cfg(not(feature = "metal"))]
|
||||||
pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage};
|
pub use dummy_metal_backend::{MetalDevice, MetalStorage};
|
||||||
|
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -90,16 +90,6 @@ pub enum Op {
|
|||||||
dilation: usize,
|
dilation: usize,
|
||||||
},
|
},
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
ConvTranspose1D {
|
|
||||||
arg: Tensor,
|
|
||||||
kernel: Tensor,
|
|
||||||
padding: usize,
|
|
||||||
output_padding: usize,
|
|
||||||
stride: usize,
|
|
||||||
dilation: usize,
|
|
||||||
},
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
Conv2D {
|
Conv2D {
|
||||||
arg: Tensor,
|
arg: Tensor,
|
||||||
@ -593,8 +583,7 @@ unary_op!(Recip, "recip", v, v.recip());
|
|||||||
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
||||||
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
||||||
|
|
||||||
/// Tanh based approximation of the `gelu` operation
|
/// `gelu` operation
|
||||||
/// GeluErf is the more precise one.
|
|
||||||
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||||
impl UnaryOpT for Gelu {
|
impl UnaryOpT for Gelu {
|
||||||
const NAME: &'static str = "gelu";
|
const NAME: &'static str = "gelu";
|
||||||
@ -684,8 +673,6 @@ impl UnaryOpT for Gelu {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// `erf` operation
|
|
||||||
/// <https://en.wikipedia.org/wiki/Error_function>
|
|
||||||
impl UnaryOpT for Erf {
|
impl UnaryOpT for Erf {
|
||||||
const NAME: &'static str = "erf";
|
const NAME: &'static str = "erf";
|
||||||
const KERNEL: &'static str = "uerf";
|
const KERNEL: &'static str = "uerf";
|
||||||
@ -975,10 +962,6 @@ impl BackpropOp {
|
|||||||
};
|
};
|
||||||
Self(op)
|
Self(op)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn is_none(&self) -> bool {
|
|
||||||
self.0.is_none()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::ops::Deref for BackpropOp {
|
impl std::ops::Deref for BackpropOp {
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
//! Support for the GGML file format.
|
//! Support for the GGML file format.
|
||||||
|
|
||||||
use super::{k_quants, GgmlDType};
|
use super::{k_quants, GgmlDType};
|
||||||
use crate::Result;
|
use crate::{Device, Result};
|
||||||
use byteorder::{LittleEndian, ReadBytesExt};
|
use byteorder::{LittleEndian, ReadBytesExt};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
@ -121,11 +121,12 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
|
|||||||
raw_data: &[u8],
|
raw_data: &[u8],
|
||||||
size_in_bytes: usize,
|
size_in_bytes: usize,
|
||||||
dims: Vec<usize>,
|
dims: Vec<usize>,
|
||||||
|
device: &Device,
|
||||||
) -> Result<super::QTensor> {
|
) -> Result<super::QTensor> {
|
||||||
let raw_data_ptr = raw_data.as_ptr();
|
let raw_data_ptr = raw_data.as_ptr();
|
||||||
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
|
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
|
||||||
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
||||||
super::QTensor::new(data.to_vec(), dims)
|
super::QTensor::new(data.to_vec(), dims, device)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a [Tensor] from a raw GGML tensor.
|
/// Creates a [Tensor] from a raw GGML tensor.
|
||||||
@ -133,6 +134,7 @@ pub fn qtensor_from_ggml(
|
|||||||
ggml_dtype: GgmlDType,
|
ggml_dtype: GgmlDType,
|
||||||
raw_data: &[u8],
|
raw_data: &[u8],
|
||||||
dims: Vec<usize>,
|
dims: Vec<usize>,
|
||||||
|
device: &Device,
|
||||||
) -> Result<super::QTensor> {
|
) -> Result<super::QTensor> {
|
||||||
let tensor_elems = dims.iter().product::<usize>();
|
let tensor_elems = dims.iter().product::<usize>();
|
||||||
let blck_size = ggml_dtype.blck_size();
|
let blck_size = ggml_dtype.blck_size();
|
||||||
@ -144,18 +146,38 @@ pub fn qtensor_from_ggml(
|
|||||||
let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
|
let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
|
||||||
|
|
||||||
match ggml_dtype {
|
match ggml_dtype {
|
||||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
|
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
|
||||||
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
|
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
|
||||||
GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
|
GgmlDType::Q4_0 => {
|
||||||
GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
|
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
|
||||||
GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
|
}
|
||||||
GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
|
GgmlDType::Q4_1 => {
|
||||||
GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
|
from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device)
|
||||||
GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
|
}
|
||||||
GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
|
GgmlDType::Q5_0 => {
|
||||||
GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
|
from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device)
|
||||||
GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
|
}
|
||||||
GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
|
GgmlDType::Q5_1 => {
|
||||||
|
from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
|
||||||
|
}
|
||||||
|
GgmlDType::Q8_0 => {
|
||||||
|
from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
|
||||||
|
}
|
||||||
|
GgmlDType::Q2K => {
|
||||||
|
from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
|
||||||
|
}
|
||||||
|
GgmlDType::Q3K => {
|
||||||
|
from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
|
||||||
|
}
|
||||||
|
GgmlDType::Q4K => {
|
||||||
|
from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
|
||||||
|
}
|
||||||
|
GgmlDType::Q5K => {
|
||||||
|
from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
|
||||||
|
}
|
||||||
|
GgmlDType::Q6K => {
|
||||||
|
from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
|
||||||
|
}
|
||||||
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
|
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -163,6 +185,7 @@ pub fn qtensor_from_ggml(
|
|||||||
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||||
reader: &mut R,
|
reader: &mut R,
|
||||||
magic: VersionedMagic,
|
magic: VersionedMagic,
|
||||||
|
device: &Device,
|
||||||
) -> Result<(String, super::QTensor)> {
|
) -> Result<(String, super::QTensor)> {
|
||||||
let n_dims = reader.read_u32::<LittleEndian>()?;
|
let n_dims = reader.read_u32::<LittleEndian>()?;
|
||||||
let name_len = reader.read_u32::<LittleEndian>()?;
|
let name_len = reader.read_u32::<LittleEndian>()?;
|
||||||
@ -187,7 +210,7 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
|||||||
// TODO: Mmap version to avoid copying the data around?
|
// TODO: Mmap version to avoid copying the data around?
|
||||||
let mut raw_data = vec![0u8; size_in_bytes];
|
let mut raw_data = vec![0u8; size_in_bytes];
|
||||||
reader.read_exact(&mut raw_data)?;
|
reader.read_exact(&mut raw_data)?;
|
||||||
match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
|
match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) {
|
||||||
Ok(tensor) => Ok((name, tensor)),
|
Ok(tensor) => Ok((name, tensor)),
|
||||||
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
|
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
|
||||||
}
|
}
|
||||||
@ -201,7 +224,10 @@ pub struct Content {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Content {
|
impl Content {
|
||||||
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
|
pub fn read<R: std::io::Seek + std::io::Read>(
|
||||||
|
reader: &mut R,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Content> {
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
|
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
|
||||||
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
|
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
|
||||||
reader.seek(std::io::SeekFrom::Start(0))?;
|
reader.seek(std::io::SeekFrom::Start(0))?;
|
||||||
@ -211,7 +237,7 @@ impl Content {
|
|||||||
let mut tensors = HashMap::new();
|
let mut tensors = HashMap::new();
|
||||||
|
|
||||||
while reader.stream_position()? != last_position {
|
while reader.stream_position()? != last_position {
|
||||||
let (name, tensor) = read_one_tensor(reader, magic)?;
|
let (name, tensor) = read_one_tensor(reader, magic, device)?;
|
||||||
tensors.insert(name, tensor);
|
tensors.insert(name, tensor);
|
||||||
}
|
}
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
|
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
|
||||||
|
|
||||||
use super::{GgmlDType, QTensor};
|
use super::{GgmlDType, QTensor};
|
||||||
use crate::Result;
|
use crate::{Device, Result};
|
||||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
@ -29,7 +29,6 @@ impl TryFrom<u32> for Magic {
|
|||||||
pub enum VersionedMagic {
|
pub enum VersionedMagic {
|
||||||
GgufV1,
|
GgufV1,
|
||||||
GgufV2,
|
GgufV2,
|
||||||
GgufV3,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VersionedMagic {
|
impl VersionedMagic {
|
||||||
@ -40,7 +39,6 @@ impl VersionedMagic {
|
|||||||
let versioned_magic = match (magic, version) {
|
let versioned_magic = match (magic, version) {
|
||||||
(Magic::Gguf, 1) => Self::GgufV1,
|
(Magic::Gguf, 1) => Self::GgufV1,
|
||||||
(Magic::Gguf, 2) => Self::GgufV2,
|
(Magic::Gguf, 2) => Self::GgufV2,
|
||||||
(Magic::Gguf, 3) => Self::GgufV3,
|
|
||||||
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
|
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
|
||||||
};
|
};
|
||||||
Ok(versioned_magic)
|
Ok(versioned_magic)
|
||||||
@ -59,6 +57,7 @@ impl TensorInfo {
|
|||||||
&self,
|
&self,
|
||||||
reader: &mut R,
|
reader: &mut R,
|
||||||
tensor_data_offset: u64,
|
tensor_data_offset: u64,
|
||||||
|
device: &Device,
|
||||||
) -> Result<QTensor> {
|
) -> Result<QTensor> {
|
||||||
let tensor_elems = self.shape.elem_count();
|
let tensor_elems = self.shape.elem_count();
|
||||||
let blck_size = self.ggml_dtype.blck_size();
|
let blck_size = self.ggml_dtype.blck_size();
|
||||||
@ -71,7 +70,12 @@ impl TensorInfo {
|
|||||||
let mut raw_data = vec![0u8; size_in_bytes];
|
let mut raw_data = vec![0u8; size_in_bytes];
|
||||||
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
|
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
|
||||||
reader.read_exact(&mut raw_data)?;
|
reader.read_exact(&mut raw_data)?;
|
||||||
super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec())
|
super::ggml_file::qtensor_from_ggml(
|
||||||
|
self.ggml_dtype,
|
||||||
|
&raw_data,
|
||||||
|
self.shape.dims().to_vec(),
|
||||||
|
device,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -86,9 +90,7 @@ pub struct Content {
|
|||||||
fn read_string<R: std::io::Read>(reader: &mut R, magic: &VersionedMagic) -> Result<String> {
|
fn read_string<R: std::io::Read>(reader: &mut R, magic: &VersionedMagic) -> Result<String> {
|
||||||
let len = match magic {
|
let len = match magic {
|
||||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||||
reader.read_u64::<LittleEndian>()? as usize
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
let mut v = vec![0u8; len];
|
let mut v = vec![0u8; len];
|
||||||
reader.read_exact(&mut v)?;
|
reader.read_exact(&mut v)?;
|
||||||
@ -288,9 +290,7 @@ impl Value {
|
|||||||
let value_type = ValueType::from_u32(value_type)?;
|
let value_type = ValueType::from_u32(value_type)?;
|
||||||
let len = match magic {
|
let len = match magic {
|
||||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||||
reader.read_u64::<LittleEndian>()? as usize
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
let mut vs = Vec::with_capacity(len);
|
let mut vs = Vec::with_capacity(len);
|
||||||
for _ in 0..len {
|
for _ in 0..len {
|
||||||
@ -387,15 +387,11 @@ impl Content {
|
|||||||
|
|
||||||
let tensor_count = match magic {
|
let tensor_count = match magic {
|
||||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||||
reader.read_u64::<LittleEndian>()? as usize
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
let metadata_kv_count = match magic {
|
let metadata_kv_count = match magic {
|
||||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||||
reader.read_u64::<LittleEndian>()? as usize
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut metadata = HashMap::new();
|
let mut metadata = HashMap::new();
|
||||||
@ -417,7 +413,7 @@ impl Content {
|
|||||||
reader.read_u32_into::<LittleEndian>(&mut dimensions)?;
|
reader.read_u32_into::<LittleEndian>(&mut dimensions)?;
|
||||||
dimensions.into_iter().map(|c| c as usize).collect()
|
dimensions.into_iter().map(|c| c as usize).collect()
|
||||||
}
|
}
|
||||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
VersionedMagic::GgufV2 => {
|
||||||
let mut dimensions = vec![0; n_dimensions as usize];
|
let mut dimensions = vec![0; n_dimensions as usize];
|
||||||
reader.read_u64_into::<LittleEndian>(&mut dimensions)?;
|
reader.read_u64_into::<LittleEndian>(&mut dimensions)?;
|
||||||
dimensions.into_iter().map(|c| c as usize).collect()
|
dimensions.into_iter().map(|c| c as usize).collect()
|
||||||
@ -460,12 +456,13 @@ impl Content {
|
|||||||
&self,
|
&self,
|
||||||
reader: &mut R,
|
reader: &mut R,
|
||||||
name: &str,
|
name: &str,
|
||||||
|
device: &Device,
|
||||||
) -> Result<QTensor> {
|
) -> Result<QTensor> {
|
||||||
let tensor_info = match self.tensor_infos.get(name) {
|
let tensor_info = match self.tensor_infos.get(name) {
|
||||||
Some(tensor_info) => tensor_info,
|
Some(tensor_info) => tensor_info,
|
||||||
None => crate::bail!("cannot find tensor-infor for {name}"),
|
None => crate::bail!("cannot find tensor-infor for {name}"),
|
||||||
};
|
};
|
||||||
tensor_info.read(reader, self.tensor_data_offset)
|
tensor_info.read(reader, self.tensor_data_offset, device)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ pub mod utils;
|
|||||||
pub use k_quants::GgmlType;
|
pub use k_quants::GgmlType;
|
||||||
|
|
||||||
pub struct QTensor {
|
pub struct QTensor {
|
||||||
|
device: Device,
|
||||||
data: Box<dyn QuantizedType>,
|
data: Box<dyn QuantizedType>,
|
||||||
shape: Shape,
|
shape: Shape,
|
||||||
}
|
}
|
||||||
@ -170,17 +171,20 @@ impl QTensor {
|
|||||||
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
|
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
|
||||||
data: Vec<T>,
|
data: Vec<T>,
|
||||||
shape: S,
|
shape: S,
|
||||||
|
device: &Device,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let shape = shape.into();
|
let shape = shape.into();
|
||||||
check_shape::<T>(&shape)?;
|
check_shape::<T>(&shape)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
data: Box::new(data),
|
data: Box::new(data),
|
||||||
shape,
|
shape,
|
||||||
|
device: device.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
|
pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
|
||||||
let shape = src.shape();
|
let shape = src.shape();
|
||||||
|
let device = src.device();
|
||||||
check_shape::<T>(shape)?;
|
check_shape::<T>(shape)?;
|
||||||
let src = src
|
let src = src
|
||||||
.to_dtype(crate::DType::F32)?
|
.to_dtype(crate::DType::F32)?
|
||||||
@ -197,6 +201,7 @@ impl QTensor {
|
|||||||
Ok(Self {
|
Ok(Self {
|
||||||
data: Box::new(data),
|
data: Box::new(data),
|
||||||
shape: shape.clone(),
|
shape: shape.clone(),
|
||||||
|
device: device.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -212,7 +217,12 @@ impl QTensor {
|
|||||||
&self.shape
|
&self.shape
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn device(&self) -> &Device {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
|
||||||
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
||||||
|
// TODO Skip the CPU part on metal
|
||||||
let mut f32_data = vec![0f32; self.shape.elem_count()];
|
let mut f32_data = vec![0f32; self.shape.elem_count()];
|
||||||
self.data.to_float(&mut f32_data)?;
|
self.data.to_float(&mut f32_data)?;
|
||||||
Tensor::from_vec(f32_data, &self.shape, device)
|
Tensor::from_vec(f32_data, &self.shape, device)
|
||||||
@ -305,6 +315,49 @@ impl crate::CustomOp1 for QTensor {
|
|||||||
)?;
|
)?;
|
||||||
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
storage: &crate::MetalStorage,
|
||||||
|
layout: &crate::Layout,
|
||||||
|
) -> Result<(crate::MetalStorage, Shape)> {
|
||||||
|
println!("TODO qmatmul");
|
||||||
|
if !layout.is_contiguous() {
|
||||||
|
crate::bail!("input tensor is not contiguous {layout:?}")
|
||||||
|
}
|
||||||
|
let src_shape = layout.shape();
|
||||||
|
// self is transposed so n is first then k.
|
||||||
|
let (n, k) = self.shape.dims2()?;
|
||||||
|
if src_shape.rank() < 2 {
|
||||||
|
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||||
|
}
|
||||||
|
let mut dst_shape = src_shape.dims().to_vec();
|
||||||
|
let last_k = dst_shape.pop().unwrap();
|
||||||
|
if last_k != k {
|
||||||
|
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
|
||||||
|
}
|
||||||
|
dst_shape.push(n);
|
||||||
|
let dst_shape = Shape::from(dst_shape);
|
||||||
|
// let storage = storage.as_slice::<f32>()?;
|
||||||
|
// let storage =
|
||||||
|
// &storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||||
|
let dst_storage = vec![0f32; dst_shape.elem_count()];
|
||||||
|
// self.matmul_t(
|
||||||
|
// (dst_shape.elem_count() / n, k, n),
|
||||||
|
// storage,
|
||||||
|
// &mut dst_storage,
|
||||||
|
// )?;
|
||||||
|
let cpu_storage = crate::CpuStorage::F32(dst_storage);
|
||||||
|
use crate::backend::{BackendDevice, BackendStorage};
|
||||||
|
if let Device::Metal(device) = &self.device{
|
||||||
|
Ok((
|
||||||
|
device.storage_from_cpu_storage(&cpu_storage)?,
|
||||||
|
dst_shape,
|
||||||
|
))
|
||||||
|
}else{
|
||||||
|
crate::bail!("qtensor not on metal device")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl QMatMul {
|
impl QMatMul {
|
||||||
|
@ -334,33 +334,6 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn conv_transpose1d(
|
|
||||||
&self,
|
|
||||||
l: &Layout,
|
|
||||||
kernel: &Self,
|
|
||||||
kernel_l: &Layout,
|
|
||||||
params: &crate::conv::ParamsConvTranspose1D,
|
|
||||||
) -> Result<Self> {
|
|
||||||
self.same_device(kernel, "conv-transpose1d")?;
|
|
||||||
self.same_dtype(kernel, "conv-transpose1d")?;
|
|
||||||
match (self, &kernel) {
|
|
||||||
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
|
|
||||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
|
||||||
Ok(Self::Cpu(s))
|
|
||||||
}
|
|
||||||
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
|
|
||||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
|
||||||
Ok(Self::Cuda(s))
|
|
||||||
}
|
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
|
||||||
lhs: lhs.device().location(),
|
|
||||||
rhs: rhs.device().location(),
|
|
||||||
op: "conv-transpose1d",
|
|
||||||
}
|
|
||||||
.bt()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn conv2d(
|
pub(crate) fn conv2d(
|
||||||
&self,
|
&self,
|
||||||
l: &Layout,
|
l: &Layout,
|
||||||
|
@ -477,12 +477,6 @@ impl Tensor {
|
|||||||
broadcast_binary_op!(broadcast_div, div);
|
broadcast_binary_op!(broadcast_div, div);
|
||||||
broadcast_binary_op!(broadcast_maximum, maximum);
|
broadcast_binary_op!(broadcast_maximum, maximum);
|
||||||
broadcast_binary_op!(broadcast_minimum, minimum);
|
broadcast_binary_op!(broadcast_minimum, minimum);
|
||||||
broadcast_binary_op!(broadcast_eq, eq);
|
|
||||||
broadcast_binary_op!(broadcast_ne, ne);
|
|
||||||
broadcast_binary_op!(broadcast_lt, lt);
|
|
||||||
broadcast_binary_op!(broadcast_le, le);
|
|
||||||
broadcast_binary_op!(broadcast_gt, gt);
|
|
||||||
broadcast_binary_op!(broadcast_ge, ge);
|
|
||||||
|
|
||||||
unary_op!(recip, Recip);
|
unary_op!(recip, Recip);
|
||||||
unary_op!(neg, Neg);
|
unary_op!(neg, Neg);
|
||||||
@ -856,20 +850,6 @@ impl Tensor {
|
|||||||
self.sum_impl(mean_dims, false)? * scale
|
self.sum_impl(mean_dims, false)? * scale
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the unbiased variance over the selected dimension.
|
|
||||||
pub fn var_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
|
||||||
let dim = dim.to_index(self.shape(), "var")?;
|
|
||||||
let mean = self.mean_keepdim(dim)?;
|
|
||||||
let squares = self.broadcast_sub(&mean)?.sqr()?;
|
|
||||||
squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the unbiased variance over the selected dimension.
|
|
||||||
pub fn var<D: Dim>(&self, dim: D) -> Result<Self> {
|
|
||||||
let dim = dim.to_index(self.shape(), "var")?;
|
|
||||||
self.var_keepdim(dim)?.squeeze(dim)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Gathers the maximum value across the selected dimension. The resulting shape has the same
|
/// Gathers the maximum value across the selected dimension. The resulting shape has the same
|
||||||
/// number of dimensions as the original tensor and the select dimension has a single element.
|
/// number of dimensions as the original tensor and the select dimension has a single element.
|
||||||
pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||||
@ -1831,23 +1811,17 @@ impl Tensor {
|
|||||||
|
|
||||||
/// Returns a new tensor detached from the current graph, gradient are not propagated through
|
/// Returns a new tensor detached from the current graph, gradient are not propagated through
|
||||||
/// this new node. The storage of this tensor is shared with the initial tensor.
|
/// this new node. The storage of this tensor is shared with the initial tensor.
|
||||||
///
|
|
||||||
/// If the tensor is already detached from the computation graph, the same tensor is returned.
|
|
||||||
pub fn detach(&self) -> Result<Tensor> {
|
pub fn detach(&self) -> Result<Tensor> {
|
||||||
if self.op.is_none() && !self.is_variable {
|
let tensor_ = Tensor_ {
|
||||||
Ok(self.clone())
|
id: TensorId::new(),
|
||||||
} else {
|
storage: self.storage.clone(),
|
||||||
let tensor_ = Tensor_ {
|
layout: self.layout.clone(),
|
||||||
id: TensorId::new(),
|
op: BackpropOp::none(),
|
||||||
storage: self.storage.clone(),
|
is_variable: false,
|
||||||
layout: self.layout.clone(),
|
dtype: self.dtype,
|
||||||
op: BackpropOp::none(),
|
device: self.device.clone(),
|
||||||
is_variable: false,
|
};
|
||||||
dtype: self.dtype,
|
Ok(Tensor(Arc::new(tensor_)))
|
||||||
device: self.device.clone(),
|
|
||||||
};
|
|
||||||
Ok(Tensor(Arc::new(tensor_)))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// If the target device is the same as the tensor device, only a shallow copy is performed.
|
/// If the target device is the same as the tensor device, only a shallow copy is performed.
|
||||||
@ -1859,14 +1833,7 @@ impl Tensor {
|
|||||||
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
|
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
|
||||||
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
|
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
|
||||||
}
|
}
|
||||||
(Storage::Cpu(storage), Device::Metal(metal)) => {
|
|
||||||
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
|
|
||||||
}
|
|
||||||
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||||
(Storage::Metal(storage), Device::Cpu) => {
|
|
||||||
println!("{storage:?} - {:?}", storage.to_cpu_storage()?);
|
|
||||||
Storage::Cpu(storage.to_cpu_storage()?)
|
|
||||||
}
|
|
||||||
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
|
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
|
||||||
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
|
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
|
||||||
// are the same.
|
// are the same.
|
||||||
@ -2440,23 +2407,6 @@ impl Tensor {
|
|||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Normalize a 'relative' axis value: positive values are kept, negative
|
|
||||||
/// values means counting the dimensions from the back.
|
|
||||||
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
|
|
||||||
let rank = self.rank() as i64;
|
|
||||||
if rank <= axis {
|
|
||||||
crate::bail!("axis {axis} is too large, tensor rank {rank}")
|
|
||||||
} else if 0 <= axis {
|
|
||||||
Ok(axis as usize)
|
|
||||||
} else {
|
|
||||||
let naxis = rank + axis;
|
|
||||||
if naxis < 0 {
|
|
||||||
crate::bail!("axis {axis} is too small, tensor rank {rank}")
|
|
||||||
}
|
|
||||||
Ok(naxis as usize)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! bin_trait {
|
macro_rules! bin_trait {
|
||||||
|
@ -4,7 +4,7 @@ use crate::{Result, Tensor};
|
|||||||
macro_rules! test_device {
|
macro_rules! test_device {
|
||||||
// TODO: Switch to generating the two last arguments automatically once concat_idents is
|
// TODO: Switch to generating the two last arguments automatically once concat_idents is
|
||||||
// stable. https://github.com/rust-lang/rust/issues/29599
|
// stable. https://github.com/rust-lang/rust/issues/29599
|
||||||
($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => {
|
($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => {
|
||||||
#[test]
|
#[test]
|
||||||
fn $test_cpu() -> Result<()> {
|
fn $test_cpu() -> Result<()> {
|
||||||
$fn_name(&Device::Cpu)
|
$fn_name(&Device::Cpu)
|
||||||
@ -15,12 +15,6 @@ macro_rules! test_device {
|
|||||||
fn $test_cuda() -> Result<()> {
|
fn $test_cuda() -> Result<()> {
|
||||||
$fn_name(&Device::new_cuda(0)?)
|
$fn_name(&Device::new_cuda(0)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
#[test]
|
|
||||||
fn $test_metal() -> Result<()> {
|
|
||||||
$fn_name(&Device::new_metal(0)?)
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -13,11 +13,6 @@ res = torch.nn.functional.conv1d(t, w)
|
|||||||
print(res.flatten())
|
print(res.flatten())
|
||||||
res = torch.nn.functional.conv1d(t, w, padding=1)
|
res = torch.nn.functional.conv1d(t, w, padding=1)
|
||||||
print(res.flatten())
|
print(res.flatten())
|
||||||
|
|
||||||
w_t = w.transpose(0, 1)
|
|
||||||
res = torch.nn.functional.conv_transpose1d(t, w_t)
|
|
||||||
print(res.shape)
|
|
||||||
print(res)
|
|
||||||
*/
|
*/
|
||||||
fn conv1d(dev: &Device) -> Result<()> {
|
fn conv1d(dev: &Device) -> Result<()> {
|
||||||
let t = Tensor::new(
|
let t = Tensor::new(
|
||||||
@ -50,17 +45,6 @@ fn conv1d(dev: &Device) -> Result<()> {
|
|||||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||||
);
|
);
|
||||||
if dev.is_cpu() {
|
|
||||||
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
|
||||||
assert_eq!(res.dims(), [1, 2, 7]);
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
|
||||||
[
|
|
||||||
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
|
|
||||||
4.7076, -5.9745, -0.8276, 1.621
|
|
||||||
],
|
|
||||||
);
|
|
||||||
}
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -563,35 +547,14 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(conv1d, conv1d_cpu, conv1d_gpu, conv1d_metal);
|
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
|
||||||
test_device!(
|
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
|
||||||
conv1d_small,
|
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
|
||||||
conv1d_small_cpu,
|
|
||||||
conv1d_small_gpu,
|
|
||||||
conv1d_small_metal
|
|
||||||
);
|
|
||||||
test_device!(conv2d, conv2d_cpu, conv2d_gpu, conv2d_metal);
|
|
||||||
test_device!(
|
test_device!(
|
||||||
conv2d_non_square,
|
conv2d_non_square,
|
||||||
conv2d_non_square_cpu,
|
conv2d_non_square_cpu,
|
||||||
conv2d_non_square_gpu,
|
conv2d_non_square_gpu
|
||||||
conv2d_non_square_metal
|
|
||||||
);
|
|
||||||
test_device!(
|
|
||||||
conv2d_small,
|
|
||||||
conv2d_small_cpu,
|
|
||||||
conv2d_small_gpu,
|
|
||||||
conv2d_small_metal
|
|
||||||
);
|
|
||||||
test_device!(
|
|
||||||
conv2d_smaller,
|
|
||||||
conv2d_smaller_cpu,
|
|
||||||
conv2d_smaller_gpu,
|
|
||||||
conv2d_smaller_metal
|
|
||||||
);
|
|
||||||
test_device!(
|
|
||||||
conv2d_grad,
|
|
||||||
conv2d_grad_cpu,
|
|
||||||
conv2d_grad_gpu,
|
|
||||||
conv2_grad_metal
|
|
||||||
);
|
);
|
||||||
|
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
|
||||||
|
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
|
||||||
|
test_device!(conv2d_grad, conv2d_grad_cpu, conv2d_grad_gpu);
|
||||||
|
@ -205,71 +205,6 @@ fn unary_grad(device: &Device) -> Result<()> {
|
|||||||
test_utils::to_vec1_round(grad_x, 4)?,
|
test_utils::to_vec1_round(grad_x, 4)?,
|
||||||
[1.0116, 1.0830, 1.0003, 0.6188],
|
[1.0116, 1.0830, 1.0003, 0.6188],
|
||||||
);
|
);
|
||||||
|
|
||||||
// Testing compared to pytorch torch.erf
|
|
||||||
//
|
|
||||||
// import torch
|
|
||||||
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
|
|
||||||
// y = x.erf()
|
|
||||||
// print(y)
|
|
||||||
// loss = y.sum()
|
|
||||||
// loss.backward()
|
|
||||||
// print(x.grad)
|
|
||||||
let y = x.erf()?;
|
|
||||||
let grads = y.backward()?;
|
|
||||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
|
||||||
assert_eq!(test_utils::to_vec1_round(&y, 4)?, [1.0, 0.8427, 1.0, 0.168]);
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec1_round(grad_x, 4)?,
|
|
||||||
[0.0001, 0.4151, 0.0, 1.1033],
|
|
||||||
);
|
|
||||||
|
|
||||||
// Testing compared to pytorch nn.GELU(approximate = 'none')
|
|
||||||
//
|
|
||||||
// import torch
|
|
||||||
// import torch.nn.functional as F
|
|
||||||
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
|
|
||||||
// y = F.gelu(x, approximate='none')
|
|
||||||
// print(y)
|
|
||||||
// loss = y.sum()
|
|
||||||
// loss.backward()
|
|
||||||
// print(x.grad)
|
|
||||||
let y = x.gelu_erf()?;
|
|
||||||
let grads = y.backward()?;
|
|
||||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec1_round(&y, 4)?,
|
|
||||||
[2.9960, 0.8413, 3.9999, 0.0839]
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec1_round(grad_x, 4)?,
|
|
||||||
[1.0119, 1.0833, 1.0005, 0.6188],
|
|
||||||
);
|
|
||||||
|
|
||||||
// Testing compared to pytorch elu
|
|
||||||
//
|
|
||||||
// import torch
|
|
||||||
// import torch.nn.functional as F
|
|
||||||
// x = torch.tensor([-1.0, 0.0, -2.0, 3.0], requires_grad=True)
|
|
||||||
// y = F.elu(x, alpha=2.0)
|
|
||||||
// print(y)
|
|
||||||
// loss = y.min
|
|
||||||
// loss = y.sum()
|
|
||||||
// loss.backward()
|
|
||||||
// print(x.grad)
|
|
||||||
let elu_x = Var::new(&[-1.0f32, 0., -2., 3.], device)?;
|
|
||||||
let y = elu_x.elu(2.)?;
|
|
||||||
let grads = y.backward()?;
|
|
||||||
let grad_x = grads.get(&elu_x).context("no grad for x")?;
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec1_round(&y, 4)?,
|
|
||||||
[-1.2642, 0.0000, -1.7293, 3.0000]
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec1_round(grad_x, 4)?,
|
|
||||||
[0.7358, 2.0000, 0.2707, 1.0000]
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -315,29 +250,9 @@ fn binary_grad(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(
|
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
|
||||||
simple_grad,
|
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
|
||||||
simple_grad_cpu,
|
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
|
||||||
simple_grad_gpu,
|
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
|
||||||
simple_grad_metal
|
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu);
|
||||||
);
|
test_device!(binary_grad, binary_grad_cpu, binary_grad_gpu);
|
||||||
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu, sum_grad_metal);
|
|
||||||
test_device!(
|
|
||||||
matmul_grad,
|
|
||||||
matmul_grad_cpu,
|
|
||||||
matmul_grad_gpu,
|
|
||||||
matmul_grad_metal
|
|
||||||
);
|
|
||||||
test_device!(
|
|
||||||
grad_descent,
|
|
||||||
grad_descent_cpu,
|
|
||||||
grad_descent_gpu,
|
|
||||||
grad_descent_metal
|
|
||||||
);
|
|
||||||
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu, unary_grad_metal);
|
|
||||||
test_device!(
|
|
||||||
binary_grad,
|
|
||||||
binary_grad_cpu,
|
|
||||||
binary_grad_gpu,
|
|
||||||
binary_grad_metal
|
|
||||||
);
|
|
||||||
|
@ -49,7 +49,7 @@ fn contiguous(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(contiguous, contiguous_cpu, contiguous_gpu, contiguous_metal);
|
test_device!(contiguous, contiguous_cpu, contiguous_gpu);
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn strided_blocks() -> Result<()> {
|
fn strided_blocks() -> Result<()> {
|
||||||
|
@ -98,17 +98,15 @@ fn upsample_nearest2d(dev: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu, avg_pool2d_metal);
|
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu);
|
||||||
test_device!(
|
test_device!(
|
||||||
avg_pool2d_pytorch,
|
avg_pool2d_pytorch,
|
||||||
avg_pool2d_pytorch_cpu,
|
avg_pool2d_pytorch_cpu,
|
||||||
avg_pool2d_pytorch_gpu,
|
avg_pool2d_pytorch_gpu
|
||||||
avg_pool2d_pytorch_metal
|
|
||||||
);
|
);
|
||||||
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu, max_pool2d_metal);
|
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);
|
||||||
test_device!(
|
test_device!(
|
||||||
upsample_nearest2d,
|
upsample_nearest2d,
|
||||||
upsample_nearest2d_cpu,
|
upsample_nearest2d_cpu,
|
||||||
upsample_nearest2d_gpu,
|
upsample_nearest2d_gpu
|
||||||
upsample_nearest2d_metal
|
|
||||||
);
|
);
|
||||||
|
@ -180,22 +180,6 @@ fn transpose(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn var(device: &Device) -> Result<()> {
|
|
||||||
// Values taken from https://pytorch.org/docs/stable/generated/torch.var.html
|
|
||||||
let data = &[
|
|
||||||
[0.2035f32, 1.2959, 1.8101, -0.4644],
|
|
||||||
[1.5027, -0.3270, 0.5905, 0.6538],
|
|
||||||
[-1.5745, 1.3330, -0.5596, -0.6548],
|
|
||||||
[0.1264, -0.5080, 1.6420, 0.1992],
|
|
||||||
];
|
|
||||||
let tensor = Tensor::new(data, device)?;
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec2_round(&tensor.var_keepdim(1)?, 4)?,
|
|
||||||
&[[1.0631], [0.559], [1.4893], [0.8258]]
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn sum(device: &Device) -> Result<()> {
|
fn sum(device: &Device) -> Result<()> {
|
||||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||||
let tensor = Tensor::new(data, device)?;
|
let tensor = Tensor::new(data, device)?;
|
||||||
@ -1070,60 +1054,34 @@ fn randn(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
test_device!(zeros, zeros_cpu, zeros_gpu);
|
||||||
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
test_device!(ones, ones_cpu, ones_gpu);
|
||||||
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
|
test_device!(arange, arange_cpu, arange_gpu);
|
||||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
||||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
||||||
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
|
test_device!(narrow, narrow_cpu, narrow_gpu);
|
||||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
|
test_device!(broadcast, broadcast_cpu, broadcast_gpu);
|
||||||
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
|
test_device!(cat, cat_cpu, cat_gpu);
|
||||||
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
|
test_device!(sum, sum_cpu, sum_gpu);
|
||||||
test_device!(min, min_cpu, min_gpu, min_metal);
|
test_device!(min, min_cpu, min_gpu);
|
||||||
test_device!(max, max_cpu, max_gpu, max_metal);
|
test_device!(max, max_cpu, max_gpu);
|
||||||
test_device!(argmax, argmax_cpu, argmax_gpu, argmax_metal);
|
test_device!(argmax, argmax_cpu, argmax_gpu);
|
||||||
test_device!(argmin, argmin_cpu, argmin_gpu, argmin_metal);
|
test_device!(argmin, argmin_cpu, argmin_gpu);
|
||||||
test_device!(transpose, transpose_cpu, transpose_gpu, transpose_metal);
|
test_device!(transpose, transpose_cpu, transpose_gpu);
|
||||||
test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
|
test_device!(unary_op, unary_op_cpu, unary_op_gpu);
|
||||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
|
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
||||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
|
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
||||||
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
|
test_device!(cmp, cmp_cpu, cmp_gpu);
|
||||||
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
test_device!(matmul, matmul_cpu, matmul_gpu);
|
||||||
test_device!(
|
test_device!(broadcast_matmul, broadcast_matmul_cpu, broadcast_matmul_gpu);
|
||||||
broadcast_matmul,
|
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
|
||||||
broadcast_matmul_cpu,
|
test_device!(index_select, index_select_cpu, index_select_gpu);
|
||||||
broadcast_matmul_gpu,
|
test_device!(index_add, index_add_cpu, index_add_gpu);
|
||||||
broadcast_matmul_metal
|
test_device!(gather, gather_cpu, gather_gpu);
|
||||||
);
|
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
||||||
test_device!(
|
test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
|
||||||
broadcasting,
|
test_device!(randn, randn_cpu, randn_gpu);
|
||||||
broadcasting_cpu,
|
test_device!(clamp, clamp_cpu, clamp_gpu);
|
||||||
broadcasting_gpu,
|
|
||||||
broadcasting_metal
|
|
||||||
);
|
|
||||||
test_device!(
|
|
||||||
index_select,
|
|
||||||
index_select_cpu,
|
|
||||||
index_select_gpu,
|
|
||||||
index_select_metal
|
|
||||||
);
|
|
||||||
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
|
|
||||||
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
|
|
||||||
test_device!(
|
|
||||||
scatter_add,
|
|
||||||
scatter_add_cpu,
|
|
||||||
scatter_add_gpu,
|
|
||||||
scatter_add_metal
|
|
||||||
);
|
|
||||||
test_device!(
|
|
||||||
slice_scatter,
|
|
||||||
slice_scatter_cpu,
|
|
||||||
slice_scatter_gpu,
|
|
||||||
slice_scatter_metal
|
|
||||||
);
|
|
||||||
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
|
||||||
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
|
||||||
test_device!(var, var_cpu, var_gpu, var_metal);
|
|
||||||
|
|
||||||
// There was originally a bug on the CPU implementation for randn
|
// There was originally a bug on the CPU implementation for randn
|
||||||
// https://github.com/huggingface/candle/issues/381
|
// https://github.com/huggingface/candle/issues/381
|
||||||
|
@ -4,9 +4,7 @@
|
|||||||
//! <https://www.cs.toronto.edu/~kriz/cifar.html>
|
//! <https://www.cs.toronto.edu/~kriz/cifar.html>
|
||||||
//! The binary version of the dataset is used.
|
//! The binary version of the dataset is used.
|
||||||
use crate::vision::Dataset;
|
use crate::vision::Dataset;
|
||||||
use candle::{DType, Device, Error, Result, Tensor};
|
use candle::{DType, Device, Result, Tensor};
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
||||||
use parquet::file::reader::{FileReader, SerializedFileReader};
|
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::{BufReader, Read};
|
use std::io::{BufReader, Read};
|
||||||
|
|
||||||
@ -62,58 +60,3 @@ pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<Dataset> {
|
|||||||
labels: 10,
|
labels: 10,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor, Tensor)> {
|
|
||||||
let samples = parquet.metadata().file_metadata().num_rows() as usize;
|
|
||||||
let mut buffer_images: Vec<u8> = Vec::with_capacity(samples * 1_024);
|
|
||||||
let mut buffer_labels: Vec<u8> = Vec::with_capacity(samples);
|
|
||||||
for row in parquet.into_iter().flatten() {
|
|
||||||
for (_name, field) in row.get_column_iter() {
|
|
||||||
if let parquet::record::Field::Group(subrow) = field {
|
|
||||||
for (_name, field) in subrow.get_column_iter() {
|
|
||||||
if let parquet::record::Field::Bytes(value) = field {
|
|
||||||
let image = image::load_from_memory(value.data()).unwrap();
|
|
||||||
buffer_images.extend(image.to_rgb8().as_raw());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if let parquet::record::Field::Long(label) = field {
|
|
||||||
buffer_labels.push(*label as u8);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let images = (Tensor::from_vec(buffer_images, (samples, 3, 32, 32), &Device::Cpu)?
|
|
||||||
.to_dtype(DType::U8)?
|
|
||||||
/ 255.)?;
|
|
||||||
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
|
|
||||||
Ok((images, labels))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load() -> Result<Dataset> {
|
|
||||||
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
|
||||||
let dataset_id = "cifar10".to_string();
|
|
||||||
let repo = Repo::with_revision(
|
|
||||||
dataset_id,
|
|
||||||
RepoType::Dataset,
|
|
||||||
"refs/convert/parquet".to_string(),
|
|
||||||
);
|
|
||||||
let repo = api.repo(repo);
|
|
||||||
let test_parquet_filename = repo
|
|
||||||
.get("plain_text/test/0000.parquet")
|
|
||||||
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
|
||||||
let train_parquet_filename = repo
|
|
||||||
.get("plain_text/train/0000.parquet")
|
|
||||||
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
|
||||||
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)
|
|
||||||
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
|
|
||||||
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)
|
|
||||||
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
|
|
||||||
let (test_images, test_labels) = load_parquet(test_parquet)?;
|
|
||||||
let (train_images, train_labels) = load_parquet(train_parquet)?;
|
|
||||||
Ok(crate::vision::Dataset {
|
|
||||||
train_images,
|
|
||||||
train_labels,
|
|
||||||
test_images,
|
|
||||||
test_labels,
|
|
||||||
labels: 10,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
@ -16,7 +16,6 @@ candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
|||||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||||
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
||||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||||
candle-onnx = { path = "../candle-onnx", version = "0.3.0", optional = true }
|
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
half = { workspace = true, optional = true }
|
half = { workspace = true, optional = true }
|
||||||
image = { workspace = true }
|
image = { workspace = true }
|
||||||
@ -52,12 +51,11 @@ anyhow = { workspace = true }
|
|||||||
default = []
|
default = []
|
||||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||||
|
metal = ["candle/metal", "candle-nn/metal", "candle-transformers/metal"]
|
||||||
cudnn = ["candle/cudnn"]
|
cudnn = ["candle/cudnn"]
|
||||||
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||||
onnx = ["candle-onnx"]
|
|
||||||
metal = ["candle/metal", "candle-nn/metal"]
|
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
name = "llama_multiprocess"
|
name = "llama_multiprocess"
|
||||||
@ -66,11 +64,3 @@ required-features = ["cuda", "nccl", "flash-attn"]
|
|||||||
[[example]]
|
[[example]]
|
||||||
name = "reinforcement-learning"
|
name = "reinforcement-learning"
|
||||||
required-features = ["pyo3"]
|
required-features = ["pyo3"]
|
||||||
|
|
||||||
[[example]]
|
|
||||||
name = "onnx"
|
|
||||||
required-features = ["onnx"]
|
|
||||||
|
|
||||||
[[example]]
|
|
||||||
name = "onnx_basics"
|
|
||||||
required-features = ["onnx"]
|
|
||||||
|
@ -8,7 +8,6 @@ use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
|
|||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
enum NormType {
|
enum NormType {
|
||||||
WeightNorm,
|
WeightNorm,
|
||||||
TimeGroupNorm,
|
|
||||||
None,
|
None,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -269,7 +268,6 @@ impl Module for EncodecConvTranspose1d {
|
|||||||
struct EncodecConv1d {
|
struct EncodecConv1d {
|
||||||
causal: bool,
|
causal: bool,
|
||||||
conv: Conv1d,
|
conv: Conv1d,
|
||||||
norm: Option<candle_nn::GroupNorm>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EncodecConv1d {
|
impl EncodecConv1d {
|
||||||
@ -294,7 +292,7 @@ impl EncodecConv1d {
|
|||||||
},
|
},
|
||||||
vb.pp("conv"),
|
vb.pp("conv"),
|
||||||
)?,
|
)?,
|
||||||
NormType::None | NormType::TimeGroupNorm => conv1d(
|
NormType::None => conv1d(
|
||||||
in_c,
|
in_c,
|
||||||
out_c,
|
out_c,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
@ -307,17 +305,9 @@ impl EncodecConv1d {
|
|||||||
vb.pp("conv"),
|
vb.pp("conv"),
|
||||||
)?,
|
)?,
|
||||||
};
|
};
|
||||||
let norm = match cfg.norm_type {
|
|
||||||
NormType::None | NormType::WeightNorm => None,
|
|
||||||
NormType::TimeGroupNorm => {
|
|
||||||
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
|
|
||||||
Some(gn)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
causal: cfg.use_causal_conv,
|
causal: cfg.use_causal_conv,
|
||||||
conv,
|
conv,
|
||||||
norm,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -326,10 +316,8 @@ impl Module for EncodecConv1d {
|
|||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
// TODO: padding, depending on causal.
|
// TODO: padding, depending on causal.
|
||||||
let xs = self.conv.forward(xs)?;
|
let xs = self.conv.forward(xs)?;
|
||||||
match &self.norm {
|
// If we add support for NormType "time_group_norm", we should add some normalization here.
|
||||||
None => Ok(xs),
|
Ok(xs)
|
||||||
Some(norm) => xs.apply(norm),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,10 +0,0 @@
|
|||||||
## Using ONNX models in Candle
|
|
||||||
|
|
||||||
This example demonstrates how to run ONNX based models in Candle, the model
|
|
||||||
being used here is a small sequeezenet variant.
|
|
||||||
|
|
||||||
You can run the example with the following command:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cargo run --example squeezenet-onnx --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
|
||||||
```
|
|
@ -1,78 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use candle::{IndexOp, D};
|
|
||||||
use clap::{Parser, ValueEnum};
|
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
|
||||||
enum Which {
|
|
||||||
SqueezeNet,
|
|
||||||
EfficientNet,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser)]
|
|
||||||
struct Args {
|
|
||||||
#[arg(long)]
|
|
||||||
image: String,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
model: Option<String>,
|
|
||||||
|
|
||||||
/// The model to be used.
|
|
||||||
#[arg(value_enum, long, default_value_t = Which::SqueezeNet)]
|
|
||||||
which: Which,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
|
||||||
let args = Args::parse();
|
|
||||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
|
||||||
let image = match args.which {
|
|
||||||
Which::SqueezeNet => image,
|
|
||||||
Which::EfficientNet => image.permute((1, 2, 0))?,
|
|
||||||
};
|
|
||||||
|
|
||||||
println!("loaded image {image:?}");
|
|
||||||
|
|
||||||
let model = match args.model {
|
|
||||||
Some(model) => std::path::PathBuf::from(model),
|
|
||||||
None => match args.which {
|
|
||||||
Which::SqueezeNet => hf_hub::api::sync::Api::new()?
|
|
||||||
.model("lmz/candle-onnx".into())
|
|
||||||
.get("squeezenet1.1-7.onnx")?,
|
|
||||||
Which::EfficientNet => hf_hub::api::sync::Api::new()?
|
|
||||||
.model("onnx/EfficientNet-Lite4".into())
|
|
||||||
.get("efficientnet-lite4-11.onnx")?,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
let model = candle_onnx::read_file(model)?;
|
|
||||||
let graph = model.graph.as_ref().unwrap();
|
|
||||||
let mut inputs = std::collections::HashMap::new();
|
|
||||||
inputs.insert(graph.input[0].name.to_string(), image.unsqueeze(0)?);
|
|
||||||
let mut outputs = candle_onnx::simple_eval(&model, inputs)?;
|
|
||||||
let output = outputs.remove(&graph.output[0].name).unwrap();
|
|
||||||
let prs = match args.which {
|
|
||||||
Which::SqueezeNet => candle_nn::ops::softmax(&output, D::Minus1)?,
|
|
||||||
Which::EfficientNet => output,
|
|
||||||
};
|
|
||||||
let prs = prs.i(0)?.to_vec1::<f32>()?;
|
|
||||||
|
|
||||||
// Sort the predictions and take the top 5
|
|
||||||
let mut top: Vec<_> = prs.iter().enumerate().collect();
|
|
||||||
top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
|
|
||||||
let top = top.into_iter().take(5).collect::<Vec<_>>();
|
|
||||||
|
|
||||||
// Print the top predictions
|
|
||||||
for &(i, p) in &top {
|
|
||||||
println!(
|
|
||||||
"{:50}: {:.2}%",
|
|
||||||
candle_examples::imagenet::CLASSES[i],
|
|
||||||
p * 100.0
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -1,87 +0,0 @@
|
|||||||
use anyhow::Result;
|
|
||||||
use candle::{Device, Tensor};
|
|
||||||
|
|
||||||
use clap::{Parser, Subcommand};
|
|
||||||
|
|
||||||
#[derive(Subcommand, Debug, Clone)]
|
|
||||||
enum Command {
|
|
||||||
Print {
|
|
||||||
#[arg(long)]
|
|
||||||
file: String,
|
|
||||||
},
|
|
||||||
SimpleEval {
|
|
||||||
#[arg(long)]
|
|
||||||
file: String,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
pub struct Args {
|
|
||||||
#[command(subcommand)]
|
|
||||||
command: Command,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn main() -> Result<()> {
|
|
||||||
let args = Args::parse();
|
|
||||||
match args.command {
|
|
||||||
Command::Print { file } => {
|
|
||||||
let model = candle_onnx::read_file(file)?;
|
|
||||||
println!("{model:?}");
|
|
||||||
let graph = model.graph.unwrap();
|
|
||||||
for node in graph.node.iter() {
|
|
||||||
println!("{node:?}");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Command::SimpleEval { file } => {
|
|
||||||
let model = candle_onnx::read_file(file)?;
|
|
||||||
let graph = model.graph.as_ref().unwrap();
|
|
||||||
let constants: std::collections::HashSet<_> =
|
|
||||||
graph.initializer.iter().map(|i| i.name.as_str()).collect();
|
|
||||||
let mut inputs = std::collections::HashMap::new();
|
|
||||||
for input in graph.input.iter() {
|
|
||||||
use candle_onnx::onnx::tensor_proto::DataType;
|
|
||||||
if constants.contains(input.name.as_str()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let type_ = input.r#type.as_ref().expect("no type for input");
|
|
||||||
let type_ = type_.value.as_ref().expect("no type.value for input");
|
|
||||||
let value = match type_ {
|
|
||||||
candle_onnx::onnx::type_proto::Value::TensorType(tt) => {
|
|
||||||
let dt = match DataType::try_from(tt.elem_type) {
|
|
||||||
Ok(dt) => match candle_onnx::dtype(dt) {
|
|
||||||
Some(dt) => dt,
|
|
||||||
None => {
|
|
||||||
anyhow::bail!(
|
|
||||||
"unsupported 'value' data-type {dt:?} for {}",
|
|
||||||
input.name
|
|
||||||
)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
type_ => anyhow::bail!("unsupported input type {type_:?}"),
|
|
||||||
};
|
|
||||||
let shape = tt.shape.as_ref().expect("no tensortype.shape for input");
|
|
||||||
let dims = shape
|
|
||||||
.dim
|
|
||||||
.iter()
|
|
||||||
.map(|dim| match dim.value.as_ref().expect("no dim value") {
|
|
||||||
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
|
|
||||||
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => Ok(42),
|
|
||||||
})
|
|
||||||
.collect::<Result<Vec<usize>>>()?;
|
|
||||||
Tensor::zeros(dims, dt, &Device::Cpu)?
|
|
||||||
}
|
|
||||||
type_ => anyhow::bail!("unsupported input type {type_:?}"),
|
|
||||||
};
|
|
||||||
println!("input {}: {value:?}", input.name);
|
|
||||||
inputs.insert(input.name.clone(), value);
|
|
||||||
}
|
|
||||||
let outputs = candle_onnx::simple_eval(&model, inputs)?;
|
|
||||||
for (name, value) in outputs.iter() {
|
|
||||||
println!("output {name}: {value:?}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -1,7 +1,5 @@
|
|||||||
# candle-quantized-t5
|
# candle-quantized-t5
|
||||||
|
|
||||||
## Seq2Seq example
|
|
||||||
|
|
||||||
This example uses a quantized version of the t5 model.
|
This example uses a quantized version of the t5 model.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@ -10,8 +8,6 @@ $ cargo run --example quantized-t5 --release -- --prompt "translate to German: A
|
|||||||
Eine schöne Kerze.
|
Eine schöne Kerze.
|
||||||
```
|
```
|
||||||
|
|
||||||
## Generating Quantized weight files
|
|
||||||
|
|
||||||
The weight file is automatically retrieved from the hub. It is also possible to
|
The weight file is automatically retrieved from the hub. It is also possible to
|
||||||
generate quantized weight files from the original safetensors file by using the
|
generate quantized weight files from the original safetensors file by using the
|
||||||
`tensor-tools` command line utility via:
|
`tensor-tools` command line utility via:
|
||||||
@ -20,11 +16,8 @@ generate quantized weight files from the original safetensors file by using the
|
|||||||
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
|
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
|
||||||
```
|
```
|
||||||
|
|
||||||
## Using custom models
|
To use a different model, specify the `model-id`. For example, you can use
|
||||||
|
quantized [CoEdit models](https://huggingface.co/jbochi/candle-coedit-quantized).
|
||||||
To use a different model, specify the `model-id`.
|
|
||||||
|
|
||||||
For example, for text editing, you can use quantized [CoEdit models](https://huggingface.co/jbochi/candle-coedit-quantized).
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cargo run --example quantized-t5 --release -- \
|
$ cargo run --example quantized-t5 --release -- \
|
||||||
@ -33,7 +26,6 @@ $ cargo run --example quantized-t5 --release -- \
|
|||||||
--temperature 0
|
--temperature 0
|
||||||
...
|
...
|
||||||
Although their flight is weak, they run quickly through the tree canopy.
|
Although their flight is weak, they run quickly through the tree canopy.
|
||||||
```
|
|
||||||
|
|
||||||
By default, it will look for `model.gguf` and `config.json`, but you can specify
|
By default, it will look for `model.gguf` and `config.json`, but you can specify
|
||||||
custom local or remote `weight-file` and `config-file`s:
|
custom local or remote `weight-file` and `config-file`s:
|
||||||
@ -48,16 +40,3 @@ cargo run --example quantized-t5 --release -- \
|
|||||||
...
|
...
|
||||||
Note that a storm surge is what forecasters consider a hurricane's most dangerous part.
|
Note that a storm surge is what forecasters consider a hurricane's most dangerous part.
|
||||||
```
|
```
|
||||||
|
|
||||||
### [MADLAD-400](https://arxiv.org/abs/2309.04662)
|
|
||||||
|
|
||||||
MADLAD-400 is a series of multilingual machine translation T5 models trained on 250 billion tokens covering over 450 languages using publicly available data. These models are competitive with significantly larger models.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cargo run --example quantized-t5 --release -- \
|
|
||||||
--model-id "jbochi/madlad400-3b-mt" --weight-file "model-q4k.gguf" \
|
|
||||||
--prompt "<2de> How are you, my friend?" \
|
|
||||||
--temperature 0
|
|
||||||
...
|
|
||||||
Wie geht es dir, mein Freund?
|
|
||||||
```
|
|
||||||
|
@ -173,11 +173,7 @@ fn main() -> Result<()> {
|
|||||||
.to_vec();
|
.to_vec();
|
||||||
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||||
let mut model = builder.build_model()?;
|
let mut model = builder.build_model()?;
|
||||||
let mut output_token_ids = [builder
|
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
||||||
.config
|
|
||||||
.decoder_start_token_id
|
|
||||||
.unwrap_or(builder.config.pad_token_id) as u32]
|
|
||||||
.to_vec();
|
|
||||||
let temperature = if args.temperature <= 0. {
|
let temperature = if args.temperature <= 0. {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
|
@ -9,10 +9,9 @@ use std::io::Write;
|
|||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
use candle::quantized::{ggml_file, gguf_file};
|
use candle::quantized::{ggml_file, gguf_file};
|
||||||
use candle::{Device, Tensor};
|
use candle::{Tensor};
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
|
||||||
use candle_transformers::models::quantized_llama as model;
|
use candle_transformers::models::quantized_llama as model;
|
||||||
use model::ModelWeights;
|
use model::ModelWeights;
|
||||||
|
|
||||||
@ -25,7 +24,7 @@ enum Prompt {
|
|||||||
One(String),
|
One(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||||
enum Which {
|
enum Which {
|
||||||
#[value(name = "7b")]
|
#[value(name = "7b")]
|
||||||
L7b,
|
L7b,
|
||||||
@ -49,10 +48,8 @@ enum Which {
|
|||||||
Mistral7b,
|
Mistral7b,
|
||||||
#[value(name = "7b-mistral-instruct")]
|
#[value(name = "7b-mistral-instruct")]
|
||||||
Mistral7bInstruct,
|
Mistral7bInstruct,
|
||||||
#[value(name = "7b-zephyr-a")]
|
#[value(name = "7b-zephyr")]
|
||||||
Zephyr7bAlpha,
|
Zephyr7b,
|
||||||
#[value(name = "7b-zephyr-b")]
|
|
||||||
Zephyr7bBeta,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Which {
|
impl Which {
|
||||||
@ -67,28 +64,7 @@ impl Which {
|
|||||||
| Self::L7bCode
|
| Self::L7bCode
|
||||||
| Self::L13bCode
|
| Self::L13bCode
|
||||||
| Self::L34bCode => false,
|
| Self::L34bCode => false,
|
||||||
// Zephyr is a fine tuned version of mistral and should be treated in the same way.
|
Self::Mistral7b | Self::Mistral7bInstruct | Self::Zephyr7b => true,
|
||||||
Self::Zephyr7bAlpha
|
|
||||||
| Self::Zephyr7bBeta
|
|
||||||
| Self::Mistral7b
|
|
||||||
| Self::Mistral7bInstruct => true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_zephyr(&self) -> bool {
|
|
||||||
match self {
|
|
||||||
Self::L7b
|
|
||||||
| Self::L13b
|
|
||||||
| Self::L70b
|
|
||||||
| Self::L7bChat
|
|
||||||
| Self::L13bChat
|
|
||||||
| Self::L70bChat
|
|
||||||
| Self::L7bCode
|
|
||||||
| Self::L13bCode
|
|
||||||
| Self::L34bCode
|
|
||||||
| Self::Mistral7b
|
|
||||||
| Self::Mistral7bInstruct => false,
|
|
||||||
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -107,7 +83,7 @@ struct Args {
|
|||||||
prompt: Option<String>,
|
prompt: Option<String>,
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
/// The length of the sample to generate (in tokens).
|
||||||
#[arg(short = 'n', long, default_value_t = 1000)]
|
#[arg(short = 'n', long, default_value_t = 100)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
/// The tokenizer config in json format.
|
/// The tokenizer config in json format.
|
||||||
@ -200,13 +176,10 @@ impl Args {
|
|||||||
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
|
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
|
||||||
"mistral-7b-instruct-v0.1.Q4_K_S.gguf",
|
"mistral-7b-instruct-v0.1.Q4_K_S.gguf",
|
||||||
),
|
),
|
||||||
Which::Zephyr7bAlpha => (
|
Which::Zephyr7b => (
|
||||||
"TheBloke/zephyr-7B-alpha-GGUF",
|
"TheBloke/zephyr-7B-alpha-GGUF",
|
||||||
"zephyr-7b-alpha.Q4_K_M.gguf",
|
"zephyr-7b-alpha.Q4_K_M.gguf",
|
||||||
),
|
),
|
||||||
Which::Zephyr7bBeta => {
|
|
||||||
("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf")
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let api = api.model(repo.to_string());
|
let api = api.model(repo.to_string());
|
||||||
@ -217,6 +190,31 @@ impl Args {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn print_token(next_token: u32, tokenizer: &Tokenizer) {
|
||||||
|
// Extracting the last token as a string is complicated, here we just apply some simple
|
||||||
|
// heuristics as it seems to work well enough for this example. See the following for more
|
||||||
|
// details:
|
||||||
|
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
||||||
|
if let Some(text) = tokenizer.id_to_token(next_token) {
|
||||||
|
let text = text.replace('▁', " ");
|
||||||
|
let ascii = text
|
||||||
|
.strip_prefix("<0x")
|
||||||
|
.and_then(|t| t.strip_suffix('>'))
|
||||||
|
.and_then(|t| u8::from_str_radix(t, 16).ok());
|
||||||
|
match ascii {
|
||||||
|
None => print!("{text}"),
|
||||||
|
Some(ascii) => {
|
||||||
|
if let Some(chr) = char::from_u32(ascii as u32) {
|
||||||
|
if chr.is_ascii() {
|
||||||
|
print!("{chr}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let _ = std::io::stdout().flush();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn format_size(size_in_bytes: usize) -> String {
|
fn format_size(size_in_bytes: usize) -> String {
|
||||||
if size_in_bytes < 1_000 {
|
if size_in_bytes < 1_000 {
|
||||||
format!("{}B", size_in_bytes)
|
format!("{}B", size_in_bytes)
|
||||||
@ -234,6 +232,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
use tracing_subscriber::prelude::*;
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
let mut device = candle_examples::device(false)?;
|
||||||
let temperature = if args.temperature == 0. {
|
let temperature = if args.temperature == 0. {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
@ -278,10 +277,10 @@ fn main() -> anyhow::Result<()> {
|
|||||||
&format_size(total_size_in_bytes),
|
&format_size(total_size_in_bytes),
|
||||||
start.elapsed().as_secs_f32(),
|
start.elapsed().as_secs_f32(),
|
||||||
);
|
);
|
||||||
ModelWeights::from_gguf(model, &mut file)?
|
ModelWeights::from_gguf(model, &mut file, &device)?
|
||||||
}
|
}
|
||||||
Some("ggml" | "bin") | Some(_) | None => {
|
Some("ggml" | "bin") | Some(_) | None => {
|
||||||
let model = ggml_file::Content::read(&mut file)?;
|
let model = ggml_file::Content::read(&mut file, &device)?;
|
||||||
let mut total_size_in_bytes = 0;
|
let mut total_size_in_bytes = 0;
|
||||||
for (_, tensor) in model.tensors.iter() {
|
for (_, tensor) in model.tensors.iter() {
|
||||||
let elem_count = tensor.shape().elem_count();
|
let elem_count = tensor.shape().elem_count();
|
||||||
@ -305,18 +304,16 @@ fn main() -> anyhow::Result<()> {
|
|||||||
| Which::L34bCode => 1,
|
| Which::L34bCode => 1,
|
||||||
Which::Mistral7b
|
Which::Mistral7b
|
||||||
| Which::Mistral7bInstruct
|
| Which::Mistral7bInstruct
|
||||||
| Which::Zephyr7bAlpha
|
| Which::Zephyr7b
|
||||||
| Which::Zephyr7bBeta
|
|
||||||
| Which::L70b
|
| Which::L70b
|
||||||
| Which::L70bChat => 8,
|
| Which::L70bChat => 8,
|
||||||
};
|
};
|
||||||
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
|
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa), &device)?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
println!("model built");
|
println!("model built");
|
||||||
|
|
||||||
let tokenizer = args.tokenizer()?;
|
let tokenizer = args.tokenizer()?;
|
||||||
let mut tos = TokenOutputStream::new(tokenizer);
|
|
||||||
let prompt = match args.prompt.as_deref() {
|
let prompt = match args.prompt.as_deref() {
|
||||||
Some("chat") => Prompt::Chat,
|
Some("chat") => Prompt::Chat,
|
||||||
Some("interactive") => Prompt::Interactive,
|
Some("interactive") => Prompt::Interactive,
|
||||||
@ -325,11 +322,10 @@ fn main() -> anyhow::Result<()> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let mut pre_prompt_tokens = vec![];
|
let mut pre_prompt_tokens = vec![];
|
||||||
for prompt_index in 0.. {
|
loop {
|
||||||
let prompt_str = match &prompt {
|
let prompt_str = match &prompt {
|
||||||
Prompt::One(prompt) => prompt.clone(),
|
Prompt::One(prompt) => prompt.clone(),
|
||||||
Prompt::Interactive | Prompt::Chat => {
|
Prompt::Interactive | Prompt::Chat => {
|
||||||
let is_interactive = matches!(prompt, Prompt::Interactive);
|
|
||||||
print!("> ");
|
print!("> ");
|
||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
let mut prompt = String::new();
|
let mut prompt = String::new();
|
||||||
@ -340,13 +336,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
prompt.pop();
|
prompt.pop();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if args.which.is_zephyr() {
|
if args.which.is_mistral() {
|
||||||
if prompt_index == 0 || is_interactive {
|
|
||||||
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>",)
|
|
||||||
} else {
|
|
||||||
format!("<|user|>\n{prompt}</s>\n<|assistant|>")
|
|
||||||
}
|
|
||||||
} else if args.which.is_mistral() {
|
|
||||||
format!("[INST] {prompt} [/INST]")
|
format!("[INST] {prompt} [/INST]")
|
||||||
} else {
|
} else {
|
||||||
prompt
|
prompt
|
||||||
@ -354,8 +344,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
print!("{}", &prompt_str);
|
print!("{}", &prompt_str);
|
||||||
let tokens = tos
|
let tokens = tokenizer
|
||||||
.tokenizer()
|
|
||||||
.encode(prompt_str, true)
|
.encode(prompt_str, true)
|
||||||
.map_err(anyhow::Error::msg)?;
|
.map_err(anyhow::Error::msg)?;
|
||||||
if args.verbose_prompt {
|
if args.verbose_prompt {
|
||||||
@ -378,51 +367,46 @@ fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let start_prompt_processing = std::time::Instant::now();
|
let start_prompt_processing = std::time::Instant::now();
|
||||||
let mut next_token = {
|
let mut next_token = {
|
||||||
let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||||
let logits = model.forward(&input, 0)?;
|
let logits = model.forward(&input, 0)?;
|
||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
|
// TODO Remove this once implementation is finished.
|
||||||
|
let logits = logits.ones_like()?;
|
||||||
logits_processor.sample(&logits)?
|
logits_processor.sample(&logits)?
|
||||||
};
|
};
|
||||||
let prompt_dt = start_prompt_processing.elapsed();
|
let prompt_dt = start_prompt_processing.elapsed();
|
||||||
all_tokens.push(next_token);
|
all_tokens.push(next_token);
|
||||||
if let Some(t) = tos.next_token(next_token)? {
|
print_token(next_token, &tokenizer);
|
||||||
print!("{t}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let eos_token = *tos.tokenizer().get_vocab(true).get("</s>").unwrap();
|
let eos_token = *tokenizer.get_vocab(true).get("</s>").unwrap();
|
||||||
|
|
||||||
let start_post_prompt = std::time::Instant::now();
|
let start_post_prompt = std::time::Instant::now();
|
||||||
let mut sampled = 0;
|
|
||||||
for index in 0..to_sample {
|
for index in 0..to_sample {
|
||||||
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
|
||||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||||
|
if let candle::Device::Metal(device) = &mut device{
|
||||||
|
device.flush()
|
||||||
|
}
|
||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
let logits = if args.repeat_penalty == 1. {
|
// let logits = if args.repeat_penalty == 1. {
|
||||||
logits
|
// logits
|
||||||
} else {
|
// } else {
|
||||||
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
// let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
||||||
candle_transformers::utils::apply_repeat_penalty(
|
// candle_transformers::utils::apply_repeat_penalty(
|
||||||
&logits,
|
// &logits,
|
||||||
args.repeat_penalty,
|
// args.repeat_penalty,
|
||||||
&all_tokens[start_at..],
|
// &all_tokens[start_at..],
|
||||||
)?
|
// )?
|
||||||
};
|
// };
|
||||||
|
// TODO Remove this once implementation is finished.
|
||||||
|
let logits = logits.ones_like()?;
|
||||||
next_token = logits_processor.sample(&logits)?;
|
next_token = logits_processor.sample(&logits)?;
|
||||||
all_tokens.push(next_token);
|
all_tokens.push(next_token);
|
||||||
if let Some(t) = tos.next_token(next_token)? {
|
print_token(next_token, &tokenizer);
|
||||||
print!("{t}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
}
|
|
||||||
sampled += 1;
|
|
||||||
if next_token == eos_token {
|
if next_token == eos_token {
|
||||||
break;
|
break;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
|
|
||||||
print!("{rest}");
|
|
||||||
}
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
let dt = start_post_prompt.elapsed();
|
let dt = start_post_prompt.elapsed();
|
||||||
println!(
|
println!(
|
||||||
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
||||||
@ -430,8 +414,9 @@ fn main() -> anyhow::Result<()> {
|
|||||||
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
||||||
);
|
);
|
||||||
println!(
|
println!(
|
||||||
"{sampled:4} tokens generated: {:.2} token/s",
|
"{:4} tokens generated: {:.2} token/s",
|
||||||
sampled as f64 / dt.as_secs_f64(),
|
to_sample,
|
||||||
|
to_sample as f64 / dt.as_secs_f64(),
|
||||||
);
|
);
|
||||||
|
|
||||||
match prompt {
|
match prompt {
|
||||||
|
@ -5,26 +5,12 @@
|
|||||||
```bash
|
```bash
|
||||||
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode
|
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode
|
||||||
...
|
...
|
||||||
|
Running on CPU, to run on GPU, build this example with `--features cuda`
|
||||||
Eine schöne Kerze.
|
Eine schöne Kerze.
|
||||||
9 tokens generated (2.42 token/s)
|
9 tokens generated (2.42 token/s)
|
||||||
```
|
```
|
||||||
|
|
||||||
Variants such as [flan-t5](https://huggingface.co/google/flan-t5-small), [flan-ul2](https://huggingface.co/google/flan-ul2) (with `--revision "refs/pr/25"`), and [Co-EdIT](https://huggingface.co/grammarly/coedit-large) are also supported.
|
## Sentence embedding example:
|
||||||
|
|
||||||
## Translation with [MADLAD-400](https://arxiv.org/abs/2309.04662)
|
|
||||||
|
|
||||||
MADLAD-400 is a series of multilingual machine translation T5 models trained on 250 billion tokens covering over 450 languages using publicly available data. These models are competitive with significantly larger models.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cargo run --example t5 --release -- \
|
|
||||||
--model-id "jbochi/madlad400-3b-mt" \
|
|
||||||
--prompt "<2de> How are you, my friend?" \
|
|
||||||
--decode --temperature 0
|
|
||||||
...
|
|
||||||
Wie geht es dir, mein Freund?
|
|
||||||
```
|
|
||||||
|
|
||||||
## Sentence embedding example
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle."
|
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle."
|
||||||
|
@ -104,17 +104,6 @@ impl T5ModelBuilder {
|
|||||||
api.get("model-00004-of-00005.safetensors")?,
|
api.get("model-00004-of-00005.safetensors")?,
|
||||||
api.get("model-00005-of-00005.safetensors")?,
|
api.get("model-00005-of-00005.safetensors")?,
|
||||||
]
|
]
|
||||||
} else if model_id == "google/flan-ul2" {
|
|
||||||
vec![
|
|
||||||
api.get("model-00001-of-00008.safetensors")?,
|
|
||||||
api.get("model-00002-of-00008.safetensors")?,
|
|
||||||
api.get("model-00003-of-00008.safetensors")?,
|
|
||||||
api.get("model-00004-of-00008.safetensors")?,
|
|
||||||
api.get("model-00005-of-00008.safetensors")?,
|
|
||||||
api.get("model-00006-of-00008.safetensors")?,
|
|
||||||
api.get("model-00007-of-00008.safetensors")?,
|
|
||||||
api.get("model-00008-of-00008.safetensors")?,
|
|
||||||
]
|
|
||||||
} else {
|
} else {
|
||||||
vec![api.get("model.safetensors")?]
|
vec![api.get("model.safetensors")?]
|
||||||
};
|
};
|
||||||
@ -183,12 +172,7 @@ fn main() -> Result<()> {
|
|||||||
println!("Took {:?}", start.elapsed());
|
println!("Took {:?}", start.elapsed());
|
||||||
} else {
|
} else {
|
||||||
let mut model = builder.build_conditional_generation()?;
|
let mut model = builder.build_conditional_generation()?;
|
||||||
let mut output_token_ids = [builder
|
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
||||||
.config
|
|
||||||
.decoder_start_token_id
|
|
||||||
.unwrap_or(builder.config.pad_token_id)
|
|
||||||
as u32]
|
|
||||||
.to_vec();
|
|
||||||
if let Some(decoder_prompt) = &args.decoder_prompt {
|
if let Some(decoder_prompt) = &args.decoder_prompt {
|
||||||
print!("{decoder_prompt}");
|
print!("{decoder_prompt}");
|
||||||
output_token_ids.extend(
|
output_token_ids.extend(
|
||||||
|
Binary file not shown.
Before Width: | Height: | Size: 36 KiB |
@ -1,154 +0,0 @@
|
|||||||
use image::{DynamicImage, ImageBuffer};
|
|
||||||
use serde::Deserialize;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
use candle::{DType, Device, Result, Tensor};
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
|
||||||
pub struct ProcessorConfig {
|
|
||||||
do_resize: bool,
|
|
||||||
height: u32,
|
|
||||||
width: u32,
|
|
||||||
do_rescale: bool,
|
|
||||||
do_normalize: bool,
|
|
||||||
image_mean: Vec<f32>,
|
|
||||||
image_std: Vec<f32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for ProcessorConfig {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
do_resize: true,
|
|
||||||
height: 384,
|
|
||||||
width: 384,
|
|
||||||
do_rescale: true,
|
|
||||||
do_normalize: true,
|
|
||||||
image_mean: vec![0.5, 0.5, 0.5],
|
|
||||||
image_std: vec![0.5, 0.5, 0.5],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct ViTImageProcessor {
|
|
||||||
do_resize: bool,
|
|
||||||
height: u32,
|
|
||||||
width: u32,
|
|
||||||
do_normalize: bool,
|
|
||||||
image_mean: Vec<f32>,
|
|
||||||
image_std: Vec<f32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ViTImageProcessor {
|
|
||||||
pub fn new(config: &ProcessorConfig) -> Self {
|
|
||||||
Self {
|
|
||||||
do_resize: config.do_resize,
|
|
||||||
height: config.height,
|
|
||||||
width: config.width,
|
|
||||||
do_normalize: config.do_normalize,
|
|
||||||
image_mean: config.image_mean.clone(),
|
|
||||||
image_std: config.image_std.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn preprocess(&self, images: Vec<&str>) -> Result<Tensor> {
|
|
||||||
let height = self.height as usize;
|
|
||||||
let width = self.width as usize;
|
|
||||||
let channels = 3;
|
|
||||||
|
|
||||||
let images = self.load_images(images)?;
|
|
||||||
|
|
||||||
let resized_images: Vec<DynamicImage> = if self.do_resize {
|
|
||||||
images
|
|
||||||
.iter()
|
|
||||||
.map(|image| self.resize(image.clone(), None).unwrap())
|
|
||||||
.collect()
|
|
||||||
} else {
|
|
||||||
images
|
|
||||||
};
|
|
||||||
|
|
||||||
let normalized_images: Vec<Tensor> = if self.do_normalize {
|
|
||||||
resized_images
|
|
||||||
.iter()
|
|
||||||
.map(|image| self.normalize(image.clone(), None, None).unwrap())
|
|
||||||
.collect()
|
|
||||||
} else {
|
|
||||||
let resized_images: Vec<ImageBuffer<image::Rgb<u8>, Vec<u8>>> =
|
|
||||||
resized_images.iter().map(|image| image.to_rgb8()).collect();
|
|
||||||
let data = resized_images
|
|
||||||
.into_iter()
|
|
||||||
.map(|image| image.into_raw())
|
|
||||||
.collect::<Vec<Vec<u8>>>();
|
|
||||||
|
|
||||||
data.iter()
|
|
||||||
.map(|image| {
|
|
||||||
Tensor::from_vec(image.clone(), (height, width, channels), &Device::Cpu)
|
|
||||||
.unwrap()
|
|
||||||
.permute((2, 0, 1))
|
|
||||||
.unwrap()
|
|
||||||
})
|
|
||||||
.collect::<Vec<Tensor>>()
|
|
||||||
};
|
|
||||||
|
|
||||||
Tensor::stack(&normalized_images, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn resize(
|
|
||||||
&self,
|
|
||||||
image: image::DynamicImage,
|
|
||||||
size: Option<HashMap<String, u32>>,
|
|
||||||
) -> Result<image::DynamicImage> {
|
|
||||||
let (height, width) = match &size {
|
|
||||||
Some(size) => (size.get("height").unwrap(), size.get("width").unwrap()),
|
|
||||||
None => (&self.height, &self.width),
|
|
||||||
};
|
|
||||||
|
|
||||||
let resized_image =
|
|
||||||
image.resize_exact(*width, *height, image::imageops::FilterType::Triangle);
|
|
||||||
|
|
||||||
Ok(resized_image)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn normalize(
|
|
||||||
&self,
|
|
||||||
image: image::DynamicImage,
|
|
||||||
mean: Option<Vec<f32>>,
|
|
||||||
std: Option<Vec<f32>>,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
let mean = match mean {
|
|
||||||
Some(mean) => mean,
|
|
||||||
None => self.image_mean.clone(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let std = match std {
|
|
||||||
Some(std) => std,
|
|
||||||
None => self.image_std.clone(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mean = Tensor::from_vec(mean, (3, 1, 1), &Device::Cpu)?;
|
|
||||||
let std = Tensor::from_vec(std, (3, 1, 1), &Device::Cpu)?;
|
|
||||||
|
|
||||||
let image = image.to_rgb8();
|
|
||||||
let data = image.into_raw();
|
|
||||||
|
|
||||||
let height = self.height as usize;
|
|
||||||
let width = self.width as usize;
|
|
||||||
let channels = 3;
|
|
||||||
|
|
||||||
let data =
|
|
||||||
Tensor::from_vec(data, &[height, width, channels], &Device::Cpu)?.permute((2, 0, 1))?;
|
|
||||||
|
|
||||||
(data.to_dtype(DType::F32)? / 255.)?
|
|
||||||
.broadcast_sub(&mean)?
|
|
||||||
.broadcast_div(&std)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_images(&self, image_path: Vec<&str>) -> Result<Vec<image::DynamicImage>> {
|
|
||||||
let mut images: Vec<image::DynamicImage> = Vec::new();
|
|
||||||
for path in image_path {
|
|
||||||
let img = image::io::Reader::open(path)?.decode().unwrap();
|
|
||||||
images.push(img);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(images)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,132 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use anyhow::Error as E;
|
|
||||||
use clap::{Parser, ValueEnum};
|
|
||||||
|
|
||||||
use candle::{DType, Tensor};
|
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
|
||||||
use candle_nn::VarBuilder;
|
|
||||||
use candle_transformers::models::trocr;
|
|
||||||
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
mod image_processor;
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
|
||||||
enum Which {
|
|
||||||
Base,
|
|
||||||
Large,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
struct Args {
|
|
||||||
#[arg(long)]
|
|
||||||
model: Option<String>,
|
|
||||||
|
|
||||||
/// Choose the variant of the model to run.
|
|
||||||
#[arg(long, default_value = "base")]
|
|
||||||
which: Which,
|
|
||||||
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
/// Text to be translated
|
|
||||||
#[arg(long)]
|
|
||||||
image: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
|
||||||
use hf_hub::api::sync::Api;
|
|
||||||
let args = Args::parse();
|
|
||||||
|
|
||||||
let tokenizer_dec = {
|
|
||||||
let tokenizer = Api::new()?
|
|
||||||
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
|
|
||||||
.get("tokenizer.json")?;
|
|
||||||
|
|
||||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
|
|
||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
|
|
||||||
let vb = {
|
|
||||||
let model = match args.model {
|
|
||||||
Some(model) => std::path::PathBuf::from(model),
|
|
||||||
None => match args.which {
|
|
||||||
Which::Base => Api::new()?
|
|
||||||
.repo(hf_hub::Repo::with_revision(
|
|
||||||
"microsoft/trocr-base-handwritten".to_string(),
|
|
||||||
hf_hub::RepoType::Model,
|
|
||||||
"refs/pr/3".to_string(),
|
|
||||||
))
|
|
||||||
.get("model.safetensors")?,
|
|
||||||
Which::Large => Api::new()?
|
|
||||||
.repo(hf_hub::Repo::with_revision(
|
|
||||||
"microsoft/trocr-large-handwritten".to_string(),
|
|
||||||
hf_hub::RepoType::Model,
|
|
||||||
"refs/pr/6".to_string(),
|
|
||||||
))
|
|
||||||
.get("model.safetensors")?,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
println!("model: {:?}", model);
|
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }
|
|
||||||
};
|
|
||||||
|
|
||||||
let encoder_config = match args.which {
|
|
||||||
Which::Base => candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten(),
|
|
||||||
Which::Large => {
|
|
||||||
candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten()
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let decoder_config = trocr::TrOCRConfig::default();
|
|
||||||
let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?;
|
|
||||||
|
|
||||||
let config = image_processor::ProcessorConfig::default();
|
|
||||||
let processor = image_processor::ViTImageProcessor::new(&config);
|
|
||||||
|
|
||||||
let image = vec![args.image.as_str()];
|
|
||||||
let image = processor.preprocess(image)?;
|
|
||||||
|
|
||||||
let encoder_xs = model.encoder().forward(&image)?;
|
|
||||||
|
|
||||||
let mut logits_processor =
|
|
||||||
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
|
|
||||||
|
|
||||||
let mut token_ids: Vec<u32> = vec![decoder_config.decoder_start_token_id];
|
|
||||||
for index in 0..1000 {
|
|
||||||
let context_size = if index >= 1 { 1 } else { token_ids.len() };
|
|
||||||
let start_pos = token_ids.len().saturating_sub(context_size);
|
|
||||||
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
|
|
||||||
|
|
||||||
let logits = model.decode(&input_ids, &encoder_xs, start_pos)?;
|
|
||||||
|
|
||||||
let logits = logits.squeeze(0)?;
|
|
||||||
let logits = logits.get(logits.dim(0)? - 1)?;
|
|
||||||
let token = logits_processor.sample(&logits)?;
|
|
||||||
token_ids.push(token);
|
|
||||||
|
|
||||||
if let Some(t) = tokenizer_dec.next_token(token)? {
|
|
||||||
use std::io::Write;
|
|
||||||
print!("{t}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
}
|
|
||||||
if token == decoder_config.eos_token_id {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
|
|
||||||
print!("{rest}");
|
|
||||||
}
|
|
||||||
println!();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -1,16 +0,0 @@
|
|||||||
# candle-trocr
|
|
||||||
|
|
||||||
`TrOCR` is a transformer OCR Model. In this example it is used to
|
|
||||||
transcribe image text. See the associated [model
|
|
||||||
card](https://huggingface.co/microsoft/trocr-base-printed) for details on
|
|
||||||
the model itself.
|
|
||||||
|
|
||||||
## Running an example
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cargo run --example trocr --release -- --which base --cpu --image assets/trocr.png
|
|
||||||
```
|
|
||||||
|
|
||||||
```
|
|
||||||
<s> industry , Mr. Brown commented icily . " Let us have a</s>
|
|
||||||
```
|
|
@ -345,7 +345,7 @@ enum Task {
|
|||||||
Translate,
|
Translate,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
enum WhichModel {
|
enum WhichModel {
|
||||||
Tiny,
|
Tiny,
|
||||||
#[value(name = "tiny.en")]
|
#[value(name = "tiny.en")]
|
||||||
@ -361,27 +361,15 @@ enum WhichModel {
|
|||||||
MediumEn,
|
MediumEn,
|
||||||
Large,
|
Large,
|
||||||
LargeV2,
|
LargeV2,
|
||||||
LargeV3,
|
|
||||||
#[value(name = "distil-medium.en")]
|
|
||||||
DistilMediumEn,
|
|
||||||
#[value(name = "distil-large-v2")]
|
|
||||||
DistilLargeV2,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WhichModel {
|
impl WhichModel {
|
||||||
fn is_multilingual(&self) -> bool {
|
fn is_multilingual(&self) -> bool {
|
||||||
match self {
|
match self {
|
||||||
Self::Tiny
|
Self::Tiny | Self::Base | Self::Small | Self::Medium | Self::Large | Self::LargeV2 => {
|
||||||
| Self::Base
|
true
|
||||||
| Self::Small
|
|
||||||
| Self::Medium
|
|
||||||
| Self::Large
|
|
||||||
| Self::LargeV2
|
|
||||||
| Self::LargeV3
|
|
||||||
| Self::DistilLargeV2 => true,
|
|
||||||
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => {
|
|
||||||
false
|
|
||||||
}
|
}
|
||||||
|
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -397,9 +385,6 @@ impl WhichModel {
|
|||||||
Self::MediumEn => ("openai/whisper-medium.en", "main"),
|
Self::MediumEn => ("openai/whisper-medium.en", "main"),
|
||||||
Self::Large => ("openai/whisper-large", "refs/pr/36"),
|
Self::Large => ("openai/whisper-large", "refs/pr/36"),
|
||||||
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
|
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
|
||||||
Self::LargeV3 => ("openai/whisper-large-v3", "main"),
|
|
||||||
Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"),
|
|
||||||
Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -511,25 +496,17 @@ fn main() -> Result<()> {
|
|||||||
repo.get(&format!("model-{ext}-q80.gguf"))?,
|
repo.get(&format!("model-{ext}-q80.gguf"))?,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
let config = repo.get("config.json")?;
|
(
|
||||||
let tokenizer = if args.model == WhichModel::LargeV3 {
|
repo.get("config.json")?,
|
||||||
panic!("openai/whisper-large-v3 does not provide a compatible tokenizer.json config at the moment")
|
repo.get("tokenizer.json")?,
|
||||||
} else {
|
repo.get("model.safetensors")?,
|
||||||
repo.get("tokenizer.json")?
|
)
|
||||||
};
|
|
||||||
let model = repo.get("model.safetensors")?;
|
|
||||||
(config, tokenizer, model)
|
|
||||||
};
|
};
|
||||||
(config, tokenizer, model, sample)
|
(config, tokenizer, model, sample)
|
||||||
};
|
};
|
||||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let mel_bytes = match config.num_mel_bins {
|
let mel_bytes = include_bytes!("melfilters.bytes");
|
||||||
80 => include_bytes!("melfilters.bytes").as_slice(),
|
|
||||||
128 => include_bytes!("melfilters128.bytes").as_slice(),
|
|
||||||
nmel => anyhow::bail!("unexpected num_mel_bins {nmel}"),
|
|
||||||
};
|
|
||||||
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
|
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
|
||||||
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
|
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
|
||||||
|
|
||||||
@ -545,15 +522,12 @@ fn main() -> Result<()> {
|
|||||||
.map(|v| *v as f32 / 32768.)
|
.map(|v| *v as f32 / 32768.)
|
||||||
.collect();
|
.collect();
|
||||||
println!("pcm data loaded {}", pcm_data.len());
|
println!("pcm data loaded {}", pcm_data.len());
|
||||||
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
|
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters);
|
||||||
let mel_len = mel.len();
|
let mel_len = mel.len();
|
||||||
let mel = Tensor::from_vec(
|
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
|
||||||
mel,
|
|
||||||
(1, config.num_mel_bins, mel_len / config.num_mel_bins),
|
|
||||||
&device,
|
|
||||||
)?;
|
|
||||||
println!("loaded mel: {:?}", mel.dims());
|
println!("loaded mel: {:?}", mel.dims());
|
||||||
|
|
||||||
|
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||||
let mut model = if args.quantized {
|
let mut model = if args.quantized {
|
||||||
let vb =
|
let vb =
|
||||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;
|
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;
|
||||||
|
Binary file not shown.
@ -1,268 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
|
||||||
use clap::{Parser, ValueEnum};
|
|
||||||
|
|
||||||
use candle_transformers::models::yi::{Config, Model};
|
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
|
||||||
use candle_nn::VarBuilder;
|
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
|
||||||
enum Which {
|
|
||||||
#[value(name = "6b")]
|
|
||||||
L6b,
|
|
||||||
#[value(name = "34b")]
|
|
||||||
L34b,
|
|
||||||
}
|
|
||||||
|
|
||||||
struct TextGeneration {
|
|
||||||
model: Model,
|
|
||||||
device: Device,
|
|
||||||
tokenizer: TokenOutputStream,
|
|
||||||
logits_processor: LogitsProcessor,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TextGeneration {
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
fn new(
|
|
||||||
model: Model,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
seed: u64,
|
|
||||||
temp: Option<f64>,
|
|
||||||
top_p: Option<f64>,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
device: &Device,
|
|
||||||
) -> Self {
|
|
||||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
|
||||||
Self {
|
|
||||||
model,
|
|
||||||
tokenizer: TokenOutputStream::new(tokenizer),
|
|
||||||
logits_processor,
|
|
||||||
repeat_penalty,
|
|
||||||
repeat_last_n,
|
|
||||||
device: device.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
|
||||||
use std::io::Write;
|
|
||||||
self.tokenizer.clear();
|
|
||||||
let mut tokens = self
|
|
||||||
.tokenizer
|
|
||||||
.tokenizer()
|
|
||||||
.encode(prompt, true)
|
|
||||||
.map_err(E::msg)?
|
|
||||||
.get_ids()
|
|
||||||
.to_vec();
|
|
||||||
for &t in tokens.iter() {
|
|
||||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
|
||||||
print!("{t}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
|
|
||||||
let mut generated_tokens = 0usize;
|
|
||||||
let eos_token = match self.tokenizer.get_token("</s>") {
|
|
||||||
Some(token) => token,
|
|
||||||
None => anyhow::bail!("cannot find the </s> token"),
|
|
||||||
};
|
|
||||||
let start_gen = std::time::Instant::now();
|
|
||||||
for index in 0..sample_len {
|
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
|
||||||
let start_pos = tokens.len().saturating_sub(context_size);
|
|
||||||
let ctxt = &tokens[start_pos..];
|
|
||||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
|
||||||
let logits = self.model.forward(&input, start_pos)?;
|
|
||||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
|
||||||
let logits = if self.repeat_penalty == 1. {
|
|
||||||
logits
|
|
||||||
} else {
|
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
||||||
candle_transformers::utils::apply_repeat_penalty(
|
|
||||||
&logits,
|
|
||||||
self.repeat_penalty,
|
|
||||||
&tokens[start_at..],
|
|
||||||
)?
|
|
||||||
};
|
|
||||||
|
|
||||||
let next_token = self.logits_processor.sample(&logits)?;
|
|
||||||
tokens.push(next_token);
|
|
||||||
generated_tokens += 1;
|
|
||||||
if next_token == eos_token {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
|
||||||
print!("{t}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let dt = start_gen.elapsed();
|
|
||||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
|
||||||
print!("{rest}");
|
|
||||||
}
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
println!(
|
|
||||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
|
||||||
generated_tokens as f64 / dt.as_secs_f64(),
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
|
||||||
#[arg(long)]
|
|
||||||
tracing: bool,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
prompt: String,
|
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
|
||||||
#[arg(long)]
|
|
||||||
temperature: Option<f64>,
|
|
||||||
|
|
||||||
/// Nucleus sampling probability cutoff.
|
|
||||||
#[arg(long)]
|
|
||||||
top_p: Option<f64>,
|
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
|
||||||
#[arg(long, default_value_t = 299792458)]
|
|
||||||
seed: u64,
|
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
|
||||||
#[arg(long, short = 'n', default_value_t = 100)]
|
|
||||||
sample_len: usize,
|
|
||||||
|
|
||||||
#[arg(long, default_value = "01-ai/Yi-6B")]
|
|
||||||
model_id: String,
|
|
||||||
|
|
||||||
#[arg(long, default_value = "main")]
|
|
||||||
revision: String,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
tokenizer_file: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
weight_files: Option<String>,
|
|
||||||
|
|
||||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
|
||||||
#[arg(long, default_value_t = 1.1)]
|
|
||||||
repeat_penalty: f32,
|
|
||||||
|
|
||||||
/// The context size to consider for the repeat penalty.
|
|
||||||
#[arg(long, default_value_t = 64)]
|
|
||||||
repeat_last_n: usize,
|
|
||||||
|
|
||||||
/// The model size to use.
|
|
||||||
#[arg(long, default_value = "6b")]
|
|
||||||
which: Which,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
|
||||||
use tracing_subscriber::prelude::*;
|
|
||||||
|
|
||||||
let args = Args::parse();
|
|
||||||
let _guard = if args.tracing {
|
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
|
||||||
Some(guard)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
println!(
|
|
||||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
|
||||||
candle::utils::with_avx(),
|
|
||||||
candle::utils::with_neon(),
|
|
||||||
candle::utils::with_simd128(),
|
|
||||||
candle::utils::with_f16c()
|
|
||||||
);
|
|
||||||
println!(
|
|
||||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
|
||||||
args.temperature.unwrap_or(0.),
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n
|
|
||||||
);
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let api = Api::new()?;
|
|
||||||
let repo = api.repo(Repo::with_revision(
|
|
||||||
args.model_id,
|
|
||||||
RepoType::Model,
|
|
||||||
args.revision,
|
|
||||||
));
|
|
||||||
let tokenizer_filename = match args.tokenizer_file {
|
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
|
||||||
None => repo.get("tokenizer.json")?,
|
|
||||||
};
|
|
||||||
let filenames = match args.weight_files {
|
|
||||||
Some(files) => files
|
|
||||||
.split(',')
|
|
||||||
.map(std::path::PathBuf::from)
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
None => match args.which {
|
|
||||||
Which::L6b => vec![
|
|
||||||
repo.get("model-00001-of-00002.safetensors")?,
|
|
||||||
repo.get("model-00002-of-00002.safetensors")?,
|
|
||||||
],
|
|
||||||
Which::L34b => vec![
|
|
||||||
repo.get("model-00001-of-00007.safetensors")?,
|
|
||||||
repo.get("model-00002-of-00007.safetensors")?,
|
|
||||||
repo.get("model-00003-of-00007.safetensors")?,
|
|
||||||
repo.get("model-00004-of-00007.safetensors")?,
|
|
||||||
repo.get("model-00005-of-00007.safetensors")?,
|
|
||||||
repo.get("model-00006-of-00007.safetensors")?,
|
|
||||||
repo.get("model-00007-of-00007.safetensors")?,
|
|
||||||
],
|
|
||||||
},
|
|
||||||
};
|
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let config = match args.which {
|
|
||||||
Which::L6b => Config::config_6b(),
|
|
||||||
Which::L34b => Config::config_34b(),
|
|
||||||
};
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let dtype = if device.is_cuda() {
|
|
||||||
DType::BF16
|
|
||||||
} else {
|
|
||||||
DType::F32
|
|
||||||
};
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
|
||||||
let model = Model::new(&config, vb)?;
|
|
||||||
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
|
||||||
|
|
||||||
let mut pipeline = TextGeneration::new(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
args.seed,
|
|
||||||
args.temperature,
|
|
||||||
args.top_p,
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n,
|
|
||||||
&device,
|
|
||||||
);
|
|
||||||
pipeline.run(&args.prompt, args.sample_len)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -8,22 +8,24 @@ use candle::{Device, Result, Tensor};
|
|||||||
pub fn device(cpu: bool) -> Result<Device> {
|
pub fn device(cpu: bool) -> Result<Device> {
|
||||||
if cpu {
|
if cpu {
|
||||||
Ok(Device::Cpu)
|
Ok(Device::Cpu)
|
||||||
} else if cuda_is_available() {
|
|
||||||
Ok(Device::new_cuda(0)?)
|
|
||||||
} else if metal_is_available() {
|
|
||||||
Ok(Device::new_metal(0)?)
|
|
||||||
} else {
|
} else {
|
||||||
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
if cuda_is_available() {
|
||||||
{
|
Ok(Device::new_cuda(0)?)
|
||||||
println!(
|
} else if metal_is_available() {
|
||||||
"Running on CPU, to run on GPU(metal), build this example with `--features metal`"
|
Ok(Device::new_metal(0)?)
|
||||||
);
|
} else {
|
||||||
|
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||||
|
{
|
||||||
|
println!("Running on CPU, to run on GPU(metal), build this example with `--features metal`");
|
||||||
|
}
|
||||||
|
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
|
||||||
|
{
|
||||||
|
println!(
|
||||||
|
"Running on CPU, to run on GPU, build this example with `--features cuda`"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(Device::Cpu)
|
||||||
}
|
}
|
||||||
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
|
|
||||||
{
|
|
||||||
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
|
|
||||||
}
|
|
||||||
Ok(Device::Cpu)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -233,8 +233,8 @@ impl FlashAttnVarLen {
|
|||||||
|
|
||||||
let (seqlens_q, seqlens_q_layout) = self.seqlens_q.storage_and_layout();
|
let (seqlens_q, seqlens_q_layout) = self.seqlens_q.storage_and_layout();
|
||||||
let seqlens_q = match &*seqlens_q {
|
let seqlens_q = match &*seqlens_q {
|
||||||
|
candle::Storage::Cpu(_) => candle::bail!("seqlens_q must be a cuda tensor"),
|
||||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
|
||||||
_ => candle::bail!("seqlens_q must be a cuda tensor"),
|
|
||||||
};
|
};
|
||||||
let seqlens_q = match seqlens_q_layout.contiguous_offsets() {
|
let seqlens_q = match seqlens_q_layout.contiguous_offsets() {
|
||||||
Some((o1, o2)) => seqlens_q.slice(o1..o2),
|
Some((o1, o2)) => seqlens_q.slice(o1..o2),
|
||||||
@ -243,8 +243,8 @@ impl FlashAttnVarLen {
|
|||||||
|
|
||||||
let (seqlens_k, seqlens_k_layout) = self.seqlens_k.storage_and_layout();
|
let (seqlens_k, seqlens_k_layout) = self.seqlens_k.storage_and_layout();
|
||||||
let seqlens_k = match &*seqlens_k {
|
let seqlens_k = match &*seqlens_k {
|
||||||
|
candle::Storage::Cpu(_) => candle::bail!("seqlens_k must be a cuda tensor"),
|
||||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
|
||||||
_ => candle::bail!("seqlens_k must be a cuda tensor"),
|
|
||||||
};
|
};
|
||||||
let seqlens_k = match seqlens_k_layout.contiguous_offsets() {
|
let seqlens_k = match seqlens_k_layout.contiguous_offsets() {
|
||||||
Some((o1, o2)) => seqlens_k.slice(o1..o2),
|
Some((o1, o2)) => seqlens_k.slice(o1..o2),
|
||||||
|
@ -1,20 +1,12 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-metal-kernels"
|
name = "candle-metal-kernels"
|
||||||
version = "0.3.0"
|
version.workspace = true
|
||||||
edition = "2021"
|
edition.workspace = true
|
||||||
|
description.workspace = true
|
||||||
description = "CUDA kernels for Candle"
|
repository.workspace = true
|
||||||
repository = "https://github.com/huggingface/candle"
|
keywords.workspace = true
|
||||||
keywords = ["blas", "tensor", "machine-learning"]
|
categories.workspace = true
|
||||||
categories = ["science"]
|
license.workspace = true
|
||||||
license = "MIT OR Apache-2.0"
|
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
metal = { workspace = true }
|
||||||
once_cell = "1.18.0"
|
|
||||||
thiserror = "1"
|
|
||||||
tracing = "0.1.37"
|
|
||||||
|
|
||||||
[dev-dependencies]
|
|
||||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
|
||||||
rand = "0.8.5"
|
|
||||||
|
@ -1,61 +0,0 @@
|
|||||||
#include <metal_stdlib>
|
|
||||||
|
|
||||||
METAL_FUNC uint get_strided_index(
|
|
||||||
uint idx,
|
|
||||||
constant size_t &num_dims,
|
|
||||||
constant size_t *dims,
|
|
||||||
constant size_t *strides
|
|
||||||
) {
|
|
||||||
uint strided_i = 0;
|
|
||||||
for (uint d = 0; d < num_dims; d++) {
|
|
||||||
uint dim_idx = num_dims - 1 - d;
|
|
||||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
|
||||||
idx /= dims[dim_idx];
|
|
||||||
}
|
|
||||||
return strided_i;
|
|
||||||
}
|
|
||||||
|
|
||||||
using namespace metal;
|
|
||||||
|
|
||||||
#define AFFINE(FN_NAME, TYPENAME) \
|
|
||||||
kernel void FN_NAME( \
|
|
||||||
constant size_t &dim, \
|
|
||||||
constant float &mul, \
|
|
||||||
constant float &add, \
|
|
||||||
device const TYPENAME *input, \
|
|
||||||
device TYPENAME *output, \
|
|
||||||
uint id [[ thread_position_in_grid ]] \
|
|
||||||
) { \
|
|
||||||
if (id >= dim) { \
|
|
||||||
return; \
|
|
||||||
} \
|
|
||||||
const TYPENAME m = TYPENAME(mul); \
|
|
||||||
const TYPENAME a = TYPENAME(add); \
|
|
||||||
output[id] = input[id] * m + a; \
|
|
||||||
} \
|
|
||||||
kernel void FN_NAME##_strided( \
|
|
||||||
constant size_t &dim, \
|
|
||||||
constant size_t &num_dims, \
|
|
||||||
constant size_t *dims, \
|
|
||||||
constant size_t *strides, \
|
|
||||||
constant float &mul, \
|
|
||||||
constant float &add, \
|
|
||||||
device const TYPENAME *input, \
|
|
||||||
device TYPENAME *output, \
|
|
||||||
uint id [[ thread_position_in_grid ]] \
|
|
||||||
) { \
|
|
||||||
if (id >= dim) { \
|
|
||||||
return; \
|
|
||||||
} \
|
|
||||||
const TYPENAME m = TYPENAME(mul); \
|
|
||||||
const TYPENAME a = TYPENAME(add); \
|
|
||||||
output[id] = input[get_strided_index(id, num_dims, dims, strides)] * m + a; \
|
|
||||||
} \
|
|
||||||
|
|
||||||
AFFINE(affine_float, float)
|
|
||||||
AFFINE(affine_half, half)
|
|
||||||
|
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
|
||||||
AFFINE(affine_bfloat, bfloat);
|
|
||||||
#endif
|
|
@ -1,72 +0,0 @@
|
|||||||
#include <metal_stdlib>
|
|
||||||
|
|
||||||
METAL_FUNC uint get_strided_index(
|
|
||||||
uint idx,
|
|
||||||
constant size_t &num_dims,
|
|
||||||
constant size_t *dims,
|
|
||||||
constant size_t *strides
|
|
||||||
) {
|
|
||||||
uint strided_i = 0;
|
|
||||||
for (uint d = 0; d < num_dims; d++) {
|
|
||||||
uint dim_idx = num_dims - 1 - d;
|
|
||||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
|
||||||
idx /= dims[dim_idx];
|
|
||||||
}
|
|
||||||
return strided_i;
|
|
||||||
}
|
|
||||||
|
|
||||||
using namespace metal;
|
|
||||||
|
|
||||||
#define BINARY(FN, TYPENAME, OUT_TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
|
||||||
kernel void FN_NAME( \
|
|
||||||
constant size_t &dim, \
|
|
||||||
device const TYPENAME *left, \
|
|
||||||
device const TYPENAME *right, \
|
|
||||||
device TYPENAME *output, \
|
|
||||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
|
||||||
) { \
|
|
||||||
if (thread_position_in_grid >= dim) { \
|
|
||||||
return; \
|
|
||||||
} \
|
|
||||||
TYPENAME x = left[thread_position_in_grid]; \
|
|
||||||
TYPENAME y = right[thread_position_in_grid]; \
|
|
||||||
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
|
|
||||||
}\
|
|
||||||
kernel void FN_NAME_STRIDED( \
|
|
||||||
constant size_t &dim, \
|
|
||||||
constant size_t &num_dims, \
|
|
||||||
constant size_t *dims, \
|
|
||||||
constant size_t *left_strides, \
|
|
||||||
constant size_t *right_strides, \
|
|
||||||
device const TYPENAME *left, \
|
|
||||||
device const TYPENAME *right, \
|
|
||||||
device TYPENAME *output, \
|
|
||||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
|
||||||
) { \
|
|
||||||
if (thread_position_in_grid >= dim) { \
|
|
||||||
return; \
|
|
||||||
} \
|
|
||||||
TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \
|
|
||||||
TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, right_strides)]; \
|
|
||||||
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define BINARY_OP(FN, NAME) \
|
|
||||||
BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \
|
|
||||||
BINARY(FN, half, half, NAME##_half, NAME##_half_strided);
|
|
||||||
|
|
||||||
#define BFLOAT_BINARY_OP(FN, NAME) \
|
|
||||||
BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
|
|
||||||
|
|
||||||
|
|
||||||
BINARY_OP(x + y, add)
|
|
||||||
BINARY_OP(x - y, sub)
|
|
||||||
BINARY_OP(x * y, mul)
|
|
||||||
BINARY_OP(x / y, div)
|
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
|
||||||
BFLOAT_BINARY_OP(x + y, add)
|
|
||||||
BFLOAT_BINARY_OP(x - y, sub)
|
|
||||||
BFLOAT_BINARY_OP(x * y, mul)
|
|
||||||
BFLOAT_BINARY_OP(x / y, div)
|
|
||||||
#endif
|
|
@ -1,53 +0,0 @@
|
|||||||
#include <metal_stdlib>
|
|
||||||
|
|
||||||
METAL_FUNC uint get_strided_index(
|
|
||||||
uint idx,
|
|
||||||
constant size_t &num_dims,
|
|
||||||
constant size_t *dims,
|
|
||||||
constant size_t *strides
|
|
||||||
) {
|
|
||||||
uint strided_i = 0;
|
|
||||||
for (uint d = 0; d < num_dims; d++) {
|
|
||||||
uint dim_idx = num_dims - 1 - d;
|
|
||||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
|
||||||
idx /= dims[dim_idx];
|
|
||||||
}
|
|
||||||
return strided_i;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
using namespace metal;
|
|
||||||
|
|
||||||
#define CAST(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME) \
|
|
||||||
kernel void FN_NAME( \
|
|
||||||
constant size_t &dim, \
|
|
||||||
device const LEFT_TYPENAME *input, \
|
|
||||||
device RIGHT_TYPENAME *output, \
|
|
||||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
|
||||||
) { \
|
|
||||||
if (thread_position_in_grid >= dim) { \
|
|
||||||
return; \
|
|
||||||
} \
|
|
||||||
output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \
|
|
||||||
} \
|
|
||||||
kernel void FN_NAME_STRIDED( \
|
|
||||||
constant size_t &dim, \
|
|
||||||
constant size_t &num_dims, \
|
|
||||||
constant size_t *dims, \
|
|
||||||
constant size_t *strides, \
|
|
||||||
device const LEFT_TYPENAME *input, \
|
|
||||||
device RIGHT_TYPENAME *output, \
|
|
||||||
uint i [[ thread_position_in_grid ]] \
|
|
||||||
) { \
|
|
||||||
if (i >= dim) { \
|
|
||||||
return; \
|
|
||||||
} \
|
|
||||||
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
|
|
||||||
} \
|
|
||||||
|
|
||||||
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
|
|
||||||
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
|
|
||||||
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
|
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
|
||||||
#endif
|
|
@ -1,103 +0,0 @@
|
|||||||
#include <metal_stdlib>
|
|
||||||
using namespace metal;
|
|
||||||
|
|
||||||
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
|
||||||
kernel void NAME( \
|
|
||||||
constant size_t &dst_size, \
|
|
||||||
constant size_t &left_size, \
|
|
||||||
constant size_t &src_dim_size, \
|
|
||||||
constant size_t &right_size, \
|
|
||||||
constant size_t &ids_size, \
|
|
||||||
const device TYPENAME *input, \
|
|
||||||
const device INDEX_TYPENAME *input_ids, \
|
|
||||||
device TYPENAME *output, \
|
|
||||||
uint gid [[ thread_position_in_grid ]] \
|
|
||||||
) { \
|
|
||||||
if (gid >= dst_size) { \
|
|
||||||
return; \
|
|
||||||
} \
|
|
||||||
const size_t id_i = (gid / right_size) % ids_size; \
|
|
||||||
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
|
|
||||||
const size_t right_rank_i = gid % right_size; \
|
|
||||||
const size_t left_rank_i = gid / right_size / ids_size; \
|
|
||||||
/* \
|
|
||||||
// Force prevent out of bounds indexing \
|
|
||||||
// since there doesn't seem to be a good way to force crash \
|
|
||||||
// No need to check for zero we're only allowing unsized. \
|
|
||||||
*/ \
|
|
||||||
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \
|
|
||||||
output[gid] = input[src_i]; \
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T, typename I>
|
|
||||||
void index_add(
|
|
||||||
device I *ids [[buffer(0)]],
|
|
||||||
device T *inp [[buffer(1)]],
|
|
||||||
device T *out [[buffer(2)]],
|
|
||||||
|
|
||||||
constant uint &ids_dim_size,
|
|
||||||
constant uint &left_size,
|
|
||||||
constant uint &dst_dim_size,
|
|
||||||
constant uint &right_size,
|
|
||||||
|
|
||||||
uint gid [[ thread_position_in_grid ]] \
|
|
||||||
) {
|
|
||||||
|
|
||||||
if (gid >= left_size * right_size) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const uint i = gid;
|
|
||||||
const uint pre = i / right_size;
|
|
||||||
const uint post = i % right_size;
|
|
||||||
|
|
||||||
for (uint j = 0; j < ids_dim_size; j++) {
|
|
||||||
const uint idx = ids[j];
|
|
||||||
const uint src_i = (pre * ids_dim_size + j) * right_size + post;
|
|
||||||
const uint dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
|
||||||
out[dst_i] += inp[src_i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
|
||||||
kernel void FN_NAME( \
|
|
||||||
device INDEX_TYPENAME *ids [[buffer(0)]], \
|
|
||||||
device TYPENAME *inp [[buffer(1)]], \
|
|
||||||
device TYPENAME *out [[buffer(2)]], \
|
|
||||||
constant uint &ids_dim_size, \
|
|
||||||
constant uint &left_size, \
|
|
||||||
constant uint &dst_dim_size, \
|
|
||||||
constant uint &right_size, \
|
|
||||||
uint gid [[ thread_position_in_grid ]] \
|
|
||||||
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \
|
|
||||||
|
|
||||||
|
|
||||||
INDEX_OP(is_u32_f32, uint, float)
|
|
||||||
INDEX_OP(is_u32_f16, uint, half)
|
|
||||||
|
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
|
||||||
IA_OP(bfloat, int64_t, ia_i64_bf16)
|
|
||||||
IA_OP(bfloat, uint32_t, ia_u32_bf16)
|
|
||||||
IA_OP(bfloat, uint8_t, ia_u8_bf16)
|
|
||||||
#endif
|
|
||||||
|
|
||||||
IA_OP(half, uint32_t, ia_u32_f16)
|
|
||||||
IA_OP(half, uint8_t, ia_u8_f16)
|
|
||||||
|
|
||||||
IA_OP(float, int64_t, ia_i64_f32)
|
|
||||||
IA_OP(uint8_t, int64_t, ia_i64_u8)
|
|
||||||
IA_OP(int64_t, int64_t, ia_i64_i64)
|
|
||||||
IA_OP(uint32_t, int64_t, ia_i64_u32)
|
|
||||||
|
|
||||||
IA_OP(float, uint32_t, ia_u32_f32)
|
|
||||||
IA_OP(uint8_t, uint32_t, ia_u32_u8)
|
|
||||||
IA_OP(int64_t, uint32_t, ia_u32_i64)
|
|
||||||
IA_OP(uint32_t, uint32_t, ia_u32_u32)
|
|
||||||
|
|
||||||
IA_OP(float, uint8_t, ia_u8_f32)
|
|
||||||
IA_OP(uint8_t, uint8_t, ia_u8_u8)
|
|
||||||
IA_OP(uint32_t, uint8_t, ia_u8_u32)
|
|
||||||
IA_OP(int64_t, uint8_t, ia_u8_i64)
|
|
File diff suppressed because it is too large
Load Diff
@ -1,139 +0,0 @@
|
|||||||
#include <metal_stdlib>
|
|
||||||
using namespace metal;
|
|
||||||
|
|
||||||
METAL_FUNC uint get_strided_index(
|
|
||||||
uint idx,
|
|
||||||
constant size_t &num_dims,
|
|
||||||
constant size_t *dims,
|
|
||||||
constant size_t *strides
|
|
||||||
) {
|
|
||||||
uint strided_i = 0;
|
|
||||||
for (uint d = 0; d < num_dims; d++) {
|
|
||||||
uint dim_idx = num_dims - 1 - d;
|
|
||||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
|
||||||
idx /= dims[dim_idx];
|
|
||||||
}
|
|
||||||
return strided_i;
|
|
||||||
}
|
|
||||||
|
|
||||||
constant int THREADGROUP_SIZE = 1024;
|
|
||||||
|
|
||||||
# define REDUCE(FN, NAME, TYPENAME) \
|
|
||||||
kernel void NAME( \
|
|
||||||
constant size_t &src_numel, \
|
|
||||||
constant size_t &el_to_sum_per_block, \
|
|
||||||
device const TYPENAME *src, \
|
|
||||||
device TYPENAME *dst, \
|
|
||||||
uint id [[ thread_position_in_grid ]], \
|
|
||||||
uint tid [[ thread_index_in_threadgroup ]], \
|
|
||||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
|
||||||
uint blockDim [[ threads_per_threadgroup ]] \
|
|
||||||
) { \
|
|
||||||
\
|
|
||||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
|
||||||
\
|
|
||||||
shared_memory[tid] = 0; \
|
|
||||||
/* \
|
|
||||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
|
||||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
|
||||||
*/ \
|
|
||||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
|
||||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
|
||||||
size_t idx = start_idx + tid; \
|
|
||||||
while (idx < stop_idx) { \
|
|
||||||
/* \
|
|
||||||
// TODO: Fast version for the contiguous case. \
|
|
||||||
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
|
||||||
*/ \
|
|
||||||
TYPENAME x = shared_memory[tid]; \
|
|
||||||
TYPENAME y = src[idx]; \
|
|
||||||
shared_memory[tid] = FN; \
|
|
||||||
idx += blockDim; \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
threadgroup_barrier(mem_flags::mem_none); \
|
|
||||||
\
|
|
||||||
/* \
|
|
||||||
// reduction in shared memory \
|
|
||||||
*/ \
|
|
||||||
for (uint s = blockDim / 2; s > 0; s >>= 1) { \
|
|
||||||
if (tid < s) { \
|
|
||||||
TYPENAME x = shared_memory[tid]; \
|
|
||||||
TYPENAME y = shared_memory[tid + s]; \
|
|
||||||
shared_memory[tid] = FN; \
|
|
||||||
} \
|
|
||||||
threadgroup_barrier(mem_flags::mem_none); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
dst[dst_id] = shared_memory[0]; \
|
|
||||||
} \
|
|
||||||
|
|
||||||
kernel void softmax_float(
|
|
||||||
constant size_t &src_numel,
|
|
||||||
constant size_t &el_to_sum_per_block,
|
|
||||||
device const float *src,
|
|
||||||
device float *dst,
|
|
||||||
uint id [[ thread_position_in_grid ]],
|
|
||||||
uint tid [[ thread_index_in_threadgroup ]],
|
|
||||||
uint dst_id [[ threadgroup_position_in_grid ]],
|
|
||||||
uint blockDim [[ threads_per_threadgroup ]]
|
|
||||||
) {
|
|
||||||
|
|
||||||
threadgroup float shared_memory[THREADGROUP_SIZE];
|
|
||||||
|
|
||||||
shared_memory[tid] = -INFINITY;
|
|
||||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
|
||||||
// to (dst_id + 1) * el_to_sum_per_block.
|
|
||||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
|
||||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
|
||||||
size_t idx = start_idx + tid;
|
|
||||||
|
|
||||||
while (idx < stop_idx) {
|
|
||||||
// TODO: Fast version for the contiguous case.
|
|
||||||
shared_memory[tid] = max(shared_memory[tid], src[idx]);
|
|
||||||
idx += blockDim;
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
// reduction in shared memory
|
|
||||||
for (uint s = blockDim / 2; s > 0; s >>= 1) {
|
|
||||||
if (tid < s) {
|
|
||||||
shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]);
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
}
|
|
||||||
|
|
||||||
float max = shared_memory[0];
|
|
||||||
|
|
||||||
shared_memory[tid] = 0;
|
|
||||||
|
|
||||||
// Restart
|
|
||||||
idx = start_idx + tid;
|
|
||||||
while (idx < stop_idx) {
|
|
||||||
// TODO: Fast version for the contiguous case.
|
|
||||||
const float val = exp(src[idx] - max);
|
|
||||||
dst[idx] = val;
|
|
||||||
shared_memory[tid] += val;
|
|
||||||
idx += blockDim;
|
|
||||||
}
|
|
||||||
// reduction in shared memory
|
|
||||||
for (uint s = blockDim / 2; s > 0; s >>= 1) {
|
|
||||||
if (tid < s) {
|
|
||||||
shared_memory[tid] += shared_memory[tid + s];
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
}
|
|
||||||
|
|
||||||
const float inv_acc = 1/shared_memory[0];
|
|
||||||
idx = start_idx + tid;
|
|
||||||
while (idx < stop_idx) {
|
|
||||||
dst[idx] *= inv_acc;
|
|
||||||
idx += blockDim;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
REDUCE(x + y, fast_sum_float, float)
|
|
||||||
REDUCE(x * y, fast_mul_float, float)
|
|
||||||
REDUCE(max(x, y), fast_max_float, float)
|
|
@ -1,57 +0,0 @@
|
|||||||
#include <metal_stdlib>
|
|
||||||
#
|
|
||||||
using namespace metal;
|
|
||||||
|
|
||||||
METAL_FUNC uint get_strided_index(
|
|
||||||
uint idx,
|
|
||||||
constant size_t &num_dims,
|
|
||||||
constant size_t *dims,
|
|
||||||
constant size_t *strides
|
|
||||||
) {
|
|
||||||
uint strided_i = 0;
|
|
||||||
for (uint d = 0; d < num_dims; d++) {
|
|
||||||
uint dim_idx = num_dims - 1 - d;
|
|
||||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
|
||||||
idx /= dims[dim_idx];
|
|
||||||
}
|
|
||||||
return strided_i;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \
|
|
||||||
kernel void FN_NAME( \
|
|
||||||
constant size_t &numel, \
|
|
||||||
constant size_t &num_dims, \
|
|
||||||
constant size_t *dims, \
|
|
||||||
constant size_t *strides, \
|
|
||||||
constant size_t *strides_t, \
|
|
||||||
constant size_t *strides_f, \
|
|
||||||
device const ID_TYPENAME *ids, \
|
|
||||||
device const TYPENAME *t, \
|
|
||||||
device const TYPENAME *f, \
|
|
||||||
device TYPENAME *out ,\
|
|
||||||
uint i [[ thread_position_in_grid ]] \
|
|
||||||
) { \
|
|
||||||
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
|
|
||||||
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
|
|
||||||
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
|
|
||||||
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \
|
|
||||||
} \
|
|
||||||
|
|
||||||
// WHERE_OP(float, int64_t, where_i64_f32)
|
|
||||||
// WHERE_OP(double, int64_t, where_i64_f64)
|
|
||||||
// WHERE_OP(uint8_t, int64_t, where_i64_u8)
|
|
||||||
// WHERE_OP(uint32_t, int64_t, where_i64_u32)
|
|
||||||
// WHERE_OP(int64_t, int64_t, where_i64_i64)
|
|
||||||
//
|
|
||||||
// WHERE_OP(float, uint32_t, where_u32_f32)
|
|
||||||
// WHERE_OP(double, uint32_t, where_u32_f64)
|
|
||||||
// WHERE_OP(uint8_t, uint32_t, where_u32_u8)
|
|
||||||
// WHERE_OP(uint32_t, uint32_t, where_u32_u32)
|
|
||||||
// WHERE_OP(int64_t, uint32_t, where_u32_i64)
|
|
||||||
|
|
||||||
WHERE_OP(float, uint8_t, where_u8_f32)
|
|
||||||
// WHERE_OP(double, uint8_t, where_u8_f64)
|
|
||||||
// WHERE_OP(uint8_t, uint8_t, where_u8_u8)
|
|
||||||
// WHERE_OP(uint32_t, uint8_t, where_u8_u32)
|
|
||||||
// WHERE_OP(int64_t, uint8_t, where_u8_i64)
|
|
@ -1,126 +0,0 @@
|
|||||||
#include <metal_stdlib>
|
|
||||||
#include <metal_math>
|
|
||||||
#
|
|
||||||
using namespace metal;
|
|
||||||
|
|
||||||
METAL_FUNC uint get_strided_index(
|
|
||||||
uint idx,
|
|
||||||
constant size_t &num_dims,
|
|
||||||
constant size_t *dims,
|
|
||||||
constant size_t *strides
|
|
||||||
) {
|
|
||||||
uint strided_i = 0;
|
|
||||||
for (uint d = 0; d < num_dims; d++) {
|
|
||||||
uint dim_idx = num_dims - 1 - d;
|
|
||||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
|
||||||
idx /= dims[dim_idx];
|
|
||||||
}
|
|
||||||
return strided_i;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
|
|
||||||
template <typename T> METAL_FUNC T neg(T in){ return -in; }
|
|
||||||
template <typename T> METAL_FUNC T erf(T in){
|
|
||||||
float x = (float) in;
|
|
||||||
// constants
|
|
||||||
float a1 = 0.254829592;
|
|
||||||
float a2 = -0.284496736;
|
|
||||||
float a3 = 1.421413741;
|
|
||||||
float a4 = -1.453152027;
|
|
||||||
float a5 = 1.061405429;
|
|
||||||
float p = 0.3275911;
|
|
||||||
|
|
||||||
// Save the sign of x
|
|
||||||
int sign = 1;
|
|
||||||
if (x < 0)
|
|
||||||
sign = -1;
|
|
||||||
x = fabs(x);
|
|
||||||
|
|
||||||
// A&S formula 7.1.26
|
|
||||||
float t = 1.0/(1.0 + p*x);
|
|
||||||
float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x);
|
|
||||||
|
|
||||||
return T(sign*y);
|
|
||||||
}
|
|
||||||
template <typename T> METAL_FUNC T id(T in){ return in; }
|
|
||||||
template <typename T> METAL_FUNC T gelu_erf(T x){ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); }
|
|
||||||
template <typename T> METAL_FUNC T gelu(T x){
|
|
||||||
T x_sq = x * x;
|
|
||||||
T x_cube = x_sq * x;
|
|
||||||
T alpha = x + static_cast<T>(0.044715) * x_cube;
|
|
||||||
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
|
|
||||||
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
|
||||||
kernel void FN_NAME( \
|
|
||||||
constant size_t &dim, \
|
|
||||||
device const TYPENAME *input, \
|
|
||||||
device TYPENAME *output, \
|
|
||||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
|
||||||
) { \
|
|
||||||
if (thread_position_in_grid >= dim) { \
|
|
||||||
return; \
|
|
||||||
} \
|
|
||||||
output[thread_position_in_grid] = TYPENAME(FN(input[thread_position_in_grid])); \
|
|
||||||
}\
|
|
||||||
kernel void FN_NAME_STRIDED( \
|
|
||||||
constant size_t &dim, \
|
|
||||||
constant size_t &num_dims, \
|
|
||||||
constant size_t *dims, \
|
|
||||||
constant size_t *strides, \
|
|
||||||
device const TYPENAME *input, \
|
|
||||||
device TYPENAME *output, \
|
|
||||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
|
||||||
) { \
|
|
||||||
if (thread_position_in_grid >= dim) { \
|
|
||||||
return; \
|
|
||||||
} \
|
|
||||||
output[thread_position_in_grid] = TYPENAME(FN(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)])); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define UNARY_OP(NAME) \
|
|
||||||
UNARY(NAME, float, NAME##_float, NAME##_float_strided); \
|
|
||||||
UNARY(NAME, half, NAME##_half, NAME##_half_strided);
|
|
||||||
|
|
||||||
#define BFLOAT_UNARY_OP(NAME) \
|
|
||||||
UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
|
|
||||||
|
|
||||||
|
|
||||||
UNARY_OP(cos)
|
|
||||||
UNARY_OP(sin)
|
|
||||||
UNARY_OP(sqr)
|
|
||||||
UNARY_OP(sqrt)
|
|
||||||
UNARY_OP(neg)
|
|
||||||
UNARY_OP(exp)
|
|
||||||
UNARY_OP(log)
|
|
||||||
UNARY_OP(gelu)
|
|
||||||
UNARY_OP(ceil)
|
|
||||||
UNARY_OP(floor)
|
|
||||||
UNARY_OP(round)
|
|
||||||
UNARY_OP(gelu_erf)
|
|
||||||
UNARY_OP(erf)
|
|
||||||
UNARY(id, float, copy_float, copy_float_strided)
|
|
||||||
UNARY(id, half, copy_half, copy_half_strided)
|
|
||||||
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
|
||||||
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
|
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
|
||||||
BFLOAT_UNARY_OP(cos)
|
|
||||||
BFLOAT_UNARY_OP(sin)
|
|
||||||
BFLOAT_UNARY_OP(sqr)
|
|
||||||
BFLOAT_UNARY_OP(sqrt)
|
|
||||||
BFLOAT_UNARY_OP(neg)
|
|
||||||
BFLOAT_UNARY_OP(exp)
|
|
||||||
BFLOAT_UNARY_OP(log)
|
|
||||||
BFLOAT_UNARY_OP(gelu)
|
|
||||||
BFLOAT_UNARY_OP(ceil)
|
|
||||||
BFLOAT_UNARY_OP(floor)
|
|
||||||
BFLOAT_UNARY_OP(round)
|
|
||||||
BFLOAT_UNARY_OP(gelu_erf)
|
|
||||||
BFLOAT_UNARY_OP(erf)
|
|
||||||
|
|
||||||
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
|
|
||||||
#endif
|
|
@ -1,76 +0,0 @@
|
|||||||
use candle_metal_kernels::{call_affine, Kernels};
|
|
||||||
use metal::objc::rc::autoreleasepool;
|
|
||||||
use metal::{Device, MTLResourceOptions};
|
|
||||||
use rand;
|
|
||||||
use std::any::type_name;
|
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
fn main() {
|
|
||||||
let device = Device::system_default().unwrap();
|
|
||||||
let kernels = Kernels::new();
|
|
||||||
|
|
||||||
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
|
|
||||||
let f32_10k = (0..10000)
|
|
||||||
.map(|_| rand::random::<f32>())
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
let f32_100k = (0..100000)
|
|
||||||
.map(|_| rand::random::<f32>())
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
println!(
|
|
||||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
|
|
||||||
"dtype", "kernel", "size", "runs", "total time", "avg time"
|
|
||||||
);
|
|
||||||
|
|
||||||
// f32
|
|
||||||
run_affine_bench(&device, &kernels, &f32_1k);
|
|
||||||
run_affine_bench(&device, &kernels, &f32_10k);
|
|
||||||
run_affine_bench(&device, &kernels, &f32_100k);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) {
|
|
||||||
let command_queue = device.new_command_queue();
|
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
|
||||||
|
|
||||||
let iterations = 10000;
|
|
||||||
let input = device.new_buffer_with_data(
|
|
||||||
v.as_ptr() as *const core::ffi::c_void,
|
|
||||||
core::mem::size_of_val(v) as u64,
|
|
||||||
options,
|
|
||||||
);
|
|
||||||
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
|
|
||||||
|
|
||||||
let mul: f32 = 1.2345;
|
|
||||||
let add: f32 = 2.3456;
|
|
||||||
let total_time = autoreleasepool(|| {
|
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
|
||||||
let start = Instant::now();
|
|
||||||
for _ in 0..iterations {
|
|
||||||
call_affine(
|
|
||||||
&device,
|
|
||||||
command_buffer,
|
|
||||||
&kernels,
|
|
||||||
"affine_float",
|
|
||||||
v.len(),
|
|
||||||
&input,
|
|
||||||
&mut output,
|
|
||||||
mul,
|
|
||||||
add,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
|
|
||||||
start.elapsed()
|
|
||||||
});
|
|
||||||
println!(
|
|
||||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
|
||||||
type_name::<T>().split("::").last().unwrap(),
|
|
||||||
"affine",
|
|
||||||
v.len(),
|
|
||||||
iterations,
|
|
||||||
total_time,
|
|
||||||
total_time / iterations
|
|
||||||
);
|
|
||||||
}
|
|
@ -1,182 +0,0 @@
|
|||||||
use candle_metal_kernels::{binary, call_binary_contiguous, call_binary_strided, Kernels};
|
|
||||||
use half::{bf16, f16};
|
|
||||||
use metal::objc::rc::autoreleasepool;
|
|
||||||
use metal::{Device, MTLResourceOptions};
|
|
||||||
use rand;
|
|
||||||
use std::any::type_name;
|
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
fn main() {
|
|
||||||
let device = Device::system_default().unwrap();
|
|
||||||
let kernels = Kernels::new();
|
|
||||||
|
|
||||||
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
|
|
||||||
let f32_10k = (0..10000)
|
|
||||||
.map(|_| rand::random::<f32>())
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
let f32_100k = (0..100000)
|
|
||||||
.map(|_| rand::random::<f32>())
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
|
||||||
let f16_1k = f16_map(&f32_1k);
|
|
||||||
let f16_10k = f16_map(&f32_10k);
|
|
||||||
let f16_100k = f16_map(&f32_100k);
|
|
||||||
|
|
||||||
let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
|
||||||
let bf16_1k = bf16_map(&f32_1k);
|
|
||||||
let bf16_10k = bf16_map(&f32_10k);
|
|
||||||
let bf16_100k = bf16_map(&f32_100k);
|
|
||||||
|
|
||||||
let f32_ckernels = [
|
|
||||||
binary::contiguous::add::FLOAT,
|
|
||||||
binary::contiguous::sub::FLOAT,
|
|
||||||
binary::contiguous::mul::FLOAT,
|
|
||||||
binary::contiguous::div::FLOAT,
|
|
||||||
];
|
|
||||||
let f32_skernels = [
|
|
||||||
binary::strided::add::FLOAT,
|
|
||||||
binary::strided::sub::FLOAT,
|
|
||||||
binary::strided::mul::FLOAT,
|
|
||||||
binary::strided::div::FLOAT,
|
|
||||||
];
|
|
||||||
let f16_ckernels = [
|
|
||||||
binary::contiguous::add::HALF,
|
|
||||||
binary::contiguous::sub::HALF,
|
|
||||||
binary::contiguous::mul::HALF,
|
|
||||||
binary::contiguous::div::HALF,
|
|
||||||
];
|
|
||||||
let f16_skernels = [
|
|
||||||
binary::strided::add::HALF,
|
|
||||||
binary::strided::sub::HALF,
|
|
||||||
binary::strided::mul::HALF,
|
|
||||||
binary::strided::div::HALF,
|
|
||||||
];
|
|
||||||
let bf16_ckernels = [
|
|
||||||
binary::contiguous::add::BFLOAT,
|
|
||||||
binary::contiguous::sub::BFLOAT,
|
|
||||||
binary::contiguous::mul::BFLOAT,
|
|
||||||
binary::contiguous::div::BFLOAT,
|
|
||||||
];
|
|
||||||
let bf16_skernels = [
|
|
||||||
binary::strided::add::BFLOAT,
|
|
||||||
binary::strided::sub::BFLOAT,
|
|
||||||
binary::strided::mul::BFLOAT,
|
|
||||||
binary::strided::div::BFLOAT,
|
|
||||||
];
|
|
||||||
|
|
||||||
println!(
|
|
||||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
|
|
||||||
"dtype", "kernel", "size", "runs", "total time", "avg time"
|
|
||||||
);
|
|
||||||
|
|
||||||
// f32
|
|
||||||
run_binary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels);
|
|
||||||
run_binary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels);
|
|
||||||
run_binary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels);
|
|
||||||
|
|
||||||
// f16
|
|
||||||
run_binary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels);
|
|
||||||
run_binary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels);
|
|
||||||
run_binary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels);
|
|
||||||
|
|
||||||
// bf16
|
|
||||||
run_binary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels);
|
|
||||||
run_binary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels);
|
|
||||||
run_binary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_binary_bench<T: Clone>(
|
|
||||||
device: &Device,
|
|
||||||
kernels: &Kernels,
|
|
||||||
v: &[T],
|
|
||||||
contiguous: [binary::contiguous::Kernel; 4],
|
|
||||||
strided: [binary::strided::Kernel; 4],
|
|
||||||
) {
|
|
||||||
let command_queue = device.new_command_queue();
|
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
|
||||||
|
|
||||||
let iterations = 1000;
|
|
||||||
let input = device.new_buffer_with_data(
|
|
||||||
v.as_ptr() as *const core::ffi::c_void,
|
|
||||||
core::mem::size_of_val(v) as u64,
|
|
||||||
options,
|
|
||||||
);
|
|
||||||
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
|
|
||||||
|
|
||||||
// Contiguous
|
|
||||||
for kernel_name in contiguous {
|
|
||||||
let total_time = autoreleasepool(|| {
|
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
|
||||||
let start = Instant::now();
|
|
||||||
for _ in 0..iterations {
|
|
||||||
call_binary_contiguous(
|
|
||||||
device,
|
|
||||||
&command_buffer,
|
|
||||||
kernels,
|
|
||||||
kernel_name,
|
|
||||||
v.len(),
|
|
||||||
&input,
|
|
||||||
&input,
|
|
||||||
&mut output,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
|
|
||||||
start.elapsed()
|
|
||||||
});
|
|
||||||
println!(
|
|
||||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
|
||||||
type_name::<T>().split("::").last().unwrap(),
|
|
||||||
kernel_name.to_string(),
|
|
||||||
v.len(),
|
|
||||||
iterations,
|
|
||||||
total_time,
|
|
||||||
total_time / iterations
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Strided
|
|
||||||
let shape = vec![2, 5_000];
|
|
||||||
let strides = vec![2, 1];
|
|
||||||
let offset = 0;
|
|
||||||
for kernel_name in strided {
|
|
||||||
let total_time = autoreleasepool(|| {
|
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
|
||||||
let start = Instant::now();
|
|
||||||
for _ in 0..iterations {
|
|
||||||
call_binary_strided(
|
|
||||||
device,
|
|
||||||
command_buffer,
|
|
||||||
&kernels,
|
|
||||||
kernel_name,
|
|
||||||
&shape,
|
|
||||||
&input,
|
|
||||||
&strides,
|
|
||||||
offset,
|
|
||||||
&input,
|
|
||||||
&strides,
|
|
||||||
offset,
|
|
||||||
&mut output,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
|
|
||||||
start.elapsed()
|
|
||||||
});
|
|
||||||
|
|
||||||
println!(
|
|
||||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
|
||||||
type_name::<T>().split("::").last().unwrap(),
|
|
||||||
kernel_name.to_string(),
|
|
||||||
v.len(),
|
|
||||||
iterations,
|
|
||||||
total_time,
|
|
||||||
total_time / iterations
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,84 +0,0 @@
|
|||||||
use candle_metal_kernels::{call_cast_contiguous, Kernels};
|
|
||||||
use metal::objc::rc::autoreleasepool;
|
|
||||||
use metal::{Device, MTLResourceOptions};
|
|
||||||
use rand;
|
|
||||||
use std::any::type_name;
|
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
fn main() {
|
|
||||||
let device = Device::system_default().unwrap();
|
|
||||||
let kernels = Kernels::new();
|
|
||||||
|
|
||||||
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
|
|
||||||
let f32_10k = (0..10000)
|
|
||||||
.map(|_| rand::random::<f32>())
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
let f32_100k = (0..100000)
|
|
||||||
.map(|_| rand::random::<f32>())
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
let contiguous_kernels = ["cast_u32_f32"];
|
|
||||||
|
|
||||||
println!(
|
|
||||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
|
|
||||||
"dtype", "kernel", "size", "runs", "total time", "avg time"
|
|
||||||
);
|
|
||||||
|
|
||||||
// f32
|
|
||||||
run_cast_bench(&device, &kernels, &f32_1k, &contiguous_kernels);
|
|
||||||
run_cast_bench(&device, &kernels, &f32_10k, &contiguous_kernels);
|
|
||||||
run_cast_bench(&device, &kernels, &f32_100k, &contiguous_kernels);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_cast_bench<T: Clone>(
|
|
||||||
device: &Device,
|
|
||||||
kernels: &Kernels,
|
|
||||||
v: &[T],
|
|
||||||
contiguous: &[&'static str],
|
|
||||||
) {
|
|
||||||
let command_queue = device.new_command_queue();
|
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
|
||||||
|
|
||||||
let iterations = 1000;
|
|
||||||
let input = device.new_buffer_with_data(
|
|
||||||
v.as_ptr() as *const core::ffi::c_void,
|
|
||||||
core::mem::size_of_val(v) as u64,
|
|
||||||
options,
|
|
||||||
);
|
|
||||||
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
|
|
||||||
|
|
||||||
// Contiguous
|
|
||||||
for kernel_name in contiguous {
|
|
||||||
let total_time = autoreleasepool(|| {
|
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
|
||||||
let start = Instant::now();
|
|
||||||
for _ in 0..iterations {
|
|
||||||
call_cast_contiguous(
|
|
||||||
device,
|
|
||||||
&command_buffer,
|
|
||||||
kernels,
|
|
||||||
kernel_name,
|
|
||||||
v.len(),
|
|
||||||
&input,
|
|
||||||
&mut output,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
|
|
||||||
start.elapsed()
|
|
||||||
});
|
|
||||||
println!(
|
|
||||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
|
||||||
type_name::<T>().split("::").last().unwrap(),
|
|
||||||
kernel_name.to_string(),
|
|
||||||
v.len(),
|
|
||||||
iterations,
|
|
||||||
total_time,
|
|
||||||
total_time / iterations
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Strided?
|
|
||||||
}
|
|
@ -1,197 +0,0 @@
|
|||||||
use candle_metal_kernels::{call_unary_contiguous, call_unary_strided, unary, Kernels};
|
|
||||||
use half::{bf16, f16};
|
|
||||||
use metal::objc::rc::autoreleasepool;
|
|
||||||
use metal::{Device, MTLResourceOptions};
|
|
||||||
use rand;
|
|
||||||
use std::any::type_name;
|
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
fn main() {
|
|
||||||
let device = Device::system_default().unwrap();
|
|
||||||
let kernels = Kernels::new();
|
|
||||||
|
|
||||||
let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
|
|
||||||
let f32_10k = (0..10000)
|
|
||||||
.map(|_| rand::random::<f32>())
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
let f32_100k = (0..100000)
|
|
||||||
.map(|_| rand::random::<f32>())
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
|
||||||
let f16_1k = f16_map(&f32_1k);
|
|
||||||
let f16_10k = f16_map(&f32_10k);
|
|
||||||
let f16_100k = f16_map(&f32_100k);
|
|
||||||
|
|
||||||
let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
|
||||||
let bf16_1k = bf16_map(&f32_1k);
|
|
||||||
let bf16_10k = bf16_map(&f32_10k);
|
|
||||||
let bf16_100k = bf16_map(&f32_100k);
|
|
||||||
|
|
||||||
let f32_ckernels = [
|
|
||||||
unary::contiguous::sin::FLOAT,
|
|
||||||
unary::contiguous::cos::FLOAT,
|
|
||||||
unary::contiguous::exp::FLOAT,
|
|
||||||
unary::contiguous::sqr::FLOAT,
|
|
||||||
unary::contiguous::sqrt::FLOAT,
|
|
||||||
unary::contiguous::neg::FLOAT,
|
|
||||||
unary::contiguous::copy::FLOAT,
|
|
||||||
];
|
|
||||||
let f32_skernels = [
|
|
||||||
unary::strided::sin::FLOAT,
|
|
||||||
unary::strided::cos::FLOAT,
|
|
||||||
unary::strided::exp::FLOAT,
|
|
||||||
unary::strided::sqr::FLOAT,
|
|
||||||
unary::strided::sqrt::FLOAT,
|
|
||||||
unary::strided::neg::FLOAT,
|
|
||||||
unary::strided::copy::FLOAT,
|
|
||||||
];
|
|
||||||
let f16_ckernels = [
|
|
||||||
unary::contiguous::sin::HALF,
|
|
||||||
unary::contiguous::cos::HALF,
|
|
||||||
unary::contiguous::exp::HALF,
|
|
||||||
unary::contiguous::sqr::HALF,
|
|
||||||
unary::contiguous::sqrt::HALF,
|
|
||||||
unary::contiguous::neg::HALF,
|
|
||||||
unary::contiguous::copy::HALF,
|
|
||||||
];
|
|
||||||
let f16_skernels = [
|
|
||||||
unary::strided::sin::HALF,
|
|
||||||
unary::strided::cos::HALF,
|
|
||||||
unary::strided::exp::HALF,
|
|
||||||
unary::strided::sqr::HALF,
|
|
||||||
unary::strided::sqrt::HALF,
|
|
||||||
unary::strided::neg::HALF,
|
|
||||||
unary::strided::copy::HALF,
|
|
||||||
];
|
|
||||||
let bf16_ckernels = [
|
|
||||||
unary::contiguous::sin::BFLOAT,
|
|
||||||
unary::contiguous::cos::BFLOAT,
|
|
||||||
unary::contiguous::exp::BFLOAT,
|
|
||||||
unary::contiguous::sqr::BFLOAT,
|
|
||||||
unary::contiguous::sqrt::BFLOAT,
|
|
||||||
unary::contiguous::neg::BFLOAT,
|
|
||||||
unary::contiguous::copy::BFLOAT,
|
|
||||||
];
|
|
||||||
let bf16_skernels = [
|
|
||||||
unary::strided::sin::BFLOAT,
|
|
||||||
unary::strided::cos::BFLOAT,
|
|
||||||
unary::strided::exp::BFLOAT,
|
|
||||||
unary::strided::sqr::BFLOAT,
|
|
||||||
unary::strided::sqrt::BFLOAT,
|
|
||||||
unary::strided::neg::BFLOAT,
|
|
||||||
unary::strided::copy::BFLOAT,
|
|
||||||
];
|
|
||||||
|
|
||||||
println!(
|
|
||||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
|
|
||||||
"dtype", "kernel", "size", "runs", "total time", "avg time"
|
|
||||||
);
|
|
||||||
|
|
||||||
// f32
|
|
||||||
run_unary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels);
|
|
||||||
run_unary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels);
|
|
||||||
run_unary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels);
|
|
||||||
|
|
||||||
// f16
|
|
||||||
run_unary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels);
|
|
||||||
run_unary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels);
|
|
||||||
run_unary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels);
|
|
||||||
|
|
||||||
// bf16
|
|
||||||
run_unary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels);
|
|
||||||
run_unary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels);
|
|
||||||
run_unary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_unary_bench<T: Clone>(
|
|
||||||
device: &Device,
|
|
||||||
kernels: &Kernels,
|
|
||||||
v: &[T],
|
|
||||||
contiguous: [unary::contiguous::Kernel; 7],
|
|
||||||
strided: [unary::strided::Kernel; 7],
|
|
||||||
) {
|
|
||||||
let command_queue = device.new_command_queue();
|
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
|
||||||
|
|
||||||
let iterations = 10000;
|
|
||||||
let input = device.new_buffer_with_data(
|
|
||||||
v.as_ptr() as *const core::ffi::c_void,
|
|
||||||
core::mem::size_of_val(v) as u64,
|
|
||||||
options,
|
|
||||||
);
|
|
||||||
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);
|
|
||||||
|
|
||||||
// Contiguous
|
|
||||||
for kernel_name in contiguous {
|
|
||||||
let total_time = autoreleasepool(|| {
|
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
|
||||||
let start = Instant::now();
|
|
||||||
for _ in 0..iterations {
|
|
||||||
call_unary_contiguous(
|
|
||||||
device,
|
|
||||||
&command_buffer,
|
|
||||||
kernels,
|
|
||||||
kernel_name,
|
|
||||||
v.len(),
|
|
||||||
&input,
|
|
||||||
&mut output,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
|
|
||||||
start.elapsed()
|
|
||||||
});
|
|
||||||
println!(
|
|
||||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
|
||||||
type_name::<T>().split("::").last().unwrap(),
|
|
||||||
kernel_name.0,
|
|
||||||
v.len(),
|
|
||||||
iterations,
|
|
||||||
total_time,
|
|
||||||
total_time / iterations
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Strided
|
|
||||||
let shape = vec![2, 5_000];
|
|
||||||
let strides = vec![2, 1];
|
|
||||||
let offset = 0;
|
|
||||||
for kernel_name in &strided {
|
|
||||||
let total_time = autoreleasepool(|| {
|
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
|
||||||
let start = Instant::now();
|
|
||||||
for _ in 0..iterations {
|
|
||||||
call_unary_strided(
|
|
||||||
device,
|
|
||||||
command_buffer,
|
|
||||||
&kernels,
|
|
||||||
kernel_name,
|
|
||||||
&shape,
|
|
||||||
&input,
|
|
||||||
&strides,
|
|
||||||
offset,
|
|
||||||
&mut output,
|
|
||||||
0,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
|
|
||||||
start.elapsed()
|
|
||||||
});
|
|
||||||
|
|
||||||
println!(
|
|
||||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
|
||||||
type_name::<T>().split("::").last().unwrap(),
|
|
||||||
kernel_name.0,
|
|
||||||
v.len(),
|
|
||||||
iterations,
|
|
||||||
total_time,
|
|
||||||
total_time / iterations
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
@ -19,7 +19,6 @@ num-traits = { workspace = true }
|
|||||||
rayon = { workspace = true }
|
rayon = { workspace = true }
|
||||||
safetensors = { workspace = true }
|
safetensors = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
@ -29,5 +28,5 @@ clap = { workspace = true }
|
|||||||
default = []
|
default = []
|
||||||
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
||||||
cuda = ["candle/cuda"]
|
cuda = ["candle/cuda"]
|
||||||
|
metal = ["candle/metal"]
|
||||||
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
||||||
metal = ["candle/metal", "dep:candle-metal-kernels"]
|
|
||||||
|
@ -6,16 +6,14 @@ use serde::Deserialize;
|
|||||||
pub enum Activation {
|
pub enum Activation {
|
||||||
#[default]
|
#[default]
|
||||||
Gelu,
|
Gelu,
|
||||||
|
#[serde(rename = "gated-gelu")]
|
||||||
NewGelu,
|
NewGelu,
|
||||||
Relu,
|
Relu,
|
||||||
Relu2,
|
Relu2,
|
||||||
Relu6,
|
Relu6,
|
||||||
Silu,
|
Silu,
|
||||||
Sigmoid,
|
Sigmoid,
|
||||||
HardSigmoid,
|
|
||||||
Swiglu,
|
|
||||||
Swish,
|
Swish,
|
||||||
HardSwish,
|
|
||||||
Elu(f64),
|
Elu(f64),
|
||||||
LeakyRelu(f64),
|
LeakyRelu(f64),
|
||||||
}
|
}
|
||||||
@ -31,10 +29,7 @@ impl super::Module for Activation {
|
|||||||
Self::Relu6 => xs.clamp(0f32, 6f32),
|
Self::Relu6 => xs.clamp(0f32, 6f32),
|
||||||
Self::Silu => crate::ops::silu(xs),
|
Self::Silu => crate::ops::silu(xs),
|
||||||
Self::Sigmoid => crate::ops::sigmoid(xs),
|
Self::Sigmoid => crate::ops::sigmoid(xs),
|
||||||
Self::HardSigmoid => crate::ops::hard_sigmoid(xs),
|
|
||||||
Self::Swiglu => crate::ops::swiglu(xs),
|
|
||||||
Self::Swish => xs * crate::ops::sigmoid(xs)?,
|
Self::Swish => xs * crate::ops::sigmoid(xs)?,
|
||||||
Self::HardSwish => xs * crate::ops::hard_sigmoid(xs)?,
|
|
||||||
&Self::Elu(alpha) => xs.elu(alpha),
|
&Self::Elu(alpha) => xs.elu(alpha),
|
||||||
&Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope),
|
&Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope),
|
||||||
}
|
}
|
||||||
|
@ -70,67 +70,6 @@ impl crate::Module for Conv1d {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
||||||
pub struct ConvTranspose1dConfig {
|
|
||||||
pub padding: usize,
|
|
||||||
pub output_padding: usize,
|
|
||||||
pub stride: usize,
|
|
||||||
pub dilation: usize,
|
|
||||||
// TODO: support groups.
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for ConvTranspose1dConfig {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
padding: 0,
|
|
||||||
output_padding: 0,
|
|
||||||
stride: 1,
|
|
||||||
dilation: 1,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub struct ConvTranspose1d {
|
|
||||||
weight: Tensor,
|
|
||||||
bias: Option<Tensor>,
|
|
||||||
config: ConvTranspose1dConfig,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ConvTranspose1d {
|
|
||||||
pub fn new(weight: Tensor, bias: Option<Tensor>, config: ConvTranspose1dConfig) -> Self {
|
|
||||||
Self {
|
|
||||||
weight,
|
|
||||||
bias,
|
|
||||||
config,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn config(&self) -> &ConvTranspose1dConfig {
|
|
||||||
&self.config
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl crate::Module for ConvTranspose1d {
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
||||||
let x = x.conv_transpose1d(
|
|
||||||
&self.weight,
|
|
||||||
self.config.padding,
|
|
||||||
self.config.output_padding,
|
|
||||||
self.config.stride,
|
|
||||||
self.config.dilation,
|
|
||||||
)?;
|
|
||||||
match &self.bias {
|
|
||||||
None => Ok(x),
|
|
||||||
Some(bias) => {
|
|
||||||
let b = bias.dims1()?;
|
|
||||||
let bias = bias.reshape((1, b, 1, 1))?;
|
|
||||||
Ok(x.broadcast_add(&bias)?)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
pub struct Conv2dConfig {
|
pub struct Conv2dConfig {
|
||||||
pub padding: usize,
|
pub padding: usize,
|
||||||
@ -302,39 +241,6 @@ pub fn conv1d(
|
|||||||
Ok(Conv1d::new(ws, Some(bs), cfg))
|
Ok(Conv1d::new(ws, Some(bs), cfg))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn conv_transpose1d(
|
|
||||||
in_channels: usize,
|
|
||||||
out_channels: usize,
|
|
||||||
kernel_size: usize,
|
|
||||||
cfg: ConvTranspose1dConfig,
|
|
||||||
vb: crate::VarBuilder,
|
|
||||||
) -> Result<ConvTranspose1d> {
|
|
||||||
let bound = 1. / (out_channels as f64 * kernel_size as f64).sqrt();
|
|
||||||
let init = crate::Init::Uniform {
|
|
||||||
lo: -bound,
|
|
||||||
up: bound,
|
|
||||||
};
|
|
||||||
let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?;
|
|
||||||
let bs = vb.get_with_hints(out_channels, "bias", init)?;
|
|
||||||
Ok(ConvTranspose1d::new(ws, Some(bs), cfg))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn conv_transpose1d_no_bias(
|
|
||||||
in_channels: usize,
|
|
||||||
out_channels: usize,
|
|
||||||
kernel_size: usize,
|
|
||||||
cfg: ConvTranspose1dConfig,
|
|
||||||
vb: crate::VarBuilder,
|
|
||||||
) -> Result<ConvTranspose1d> {
|
|
||||||
let bound = 1. / (out_channels as f64 * kernel_size as f64).sqrt();
|
|
||||||
let init = crate::Init::Uniform {
|
|
||||||
lo: -bound,
|
|
||||||
up: bound,
|
|
||||||
};
|
|
||||||
let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?;
|
|
||||||
Ok(ConvTranspose1d::new(ws, None, cfg))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn conv2d(
|
pub fn conv2d(
|
||||||
in_channels: usize,
|
in_channels: usize,
|
||||||
out_channels: usize,
|
out_channels: usize,
|
||||||
|
@ -95,14 +95,6 @@ impl LayerNorm {
|
|||||||
eps,
|
eps,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn weight(&self) -> &Tensor {
|
|
||||||
&self.weight
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn bias(&self) -> Option<&Tensor> {
|
|
||||||
self.bias.as_ref()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl crate::Module for LayerNorm {
|
impl crate::Module for LayerNorm {
|
||||||
|
@ -39,21 +39,11 @@ pub fn silu(xs: &Tensor) -> Result<Tensor> {
|
|||||||
xs / (xs.neg()?.exp()? + 1.0)?
|
xs / (xs.neg()?.exp()? + 1.0)?
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let xs = xs.chunk(2, candle::D::Minus1)?;
|
|
||||||
crate::ops::silu(&xs[0])? * &xs[1]
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
|
pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
|
||||||
// TODO: Should we have a specialized op for this?
|
// TODO: Should we have a specialized op for this?
|
||||||
(xs.neg()?.exp()? + 1.0)?.recip()
|
(xs.neg()?.exp()? + 1.0)?.recip()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn hard_sigmoid(xs: &Tensor) -> Result<Tensor> {
|
|
||||||
// TODO: Should we have a specialized op for this?
|
|
||||||
((xs + 3.0)? / 6.0)?.clamp(0f32, 1f32)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result<Tensor> {
|
pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result<Tensor> {
|
||||||
let zeros = xs.zeros_like()?;
|
let zeros = xs.zeros_like()?;
|
||||||
xs.maximum(&zeros)? + xs.minimum(&zeros)? * negative_slope
|
xs.maximum(&zeros)? + xs.minimum(&zeros)? * negative_slope
|
||||||
@ -200,7 +190,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||||||
device: dev.clone(),
|
device: dev.clone(),
|
||||||
};
|
};
|
||||||
Ok((dst, layout.shape().clone()))
|
Ok((dst, layout.shape().clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "metal")]
|
#[cfg(feature = "metal")]
|
||||||
fn metal_fwd(
|
fn metal_fwd(
|
||||||
@ -208,29 +198,8 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||||||
storage: &candle::MetalStorage,
|
storage: &candle::MetalStorage,
|
||||||
layout: &Layout,
|
layout: &Layout,
|
||||||
) -> Result<(candle::MetalStorage, Shape)> {
|
) -> Result<(candle::MetalStorage, Shape)> {
|
||||||
use candle::backend::{BackendStorage};
|
println!("TODO softmax-last-dim");
|
||||||
let device = storage.device();
|
Ok((storage.clone(), layout.shape().clone()))
|
||||||
let command_buffer = device.command_buffer();
|
|
||||||
let kernels = device.kernels();
|
|
||||||
let name = "softmax_float";
|
|
||||||
assert!(layout.is_contiguous());
|
|
||||||
assert!(layout.start_offset() == 0);
|
|
||||||
let last_dim = layout.dims()[layout.shape().rank() - 1];
|
|
||||||
let elem_count = layout.shape().elem_count();
|
|
||||||
let mut output = device.new_buffer(elem_count, storage.dtype());
|
|
||||||
candle_metal_kernels::call_last_softmax(
|
|
||||||
device.metal_device(),
|
|
||||||
&command_buffer,
|
|
||||||
&kernels,
|
|
||||||
name,
|
|
||||||
elem_count,
|
|
||||||
last_dim,
|
|
||||||
storage.buffer(),
|
|
||||||
&mut output,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
|
|
||||||
Ok((newstorage, layout.shape().clone()))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,23 +0,0 @@
|
|||||||
[package]
|
|
||||||
name = "candle-onnx"
|
|
||||||
version = "0.3.0"
|
|
||||||
edition = "2021"
|
|
||||||
|
|
||||||
description = "ONNX support for Candle"
|
|
||||||
repository = "https://github.com/huggingface/candle"
|
|
||||||
keywords = ["blas", "tensor", "machine-learning"]
|
|
||||||
categories = ["science"]
|
|
||||||
license = "MIT OR Apache-2.0"
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
|
||||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
|
||||||
prost = "0.12.1"
|
|
||||||
|
|
||||||
[build-dependencies]
|
|
||||||
prost-build = "0.12.1"
|
|
||||||
|
|
||||||
[dev-dependencies]
|
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
|
||||||
|
|
@ -1,21 +0,0 @@
|
|||||||
# candle-onnx
|
|
||||||
|
|
||||||
This crate adds ONNX support to candle
|
|
||||||
|
|
||||||
## FAQ
|
|
||||||
|
|
||||||
#### Missing protoc installation when compiling candle-onnx
|
|
||||||
|
|
||||||
The candle-onnx dependency prost-build no longer comes bundled with prost
|
|
||||||
binaries. This could cause the following error when attempting to compile
|
|
||||||
candle-onnx:
|
|
||||||
|
|
||||||
```
|
|
||||||
error: failed to run custom build command for `candle-onnx`
|
|
||||||
Caused by: // (...)
|
|
||||||
Could not find `protoc` installation and this build crate cannot proceed without this knowledge.
|
|
||||||
```
|
|
||||||
|
|
||||||
To fix this issue install protoc on your system and make it available in your
|
|
||||||
system `PATH`. See the [protoc
|
|
||||||
documentation](https://grpc.io/docs/protoc-installation/) for more information.
|
|
@ -1,6 +0,0 @@
|
|||||||
use std::io::Result;
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
prost_build::compile_protos(&["src/onnx.proto3"], &["src/"])?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -1,755 +0,0 @@
|
|||||||
use crate::onnx;
|
|
||||||
use crate::onnx::attribute_proto::AttributeType;
|
|
||||||
use crate::onnx::tensor_proto::DataType;
|
|
||||||
use candle::{bail, DType, Device, Result, Tensor};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
pub type Value = Tensor;
|
|
||||||
|
|
||||||
pub fn dtype(dt: DataType) -> Option<DType> {
|
|
||||||
match dt {
|
|
||||||
DataType::Uint8 => Some(DType::U8),
|
|
||||||
DataType::Uint32 => Some(DType::U32),
|
|
||||||
DataType::Int64 => Some(DType::I64),
|
|
||||||
DataType::Float16 => Some(DType::F16),
|
|
||||||
DataType::Float => Some(DType::F32),
|
|
||||||
DataType::Double => Some(DType::F64),
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
trait Attr {
|
|
||||||
const TYPE: AttributeType;
|
|
||||||
fn get(attr: &onnx::AttributeProto) -> Result<&Self>;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Attr for i64 {
|
|
||||||
const TYPE: AttributeType = AttributeType::Int;
|
|
||||||
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
|
|
||||||
Ok(&attr.i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Attr for f32 {
|
|
||||||
const TYPE: AttributeType = AttributeType::Float;
|
|
||||||
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
|
|
||||||
Ok(&attr.f)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Attr for [i64] {
|
|
||||||
const TYPE: AttributeType = AttributeType::Ints;
|
|
||||||
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
|
|
||||||
Ok(attr.ints.as_slice())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Attr for str {
|
|
||||||
const TYPE: AttributeType = AttributeType::String;
|
|
||||||
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
|
|
||||||
std::str::from_utf8(&attr.s).map_err(candle::Error::wrap)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_attr_<'a>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a onnx::AttributeProto> {
|
|
||||||
match node.attribute.iter().find(|attr| attr.name == name) {
|
|
||||||
None => {
|
|
||||||
bail!(
|
|
||||||
"cannot find the '{name}' attribute in '{}' for {}",
|
|
||||||
node.op_type,
|
|
||||||
node.name
|
|
||||||
)
|
|
||||||
}
|
|
||||||
Some(dt) => Ok(dt),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_attr<'a, T: Attr + ?Sized>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a T> {
|
|
||||||
let attr = get_attr_(node, name)?;
|
|
||||||
if attr.r#type() != T::TYPE {
|
|
||||||
bail!(
|
|
||||||
"unsupported type {:?} for '{name}' attribute in '{}' for {}",
|
|
||||||
attr.r#type,
|
|
||||||
node.op_type,
|
|
||||||
node.name
|
|
||||||
)
|
|
||||||
}
|
|
||||||
T::get(attr)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_attr_opt<'a, T: Attr + ?Sized>(
|
|
||||||
node: &'a onnx::NodeProto,
|
|
||||||
name: &str,
|
|
||||||
) -> Result<Option<&'a T>> {
|
|
||||||
match node.attribute.iter().find(|attr| attr.name == name) {
|
|
||||||
None => Ok(None),
|
|
||||||
Some(attr) => {
|
|
||||||
if attr.r#type() != T::TYPE {
|
|
||||||
bail!(
|
|
||||||
"unsupported type {:?} for '{name}' attribute in '{}' for {}",
|
|
||||||
attr.r#type,
|
|
||||||
node.op_type,
|
|
||||||
node.name
|
|
||||||
)
|
|
||||||
}
|
|
||||||
let val = T::get(attr)?;
|
|
||||||
Ok(Some(val))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
|
|
||||||
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
|
|
||||||
match DataType::try_from(t.data_type) {
|
|
||||||
Ok(DataType::Int32) => {
|
|
||||||
if t.int32_data.is_empty() {
|
|
||||||
let len = t.raw_data.len() / 4;
|
|
||||||
let data: &[i32] =
|
|
||||||
unsafe { std::slice::from_raw_parts(t.raw_data.as_ptr() as *const i32, len) };
|
|
||||||
let data = data.iter().map(|v| *v as i64).collect::<Vec<_>>();
|
|
||||||
Tensor::from_vec(data, len, &Device::Cpu)
|
|
||||||
} else {
|
|
||||||
let data = t.int32_data.iter().map(|v| *v as i64).collect::<Vec<_>>();
|
|
||||||
Tensor::from_vec(data, t.int32_data.len(), &Device::Cpu)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(dt) => match dtype(dt) {
|
|
||||||
Some(dt) => {
|
|
||||||
if dt == DType::F32 && !t.float_data.is_empty() {
|
|
||||||
Tensor::from_slice(&t.float_data, dims.as_slice(), &Device::Cpu)
|
|
||||||
} else if dt == DType::F64 && !t.double_data.is_empty() {
|
|
||||||
Tensor::from_slice(&t.double_data, dims.as_slice(), &Device::Cpu)
|
|
||||||
} else if dt == DType::I64 && !t.int64_data.is_empty() {
|
|
||||||
Tensor::from_slice(&t.int64_data, dims.as_slice(), &Device::Cpu)
|
|
||||||
} else {
|
|
||||||
Tensor::from_raw_buffer(
|
|
||||||
t.raw_data.as_slice(),
|
|
||||||
dt,
|
|
||||||
dims.as_slice(),
|
|
||||||
&Device::Cpu,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
bail!("unsupported 'value' data-type {dt:?} for {name}")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Err(_) => {
|
|
||||||
bail!("unsupported 'value' data-type {} for {name}", t.data_type,)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// This function provides a direct evaluation of the proto.
|
|
||||||
// Longer-term, we should first convert the proto to an intermediate representation of the compute
|
|
||||||
// graph so as to make multiple evaluations more efficient.
|
|
||||||
// An example upside of this would be to remove intermediary values when they are not needed
|
|
||||||
// anymore.
|
|
||||||
pub fn simple_eval(
|
|
||||||
model: &onnx::ModelProto,
|
|
||||||
inputs: HashMap<String, Value>,
|
|
||||||
) -> Result<HashMap<String, Value>> {
|
|
||||||
let graph = match &model.graph {
|
|
||||||
None => bail!("no graph defined in proto"),
|
|
||||||
Some(graph) => graph,
|
|
||||||
};
|
|
||||||
let mut values = inputs;
|
|
||||||
for t in graph.initializer.iter() {
|
|
||||||
let tensor = get_tensor(t, t.name.as_str())?;
|
|
||||||
values.insert(t.name.to_string(), tensor);
|
|
||||||
}
|
|
||||||
for input in graph.input.iter() {
|
|
||||||
let input_type = match &input.r#type {
|
|
||||||
Some(input_type) => input_type,
|
|
||||||
None => continue,
|
|
||||||
};
|
|
||||||
let input_type = match &input_type.value {
|
|
||||||
Some(input_type) => input_type,
|
|
||||||
None => continue,
|
|
||||||
};
|
|
||||||
let tensor_type = match input_type {
|
|
||||||
onnx::type_proto::Value::TensorType(tt) => tt,
|
|
||||||
_ => continue,
|
|
||||||
};
|
|
||||||
|
|
||||||
let tensor = match values.get(&input.name) {
|
|
||||||
None => bail!("missing input {}", input.name),
|
|
||||||
Some(tensor) => tensor,
|
|
||||||
};
|
|
||||||
let dt = match DataType::try_from(tensor_type.elem_type) {
|
|
||||||
Ok(dt) => match dtype(dt) {
|
|
||||||
Some(dt) => dt,
|
|
||||||
None => {
|
|
||||||
bail!("unsupported 'value' data-type {dt:?} for {}", input.name)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
type_ => bail!("unsupported input type {type_:?}"),
|
|
||||||
};
|
|
||||||
match &tensor_type.shape {
|
|
||||||
None => continue,
|
|
||||||
Some(shape) => {
|
|
||||||
if shape.dim.len() != tensor.rank() {
|
|
||||||
bail!(
|
|
||||||
"unexpected rank for {}, got {:?}, expected {:?}",
|
|
||||||
input.name,
|
|
||||||
shape.dim,
|
|
||||||
tensor.shape()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
for (idx, (d, &dim)) in shape.dim.iter().zip(tensor.dims().iter()).enumerate() {
|
|
||||||
match &d.value {
|
|
||||||
Some(onnx::tensor_shape_proto::dimension::Value::DimValue(v)) => {
|
|
||||||
if *v as usize != dim {
|
|
||||||
bail!(
|
|
||||||
"unexpected dim {idx} for {}, got {:?}, expected {:?}",
|
|
||||||
input.name,
|
|
||||||
shape.dim,
|
|
||||||
tensor.shape()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// We do not check equality constraints for the DimParam dimensions for now.
|
|
||||||
Some(onnx::tensor_shape_proto::dimension::Value::DimParam(_)) | None => (),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
if dt != tensor.dtype() {
|
|
||||||
bail!(
|
|
||||||
"unexpected dtype for {}, got {:?}, expected {dt:?}",
|
|
||||||
input.name,
|
|
||||||
tensor.dtype()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// The nodes are topologically sorted so we can just process them in order.
|
|
||||||
for node in graph.node.iter() {
|
|
||||||
let get = |input_name: &str| match values.get(input_name) {
|
|
||||||
Some(value) => Ok(value),
|
|
||||||
None => bail!("cannot find {input_name} for op {}", node.name),
|
|
||||||
};
|
|
||||||
// TODO: Validate node.input for each operator.
|
|
||||||
match node.op_type.as_str() {
|
|
||||||
"Add" => {
|
|
||||||
let input0 = get(&node.input[0])?;
|
|
||||||
let input1 = get(&node.input[1])?;
|
|
||||||
let output = input0.broadcast_add(input1)?;
|
|
||||||
values.insert(node.output[0].clone(), output);
|
|
||||||
}
|
|
||||||
"Sub" => {
|
|
||||||
let input0 = get(&node.input[0])?;
|
|
||||||
let input1 = get(&node.input[1])?;
|
|
||||||
let output = input0.broadcast_sub(input1)?;
|
|
||||||
values.insert(node.output[0].clone(), output);
|
|
||||||
}
|
|
||||||
"Mul" => {
|
|
||||||
let input0 = get(&node.input[0])?;
|
|
||||||
let input1 = get(&node.input[1])?;
|
|
||||||
let output = input0.broadcast_mul(input1)?;
|
|
||||||
values.insert(node.output[0].clone(), output);
|
|
||||||
}
|
|
||||||
"Div" => {
|
|
||||||
let input0 = get(&node.input[0])?;
|
|
||||||
let input1 = get(&node.input[1])?;
|
|
||||||
let output = input0.broadcast_div(input1)?;
|
|
||||||
values.insert(node.output[0].clone(), output);
|
|
||||||
}
|
|
||||||
"Equal" => {
|
|
||||||
let input0 = get(&node.input[0])?;
|
|
||||||
let input1 = get(&node.input[1])?;
|
|
||||||
let output = input0.broadcast_eq(input1)?;
|
|
||||||
values.insert(node.output[0].clone(), output);
|
|
||||||
}
|
|
||||||
"Not" => {
|
|
||||||
let xs = get(&node.input[0])?;
|
|
||||||
let xs = xs.eq(&xs.zeros_like()?)?;
|
|
||||||
values.insert(node.output[0].clone(), xs);
|
|
||||||
}
|
|
||||||
"MatMul" => {
|
|
||||||
let input0 = get(&node.input[0])?;
|
|
||||||
let input1 = get(&node.input[1])?;
|
|
||||||
let output = input0.broadcast_matmul(input1)?;
|
|
||||||
values.insert(node.output[0].clone(), output);
|
|
||||||
}
|
|
||||||
"Reshape" => {
|
|
||||||
let input0 = get(&node.input[0])?;
|
|
||||||
let input1 = get(&node.input[1])?.to_vec1::<i64>()?;
|
|
||||||
// TODO: Check that there is at most a single -1 or 0, handle other neg values.
|
|
||||||
let mut other_than_minus1 = 1usize;
|
|
||||||
for &v in input1.iter() {
|
|
||||||
if v != -1 && v != 0 {
|
|
||||||
other_than_minus1 *= v as usize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let input1 = input1
|
|
||||||
.iter()
|
|
||||||
.enumerate()
|
|
||||||
.map(|(idx, &v)| match v {
|
|
||||||
-1 => Ok(input0.elem_count() / other_than_minus1),
|
|
||||||
0 => input0.dim(idx),
|
|
||||||
_ => Ok(v as usize),
|
|
||||||
})
|
|
||||||
.collect::<Result<Vec<usize>>>()?;
|
|
||||||
let output = input0.reshape(input1)?;
|
|
||||||
values.insert(node.output[0].clone(), output);
|
|
||||||
}
|
|
||||||
"LogSoftmax" => {
|
|
||||||
let input = get(&node.input[0])?;
|
|
||||||
let output = match get_attr_opt::<i64>(node, "axis")? {
|
|
||||||
None => candle_nn::ops::softmax_last_dim(input)?,
|
|
||||||
Some(&axis) => {
|
|
||||||
let axis = input.normalize_axis(axis)?;
|
|
||||||
candle_nn::ops::log_softmax(input, axis)?
|
|
||||||
}
|
|
||||||
};
|
|
||||||
values.insert(node.output[0].clone(), output);
|
|
||||||
}
|
|
||||||
"Softmax" => {
|
|
||||||
let input = get(&node.input[0])?;
|
|
||||||
let output = match get_attr_opt::<i64>(node, "axis")? {
|
|
||||||
None => candle_nn::ops::softmax_last_dim(input)?,
|
|
||||||
Some(&axis) => {
|
|
||||||
let axis = input.normalize_axis(axis)?;
|
|
||||||
candle_nn::ops::softmax(input, axis)?
|
|
||||||
}
|
|
||||||
};
|
|
||||||
values.insert(node.output[0].clone(), output);
|
|
||||||
}
|
|
||||||
"Transpose" => {
|
|
||||||
let input = get(&node.input[0])?;
|
|
||||||
let output = match get_attr_opt::<[i64]>(node, "perm")? {
|
|
||||||
None => input.t()?,
|
|
||||||
Some(perm) => {
|
|
||||||
let perm = perm.iter().map(|&v| v as usize).collect::<Vec<_>>();
|
|
||||||
input.permute(perm)?
|
|
||||||
}
|
|
||||||
};
|
|
||||||
values.insert(node.output[0].clone(), output);
|
|
||||||
}
|
|
||||||
"Dropout" => {
|
|
||||||
let input = get(&node.input[0])?;
|
|
||||||
// Do not apply dropout at the moment, consider that we're only doing inference.
|
|
||||||
values.insert(node.output[0].clone(), input.clone());
|
|
||||||
}
|
|
||||||
"MaxPool" => {
|
|
||||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#MaxPool
|
|
||||||
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
|
|
||||||
let kernel_shape = get_attr::<[i64]>(node, "kernel_shape")?;
|
|
||||||
let pads = get_attr_opt::<[i64]>(node, "pads")?;
|
|
||||||
let strides = get_attr_opt::<[i64]>(node, "strides")?;
|
|
||||||
let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
|
|
||||||
match auto_pad {
|
|
||||||
None | Some("NOTSET") => (),
|
|
||||||
Some(s) => bail!("unsupported auto_pad {s}"),
|
|
||||||
};
|
|
||||||
if let Some(d) = dilations {
|
|
||||||
if d.iter().any(|&v| v != 1) {
|
|
||||||
bail!("MaxPool with dilation != 1, {dilations:?}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if let Some(d) = pads {
|
|
||||||
if d.iter().any(|&v| v != 0) {
|
|
||||||
bail!("MaxPool with pads != 0, {pads:?}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let xs = get(&node.input[0])?;
|
|
||||||
let (k1, k2) = match kernel_shape {
|
|
||||||
[k1, k2] => (*k1 as usize, *k2 as usize),
|
|
||||||
_ => bail!("only 2d MaxPool is supported, kernel shape {kernel_shape:?}"),
|
|
||||||
};
|
|
||||||
let ys = match strides {
|
|
||||||
None => xs.max_pool2d((k1, k2))?,
|
|
||||||
Some([s1, s2]) => {
|
|
||||||
xs.max_pool2d_with_stride((k1, k2), (*s1 as usize, *s2 as usize))?
|
|
||||||
}
|
|
||||||
Some(strides) => bail!("only 2d MaxPool is supported, strides {strides:?}"),
|
|
||||||
};
|
|
||||||
values.insert(node.output[0].clone(), ys);
|
|
||||||
}
|
|
||||||
"AveragePool" => {
|
|
||||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool
|
|
||||||
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
|
|
||||||
let kernel_shape = get_attr::<[i64]>(node, "kernel_shape")?;
|
|
||||||
let pads = get_attr_opt::<[i64]>(node, "pads")?;
|
|
||||||
let strides = get_attr_opt::<[i64]>(node, "strides")?;
|
|
||||||
let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
|
|
||||||
match auto_pad {
|
|
||||||
None | Some("NOTSET") => (),
|
|
||||||
Some(s) => bail!("unsupported auto_pad {s}"),
|
|
||||||
};
|
|
||||||
if let Some(d) = dilations {
|
|
||||||
if d.iter().any(|&v| v != 1) {
|
|
||||||
bail!("AvgPool with dilation != 1, {dilations:?}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if let Some(d) = pads {
|
|
||||||
if d.iter().any(|&v| v != 0) {
|
|
||||||
bail!("AvgPool with pads != 0, {pads:?}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let xs = get(&node.input[0])?;
|
|
||||||
let (k1, k2) = match kernel_shape {
|
|
||||||
[k1, k2] => (*k1 as usize, *k2 as usize),
|
|
||||||
_ => bail!("only 2d AvgPool is supported, kernel shape {kernel_shape:?}"),
|
|
||||||
};
|
|
||||||
let ys = match strides {
|
|
||||||
None => xs.avg_pool2d((k1, k2))?,
|
|
||||||
Some([s1, s2]) => {
|
|
||||||
xs.avg_pool2d_with_stride((k1, k2), (*s1 as usize, *s2 as usize))?
|
|
||||||
}
|
|
||||||
Some(strides) => bail!("only 2d AvgPool is supported, strides {strides:?}"),
|
|
||||||
};
|
|
||||||
values.insert(node.output[0].clone(), ys);
|
|
||||||
}
|
|
||||||
"BatchNormalization" => {
|
|
||||||
let training_mode = get_attr_opt::<i64>(node, "training_mode")?;
|
|
||||||
if training_mode.copied().unwrap_or(0) != 0 {
|
|
||||||
bail!("training mode is not supported for BatchNorm")
|
|
||||||
}
|
|
||||||
let eps = get_attr_opt::<f32>(node, "epsilon")?
|
|
||||||
.copied()
|
|
||||||
.unwrap_or(1e-5);
|
|
||||||
let xs = get(&node.input[0])?;
|
|
||||||
let weight = get(&node.input[1])?;
|
|
||||||
let bias = get(&node.input[2])?;
|
|
||||||
let running_mean = get(&node.input[3])?;
|
|
||||||
let running_var = get(&node.input[4])?;
|
|
||||||
let target_shape: Vec<usize> = xs
|
|
||||||
.dims()
|
|
||||||
.iter()
|
|
||||||
.enumerate()
|
|
||||||
.map(|(idx, v)| if idx == 1 { *v } else { 1 })
|
|
||||||
.collect();
|
|
||||||
let target_shape = target_shape.as_slice();
|
|
||||||
let xs = xs
|
|
||||||
.broadcast_sub(&running_mean.reshape(target_shape)?)?
|
|
||||||
.broadcast_div(&(running_var.reshape(target_shape)? + eps as f64)?.sqrt()?)?;
|
|
||||||
let weight = weight.reshape(target_shape)?;
|
|
||||||
let bias = bias.reshape(target_shape)?;
|
|
||||||
let xs = xs.broadcast_mul(&weight)?.broadcast_add(&bias)?;
|
|
||||||
values.insert(node.output[0].clone(), xs);
|
|
||||||
}
|
|
||||||
"Squeeze" => {
|
|
||||||
let xs = get(&node.input[0])?;
|
|
||||||
let mut axes = if node.input.len() <= 1 {
|
|
||||||
// contract all the dimensions with size 1 except the batch dim.
|
|
||||||
xs.dims()
|
|
||||||
.iter()
|
|
||||||
.enumerate()
|
|
||||||
.flat_map(|(idx, &s)| if s == 1 && idx > 0 { Some(idx) } else { None })
|
|
||||||
.collect()
|
|
||||||
} else {
|
|
||||||
get(&node.input[1])?
|
|
||||||
.to_vec1::<i64>()?
|
|
||||||
.iter()
|
|
||||||
.map(|&i| xs.normalize_axis(i))
|
|
||||||
.collect::<Result<Vec<_>>>()?
|
|
||||||
};
|
|
||||||
axes.sort();
|
|
||||||
let mut xs = xs.clone();
|
|
||||||
for &axis in axes.iter().rev() {
|
|
||||||
xs = xs.squeeze(axis)?
|
|
||||||
}
|
|
||||||
values.insert(node.output[0].clone(), xs);
|
|
||||||
}
|
|
||||||
"ConstantOfShape" => {
|
|
||||||
let dims = get(&node.input[0])?;
|
|
||||||
let shape = dims
|
|
||||||
.to_vec1::<i64>()?
|
|
||||||
.into_iter()
|
|
||||||
.map(|v| v as usize)
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
let xs = Tensor::zeros(shape, DType::F32, dims.device())?;
|
|
||||||
values.insert(node.output[0].clone(), xs);
|
|
||||||
}
|
|
||||||
"Unsqueeze" => {
|
|
||||||
let xs = get(&node.input[0])?;
|
|
||||||
let axes = match get_attr_opt::<[i64]>(node, "axes")? {
|
|
||||||
Some(axis) => axis.to_vec(),
|
|
||||||
None => get(&node.input[1])?.to_vec1::<i64>()?,
|
|
||||||
};
|
|
||||||
let mut axes = axes
|
|
||||||
.iter()
|
|
||||||
.map(|&i| {
|
|
||||||
if i == xs.rank() as i64 {
|
|
||||||
Ok(xs.rank())
|
|
||||||
} else {
|
|
||||||
xs.normalize_axis(i)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
axes.sort();
|
|
||||||
let mut xs = xs.clone();
|
|
||||||
for &axis in axes.iter().rev() {
|
|
||||||
xs = xs.unsqueeze(axis)?
|
|
||||||
}
|
|
||||||
values.insert(node.output[0].clone(), xs);
|
|
||||||
}
|
|
||||||
"Clip" => {
|
|
||||||
let xs = get(&node.input[0])?;
|
|
||||||
let xs = if node.input.len() >= 2 {
|
|
||||||
let mins = get(&node.input[1])?;
|
|
||||||
xs.broadcast_maximum(mins)?
|
|
||||||
} else {
|
|
||||||
xs.clone()
|
|
||||||
};
|
|
||||||
let xs = if node.input.len() >= 3 {
|
|
||||||
let maxs = get(&node.input[2])?;
|
|
||||||
xs.broadcast_minimum(maxs)?
|
|
||||||
} else {
|
|
||||||
xs.clone()
|
|
||||||
};
|
|
||||||
values.insert(node.output[0].clone(), xs);
|
|
||||||
}
|
|
||||||
"Gather" => {
|
|
||||||
let xs = get(&node.input[0])?;
|
|
||||||
let indices = get(&node.input[1])?;
|
|
||||||
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
|
|
||||||
let axis = xs.normalize_axis(axis)?;
|
|
||||||
// TODO: Provide an op to handle the ONNX generalized gather op ideally in a
|
|
||||||
// differentiable way.
|
|
||||||
let xs = if indices.rank() == 0 {
|
|
||||||
let index = indices.to_vec0::<i64>()? as usize;
|
|
||||||
xs.narrow(axis, index, 1)?.squeeze(axis)?
|
|
||||||
} else {
|
|
||||||
todo!("implement gather for {xs:?} {indices:?} axis {axis}")
|
|
||||||
};
|
|
||||||
values.insert(node.output[0].clone(), xs);
|
|
||||||
}
|
|
||||||
"Shape" => {
|
|
||||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape
|
|
||||||
let xs = get(&node.input[0])?;
|
|
||||||
let start = get_attr_opt::<i64>(node, "start")?.copied().unwrap_or(0);
|
|
||||||
let end = get_attr_opt::<i64>(node, "end")?.copied().unwrap_or(-1);
|
|
||||||
let start = xs.normalize_axis(start)?;
|
|
||||||
let end = xs.normalize_axis(end)?;
|
|
||||||
let mut dims = vec![];
|
|
||||||
for idx in start..=end {
|
|
||||||
dims.push(xs.dim(idx)? as i64)
|
|
||||||
}
|
|
||||||
let dims = Tensor::from_vec(dims, xs.rank(), xs.device())?;
|
|
||||||
values.insert(node.output[0].clone(), dims);
|
|
||||||
}
|
|
||||||
"Conv" => {
|
|
||||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
|
|
||||||
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
|
|
||||||
let groups = get_attr_opt::<i64>(node, "group")?.copied().unwrap_or(1);
|
|
||||||
let _kernel_shape = get_attr_opt::<[i64]>(node, "kernel_shape")?;
|
|
||||||
let pads = get_attr_opt::<[i64]>(node, "pads")?;
|
|
||||||
let strides = get_attr_opt::<[i64]>(node, "strides")?;
|
|
||||||
let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
|
|
||||||
match auto_pad {
|
|
||||||
None | Some("NOTSET") => (),
|
|
||||||
Some(s) => bail!("unsupported auto_pad {s}"),
|
|
||||||
};
|
|
||||||
let xs = get(&node.input[0])?;
|
|
||||||
let ws = get(&node.input[1])?;
|
|
||||||
let ys = match ws.rank() {
|
|
||||||
3 => {
|
|
||||||
let (pads, xs) = match pads {
|
|
||||||
None => (0, xs.clone()),
|
|
||||||
Some([p]) => (*p as usize, xs.clone()),
|
|
||||||
Some([p1, p2]) => {
|
|
||||||
if p1 != p2 {
|
|
||||||
(0usize, xs.pad_with_zeros(2, *p1 as usize, *p2 as usize)?)
|
|
||||||
} else {
|
|
||||||
(*p1 as usize, xs.clone())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Some(pads) => {
|
|
||||||
bail!("more pads than expected in conv1d {pads:?} {}", node.name)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let strides = match strides {
|
|
||||||
None => 1,
|
|
||||||
Some([p]) => *p as usize,
|
|
||||||
Some(s) => {
|
|
||||||
bail!("more strides than expected in conv1d {s:?} {}", node.name)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let dilations = match dilations {
|
|
||||||
None => 1,
|
|
||||||
Some([p]) => *p as usize,
|
|
||||||
Some(s) => {
|
|
||||||
bail!("more dilations than expected in conv1d {s:?} {}", node.name)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
xs.conv1d(ws, pads, strides, dilations, groups as usize)?
|
|
||||||
}
|
|
||||||
4 => {
|
|
||||||
let (pads, xs) = match pads {
|
|
||||||
None => (0, xs.clone()),
|
|
||||||
Some([p]) => (*p as usize, xs.clone()),
|
|
||||||
Some(&[p1, p2, p3, p4]) => {
|
|
||||||
let p1 = p1 as usize;
|
|
||||||
let p2 = p2 as usize;
|
|
||||||
let p3 = p3 as usize;
|
|
||||||
let p4 = p4 as usize;
|
|
||||||
if p1 != p2 || p1 != p3 || p1 != p4 {
|
|
||||||
(0, xs.pad_with_zeros(2, p1, p3)?.pad_with_zeros(3, p2, p4)?)
|
|
||||||
} else {
|
|
||||||
(p1, xs.clone())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Some(pads) => {
|
|
||||||
bail!("more pads than expected in conv2d {pads:?} {}", node.name)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let strides = match strides {
|
|
||||||
None => 1,
|
|
||||||
Some([p]) => *p as usize,
|
|
||||||
Some([p1, p2]) => {
|
|
||||||
if p1 != p2 {
|
|
||||||
bail!(
|
|
||||||
"strides have to be the same on both axis {pads:?} {}",
|
|
||||||
node.name
|
|
||||||
)
|
|
||||||
}
|
|
||||||
*p1 as usize
|
|
||||||
}
|
|
||||||
Some(s) => {
|
|
||||||
bail!("more strides than expected in conv2d {s:?} {}", node.name)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let dilations = match dilations {
|
|
||||||
None => 1,
|
|
||||||
Some([p]) => *p as usize,
|
|
||||||
Some([p1, p2]) => {
|
|
||||||
if p1 != p2 {
|
|
||||||
bail!(
|
|
||||||
"dilations have to be the same on both axis {pads:?} {}",
|
|
||||||
node.name
|
|
||||||
)
|
|
||||||
}
|
|
||||||
*p1 as usize
|
|
||||||
}
|
|
||||||
Some(s) => {
|
|
||||||
bail!("more dilations than expected in conv2d {s:?} {}", node.name)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
xs.conv2d(ws, pads, strides, dilations, groups as usize)?
|
|
||||||
}
|
|
||||||
rank => bail!(
|
|
||||||
"unsupported rank for weight matrix {rank} in conv {}",
|
|
||||||
node.name
|
|
||||||
),
|
|
||||||
};
|
|
||||||
let ys = if node.input.len() > 2 {
|
|
||||||
let bs = get(&node.input[2])?;
|
|
||||||
let mut bs_shape = vec![1; ys.rank()];
|
|
||||||
bs_shape[1] = bs.elem_count();
|
|
||||||
ys.broadcast_add(&bs.reshape(bs_shape)?)?
|
|
||||||
} else {
|
|
||||||
ys
|
|
||||||
};
|
|
||||||
values.insert(node.output[0].clone(), ys);
|
|
||||||
}
|
|
||||||
"Concat" => {
|
|
||||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Concat
|
|
||||||
let inputs = node
|
|
||||||
.input
|
|
||||||
.iter()
|
|
||||||
.map(|n| Ok(get(n.as_str())?.clone()))
|
|
||||||
.collect::<Result<Vec<Value>>>()?;
|
|
||||||
let axis: i64 = *get_attr(node, "axis")?;
|
|
||||||
if inputs.is_empty() {
|
|
||||||
bail!("empty concat")
|
|
||||||
};
|
|
||||||
let axis = inputs[0].normalize_axis(axis)?;
|
|
||||||
let output = Tensor::cat(&inputs, axis)?;
|
|
||||||
values.insert(node.output[0].clone(), output);
|
|
||||||
}
|
|
||||||
"Abs" => {
|
|
||||||
let input = get(&node.input[0])?;
|
|
||||||
let output = input.abs()?;
|
|
||||||
values.insert(node.output[0].clone(), output);
|
|
||||||
}
|
|
||||||
"Cos" => {
|
|
||||||
let input = get(&node.input[0])?;
|
|
||||||
let output = input.cos()?;
|
|
||||||
values.insert(node.output[0].clone(), output);
|
|
||||||
}
|
|
||||||
"Sin" => {
|
|
||||||
let input = get(&node.input[0])?;
|
|
||||||
let output = input.sin()?;
|
|
||||||
values.insert(node.output[0].clone(), output);
|
|
||||||
}
|
|
||||||
"Neg" => {
|
|
||||||
let input = get(&node.input[0])?;
|
|
||||||
let output = input.neg()?;
|
|
||||||
values.insert(node.output[0].clone(), output);
|
|
||||||
}
|
|
||||||
"Erf" => {
|
|
||||||
let input = get(&node.input[0])?;
|
|
||||||
let output = input.erf()?;
|
|
||||||
values.insert(node.output[0].clone(), output);
|
|
||||||
}
|
|
||||||
"Tanh" => {
|
|
||||||
let input = get(&node.input[0])?;
|
|
||||||
let output = input.tanh()?;
|
|
||||||
values.insert(node.output[0].clone(), output);
|
|
||||||
}
|
|
||||||
"Sigmoid" => {
|
|
||||||
let input = get(&node.input[0])?;
|
|
||||||
let output = candle_nn::ops::sigmoid(input)?;
|
|
||||||
values.insert(node.output[0].clone(), output);
|
|
||||||
}
|
|
||||||
"Gelu" => {
|
|
||||||
let input = get(&node.input[0])?;
|
|
||||||
let output = input.gelu_erf()?;
|
|
||||||
values.insert(node.output[0].clone(), output);
|
|
||||||
}
|
|
||||||
"Relu" => {
|
|
||||||
let input = get(&node.input[0])?;
|
|
||||||
let output = input.relu()?;
|
|
||||||
values.insert(node.output[0].clone(), output);
|
|
||||||
}
|
|
||||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Constant
|
|
||||||
"Constant" => {
|
|
||||||
let value = match node.attribute.iter().find(|attr| attr.name == "value") {
|
|
||||||
None => {
|
|
||||||
// TODO: support sparse_value etc.
|
|
||||||
bail!("cannot find 'value' attr in 'Constant' for {}", node.name)
|
|
||||||
}
|
|
||||||
Some(value) => value,
|
|
||||||
};
|
|
||||||
let output = match value.r#type() {
|
|
||||||
AttributeType::Tensor => {
|
|
||||||
let t = value.t.as_ref().unwrap();
|
|
||||||
get_tensor(t, &node.name)?
|
|
||||||
}
|
|
||||||
rtype => bail!("unsupported 'value' type {rtype:?} for {}", node.name),
|
|
||||||
};
|
|
||||||
values.insert(node.output[0].clone(), output);
|
|
||||||
}
|
|
||||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
|
|
||||||
"Cast" => {
|
|
||||||
let input = get(&node.input[0])?;
|
|
||||||
let dt: i64 = *get_attr(node, "to")?;
|
|
||||||
let dtype = match DataType::try_from(dt as i32) {
|
|
||||||
Ok(DataType::Int32) => DType::I64,
|
|
||||||
Ok(dt) => match dtype(dt) {
|
|
||||||
Some(dt) => dt,
|
|
||||||
None => {
|
|
||||||
bail!("unsupported 'to' value {dt:?} for cast {}", node.name)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Err(_) => {
|
|
||||||
bail!("unsupported 'to' value {dt:?} for cast {}", node.name)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let output = input.to_dtype(dtype)?;
|
|
||||||
values.insert(node.output[0].clone(), output);
|
|
||||||
}
|
|
||||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
graph
|
|
||||||
.output
|
|
||||||
.iter()
|
|
||||||
.map(|output| match values.remove(&output.name) {
|
|
||||||
None => bail!("cannot find output {}", output.name),
|
|
||||||
Some(value) => Ok((output.name.clone(), value)),
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
|
@ -1,14 +0,0 @@
|
|||||||
use candle::Result;
|
|
||||||
use prost::Message;
|
|
||||||
|
|
||||||
pub mod onnx {
|
|
||||||
include!(concat!(env!("OUT_DIR"), "/onnx.rs"));
|
|
||||||
}
|
|
||||||
|
|
||||||
pub mod eval;
|
|
||||||
pub use eval::{dtype, simple_eval};
|
|
||||||
|
|
||||||
pub fn read_file<P: AsRef<std::path::Path>>(p: P) -> Result<onnx::ModelProto> {
|
|
||||||
let buf = std::fs::read(p)?;
|
|
||||||
onnx::ModelProto::decode(buf.as_slice()).map_err(candle::Error::wrap)
|
|
||||||
}
|
|
@ -1,836 +0,0 @@
|
|||||||
//
|
|
||||||
// WARNING: This file is automatically generated! Please edit onnx.in.proto.
|
|
||||||
//
|
|
||||||
|
|
||||||
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
|
|
||||||
syntax = "proto3";
|
|
||||||
|
|
||||||
package onnx;
|
|
||||||
|
|
||||||
// Overview
|
|
||||||
//
|
|
||||||
// ONNX is an open specification that is comprised of the following components:
|
|
||||||
//
|
|
||||||
// 1) A definition of an extensible computation graph model.
|
|
||||||
// 2) Definitions of standard data types.
|
|
||||||
// 3) Definitions of built-in operators.
|
|
||||||
//
|
|
||||||
// This document describes the syntax of models and their computation graphs,
|
|
||||||
// as well as the standard data types. Together, they are referred to as the ONNX
|
|
||||||
// Intermediate Representation, or 'IR' for short.
|
|
||||||
//
|
|
||||||
// The normative semantic specification of the ONNX IR is found in docs/IR.md.
|
|
||||||
// Definitions of the built-in neural network operators may be found in docs/Operators.md.
|
|
||||||
|
|
||||||
// Notes
|
|
||||||
//
|
|
||||||
// Protobuf compatibility
|
|
||||||
//
|
|
||||||
// To simplify framework compatibility, ONNX is defined using the subset of protobuf
|
|
||||||
// that is compatible with both protobuf v2 and v3. This means that we do not use any
|
|
||||||
// protobuf features that are only available in one of the two versions.
|
|
||||||
//
|
|
||||||
// Here are the most notable contortions we have to carry out to work around
|
|
||||||
// these limitations:
|
|
||||||
//
|
|
||||||
// - No 'map' (added protobuf 3.0). We instead represent mappings as lists
|
|
||||||
// of key-value pairs, where order does not matter and duplicates
|
|
||||||
// are not allowed.
|
|
||||||
|
|
||||||
|
|
||||||
// Versioning
|
|
||||||
//
|
|
||||||
// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md
|
|
||||||
//
|
|
||||||
// To be compatible with both proto2 and proto3, we will use a version number
|
|
||||||
// that is not defined by the default value but an explicit enum number.
|
|
||||||
enum Version {
|
|
||||||
// proto3 requires the first enum value to be zero.
|
|
||||||
// We add this just to appease the compiler.
|
|
||||||
_START_VERSION = 0;
|
|
||||||
// The version field is always serialized and we will use it to store the
|
|
||||||
// version that the graph is generated from. This helps us set up version
|
|
||||||
// control.
|
|
||||||
// For the IR, we are using simple numbers starting with 0x00000001,
|
|
||||||
// which was the version we published on Oct 10, 2017.
|
|
||||||
IR_VERSION_2017_10_10 = 0x0000000000000001;
|
|
||||||
|
|
||||||
// IR_VERSION 2 published on Oct 30, 2017
|
|
||||||
// - Added type discriminator to AttributeProto to support proto3 users
|
|
||||||
IR_VERSION_2017_10_30 = 0x0000000000000002;
|
|
||||||
|
|
||||||
// IR VERSION 3 published on Nov 3, 2017
|
|
||||||
// - For operator versioning:
|
|
||||||
// - Added new message OperatorSetIdProto
|
|
||||||
// - Added opset_import in ModelProto
|
|
||||||
// - For vendor extensions, added domain in NodeProto
|
|
||||||
IR_VERSION_2017_11_3 = 0x0000000000000003;
|
|
||||||
|
|
||||||
// IR VERSION 4 published on Jan 22, 2019
|
|
||||||
// - Relax constraint that initializers should be a subset of graph inputs
|
|
||||||
// - Add type BFLOAT16
|
|
||||||
IR_VERSION_2019_1_22 = 0x0000000000000004;
|
|
||||||
|
|
||||||
// IR VERSION 5 published on March 18, 2019
|
|
||||||
// - Add message TensorAnnotation.
|
|
||||||
// - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters.
|
|
||||||
IR_VERSION_2019_3_18 = 0x0000000000000005;
|
|
||||||
|
|
||||||
// IR VERSION 6 published on Sep 19, 2019
|
|
||||||
// - Add support for sparse tensor constants stored in model.
|
|
||||||
// - Add message SparseTensorProto
|
|
||||||
// - Add sparse initializers
|
|
||||||
IR_VERSION_2019_9_19 = 0x0000000000000006;
|
|
||||||
|
|
||||||
// IR VERSION 7 published on May 8, 2020
|
|
||||||
// - Add support to allow function body graph to rely on multiple external opreator sets.
|
|
||||||
// - Add a list to promote inference graph's initializers to global and
|
|
||||||
// mutable variables. Global variables are visible in all graphs of the
|
|
||||||
// stored models.
|
|
||||||
// - Add message TrainingInfoProto to store initialization
|
|
||||||
// method and training algorithm. The execution of TrainingInfoProto
|
|
||||||
// can modify the values of mutable variables.
|
|
||||||
// - Implicitly add inference graph into each TrainingInfoProto's algorithm.
|
|
||||||
IR_VERSION_2020_5_8 = 0x0000000000000007;
|
|
||||||
|
|
||||||
// IR VERSION 8 published on July 30, 2021
|
|
||||||
// Introduce TypeProto.SparseTensor
|
|
||||||
// Introduce TypeProto.Optional
|
|
||||||
// Added a list of FunctionProtos local to the model
|
|
||||||
// Deprecated since_version and operator status from FunctionProto
|
|
||||||
IR_VERSION_2021_7_30 = 0x0000000000000008;
|
|
||||||
|
|
||||||
// IR VERSION 9 published on May 5, 2023
|
|
||||||
// Added AttributeProto to FunctionProto so that default attribute values can be set.
|
|
||||||
// Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ.
|
|
||||||
IR_VERSION = 0x0000000000000009;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attributes
|
|
||||||
//
|
|
||||||
// A named attribute containing either singular float, integer, string, graph,
|
|
||||||
// and tensor values, or repeated float, integer, string, graph, and tensor values.
|
|
||||||
// An AttributeProto MUST contain the name field, and *only one* of the
|
|
||||||
// following content fields, effectively enforcing a C/C++ union equivalent.
|
|
||||||
message AttributeProto {
|
|
||||||
reserved 12, 16 to 19;
|
|
||||||
reserved "v";
|
|
||||||
|
|
||||||
// Note: this enum is structurally identical to the OpSchema::AttrType
|
|
||||||
// enum defined in schema.h. If you rev one, you likely need to rev the other.
|
|
||||||
enum AttributeType {
|
|
||||||
UNDEFINED = 0;
|
|
||||||
FLOAT = 1;
|
|
||||||
INT = 2;
|
|
||||||
STRING = 3;
|
|
||||||
TENSOR = 4;
|
|
||||||
GRAPH = 5;
|
|
||||||
SPARSE_TENSOR = 11;
|
|
||||||
TYPE_PROTO = 13;
|
|
||||||
|
|
||||||
FLOATS = 6;
|
|
||||||
INTS = 7;
|
|
||||||
STRINGS = 8;
|
|
||||||
TENSORS = 9;
|
|
||||||
GRAPHS = 10;
|
|
||||||
SPARSE_TENSORS = 12;
|
|
||||||
TYPE_PROTOS = 14;
|
|
||||||
}
|
|
||||||
|
|
||||||
// The name field MUST be present for this version of the IR.
|
|
||||||
string name = 1; // namespace Attribute
|
|
||||||
|
|
||||||
// if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
|
|
||||||
// In this case, this AttributeProto does not contain data, and it's a reference of attribute
|
|
||||||
// in parent scope.
|
|
||||||
// NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph.
|
|
||||||
string ref_attr_name = 21;
|
|
||||||
|
|
||||||
// A human-readable documentation for this attribute. Markdown is allowed.
|
|
||||||
string doc_string = 13;
|
|
||||||
|
|
||||||
// The type field MUST be present for this version of the IR.
|
|
||||||
// For 0.0.1 versions of the IR, this field was not defined, and
|
|
||||||
// implementations needed to use has_field heuristics to determine
|
|
||||||
// which value field was in use. For IR_VERSION 0.0.2 or later, this
|
|
||||||
// field MUST be set and match the f|i|s|t|... field in use. This
|
|
||||||
// change was made to accommodate proto3 implementations.
|
|
||||||
AttributeType type = 20; // discriminator that indicates which field below is in use
|
|
||||||
|
|
||||||
// Exactly ONE of the following fields must be present for this version of the IR
|
|
||||||
float f = 2; // float
|
|
||||||
int64 i = 3; // int
|
|
||||||
bytes s = 4; // UTF-8 string
|
|
||||||
TensorProto t = 5; // tensor value
|
|
||||||
GraphProto g = 6; // graph
|
|
||||||
SparseTensorProto sparse_tensor = 22; // sparse tensor value
|
|
||||||
// Do not use field below, it's deprecated.
|
|
||||||
// optional ValueProto v = 12; // value - subsumes everything but graph
|
|
||||||
TypeProto tp = 14; // type proto
|
|
||||||
|
|
||||||
repeated float floats = 7; // list of floats
|
|
||||||
repeated int64 ints = 8; // list of ints
|
|
||||||
repeated bytes strings = 9; // list of UTF-8 strings
|
|
||||||
repeated TensorProto tensors = 10; // list of tensors
|
|
||||||
repeated GraphProto graphs = 11; // list of graph
|
|
||||||
repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors
|
|
||||||
repeated TypeProto type_protos = 15;// list of type protos
|
|
||||||
}
|
|
||||||
|
|
||||||
// Defines information on value, including the name, the type, and
|
|
||||||
// the shape of the value.
|
|
||||||
message ValueInfoProto {
|
|
||||||
// This field MUST be present in this version of the IR.
|
|
||||||
string name = 1; // namespace Value
|
|
||||||
// This field MUST be present in this version of the IR for
|
|
||||||
// inputs and outputs of the top-level graph.
|
|
||||||
TypeProto type = 2;
|
|
||||||
// A human-readable documentation for this value. Markdown is allowed.
|
|
||||||
string doc_string = 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Nodes
|
|
||||||
//
|
|
||||||
// Computation graphs are made up of a DAG of nodes, which represent what is
|
|
||||||
// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
|
|
||||||
//
|
|
||||||
// For example, it can be a node of type "Conv" that takes in an image, a filter
|
|
||||||
// tensor and a bias tensor, and produces the convolved output.
|
|
||||||
message NodeProto {
|
|
||||||
repeated string input = 1; // namespace Value
|
|
||||||
repeated string output = 2; // namespace Value
|
|
||||||
|
|
||||||
// An optional identifier for this node in a graph.
|
|
||||||
// This field MAY be absent in ths version of the IR.
|
|
||||||
string name = 3; // namespace Node
|
|
||||||
|
|
||||||
// The symbolic identifier of the Operator to execute.
|
|
||||||
string op_type = 4; // namespace Operator
|
|
||||||
// The domain of the OperatorSet that specifies the operator named by op_type.
|
|
||||||
string domain = 7; // namespace Domain
|
|
||||||
|
|
||||||
// Additional named attributes.
|
|
||||||
repeated AttributeProto attribute = 5;
|
|
||||||
|
|
||||||
// A human-readable documentation for this node. Markdown is allowed.
|
|
||||||
string doc_string = 6;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Training information
|
|
||||||
// TrainingInfoProto stores information for training a model.
|
|
||||||
// In particular, this defines two functionalities: an initialization-step
|
|
||||||
// and a training-algorithm-step. Initialization resets the model
|
|
||||||
// back to its original state as if no training has been performed.
|
|
||||||
// Training algorithm improves the model based on input data.
|
|
||||||
//
|
|
||||||
// The semantics of the initialization-step is that the initializers
|
|
||||||
// in ModelProto.graph and in TrainingInfoProto.algorithm are first
|
|
||||||
// initialized as specified by the initializers in the graph, and then
|
|
||||||
// updated by the "initialization_binding" in every instance in
|
|
||||||
// ModelProto.training_info.
|
|
||||||
//
|
|
||||||
// The field "algorithm" defines a computation graph which represents a
|
|
||||||
// training algorithm's step. After the execution of a
|
|
||||||
// TrainingInfoProto.algorithm, the initializers specified by "update_binding"
|
|
||||||
// may be immediately updated. If the targeted training algorithm contains
|
|
||||||
// consecutive update steps (such as block coordinate descent methods),
|
|
||||||
// the user needs to create a TrainingInfoProto for each step.
|
|
||||||
message TrainingInfoProto {
|
|
||||||
// This field describes a graph to compute the initial tensors
|
|
||||||
// upon starting the training process. Initialization graph has no input
|
|
||||||
// and can have multiple outputs. Usually, trainable tensors in neural
|
|
||||||
// networks are randomly initialized. To achieve that, for each tensor,
|
|
||||||
// the user can put a random number operator such as RandomNormal or
|
|
||||||
// RandomUniform in TrainingInfoProto.initialization.node and assign its
|
|
||||||
// random output to the specific tensor using "initialization_binding".
|
|
||||||
// This graph can also set the initializers in "algorithm" in the same
|
|
||||||
// TrainingInfoProto; a use case is resetting the number of training
|
|
||||||
// iteration to zero.
|
|
||||||
//
|
|
||||||
// By default, this field is an empty graph and its evaluation does not
|
|
||||||
// produce any output. Thus, no initializer would be changed by default.
|
|
||||||
GraphProto initialization = 1;
|
|
||||||
|
|
||||||
// This field represents a training algorithm step. Given required inputs,
|
|
||||||
// it computes outputs to update initializers in its own or inference graph's
|
|
||||||
// initializer lists. In general, this field contains loss node, gradient node,
|
|
||||||
// optimizer node, increment of iteration count.
|
|
||||||
//
|
|
||||||
// An execution of the training algorithm step is performed by executing the
|
|
||||||
// graph obtained by combining the inference graph (namely "ModelProto.graph")
|
|
||||||
// and the "algorithm" graph. That is, the actual
|
|
||||||
// input/initializer/output/node/value_info/sparse_initializer list of
|
|
||||||
// the training graph is the concatenation of
|
|
||||||
// "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer"
|
|
||||||
// and "algorithm.input/initializer/output/node/value_info/sparse_initializer"
|
|
||||||
// in that order. This combined graph must satisfy the normal ONNX conditions.
|
|
||||||
// Now, let's provide a visualization of graph combination for clarity.
|
|
||||||
// Let the inference graph (i.e., "ModelProto.graph") be
|
|
||||||
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d
|
|
||||||
// and the "algorithm" graph be
|
|
||||||
// tensor_d -> Add -> tensor_e
|
|
||||||
// The combination process results
|
|
||||||
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e
|
|
||||||
//
|
|
||||||
// Notice that an input of a node in the "algorithm" graph may reference the
|
|
||||||
// output of a node in the inference graph (but not the other way round). Also, inference
|
|
||||||
// node cannot reference inputs of "algorithm". With these restrictions, inference graph
|
|
||||||
// can always be run independently without training information.
|
|
||||||
//
|
|
||||||
// By default, this field is an empty graph and its evaluation does not
|
|
||||||
// produce any output. Evaluating the default training step never
|
|
||||||
// update any initializers.
|
|
||||||
GraphProto algorithm = 2;
|
|
||||||
|
|
||||||
// This field specifies the bindings from the outputs of "initialization" to
|
|
||||||
// some initializers in "ModelProto.graph.initializer" and
|
|
||||||
// the "algorithm.initializer" in the same TrainingInfoProto.
|
|
||||||
// See "update_binding" below for details.
|
|
||||||
//
|
|
||||||
// By default, this field is empty and no initializer would be changed
|
|
||||||
// by the execution of "initialization".
|
|
||||||
repeated StringStringEntryProto initialization_binding = 3;
|
|
||||||
|
|
||||||
// Gradient-based training is usually an iterative procedure. In one gradient
|
|
||||||
// descent iteration, we apply
|
|
||||||
//
|
|
||||||
// x = x - r * g
|
|
||||||
//
|
|
||||||
// where "x" is the optimized tensor, "r" stands for learning rate, and "g" is
|
|
||||||
// gradient of "x" with respect to a chosen loss. To avoid adding assignments
|
|
||||||
// into the training graph, we split the update equation into
|
|
||||||
//
|
|
||||||
// y = x - r * g
|
|
||||||
// x = y
|
|
||||||
//
|
|
||||||
// The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To
|
|
||||||
// tell that "y" should be assigned to "x", the field "update_binding" may
|
|
||||||
// contain a key-value pair of strings, "x" (key of StringStringEntryProto)
|
|
||||||
// and "y" (value of StringStringEntryProto).
|
|
||||||
// For a neural network with multiple trainable (mutable) tensors, there can
|
|
||||||
// be multiple key-value pairs in "update_binding".
|
|
||||||
//
|
|
||||||
// The initializers appears as keys in "update_binding" are considered
|
|
||||||
// mutable variables. This implies some behaviors
|
|
||||||
// as described below.
|
|
||||||
//
|
|
||||||
// 1. We have only unique keys in all "update_binding"s so that two
|
|
||||||
// variables may not have the same name. This ensures that one
|
|
||||||
// variable is assigned up to once.
|
|
||||||
// 2. The keys must appear in names of "ModelProto.graph.initializer" or
|
|
||||||
// "TrainingInfoProto.algorithm.initializer".
|
|
||||||
// 3. The values must be output names of "algorithm" or "ModelProto.graph.output".
|
|
||||||
// 4. Mutable variables are initialized to the value specified by the
|
|
||||||
// corresponding initializer, and then potentially updated by
|
|
||||||
// "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s.
|
|
||||||
//
|
|
||||||
// This field usually contains names of trainable tensors
|
|
||||||
// (in ModelProto.graph), optimizer states such as momentums in advanced
|
|
||||||
// stochastic gradient methods (in TrainingInfoProto.graph),
|
|
||||||
// and number of training iterations (in TrainingInfoProto.graph).
|
|
||||||
//
|
|
||||||
// By default, this field is empty and no initializer would be changed
|
|
||||||
// by the execution of "algorithm".
|
|
||||||
repeated StringStringEntryProto update_binding = 4;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Models
|
|
||||||
//
|
|
||||||
// ModelProto is a top-level file/container format for bundling a ML model and
|
|
||||||
// associating its computation graph with metadata.
|
|
||||||
//
|
|
||||||
// The semantics of the model are described by the associated GraphProto's.
|
|
||||||
message ModelProto {
|
|
||||||
// The version of the IR this model targets. See Version enum above.
|
|
||||||
// This field MUST be present.
|
|
||||||
int64 ir_version = 1;
|
|
||||||
|
|
||||||
// The OperatorSets this model relies on.
|
|
||||||
// All ModelProtos MUST have at least one entry that
|
|
||||||
// specifies which version of the ONNX OperatorSet is
|
|
||||||
// being imported.
|
|
||||||
//
|
|
||||||
// All nodes in the ModelProto's graph will bind against the operator
|
|
||||||
// with the same-domain/same-op_type operator with the HIGHEST version
|
|
||||||
// in the referenced operator sets.
|
|
||||||
repeated OperatorSetIdProto opset_import = 8;
|
|
||||||
|
|
||||||
// The name of the framework or tool used to generate this model.
|
|
||||||
// This field SHOULD be present to indicate which implementation/tool/framework
|
|
||||||
// emitted the model.
|
|
||||||
string producer_name = 2;
|
|
||||||
|
|
||||||
// The version of the framework or tool used to generate this model.
|
|
||||||
// This field SHOULD be present to indicate which implementation/tool/framework
|
|
||||||
// emitted the model.
|
|
||||||
string producer_version = 3;
|
|
||||||
|
|
||||||
// Domain name of the model.
|
|
||||||
// We use reverse domain names as name space indicators. For example:
|
|
||||||
// `com.facebook.fair` or `com.microsoft.cognitiveservices`
|
|
||||||
//
|
|
||||||
// Together with `model_version` and GraphProto.name, this forms the unique identity of
|
|
||||||
// the graph.
|
|
||||||
string domain = 4;
|
|
||||||
|
|
||||||
// The version of the graph encoded. See Version enum below.
|
|
||||||
int64 model_version = 5;
|
|
||||||
|
|
||||||
// A human-readable documentation for this model. Markdown is allowed.
|
|
||||||
string doc_string = 6;
|
|
||||||
|
|
||||||
// The parameterized graph that is evaluated to execute the model.
|
|
||||||
GraphProto graph = 7;
|
|
||||||
|
|
||||||
// Named metadata values; keys should be distinct.
|
|
||||||
repeated StringStringEntryProto metadata_props = 14;
|
|
||||||
|
|
||||||
// Training-specific information. Sequentially executing all stored
|
|
||||||
// `TrainingInfoProto.algorithm`s and assigning their outputs following
|
|
||||||
// the corresponding `TrainingInfoProto.update_binding`s is one training
|
|
||||||
// iteration. Similarly, to initialize the model
|
|
||||||
// (as if training hasn't happened), the user should sequentially execute
|
|
||||||
// all stored `TrainingInfoProto.initialization`s and assigns their outputs
|
|
||||||
// using `TrainingInfoProto.initialization_binding`s.
|
|
||||||
//
|
|
||||||
// If this field is empty, the training behavior of the model is undefined.
|
|
||||||
repeated TrainingInfoProto training_info = 20;
|
|
||||||
|
|
||||||
// A list of function protos local to the model.
|
|
||||||
//
|
|
||||||
// Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain".
|
|
||||||
// In case of any conflicts the behavior (whether the model local functions are given higher priority,
|
|
||||||
// or standard operator sets are given higher priotity or this is treated as error) is defined by
|
|
||||||
// the runtimes.
|
|
||||||
//
|
|
||||||
// The operator sets imported by FunctionProto should be compatible with the ones
|
|
||||||
// imported by ModelProto and other model local FunctionProtos.
|
|
||||||
// Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto
|
|
||||||
// or by 2 FunctionProtos then versions for the operator set may be different but,
|
|
||||||
// the operator schema returned for op_type, domain, version combination
|
|
||||||
// for both the versions should be same for every node in the function body.
|
|
||||||
//
|
|
||||||
// One FunctionProto can reference other FunctionProto in the model, however, recursive reference
|
|
||||||
// is not allowed.
|
|
||||||
repeated FunctionProto functions = 25;
|
|
||||||
};
|
|
||||||
|
|
||||||
// StringStringEntryProto follows the pattern for cross-proto-version maps.
|
|
||||||
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
|
|
||||||
message StringStringEntryProto {
|
|
||||||
string key = 1;
|
|
||||||
string value = 2;
|
|
||||||
};
|
|
||||||
|
|
||||||
message TensorAnnotation {
|
|
||||||
string tensor_name = 1;
|
|
||||||
// <key, value> pairs to annotate tensor specified by <tensor_name> above.
|
|
||||||
// The keys used in the mapping below must be pre-defined in ONNX spec.
|
|
||||||
// For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as
|
|
||||||
// quantization parameter keys.
|
|
||||||
repeated StringStringEntryProto quant_parameter_tensor_names = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// Graphs
|
|
||||||
//
|
|
||||||
// A graph defines the computational logic of a model and is comprised of a parameterized
|
|
||||||
// list of nodes that form a directed acyclic graph based on their inputs and outputs.
|
|
||||||
// This is the equivalent of the "network" or "graph" in many deep learning
|
|
||||||
// frameworks.
|
|
||||||
message GraphProto {
|
|
||||||
// The nodes in the graph, sorted topologically.
|
|
||||||
repeated NodeProto node = 1;
|
|
||||||
|
|
||||||
// The name of the graph.
|
|
||||||
string name = 2; // namespace Graph
|
|
||||||
|
|
||||||
// A list of named tensor values, used to specify constant inputs of the graph.
|
|
||||||
// Each initializer (both TensorProto as well SparseTensorProto) MUST have a name.
|
|
||||||
// The name MUST be unique across both initializer and sparse_initializer,
|
|
||||||
// but the name MAY also appear in the input list.
|
|
||||||
repeated TensorProto initializer = 5;
|
|
||||||
|
|
||||||
// Initializers (see above) stored in sparse format.
|
|
||||||
repeated SparseTensorProto sparse_initializer = 15;
|
|
||||||
|
|
||||||
// A human-readable documentation for this graph. Markdown is allowed.
|
|
||||||
string doc_string = 10;
|
|
||||||
|
|
||||||
// The inputs and outputs of the graph.
|
|
||||||
repeated ValueInfoProto input = 11;
|
|
||||||
repeated ValueInfoProto output = 12;
|
|
||||||
|
|
||||||
// Information for the values in the graph. The ValueInfoProto.name's
|
|
||||||
// must be distinct. It is optional for a value to appear in value_info list.
|
|
||||||
repeated ValueInfoProto value_info = 13;
|
|
||||||
|
|
||||||
// This field carries information to indicate the mapping among a tensor and its
|
|
||||||
// quantization parameter tensors. For example:
|
|
||||||
// For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated,
|
|
||||||
// which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
|
|
||||||
repeated TensorAnnotation quantization_annotation = 14;
|
|
||||||
|
|
||||||
reserved 3, 4, 6 to 9;
|
|
||||||
reserved "ir_version", "producer_version", "producer_tag", "domain";
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tensors
|
|
||||||
//
|
|
||||||
// A serialized tensor value.
|
|
||||||
message TensorProto {
|
|
||||||
enum DataType {
|
|
||||||
UNDEFINED = 0;
|
|
||||||
// Basic types.
|
|
||||||
FLOAT = 1; // float
|
|
||||||
UINT8 = 2; // uint8_t
|
|
||||||
INT8 = 3; // int8_t
|
|
||||||
UINT16 = 4; // uint16_t
|
|
||||||
INT16 = 5; // int16_t
|
|
||||||
INT32 = 6; // int32_t
|
|
||||||
INT64 = 7; // int64_t
|
|
||||||
STRING = 8; // string
|
|
||||||
BOOL = 9; // bool
|
|
||||||
|
|
||||||
// IEEE754 half-precision floating-point format (16 bits wide).
|
|
||||||
// This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
|
|
||||||
FLOAT16 = 10;
|
|
||||||
|
|
||||||
DOUBLE = 11;
|
|
||||||
UINT32 = 12;
|
|
||||||
UINT64 = 13;
|
|
||||||
COMPLEX64 = 14; // complex with float32 real and imaginary components
|
|
||||||
COMPLEX128 = 15; // complex with float64 real and imaginary components
|
|
||||||
|
|
||||||
// Non-IEEE floating-point format based on IEEE754 single-precision
|
|
||||||
// floating-point number truncated to 16 bits.
|
|
||||||
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
|
|
||||||
BFLOAT16 = 16;
|
|
||||||
|
|
||||||
// Non-IEEE floating-point format based on papers
|
|
||||||
// FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433,
|
|
||||||
// 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf.
|
|
||||||
// Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear.
|
|
||||||
// The computation usually happens inside a block quantize / dequantize
|
|
||||||
// fused by the runtime.
|
|
||||||
FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf
|
|
||||||
FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero
|
|
||||||
FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients
|
|
||||||
FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero
|
|
||||||
|
|
||||||
// Future extensions go here.
|
|
||||||
}
|
|
||||||
|
|
||||||
// The shape of the tensor.
|
|
||||||
repeated int64 dims = 1;
|
|
||||||
|
|
||||||
// The data type of the tensor.
|
|
||||||
// This field MUST have a valid TensorProto.DataType value
|
|
||||||
int32 data_type = 2;
|
|
||||||
|
|
||||||
// For very large tensors, we may want to store them in chunks, in which
|
|
||||||
// case the following fields will specify the segment that is stored in
|
|
||||||
// the current TensorProto.
|
|
||||||
message Segment {
|
|
||||||
int64 begin = 1;
|
|
||||||
int64 end = 2;
|
|
||||||
}
|
|
||||||
Segment segment = 3;
|
|
||||||
|
|
||||||
// Tensor content must be organized in row-major order.
|
|
||||||
//
|
|
||||||
// Depending on the data_type field, exactly one of the fields below with
|
|
||||||
// name ending in _data is used to store the elements of the tensor.
|
|
||||||
|
|
||||||
// For float and complex64 values
|
|
||||||
// Complex64 tensors are encoded as a single array of floats,
|
|
||||||
// with the real components appearing in odd numbered positions,
|
|
||||||
// and the corresponding imaginary component appearing in the
|
|
||||||
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
|
|
||||||
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
|
|
||||||
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
|
|
||||||
repeated float float_data = 4 [packed = true];
|
|
||||||
|
|
||||||
// For int32, uint8, int8, uint16, int16, bool, float8, and float16 values
|
|
||||||
// float16 and float8 values must be bit-wise converted to an uint16_t prior
|
|
||||||
// to writing to the buffer.
|
|
||||||
// When this field is present, the data_type field MUST be
|
|
||||||
// INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ
|
|
||||||
repeated int32 int32_data = 5 [packed = true];
|
|
||||||
|
|
||||||
// For strings.
|
|
||||||
// Each element of string_data is a UTF-8 encoded Unicode
|
|
||||||
// string. No trailing null, no leading BOM. The protobuf "string"
|
|
||||||
// scalar type is not used to match ML community conventions.
|
|
||||||
// When this field is present, the data_type field MUST be STRING
|
|
||||||
repeated bytes string_data = 6;
|
|
||||||
|
|
||||||
// For int64.
|
|
||||||
// When this field is present, the data_type field MUST be INT64
|
|
||||||
repeated int64 int64_data = 7 [packed = true];
|
|
||||||
|
|
||||||
// Optionally, a name for the tensor.
|
|
||||||
string name = 8; // namespace Value
|
|
||||||
|
|
||||||
// A human-readable documentation for this tensor. Markdown is allowed.
|
|
||||||
string doc_string = 12;
|
|
||||||
|
|
||||||
// Serializations can either use one of the fields above, or use this
|
|
||||||
// raw bytes field. The only exception is the string case, where one is
|
|
||||||
// required to store the content in the repeated bytes string_data field.
|
|
||||||
//
|
|
||||||
// When this raw_data field is used to store tensor value, elements MUST
|
|
||||||
// be stored in as fixed-width, little-endian order.
|
|
||||||
// Floating-point data types MUST be stored in IEEE 754 format.
|
|
||||||
// Complex64 elements must be written as two consecutive FLOAT values, real component first.
|
|
||||||
// Complex128 elements must be written as two consecutive DOUBLE values, real component first.
|
|
||||||
// Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
|
|
||||||
//
|
|
||||||
// Note: the advantage of specific field rather than the raw_data field is
|
|
||||||
// that in some cases (e.g. int data), protobuf does a better packing via
|
|
||||||
// variable length storage, and may lead to smaller binary footprint.
|
|
||||||
// When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
|
|
||||||
bytes raw_data = 9;
|
|
||||||
|
|
||||||
// Data can be stored inside the protobuf file using type-specific fields or raw_data.
|
|
||||||
// Alternatively, raw bytes data can be stored in an external file, using the external_data field.
|
|
||||||
// external_data stores key-value pairs describing data location. Recognized keys are:
|
|
||||||
// - "location" (required) - POSIX filesystem path relative to the directory where the ONNX
|
|
||||||
// protobuf model was stored
|
|
||||||
// - "offset" (optional) - position of byte at which stored data begins. Integer stored as string.
|
|
||||||
// Offset values SHOULD be multiples 4096 (page size) to enable mmap support.
|
|
||||||
// - "length" (optional) - number of bytes containing data. Integer stored as string.
|
|
||||||
// - "checksum" (optional) - SHA1 digest of file specified in under 'location' key.
|
|
||||||
repeated StringStringEntryProto external_data = 13;
|
|
||||||
|
|
||||||
// Location of the data for this tensor. MUST be one of:
|
|
||||||
// - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field.
|
|
||||||
// - EXTERNAL - data stored in an external location as described by external_data field.
|
|
||||||
enum DataLocation {
|
|
||||||
DEFAULT = 0;
|
|
||||||
EXTERNAL = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If value not set, data is stored in raw_data (if set) otherwise in type-specified field.
|
|
||||||
DataLocation data_location = 14;
|
|
||||||
|
|
||||||
// For double
|
|
||||||
// Complex128 tensors are encoded as a single array of doubles,
|
|
||||||
// with the real components appearing in odd numbered positions,
|
|
||||||
// and the corresponding imaginary component appearing in the
|
|
||||||
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
|
|
||||||
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
|
|
||||||
// When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
|
|
||||||
repeated double double_data = 10 [packed = true];
|
|
||||||
|
|
||||||
// For uint64 and uint32 values
|
|
||||||
// When this field is present, the data_type field MUST be
|
|
||||||
// UINT32 or UINT64
|
|
||||||
repeated uint64 uint64_data = 11 [packed = true];
|
|
||||||
}
|
|
||||||
|
|
||||||
// A serialized sparse-tensor value
|
|
||||||
message SparseTensorProto {
|
|
||||||
// The sequence of non-default values are encoded as a tensor of shape [NNZ].
|
|
||||||
// The default-value is zero for numeric tensors, and empty-string for string tensors.
|
|
||||||
// values must have a non-empty name present which serves as a name for SparseTensorProto
|
|
||||||
// when used in sparse_initializer list.
|
|
||||||
TensorProto values = 1;
|
|
||||||
|
|
||||||
// The indices of the non-default values, which may be stored in one of two formats.
|
|
||||||
// (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value
|
|
||||||
// corresponding to the j-th index of the i-th value (in the values tensor).
|
|
||||||
// (b) Indices can be a tensor of shape [NNZ], in which case the i-th value
|
|
||||||
// must be the linearized-index of the i-th value (in the values tensor).
|
|
||||||
// The linearized-index can be converted into an index tuple (k_1,...,k_rank)
|
|
||||||
// using the shape provided below.
|
|
||||||
// The indices must appear in ascending order without duplication.
|
|
||||||
// In the first format, the ordering is lexicographic-ordering:
|
|
||||||
// e.g., index-value [1,4] must appear before [2,1]
|
|
||||||
TensorProto indices = 2;
|
|
||||||
|
|
||||||
// The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank]
|
|
||||||
repeated int64 dims = 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Defines a tensor shape. A dimension can be either an integer value
|
|
||||||
// or a symbolic variable. A symbolic variable represents an unknown
|
|
||||||
// dimension.
|
|
||||||
message TensorShapeProto {
|
|
||||||
message Dimension {
|
|
||||||
oneof value {
|
|
||||||
int64 dim_value = 1;
|
|
||||||
string dim_param = 2; // namespace Shape
|
|
||||||
};
|
|
||||||
// Standard denotation can optionally be used to denote tensor
|
|
||||||
// dimensions with standard semantic descriptions to ensure
|
|
||||||
// that operations are applied to the correct axis of a tensor.
|
|
||||||
// Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
|
|
||||||
// for pre-defined dimension denotations.
|
|
||||||
string denotation = 3;
|
|
||||||
};
|
|
||||||
repeated Dimension dim = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Types
|
|
||||||
//
|
|
||||||
// The standard ONNX data types.
|
|
||||||
message TypeProto {
|
|
||||||
|
|
||||||
message Tensor {
|
|
||||||
// This field MUST NOT have the value of UNDEFINED
|
|
||||||
// This field MUST have a valid TensorProto.DataType value
|
|
||||||
// This field MUST be present for this version of the IR.
|
|
||||||
int32 elem_type = 1;
|
|
||||||
TensorShapeProto shape = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
// repeated T
|
|
||||||
message Sequence {
|
|
||||||
// The type and optional shape of each element of the sequence.
|
|
||||||
// This field MUST be present for this version of the IR.
|
|
||||||
TypeProto elem_type = 1;
|
|
||||||
};
|
|
||||||
|
|
||||||
// map<K,V>
|
|
||||||
message Map {
|
|
||||||
// This field MUST have a valid TensorProto.DataType value
|
|
||||||
// This field MUST be present for this version of the IR.
|
|
||||||
// This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING
|
|
||||||
int32 key_type = 1;
|
|
||||||
// This field MUST be present for this version of the IR.
|
|
||||||
TypeProto value_type = 2;
|
|
||||||
};
|
|
||||||
|
|
||||||
// wrapper for Tensor, Sequence, or Map
|
|
||||||
message Optional {
|
|
||||||
// The type and optional shape of the element wrapped.
|
|
||||||
// This field MUST be present for this version of the IR.
|
|
||||||
// Possible values correspond to OptionalProto.DataType enum
|
|
||||||
TypeProto elem_type = 1;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
message SparseTensor {
|
|
||||||
// This field MUST NOT have the value of UNDEFINED
|
|
||||||
// This field MUST have a valid TensorProto.DataType value
|
|
||||||
// This field MUST be present for this version of the IR.
|
|
||||||
int32 elem_type = 1;
|
|
||||||
TensorShapeProto shape = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
oneof value {
|
|
||||||
// The type of a tensor.
|
|
||||||
Tensor tensor_type = 1;
|
|
||||||
|
|
||||||
// NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values
|
|
||||||
// as input and output to graphs and nodes. These types are needed to naturally
|
|
||||||
// support classical ML operators. DNN operators SHOULD restrict their input
|
|
||||||
// and output types to tensors.
|
|
||||||
|
|
||||||
// The type of a sequence.
|
|
||||||
Sequence sequence_type = 4;
|
|
||||||
|
|
||||||
// The type of a map.
|
|
||||||
Map map_type = 5;
|
|
||||||
|
|
||||||
// The type of an optional.
|
|
||||||
Optional optional_type = 9;
|
|
||||||
|
|
||||||
|
|
||||||
// Type of the sparse tensor
|
|
||||||
SparseTensor sparse_tensor_type = 8;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// An optional denotation can be used to denote the whole
|
|
||||||
// type with a standard semantic description as to what is
|
|
||||||
// stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition
|
|
||||||
// for pre-defined type denotations.
|
|
||||||
string denotation = 6;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Operator Sets
|
|
||||||
//
|
|
||||||
// OperatorSets are uniquely identified by a (domain, opset_version) pair.
|
|
||||||
message OperatorSetIdProto {
|
|
||||||
// The domain of the operator set being identified.
|
|
||||||
// The empty string ("") or absence of this field implies the operator
|
|
||||||
// set that is defined as part of the ONNX specification.
|
|
||||||
// This field MUST be present in this version of the IR when referring to any other operator set.
|
|
||||||
string domain = 1;
|
|
||||||
|
|
||||||
// The version of the operator set being identified.
|
|
||||||
// This field MUST be present in this version of the IR.
|
|
||||||
int64 version = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Operator/function status.
|
|
||||||
enum OperatorStatus {
|
|
||||||
EXPERIMENTAL = 0;
|
|
||||||
STABLE = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message FunctionProto {
|
|
||||||
// The name of the function, similar usage of op_type in OperatorProto.
|
|
||||||
// Combined with FunctionProto.domain, this forms the unique identity of
|
|
||||||
// the FunctionProto.
|
|
||||||
string name = 1;
|
|
||||||
|
|
||||||
// Deprecated since IR Version 8
|
|
||||||
// optional int64 since_version = 2;
|
|
||||||
reserved 2;
|
|
||||||
reserved "since_version";
|
|
||||||
|
|
||||||
// Deprecated since IR Version 8
|
|
||||||
// optional OperatorStatus status = 3;
|
|
||||||
reserved 3;
|
|
||||||
reserved "status";
|
|
||||||
|
|
||||||
// The inputs and outputs of the function.
|
|
||||||
repeated string input = 4;
|
|
||||||
repeated string output = 5;
|
|
||||||
|
|
||||||
// The attribute parameters of the function.
|
|
||||||
// It is for function parameters without default values.
|
|
||||||
repeated string attribute = 6;
|
|
||||||
|
|
||||||
// The attribute protos of the function.
|
|
||||||
// It is for function attributes with default values.
|
|
||||||
// A function attribute shall be represented either as
|
|
||||||
// a string attribute or an AttributeProto, not both.
|
|
||||||
repeated AttributeProto attribute_proto = 11;
|
|
||||||
|
|
||||||
// The nodes in the function.
|
|
||||||
repeated NodeProto node = 7;
|
|
||||||
// A human-readable documentation for this function. Markdown is allowed.
|
|
||||||
string doc_string = 8;
|
|
||||||
|
|
||||||
// The OperatorSets this function body (graph) relies on.
|
|
||||||
//
|
|
||||||
// All nodes in the function body (graph) will bind against the operator
|
|
||||||
// with the same-domain/same-op_type operator with the HIGHEST version
|
|
||||||
// in the referenced operator sets. This means at most one version can be relied
|
|
||||||
// for one domain.
|
|
||||||
//
|
|
||||||
// The operator sets imported by FunctionProto should be compatible with the ones
|
|
||||||
// imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto
|
|
||||||
// and ModelProto then versions for the operator set may be different but,
|
|
||||||
// the operator schema returned for op_type, domain, version combination
|
|
||||||
// for both the versions should be same.
|
|
||||||
|
|
||||||
repeated OperatorSetIdProto opset_import = 9;
|
|
||||||
|
|
||||||
// The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of
|
|
||||||
// the FunctionProto.
|
|
||||||
string domain = 10;
|
|
||||||
}
|
|
||||||
|
|
||||||
// For using protobuf-lite
|
|
||||||
option optimize_for = LITE_RUNTIME;
|
|
||||||
|
|
@ -17,7 +17,6 @@ crate-type = ["cdylib"]
|
|||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||||
candle-onnx = {path= "../candle-onnx", version = "0.3.0", optional = true}
|
|
||||||
half = { workspace = true }
|
half = { workspace = true }
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
|
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
|
||||||
@ -30,5 +29,3 @@ default = []
|
|||||||
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
||||||
cuda = ["candle/cuda"]
|
cuda = ["candle/cuda"]
|
||||||
mkl = ["dep:intel-mkl-src","candle/mkl"]
|
mkl = ["dep:intel-mkl-src","candle/mkl"]
|
||||||
onnx = ["dep:candle-onnx"]
|
|
||||||
|
|
||||||
|
@ -1,5 +0,0 @@
|
|||||||
# Generated content DO NOT EDIT
|
|
||||||
from .. import onnx
|
|
||||||
|
|
||||||
ONNXModel = onnx.ONNXModel
|
|
||||||
ONNXTensorDescription = onnx.ONNXTensorDescription
|
|
@ -1,89 +0,0 @@
|
|||||||
# Generated content DO NOT EDIT
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
|
||||||
from os import PathLike
|
|
||||||
from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
|
|
||||||
from candle import Tensor, DType, QTensor
|
|
||||||
|
|
||||||
class ONNXModel:
|
|
||||||
"""
|
|
||||||
A wrapper around an ONNX model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, path: str):
|
|
||||||
pass
|
|
||||||
@property
|
|
||||||
def doc_string(self) -> str:
|
|
||||||
"""
|
|
||||||
The doc string of the model.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
@property
|
|
||||||
def domain(self) -> str:
|
|
||||||
"""
|
|
||||||
The domain of the operator set of the model.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
def initializers(self) -> Dict[str, Tensor]:
|
|
||||||
"""
|
|
||||||
Get the weights of the model.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
@property
|
|
||||||
def inputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
|
|
||||||
"""
|
|
||||||
The inputs of the model.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
@property
|
|
||||||
def ir_version(self) -> int:
|
|
||||||
"""
|
|
||||||
The version of the IR this model targets.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
@property
|
|
||||||
def model_version(self) -> int:
|
|
||||||
"""
|
|
||||||
The version of the model.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
@property
|
|
||||||
def outputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
|
|
||||||
"""
|
|
||||||
The outputs of the model.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
@property
|
|
||||||
def producer_name(self) -> str:
|
|
||||||
"""
|
|
||||||
The producer of the model.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
@property
|
|
||||||
def producer_version(self) -> str:
|
|
||||||
"""
|
|
||||||
The version of the producer of the model.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
def run(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
||||||
"""
|
|
||||||
Run the model on the given inputs.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
class ONNXTensorDescription:
|
|
||||||
"""
|
|
||||||
A wrapper around an ONNX tensor description.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self) -> DType:
|
|
||||||
"""
|
|
||||||
The data type of the tensor.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
@property
|
|
||||||
def shape(self) -> Tuple[Union[int, str, Any]]:
|
|
||||||
"""
|
|
||||||
The shape of the tensor.
|
|
||||||
"""
|
|
||||||
pass
|
|
@ -19,14 +19,12 @@ extern crate accelerate_src;
|
|||||||
|
|
||||||
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
|
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
|
||||||
|
|
||||||
mod utils;
|
|
||||||
use utils::wrap_err;
|
|
||||||
|
|
||||||
mod shape;
|
mod shape;
|
||||||
use shape::{PyShape, PyShapeWithHole};
|
use shape::{PyShape, PyShapeWithHole};
|
||||||
|
|
||||||
#[cfg(feature = "onnx")]
|
pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
||||||
mod onnx;
|
PyErr::new::<PyValueError, _>(format!("{err:?}"))
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
#[pyclass(name = "Tensor")]
|
#[pyclass(name = "Tensor")]
|
||||||
@ -71,13 +69,11 @@ impl PyDType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static CUDA_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
|
static CUDA_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
|
||||||
static METAL_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
|
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||||
enum PyDevice {
|
enum PyDevice {
|
||||||
Cpu,
|
Cpu,
|
||||||
Cuda,
|
Cuda,
|
||||||
Metal,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PyDevice {
|
impl PyDevice {
|
||||||
@ -85,7 +81,7 @@ impl PyDevice {
|
|||||||
match device {
|
match device {
|
||||||
Device::Cpu => Self::Cpu,
|
Device::Cpu => Self::Cpu,
|
||||||
Device::Cuda(_) => Self::Cuda,
|
Device::Cuda(_) => Self::Cuda,
|
||||||
Device::Metal(_) => Self::Metal,
|
Device::Metal(_) => unimplemented!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -101,15 +97,6 @@ impl PyDevice {
|
|||||||
*device = Some(d.clone());
|
*device = Some(d.clone());
|
||||||
Ok(d)
|
Ok(d)
|
||||||
}
|
}
|
||||||
Self::Metal => {
|
|
||||||
let mut device = METAL_DEVICE.lock().unwrap();
|
|
||||||
if let Some(device) = device.as_ref() {
|
|
||||||
return Ok(device.clone());
|
|
||||||
};
|
|
||||||
let d = Device::new_metal(0).map_err(wrap_err)?;
|
|
||||||
*device = Some(d.clone());
|
|
||||||
Ok(d)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -131,7 +118,6 @@ impl ToPyObject for PyDevice {
|
|||||||
let str = match self {
|
let str = match self {
|
||||||
PyDevice::Cpu => "cpu",
|
PyDevice::Cpu => "cpu",
|
||||||
PyDevice::Cuda => "cuda",
|
PyDevice::Cuda => "cuda",
|
||||||
PyDevice::Metal => "metal",
|
|
||||||
};
|
};
|
||||||
str.to_object(py)
|
str.to_object(py)
|
||||||
}
|
}
|
||||||
@ -1574,14 +1560,6 @@ fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "onnx")]
|
|
||||||
fn candle_onnx_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
|
||||||
use onnx::{PyONNXModel, PyONNXTensorDescriptor};
|
|
||||||
m.add_class::<PyONNXModel>()?;
|
|
||||||
m.add_class::<PyONNXTensorDescriptor>()?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[pymodule]
|
#[pymodule]
|
||||||
fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||||
let utils = PyModule::new(py, "utils")?;
|
let utils = PyModule::new(py, "utils")?;
|
||||||
@ -1590,12 +1568,6 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
|||||||
let nn = PyModule::new(py, "functional")?;
|
let nn = PyModule::new(py, "functional")?;
|
||||||
candle_functional_m(py, nn)?;
|
candle_functional_m(py, nn)?;
|
||||||
m.add_submodule(nn)?;
|
m.add_submodule(nn)?;
|
||||||
#[cfg(feature = "onnx")]
|
|
||||||
{
|
|
||||||
let onnx = PyModule::new(py, "onnx")?;
|
|
||||||
candle_onnx_m(py, onnx)?;
|
|
||||||
m.add_submodule(onnx)?;
|
|
||||||
}
|
|
||||||
m.add_class::<PyTensor>()?;
|
m.add_class::<PyTensor>()?;
|
||||||
m.add_class::<PyQTensor>()?;
|
m.add_class::<PyQTensor>()?;
|
||||||
m.add_class::<PyDType>()?;
|
m.add_class::<PyDType>()?;
|
||||||
|
@ -1,212 +0,0 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
use crate::utils::wrap_err;
|
|
||||||
use crate::{PyDType, PyTensor};
|
|
||||||
use candle_onnx::eval::{dtype, get_tensor, simple_eval};
|
|
||||||
use candle_onnx::onnx::tensor_proto::DataType;
|
|
||||||
use candle_onnx::onnx::tensor_shape_proto::dimension::Value;
|
|
||||||
use candle_onnx::onnx::type_proto::{Tensor as ONNXTensor, Value as ONNXValue};
|
|
||||||
use candle_onnx::onnx::{ModelProto, ValueInfoProto};
|
|
||||||
use pyo3::exceptions::PyValueError;
|
|
||||||
use pyo3::prelude::*;
|
|
||||||
use pyo3::types::{PyList, PyTuple};
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
#[pyclass(name = "ONNXTensorDescription")]
|
|
||||||
/// A wrapper around an ONNX tensor description.
|
|
||||||
pub struct PyONNXTensorDescriptor(ONNXTensor);
|
|
||||||
|
|
||||||
#[pymethods]
|
|
||||||
impl PyONNXTensorDescriptor {
|
|
||||||
#[getter]
|
|
||||||
/// The data type of the tensor.
|
|
||||||
/// &RETURNS&: DType
|
|
||||||
fn dtype(&self) -> PyResult<PyDType> {
|
|
||||||
match DataType::try_from(self.0.elem_type) {
|
|
||||||
Ok(dt) => match dtype(dt) {
|
|
||||||
Some(dt) => Ok(PyDType(dt)),
|
|
||||||
None => Err(PyValueError::new_err(format!(
|
|
||||||
"unsupported 'value' data-type {dt:?}"
|
|
||||||
))),
|
|
||||||
},
|
|
||||||
type_ => Err(PyValueError::new_err(format!(
|
|
||||||
"unsupported input type {type_:?}"
|
|
||||||
))),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[getter]
|
|
||||||
/// The shape of the tensor.
|
|
||||||
/// &RETURNS&: Tuple[Union[int,str,Any]]
|
|
||||||
fn shape(&self, py: Python) -> PyResult<Py<PyTuple>> {
|
|
||||||
let shape = PyList::empty(py);
|
|
||||||
if let Some(d) = &self.0.shape {
|
|
||||||
for dim in d.dim.iter() {
|
|
||||||
if let Some(value) = &dim.value {
|
|
||||||
match value {
|
|
||||||
Value::DimValue(v) => shape.append(*v)?,
|
|
||||||
Value::DimParam(s) => shape.append(s.clone())?,
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
return Err(PyValueError::new_err("None value in shape"));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(shape.to_tuple().into())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn __repr__(&self, py: Python) -> String {
|
|
||||||
match (self.shape(py), self.dtype()) {
|
|
||||||
(Ok(shape), Ok(dtype)) => format!(
|
|
||||||
"TensorDescriptor[shape: {:?}, dtype: {:?}]",
|
|
||||||
shape.to_string(),
|
|
||||||
dtype.__str__()
|
|
||||||
),
|
|
||||||
(Err(_), Err(_)) => "TensorDescriptor[shape: unknown, dtype: unknown]".to_string(),
|
|
||||||
(Err(_), Ok(dtype)) => format!(
|
|
||||||
"TensorDescriptor[shape: unknown, dtype: {:?}]",
|
|
||||||
dtype.__str__()
|
|
||||||
),
|
|
||||||
(Ok(shape), Err(_)) => format!(
|
|
||||||
"TensorDescriptor[shape: {:?}, dtype: unknown]",
|
|
||||||
shape.to_string()
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn __str__(&self, py: Python) -> String {
|
|
||||||
self.__repr__(py)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
#[pyclass(name = "ONNXModel")]
|
|
||||||
/// A wrapper around an ONNX model.
|
|
||||||
pub struct PyONNXModel(ModelProto);
|
|
||||||
|
|
||||||
fn extract_tensor_descriptions(
|
|
||||||
value_infos: &[ValueInfoProto],
|
|
||||||
) -> HashMap<String, PyONNXTensorDescriptor> {
|
|
||||||
let mut map = HashMap::new();
|
|
||||||
for value_info in value_infos.iter() {
|
|
||||||
let input_type = match &value_info.r#type {
|
|
||||||
Some(input_type) => input_type,
|
|
||||||
None => continue,
|
|
||||||
};
|
|
||||||
let input_type = match &input_type.value {
|
|
||||||
Some(input_type) => input_type,
|
|
||||||
None => continue,
|
|
||||||
};
|
|
||||||
|
|
||||||
let tensor_type: &ONNXTensor = match input_type {
|
|
||||||
ONNXValue::TensorType(tt) => tt,
|
|
||||||
_ => continue,
|
|
||||||
};
|
|
||||||
map.insert(
|
|
||||||
value_info.name.to_string(),
|
|
||||||
PyONNXTensorDescriptor(tensor_type.clone()),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
map
|
|
||||||
}
|
|
||||||
|
|
||||||
#[pymethods]
|
|
||||||
impl PyONNXModel {
|
|
||||||
#[new]
|
|
||||||
#[pyo3(text_signature = "(self, path:str)")]
|
|
||||||
/// Load an ONNX model from the given path.
|
|
||||||
fn new(path: String) -> PyResult<Self> {
|
|
||||||
let model: ModelProto = candle_onnx::read_file(path).map_err(wrap_err)?;
|
|
||||||
Ok(PyONNXModel(model))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[getter]
|
|
||||||
/// The version of the IR this model targets.
|
|
||||||
/// &RETURNS&: int
|
|
||||||
fn ir_version(&self) -> i64 {
|
|
||||||
self.0.ir_version
|
|
||||||
}
|
|
||||||
|
|
||||||
#[getter]
|
|
||||||
/// The producer of the model.
|
|
||||||
/// &RETURNS&: str
|
|
||||||
fn producer_name(&self) -> String {
|
|
||||||
self.0.producer_name.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[getter]
|
|
||||||
/// The version of the producer of the model.
|
|
||||||
/// &RETURNS&: str
|
|
||||||
fn producer_version(&self) -> String {
|
|
||||||
self.0.producer_version.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[getter]
|
|
||||||
/// The domain of the operator set of the model.
|
|
||||||
/// &RETURNS&: str
|
|
||||||
fn domain(&self) -> String {
|
|
||||||
self.0.domain.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[getter]
|
|
||||||
/// The version of the model.
|
|
||||||
/// &RETURNS&: int
|
|
||||||
fn model_version(&self) -> i64 {
|
|
||||||
self.0.model_version
|
|
||||||
}
|
|
||||||
|
|
||||||
#[getter]
|
|
||||||
/// The doc string of the model.
|
|
||||||
/// &RETURNS&: str
|
|
||||||
fn doc_string(&self) -> String {
|
|
||||||
self.0.doc_string.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the weights of the model.
|
|
||||||
/// &RETURNS&: Dict[str, Tensor]
|
|
||||||
fn initializers(&self) -> PyResult<HashMap<String, PyTensor>> {
|
|
||||||
let mut map = HashMap::new();
|
|
||||||
if let Some(graph) = self.0.graph.as_ref() {
|
|
||||||
for tensor_description in graph.initializer.iter() {
|
|
||||||
let tensor = get_tensor(tensor_description, tensor_description.name.as_str())
|
|
||||||
.map_err(wrap_err)?;
|
|
||||||
map.insert(tensor_description.name.to_string(), PyTensor(tensor));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(map)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[getter]
|
|
||||||
/// The inputs of the model.
|
|
||||||
/// &RETURNS&: Optional[Dict[str, ONNXTensorDescription]]
|
|
||||||
fn inputs(&self) -> Option<HashMap<String, PyONNXTensorDescriptor>> {
|
|
||||||
if let Some(graph) = self.0.graph.as_ref() {
|
|
||||||
return Some(extract_tensor_descriptions(&graph.input));
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
#[getter]
|
|
||||||
/// The outputs of the model.
|
|
||||||
/// &RETURNS&: Optional[Dict[str, ONNXTensorDescription]]
|
|
||||||
fn outputs(&self) -> Option<HashMap<String, PyONNXTensorDescriptor>> {
|
|
||||||
if let Some(graph) = self.0.graph.as_ref() {
|
|
||||||
return Some(extract_tensor_descriptions(&graph.output));
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
#[pyo3(text_signature = "(self, inputs:Dict[str,Tensor])")]
|
|
||||||
/// Run the model on the given inputs.
|
|
||||||
/// &RETURNS&: Dict[str,Tensor]
|
|
||||||
fn run(&self, inputs: HashMap<String, PyTensor>) -> PyResult<HashMap<String, PyTensor>> {
|
|
||||||
let unwrapped_tensors = inputs.into_iter().map(|(k, v)| (k.clone(), v.0)).collect();
|
|
||||||
|
|
||||||
let result = simple_eval(&self.0, unwrapped_tensors).map_err(wrap_err)?;
|
|
||||||
|
|
||||||
Ok(result
|
|
||||||
.into_iter()
|
|
||||||
.map(|(k, v)| (k.clone(), PyTensor(v)))
|
|
||||||
.collect())
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,6 +0,0 @@
|
|||||||
use pyo3::exceptions::PyValueError;
|
|
||||||
use pyo3::prelude::*;
|
|
||||||
|
|
||||||
pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
|
||||||
PyErr::new::<PyValueError, _>(format!("{err:?}"))
|
|
||||||
}
|
|
@ -21,7 +21,6 @@ rand = { workspace = true }
|
|||||||
rayon = { workspace = true }
|
rayon = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
serde_plain = { workspace = true }
|
|
||||||
tracing = { workspace = true }
|
tracing = { workspace = true }
|
||||||
wav = { workspace = true }
|
wav = { workspace = true }
|
||||||
|
|
||||||
@ -29,5 +28,6 @@ wav = { workspace = true }
|
|||||||
default = []
|
default = []
|
||||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
|
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
|
||||||
cuda = ["candle/cuda", "candle-nn/cuda"]
|
cuda = ["candle/cuda", "candle-nn/cuda"]
|
||||||
|
metal = ["candle/metal", "candle-nn/metal"]
|
||||||
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
||||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
|
|
||||||
use candle::{DType, Device, Result, Tensor};
|
use candle::{DType, Device, Result, Tensor};
|
||||||
use candle_nn::{Embedding, Module, VarBuilder};
|
use candle_nn::{Embedding, Module, VarBuilder};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
@ -33,6 +32,76 @@ impl HiddenActLayer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Linear {
|
||||||
|
weight: Tensor,
|
||||||
|
bias: Option<Tensor>,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Linear {
|
||||||
|
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||||
|
Self { weight, bias, span }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Linear {
|
||||||
|
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let w = match x.dims() {
|
||||||
|
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
||||||
|
_ => self.weight.t()?,
|
||||||
|
};
|
||||||
|
let x = x.matmul(&w)?;
|
||||||
|
match &self.bias {
|
||||||
|
None => Ok(x),
|
||||||
|
Some(bias) => x.broadcast_add(bias),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct LayerNorm {
|
||||||
|
weight: Tensor,
|
||||||
|
bias: Tensor,
|
||||||
|
eps: f64,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LayerNorm {
|
||||||
|
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
|
||||||
|
Self {
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
eps,
|
||||||
|
span,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for LayerNorm {
|
||||||
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let x_dtype = x.dtype();
|
||||||
|
let internal_dtype = match x_dtype {
|
||||||
|
DType::F16 | DType::BF16 => DType::F32,
|
||||||
|
d => d,
|
||||||
|
};
|
||||||
|
let (_bsize, _seq_len, hidden_size) = x.dims3()?;
|
||||||
|
let x = x.to_dtype(internal_dtype)?;
|
||||||
|
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
|
||||||
|
let x = x.broadcast_sub(&mean_x)?;
|
||||||
|
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||||
|
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||||
|
let x = x_normed
|
||||||
|
.to_dtype(x_dtype)?
|
||||||
|
.broadcast_mul(&self.weight)?
|
||||||
|
.broadcast_add(&self.bias)?;
|
||||||
|
Ok(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
enum PositionEmbeddingType {
|
enum PositionEmbeddingType {
|
||||||
@ -115,6 +184,12 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
|
|||||||
Ok(Embedding::new(embeddings, hidden_size))
|
Ok(Embedding::new(embeddings, hidden_size))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||||
|
let weight = vb.get((size2, size1), "weight")?;
|
||||||
|
let bias = vb.get(size2, "bias")?;
|
||||||
|
Ok(Linear::new(weight, Some(bias)))
|
||||||
|
}
|
||||||
|
|
||||||
struct Dropout {
|
struct Dropout {
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pr: f64,
|
pr: f64,
|
||||||
@ -133,6 +208,20 @@ impl Module for Dropout {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||||
|
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
|
||||||
|
(Ok(weight), Ok(bias)) => (weight, bias),
|
||||||
|
(Err(err), _) | (_, Err(err)) => {
|
||||||
|
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
|
||||||
|
(weight, bias)
|
||||||
|
} else {
|
||||||
|
return Err(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(LayerNorm::new(weight, bias, eps))
|
||||||
|
}
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
|
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
|
||||||
struct BertEmbeddings {
|
struct BertEmbeddings {
|
||||||
word_embeddings: Embedding,
|
word_embeddings: Embedding,
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
use super::with_tracing::{linear_no_bias as linear, Linear};
|
|
||||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{Embedding, Module, VarBuilder};
|
use candle_nn::{Embedding, Module, VarBuilder};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
@ -82,6 +81,21 @@ impl Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
|
||||||
|
// model.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Linear {
|
||||||
|
inner: candle_nn::Linear,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Linear {
|
||||||
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
self.inner.forward(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Cache {
|
pub struct Cache {
|
||||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||||
@ -136,6 +150,12 @@ impl Cache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||||
|
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
|
||||||
|
Ok(Linear { inner, span })
|
||||||
|
}
|
||||||
|
|
||||||
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
||||||
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
|
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
|
||||||
Ok(Embedding::new(embeddings, cfg.hidden_size))
|
Ok(Embedding::new(embeddings, cfg.hidden_size))
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
use super::with_tracing::{linear, Embedding, Linear};
|
#![allow(unused)]
|
||||||
use candle::{Result, Tensor};
|
use super::with_tracing::{linear, linear_no_bias, Embedding, Linear};
|
||||||
|
use candle::{Module, Result, Tensor};
|
||||||
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -169,6 +170,7 @@ impl Attention {
|
|||||||
kv_states: Option<&Tensor>,
|
kv_states: Option<&Tensor>,
|
||||||
attn_mask: Option<&Tensor>,
|
attn_mask: Option<&Tensor>,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
|
let is_cross_attn = kv_states.is_some();
|
||||||
let (b_sz, tgt_len, _) = xs.dims3()?;
|
let (b_sz, tgt_len, _) = xs.dims3()?;
|
||||||
let query_states = (xs.apply(&self.q_proj)? * self.scaling)?;
|
let query_states = (xs.apply(&self.q_proj)? * self.scaling)?;
|
||||||
let (key_states, value_states) = match kv_states {
|
let (key_states, value_states) = match kv_states {
|
||||||
@ -257,10 +259,6 @@ impl EncoderLayer {
|
|||||||
.apply(&self.fc2)?;
|
.apply(&self.fc2)?;
|
||||||
(xs + residual)?.apply(&self.final_layer_norm)
|
(xs + residual)?.apply(&self.final_layer_norm)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reset_kv_cache(&mut self) {
|
|
||||||
self.self_attn.reset_kv_cache()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -322,11 +320,6 @@ impl DecoderLayer {
|
|||||||
let xs = (xs + residual)?.apply(&self.final_layer_norm)?;
|
let xs = (xs + residual)?.apply(&self.final_layer_norm)?;
|
||||||
Ok(xs)
|
Ok(xs)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reset_kv_cache(&mut self) {
|
|
||||||
self.self_attn.reset_kv_cache();
|
|
||||||
self.encoder_attn.reset_kv_cache()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -375,12 +368,6 @@ impl Encoder {
|
|||||||
}
|
}
|
||||||
Ok(xs)
|
Ok(xs)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn reset_kv_cache(&mut self) {
|
|
||||||
for layer in self.layers.iter_mut() {
|
|
||||||
layer.reset_kv_cache()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -435,12 +422,6 @@ impl Decoder {
|
|||||||
}
|
}
|
||||||
Ok(xs)
|
Ok(xs)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn reset_kv_cache(&mut self) {
|
|
||||||
for layer in self.layers.iter_mut() {
|
|
||||||
layer.reset_kv_cache()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -461,11 +442,6 @@ impl Model {
|
|||||||
decoder,
|
decoder,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reset_kv_cache(&mut self) {
|
|
||||||
self.encoder.reset_kv_cache();
|
|
||||||
self.decoder.reset_kv_cache();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -513,8 +489,4 @@ impl MTModel {
|
|||||||
.apply(&self.lm_head)?
|
.apply(&self.lm_head)?
|
||||||
.broadcast_add(&self.final_logits_bias)
|
.broadcast_add(&self.final_logits_bias)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn reset_kv_cache(&mut self) {
|
|
||||||
self.model.reset_kv_cache();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -29,10 +29,8 @@ pub mod segment_anything;
|
|||||||
pub mod stable_diffusion;
|
pub mod stable_diffusion;
|
||||||
pub mod stable_lm;
|
pub mod stable_lm;
|
||||||
pub mod t5;
|
pub mod t5;
|
||||||
pub mod trocr;
|
|
||||||
pub mod vgg;
|
pub mod vgg;
|
||||||
pub mod vit;
|
pub mod vit;
|
||||||
pub mod whisper;
|
pub mod whisper;
|
||||||
pub mod with_tracing;
|
pub mod with_tracing;
|
||||||
pub mod wuerstchen;
|
pub mod wuerstchen;
|
||||||
pub mod yi;
|
|
||||||
|
@ -2,7 +2,7 @@ use std::collections::HashMap;
|
|||||||
|
|
||||||
use candle::quantized::QTensor;
|
use candle::quantized::QTensor;
|
||||||
use candle::quantized::{ggml_file, gguf_file};
|
use candle::quantized::{ggml_file, gguf_file};
|
||||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
use candle::{Device, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{Embedding, Module};
|
use candle_nn::{Embedding, Module};
|
||||||
|
|
||||||
pub const MAX_SEQ_LEN: usize = 4096;
|
pub const MAX_SEQ_LEN: usize = 4096;
|
||||||
@ -16,7 +16,7 @@ struct RmsNorm {
|
|||||||
impl RmsNorm {
|
impl RmsNorm {
|
||||||
fn new(scale: QTensor, eps: f32) -> Result<Self> {
|
fn new(scale: QTensor, eps: f32) -> Result<Self> {
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
||||||
let scale = scale.dequantize(&Device::Cpu)?;
|
let scale = scale.dequantize(scale.device())?;
|
||||||
let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64);
|
let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64);
|
||||||
Ok(Self { inner, span })
|
Ok(Self { inner, span })
|
||||||
}
|
}
|
||||||
@ -79,6 +79,8 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
|||||||
impl LayerWeights {
|
impl LayerWeights {
|
||||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||||
let _enter = self.span_rot.enter();
|
let _enter = self.span_rot.enter();
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-cos");
|
||||||
|
let _enter = span.enter();
|
||||||
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
|
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
|
||||||
let cos = self
|
let cos = self
|
||||||
.cos
|
.cos
|
||||||
@ -88,21 +90,37 @@ impl LayerWeights {
|
|||||||
.sin
|
.sin
|
||||||
.narrow(0, index_pos, seq_len)?
|
.narrow(0, index_pos, seq_len)?
|
||||||
.reshape((seq_len, n_embd / 2, 1))?;
|
.reshape((seq_len, n_embd / 2, 1))?;
|
||||||
|
drop(_enter);
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-broad");
|
||||||
|
let _enter = span.enter();
|
||||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
|
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
|
||||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
|
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
|
||||||
|
drop(_enter);
|
||||||
// This mimics the llama.cpp behavior.
|
// This mimics the llama.cpp behavior.
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105
|
// https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105
|
||||||
// The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension.
|
// The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension.
|
||||||
// The resulting y0 and y1 are also interleaved with:
|
// The resulting y0 and y1 are also interleaved with:
|
||||||
// y0 = x0*cos - x1*sin
|
// y0 = x0*cos - x1*sin
|
||||||
// y1 = x0*sin + x1*cos
|
// y1 = x0*sin + x1*cos
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-reshape");
|
||||||
|
let _enter = span.enter();
|
||||||
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
|
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
|
||||||
let x0 = x.narrow(D::Minus1, 0, 1)?;
|
let x0 = x.narrow(D::Minus1, 0, 1)?;
|
||||||
let x1 = x.narrow(D::Minus1, 1, 1)?;
|
let x1 = x.narrow(D::Minus1, 1, 1)?;
|
||||||
|
drop(_enter);
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-broad-mul");
|
||||||
|
let _enter = span.enter();
|
||||||
let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
|
let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
|
||||||
let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
|
let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
|
||||||
|
drop(_enter);
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-cat");
|
||||||
|
let _enter = span.enter();
|
||||||
let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
|
let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
|
||||||
|
drop(_enter);
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-flatten");
|
||||||
|
let _enter = span.enter();
|
||||||
let rope = rope.flatten_from(D::Minus2)?;
|
let rope = rope.flatten_from(D::Minus2)?;
|
||||||
|
drop(_enter);
|
||||||
Ok(rope)
|
Ok(rope)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -112,6 +130,7 @@ impl LayerWeights {
|
|||||||
let q = self.attention_wq.forward(x)?;
|
let q = self.attention_wq.forward(x)?;
|
||||||
let k = self.attention_wk.forward(x)?;
|
let k = self.attention_wk.forward(x)?;
|
||||||
let v = self.attention_wv.forward(x)?;
|
let v = self.attention_wv.forward(x)?;
|
||||||
|
// println!("Q {:?} K {:?} V {:?}", q.dtype(), k.dtype(), v.dtype());
|
||||||
|
|
||||||
let q = q
|
let q = q
|
||||||
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
||||||
@ -145,9 +164,12 @@ impl LayerWeights {
|
|||||||
let v = self.repeat_kv(v)?;
|
let v = self.repeat_kv(v)?;
|
||||||
|
|
||||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||||
|
// println!("att {:?}", att.dtype());
|
||||||
let mask = mask.broadcast_as(att.shape())?;
|
let mask = mask.broadcast_as(att.shape())?;
|
||||||
|
// println!("mask {:?}", mask.dtype());
|
||||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||||
let att = candle_nn::ops::softmax_last_dim(&att)?;
|
let att = candle_nn::ops::softmax_last_dim(&att)?;
|
||||||
|
// println!("att {:?} v {:?}", att.dtype(), v.dtype());
|
||||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||||
let y = att.matmul(&v.contiguous()?)?;
|
let y = att.matmul(&v.contiguous()?)?;
|
||||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||||
@ -181,28 +203,37 @@ pub struct ModelWeights {
|
|||||||
span_output: tracing::Span,
|
span_output: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> {
|
fn precomput_freqs_cis(
|
||||||
|
head_dim: usize,
|
||||||
|
freq_base: f32,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
let theta: Vec<_> = (0..head_dim)
|
let theta: Vec<_> = (0..head_dim)
|
||||||
.step_by(2)
|
.step_by(2)
|
||||||
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
|
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
|
||||||
.collect();
|
.collect();
|
||||||
let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?;
|
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)?
|
let range: Vec<f32> = (0..MAX_SEQ_LEN).map(|r| r as f32).collect();
|
||||||
.to_dtype(DType::F32)?
|
let idx_theta = Tensor::new(range.as_slice(), device)?
|
||||||
.reshape((MAX_SEQ_LEN, 1))?
|
.reshape((MAX_SEQ_LEN, 1))?
|
||||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||||
|
// TODO This change avoids allocating on Metal and then casting since allocating directly on
|
||||||
|
// CPU as f32 seems just as fast
|
||||||
|
// let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||||
|
// .to_dtype(DType::F32)?
|
||||||
|
// .reshape((MAX_SEQ_LEN, 1))?
|
||||||
|
// .matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||||
let cos = idx_theta.cos()?;
|
let cos = idx_theta.cos()?;
|
||||||
let sin = idx_theta.sin()?;
|
let sin = idx_theta.sin()?;
|
||||||
Ok((cos, sin))
|
Ok((cos, sin))
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ModelWeights {
|
impl ModelWeights {
|
||||||
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
|
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize, device: &Device) -> Result<Self> {
|
||||||
let cpu = &Device::Cpu;
|
|
||||||
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
|
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
|
||||||
let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?;
|
let (cos, sin) = precomput_freqs_cis(head_dim, 10000., device)?;
|
||||||
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
|
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
|
||||||
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
|
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||||
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
|
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
|
||||||
let output = ct.remove("output.weight")?;
|
let output = ct.remove("output.weight")?;
|
||||||
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
|
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
|
||||||
@ -257,8 +288,8 @@ impl ModelWeights {
|
|||||||
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
||||||
ct: gguf_file::Content,
|
ct: gguf_file::Content,
|
||||||
reader: &mut R,
|
reader: &mut R,
|
||||||
|
device: &Device,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let cpu = &Device::Cpu;
|
|
||||||
let md_get = |s: &str| match ct.metadata.get(s) {
|
let md_get = |s: &str| match ct.metadata.get(s) {
|
||||||
None => candle::bail!("cannot find {s} in metadata"),
|
None => candle::bail!("cannot find {s} in metadata"),
|
||||||
Some(v) => Ok(v),
|
Some(v) => Ok(v),
|
||||||
@ -276,24 +307,31 @@ impl ModelWeights {
|
|||||||
let rope_freq_base = md_get("llama.rope.freq_base")
|
let rope_freq_base = md_get("llama.rope.freq_base")
|
||||||
.and_then(|m| m.to_f32())
|
.and_then(|m| m.to_f32())
|
||||||
.unwrap_or(10000f32);
|
.unwrap_or(10000f32);
|
||||||
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?;
|
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
|
||||||
|
|
||||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight")?;
|
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||||
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
|
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||||
let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?;
|
let norm = RmsNorm::new(
|
||||||
let output = ct.tensor(reader, "output.weight")?;
|
ct.tensor(reader, "output_norm.weight", device)?,
|
||||||
|
rms_norm_eps,
|
||||||
|
)?;
|
||||||
|
let output = ct.tensor(reader, "output.weight", device)?;
|
||||||
let mut layers = Vec::with_capacity(block_count);
|
let mut layers = Vec::with_capacity(block_count);
|
||||||
for layer_idx in 0..block_count {
|
for layer_idx in 0..block_count {
|
||||||
let prefix = format!("blk.{layer_idx}");
|
let prefix = format!("blk.{layer_idx}");
|
||||||
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?;
|
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
|
||||||
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?;
|
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
|
||||||
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?;
|
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
|
||||||
let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?;
|
let attention_wo =
|
||||||
let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?;
|
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
|
||||||
let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?;
|
let feed_forward_w1 =
|
||||||
let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?;
|
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
|
||||||
let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?;
|
let feed_forward_w2 =
|
||||||
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?;
|
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
|
||||||
|
let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
|
||||||
|
let attention_norm =
|
||||||
|
ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?;
|
||||||
|
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?;
|
||||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||||
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
||||||
@ -331,14 +369,14 @@ impl ModelWeights {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mask(&mut self, t: usize) -> Result<Tensor> {
|
fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> {
|
||||||
if let Some(mask) = self.masks.get(&t) {
|
if let Some(mask) = self.masks.get(&t) {
|
||||||
Ok(mask.clone())
|
Ok(mask.clone())
|
||||||
} else {
|
} else {
|
||||||
let mask: Vec<_> = (0..t)
|
let mask: Vec<_> = (0..t)
|
||||||
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
||||||
.collect();
|
.collect();
|
||||||
let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
|
let mask = Tensor::from_slice(&mask, (t, t), device)?;
|
||||||
self.masks.insert(t, mask.clone());
|
self.masks.insert(t, mask.clone());
|
||||||
Ok(mask)
|
Ok(mask)
|
||||||
}
|
}
|
||||||
@ -346,7 +384,7 @@ impl ModelWeights {
|
|||||||
|
|
||||||
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||||
let (_b_sz, seq_len) = x.dims2()?;
|
let (_b_sz, seq_len) = x.dims2()?;
|
||||||
let mask = self.mask(seq_len)?;
|
let mask = self.mask(seq_len, x.device())?;
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
let mut layer_in = self.tok_embeddings.forward(x)?;
|
let mut layer_in = self.tok_embeddings.forward(x)?;
|
||||||
for layer in self.layers.iter_mut() {
|
for layer in self.layers.iter_mut() {
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
// T5 Text Model, quantized version
|
// T5 Text Model, quantized version
|
||||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||||
|
|
||||||
use crate::models::t5::{deserialize_feed_forward_proj_activation, ActivationWithOptionalGating};
|
|
||||||
use crate::models::with_tracing::QMatMul;
|
use crate::models::with_tracing::QMatMul;
|
||||||
use crate::quantized_nn::Embedding;
|
use crate::quantized_nn::Embedding;
|
||||||
pub use crate::quantized_var_builder::VarBuilder;
|
pub use crate::quantized_var_builder::VarBuilder;
|
||||||
@ -55,8 +54,8 @@ pub struct Config {
|
|||||||
dropout_rate: f64,
|
dropout_rate: f64,
|
||||||
layer_norm_epsilon: f64,
|
layer_norm_epsilon: f64,
|
||||||
initializer_factor: f64,
|
initializer_factor: f64,
|
||||||
#[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")]
|
#[serde(default)]
|
||||||
pub feed_forward_proj: ActivationWithOptionalGating,
|
feed_forward_proj: Activation,
|
||||||
#[serde(default = "default_tie_word_embeddings")]
|
#[serde(default = "default_tie_word_embeddings")]
|
||||||
tie_word_embeddings: bool,
|
tie_word_embeddings: bool,
|
||||||
#[serde(default = "default_is_decoder")]
|
#[serde(default = "default_is_decoder")]
|
||||||
@ -66,7 +65,6 @@ pub struct Config {
|
|||||||
pub use_cache: bool,
|
pub use_cache: bool,
|
||||||
pub pad_token_id: usize,
|
pub pad_token_id: usize,
|
||||||
pub eos_token_id: usize,
|
pub eos_token_id: usize,
|
||||||
pub decoder_start_token_id: Option<usize>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for Config {
|
impl Default for Config {
|
||||||
@ -84,17 +82,13 @@ impl Default for Config {
|
|||||||
dropout_rate: 0.1,
|
dropout_rate: 0.1,
|
||||||
layer_norm_epsilon: 1e-6,
|
layer_norm_epsilon: 1e-6,
|
||||||
initializer_factor: 1.0,
|
initializer_factor: 1.0,
|
||||||
feed_forward_proj: ActivationWithOptionalGating {
|
feed_forward_proj: Activation::Relu,
|
||||||
gated: false,
|
|
||||||
activation: Activation::Relu,
|
|
||||||
},
|
|
||||||
tie_word_embeddings: true,
|
tie_word_embeddings: true,
|
||||||
is_decoder: false,
|
is_decoder: false,
|
||||||
is_encoder_decoder: true,
|
is_encoder_decoder: true,
|
||||||
use_cache: true,
|
use_cache: true,
|
||||||
pad_token_id: 0,
|
pad_token_id: 0,
|
||||||
eos_token_id: 1,
|
eos_token_id: 1,
|
||||||
decoder_start_token_id: Some(0),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -180,7 +174,7 @@ impl T5DenseGatedActDense {
|
|||||||
wi_0,
|
wi_0,
|
||||||
wi_1,
|
wi_1,
|
||||||
wo,
|
wo,
|
||||||
act: cfg.feed_forward_proj.activation,
|
act: Activation::NewGelu,
|
||||||
span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
|
span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -209,7 +203,7 @@ impl T5LayerFF {
|
|||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let layer_norm =
|
let layer_norm =
|
||||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||||
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj.gated {
|
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu {
|
||||||
(
|
(
|
||||||
None,
|
None,
|
||||||
Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
|
Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
|
||||||
@ -648,12 +642,7 @@ pub struct T5EncoderModel {
|
|||||||
|
|
||||||
impl T5EncoderModel {
|
impl T5EncoderModel {
|
||||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let shared_vb = if vb.contains_key("shared.weight") {
|
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||||
vb.pp("shared")
|
|
||||||
} else {
|
|
||||||
vb.pp("decoder").pp("embed_tokens")
|
|
||||||
};
|
|
||||||
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
|
|
||||||
let shared = Arc::new(shared);
|
let shared = Arc::new(shared);
|
||||||
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
|
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -694,12 +683,7 @@ impl T5ForConditionalGeneration {
|
|||||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
assert!(cfg.is_encoder_decoder);
|
assert!(cfg.is_encoder_decoder);
|
||||||
let d_model = cfg.d_model;
|
let d_model = cfg.d_model;
|
||||||
let shared_vb = if vb.contains_key("shared.weight") {
|
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||||
vb.pp("shared")
|
|
||||||
} else {
|
|
||||||
vb.pp("decoder").pp("embed_tokens")
|
|
||||||
};
|
|
||||||
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
|
|
||||||
let shared = Arc::new(shared);
|
let shared = Arc::new(shared);
|
||||||
|
|
||||||
let mut encoder_cfg = cfg.clone();
|
let mut encoder_cfg = cfg.clone();
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
pub use crate::models::with_tracing::Linear;
|
|
||||||
use candle::{Result, Tensor};
|
use candle::{Result, Tensor};
|
||||||
use candle_nn::{Module, VarBuilder};
|
use candle_nn::{Module, VarBuilder};
|
||||||
|
|
||||||
@ -10,11 +9,13 @@ pub mod tiny_vit;
|
|||||||
pub mod transformer;
|
pub mod transformer;
|
||||||
|
|
||||||
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
|
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
|
||||||
if bias {
|
let inner = if bias {
|
||||||
crate::models::with_tracing::linear(in_dim, out_dim, vb)
|
candle_nn::linear(in_dim, out_dim, vb)?
|
||||||
} else {
|
} else {
|
||||||
crate::models::with_tracing::linear_no_bias(in_dim, out_dim, vb)
|
candle_nn::linear_no_bias(in_dim, out_dim, vb)?
|
||||||
}
|
};
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||||
|
Ok(Linear { inner, span })
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -84,3 +85,16 @@ impl Module for MlpBlock {
|
|||||||
.apply(&self.lin2)
|
.apply(&self.lin2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Linear {
|
||||||
|
inner: candle_nn::Linear,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Linear {
|
||||||
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
self.inner.forward(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -102,14 +102,6 @@ impl Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn ssd1b() -> Self {
|
|
||||||
Self::sdxl()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn ssd1b2() -> Self {
|
|
||||||
Self::sdxl2()
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://huggingface.co/warp-ai/wuerstchen/blob/main/text_encoder/config.json
|
// https://huggingface.co/warp-ai/wuerstchen/blob/main/text_encoder/config.json
|
||||||
pub fn wuerstchen() -> Self {
|
pub fn wuerstchen() -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
@ -249,71 +249,6 @@ impl StableDiffusionConfig {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn ssd1b(
|
|
||||||
sliced_attention_size: Option<usize>,
|
|
||||||
height: Option<usize>,
|
|
||||||
width: Option<usize>,
|
|
||||||
) -> Self {
|
|
||||||
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
|
|
||||||
out_channels,
|
|
||||||
use_cross_attn,
|
|
||||||
attention_head_dim,
|
|
||||||
};
|
|
||||||
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json
|
|
||||||
let unet = unet_2d::UNet2DConditionModelConfig {
|
|
||||||
blocks: vec![
|
|
||||||
bc(320, None, 5),
|
|
||||||
bc(640, Some(2), 10),
|
|
||||||
bc(1280, Some(10), 20),
|
|
||||||
],
|
|
||||||
center_input_sample: false,
|
|
||||||
cross_attention_dim: 2048,
|
|
||||||
downsample_padding: 1,
|
|
||||||
flip_sin_to_cos: true,
|
|
||||||
freq_shift: 0.,
|
|
||||||
layers_per_block: 2,
|
|
||||||
mid_block_scale_factor: 1.,
|
|
||||||
norm_eps: 1e-5,
|
|
||||||
norm_num_groups: 32,
|
|
||||||
sliced_attention_size,
|
|
||||||
use_linear_projection: true,
|
|
||||||
};
|
|
||||||
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/vae/config.json
|
|
||||||
let autoencoder = vae::AutoEncoderKLConfig {
|
|
||||||
block_out_channels: vec![128, 256, 512, 512],
|
|
||||||
layers_per_block: 2,
|
|
||||||
latent_channels: 4,
|
|
||||||
norm_num_groups: 32,
|
|
||||||
};
|
|
||||||
let scheduler = ddim::DDIMSchedulerConfig {
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
|
|
||||||
let height = if let Some(height) = height {
|
|
||||||
assert_eq!(height % 8, 0, "height has to be divisible by 8");
|
|
||||||
height
|
|
||||||
} else {
|
|
||||||
1024
|
|
||||||
};
|
|
||||||
|
|
||||||
let width = if let Some(width) = width {
|
|
||||||
assert_eq!(width % 8, 0, "width has to be divisible by 8");
|
|
||||||
width
|
|
||||||
} else {
|
|
||||||
1024
|
|
||||||
};
|
|
||||||
|
|
||||||
Self {
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
clip: clip::Config::ssd1b(),
|
|
||||||
clip2: Some(clip::Config::ssd1b2()),
|
|
||||||
autoencoder,
|
|
||||||
scheduler,
|
|
||||||
unet,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn build_vae<P: AsRef<std::path::Path>>(
|
pub fn build_vae<P: AsRef<std::path::Path>>(
|
||||||
&self,
|
&self,
|
||||||
vae_weights: P,
|
vae_weights: P,
|
||||||
|
@ -37,37 +37,6 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
|||||||
Ok(m)
|
Ok(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Default, Clone, PartialEq)]
|
|
||||||
pub struct ActivationWithOptionalGating {
|
|
||||||
pub gated: bool,
|
|
||||||
pub activation: candle_nn::Activation,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn deserialize_feed_forward_proj_activation<'de, D>(
|
|
||||||
deserializer: D,
|
|
||||||
) -> std::result::Result<ActivationWithOptionalGating, D::Error>
|
|
||||||
where
|
|
||||||
D: serde::de::Deserializer<'de>,
|
|
||||||
{
|
|
||||||
match String::deserialize(deserializer)?.as_str() {
|
|
||||||
"gated-gelu" => Ok(ActivationWithOptionalGating {
|
|
||||||
gated: true,
|
|
||||||
activation: candle_nn::Activation::NewGelu,
|
|
||||||
}),
|
|
||||||
"gated-silu" => Ok(ActivationWithOptionalGating {
|
|
||||||
gated: true,
|
|
||||||
activation: candle_nn::Activation::Silu,
|
|
||||||
}),
|
|
||||||
buf => {
|
|
||||||
let activation = serde_plain::from_str(buf).map_err(serde::de::Error::custom)?;
|
|
||||||
Ok(ActivationWithOptionalGating {
|
|
||||||
gated: false,
|
|
||||||
activation,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
vocab_size: usize,
|
vocab_size: usize,
|
||||||
@ -83,8 +52,8 @@ pub struct Config {
|
|||||||
dropout_rate: f64,
|
dropout_rate: f64,
|
||||||
layer_norm_epsilon: f64,
|
layer_norm_epsilon: f64,
|
||||||
initializer_factor: f64,
|
initializer_factor: f64,
|
||||||
#[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")]
|
#[serde(default)]
|
||||||
feed_forward_proj: ActivationWithOptionalGating,
|
feed_forward_proj: Activation,
|
||||||
#[serde(default = "default_tie_word_embeddings")]
|
#[serde(default = "default_tie_word_embeddings")]
|
||||||
tie_word_embeddings: bool,
|
tie_word_embeddings: bool,
|
||||||
#[serde(default = "default_is_decoder")]
|
#[serde(default = "default_is_decoder")]
|
||||||
@ -94,7 +63,6 @@ pub struct Config {
|
|||||||
pub use_cache: bool,
|
pub use_cache: bool,
|
||||||
pub pad_token_id: usize,
|
pub pad_token_id: usize,
|
||||||
pub eos_token_id: usize,
|
pub eos_token_id: usize,
|
||||||
pub decoder_start_token_id: Option<usize>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for Config {
|
impl Default for Config {
|
||||||
@ -112,17 +80,13 @@ impl Default for Config {
|
|||||||
dropout_rate: 0.1,
|
dropout_rate: 0.1,
|
||||||
layer_norm_epsilon: 1e-6,
|
layer_norm_epsilon: 1e-6,
|
||||||
initializer_factor: 1.0,
|
initializer_factor: 1.0,
|
||||||
feed_forward_proj: ActivationWithOptionalGating {
|
feed_forward_proj: Activation::Relu,
|
||||||
gated: false,
|
|
||||||
activation: Activation::Relu,
|
|
||||||
},
|
|
||||||
tie_word_embeddings: true,
|
tie_word_embeddings: true,
|
||||||
is_decoder: false,
|
is_decoder: false,
|
||||||
is_encoder_decoder: true,
|
is_encoder_decoder: true,
|
||||||
use_cache: true,
|
use_cache: true,
|
||||||
pad_token_id: 0,
|
pad_token_id: 0,
|
||||||
eos_token_id: 1,
|
eos_token_id: 1,
|
||||||
decoder_start_token_id: Some(0),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -136,10 +100,7 @@ impl Config {
|
|||||||
d_model: 768,
|
d_model: 768,
|
||||||
dropout_rate: 0.1,
|
dropout_rate: 0.1,
|
||||||
eos_token_id: 1,
|
eos_token_id: 1,
|
||||||
feed_forward_proj: ActivationWithOptionalGating {
|
feed_forward_proj: Activation::Relu,
|
||||||
gated: false,
|
|
||||||
activation: Activation::Relu,
|
|
||||||
},
|
|
||||||
tie_word_embeddings: true,
|
tie_word_embeddings: true,
|
||||||
initializer_factor: 1.0,
|
initializer_factor: 1.0,
|
||||||
is_decoder: false,
|
is_decoder: false,
|
||||||
@ -149,7 +110,6 @@ impl Config {
|
|||||||
num_heads: 12,
|
num_heads: 12,
|
||||||
num_layers: 12,
|
num_layers: 12,
|
||||||
pad_token_id: 0,
|
pad_token_id: 0,
|
||||||
decoder_start_token_id: Some(0),
|
|
||||||
relative_attention_max_distance: 128,
|
relative_attention_max_distance: 128,
|
||||||
relative_attention_num_buckets: 32,
|
relative_attention_num_buckets: 32,
|
||||||
use_cache: true,
|
use_cache: true,
|
||||||
@ -239,7 +199,7 @@ impl T5DenseGatedActDense {
|
|||||||
wi_0,
|
wi_0,
|
||||||
wi_1,
|
wi_1,
|
||||||
wo,
|
wo,
|
||||||
act: cfg.feed_forward_proj.activation,
|
act: Activation::NewGelu,
|
||||||
span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
|
span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -268,7 +228,7 @@ impl T5LayerFF {
|
|||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let layer_norm =
|
let layer_norm =
|
||||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||||
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj.gated {
|
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu {
|
||||||
(
|
(
|
||||||
None,
|
None,
|
||||||
Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
|
Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
|
||||||
@ -462,7 +422,7 @@ impl T5Attention {
|
|||||||
self.relative_attention_max_distance as f32
|
self.relative_attention_max_distance as f32
|
||||||
/ max_exact as f32,
|
/ max_exact as f32,
|
||||||
) * (num_buckets - max_exact) as f32;
|
) * (num_buckets - max_exact) as f32;
|
||||||
u32::min(max_exact + b as u32, num_buckets - 1)
|
max_exact + b as u32
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.collect::<Vec<u32>>()
|
.collect::<Vec<u32>>()
|
||||||
@ -707,12 +667,7 @@ pub struct T5EncoderModel {
|
|||||||
|
|
||||||
impl T5EncoderModel {
|
impl T5EncoderModel {
|
||||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let shared_vb = if vb.contains_tensor("shared.weight") {
|
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||||
vb.pp("shared")
|
|
||||||
} else {
|
|
||||||
vb.pp("decoder").pp("embed_tokens")
|
|
||||||
};
|
|
||||||
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
|
|
||||||
let shared = Arc::new(shared);
|
let shared = Arc::new(shared);
|
||||||
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
|
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -753,12 +708,7 @@ impl T5ForConditionalGeneration {
|
|||||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
assert!(cfg.is_encoder_decoder);
|
assert!(cfg.is_encoder_decoder);
|
||||||
let d_model = cfg.d_model;
|
let d_model = cfg.d_model;
|
||||||
let shared_vb = if vb.contains_tensor("shared.weight") {
|
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||||
vb.pp("shared")
|
|
||||||
} else {
|
|
||||||
vb.pp("decoder").pp("embed_tokens")
|
|
||||||
};
|
|
||||||
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
|
|
||||||
let shared = Arc::new(shared);
|
let shared = Arc::new(shared);
|
||||||
|
|
||||||
let mut encoder_cfg = cfg.clone();
|
let mut encoder_cfg = cfg.clone();
|
||||||
|
@ -1,434 +0,0 @@
|
|||||||
use crate::models::vit::{Config, Embeddings, Encoder};
|
|
||||||
use candle::{Result, Tensor};
|
|
||||||
use candle_nn::{
|
|
||||||
embedding, layer_norm, linear_no_bias, Embedding, LayerNorm, Linear, Module, VarBuilder,
|
|
||||||
};
|
|
||||||
use serde::Deserialize;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
|
||||||
pub struct TrOCRConfig {
|
|
||||||
pub vocab_size: usize,
|
|
||||||
pub d_model: usize,
|
|
||||||
pub hidden_size: usize,
|
|
||||||
pub decoder_layers: usize,
|
|
||||||
pub decoder_attention_heads: usize,
|
|
||||||
pub decoder_ffn_dim: usize,
|
|
||||||
pub activation_function: candle_nn::Activation,
|
|
||||||
pub max_position_embeddings: usize,
|
|
||||||
pub dropout: f64,
|
|
||||||
pub attention_dropout: f64,
|
|
||||||
pub activation_dropout: f64,
|
|
||||||
pub decoder_start_token_id: u32,
|
|
||||||
pub init_std: f64,
|
|
||||||
pub decoder_layerdrop: f64,
|
|
||||||
pub use_cache: bool,
|
|
||||||
pub scale_embedding: bool,
|
|
||||||
pub use_learned_position_embeddings: bool,
|
|
||||||
pub layernorm_embedding: bool,
|
|
||||||
pub pad_token_id: usize,
|
|
||||||
pub bos_token_id: usize,
|
|
||||||
pub eos_token_id: u32,
|
|
||||||
pub num_attention_heads: usize,
|
|
||||||
pub decoder_vocab_size: Option<usize>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for TrOCRConfig {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
vocab_size: 50265,
|
|
||||||
d_model: 1024,
|
|
||||||
hidden_size: 768,
|
|
||||||
decoder_layers: 12,
|
|
||||||
decoder_attention_heads: 16,
|
|
||||||
decoder_ffn_dim: 4096,
|
|
||||||
activation_function: candle_nn::Activation::Gelu,
|
|
||||||
max_position_embeddings: 512,
|
|
||||||
dropout: 0.1,
|
|
||||||
attention_dropout: 0.0,
|
|
||||||
activation_dropout: 0.0,
|
|
||||||
decoder_start_token_id: 2,
|
|
||||||
init_std: 0.02,
|
|
||||||
decoder_layerdrop: 0.0,
|
|
||||||
use_cache: true,
|
|
||||||
scale_embedding: false,
|
|
||||||
use_learned_position_embeddings: true,
|
|
||||||
layernorm_embedding: true,
|
|
||||||
pad_token_id: 1,
|
|
||||||
bos_token_id: 0,
|
|
||||||
eos_token_id: 2,
|
|
||||||
num_attention_heads: 12,
|
|
||||||
decoder_vocab_size: Some(50265),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct TrOCRLearnedPositionalEmbedding {
|
|
||||||
offset: usize,
|
|
||||||
weights: Embedding,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TrOCRLearnedPositionalEmbedding {
|
|
||||||
fn load(vb: VarBuilder, cfg: &TrOCRConfig) -> Result<Self> {
|
|
||||||
let offset: usize = 2;
|
|
||||||
let num_embeddings = cfg.max_position_embeddings;
|
|
||||||
let embedding_dim = cfg.d_model;
|
|
||||||
let weights = embedding(num_embeddings + offset, embedding_dim, vb)?;
|
|
||||||
|
|
||||||
Ok(Self { offset, weights })
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&mut self, input_ids: &Tensor, past_key_values_length: u32) -> Result<Tensor> {
|
|
||||||
let (b_sz, seq_len) = input_ids.dims2()?;
|
|
||||||
|
|
||||||
let mut positions = Tensor::arange(
|
|
||||||
past_key_values_length,
|
|
||||||
seq_len as u32 + past_key_values_length,
|
|
||||||
input_ids.device(),
|
|
||||||
)?
|
|
||||||
.expand((b_sz, seq_len))?;
|
|
||||||
|
|
||||||
positions =
|
|
||||||
positions.broadcast_add(&Tensor::new(self.offset as u32, input_ids.device())?)?;
|
|
||||||
self.weights.forward(&positions)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct TrOCRAttention {
|
|
||||||
head_dim: usize,
|
|
||||||
num_heads: usize,
|
|
||||||
is_decoder: bool,
|
|
||||||
scaling: f64,
|
|
||||||
k_proj: Linear,
|
|
||||||
v_proj: Linear,
|
|
||||||
q_proj: Linear,
|
|
||||||
out_proj: Linear,
|
|
||||||
kv_cache: Option<(Tensor, Tensor)>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TrOCRAttention {
|
|
||||||
fn load(
|
|
||||||
vb: VarBuilder,
|
|
||||||
cfg: &TrOCRConfig,
|
|
||||||
kdim: Option<usize>,
|
|
||||||
vdim: Option<usize>,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let embed_dim = cfg.d_model;
|
|
||||||
let num_heads = cfg.decoder_attention_heads;
|
|
||||||
let head_dim = embed_dim / num_heads;
|
|
||||||
let kdim = kdim.unwrap_or(embed_dim);
|
|
||||||
let vdim = vdim.unwrap_or(embed_dim);
|
|
||||||
|
|
||||||
let k_proj = linear_no_bias(kdim, embed_dim, vb.pp("k_proj"))?;
|
|
||||||
let v_proj = linear_no_bias(vdim, embed_dim, vb.pp("v_proj"))?;
|
|
||||||
let q_proj = linear_no_bias(embed_dim, embed_dim, vb.pp("q_proj"))?;
|
|
||||||
|
|
||||||
let out_proj = linear_no_bias(embed_dim, embed_dim, vb.pp("out_proj"))?;
|
|
||||||
Ok(Self {
|
|
||||||
head_dim,
|
|
||||||
num_heads,
|
|
||||||
is_decoder: true,
|
|
||||||
scaling: 1. / (head_dim as f64).sqrt(),
|
|
||||||
k_proj,
|
|
||||||
v_proj,
|
|
||||||
q_proj,
|
|
||||||
out_proj,
|
|
||||||
kv_cache: None,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn _shape(&self, tensor: &Tensor, bsz: usize) -> Result<Tensor> {
|
|
||||||
tensor
|
|
||||||
.reshape((bsz, (), self.num_heads, self.head_dim))?
|
|
||||||
.transpose(1, 2)?
|
|
||||||
.contiguous()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(
|
|
||||||
&mut self,
|
|
||||||
xs: &Tensor,
|
|
||||||
kv_states: Option<&Tensor>,
|
|
||||||
attn_mask: Option<&Tensor>,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
let (b_sz, tgt_len, _) = xs.dims3()?;
|
|
||||||
let query_states = (xs.apply(&self.q_proj)? * self.scaling)?;
|
|
||||||
let (key_states, value_states) = match kv_states {
|
|
||||||
None => {
|
|
||||||
let key_states = self._shape(&xs.apply(&self.k_proj)?, b_sz)?;
|
|
||||||
let value_states = self._shape(&xs.apply(&self.v_proj)?, b_sz)?;
|
|
||||||
if self.is_decoder {
|
|
||||||
let kv_states = match &self.kv_cache {
|
|
||||||
None => (key_states, value_states),
|
|
||||||
Some((p_key_states, p_value_states)) => {
|
|
||||||
let key_states = Tensor::cat(&[p_key_states, &key_states], 2)?;
|
|
||||||
let value_states = Tensor::cat(&[p_value_states, &value_states], 2)?;
|
|
||||||
(key_states, value_states)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
self.kv_cache = Some(kv_states.clone());
|
|
||||||
kv_states
|
|
||||||
} else {
|
|
||||||
(key_states, value_states)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Some(kv_states) => {
|
|
||||||
let key_states = self._shape(&kv_states.apply(&self.k_proj)?, b_sz)?;
|
|
||||||
let value_states = self._shape(&kv_states.apply(&self.v_proj)?, b_sz)?;
|
|
||||||
(key_states, value_states)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let proj_shape = (b_sz * self.num_heads, (), self.head_dim);
|
|
||||||
let query_states = self._shape(&query_states, b_sz)?.reshape(proj_shape)?;
|
|
||||||
let key_states = key_states.reshape(proj_shape)?;
|
|
||||||
let value_states = value_states.reshape(proj_shape)?;
|
|
||||||
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
|
|
||||||
let attn_weights = match attn_mask {
|
|
||||||
None => attn_weights,
|
|
||||||
Some(attn_mask) => attn_weights.broadcast_add(attn_mask)?,
|
|
||||||
};
|
|
||||||
let attn_probs = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
|
||||||
let attn_output = attn_probs.matmul(&value_states)?;
|
|
||||||
attn_output
|
|
||||||
.reshape((b_sz, self.num_heads, tgt_len, self.head_dim))?
|
|
||||||
.transpose(1, 2)?
|
|
||||||
.reshape((b_sz, tgt_len, self.head_dim * self.num_heads))?
|
|
||||||
.apply(&self.out_proj)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct TrOCRDecoderLayer {
|
|
||||||
self_attn: TrOCRAttention,
|
|
||||||
activation_fn: candle_nn::Activation,
|
|
||||||
self_attn_layer_norm: LayerNorm,
|
|
||||||
encoder_attn: TrOCRAttention,
|
|
||||||
encoder_attn_layer_norm: LayerNorm,
|
|
||||||
fc1: Linear,
|
|
||||||
fc2: Linear,
|
|
||||||
final_layer_norm: LayerNorm,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TrOCRDecoderLayer {
|
|
||||||
fn load(vb: VarBuilder, cfg: &TrOCRConfig) -> Result<Self> {
|
|
||||||
let embed_dim = cfg.d_model;
|
|
||||||
let self_attn = TrOCRAttention::load(vb.pp("self_attn"), cfg, None, None)?;
|
|
||||||
let self_attn_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("self_attn_layer_norm"))?;
|
|
||||||
let encoder_attn = TrOCRAttention::load(
|
|
||||||
vb.pp("encoder_attn"),
|
|
||||||
cfg,
|
|
||||||
Some(cfg.hidden_size),
|
|
||||||
Some(cfg.hidden_size),
|
|
||||||
)?;
|
|
||||||
let encoder_attn_layer_norm =
|
|
||||||
layer_norm(embed_dim, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
|
|
||||||
let fc1 = linear_no_bias(embed_dim, cfg.decoder_ffn_dim, vb.pp("fc1"))?;
|
|
||||||
let fc2 = linear_no_bias(cfg.decoder_ffn_dim, embed_dim, vb.pp("fc2"))?;
|
|
||||||
let final_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("final_layer_norm"))?;
|
|
||||||
let activation_fn = candle_nn::Activation::Gelu;
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
self_attn,
|
|
||||||
activation_fn,
|
|
||||||
self_attn_layer_norm,
|
|
||||||
encoder_attn,
|
|
||||||
encoder_attn_layer_norm,
|
|
||||||
fc1,
|
|
||||||
fc2,
|
|
||||||
final_layer_norm,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(
|
|
||||||
&mut self,
|
|
||||||
xs: &Tensor,
|
|
||||||
attention_mask: &Tensor,
|
|
||||||
encoder_hidden_states: Option<&Tensor>,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
let residual = xs.clone();
|
|
||||||
let xs = self.self_attn.forward(xs, None, Some(attention_mask))?;
|
|
||||||
let xs = (xs + residual)?;
|
|
||||||
let mut xs = self.self_attn_layer_norm.forward(&xs)?;
|
|
||||||
|
|
||||||
if let Some(encoder_hidden_states) = &encoder_hidden_states {
|
|
||||||
let residual = xs.clone();
|
|
||||||
let encoder_attention_mask = attention_mask.clone(); // TODO
|
|
||||||
xs = self.encoder_attn.forward(
|
|
||||||
&xs,
|
|
||||||
Some(encoder_hidden_states),
|
|
||||||
Some(&encoder_attention_mask),
|
|
||||||
)?;
|
|
||||||
xs = (xs + residual)?;
|
|
||||||
xs = self.encoder_attn_layer_norm.forward(&xs)?
|
|
||||||
}
|
|
||||||
|
|
||||||
let residual = xs.clone();
|
|
||||||
let xs = self.fc1.forward(&xs)?;
|
|
||||||
let xs = self.activation_fn.forward(&xs)?;
|
|
||||||
let xs = self.fc2.forward(&xs)?;
|
|
||||||
let xs = (xs + residual)?;
|
|
||||||
let xs = self.final_layer_norm.forward(&xs)?;
|
|
||||||
|
|
||||||
Ok(xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct TrOCRDecoder {
|
|
||||||
layers: Vec<TrOCRDecoderLayer>,
|
|
||||||
embed_scale: Option<f64>,
|
|
||||||
embed_tokens: Embedding,
|
|
||||||
embed_positions: TrOCRLearnedPositionalEmbedding,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TrOCRDecoder {
|
|
||||||
fn new(cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let vb = vb.pp("decoder.model.decoder");
|
|
||||||
|
|
||||||
let embed_tokens = embedding(cfg.vocab_size, cfg.d_model, vb.pp("embed_tokens"))?;
|
|
||||||
let embed_positions = TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?;
|
|
||||||
let mut layers = Vec::with_capacity(cfg.decoder_layers);
|
|
||||||
let vb_l = vb.pp("layers");
|
|
||||||
for idx in 0..cfg.decoder_layers {
|
|
||||||
let layer = TrOCRDecoderLayer::load(vb_l.pp(idx), cfg)?;
|
|
||||||
layers.push(layer)
|
|
||||||
}
|
|
||||||
let embed_scale = if cfg.scale_embedding {
|
|
||||||
Some((cfg.d_model as f64).sqrt())
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
layers,
|
|
||||||
embed_scale,
|
|
||||||
embed_tokens,
|
|
||||||
embed_positions,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn forward(
|
|
||||||
&mut self,
|
|
||||||
xs: &Tensor,
|
|
||||||
encoder_xs: Option<&Tensor>,
|
|
||||||
past_kv_len: usize,
|
|
||||||
attn_mask: &Tensor,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
let embed_pos = self.embed_positions.forward(xs, past_kv_len as u32)?;
|
|
||||||
let xs = xs.apply(&self.embed_tokens)?;
|
|
||||||
|
|
||||||
let xs = match self.embed_scale {
|
|
||||||
None => xs,
|
|
||||||
Some(scale) => (xs * scale)?,
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut xs = xs.broadcast_add(&embed_pos)?;
|
|
||||||
|
|
||||||
for layer in self.layers.iter_mut() {
|
|
||||||
xs = layer.forward(&xs, attn_mask, encoder_xs)?;
|
|
||||||
}
|
|
||||||
Ok(xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct TrOCREncoder {
|
|
||||||
embeddings: Embeddings,
|
|
||||||
encoder: Encoder,
|
|
||||||
layernorm: LayerNorm,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TrOCREncoder {
|
|
||||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let vb_v = vb.pp("encoder");
|
|
||||||
|
|
||||||
let embeddings = Embeddings::new(cfg, false, vb_v.pp("embeddings"))?;
|
|
||||||
|
|
||||||
let encoder = Encoder::new(cfg, vb_v.pp("encoder"))?;
|
|
||||||
let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_v.pp("layernorm"))?;
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
embeddings,
|
|
||||||
encoder,
|
|
||||||
layernorm,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let embedding_output = self.embeddings.forward(xs, None, false)?;
|
|
||||||
let encoder_outputs = self.encoder.forward(&embedding_output)?;
|
|
||||||
|
|
||||||
self.layernorm.forward(&encoder_outputs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct TrOCRForCausalLM {
|
|
||||||
decoder: TrOCRDecoder,
|
|
||||||
output_projection: Linear,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TrOCRForCausalLM {
|
|
||||||
pub fn new(decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let decoder = TrOCRDecoder::new(decoder_cfg, vb.clone())?;
|
|
||||||
let output_projection =
|
|
||||||
candle_nn::Linear::new(decoder.embed_tokens.embeddings().clone(), None);
|
|
||||||
Ok(Self {
|
|
||||||
decoder,
|
|
||||||
output_projection,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn forward(
|
|
||||||
&mut self,
|
|
||||||
xs: &Tensor,
|
|
||||||
encoder_xs: Option<&Tensor>,
|
|
||||||
past_kv_len: usize,
|
|
||||||
attn_mask: &Tensor,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
let xs = self
|
|
||||||
.decoder
|
|
||||||
.forward(xs, encoder_xs, past_kv_len, attn_mask)?;
|
|
||||||
let xs = xs.apply(&self.output_projection)?;
|
|
||||||
|
|
||||||
Ok(xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct TrOCRModel {
|
|
||||||
encoder: TrOCREncoder,
|
|
||||||
decoder: TrOCRForCausalLM,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TrOCRModel {
|
|
||||||
pub fn new(encoder_cfg: &Config, decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let encoder = TrOCREncoder::new(encoder_cfg, vb.clone())?;
|
|
||||||
let decoder = TrOCRForCausalLM::new(decoder_cfg, vb)?;
|
|
||||||
Ok(Self { encoder, decoder })
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn encoder(&mut self) -> &mut TrOCREncoder {
|
|
||||||
&mut self.encoder
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn decoder(&mut self) -> &mut TrOCRForCausalLM {
|
|
||||||
&mut self.decoder
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn decode(
|
|
||||||
&mut self,
|
|
||||||
xs: &Tensor,
|
|
||||||
encoder_xs: &Tensor,
|
|
||||||
past_kv_len: usize,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
let seq_len = xs.dim(1)?;
|
|
||||||
let mask: Vec<_> = (0..seq_len)
|
|
||||||
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
|
||||||
.collect();
|
|
||||||
let mask = Tensor::from_vec(mask, (seq_len, seq_len), xs.device())?;
|
|
||||||
|
|
||||||
self.decoder
|
|
||||||
.forward(xs, Some(encoder_xs), past_kv_len, &mask)
|
|
||||||
}
|
|
||||||
}
|
|
@ -6,16 +6,16 @@ use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
|||||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/configuration_vit.py
|
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/configuration_vit.py
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub hidden_size: usize,
|
hidden_size: usize,
|
||||||
pub num_hidden_layers: usize,
|
num_hidden_layers: usize,
|
||||||
pub num_attention_heads: usize,
|
num_attention_heads: usize,
|
||||||
pub intermediate_size: usize,
|
intermediate_size: usize,
|
||||||
pub hidden_act: candle_nn::Activation,
|
hidden_act: candle_nn::Activation,
|
||||||
pub layer_norm_eps: f64,
|
layer_norm_eps: f64,
|
||||||
pub image_size: usize,
|
image_size: usize,
|
||||||
pub patch_size: usize,
|
patch_size: usize,
|
||||||
pub num_channels: usize,
|
num_channels: usize,
|
||||||
pub qkv_bias: bool,
|
qkv_bias: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
@ -34,21 +34,6 @@ impl Config {
|
|||||||
qkv_bias: true,
|
qkv_bias: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn microsoft_trocr_base_handwritten() -> Self {
|
|
||||||
Self {
|
|
||||||
hidden_size: 768,
|
|
||||||
num_hidden_layers: 12,
|
|
||||||
num_attention_heads: 12,
|
|
||||||
intermediate_size: 3072,
|
|
||||||
hidden_act: candle_nn::Activation::Gelu,
|
|
||||||
layer_norm_eps: 1e-12,
|
|
||||||
image_size: 384,
|
|
||||||
patch_size: 16,
|
|
||||||
num_channels: 3,
|
|
||||||
qkv_bias: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -91,7 +76,7 @@ impl Module for PatchEmbeddings {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Embeddings {
|
struct Embeddings {
|
||||||
cls_token: Tensor,
|
cls_token: Tensor,
|
||||||
mask_token: Option<Tensor>,
|
mask_token: Option<Tensor>,
|
||||||
patch_embeddings: PatchEmbeddings,
|
patch_embeddings: PatchEmbeddings,
|
||||||
@ -100,7 +85,7 @@ pub struct Embeddings {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Embeddings {
|
impl Embeddings {
|
||||||
pub fn new(cfg: &Config, use_mask_token: bool, vb: VarBuilder) -> Result<Self> {
|
fn new(cfg: &Config, use_mask_token: bool, vb: VarBuilder) -> Result<Self> {
|
||||||
let hidden_size = cfg.hidden_size;
|
let hidden_size = cfg.hidden_size;
|
||||||
let cls_token = vb.get((1, 1, hidden_size), "cls_token")?;
|
let cls_token = vb.get((1, 1, hidden_size), "cls_token")?;
|
||||||
let mask_token = if use_mask_token {
|
let mask_token = if use_mask_token {
|
||||||
@ -130,7 +115,7 @@ impl Embeddings {
|
|||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(
|
fn forward(
|
||||||
&self,
|
&self,
|
||||||
pixel_values: &Tensor,
|
pixel_values: &Tensor,
|
||||||
bool_masked_pos: Option<&Tensor>,
|
bool_masked_pos: Option<&Tensor>,
|
||||||
@ -339,12 +324,12 @@ impl Module for Layer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Encoder {
|
struct Encoder {
|
||||||
layers: Vec<Layer>,
|
layers: Vec<Layer>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Encoder {
|
impl Encoder {
|
||||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
let vb = vb.pp("layer");
|
let vb = vb.pp("layer");
|
||||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
for i in 0..cfg.num_hidden_layers {
|
for i in 0..cfg.num_hidden_layers {
|
||||||
|
@ -198,17 +198,13 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
|
|||||||
mel
|
mel
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn pcm_to_mel<T: Float + std::fmt::Display>(
|
pub fn pcm_to_mel<T: Float + std::fmt::Display>(samples: &[T], filters: &[T]) -> Vec<T> {
|
||||||
cfg: &super::Config,
|
|
||||||
samples: &[T],
|
|
||||||
filters: &[T],
|
|
||||||
) -> Vec<T> {
|
|
||||||
log_mel_spectrogram_(
|
log_mel_spectrogram_(
|
||||||
samples,
|
samples,
|
||||||
filters,
|
filters,
|
||||||
super::N_FFT,
|
super::N_FFT,
|
||||||
super::HOP_LENGTH,
|
super::HOP_LENGTH,
|
||||||
cfg.num_mel_bins,
|
super::N_MELS,
|
||||||
false,
|
false,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -18,7 +18,6 @@ pub struct Config {
|
|||||||
// pub n_text_state: usize,
|
// pub n_text_state: usize,
|
||||||
pub decoder_attention_heads: usize, // n_text_head
|
pub decoder_attention_heads: usize, // n_text_head
|
||||||
pub decoder_layers: usize, // n_text_layer
|
pub decoder_layers: usize, // n_text_layer
|
||||||
#[serde(default)]
|
|
||||||
pub suppress_tokens: Vec<u32>,
|
pub suppress_tokens: Vec<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -27,6 +26,7 @@ pub const DTYPE: candle::DType = candle::DType::F32;
|
|||||||
// Audio parameters.
|
// Audio parameters.
|
||||||
pub const SAMPLE_RATE: usize = 16000;
|
pub const SAMPLE_RATE: usize = 16000;
|
||||||
pub const N_FFT: usize = 400;
|
pub const N_FFT: usize = 400;
|
||||||
|
pub const N_MELS: usize = 80;
|
||||||
pub const HOP_LENGTH: usize = 160;
|
pub const HOP_LENGTH: usize = 160;
|
||||||
pub const CHUNK_LENGTH: usize = 30;
|
pub const CHUNK_LENGTH: usize = 30;
|
||||||
pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
|
pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
use super::Config;
|
use super::Config;
|
||||||
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
|
|
||||||
use candle::{Device, IndexOp, Result, Tensor, D};
|
use candle::{Device, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
||||||
|
|
||||||
@ -7,6 +6,33 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
|
|||||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||||
Ok(Embedding::new(embeddings, hidden_size))
|
Ok(Embedding::new(embeddings, hidden_size))
|
||||||
}
|
}
|
||||||
|
//
|
||||||
|
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
|
||||||
|
// model.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Linear {
|
||||||
|
inner: candle_nn::Linear,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Linear {
|
||||||
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
self.inner.forward(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||||
|
let inner = candle_nn::linear(size1, size2, vb)?;
|
||||||
|
Ok(Linear { inner, span })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||||
|
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
|
||||||
|
Ok(Linear { inner, span })
|
||||||
|
}
|
||||||
|
|
||||||
fn conv1d(
|
fn conv1d(
|
||||||
in_channels: usize,
|
in_channels: usize,
|
||||||
|
@ -124,34 +124,3 @@ impl std::fmt::Debug for QMatMul {
|
|||||||
write!(f, "QMatMul")
|
write!(f, "QMatMul")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub struct LayerNorm {
|
|
||||||
inner: candle_nn::LayerNorm,
|
|
||||||
span: tracing::Span,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LayerNorm {
|
|
||||||
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
|
|
||||||
let inner = candle_nn::LayerNorm::new(weight, bias, eps);
|
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
|
|
||||||
Self { inner, span }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for LayerNorm {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
|
||||||
self.inner.forward(xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn layer_norm<C: Into<candle_nn::LayerNormConfig>>(
|
|
||||||
size: usize,
|
|
||||||
c: C,
|
|
||||||
vb: VarBuilder,
|
|
||||||
) -> Result<LayerNorm> {
|
|
||||||
let inner = candle_nn::layer_norm(size, c, vb)?;
|
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
|
|
||||||
Ok(LayerNorm { inner, span })
|
|
||||||
}
|
|
||||||
|
@ -1,377 +0,0 @@
|
|||||||
/// https://huggingface.co/01-ai/Yi-6B/blob/main/modeling_yi.py
|
|
||||||
use crate::models::with_tracing::{linear_no_bias, Linear};
|
|
||||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
|
||||||
use candle_nn::{Activation, VarBuilder};
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
|
||||||
pub struct Config {
|
|
||||||
pub(crate) vocab_size: usize,
|
|
||||||
pub(crate) hidden_size: usize,
|
|
||||||
pub(crate) intermediate_size: usize,
|
|
||||||
pub(crate) num_hidden_layers: usize,
|
|
||||||
pub(crate) num_attention_heads: usize,
|
|
||||||
pub(crate) num_key_value_heads: usize,
|
|
||||||
pub(crate) hidden_act: Activation,
|
|
||||||
pub(crate) max_position_embeddings: usize,
|
|
||||||
pub(crate) rms_norm_eps: f64,
|
|
||||||
pub(crate) rope_theta: f64,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Config {
|
|
||||||
pub fn config_6b() -> Self {
|
|
||||||
Self {
|
|
||||||
vocab_size: 64000,
|
|
||||||
hidden_size: 4096,
|
|
||||||
intermediate_size: 11008,
|
|
||||||
num_hidden_layers: 32,
|
|
||||||
num_attention_heads: 32,
|
|
||||||
num_key_value_heads: 4,
|
|
||||||
hidden_act: Activation::Silu,
|
|
||||||
max_position_embeddings: 4096,
|
|
||||||
rms_norm_eps: 1e-5,
|
|
||||||
rope_theta: 5_000_000.,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn config_34b() -> Self {
|
|
||||||
Self {
|
|
||||||
vocab_size: 64000,
|
|
||||||
hidden_size: 7168,
|
|
||||||
intermediate_size: 20480,
|
|
||||||
num_hidden_layers: 60,
|
|
||||||
num_attention_heads: 56,
|
|
||||||
num_key_value_heads: 8,
|
|
||||||
hidden_act: Activation::Silu,
|
|
||||||
max_position_embeddings: 4096,
|
|
||||||
rms_norm_eps: 1e-5,
|
|
||||||
rope_theta: 5_000_000.,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct RmsNorm {
|
|
||||||
inner: candle_nn::RmsNorm,
|
|
||||||
span: tracing::Span,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RmsNorm {
|
|
||||||
fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
|
||||||
let inner = candle_nn::rms_norm(size, eps, vb)?;
|
|
||||||
Ok(Self { inner, span })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for RmsNorm {
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
|
||||||
self.inner.forward(x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct RotaryEmbedding {
|
|
||||||
sin: Tensor,
|
|
||||||
cos: Tensor,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let last_dim = xs.dim(D::Minus1)?;
|
|
||||||
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
|
|
||||||
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
|
|
||||||
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RotaryEmbedding {
|
|
||||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
|
||||||
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
|
||||||
let max_seq_len = cfg.max_position_embeddings;
|
|
||||||
let inv_freq: Vec<_> = (0..dim)
|
|
||||||
.step_by(2)
|
|
||||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
|
|
||||||
.collect();
|
|
||||||
let inv_freq_len = inv_freq.len();
|
|
||||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
|
||||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
|
||||||
.to_dtype(dtype)?
|
|
||||||
.reshape((max_seq_len, 1))?;
|
|
||||||
let freqs = t.matmul(&inv_freq)?;
|
|
||||||
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
|
||||||
Ok(Self {
|
|
||||||
sin: freqs.sin()?,
|
|
||||||
cos: freqs.cos()?,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn apply_rotary_emb_qkv(
|
|
||||||
&self,
|
|
||||||
q: &Tensor,
|
|
||||||
k: &Tensor,
|
|
||||||
seqlen_offset: usize,
|
|
||||||
) -> Result<(Tensor, Tensor)> {
|
|
||||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
|
||||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
|
||||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
|
||||||
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
|
||||||
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
|
||||||
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
|
|
||||||
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
|
|
||||||
Ok((q_embed, k_embed))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
#[allow(clippy::upper_case_acronyms)]
|
|
||||||
struct MLP {
|
|
||||||
gate_proj: Linear,
|
|
||||||
up_proj: Linear,
|
|
||||||
down_proj: Linear,
|
|
||||||
act_fn: Activation,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MLP {
|
|
||||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let hidden_sz = cfg.hidden_size;
|
|
||||||
let intermediate_sz = cfg.intermediate_size;
|
|
||||||
let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
|
|
||||||
let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
|
|
||||||
let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
|
|
||||||
Ok(Self {
|
|
||||||
gate_proj,
|
|
||||||
up_proj,
|
|
||||||
down_proj,
|
|
||||||
act_fn: cfg.hidden_act,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for MLP {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
|
||||||
let rhs = xs.apply(&self.up_proj)?;
|
|
||||||
(lhs * rhs)?.apply(&self.down_proj)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct Attention {
|
|
||||||
q_proj: Linear,
|
|
||||||
k_proj: Linear,
|
|
||||||
v_proj: Linear,
|
|
||||||
o_proj: Linear,
|
|
||||||
num_heads: usize,
|
|
||||||
num_kv_heads: usize,
|
|
||||||
num_kv_groups: usize,
|
|
||||||
head_dim: usize,
|
|
||||||
hidden_size: usize,
|
|
||||||
rotary_emb: Arc<RotaryEmbedding>,
|
|
||||||
kv_cache: Option<(Tensor, Tensor)>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Attention {
|
|
||||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let hidden_sz = cfg.hidden_size;
|
|
||||||
let num_heads = cfg.num_attention_heads;
|
|
||||||
let num_kv_heads = cfg.num_key_value_heads;
|
|
||||||
let num_kv_groups = num_heads / num_kv_heads;
|
|
||||||
let head_dim = hidden_sz / num_heads;
|
|
||||||
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
|
||||||
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
|
||||||
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
|
||||||
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
|
|
||||||
Ok(Self {
|
|
||||||
q_proj,
|
|
||||||
k_proj,
|
|
||||||
v_proj,
|
|
||||||
o_proj,
|
|
||||||
num_heads,
|
|
||||||
num_kv_heads,
|
|
||||||
num_kv_groups,
|
|
||||||
head_dim,
|
|
||||||
hidden_size: hidden_sz,
|
|
||||||
rotary_emb,
|
|
||||||
kv_cache: None,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
|
||||||
let n_rep = self.num_kv_groups;
|
|
||||||
if n_rep == 1 {
|
|
||||||
Ok(xs)
|
|
||||||
} else {
|
|
||||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
|
||||||
xs.unsqueeze(2)?
|
|
||||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
|
||||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(
|
|
||||||
&mut self,
|
|
||||||
xs: &Tensor,
|
|
||||||
attention_mask: Option<&Tensor>,
|
|
||||||
seqlen_offset: usize,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
let (b_sz, q_len, _) = xs.dims3()?;
|
|
||||||
|
|
||||||
let query_states = self.q_proj.forward(xs)?;
|
|
||||||
let key_states = self.k_proj.forward(xs)?;
|
|
||||||
let value_states = self.v_proj.forward(xs)?;
|
|
||||||
|
|
||||||
let query_states = query_states
|
|
||||||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
|
||||||
.transpose(1, 2)?;
|
|
||||||
let key_states = key_states
|
|
||||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
|
||||||
.transpose(1, 2)?;
|
|
||||||
let value_states = value_states
|
|
||||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
|
||||||
.transpose(1, 2)?;
|
|
||||||
|
|
||||||
let (query_states, key_states) =
|
|
||||||
self.rotary_emb
|
|
||||||
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
|
|
||||||
|
|
||||||
let (key_states, value_states) = match &self.kv_cache {
|
|
||||||
None => (key_states, value_states),
|
|
||||||
Some((prev_k, prev_v)) => {
|
|
||||||
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
|
|
||||||
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
|
|
||||||
(key_states, value_states)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
|
||||||
|
|
||||||
let key_states = self.repeat_kv(key_states)?;
|
|
||||||
let value_states = self.repeat_kv(value_states)?;
|
|
||||||
|
|
||||||
let attn_output = {
|
|
||||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
|
||||||
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
|
||||||
|
|
||||||
let attn_weights = match attention_mask {
|
|
||||||
None => attn_weights,
|
|
||||||
Some(mask) => attn_weights.broadcast_add(mask)?,
|
|
||||||
};
|
|
||||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
|
||||||
attn_weights.matmul(&value_states)?
|
|
||||||
};
|
|
||||||
attn_output
|
|
||||||
.transpose(1, 2)?
|
|
||||||
.reshape((b_sz, q_len, self.hidden_size))?
|
|
||||||
.apply(&self.o_proj)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct DecoderLayer {
|
|
||||||
self_attn: Attention,
|
|
||||||
mlp: MLP,
|
|
||||||
ln1: RmsNorm,
|
|
||||||
ln2: RmsNorm,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DecoderLayer {
|
|
||||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
|
||||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
|
||||||
let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("ln1"))?;
|
|
||||||
let ln2 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("ln2"))?;
|
|
||||||
Ok(Self {
|
|
||||||
self_attn,
|
|
||||||
mlp,
|
|
||||||
ln1,
|
|
||||||
ln2,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(
|
|
||||||
&mut self,
|
|
||||||
xs: &Tensor,
|
|
||||||
attention_mask: Option<&Tensor>,
|
|
||||||
seqlen_offset: usize,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
let residual = xs;
|
|
||||||
let xs = self.ln1.forward(xs)?;
|
|
||||||
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
|
|
||||||
let xs = (xs + residual)?;
|
|
||||||
let residual = &xs;
|
|
||||||
let xs = xs.apply(&self.ln2)?.apply(&self.mlp)?;
|
|
||||||
residual + xs
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct Model {
|
|
||||||
embed_tokens: candle_nn::Embedding,
|
|
||||||
layers: Vec<DecoderLayer>,
|
|
||||||
norm: RmsNorm,
|
|
||||||
lm_head: Linear,
|
|
||||||
device: Device,
|
|
||||||
dtype: DType,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Model {
|
|
||||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let vb_m = vb.pp("model");
|
|
||||||
let embed_tokens =
|
|
||||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
|
||||||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
|
||||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
|
||||||
let vb_l = vb_m.pp("layers");
|
|
||||||
for layer_idx in 0..cfg.num_hidden_layers {
|
|
||||||
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
|
|
||||||
layers.push(layer)
|
|
||||||
}
|
|
||||||
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
|
||||||
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
|
||||||
Ok(Self {
|
|
||||||
embed_tokens,
|
|
||||||
layers,
|
|
||||||
norm,
|
|
||||||
lm_head,
|
|
||||||
device: vb.device().clone(),
|
|
||||||
dtype: vb.dtype(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn prepare_decoder_attention_mask(
|
|
||||||
&self,
|
|
||||||
b_size: usize,
|
|
||||||
tgt_len: usize,
|
|
||||||
seqlen_offset: usize,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
// Sliding window mask?
|
|
||||||
let mask: Vec<_> = (0..tgt_len)
|
|
||||||
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
|
||||||
.collect();
|
|
||||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
|
||||||
let mask = if seqlen_offset > 0 {
|
|
||||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
|
||||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
|
||||||
} else {
|
|
||||||
mask
|
|
||||||
};
|
|
||||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
|
||||||
.to_dtype(self.dtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
|
||||||
let (b_size, seq_len) = input_ids.dims2()?;
|
|
||||||
let attention_mask = if seq_len <= 1 {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
|
||||||
Some(mask)
|
|
||||||
};
|
|
||||||
let mut xs = self.embed_tokens.forward(input_ids)?;
|
|
||||||
for layer in self.layers.iter_mut() {
|
|
||||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
|
||||||
}
|
|
||||||
xs.narrow(1, seq_len - 1, 1)?
|
|
||||||
.apply(&self.norm)?
|
|
||||||
.apply(&self.lm_head)
|
|
||||||
}
|
|
||||||
}
|
|
@ -10,12 +10,12 @@ pub struct VarBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl VarBuilder {
|
impl VarBuilder {
|
||||||
pub fn from_gguf<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
|
pub fn from_gguf<P: AsRef<std::path::Path>>(p: P, device: &Device) -> Result<Self> {
|
||||||
let mut file = std::fs::File::open(p)?;
|
let mut file = std::fs::File::open(p)?;
|
||||||
let content = candle::quantized::gguf_file::Content::read(&mut file)?;
|
let content = candle::quantized::gguf_file::Content::read(&mut file)?;
|
||||||
let mut data = std::collections::HashMap::new();
|
let mut data = std::collections::HashMap::new();
|
||||||
for tensor_name in content.tensor_infos.keys() {
|
for tensor_name in content.tensor_infos.keys() {
|
||||||
let tensor = content.tensor(&mut file, tensor_name)?;
|
let tensor = content.tensor(&mut file, tensor_name, device)?;
|
||||||
data.insert(tensor_name.to_string(), Arc::new(tensor));
|
data.insert(tensor_name.to_string(), Arc::new(tensor));
|
||||||
}
|
}
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -25,12 +25,12 @@ impl VarBuilder {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_gguf_buffer(buffer: &[u8]) -> Result<Self> {
|
pub fn from_gguf_buffer(buffer: &[u8], device: &Device) -> Result<Self> {
|
||||||
let mut cursor = std::io::Cursor::new(buffer);
|
let mut cursor = std::io::Cursor::new(buffer);
|
||||||
let content = candle::quantized::gguf_file::Content::read(&mut cursor)?;
|
let content = candle::quantized::gguf_file::Content::read(&mut cursor)?;
|
||||||
let mut data = std::collections::HashMap::new();
|
let mut data = std::collections::HashMap::new();
|
||||||
for tensor_name in content.tensor_infos.keys() {
|
for tensor_name in content.tensor_infos.keys() {
|
||||||
let tensor = content.tensor(&mut cursor, tensor_name)?;
|
let tensor = content.tensor(&mut cursor, tensor_name, device)?;
|
||||||
data.insert(tensor_name.to_string(), Arc::new(tensor));
|
data.insert(tensor_name.to_string(), Arc::new(tensor));
|
||||||
}
|
}
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -90,8 +90,4 @@ impl VarBuilder {
|
|||||||
pub fn device(&self) -> &Device {
|
pub fn device(&self) -> &Device {
|
||||||
&self.device
|
&self.device
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn contains_key(&self, key: &str) -> bool {
|
|
||||||
self.data.contains_key(key)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{
|
use candle_nn::{embedding, linear, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
|
||||||
embedding, linear_no_bias as linear, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder,
|
|
||||||
};
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user