Compare commits

..

1 Commits

Author SHA1 Message Date
0bb344f798 [RFC] Start removal of VarBuilder.
- Uses `Initializer` trait instead.
- Allows more decoupling between init and load, which are very different
  ops.
- Allows more decoupling between backends (safetensors, npy, ggml,
  etc...)

This is a minimum viable change.

There are 3 kind of objects with various relations.

The `Model`:
    This is `Llama`, `Linear`, `Rms` ...
    They contain tensors (and possibly other things).  and are used to
    call `forward` basically.
    They should have no ownership of any internals like Rng state or
    actual shapes of the tensors (the tensors already own those)

The `Initializer`:
    This is a struct containing necessary information to generate new
    random tensors. Typically they should own a random generator, and
    generate different kind of random tensors based on what kind of
    `Model` they are initializing.
    This do not own any information about the `Model` itself.
    Default init stores the `Vec<Var>` for now, in order to send to the
    optimizer.

Ths `Config`:
    This is the necessary information to link between the `Model` and
    the `Initializer`. This is another struct which is a companion of
    the implementation of the initalization.
    Typical information is the shape of the tensors for simple `Model`,
    the `eps` for RMS, the `use_bias` boolean to know whether we should
    have a bias in the linear layer.

This should remove all need for `VarBuilder` during intialization, and
allow removing every initialization bit within `VarBuilder`.

Modifying `llama2-c` to follow that initialization is left on purpose
for a follow-up to keep the current PR rather small.
2023-08-16 14:39:36 +02:00
206 changed files with 2450 additions and 23600 deletions

View File

@ -1,8 +1,8 @@
[build]
[target.x86_64-unknown-linux-gnu]
rustflags = ["-C", "target-cpu=native"]
[target.aarch64-apple-darwin]
rustflags = ["-C", "target-cpu=native"]
[target.wasm32-unknown-unknown]
rustflags = ["-C", "target-feature=+simd128"]
[target.x86_64-apple-darwin]
rustflags = ["-C", "target-feature=-avx,-avx2"]

View File

@ -59,7 +59,7 @@ jobs:
- name: Install Rust Stable
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
- uses: Swatinem/rust-cache@v2
- run: apt-get update -y && apt-get install libssl-dev -y
- run: apt update -y && apt install libssl-dev -y
- name: Test (cuda)
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
stop-runner:

5
.gitignore vendored
View File

@ -25,12 +25,7 @@ flamegraph.svg
*.swp
trace-*.json
candle-wasm-examples/*/build
candle-wasm-examples/*/*.bin
candle-wasm-examples/*/*.jpeg
candle-wasm-examples/*/*.wav
candle-wasm-examples/*/*.safetensors
candle-wasm-examples/*/package-lock.json
.DS_Store
.idea/*

View File

@ -1,42 +0,0 @@
# Changelog
This documents the main changes to the `candle` crate.
## v0.2.1 - Unreleased
### Added
### Modified
- Dilations are now supported in conv-transpose2d.
[671](https://github.com/huggingface/candle/pull/671).
## v0.2.0 - 2023-08-30
### Added
- Add the powf op
[664](https://github.com/huggingface/candle/pull/664).
- Stable Diffusion XL support
[647](https://github.com/huggingface/candle/pull/647).
- Add the conv-transpose2d op
[635](https://github.com/huggingface/candle/pull/635).
- Refactor the VarBuilder api
[627](https://github.com/huggingface/candle/pull/627).
- Add some quantization command
[625](https://github.com/huggingface/candle/pull/625).
- Support more quantized types, e.g. Q2K, Q4K, Q5K...
[586](https://github.com/huggingface/candle/pull/586).
- Add pose estimation to the yolo example
[589](https://github.com/huggingface/candle/pull/589).
- Api to write GGUF files
[585](https://github.com/huggingface/candle/pull/585).
- Support more quantization types
[580](https://github.com/huggingface/candle/pull/580).
- Add EfficientNet as an example Computer Vision model
[572](https://github.com/huggingface/candle/pull/572).
- Add a group parameter to convolutions
[566](https://github.com/huggingface/candle/pull/566).
- New dtype: int64
[563](https://github.com/huggingface/candle/pull/563).
- Handling of the GGUF file format.
[559](https://github.com/huggingface/candle/pull/559).
## v0.1.2 - 2023-08-21

View File

@ -3,22 +3,19 @@ members = [
"candle-core",
"candle-datasets",
"candle-examples",
"candle-book",
"candle-nn",
"candle-pyo3",
"candle-transformers",
"candle-wasm-examples/llama2-c",
"candle-wasm-examples/whisper",
"candle-wasm-examples/yolo",
]
exclude = [
"candle-flash-attn",
"candle-kernels",
]
resolver = "2"
[workspace.package]
version = "0.2.1"
version = "0.1.1"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -34,10 +31,9 @@ clap = { version = "4.2.4", features = ["derive"] }
cudarc = { version = "0.9.14", features = ["f16"] }
# TODO: Switch back to the official gemm implementation once it has caught up.
gemm = { version = "0.15.6", package = "candle-gemm" }
hf-hub = "0.3.0"
hf-hub = "0.2.0"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
imageproc = { version = "0.23.0", default-features = false }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
libc = { version = "0.2.147" }
log = "0.4"
@ -47,7 +43,6 @@ num-traits = "0.2.15"
rand = "0.8.5"
rand_distr = "0.4.3"
rayon = "1.7.0"
rusttype = { version = "0.9", default-features = false }
safetensors = "0.3.1"
serde = { version = "1.0.171", features = ["derive"] }
serde_json = "1.0.99"
@ -58,7 +53,6 @@ tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
wav = "1.0.0"
zip = { version = "0.6.6", default-features = false }
parquet = { version = "45.0.0" }
[profile.release-with-debug]
inherits = "release"

View File

@ -1,5 +1,3 @@
.PHONY: clean-ptx clean test
clean-ptx:
find target -name "*.ptx" -type f -delete
echo "" > candle-kernels/src/lib.rs
@ -13,4 +11,8 @@ clean:
test:
cargo test
pyo3-test:
cargo build --profile=release-with-debug --package candle-pyo3
python3 candle-pyo3/test.py
all: test

135
README.md
View File

@ -1,5 +1,5 @@
# candle
[![discord server](https://dcbadge.vercel.app/api/server/hugging-face-879548962464493619)](https://discord.gg/hugging-face-879548962464493619)
[![discord server](https://dcbadge.vercel.app/api/server/hugging-face-879548962464493619)](https://discord.com/channels/879548962464493619/1136218819447238726)
[![Latest version](https://img.shields.io/crates/v/candle-core.svg)](https://crates.io/crates/candle-core)
[![Documentation](https://docs.rs/candle-core/badge.svg)](https://docs.rs/candle-core)
![License](https://img.shields.io/crates/l/candle-core.svg)
@ -7,65 +7,29 @@
Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support)
and ease of use. Try our online demos:
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
[LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2),
[yolo](https://huggingface.co/spaces/lmz/candle-yolo).
[llama2](https://huggingface.co/spaces/lmz/candle-llama2).
## Get started
Make sure that you have [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) correctly installed as described in [**Installation**](https://huggingface.github.io/candle/guide/installation.html).
Let's see how to run a simple matrix multiplication.
Write the following to your `myapp/src/main.rs` file:
```rust
use candle_core::{Device, Tensor};
let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let device = Device::Cpu;
let a = Tensor::randn(0f32, 1., (2, 3), &device)?;
let b = Tensor::randn(0f32, 1., (3, 4), &device)?;
let c = a.matmul(&b)?;
println!("{c}");
Ok(())
}
let c = a.matmul(&b)?;
println!("{c}");
```
`cargo run` should display a tensor of shape `Tensor[[2, 4], f32]`.
Having installed `candle` with Cuda support, simply define the `device` to be on GPU:
```diff
- let device = Device::Cpu;
+ let device = Device::new_cuda(0)?;
```
For more advanced examples, please have a look at the following section.
## Check out our examples
Check out our [examples](./candle-examples/examples/):
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
- [Llama and Llama-v2](./candle-examples/examples/llama/): general LLM.
- [Falcon](./candle-examples/examples/falcon/): general LLM.
- [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code
generation.
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions.
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
using self-supervision (can be used for imagenet classification, depth
evaluation, segmentation).
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
the LLaMA model using the same quantization techniques as
[llama.cpp](https://github.com/ggerganov/llama.cpp).
- [yolo-v3](./candle-examples/examples/yolo-v3/) and
[yolo-v8](./candle-examples/examples/yolo-v8/): object detection and pose
estimation models.
[segment-anything](./candle-examples/examples/segment-anything/): image
segmentation model with prompt.
image generative model, yet to be optimized.
Run them using the following commands:
```
cargo run --example whisper --release
@ -73,16 +37,10 @@ cargo run --example llama --release
cargo run --example falcon --release
cargo run --example bert --release
cargo run --example bigcode --release
cargo run --example stable-diffusion --release -- --prompt "a rusty robot holding a fire torch"
cargo run --example dinov2 --release -- --image path/to/myinput.jpg
cargo run --example quantized --release
cargo run --example yolo-v3 --release -- myimage.jpg
cargo run --example yolo-v8 --release -- myimage.jpg # for pose estimation, add --task pose
cargo run --example segment-anything --release -- --image myimage.jpg
cargo run --example stable-diffusion --release --features image -- --prompt "a rusty robot holding a fire torch"
```
In order to use **CUDA** add `--features cuda` to the example command line. If
you have cuDNN installed, use `--features cudnn` for even more speedups.
In order to use **CUDA** add `--features cuda` to the example command line.
There are also some wasm examples for whisper and
[llama2.c](https://github.com/karpathy/llama2.c). You can either build them with
@ -90,16 +48,16 @@ There are also some wasm examples for whisper and
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
[llama2](https://huggingface.co/spaces/lmz/candle-llama2).
For LLaMA2, run the following command to retrieve the weight files and start a
For llama2, run the following command to retrieve the weight files and start a
test server:
```bash
cd candle-wasm-examples/llama2-c
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/model.bin
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/tokenizer.json
trunk serve --release --port 8081
trunk serve --release --public-url /candle-llama2/ --port 8081
```
And then head over to
[http://localhost:8081/](http://localhost:8081/).
[http://localhost:8081/candle-llama2](http://localhost:8081/candle-llama2).
<!--- ANCHOR: features --->
@ -112,14 +70,11 @@ And then head over to
- Optimized CPU backend with optional MKL support for x86 and Accelerate for macs.
- CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.
- WASM support, run your models in a browser.
- Included models.
- LLMs: LLaMA v1 and v2, Falcon, StarCoder.
- Whisper (multi-lingual support).
- Model support out of the box.
- LLMs: Llama v1 and v2, Falcon, StarCoder.
- Whisper.
- Stable Diffusion.
- Computer Vision: DINOv2, EfficientNet, yolo-v3, yolo-v8.
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
- Serverless (on CPU), small and fast deployments.
- Quantization support using the llama.cpp quantized types.
<!--- ANCHOR_END: features --->
@ -136,7 +91,7 @@ Cheatsheet:
| Operations | `tensor.view((2, 2))` | `tensor.reshape((2, 2))?` |
| Operations | `a.matmul(b)` | `a.matmul(&b)?` |
| Arithmetic | `a + b` | `&a + &b` |
| Device | `tensor.to(device="cuda")` | `tensor.to_device(&Device::new_cuda(0)?)?` |
| Device | `tensor.to(device="cuda")` | `tensor.to_device(&Device::Cuda(0))?` |
| Dtype | `tensor.to(dtype=torch.float16)` | `tensor.to_dtype(&DType::F16)?` |
| Saving | `torch.save({"A": A}, "model.bin")` | `candle::safetensors::save(&HashMap::from([("A", A)]), "model.safetensors")?` |
| Loading | `weights = torch.load("model.bin")` | `candle::safetensors::load("model.safetensors", &device)` |
@ -189,74 +144,34 @@ Finally, Rust is cool! A lot of the HF ecosystem already has Rust crates, like [
#### Missing symbols when compiling with the mkl feature.
If you get some missing symbols when compiling binaries/tests using the mkl
or accelerate features, e.g. for mkl you get:
features, e.g.:
```
= note: /usr/bin/ld: (....o): in function `blas::sgemm':
.../blas-0.22.0/src/lib.rs:1944: undefined reference to `sgemm_' collect2: error: ld returned 1 exit status
= note: some `extern` functions couldn't be found; some native libraries may need to be installed or have their path specified
= note: use the `-l` flag to specify native libraries to link
= note: use the `cargo:rustc-link-lib` directive to specify the native libraries to link with Cargo
```
or for accelerate:
```
Undefined symbols for architecture arm64:
"_dgemm_", referenced from:
candle_core::accelerate::dgemm::h1b71a038552bcabe in libcandle_core...
"_sgemm_", referenced from:
candle_core::accelerate::sgemm::h2cf21c592cba3c47 in libcandle_core...
ld: symbol(s) not found for architecture arm64
= note: use the `cargo:rustc-link-lib` directive to specify the native libraries to link with Cargo (see https://doc.rust-lang.org/cargo/reference/build-scripts.html#cargorustc-link-libkindname)
```
This is likely due to a missing linker flag that was needed to enable the mkl library. You
can try adding the following for mkl at the top of your binary:
```rust
can try adding the following at the top of your binary:
```
extern crate intel_mkl_src;
```
or for accelerate:
```rust
extern crate accelerate_src;
```
#### Cannot run the LLaMA examples: access to source requires login credentials
#### Cannot run llama example : access to source requires login credentials
```
Error: request error: https://huggingface.co/meta-llama/Llama-2-7b-hf/resolve/main/tokenizer.json: status code 401
```
This is likely because you're not permissioned for the LLaMA-v2 model. To fix
this, you have to register on the huggingface-hub, accept the [LLaMA-v2 model
This is likely because you're not permissioned for the llama-v2 model. To fix
this, you have to register on the huggingface-hub, accept the [llama-v2 model
conditions](https://huggingface.co/meta-llama/Llama-2-7b-hf), and set up your
authentication token. See issue
[#350](https://github.com/huggingface/candle/issues/350) for more details.
#### Missing cute/cutlass headers when compiling flash-attn
```
In file included from kernels/flash_fwd_launch_template.h:11:0,
from kernels/flash_fwd_hdim224_fp16_sm80.cu:5:
kernels/flash_fwd_kernel.h:8:10: fatal error: cute/algorithm/copy.hpp: No such file or directory
#include <cute/algorithm/copy.hpp>
^~~~~~~~~~~~~~~~~~~~~~~~~
compilation terminated.
Error: nvcc error while compiling:
```
[cutlass](https://github.com/NVIDIA/cutlass) is provided as a git submodule so you may want to run the following command to check it in properly.
```bash
git submodule update --init
```
#### Compiling with flash-attention fails
```
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ...:
```
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the CANDLE_NVCC_CCBIN environment variable.
```
env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
```
#### Tracking down errors
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle

View File

@ -1,49 +0,0 @@
[package]
name = "candle-book"
version.workspace = true
edition.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true
readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.2.1" }
candle-nn = { path = "../candle-nn", version = "0.2.1" }
candle-transformers = { path = "../candle-transformers", version = "0.2.1" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
num-traits = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
image = { workspace = true, optional = true }
[dev-dependencies]
anyhow = { workspace = true }
byteorder = { workspace = true }
hf-hub = { workspace = true, features=["tokio"]}
clap = { workspace = true }
memmap2 = { workspace = true }
rand = { workspace = true }
tokenizers = { workspace = true, features = ["onig"] }
tracing = { workspace = true }
tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true }
wav = { workspace = true }
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
tokio = "1.29.1"
parquet = { workspace = true }
image = { workspace = true }
[build-dependencies]
anyhow = { workspace = true }
[features]
default = []

View File

@ -12,11 +12,7 @@
- [Running a model](inference/README.md)
- [Using the hub](inference/hub.md)
- [Error management](error_manage.md)
- [Training](training/README.md)
- [MNIST](training/mnist.md)
- [Fine-tuning]()
- [Serialization]()
- [Error management]()
- [Advanced Cuda usage]()
- [Writing a custom kernel]()
- [Porting a custom kernel]()
@ -25,3 +21,7 @@
- [Creating a WASM app]()
- [Creating a REST api webserver]()
- [Creating a desktop Tauri app]()
- [Training]()
- [MNIST]()
- [Fine-tuning]()
- [Serialization]()

View File

@ -147,7 +147,7 @@ And rewrite our examples using it
# extern crate candle_core;
# extern crate candle_nn;
use candle_core::{DType, Device, Result, Tensor};
use candle_nn::{Linear, Module};
use candle_nn::Linear;
struct Model {
first: Linear,

View File

@ -1,43 +1,6 @@
# Installation
**With Cuda support**:
1. First, make sure that Cuda is correctly installed.
- `nvcc --version` should print information about your Cuda compiler driver.
- `nvidia-smi --query-gpu=compute_cap --format=csv` should print your GPUs compute capability, e.g. something
like:
```bash
compute_cap
8.9
```
If any of the above commands errors out, please make sure to update your Cuda version.
2. Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) with Cuda support.
Start by creating a new cargo:
```bash
cargo new myapp
cd myapp
```
Make sure to add the `candle-core` crate with the cuda feature:
```bash
cargo add --git https://github.com/huggingface/candle.git candle-core --features "cuda"
```
Run `cargo build` to make sure everything can be correctly built.
```bash
cargo build
```
**Without Cuda support**:
Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) as follows:
Start by creating a new app:
```bash
cargo new myapp
@ -45,12 +8,17 @@ cd myapp
cargo add --git https://github.com/huggingface/candle.git candle-core
```
Finally, run `cargo build` to make sure everything can be correctly built.
At this point, candle will be built **without** CUDA support.
To get CUDA support use the `cuda` feature
```bash
cargo add --git https://github.com/huggingface/candle.git candle-core --features cuda
```
You can check everything works properly:
```bash
cargo build
```
**With mkl support**
You can also see the `mkl` feature which could be interesting to get faster inference on CPU. [Using mkl](./advanced/mkl.md)

View File

@ -39,7 +39,7 @@ cargo add hf-hub --features tokio
```rust,ignore
# This is tested directly in examples crate because it needs external dependencies unfortunately:
# See [this](https://github.com/rust-lang/mdBook/issues/706)
{{#include ../lib.rs:book_hub_1}}
{{#include ../../../candle-examples/src/lib.rs:book_hub_1}}
```
@ -58,7 +58,7 @@ Now that we have our weights, we can use them in our bert architecture:
#
# let weights = repo.get("model.safetensors").unwrap();
use candle_core::{Device, Tensor, DType};
use candle_nn::{Linear, Module};
use candle_nn::Linear;
let weights = candle_core::safetensors::load(weights, &Device::Cpu).unwrap();
@ -81,7 +81,7 @@ For more efficient loading, instead of reading the file, you could use [`memmap2
and will definitely be slower on network mounted disk, because it will issue more read calls.
```rust,ignore
{{#include ../lib.rs:book_hub_2}}
{{#include ../../../candle-examples/src/lib.rs:book_hub_2}}
```
**Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety).
@ -100,5 +100,5 @@ cargo add safetensors
```rust,ignore
{{#include ../lib.rs:book_hub_3}}
{{#include ../../../candle-examples/src/lib.rs:book_hub_3}}
```

View File

@ -1,193 +0,0 @@
#[cfg(test)]
mod tests {
use anyhow::Result;
use candle::{DType, Device, Tensor};
use parquet::file::reader::SerializedFileReader;
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
#[rustfmt::skip]
#[tokio::test]
async fn book_hub_1() {
// ANCHOR: book_hub_1
use candle::Device;
use hf_hub::api::tokio::Api;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").await.unwrap();
let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
// ANCHOR_END: book_hub_1
assert_eq!(weights.len(), 206);
}
#[rustfmt::skip]
#[test]
fn book_hub_2() {
// ANCHOR: book_hub_2
use candle::Device;
use hf_hub::api::sync::Api;
use memmap2::Mmap;
use std::fs;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").unwrap();
let file = fs::File::open(weights_filename).unwrap();
let mmap = unsafe { Mmap::map(&file).unwrap() };
let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
// ANCHOR_END: book_hub_2
assert_eq!(weights.len(), 206);
}
#[rustfmt::skip]
#[test]
fn book_hub_3() {
// ANCHOR: book_hub_3
use candle::{DType, Device, Tensor};
use hf_hub::api::sync::Api;
use memmap2::Mmap;
use safetensors::slice::IndexOp;
use safetensors::SafeTensors;
use std::fs;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").unwrap();
let file = fs::File::open(weights_filename).unwrap();
let mmap = unsafe { Mmap::map(&file).unwrap() };
// Use safetensors directly
let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
let view = tensors
.tensor("bert.encoder.layer.0.attention.self.query.weight")
.unwrap();
// We're going to load shard with rank 1, within a world_size of 4
// We're going to split along dimension 0 doing VIEW[start..stop, :]
let rank = 1;
let world_size = 4;
let dim = 0;
let dtype = view.dtype();
let mut tp_shape = view.shape().to_vec();
let size = tp_shape[0];
if size % world_size != 0 {
panic!("The dimension is not divisble by `world_size`");
}
let block_size = size / world_size;
let start = rank * block_size;
let stop = (rank + 1) * block_size;
// Everything is expressed in tensor dimension
// bytes offsets is handled automatically for safetensors.
let iterator = view.slice(start..stop).unwrap();
tp_shape[dim] = block_size;
// Convert safetensors Dtype to candle DType
let dtype: DType = dtype.try_into().unwrap();
// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
// ANCHOR_END: book_hub_3
assert_eq!(view.shape(), &[768, 768]);
assert_eq!(tp_tensor.dims(), &[192, 768]);
}
#[rustfmt::skip]
#[test]
fn book_training_1() -> Result<()>{
// ANCHOR: book_training_1
use hf_hub::{api::sync::Api, Repo, RepoType};
let dataset_id = "mnist".to_string();
let api = Api::new()?;
let repo = Repo::with_revision(
dataset_id,
RepoType::Dataset,
"refs/convert/parquet".to_string(),
);
let repo = api.repo(repo);
let test_parquet_filename = repo.get("mnist/test/0000.parquet")?;
let train_parquet_filename = repo.get("mnist/train/0000.parquet")?;
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
// ANCHOR_END: book_training_1
// Ignore unused
let _train = train_parquet;
// ANCHOR: book_training_2
for row in test_parquet {
for (idx, (name, field)) in row?.get_column_iter().enumerate() {
println!("Column id {idx}, name {name}, value {field}");
}
}
// ANCHOR_END: book_training_2
let test_parquet_filename = repo.get("mnist/test/0000.parquet")?;
let train_parquet_filename = repo.get("mnist/train/0000.parquet")?;
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
// ANCHOR: book_training_3
let test_samples = 10_000;
let mut test_buffer_images: Vec<u8> = Vec::with_capacity(test_samples * 784);
let mut test_buffer_labels: Vec<u8> = Vec::with_capacity(test_samples);
for row in test_parquet{
for (_name, field) in row?.get_column_iter() {
if let parquet::record::Field::Group(subrow) = field {
for (_name, field) in subrow.get_column_iter() {
if let parquet::record::Field::Bytes(value) = field {
let image = image::load_from_memory(value.data()).unwrap();
test_buffer_images.extend(image.to_luma8().as_raw());
}
}
}else if let parquet::record::Field::Long(label) = field {
test_buffer_labels.push(*label as u8);
}
}
}
let test_images = (Tensor::from_vec(test_buffer_images, (test_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
let test_labels = Tensor::from_vec(test_buffer_labels, (test_samples, ), &Device::Cpu)?;
let train_samples = 60_000;
let mut train_buffer_images: Vec<u8> = Vec::with_capacity(train_samples * 784);
let mut train_buffer_labels: Vec<u8> = Vec::with_capacity(train_samples);
for row in train_parquet{
for (_name, field) in row?.get_column_iter() {
if let parquet::record::Field::Group(subrow) = field {
for (_name, field) in subrow.get_column_iter() {
if let parquet::record::Field::Bytes(value) = field {
let image = image::load_from_memory(value.data()).unwrap();
train_buffer_images.extend(image.to_luma8().as_raw());
}
}
}else if let parquet::record::Field::Long(label) = field {
train_buffer_labels.push(*label as u8);
}
}
}
let train_images = (Tensor::from_vec(train_buffer_images, (train_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
let train_labels = Tensor::from_vec(train_buffer_labels, (train_samples, ), &Device::Cpu)?;
let mnist = candle_datasets::vision::Dataset {
train_images,
train_labels,
test_images,
test_labels,
labels: 10,
};
// ANCHOR_END: book_training_3
assert_eq!(mnist.test_images.dims(), &[10_000, 784]);
assert_eq!(mnist.test_labels.dims(), &[10_000]);
assert_eq!(mnist.train_images.dims(), &[60_000, 784]);
assert_eq!(mnist.train_labels.dims(), &[60_000]);
Ok(())
}
}

View File

@ -1,39 +1 @@
# Training
Training starts with data. We're going to use the huggingface hub and
start with the Hello world dataset of machine learning, MNIST.
Let's start with downloading `MNIST` from [huggingface](https://huggingface.co/datasets/mnist).
This requires [`hf-hub`](https://github.com/huggingface/hf-hub).
```bash
cargo add hf-hub
```
This is going to be very hands-on for now.
```rust,ignore
{{#include ../../../candle-examples/src/lib.rs:book_training_1}}
```
This uses the standardized `parquet` files from the `refs/convert/parquet` branch on every dataset.
Our handles are now [`parquet::file::serialized_reader::SerializedFileReader`].
We can inspect the content of the files with:
```rust,ignore
{{#include ../../../candle-examples/src/lib.rs:book_training_2}}
```
You should see something like:
```bash
Column id 1, name label, value 6
Column id 0, name image, value {bytes: [137, ....]
Column id 1, name label, value 8
Column id 0, name image, value {bytes: [137, ....]
```
So each row contains 2 columns (image, label) with image being saved as bytes.
Let's put them into a useful struct.

View File

@ -1,10 +1 @@
# MNIST
So we now have downloaded the MNIST parquet files, let's put them in a simple struct.
```rust,ignore
{{#include ../lib.rs:book_training_3}}
```
The parsing of the file and putting it into single tensors requires the dataset to fit the entire memory.
It is quite rudimentary, but simple enough for a small dataset like MNIST.

View File

@ -12,7 +12,7 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.2.1", optional = true }
candle-kernels = { path = "../candle-kernels", version = "0.1.1", optional = true }
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
half = { workspace = true }

View File

@ -11,7 +11,7 @@ fn main() -> Result<()> {
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
let start = std::time::Instant::now();
let res = inp.conv2d(&w, 0, 1, 1, 1)?;
let res = inp.conv2d(&w, 0, 1);
println!("{:?}", start.elapsed());
println!("{res:?}");
Ok(())

View File

@ -5,10 +5,18 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::quantized::GgmlType;
use candle::{Device, Result, Tensor, D};
use candle_core::{Device, Result, Tensor, D};
use clap::{Parser, Subcommand};
fn softmax<D: candle_core::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> {
let dim = dim.to_index(xs.shape(), "softmax")?;
let max = xs.max_keepdim(dim)?;
let diff = xs.broadcast_sub(&max)?;
let num = diff.exp()?;
let den = num.sum_keepdim(dim)?;
num.broadcast_div(&den)
}
trait Benchmark {
type PreProcessData;
type RunResult;
@ -31,7 +39,7 @@ impl Benchmark for Conv1d {
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
d.0.conv1d(&d.1, 0, 1, 1, 1)
d.0.conv1d(&d.1, 0, 1)
}
const ITERS: usize = 5;
@ -50,7 +58,7 @@ impl Benchmark for Conv2d {
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
d.0.conv2d(&d.1, 0, 1, 1, 1)
d.0.conv2d(&d.1, 0, 1)
}
const ITERS: usize = 1;
@ -73,27 +81,6 @@ impl Benchmark for Matmul {
const ITERS: usize = 100;
}
// This benchmark is similar to:
// https://github.com/ggerganov/llama.cpp/blob/master/examples/benchmark/benchmark-matmult.cpp
struct QMatMul;
impl Benchmark for QMatMul {
type PreProcessData = (candle::quantized::QMatMul, Tensor);
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
let zeros = vec![candle::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
let mm = candle::quantized::QTensor::new(zeros, (4096, 11008))?;
let mm = candle::quantized::QMatMul::from_qtensor(mm);
let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?;
Ok((mm, arg))
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
d.0.forward(&d.1)
}
const ITERS: usize = 100;
}
struct Softmax;
impl Benchmark for Softmax {
type PreProcessData = Tensor;
@ -105,24 +92,7 @@ impl Benchmark for Softmax {
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
candle_nn::ops::softmax(d, D::Minus1)
}
const ITERS: usize = 100;
}
struct SoftmaxLastDim;
impl Benchmark for SoftmaxLastDim {
type PreProcessData = Tensor;
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
// Typical whisper tiny size.
let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?;
Ok(x)
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
candle_nn::ops::softmax_last_dim(d)
softmax(d, D::Minus1)
}
const ITERS: usize = 100;
@ -146,9 +116,7 @@ enum Task {
Conv1d,
Conv2d,
Matmul,
Qmatmul,
Softmax,
SoftmaxLastDim,
}
#[derive(Parser, Debug)]
@ -169,8 +137,6 @@ fn main() -> Result<()> {
Task::Conv2d => run::<Conv2d>(args.iters)?,
Task::Matmul => run::<Matmul>(args.iters)?,
Task::Softmax => run::<Softmax>(args.iters)?,
Task::SoftmaxLastDim => run::<SoftmaxLastDim>(args.iters)?,
Task::Qmatmul => run::<QMatMul>(args.iters)?,
}
Ok(())
}

View File

@ -9,21 +9,9 @@ use candle_core::{Device, Tensor};
fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let in_t = Tensor::rand(-1f32, 1f32, (1, 3, 12, 7), &device)?;
let k_t = Tensor::rand(-1f32, 1f32, (6, 3, 1, 1), &device)?;
let out_t = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
println!("{out_t}");
let in_t = in_t.to_device(&Device::Cpu)?;
let k_t = k_t.to_device(&Device::Cpu)?;
let out_t2 = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
let diff = (out_t.to_device(&Device::Cpu)? - out_t2)?
.sqr()?
.sum_all()?;
println!("{diff}");
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
let res = t.conv2d(&w, 1, 1, 1, 1)?;
let res = t.conv2d(&w, 1, 1)?;
println!("{res:?}");
Ok(())
}

View File

@ -1,299 +0,0 @@
use candle_core::quantized::{gguf_file, k_quants, QTensor};
use candle_core::{Device, Result, Tensor};
use clap::{Parser, Subcommand, ValueEnum};
use rayon::prelude::*;
#[derive(ValueEnum, Debug, Clone)]
enum QuantizationMode {
/// The default quantization includes all 2d tensors, except the output tensor which always
/// uses Q6_K.
Llama,
}
impl QuantizationMode {
fn quantize(
&self,
name: &str,
tensor: QTensor,
default: fn(&Tensor) -> Result<QTensor>,
) -> Result<QTensor> {
match self {
Self::Llama => {
// Same behavior as the llama.cpp quantization.
let should_quantize = name.ends_with(".weight") && tensor.rank() == 2;
if should_quantize {
let tensor = tensor.dequantize(&Device::Cpu)?;
if name == "output.weight" {
QTensor::quantize::<k_quants::BlockQ6K>(&tensor)
} else {
default(&tensor)
}
} else {
Ok(tensor)
}
}
}
}
}
#[derive(ValueEnum, Debug, Clone)]
enum Quantization {
#[value(name = "q4_0")]
Q4_0,
#[value(name = "q4_1")]
Q4_1,
#[value(name = "q5_0")]
Q5_0,
#[value(name = "q5_1")]
Q5_1,
#[value(name = "q8_0")]
Q8_0,
#[value(name = "q8_1")]
Q8_1,
Q2k,
Q3k,
Q4k,
Q5k,
Q6k,
Q8k,
F16,
F32,
}
#[derive(ValueEnum, Debug, Clone)]
enum Format {
Safetensors,
Npz,
Ggml,
Gguf,
Pth,
Pickle,
}
impl Format {
fn infer<P: AsRef<std::path::Path>>(p: P) -> Option<Self> {
p.as_ref()
.extension()
.and_then(|e| e.to_str())
.and_then(|e| match e {
// We don't infer any format for .bin as it can be used for ggml/gguf or pytorch.
"safetensors" | "safetensor" => Some(Self::Safetensors),
"npz" => Some(Self::Npz),
"pth" | "pt" => Some(Self::Pth),
"ggml" => Some(Self::Ggml),
"gguf" => Some(Self::Gguf),
_ => None,
})
}
}
#[derive(Subcommand, Debug, Clone)]
enum Command {
Ls {
files: Vec<std::path::PathBuf>,
/// The file format to use, if unspecified infer from the file extension.
#[arg(long, value_enum)]
format: Option<Format>,
/// Enable verbose mode.
#[arg(short, long)]
verbose: bool,
},
Quantize {
/// The input file, in gguf format.
in_file: std::path::PathBuf,
/// The output file, in gguf format.
out_file: std::path::PathBuf,
/// The quantization schema to apply.
#[arg(long, value_enum)]
quantization: Quantization,
/// Which tensor to quantize.
#[arg(long, value_enum, default_value_t = QuantizationMode::Llama)]
mode: QuantizationMode,
},
}
#[derive(Parser, Debug, Clone)]
struct Args {
#[command(subcommand)]
command: Command,
}
fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> Result<()> {
let format = match format {
Some(format) => format,
None => match Format::infer(file) {
Some(format) => format,
None => {
println!(
"{file:?}: cannot infer format from file extension, use the --format flag"
);
return Ok(());
}
},
};
match format {
Format::Npz => {
let tensors = candle_core::npy::NpzTensors::new(file)?;
let mut names = tensors.names();
names.sort();
for name in names {
let shape_dtype = match tensors.get_shape_and_dtype(name) {
Ok((shape, dtype)) => format!("[{shape:?}; {dtype:?}]"),
Err(err) => err.to_string(),
};
println!("{name}: {shape_dtype}")
}
}
Format::Safetensors => {
let tensors = unsafe { candle_core::safetensors::MmapedFile::new(file)? };
let tensors = tensors.deserialize()?;
let mut tensors = tensors.tensors();
tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, view) in tensors.iter() {
let dtype = view.dtype();
let dtype = match candle_core::DType::try_from(dtype) {
Ok(dtype) => format!("{dtype:?}"),
Err(_) => format!("{dtype:?}"),
};
let shape = view.shape();
println!("{name}: [{shape:?}; {dtype}]")
}
}
Format::Pth => {
let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose)?;
tensors.sort_by(|a, b| a.name.cmp(&b.name));
for tensor_info in tensors.iter() {
println!(
"{}: [{:?}; {:?}]",
tensor_info.name,
tensor_info.layout.shape(),
tensor_info.dtype,
);
if verbose {
println!(" {:?}", tensor_info);
}
}
}
Format::Pickle => {
let file = std::fs::File::open(file)?;
let mut reader = std::io::BufReader::new(file);
let mut stack = candle_core::pickle::Stack::empty();
stack.read_loop(&mut reader)?;
for (i, obj) in stack.stack().iter().enumerate() {
println!("{i} {obj:?}");
}
}
Format::Ggml => {
let mut file = std::fs::File::open(file)?;
let content = candle_core::quantized::ggml_file::Content::read(&mut file)?;
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, qtensor) in tensors.iter() {
println!("{name}: [{:?}; {:?}]", qtensor.shape(), qtensor.dtype());
}
}
Format::Gguf => {
let mut file = std::fs::File::open(file)?;
let content = gguf_file::Content::read(&mut file)?;
if verbose {
let mut metadata = content.metadata.into_iter().collect::<Vec<_>>();
metadata.sort_by(|a, b| a.0.cmp(&b.0));
println!("metadata entries ({})", metadata.len());
for (key, value) in metadata.iter() {
println!(" {key}: {value:?}");
}
}
let mut tensors = content.tensor_infos.into_iter().collect::<Vec<_>>();
tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, info) in tensors.iter() {
println!("{name}: [{:?}; {:?}]", info.shape, info.ggml_dtype);
}
}
}
Ok(())
}
fn run_quantize(
in_file: std::path::PathBuf,
out_file: std::path::PathBuf,
q: Quantization,
qmode: QuantizationMode,
) -> Result<()> {
// Open the out file early so as to fail directly on missing directories etc.
let mut out_file = std::fs::File::create(out_file)?;
let mut in_ = std::fs::File::open(&in_file)?;
let content = gguf_file::Content::read(&mut in_)?;
println!("tensors: {}", content.tensor_infos.len());
let quantize_fn = match q {
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
Quantization::F16 => QTensor::quantize::<half::f16>,
Quantization::F32 => QTensor::quantize::<f32>,
};
let qtensors = content
.tensor_infos
.par_iter()
.map(|(name, _)| {
println!(" quantizing {name}");
let mut in_file = std::fs::File::open(&in_file)?;
let tensor = content.tensor(&mut in_file, name)?;
let tensor = qmode.quantize(name, tensor, quantize_fn)?;
Ok((name, tensor))
})
.collect::<Result<Vec<_>>>()?;
let qtensors = qtensors
.iter()
.map(|(k, v)| (k.as_str(), v))
.collect::<Vec<_>>();
let metadata = content
.metadata
.iter()
.map(|(k, v)| (k.as_str(), v))
.collect::<Vec<_>>();
gguf_file::write(&mut out_file, metadata.as_slice(), &qtensors)?;
Ok(())
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
match args.command {
Command::Ls {
files,
format,
verbose,
} => {
let multiple_files = files.len() > 1;
for file in files.iter() {
if multiple_files {
println!("--- {file:?} ---");
}
run_ls(file, format.clone(), verbose)?
}
}
Command::Quantize {
in_file,
out_file,
quantization,
mode,
} => run_quantize(in_file, out_file, quantization, mode)?,
}
Ok(())
}

View File

@ -50,8 +50,6 @@ mod ffi {
pub fn vvcos(dst: *mut c_double, src: *const c_double, len: *const c_int);
pub fn vvlogf(dst: *mut c_float, src: *const c_float, len: *const c_int);
pub fn vvlog(dst: *mut c_double, src: *const c_double, len: *const c_int);
pub fn vvtanhf(dst: *mut c_float, src: *const c_float, len: *const c_int);
pub fn vvtanh(dst: *mut c_double, src: *const c_double, len: *const c_int);
pub fn vDSP_vaddD(
_: *const c_double,
@ -125,42 +123,6 @@ mod ffi {
_: c_long,
_: c_ulong,
);
pub fn vDSP_vminD(
_: *const c_double,
_: c_long,
_: *const c_double,
_: c_long,
_: *mut c_double,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vmin(
_: *const c_float,
_: c_long,
_: *const c_float,
_: c_long,
_: *mut c_float,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vmaxD(
_: *const c_double,
_: c_long,
_: *const c_double,
_: c_long,
_: *mut c_double,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vmax(
_: *const c_float,
_: c_long,
_: *const c_float,
_: c_long,
_: *mut c_float,
_: c_long,
_: c_ulong,
);
}
}
@ -310,26 +272,6 @@ pub fn vd_cos(a: &[f64], y: &mut [f64]) {
}
unsafe { ffi::vvcos(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vvtanhf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vvtanh(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vs_ln(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
@ -406,7 +348,3 @@ binary_op!(vs_mul, f32, vDSP_vmul);
binary_op!(vd_mul, f64, vDSP_vmulD);
binary_op!(vs_div, f32, vDSP_vdiv);
binary_op!(vd_div, f64, vDSP_vdivD);
binary_op!(vs_max, f32, vDSP_vmax);
binary_op!(vd_max, f64, vDSP_vmaxD);
binary_op!(vs_min, f32, vDSP_vmin);
binary_op!(vd_min, f64, vDSP_vminD);

View File

@ -15,8 +15,6 @@ pub trait BackendStorage: Sized {
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self>;
fn powf(&self, _: &Layout, _: f64) -> Result<Self>;
fn elu(&self, _: &Layout, _: f64) -> Result<Self>;
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self>;
@ -47,14 +45,6 @@ pub trait BackendStorage: Sized {
_params: &crate::conv::ParamsConv2D,
) -> Result<Self>;
fn conv_transpose2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose2D,
) -> Result<Self>;
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;

View File

@ -60,11 +60,6 @@ impl Tensor {
kernel: rhs,
..
}
| Op::ConvTranspose2D {
arg: lhs,
kernel: rhs,
..
}
| Op::CustomOp2(lhs, rhs, _)
| Op::Binary(lhs, rhs, _)
| Op::Gather(lhs, rhs, _)
@ -101,11 +96,9 @@ impl Tensor {
| Op::ToDType(node)
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
| Op::Permute(node, _)
| Op::Narrow(node, _, _, _)
| Op::Unary(node, _)
| Op::Elu(node, _)
| Op::Powf(node, _)
| Op::CustomOp1(node, _) => {
let (tg, nodes) = walk(node, nodes, already_seen);
track_grad |= tg;
@ -168,21 +161,6 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?;
}
Op::Binary(lhs, rhs, BinaryOp::Minimum)
| Op::Binary(lhs, rhs, BinaryOp::Maximum) => {
let mask_lhs = node.eq(lhs)?.to_dtype(grad.dtype())?;
let mask_rhs = node.eq(rhs)?.to_dtype(grad.dtype())?;
// If both masks are 1 one the same point, we want to scale the
// gradient by 0.5 rather than 1.
let lhs_grad = mask_lhs.mul(&grad)?.div(&(&mask_rhs + 1.)?)?;
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
let rhs_grad = mask_rhs.mul(&grad)?.div(&(&mask_lhs + 1.)?)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::WhereCond(pred, t, f) => {
let zeros = grad.zeros_like()?;
let t_sum_grad = grads.or_insert(t)?;
@ -193,75 +171,9 @@ impl Tensor {
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
Op::Conv2D {
arg,
kernel,
padding,
stride,
dilation,
} => {
// The output height for conv_transpose2d is:
// (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1
let grad_h = grad.dim(2)?;
let k_h = kernel.dim(2)?;
let out_size =
(grad_h - 1) * stride + dilation * (k_h - 1) + 1 - 2 * padding;
let out_padding = arg.dim(2)? - out_size;
let grad_arg = grad.conv_transpose2d(
kernel,
*padding,
out_padding,
*stride,
*dilation,
)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
let grad_kernel = arg
.transpose(0, 1)?
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
op: "conv-transpose2d",
})?,
Op::AvgPool2D {
arg,
kernel_size,
stride,
} => {
if kernel_size != stride {
crate::bail!("backward not supported for avgpool2d if ksize {kernel_size:?} != stride {stride:?}")
}
let (_n, _c, h, w) = arg.dims4()?;
let grad_arg = grad.upsample_nearest2d(h, w)?;
let grad_arg =
(grad_arg * (1f64 / (kernel_size.0 * kernel_size.1) as f64))?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
}
Op::MaxPool2D {
arg,
kernel_size,
stride,
} => {
if kernel_size != stride {
crate::bail!("backward not supported for maxpool2d if ksize {kernel_size:?} != stride {stride:?}")
}
let (_n, _c, h, w) = arg.dims4()?;
// For computing the max-pool gradient, we compute a mask where a 1 means
// that the element is the maximum, then we apply this mask to the
// upsampled gradient (taking into account that multiple max may exist so
// we scale the gradient for this case).
let node_upsampled = node.upsample_nearest2d(h, w)?;
let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?;
let avg = mask.avg_pool2d_with_stride(*kernel_size, *stride)?;
let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
}
Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?,
Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?,
Op::MaxPool2D { .. } => Err(Error::BackwardNotSupported { op: "max-pool2d" })?,
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest2d",
})?,
@ -379,11 +291,6 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
}
Op::Unary(arg, UnaryOp::Tanh) => {
let sum_grad = grads.or_insert(arg)?;
let minus_dtanh = (node.sqr()? - 1.)?;
*sum_grad = sum_grad.sub(&(&grad * &minus_dtanh)?)?
}
Op::Unary(arg, UnaryOp::Abs) => {
let sum_grad = grads.or_insert(arg)?;
let ones = arg.ones_like()?;
@ -443,11 +350,6 @@ impl Tensor {
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
}
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
Op::Powf(arg, e) => {
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::CustomOp1(arg, c) => {
if let Some(arg_grad) = c.bwd(arg, node, &grad)? {
let sum_grad = grads.or_insert(arg)?;
@ -501,15 +403,6 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Permute(arg, dims) => {
let mut inv_dims = vec![0; dims.len()];
for (i, &dim_idx) in dims.iter().enumerate() {
inv_dims[dim_idx] = i
}
let arg_grad = grad.permute(inv_dims)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
};
}
}

View File

@ -1,5 +1,3 @@
use crate::{op::BackpropOp, op::Op, Error, Result, Tensor};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConv1D {
pub(crate) b_size: usize,
@ -11,12 +9,12 @@ pub struct ParamsConv1D {
pub(crate) k_size: usize,
pub(crate) padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
}
impl ParamsConv1D {
pub(crate) fn l_out(&self) -> usize {
(self.l_in + 2 * self.padding - self.dilation * (self.k_size - 1) - 1) / self.stride + 1
let dilation = 1;
(self.l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
@ -36,215 +34,20 @@ pub struct ParamsConv2D {
pub(crate) c_in: usize,
pub(crate) padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
}
impl ParamsConv2D {
pub(crate) fn out_h(&self) -> usize {
(self.i_h + 2 * self.padding - self.dilation * (self.k_h - 1) - 1) / self.stride + 1
let dilation = 1;
(self.i_h + 2 * self.padding - dilation * (self.k_h - 1) - 1) / self.stride + 1
}
pub(crate) fn out_w(&self) -> usize {
(self.i_w + 2 * self.padding - self.dilation * (self.k_w - 1) - 1) / self.stride + 1
let dilation = 1;
(self.i_w + 2 * self.padding - dilation * (self.k_w - 1) - 1) / self.stride + 1
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConvTranspose2D {
pub(crate) b_size: usize,
pub(crate) i_h: usize,
pub(crate) i_w: usize,
pub(crate) k_h: usize,
pub(crate) k_w: usize,
pub(crate) c_out: usize,
pub(crate) c_in: usize,
pub(crate) padding: usize,
pub(crate) output_padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
}
impl ParamsConvTranspose2D {
pub(crate) fn out_h(&self) -> usize {
(self.i_h - 1) * self.stride + self.dilation * (self.k_h - 1) + self.output_padding + 1
- 2 * self.padding
}
pub(crate) fn out_w(&self) -> usize {
(self.i_w - 1) * self.stride + self.dilation * (self.k_w - 1) + self.output_padding + 1
- 2 * self.padding
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
}
}
impl Tensor {
fn conv1d_single_group(&self, kernel: &Self, params: &ParamsConv1D) -> Result<Self> {
let storage =
self.storage()
.conv1d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
arg,
kernel,
padding: params.padding,
stride: params.stride,
dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
}
/// Applies a 1D convolution over the input tensor.
pub fn conv1d(
&self,
kernel: &Self,
padding: usize,
stride: usize,
dilation: usize,
groups: usize,
) -> Result<Self> {
let (c_out, c_in_k, k_size) = kernel.dims3()?;
let (b_size, c_in, l_in) = self.dims3()?;
if c_in != c_in_k * groups {
Err(Error::Conv1dInvalidArgs {
inp_shape: self.shape().clone(),
k_shape: kernel.shape().clone(),
padding,
stride,
msg: "the number of in-channels on the input doesn't match the kernel size",
}
.bt())?
}
let params = ParamsConv1D {
b_size,
l_in,
c_out: c_out / groups,
c_in: c_in / groups,
k_size,
padding,
stride,
dilation,
};
if groups == 1 {
self.conv1d_single_group(kernel, &params)
} else {
let blocks = self.chunk(groups, 1)?;
let kernel = kernel.chunk(groups, 0)?;
let blocks = blocks
.iter()
.zip(&kernel)
.map(|(block, kernel)| block.conv1d_single_group(kernel, &params))
.collect::<Result<Vec<_>>>()?;
Tensor::cat(&blocks, 1)
}
}
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
let storage =
self.storage()
.conv2d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
arg,
kernel,
padding: params.padding,
stride: params.stride,
dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
}
/// Applies a 2D convolution over the input tensor.
pub fn conv2d(
&self,
kernel: &Self,
padding: usize,
stride: usize,
dilation: usize,
groups: usize,
) -> Result<Self> {
let (b_size, c_in, i_h, i_w) = self.dims4()?;
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
if c_in != c_in_k * groups {
crate::bail!(
"in_channel mismatch between input ({c_in}, groups {groups}) and kernel ({c_in_k})"
)
}
let params = ParamsConv2D {
b_size,
i_h,
i_w,
k_h,
k_w,
c_out: c_out / groups,
c_in: c_in / groups,
padding,
stride,
dilation,
};
if groups == 1 {
self.conv2d_single_group(kernel, &params)
} else {
let blocks = self.chunk(groups, 1)?;
let kernel = kernel.chunk(groups, 0)?;
let blocks = blocks
.iter()
.zip(&kernel)
.map(|(block, kernel)| block.conv2d_single_group(kernel, &params))
.collect::<Result<Vec<_>>>()?;
Tensor::cat(&blocks, 1)
}
}
/// Applies a 2D transposed convolution over the input tensor.
pub fn conv_transpose2d(
&self,
kernel: &Self,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
) -> Result<Self> {
let (b_size, c_in, i_h, i_w) = self.dims4()?;
let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?;
if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
}
let params = ParamsConvTranspose2D {
b_size,
i_h,
i_w,
k_h,
k_w,
c_out,
c_in,
padding,
output_padding,
stride,
dilation,
};
let storage = self.storage().conv_transpose2d(
self.layout(),
&kernel.storage(),
kernel.layout(),
&params,
)?;
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose2D {
arg,
kernel,
padding: params.padding,
output_padding: params.output_padding,
stride: params.stride,
dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
}
}

View File

@ -92,7 +92,6 @@ from_tensor!(f64);
from_tensor!(f32);
from_tensor!(f16);
from_tensor!(bf16);
from_tensor!(i64);
from_tensor!(u32);
from_tensor!(u8);
@ -130,11 +129,6 @@ impl Tensor {
f.write_u32::<LittleEndian>(v)?
}
}
DType::I64 => {
for v in vs.to_vec1::<i64>()? {
f.write_i64::<LittleEndian>(v)?
}
}
DType::U8 => {
let vs = vs.to_vec1::<u8>()?;
f.write_all(&vs)?;

View File

@ -103,7 +103,7 @@ impl CpuF16<ARR> for CurrentCpuF16 {
for i in 0..8 {
tmp[i] = (*mem_addr.add(i)).to_f32();
}
_mm256_loadu_ps(tmp.as_ptr())
_mm_loadu_ps(tmp.as_ptr())
}
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {

View File

@ -1,7 +1,4 @@
pub trait VecOps: num_traits::NumAssign + Copy {
fn min(self, rhs: Self) -> Self;
fn max(self, rhs: Self) -> Self;
/// Dot-product of two vectors.
///
/// # Safety
@ -29,47 +26,9 @@ pub trait VecOps: num_traits::NumAssign + Copy {
*res += *xs.add(i)
}
}
/// Maximum element in a non-empty vector.
///
/// # Safety
///
/// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
/// element.
#[inline(always)]
unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) {
*res = *xs;
for i in 1..len {
*res = (*res).max(*xs.add(i))
}
}
/// Minimum element in a non-empty vector.
///
/// # Safety
///
/// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
/// element.
#[inline(always)]
unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) {
*res = *xs;
for i in 1..len {
*res = (*res).min(*xs.add(i))
}
}
}
impl VecOps for f32 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Self::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
Self::max(self, other)
}
#[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
super::vec_dot_f32(lhs, rhs, res, len)
@ -82,16 +41,6 @@ impl VecOps for f32 {
}
impl VecOps for half::f16 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Self::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
Self::max(self, other)
}
#[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
let mut res_f32 = 0f32;
@ -100,61 +49,10 @@ impl VecOps for half::f16 {
}
}
impl VecOps for f64 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Self::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
Self::max(self, other)
}
}
impl VecOps for half::bf16 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Self::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
Self::max(self, other)
}
}
impl VecOps for u8 {
#[inline(always)]
fn min(self, other: Self) -> Self {
<Self as Ord>::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
<Self as Ord>::max(self, other)
}
}
impl VecOps for u32 {
#[inline(always)]
fn min(self, other: Self) -> Self {
<Self as Ord>::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
<Self as Ord>::max(self, other)
}
}
impl VecOps for i64 {
#[inline(always)]
fn min(self, other: Self) -> Self {
<Self as Ord>::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
<Self as Ord>::max(self, other)
}
}
impl VecOps for f64 {}
impl VecOps for half::bf16 {}
impl VecOps for u8 {}
impl VecOps for u32 {}
#[inline(always)]
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {

View File

@ -2,7 +2,6 @@ use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
use half::{bf16, f16};
use rayon::prelude::*;
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
// intercept the oom errors to avoid panicking and provide a proper error.
@ -10,7 +9,6 @@ use rayon::prelude::*;
pub enum CpuStorage {
U8(Vec<u8>),
U32(Vec<u32>),
I64(Vec<i64>),
BF16(Vec<bf16>),
F16(Vec<f16>),
F32(Vec<f32>),
@ -27,7 +25,6 @@ pub trait Map1 {
match vs {
CpuStorage::U8(vs) => Ok(CpuStorage::U8(self.f(vs, layout)?)),
CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)),
CpuStorage::I64(vs) => Ok(CpuStorage::I64(self.f(vs, layout)?)),
CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)),
CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)),
CpuStorage::F32(vs) => Ok(CpuStorage::F32(self.f(vs, layout)?)),
@ -48,7 +45,6 @@ pub trait Map1Any {
match vs {
CpuStorage::U8(vs) => Ok(self.f(vs, layout, CpuStorage::U8)?),
CpuStorage::U32(vs) => Ok(self.f(vs, layout, CpuStorage::U32)?),
CpuStorage::I64(vs) => Ok(self.f(vs, layout, CpuStorage::I64)?),
CpuStorage::BF16(vs) => Ok(self.f(vs, layout, CpuStorage::BF16)?),
CpuStorage::F16(vs) => Ok(self.f(vs, layout, CpuStorage::F16)?),
CpuStorage::F32(vs) => Ok(self.f(vs, layout, CpuStorage::F32)?),
@ -72,7 +68,6 @@ pub trait Map2 {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
@ -101,7 +96,6 @@ pub trait Map2U8 {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
@ -292,9 +286,10 @@ struct ReduceSum<'a> {
impl<'a> ReduceSum<'a> {
#[inline(always)]
fn fold_impl<T>(&self, src: &[T], src_l: &Layout, start_elt: T) -> Result<Vec<T>>
fn fold_impl<T, F>(&self, src: &[T], src_l: &Layout, start_elt: T, f: F) -> Result<Vec<T>>
where
T: WithDType,
F: Fn(T, T) -> T,
{
let mut dst = vec![start_elt; self.dst_shape.elem_count()];
match src_l.contiguous_offsets() {
@ -335,7 +330,7 @@ impl<'a> ReduceSum<'a> {
let (pre, post) = (dst_index / stride, dst_index % stride);
dst_index = (pre / dim) * stride + post;
}
dst[dst_index] += src;
dst[dst_index] = f(dst[dst_index], src);
}
}
None => {
@ -347,7 +342,7 @@ impl<'a> ReduceSum<'a> {
let (pre, post) = (dst_index / stride, dst_index % stride);
dst_index = (pre / dim) * stride + post;
}
dst[dst_index] += src[src_index];
dst[dst_index] = f(dst[dst_index], src[src_index]);
}
}
}
@ -358,7 +353,7 @@ impl<'a> ReduceSum<'a> {
impl<'a> Map1 for ReduceSum<'a> {
#[inline(always)]
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
self.fold_impl(src, src_l, T::zero())
self.fold_impl(src, src_l, T::zero(), |x, y| x + y)
}
}
@ -446,7 +441,7 @@ pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U
}
// This function maps over two strided index sequences.
pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
lhs_l: &Layout,
rhs_l: &Layout,
lhs: &[T],
@ -526,7 +521,7 @@ pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
}
// Similar to binary_map but with vectorized variants.
pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
lhs_l: &Layout,
rhs_l: &Layout,
lhs: &[T],
@ -1053,8 +1048,10 @@ impl<'a> Map2 for Conv1D<'a> {
}
}
let num_threads = crate::utils::get_num_threads();
for offset in 0..p.k_size {
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
let dst_idx = dst_c_idx * l_out;
let k_cont = (0..p.c_in)
.map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
@ -1063,7 +1060,7 @@ impl<'a> Map2 for Conv1D<'a> {
let dst_idx = dst_idx + b_idx * p.c_out * l_out;
for dst_l in 0..l_out {
let dst_idx = dst_idx + dst_l;
let src_l = p.stride * dst_l + offset * p.dilation;
let src_l = p.stride * dst_l + offset;
if src_l < p.padding || src_l >= p.padding + p.l_in {
continue;
}
@ -1122,9 +1119,11 @@ impl<'a> Map2 for Conv2D<'a> {
}
}
let num_threads = crate::utils::get_num_threads();
for offset_h in 0..p.k_h {
for offset_w in 0..p.k_w {
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
let dst_idx = dst_c_idx * out_w * out_h;
let k_cont = (0..p.c_in)
.map(|c_in_idx| {
@ -1138,14 +1137,14 @@ impl<'a> Map2 for Conv2D<'a> {
let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w;
for dst_h in 0..out_h {
let dst_idx = dst_idx + dst_h * out_w;
let src_h = p.stride * dst_h + offset_h * p.dilation;
let src_h = p.stride * dst_h + offset_h;
if src_h < p.padding || src_h >= p.i_h + p.padding {
continue;
}
let src_h = src_h - p.padding;
for dst_w in 0..out_w {
let dst_idx = dst_idx + dst_w;
let src_w = p.stride * dst_w + offset_w * p.dilation;
let src_w = p.stride * dst_w + offset_w;
if src_w < p.padding || src_w >= p.i_w + p.padding {
continue;
}
@ -1177,96 +1176,6 @@ impl<'a> Map2 for Conv2D<'a> {
}
}
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
impl<'a> Map2 for ConvTranspose2D<'a> {
const OP: &'static str = "conv_transpose2d";
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
let inp = &inp[inp_l.start_offset()..];
let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
let k = &k[k_l.start_offset()..];
let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
let (out_h, out_w) = (p.out_h(), p.out_w());
// Output shape: [b_size, c_out, out_h, out_w].
let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
let dst_s0 = p.c_out * out_h * out_w;
let dst_s1 = out_h * out_w;
let dst_s2 = out_w;
let dst_s3 = 1;
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
let cont_s0 = p.i_h * p.i_w * p.c_in;
let cont_s1 = p.i_w * p.c_in;
let cont_s2 = p.c_in;
for b_idx in 0..p.b_size {
for h_idx in 0..p.i_h {
for w_idx in 0..p.i_w {
for c_idx in 0..p.c_in {
let src_idx =
b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
inp_cont[dst_idx] = inp[src_idx]
}
}
}
}
for k_y in 0..p.k_h {
for k_x in 0..p.k_w {
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
let k_cont = (0..p.c_in)
.map(|c_in_idx| {
k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3]
})
.collect::<Vec<_>>();
for b_idx in 0..p.b_size {
for inp_y in 0..p.i_h {
for inp_x in 0..p.i_w {
let out_x = inp_x * p.stride + k_x * p.dilation;
let out_y = inp_y * p.stride + k_y * p.dilation;
if out_x < p.padding || out_y < p.padding {
continue;
}
let out_x = out_x - p.padding;
let out_y = out_y - p.padding;
if out_x < out_w && out_y < out_h {
let inp_cont = &inp_cont
[b_idx * cont_s0 + inp_y * cont_s1 + inp_x * cont_s2..];
let dst_idx = b_idx * dst_s0
+ out_y * dst_s2
+ out_x * dst_s3
+ dst_c_idx * dst_s1;
let mut d = T::zero();
unsafe {
T::vec_dot(
inp_cont.as_ptr(),
k_cont.as_ptr(),
&mut d,
p.c_in,
)
}
let dst_p = dst.as_ptr();
// Safety: dst_idx are uniques per dst_c_idx which is used to
// parallelise the different tasks so no two threads can try to
// write at the same location.
unsafe {
let ptr = dst_p.add(dst_idx) as *mut T;
*ptr += d
}
}
}
}
}
})
}
}
Ok(dst)
}
}
struct MatMul((usize, usize, usize, usize));
impl MatMul {
@ -1293,11 +1202,6 @@ impl Map2 for MatMul {
rhs_l: &Layout,
) -> Result<Vec<T>> {
use gemm::{gemm, Parallelism};
if T::DTYPE == DType::BF16 {
return Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?;
}
let (b, m, n, k) = self.0;
let lhs = &lhs[lhs_l.start_offset()..];
let rhs = &rhs[rhs_l.start_offset()..];
@ -1614,90 +1518,6 @@ impl CpuStorage {
pub fn as_slice<D: WithDType>(&self) -> Result<&[D]> {
D::cpu_storage_as_slice(self)
}
pub fn concat(storages: &[CpuStorage]) -> Result<CpuStorage> {
let storage0 = &storages[0];
let s = match storage0 {
Self::U8(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::U8(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::U8(storages)
}
Self::U32(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::U32(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::U32(storages)
}
Self::I64(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::I64(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::I64(storages)
}
Self::BF16(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::BF16(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::BF16(storages)
}
Self::F16(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::F16(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::F16(storages)
}
Self::F32(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::F32(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::F32(storages)
}
Self::F64(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::F64(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::F64(storages)
}
};
Ok(s)
}
}
impl BackendStorage for CpuStorage {
@ -1707,7 +1527,6 @@ impl BackendStorage for CpuStorage {
match self {
Self::U8(_) => DType::U8,
Self::U32(_) => DType::U32,
Self::I64(_) => DType::I64,
Self::BF16(_) => DType::BF16,
Self::F16(_) => DType::F16,
Self::F32(_) => DType::F32,
@ -1726,10 +1545,6 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
Ok(Self::BF16(data))
}
(Self::I64(storage), DType::BF16) => {
let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
Ok(Self::BF16(data))
}
(Self::BF16(storage), DType::BF16) => {
let data = unary_map(storage, layout, |v| v);
Ok(Self::BF16(data))
@ -1754,10 +1569,6 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
Ok(Self::F16(data))
}
(Self::I64(storage), DType::F16) => {
let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
Ok(Self::F16(data))
}
(Self::BF16(storage), DType::F16) => {
let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
Ok(Self::F16(data))
@ -1782,10 +1593,6 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as f32);
Ok(Self::F32(data))
}
(Self::I64(storage), DType::F32) => {
let data = unary_map(storage, layout, |v| v as f32);
Ok(Self::F32(data))
}
(Self::BF16(storage), DType::F32) => {
let data = unary_map(storage, layout, |v| v.to_f32());
Ok(Self::F32(data))
@ -1822,26 +1629,18 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as u8);
Ok(Self::U8(data))
}
(Self::U32(storage), DType::U8) => {
let data = unary_map(storage, layout, |v| v as u8);
Ok(Self::U8(data))
}
(Self::I64(storage), DType::U8) => {
let data = unary_map(storage, layout, |v| v as u8);
Ok(Self::U8(data))
}
(Self::U8(storage), DType::U32) => {
let data = unary_map(storage, layout, |v| v as u32);
Ok(Self::U32(data))
}
(Self::U32(storage), DType::U8) => {
let data = unary_map(storage, layout, |v| v as u8);
Ok(Self::U8(data))
}
(Self::U32(storage), DType::U32) => {
let data = unary_map(storage, layout, |v| v);
Ok(Self::U32(data))
}
(Self::I64(storage), DType::U32) => {
let data = unary_map(storage, layout, |v| v as u32);
Ok(Self::U32(data))
}
(Self::BF16(storage), DType::U32) => {
let data = unary_map(storage, layout, |v| v.to_f32() as u32);
Ok(Self::U32(data))
@ -1858,34 +1657,6 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as u32);
Ok(Self::U32(data))
}
(Self::U8(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
}
(Self::U32(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
}
(Self::I64(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v);
Ok(Self::I64(data))
}
(Self::BF16(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v.to_f32() as i64);
Ok(Self::I64(data))
}
(Self::F16(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v.to_f32() as i64);
Ok(Self::I64(data))
}
(Self::F32(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
}
(Self::F64(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
}
(Self::U8(storage), DType::F64) => {
let data = unary_map(storage, layout, |v| v as f64);
Ok(Self::F64(data))
@ -1894,10 +1665,6 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as f64);
Ok(Self::F64(data))
}
(Self::I64(storage), DType::F64) => {
let data = unary_map(storage, layout, |v| v as f64);
Ok(Self::F64(data))
}
(Self::BF16(storage), DType::F64) => {
let data = unary_map(storage, layout, |v| v.to_f64());
Ok(Self::F64(data))
@ -2003,32 +1770,6 @@ impl BackendStorage for CpuStorage {
UpsampleNearest2D(h, w).map(self, layout)
}
fn powf(&self, layout: &Layout, e: f64) -> Result<Self> {
use num_traits::Float;
// TODO: Have some generic map for functions that apply on num_traits::Float elements.
match self {
Self::BF16(storage) => {
let data = unary_map(storage, layout, |v| v.powf(bf16::from_f64(e)));
Ok(Self::BF16(data))
}
Self::F16(storage) => {
let data = unary_map(storage, layout, |v| v.powf(f16::from_f64(e)));
Ok(Self::F16(data))
}
Self::F32(storage) => {
let data = unary_map(storage, layout, |v| v.powf(e as f32));
Ok(Self::F32(data))
}
Self::F64(storage) => {
let data = unary_map(storage, layout, |v| v.powf(e));
Ok(Self::F64(data))
}
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
}
}
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
// TODO: Have some generic map for functions that apply on num_traits::Float elements.
match self {
@ -2050,7 +1791,6 @@ impl BackendStorage for CpuStorage {
}
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
}
}
@ -2100,10 +1840,6 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, B::u32);
Ok(Self::U32(data))
}
Self::I64(storage) => {
let data = unary_map(storage, layout, B::i64);
Ok(Self::I64(data))
}
}
}
@ -2154,14 +1890,6 @@ impl BackendStorage for CpuStorage {
};
Ok(Self::U32(data))
}
(Self::I64(lhs), Self::I64(rhs)) => {
let data = if B::I64_VEC {
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i64, B::i64_vec)
} else {
binary_map(lhs_l, rhs_l, lhs, rhs, B::i64)
};
Ok(Self::I64(data))
}
(Self::U8(lhs), Self::U8(rhs)) => {
let data = if B::U8_VEC {
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u8, B::u8_vec)
@ -2186,7 +1914,6 @@ impl BackendStorage for CpuStorage {
match (self, dst) {
(Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
@ -2215,7 +1942,6 @@ impl BackendStorage for CpuStorage {
match self {
Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")),
}
}
@ -2240,21 +1966,10 @@ impl BackendStorage for CpuStorage {
Conv2D(params).map(self, l, kernel, kernel_l)
}
fn conv_transpose2d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose2D,
) -> Result<Self> {
ConvTranspose2D(params).map(self, l, kernel, kernel_l)
}
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
match ids {
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select")),
}
}
@ -2263,7 +1978,6 @@ impl BackendStorage for CpuStorage {
match ids {
Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather")),
}
}
@ -2280,7 +1994,6 @@ impl BackendStorage for CpuStorage {
match ids {
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add")),
}
}
@ -2309,13 +2022,6 @@ impl BackendStorage for CpuStorage {
};
IndexAdd { ids, dim }.map(self, l, src, src_l)
}
Self::I64(ids) => {
let ids = match ids_l.contiguous_offsets() {
Some((a, b)) => &ids[a..b],
None => Err(Error::RequiresContiguous { op: "index-add" })?,
};
IndexAdd { ids, dim }.map(self, l, src, src_l)
}
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")),
}
}
@ -2368,9 +2074,7 @@ impl BackendDevice for CpuDevice {
let elem_count = shape.elem_count();
let mut rng = rand::thread_rng();
match dtype {
DType::U8 | DType::U32 | DType::I64 => {
Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt())
}
DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()),
DType::BF16 => {
let mut data = Vec::with_capacity(elem_count);
let uniform =
@ -2414,9 +2118,7 @@ impl BackendDevice for CpuDevice {
let elem_count = shape.elem_count();
let mut rng = rand::thread_rng();
match dtype {
DType::U8 | DType::U32 | DType::I64 => {
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())
}
DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()),
DType::BF16 => {
let mut data = Vec::with_capacity(elem_count);
let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std))
@ -2460,7 +2162,6 @@ impl BackendDevice for CpuDevice {
let storage = match dtype {
DType::U8 => CpuStorage::U8(vec![1u8; elem_count]),
DType::U32 => CpuStorage::U32(vec![1u32; elem_count]),
DType::I64 => CpuStorage::I64(vec![1i64; elem_count]),
DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]),
DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]),
DType::F32 => CpuStorage::F32(vec![1f32; elem_count]),
@ -2474,7 +2175,6 @@ impl BackendDevice for CpuDevice {
let storage = match dtype {
DType::U8 => CpuStorage::U8(vec![0u8; elem_count]),
DType::U32 => CpuStorage::U32(vec![0u32; elem_count]),
DType::I64 => CpuStorage::I64(vec![0i64; elem_count]),
DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]),
DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]),
DType::F32 => CpuStorage::F32(vec![0f32; elem_count]),

View File

@ -1,7 +1,7 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
pub use candle_kernels as kernels;
use candle_kernels as kernels;
pub use cudarc;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
use cudarc::driver::{
@ -139,14 +139,6 @@ impl CudaDevice {
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
let params = (&data, v as i64, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
@ -244,10 +236,6 @@ impl BackendDevice for CudaDevice {
let data = self.alloc_zeros::<u32>(elem_count).w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
let data = self.alloc_zeros::<i64>(elem_count).w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
CudaStorageSlice::BF16(data)
@ -277,13 +265,11 @@ impl BackendDevice for CudaDevice {
let slice = match dtype {
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
// cudarc changes.
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_uniform",
})
.w()?
}
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_uniform",
})
.w()?,
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
curand.0.fill_with_uniform(&mut data).w()?;
@ -295,12 +281,10 @@ impl BackendDevice for CudaDevice {
CudaStorageSlice::F64(data)
}
};
let slice = if lo == 0. && up == 1.0 {
slice
} else {
if lo != 0.0 || up != 1.0 {
let layout = Layout::contiguous(shape);
Affine(up - lo, lo).map(&slice, self, &layout)?
};
Affine(up - lo, lo).map(&slice, self, &layout)?;
}
Ok(CudaStorage {
slice,
device: self.clone(),
@ -313,13 +297,11 @@ impl BackendDevice for CudaDevice {
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
let slice = match dtype {
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_normal",
})
.w()?
}
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_normal",
})
.w()?,
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
curand
@ -354,10 +336,6 @@ impl BackendDevice for CudaDevice {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorage::I64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::I64(data)
}
CpuStorage::BF16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::BF16(data)
@ -383,10 +361,9 @@ impl BackendDevice for CudaDevice {
}
#[derive(Debug)]
pub enum CudaStorageSlice {
enum CudaStorageSlice {
U8(CudaSlice<u8>),
U32(CudaSlice<u32>),
I64(CudaSlice<i64>),
BF16(CudaSlice<bf16>),
F16(CudaSlice<f16>),
F32(CudaSlice<f32>),
@ -394,7 +371,7 @@ pub enum CudaStorageSlice {
}
type S = CudaStorageSlice;
pub trait Map1 {
trait Map1 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
@ -406,7 +383,6 @@ pub trait Map1 {
let out = match s {
S::U8(s) => S::U8(self.f(s, d, l)?),
S::U32(s) => S::U32(self.f(s, d, l)?),
S::I64(s) => S::I64(self.f(s, d, l)?),
S::BF16(s) => S::BF16(self.f(s, d, l)?),
S::F16(s) => S::F16(self.f(s, d, l)?),
S::F32(s) => S::F32(self.f(s, d, l)?),
@ -416,7 +392,7 @@ pub trait Map1 {
}
}
pub trait Map2 {
trait Map2 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
@ -430,7 +406,6 @@ pub trait Map2 {
let out = match (s1, s2) {
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
(S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?),
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
@ -441,7 +416,7 @@ pub trait Map2 {
}
}
pub trait Map2InPlace {
trait Map2InPlace {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
@ -462,7 +437,6 @@ pub trait Map2InPlace {
match (dst, src) {
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
@ -472,7 +446,7 @@ pub trait Map2InPlace {
}
}
pub trait Map1Any {
trait Map1Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
src: &CudaSlice<T>,
@ -485,7 +459,6 @@ pub trait Map1Any {
let out = match s {
S::U8(s) => self.f(s, d, l, S::U8)?,
S::U32(s) => self.f(s, d, l, S::U32)?,
S::I64(s) => self.f(s, d, l, S::I64)?,
S::BF16(s) => self.f(s, d, l, S::BF16)?,
S::F16(s) => self.f(s, d, l, S::F16)?,
S::F32(s) => self.f(s, d, l, S::F32)?,
@ -495,7 +468,7 @@ pub trait Map1Any {
}
}
pub trait Map2Any {
trait Map2Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
@ -509,7 +482,6 @@ pub trait Map2Any {
let out = match (s1, s2) {
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
@ -532,7 +504,7 @@ impl Map1 for Clone {
}
}
pub fn kernel_name<T: WithDType>(root: &str) -> String {
fn kernel_name<T: WithDType>(root: &str) -> String {
let dtype = T::DTYPE.as_str();
format!("{root}_{dtype}")
}
@ -593,30 +565,6 @@ impl Map1 for Elu {
}
}
struct Powf(f64);
impl Map1 for Powf {
fn f<T: DeviceRepr + WithDType>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
let shape = layout.shape();
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("upowf"), kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (el, dims.len(), &ds, T::from_f64(self.0), src, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
struct Sum<'a>(&'a [usize]);
impl<'a> Map1 for Sum<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
@ -766,9 +714,6 @@ impl<'a> Map1 for IndexSelect<'a> {
CudaStorageSlice::U8(slice) => {
("is_u8", *slice.slice(ids_l.start_offset()..).device_ptr())
}
CudaStorageSlice::I64(slice) => {
("is_i64", *slice.slice(ids_l.start_offset()..).device_ptr())
}
_ => Err(CudaError::UnexpectedDType {
msg: "index_select ids should be u8 or u32",
expected: DType::U32,
@ -828,11 +773,8 @@ impl<'a> Map1 for Gather<'a> {
("gather_u32", *slice.slice(ids_o1..ids_o2).device_ptr())
}
CudaStorageSlice::U8(slice) => ("gather_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::I64(slice) => {
("gather_i64", *slice.slice(ids_o1..ids_o2).device_ptr())
}
_ => Err(CudaError::UnexpectedDType {
msg: "gather ids should be u8/u32/i64",
msg: "gather ids should be u8 or u32",
expected: DType::U32,
got: ids.dtype(),
})?,
@ -878,10 +820,9 @@ impl<'a> Map2InPlace for IndexAdd<'a> {
};
let (name, ids) = match &ids.slice {
CudaStorageSlice::U32(slice) => ("ia_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::I64(slice) => ("ia_i64", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::U8(slice) => ("ia_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
_ => Err(CudaError::UnexpectedDType {
msg: "index-add ids should be u8/u32/i64",
msg: "index-add ids should be u8 or u32",
expected: DType::U32,
got: ids.dtype(),
})?,
@ -926,10 +867,9 @@ impl<'a> Map2InPlace for ScatterAdd<'a> {
};
let (name, ids) = match &ids.slice {
CudaStorageSlice::U32(slice) => ("sa_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::I64(slice) => ("sa_i64", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::U8(slice) => ("sa_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
_ => Err(CudaError::UnexpectedDType {
msg: "scatter-add ids should be u8/u32/i64",
msg: "scatter-add ids should be u8 or u32",
expected: DType::U32,
got: ids.dtype(),
})?,
@ -981,12 +921,10 @@ impl<'a> Map2 for Conv1D<'a> {
} else if dims.len() == 2 {
[&[1], dims, &[1], inp_l.stride(), k_l.dims(), k_l.stride()].concat()
} else {
crate::bail!("unexpected input shape for conv1d {dims:?}")
panic!("unexpected input shape for conv1d {dims:?}")
};
let ds = dev.htod_copy(ds).w()?;
let params = (
el, l_out, p.stride, p.padding, p.dilation, &ds, inp, k, &out,
);
let params = (el, l_out, p.stride, p.padding, &ds, inp, k, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
@ -1003,8 +941,8 @@ impl<'a> Map2 for Conv2D<'a> {
k_l: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
// Kernel shape: (c_out, c_in_k, h_k, w_k)
// Input shape: (b_size, c_in, h_in, w_in)
// Kernel shape: (c_out, c_in_k, w_k, h_k)
// Input shape: (b_size, c_in, w_in, c_in)
let p = &self.0;
let (out_w, out_h) = (p.out_w(), p.out_h());
let dst_el = p.c_out * out_w * out_h * p.b_size;
@ -1021,62 +959,10 @@ impl<'a> Map2 for Conv2D<'a> {
let ds = if dims.len() == 4 {
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
} else {
crate::bail!("unexpected input shape for conv2d {dims:?}")
panic!("unexpected input shape for conv1d {dims:?}")
};
let ds = dev.htod_copy(ds).w()?;
let params = (
el, out_w, out_h, p.stride, p.padding, p.dilation, &ds, inp, k, &out,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
impl<'a> Map2 for ConvTranspose2D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
inp: &CudaSlice<T>,
inp_l: &Layout,
k: &CudaSlice<T>,
k_l: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
// Kernel shape: (c_in_k, c_out, h_k, w_k)
// Input shape: (b_size, c_in, h_in, w_in)
let p = &self.0;
let (out_w, out_h) = (p.out_w(), p.out_h());
let dst_el = p.c_out * out_w * out_h * p.b_size;
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(k_l.start_offset()..);
let shape = inp_l.shape();
let dims = shape.dims();
let el = shape.elem_count();
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose2d"), kernels::CONV)?;
let ds = if dims.len() == 4 {
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
} else {
crate::bail!("unexpected input shape for conv_transpose2d {dims:?}")
};
let ds = dev.htod_copy(ds).w()?;
let params = (
el,
out_w,
out_h,
p.stride,
p.padding,
p.output_padding,
p.dilation,
&ds,
inp,
k,
&out,
);
let params = (el, out_w, out_h, p.stride, p.padding, &ds, inp, k, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
@ -1110,7 +996,7 @@ impl Map1 for Pool2D {
let ds = if dims.len() == 4 {
[dims, inp_l.stride()].concat()
} else {
crate::bail!("unexpected input shape for pool {dims:?}")
panic!("unexpected input shape for conv1d {dims:?}")
};
let el = shape.elem_count();
let out_w = (dims[2] - self.w_k) / self.w_stride + 1;
@ -1156,7 +1042,7 @@ impl Map1 for UpsampleNearest2D {
let ds = if dims.len() == 4 {
[dims, inp_l.stride()].concat()
} else {
crate::bail!("unexpected input shape for upsample {dims:?}")
panic!("unexpected input shape for conv1d {dims:?}")
};
let (out_w, out_h) = (self.0, self.1);
let dst_el = out_w * out_h * dims[0] * dims[1];
@ -1194,12 +1080,8 @@ impl<'a> Map2 for WhereCond<'a> {
let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
(ptr, "where_u32")
}
CudaStorageSlice::I64(slice) => {
let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
(ptr, "where_i64")
}
_ => Err(CudaError::UnexpectedDType {
msg: "where conditions should be u8/u32/i64",
msg: "where conditions should be u8 or u32",
expected: DType::U32,
got: self.0.dtype(),
})
@ -1310,8 +1192,8 @@ fn slice_src_and_dst<'a, T>(
#[derive(Debug)]
pub struct CudaStorage {
pub slice: CudaStorageSlice,
pub device: CudaDevice,
slice: CudaStorageSlice,
device: CudaDevice,
}
pub trait CudaDType: Sized {
@ -1343,7 +1225,6 @@ macro_rules! cuda_dtype {
}
cuda_dtype!(u8, U8);
cuda_dtype!(u32, U32);
cuda_dtype!(i64, I64);
cuda_dtype!(f16, F16);
cuda_dtype!(bf16, BF16);
cuda_dtype!(f32, F32);
@ -1457,7 +1338,6 @@ impl BackendStorage for CudaStorage {
match self.slice {
CudaStorageSlice::U8(_) => DType::U8,
CudaStorageSlice::U32(_) => DType::U32,
CudaStorageSlice::I64(_) => DType::I64,
CudaStorageSlice::BF16(_) => DType::BF16,
CudaStorageSlice::F16(_) => DType::F16,
CudaStorageSlice::F32(_) => DType::F32,
@ -1483,7 +1363,6 @@ impl BackendStorage for CudaStorage {
let inp = match &self.slice {
CudaStorageSlice::U8(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::I64(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::F32(inp) => *inp.slice(start_o..).device_ptr(),
@ -1506,12 +1385,6 @@ impl BackendStorage for CudaStorage {
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U32(out)
}
DType::I64 => {
let out = unsafe { dev.alloc::<i64>(el) }.w()?;
let params = (el, dims.len(), &ds, *inp, &out);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::I64(out)
}
DType::BF16 => {
let out = unsafe { dev.alloc::<bf16>(el) }.w()?;
let params = (el, dims.len(), &ds, *inp, &out);
@ -1549,12 +1422,6 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
fn powf(&self, layout: &Layout, e: f64) -> Result<Self> {
let device = self.device().clone();
let slice = Powf(e).map(&self.slice, &device, layout)?;
Ok(Self { slice, device })
}
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
let device = self.device().clone();
let slice = Elu(alpha).map(&self.slice, &device, layout)?;
@ -1602,11 +1469,6 @@ impl BackendStorage for CudaStorage {
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
Ok(CpuStorage::U32(cpu_storage))
}
CudaStorageSlice::I64(slice) => {
let dev = slice.device();
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
Ok(CpuStorage::I64(cpu_storage))
}
CudaStorageSlice::BF16(slice) => {
let dev = slice.device();
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
@ -1708,6 +1570,7 @@ impl BackendStorage for CudaStorage {
.map_err(crate::Error::wrap)?;
S::F16(out)
}
(S::F32(inp), S::F32(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
@ -1725,25 +1588,11 @@ impl BackendStorage for CudaStorage {
S::F64(out)
}
(S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv2d does not support u32"))?,
(S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv2d does not support i64"))?,
_ => Err(CudaError::InternalError("dtype mismatch in conv2d"))?,
};
Ok(Self { slice, device })
}
fn conv_transpose2d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose2D,
) -> Result<Self> {
let device = self.device().clone();
let slice =
ConvTranspose2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
Ok(Self { slice, device })
}
fn avg_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result<Self> {
let device = self.device().clone();
let slice = Pool2D {
@ -1889,9 +1738,6 @@ impl BackendStorage for CudaStorage {
let src_shape = src_l.shape();
let dims = src_shape.dims();
let el_count = src_shape.elem_count();
if el_count == 0 {
return Ok(());
}
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let dev = &self.device;
let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?;
@ -1956,18 +1802,6 @@ impl BackendStorage for CudaStorage {
unsafe { func.launch(cfg, params) }.w()?
}
}
(CudaStorageSlice::I64(src), CudaStorageSlice::I64(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
dev.dtod_copy(&src, &mut dst).w()?
} else {
let func = dev.get_or_load_func("ucopy_i64", kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, &src, &mut dst);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?
}
}
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {

View File

@ -48,14 +48,14 @@ pub(crate) fn launch_conv2d<
let conv = cudnn.create_conv2d::<T>(
/* pad */ [params.padding as i32, params.padding as i32],
/* stride */ [params.stride as i32, params.stride as i32],
/* dilation */ [params.dilation as i32, params.dilation as i32],
/* dilation */ [1, 1],
cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
)?;
let x_shape = [
params.b_size as i32,
params.c_in as i32,
params.i_h as i32,
params.i_w as i32,
params.i_h as i32,
];
// Note that `src` already starts at the proper offset.
let x = if src_l.is_contiguous() {
@ -75,14 +75,14 @@ pub(crate) fn launch_conv2d<
[
params.c_out as i32,
params.c_in as i32,
params.k_h as i32,
params.k_w as i32,
params.k_h as i32,
],
)?;
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
let y = cudnn.create_4d_tensor(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[params.b_size as i32, params.c_out as i32, h_out, w_out],
[params.b_size as i32, params.c_out as i32, w_out, h_out],
)?;
let conv2d = Conv2dForward {
conv: &conv,

View File

@ -16,6 +16,7 @@ pub enum Device {
Cuda(crate::CudaDevice),
}
// TODO: Should we back the cpu implementation using the NdArray crate or similar?
pub trait NdArray {
fn shape(&self) -> Result<Shape>;
@ -80,49 +81,6 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
}
}
impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4: usize> NdArray
for &[[[[S; N4]; N3]; N2]; N1]
{
fn shape(&self) -> Result<Shape> {
Ok(Shape::from((N1, N2, N3, N4)))
}
fn to_cpu_storage(&self) -> CpuStorage {
let mut vec = Vec::with_capacity(N1 * N2 * N3 * N4);
for i1 in 0..N1 {
for i2 in 0..N2 {
for i3 in 0..N3 {
vec.extend(self[i1][i2][i3])
}
}
}
S::to_cpu_storage_owned(vec)
}
}
impl<S: NdArray> NdArray for Vec<S> {
fn shape(&self) -> Result<Shape> {
if self.is_empty() {
crate::bail!("empty array")
}
let shape0 = self[0].shape()?;
let n = self.len();
for v in self.iter() {
let shape = v.shape()?;
if shape != shape0 {
crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
}
}
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
}
fn to_cpu_storage(&self) -> CpuStorage {
// This allocates intermediary memory and shouldn't be necessary.
let storages = self.iter().map(|v| v.to_cpu_storage()).collect::<Vec<_>>();
CpuStorage::concat(storages.as_slice()).unwrap()
}
}
impl Device {
pub fn new_cuda(ordinal: usize) -> Result<Self> {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))

View File

@ -9,14 +9,11 @@ impl Tensor {
&self,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
let device_str = match self.device().location() {
crate::DeviceLocation::Cpu => "".to_owned(),
crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id)
}
let prefix = match self.device() {
crate::Device::Cpu => "Cpu",
crate::Device::Cuda(_) => "Cuda",
};
write!(f, "Tensor[")?;
write!(f, "{prefix}Tensor[")?;
match self.dims() {
[] => {
if let Ok(v) = self.to_scalar::<T>() {
@ -43,7 +40,7 @@ impl Tensor {
}
}
}
write!(f, "; {}{}]", self.dtype().as_str(), device_str)
write!(f, "; {}]", self.dtype().as_str())
}
}
@ -52,7 +49,6 @@ impl std::fmt::Debug for Tensor {
match self.dtype() {
DType::U8 => self.fmt_dt::<u8>(f),
DType::U32 => self.fmt_dt::<u32>(f),
DType::I64 => self.fmt_dt::<i64>(f),
DType::BF16 => self.fmt_dt::<bf16>(f),
DType::F16 => self.fmt_dt::<f16>(f),
DType::F32 => self.fmt_dt::<f32>(f),
@ -435,12 +431,6 @@ impl std::fmt::Display for Tensor {
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
DType::I64 => {
let tf: IntFormatter<i64> = IntFormatter::new();
let max_w = tf.max_width(&to_display);
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
DType::BF16 => {
if let Ok(tf) = FloatFormatter::<bf16>::new(&to_display, &po) {
let max_w = tf.max_width(&to_display);
@ -470,20 +460,6 @@ impl std::fmt::Display for Tensor {
}
}
};
let device_str = match self.device().location() {
crate::DeviceLocation::Cpu => "".to_owned(),
crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id)
}
};
write!(
f,
"Tensor[{:?}, {}{}]",
self.dims(),
self.dtype().as_str(),
device_str
)
write!(f, "Tensor[{:?}, {}]", self.dims(), self.dtype().as_str())
}
}

View File

@ -1,24 +1,13 @@
//! Types for elements that can be stored and manipulated using tensors.
#![allow(clippy::redundant_closure_call)]
use crate::backend::BackendStorage;
use crate::{CpuStorage, Error, Result};
/// The different types of elements allowed in tensors.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum DType {
// Unsigned 8 bits integer.
U8,
// Unsigned 32 bits integer.
U32,
// Signed 64 bits integer.
I64,
// Brain floating-point using half precision (16 bits).
BF16,
// Floating-point using half precision (16 bits).
F16,
// Floating-point using single precision (32 bits).
F32,
// Floating-point using double precision (64 bits).
F64,
}
@ -31,7 +20,6 @@ impl std::str::FromStr for DType {
match s {
"u8" => Ok(Self::U8),
"u32" => Ok(Self::U32),
"i64" => Ok(Self::I64),
"bf16" => Ok(Self::BF16),
"f16" => Ok(Self::F16),
"f32" => Ok(Self::F32),
@ -42,12 +30,10 @@ impl std::str::FromStr for DType {
}
impl DType {
/// String representation for dtypes.
pub fn as_str(&self) -> &'static str {
match self {
Self::U8 => "u8",
Self::U32 => "u32",
Self::I64 => "i64",
Self::BF16 => "bf16",
Self::F16 => "f16",
Self::F32 => "f32",
@ -55,12 +41,10 @@ impl DType {
}
}
/// The size used by each element in bytes, i.e. 1 for `U8`, 4 for `F32`.
pub fn size_in_bytes(&self) -> usize {
match self {
Self::U8 => 1,
Self::U32 => 4,
Self::I64 => 8,
Self::BF16 => 2,
Self::F16 => 2,
Self::F32 => 4,
@ -141,7 +125,6 @@ use half::{bf16, f16};
with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64);
with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64);
with_dtype!(i64, I64, |v: f64| v as i64, |v: i64| v as f64);
with_dtype!(f16, F16, f16::from_f64, f16::to_f64);
with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
@ -152,15 +135,6 @@ pub trait IntDType: WithDType {
fn as_usize(&self) -> usize;
}
impl IntDType for i64 {
fn is_true(&self) -> bool {
*self != 0
}
fn as_usize(&self) -> usize {
*self as usize
}
}
impl IntDType for u32 {
fn is_true(&self) -> bool {
*self != 0

View File

@ -37,10 +37,6 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
@ -89,16 +85,6 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
fn conv_transpose2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose2D,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}

View File

@ -30,7 +30,7 @@ pub enum Error {
UnsupportedDTypeForOp(DType, &'static str),
// === Dimension Index Errors ===
#[error("{op}: dimension index {dim} out of range for shape {shape:?}")]
#[error("{op}: dimension index {dim} out of range for {shape:?}")]
DimOutOfRange {
shape: Shape,
dim: i32,
@ -207,11 +207,7 @@ pub type Result<T> = std::result::Result<T, Error>;
impl Error {
pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self {
Self::Wrapped(Box::new(err)).bt()
}
pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self {
Self::Msg(err.to_string()).bt()
Self::Wrapped(Box::new(err))
}
pub fn bt(self) -> Self {

View File

@ -9,14 +9,6 @@ pub struct Layout {
}
impl Layout {
pub fn new(shape: Shape, stride: Vec<usize>, start_offset: usize) -> Self {
Self {
shape,
stride,
start_offset,
}
}
pub fn contiguous_with_offset<S: Into<Shape>>(shape: S, start_offset: usize) -> Self {
let shape = shape.into();
let stride = shape.stride_contiguous();
@ -120,31 +112,6 @@ impl Layout {
})
}
pub(crate) fn permute(&self, idxs: &[usize]) -> Result<Self> {
let is_permutation =
idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
if !is_permutation {
crate::bail!(
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
self.dims(),
idxs
)
}
let stride = self.stride();
let dims = self.shape().dims();
let mut perm_stride = stride.to_vec();
let mut perm_dims = dims.to_vec();
for (i, &idx) in idxs.iter().enumerate() {
perm_stride[i] = stride[idx];
perm_dims[i] = dims[idx];
}
Ok(Self {
shape: Shape::from(perm_dims),
stride: perm_stride,
start_offset: self.start_offset,
})
}
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
let shape = shape.into();
if shape.rank() < self.shape().rank() {

View File

@ -56,15 +56,12 @@ pub mod layout;
mod mkl;
pub mod npy;
mod op;
pub mod pickle;
pub mod quantized;
pub mod safetensors;
pub mod scalar;
pub mod shape;
mod storage;
mod strided_index;
mod tensor;
pub mod test_utils;
pub mod utils;
mod variable;
@ -89,39 +86,3 @@ pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
pub trait ToUsize2 {
fn to_usize2(self) -> (usize, usize);
}
impl ToUsize2 for usize {
fn to_usize2(self) -> (usize, usize) {
(self, self)
}
}
impl ToUsize2 for (usize, usize) {
fn to_usize2(self) -> (usize, usize) {
self
}
}
// A simple trait defining a module with forward method using a single argument.
pub trait Module: std::fmt::Debug {
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
/// Change the module to use training mode vs eval mode.
///
/// The default implementation does nothing as this is only used for a couple modules such as
/// dropout or batch-normalization.
fn set_training(&mut self, _training: bool) {}
}
impl Module for quantized::QMatMul {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self.forward(xs)
}
}

View File

@ -25,10 +25,6 @@ mod ffi {
pub fn vdMul(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn vsDiv(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdDiv(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn vsFmax(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdFmax(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn vsFmin(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdFmin(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn sgemm_(
transa: *const c_char,
@ -301,7 +297,7 @@ pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
}
#[inline]
pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
fn vs_tanh(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
@ -311,7 +307,7 @@ pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
}
#[inline]
pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
fn vd_tanh(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
@ -380,7 +376,3 @@ binary_op!(vs_mul, f32, vsMul);
binary_op!(vd_mul, f64, vdMul);
binary_op!(vs_div, f32, vsDiv);
binary_op!(vd_div, f64, vdDiv);
binary_op!(vs_max, f32, vsFmax);
binary_op!(vd_max, f64, vdFmax);
binary_op!(vs_min, f32, vsFmin);
binary_op!(vd_min, f64, vdFmin);

View File

@ -85,7 +85,6 @@ impl Header {
DType::F16 => "f2",
DType::F32 => "f4",
DType::F64 => "f8",
DType::I64 => "i8",
DType::U32 => "u4",
DType::U8 => "u1",
};
@ -161,7 +160,7 @@ impl Header {
"f" | "f4" => DType::F32,
"d" | "f8" => DType::F64,
// "i" | "i4" => DType::S32,
"q" | "i8" => DType::I64,
// "q" | "i8" => DType::S64,
// "h" | "i2" => DType::S16,
// "b" | "i1" => DType::S8,
"B" | "u1" => DType::U8,
@ -197,11 +196,7 @@ impl Header {
impl Tensor {
// TODO: Add the possibility to read directly to a device?
pub(crate) fn from_reader<R: std::io::Read>(
shape: Shape,
dtype: DType,
reader: &mut R,
) -> Result<Self> {
fn from_reader<R: std::io::Read>(shape: Shape, dtype: DType, reader: &mut R) -> Result<Self> {
let elem_count = shape.elem_count();
match dtype {
DType::BF16 => {
@ -234,11 +229,6 @@ impl Tensor {
reader.read_u32_into::<LittleEndian>(&mut data_t)?;
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
DType::I64 => {
let mut data_t = vec![0i64; elem_count];
reader.read_i64_into::<LittleEndian>(&mut data_t)?;
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
}
}
@ -371,25 +361,6 @@ impl NpzTensors {
})
}
pub fn names(&self) -> Vec<&String> {
self.index_per_name.keys().collect()
}
/// This only returns the shape and dtype for a named tensor. Compared to `get`, this avoids
/// reading the whole tensor data.
pub fn get_shape_and_dtype(&self, name: &str) -> Result<(Shape, DType)> {
let index = match self.index_per_name.get(name) {
None => crate::bail!("cannot find tensor {name}"),
Some(index) => *index,
};
let zip_reader = BufReader::new(File::open(&self.path)?);
let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut reader = zip.by_index(index)?;
let header = read_header(&mut reader)?;
let header = Header::parse(&header)?;
Ok((header.shape(), header.descr))
}
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
let index = match self.index_per_name.get(name) {
None => return Ok(None),

View File

@ -1,4 +1,3 @@
#![allow(clippy::redundant_closure_call)]
use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor};
use half::{bf16, f16};
use num_traits::float::Float;
@ -41,8 +40,6 @@ pub enum BinaryOp {
Mul,
Sub,
Div,
Maximum,
Minimum,
}
// Unary ops with no argument
@ -59,7 +56,6 @@ pub enum UnaryOp {
Sqrt,
Gelu,
Relu,
Tanh,
}
#[derive(Clone)]
@ -82,7 +78,6 @@ pub enum Op {
kernel: Tensor,
padding: usize,
stride: usize,
dilation: usize,
},
#[allow(dead_code)]
@ -91,17 +86,6 @@ pub enum Op {
kernel: Tensor,
padding: usize,
stride: usize,
dilation: usize,
},
#[allow(dead_code)]
ConvTranspose2D {
arg: Tensor,
kernel: Tensor,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
},
AvgPool2D {
@ -133,25 +117,14 @@ pub enum Op {
Reshape(Tensor),
ToDevice(Tensor),
Transpose(Tensor, usize, usize),
Permute(Tensor, Vec<usize>),
Elu(Tensor, f64),
Powf(Tensor, f64),
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>),
CustomOp2(
Tensor,
Tensor,
std::sync::Arc<Box<dyn CustomOp2 + Send + Sync>>,
),
CustomOp3(
Tensor,
Tensor,
Tensor,
std::sync::Arc<Box<dyn CustomOp3 + Send + Sync>>,
),
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1>>),
CustomOp2(Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp2>>),
CustomOp3(Tensor, Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp3>>),
}
/// Unary ops that can be defined in user-land.
pub trait CustomOp1 {
pub trait CustomOp1: Send + Sync {
// Box<dyn> does not support const yet, so use a function to get the name.
fn name(&self) -> &'static str;
@ -175,7 +148,7 @@ pub trait CustomOp1 {
}
}
pub trait CustomOp2 {
pub trait CustomOp2: Send + Sync {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
@ -213,7 +186,7 @@ pub trait CustomOp2 {
}
}
pub trait CustomOp3 {
pub trait CustomOp3: Send + Sync {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
@ -266,7 +239,6 @@ pub trait UnaryOpT {
fn f64(v1: f64) -> f64;
fn u8(v1: u8) -> u8;
fn u32(v1: u32) -> u32;
fn i64(v1: i64) -> i64;
// There is no very good way to represent optional function in traits so we go for an explicit
// boolean flag to mark the function as existing.
@ -290,7 +262,6 @@ pub trait BinaryOpT {
fn f64(v1: f64, v2: f64) -> f64;
fn u8(v1: u8, v2: u8) -> u8;
fn u32(v1: u32, v2: u32) -> u32;
fn i64(v1: i64, v2: i64) -> i64;
const BF16_VEC: bool = false;
fn bf16_vec(_xs1: &[bf16], _xs2: &[bf16], _ys: &mut [bf16]) {}
@ -304,16 +275,12 @@ pub trait BinaryOpT {
fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {}
const U32_VEC: bool = false;
fn u32_vec(_xs1: &[u32], _xs2: &[u32], _ys: &mut [u32]) {}
const I64_VEC: bool = false;
fn i64_vec(_xs1: &[i64], _xs2: &[i64], _ys: &mut [i64]) {}
}
pub(crate) struct Add;
pub(crate) struct Div;
pub(crate) struct Mul;
pub(crate) struct Sub;
pub(crate) struct Maximum;
pub(crate) struct Minimum;
pub(crate) struct Exp;
pub(crate) struct Log;
pub(crate) struct Sin;
@ -325,7 +292,6 @@ pub(crate) struct Sqr;
pub(crate) struct Sqrt;
pub(crate) struct Gelu;
pub(crate) struct Relu;
pub(crate) struct Tanh;
macro_rules! bin_op {
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
@ -357,10 +323,6 @@ macro_rules! bin_op {
fn u32(v1: u32, v2: u32) -> u32 {
$e(v1, v2)
}
#[inline(always)]
fn i64(v1: i64, v2: i64) -> i64 {
$e(v1, v2)
}
#[cfg(feature = "mkl")]
const F32_VEC: bool = true;
@ -399,22 +361,7 @@ bin_op!(Add, "add", |v1, v2| v1 + v2, vs_add, vd_add);
bin_op!(Sub, "sub", |v1, v2| v1 - v2, vs_sub, vd_sub);
bin_op!(Mul, "mul", |v1, v2| v1 * v2, vs_mul, vd_mul);
bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div);
bin_op!(
Minimum,
"minimum",
|v1, v2| if v1 > v2 { v2 } else { v1 },
vs_min,
vd_min
);
bin_op!(
Maximum,
"maximum",
|v1, v2| if v1 < v2 { v2 } else { v1 },
vs_max,
vd_max
);
#[allow(clippy::redundant_closure_call)]
macro_rules! unary_op {
($op: ident, $name: literal, $a: ident, $e: expr) => {
impl UnaryOpT for $op {
@ -445,10 +392,6 @@ macro_rules! unary_op {
fn u32(_: u32) -> u32 {
todo!("no unary function for u32")
}
#[inline(always)]
fn i64(_: i64) -> i64 {
todo!("no unary function for i64")
}
}
};
@ -481,10 +424,6 @@ macro_rules! unary_op {
fn u32(_: u32) -> u32 {
todo!("no unary function for u32")
}
#[inline(always)]
fn i64(_: i64) -> i64 {
todo!("no unary function for i64")
}
#[cfg(feature = "mkl")]
const F32_VEC: bool = true;
@ -523,7 +462,6 @@ unary_op!(Exp, "exp", v, v.exp(), vs_exp, vd_exp);
unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
unary_op!(Abs, "abs", v, v.abs());
unary_op!(Neg, "neg", v, -v);
unary_op!(Recip, "recip", v, v.recip());
@ -577,10 +515,6 @@ impl UnaryOpT for Gelu {
fn u32(_: u32) -> u32 {
0
}
#[inline(always)]
fn i64(_: i64) -> i64 {
0
}
const KERNEL: &'static str = "ugelu";
#[cfg(feature = "mkl")]
@ -630,10 +564,6 @@ impl UnaryOpT for Relu {
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v
}
}
/// `BackpropOp` is a wrapper around `Option<Op>`. The main goal is to ensure that dependencies are

View File

@ -1,725 +0,0 @@
// Just enough pickle support to be able to read PyTorch checkpoints.
// This hardcodes objects that are required for tensor reading, we may want to make this a bit more
// composable/tensor agnostic at some point.
use crate::{DType, Error as E, Layout, Result, Tensor};
use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap;
use std::io::BufRead;
const VERBOSE: bool = false;
// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/
#[repr(u8)]
#[derive(Debug, Eq, PartialEq, Clone)]
pub enum OpCode {
// https://github.com/python/cpython/blob/ed25f097160b5cbb0c9a1f9a746d2f1bbc96515a/Lib/pickletools.py#L2123
Proto = 0x80,
Global = b'c',
BinPut = b'q',
LongBinPut = b'r',
EmptyTuple = b')',
Reduce = b'R',
Mark = b'(',
BinUnicode = b'X',
BinInt = b'J',
Tuple = b't',
BinPersId = b'Q',
BinInt1 = b'K',
BinInt2 = b'M',
Tuple1 = 0x85,
Tuple2 = 0x86,
Tuple3 = 0x87,
NewTrue = 0x88,
NewFalse = 0x89,
None = b'N',
BinGet = b'h',
LongBinGet = b'j',
SetItem = b's',
SetItems = b'u',
EmptyDict = b'}',
Dict = b'd',
Build = b'b',
Stop = b'.',
NewObj = 0x81,
EmptyList = b']',
BinFloat = b'g',
Append = b'a',
Appends = b'e',
}
// Avoid using FromPrimitive so as not to drag another dependency.
impl TryFrom<u8> for OpCode {
type Error = u8;
fn try_from(value: u8) -> std::result::Result<Self, Self::Error> {
match value {
0x80 => Ok(Self::Proto),
b'c' => Ok(Self::Global),
b'q' => Ok(Self::BinPut),
b'r' => Ok(Self::LongBinPut),
b')' => Ok(Self::EmptyTuple),
b'R' => Ok(Self::Reduce),
b'(' => Ok(Self::Mark),
b'X' => Ok(Self::BinUnicode),
b'J' => Ok(Self::BinInt),
b't' => Ok(Self::Tuple),
b'Q' => Ok(Self::BinPersId),
b'K' => Ok(Self::BinInt1),
b'M' => Ok(Self::BinInt2),
b'N' => Ok(Self::None),
0x85 => Ok(Self::Tuple1),
0x86 => Ok(Self::Tuple2),
0x87 => Ok(Self::Tuple3),
0x88 => Ok(Self::NewTrue),
0x89 => Ok(Self::NewFalse),
b'h' => Ok(Self::BinGet),
b'j' => Ok(Self::LongBinGet),
b's' => Ok(Self::SetItem),
b'u' => Ok(Self::SetItems),
b'}' => Ok(Self::EmptyDict),
b'd' => Ok(Self::EmptyDict),
b'b' => Ok(Self::Build),
b'.' => Ok(Self::Stop),
0x81 => Ok(Self::NewObj),
b']' => Ok(Self::EmptyList),
b'G' => Ok(Self::BinFloat),
b'a' => Ok(Self::Append),
b'e' => Ok(Self::Appends),
value => Err(value),
}
}
}
fn read_to_newline<R: BufRead>(r: &mut R) -> Result<Vec<u8>> {
let mut data: Vec<u8> = Vec::with_capacity(32);
r.read_until(b'\n', &mut data)?;
data.pop();
if data.last() == Some(&b'\r') {
data.pop();
}
Ok(data)
}
#[derive(Debug, Clone, PartialEq)]
pub enum Object {
Class {
module_name: String,
class_name: String,
},
Int(i32),
Float(f64),
Unicode(String),
Bool(bool),
None,
Tuple(Vec<Object>),
List(Vec<Object>),
Mark,
Dict(Vec<(Object, Object)>),
Reduce {
callable: Box<Object>,
args: Box<Object>,
},
Build {
callable: Box<Object>,
args: Box<Object>,
},
PersistentLoad(Box<Object>),
}
type OResult<T> = std::result::Result<T, Object>;
impl Object {
pub fn unicode(self) -> OResult<String> {
match self {
Self::Unicode(t) => Ok(t),
_ => Err(self),
}
}
pub fn reduce(self) -> OResult<(Self, Self)> {
match self {
Self::Reduce { callable, args } => Ok((*callable, *args)),
_ => Err(self),
}
}
pub fn none(self) -> OResult<()> {
match self {
Self::None => Ok(()),
_ => Err(self),
}
}
pub fn persistent_load(self) -> OResult<Self> {
match self {
Self::PersistentLoad(t) => Ok(*t),
_ => Err(self),
}
}
pub fn bool(self) -> OResult<bool> {
match self {
Self::Bool(t) => Ok(t),
_ => Err(self),
}
}
pub fn int(self) -> OResult<i32> {
match self {
Self::Int(t) => Ok(t),
_ => Err(self),
}
}
pub fn tuple(self) -> OResult<Vec<Self>> {
match self {
Self::Tuple(t) => Ok(t),
_ => Err(self),
}
}
pub fn dict(self) -> OResult<Vec<(Self, Self)>> {
match self {
Self::Dict(t) => Ok(t),
_ => Err(self),
}
}
pub fn class(self) -> OResult<(String, String)> {
match self {
Self::Class {
module_name,
class_name,
} => Ok((module_name, class_name)),
_ => Err(self),
}
}
}
impl TryFrom<Object> for String {
type Error = Object;
fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
match value {
Object::Unicode(s) => Ok(s),
other => Err(other),
}
}
}
impl TryFrom<Object> for usize {
type Error = Object;
fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
match value {
Object::Int(s) if s >= 0 => Ok(s as usize),
other => Err(other),
}
}
}
impl<T: TryFrom<Object, Error = Object>> TryFrom<Object> for Vec<T> {
type Error = Object;
fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
match value {
Object::Tuple(values) => {
// This does not return the appropriate value in the error case but instead return
// the object related to the first error.
values
.into_iter()
.map(|v| T::try_from(v))
.collect::<std::result::Result<Vec<T>, Self::Error>>()
}
other => Err(other),
}
}
}
#[derive(Debug)]
pub struct Stack {
stack: Vec<Object>,
memo: HashMap<u32, Object>,
}
impl Stack {
pub fn empty() -> Self {
Self {
stack: Vec::with_capacity(512),
memo: HashMap::new(),
}
}
pub fn stack(&self) -> &[Object] {
self.stack.as_slice()
}
pub fn read_loop<R: BufRead>(&mut self, r: &mut R) -> Result<()> {
loop {
if self.read(r)? {
break;
}
}
Ok(())
}
pub fn finalize(mut self) -> Result<Object> {
self.pop()
}
fn push(&mut self, obj: Object) {
self.stack.push(obj)
}
fn pop(&mut self) -> Result<Object> {
match self.stack.pop() {
None => crate::bail!("unexpected empty stack"),
Some(obj) => Ok(obj),
}
}
// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/#Pickle.OpCodes.BUILD
fn build(&mut self) -> Result<()> {
let args = self.pop()?;
let obj = self.pop()?;
let obj = match (obj, args) {
(Object::Dict(mut obj), Object::Dict(mut args)) => {
obj.append(&mut args);
Object::Dict(obj)
}
(obj, args) => Object::Build {
callable: Box::new(obj),
args: Box::new(args),
},
};
self.push(obj);
Ok(())
}
fn reduce(&mut self) -> Result<()> {
let args = self.pop()?;
let callable = self.pop()?;
#[allow(clippy::single_match)]
let reduced = match &callable {
Object::Class {
module_name,
class_name,
} => {
if module_name == "collections" && class_name == "OrderedDict" {
// TODO: have a separate ordered dict.
Some(Object::Dict(vec![]))
} else {
None
}
}
_ => None,
};
let reduced = reduced.unwrap_or_else(|| Object::Reduce {
callable: Box::new(callable),
args: Box::new(args),
});
self.push(reduced);
Ok(())
}
fn last(&mut self) -> Result<&mut Object> {
match self.stack.last_mut() {
None => crate::bail!("unexpected empty stack"),
Some(obj) => Ok(obj),
}
}
fn memo_get(&self, id: u32) -> Result<Object> {
match self.memo.get(&id) {
None => crate::bail!("missing object in memo {id}"),
Some(obj) => {
// Maybe we should use refcounting rather than doing potential large clones here.
Ok(obj.clone())
}
}
}
fn memo_put(&mut self, id: u32) -> Result<()> {
let obj = self.last()?.clone();
self.memo.insert(id, obj);
Ok(())
}
fn persistent_load(&self, id: Object) -> Result<Object> {
Ok(Object::PersistentLoad(Box::new(id)))
}
fn new_obj(&self, class: Object, args: Object) -> Result<Object> {
Ok(Object::Reduce {
callable: Box::new(class),
args: Box::new(args),
})
}
fn pop_to_marker(&mut self) -> Result<Vec<Object>> {
let mut mark_idx = None;
for (idx, obj) in self.stack.iter().enumerate().rev() {
if obj == &Object::Mark {
mark_idx = Some(idx);
break;
}
}
match mark_idx {
Some(mark_idx) => {
let objs = self.stack.split_off(mark_idx + 1);
self.stack.pop();
Ok(objs)
}
None => {
crate::bail!("marker object not found")
}
}
}
pub fn read<R: BufRead>(&mut self, r: &mut R) -> Result<bool> {
let op_code = match OpCode::try_from(r.read_u8()?) {
Ok(op_code) => op_code,
Err(op_code) => {
crate::bail!("unknown op-code {op_code}")
}
};
// println!("op: {op_code:?}");
// println!("{:?}", self.stack);
match op_code {
OpCode::Proto => {
let version = r.read_u8()?;
if VERBOSE {
println!("proto {version}");
}
}
OpCode::Global => {
let module_name = read_to_newline(r)?;
let class_name = read_to_newline(r)?;
let module_name = String::from_utf8_lossy(&module_name).to_string();
let class_name = String::from_utf8_lossy(&class_name).to_string();
self.push(Object::Class {
module_name,
class_name,
})
}
OpCode::BinInt1 => {
let arg = r.read_u8()?;
self.push(Object::Int(arg as i32))
}
OpCode::BinInt2 => {
let arg = r.read_u16::<LittleEndian>()?;
self.push(Object::Int(arg as i32))
}
OpCode::BinInt => {
let arg = r.read_i32::<LittleEndian>()?;
self.push(Object::Int(arg))
}
OpCode::BinFloat => {
let arg = r.read_f64::<LittleEndian>()?;
self.push(Object::Float(arg))
}
OpCode::BinUnicode => {
let len = r.read_u32::<LittleEndian>()?;
let mut data = vec![0u8; len as usize];
r.read_exact(&mut data)?;
let data = String::from_utf8(data).map_err(E::wrap)?;
self.push(Object::Unicode(data))
}
OpCode::BinPersId => {
let id = self.pop()?;
let obj = self.persistent_load(id)?;
self.push(obj)
}
OpCode::Tuple => {
let objs = self.pop_to_marker()?;
self.push(Object::Tuple(objs))
}
OpCode::Tuple1 => {
let obj = self.pop()?;
self.push(Object::Tuple(vec![obj]))
}
OpCode::Tuple2 => {
let obj2 = self.pop()?;
let obj1 = self.pop()?;
self.push(Object::Tuple(vec![obj1, obj2]))
}
OpCode::Tuple3 => {
let obj3 = self.pop()?;
let obj2 = self.pop()?;
let obj1 = self.pop()?;
self.push(Object::Tuple(vec![obj1, obj2, obj3]))
}
OpCode::NewTrue => self.push(Object::Bool(true)),
OpCode::NewFalse => self.push(Object::Bool(false)),
OpCode::Append => {
let value = self.pop()?;
let pylist = self.last()?;
if let Object::List(d) = pylist {
d.push(value)
} else {
crate::bail!("expected a list, got {pylist:?}")
}
}
OpCode::Appends => {
let objs = self.pop_to_marker()?;
let pylist = self.last()?;
if let Object::List(d) = pylist {
d.extend(objs)
} else {
crate::bail!("expected a list, got {pylist:?}")
}
}
OpCode::SetItem => {
let value = self.pop()?;
let key = self.pop()?;
let pydict = self.last()?;
if let Object::Dict(d) = pydict {
d.push((key, value))
} else {
crate::bail!("expected a dict, got {pydict:?}")
}
}
OpCode::SetItems => {
let mut objs = self.pop_to_marker()?;
let pydict = self.last()?;
if let Object::Dict(d) = pydict {
if objs.len() % 2 != 0 {
crate::bail!("setitems: not an even number of objects")
}
while let Some(value) = objs.pop() {
let key = objs.pop().unwrap();
d.push((key, value))
}
} else {
crate::bail!("expected a dict, got {pydict:?}")
}
}
OpCode::None => self.push(Object::None),
OpCode::Stop => {
return Ok(true);
}
OpCode::Build => self.build()?,
OpCode::EmptyDict => self.push(Object::Dict(vec![])),
OpCode::Dict => {
let mut objs = self.pop_to_marker()?;
let mut pydict = vec![];
if objs.len() % 2 != 0 {
crate::bail!("setitems: not an even number of objects")
}
while let Some(value) = objs.pop() {
let key = objs.pop().unwrap();
pydict.push((key, value))
}
self.push(Object::Dict(pydict))
}
OpCode::Mark => self.push(Object::Mark),
OpCode::Reduce => self.reduce()?,
OpCode::EmptyTuple => self.push(Object::Tuple(vec![])),
OpCode::EmptyList => self.push(Object::List(vec![])),
OpCode::BinGet => {
let arg = r.read_u8()?;
let obj = self.memo_get(arg as u32)?;
self.push(obj)
}
OpCode::LongBinGet => {
let arg = r.read_u32::<LittleEndian>()?;
let obj = self.memo_get(arg)?;
self.push(obj)
}
OpCode::BinPut => {
let arg = r.read_u8()?;
self.memo_put(arg as u32)?
}
OpCode::LongBinPut => {
let arg = r.read_u32::<LittleEndian>()?;
self.memo_put(arg)?
}
OpCode::NewObj => {
let args = self.pop()?;
let class = self.pop()?;
let obj = self.new_obj(class, args)?;
self.push(obj)
}
}
Ok(false)
}
}
impl From<Object> for E {
fn from(value: Object) -> Self {
E::Msg(format!("conversion error on {value:?}"))
}
}
// https://github.com/pytorch/pytorch/blob/4eac43d046ded0f0a5a5fa8db03eb40f45bf656e/torch/_utils.py#L198
// Arguments: storage, storage_offset, size, stride, requires_grad, backward_hooks
fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
let mut args = args.tuple()?;
let stride = Vec::<usize>::try_from(args.remove(3))?;
let size = Vec::<usize>::try_from(args.remove(2))?;
let offset = args.remove(1).int()? as usize;
let storage = args.remove(0).persistent_load()?;
let mut storage = storage.tuple()?;
let storage_size = storage.remove(4).int()? as usize;
let path = storage.remove(2).unicode()?;
let (_module_name, class_name) = storage.remove(1).class()?;
let dtype = match class_name.as_str() {
"FloatStorage" => DType::F32,
"DoubleStorage" => DType::F64,
"HalfStorage" => DType::F16,
"BFloat16Storage" => DType::BF16,
"ByteStorage" => DType::U8,
other => {
crate::bail!("unsupported storage type {other}")
}
};
let layout = Layout::new(crate::Shape::from(size), stride, offset);
Ok((layout, dtype, path, storage_size))
}
#[derive(Debug, Clone)]
pub struct TensorInfo {
pub name: String,
pub dtype: DType,
pub layout: Layout,
pub path: String,
pub storage_size: usize,
}
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
file: P,
verbose: bool,
) -> Result<Vec<TensorInfo>> {
let file = std::fs::File::open(file)?;
let zip_reader = std::io::BufReader::new(file);
let mut zip = zip::ZipArchive::new(zip_reader)?;
let zip_file_names = zip
.file_names()
.map(|f| f.to_string())
.collect::<Vec<String>>();
let mut tensor_infos = vec![];
for file_name in zip_file_names.iter() {
if !file_name.ends_with("data.pkl") {
continue;
}
let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").unwrap());
let reader = zip.by_name(file_name)?;
let mut reader = std::io::BufReader::new(reader);
let mut stack = Stack::empty();
stack.read_loop(&mut reader)?;
let obj = stack.finalize()?;
if VERBOSE || verbose {
println!("{obj:?}");
}
let obj = match obj {
Object::Build { callable, args } => match *callable {
Object::Reduce { callable, args: _ } => match *callable {
Object::Class {
module_name,
class_name,
} if module_name == "__torch__" && class_name == "Module" => *args,
_ => continue,
},
_ => continue,
},
obj => obj,
};
if let Object::Dict(key_values) = obj {
for (name, value) in key_values.into_iter() {
let name = match name.unicode() {
Ok(name) => name,
Err(_) => continue,
};
let (callable, args) = match value.reduce() {
Ok(callable_args) => callable_args,
_ => continue,
};
let (callable, args) = match callable {
Object::Class {
module_name,
class_name,
} if module_name == "torch._tensor"
&& class_name == "_rebuild_from_type_v2" =>
{
let mut args = args.tuple()?;
let callable = args.remove(0);
let args = args.remove(1);
(callable, args)
}
_ => (callable, args),
};
match callable {
Object::Class {
module_name,
class_name,
} if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
_ => continue,
};
match rebuild_args(args) {
Ok((layout, dtype, file_path, storage_size)) => {
let mut path = dir_name.clone();
path.push(file_path);
tensor_infos.push(TensorInfo {
name,
dtype,
layout,
path: path.to_string_lossy().into_owned(),
storage_size,
})
}
Err(err) => {
eprintln!("skipping {name}: {err:?}")
}
}
}
}
}
Ok(tensor_infos)
}
/// Lazy tensor loader.
pub struct PthTensors {
tensor_infos: HashMap<String, TensorInfo>,
path: std::path::PathBuf,
// We do not store a zip reader as it needs mutable access to extract data. Instead we
// re-create a zip reader for each tensor.
}
impl PthTensors {
pub fn new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
let tensor_infos = read_pth_tensor_info(path.as_ref(), false)?;
let tensor_infos = tensor_infos
.into_iter()
.map(|ti| (ti.name.to_string(), ti))
.collect();
let path = path.as_ref().to_owned();
Ok(Self { tensor_infos, path })
}
pub fn tensor_infos(&self) -> &HashMap<String, TensorInfo> {
&self.tensor_infos
}
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
let tensor_info = match self.tensor_infos.get(name) {
None => return Ok(None),
Some(tensor_info) => tensor_info,
};
// We hope that the file has not changed since first reading it.
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut reader = zip.by_name(&tensor_info.path)?;
// Reading the data is a bit tricky as it can be strided, use an offset, etc.
// For now only support the basic case.
if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() {
crate::bail!(
"cannot retrieve non-contiguous tensors {:?}",
tensor_info.layout
)
}
let tensor = Tensor::from_reader(
tensor_info.layout.shape().clone(),
tensor_info.dtype,
&mut reader,
)?;
Ok(Some(tensor))
}
}

View File

@ -1,640 +0,0 @@
use super::k_quants::{
BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K,
};
use crate::Result;
use byteorder::{ByteOrder, LittleEndian};
use half::f16;
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
#[inline(always)]
pub(crate) unsafe fn sum_i16_pairs_float(x: __m256i) -> __m256 {
let ones = _mm256_set1_epi16(1);
let summed_pairs = _mm256_madd_epi16(ones, x);
_mm256_cvtepi32_ps(summed_pairs)
}
#[inline(always)]
pub(crate) unsafe fn mul_sum_us8_pairs_float(ax: __m256i, sy: __m256i) -> __m256 {
let dot = _mm256_maddubs_epi16(ax, sy);
sum_i16_pairs_float(dot)
}
#[inline(always)]
pub(crate) unsafe fn hsum_float_8(x: __m256) -> f32 {
let res = _mm256_extractf128_ps(x, 1);
let res = _mm_add_ps(res, _mm256_castps256_ps128(x));
let res = _mm_add_ps(res, _mm_movehl_ps(res, res));
let res = _mm_add_ss(res, _mm_movehdup_ps(res));
_mm_cvtss_f32(res)
}
#[inline(always)]
pub(crate) unsafe fn bytes_from_nibbles_32(rsi: *const u8) -> __m256i {
let tmp = _mm_loadu_si128(rsi as *const __m128i);
let bytes = _mm256_insertf128_si256::<1>(_mm256_castsi128_si256(tmp), _mm_srli_epi16(tmp, 4));
let low_mask = _mm256_set1_epi8(0xF);
_mm256_and_si256(low_mask, bytes)
}
#[inline(always)]
pub(crate) unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 {
let ax = _mm256_sign_epi8(x, x);
let sy = _mm256_sign_epi8(y, x);
mul_sum_us8_pairs_float(ax, sy)
}
#[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
let nb = n / qk;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}
unsafe {
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d));
let bx = bytes_from_nibbles_32(x.qs.as_ptr());
let off = _mm256_set1_epi8(8);
let bx = _mm256_sub_epi8(bx, off);
let by = _mm256_loadu_si256(y.qs.as_ptr() as *const __m256i);
let q = mul_sum_i8_pairs_float(bx, by);
acc = _mm256_fmadd_ps(d, q, acc);
}
Ok(hsum_float_8(acc))
}
}
#[inline(always)]
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
}
unsafe {
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d));
let bx = _mm256_loadu_si256(x.qs.as_ptr() as *const __m256i);
let by = _mm256_loadu_si256(y.qs.as_ptr() as *const __m256i);
let q = mul_sum_i8_pairs_float(bx, by);
acc = _mm256_fmadd_ps(d, q, acc);
}
Ok(hsum_float_8(acc))
}
}
#[inline(always)]
unsafe fn get_scale_shuffle(i: usize) -> __m128i {
const K_SHUFFLE: [u8; 128] = [
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3,
3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7,
7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10,
11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13,
13, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15,
];
_mm_loadu_si128((K_SHUFFLE.as_ptr() as *const __m128i).add(i))
}
#[inline(always)]
unsafe fn get_scale_shuffle_k4(i: usize) -> __m256i {
const K_SHUFFLE: [u8; 256] = [
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
2, 3, 2, 3, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
6, 7, 6, 7, 6, 7, 6, 7, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10,
11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 12, 13, 12, 13, 12, 13,
12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12,
13, 12, 13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
];
_mm256_loadu_si256((K_SHUFFLE.as_ptr() as *const __m256i).add(i))
}
#[inline(always)]
unsafe fn get_scale_shuffle_q3k(i: usize) -> __m256i {
const K_SHUFFLE: [u8; 128] = [
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
2, 3, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
6, 7, 6, 7, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10, 11, 10, 11, 10, 11, 10, 11,
10, 11, 10, 11, 10, 11, 10, 11, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12,
13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
];
_mm256_loadu_si256((K_SHUFFLE.as_ptr() as *const __m256i).add(i))
}
#[inline(always)]
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
let qk = QK_K;
if n % qk != 0 {
crate::bail!("vec_dot_q6k_8k: {n} is not divisible by {qk}")
}
unsafe {
let m4 = _mm256_set1_epi8(0xF);
let m2 = _mm256_set1_epi8(3);
let m32s = _mm256_set1_epi8(32);
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let mut q4 = x.ql.as_ptr();
let mut qh = x.qh.as_ptr();
let mut q8 = y.qs.as_ptr();
let scales = _mm_loadu_si128(x.scales.as_ptr() as *const __m128i);
let mut sumi = _mm256_setzero_si256();
for j in 0..QK_K / 128 {
let is = j * 4;
let scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is));
let scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
let scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
let scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
let q4bits1 = _mm256_loadu_si256(q4 as *const __m256i);
q4 = q4.add(32);
let q4bits2 = _mm256_loadu_si256(q4 as *const __m256i);
q4 = q4.add(32);
let q4bits_h = _mm256_loadu_si256(qh as *const __m256i);
qh = qh.add(32);
let q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bits_h, m2), 4);
let q4h_1 =
_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bits_h, 2), m2), 4);
let q4h_2 =
_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bits_h, 4), m2), 4);
let q4h_3 =
_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bits_h, 6), m2), 4);
let q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
let q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
let q4_2 =
_mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
let q4_3 =
_mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_2 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_3 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
let q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
let q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
let q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
let p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
let p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
let p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
let p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
let p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
let p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
let p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
let p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
let p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
let p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
let p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2);
let p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3);
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3));
}
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
}
Ok(hsum_float_8(acc))
}
}
#[inline(always)]
unsafe fn mm256_set_m128i(a: __m128i, b: __m128i) -> __m256i {
_mm256_insertf128_si256(_mm256_castsi128_si256(b), a, 1)
}
#[inline(always)]
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
}
unsafe {
let m3 = _mm256_set1_epi8(3);
let m4 = _mm_set1_epi8(0xF);
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let dmin = -y.d * x.dmin.to_f32();
let mut q2 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();
let mins_and_scales = _mm_loadu_si128(x.scales.as_ptr() as *const __m128i);
let scales8 = _mm_and_si128(mins_and_scales, m4);
let mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
let mins = _mm256_cvtepi8_epi16(mins8);
let prod =
_mm256_madd_epi16(mins, _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i));
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc);
let all_scales = _mm256_cvtepi8_epi16(scales8);
let l_scales = _mm256_extracti128_si256(all_scales, 0);
let h_scales = _mm256_extracti128_si256(all_scales, 1);
let scales = [
mm256_set_m128i(l_scales, l_scales),
mm256_set_m128i(h_scales, h_scales),
];
let mut sumi = _mm256_setzero_si256();
for scale in scales {
let q2bits = _mm256_loadu_si256(q2 as *const __m256i);
q2 = q2.add(32);
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_2 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_3 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q2_0 = _mm256_and_si256(q2bits, m3);
let q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3);
let q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3);
let q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3);
let p0 = _mm256_maddubs_epi16(q2_0, q8_0);
let p1 = _mm256_maddubs_epi16(q2_1, q8_1);
let p2 = _mm256_maddubs_epi16(q2_2, q8_2);
let p3 = _mm256_maddubs_epi16(q2_3, q8_3);
let p0 =
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(0)), p0);
let p1 =
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(1)), p1);
let p2 =
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(2)), p2);
let p3 =
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(3)), p3);
let p0 = _mm256_add_epi32(p0, p1);
let p2 = _mm256_add_epi32(p2, p3);
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2));
}
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
}
Ok(hsum_float_8(acc))
}
}
#[inline(always)]
pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
}
const KMASK1: u32 = 0x03030303;
const KMASK2: u32 = 0x0f0f0f0f;
let mut aux = [0u32; 3];
unsafe {
let m3 = _mm256_set1_epi8(3);
let mone = _mm256_set1_epi8(1);
let m32 = _mm_set1_epi8(32);
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let mut q3 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();
LittleEndian::read_u32_into(&x.scales, &mut aux);
let scales128 = _mm_set_epi32(
(((aux[1] >> 4) & KMASK2) | (((aux[2] >> 6) & KMASK1) << 4)) as i32,
(((aux[0] >> 4) & KMASK2) | (((aux[2] >> 4) & KMASK1) << 4)) as i32,
((aux[1] & KMASK2) | (((aux[2] >> 2) & KMASK1) << 4)) as i32,
((aux[0] & KMASK2) | (((aux[2]) & KMASK1) << 4)) as i32,
);
let scales128 = _mm_sub_epi8(scales128, m32);
let all_scales = _mm256_cvtepi8_epi16(scales128);
let l_scales = _mm256_extracti128_si256(all_scales, 0);
let h_scales = _mm256_extracti128_si256(all_scales, 1);
let scales = [
mm256_set_m128i(l_scales, l_scales),
mm256_set_m128i(h_scales, h_scales),
];
// high bit
let hbits = _mm256_loadu_si256(x.hmask.as_ptr() as *const __m256i);
let mut sumi = _mm256_setzero_si256();
for (j, scale) in scales.iter().enumerate() {
// load low 2 bits
let q3bits = _mm256_loadu_si256(q3 as *const __m256i);
q3 = q3.add(32);
// Prepare low and high bits
// We hardcode the shifts here to avoid loading them into a seperate register
let q3l_0 = _mm256_and_si256(q3bits, m3);
let q3h_0 = if j == 0 {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), 0)
} else {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 4)), 4)
};
let q3h_0 = _mm256_slli_epi16(q3h_0, 2);
let q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3);
let q3h_1 = if j == 0 {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 1)), 1)
} else {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 5)), 5)
};
let q3h_1 = _mm256_slli_epi16(q3h_1, 2);
let q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3);
let q3h_2 = if j == 0 {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 2)), 2)
} else {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 6)), 6)
};
let q3h_2 = _mm256_slli_epi16(q3h_2, 2);
let q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3);
let q3h_3 = if j == 0 {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 3)), 3)
} else {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 7)), 7)
};
let q3h_3 = _mm256_slli_epi16(q3h_3, 2);
// load Q8 quants
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_2 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_3 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we
// can use _mm256_maddubs_epi16, and then subtract. The high bit part has the 2
// already subtracted (and so, it is zero if the high bit was not set, and 2 if the
// high bit was set)
let q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
let q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
let q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2);
let q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3);
let p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
let p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
let p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2);
let p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3);
let p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
let p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
let p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
let p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
// multiply with scales
let p16_0 =
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(0)), p16_0);
let p16_1 =
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(1)), p16_1);
let p16_2 =
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(2)), p16_2);
let p16_3 =
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(3)), p16_3);
// accumulate
let p16_0 = _mm256_add_epi32(p16_0, p16_1);
let p16_2 = _mm256_add_epi32(p16_2, p16_3);
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2));
}
// multiply with block scale and accumulate
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
}
Ok(hsum_float_8(acc))
}
}
#[inline(always)]
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
}
let mut utmp = [0u32; 4];
const KMASK1: u32 = 0x3f3f3f3f;
const KMASK2: u32 = 0x0f0f0f0f;
const KMASK3: u32 = 0x03030303;
unsafe {
let m4 = _mm256_set1_epi8(0xF);
let mut acc = _mm256_setzero_ps();
let mut acc_m = _mm_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let dmin = -y.d * x.dmin.to_f32();
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
let uaux = utmp[1] & KMASK1;
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
utmp[2] = uaux;
utmp[0] &= KMASK1;
let mut q4 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();
let mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(
utmp[3] as i32,
utmp[2] as i32,
utmp[1] as i32,
utmp[0] as i32,
));
let q8sums = _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i);
let q8s = _mm_hadd_epi16(
_mm256_extracti128_si256(q8sums, 0),
_mm256_extracti128_si256(q8sums, 1),
);
let prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
let sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
let scales = mm256_set_m128i(sc128, sc128);
let mut sumi = _mm256_setzero_si256();
for j in 0..QK_K / 64 {
let scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j));
let scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j + 1));
let q4bits = _mm256_loadu_si256(q4 as *const __m256i);
q4 = q4.add(32);
let q4l = _mm256_and_si256(q4bits, m4);
let q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
let q8l = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let p16l = _mm256_maddubs_epi16(q4l, q8l);
let p16l = _mm256_madd_epi16(scale_l, p16l);
sumi = _mm256_add_epi32(sumi, p16l);
let q8h = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let p16h = _mm256_maddubs_epi16(q4h, q8h);
let p16h = _mm256_madd_epi16(scale_h, p16h);
sumi = _mm256_add_epi32(sumi, p16h);
}
let vd = _mm256_set1_ps(d);
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
}
let acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
let acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
Ok(hsum_float_8(acc) + _mm_cvtss_f32(acc_m))
}
}
#[inline(always)]
pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
}
let mut utmp = [0u32; 4];
const KMASK1: u32 = 0x3f3f3f3f;
const KMASK2: u32 = 0x0f0f0f0f;
const KMASK3: u32 = 0x03030303;
unsafe {
let m4 = _mm256_set1_epi8(0xF);
let mzero = _mm_setzero_si128();
let mone = _mm256_set1_epi8(1);
let mut acc = _mm256_setzero_ps();
let mut summs = 0.0;
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let dmin = -y.d * x.dmin.to_f32();
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
let uaux = utmp[1] & KMASK1;
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
utmp[2] = uaux;
utmp[0] &= KMASK1;
let mut q5 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();
let mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(
utmp[3] as i32,
utmp[2] as i32,
utmp[1] as i32,
utmp[0] as i32,
));
let q8sums = _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i);
let q8s = _mm_hadd_epi16(
_mm256_extracti128_si256(q8sums, 0),
_mm256_extracti128_si256(q8sums, 1),
);
let prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
let hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
summs += dmin * _mm_extract_epi32(hsum, 0) as f32;
let sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
let scales = mm256_set_m128i(sc128, sc128);
let hbits = _mm256_loadu_si256(x.qh.as_ptr() as *const __m256i);
let mut hmask = mone;
let mut sumi = _mm256_setzero_si256();
for j in 0..QK_K / 64 {
let scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j));
let scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j + 1));
let q5bits = _mm256_loadu_si256(q5 as *const __m256i);
q5 = q5.add(32);
//Similar to q3k we hardcode the shifts here to avoid loading them into a seperate register
let q5l_0 = _mm256_and_si256(q5bits, m4);
let q5l_0_shift_input = _mm256_and_si256(hbits, hmask);
let q5l_0_right_shift = match j {
0 => _mm256_srli_epi16(q5l_0_shift_input, 0),
1 => _mm256_srli_epi16(q5l_0_shift_input, 2),
2 => _mm256_srli_epi16(q5l_0_shift_input, 4),
3 => _mm256_srli_epi16(q5l_0_shift_input, 6),
_ => unreachable!(),
};
let q5h_0 = _mm256_slli_epi16(q5l_0_right_shift, 4);
let q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
hmask = _mm256_slli_epi16(hmask, 1);
let q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
let q5l_1_shift_input = _mm256_and_si256(hbits, hmask);
let q5l_1_right_shift = match j {
0 => _mm256_srli_epi16(q5l_1_shift_input, 1),
1 => _mm256_srli_epi16(q5l_1_shift_input, 3),
2 => _mm256_srli_epi16(q5l_1_shift_input, 5),
3 => _mm256_srli_epi16(q5l_1_shift_input, 7),
_ => unreachable!(),
};
let q5h_1 = _mm256_slli_epi16(q5l_1_right_shift, 4);
let q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
hmask = _mm256_slli_epi16(hmask, 1);
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
let p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
let p16_0 = _mm256_madd_epi16(scale_0, p16_0);
let p16_1 = _mm256_madd_epi16(scale_1, p16_1);
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
}
let vd = _mm256_set1_ps(d);
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
}
Ok(hsum_float_8(acc) + summs)
}
}

View File

@ -3,7 +3,6 @@
use super::{k_quants, GgmlDType};
use crate::Result;
use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap;
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -125,7 +124,7 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
let raw_data_ptr = raw_data.as_ptr();
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
super::QTensor::new(data.to_vec(), dims)
Ok(super::QTensor::new(data.to_vec(), dims))
}
/// Creates a [Tensor] from a raw GGML tensor.
@ -164,9 +163,6 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
let ggml_dtype = GgmlDType::from_u32(ggml_dtype)?;
let mut dims = vec![0u32; n_dims as usize];
reader.read_u32_into::<LittleEndian>(&mut dims)?;
// The dimensions are stored in reverse order, see for example:
// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/convert.py#L969
dims.reverse();
let mut name = vec![0u8; name_len as usize];
reader.read_exact(&mut name)?;
let name = String::from_utf8_lossy(&name).into_owned();
@ -178,6 +174,7 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
let tensor_elems = dims.iter().product::<usize>();
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
println!("{name} {ggml_dtype:?} {dims:?}");
// TODO: Mmap version to avoid copying the data around?
let mut raw_data = vec![0u8; size_in_bytes];
reader.read_exact(&mut raw_data)?;
@ -191,7 +188,7 @@ pub struct Content {
pub magic: VersionedMagic,
pub hparams: HParams,
pub vocab: Vocab,
pub tensors: HashMap<String, super::QTensor>,
pub tensors: Vec<(String, super::QTensor)>,
}
impl Content {
@ -202,11 +199,11 @@ impl Content {
let magic = VersionedMagic::read(reader)?;
let hparams = HParams::read(reader)?;
let vocab = Vocab::read(reader, hparams.n_vocab as usize)?;
let mut tensors = HashMap::new();
let mut tensors = vec![];
while reader.stream_position()? != last_position {
let (name, tensor) = read_one_tensor(reader, magic)?;
tensors.insert(name, tensor);
tensors.push((name, tensor))
}
Ok(Self {
magic,
@ -215,11 +212,4 @@ impl Content {
tensors,
})
}
pub fn remove(&mut self, name: &str) -> Result<super::QTensor> {
match self.tensors.remove(name) {
None => crate::bail!("cannot find tensor with name '{name}'"),
Some(tensor) => Ok(tensor),
}
}
}

View File

@ -1,513 +0,0 @@
//! Support for the GGUF file format.
//!
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
use super::{GgmlDType, QTensor};
use crate::Result;
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::collections::HashMap;
pub const DEFAULT_ALIGNMENT: u64 = 32;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Magic {
Gguf,
}
impl TryFrom<u32> for Magic {
type Error = crate::Error;
fn try_from(value: u32) -> Result<Self> {
let magic = match value {
0x46554747 | 0x47475546 => Self::Gguf,
_ => crate::bail!("unknown magic 0x{value:08x}"),
};
Ok(magic)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VersionedMagic {
GgufV1,
GgufV2,
}
impl VersionedMagic {
fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
let magic = reader.read_u32::<LittleEndian>()?;
let magic = Magic::try_from(magic)?;
let version = reader.read_u32::<LittleEndian>()?;
let versioned_magic = match (magic, version) {
(Magic::Gguf, 1) => Self::GgufV1,
(Magic::Gguf, 2) => Self::GgufV2,
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
};
Ok(versioned_magic)
}
}
#[derive(Debug)]
pub struct TensorInfo {
pub ggml_dtype: GgmlDType,
pub shape: crate::Shape,
pub offset: u64,
}
impl TensorInfo {
pub fn read<R: std::io::Seek + std::io::Read>(
&self,
reader: &mut R,
tensor_data_offset: u64,
) -> Result<QTensor> {
let tensor_elems = self.shape.elem_count();
let size_in_bytes =
tensor_elems * self.ggml_dtype.type_size() / self.ggml_dtype.blck_size();
let mut raw_data = vec![0u8; size_in_bytes];
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
reader.read_exact(&mut raw_data)?;
super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec())
}
}
#[derive(Debug)]
pub struct Content {
pub magic: VersionedMagic,
pub metadata: HashMap<String, Value>,
pub tensor_infos: HashMap<String, TensorInfo>,
pub tensor_data_offset: u64,
}
fn read_string<R: std::io::Read>(reader: &mut R, magic: &VersionedMagic) -> Result<String> {
let len = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
};
let mut v = vec![0u8; len];
reader.read_exact(&mut v)?;
// GGUF strings are supposed to be non-null terminated but in practice this happens.
while let Some(0) = v.last() {
v.pop();
}
// GGUF strings are utf8 encoded but there are cases that don't seem to be valid.
Ok(String::from_utf8_lossy(&v).into_owned())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ValueType {
// The value is a 8-bit unsigned integer.
U8,
// The value is a 8-bit signed integer.
I8,
// The value is a 16-bit unsigned little-endian integer.
U16,
// The value is a 16-bit signed little-endian integer.
I16,
// The value is a 32-bit unsigned little-endian integer.
U32,
// The value is a 32-bit signed little-endian integer.
I32,
// The value is a 64-bit unsigned little-endian integer.
U64,
// The value is a 64-bit signed little-endian integer.
I64,
// The value is a 32-bit IEEE754 floating point number.
F32,
// The value is a 64-bit IEEE754 floating point number.
F64,
// The value is a boolean.
// 1-byte value where 0 is false and 1 is true.
// Anything else is invalid, and should be treated as either the model being invalid or the reader being buggy.
Bool,
// The value is a UTF-8 non-null-terminated string, with length prepended.
String,
// The value is an array of other values, with the length and type prepended.
///
// Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes.
Array,
}
#[derive(Debug, Clone)]
pub enum Value {
U8(u8),
I8(i8),
U16(u16),
I16(i16),
U32(u32),
I32(i32),
U64(u64),
I64(i64),
F32(f32),
F64(f64),
Bool(bool),
String(String),
Array(Vec<Value>),
}
impl Value {
pub fn value_type(&self) -> ValueType {
match self {
Self::U8(_) => ValueType::U8,
Self::I8(_) => ValueType::I8,
Self::U16(_) => ValueType::U16,
Self::I16(_) => ValueType::I16,
Self::U32(_) => ValueType::U32,
Self::I32(_) => ValueType::I32,
Self::U64(_) => ValueType::U64,
Self::I64(_) => ValueType::I64,
Self::F32(_) => ValueType::F32,
Self::F64(_) => ValueType::F64,
Self::Bool(_) => ValueType::Bool,
Self::String(_) => ValueType::String,
Self::Array(_) => ValueType::Array,
}
}
pub fn to_u8(&self) -> Result<u8> {
match self {
Self::U8(v) => Ok(*v),
v => crate::bail!("not a u8 {v:?}"),
}
}
pub fn to_i8(&self) -> Result<i8> {
match self {
Self::I8(v) => Ok(*v),
v => crate::bail!("not a i8 {v:?}"),
}
}
pub fn to_u16(&self) -> Result<u16> {
match self {
Self::U16(v) => Ok(*v),
v => crate::bail!("not a u16 {v:?}"),
}
}
pub fn to_i16(&self) -> Result<i16> {
match self {
Self::I16(v) => Ok(*v),
v => crate::bail!("not a i16 {v:?}"),
}
}
pub fn to_u32(&self) -> Result<u32> {
match self {
Self::U32(v) => Ok(*v),
v => crate::bail!("not a u32 {v:?}"),
}
}
pub fn to_i32(&self) -> Result<i32> {
match self {
Self::I32(v) => Ok(*v),
v => crate::bail!("not a i32 {v:?}"),
}
}
pub fn to_u64(&self) -> Result<u64> {
match self {
Self::U64(v) => Ok(*v),
v => crate::bail!("not a u64 {v:?}"),
}
}
pub fn to_i64(&self) -> Result<i64> {
match self {
Self::I64(v) => Ok(*v),
v => crate::bail!("not a i64 {v:?}"),
}
}
pub fn to_f32(&self) -> Result<f32> {
match self {
Self::F32(v) => Ok(*v),
v => crate::bail!("not a f32 {v:?}"),
}
}
pub fn to_f64(&self) -> Result<f64> {
match self {
Self::F64(v) => Ok(*v),
v => crate::bail!("not a f64 {v:?}"),
}
}
pub fn to_bool(&self) -> Result<bool> {
match self {
Self::Bool(v) => Ok(*v),
v => crate::bail!("not a bool {v:?}"),
}
}
pub fn to_vec(&self) -> Result<&Vec<Value>> {
match self {
Self::Array(v) => Ok(v),
v => crate::bail!("not a vec {v:?}"),
}
}
pub fn to_string(&self) -> Result<&String> {
match self {
Self::String(v) => Ok(v),
v => crate::bail!("not a string {v:?}"),
}
}
fn read<R: std::io::Read>(
reader: &mut R,
value_type: ValueType,
magic: &VersionedMagic,
) -> Result<Self> {
let v = match value_type {
ValueType::U8 => Self::U8(reader.read_u8()?),
ValueType::I8 => Self::I8(reader.read_i8()?),
ValueType::U16 => Self::U16(reader.read_u16::<LittleEndian>()?),
ValueType::I16 => Self::I16(reader.read_i16::<LittleEndian>()?),
ValueType::U32 => Self::U32(reader.read_u32::<LittleEndian>()?),
ValueType::I32 => Self::I32(reader.read_i32::<LittleEndian>()?),
ValueType::U64 => Self::U64(reader.read_u64::<LittleEndian>()?),
ValueType::I64 => Self::I64(reader.read_i64::<LittleEndian>()?),
ValueType::F32 => Self::F32(reader.read_f32::<LittleEndian>()?),
ValueType::F64 => Self::F64(reader.read_f64::<LittleEndian>()?),
ValueType::Bool => match reader.read_u8()? {
0 => Self::Bool(false),
1 => Self::Bool(true),
b => crate::bail!("unexpected bool value {b}"),
},
ValueType::String => Self::String(read_string(reader, magic)?),
ValueType::Array => {
let value_type = reader.read_u32::<LittleEndian>()?;
let value_type = ValueType::from_u32(value_type)?;
let len = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
};
let mut vs = Vec::with_capacity(len);
for _ in 0..len {
vs.push(Value::read(reader, value_type, magic)?)
}
Self::Array(vs)
}
};
Ok(v)
}
fn write<W: std::io::Write>(&self, w: &mut W) -> Result<()> {
match self {
&Self::U8(v) => w.write_u8(v)?,
&Self::I8(v) => w.write_i8(v)?,
&Self::U16(v) => w.write_u16::<LittleEndian>(v)?,
&Self::I16(v) => w.write_i16::<LittleEndian>(v)?,
&Self::U32(v) => w.write_u32::<LittleEndian>(v)?,
&Self::I32(v) => w.write_i32::<LittleEndian>(v)?,
&Self::U64(v) => w.write_u64::<LittleEndian>(v)?,
&Self::I64(v) => w.write_i64::<LittleEndian>(v)?,
&Self::F32(v) => w.write_f32::<LittleEndian>(v)?,
&Self::F64(v) => w.write_f64::<LittleEndian>(v)?,
&Self::Bool(v) => w.write_u8(u8::from(v))?,
Self::String(v) => write_string(w, v.as_str())?,
Self::Array(v) => {
// The `Value` type does not enforce that all the values in an Array have the same
// type.
let value_type = if v.is_empty() {
// Doesn't matter, the array is empty.
ValueType::U32
} else {
let value_type: std::collections::HashSet<_> =
v.iter().map(|elem| elem.value_type()).collect();
if value_type.len() != 1 {
crate::bail!("multiple value-types in the same array {value_type:?}")
}
value_type.into_iter().next().unwrap()
};
w.write_u32::<LittleEndian>(value_type.to_u32())?;
w.write_u64::<LittleEndian>(v.len() as u64)?;
for elem in v.iter() {
elem.write(w)?
}
}
}
Ok(())
}
}
impl ValueType {
fn from_u32(v: u32) -> Result<Self> {
let v = match v {
0 => Self::U8,
1 => Self::I8,
2 => Self::U16,
3 => Self::I16,
4 => Self::U32,
5 => Self::I32,
6 => Self::F32,
7 => Self::Bool,
8 => Self::String,
9 => Self::Array,
10 => Self::U64,
11 => Self::I64,
12 => Self::F64,
v => crate::bail!("unrecognized value-type {v:#08x}"),
};
Ok(v)
}
fn to_u32(self) -> u32 {
match self {
Self::U8 => 0,
Self::I8 => 1,
Self::U16 => 2,
Self::I16 => 3,
Self::U32 => 4,
Self::I32 => 5,
Self::F32 => 6,
Self::Bool => 7,
Self::String => 8,
Self::Array => 9,
Self::U64 => 10,
Self::I64 => 11,
Self::F64 => 12,
}
}
}
impl Content {
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Self> {
let magic = VersionedMagic::read(reader)?;
let tensor_count = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
};
let metadata_kv_count = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
};
let mut metadata = HashMap::new();
for _idx in 0..metadata_kv_count {
let key = read_string(reader, &magic)?;
let value_type = reader.read_u32::<LittleEndian>()?;
let value_type = ValueType::from_u32(value_type)?;
let value = Value::read(reader, value_type, &magic)?;
metadata.insert(key, value);
}
let mut tensor_infos = HashMap::new();
for _idx in 0..tensor_count {
let tensor_name = read_string(reader, &magic)?;
let n_dimensions = reader.read_u32::<LittleEndian>()?;
let mut dimensions: Vec<usize> = match magic {
VersionedMagic::GgufV1 => {
let mut dimensions = vec![0; n_dimensions as usize];
reader.read_u32_into::<LittleEndian>(&mut dimensions)?;
dimensions.into_iter().map(|c| c as usize).collect()
}
VersionedMagic::GgufV2 => {
let mut dimensions = vec![0; n_dimensions as usize];
reader.read_u64_into::<LittleEndian>(&mut dimensions)?;
dimensions.into_iter().map(|c| c as usize).collect()
}
};
dimensions.reverse();
let ggml_dtype = reader.read_u32::<LittleEndian>()?;
let ggml_dtype = GgmlDType::from_u32(ggml_dtype)?;
let offset = reader.read_u64::<LittleEndian>()?;
tensor_infos.insert(
tensor_name,
TensorInfo {
shape: crate::Shape::from(dimensions),
offset,
ggml_dtype,
},
);
}
let position = reader.stream_position()?;
let alignment = match metadata.get("general.alignment") {
Some(Value::U8(v)) => *v as u64,
Some(Value::U16(v)) => *v as u64,
Some(Value::U32(v)) => *v as u64,
Some(Value::I8(v)) if *v >= 0 => *v as u64,
Some(Value::I16(v)) if *v >= 0 => *v as u64,
Some(Value::I32(v)) if *v >= 0 => *v as u64,
_ => DEFAULT_ALIGNMENT,
};
let tensor_data_offset = (position + alignment - 1) / alignment * alignment;
Ok(Self {
magic,
metadata,
tensor_infos,
tensor_data_offset,
})
}
pub fn tensor<R: std::io::Seek + std::io::Read>(
&self,
reader: &mut R,
name: &str,
) -> Result<QTensor> {
let tensor_info = match self.tensor_infos.get(name) {
Some(tensor_info) => tensor_info,
None => crate::bail!("cannot find tensor-infor for {name}"),
};
tensor_info.read(reader, self.tensor_data_offset)
}
}
fn write_string<W: std::io::Write>(w: &mut W, str: &str) -> Result<()> {
let bytes = str.as_bytes();
w.write_u64::<LittleEndian>(bytes.len() as u64)?;
w.write_all(bytes)?;
Ok(())
}
pub fn write<W: std::io::Seek + std::io::Write>(
w: &mut W,
metadata: &[(&str, &Value)],
tensors: &[(&str, &QTensor)],
) -> Result<()> {
w.write_u32::<LittleEndian>(0x46554747)?;
w.write_u32::<LittleEndian>(2)?; // version 2.
w.write_u64::<LittleEndian>(tensors.len() as u64)?;
w.write_u64::<LittleEndian>(metadata.len() as u64)?;
for (name, value) in metadata.iter() {
write_string(w, name)?;
w.write_u32::<LittleEndian>(value.value_type().to_u32())?;
value.write(w)?;
}
let mut offset = 0usize;
let mut offsets = Vec::with_capacity(tensors.len());
for (name, tensor) in tensors.iter() {
write_string(w, name)?;
let dims = tensor.shape().dims();
w.write_u32::<LittleEndian>(dims.len() as u32)?;
for &dim in dims.iter().rev() {
w.write_u64::<LittleEndian>(dim as u64)?;
}
w.write_u32::<LittleEndian>(tensor.dtype().to_u32())?;
w.write_u64::<LittleEndian>(offset as u64)?;
offsets.push(offset);
let size_in_bytes = tensor.storage_size_in_bytes();
let padding = 31 - (31 + size_in_bytes) % 32;
offset += size_in_bytes + padding;
}
let pos = w.stream_position()? as usize;
let padding = 31 - (31 + pos) % 32;
w.write_all(&vec![0u8; padding])?;
let tensor_start_pos = w.stream_position()? as usize;
for (offset, (_name, tensor)) in offsets.iter().zip(tensors.iter()) {
let pos = w.stream_position()? as usize;
if tensor_start_pos + offset != pos {
crate::bail!(
"internal error, unexpected current position {tensor_start_pos} {offset} {pos}"
)
}
let data_ptr = tensor.as_ptr();
let size_in_bytes = tensor.storage_size_in_bytes();
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
w.write_all(data)?;
let padding = 31 - (31 + size_in_bytes) % 32;
w.write_all(&vec![0u8; padding])?;
}
Ok(())
}

File diff suppressed because it is too large Load Diff

View File

@ -1,13 +1,7 @@
use crate::{Device, Result, Shape, Tensor};
#[cfg(target_feature = "avx")]
pub mod avx;
pub mod ggml_file;
pub mod gguf_file;
pub mod k_quants;
#[cfg(target_feature = "neon")]
pub mod neon;
pub mod utils;
pub use k_quants::GgmlType;
@ -16,7 +10,7 @@ pub struct QTensor {
shape: Shape,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GgmlDType {
F32,
F16,
@ -56,27 +50,7 @@ impl GgmlDType {
Ok(dtype)
}
pub(crate) fn to_u32(self) -> u32 {
match self {
Self::F32 => 0,
Self::F16 => 1,
Self::Q4_0 => 2,
Self::Q4_1 => 3,
Self::Q5_0 => 6,
Self::Q5_1 => 7,
Self::Q8_0 => 8,
Self::Q8_1 => 9,
Self::Q2K => 10,
Self::Q3K => 11,
Self::Q4K => 12,
Self::Q5K => 13,
Self::Q6K => 14,
Self::Q8K => 15,
}
}
/// The type size for blocks in bytes.
pub fn type_size(&self) -> usize {
fn type_size(&self) -> usize {
use k_quants::*;
match self {
Self::F32 => 4,
@ -97,8 +71,7 @@ impl GgmlDType {
}
}
/// The block size, i.e. the number of elements stored in each block.
pub fn blck_size(&self) -> usize {
fn blck_size(&self) -> usize {
match self {
Self::F32 => 1,
Self::F16 => 1,
@ -118,8 +91,6 @@ pub trait QuantizedType: Send + Sync {
fn dtype(&self) -> GgmlDType;
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
fn to_float(&self, ys: &mut [f32]) -> Result<()>;
fn storage_size_in_bytes(&self) -> usize;
fn as_ptr(&self) -> *const u8;
}
impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
@ -134,14 +105,6 @@ impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
fn to_float(&self, ys: &mut [f32]) -> Result<()> {
T::to_float(self.as_slice(), ys)
}
fn storage_size_in_bytes(&self) -> usize {
self.len() * std::mem::size_of::<T>()
}
fn as_ptr(&self) -> *const u8 {
self.as_ptr() as *const u8
}
}
impl std::fmt::Debug for QTensor {
@ -150,62 +113,21 @@ impl std::fmt::Debug for QTensor {
}
}
fn check_shape<T: k_quants::GgmlType>(shape: &Shape) -> Result<()> {
let dims = shape.dims();
if dims.is_empty() {
crate::bail!("scalar tensor cannot be quantized {shape:?}")
}
if dims[dims.len() - 1] % T::BLCK_SIZE != 0 {
crate::bail!(
"quantized tensor must have their last dim divisible by block size {shape:?} {}",
T::BLCK_SIZE
)
}
Ok(())
}
impl QTensor {
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
data: Vec<T>,
shape: S,
) -> Result<Self> {
let shape = shape.into();
check_shape::<T>(&shape)?;
Ok(Self {
) -> Self {
Self {
data: Box::new(data),
shape,
})
}
pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
let shape = src.shape();
check_shape::<T>(shape)?;
let src = src
.to_dtype(crate::DType::F32)?
.flatten_all()?
.to_vec1::<f32>()?;
if src.len() % T::BLCK_SIZE != 0 {
crate::bail!(
"tensor size ({shape:?}) is not divisible by block size {}",
T::BLCK_SIZE
)
shape: shape.into(),
}
let mut data = vec![T::zeros(); src.len() / T::BLCK_SIZE];
T::from_float(&src, &mut data)?;
Ok(Self {
data: Box::new(data),
shape: shape.clone(),
})
}
pub fn dtype(&self) -> GgmlDType {
self.data.dtype()
}
pub fn rank(&self) -> usize {
self.shape.rank()
}
pub fn shape(&self) -> &Shape {
&self.shape
}
@ -219,34 +141,18 @@ impl QTensor {
pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
self.data.matmul_t(mkn, lhs, dst)
}
pub fn storage_size_in_bytes(&self) -> usize {
self.data.storage_size_in_bytes()
}
pub fn as_ptr(&self) -> *const u8 {
self.data.as_ptr()
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct QMatMul(std::sync::Arc<QTensor>);
impl QMatMul {
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Self {
pub fn new(qtensor: std::sync::Arc<QTensor>) -> Self {
Self(qtensor)
}
pub fn from_qtensor(qtensor: QTensor) -> Self {
Self(std::sync::Arc::new(qtensor))
}
pub fn inner(&self) -> &std::sync::Arc<QTensor> {
&self.0
}
}
impl crate::CustomOp1 for QTensor {
impl crate::CustomOp1 for QMatMul {
fn name(&self) -> &'static str {
"qmatmul"
}
@ -260,15 +166,17 @@ impl crate::CustomOp1 for QTensor {
crate::bail!("input tensor is not contiguous {layout:?}")
}
let src_shape = layout.shape();
// self is transposed so n is first then k.
let (n, k) = self.shape.dims2()?;
let (k, n) = self.0.shape.dims2()?;
if src_shape.rank() < 2 {
crate::bail!("input tensor has only one dimension {layout:?}")
}
let mut dst_shape = src_shape.dims().to_vec();
let last_k = dst_shape.pop().unwrap();
if last_k != k {
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
crate::bail!(
"input tensor {layout:?} incompatible with {:?}",
self.0.shape
)
}
dst_shape.push(n);
let dst_shape = Shape::from(dst_shape);
@ -276,7 +184,7 @@ impl crate::CustomOp1 for QTensor {
let storage =
&storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
let mut dst_storage = vec![0f32; dst_shape.elem_count()];
self.matmul_t(
self.0.matmul_t(
(dst_shape.elem_count() / n, k, n),
storage,
&mut dst_storage,
@ -284,9 +192,3 @@ impl crate::CustomOp1 for QTensor {
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
}
}
impl QMatMul {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply_op1_no_bwd(self.0.as_ref())
}
}

View File

@ -1,727 +0,0 @@
use super::k_quants::{
BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K,
};
use crate::Result;
use byteorder::{ByteOrder, LittleEndian};
#[allow(unused_imports)]
#[cfg(target_arch = "arm")]
use core::arch::arm::*;
#[allow(unused_imports)]
#[cfg(target_arch = "aarch64")]
use core::arch::aarch64::*;
#[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
let nb = n / qk;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}
unsafe {
let mut sumv0 = vdupq_n_f32(0.0f32);
let mut sumv1 = vdupq_n_f32(0.0f32);
for i in (0..nb).step_by(2) {
let x0 = &xs[i];
let x1 = &xs[i + 1];
let y0 = &ys[i];
let y1 = &ys[i + 1];
let m4b = vdupq_n_u8(0x0F);
let s8b = vdupq_n_s8(0x8);
let v0_0 = vld1q_u8(x0.qs.as_ptr());
let v0_1 = vld1q_u8(x1.qs.as_ptr());
// 4-bit -> 8-bit
let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
let v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
let v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
// sub 8
let v0_0ls = vsubq_s8(v0_0l, s8b);
let v0_0hs = vsubq_s8(v0_0h, s8b);
let v0_1ls = vsubq_s8(v0_1l, s8b);
let v0_1hs = vsubq_s8(v0_1h, s8b);
// load y
let v1_0l = vld1q_s8(y0.qs.as_ptr());
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
let v1_1l = vld1q_s8(y1.qs.as_ptr());
let v1_1h = vld1q_s8(y1.qs.as_ptr().add(16));
// TODO: Support dotprod when it's available outside of nightly.
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
let pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
let pl1l = vmull_s8(vget_low_s8(v0_1ls), vget_low_s8(v1_1l));
let pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
let ph1l = vmull_s8(vget_low_s8(v0_1hs), vget_low_s8(v1_1h));
let ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
let pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
let ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
sumv0 = vmlaq_n_f32(
sumv0,
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
x0.d.to_f32() * y0.d.to_f32(),
);
sumv1 = vmlaq_n_f32(
sumv1,
vcvtq_f32_s32(vaddq_s32(pl1, ph1)),
x1.d.to_f32() * y1.d.to_f32(),
);
}
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
}
}
#[inline(always)]
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
}
let nb = n / QK8_0;
if nb % 2 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
}
unsafe {
let mut sumv0 = vdupq_n_f32(0.0f32);
let mut sumv1 = vdupq_n_f32(0.0f32);
for i in (0..nb).step_by(2) {
let x0 = &xs[i];
let x1 = &xs[i + 1];
let y0 = &ys[i];
let y1 = &ys[i + 1];
let x0_0 = vld1q_s8(x0.qs.as_ptr());
let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16));
let x1_0 = vld1q_s8(x1.qs.as_ptr());
let x1_1 = vld1q_s8(x1.qs.as_ptr().add(16));
// load y
let y0_0 = vld1q_s8(y0.qs.as_ptr());
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
let y1_0 = vld1q_s8(y1.qs.as_ptr());
let y1_1 = vld1q_s8(y1.qs.as_ptr().add(16));
// TODO dotprod once this is the intrinsics are.
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
let p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
let p1_0 = vmull_s8(vget_low_s8(x1_0), vget_low_s8(y1_0));
let p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
let p1_2 = vmull_s8(vget_low_s8(x1_1), vget_low_s8(y1_1));
let p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
let p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
let p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
sumv0 = vmlaq_n_f32(
sumv0,
vcvtq_f32_s32(vaddq_s32(p0, p1)),
x0.d.to_f32() * y0.d.to_f32(),
);
sumv1 = vmlaq_n_f32(
sumv1,
vcvtq_f32_s32(vaddq_s32(p2, p3)),
x1.d.to_f32() * y1.d.to_f32(),
);
}
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
}
}
#[inline(always)]
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
}
let mut sum = 0f32;
unsafe {
let m4b = vdupq_n_u8(0xF);
let mone = vdupq_n_u8(3);
for (x, y) in xs.iter().zip(ys.iter()) {
let d_all = x.d.to_f32();
let mut q6 = x.ql.as_ptr();
let mut qh = x.qh.as_ptr();
let mut q8 = y.qs.as_ptr();
let mut scale = x.scales.as_ptr();
let q8sums = vld1q_s16_x2(y.bsums.as_ptr());
let scales = vld1q_s8(scale);
let q6scales = int16x8x2_t(
vmovl_s8(vget_low_s8(scales)),
vmovl_s8(vget_high_s8(scales)),
);
let prod = vaddq_s32(
vaddq_s32(
vmull_s16(vget_low_s16(q8sums.0), vget_low_s16(q6scales.0)),
vmull_s16(vget_high_s16(q8sums.0), vget_high_s16(q6scales.0)),
),
vaddq_s32(
vmull_s16(vget_low_s16(q8sums.1), vget_low_s16(q6scales.1)),
vmull_s16(vget_high_s16(q8sums.1), vget_high_s16(q6scales.1)),
),
);
let isum_mins = vaddvq_s32(prod);
let mut isum = 0i32;
for _j in 0..QK_K / 128 {
let qhbits = vld1q_u8_x2(qh);
qh = qh.add(32);
let q6bits = vld1q_u8_x4(q6);
q6 = q6.add(64);
let q8bytes = vld1q_s8_x4(q8);
q8 = q8.add(64);
let q6h_0 = vshlq_n_u8(vandq_u8(mone, qhbits.0), 4);
let q6h_1 = vshlq_n_u8(vandq_u8(mone, qhbits.1), 4);
let shifted = vshrq_n_u8(qhbits.0, 2);
let q6h_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
let shifted = vshrq_n_u8(qhbits.1, 2);
let q6h_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
let q6bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.0, m4b), q6h_0));
let q6bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.1, m4b), q6h_1));
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2));
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3));
// TODO: dotprod
let p0 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
scale = scale.add(2);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
scale = scale.add(2);
let q8bytes = vld1q_s8_x4(q8);
q8 = q8.add(64);
let shifted = vshrq_n_u8(qhbits.0, 4);
let q6h_0 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
let shifted = vshrq_n_u8(qhbits.1, 4);
let q6h_1 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
let shifted = vshrq_n_u8(qhbits.0, 6);
let q6h_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
let shifted = vshrq_n_u8(qhbits.1, 6);
let q6h_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
let q6bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.0, 4), q6h_0));
let q6bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.1, 4), q6h_1));
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2));
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3));
// TODO: dotprod case.
let p0 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
scale = scale.add(2);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
scale = scale.add(2);
}
sum += d_all * y.d * ((isum - 32 * isum_mins) as f32);
}
}
Ok(sum)
}
#[inline(always)]
pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
}
let mut sumf = 0f32;
let mut utmp = [0u32; 4];
const KMASK1: u32 = 0x3f3f3f3f;
const KMASK2: u32 = 0x0f0f0f0f;
const KMASK3: u32 = 0x03030303;
unsafe {
let m4b = vdupq_n_u8(0xF);
let mone = vdupq_n_u8(1);
let mtwo = vdupq_n_u8(2);
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let dmin = y.d * x.dmin.to_f32();
let q8sums = vpaddq_s16(
vld1q_s16(y.bsums.as_ptr()),
vld1q_s16(y.bsums.as_ptr().add(8)),
);
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
let uaux = utmp[1] & KMASK1;
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
utmp[2] = uaux;
utmp[0] &= KMASK1;
let mins8 = vld1_u8((utmp.as_ptr() as *const u8).add(8));
let mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
let prod = vaddq_s32(
vmull_s16(vget_low_s16(q8sums), vget_low_s16(mins)),
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)),
);
let sumi_mins = vaddvq_s32(prod);
let mut scales = utmp.as_ptr() as *const u8;
let mut q5 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();
let mut qhbits = vld1q_u8_x2(x.qh.as_ptr());
let mut sumi = 0i32;
for _j in 0..QK_K / 64 {
let q5bits = vld1q_u8_x2(q5);
q5 = q5.add(32);
let q8bytes = vld1q_s8_x4(q8);
q8 = q8.add(64);
let q5h_0 = vshlq_n_u8(vandq_u8(mone, qhbits.0), 4);
let q5h_1 = vshlq_n_u8(vandq_u8(mone, qhbits.1), 4);
let q5h_2 = vshlq_n_u8(vandq_u8(mtwo, qhbits.0), 3);
let q5h_3 = vshlq_n_u8(vandq_u8(mtwo, qhbits.1), 3);
qhbits.0 = vshrq_n_u8(qhbits.0, 2);
qhbits.1 = vshrq_n_u8(qhbits.1, 2);
let q5bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.0, m4b), q5h_0));
let q5bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.1, m4b), q5h_1));
let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
// TODO: dotprod
let p0 = vaddq_s16(
vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)),
);
sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32;
scales = scales.add(1);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)),
vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)),
vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)),
);
sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32;
scales = scales.add(1);
}
sumf += d * sumi as f32 - dmin * sumi_mins as f32;
}
}
Ok(sumf)
}
#[inline(always)]
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
}
let mut sumf = 0f32;
let mut utmp = [0u32; 4];
let mut scales = [0u8; 16];
const KMASK1: u32 = 0x3f3f3f3f;
const KMASK2: u32 = 0x0f0f0f0f;
const KMASK3: u32 = 0x03030303;
unsafe {
let m4b = vdupq_n_u8(0xF);
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let dmin = y.d * x.dmin.to_f32();
let q8sums = vpaddq_s16(
vld1q_s16(y.bsums.as_ptr()),
vld1q_s16(y.bsums.as_ptr().add(8)),
);
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
let mins8 = vld1_u32(
[
utmp[1] & KMASK1,
((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4),
]
.as_ptr(),
);
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
utmp[0] &= KMASK1;
let mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
let prod = vaddq_s32(
vmull_s16(vget_low_s16(q8sums), vget_low_s16(mins)),
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)),
);
sumf -= dmin * vaddvq_s32(prod) as f32;
LittleEndian::write_u32_into(&utmp, &mut scales);
let mut q4 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();
let mut sumi1 = 0i32;
let mut sumi2 = 0i32;
for j in 0..QK_K / 64 {
let q4bits = vld1q_u8_x2(q4);
q4 = q4.add(32);
// TODO: dotprod
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
let q4bytes = int8x16x2_t(
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
);
let p0 = vaddq_s16(
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
);
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32;
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
let q4bytes = int8x16x2_t(
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
);
sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32;
}
sumf += d * (sumi1 + sumi2) as f32;
}
}
Ok(sumf)
}
#[inline(always)]
pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
}
let mut sumf = 0f32;
let mut utmp = [0u32; 4];
let mut aux = [0u32; 3];
const KMASK1: u32 = 0x03030303;
const KMASK2: u32 = 0x0f0f0f0f;
unsafe {
let m3b = vdupq_n_u8(0x3);
let m0 = vdupq_n_u8(1);
let m1 = vshlq_n_u8(m0, 1);
let m2 = vshlq_n_u8(m0, 2);
let m3 = vshlq_n_u8(m0, 3);
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let mut q3 = x.qs.as_ptr();
let qh = x.hmask.as_ptr();
let mut q8 = y.qs.as_ptr();
let mut qhbits = vld1q_u8_x2(qh);
let mut isum = 0i32;
// Set up scales
LittleEndian::read_u32_into(&x.scales, &mut aux);
utmp[3] = ((aux[1] >> 4) & KMASK2) | (((aux[2] >> 6) & KMASK1) << 4);
utmp[2] = ((aux[0] >> 4) & KMASK2) | (((aux[2] >> 4) & KMASK1) << 4);
utmp[1] = (aux[1] & KMASK2) | (((aux[2] >> 2) & KMASK1) << 4);
utmp[0] = (aux[0] & KMASK2) | ((aux[2] & KMASK1) << 4);
let mut scale = utmp.as_mut_ptr() as *mut i8;
for j in 0..16 {
*scale.add(j) -= 32i8
}
for j in 0..QK_K / 128 {
let q3bits = vld1q_u8_x2(q3);
q3 = q3.add(32);
let q8bytes_1 = vld1q_s8_x4(q8);
q8 = q8.add(64);
let q8bytes_2 = vld1q_s8_x4(q8);
q8 = q8.add(64);
let q3h_0 = vshlq_n_u8(vbicq_u8(m0, qhbits.0), 2);
let q3h_1 = vshlq_n_u8(vbicq_u8(m0, qhbits.1), 2);
let q3h_2 = vshlq_n_u8(vbicq_u8(m1, qhbits.0), 1);
let q3h_3 = vshlq_n_u8(vbicq_u8(m1, qhbits.1), 1);
let q3bytes_0 = vsubq_s8(
vreinterpretq_s8_u8(vandq_u8(q3bits.0, m3b)),
vreinterpretq_s8_u8(q3h_0),
);
let q3bytes_1 = vsubq_s8(
vreinterpretq_s8_u8(vandq_u8(q3bits.1, m3b)),
vreinterpretq_s8_u8(q3h_1),
);
let q3bytes_2 = vsubq_s8(
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 2), m3b)),
vreinterpretq_s8_u8(q3h_2),
);
let q3bytes_3 = vsubq_s8(
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 2), m3b)),
vreinterpretq_s8_u8(q3h_3),
);
// TODO: dotprod
let p0 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)),
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)),
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)),
);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)),
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)),
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)),
);
isum += vaddvq_s16(p0) as i32 * *scale as i32
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
scale = scale.add(4);
let q3h_0 = vbicq_u8(m2, qhbits.0);
let q3h_1 = vbicq_u8(m2, qhbits.1);
let q3h_2 = vshrq_n_u8(vbicq_u8(m3, qhbits.0), 1);
let q3h_3 = vshrq_n_u8(vbicq_u8(m3, qhbits.1), 1);
let q3bytes_0 = vsubq_s8(
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 4), m3b)),
vreinterpretq_s8_u8(q3h_0),
);
let q3bytes_1 = vsubq_s8(
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 4), m3b)),
vreinterpretq_s8_u8(q3h_1),
);
let q3bytes_2 = vsubq_s8(
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 6), m3b)),
vreinterpretq_s8_u8(q3h_2),
);
let q3bytes_3 = vsubq_s8(
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 6), m3b)),
vreinterpretq_s8_u8(q3h_3),
);
// TODO: dotprod
let p0 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)),
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)),
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)),
);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)),
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)),
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)),
);
isum += vaddvq_s16(p0) as i32 * *scale as i32
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
scale = scale.add(4);
if j == 0 {
qhbits.0 = vshrq_n_u8(qhbits.0, 4);
qhbits.1 = vshrq_n_u8(qhbits.1, 4);
}
}
sumf += d * isum as f32;
}
}
Ok(sumf)
}
#[inline(always)]
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
}
let mut sumf = 0f32;
let mut aux = [0u8; 16];
unsafe {
let m3 = vdupq_n_u8(0x3);
let m4 = vdupq_n_u8(0xF);
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let dmin = -y.d * x.dmin.to_f32();
let mut q2 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();
let sc = x.scales.as_ptr();
let mins_and_scales = vld1q_u8(sc);
let scales = vandq_u8(mins_and_scales, m4);
vst1q_u8(aux.as_mut_ptr(), scales);
let mins = vshrq_n_u8(mins_and_scales, 4);
let q8sums = vld1q_s16_x2(y.bsums.as_ptr());
let mins16 = int16x8x2_t(
vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))),
vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins))),
);
let s0 = vaddq_s32(
vmull_s16(vget_low_s16(mins16.0), vget_low_s16(q8sums.0)),
vmull_s16(vget_high_s16(mins16.0), vget_high_s16(q8sums.0)),
);
let s1 = vaddq_s32(
vmull_s16(vget_low_s16(mins16.1), vget_low_s16(q8sums.1)),
vmull_s16(vget_high_s16(mins16.1), vget_high_s16(q8sums.1)),
);
sumf += dmin * vaddvq_s32(vaddq_s32(s0, s1)) as f32;
let mut isum = 0i32;
let mut is = 0usize;
// TODO: dotprod
for _j in 0..QK_K / 128 {
let q2bits = vld1q_u8_x2(q2);
q2 = q2.add(32);
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
let mut q2bytes = int8x16x2_t(
vreinterpretq_s8_u8(vandq_u8(q2bits.0, m3)),
vreinterpretq_s8_u8(vandq_u8(q2bits.1, m3)),
);
isum += multiply_accum_with_scale(&aux, is, 0, q2bytes, q8bytes);
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 2), m3));
q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 2), m3));
isum += multiply_accum_with_scale(&aux, is, 2, q2bytes, q8bytes);
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 4), m3));
q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 4), m3));
isum += multiply_accum_with_scale(&aux, is, 4, q2bytes, q8bytes);
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 6), m3));
q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 6), m3));
isum += multiply_accum_with_scale(&aux, is, 6, q2bytes, q8bytes);
is += 8;
}
sumf += d * isum as f32;
}
}
Ok(sumf)
}
#[inline(always)]
unsafe fn multiply_accum_with_scale(
aux: &[u8; 16],
is: usize,
index: usize,
q2bytes: int8x16x2_t,
q8bytes: int8x16x2_t,
) -> i32 {
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)),
);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)),
);
vaddvq_s16(p1) as i32 * aux[is + index] as i32
+ vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32
}

View File

@ -1,326 +0,0 @@
use crate::Result;
pub(super) fn nearest_int(v: f32) -> i32 {
v.round() as i32
}
/// Validates that the input and output are the right size and returns an iterator which maps each
/// input region `xs` to its corresponding output block in `ys`. Each output region is guaranteed
/// to be `T::BLCK_SIZE` long.
pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
xs: &'b [f32],
ys: &'a mut [T],
) -> Result<Vec<(&'a mut T, &'b [f32])>> {
let block_size = T::BLCK_SIZE;
let dtype = T::DTYPE;
let expected_blocks = xs.len() / block_size;
let actual_blocks = ys.len();
//validate that the input is the right size
if expected_blocks != actual_blocks {
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
}
Ok(ys.iter_mut().zip(xs.chunks_exact(block_size)).collect())
}
/// Validates that the input and output are the right size and returns an iterator which maps each
/// input block `xs` to its corresponding output region in `ys`. Each output region is guaranteed
/// to be `T::BLCK_SIZE` long.
pub(super) fn group_for_dequantization<'a, 'b, T: super::k_quants::GgmlType>(
xs: &'a [T],
ys: &'b mut [f32],
) -> Result<Vec<(&'a T, &'b mut [f32])>> {
let block_size = T::BLCK_SIZE;
let dtype = T::DTYPE;
let actual_output_len = ys.len();
let expected_output_len = xs.len() * block_size;
//validate that the output is the right size
if expected_output_len != actual_output_len {
crate::bail!("dequantize {dtype:?}: ys (len = {actual_output_len}) does not match the expected length of {expected_output_len}!")
}
//zip the blocks and outputs together
Ok(xs.iter().zip(ys.chunks_exact_mut(block_size)).collect())
}
pub(super) fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
if j < 4 {
let d = q[j] & 63;
let m = q[j + 4] & 63;
(d, m)
} else {
let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
(d, m)
}
}
pub(super) unsafe fn make_qx_quants(
n: usize,
nmax: i32,
x: *const f32,
ls: *mut i8,
rmse_type: i32,
) -> f32 {
let mut max = 0f32;
let mut amax = 0f32;
for i in 0..n {
let x = *x.add(i);
let ax = x.abs();
if ax > amax {
amax = ax;
max = x;
}
}
if amax == 0. {
// all zero
for i in 0..n {
*ls.add(i) = 0;
}
return 0.;
}
let mut iscale = -(nmax as f32) / max;
if rmse_type == 0 {
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
}
return 1.0 / iscale;
}
let weight_type = rmse_type % 2;
let mut sumlx = 0f32;
let mut suml2 = 0f32;
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
let l = l.clamp(-nmax, nmax - 1);
*ls.add(i) = (l + nmax) as i8;
let w = if weight_type == 1 { x * x } else { 1.0 };
let l = l as f32;
sumlx += w * x * l;
suml2 += w * l * l;
}
let mut scale = sumlx / suml2;
let mut best = scale * sumlx;
for _itry in 0..3 {
let iscale = 1.0 / scale;
let mut slx = 0f32;
let mut sl2 = 0f32;
let mut changed = false;
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
let l = l.clamp(-nmax, nmax - 1);
if l + nmax != *ls.add(i) as i32 {
changed = true;
}
let w = if weight_type == 1 { x * x } else { 1f32 };
let l = l as f32;
slx += w * x * l;
sl2 += w * l * l;
}
if !changed || sl2 == 0.0 || slx * slx <= best * sl2 {
break;
}
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
}
sumlx = slx;
suml2 = sl2;
scale = sumlx / suml2;
best = scale * sumlx;
}
for _itry in 0..5 {
let mut n_changed = 0;
for i in 0..n {
let x = *x.add(i);
let w = if weight_type == 1 { x * x } else { 1. };
let l = *ls.add(i) as i32 - nmax;
let mut slx = sumlx - w * x * l as f32;
if slx > 0. {
let mut sl2 = suml2 - w * l as f32 * l as f32;
let new_l = nearest_int(x * sl2 / slx);
let new_l = new_l.clamp(-nmax, nmax - 1);
if new_l != l {
slx += w * x * new_l as f32;
sl2 += w * new_l as f32 * new_l as f32;
if sl2 > 0. && slx * slx * suml2 > sumlx * sumlx * sl2 {
*ls.add(i) = (nmax + new_l) as i8;
sumlx = slx;
suml2 = sl2;
scale = sumlx / suml2;
best = scale * sumlx;
n_changed += 1;
}
}
}
}
if n_changed == 0 {
break;
}
}
if rmse_type < 3 {
return scale;
}
for is in -4..4 {
if is == 0 {
continue;
}
iscale = -(nmax as f32 + 0.1f32 * is as f32) / max;
let mut sumlx = 0.;
let mut suml2 = 0.;
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
let l = l.clamp(-nmax, nmax - 1);
let w = if weight_type == 1 { x * x } else { 1. };
let l = l as f32;
sumlx += w * x * l;
suml2 += w * l * l;
}
if suml2 > 0. && sumlx * sumlx > best * suml2 {
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
}
scale = sumlx / suml2;
best = scale * sumlx;
}
}
scale
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L224
pub(super) fn make_qkx1_quants(nmax: i32, ntry: usize, x: &[f32]) -> (f32, f32) {
let n = x.len();
let mut l = vec![0; n];
// Get min/max
let min = *x
.iter()
.take(n)
.min_by(|a, b| a.total_cmp(b))
.unwrap_or(&x[0]);
let max = *x.iter().max_by(|a, b| a.total_cmp(b)).unwrap_or(&x[0]);
// If min == max, all values are the same => nothing to do here
if max == min {
return (0.0, 0.0);
}
// Ensure min <= 0.0
let mut min = min.min(0.);
// Compute scale and inverse scale
let mut iscale = nmax as f32 / (max - min);
let mut scale = 1.0 / iscale;
for _ in 0..ntry {
let mut sumlx = 0.0;
let mut suml2 = 0;
let mut did_change = false;
for (i, value) in x.iter().enumerate().take(n) {
let li = nearest_int(iscale * (value - min)).clamp(0, nmax);
let clamped_li = li as u8;
if clamped_li != l[i] {
l[i] = clamped_li;
did_change = true;
}
sumlx += (value - min) * li as f32;
suml2 += li * li;
}
scale = sumlx / suml2 as f32;
let sum: f32 = x
.iter()
.take(n)
.zip(l.iter().take(n))
.map(|(xi, &li)| xi - scale * li as f32)
.sum();
min = sum / n as f32;
if min > 0.0 {
min = 0.0;
}
iscale = 1.0 / scale;
if !did_change {
break;
}
}
(scale, -min)
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L165
pub(super) fn make_q3_quants(x: &[f32], nmax: i32, do_rmse: bool) -> f32 {
let n = x.len();
let mut l = vec![0i8; n];
let mut max = 0.0;
let mut amax = 0.0;
for &xi in x.iter().take(n) {
let ax = xi.abs();
if ax > amax {
amax = ax;
max = xi;
}
}
if amax == 0.0 {
return 0.0;
}
let iscale = -(nmax as f32) / max;
if do_rmse {
let mut sumlx = 0.0;
let mut suml2 = 0.0;
for i in 0..n {
let li = (iscale * x[i]).round() as i32;
let li = li.clamp(-nmax, nmax - 1);
l[i] = li as i8;
let w = x[i] * x[i];
sumlx += w * x[i] * li as f32;
suml2 += w * (li * li) as f32;
}
for _ in 0..5 {
let mut n_changed = 0;
for i in 0..n {
let w = x[i] * x[i];
let mut slx = sumlx - w * x[i] * l[i] as f32;
if slx > 0.0 {
let mut sl2 = suml2 - w * (l[i] as i32 * l[i] as i32) as f32;
let mut new_l = (x[i] * sl2 / slx).round() as i32;
new_l = new_l.clamp(-nmax, nmax - 1);
if new_l != l[i] as i32 {
slx += w * x[i] * new_l as f32;
sl2 += w * (new_l * new_l) as f32;
if sl2 > 0.0 && slx * slx * suml2 > sumlx * sumlx * sl2 {
l[i] = new_l as i8;
sumlx = slx;
suml2 = sl2;
n_changed += 1;
}
}
}
}
if n_changed == 0 {
break;
}
}
for li in l.iter_mut() {
*li += nmax as i8;
}
return sumlx / suml2;
}
for i in 0..n {
let li = (iscale * x[i]).round() as i32;
l[i] = (li.clamp(-nmax, nmax - 1) + nmax) as i8;
}
1.0 / iscale
}

View File

@ -10,7 +10,6 @@ impl From<DType> for st::Dtype {
match value {
DType::U8 => st::Dtype::U8,
DType::U32 => st::Dtype::U32,
DType::I64 => st::Dtype::I64,
DType::BF16 => st::Dtype::BF16,
DType::F16 => st::Dtype::F16,
DType::F32 => st::Dtype::F32,
@ -25,7 +24,6 @@ impl TryFrom<st::Dtype> for DType {
match value {
st::Dtype::U8 => Ok(DType::U8),
st::Dtype::U32 => Ok(DType::U32),
st::Dtype::I64 => Ok(DType::I64),
st::Dtype::BF16 => Ok(DType::BF16),
st::Dtype::F16 => Ok(DType::F16),
st::Dtype::F32 => Ok(DType::F32),
@ -191,7 +189,6 @@ impl Tensor {
match dtype {
DType::U8 => convert_slice::<u8>(data, shape, device),
DType::U32 => convert_slice::<u32>(data, shape, device),
DType::I64 => convert_slice::<i64>(data, shape, device),
DType::BF16 => convert_slice::<half::bf16>(data, shape, device),
DType::F16 => convert_slice::<half::f16>(data, shape, device),
DType::F32 => convert_slice::<f32>(data, shape, device),
@ -208,15 +205,24 @@ fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
convert_with_cast_::<u16, u32, _>(view, device, conv)
}
st::Dtype::U32 => convert_::<u32>(view, device),
st::Dtype::I32 => {
let conv = |x| Ok(i64::from(x));
convert_with_cast_::<i32, i64, _>(view, device, conv)
}
st::Dtype::I64 => convert_::<i64>(view, device),
st::Dtype::BF16 => convert_::<half::bf16>(view, device),
st::Dtype::F16 => convert_::<half::f16>(view, device),
st::Dtype::F32 => convert_::<f32>(view, device),
st::Dtype::F64 => convert_::<f64>(view, device),
st::Dtype::I32 => {
let conv = |x| {
u32::try_from(x)
.map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}")))
};
convert_with_cast_::<i32, u32, _>(view, device, conv)
}
st::Dtype::I64 => {
let conv = |x| {
u32::try_from(x)
.map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}")))
};
convert_with_cast_::<i64, u32, _>(view, device, conv)
}
dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
}
}
@ -227,7 +233,6 @@ fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
match tensor.dtype() {
DType::U8 => Ok(convert_back_::<u8>(tensor.to_vec1()?)),
DType::U32 => Ok(convert_back_::<u32>(tensor.to_vec1()?)),
DType::I64 => Ok(convert_back_::<i64>(tensor.to_vec1()?)),
DType::F16 => Ok(convert_back_::<half::f16>(tensor.to_vec1()?)),
DType::BF16 => Ok(convert_back_::<half::bf16>(tensor.to_vec1()?)),
DType::F32 => Ok(convert_back_::<f32>(tensor.to_vec1()?)),

View File

@ -1,23 +0,0 @@
use crate::{Result, Tensor, WithDType};
pub enum TensorScalar {
Tensor(Tensor),
Scalar(Tensor),
}
pub trait TensorOrScalar {
fn to_tensor_scalar(self) -> Result<TensorScalar>;
}
impl TensorOrScalar for &Tensor {
fn to_tensor_scalar(self) -> Result<TensorScalar> {
Ok(TensorScalar::Tensor(self.clone()))
}
}
impl<T: WithDType> TensorOrScalar for T {
fn to_tensor_scalar(self) -> Result<TensorScalar> {
let scalar = Tensor::new(self, &crate::Device::Cpu)?;
Ok(TensorScalar::Scalar(scalar))
}
}

View File

@ -1,5 +1,3 @@
//! The shape of a tensor is a tuple with the size of each of its dimensions.
#![allow(clippy::redundant_closure_call)]
use crate::{Error, Result};
#[derive(Clone, PartialEq, Eq)]
@ -73,14 +71,6 @@ impl From<(usize, usize, usize, usize, usize)> for Shape {
}
}
impl From<(usize, usize, usize, usize, usize, usize)> for Shape {
fn from(d123456: (usize, usize, usize, usize, usize, usize)) -> Self {
Self(vec![
d123456.0, d123456.1, d123456.2, d123456.3, d123456.4, d123456.5,
])
}
}
impl From<Vec<usize>> for Shape {
fn from(dims: Vec<usize>) -> Self {
Self(dims)
@ -128,7 +118,6 @@ impl Shape {
Self(dims.to_vec())
}
/// The rank is the number of dimensions, 0 for a scalar value, 1 for a vector, etc.
pub fn rank(&self) -> usize {
self.0.len()
}
@ -137,12 +126,10 @@ impl Shape {
self.0
}
/// The dimensions as a slice of `usize`.
pub fn dims(&self) -> &[usize] {
&self.0
}
/// The total number of elements, this is the product of all dimension sizes.
pub fn elem_count(&self) -> usize {
self.0.iter().product()
}
@ -194,75 +181,10 @@ impl Shape {
true
}
/// Modifies the shape by adding a list of additional dimensions at the end of the existing
/// dimensions.
pub fn extend(mut self, additional_dims: &[usize]) -> Self {
self.0.extend(additional_dims);
self
}
/// Check whether the two shapes are compatible for broadcast, and if it is the case return the
/// broadcasted shape. This is to be used for binary pointwise ops.
pub(crate) fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
let lhs = self;
let lhs_dims = lhs.dims();
let rhs_dims = rhs.dims();
let lhs_ndims = lhs_dims.len();
let rhs_ndims = rhs_dims.len();
let bcast_ndims = usize::max(lhs_ndims, rhs_ndims);
let mut bcast_dims = vec![0; bcast_ndims];
for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() {
let rev_idx = bcast_ndims - idx;
let l_value = if lhs_ndims < rev_idx {
1
} else {
lhs_dims[lhs_ndims - rev_idx]
};
let r_value = if rhs_ndims < rev_idx {
1
} else {
rhs_dims[rhs_ndims - rev_idx]
};
*bcast_value = if l_value == r_value {
l_value
} else if l_value == 1 {
r_value
} else if r_value == 1 {
l_value
} else {
Err(Error::ShapeMismatchBinaryOp {
lhs: lhs.clone(),
rhs: rhs.clone(),
op,
}
.bt())?
}
}
Ok(Shape::from(bcast_dims))
}
pub(crate) fn broadcast_shape_matmul(&self, rhs: &Self) -> Result<(Shape, Shape)> {
let lhs = self;
let lhs_dims = lhs.dims();
let rhs_dims = rhs.dims();
if lhs_dims.len() < 2 || rhs_dims.len() < 2 {
crate::bail!("only 2d matrixes are supported {lhs:?} {rhs:?}")
}
let (m, lhs_k) = (lhs_dims[lhs_dims.len() - 2], lhs_dims[lhs_dims.len() - 1]);
let (rhs_k, n) = (rhs_dims[rhs_dims.len() - 2], rhs_dims[rhs_dims.len() - 1]);
if lhs_k != rhs_k {
crate::bail!("different inner dimensions in broadcast matmul {lhs:?} {rhs:?}")
}
let lhs_b = Self::from(&lhs_dims[..lhs_dims.len() - 2]);
let rhs_b = Self::from(&rhs_dims[..rhs_dims.len() - 2]);
let bcast = lhs_b.broadcast_shape_binary_op(&rhs_b, "broadcast_matmul")?;
let bcast_dims = bcast.dims();
let bcast_lhs = [bcast_dims, &[m, lhs_k]].concat();
let bcast_rhs = [bcast_dims, &[rhs_k, n]].concat();
Ok((Shape::from(bcast_lhs), Shape::from(bcast_rhs)))
}
}
pub trait Dim {
@ -423,27 +345,6 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
}
}
impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim> Dims for (D1, D2, D3, D4) {
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
let d0 = self.0.to_index(shape, op)?;
let d1 = self.1.to_index(shape, op)?;
let d2 = self.2.to_index(shape, op)?;
let d3 = self.3.to_index(shape, op)?;
Ok(vec![d0, d1, d2, d3])
}
}
impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5) {
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
let d0 = self.0.to_index(shape, op)?;
let d1 = self.1.to_index(shape, op)?;
let d2 = self.2.to_index(shape, op)?;
let d3 = self.3.to_index(shape, op)?;
let d4 = self.4.to_index(shape, op)?;
Ok(vec![d0, d1, d2, d3, d4])
}
}
extract_dims!(dims0, 0, |_: &[usize]| (), ());
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
@ -482,171 +383,3 @@ mod tests {
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
}
}
pub trait ShapeWithOneHole {
fn into_shape(self, el_count: usize) -> Result<Shape>;
}
impl<S: Into<Shape>> ShapeWithOneHole for S {
fn into_shape(self, _el_count: usize) -> Result<Shape> {
Ok(self.into())
}
}
impl ShapeWithOneHole for ((),) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
Ok(el_count.into())
}
}
impl ShapeWithOneHole for ((), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1) = self;
if el_count % d1 != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
}
Ok((el_count / d1, d1).into())
}
}
impl ShapeWithOneHole for (usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, ()) = self;
if el_count % d1 != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
}
Ok((d1, el_count / d1).into())
}
}
impl ShapeWithOneHole for ((), usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1, d2) = self;
let d = d1 * d2;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((el_count / d, d1, d2).into())
}
}
impl ShapeWithOneHole for (usize, (), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, (), d2) = self;
let d = d1 * d2;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, el_count / d, d2).into())
}
}
impl ShapeWithOneHole for (usize, usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, ()) = self;
let d = d1 * d2;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, el_count / d).into())
}
}
impl ShapeWithOneHole for ((), usize, usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1, d2, d3) = self;
let d = d1 * d2 * d3;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((el_count / d, d1, d2, d3).into())
}
}
impl ShapeWithOneHole for (usize, (), usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, (), d2, d3) = self;
let d = d1 * d2 * d3;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, el_count / d, d2, d3).into())
}
}
impl ShapeWithOneHole for (usize, usize, (), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, (), d3) = self;
let d = d1 * d2 * d3;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, el_count / d, d3).into())
}
}
impl ShapeWithOneHole for (usize, usize, usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, d3, ()) = self;
let d = d1 * d2 * d3;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, d3, el_count / d).into())
}
}
impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1, d2, d3, d4) = self;
let d = d1 * d2 * d3 * d4;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((el_count / d, d1, d2, d3, d4).into())
}
}
impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, (), d2, d3, d4) = self;
let d = d1 * d2 * d3 * d4;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, el_count / d, d2, d3, d4).into())
}
}
impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, (), d3, d4) = self;
let d = d1 * d2 * d3 * d4;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, el_count / d, d3, d4).into())
}
}
impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, d3, (), d4) = self;
let d = d1 * d2 * d3 * d4;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, d3, el_count / d, d4).into())
}
}
impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, d3, d4, ()) = self;
let d = d1 * d2 * d3 * d4;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, d3, d4, el_count / d).into())
}
}

View File

@ -68,19 +68,6 @@ impl Storage {
}
}
pub(crate) fn powf(&self, layout: &Layout, alpha: f64) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.powf(layout, alpha)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.powf(layout, alpha)?;
Ok(Self::Cuda(storage))
}
}
}
pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
@ -151,7 +138,7 @@ impl Storage {
}
}
pub(crate) fn apply_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
pub(crate) fn custom_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
match self {
Self::Cpu(storage) => {
let (storage, shape) = c.cpu_fwd(storage, l)?;
@ -164,7 +151,7 @@ impl Storage {
}
}
pub(crate) fn apply_op2(
pub(crate) fn custom_op2(
&self,
l1: &Layout,
t2: &Self,
@ -185,7 +172,7 @@ impl Storage {
}
}
pub(crate) fn apply_op3(
pub(crate) fn custom_op3(
&self,
l1: &Layout,
t2: &Self,
@ -306,33 +293,6 @@ impl Storage {
}
}
pub(crate) fn conv_transpose2d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose2D,
) -> Result<Self> {
self.same_device(kernel, "conv_transpose2d")?;
self.same_dtype(kernel, "conv_transpose2d")?;
match (self, &kernel) {
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
Ok(Self::Cpu(s))
}
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "conv_transpose2d",
}
.bt()),
}
}
pub(crate) fn avg_pool2d(
&self,
layout: &Layout,

View File

@ -1,10 +1,7 @@
//! Tensors are N-dimenional matrixes of elements using a single data type.
#![allow(clippy::redundant_closure_call)]
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
};
use crate::scalar::TensorOrScalar;
use crate::shape::{Dim, Dims};
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::{Arc, RwLock};
@ -109,9 +106,7 @@ macro_rules! broadcast_binary_op {
($fn_name:ident, $inner_fn_name:ident) => {
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
let lhs = self;
let shape = lhs
.shape()
.broadcast_shape_binary_op(rhs.shape(), stringify!($fn_name))?;
let shape = lhs.broadcast_shape_binary_op(rhs, stringify!($fn_name))?;
let l_broadcast = shape != *lhs.shape();
let r_broadcast = shape != *rhs.shape();
match (l_broadcast, r_broadcast) {
@ -127,7 +122,7 @@ macro_rules! broadcast_binary_op {
}
/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides.
pub(crate) fn from_storage<S: Into<Shape>>(
fn from_storage<S: Into<Shape>>(
storage: Storage,
shape: S,
op: BackpropOp,
@ -420,6 +415,48 @@ impl Tensor {
Self::new_impl(array, shape.into(), device, false)
}
pub(crate) fn broadcast_shape_binary_op<'a>(
&'a self,
rhs: &'a Self,
op: &'static str,
) -> Result<Shape> {
let lhs = self;
let lhs_dims = lhs.shape().dims();
let rhs_dims = rhs.shape().dims();
let lhs_ndims = lhs_dims.len();
let rhs_ndims = rhs_dims.len();
let bcast_ndims = usize::max(lhs_ndims, rhs_ndims);
let mut bcast_dims = vec![0; bcast_ndims];
for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() {
let rev_idx = bcast_ndims - idx;
let l_value = if lhs_ndims < rev_idx {
1
} else {
lhs_dims[lhs_ndims - rev_idx]
};
let r_value = if rhs_ndims < rev_idx {
1
} else {
rhs_dims[rhs_ndims - rev_idx]
};
*bcast_value = if l_value == r_value {
l_value
} else if l_value == 1 {
r_value
} else if r_value == 1 {
l_value
} else {
Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(),
rhs: rhs.shape().clone(),
op,
}
.bt())?
}
}
Ok(Shape::from(bcast_dims))
}
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
let lhs = self.shape();
let rhs = rhs.shape();
@ -447,14 +484,10 @@ impl Tensor {
binary_op!(mul, Mul);
binary_op!(sub, Sub);
binary_op!(div, Div);
binary_op!(maximum, Maximum);
binary_op!(minimum, Minimum);
broadcast_binary_op!(broadcast_add, add);
broadcast_binary_op!(broadcast_mul, mul);
broadcast_binary_op!(broadcast_sub, sub);
broadcast_binary_op!(broadcast_div, div);
broadcast_binary_op!(broadcast_maximum, maximum);
broadcast_binary_op!(broadcast_minimum, minimum);
unary_op!(recip, Recip);
unary_op!(neg, Neg);
@ -462,7 +495,6 @@ impl Tensor {
unary_op!(log, Log);
unary_op!(sin, Sin);
unary_op!(cos, Cos);
unary_op!(tanh, Tanh);
unary_op!(abs, Abs);
unary_op!(sqr, Sqr);
unary_op!(sqrt, Sqrt);
@ -495,25 +527,6 @@ impl Tensor {
self.to_scalar::<S>()
}
/// Repeat this tensor along the specified dimensions.
pub fn repeat<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
// Similar to PyTorch, we extend the number of dimensions of self if needed.
let repeats = shape.into();
let repeats = repeats.dims();
let mut inp = if self.rank() < repeats.len() {
let shape = [vec![1; repeats.len() - self.rank()], self.dims().to_vec()].concat();
self.reshape(shape)?
} else {
self.clone()
};
for (idx, &repeat) in repeats.iter().enumerate() {
if repeat > 1 {
inp = Tensor::cat(&vec![&inp; repeat], idx)?
}
}
Ok(inp)
}
/// This operation multiplies the input tensor by `mul` then adds `add` and return the result.
/// The input values `mul` and `add` are casted to the appropriate type so some rounding might
/// be performed.
@ -538,13 +551,6 @@ impl Tensor {
Ok(from_storage(storage, self.shape(), op, false))
}
/// Raise the tensor to some float exponent `e`.
pub fn powf(&self, e: f64) -> Result<Self> {
let storage = self.storage().powf(self.layout(), e)?;
let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
Ok(from_storage(storage, self.shape(), op, false))
}
fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
if dim >= self.dims().len() {
Err(Error::DimOutOfRange {
@ -699,58 +705,18 @@ impl Tensor {
self.sum_impl(sum_dims, false)
}
/// Returns the mean of all elements in the input tensor. The mean is performed over all the
/// input dimensions.
///
/// The resulting tensor has a shape that is similar to the shape of the input tensor, except
/// that the number of elements for each dimension index in `mean_dims` is 1.
///
/// ```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
/// let s = a.mean_keepdim(0)?;
/// assert_eq!(s.to_vec2::<f32>()?, &[[1., 2.]]);
/// let s = a.mean_keepdim(1)?;
/// assert_eq!(s.to_vec2::<f32>()?, &[[0.5], [2.5]]);
/// let s = a.mean_keepdim((0, 1))?;
/// assert_eq!(s.to_vec2::<f32>()?, &[[1.5]]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn mean_keepdim<D: Dims>(&self, mean_dims: D) -> Result<Self> {
let mean_dims = mean_dims.to_indexes(self.shape(), "mean-keepdim")?;
let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
let scale = 1f64 / (reduced_dim as f64);
self.sum_impl(mean_dims, true)? * scale
}
/// Returns the mean of all elements in the input tensor. The mean is performed over all the
/// input dimensions and compared to `mean_keepdim` these dimensions are squeezed rather than
/// kept.
pub fn mean<D: Dims>(&self, mean_dims: D) -> Result<Self> {
let mean_dims = mean_dims.to_indexes(self.shape(), "mean")?;
let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
let scale = 1f64 / (reduced_dim as f64);
self.sum_impl(mean_dims, false)? * scale
}
/// Gathers the maximum value across the selected dimension. The resulting shape has the same
/// number of dimensions as the original tensor and the select dimension has a single element.
pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
self.reduce_impl(dim, true, ReduceOp::Max)
}
/// Similar to `max_keepdim` but the target dimension is squeezed.
pub fn max<D: Dim>(&self, dim: D) -> Result<Self> {
self.reduce_impl(dim, false, ReduceOp::Max)
}
/// Gathers the minimum value across the selected dimension. The resulting shape has the same
/// number of dimensions as the original tensor and the select dimension has a single element.
pub fn min_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
self.reduce_impl(dim, true, ReduceOp::Min)
}
/// Similar to `min_keepdim` but the target dimension is squeezed.
pub fn min<D: Dim>(&self, dim: D) -> Result<Self> {
self.reduce_impl(dim, false, ReduceOp::Min)
}
@ -759,7 +725,6 @@ impl Tensor {
self.reduce_impl(dim, true, ReduceOp::ArgMax)
}
/// Similar to `argmax_keepdim` but the target dimension is squeezed.
pub fn argmax<D: Dim>(&self, dim: D) -> Result<Self> {
self.reduce_impl(dim, false, ReduceOp::ArgMax)
}
@ -768,24 +733,12 @@ impl Tensor {
self.reduce_impl(dim, true, ReduceOp::ArgMin)
}
/// Similar to `argmin_keepdim` but the target dimension is squeezed.
pub fn argmin<D: Dim>(&self, dim: D) -> Result<Self> {
self.reduce_impl(dim, false, ReduceOp::ArgMin)
}
/// Element-wise comparison between two tensors, e.g. equality, greater than, ... The actual
/// comparison operation is specified by the `op` argument.
///
/// The returned tensor has the same shape as the original tensors and uses `u8` elements.
pub fn cmp<T: TensorOrScalar>(&self, rhs: T, op: CmpOp) -> Result<Self> {
let rhs = match rhs.to_tensor_scalar()? {
crate::scalar::TensorScalar::Tensor(rhs) => rhs,
crate::scalar::TensorScalar::Scalar(rhs) => rhs
.to_dtype(self.dtype())?
.to_device(self.device())?
.broadcast_as(self.shape())?,
};
let shape = self.same_shape_binary_op(&rhs, "cmp")?;
pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result<Self> {
let shape = self.same_shape_binary_op(rhs, "cmp")?;
let storage = self
.storage()
.cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
@ -793,45 +746,96 @@ impl Tensor {
Ok(from_storage(storage, shape.dims(), op, false))
}
/// Element-wise equality.
pub fn eq<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
pub fn eq(&self, rhs: &Self) -> Result<Self> {
self.cmp(rhs, CmpOp::Eq)
}
/// Element-wise non-equality.
pub fn ne<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
pub fn ne(&self, rhs: &Self) -> Result<Self> {
self.cmp(rhs, CmpOp::Ne)
}
/// Element-wise comparison with lower-than, the returned tensor uses value 1 where `self <
/// rhs` and 0 otherwise.
pub fn lt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
pub fn lt(&self, rhs: &Self) -> Result<Self> {
self.cmp(rhs, CmpOp::Lt)
}
/// Element-wise comparison with greater-than, the returned tensor uses value 1 where `self >
/// rhs` and 0 otherwise.
pub fn gt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
pub fn gt(&self, rhs: &Self) -> Result<Self> {
self.cmp(rhs, CmpOp::Gt)
}
/// Element-wise comparison with greater-equal, the returned tensor uses value 1 where `self >=
/// rhs` and 0 otherwise.
pub fn ge<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
pub fn ge(&self, rhs: &Self) -> Result<Self> {
self.cmp(rhs, CmpOp::Ge)
}
/// Element-wise comparison with lower-equal, the returned tensor uses value 1 where `self <=
/// rhs` and 0 otherwise.
pub fn le<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
pub fn le(&self, rhs: &Self) -> Result<Self> {
self.cmp(rhs, CmpOp::Le)
}
/// Upsample the input tensor to the `(target_h, target_w)` size, taking the value of the
/// nearest element.
///
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
/// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
/// Applies a 1D convolution over the input tensor.
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
let (c_out, c_in_k, k_size) = kernel.dims3()?;
let (b_size, c_in, l_in) = self.dims3()?;
if c_in != c_in_k {
Err(Error::Conv1dInvalidArgs {
inp_shape: self.shape().clone(),
k_shape: kernel.shape().clone(),
padding,
stride,
msg: "the number of in-channels on the input doesn't match the kernel size",
}
.bt())?
}
let params = crate::conv::ParamsConv1D {
b_size,
l_in,
c_out,
c_in,
k_size,
padding,
stride,
};
let storage =
self.storage()
.conv1d(self.layout(), &kernel.storage(), kernel.layout(), &params)?;
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
arg,
kernel,
padding,
stride,
});
let out_dims = params.out_dims();
Ok(from_storage(storage, out_dims, op, false))
}
pub fn conv2d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
let (b_size, c_in, i_h, i_w) = self.dims4()?;
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
}
let params = crate::conv::ParamsConv2D {
b_size,
i_h,
i_w,
k_h,
k_w,
c_out,
c_in,
padding,
stride,
};
let storage =
self.storage()
.conv2d(self.layout(), &kernel.storage(), kernel.layout(), &params)?;
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
arg,
kernel,
padding,
stride,
});
let out_dims = params.out_dims();
Ok(from_storage(storage, out_dims, op, false))
}
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
let (n, c, _h, _w) = self.dims4()?;
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
@ -841,26 +845,7 @@ impl Tensor {
Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
}
/// 2D average pooling over an input tensor with multiple channels.
///
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
/// tensor also has four dimensions, `(batch, channels, h', w')`. The pooling is performed on
/// the two last dimensions using a kernel of size `sz`. The returned element is the average
/// value over the kernel window.
pub fn avg_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
let sz = sz.to_usize2();
self.avg_pool2d_with_stride(sz, sz)
}
/// Same as `avg_pool2d` but with a `stride` that can be set to a value different from the
/// kernel size.
pub fn avg_pool2d_with_stride<T: crate::ToUsize2>(
&self,
kernel_size: T,
stride: T,
) -> Result<Self> {
let kernel_size = kernel_size.to_usize2();
let stride = stride.to_usize2();
pub fn avg_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> {
let (n, c, h, w) = self.dims4()?;
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
let h_out = (h - kernel_size.0) / stride.0 + 1;
@ -876,26 +861,7 @@ impl Tensor {
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
}
/// 2D max pooling over an input tensor with multiple channels.
///
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
/// tensor also has four dimensions, `(batch, channels, h', w')`. The pooling is performed on
/// the two last dimensions using a kernel of size `sz`, the returned element is the maximum
/// value over the kernel window.
pub fn max_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
let sz = sz.to_usize2();
self.max_pool2d_with_stride(sz, sz)
}
/// Same as `max_pool2d` but with a `stride` that can be set to a value different from the
/// kernel size.
pub fn max_pool2d_with_stride<T: crate::ToUsize2>(
&self,
kernel_size: T,
stride: T,
) -> Result<Self> {
let kernel_size = kernel_size.to_usize2();
let stride = stride.to_usize2();
pub fn max_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> {
let (n, c, h, w) = self.dims4()?;
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
let h_out = (h - kernel_size.0) / stride.0 + 1;
@ -961,28 +927,6 @@ impl Tensor {
Ok(from_storage(storage, c_shape, op, false))
}
/// Matrix-multiplication with broadcasting support.
///
/// Compared to `matmul` the two matrixes are allowed to have different dimensions as long as
/// they are compatible for broadcast. E.g. if `self` has shape `(j, 1, n, k)` and `rhs` has
/// shape `(l, k, m)`, the output will have shape `(j, l, n, m)`.
pub fn broadcast_matmul(&self, rhs: &Self) -> Result<Self> {
let lhs = self;
let (l_shape, r_shape) = lhs.shape().broadcast_shape_matmul(rhs.shape())?;
let l_broadcast = l_shape != *lhs.shape();
let r_broadcast = r_shape != *rhs.shape();
// TODO: Avoid concretising the broadcasted matrixes via contiguous.
match (l_broadcast, r_broadcast) {
(true, true) => lhs
.broadcast_as(&l_shape)?
.contiguous()?
.matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
(false, true) => lhs.matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
(true, false) => lhs.broadcast_as(&l_shape)?.contiguous()?.matmul(rhs),
(false, false) => lhs.matmul(rhs),
}
}
/// Returns a tensor with the same shape as the input tensor, the values are taken from
/// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the
/// input tensor is equal to zero.
@ -1075,7 +1019,6 @@ impl Tensor {
Ok(from_storage(storage, self.shape(), op, false))
}
/// Accumulate element from `source` at indexes `indexes` and add them to `self`.
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "index-add")?;
let source_dims = source.dims();
@ -1124,17 +1067,6 @@ impl Tensor {
Ok(from_storage(storage, self.shape(), op, false))
}
/// Gather values across the target dimension.
///
/// # Arguments
///
/// * `self` - The input tensor.
/// * `indexes` - The indices of elements to gather, this should have the same shape as `self`
/// but can have a different number of elements on the target dimension.
/// * `dim` - the target dimension.
///
/// The resulting tensor has the same shape as `indexes` and use values from `self` indexed on
/// dimension `dim` by the values in `indexes`.
pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "gather")?;
let self_dims = self.dims();
@ -1165,13 +1097,6 @@ impl Tensor {
Ok(from_storage(storage, indexes.shape(), op, false))
}
/// Select values for the input tensor at the target indexes across the specified dimension.
///
/// The `indexes` is argument is an int tensor with a single dimension.
/// The output has the same number of dimension as the `self` input. The target dimension of
/// the output has length the length of `indexes` and the values are taken from `self` using
/// the index from `indexes`. Other dimensions have the same number of elements as the input
/// tensor.
pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "index-select")?;
let indexes_len = match indexes.dims() {
@ -1379,10 +1304,6 @@ impl Tensor {
self.sum(dims)
}
pub fn mean_all(&self) -> Result<Tensor> {
self.sum_all()? / self.elem_count() as f64
}
fn flatten_<D1: Dim, D2: Dim>(
&self,
start_dim: Option<D1>,
@ -1504,42 +1425,6 @@ impl Tensor {
Ok(Tensor(Arc::new(tensor_)))
}
/// Returns a tensor with the same data as the input where the dimensions have been permuted.
/// dims must be a permutation, i.e. include each dimension index exactly once.
///
/// ```rust
/// use candle_core::{Tensor, Device};
/// let tensor = Tensor::arange(0u32, 120u32, &Device::Cpu)?.reshape((2, 3, 4, 5))?;
/// assert_eq!(tensor.dims(), &[2, 3, 4, 5]);
/// let tensor = tensor.permute((2, 3, 1, 0))?;
/// assert_eq!(tensor.dims(), &[4, 5, 3, 2]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn permute<D: Dims>(&self, dims: D) -> Result<Tensor> {
let dims = dims.to_indexes(self.shape(), "permute")?;
// O(n^2) permutation check but these arrays are small.
let is_permutation =
dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
if !is_permutation {
crate::bail!(
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
self.dims(),
dims
)
}
let op = BackpropOp::new1(self, |t| Op::Permute(t, dims.clone()));
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
layout: self.layout.permute(&dims)?,
op,
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
}
/// Returns true if the data is stored in a C contiguous (aka row major) way.
pub fn is_contiguous(&self) -> bool {
self.layout.is_contiguous()
@ -1693,15 +1578,12 @@ impl Tensor {
Ok(from_storage(storage, shape, BackpropOp::none(), true))
}
// TODO: Do we want to allow target shape using -1 on some dimensions?
/// Reshape returns a tensor with the target shape provided that the number of elements of the
/// original tensor is the same.
/// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses
/// a new storage and copies the data over, the returned tensor is always contiguous.
///
/// The shape can be specified using a tuple of `usize` and at most one `()` in which case
/// the behavior is the same as when using `-1` in PyTorch: this dimension size is adjusted so
/// as to match the number of elements in the tensor.
///
/// ```rust
/// # use candle_core::{Tensor, DType, Device, D};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
@ -1711,14 +1593,10 @@ impl Tensor {
///
/// let c = a.reshape((3, 2))?;
/// assert_eq!(c.shape().dims(), &[3, 2]);
///
/// let c = a.reshape((2, (), 1))?;
/// assert_eq!(c.shape().dims(), &[2, 3, 1]);
///
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn reshape<S: crate::shape::ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
let shape = s.into_shape(self.elem_count())?;
pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
let shape = shape.into();
if shape.elem_count() != self.elem_count() {
return Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(),
@ -1941,8 +1819,6 @@ impl Tensor {
Ok(from_storage(storage, shape, op, false))
}
/// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the
/// input tensor values and `right` elements after.
pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
if left == 0 && right == 0 {
Ok(self.clone())
@ -1969,12 +1845,7 @@ impl Tensor {
}
}
/// Run the `forward` method of `m` on `self`.
pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
m.forward(self)
}
pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
self.storage.read().unwrap()
}
@ -1999,53 +1870,22 @@ impl Tensor {
std::ptr::eq(lhs, rhs)
}
/// Applies a unary custom op without backward support
pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a binary custom op without backward support
pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
let (storage, shape) =
self.storage()
.apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a ternary custom op without backward support
pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
let (storage, shape) = self.storage().apply_op3(
self.layout(),
&t2.storage(),
t2.layout(),
&t3.storage(),
t3.layout(),
c,
)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a unary custom op.
pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
pub fn custom_op1_arc(&self, c: Arc<Box<dyn CustomOp1>>) -> Result<Self> {
let (storage, shape) = self
.storage()
.apply_op1(self.layout(), c.as_ref().as_ref())?;
.custom_op1(self.layout(), c.as_ref().as_ref())?;
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
self.apply_op1_arc(Arc::new(Box::new(c)))
pub fn custom_op1<C: 'static + CustomOp1>(&self, c: C) -> Result<Self> {
self.custom_op1_arc(Arc::new(Box::new(c)))
}
/// Applies a binary custom op.
pub fn apply_op2_arc(
&self,
rhs: &Self,
c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
) -> Result<Self> {
let (storage, shape) = self.storage().apply_op2(
pub fn custom_op2_arc(&self, rhs: &Self, c: Arc<Box<dyn CustomOp2>>) -> Result<Self> {
let (storage, shape) = self.storage().custom_op2(
self.layout(),
&rhs.storage(),
rhs.layout(),
@ -2055,18 +1895,13 @@ impl Tensor {
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
self.apply_op2_arc(r, Arc::new(Box::new(c)))
pub fn custom_op2<C: 'static + CustomOp2>(&self, r: &Self, c: C) -> Result<Self> {
self.custom_op2_arc(r, Arc::new(Box::new(c)))
}
/// Applies a ternary custom op.
pub fn apply_op3_arc(
&self,
t2: &Self,
t3: &Self,
c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
) -> Result<Self> {
let (storage, shape) = self.storage().apply_op3(
pub fn custom_op3_arc(&self, t2: &Self, t3: &Self, c: Arc<Box<dyn CustomOp3>>) -> Result<Self> {
let (storage, shape) = self.storage().custom_op3(
self.layout(),
&t2.storage(),
t2.layout(),
@ -2080,13 +1915,8 @@ impl Tensor {
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
&self,
t2: &Self,
t3: &Self,
c: C,
) -> Result<Self> {
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
pub fn custom_op3<C: 'static + CustomOp3>(&self, t2: &Self, t3: &Self, c: C) -> Result<Self> {
self.custom_op3_arc(t2, t3, Arc::new(Box::new(c)))
}
}
@ -2108,22 +1938,6 @@ macro_rules! bin_trait {
}
}
impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Tensor> for Result<B> {
type Output = Result<Tensor>;
fn $fn1(self, rhs: Tensor) -> Self::Output {
Tensor::$fn1(self?.borrow(), &rhs)
}
}
impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<&Tensor> for Result<B> {
type Output = Result<Tensor>;
fn $fn1(self, rhs: &Tensor) -> Self::Output {
Tensor::$fn1(self?.borrow(), rhs)
}
}
impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Result<B>> for Tensor {
type Output = Result<Tensor>;
@ -2162,69 +1976,3 @@ bin_trait!(Add, add, |_| 1., |v| v);
bin_trait!(Sub, sub, |_| 1., |v: f64| -v);
bin_trait!(Mul, mul, |v| v, |_| 0.);
bin_trait!(Div, div, |v| 1. / v, |_| 0.);
impl std::ops::Add<Tensor> for f64 {
type Output = Result<Tensor>;
fn add(self, rhs: Tensor) -> Self::Output {
rhs + self
}
}
impl std::ops::Add<&Tensor> for f64 {
type Output = Result<Tensor>;
fn add(self, rhs: &Tensor) -> Self::Output {
rhs + self
}
}
impl std::ops::Mul<Tensor> for f64 {
type Output = Result<Tensor>;
fn mul(self, rhs: Tensor) -> Self::Output {
rhs * self
}
}
impl std::ops::Mul<&Tensor> for f64 {
type Output = Result<Tensor>;
fn mul(self, rhs: &Tensor) -> Self::Output {
rhs * self
}
}
impl std::ops::Sub<Tensor> for f64 {
type Output = Result<Tensor>;
fn sub(self, rhs: Tensor) -> Self::Output {
rhs.affine(-1., self)
}
}
impl std::ops::Sub<&Tensor> for f64 {
type Output = Result<Tensor>;
fn sub(self, rhs: &Tensor) -> Self::Output {
rhs.affine(-1., self)
}
}
impl std::ops::Div<Tensor> for f64 {
type Output = Result<Tensor>;
#[allow(clippy::suspicious_arithmetic_impl)]
fn div(self, rhs: Tensor) -> Self::Output {
rhs.recip()? * self
}
}
impl std::ops::Div<&Tensor> for f64 {
type Output = Result<Tensor>;
#[allow(clippy::suspicious_arithmetic_impl)]
fn div(self, rhs: &Tensor) -> Self::Output {
rhs.recip()? * self
}
}

View File

@ -22,19 +22,3 @@ pub fn has_mkl() -> bool {
pub fn cuda_is_available() -> bool {
cfg!(feature = "cuda")
}
pub fn with_avx() -> bool {
cfg!(target_feature = "avx")
}
pub fn with_neon() -> bool {
cfg!(target_feature = "neon")
}
pub fn with_simd128() -> bool {
cfg!(target_feature = "simd128")
}
pub fn with_f16c() -> bool {
cfg!(target_feature = "f16c")
}

View File

@ -1,5 +1,6 @@
mod test_utils;
use anyhow::Result;
use candle_core::{test_device, test_utils, Device, IndexOp, Tensor};
use candle_core::{Device, Tensor};
/* This test is based on the following script.
import torch
@ -32,13 +33,13 @@ fn conv1d(dev: &Device) -> Result<()> {
dev,
)?
.reshape((2, 4, 3))?;
let res = t.conv1d(&w, 0, 1, 1, 1)?;
let res = t.conv1d(&w, 0, 1)?;
assert_eq!(res.dims(), [1, 2, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.6357, -1.3336, 4.1393, -1.1784, 3.5675, 0.5069]
);
let res = t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?;
let res = t.conv1d(&w, /*padding*/ 1, 1)?;
assert_eq!(res.dims(), [1, 2, 5]);
// Same as pytorch default padding: use zeros.
assert_eq!(
@ -51,13 +52,13 @@ fn conv1d(dev: &Device) -> Result<()> {
fn conv1d_small(dev: &Device) -> Result<()> {
let t = Tensor::new(&[0.4056f32, -0.8689, -0.0773, -1.5630], dev)?.reshape((1, 1, 4))?;
let w = Tensor::new(&[1f32, 0., 0.], dev)?.reshape((1, 1, 3))?;
let res = t.conv1d(&w, 0, 1, 1, 1)?;
let res = t.conv1d(&w, 0, 1)?;
assert_eq!(res.dims(), [1, 1, 2]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.4056, -0.8689]
);
let res = t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?;
let res = t.conv1d(&w, /*padding*/ 1, 1)?;
assert_eq!(res.dims(), [1, 1, 4]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@ -76,19 +77,6 @@ print(t.flatten())
print(w.flatten())
res = torch.nn.functional.conv2d(t, w)
print(res.flatten())
w_t = w.transpose(0, 1)
res = torch.nn.functional.conv_transpose2d(t, w_t)
print(res.shape)
print(res)
res = torch.nn.functional.conv2d(t, w, dilation=2)
print(res.shape)
print(res[0])
res = torch.nn.functional.conv_transpose2d(t, w_t, dilation=2)
print(res.shape)
print(res)
*/
fn conv2d(dev: &Device) -> Result<()> {
let t = Tensor::new(
@ -121,7 +109,7 @@ fn conv2d(dev: &Device) -> Result<()> {
)?;
let t = t.reshape((1, 4, 5, 5))?;
let w = w.reshape((2, 4, 3, 3))?;
let res = t.conv2d(&w, 0, 1, 1, 1)?;
let res = t.conv2d(&w, 0, 1)?;
assert_eq!(res.dims(), [1, 2, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@ -130,69 +118,6 @@ fn conv2d(dev: &Device) -> Result<()> {
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
]
);
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 7, 7]);
assert_eq!(
test_utils::to_vec3_round(&res.i(0)?, 4)?,
[
[
[-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277],
[1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375],
[0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889],
[0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632],
[-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985],
[2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114],
[5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579]
],
[
[1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211],
[-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131],
[1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621],
[-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142],
[7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059],
[-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516],
[-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171]
]
]
);
// Dilations.
let res = t.conv2d(&w, 0, 1, 2, 1)?;
assert_eq!(res.dims(), [1, 2, 1, 1]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.45, -2.3504],
);
// Transpose and dilations.
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?;
assert_eq!(res.dims(), [1, 2, 9, 9]);
assert_eq!(
test_utils::to_vec3_round(&res.i(0)?, 4)?,
[
[
[-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277],
[2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499],
[-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376],
[-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141],
[-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822],
[0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03],
[-2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51, -3.5024],
[4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787],
[5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579]
],
[
[1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211],
[-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278],
[1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861],
[1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185],
[1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642],
[3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957],
[5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856],
[-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908],
[-5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827, 1.0171]
]
]
);
Ok(())
}
@ -206,16 +131,6 @@ print(t.flatten())
print(w.flatten())
res = torch.nn.functional.conv2d(t, w)
print(res.flatten())
w_t = w.transpose(0, 1)
res = torch.nn.functional.conv_transpose2d(t, w_t)
print(res.shape)
print(res.flatten())
t_t = w.transpose(0, 1)
res = torch.nn.functional.conv_transpose2d(t_t, w)
print(res.shape)
print(res.flatten())
*/
fn conv2d_small(dev: &Device) -> Result<()> {
let t = Tensor::new(
@ -228,41 +143,12 @@ fn conv2d_small(dev: &Device) -> Result<()> {
let w = Tensor::new(&[-0.9259f32, 1.3017], dev)?;
let t = t.reshape((1, 2, 3, 3))?;
let w = w.reshape((1, 2, 1, 1))?;
let res = t.conv2d(&w, 0, 1, 1, 1)?;
let res = t.conv2d(&w, 0, 1)?;
assert_eq!(res.dims(), [1, 1, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539]
);
let res = t.conv2d(&w, 2, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 7, 7]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1640, -0.0111, -0.1742, 0.0000, 0.0000,
0.0000, 0.0000, 2.6437, -2.0268, 1.1823, 0.0000, 0.0000, 0.0000, 0.0000, 3.2855,
-1.0324, 0.2539, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
]
);
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539],
);
let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [2, 2, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[
-0.3755, 0.8045, -0.6336, -0.2218, -1.1369, 0.8599, 1.5768, -0.1268, -0.1728, 0.528,
-1.131, 0.8908, 0.3118, 1.5984, -1.2089, -2.2168, 0.1783, 0.2429, -0.3838, 0.5802,
-0.3268, -2.0382, 0.6329, -0.2293, -1.2154, 0.6441, -0.3035, 0.5396, -0.8156, 0.4594,
2.8654, -0.8898, 0.3224, 1.7087, -0.9056, 0.4267
]
);
Ok(())
}
@ -276,7 +162,7 @@ fn conv2d_smaller(dev: &Device) -> Result<()> {
let w = Tensor::new(&[1f32, 1., 1., 1., 1., 1., 1., 1., 1.], dev)?;
let t = t.reshape((1, 1, 3, 3))?;
let w = w.reshape((1, 1, 3, 3))?;
let res = t.conv2d(&w, 0, 1, 1, 1)?;
let res = t.conv2d(&w, 0, 1)?;
assert_eq!(res.dims(), [1, 1, 1, 1]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@ -285,211 +171,8 @@ fn conv2d_smaller(dev: &Device) -> Result<()> {
Ok(())
}
/* This test is based on the following script.
import torch
torch.manual_seed(4242)
t = torch.randn((1, 2, 4, 2))
w = torch.randn((1, 2, 1, 1))
print(t.flatten())
print(w.flatten())
res = torch.nn.functional.conv2d(t, w)
print(res.flatten())
*/
fn conv2d_non_square(dev: &Device) -> Result<()> {
let t = Tensor::new(
&[
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699,
],
dev,
)?;
let w = Tensor::new(&[-1.1351f32, 1.3841], dev)?;
let t = t.reshape((1, 2, 4, 2))?;
let w = w.reshape((1, 2, 1, 1))?;
let res = t.conv2d(&w, 0, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 4, 2]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.2312, 5.2238, 2.3772, 1.9076, 2.0256, -0.5776, -1.6028, -1.467]
);
Ok(())
}
/*
import torch
torch.manual_seed(4242)
t = torch.randn((1, 4, 5, 5), requires_grad=True)
w = torch.randn((2, 4, 3, 3), requires_grad=True)
print(t.flatten())
print(w.flatten())
res = torch.nn.functional.conv2d(t, w)
print(res.flatten())
loss = (res ** 2).sum()
print(loss)
loss.backward()
print(t.grad.shape)
print(t.grad.flatten())
print(w.grad.shape)
print(w.grad.flatten())
t.grad.zero_()
w.grad.zero_()
res = torch.nn.functional.conv2d(t, w, stride=2)
print(res.flatten())
loss = (res ** 2).sum()
print(loss)
loss.backward()
print(t.grad.shape)
print(t.grad[0])
print(w.grad.shape)
print(w.grad[0])
*/
fn conv2d_grad(dev: &Device) -> Result<()> {
use candle_core::Var;
let t = Var::from_slice(
&[
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395,
1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836,
0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130, 1.3123,
1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071, 1.1586,
0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090, 0.2049,
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
],
(1, 4, 5, 5),
dev,
)?;
let w = Var::from_slice(
&[
-0.9325f32, 0.6451, -0.8537, 0.2378, 0.8764, -0.1832, 0.2987, -0.6488, -0.2273,
-2.4184, -0.1192, -0.4821, -0.5079, -0.5766, -2.4729, 1.6734, 0.4558, 0.2851, 1.1514,
-0.9013, 1.0662, -0.1817, -0.0259, 0.1709, 0.5367, 0.7513, 0.8086, -2.2586, -0.5027,
0.9141, -1.3086, -1.3343, -1.5669, -0.1657, 0.7958, 0.1432, 0.3896, -0.4501, 0.1667,
0.0714, -0.0952, 1.2970, -0.1674, -0.3178, 1.0677, 0.3060, 0.7080, 0.1914, 1.1679,
-0.3602, 1.9265, -1.8626, -0.5112, -0.0982, 0.2621, 0.6565, 0.5908, 1.0089, -0.1646,
1.8032, -0.6286, 0.2016, -0.3370, 1.2555, 0.8009, -0.6488, -0.4652, -1.5685, 1.5860,
0.5583, 0.4623, 0.6026,
],
(2, 4, 3, 3),
dev,
)?;
let res = t.conv2d(&w, 0, 1, 1, 1)?;
let loss = res.sqr()?.sum_all()?;
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 741.12f32);
let grads = loss.backward()?;
let grad_t = grads.get(&t).unwrap();
let grad_w = grads.get(&w).unwrap();
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&grad_t.flatten_all()?, 2)?,
[
9.29, -2.84, -5.71, 3.38, -7.71, -19.15, 7.02, 29.1, 9.34, 34.73, -22.87, 24.35,
-39.88, -14.01, 21.08, 9.94, 13.63, -34.68, 11.21, -6.26, 7.72, -6.32, -16.64, -1.08,
-20.22, 21.73, -0.37, -4.06, 5.82, -3.65, -30.73, 14.55, 87.7, 31.6, 4.53, -89.78,
-75.37, -57.43, -7.56, 92.96, 18.79, -4.63, -159.75, -42.47, -47.26, 52.88, 37.32,
49.0, 12.82, 2.01, -8.98, 20.18, 16.62, 12.06, 15.38, 20.0, 2.57, -15.22, 72.62,
-10.75, 2.25, -31.2, 3.75, -0.2, 9.76, -0.68, 5.21, -40.44, -22.59, -61.61, 17.28,
20.41, 37.55, 5.23, 6.81, 23.54, 23.62, -9.99, -9.13, 4.87, -35.06, -26.1, 63.48,
25.81, -39.21, -70.68, -46.96, 2.33, 41.81, 82.42, -28.63, -11.78, -35.33, -10.28,
-28.57, -9.13, 7.21, -9.05, -9.62, -11.25
]
);
assert_eq!(
test_utils::to_vec1_round(&grad_w.flatten_all()?, 2)?,
[
-28.92, -22.88, -141.23, 73.35, 61.07, 47.81, -20.0, -73.71, -41.82, -13.59, 21.5,
28.72, 28.57, -46.85, -90.19, 143.61, 16.68, 7.43, 18.88, -90.81, -20.29, 54.79, 82.63,
22.94, 77.81, -16.39, -13.2, 9.34, -40.39, -26.62, 5.33, -60.91, 9.09, -59.37, 7.08,
58.64, 5.55, 20.52, 2.5, -17.25, -6.8, 22.21, 30.15, -7.52, -37.46, 5.67, 22.58, 9.03,
47.05, 17.61, 37.31, -98.13, -14.61, -4.8, -6.36, 44.69, 23.34, 8.37, -13.52, 80.05,
-34.24, -16.36, -12.31, 1.92, -33.62, -14.1, -49.23, -7.39, 11.5, -9.98, 9.66, 29.6
]
);
// Same as before but with stride.
let res = t.conv2d(&w, 0, 2, 1, 1)?;
let loss = res.sqr()?.sum_all()?;
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 277.16f32);
let grads = loss.backward()?;
let grad_t = grads.get(&t).unwrap();
let grad_w = grads.get(&w).unwrap();
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
assert_eq!(
test_utils::to_vec3_round(&grad_t.i(0)?, 2)?,
[
[
[9.29, -7.03, 0.94, 3.49, -7.71],
[-1.8, -7.82, 8.9, 8.46, 7.43],
[-25.84, 22.09, -19.27, -0.22, 1.69],
[4.02, 18.53, -18.37, 2.3, -24.51],
[7.72, -9.68, -12.34, 5.6, -20.22]
],
[
[21.73, 3.39, -18.27, 3.86, -3.65],
[8.25, 3.73, 30.73, -8.61, -11.93],
[-72.15, -15.36, -17.53, -12.32, -1.61],
[-22.32, -7.79, -91.82, 6.44, -37.69],
[52.88, 14.44, 42.75, 9.88, 2.01]
],
[
[-8.98, 9.91, 6.75, -4.68, 15.38],
[4.93, -0.33, 9.94, -1.46, 14.78],
[13.62, -30.63, 3.96, -3.58, -4.48],
[-14.13, 1.19, -34.43, 3.08, -33.83],
[17.28, 12.94, 31.83, -3.35, 6.81]
],
[
[23.54, 6.98, -24.52, 0.52, 4.87],
[9.65, 6.18, 1.71, -25.23, -4.93],
[-54.99, -23.66, 3.19, -3.73, 18.58],
[-21.35, -10.39, -39.88, 28.73, -30.76],
[-9.13, 11.12, -14.0, -8.23, -11.25]
]
]
);
assert_eq!(
test_utils::to_vec3_round(&grad_w.i(0)?, 2)?,
[
[
[28.34, -7.91, -45.75],
[21.03, 3.86, 29.86],
[0.72, -36.58, -35.28]
],
[
[-16.04, 11.53, -16.38],
[29.62, -16.32, -48.35],
[57.5, 28.29, 25.81]
],
[
[2.93, -19.6, 1.57],
[27.15, 53.88, -24.64],
[12.74, -22.6, -26.2]
],
[
[-0.18, -14.86, -6.82],
[-19.55, -2.72, 45.9],
[-2.54, 36.97, 27.11]
]
]
);
Ok(())
}
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
test_device!(
conv2d_non_square,
conv2d_non_square_cpu,
conv2d_non_square_gpu
);
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
test_device!(conv2d_grad, conv2d_grad_cpu, conv2d_grad_gpu);

View File

@ -1,8 +1,10 @@
use candle_core::backend::BackendStorage;
use candle_core::cpu_backend;
use candle_core::test_utils::to_vec1_round;
use candle_core::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor};
mod test_utils;
use test_utils::to_vec1_round;
fn fwd<T: num_traits::Float>(v: T, alpha: f64) -> T {
if v.is_sign_positive() {
v
@ -37,7 +39,7 @@ fn custom_op1_no_backward() -> Result<()> {
let cpu = &Device::Cpu;
let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?;
let t = (t - 5.)?;
let elu_t = t.apply_op1_no_bwd(&Elu { alpha: 1. })?;
let elu_t = t.custom_op1(Elu { alpha: 1. })?;
assert_eq!(
to_vec1_round(&elu_t, 4)?,
&[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
@ -94,7 +96,7 @@ impl CustomOp1 for EluWithBackward {
fn bwd(&self, arg: &Tensor, _res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
let alpha = self.0.alpha;
let bwd = arg.apply_op1(EluBackward { alpha })?;
let bwd = arg.custom_op1(EluBackward { alpha })?;
Ok(Some(grad_res.mul(&bwd)?))
}
}
@ -103,7 +105,7 @@ impl CustomOp1 for EluWithBackward {
fn custom_op1_with_backward() -> Result<()> {
let cpu = &Device::Cpu;
let t = candle_core::Var::new(&[-2f32, 0f32, 2f32], cpu)?;
let elu_t = t.apply_op1(EluWithBackward::new(2.))?;
let elu_t = t.custom_op1(EluWithBackward::new(2.))?;
assert_eq!(to_vec1_round(&elu_t, 4)?, &[-1.7293, 0.0, 2.0]);
let grads = elu_t.backward()?;

View File

@ -1,5 +1,6 @@
use anyhow::{Context, Result};
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
use candle_core::{Device, Shape, Tensor, Var};
mod test_utils;
fn simple_grad(device: &Device) -> Result<()> {
let x = Var::new(&[3f32, 1., 4.], device)?;
@ -173,51 +174,6 @@ fn unary_grad(device: &Device) -> Result<()> {
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [6., 2., 8., 0.3]);
assert_eq!(grad_x.to_vec1::<f32>()?, [2., 2., 2., 2.]);
let x = Var::new(&[3f32, 1., 4., 0.15], device)?;
let y = x.powf(2.5)?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(test_utils::to_vec1_round(&y, 2)?, [15.59, 1.0, 32.0, 0.01]);
assert_eq!(
test_utils::to_vec1_round(grad_x, 2)?,
[12.99, 2.5, 20.0, 0.15]
);
let y = x.tanh()?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(test_utils::to_vec1_round(&y, 2)?, [1.0, 0.76, 1.0, 0.15]);
assert_eq!(
test_utils::to_vec1_round(grad_x, 2)?,
[0.01, 0.42, 0.0, 0.98],
);
Ok(())
}
fn binary_grad(device: &Device) -> Result<()> {
let x = Var::new(&[3f32, 1., -4., -1.], device)?;
let x = x.as_tensor();
// leaky relu
let y = x.maximum(&(x * 0.1)?)?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(x.to_vec1::<f32>()?, [3., 1., -4., -1.]);
assert_eq!(y.to_vec1::<f32>()?, [3., 1., -0.4, -0.1]);
assert_eq!(grad_x.to_vec1::<f32>()?, [1., 1., 0.1, 0.1]);
let y = x.minimum(&(x * 0.1)?)?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [0.3, 0.1, -4., -1.]);
assert_eq!(grad_x.to_vec1::<f32>()?, [0.1, 0.1, 1., 1.]);
// This one is easy to mess up, we want the gradient to be one as it is the identity function.
let y = x.minimum(x)?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [3., 1., -4., -1.]);
assert_eq!(grad_x.to_vec1::<f32>()?, [1., 1., 1., 1.]);
Ok(())
}
@ -226,4 +182,3 @@ test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu);
test_device!(binary_grad, binary_grad_cpu, binary_grad_gpu);

View File

@ -1,6 +1,8 @@
use anyhow::Result;
use candle_core::{Device, IndexOp, Tensor};
mod test_utils;
#[test]
fn integer_index() -> Result<()> {
let dev = Device::Cpu;

View File

@ -1,4 +1,5 @@
use candle::{test_device, Device, IndexOp, Result, Tensor};
mod test_utils;
use candle::{Device, IndexOp, Result, Tensor};
use candle_core as candle;
fn contiguous(device: &Device) -> Result<()> {

View File

@ -1,4 +1,5 @@
use candle_core::{test_device, test_utils, Device, IndexOp, Result, Tensor};
mod test_utils;
use candle_core::{Device, IndexOp, Result, Tensor};
// https://github.com/huggingface/candle/issues/364
fn avg_pool2d(dev: &Device) -> Result<()> {
@ -6,15 +7,8 @@ fn avg_pool2d(dev: &Device) -> Result<()> {
1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?;
let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?;
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
assert_eq!(pool.to_vec2::<f32>()?, [[0.5f32, 1.], [1., 1.]]);
let data: Vec<f32> = vec![
1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1.,
];
let t = Tensor::from_vec(data, (1, 1, 2, 8), dev)?;
let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?;
assert_eq!(pool.to_vec2::<f32>()?, [[5. / 4., 6. / 4., 6. / 4., 1.]]);
Ok(())
}
@ -24,12 +18,8 @@ fn max_pool2d(dev: &Device) -> Result<()> {
];
let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?;
let pool = t.max_pool2d(2)?.squeeze(0)?.squeeze(0)?;
let pool = t.max_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
assert_eq!(pool.to_vec2::<f32>()?, [[2f32, 3.], [5., 1.]]);
let t = t.reshape((1, 1, 2, 8))?;
let pool = t.max_pool2d(2)?.squeeze(0)?.squeeze(0)?;
assert_eq!(pool.to_vec2::<f32>()?, [[2.0, 3.0, 5.0, 1.0]]);
Ok(())
}
@ -53,29 +43,16 @@ fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
dev,
)?
.reshape((1, 2, 4, 4))?;
let pool = t.avg_pool2d(2)?.squeeze(0)?;
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?;
assert_eq!(
test_utils::to_vec3_round(&pool, 4)?,
test_utils::to_vec3_round(pool, 4)?,
[
[[-1.1926, -0.0395], [0.2688, 0.1871]],
[[0.1835, -0.1606], [0.6249, 0.3217]]
]
);
let pool = t.avg_pool2d(3)?.squeeze(0)?;
assert_eq!(
test_utils::to_vec3_round(&pool, 4)?,
[[[0.085]], [[0.0078]]]
);
let t = t.reshape((1, 1, 4, 8))?;
let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?;
assert_eq!(
test_utils::to_vec2_round(&pool, 4)?,
[
[0.7745, 0.0276, -1.6983, 0.12],
[0.3542, 0.1625, 0.4542, -0.0014]
]
);
let pool = t.avg_pool2d((3, 3), (3, 3))?.squeeze(0)?;
assert_eq!(test_utils::to_vec3_round(pool, 4)?, [[[0.085]], [[0.0078]]]);
Ok(())
}

View File

@ -1,17 +1,5 @@
use candle_core::{
quantized::{self, GgmlDType},
test_utils::to_vec2_round,
Device, Result, Tensor,
};
use candle_core::{quantized, Device, Result, Tensor};
use quantized::{k_quants, GgmlType};
use rand::prelude::*;
const GGML_TEST_SIZE: usize = 32 * 128;
const GGML_MAX_QUANTIZATION_TOTAL_ERROR: f32 = 0.002;
const GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS: f32 = 0.0075;
const GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS: f32 = 0.0040;
const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02;
#[test]
fn quantized_matmul() -> Result<()> {
@ -26,10 +14,10 @@ fn quantized_matmul() -> Result<()> {
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
assert_eq!(
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
dst,
&[
85120.0, 214562.0, 345455.0, 474748.0, 213475.0, 604465.0, 1000686.0, 1388317.0,
341876.0, 994283.0, 1655709.0, 2301518.0
85120.43, 214561.61, 345454.9, 474748.1, 213474.94, 604465.25, 1000686.4, 1388317.3,
341875.88, 994283.0, 1655708.8, 2301518.3
]
);
let mm = tensor_lhs.matmul(&tensor_rhs)?;
@ -42,648 +30,17 @@ fn quantized_matmul() -> Result<()> {
]
);
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
let matmul = quantized::QMatMul::from_qtensor(qtensor);
let res = matmul.forward(&tensor_lhs)?;
let qtensor = quantized::QTensor::new(rhs_t, (64, 4));
let op = quantized::QMatMul::new(std::sync::Arc::new(qtensor));
let res = tensor_lhs.custom_op1(op)?;
assert_eq!(
to_vec2_round(&res, 0)?,
res.to_vec2::<f32>()?,
&[
[85120.0, 214562.0, 345455.0, 474748.0],
[213475.0, 604465.0, 1000686.0, 1388317.0],
[341876.0, 994283.0, 1655709.0, 2301518.0]
[85120.43, 214561.61, 345454.9, 474748.1],
[213474.94, 604465.25, 1000686.4, 1388317.3],
[341875.88, 994283.0, 1655708.8, 2301518.3]
]
);
Ok(())
}
#[test]
fn quantized_matmul_neg() -> Result<()> {
let cpu = &Device::Cpu;
let (m, k, n) = (3, 64, 4);
let lhs = (0..(m * k))
.map(|v| v as f32 - (m * k) as f32 / 2.0)
.collect::<Vec<_>>();
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
let mut dst = vec![42.; 3 * 4];
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
let rhs = (0..k * n)
.map(|v| v as f32 - (k * n) as f32 / 3.0)
.collect::<Vec<_>>();
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
assert_eq!(
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
&[
243524.0, -19596.0, -285051.0, -549815.0, 23777.0, 21651.0, 19398.0, 18367.0,
-196472.0, 63012.0, 324585.0, 587902.0
]
);
let mm = tensor_lhs.matmul(&tensor_rhs)?;
assert_eq!(
to_vec2_round(&mm, 0)?,
&[
[244064.0, -20128.0, -284320.0, -548512.0],
[23563.0, 21515.0, 19467.0, 17419.0],
[-196939.0, 63157.0, 323253.0, 583349.0]
]
);
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
let matmul = quantized::QMatMul::from_qtensor(qtensor);
let res = matmul.forward(&tensor_lhs)?;
assert_eq!(
to_vec2_round(&res, 0)?,
&[
[243524.0, -19596.0, -285051.0, -549815.0],
[23777.0, 21651.0, 19398.0, 18367.0],
[-196472.0, 63012.0, 324585.0, 587902.0]
]
);
Ok(())
}
#[test]
fn quantize_q4_0() -> Result<()> {
use k_quants::BlockQ4_0;
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let mut dst = vec![0f32; 32 * 4];
let mut quant = vec![BlockQ4_0::zeros(); 4];
BlockQ4_0::from_float(&src, &mut quant)?;
BlockQ4_0::to_float(&quant, dst.as_mut_slice())?;
assert_eq!(
dst,
&[
-0.0, -0.0, 3.875, 3.875, 3.875, 3.875, 7.75, 7.75, 7.75, 7.75, 11.625, 11.625, 11.625,
11.625, 15.5, 15.5, 15.5, 15.5, 19.375, 19.375, 19.375, 19.375, 23.25, 23.25, 23.25,
23.25, 27.125, 27.125, 27.125, 27.125, 31.0, 31.0, 31.5, 31.5, 31.5, 31.5, 39.375,
39.375, 39.375, 39.375, 39.375, 39.375, 39.375, 39.375, 47.25, 47.25, 47.25, 47.25,
47.25, 47.25, 47.25, 47.25, 55.125, 55.125, 55.125, 55.125, 55.125, 55.125, 55.125,
55.125, 63.0, 63.0, 63.0, 63.0, 59.375, 59.375, 71.25, 71.25, 71.25, 71.25, 71.25,
71.25, 71.25, 71.25, 71.25, 71.25, 71.25, 71.25, 83.125, 83.125, 83.125, 83.125,
83.125, 83.125, 83.125, 83.125, 83.125, 83.125, 83.125, 83.125, 95.0, 95.0, 95.0, 95.0,
95.0, 95.0, 95.25, 95.25, 95.25, 95.25, 95.25, 95.25, 95.25, 95.25, 111.125, 111.125,
111.125, 111.125, 111.125, 111.125, 111.125, 111.125, 111.125, 111.125, 111.125,
111.125, 111.125, 111.125, 111.125, 111.125, 127.0, 127.0, 127.0, 127.0, 127.0, 127.0,
127.0, 127.0
]
);
ggml_quantization_error_test::<BlockQ4_0>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(())
}
#[test]
fn quantize_q4_1() -> Result<()> {
use k_quants::BlockQ4_1;
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let mut dst = vec![0f32; 32 * 4];
let mut quant = vec![BlockQ4_1::zeros(); 4];
BlockQ4_1::from_float(&src, &mut quant)?;
BlockQ4_1::to_float(&quant, dst.as_mut_slice())?;
assert_eq!(
round_vector(&dst),
&[
0.0, 0.0, 2.066, 2.066, 4.133, 4.133, 6.199, 6.199, 8.266, 8.266, 10.332, 10.332,
12.398, 12.398, 14.465, 14.465, 16.531, 16.531, 18.598, 18.598, 20.664, 20.664, 22.73,
22.73, 24.797, 24.797, 26.863, 26.863, 28.93, 28.93, 30.996, 30.996, 32.0, 32.0,
34.066, 34.066, 36.133, 36.133, 38.199, 38.199, 40.266, 40.266, 42.332, 42.332, 44.398,
44.398, 46.465, 46.465, 48.531, 48.531, 50.598, 50.598, 52.664, 52.664, 54.73, 54.73,
56.797, 56.797, 58.863, 58.863, 60.93, 60.93, 62.996, 62.996, 64.0, 64.0, 66.066,
66.066, 68.133, 68.133, 70.199, 70.199, 72.266, 72.266, 74.332, 74.332, 76.398, 76.398,
78.465, 78.465, 80.531, 80.531, 82.598, 82.598, 84.664, 84.664, 86.73, 86.73, 88.797,
88.797, 90.863, 90.863, 92.93, 92.93, 94.996, 94.996, 96.0, 96.0, 98.066, 98.066,
100.133, 100.133, 102.199, 102.199, 104.266, 104.266, 106.332, 106.332, 108.398,
108.398, 110.465, 110.465, 112.531, 112.531, 114.598, 114.598, 116.664, 116.664,
118.73, 118.73, 120.797, 120.797, 122.863, 122.863, 124.93, 124.93, 126.996, 126.996
]
);
ggml_quantization_error_test::<BlockQ4_1>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(())
}
#[test]
fn quantize_q5_0() -> Result<()> {
use k_quants::BlockQ5_0;
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let mut dst = vec![0f32; 32 * 4];
let mut quant = vec![BlockQ5_0::zeros(); 4];
BlockQ5_0::from_float(&src, &mut quant)?;
BlockQ5_0::to_float(&quant, dst.as_mut_slice())?;
assert_eq!(
round_vector(&dst),
&[
-0.0, 1.938, 1.938, 3.875, 3.875, 5.813, 5.813, 7.75, 7.75, 9.688, 9.688, 11.625,
11.625, 13.563, 13.563, 15.5, 15.5, 17.438, 17.438, 19.375, 19.375, 21.313, 21.313,
23.25, 23.25, 25.188, 25.188, 27.125, 27.125, 29.063, 29.063, 31.0, 31.5, 31.5, 35.438,
35.438, 35.438, 35.438, 39.375, 39.375, 39.375, 39.375, 43.313, 43.313, 43.313, 43.313,
47.25, 47.25, 47.25, 47.25, 51.188, 51.188, 51.188, 51.188, 55.125, 55.125, 55.125,
55.125, 59.063, 59.063, 59.063, 59.063, 63.0, 63.0, 65.313, 65.313, 65.313, 65.313,
65.313, 71.25, 71.25, 71.25, 71.25, 71.25, 71.25, 77.188, 77.188, 77.188, 77.188,
77.188, 77.188, 83.125, 83.125, 83.125, 83.125, 83.125, 83.125, 89.063, 89.063, 89.063,
89.063, 89.063, 89.063, 95.0, 95.0, 95.0, 95.25, 95.25, 95.25, 95.25, 103.188, 103.188,
103.188, 103.188, 103.188, 103.188, 103.188, 103.188, 111.125, 111.125, 111.125,
111.125, 111.125, 111.125, 111.125, 111.125, 119.063, 119.063, 119.063, 119.063,
119.063, 119.063, 119.063, 119.063, 127.0, 127.0, 127.0, 127.0
]
);
ggml_quantization_error_test::<BlockQ5_0>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(())
}
#[test]
fn quantize_q5_1() -> Result<()> {
use k_quants::BlockQ5_1;
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let mut dst = vec![0f32; 32 * 4];
let mut quant = vec![BlockQ5_1::zeros(); 4];
BlockQ5_1::from_float(&src, &mut quant)?;
BlockQ5_1::to_float(&quant, dst.as_mut_slice())?;
assert_eq!(
dst,
&[
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0,
30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0,
44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0,
58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0,
72.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0, 80.0, 81.0, 82.0, 83.0, 84.0, 85.0,
86.0, 87.0, 88.0, 89.0, 90.0, 91.0, 92.0, 93.0, 94.0, 95.0, 96.0, 97.0, 98.0, 99.0,
100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0,
112.0, 113.0, 114.0, 115.0, 116.0, 117.0, 118.0, 119.0, 120.0, 121.0, 122.0, 123.0,
124.0, 125.0, 126.0, 127.0
]
);
ggml_quantization_error_test::<BlockQ5_1>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(())
}
/// Generates a small test vector ranging from -`bound` to `bound` with `size` steps
fn get_test_vector(bound: f32, size: usize) -> (Vec<f32>, Vec<f32>) {
assert!(
size % crate::quantized::k_quants::QK_K == 0,
"size must be a multiple of {}",
crate::quantized::k_quants::QK_K
);
let src = (0..size)
.map(|v| (v as f32 - size as f32 / 2.) * bound / (size as f32 / 2.))
.collect::<Vec<_>>();
let dst = vec![0f32; size];
assert_eq!([src[0], src[size / 2]], [-bound, 0.0]);
(src, dst)
}
/// Round a vector
fn round_vector(values: &[f32]) -> Vec<f32> {
values
.iter()
.map(|x| (1000. * x).round() / 1000.)
.collect::<Vec<_>>()
}
fn compare_with_error(values: &[f32], expected: &[f32], tolerance: f32) {
for (i, (value, expected_value)) in values.iter().zip(expected.iter()).enumerate() {
let difference = (value - expected_value).abs();
assert!(
difference < tolerance,
"Error at index {}: value = {}, expected = {}. Difference = {} exceeds tolerance = {}.",
i,
value,
expected_value,
difference,
tolerance
);
}
}
/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
fn create_ggml_like_vector(offset: f32) -> Vec<f32> {
(0..GGML_TEST_SIZE)
.map(|i| 0.1 + 2.0 * (i as f32 + offset).cos())
.collect()
}
/// Calculates the root mean square error between two vectors
fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len());
let sum = a
.iter()
.zip(b)
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>()
.sqrt();
sum / a.len() as f32
}
/// Mirrores the GGML quanitzation unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> {
let src = create_ggml_like_vector(0.0);
let mut dst = vec![0.0; GGML_TEST_SIZE];
let _quant = quantize_roundtrip::<T>(src.as_slice(), dst.as_mut_slice())?;
let error = calculate_rmse(src.as_slice(), dst.as_slice());
if error > max_error {
candle_core::bail!(
"Quantization error {} exceeds max error {}",
error,
max_error
);
}
Ok(())
}
fn quantize_roundtrip<T: GgmlType>(src: &[f32], dst: &mut [f32]) -> Result<Vec<T>> {
let mut quant = vec![T::zeros(); src.len() / T::BLCK_SIZE];
T::from_float(src, &mut quant)?;
T::to_float(&quant, dst)?;
Ok(quant)
}
#[test]
fn quantize_q2k() -> Result<()> {
use k_quants::BlockQ2K;
let (src, mut dst) = get_test_vector(0.5, 1024);
let _quant = quantize_roundtrip::<BlockQ2K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.1);
// Test some specific values
assert_eq!(
[src[0], src[128], src[256], src[512], src[800], src[1023]],
[-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344]
);
let dst = round_vector(&dst);
assert_eq!(
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
[-0.499, -0.366, -0.249, 0.0, 0.295, 0.492]
);
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let _quant_big = quantize_roundtrip::<BlockQ2K>(src_big.as_slice(), dst_big.as_mut_slice())?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 6.0);
ggml_quantization_error_test::<BlockQ2K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?;
Ok(())
}
#[test]
fn quantize_q3k() -> Result<()> {
use k_quants::BlockQ3K;
let (src, mut dst) = get_test_vector(0.5, 1024);
let _quant = quantize_roundtrip::<BlockQ3K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.03);
// Test some specific values
assert_eq!(
[src[0], src[128], src[256], src[512], src[800], src[1023]],
[-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344]
);
let dst = round_vector(&dst);
assert_eq!(
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
[-0.493, -0.37, -0.243, -0.0, 0.292, 0.492]
);
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let _quant_big = quantize_roundtrip::<BlockQ3K>(src_big.as_slice(), dst_big.as_mut_slice())?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 3.5);
ggml_quantization_error_test::<BlockQ3K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?;
Ok(())
}
#[test]
fn quantize_q4k() -> Result<()> {
use k_quants::BlockQ4K;
let (src, mut dst) = get_test_vector(0.5, 1024);
let _quant = quantize_roundtrip::<BlockQ4K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.017);
// Test some specific values
assert_eq!(
[src[0], src[128], src[256], src[512], src[800], src[1023]],
[-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344]
);
let dst = round_vector(&dst);
assert_eq!(
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
[-0.5, -0.373, -0.25, 0.0, 0.288, 0.498]
);
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let _quant_big = quantize_roundtrip::<BlockQ4K>(src_big.as_slice(), dst_big.as_mut_slice())?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 4.5);
ggml_quantization_error_test::<BlockQ4K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(())
}
#[test]
fn quantize_q5k() -> Result<()> {
use k_quants::BlockQ5K;
let (src, mut dst) = get_test_vector(0.5, 1024);
let _quant = quantize_roundtrip::<BlockQ5K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
// Test some specific values
assert_eq!(
[src[0], src[128], src[256], src[512], src[800], src[1023]],
[-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344]
);
let dst = round_vector(&dst);
assert_eq!(
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
[-0.499, -0.372, -0.249, 0.001, 0.279, 0.499]
);
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let _quant_big = quantize_roundtrip::<BlockQ5K>(src_big.as_slice(), dst_big.as_mut_slice())?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.5);
ggml_quantization_error_test::<BlockQ5K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(())
}
#[test]
fn quantize_q6k() -> Result<()> {
use k_quants::BlockQ6K;
let (src, mut dst) = get_test_vector(0.5, 1024);
let _quant = quantize_roundtrip::<BlockQ6K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
// Test some specific values
assert_eq!(
[src[0], src[128], src[256], src[512], src[800], src[1023]],
[-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344]
);
let dst = round_vector(&dst);
assert_eq!(
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
[-0.497, -0.372, -0.25, -0.0, 0.284, 0.5]
);
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let _quant_big = quantize_roundtrip::<BlockQ6K>(src_big.as_slice(), dst_big.as_mut_slice())?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.0);
ggml_quantization_error_test::<BlockQ6K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(())
}
#[test]
fn quantize_q8k() -> Result<()> {
use k_quants::BlockQ8K;
let (src, mut dst) = get_test_vector(0.5, 1024);
let _quant = quantize_roundtrip::<BlockQ8K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.003);
// Test some specific values
assert_eq!(
[src[0], src[128], src[256], src[512], src[800], src[1023]],
[-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344]
);
let dst = round_vector(&dst);
assert_eq!(
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
[-0.5, -0.375, -0.25, -0.0, 0.281, 0.499]
);
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let _quant_big = quantize_roundtrip::<BlockQ8K>(src_big.as_slice(), dst_big.as_mut_slice())?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 0.6);
ggml_quantization_error_test::<BlockQ8K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(())
}
/// Very simple dot product implementation
fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(a, b)| a * b).sum()
}
/// Returns the error achieved by the GGML matmul unit test.
fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
let err = match dtype {
GgmlDType::F16 => 0.000010,
GgmlDType::Q2K => 0.004086,
GgmlDType::Q3K => 0.016148,
GgmlDType::Q4K => 0.002425,
GgmlDType::Q5K => 0.000740,
GgmlDType::Q6K => 0.000952,
GgmlDType::Q4_0 => 0.001143,
GgmlDType::Q4_1 => 0.007784,
GgmlDType::Q5_0 => 0.001353,
GgmlDType::Q5_1 => 0.001363,
GgmlDType::Q8_0 => 0.000092,
_ => candle_core::bail!("No GGML results for quantization type {dtype:?}",),
};
Ok(err)
}
/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
let a = create_ggml_like_vector(0.0);
let b = create_ggml_like_vector(1.0);
let length = a.len();
let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE];
let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE];
T::from_float(&a, &mut a_quant)?;
T::VecDotType::from_float(&b, &mut b_quant)?;
let result = T::vec_dot(length, &a_quant, &b_quant)?;
let reference_result = vec_dot_reference(&a, &b);
let error = (result - reference_result).abs() / length as f32;
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;
if error > GGML_MAX_DOT_PRODUCT_ERROR {
candle_core::bail!(
"Dot product error {} exceeds max error {}",
error,
GGML_MAX_DOT_PRODUCT_ERROR
);
}
// We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML
// => we use a slightly higher error threshold
const ERROR_LENIENCY: f32 = 0.00001;
if error - ERROR_LENIENCY > ggml_error {
candle_core::bail!(
"Dot product error {} exceeds ggml reference error {}",
error,
ggml_error
);
}
Ok(())
}
/// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result.
fn get_random_tensors(
m: usize,
k: usize,
n: usize,
device: &Device,
) -> Result<(Tensor, Tensor, Tensor)> {
let mut rng = StdRng::seed_from_u64(314159265358979);
let lhs = (0..m * k)
.map(|_| rng.gen::<f32>() - 0.5)
.collect::<Vec<_>>();
let rhs = (0..n * k)
.map(|_| rng.gen::<f32>() - 0.5)
.collect::<Vec<_>>();
let lhs = Tensor::from_vec(lhs, (m, k), device)?;
let rhs = Tensor::from_vec(rhs, (n, k), device)?;
let mm = lhs.matmul(&rhs.t()?)?;
Ok((lhs, rhs, mm))
}
#[test]
fn quantized_matmul_q2k() -> Result<()> {
use k_quants::BlockQ2K;
let cpu = &Device::Cpu;
let (m, k, n) = (11, 512, 21);
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs);
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [0.916, 0.422, 0.215, 1.668]);
ggml_matmul_error_test::<BlockQ2K>()?;
Ok(())
}
#[test]
fn quantized_matmul_q3k() -> Result<()> {
use k_quants::BlockQ3K;
let cpu = &Device::Cpu;
let (m, k, n) = (11, 512, 21);
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs);
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.029, 1.418, -0.314, 1.495]);
ggml_matmul_error_test::<BlockQ3K>()?;
Ok(())
}
#[test]
fn quantized_matmul_q4k() -> Result<()> {
use k_quants::BlockQ4K;
let cpu = &Device::Cpu;
let (m, k, n) = (11, 512, 21);
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs);
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.125, 1.435, -0.201, 1.589]);
ggml_matmul_error_test::<BlockQ4K>()?;
Ok(())
}
#[test]
fn quantized_matmul_q5k() -> Result<()> {
use k_quants::BlockQ5K;
let cpu = &Device::Cpu;
let (m, k, n) = (11, 512, 21);
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs);
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.192, 1.491, -0.18, 1.743]);
//Expected: 0.000740408897
ggml_matmul_error_test::<BlockQ5K>()?;
Ok(())
}
#[test]
fn quantized_matmul_q6k() -> Result<()> {
use k_quants::BlockQ6K;
let cpu = &Device::Cpu;
let (m, k, n) = (11, 512, 21);
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs);
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.324, 1.49, -0.164, 1.741]);
ggml_matmul_error_test::<BlockQ6K>()?;
Ok(())
}

View File

@ -1,4 +1,5 @@
use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor};
mod test_utils;
use candle_core::{DType, Device, IndexOp, Result, Tensor};
fn zeros(device: &Device) -> Result<()> {
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
@ -35,10 +36,10 @@ fn tensor_2d(device: &Device) -> Result<()> {
fn binary_op(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor1 = Tensor::new(data, device)?;
let tensor = Tensor::new(data, device)?;
let data2 = &[[5f32, 5., 5., 5., 5.], [2., 1., 7., 8., 2.]];
let tensor2 = Tensor::new(data2, device)?;
let tensor = (&tensor1 + (&tensor1 * &tensor1)? / (&tensor1 + &tensor2))?;
let tensor = (&tensor + (&tensor * &tensor)? / (&tensor + &tensor2))?;
let dims = tensor.dims2()?;
assert_eq!(dims, (2, 5));
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
@ -48,17 +49,6 @@ fn binary_op(device: &Device) -> Result<()> {
let tensor = (&tensor - &tensor)?;
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
assert_eq!(content[0], [0., 0., 0., 0., 0.]);
let min = tensor1.minimum(&(&tensor2 * 0.5)?)?;
let max = tensor1.maximum(&(&tensor2 * 0.5)?)?;
assert_eq!(
min.to_vec2::<f32>()?,
[[2.5, 1.0, 2.5, 1.0, 2.5], [1.0, 0.5, 3.5, 4.0, 1.0]],
);
assert_eq!(
max.to_vec2::<f32>()?,
[[3.0, 2.5, 4.0, 2.5, 5.0], [2.0, 1.0, 7.0, 8.0, 2.0]]
);
Ok(())
}
@ -757,25 +747,6 @@ fn matmul(device: &Device) -> Result<()> {
Ok(())
}
fn broadcast_matmul(device: &Device) -> Result<()> {
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
let out = lhs.broadcast_matmul(&rhs)?;
assert_eq!(out.dims(), &[3, 6, 4, 2]);
for idx1 in 0..3 {
for idx2 in 0..6 {
let out = out.i((idx1, idx2))?;
let lhs = lhs.i((idx1, 0))?;
let rhs = rhs.i(idx2)?;
let out2 = lhs.matmul(&rhs);
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
// With cuda, we see errors of up to ~1e-12.
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
}
}
Ok(())
}
fn broadcasting(device: &Device) -> Result<()> {
let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?;
let t2 = Tensor::new(&[100f32, 200f32], device)?;
@ -893,7 +864,6 @@ test_device!(binary_op, binary_op_cpu, binary_op_gpu);
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
test_device!(cmp, cmp_cpu, cmp_gpu);
test_device!(matmul, matmul_cpu, matmul_gpu);
test_device!(broadcast_matmul, broadcast_matmul_cpu, broadcast_matmul_gpu);
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
test_device!(index_select, index_select_cpu, index_select_gpu);
test_device!(index_add, index_add_cpu, index_add_gpu);

View File

@ -1,4 +1,9 @@
use crate::{Result, Tensor};
#![allow(dead_code)]
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_core::{Result, Tensor};
#[macro_export]
macro_rules! test_device {
@ -18,12 +23,6 @@ macro_rules! test_device {
};
}
pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
let b = 10f32.powi(digits);
let t = t.to_vec0::<f32>()?;
Ok(f32::round(t * b) / b)
}
pub fn to_vec1_round(t: &Tensor, digits: i32) -> Result<Vec<f32>> {
let b = 10f32.powi(digits);
let t = t.to_vec1::<f32>()?;
@ -41,7 +40,7 @@ pub fn to_vec2_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<f32>>> {
Ok(t)
}
pub fn to_vec3_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
pub fn to_vec3_round(t: Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
let b = 10f32.powi(digits);
let t = t.to_vec3::<f32>()?;
let t = t

View File

@ -11,13 +11,10 @@ readme = "README.md"
[dependencies]
byteorder = { workspace = true }
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.2.1" }
candle = { path = "../candle-core", version = "0.1.1", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.1.1" }
hf-hub = { workspace = true}
intel-mkl-src = { workspace = true, optional = true }
memmap2 = { workspace = true }
tokenizers = { workspace = true, features = ["onig"] }
rand = { workspace = true }
thiserror = { workspace = true }
parquet = { workspace = true}
image = { workspace = true }

View File

@ -1,73 +0,0 @@
use hf_hub::{
api::sync::{Api, ApiRepo},
Repo, RepoType,
};
use parquet::file::reader::SerializedFileReader;
use std::fs::File;
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("ApiError : {0}")]
ApiError(#[from] hf_hub::api::sync::ApiError),
#[error("IoError : {0}")]
IoError(#[from] std::io::Error),
#[error("ParquetError : {0}")]
ParquetError(#[from] parquet::errors::ParquetError),
}
fn sibling_to_parquet(
rfilename: &str,
repo: &ApiRepo,
) -> Result<SerializedFileReader<File>, Error> {
let local = repo.get(rfilename)?;
let file = File::open(local)?;
let reader = SerializedFileReader::new(file)?;
Ok(reader)
}
pub fn from_hub(api: &Api, dataset_id: String) -> Result<Vec<SerializedFileReader<File>>, Error> {
let repo = Repo::with_revision(
dataset_id,
RepoType::Dataset,
"refs/convert/parquet".to_string(),
);
let repo = api.repo(repo);
let info = repo.info()?;
let files: Result<Vec<_>, _> = info
.siblings
.into_iter()
.filter_map(|s| -> Option<Result<_, _>> {
let filename = s.rfilename;
if filename.ends_with(".parquet") {
let reader_result = sibling_to_parquet(&filename, &repo);
Some(reader_result)
} else {
None
}
})
.collect();
let files = files?;
Ok(files)
}
#[cfg(test)]
mod tests {
use super::*;
use parquet::file::reader::FileReader;
#[test]
fn test_dataset() {
let api = Api::new().unwrap();
let files = from_hub(
&api,
"hf-internal-testing/dummy_image_text_data".to_string(),
)
.unwrap();
assert_eq!(files.len(), 1);
assert_eq!(files[0].metadata().file_metadata().num_rows(), 20);
}
}

View File

@ -1,6 +1,5 @@
//! Datasets & Dataloaders for Candle
pub mod batcher;
pub mod hub;
pub mod nlp;
pub mod vision;

View File

@ -2,9 +2,7 @@
//!
//! The files can be obtained from the following link:
//! <http://yann.lecun.com/exdb/mnist/>
use candle::{DType, Device, Error, Result, Tensor};
use hf_hub::{api::sync::Api, Repo, RepoType};
use parquet::file::reader::{FileReader, SerializedFileReader};
use candle::{DType, Device, Result, Tensor};
use std::fs::File;
use std::io::{self, BufReader, Read};
@ -65,58 +63,3 @@ pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<crate::vision::Data
labels: 10,
})
}
fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor, Tensor)> {
let samples = parquet.metadata().file_metadata().num_rows() as usize;
let mut buffer_images: Vec<u8> = Vec::with_capacity(samples * 784);
let mut buffer_labels: Vec<u8> = Vec::with_capacity(samples);
for row in parquet.into_iter().flatten() {
for (_name, field) in row.get_column_iter() {
if let parquet::record::Field::Group(subrow) = field {
for (_name, field) in subrow.get_column_iter() {
if let parquet::record::Field::Bytes(value) = field {
let image = image::load_from_memory(value.data()).unwrap();
buffer_images.extend(image.to_luma8().as_raw());
}
}
} else if let parquet::record::Field::Long(label) = field {
buffer_labels.push(*label as u8);
}
}
}
let images = (Tensor::from_vec(buffer_images, (samples, 784), &Device::Cpu)?
.to_dtype(DType::F32)?
/ 255.)?;
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
Ok((images, labels))
}
pub fn load() -> Result<crate::vision::Dataset> {
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let dataset_id = "mnist".to_string();
let repo = Repo::with_revision(
dataset_id,
RepoType::Dataset,
"refs/convert/parquet".to_string(),
);
let repo = api.repo(repo);
let test_parquet_filename = repo
.get("mnist/test/0000.parquet")
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let train_parquet_filename = repo
.get("mnist/train/0000.parquet")
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
let (test_images, test_labels) = load_parquet(test_parquet)?;
let (train_images, train_labels) = load_parquet(train_parquet)?;
Ok(crate::vision::Dataset {
train_images,
train_labels,
test_images,
test_labels,
labels: 10,
})
}

View File

@ -11,11 +11,11 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.2.1" }
candle-nn = { path = "../candle-nn", version = "0.2.1" }
candle-transformers = { path = "../candle-transformers", version = "0.2.1" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
candle = { path = "../candle-core", version = "0.1.1", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.1.1" }
candle-nn = { path = "../candle-nn", version = "0.1.1" }
candle-transformers = { path = "../candle-transformers", version = "0.1.1" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.1.1", optional = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
@ -23,17 +23,15 @@ num-traits = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
image = { workspace = true }
image = { workspace = true, optional = true }
[dev-dependencies]
anyhow = { workspace = true }
byteorder = { workspace = true }
clap = { workspace = true }
hf-hub = { workspace = true, features=["tokio"]}
imageproc = { workspace = true }
clap = { workspace = true }
memmap2 = { workspace = true }
rand = { workspace = true }
rusttype = { workspace = true }
tokenizers = { workspace = true, features = ["onig"] }
tracing = { workspace = true }
tracing-chrome = { workspace = true }
@ -57,3 +55,7 @@ nccl = ["cuda", "cudarc/nccl", "dep:half"]
[[example]]
name = "llama_multiprocess"
required-features = ["cuda", "nccl", "flash-attn"]
[[example]]
name = "stable-diffusion"
required-features = ["image"]

View File

@ -1,8 +1,5 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
mod model;
use anyhow::{anyhow, Error as E, Result};
@ -62,16 +59,16 @@ impl Args {
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
let cache = Cache::default().repo(repo);
let cache = Cache::default();
(
cache
.get("config.json")
.get(&repo, "config.json")
.ok_or(anyhow!("Missing config file in cache"))?,
cache
.get("tokenizer.json")
.get(&repo, "tokenizer.json")
.ok_or(anyhow!("Missing tokenizer file in cache"))?,
cache
.get("model.safetensors")
.get(&repo, "model.safetensors")
.ok_or(anyhow!("Missing weights file in cache"))?,
)
} else {

View File

@ -1,5 +1,5 @@
use candle::{DType, Device, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
use candle_nn::{Embedding, VarBuilder};
use serde::Deserialize;
pub const DTYPE: DType = DType::F32;

View File

@ -1,9 +1,6 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;

View File

@ -1,5 +1,5 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
use candle_nn::{Embedding, LayerNorm, Linear, VarBuilder};
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;

View File

@ -2,16 +2,19 @@
// own forward pass (CPU and GPU versions) as well as their backward pass.
//
// In this example we add the RMS normalization operation and implement it for f32.
#![allow(dead_code)]
#![allow(unused)]
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[allow(unused)]
mod cuda_kernels;
use clap::Parser;
use candle::{CpuStorage, CustomOp1, Layout, Result, Shape, Tensor};
use candle::backend::BackendStorage;
use candle::cpu_backend;
use candle::{CpuStorage, CustomOp1, DType, Device, Layout, Result, Shape, Tensor};
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
@ -54,9 +57,8 @@ impl CustomOp1 for LayerNorm {
storage: &candle::CudaStorage,
layout: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::{LaunchAsync, LaunchConfig};
use candle::cuda_backend::WrapErr;
use candle::cuda_backend::{cudarc, WrapErr};
use cudarc::driver::{LaunchAsync, LaunchConfig};
let (d1, d2) = layout.shape().dims2()?;
let d1 = d1 as u32;
let d2 = d2 as u32;
@ -87,7 +89,7 @@ fn main() -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?;
let t = Tensor::arange(0f32, 14f32, &device)?.reshape((2, 7))?;
println!("{t}");
let t = t.apply_op1(LayerNorm { eps: 1e-5 })?;
let t = t.custom_op1(LayerNorm { eps: 1e-5 })?;
println!("{t}");
Ok(())
}

View File

@ -1,339 +0,0 @@
//! DINOv2: Learning Robust Visual Features without Supervision
//! https://github.com/facebookresearch/dinov2
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::Parser;
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
const IMG_SIZE: usize = 518;
const PATCH_SIZE: usize = 14;
const NUM_CLASSES: usize = 1000;
fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
if bias {
candle_nn::linear(in_dim, out_dim, vb)
} else {
candle_nn::linear_no_bias(in_dim, out_dim, vb)
}
}
#[derive(Debug)]
struct Attention {
qkv: Linear,
proj: Linear,
num_heads: usize,
scale: f64,
}
impl Attention {
fn new(
vb: VarBuilder,
dim: usize,
num_heads: usize,
qkv_bias: bool,
proj_bias: bool,
) -> Result<Self> {
let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
let scale = 1. / ((dim / num_heads) as f64).sqrt();
Ok(Self {
qkv,
proj,
num_heads,
scale,
})
}
}
impl Module for Attention {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (b, n, c) = xs.dims3()?;
let qkv = self
.qkv
.forward(xs)?
.reshape((b, n, 3, self.num_heads, c / self.num_heads))?
.transpose(1, 2)? // 02134
.transpose(0, 1)? // 20134
.transpose(2, 3)?; // 20314
let q = (qkv.i(0)? * self.scale)?;
let k = qkv.i(1)?;
let v = qkv.i(2)?;
let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?;
let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
self.proj.forward(&attn)
}
}
#[derive(Debug)]
struct LayerScale {
gamma: Tensor,
}
impl LayerScale {
fn new(vb: VarBuilder, dim: usize) -> Result<Self> {
let gamma = vb.get(dim, "gamma")?;
Ok(Self { gamma })
}
}
impl Module for LayerScale {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.broadcast_mul(&self.gamma)
}
}
#[derive(Debug)]
struct Mlp {
fc1: Linear,
fc2: Linear,
}
impl Mlp {
fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> {
let out_features = in_features;
let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?;
let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?;
Ok(Self { fc1, fc2 })
}
}
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.fc1.forward(xs)?.gelu()?;
self.fc2.forward(&xs)
}
}
#[derive(Debug)]
struct Block {
norm1: LayerNorm,
attn: Attention,
ls1: LayerScale,
norm2: LayerNorm,
mlp: Mlp,
ls2: LayerScale,
}
impl Block {
fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?;
let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
let ls2 = LayerScale::new(vb.pp("ls2"), dim)?;
Ok(Self {
norm1,
attn,
ls1,
norm2,
mlp,
ls2,
})
}
}
impl Module for Block {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
let xs = self
.ls1
.forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = self
.ls2
.forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?;
xs + residual
}
}
#[derive(Debug)]
struct PatchEmbed {
proj: candle_nn::Conv2d,
patch_size: (usize, usize),
num_patches: usize,
}
impl PatchEmbed {
fn new(
vb: VarBuilder,
img_size: usize,
patch_size: usize,
in_chans: usize,
embed_dim: usize,
) -> Result<Self> {
let config = candle_nn::Conv2dConfig {
stride: patch_size,
..Default::default()
};
let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?;
let num_patches = (img_size / patch_size) * (img_size / patch_size);
Ok(Self {
proj,
patch_size: (patch_size, patch_size),
num_patches,
})
}
}
impl Module for PatchEmbed {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (_b, _c, h, w) = xs.dims4()?;
let (patch_h, patch_w) = self.patch_size;
if (h % patch_h) != 0 {
candle::bail!("image height {h} is not a multiple of patch height {patch_h}")
}
if (w % patch_w) != 0 {
candle::bail!("image width {w} is not a multiple of patch width {patch_w}")
}
let xs = self.proj.forward(xs)?;
let (b, c, h, w) = xs.dims4()?;
// flatten embeddings.
xs.reshape((b, c, h * w))?.transpose(1, 2)
}
}
#[derive(Debug)]
pub struct DinoVisionTransformer {
patch_embed: PatchEmbed,
cls_token: Tensor,
pos_embed: Tensor,
blocks: Vec<Block>,
norm: LayerNorm,
head: Linear,
}
impl DinoVisionTransformer {
pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> {
let patch_embed =
PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?;
let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
let num_tokens = 1;
let pos_embed = vb.get(
(1, patch_embed.num_patches + num_tokens, embed_dim),
"pos_embed",
)?;
let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?;
let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?;
let vb_b = vb.pp("blocks");
let blocks = (0..depth)
.map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
patch_embed,
cls_token,
pos_embed,
blocks,
norm,
head,
})
}
fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
let npatch = xs.dim(1)? - 1;
let n = self.pos_embed.dim(1)? - 1;
let sqrt_n = (n as f64).sqrt();
if npatch == n && w == h {
return Ok(xs.clone());
}
let class_pos_embed = self.pos_embed.i((.., ..1))?;
let patch_pos_embed = self.pos_embed.i((.., 1..))?;
let dim = xs.dim(D::Minus1)?;
let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
let patch_pos_embed = patch_pos_embed
.reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
.transpose(2, 3)?
.transpose(1, 2)?;
// This uses bicubic interpolation in the original implementation.
let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
let el_count = patch_pos_embed.shape().elem_count();
let patch_pos_embed =
patch_pos_embed
.transpose(1, 2)?
.transpose(2, 3)?
.reshape((1, el_count / dim, dim))?;
Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1)
}
fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
let (_b, _nc, w, h) = xs.dims4()?;
let xs = self.patch_embed.forward(xs)?;
let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
&xs + &self.interpolate_pos_encoding(&xs, w, h)?
}
}
impl Module for DinoVisionTransformer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.prepare_tokens_with_mask(xs)?;
for blk in self.blocks.iter() {
xs = blk.forward(&xs)?
}
let xs = self.norm.forward(&xs)?;
let xs_norm_clstoken = xs.i((.., 0))?;
let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?;
let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?;
self.head.forward(&xs)
}
}
pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> {
DinoVisionTransformer::new(vb, 12, 384, 6)
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("lmz/candle-dino-v2".into());
api.get("dinov2_vits14.safetensors")?
}
Some(model) => model.into(),
};
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let model = vit_small(vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -1,429 +0,0 @@
//! EfficientNet implementation.
//!
//! https://arxiv.org/abs/1905.11946
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn as nn;
use nn::{Module, VarBuilder};
// Based on the Python version from torchvision.
// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47
#[derive(Debug, Clone, Copy)]
pub struct MBConvConfig {
expand_ratio: f64,
kernel: usize,
stride: usize,
input_channels: usize,
out_channels: usize,
num_layers: usize,
}
fn make_divisible(v: f64, divisor: usize) -> usize {
let min_value = divisor;
let new_v = usize::max(
min_value,
(v + divisor as f64 * 0.5) as usize / divisor * divisor,
);
if (new_v as f64) < 0.9 * v {
new_v + divisor
} else {
new_v
}
}
fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> {
let bneck_conf = |e, k, s, i, o, n| {
let input_channels = make_divisible(i as f64 * width_mult, 8);
let out_channels = make_divisible(o as f64 * width_mult, 8);
let num_layers = (n as f64 * depth_mult).ceil() as usize;
MBConvConfig {
expand_ratio: e,
kernel: k,
stride: s,
input_channels,
out_channels,
num_layers,
}
};
vec![
bneck_conf(1., 3, 1, 32, 16, 1),
bneck_conf(6., 3, 2, 16, 24, 2),
bneck_conf(6., 5, 2, 24, 40, 2),
bneck_conf(6., 3, 2, 40, 80, 3),
bneck_conf(6., 5, 1, 80, 112, 3),
bneck_conf(6., 5, 2, 112, 192, 4),
bneck_conf(6., 3, 1, 192, 320, 1),
]
}
impl MBConvConfig {
fn b0() -> Vec<Self> {
bneck_confs(1.0, 1.0)
}
fn b1() -> Vec<Self> {
bneck_confs(1.0, 1.1)
}
fn b2() -> Vec<Self> {
bneck_confs(1.1, 1.2)
}
fn b3() -> Vec<Self> {
bneck_confs(1.2, 1.4)
}
fn b4() -> Vec<Self> {
bneck_confs(1.4, 1.8)
}
fn b5() -> Vec<Self> {
bneck_confs(1.6, 2.2)
}
fn b6() -> Vec<Self> {
bneck_confs(1.8, 2.6)
}
fn b7() -> Vec<Self> {
bneck_confs(2.0, 3.1)
}
}
/// Conv2D with same padding.
#[derive(Debug)]
struct Conv2DSame {
conv2d: nn::Conv2d,
s: usize,
k: usize,
}
impl Conv2DSame {
fn new(
vb: VarBuilder,
i: usize,
o: usize,
k: usize,
stride: usize,
groups: usize,
bias: bool,
) -> Result<Self> {
let conv_config = nn::Conv2dConfig {
stride,
groups,
..Default::default()
};
let conv2d = if bias {
nn::conv2d(i, o, k, conv_config, vb)?
} else {
nn::conv2d_no_bias(i, o, k, conv_config, vb)?
};
Ok(Self {
conv2d,
s: stride,
k,
})
}
}
impl Module for Conv2DSame {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let s = self.s;
let k = self.k;
let (_, _, ih, iw) = xs.dims4()?;
let oh = (ih + s - 1) / s;
let ow = (iw + s - 1) / s;
let pad_h = usize::max((oh - 1) * s + k - ih, 0);
let pad_w = usize::max((ow - 1) * s + k - iw, 0);
if pad_h > 0 || pad_w > 0 {
let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?;
let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?;
self.conv2d.forward(&xs)
} else {
self.conv2d.forward(xs)
}
}
}
#[derive(Debug)]
struct ConvNormActivation {
conv2d: Conv2DSame,
bn2d: nn::BatchNorm,
activation: bool,
}
impl ConvNormActivation {
fn new(
vb: VarBuilder,
i: usize,
o: usize,
k: usize,
stride: usize,
groups: usize,
) -> Result<Self> {
let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?;
let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?;
Ok(Self {
conv2d,
bn2d,
activation: true,
})
}
fn no_activation(self) -> Self {
Self {
activation: false,
..self
}
}
}
impl Module for ConvNormActivation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.conv2d.forward(xs)?;
let xs = self.bn2d.forward(&xs)?;
if self.activation {
swish(&xs)
} else {
Ok(xs)
}
}
}
#[derive(Debug)]
struct SqueezeExcitation {
fc1: Conv2DSame,
fc2: Conv2DSame,
}
impl SqueezeExcitation {
fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> {
let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?;
let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?;
Ok(Self { fc1, fc2 })
}
}
impl Module for SqueezeExcitation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
// equivalent to adaptive_avg_pool2d([1, 1])
let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
let xs = self.fc1.forward(&xs)?;
let xs = swish(&xs)?;
let xs = self.fc2.forward(&xs)?;
let xs = nn::ops::sigmoid(&xs)?;
residual.broadcast_mul(&xs)
}
}
#[derive(Debug)]
struct MBConv {
expand_cna: Option<ConvNormActivation>,
depthwise_cna: ConvNormActivation,
squeeze_excitation: SqueezeExcitation,
project_cna: ConvNormActivation,
config: MBConvConfig,
}
impl MBConv {
fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> {
let vb = vb.pp("block");
let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8);
let expand_cna = if exp != c.input_channels {
Some(ConvNormActivation::new(
vb.pp("0"),
c.input_channels,
exp,
1,
1,
1,
)?)
} else {
None
};
let start_index = if expand_cna.is_some() { 1 } else { 0 };
let depthwise_cna =
ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?;
let squeeze_channels = usize::max(1, c.input_channels / 4);
let squeeze_excitation =
SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?;
let project_cna =
ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)?
.no_activation();
Ok(Self {
expand_cna,
depthwise_cna,
squeeze_excitation,
project_cna,
config: c,
})
}
}
impl Module for MBConv {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let use_res_connect =
self.config.stride == 1 && self.config.input_channels == self.config.out_channels;
let ys = match &self.expand_cna {
Some(expand_cna) => expand_cna.forward(xs)?,
None => xs.clone(),
};
let ys = self.depthwise_cna.forward(&ys)?;
let ys = self.squeeze_excitation.forward(&ys)?;
let ys = self.project_cna.forward(&ys)?;
if use_res_connect {
ys + xs
} else {
Ok(ys)
}
}
}
fn swish(s: &Tensor) -> Result<Tensor> {
s * nn::ops::sigmoid(s)?
}
#[derive(Debug)]
struct EfficientNet {
init_cna: ConvNormActivation,
blocks: Vec<MBConv>,
final_cna: ConvNormActivation,
classifier: nn::Linear,
}
impl EfficientNet {
fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> {
let f_p = p.pp("features");
let first_in_c = configs[0].input_channels;
let last_out_c = configs.last().unwrap().out_channels;
let final_out_c = 4 * last_out_c;
let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
let nconfigs = configs.len();
let mut blocks = vec![];
for (index, cnf) in configs.into_iter().enumerate() {
let f_p = f_p.pp(index + 1);
for r_index in 0..cnf.num_layers {
let cnf = if r_index == 0 {
cnf
} else {
MBConvConfig {
input_channels: cnf.out_channels,
stride: 1,
..cnf
}
};
blocks.push(MBConv::new(f_p.pp(r_index), cnf)?)
}
}
let final_cna =
ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?;
let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?;
Ok(Self {
init_cna,
blocks,
final_cna,
classifier,
})
}
}
impl Module for EfficientNet {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.init_cna.forward(xs)?;
for block in self.blocks.iter() {
xs = block.forward(&xs)?
}
let xs = self.final_cna.forward(&xs)?;
// Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1)
let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?;
self.classifier.forward(&xs)
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
B0,
B1,
B2,
B3,
B4,
B5,
B6,
B7,
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Variant of the model to use.
#[arg(value_enum, long, default_value_t = Which::B2)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("lmz/candle-efficientnet".into());
let filename = match args.which {
Which::B0 => "efficientnet-b0.safetensors",
Which::B1 => "efficientnet-b1.safetensors",
Which::B2 => "efficientnet-b2.safetensors",
Which::B3 => "efficientnet-b3.safetensors",
Which::B4 => "efficientnet-b4.safetensors",
Which::B5 => "efficientnet-b5.safetensors",
Which::B6 => "efficientnet-b6.safetensors",
Which::B7 => "efficientnet-b7.safetensors",
};
api.get(filename)?
}
Some(model) => model.into(),
};
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let cfg = match args.which {
Which::B0 => MBConvConfig::b0(),
Which::B1 => MBConvConfig::b1(),
Which::B2 => MBConvConfig::b2(),
Which::B3 => MBConvConfig::b3(),
Which::B4 => MBConvConfig::b4(),
Which::B5 => MBConvConfig::b5(),
Which::B6 => MBConvConfig::b6(),
Which::B7 => MBConvConfig::b7(),
};
let model = EfficientNet::new(vb, cfg, candle_examples::imagenet::CLASS_COUNT as usize)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -22,8 +22,6 @@ struct TextGeneration {
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
@ -33,8 +31,6 @@ impl TextGeneration {
seed: u64,
temp: Option<f64>,
device: &Device,
repeat_penalty: f32,
repeat_last_n: usize,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp);
Self {
@ -42,8 +38,6 @@ impl TextGeneration {
tokenizer,
logits_processor,
device: device.clone(),
repeat_penalty,
repeat_last_n,
}
}
@ -69,16 +63,6 @@ impl TextGeneration {
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
@ -132,14 +116,6 @@ struct Args {
#[arg(long, default_value = "refs/pr/43")]
revision: String,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.0)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
@ -186,15 +162,7 @@ fn main() -> Result<()> {
let model = Falcon::load(vb, config)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
&device,
args.repeat_penalty,
args.repeat_last_n,
);
let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -1,6 +1,6 @@
use anyhow::Result;
use candle::{DType, Device, Tensor, D};
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
use candle_nn::{Embedding, LayerNorm, Linear, VarBuilder};
const MAX_SEQ_LEN: usize = 5000;

View File

@ -0,0 +1,28 @@
use anyhow::Result;
use clap::Parser;
use std::fs::File;
use candle::quantized::ggml_file::Content;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// GGML file to load, typically a .bin file generated by the quantize command from llama.cpp
#[arg(long)]
model: String,
}
fn main() -> Result<()> {
let args = Args::parse();
let mut file = File::open(args.model)?;
let start = std::time::Instant::now();
let model = Content::read(&mut file)?;
println!(
"Loaded {:?} tensors in {:?}",
model.tensors.len(),
start.elapsed()
);
Ok(())
}

View File

@ -12,17 +12,17 @@ extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use anyhow::{bail, Error as E, Result};
use anyhow::{Error as E, Result};
use clap::Parser;
use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use hf_hub::api::sync::Api;
use std::io::Write;
mod model;
use model::{Config, Llama, LlamaConfig};
use model::{Config, Llama};
const EOS_TOKEN: &str = "</s>";
const MAX_SEQ_LEN: usize = 4096;
@ -59,9 +59,9 @@ struct Args {
#[arg(long)]
prompt: Option<String>,
/// Use different dtype than f16
/// Use f32 computations rather than f16.
#[arg(long)]
dtype: Option<String>,
use_f32: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
@ -70,9 +70,6 @@ struct Args {
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
v1: bool,
@ -83,14 +80,6 @@ struct Args {
/// (same structure as huggingface online)
#[arg(long)]
local_weights: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.0)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
@ -100,6 +89,7 @@ fn main() -> Result<()> {
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
@ -108,24 +98,18 @@ fn main() -> Result<()> {
};
let device = candle_examples::device(args.cpu)?;
let dtype = match args.dtype.as_deref() {
Some("f16") => DType::F16,
Some("bf16") => DType::BF16,
Some("f32") => DType::F32,
Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::F16,
let config = if args.v1 {
Config::config_7b_v1(args.use_flash_attn)
} else {
Config::config_7b_v2(args.use_flash_attn)
};
let (llama, tokenizer_filename, cache) = match args.npy {
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let (llama, tokenizer_filename) = match args.npy {
Some(filename) => {
let config = if args.v1 {
Config::config_7b_v1(args.use_flash_attn)
} else {
Config::config_7b_v2(args.use_flash_attn)
};
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = VarBuilder::from_npz(filename, dtype, &device)?;
let tokenizer = std::path::PathBuf::from("llama-tokenizer.json");
(Llama::load(vb, &cache, &config)?, tokenizer, cache)
(Llama::load(vb, &cache, &config)?, tokenizer)
}
None => {
let api = Api::new()?;
@ -137,21 +121,13 @@ fn main() -> Result<()> {
}
});
println!("loading the model weights from {model_id}");
let revision = args.revision.unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let api = api.model(model_id);
let tokenizer_filename = match &args.local_weights {
Some(path) => (path.to_owned() + "tokenizer.json").into(),
_ => api.get("tokenizer.json")?,
};
let config_filename = match &args.local_weights {
Some(path) => (path.to_owned() + "config.json").into(),
_ => api.get("config.json")?,
};
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let config = config.into_config(args.use_flash_attn);
let mut filenames = vec![];
for rfilename in [
"model-00001-of-00002.safetensors",
@ -177,10 +153,9 @@ fn main() -> Result<()> {
.iter()
.map(|h| Ok(h.deserialize()?))
.collect::<Result<Vec<_>>>()?;
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
(Llama::load(vb, &cache, &config)?, tokenizer_filename)
}
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
@ -208,16 +183,6 @@ fn main() -> Result<()> {
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, index_pos)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&tokens[start_at..],
)?
};
index_pos += ctxt.len();
let next_token = logits_processor.sample(&logits)?;

View File

@ -1,54 +1,20 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Module, VarBuilder};
use serde::Deserialize;
use candle_nn::{Embedding, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use super::MAX_SEQ_LEN;
#[derive(Deserialize)]
pub struct LlamaConfig {
pub hidden_size: usize,
pub intermediate_size: usize,
pub vocab_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: Option<usize>,
pub rms_norm_eps: f64,
#[serde(default = "default_rope")]
pub rope_theta: f32,
}
fn default_rope() -> f32 {
10_000.0
}
impl LlamaConfig {
pub fn into_config(self, use_flash_attn: bool) -> Config {
Config {
hidden_size: self.hidden_size,
intermediate_size: self.intermediate_size,
vocab_size: self.vocab_size,
num_hidden_layers: self.num_hidden_layers,
num_attention_heads: self.num_attention_heads,
num_key_value_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads),
rms_norm_eps: self.rms_norm_eps,
rope_theta: self.rope_theta,
use_flash_attn,
}
}
}
pub struct Config {
pub hidden_size: usize,
pub intermediate_size: usize,
pub vocab_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub n_layer: usize,
pub n_head: usize,
pub n_embd: usize,
pub n_key_value_head: usize,
pub use_flash_attn: bool,
pub rms_norm_eps: f64,
pub rope_theta: f32,
}
impl Config {
@ -57,12 +23,12 @@ impl Config {
hidden_size: 4096,
intermediate_size: 11008,
vocab_size: 32000,
num_hidden_layers: 32,
num_attention_heads: 32,
num_key_value_heads: 32,
n_layer: 32,
n_head: 32,
n_embd: 4096,
n_key_value_head: 32,
use_flash_attn,
rms_norm_eps: 1e-6,
rope_theta: 10_000.0,
}
}
@ -71,12 +37,12 @@ impl Config {
hidden_size: 4096,
intermediate_size: 11008,
vocab_size: 32000,
num_hidden_layers: 32,
num_attention_heads: 32,
num_key_value_heads: 32,
n_layer: 32,
n_head: 32,
n_embd: 4096,
n_key_value_head: 32,
use_flash_attn,
rms_norm_eps: 1e-5,
rope_theta: 10_000.0,
}
}
}
@ -110,10 +76,10 @@ pub struct Cache {
impl Cache {
pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> {
// precompute freqs_cis
let n_elem = config.hidden_size / config.num_attention_heads;
let n_elem = config.n_embd / config.n_head;
let theta: Vec<_> = (0..n_elem)
.step_by(2)
.map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32))
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), device)?;
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
@ -128,7 +94,7 @@ impl Cache {
Ok(Self {
masks: Arc::new(Mutex::new(HashMap::new())),
use_kv_cache,
kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])),
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
device: device.clone(),
cos,
sin,
@ -150,6 +116,10 @@ impl Cache {
}
}
fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)?
}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let span = tracing::span!(tracing::Level::TRACE, "linear");
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
@ -162,20 +132,35 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
}
struct RmsNorm {
inner: candle_nn::RmsNorm,
scale: Tensor,
eps: f64,
span: tracing::Span,
}
impl RmsNorm {
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let inner = candle_nn::rms_norm(size, eps, vb)?;
Ok(Self { inner, span })
let scale = vb.get(size, "weight")?;
Ok(Self { scale, eps, span })
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
let in_dtype = x.dtype();
// This is a no-op if x's dtype is already f32.
let x = x.to_dtype(DType::F32)?;
let (b_sz, seq_len, hidden_size) = x.dims3()?;
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
let size = self.scale.dims1()?;
let scale = self
.scale
.to_dtype(DType::F32)?
.broadcast_as((b_sz, seq_len, size))?;
let x = (scale * x_normed)?;
let x = x.to_dtype(in_dtype)?;
Ok(x)
}
}
@ -184,8 +169,8 @@ struct CausalSelfAttention {
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
num_attention_heads: usize,
num_key_value_heads: usize,
n_head: usize,
n_key_value_head: usize,
head_dim: usize,
cache: Cache,
use_flash_attn: bool,
@ -212,13 +197,13 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let _enter = self.span_rot.enter();
let (b_sz, _, seq_len, hidden_size) = x.dims4()?;
let (b_sz, _, seq_len, n_embd) = x.dims4()?;
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?;
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?;
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd))?;
let x1 = x.narrow(D::Minus1, 0, n_embd / 2)?;
let x2 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?;
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
Ok(rope)
@ -226,19 +211,19 @@ impl CausalSelfAttention {
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let _enter = self.span.enter();
let (b_sz, seq_len, hidden_size) = x.dims3()?;
let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;
let q = q
.reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
.transpose(1, 2)?;
let mut v = v
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
.transpose(1, 2)?;
let q = self.apply_rotary_emb(&q, index_pos)?;
@ -287,13 +272,13 @@ impl CausalSelfAttention {
// Convert to contiguous as matmul doesn't support strided vs for now.
att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
};
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = self.o_proj.forward(&y)?;
Ok(y)
}
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
let n_rep = self.num_attention_heads / self.num_key_value_heads;
let n_rep = self.n_head / self.n_key_value_head;
if n_rep == 1 {
Ok(x)
} else {
@ -301,7 +286,7 @@ impl CausalSelfAttention {
let x = x
.unsqueeze(2)?
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
.reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?;
Ok(x)
}
}
@ -310,8 +295,8 @@ impl CausalSelfAttention {
let span = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let size_in = cfg.hidden_size;
let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
let size_q = (cfg.hidden_size / cfg.n_head) * cfg.n_head;
let size_kv = (cfg.hidden_size / cfg.n_head) * cfg.n_key_value_head;
let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?;
let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?;
let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?;
@ -321,9 +306,9 @@ impl CausalSelfAttention {
k_proj,
v_proj,
o_proj,
num_attention_heads: cfg.num_attention_heads,
num_key_value_heads: cfg.num_key_value_heads,
head_dim: cfg.hidden_size / cfg.num_attention_heads,
n_head: cfg.n_head,
n_key_value_head: cfg.n_key_value_head,
head_dim: cfg.hidden_size / cfg.n_head,
cache: cache.clone(),
use_flash_attn: cfg.use_flash_attn,
span,
@ -349,7 +334,7 @@ struct Mlp {
impl Mlp {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let x = (candle_nn::ops::silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
self.c_proj.forward(&x)
}
@ -432,7 +417,7 @@ impl Llama {
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
let blocks: Vec<_> = (0..cfg.n_layer)
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap())
.collect();

View File

@ -103,14 +103,6 @@ pub struct Args {
/// Tokenizer config file.
#[arg(long)]
tokenizer: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
impl Args {
@ -276,16 +268,6 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = model.forward(&input, index_pos)?;
let logits = logits.i((0, logits.dim(1)? - 1))?;
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
logits
} else {
let start_at = tokens.len().saturating_sub(common_args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
common_args.repeat_penalty,
&tokens[start_at..],
)?
};
index_pos += ctxt.len();
let next_token = logits_processor.sample(&logits)?;

View File

@ -1,6 +1,6 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::linear_no_bias as linear;
use candle_nn::{embedding, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
use candle_nn::{embedding, Embedding, Linear, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@ -94,6 +94,32 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)?
}
struct RmsNorm {
scale: Tensor,
eps: f64,
}
impl RmsNorm {
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let scale = vb.get_or_init(size, "weight", candle_nn::Init::Const(1.))?;
Ok(Self { scale, eps })
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let (b_sz, seq_len, hidden_size) = x.dims3()?;
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
let size = self.scale.dims1()?;
let scale = self
.scale
.to_dtype(DType::F32)?
.broadcast_as((b_sz, seq_len, size))?;
let x = (scale * x_normed)?;
Ok(x)
}
}
struct CausalSelfAttention {
q_proj: Linear,
k_proj: Linear,
@ -264,9 +290,9 @@ impl Block {
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
let input_layernorm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm =
rms_norm(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?;
RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?;
Ok(Self::new(
input_layernorm,
attn,
@ -299,7 +325,7 @@ impl Llama {
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let ln_f = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layers)
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, &cfg).unwrap())
.collect();

View File

@ -1,7 +1,6 @@
use crate::model::{Cache, Config, Llama};
use candle::{DType, Device, Result};
use candle_datasets::nlp::tinystories::{Dataset, DatasetRandomIter};
use candle_nn::Optimizer;
fn valid_loss(
dataset: &Dataset,

View File

@ -9,14 +9,15 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use anyhow::{bail, Error as E, Result};
use anyhow::{Error as E, Result};
use clap::Parser;
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use cudarc::driver::safe::CudaDevice;
use cudarc::nccl::safe::{Comm, Id};
use hf_hub::{api::sync::Api, Repo, RepoType};
use hf_hub::api::sync::Api;
use std::io::Write;
use std::rc::Rc;
@ -107,12 +108,6 @@ struct Args {
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
dtype: Option<String>,
}
fn main() -> Result<()> {
@ -120,13 +115,8 @@ fn main() -> Result<()> {
let args = Args::parse();
let dtype = match args.dtype.as_deref() {
Some("f16") => DType::F16,
Some("bf16") => DType::BF16,
Some("f32") => DType::F32,
Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::F16,
};
let config = Config::config_7b();
let dtype = DType::F16;
let api = Api::new()?;
@ -134,10 +124,7 @@ fn main() -> Result<()> {
.model_id
.unwrap_or_else(|| "meta-llama/Llama-2-7b-hf".to_string());
println!("loading the model weights from {model_id}");
let revision = args.revision.unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let config_filename = api.get("config.json")?;
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let api = api.model(model_id);
let tokenizer_filename = api.get("tokenizer.json")?;
let mut filenames = vec![];
for rfilename in [
@ -198,7 +185,7 @@ fn main() -> Result<()> {
println!("Rank {rank:?} spawned");
let device = Device::new_cuda(i)?;
let cache = model::Cache::new(dtype, &config, &device)?;
let cache = model::Cache::new(&config, &device)?;
println!("building the model");
let handles = filenames
@ -210,7 +197,7 @@ fn main() -> Result<()> {
.map(|h| Ok(h.deserialize()?))
.collect::<Result<Vec<_>>>()?;
let vb = candle_nn::var_builder::ShardedSafeTensors::var_builder(tensors, dtype, &device);
let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
let llama = Llama::load(vb, &cache, &config, comm)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
@ -244,7 +231,7 @@ fn main() -> Result<()> {
"{} token: {} '{}'",
index + 1,
next_token,
tokenizer.decode(&[next_token], true).map_err(E::msg)?
tokenizer.decode(vec![next_token], true).map_err(E::msg)?
);
}
}
@ -254,9 +241,7 @@ fn main() -> Result<()> {
"{} tokens generated ({} token/s)\n----\n{}\n----",
args.sample_len,
args.sample_len as f64 / dt.as_secs_f64(),
tokenizer
.decode(new_tokens.as_slice(), true)
.map_err(E::msg)?
tokenizer.decode(new_tokens, true).map_err(E::msg)?
);
}
Ok(())

View File

@ -1,16 +1,13 @@
use candle::backend::BackendStorage;
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
use candle_nn::{Embedding, Linear, Module, RmsNorm};
use candle_nn::{Embedding, Linear, VarBuilder};
use cudarc::nccl::safe::{Comm, ReduceOp};
use half::f16;
use serde::Deserialize;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use super::MAX_SEQ_LEN;
use candle_nn::var_builder::ShardedVarBuilder as VarBuilder;
struct TensorParallelColumnLinear {
linear: Linear,
}
@ -71,7 +68,7 @@ impl CustomOp1 for AllReduce {
}
fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
x.apply_op1(AllReduce { comm: comm.clone() })
x.custom_op1(AllReduce { comm: comm.clone() })
}
impl TensorParallelRowLinear {
@ -84,19 +81,11 @@ impl TensorParallelRowLinear {
}
}
fn shard(dim: usize, rank: usize, world_size: usize) -> candle_nn::var_builder::Shard {
candle_nn::var_builder::Shard {
dim,
rank,
world_size,
}
}
impl TensorParallelColumnLinear {
fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
let rank = comm.rank();
let size = comm.world_size();
let weight = vb.get_with_hints((), "weight", shard(0, rank, size))?;
let weight = vb.get_sharded("weight", 0, rank, size)?;
Ok(Self::new(Linear::new(weight, None)))
}
@ -105,8 +94,8 @@ impl TensorParallelColumnLinear {
let size = comm.world_size();
let weights: Vec<_> = prefixes
.iter()
.map(|p| vb.pp(p).get_with_hints((), "weight", shard(0, rank, size)))
.collect::<Result<Vec<_>>>()?;
.map(|p| vb.pp(p).get_sharded("weight", 0, rank, size).unwrap())
.collect();
let weight = Tensor::cat(&weights, 0)?;
Ok(Self::new(Linear::new(weight, None)))
}
@ -116,26 +105,33 @@ impl TensorParallelRowLinear {
fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
let rank = comm.rank();
let size = comm.world_size();
let weight = vb.get_with_hints((), "weight", shard(1, rank, size))?;
let weight = vb.get_sharded("weight", 1, rank, size)?;
Ok(Self::new(Linear::new(weight, None), comm))
}
}
#[derive(Deserialize)]
pub struct Config {
pub hidden_size: usize,
pub intermediate_size: usize,
pub vocab_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub rms_norm_eps: f64,
#[serde(default = "default_rope")]
pub rope_theta: f32,
pub n_layer: usize,
pub n_head: usize,
pub n_embd: usize,
pub n_key_value_head: usize,
}
fn default_rope() -> f32 {
10_000.0
impl Config {
pub fn config_7b() -> Self {
Self {
hidden_size: 4096,
intermediate_size: 11008,
vocab_size: 32000,
n_layer: 32,
n_head: 32,
n_embd: 4096,
n_key_value_head: 32,
}
}
}
#[derive(Clone)]
@ -147,12 +143,12 @@ pub struct Cache {
}
impl Cache {
pub fn new(dtype: DType, config: &Config, device: &Device) -> Result<Self> {
pub fn new(config: &Config, device: &Device) -> Result<Self> {
// precompute freqs_cis
let n_elem = config.hidden_size / config.num_attention_heads;
let n_elem = config.n_embd / config.n_head;
let theta: Vec<_> = (0..n_elem)
.step_by(2)
.map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32))
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), device)?;
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
@ -162,10 +158,10 @@ impl Cache {
// This is different from the paper, see:
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
let cos = idx_theta.cos()?.to_dtype(dtype)?;
let sin = idx_theta.sin()?.to_dtype(dtype)?;
let cos = idx_theta.cos()?.to_dtype(DType::F16)?;
let sin = idx_theta.sin()?.to_dtype(DType::F16)?;
Ok(Self {
kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])),
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
cos,
sin,
})
@ -186,24 +182,57 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
Ok(Embedding::new(embeddings, cfg.hidden_size))
}
struct RmsNorm {
scale: Tensor,
}
impl RmsNorm {
fn load(size: usize, vb: VarBuilder) -> Result<Self> {
let scale = vb.get(size, "weight")?;
Ok(Self::new(scale))
}
fn new(scale: Tensor) -> Self {
Self { scale }
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let in_dtype = x.dtype();
// This is a no-op if x's dtype is already f32.
let x = x.to_dtype(DType::F32)?;
let (b_sz, seq_len, hidden_size) = x.shape().dims3()?;
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
let size = self.scale.shape().dims1()?;
let scale = self
.scale
.to_dtype(DType::F32)?
.broadcast_as((b_sz, seq_len, size))?;
let x = (scale * x_normed)?;
let x = x.to_dtype(in_dtype)?;
Ok(x)
}
}
struct CausalSelfAttention {
qkv_proj: TensorParallelColumnLinear,
o_proj: TensorParallelRowLinear,
num_attention_heads: usize,
num_key_value_heads: usize,
n_head: usize,
n_key_value_head: usize,
head_dim: usize,
cache: Cache,
}
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (b_sz, _, seq_len, hidden_size) = x.shape().dims4()?;
let (b_sz, _, seq_len, n_embd) = x.shape().dims4()?;
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?;
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?;
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd))?;
let x1 = x.narrow(D::Minus1, 0, n_embd / 2)?;
let x2 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?;
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
Ok(rope)
@ -213,31 +242,30 @@ impl CausalSelfAttention {
let (b_sz, seq_len, _) = x.shape().dims3()?;
let qkv = self.qkv_proj.forward(x)?;
let hidden_size = self.num_attention_heads * self.head_dim;
let n_embd = self.n_head * self.head_dim;
let q = qkv.i((.., .., ..self.num_attention_heads * self.head_dim))?;
let q = qkv.i((.., .., ..self.n_head * self.head_dim))?;
let k = qkv.i((
..,
..,
self.num_attention_heads * self.head_dim
..self.num_attention_heads * self.head_dim
+ self.num_key_value_heads * self.head_dim,
self.n_head * self.head_dim
..self.n_head * self.head_dim + self.n_key_value_head * self.head_dim,
))?;
let v = qkv.i((
..,
..,
self.num_attention_heads * self.head_dim + self.num_key_value_heads * self.head_dim..,
self.n_head * self.head_dim + self.n_key_value_head * self.head_dim..,
))?;
// todo!("Q {:?} K {:?} V {:?} - x {:?}", q.shape(), k.shape(), v.shape(), x.shape());
let q = q
.reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
.transpose(1, 2)?;
let mut v = v
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
.transpose(1, 2)?;
let q = self.apply_rotary_emb(&q, index_pos)?;
@ -271,13 +299,13 @@ impl CausalSelfAttention {
let y = candle_flash_attn::flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?
.transpose(1, 2)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = self.o_proj.forward(&y)?;
Ok(y)
}
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
let n_rep = self.num_attention_heads / self.num_key_value_heads;
let n_rep = self.n_head / self.n_key_value_head;
if n_rep == 1 {
Ok(x)
} else {
@ -300,9 +328,9 @@ impl CausalSelfAttention {
Ok(Self {
qkv_proj,
o_proj,
num_attention_heads: cfg.num_attention_heads / comm.world_size(),
num_key_value_heads: cfg.num_key_value_heads / comm.world_size(),
head_dim: cfg.hidden_size / cfg.num_attention_heads,
n_head: cfg.n_head / comm.world_size(),
n_key_value_head: cfg.n_key_value_head / comm.world_size(),
head_dim: cfg.hidden_size / cfg.n_head,
cache: cache.clone(),
})
}
@ -347,11 +375,6 @@ struct Block {
mlp: Mlp,
}
fn rms_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<RmsNorm> {
let weight = vb.get_with_hints(size, "weight", shard(0, 0, 1))?;
Ok(RmsNorm::new(weight, eps))
}
impl Block {
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
Self {
@ -374,9 +397,9 @@ impl Block {
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?;
let mlp = Mlp::load(vb.pp("mlp"), cfg, comm)?;
let input_layernorm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("input_layernorm"))?;
let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?;
let post_attention_layernorm =
rms_norm(cfg.hidden_size, 1e-5, vb.pp("post_attention_layernorm"))?;
RmsNorm::load(cfg.hidden_size, vb.pp("post_attention_layernorm"))?;
Ok(Self::new(
input_layernorm,
attn,
@ -418,8 +441,8 @@ impl Llama {
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layer)
.map(|i| {
Block::load(
vb.pp(&format!("model.layers.{i}")),

View File

@ -2,21 +2,17 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use rand::prelude::*;
use candle::{DType, Result, Tensor, D};
use candle_nn::{loss, ops, Conv2d, Linear, Module, Optimizer, VarBuilder, VarMap};
use candle_nn::{loss, ops, Linear, VarBuilder, VarMap};
const IMAGE_DIM: usize = 784;
const LABELS: usize = 10;
fn linear_z(in_dim: usize, out_dim: usize, vs: VarBuilder) -> Result<Linear> {
let ws = vs.get_with_hints((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
let bs = vs.get_with_hints(out_dim, "bias", candle_nn::init::ZERO)?;
let ws = vs.get_or_init((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
let bs = vs.get_or_init(out_dim, "bias", candle_nn::init::ZERO)?;
Ok(Linear::new(ws, Some(bs)))
}
@ -59,46 +55,6 @@ impl Model for Mlp {
}
}
#[derive(Debug)]
struct ConvNet {
conv1: Conv2d,
conv2: Conv2d,
fc1: Linear,
fc2: Linear,
dropout: candle_nn::Dropout,
}
impl ConvNet {
fn new(vs: VarBuilder) -> Result<Self> {
let conv1 = candle_nn::conv2d(1, 32, 5, Default::default(), vs.pp("c1"))?;
let conv2 = candle_nn::conv2d(32, 64, 5, Default::default(), vs.pp("c2"))?;
let fc1 = candle_nn::linear(1024, 1024, vs.pp("fc1"))?;
let fc2 = candle_nn::linear(1024, LABELS, vs.pp("fc2"))?;
let dropout = candle_nn::Dropout::new(0.5);
Ok(Self {
conv1,
conv2,
fc1,
fc2,
dropout,
})
}
fn forward(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
let (b_sz, _img_dim) = xs.dims2()?;
let xs = xs
.reshape((b_sz, 1, 28, 28))?
.apply(&self.conv1)?
.max_pool2d(2)?
.apply(&self.conv2)?
.max_pool2d(2)?
.flatten_from(1)?
.apply(&self.fc1)?
.relu()?;
self.dropout.forward(&xs, train)?.apply(&self.fc2)
}
}
struct TrainingArgs {
learning_rate: f64,
load: Option<String>,
@ -106,71 +62,6 @@ struct TrainingArgs {
epochs: usize,
}
fn training_loop_cnn(
m: candle_datasets::vision::Dataset,
args: &TrainingArgs,
) -> anyhow::Result<()> {
const BSIZE: usize = 64;
let dev = candle::Device::cuda_if_available(0)?;
let train_labels = m.train_labels;
let train_images = m.train_images.to_device(&dev)?;
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
let mut varmap = VarMap::new();
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
let model = ConvNet::new(vs.clone())?;
if let Some(load) = &args.load {
println!("loading weights from {load}");
varmap.load(load)?
}
let adamw_params = candle_nn::ParamsAdamW {
lr: args.learning_rate,
..Default::default()
};
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), adamw_params)?;
let test_images = m.test_images.to_device(&dev)?;
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
let n_batches = train_images.dim(0)? / BSIZE;
let mut batch_idxs = (0..n_batches).collect::<Vec<usize>>();
for epoch in 1..args.epochs {
let mut sum_loss = 0f32;
batch_idxs.shuffle(&mut thread_rng());
for batch_idx in batch_idxs.iter() {
let train_images = train_images.narrow(0, batch_idx * BSIZE, BSIZE)?;
let train_labels = train_labels.narrow(0, batch_idx * BSIZE, BSIZE)?;
let logits = model.forward(&train_images, true)?;
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
let loss = loss::nll(&log_sm, &train_labels)?;
opt.backward_step(&loss)?;
sum_loss += loss.to_vec0::<f32>()?;
}
let avg_loss = sum_loss / n_batches as f32;
let test_logits = model.forward(&test_images, false)?;
let sum_ok = test_logits
.argmax(D::Minus1)?
.eq(&test_labels)?
.to_dtype(DType::F32)?
.sum_all()?
.to_scalar::<f32>()?;
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
println!(
"{epoch:4} train loss {:8.5} test acc: {:5.2}%",
avg_loss,
100. * test_accuracy
);
}
if let Some(save) = &args.save {
println!("saving trained weights in {save}");
varmap.save(save)?
}
Ok(())
}
fn training_loop<M: Model>(
m: candle_datasets::vision::Dataset,
args: &TrainingArgs,
@ -190,7 +81,7 @@ fn training_loop<M: Model>(
varmap.load(load)?
}
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate)?;
let sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate);
let test_images = m.test_images.to_device(&dev)?;
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
for epoch in 1..args.epochs {
@ -224,7 +115,6 @@ fn training_loop<M: Model>(
enum WhichModel {
Linear,
Mlp,
Cnn,
}
#[derive(Parser)]
@ -245,20 +135,12 @@ struct Args {
/// The file where to load the trained weights from, in safetensors format.
#[arg(long)]
load: Option<String>,
/// The directory where to load the dataset from, in ubyte format.
#[arg(long)]
local_mnist: Option<String>,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
// Load the dataset
let m = if let Some(directory) = args.local_mnist {
candle_datasets::vision::mnist::load_dir(directory)?
} else {
candle_datasets::vision::mnist::load()?
};
let m = candle_datasets::vision::mnist::load_dir("data")?;
println!("train-images: {:?}", m.train_images.shape());
println!("train-labels: {:?}", m.train_labels.shape());
println!("test-images: {:?}", m.test_images.shape());
@ -267,7 +149,6 @@ pub fn main() -> anyhow::Result<()> {
let default_learning_rate = match args.model {
WhichModel::Linear => 1.,
WhichModel::Mlp => 0.05,
WhichModel::Cnn => 0.001,
};
let training_args = TrainingArgs {
epochs: args.epochs,
@ -278,6 +159,5 @@ pub fn main() -> anyhow::Result<()> {
match args.model {
WhichModel::Linear => training_loop::<LinearModel>(m, &training_args),
WhichModel::Mlp => training_loop::<Mlp>(m, &training_args),
WhichModel::Cnn => training_loop_cnn(m, &training_args),
}
}

View File

@ -1,6 +1,6 @@
use crate::nn::conv1d_weight_norm;
use candle::{DType, IndexOp, Result, Tensor};
use candle_nn::{conv1d, Conv1d, Conv1dConfig, Module, VarBuilder};
use crate::nn::{conv1d, conv1d_weight_norm, Conv1d, Conv1dConfig, VarBuilder};
use anyhow::Result;
use candle::{DType, IndexOp, Tensor};
// Encodec Model
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
@ -182,7 +182,7 @@ impl EncodecResidualVectorQuantizer {
fn decode(&self, codes: &Tensor) -> Result<Tensor> {
let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?;
if codes.dim(0)? != self.layers.len() {
candle::bail!(
anyhow::bail!(
"codes shape {:?} does not match the number of quantization layers {}",
codes.shape(),
self.layers.len()
@ -273,24 +273,14 @@ impl EncodecConv1d {
in_c,
out_c,
kernel_size,
Conv1dConfig {
padding: 0,
stride,
groups: 1,
dilation: 1,
},
Conv1dConfig { padding: 0, stride },
vb.pp("conv"),
)?,
NormType::None => conv1d(
in_c,
out_c,
kernel_size,
Conv1dConfig {
padding: 0,
stride,
groups: 1,
dilation: 1,
},
Conv1dConfig { padding: 0, stride },
vb.pp("conv"),
)?,
};
@ -320,7 +310,7 @@ impl EncodecResnetBlock {
let h = dim / cfg.compress;
let mut layer = Layer::new(vb.pp("block"));
if dilations.len() != 2 {
candle::bail!("expected dilations of size 2")
anyhow::bail!("expected dilations of size 2")
}
// TODO: Apply dilations!
layer.inc();
@ -369,7 +359,7 @@ impl<'a> Layer<'a> {
self.cnt += 1;
}
fn next(&mut self) -> VarBuilder {
fn next(&mut self) -> VarBuilder<'a> {
let vb = self.vb.pp(&self.cnt.to_string());
self.cnt += 1;
vb

View File

@ -7,21 +7,17 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
mod encodec_model;
mod musicgen_model;
mod nn;
mod t5_model;
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
use nn::VarBuilder;
use anyhow::{Error as E, Result};
use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use candle::DType;
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
const DTYPE: DType = DType::F32;
@ -34,17 +30,11 @@ struct Args {
/// The model weight file, in safetensor format.
#[arg(long)]
model: Option<String>,
model: String,
/// The tokenizer config.
#[arg(long)]
tokenizer: Option<String>,
#[arg(
long,
default_value = "90s rock song with loud guitars and heavy drums"
)]
prompt: String,
tokenizer: String,
}
fn main() -> Result<()> {
@ -52,44 +42,13 @@ fn main() -> Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let tokenizer = match args.tokenizer {
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
None => Api::new()?
.model("facebook/musicgen-small".to_string())
.get("tokenizer.json")?,
};
let mut tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let mut tokenizer = Tokenizer::from_file(args.tokenizer).map_err(E::msg)?;
let _tokenizer = tokenizer.with_padding(None).with_truncation(None);
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => Api::new()?
.repo(Repo::with_revision(
"facebook/musicgen-small".to_string(),
RepoType::Model,
"refs/pr/13".to_string(),
))
.get("model.safetensors")?,
};
let model = unsafe { candle::safetensors::MmapedFile::new(model)? };
let model = unsafe { candle::safetensors::MmapedFile::new(args.model)? };
let model = model.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
let config = GenConfig::small();
let model = MusicgenForConditionalGeneration::load(vb, config)?;
let tokens = tokenizer
.encode(args.prompt.as_str(), true)
.map_err(E::msg)?
.get_ids()
.to_vec();
println!("tokens: {tokens:?}");
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
println!("{tokens:?}");
let embeds = model.text_encoder.forward(&tokens)?;
println!("{embeds}");
let _model = MusicgenForConditionalGeneration::load(vb, config)?;
Ok(())
}

View File

@ -1,9 +1,9 @@
use crate::{encodec_model, t5_model};
use candle::{DType, Device, Result, Tensor, D};
use candle_nn::{
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
VarBuilder,
use crate::nn::{
embedding, layer_norm, linear, Embedding, HiddenAct, LayerNorm, Linear, VarBuilder,
};
use crate::{encodec_model, t5_model};
use anyhow::Result;
use candle::{DType, Device, Tensor, D};
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
#[derive(Debug, Clone, PartialEq)]
@ -15,7 +15,7 @@ pub struct Config {
num_attention_heads: usize,
layerdrop: f64,
use_cache: bool,
activation_function: Activation,
activation_function: HiddenAct,
hidden_size: usize,
dropout: f64,
attention_dropout: f64,
@ -39,7 +39,7 @@ impl Default for Config {
num_attention_heads: 16,
layerdrop: 0.0,
use_cache: true,
activation_function: Activation::Gelu, // TODO: Handle old style gelu.
activation_function: HiddenAct::Gelu, // TODO: Handle old style gelu.
hidden_size: 1024,
dropout: 0.1,
attention_dropout: 0.0,
@ -65,7 +65,7 @@ impl Config {
num_attention_heads: 16,
layerdrop: 0.0,
use_cache: true,
activation_function: Activation::Gelu, // TODO: Handle old style gelu.
activation_function: HiddenAct::Gelu, // TODO: Handle old style gelu.
hidden_size: 1024,
dropout: 0.1,
attention_dropout: 0.0,
@ -127,7 +127,7 @@ impl MusicgenSinusoidalPositionalEmbedding {
if seq_len > self.weights.dim(0)? {
self.weights = get_embedding(seq_len, self.embedding_dim)?
}
self.weights.narrow(0, 0, seq_len)
Ok(self.weights.narrow(0, 0, seq_len)?)
}
}
@ -148,10 +148,10 @@ impl MusicgenAttention {
let h = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let head_dim = h / num_heads;
let k_proj = linear_no_bias(h, h, vb.pp("k_proj"))?;
let v_proj = linear_no_bias(h, h, vb.pp("v_proj"))?;
let q_proj = linear_no_bias(h, h, vb.pp("q_proj"))?;
let out_proj = linear_no_bias(h, h, vb.pp("out_proj"))?;
let k_proj = linear(h, h, false, vb.pp("k_proj"))?;
let v_proj = linear(h, h, false, vb.pp("v_proj"))?;
let q_proj = linear(h, h, false, vb.pp("q_proj"))?;
let out_proj = linear(h, h, false, vb.pp("out_proj"))?;
Ok(Self {
scaling: 1. / (head_dim as f64).sqrt(),
is_decoder: true,
@ -208,7 +208,7 @@ struct MusicgenDecoderLayer {
fc1: Linear,
fc2: Linear,
final_layer_norm: LayerNorm,
activation_fn: Activation,
activation_fn: HiddenAct,
}
impl MusicgenDecoderLayer {
@ -218,8 +218,8 @@ impl MusicgenDecoderLayer {
let self_attn_layer_norm = layer_norm(h, 1e-5, vb.pp("self_attn_layer_norm"))?;
let encoder_attn = MusicgenAttention::load(vb.pp("encoder_attn"), cfg)?;
let encoder_attn_layer_norm = layer_norm(h, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
let fc1 = linear_no_bias(h, cfg.ffn_dim, vb.pp("fc1"))?;
let fc2 = linear_no_bias(cfg.ffn_dim, h, vb.pp("fc2"))?;
let fc1 = linear(h, cfg.ffn_dim, false, vb.pp("fc1"))?;
let fc2 = linear(cfg.ffn_dim, h, false, vb.pp("fc2"))?;
let final_layer_norm = layer_norm(h, 1e-5, vb.pp("final_layer_norm"))?;
Ok(Self {
self_attn,
@ -341,7 +341,7 @@ impl MusicgenForCausalLM {
let h = cfg.hidden_size;
let decoder = MusicgenDecoder::load(vb.pp("model.decoder"), cfg)?;
let lm_heads = (0..cfg.num_codebooks)
.map(|i| linear_no_bias(h, cfg.vocab_size, vb.pp(&format!("lm_heads.{i}"))))
.map(|i| linear(h, cfg.vocab_size, false, vb.pp(&format!("lm_heads.{i}"))))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
decoder,
@ -357,7 +357,7 @@ impl MusicgenForCausalLM {
let lm_logits = self
.lm_heads
.iter()
.map(|h| h.forward(&hidden_states))
.map(|h| Ok(h.forward(&hidden_states)?))
.collect::<Result<Vec<_>>>()?;
let lm_logits = Tensor::stack(&lm_logits, 1)?.reshape((
b_sz * self.num_codebooks,
@ -370,9 +370,9 @@ impl MusicgenForCausalLM {
#[derive(Debug)]
pub struct MusicgenForConditionalGeneration {
pub text_encoder: crate::t5_model::T5EncoderModel,
pub audio_encoder: crate::encodec_model::EncodecModel,
pub decoder: MusicgenForCausalLM,
text_encoder: crate::t5_model::T5EncoderModel,
audio_encoder: crate::encodec_model::EncodecModel,
decoder: MusicgenForCausalLM,
cfg: GenConfig,
}

View File

@ -1,5 +1,62 @@
use candle::Result;
use candle_nn::{Conv1d, Conv1dConfig, VarBuilder};
use anyhow::Result;
use candle::Tensor;
const MAX_SEQ_LEN: usize = 5000;
pub type VarBuilder<'a> = candle_nn::VarBuilder<'a>;
pub type Linear = candle_nn::Linear;
pub fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
let bias = if bias {
Some(vb.get(size2, "bias")?)
} else {
None
};
Ok(Linear::new(weight, bias))
}
pub type LayerNorm = candle_nn::LayerNorm;
pub fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
(Ok(weight), Ok(bias)) => (weight, bias),
(Err(err), _) | (_, Err(err)) => {
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
(weight, bias)
} else {
return Err(err.into());
}
}
};
Ok(LayerNorm::new(weight, bias, eps))
}
#[derive(Debug)]
pub struct Dropout {
pr: f64,
}
impl Dropout {
pub fn new(pr: f64) -> Self {
Self { pr }
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
// TODO
Ok(x.clone())
}
}
pub type Embedding = candle_nn::Embedding;
pub fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
pub type Conv1d = candle_nn::Conv1d;
pub type Conv1dConfig = candle_nn::Conv1dConfig;
// Applies weight norm for inference by recomputing the weight tensor. This
// does not apply to training.
@ -18,3 +75,17 @@ pub fn conv1d_weight_norm(
let bias = vb.get(out_c, "bias")?;
Ok(Conv1d::new(weight, Some(bias), config))
}
pub fn conv1d(
in_c: usize,
out_c: usize,
kernel_size: usize,
config: Conv1dConfig,
vb: VarBuilder,
) -> Result<Conv1d> {
let weight = vb.get((out_c, in_c, kernel_size), "weight")?;
let bias = vb.get(out_c, "bias")?;
Ok(Conv1d::new(weight, Some(bias), config))
}
pub type HiddenAct = candle_nn::Activation;

View File

@ -1,8 +1,9 @@
// T5 Text Encoder
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
use candle::{DType, Result, Tensor, D};
use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, Module, VarBuilder};
use crate::nn::{embedding, linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder};
use anyhow::Result;
use candle::{DType, Tensor, D};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)]
@ -19,7 +20,7 @@ pub struct Config {
dropout_rate: f64,
layer_norm_epsilon: f64,
initializer_factor: f64,
feed_forward_proj: Activation,
feed_forward_proj: HiddenAct,
is_decoder: bool,
is_encoder_decoder: bool,
use_cache: bool,
@ -42,7 +43,7 @@ impl Default for Config {
dropout_rate: 0.1,
layer_norm_epsilon: 1e-6,
initializer_factor: 1.0,
feed_forward_proj: Activation::Relu,
feed_forward_proj: HiddenAct::Relu,
is_decoder: false,
is_encoder_decoder: true,
use_cache: true,
@ -61,7 +62,7 @@ impl Config {
d_model: 768,
dropout_rate: 0.1,
eos_token_id: 1,
feed_forward_proj: Activation::Relu,
feed_forward_proj: HiddenAct::Relu,
initializer_factor: 1.0,
is_decoder: false,
is_encoder_decoder: true,
@ -96,9 +97,10 @@ impl T5LayerNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let dtype = xs.dtype();
let xs_f32 = xs.to_dtype(DType::F32)?;
// variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
let xs2_f32 = (&xs_f32 * &xs_f32)?;
let sum_xs2_f32 = xs2_f32.sum_keepdim(D::Minus1)?;
let variance = xs2_f32.broadcast_div(&sum_xs2_f32)?;
let xs = (xs / (variance + self.variance_epsilon)?.sqrt()?)?;
let xs = xs.to_dtype(dtype)?;
let xs = xs.broadcast_mul(&self.weight)?;
Ok(xs)
@ -109,23 +111,27 @@ impl T5LayerNorm {
struct T5DenseActDense {
wi: Linear,
wo: Linear,
act: Activation,
dropout: Dropout,
act: HiddenAct,
}
impl T5DenseActDense {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let wi = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
let wi = linear(cfg.d_model, cfg.d_ff, false, vb.pp("wi"))?;
let wo = linear(cfg.d_ff, cfg.d_model, false, vb.pp("wo"))?;
let dropout = Dropout::new(cfg.dropout_rate);
Ok(Self {
wi,
wo,
act: Activation::Relu,
dropout,
act: HiddenAct::Relu,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.wi.forward(xs)?;
let xs = self.act.forward(&xs)?;
let xs = self.dropout.forward(&xs)?;
let xs = self.wo.forward(&xs)?;
Ok(xs)
}
@ -135,6 +141,7 @@ impl T5DenseActDense {
struct T5LayerFF {
dense_relu_dense: T5DenseActDense,
layer_norm: T5LayerNorm,
dropout: Dropout,
}
impl T5LayerFF {
@ -143,16 +150,18 @@ impl T5LayerFF {
let dense_relu_dense = T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?;
let layer_norm =
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
let dropout = Dropout::new(cfg.dropout_rate);
Ok(Self {
dense_relu_dense,
layer_norm,
dropout,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let ys = self.layer_norm.forward(xs)?;
let ys = self.dense_relu_dense.forward(&ys)?;
let xs = (xs + ys)?;
let xs = (xs + self.dropout.forward(&ys)?)?;
Ok(xs)
}
}
@ -166,18 +175,15 @@ struct T5Attention {
n_heads: usize,
d_kv: usize,
relative_attention_bias: Option<Embedding>,
relative_attention_num_buckets: usize,
relative_attention_max_distance: usize,
inner_dim: usize,
}
impl T5Attention {
fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
let inner_dim = cfg.num_heads * cfg.d_kv;
let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?;
let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?;
let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?;
let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?;
let q = linear(cfg.d_model, inner_dim, false, vb.pp("q"))?;
let k = linear(cfg.d_model, inner_dim, false, vb.pp("k"))?;
let v = linear(cfg.d_model, inner_dim, false, vb.pp("v"))?;
let o = linear(inner_dim, cfg.d_model, false, vb.pp("o"))?;
let relative_attention_bias = if h {
let emb = embedding(
cfg.relative_attention_num_buckets,
@ -196,17 +202,10 @@ impl T5Attention {
n_heads: cfg.num_heads,
d_kv: cfg.d_kv,
relative_attention_bias,
relative_attention_num_buckets: cfg.relative_attention_num_buckets,
relative_attention_max_distance: cfg.relative_attention_max_distance,
inner_dim,
})
}
fn forward(
&self,
xs: &Tensor,
position_bias: Option<&Tensor>,
) -> Result<(Tensor, Option<Tensor>)> {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
// TODO: Apply the mask(s)?
// TODO: kv caching.
let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?);
@ -215,78 +214,19 @@ impl T5Attention {
let v = self.v.forward(xs)?;
let q = q
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?
.contiguous()?;
.transpose(1, 2)?;
let k = k
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?
.contiguous()?;
.transpose(1, 2)?;
let v = v
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?
.contiguous()?;
.transpose(1, 2)?;
let scores = q.matmul(&k.t()?)?;
let (scores, position_bias) = match position_bias {
Some(position_bias) => ((scores + position_bias)?, Some(position_bias.clone())),
None => match &self.relative_attention_bias {
None => (scores, None),
Some(relative_attention_bias) => {
let query_length = seq_len;
let key_length = seq_len;
// This only handles the bidirectional case.
let num_buckets = self.relative_attention_num_buckets as u32 / 2;
let max_exact = num_buckets / 2;
let relative_position = (0..query_length as u32)
.map(|i| {
(0..key_length as u32)
.map(|j| {
if i < j {
if j - i < max_exact {
j - i + num_buckets
} else {
let b = f32::log(
(j - i) as f32 / max_exact as f32,
self.relative_attention_max_distance as f32
/ max_exact as f32,
) * (num_buckets - max_exact) as f32;
u32::min(
max_exact + num_buckets + b as u32,
self.relative_attention_num_buckets as u32 - 1,
)
}
} else if i - j < max_exact {
i - j
} else {
let b = f32::log(
(i - j) as f32 / max_exact as f32,
self.relative_attention_max_distance as f32
/ max_exact as f32,
) * (num_buckets - max_exact) as f32;
max_exact + b as u32
}
})
.collect::<Vec<u32>>()
})
.collect::<Vec<Vec<_>>>();
let relative_buckets = Tensor::new(relative_position, q.device())?;
let position_bias = relative_attention_bias
.forward(&relative_buckets)?
.permute((2, 0, 1))?
.unsqueeze(0)?;
((scores + &position_bias)?, Some(position_bias))
// TODO: position_bias_masked?
}
},
};
// TODO: position_bias_masked
let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
let attn_output = attn_weights.matmul(&v)?;
let attn_output = attn_output
.transpose(1, 2)?
.reshape((b_sz, seq_len, self.inner_dim))?;
let attn_output = self.o.forward(&attn_output)?;
Ok((attn_output, position_bias))
Ok(attn_output)
}
}
@ -294,6 +234,7 @@ impl T5Attention {
struct T5LayerSelfAttention {
self_attention: T5Attention,
layer_norm: T5LayerNorm,
dropout: Dropout,
}
impl T5LayerSelfAttention {
@ -301,21 +242,19 @@ impl T5LayerSelfAttention {
let self_attention = T5Attention::load(h, vb.pp("SelfAttention"), cfg)?;
let layer_norm =
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
let dropout = Dropout::new(cfg.dropout_rate);
Ok(Self {
self_attention,
layer_norm,
dropout,
})
}
fn forward(
&self,
xs: &Tensor,
position_bias: Option<&Tensor>,
) -> Result<(Tensor, Option<Tensor>)> {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let normed_xs = self.layer_norm.forward(xs)?;
let (ys, position_bias) = self.self_attention.forward(&normed_xs, position_bias)?;
let ys = self.self_attention.forward(&normed_xs)?;
let ys = (xs + ys)?;
Ok((ys, position_bias))
Ok(ys)
}
}
@ -357,12 +296,8 @@ impl T5Block {
})
}
fn forward(
&self,
xs: &Tensor,
position_bias: Option<&Tensor>,
) -> Result<(Tensor, Option<Tensor>)> {
let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias)?;
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.self_attn.forward(xs)?;
// TODO: clamp for f16?
if let Some(cross_attn) = &self.cross_attn {
xs = cross_attn.forward(&xs)?;
@ -370,7 +305,7 @@ impl T5Block {
}
let xs = self.ff.forward(&xs)?;
// TODO: clamp for f16?
Ok((xs, position_bias))
Ok(xs)
}
}
@ -379,6 +314,7 @@ struct T5Stack {
block: Vec<T5Block>,
shared: Arc<Embedding>,
final_layer_norm: T5LayerNorm,
dropout: Dropout,
}
impl T5Stack {
@ -391,24 +327,25 @@ impl T5Stack {
cfg.layer_norm_epsilon,
vb.pp("final_layer_norm"),
)?;
let dropout = Dropout::new(cfg.dropout_rate);
Ok(Self {
block,
shared: shared.clone(),
final_layer_norm,
dropout,
})
}
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let input_embeds = self.shared.as_ref().forward(input_ids)?;
let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?);
let (_b_sz, _seq_len) = input_embeds.dims2()?;
let mut hidden_states = input_embeds;
let mut position_bias = None;
let mut hidden_states = self.dropout.forward(&input_embeds)?;
for block in self.block.iter() {
(hidden_states, position_bias) =
block.forward(&hidden_states, position_bias.as_ref())?
hidden_states = block.forward(&hidden_states)?
}
let hidden_states = self.final_layer_norm.forward(&hidden_states)?;
let hidden_states = self.dropout.forward(&hidden_states)?;
Ok(hidden_states)
}
}

View File

@ -1,367 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use std::io::Write;
use tokenizers::Tokenizer;
use candle::quantized::{ggml_file, gguf_file};
use candle::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
mod model;
use model::ModelWeights;
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
#[derive(Debug)]
enum Prompt {
Interactive,
Chat,
One(String),
}
#[derive(Clone, Debug, Copy, ValueEnum)]
enum Which {
#[value(name = "7b")]
L7b,
#[value(name = "13b")]
L13b,
#[value(name = "70b")]
L70b,
#[value(name = "7b-chat")]
L7bChat,
#[value(name = "13b-chat")]
L13bChat,
#[value(name = "70b-chat")]
L70bChat,
#[value(name = "7b-code")]
L7bCode,
#[value(name = "13b-code")]
L13bCode,
#[value(name = "32b-code")]
L34bCode,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// GGML file to load, typically a .bin file generated by the quantize command from llama.cpp
#[arg(long)]
model: Option<String>,
/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
/// is preserved.
#[arg(long)]
prompt: Option<String>,
/// The length of the sample to generate (in tokens).
#[arg(short = 'n', long, default_value_t = 100)]
sample_len: usize,
/// The tokenizer config in json format.
#[arg(long)]
tokenizer: Option<String>,
/// The temperature used to generate samples, use 0 for greedy sampling.
#[arg(long, default_value_t = 0.8)]
temperature: f64,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model size to use.
#[arg(long, default_value = "7b")]
which: Which,
/// Group-Query Attention, use 8 for the 70B version of LLaMAv2.
#[arg(long)]
gqa: Option<usize>,
}
impl Args {
fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
let tokenizer_path = match &self.tokenizer {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("hf-internal-testing/llama-tokenizer".to_string());
api.get("tokenizer.json")?
}
};
Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
}
fn model(&self) -> anyhow::Result<std::path::PathBuf> {
let model_path = match &self.model {
Some(config) => std::path::PathBuf::from(config),
None => {
let (repo, filename) = match self.which {
Which::L7b => ("TheBloke/Llama-2-7B-GGML", "llama-2-7b.ggmlv3.q4_0.bin"),
Which::L13b => ("TheBloke/Llama-2-13B-GGML", "llama-2-13b.ggmlv3.q4_0.bin"),
Which::L70b => ("TheBloke/Llama-2-70B-GGML", "llama-2-70b.ggmlv3.q4_0.bin"),
Which::L7bChat => (
"TheBloke/Llama-2-7B-Chat-GGML",
"llama-2-7b-chat.ggmlv3.q4_0.bin",
),
Which::L13bChat => (
"TheBloke/Llama-2-13B-Chat-GGML",
"llama-2-13b-chat.ggmlv3.q4_0.bin",
),
Which::L70bChat => (
"TheBloke/Llama-2-70B-Chat-GGML",
"llama-2-70b-chat.ggmlv3.q4_0.bin",
),
Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"),
Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"),
Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"),
};
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(repo.to_string());
api.get(filename)?
}
};
Ok(model_path)
}
}
fn print_token(next_token: u32, tokenizer: &Tokenizer) {
// Extracting the last token as a string is complicated, here we just apply some simple
// heuristics as it seems to work well enough for this example. See the following for more
// details:
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
if let Some(text) = tokenizer.id_to_token(next_token) {
let text = text.replace('▁', " ");
let ascii = text
.strip_prefix("<0x")
.and_then(|t| t.strip_suffix('>'))
.and_then(|t| u8::from_str_radix(t, 16).ok());
match ascii {
None => print!("{text}"),
Some(ascii) => {
if let Some(chr) = char::from_u32(ascii as u32) {
if chr.is_ascii() {
print!("{chr}")
}
}
}
}
let _ = std::io::stdout().flush();
}
}
fn format_size(size_in_bytes: usize) -> String {
if size_in_bytes < 1_000 {
format!("{}B", size_in_bytes)
} else if size_in_bytes < 1_000_000 {
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
} else if size_in_bytes < 1_000_000_000 {
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
} else {
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
}
}
fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let temperature = if args.temperature == 0. {
None
} else {
Some(args.temperature)
};
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature, args.repeat_penalty, args.repeat_last_n
);
let model_path = args.model()?;
let mut file = std::fs::File::open(&model_path)?;
let start = std::time::Instant::now();
let mut model = match model_path.extension().and_then(|v| v.to_str()) {
Some("gguf") => {
let model = gguf_file::Content::read(&mut file)?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count();
total_size_in_bytes +=
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size();
}
println!(
"loaded {:?} tensors ({}) in {:.2}s",
model.tensor_infos.len(),
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
ModelWeights::from_gguf(model, &mut file)?
}
Some("ggml" | "bin") | Some(_) | None => {
let model = ggml_file::Content::read(&mut file)?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensors.iter() {
let elem_count = tensor.shape().elem_count();
total_size_in_bytes +=
elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size();
}
println!(
"loaded {:?} tensors ({}) in {:.2}s",
model.tensors.len(),
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
println!("params: {:?}", model.hparams);
let default_gqa = match args.which {
Which::L7b
| Which::L13b
| Which::L7bChat
| Which::L13bChat
| Which::L7bCode
| Which::L13bCode
| Which::L34bCode => 1,
Which::L70b | Which::L70bChat => 8,
};
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
}
};
println!("model built");
let tokenizer = args.tokenizer()?;
let prompt = match args.prompt.as_deref() {
Some("chat") => Prompt::Chat,
Some("interactive") => Prompt::Interactive,
Some(s) => Prompt::One(s.to_string()),
None => Prompt::One(DEFAULT_PROMPT.to_string()),
};
let mut pre_prompt_tokens = vec![];
loop {
let prompt_str = match &prompt {
Prompt::One(prompt) => prompt.clone(),
Prompt::Interactive | Prompt::Chat => {
print!("> ");
std::io::stdout().flush()?;
let mut prompt = String::new();
std::io::stdin().read_line(&mut prompt)?;
if prompt.ends_with('\n') {
prompt.pop();
if prompt.ends_with('\r') {
prompt.pop();
}
}
prompt
}
};
print!("{}", &prompt_str);
let tokens = tokenizer
.encode(prompt_str, true)
.map_err(anyhow::Error::msg)?;
if args.verbose_prompt {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
}
let prompt_tokens = [&pre_prompt_tokens, tokens.get_ids()].concat();
let to_sample = args.sample_len.saturating_sub(1);
let prompt_tokens = if prompt_tokens.len() + to_sample > model::MAX_SEQ_LEN - 10 {
let to_remove = prompt_tokens.len() + to_sample + 10 - model::MAX_SEQ_LEN;
prompt_tokens[prompt_tokens.len().saturating_sub(to_remove)..].to_vec()
} else {
prompt_tokens
};
let mut all_tokens = vec![];
let mut logits_processor = LogitsProcessor::new(args.seed, temperature);
let start_prompt_processing = std::time::Instant::now();
let mut next_token = {
let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)?
};
let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token);
print_token(next_token, &tokenizer);
let start_post_prompt = std::time::Instant::now();
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&all_tokens[start_at..],
)?
};
next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);
print_token(next_token, &tokenizer);
}
let dt = start_post_prompt.elapsed();
println!(
"\n\n{:4} prompt tokens processed: {:.2} token/s",
prompt_tokens.len(),
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
);
println!(
"{:4} tokens generated: {:.2} token/s",
to_sample,
to_sample as f64 / dt.as_secs_f64(),
);
match prompt {
Prompt::One(_) => break,
Prompt::Interactive => {}
Prompt::Chat => {
pre_prompt_tokens = [prompt_tokens.as_slice(), all_tokens.as_slice()].concat()
}
}
}
Ok(())
}

View File

@ -1,371 +0,0 @@
use std::collections::HashMap;
use candle::quantized::QTensor;
use candle::quantized::{ggml_file, gguf_file};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Module};
pub const MAX_SEQ_LEN: usize = 4096;
struct RmsNorm {
inner: candle_nn::LayerNorm,
span: tracing::Span,
}
impl RmsNorm {
fn new(scale: QTensor, eps: f32) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let scale = scale.dequantize(&Device::Cpu)?;
let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64);
Ok(Self { inner, span })
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
// QMatMul wrapper adding some tracing.
struct QMatMul {
inner: candle::quantized::QMatMul,
span: tracing::Span,
}
impl QMatMul {
fn from_qtensor(qtensor: QTensor) -> Self {
let inner = candle::quantized::QMatMul::from_qtensor(qtensor);
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
Self { inner, span }
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(xs)
}
}
struct LayerWeights {
attention_wq: QMatMul,
attention_wk: QMatMul,
attention_wv: QMatMul,
attention_wo: QMatMul,
attention_norm: RmsNorm,
feed_forward_w1: QMatMul,
feed_forward_w2: QMatMul,
feed_forward_w3: QMatMul,
ffn_norm: RmsNorm,
n_head: usize,
n_kv_head: usize,
head_dim: usize,
cos: Tensor,
sin: Tensor,
kv_cache: Option<(Tensor, Tensor)>,
span_attn: tracing::Span,
span_rot: tracing::Span,
span_mlp: tracing::Span,
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}
impl LayerWeights {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let _enter = self.span_rot.enter();
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
let cos = self
.cos
.narrow(0, index_pos, seq_len)?
.reshape((seq_len, n_embd / 2, 1))?;
let sin = self
.sin
.narrow(0, index_pos, seq_len)?
.reshape((seq_len, n_embd / 2, 1))?;
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
// This mimics the llama.cpp behavior.
// https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105
// The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension.
// The resulting y0 and y1 are also interleaved with:
// y0 = x0*cos - x1*sin
// y1 = x0*sin + x1*cos
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
let x0 = x.narrow(D::Minus1, 0, 1)?;
let x1 = x.narrow(D::Minus1, 1, 1)?;
let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
let rope = rope.flatten_from(D::Minus2)?;
Ok(rope)
}
fn forward_attn(&mut self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> {
let _enter = self.span_attn.enter();
let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.attention_wq.forward(x)?;
let k = self.attention_wk.forward(x)?;
let v = self.attention_wv.forward(x)?;
let q = q
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
.transpose(1, 2)?;
let q = self.apply_rotary_emb(&q, index_pos)?;
let k = self.apply_rotary_emb(&k, index_pos)?;
let (k, v) = match &self.kv_cache {
None => (k, v),
Some((k_cache, v_cache)) => {
if index_pos == 0 {
(k, v)
} else {
let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?;
let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?;
(k, v)
}
}
};
self.kv_cache = Some((k.clone(), v.clone()));
// Support for MQA, useful for 70B models.
let k = self.repeat_kv(k)?;
let v = self.repeat_kv(v)?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = mask.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = self.attention_wo.forward(&y)?;
Ok(y)
}
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
let n_rep = self.n_head / self.n_kv_head;
if n_rep == 1 {
Ok(x)
} else {
let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
let x = x
.unsqueeze(2)?
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
Ok(x)
}
}
}
pub struct ModelWeights {
tok_embeddings: Embedding,
layers: Vec<LayerWeights>,
norm: RmsNorm,
output: QMatMul,
masks: HashMap<usize, Tensor>,
span: tracing::Span,
span_output: tracing::Span,
}
fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> {
let theta: Vec<_> = (0..head_dim)
.step_by(2)
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?;
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)?
.to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let cos = idx_theta.cos()?;
let sin = idx_theta.sin()?;
Ok((cos, sin))
}
impl ModelWeights {
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
let cpu = &Device::Cpu;
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?;
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
let output = ct.remove("output.weight")?;
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
for layer_idx in 0..ct.hparams.n_layer {
let prefix = format!("layers.{layer_idx}");
let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?;
let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?;
let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
layers.push(LayerWeights {
attention_wq: QMatMul::from_qtensor(attention_wq),
attention_wk: QMatMul::from_qtensor(attention_wk),
attention_wv: QMatMul::from_qtensor(attention_wv),
attention_wo: QMatMul::from_qtensor(attention_wo),
attention_norm: RmsNorm::new(attention_norm, 1e-5)?,
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1),
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2),
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?,
n_head: ct.hparams.n_head as usize,
n_kv_head: ct.hparams.n_head as usize / gqa,
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
cos: cos.clone(),
sin: sin.clone(),
kv_cache: None,
span_attn,
span_rot,
span_mlp,
})
}
let span = tracing::span!(tracing::Level::TRACE, "model");
let span_output = tracing::span!(tracing::Level::TRACE, "output");
Ok(Self {
tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
layers,
norm,
output: QMatMul::from_qtensor(output),
masks: HashMap::new(),
span,
span_output,
})
}
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content,
reader: &mut R,
) -> Result<Self> {
let cpu = &Device::Cpu;
let md_get = |s: &str| match ct.metadata.get(s) {
None => candle::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v),
};
// Parameter extraction from metadata.
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
// Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()?;
let rope_freq_base = md_get("llama.rope.freq_base")
.and_then(|m| m.to_f32())
.unwrap_or(10000f32);
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?;
let tok_embeddings = ct.tensor(reader, "token_embd.weight")?;
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?;
let output = ct.tensor(reader, "output.weight")?;
let mut layers = Vec::with_capacity(block_count);
for layer_idx in 0..block_count {
let prefix = format!("blk.{layer_idx}");
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?;
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?;
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?;
let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?;
let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?;
let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?;
let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?;
let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?;
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
layers.push(LayerWeights {
attention_wq: QMatMul::from_qtensor(attention_wq),
attention_wk: QMatMul::from_qtensor(attention_wk),
attention_wv: QMatMul::from_qtensor(attention_wv),
attention_wo: QMatMul::from_qtensor(attention_wo),
attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?,
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1),
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2),
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?,
n_head: head_count,
n_kv_head: head_count_kv,
head_dim: embedding_length / head_count,
cos: cos.clone(),
sin: sin.clone(),
kv_cache: None,
span_attn,
span_rot,
span_mlp,
})
}
let span = tracing::span!(tracing::Level::TRACE, "model");
let span_output = tracing::span!(tracing::Level::TRACE, "output");
Ok(Self {
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
layers,
norm,
output: QMatMul::from_qtensor(output),
masks: HashMap::new(),
span,
span_output,
})
}
fn mask(&mut self, t: usize) -> Result<Tensor> {
if let Some(mask) = self.masks.get(&t) {
Ok(mask.clone())
} else {
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
self.masks.insert(t, mask.clone());
Ok(mask)
}
}
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, seq_len) = x.dims2()?;
let mask = self.mask(seq_len)?;
let _enter = self.span.enter();
let mut layer_in = self.tok_embeddings.forward(x)?;
for layer in self.layers.iter_mut() {
let x = layer_in;
let residual = &x;
let x = layer.attention_norm.forward(&x)?;
let attn = layer.forward_attn(&x, &mask, index_pos)?;
let x = (attn + residual)?;
// MLP
let _enter = layer.span_mlp.enter();
let residual = &x;
let x = layer.ffn_norm.forward(&x)?;
let w1 = layer.feed_forward_w1.forward(&x)?;
let w3 = layer.feed_forward_w3.forward(&x)?;
let mlp = layer
.feed_forward_w2
.forward(&(candle_nn::ops::silu(&w1)? * w3)?)?;
layer_in = (mlp + residual)?;
}
let x = self.norm.forward(&layer_in)?;
let x = x.i((.., seq_len - 1, ..))?;
let _enter = self.span_output.enter();
self.output.forward(&x)
}
}

View File

@ -1,273 +0,0 @@
//! SAM: Segment Anything Model
//! https://github.com/facebookresearch/segment-anything
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
pub mod model_image_encoder;
pub mod model_mask_decoder;
pub mod model_prompt_encoder;
pub mod model_sam;
pub mod model_transformer;
use candle::{DType, Result, Tensor};
use candle_nn::{Module, VarBuilder};
use clap::Parser;
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
let inner = if bias {
candle_nn::linear(in_dim, out_dim, vb)?
} else {
candle_nn::linear_no_bias(in_dim, out_dim, vb)?
};
let span = tracing::span!(tracing::Level::TRACE, "linear");
Ok(Linear { inner, span })
}
#[derive(Debug)]
pub struct LayerNorm2d {
weight: Tensor,
bias: Tensor,
num_channels: usize,
eps: f64,
}
impl LayerNorm2d {
pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let weight = vb.get(num_channels, "weight")?;
let bias = vb.get(num_channels, "bias")?;
Ok(Self {
weight,
bias,
num_channels,
eps,
})
}
}
impl Module for LayerNorm2d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let u = xs.mean_keepdim(1)?;
let xs = xs.broadcast_sub(&u)?;
let s = xs.sqr()?.mean_keepdim(1)?;
let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?;
xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)?
.broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?)
}
}
#[derive(Debug)]
pub struct MlpBlock {
lin1: Linear,
lin2: Linear,
activation: candle_nn::Activation,
span: tracing::Span,
}
impl MlpBlock {
pub fn new(
embedding_dim: usize,
mlp_dim: usize,
activation: candle_nn::Activation,
vb: VarBuilder,
) -> Result<Self> {
let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?;
let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?;
let span = tracing::span!(tracing::Level::TRACE, "mlp-block");
Ok(Self {
lin1,
lin2,
activation,
span,
})
}
}
impl Module for MlpBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.lin1)?
.apply(&self.activation)?
.apply(&self.lin2)
}
}
#[derive(Debug)]
pub struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
}
impl Module for Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(long)]
generate_masks: bool,
#[arg(long, default_value_t = 0.5)]
point_x: f64,
#[arg(long, default_value_t = 0.5)]
point_y: f64,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
}
pub fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let device = candle_examples::device(args.cpu)?;
let (image, initial_h, initial_w) = if args.image.ends_with(".safetensors") {
let mut tensors = candle::safetensors::load(&args.image, &device)?;
let image = match tensors.remove("image") {
Some(image) => image,
None => {
if tensors.len() != 1 {
anyhow::bail!("multiple tensors in '{}'", args.image)
}
tensors.into_values().next().unwrap()
}
};
let image = if image.rank() == 4 {
image.get(0)?
} else {
image
};
let (_c, h, w) = image.dims3()?;
(image, h, w)
} else {
let (image, h, w) = candle_examples::load_image(&args.image, Some(model_sam::IMAGE_SIZE))?;
(image.to_device(&device)?, h, w)
};
println!("loaded image {image:?}");
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("lmz/candle-sam".to_string());
api.get("sam_vit_b_01ec64.safetensors")?
}
};
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
if args.generate_masks {
// Default options similar to the Python version.
let bboxes = sam.generate_masks(
&image,
/* points_per_side */ 32,
/* crop_n_layer */ 0,
/* crop_overlap_ratio */ 512. / 1500.,
/* crop_n_points_downscale_factor */ 1,
)?;
for (idx, bbox) in bboxes.iter().enumerate() {
println!("{bbox:?}");
let mask = (&bbox.data.to_dtype(DType::U8)? * 255.)?;
let (h, w) = mask.dims2()?;
let mask = mask.broadcast_as((3, h, w))?;
candle_examples::save_image_resize(
&mask,
format!("sam_mask{idx}.png"),
initial_h,
initial_w,
)?;
}
} else {
let point = Some((args.point_x, args.point_y));
let start_time = std::time::Instant::now();
let (mask, iou_predictions) = sam.forward(&image, point, false)?;
println!(
"mask generated in {:.2}s",
start_time.elapsed().as_secs_f32()
);
println!("mask:\n{mask}");
println!("iou_predictions: {iou_predictions:?}");
// Save the mask as an image.
let mask = (mask.ge(0f32)? * 255.)?;
let (_one, h, w) = mask.dims3()?;
let mask = mask.expand((3, h, w))?;
candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?;
if !args.image.ends_with(".safetensors") {
let mut img = image::io::Reader::open(&args.image)?
.decode()
.map_err(candle::Error::wrap)?;
let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) {
Some(image) => image,
None => anyhow::bail!("error saving merged image"),
};
let mask_img = image::DynamicImage::from(mask_img).resize_to_fill(
img.width(),
img.height(),
image::imageops::FilterType::CatmullRom,
);
for x in 0..img.width() {
for y in 0..img.height() {
let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y);
if mask_p.0[0] > 100 {
let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y);
img_p.0[2] = 255 - (255 - img_p.0[2]) / 2;
img_p.0[1] /= 2;
img_p.0[0] /= 2;
imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p)
}
}
}
match point {
Some((x, y)) => {
let (x, y) = (
(x * img.width() as f64) as i32,
(y * img.height() as f64) as i32,
);
imageproc::drawing::draw_filled_circle(
&img,
(x, y),
3,
image::Rgba([255, 0, 0, 200]),
)
.save("sam_merged.jpg")?
}
None => img.save("sam_merged.jpg")?,
};
}
}
Ok(())
}

View File

@ -1,516 +0,0 @@
use candle::{DType, IndexOp, Result, Tensor};
use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder};
#[derive(Debug)]
struct PatchEmbed {
proj: candle_nn::Conv2d,
span: tracing::Span,
}
impl PatchEmbed {
fn new(
in_chans: usize,
embed_dim: usize,
k_size: usize,
stride: usize,
padding: usize,
vb: VarBuilder,
) -> Result<Self> {
let cfg = candle_nn::Conv2dConfig {
stride,
padding,
..Default::default()
};
let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?;
let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
Ok(Self { proj, span })
}
}
impl Module for PatchEmbed {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.proj)?.permute((0, 2, 3, 1))
}
}
#[derive(Debug)]
struct Attention {
qkv: crate::Linear,
proj: crate::Linear,
num_heads: usize,
scale: f64,
rel_pos_hw: Option<(Tensor, Tensor)>,
span: tracing::Span,
span_rel_pos: tracing::Span,
span_softmax: tracing::Span,
}
// rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
struct EinSum1;
impl candle::CustomOp2 for EinSum1 {
fn name(&self) -> &'static str {
"einsum1"
}
fn cpu_fwd(
&self,
s1: &candle::CpuStorage,
l1: &candle::Layout,
s2: &candle::CpuStorage,
l2: &candle::Layout,
) -> Result<(candle::CpuStorage, candle::Shape)> {
use candle::cpu::kernels::VecOps;
let (b, h, w, c) = l1.shape().dims4()?;
let (h2, k, c2) = l2.shape().dims3()?;
if c != c2 || h != h2 {
candle::bail!("shape mismatch {l1:?} {l2:?}")
}
let s1 = s1.as_slice::<f32>()?;
let s1 = match l1.contiguous_offsets() {
None => candle::bail!("input1 has to be contiguous"),
Some((o1, o2)) => &s1[o1..o2],
};
let s2 = s2.as_slice::<f32>()?;
let s2 = match l2.contiguous_offsets() {
None => candle::bail!("input2 has to be contiguous"),
Some((o1, o2)) => &s2[o1..o2],
};
let mut dst = vec![0f32; b * h * w * k];
for b_idx in 0..b {
let lhs_idx = b_idx * h * w * c;
let dst_idx = b_idx * h * w * k;
for h_idx in 0..h {
let lhs_idx = lhs_idx + h_idx * w * c;
let rhs_idx = h_idx * k * c;
let dst_idx = dst_idx + h_idx * w * k;
for w_idx in 0..w {
let lhs_idx = lhs_idx + w_idx * c;
let dst_idx = dst_idx + w_idx * k;
let lhs = &s1[lhs_idx..lhs_idx + c];
for k_idx in 0..k {
let rhs_idx = rhs_idx + k_idx * c;
let rhs = &s2[rhs_idx..rhs_idx + c];
let mut d = 0f32;
unsafe { f32::vec_dot(lhs.as_ptr(), rhs.as_ptr(), &mut d, c) };
dst[dst_idx + k_idx] += d;
}
}
}
}
let storage = candle::WithDType::to_cpu_storage_owned(dst);
Ok((storage, (b, h, w, k).into()))
}
}
// rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
struct EinSum2;
impl candle::CustomOp2 for EinSum2 {
fn name(&self) -> &'static str {
"einsum2"
}
fn cpu_fwd(
&self,
s1: &candle::CpuStorage,
l1: &candle::Layout,
s2: &candle::CpuStorage,
l2: &candle::Layout,
) -> Result<(candle::CpuStorage, candle::Shape)> {
use candle::cpu::kernels::VecOps;
let (b, h, w, c) = l1.shape().dims4()?;
let (w2, k, c2) = l2.shape().dims3()?;
if c != c2 || w != w2 {
candle::bail!("shape mismatch {l1:?} {l2:?}")
}
let s1 = s1.as_slice::<f32>()?;
let s1 = match l1.contiguous_offsets() {
None => candle::bail!("input1 has to be contiguous"),
Some((o1, o2)) => &s1[o1..o2],
};
let s2 = s2.as_slice::<f32>()?;
let s2 = match l2.contiguous_offsets() {
None => candle::bail!("input2 has to be contiguous"),
Some((o1, o2)) => &s2[o1..o2],
};
let mut dst = vec![0f32; b * h * w * k];
for b_idx in 0..b {
let lhs_idx = b_idx * h * w * c;
let dst_idx = b_idx * h * w * k;
for h_idx in 0..h {
let lhs_idx = lhs_idx + h_idx * w * c;
let dst_idx = dst_idx + h_idx * w * k;
for w_idx in 0..w {
let lhs_idx = lhs_idx + w_idx * c;
let rhs_idx = w_idx * k * c;
let dst_idx = dst_idx + w_idx * k;
let lhs = &s1[lhs_idx..lhs_idx + c];
for k_idx in 0..k {
let rhs_idx = rhs_idx + k_idx * c;
let rhs = &s2[rhs_idx..rhs_idx + c];
let mut d = 0f32;
unsafe { f32::vec_dot(lhs.as_ptr(), rhs.as_ptr(), &mut d, c) };
dst[dst_idx + k_idx] += d;
}
}
}
}
let storage = candle::WithDType::to_cpu_storage_owned(dst);
Ok((storage, (b, h, w, k).into()))
}
}
impl Attention {
fn new(
dim: usize,
num_heads: usize,
qkv_bias: bool,
use_rel_pos: bool,
input_size: (usize, usize),
vb: VarBuilder,
) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "attention");
let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos");
let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
let proj = crate::linear(vb.pp("proj"), dim, dim, true)?;
let head_dim = dim / num_heads;
let scale = 1. / (head_dim as f64).sqrt();
let rel_pos_hw = if use_rel_pos {
let h = vb.get((2 * input_size.0 - 1, head_dim), "rel_pos_h")?;
let w = vb.get((2 * input_size.1 - 1, head_dim), "rel_pos_w")?;
Some((h, w))
} else {
None
};
Ok(Self {
qkv,
proj,
num_heads,
scale,
rel_pos_hw,
span,
span_rel_pos,
span_softmax,
})
}
fn add_decomposed_rel_pos(
&self,
attn: Tensor,
q: &Tensor,
(q_h, q_w): (usize, usize),
(k_h, k_w): (usize, usize),
) -> Result<Tensor> {
match &self.rel_pos_hw {
Some((rel_pos_h, rel_pos_w)) => {
println!("{:?} {:?}", attn.layout(), q.layout());
let r_h = get_rel_pos(q_h, k_h, rel_pos_h)?;
let r_w = get_rel_pos(q_w, k_w, rel_pos_w)?;
let (b, _, dim) = q.dims3()?;
let r_q = q.reshape((b, q_h, q_w, dim))?;
// rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
let rel_h = r_q.apply_op2_no_bwd(&r_h, &EinSum1)?;
// rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
let rel_w = r_q.apply_op2_no_bwd(&r_w, &EinSum2)?;
(attn.reshape((b, q_h, q_w, k_h, k_w))?
+ rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
.reshape((b, q_h * q_w, k_h * k_w))
}
None => Ok(attn),
}
}
}
fn get_rel_pos(q_size: usize, k_size: usize, rel_pos: &Tensor) -> Result<Tensor> {
let max_rel_dist = 2 * usize::max(q_size, k_size) - 1;
let dev = rel_pos.device();
let rel_pos_resized = if rel_pos.dim(0)? != max_rel_dist {
todo!("interpolation")
} else {
rel_pos
};
let q_coords = Tensor::arange(0u32, q_size as u32, dev)?
.reshape((q_size, 1))?
.to_dtype(DType::F32)?;
let k_coords = Tensor::arange(0u32, k_size as u32, dev)?
.reshape((1, k_size))?
.to_dtype(DType::F32)?;
let q_coords = (q_coords * f64::max(1f64, k_size as f64 / q_size as f64))?;
let k_coords = (k_coords * f64::max(1f64, q_size as f64 / k_size as f64))?;
let relative_coords = (q_coords.broadcast_sub(&k_coords)?
+ (k_size as f64 - 1.) * f64::max(1f64, q_size as f64 / k_size as f64))?;
let (d1, d2) = relative_coords.dims2()?;
let relative_coords = relative_coords.to_dtype(DType::U32)?;
rel_pos_resized
.index_select(&relative_coords.reshape(d1 * d2)?, 0)?
.reshape((d1, d2, ()))
}
impl Module for Attention {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (b, h, w, c) = xs.dims4()?;
let qkv = self
.qkv
.forward(&xs.flatten_to(1)?)?
.reshape((b, h * w, 3, self.num_heads, c / self.num_heads))?
.permute((2, 0, 3, 1, 4))?
.reshape((3, b * self.num_heads, h * w, c / self.num_heads))?;
let q = qkv.i(0)?;
let k = qkv.i(1)?;
let v = qkv.i(2)?;
let attn = (&q * self.scale)?.matmul(&k.t()?)?;
let attn = {
let _enter = self.span_rel_pos.enter();
self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))?
};
let attn = {
let _enter = self.span_softmax.enter();
candle_nn::ops::softmax_last_dim(&attn)?
};
let attn = attn.matmul(&v)?;
let attn = attn
.reshape((b, self.num_heads, h, w, c / self.num_heads))?
.permute((0, 2, 3, 1, 4))?
.reshape((b, h * w, c))?;
self.proj.forward(&attn)?.reshape((b, h, w, c))
}
}
#[derive(Debug)]
struct Block {
norm1: LayerNorm,
attn: Attention,
norm2: LayerNorm,
mlp: crate::MlpBlock,
window_size: usize,
span: tracing::Span,
}
impl Block {
fn new(
dim: usize,
num_heads: usize,
qkv_bias: bool,
use_rel_pos: bool,
window_size: usize,
input_size: (usize, usize),
vb: VarBuilder,
) -> Result<Self> {
let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?;
let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?;
let input_size_attn = if window_size == 0 {
input_size
} else {
(window_size, window_size)
};
let attn = Attention::new(
dim,
num_heads,
qkv_bias,
use_rel_pos,
input_size_attn,
vb.pp("attn"),
)?;
let mlp = crate::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?;
let span = tracing::span!(tracing::Level::TRACE, "ie-block");
Ok(Self {
norm1,
attn,
norm2,
mlp,
window_size,
span,
})
}
}
fn window_partition(xs: Tensor, window_size: usize) -> Result<(Tensor, (usize, usize))> {
let (b, h, w, c) = xs.dims4()?;
let pad_h = (window_size - h % window_size) % window_size;
let pad_w = (window_size - w % window_size) % window_size;
let xs = if pad_h > 0 {
xs.pad_with_zeros(1, 0, pad_h)?
} else {
xs
};
let xs = if pad_w > 0 {
xs.pad_with_zeros(2, 0, pad_w)?
} else {
xs
};
let (h_p, w_p) = (h + pad_h, w + pad_w);
let windows = xs
.reshape((
b,
h_p / window_size,
window_size,
w_p / window_size,
window_size,
c,
))?
.transpose(2, 3)?
.contiguous()?
.flatten_to(2)?;
Ok((windows, (h_p, w_p)))
}
fn window_unpartition(
windows: Tensor,
window_size: usize,
(h_p, w_p): (usize, usize),
(h, w): (usize, usize),
) -> Result<Tensor> {
let b = windows.dim(0)? / (h_p * w_p / window_size / window_size);
let xs = windows
.reshape((
b,
h_p / window_size,
w_p / window_size,
window_size,
window_size,
windows.elem_count() / b / h_p / w_p,
))?
.transpose(2, 3)?
.contiguous()?
.reshape((b, h_p, w_p, ()))?;
let xs = if h_p > h { xs.narrow(1, 0, h)? } else { xs };
let xs = if w_p > w { xs.narrow(2, 0, w)? } else { xs };
Ok(xs)
}
impl Module for Block {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let shortcut = xs;
let xs = self.norm1.forward(xs)?;
let hw = (xs.dim(1)?, xs.dim(2)?);
let (xs, pad_hw) = if self.window_size > 0 {
window_partition(xs, self.window_size)?
} else {
(xs, (0, 0))
};
let xs = self.attn.forward(&xs)?;
let xs = if self.window_size > 0 {
window_unpartition(xs, self.window_size, pad_hw, hw)?
} else {
xs
};
let xs = (xs + shortcut)?;
&xs + xs.apply(&self.norm2)?.apply(&self.mlp)?
}
}
#[derive(Debug)]
pub struct ImageEncoderViT {
patch_embed: PatchEmbed,
blocks: Vec<Block>,
neck_conv1: candle_nn::Conv2d,
neck_ln1: crate::LayerNorm2d,
neck_conv2: candle_nn::Conv2d,
neck_ln2: crate::LayerNorm2d,
pos_embed: Option<Tensor>,
span: tracing::Span,
}
impl ImageEncoderViT {
#[allow(clippy::too_many_arguments)]
pub fn new(
img_size: usize,
patch_size: usize,
in_chans: usize,
embed_dim: usize,
depth: usize,
num_heads: usize,
out_chans: usize,
qkv_bias: bool,
use_rel_pos: bool,
use_abs_pos: bool,
window_size: usize,
global_attn_indexes: &[usize],
vb: VarBuilder,
) -> Result<Self> {
let patch_embed = PatchEmbed::new(
in_chans,
embed_dim,
patch_size,
patch_size,
0,
vb.pp("patch_embed"),
)?;
let mut blocks = Vec::with_capacity(depth);
let vb_b = vb.pp("blocks");
for i in 0..depth {
let window_size = if global_attn_indexes.contains(&i) {
0
} else {
window_size
};
let block = Block::new(
embed_dim,
num_heads,
qkv_bias,
use_rel_pos,
window_size,
(img_size / patch_size, img_size / patch_size),
vb_b.pp(i),
)?;
blocks.push(block)
}
let neck_conv1 = candle_nn::conv2d_no_bias(
embed_dim,
out_chans,
1,
Default::default(),
vb.pp("neck.0"),
)?;
let neck_ln1 = crate::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.1"))?;
let cfg = candle_nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?;
let neck_ln2 = crate::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.3"))?;
let pos_embed = if use_abs_pos {
let p = vb.get(
(1, img_size / patch_size, img_size / patch_size, embed_dim),
"pos_embed",
)?;
Some(p)
} else {
None
};
let span = tracing::span!(tracing::Level::TRACE, "image-encoder-vit");
Ok(Self {
patch_embed,
blocks,
neck_conv1,
neck_ln1,
neck_conv2,
neck_ln2,
pos_embed,
span,
})
}
}
impl Module for ImageEncoderViT {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = self.patch_embed.forward(xs)?;
let mut xs = match &self.pos_embed {
Some(pos_embed) => (xs + pos_embed)?,
None => xs,
};
for block in self.blocks.iter() {
xs = block.forward(&xs)?
}
xs.permute((0, 3, 1, 2))?
.apply(&self.neck_conv1)?
.apply(&self.neck_ln1)?
.apply(&self.neck_conv2)?
.apply(&self.neck_ln2)
}
}

View File

@ -1,239 +0,0 @@
use candle::{IndexOp, Result, Tensor};
use candle_nn::{Module, VarBuilder};
use crate::model_transformer::TwoWayTransformer;
#[derive(Debug)]
struct MlpMaskDecoder {
layers: Vec<crate::Linear>,
sigmoid_output: bool,
span: tracing::Span,
}
impl MlpMaskDecoder {
fn new(
input_dim: usize,
hidden_dim: usize,
output_dim: usize,
num_layers: usize,
sigmoid_output: bool,
vb: VarBuilder,
) -> Result<Self> {
let mut layers = Vec::with_capacity(num_layers);
let vb = vb.pp("layers");
for i in 0..num_layers {
let in_dim = if i == 0 { input_dim } else { hidden_dim };
let out_dim = if i + 1 == num_layers {
output_dim
} else {
hidden_dim
};
let layer = crate::linear(vb.pp(i), in_dim, out_dim, true)?;
layers.push(layer)
}
let span = tracing::span!(tracing::Level::TRACE, "mlp-mask-decoder");
Ok(Self {
layers,
sigmoid_output,
span,
})
}
}
impl Module for MlpMaskDecoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for (i, layer) in self.layers.iter().enumerate() {
xs = layer.forward(&xs)?;
if i + 1 < self.layers.len() {
xs = xs.relu()?
}
}
if self.sigmoid_output {
candle_nn::ops::sigmoid(&xs)
} else {
Ok(xs)
}
}
}
#[derive(Debug)]
pub struct MaskDecoder {
iou_token: candle_nn::Embedding,
mask_tokens: candle_nn::Embedding,
iou_prediction_head: MlpMaskDecoder,
output_upscaling_conv1: candle_nn::ConvTranspose2d,
output_upscaling_ln: crate::LayerNorm2d,
output_upscaling_conv2: candle_nn::ConvTranspose2d,
num_mask_tokens: usize,
output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
transformer: TwoWayTransformer,
span: tracing::Span,
}
impl MaskDecoder {
pub fn new(
transformer_dim: usize,
num_multimask_outputs: usize,
iou_head_depth: usize,
iou_head_hidden_dim: usize,
vb: VarBuilder,
) -> Result<Self> {
let num_mask_tokens = num_multimask_outputs + 1;
let iou_prediction_head = MlpMaskDecoder::new(
transformer_dim,
iou_head_hidden_dim,
num_mask_tokens,
iou_head_depth,
false,
vb.pp("iou_prediction_head"),
)?;
let iou_token = candle_nn::embedding(1, transformer_dim, vb.pp("iou_token"))?;
let mask_tokens =
candle_nn::embedding(num_mask_tokens, transformer_dim, vb.pp("mask_tokens"))?;
let cfg = candle_nn::ConvTranspose2dConfig {
stride: 2,
..Default::default()
};
let output_upscaling_conv1 = candle_nn::conv_transpose2d(
transformer_dim,
transformer_dim / 4,
2,
cfg,
vb.pp("output_upscaling.0"),
)?;
let output_upscaling_ln =
crate::LayerNorm2d::new(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?;
let output_upscaling_conv2 = candle_nn::conv_transpose2d(
transformer_dim / 4,
transformer_dim / 8,
2,
cfg,
vb.pp("output_upscaling.3"),
)?;
let mut output_hypernetworks_mlps = Vec::with_capacity(num_mask_tokens);
let vb_o = vb.pp("output_hypernetworks_mlps");
for i in 0..num_mask_tokens {
let mlp = MlpMaskDecoder::new(
transformer_dim,
transformer_dim,
transformer_dim / 8,
3,
false,
vb_o.pp(i),
)?;
output_hypernetworks_mlps.push(mlp)
}
let transformer = TwoWayTransformer::new(
/* depth */ 2,
/* embedding_dim */ transformer_dim,
/* num_heads */ 8,
/* mlp_dim */ 2048,
vb.pp("transformer"),
)?;
let span = tracing::span!(tracing::Level::TRACE, "mask-decoder");
Ok(Self {
iou_token,
mask_tokens,
iou_prediction_head,
output_upscaling_conv1,
output_upscaling_ln,
output_upscaling_conv2,
num_mask_tokens,
output_hypernetworks_mlps,
transformer,
span,
})
}
pub fn forward(
&self,
image_embeddings: &Tensor,
image_pe: &Tensor,
sparse_prompt_embeddings: &Tensor,
dense_prompt_embeddings: &Tensor,
multimask_output: bool,
) -> Result<(Tensor, Tensor)> {
let _enter = self.span.enter();
let (masks, iou_pred) = self.predict_masks(
image_embeddings,
image_pe,
sparse_prompt_embeddings,
dense_prompt_embeddings,
)?;
let masks = if multimask_output {
masks.i((.., 1..))?
} else {
masks.i((.., 0..1))?
};
let iou_pred = if multimask_output {
iou_pred.i((.., 1..))?
} else {
iou_pred.i((.., 0..1))?
};
Ok((masks, iou_pred))
}
fn predict_masks(
&self,
image_embeddings: &Tensor,
image_pe: &Tensor,
sparse_prompt_embeddings: &Tensor,
dense_prompt_embeddings: &Tensor,
) -> Result<(Tensor, Tensor)> {
// Concatenate ouput tokens.
let output_tokens = Tensor::cat(
&[self.iou_token.embeddings(), self.mask_tokens.embeddings()],
0,
)?;
let (d1, d2) = output_tokens.dims2()?;
let output_tokens =
output_tokens
.unsqueeze(0)?
.expand((sparse_prompt_embeddings.dim(0)?, d1, d2))?;
let tokens = Tensor::cat(&[&output_tokens, sparse_prompt_embeddings], 1)?;
// Expand per-image data in batch direction to be per mask
let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?;
let src = src.broadcast_add(dense_prompt_embeddings)?;
let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?;
let (b, c, h, w) = src.dims4()?;
// Run the transformer
let (hs, src) = self.transformer.forward(&src, &pos_src, &tokens)?;
let iou_token_out = hs.i((.., 0))?;
let mask_tokens_out = hs.i((.., 1..1 + self.num_mask_tokens))?;
// Upscale mask embeddings and predict masks using the masks tokens.
let src = src.transpose(1, 2)?.reshape((b, c, h, w))?;
let upscaled_embedding = self
.output_upscaling_conv1
.forward(&src)?
.apply(&self.output_upscaling_ln)?
.gelu()?
.apply(&self.output_upscaling_conv2)?
.gelu()?;
let mut hyper_in_list = Vec::with_capacity(self.num_mask_tokens);
for (i, mlp) in self.output_hypernetworks_mlps.iter().enumerate() {
let h = mlp.forward(&mask_tokens_out.i((.., i))?)?;
hyper_in_list.push(h)
}
let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?.contiguous()?;
let (b, c, h, w) = upscaled_embedding.dims4()?;
let masks = hyper_in.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?;
let masks = masks.reshape((b, (), h, w))?;
// Generate mask quality predictions.
let iou_pred = self.iou_prediction_head.forward(&iou_token_out)?;
Ok((masks, iou_pred))
}
}
// Equivalent to torch.repeat_interleave
fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> {
let img = img.unsqueeze(dim + 1)?;
let mut dims = img.dims().to_vec();
dims[dim + 1] = repeats;
img.broadcast_as(dims)?.flatten(dim, dim + 1)
}

View File

@ -1,239 +0,0 @@
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::VarBuilder;
#[derive(Debug)]
struct PostionEmbeddingRandom {
positional_encoding_gaussian_matrix: Tensor,
}
impl PostionEmbeddingRandom {
fn new(num_pos_feats: usize, vb: VarBuilder) -> Result<Self> {
let positional_encoding_gaussian_matrix =
vb.get((2, num_pos_feats), "positional_encoding_gaussian_matrix")?;
Ok(Self {
positional_encoding_gaussian_matrix,
})
}
fn pe_encoding(&self, coords: &Tensor) -> Result<Tensor> {
let coords = coords.affine(2., -1.)?;
let coords = coords.broadcast_matmul(&self.positional_encoding_gaussian_matrix)?;
let coords = (coords * (2. * std::f64::consts::PI))?;
Tensor::cat(&[coords.sin()?, coords.cos()?], D::Minus1)
}
fn forward(&self, h: usize, w: usize) -> Result<Tensor> {
let device = self.positional_encoding_gaussian_matrix.device();
let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
let x_embed = (x_embed / w as f64)?
.reshape((1, ()))?
.broadcast_as((h, w))?;
let y_embed = (y_embed / h as f64)?
.reshape(((), 1))?
.broadcast_as((h, w))?;
let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?;
self.pe_encoding(&coords)?.permute((2, 0, 1))
}
fn forward_with_coords(
&self,
coords_input: &Tensor,
image_size: (usize, usize),
) -> Result<Tensor> {
let coords0 = (coords_input.narrow(D::Minus1, 0, 1)? / image_size.1 as f64)?;
let coords1 = (coords_input.narrow(D::Minus1, 1, 1)? / image_size.0 as f64)?;
let c = coords_input.dim(D::Minus1)?;
let coords_rest = coords_input.narrow(D::Minus1, 2, c - 2)?;
let coords = Tensor::cat(&[&coords0, &coords1, &coords_rest], D::Minus1)?;
self.pe_encoding(&coords)
}
}
#[derive(Debug)]
pub struct PromptEncoder {
pe_layer: PostionEmbeddingRandom,
point_embeddings: Vec<candle_nn::Embedding>,
not_a_point_embed: candle_nn::Embedding,
mask_downscaling_conv1: candle_nn::Conv2d,
mask_downscaling_ln1: crate::LayerNorm2d,
mask_downscaling_conv2: candle_nn::Conv2d,
mask_downscaling_ln2: crate::LayerNorm2d,
mask_downscaling_conv3: candle_nn::Conv2d,
no_mask_embed: candle_nn::Embedding,
image_embedding_size: (usize, usize),
input_image_size: (usize, usize),
embed_dim: usize,
span: tracing::Span,
}
impl PromptEncoder {
pub fn new(
embed_dim: usize,
image_embedding_size: (usize, usize),
input_image_size: (usize, usize),
mask_in_chans: usize,
vb: VarBuilder,
) -> Result<Self> {
let num_points_embeddings = 4;
let pe_layer = PostionEmbeddingRandom::new(embed_dim / 2, vb.pp("pe_layer"))?;
let not_a_point_embed = candle_nn::embedding(1, embed_dim, vb.pp("not_a_point_embed"))?;
let no_mask_embed = candle_nn::embedding(1, embed_dim, vb.pp("no_mask_embed"))?;
let cfg = candle_nn::Conv2dConfig {
stride: 2,
..Default::default()
};
let mask_downscaling_conv1 =
candle_nn::conv2d(1, mask_in_chans / 4, 2, cfg, vb.pp("mask_downscaling.0"))?;
let mask_downscaling_conv2 = candle_nn::conv2d(
mask_in_chans / 4,
mask_in_chans,
2,
cfg,
vb.pp("mask_downscaling.3"),
)?;
let mask_downscaling_conv3 = candle_nn::conv2d(
mask_in_chans,
embed_dim,
1,
Default::default(),
vb.pp("mask_downscaling.6"),
)?;
let mask_downscaling_ln1 =
crate::LayerNorm2d::new(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?;
let mask_downscaling_ln2 =
crate::LayerNorm2d::new(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?;
let mut point_embeddings = Vec::with_capacity(num_points_embeddings);
let vb_e = vb.pp("point_embeddings");
for i in 0..num_points_embeddings {
let emb = candle_nn::embedding(1, embed_dim, vb_e.pp(i))?;
point_embeddings.push(emb)
}
let span = tracing::span!(tracing::Level::TRACE, "prompt-encoder");
Ok(Self {
pe_layer,
point_embeddings,
not_a_point_embed,
mask_downscaling_conv1,
mask_downscaling_ln1,
mask_downscaling_conv2,
mask_downscaling_ln2,
mask_downscaling_conv3,
no_mask_embed,
image_embedding_size,
input_image_size,
embed_dim,
span,
})
}
pub fn get_dense_pe(&self) -> Result<Tensor> {
self.pe_layer
.forward(self.image_embedding_size.0, self.image_embedding_size.1)?
.unsqueeze(0)
}
fn embed_masks(&self, masks: &Tensor) -> Result<Tensor> {
masks
.apply(&self.mask_downscaling_conv1)?
.apply(&self.mask_downscaling_ln1)?
.gelu()?
.apply(&self.mask_downscaling_conv2)?
.apply(&self.mask_downscaling_ln2)?
.gelu()?
.apply(&self.mask_downscaling_conv3)
}
fn embed_points(&self, points: &Tensor, labels: &Tensor, pad: bool) -> Result<Tensor> {
let points = (points + 0.5)?;
let dev = points.device();
let (points, labels) = if pad {
let padding_point = Tensor::zeros((points.dim(0)?, 1, 2), DType::F32, dev)?;
let padding_label = (Tensor::ones((labels.dim(0)?, 1), DType::F32, dev)? * (-1f64))?;
let points = Tensor::cat(&[&points, &padding_point], 1)?;
let labels = Tensor::cat(&[labels, &padding_label], 1)?;
(points, labels)
} else {
(points, labels.clone())
};
let point_embedding = self
.pe_layer
.forward_with_coords(&points, self.input_image_size)?;
let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?;
let zeros = point_embedding.zeros_like()?;
let point_embedding = labels.lt(0f32)?.where_cond(
&self
.not_a_point_embed
.embeddings()
.broadcast_as(zeros.shape())?,
&point_embedding,
)?;
let labels0 = labels.eq(0f32)?.where_cond(
&self.point_embeddings[0]
.embeddings()
.broadcast_as(zeros.shape())?,
&zeros,
)?;
let point_embedding = (point_embedding + labels0)?;
let labels1 = labels.eq(1f32)?.where_cond(
&self.point_embeddings[1]
.embeddings()
.broadcast_as(zeros.shape())?,
&zeros,
)?;
let point_embedding = (point_embedding + labels1)?;
Ok(point_embedding)
}
fn embed_boxes(&self, boxes: &Tensor) -> Result<Tensor> {
let boxes = (boxes + 0.5)?;
let coords = boxes.reshape(((), 2, 2))?;
let corner_embedding = self
.pe_layer
.forward_with_coords(&coords, self.input_image_size)?;
let ce1 = corner_embedding.i((.., 0))?;
let ce2 = corner_embedding.i((.., 1))?;
let ce1 = (ce1 + self.point_embeddings[2].embeddings())?;
let ce2 = (ce2 + self.point_embeddings[3].embeddings())?;
Tensor::cat(&[&ce1, &ce2], 1)
}
pub fn forward(
&self,
points: Option<(&Tensor, &Tensor)>,
boxes: Option<&Tensor>,
masks: Option<&Tensor>,
) -> Result<(Tensor, Tensor)> {
let _enter = self.span.enter();
let se_points = match points {
Some((coords, labels)) => Some(self.embed_points(coords, labels, boxes.is_none())?),
None => None,
};
let se_boxes = match boxes {
Some(boxes) => Some(self.embed_boxes(boxes)?),
None => None,
};
let sparse_embeddings = match (se_points, se_boxes) {
(Some(se_points), Some(se_boxes)) => Tensor::cat(&[se_points, se_boxes], 1)?,
(Some(se_points), None) => se_points,
(None, Some(se_boxes)) => se_boxes,
(None, None) => {
Tensor::zeros((1, 0, self.embed_dim), DType::F32, &candle::Device::Cpu)?
}
};
let dense_embeddings = match masks {
None => {
let emb = self.no_mask_embed.embeddings();
emb.reshape((1, (), 1, 1))?.expand((
1,
emb.elem_count(),
self.image_embedding_size.0,
self.image_embedding_size.1,
))?
}
Some(masks) => self.embed_masks(masks)?,
};
Ok((sparse_embeddings, dense_embeddings))
}
}

View File

@ -1,364 +0,0 @@
use candle::{DType, IndexOp, Result, Tensor};
use candle_nn::{Module, VarBuilder};
use crate::model_image_encoder::ImageEncoderViT;
use crate::model_mask_decoder::MaskDecoder;
use crate::model_prompt_encoder::PromptEncoder;
const PROMPT_EMBED_DIM: usize = 256;
pub const IMAGE_SIZE: usize = 1024;
const VIT_PATCH_SIZE: usize = 16;
const PRED_IOU_THRESH: f32 = 0.88;
const STABILITY_SCORE_OFFSET: f32 = 1.0;
const STABILITY_SCORE_THRESHOLD: f32 = 0.95;
const MODEL_MASK_THRESHOLD: f32 = 0.0;
const CROP_NMS_THRESH: f32 = 0.7;
#[derive(Debug)]
pub struct Sam {
image_encoder: ImageEncoderViT,
prompt_encoder: PromptEncoder,
mask_decoder: MaskDecoder,
pixel_mean: Tensor,
pixel_std: Tensor,
}
impl Sam {
pub fn new(
encoder_embed_dim: usize,
encoder_depth: usize,
encoder_num_heads: usize,
encoder_global_attn_indexes: &[usize],
vb: VarBuilder,
) -> Result<Self> {
let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE;
let image_encoder = ImageEncoderViT::new(
IMAGE_SIZE,
VIT_PATCH_SIZE,
3,
encoder_embed_dim,
encoder_depth,
encoder_num_heads,
PROMPT_EMBED_DIM,
/* qkv_bias */ true,
/* use_rel_pos */ true,
/* use_abs_pos */ true,
/* window_size */ 14,
/* global_attn_indexes */ encoder_global_attn_indexes,
vb.pp("image_encoder"),
)?;
let prompt_encoder = PromptEncoder::new(
PROMPT_EMBED_DIM,
(image_embedding_size, image_embedding_size),
(IMAGE_SIZE, IMAGE_SIZE),
16,
vb.pp("prompt_encoder"),
)?;
let mask_decoder = MaskDecoder::new(
PROMPT_EMBED_DIM,
/* num_multitask_outputs */ 3,
/* iou_head_depth */ 3,
/* iou_head_hidden_dim */ 256,
vb.pp("mask_decoder"),
)?;
let pixel_mean =
Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?;
let pixel_std =
Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
Ok(Self {
image_encoder,
prompt_encoder,
mask_decoder,
pixel_std,
pixel_mean,
})
}
pub fn forward(
&self,
img: &Tensor,
point: Option<(f64, f64)>,
multimask_output: bool,
) -> Result<(Tensor, Tensor)> {
let (_c, original_h, original_w) = img.dims3()?;
let img = self.preprocess(img)?.unsqueeze(0)?;
let img_embeddings = self.image_encoder.forward(&img)?;
let image_pe = self.prompt_encoder.get_dense_pe()?;
let points = match point {
None => None,
Some((x, y)) => {
let points = Tensor::new(
&[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]],
img.device(),
)?;
let labels = Tensor::ones((1, 1), DType::F32, img.device())?;
Some((points, labels))
}
};
let points = points.as_ref().map(|(x, y)| (x, y));
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
self.prompt_encoder.forward(points, None, None)?;
let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
&img_embeddings,
&image_pe,
&sparse_prompt_embeddings,
&dense_prompt_embeddings,
multimask_output,
)?;
let mask = low_res_mask
.upsample_nearest2d(IMAGE_SIZE, IMAGE_SIZE)?
.get(0)?
.i((.., ..original_h, ..original_w))?;
Ok((mask, iou_predictions))
}
pub fn unpreprocess(&self, img: &Tensor) -> Result<Tensor> {
let img = img
.broadcast_mul(&self.pixel_std)?
.broadcast_add(&self.pixel_mean)?;
img.maximum(&img.zeros_like()?)?
.minimum(&(img.ones_like()? * 255.)?)
}
pub fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
let (_c, h, w) = img.dims3()?;
let img = img
.to_dtype(DType::F32)?
.broadcast_sub(&self.pixel_mean)?
.broadcast_div(&self.pixel_std)?;
if h > IMAGE_SIZE || w > IMAGE_SIZE {
candle::bail!("image is too large ({w}, {h}), maximum size {IMAGE_SIZE}")
}
let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?;
img.pad_with_zeros(2, 0, IMAGE_SIZE - w)
}
fn process_crop(
&self,
img: &Tensor,
cb: CropBox,
point_grids: &[(f64, f64)],
) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> {
// Crop the image and calculate embeddings.
let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?;
let img = self.preprocess(&img)?.unsqueeze(0)?;
let img_embeddings = self.image_encoder.forward(&img)?;
let crop_w = cb.x1 - cb.x0;
let crop_h = cb.y1 - cb.y0;
// Generate masks for this crop.
let image_pe = self.prompt_encoder.get_dense_pe()?;
let points = point_grids
.iter()
.map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32])
.collect::<Vec<_>>();
let mut bboxes = Vec::new();
for points in points.chunks(64) {
// Run the model on this batch.
let points_len = points.len();
let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?;
let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?;
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
self.prompt_encoder
.forward(Some((&in_points, &in_labels)), None, None)?;
let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
&img_embeddings,
&image_pe,
&sparse_prompt_embeddings,
&dense_prompt_embeddings,
/* multimask_output */ true,
)?;
let low_res_mask = low_res_mask.flatten(0, 1)?;
let iou_predictions = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?;
let dev = low_res_mask.device();
for (i, iou) in iou_predictions.iter().enumerate() {
// Filter by predicted IoU.
if *iou < PRED_IOU_THRESH {
continue;
}
let low_res_mask = low_res_mask.get(i)?;
// Calculate stability score.
let bound = Tensor::new(MODEL_MASK_THRESHOLD + STABILITY_SCORE_OFFSET, dev)?
.broadcast_as(low_res_mask.shape())?;
let intersections = low_res_mask
.ge(&bound)?
.to_dtype(DType::F32)?
.sum_all()?
.to_vec0::<f32>()?;
let bound = Tensor::new(MODEL_MASK_THRESHOLD - STABILITY_SCORE_OFFSET, dev)?
.broadcast_as(low_res_mask.shape())?;
let unions = low_res_mask
.ge(&bound)?
.to_dtype(DType::F32)?
.sum_all()?
.to_vec0::<f32>()?;
let stability_score = intersections / unions;
if stability_score < STABILITY_SCORE_THRESHOLD {
continue;
}
// Threshold masks and calculate boxes.
let low_res_mask = low_res_mask
.ge(&Tensor::new(0f32, dev)?.broadcast_as(low_res_mask.shape())?)?
.to_dtype(DType::U32)?;
let low_res_mask_per_x = low_res_mask.sum(0)?.to_vec1::<u32>()?;
let low_res_mask_per_y = low_res_mask.sum(1)?.to_vec1::<u32>()?;
let min_max_x = min_max_indexes(&low_res_mask_per_x);
let min_max_y = min_max_indexes(&low_res_mask_per_y);
if let Some(((x0, x1), (y0, y1))) = min_max_x.zip(min_max_y) {
let bbox = candle_examples::object_detection::Bbox {
xmin: x0 as f32,
ymin: y0 as f32,
xmax: x1 as f32,
ymax: y1 as f32,
confidence: *iou,
data: low_res_mask,
};
bboxes.push(bbox);
}
// TODO:
// Filter boxes that touch crop boundaries
// Compress to RLE.
}
}
let mut bboxes = vec![bboxes];
// Remove duplicates within this crop.
candle_examples::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH);
// TODO: Return to the original image frame.
Ok(bboxes.remove(0))
}
pub fn generate_masks(
&self,
img: &Tensor,
points_per_side: usize,
crop_n_layer: usize,
crop_overlap_ratio: f64,
crop_n_points_downscale_factor: usize,
) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> {
let (_c, h, w) = img.dims3()?;
let point_grids = build_all_layer_point_grids(
points_per_side,
crop_n_layer,
crop_n_points_downscale_factor,
);
let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio);
let mut bboxes = Vec::new();
for crop_box in crop_boxes.into_iter() {
let layer_idx = crop_box.layer_idx;
let b = self.process_crop(img, crop_box, &point_grids[layer_idx])?;
bboxes.extend(b)
}
// TODO: remove duplicates
Ok(bboxes)
}
}
// Return the first and last indexes i for which values[i] > 0
fn min_max_indexes(values: &[u32]) -> Option<(usize, usize)> {
let (mut min_i, mut max_i) = (usize::MAX, usize::MIN);
for (i, &s) in values.iter().enumerate() {
if s == 0 {
continue;
}
min_i = usize::min(i, min_i);
max_i = usize::max(i, max_i);
}
if max_i < min_i {
None
} else {
Some((min_i, max_i))
}
}
#[derive(Debug)]
struct CropBox {
x0: usize,
y0: usize,
x1: usize,
y1: usize,
layer_idx: usize,
}
impl CropBox {
fn new(x0: usize, y0: usize, x1: usize, y1: usize, layer_idx: usize) -> Self {
Self {
x0,
y0,
x1,
y1,
layer_idx,
}
}
}
fn generate_crop_boxes(
(im_h, im_w): (usize, usize),
n_layers: usize,
overlap_ratio: f64,
) -> Vec<CropBox> {
fn crop_len(orig_len: usize, n_crops: usize, overlap: usize) -> usize {
f64::ceil((overlap * (n_crops - 1) + orig_len) as f64 / n_crops as f64) as usize
}
let short_side = usize::min(im_h, im_w);
let mut crop_boxes = Vec::new();
// Original image.
crop_boxes.push(CropBox::new(0, 0, im_w, im_h, 0));
for layer_idx in 1..=n_layers {
let n_crops_per_side = 1 << layer_idx;
let overlap = (overlap_ratio * short_side as f64 * 2. / n_crops_per_side as f64) as usize;
let crop_w = crop_len(im_w, n_crops_per_side, overlap);
let crop_h = crop_len(im_w, n_crops_per_side, overlap);
for i_x in 0..n_crops_per_side {
let x0 = (crop_w - overlap) * i_x;
for i_y in 0..n_crops_per_side {
let y0 = (crop_h - overlap) * i_y;
let x1 = usize::min(im_w, x0 + crop_w);
let y1 = usize::min(im_h, y0 + crop_h);
crop_boxes.push(CropBox::new(x0, y0, x1, y1, layer_idx));
}
}
}
crop_boxes
}
// Generates a 2D grid of points evenly spaced in [0,1]x[0,1].
fn build_point_grid(n_per_side: usize) -> Vec<(f64, f64)> {
let offset = 1f64 / (2 * n_per_side) as f64;
let mut points = Vec::with_capacity(n_per_side * n_per_side);
for i_x in 0..n_per_side {
let x = offset + i_x as f64 / n_per_side as f64;
for i_y in 0..n_per_side {
let y = offset + i_y as f64 / n_per_side as f64;
points.push((x, y))
}
}
points
}
fn build_all_layer_point_grids(
n_per_side: usize,
n_layers: usize,
scale_per_layer: usize,
) -> Vec<Vec<(f64, f64)>> {
let mut points_by_layer = Vec::with_capacity(n_layers + 1);
for i in 0..=n_layers {
let n_points = n_per_side / scale_per_layer.pow(i as u32);
points_by_layer.push(build_point_grid(n_points))
}
points_by_layer
}

View File

@ -1,221 +0,0 @@
use candle::{Result, Tensor};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
#[derive(Debug)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
out_proj: Linear,
num_heads: usize,
}
impl Attention {
fn new(
embedding_dim: usize,
num_heads: usize,
downsample_rate: usize,
vb: VarBuilder,
) -> Result<Self> {
let internal_dim = embedding_dim / downsample_rate;
let q_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("q_proj"))?;
let k_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("k_proj"))?;
let v_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("v_proj"))?;
let out_proj = candle_nn::linear(internal_dim, embedding_dim, vb.pp("out_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
out_proj,
num_heads,
})
}
fn separate_heads(&self, x: &Tensor) -> Result<Tensor> {
let (b, n, c) = x.dims3()?;
x.reshape((b, n, self.num_heads, c / self.num_heads))?
.transpose(1, 2)?
.contiguous()
}
fn recombine_heads(&self, x: &Tensor) -> Result<Tensor> {
let (b, n_heads, n_tokens, c_per_head) = x.dims4()?;
x.transpose(1, 2)?
.reshape((b, n_tokens, n_heads * c_per_head))
}
fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
let q = self.q_proj.forward(&q.contiguous()?)?;
let k = self.k_proj.forward(&k.contiguous()?)?;
let v = self.v_proj.forward(&v.contiguous()?)?;
let q = self.separate_heads(&q)?;
let k = self.separate_heads(&k)?;
let v = self.separate_heads(&v)?;
let (_, _, _, c_per_head) = q.dims4()?;
let attn = (q.matmul(&k.t()?)? / (c_per_head as f64).sqrt())?;
let attn = candle_nn::ops::softmax_last_dim(&attn)?;
let out = attn.matmul(&v)?;
self.recombine_heads(&out)?.apply(&self.out_proj)
}
}
#[derive(Debug)]
struct TwoWayAttentionBlock {
self_attn: Attention,
norm1: LayerNorm,
cross_attn_token_to_image: Attention,
norm2: LayerNorm,
mlp: crate::MlpBlock,
norm3: LayerNorm,
norm4: LayerNorm,
cross_attn_image_to_token: Attention,
skip_first_layer_pe: bool,
}
impl TwoWayAttentionBlock {
fn new(
embedding_dim: usize,
num_heads: usize,
mlp_dim: usize,
skip_first_layer_pe: bool,
vb: VarBuilder,
) -> Result<Self> {
let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?;
let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?;
let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?;
let norm4 = layer_norm(embedding_dim, 1e-5, vb.pp("norm4"))?;
let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?;
let cross_attn_token_to_image = Attention::new(
embedding_dim,
num_heads,
2,
vb.pp("cross_attn_token_to_image"),
)?;
let cross_attn_image_to_token = Attention::new(
embedding_dim,
num_heads,
2,
vb.pp("cross_attn_image_to_token"),
)?;
let mlp = crate::MlpBlock::new(
embedding_dim,
mlp_dim,
candle_nn::Activation::Relu,
vb.pp("mlp"),
)?;
Ok(Self {
self_attn,
norm1,
cross_attn_image_to_token,
norm2,
mlp,
norm3,
norm4,
cross_attn_token_to_image,
skip_first_layer_pe,
})
}
fn forward(
&self,
queries: &Tensor,
keys: &Tensor,
query_pe: &Tensor,
key_pe: &Tensor,
) -> Result<(Tensor, Tensor)> {
// Self attention block
let queries = if self.skip_first_layer_pe {
self.self_attn.forward(queries, queries, queries)?
} else {
let q = (queries + query_pe)?;
let attn_out = self.self_attn.forward(&q, &q, queries)?;
(queries + attn_out)?
};
let queries = self.norm1.forward(&queries)?;
// Cross attention block, tokens attending to image embedding
let q = (&queries + query_pe)?;
let k = (keys + key_pe)?;
let attn_out = self.cross_attn_token_to_image.forward(&q, &k, keys)?;
let queries = (&queries + attn_out)?;
let queries = self.norm2.forward(&queries)?;
// MLP block
let mlp_out = self.mlp.forward(&queries);
let queries = (queries + mlp_out)?;
let queries = self.norm3.forward(&queries)?;
// Cross attention block, image embedding attending to tokens
let q = (&queries + query_pe)?;
let k = (keys + key_pe)?;
let attn_out = self.cross_attn_image_to_token.forward(&k, &q, &queries)?;
let keys = (keys + attn_out)?;
let keys = self.norm4.forward(&keys)?;
Ok((queries, keys))
}
}
#[derive(Debug)]
pub struct TwoWayTransformer {
layers: Vec<TwoWayAttentionBlock>,
final_attn_token_to_image: Attention,
norm_final_attn: LayerNorm,
}
impl TwoWayTransformer {
pub fn new(
depth: usize,
embedding_dim: usize,
num_heads: usize,
mlp_dim: usize,
vb: VarBuilder,
) -> Result<Self> {
let vb_l = vb.pp("layers");
let mut layers = Vec::with_capacity(depth);
for i in 0..depth {
let layer =
TwoWayAttentionBlock::new(embedding_dim, num_heads, mlp_dim, i == 0, vb_l.pp(i))?;
layers.push(layer)
}
let final_attn_token_to_image = Attention::new(
embedding_dim,
num_heads,
2,
vb.pp("final_attn_token_to_image"),
)?;
let norm_final_attn = layer_norm(embedding_dim, 1e-5, vb.pp("norm_final_attn"))?;
Ok(Self {
layers,
final_attn_token_to_image,
norm_final_attn,
})
}
pub fn forward(
&self,
image_embedding: &Tensor,
image_pe: &Tensor,
point_embedding: &Tensor,
) -> Result<(Tensor, Tensor)> {
let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?;
let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?;
let mut queries = point_embedding.clone();
let mut keys = image_embedding;
for layer in self.layers.iter() {
(queries, keys) = layer.forward(&queries, &keys, point_embedding, &image_pe)?
}
let q = (&queries + point_embedding)?;
let k = (&keys + image_pe)?;
let attn_out = self.final_attn_token_to_image.forward(&q, &k, &keys)?;
let queries = (queries + attn_out)?.apply(&self.norm_final_attn)?;
Ok((queries, keys))
}
}

View File

@ -1,7 +1,7 @@
#![allow(dead_code)]
//! Attention Based Building Blocks
use candle::{DType, IndexOp, Result, Tensor, D};
use candle::{IndexOp, Result, Tensor, D};
use candle_nn as nn;
use candle_nn::Module;
#[derive(Debug)]
struct GeGlu {
@ -61,22 +61,6 @@ impl FeedForward {
}
}
#[cfg(feature = "flash-attn")]
fn flash_attn(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
}
#[cfg(not(feature = "flash-attn"))]
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
unimplemented!("compile with '--features flash-attn'")
}
#[derive(Debug)]
struct CrossAttention {
to_q: nn::Linear,
@ -88,8 +72,6 @@ struct CrossAttention {
slice_size: Option<usize>,
span: tracing::Span,
span_attn: tracing::Span,
span_softmax: tracing::Span,
use_flash_attn: bool,
}
impl CrossAttention {
@ -101,7 +83,6 @@ impl CrossAttention {
heads: usize,
dim_head: usize,
slice_size: Option<usize>,
use_flash_attn: bool,
) -> Result<Self> {
let inner_dim = dim_head * heads;
let context_dim = context_dim.unwrap_or(query_dim);
@ -112,7 +93,6 @@ impl CrossAttention {
let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?;
let span = tracing::span!(tracing::Level::TRACE, "xa");
let span_attn = tracing::span!(tracing::Level::TRACE, "xa-attn");
let span_softmax = tracing::span!(tracing::Level::TRACE, "xa-softmax");
Ok(Self {
to_q,
to_k,
@ -123,8 +103,6 @@ impl CrossAttention {
slice_size,
span,
span_attn,
span_softmax,
use_flash_attn,
})
}
@ -151,10 +129,6 @@ impl CrossAttention {
) -> Result<Tensor> {
let batch_size_attention = query.dim(0)?;
let mut hidden_states = Vec::with_capacity(batch_size_attention / slice_size);
let in_dtype = query.dtype();
let query = query.to_dtype(DType::F32)?;
let key = key.to_dtype(DType::F32)?;
let value = value.to_dtype(DType::F32)?;
for i in 0..batch_size_attention / slice_size {
let start_idx = i * slice_size;
@ -166,65 +140,35 @@ impl CrossAttention {
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(&value.i(start_idx..end_idx)?)?;
hidden_states.push(xs)
}
let hidden_states = Tensor::stack(&hidden_states, 0)?.to_dtype(in_dtype)?;
let hidden_states = Tensor::stack(&hidden_states, 0)?;
self.reshape_batch_dim_to_heads(&hidden_states)
}
fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
let _enter = self.span_attn.enter();
let xs = if self.use_flash_attn {
let init_dtype = query.dtype();
let q = query
.to_dtype(candle::DType::F16)?
.unsqueeze(0)?
.transpose(1, 2)?;
let k = key
.to_dtype(candle::DType::F16)?
.unsqueeze(0)?
.transpose(1, 2)?;
let v = value
.to_dtype(candle::DType::F16)?
.unsqueeze(0)?
.transpose(1, 2)?;
flash_attn(&q, &k, &v, self.scale as f32, false)?
.transpose(1, 2)?
.squeeze(0)?
.to_dtype(init_dtype)?
} else {
let in_dtype = query.dtype();
let query = query.to_dtype(DType::F32)?;
let key = key.to_dtype(DType::F32)?;
let value = value.to_dtype(DType::F32)?;
let xs = query.matmul(&(key.t()? * self.scale)?)?;
let xs = {
let _enter = self.span_softmax.enter();
nn::ops::softmax_last_dim(&xs)?
};
xs.matmul(&value)?.to_dtype(in_dtype)?
};
let xs = query.matmul(&(key.transpose(D::Minus1, D::Minus2)? * self.scale)?)?;
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?;
self.reshape_batch_dim_to_heads(&xs)
}
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let query = self.to_q.forward(xs)?;
let context = context.unwrap_or(xs).contiguous()?;
let key = self.to_k.forward(&context)?;
let value = self.to_v.forward(&context)?;
let context = context.unwrap_or(xs);
let key = self.to_k.forward(context)?;
let value = self.to_v.forward(context)?;
let query = self.reshape_heads_to_batch_dim(&query)?;
let key = self.reshape_heads_to_batch_dim(&key)?;
let value = self.reshape_heads_to_batch_dim(&value)?;
let dim0 = query.dim(0)?;
let slice_size = self.slice_size.and_then(|slice_size| {
if dim0 < slice_size {
None
} else {
Some(slice_size)
}
});
let xs = match slice_size {
let xs = match self.slice_size {
None => self.attention(&query, &key, &value)?,
Some(slice_size) => self.sliced_attention(&query, &key, &value, slice_size)?,
Some(slice_size) => {
if query.dim(0)? / slice_size <= 1 {
self.attention(&query, &key, &value)?
} else {
self.sliced_attention(&query, &key, &value, slice_size)?
}
}
};
self.to_out.forward(&xs)
}
@ -250,7 +194,6 @@ impl BasicTransformerBlock {
d_head: usize,
context_dim: Option<usize>,
sliced_attention_size: Option<usize>,
use_flash_attn: bool,
) -> Result<Self> {
let attn1 = CrossAttention::new(
vs.pp("attn1"),
@ -259,7 +202,6 @@ impl BasicTransformerBlock {
n_heads,
d_head,
sliced_attention_size,
use_flash_attn,
)?;
let ff = FeedForward::new(vs.pp("ff"), dim, None, 4)?;
let attn2 = CrossAttention::new(
@ -269,7 +211,6 @@ impl BasicTransformerBlock {
n_heads,
d_head,
sliced_attention_size,
use_flash_attn,
)?;
let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?;
let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?;
@ -338,7 +279,6 @@ impl SpatialTransformer {
in_channels: usize,
n_heads: usize,
d_head: usize,
use_flash_attn: bool,
config: SpatialTransformerConfig,
) -> Result<Self> {
let inner_dim = n_heads * d_head;
@ -364,7 +304,6 @@ impl SpatialTransformer {
d_head,
config.context_dim,
config.sliced_attention_size,
use_flash_attn,
)?;
transformer_blocks.push(tb)
}
@ -473,15 +412,10 @@ impl AttentionBlock {
let num_heads = channels / num_head_channels;
let group_norm =
nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?;
let (q_path, k_path, v_path, out_path) = if vs.contains_tensor("to_q.weight") {
("to_q", "to_k", "to_v", "to_out.0")
} else {
("query", "key", "value", "proj_attn")
};
let query = nn::linear(channels, channels, vs.pp(q_path))?;
let key = nn::linear(channels, channels, vs.pp(k_path))?;
let value = nn::linear(channels, channels, vs.pp(v_path))?;
let proj_attn = nn::linear(channels, channels, vs.pp(out_path))?;
let query = nn::linear(channels, channels, vs.pp("query"))?;
let key = nn::linear(channels, channels, vs.pp("key"))?;
let value = nn::linear(channels, channels, vs.pp("value"))?;
let proj_attn = nn::linear(channels, channels, vs.pp("proj_attn"))?;
let span = tracing::span!(tracing::Level::TRACE, "attn-block");
Ok(Self {
group_norm,
@ -504,7 +438,6 @@ impl AttentionBlock {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let in_dtype = xs.dtype();
let residual = xs;
let (batch, channel, height, width) = xs.dims4()?;
let xs = self
@ -517,13 +450,9 @@ impl AttentionBlock {
let key_proj = self.key.forward(&xs)?;
let value_proj = self.value.forward(&xs)?;
let query_states = self
.transpose_for_scores(query_proj)?
.to_dtype(DType::F32)?;
let key_states = self.transpose_for_scores(key_proj)?.to_dtype(DType::F32)?;
let value_states = self
.transpose_for_scores(value_proj)?
.to_dtype(DType::F32)?;
let query_states = self.transpose_for_scores(query_proj)?;
let key_states = self.transpose_for_scores(key_proj)?;
let value_states = self.transpose_for_scores(value_proj)?;
let scale = f64::powf((self.channels as f64) / (self.num_heads as f64), -0.25);
let attention_scores =
@ -532,7 +461,6 @@ impl AttentionBlock {
let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?;
let xs = attention_probs.matmul(&value_states.contiguous()?)?;
let xs = xs.to_dtype(in_dtype)?;
let xs = xs.transpose(1, 2)?.contiguous()?;
let xs = xs.flatten_from(D::Minus2)?;
let xs = self

Some files were not shown because too many files have changed in this diff Show More