mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Compare commits
1 Commits
einsum-cus
...
initialize
Author | SHA1 | Date | |
---|---|---|---|
0bb344f798 |
@ -1,8 +1,8 @@
|
||||
[build]
|
||||
[target.x86_64-unknown-linux-gnu]
|
||||
rustflags = ["-C", "target-cpu=native"]
|
||||
|
||||
[target.aarch64-apple-darwin]
|
||||
rustflags = ["-C", "target-cpu=native"]
|
||||
|
||||
[target.wasm32-unknown-unknown]
|
||||
rustflags = ["-C", "target-feature=+simd128"]
|
||||
|
||||
[target.x86_64-apple-darwin]
|
||||
rustflags = ["-C", "target-feature=-avx,-avx2"]
|
2
.github/workflows/ci_cuda.yaml
vendored
2
.github/workflows/ci_cuda.yaml
vendored
@ -59,7 +59,7 @@ jobs:
|
||||
- name: Install Rust Stable
|
||||
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- run: apt-get update -y && apt-get install libssl-dev -y
|
||||
- run: apt update -y && apt install libssl-dev -y
|
||||
- name: Test (cuda)
|
||||
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
|
||||
stop-runner:
|
||||
|
5
.gitignore
vendored
5
.gitignore
vendored
@ -25,12 +25,7 @@ flamegraph.svg
|
||||
*.swp
|
||||
trace-*.json
|
||||
|
||||
candle-wasm-examples/*/build
|
||||
candle-wasm-examples/*/*.bin
|
||||
candle-wasm-examples/*/*.jpeg
|
||||
candle-wasm-examples/*/*.wav
|
||||
candle-wasm-examples/*/*.safetensors
|
||||
candle-wasm-examples/*/package-lock.json
|
||||
|
||||
.DS_Store
|
||||
.idea/*
|
||||
|
42
CHANGELOG.md
42
CHANGELOG.md
@ -1,42 +0,0 @@
|
||||
# Changelog
|
||||
This documents the main changes to the `candle` crate.
|
||||
|
||||
## v0.2.1 - Unreleased
|
||||
|
||||
### Added
|
||||
|
||||
### Modified
|
||||
- Dilations are now supported in conv-transpose2d.
|
||||
[671](https://github.com/huggingface/candle/pull/671).
|
||||
|
||||
## v0.2.0 - 2023-08-30
|
||||
|
||||
### Added
|
||||
- Add the powf op
|
||||
[664](https://github.com/huggingface/candle/pull/664).
|
||||
- Stable Diffusion XL support
|
||||
[647](https://github.com/huggingface/candle/pull/647).
|
||||
- Add the conv-transpose2d op
|
||||
[635](https://github.com/huggingface/candle/pull/635).
|
||||
- Refactor the VarBuilder api
|
||||
[627](https://github.com/huggingface/candle/pull/627).
|
||||
- Add some quantization command
|
||||
[625](https://github.com/huggingface/candle/pull/625).
|
||||
- Support more quantized types, e.g. Q2K, Q4K, Q5K...
|
||||
[586](https://github.com/huggingface/candle/pull/586).
|
||||
- Add pose estimation to the yolo example
|
||||
[589](https://github.com/huggingface/candle/pull/589).
|
||||
- Api to write GGUF files
|
||||
[585](https://github.com/huggingface/candle/pull/585).
|
||||
- Support more quantization types
|
||||
[580](https://github.com/huggingface/candle/pull/580).
|
||||
- Add EfficientNet as an example Computer Vision model
|
||||
[572](https://github.com/huggingface/candle/pull/572).
|
||||
- Add a group parameter to convolutions
|
||||
[566](https://github.com/huggingface/candle/pull/566).
|
||||
- New dtype: int64
|
||||
[563](https://github.com/huggingface/candle/pull/563).
|
||||
- Handling of the GGUF file format.
|
||||
[559](https://github.com/huggingface/candle/pull/559).
|
||||
|
||||
## v0.1.2 - 2023-08-21
|
10
Cargo.toml
10
Cargo.toml
@ -3,22 +3,19 @@ members = [
|
||||
"candle-core",
|
||||
"candle-datasets",
|
||||
"candle-examples",
|
||||
"candle-book",
|
||||
"candle-nn",
|
||||
"candle-pyo3",
|
||||
"candle-transformers",
|
||||
"candle-wasm-examples/llama2-c",
|
||||
"candle-wasm-examples/whisper",
|
||||
"candle-wasm-examples/yolo",
|
||||
]
|
||||
exclude = [
|
||||
"candle-flash-attn",
|
||||
"candle-kernels",
|
||||
]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.2.1"
|
||||
version = "0.1.1"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -34,10 +31,9 @@ clap = { version = "4.2.4", features = ["derive"] }
|
||||
cudarc = { version = "0.9.14", features = ["f16"] }
|
||||
# TODO: Switch back to the official gemm implementation once it has caught up.
|
||||
gemm = { version = "0.15.6", package = "candle-gemm" }
|
||||
hf-hub = "0.3.0"
|
||||
hf-hub = "0.2.0"
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
|
||||
imageproc = { version = "0.23.0", default-features = false }
|
||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||
libc = { version = "0.2.147" }
|
||||
log = "0.4"
|
||||
@ -47,7 +43,6 @@ num-traits = "0.2.15"
|
||||
rand = "0.8.5"
|
||||
rand_distr = "0.4.3"
|
||||
rayon = "1.7.0"
|
||||
rusttype = { version = "0.9", default-features = false }
|
||||
safetensors = "0.3.1"
|
||||
serde = { version = "1.0.171", features = ["derive"] }
|
||||
serde_json = "1.0.99"
|
||||
@ -58,7 +53,6 @@ tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
wav = "1.0.0"
|
||||
zip = { version = "0.6.6", default-features = false }
|
||||
parquet = { version = "45.0.0" }
|
||||
|
||||
[profile.release-with-debug]
|
||||
inherits = "release"
|
||||
|
6
Makefile
6
Makefile
@ -1,5 +1,3 @@
|
||||
.PHONY: clean-ptx clean test
|
||||
|
||||
clean-ptx:
|
||||
find target -name "*.ptx" -type f -delete
|
||||
echo "" > candle-kernels/src/lib.rs
|
||||
@ -13,4 +11,8 @@ clean:
|
||||
test:
|
||||
cargo test
|
||||
|
||||
pyo3-test:
|
||||
cargo build --profile=release-with-debug --package candle-pyo3
|
||||
python3 candle-pyo3/test.py
|
||||
|
||||
all: test
|
||||
|
135
README.md
135
README.md
@ -1,5 +1,5 @@
|
||||
# candle
|
||||
[](https://discord.gg/hugging-face-879548962464493619)
|
||||
[](https://discord.com/channels/879548962464493619/1136218819447238726)
|
||||
[](https://crates.io/crates/candle-core)
|
||||
[](https://docs.rs/candle-core)
|
||||

|
||||
@ -7,65 +7,29 @@
|
||||
Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support)
|
||||
and ease of use. Try our online demos:
|
||||
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
||||
[LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2),
|
||||
[yolo](https://huggingface.co/spaces/lmz/candle-yolo).
|
||||
[llama2](https://huggingface.co/spaces/lmz/candle-llama2).
|
||||
|
||||
## Get started
|
||||
|
||||
Make sure that you have [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) correctly installed as described in [**Installation**](https://huggingface.github.io/candle/guide/installation.html).
|
||||
|
||||
Let's see how to run a simple matrix multiplication.
|
||||
Write the following to your `myapp/src/main.rs` file:
|
||||
```rust
|
||||
use candle_core::{Device, Tensor};
|
||||
let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
|
||||
let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let device = Device::Cpu;
|
||||
|
||||
let a = Tensor::randn(0f32, 1., (2, 3), &device)?;
|
||||
let b = Tensor::randn(0f32, 1., (3, 4), &device)?;
|
||||
|
||||
let c = a.matmul(&b)?;
|
||||
println!("{c}");
|
||||
Ok(())
|
||||
}
|
||||
let c = a.matmul(&b)?;
|
||||
println!("{c}");
|
||||
```
|
||||
|
||||
`cargo run` should display a tensor of shape `Tensor[[2, 4], f32]`.
|
||||
|
||||
|
||||
Having installed `candle` with Cuda support, simply define the `device` to be on GPU:
|
||||
|
||||
```diff
|
||||
- let device = Device::Cpu;
|
||||
+ let device = Device::new_cuda(0)?;
|
||||
```
|
||||
|
||||
For more advanced examples, please have a look at the following section.
|
||||
|
||||
## Check out our examples
|
||||
|
||||
Check out our [examples](./candle-examples/examples/):
|
||||
|
||||
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
||||
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
|
||||
- [Llama and Llama-v2](./candle-examples/examples/llama/): general LLM.
|
||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||
- [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code
|
||||
generation.
|
||||
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
|
||||
image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions.
|
||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||
using self-supervision (can be used for imagenet classification, depth
|
||||
evaluation, segmentation).
|
||||
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
|
||||
the LLaMA model using the same quantization techniques as
|
||||
[llama.cpp](https://github.com/ggerganov/llama.cpp).
|
||||
- [yolo-v3](./candle-examples/examples/yolo-v3/) and
|
||||
[yolo-v8](./candle-examples/examples/yolo-v8/): object detection and pose
|
||||
estimation models.
|
||||
[segment-anything](./candle-examples/examples/segment-anything/): image
|
||||
segmentation model with prompt.
|
||||
image generative model, yet to be optimized.
|
||||
|
||||
Run them using the following commands:
|
||||
```
|
||||
cargo run --example whisper --release
|
||||
@ -73,16 +37,10 @@ cargo run --example llama --release
|
||||
cargo run --example falcon --release
|
||||
cargo run --example bert --release
|
||||
cargo run --example bigcode --release
|
||||
cargo run --example stable-diffusion --release -- --prompt "a rusty robot holding a fire torch"
|
||||
cargo run --example dinov2 --release -- --image path/to/myinput.jpg
|
||||
cargo run --example quantized --release
|
||||
cargo run --example yolo-v3 --release -- myimage.jpg
|
||||
cargo run --example yolo-v8 --release -- myimage.jpg # for pose estimation, add --task pose
|
||||
cargo run --example segment-anything --release -- --image myimage.jpg
|
||||
cargo run --example stable-diffusion --release --features image -- --prompt "a rusty robot holding a fire torch"
|
||||
```
|
||||
|
||||
In order to use **CUDA** add `--features cuda` to the example command line. If
|
||||
you have cuDNN installed, use `--features cudnn` for even more speedups.
|
||||
In order to use **CUDA** add `--features cuda` to the example command line.
|
||||
|
||||
There are also some wasm examples for whisper and
|
||||
[llama2.c](https://github.com/karpathy/llama2.c). You can either build them with
|
||||
@ -90,16 +48,16 @@ There are also some wasm examples for whisper and
|
||||
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
||||
[llama2](https://huggingface.co/spaces/lmz/candle-llama2).
|
||||
|
||||
For LLaMA2, run the following command to retrieve the weight files and start a
|
||||
For llama2, run the following command to retrieve the weight files and start a
|
||||
test server:
|
||||
```bash
|
||||
cd candle-wasm-examples/llama2-c
|
||||
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/model.bin
|
||||
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/tokenizer.json
|
||||
trunk serve --release --port 8081
|
||||
trunk serve --release --public-url /candle-llama2/ --port 8081
|
||||
```
|
||||
And then head over to
|
||||
[http://localhost:8081/](http://localhost:8081/).
|
||||
[http://localhost:8081/candle-llama2](http://localhost:8081/candle-llama2).
|
||||
|
||||
<!--- ANCHOR: features --->
|
||||
|
||||
@ -112,14 +70,11 @@ And then head over to
|
||||
- Optimized CPU backend with optional MKL support for x86 and Accelerate for macs.
|
||||
- CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.
|
||||
- WASM support, run your models in a browser.
|
||||
- Included models.
|
||||
- LLMs: LLaMA v1 and v2, Falcon, StarCoder.
|
||||
- Whisper (multi-lingual support).
|
||||
- Model support out of the box.
|
||||
- LLMs: Llama v1 and v2, Falcon, StarCoder.
|
||||
- Whisper.
|
||||
- Stable Diffusion.
|
||||
- Computer Vision: DINOv2, EfficientNet, yolo-v3, yolo-v8.
|
||||
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
||||
- Serverless (on CPU), small and fast deployments.
|
||||
- Quantization support using the llama.cpp quantized types.
|
||||
|
||||
<!--- ANCHOR_END: features --->
|
||||
|
||||
@ -136,7 +91,7 @@ Cheatsheet:
|
||||
| Operations | `tensor.view((2, 2))` | `tensor.reshape((2, 2))?` |
|
||||
| Operations | `a.matmul(b)` | `a.matmul(&b)?` |
|
||||
| Arithmetic | `a + b` | `&a + &b` |
|
||||
| Device | `tensor.to(device="cuda")` | `tensor.to_device(&Device::new_cuda(0)?)?` |
|
||||
| Device | `tensor.to(device="cuda")` | `tensor.to_device(&Device::Cuda(0))?` |
|
||||
| Dtype | `tensor.to(dtype=torch.float16)` | `tensor.to_dtype(&DType::F16)?` |
|
||||
| Saving | `torch.save({"A": A}, "model.bin")` | `candle::safetensors::save(&HashMap::from([("A", A)]), "model.safetensors")?` |
|
||||
| Loading | `weights = torch.load("model.bin")` | `candle::safetensors::load("model.safetensors", &device)` |
|
||||
@ -189,74 +144,34 @@ Finally, Rust is cool! A lot of the HF ecosystem already has Rust crates, like [
|
||||
#### Missing symbols when compiling with the mkl feature.
|
||||
|
||||
If you get some missing symbols when compiling binaries/tests using the mkl
|
||||
or accelerate features, e.g. for mkl you get:
|
||||
features, e.g.:
|
||||
```
|
||||
= note: /usr/bin/ld: (....o): in function `blas::sgemm':
|
||||
.../blas-0.22.0/src/lib.rs:1944: undefined reference to `sgemm_' collect2: error: ld returned 1 exit status
|
||||
|
||||
= note: some `extern` functions couldn't be found; some native libraries may need to be installed or have their path specified
|
||||
= note: use the `-l` flag to specify native libraries to link
|
||||
= note: use the `cargo:rustc-link-lib` directive to specify the native libraries to link with Cargo
|
||||
```
|
||||
or for accelerate:
|
||||
```
|
||||
Undefined symbols for architecture arm64:
|
||||
"_dgemm_", referenced from:
|
||||
candle_core::accelerate::dgemm::h1b71a038552bcabe in libcandle_core...
|
||||
"_sgemm_", referenced from:
|
||||
candle_core::accelerate::sgemm::h2cf21c592cba3c47 in libcandle_core...
|
||||
ld: symbol(s) not found for architecture arm64
|
||||
= note: use the `cargo:rustc-link-lib` directive to specify the native libraries to link with Cargo (see https://doc.rust-lang.org/cargo/reference/build-scripts.html#cargorustc-link-libkindname)
|
||||
```
|
||||
|
||||
This is likely due to a missing linker flag that was needed to enable the mkl library. You
|
||||
can try adding the following for mkl at the top of your binary:
|
||||
```rust
|
||||
can try adding the following at the top of your binary:
|
||||
```
|
||||
extern crate intel_mkl_src;
|
||||
```
|
||||
or for accelerate:
|
||||
```rust
|
||||
extern crate accelerate_src;
|
||||
```
|
||||
|
||||
#### Cannot run the LLaMA examples: access to source requires login credentials
|
||||
#### Cannot run llama example : access to source requires login credentials
|
||||
|
||||
```
|
||||
Error: request error: https://huggingface.co/meta-llama/Llama-2-7b-hf/resolve/main/tokenizer.json: status code 401
|
||||
```
|
||||
|
||||
This is likely because you're not permissioned for the LLaMA-v2 model. To fix
|
||||
this, you have to register on the huggingface-hub, accept the [LLaMA-v2 model
|
||||
This is likely because you're not permissioned for the llama-v2 model. To fix
|
||||
this, you have to register on the huggingface-hub, accept the [llama-v2 model
|
||||
conditions](https://huggingface.co/meta-llama/Llama-2-7b-hf), and set up your
|
||||
authentication token. See issue
|
||||
[#350](https://github.com/huggingface/candle/issues/350) for more details.
|
||||
|
||||
#### Missing cute/cutlass headers when compiling flash-attn
|
||||
|
||||
```
|
||||
In file included from kernels/flash_fwd_launch_template.h:11:0,
|
||||
from kernels/flash_fwd_hdim224_fp16_sm80.cu:5:
|
||||
kernels/flash_fwd_kernel.h:8:10: fatal error: cute/algorithm/copy.hpp: No such file or directory
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
^~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
compilation terminated.
|
||||
Error: nvcc error while compiling:
|
||||
```
|
||||
[cutlass](https://github.com/NVIDIA/cutlass) is provided as a git submodule so you may want to run the following command to check it in properly.
|
||||
```bash
|
||||
git submodule update --init
|
||||
```
|
||||
|
||||
#### Compiling with flash-attention fails
|
||||
|
||||
```
|
||||
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ‘...’:
|
||||
```
|
||||
|
||||
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the CANDLE_NVCC_CCBIN environment variable.
|
||||
```
|
||||
env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
|
||||
```
|
||||
|
||||
#### Tracking down errors
|
||||
|
||||
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
|
||||
|
@ -1,49 +0,0 @@
|
||||
[package]
|
||||
name = "candle-book"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
license.workspace = true
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.2.1" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.1" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.2.1" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
image = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
byteorder = { workspace = true }
|
||||
hf-hub = { workspace = true, features=["tokio"]}
|
||||
clap = { workspace = true }
|
||||
memmap2 = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-chrome = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||
tokio = "1.29.1"
|
||||
parquet = { workspace = true }
|
||||
image = { workspace = true }
|
||||
|
||||
[build-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = []
|
@ -12,11 +12,7 @@
|
||||
|
||||
- [Running a model](inference/README.md)
|
||||
- [Using the hub](inference/hub.md)
|
||||
- [Error management](error_manage.md)
|
||||
- [Training](training/README.md)
|
||||
- [MNIST](training/mnist.md)
|
||||
- [Fine-tuning]()
|
||||
- [Serialization]()
|
||||
- [Error management]()
|
||||
- [Advanced Cuda usage]()
|
||||
- [Writing a custom kernel]()
|
||||
- [Porting a custom kernel]()
|
||||
@ -25,3 +21,7 @@
|
||||
- [Creating a WASM app]()
|
||||
- [Creating a REST api webserver]()
|
||||
- [Creating a desktop Tauri app]()
|
||||
- [Training]()
|
||||
- [MNIST]()
|
||||
- [Fine-tuning]()
|
||||
- [Serialization]()
|
||||
|
@ -147,7 +147,7 @@ And rewrite our examples using it
|
||||
# extern crate candle_core;
|
||||
# extern crate candle_nn;
|
||||
use candle_core::{DType, Device, Result, Tensor};
|
||||
use candle_nn::{Linear, Module};
|
||||
use candle_nn::Linear;
|
||||
|
||||
struct Model {
|
||||
first: Linear,
|
||||
|
@ -1,43 +1,6 @@
|
||||
# Installation
|
||||
|
||||
**With Cuda support**:
|
||||
|
||||
1. First, make sure that Cuda is correctly installed.
|
||||
- `nvcc --version` should print information about your Cuda compiler driver.
|
||||
- `nvidia-smi --query-gpu=compute_cap --format=csv` should print your GPUs compute capability, e.g. something
|
||||
like:
|
||||
|
||||
```bash
|
||||
compute_cap
|
||||
8.9
|
||||
```
|
||||
|
||||
If any of the above commands errors out, please make sure to update your Cuda version.
|
||||
|
||||
2. Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) with Cuda support.
|
||||
|
||||
Start by creating a new cargo:
|
||||
|
||||
```bash
|
||||
cargo new myapp
|
||||
cd myapp
|
||||
```
|
||||
|
||||
Make sure to add the `candle-core` crate with the cuda feature:
|
||||
|
||||
```bash
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core --features "cuda"
|
||||
```
|
||||
|
||||
Run `cargo build` to make sure everything can be correctly built.
|
||||
|
||||
```bash
|
||||
cargo build
|
||||
```
|
||||
|
||||
**Without Cuda support**:
|
||||
|
||||
Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) as follows:
|
||||
Start by creating a new app:
|
||||
|
||||
```bash
|
||||
cargo new myapp
|
||||
@ -45,12 +8,17 @@ cd myapp
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core
|
||||
```
|
||||
|
||||
Finally, run `cargo build` to make sure everything can be correctly built.
|
||||
At this point, candle will be built **without** CUDA support.
|
||||
To get CUDA support use the `cuda` feature
|
||||
```bash
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core --features cuda
|
||||
```
|
||||
|
||||
You can check everything works properly:
|
||||
|
||||
```bash
|
||||
cargo build
|
||||
```
|
||||
|
||||
**With mkl support**
|
||||
|
||||
You can also see the `mkl` feature which could be interesting to get faster inference on CPU. [Using mkl](./advanced/mkl.md)
|
||||
|
@ -39,7 +39,7 @@ cargo add hf-hub --features tokio
|
||||
```rust,ignore
|
||||
# This is tested directly in examples crate because it needs external dependencies unfortunately:
|
||||
# See [this](https://github.com/rust-lang/mdBook/issues/706)
|
||||
{{#include ../lib.rs:book_hub_1}}
|
||||
{{#include ../../../candle-examples/src/lib.rs:book_hub_1}}
|
||||
```
|
||||
|
||||
|
||||
@ -58,7 +58,7 @@ Now that we have our weights, we can use them in our bert architecture:
|
||||
#
|
||||
# let weights = repo.get("model.safetensors").unwrap();
|
||||
use candle_core::{Device, Tensor, DType};
|
||||
use candle_nn::{Linear, Module};
|
||||
use candle_nn::Linear;
|
||||
|
||||
let weights = candle_core::safetensors::load(weights, &Device::Cpu).unwrap();
|
||||
|
||||
@ -81,7 +81,7 @@ For more efficient loading, instead of reading the file, you could use [`memmap2
|
||||
and will definitely be slower on network mounted disk, because it will issue more read calls.
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../lib.rs:book_hub_2}}
|
||||
{{#include ../../../candle-examples/src/lib.rs:book_hub_2}}
|
||||
```
|
||||
|
||||
**Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety).
|
||||
@ -100,5 +100,5 @@ cargo add safetensors
|
||||
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../lib.rs:book_hub_3}}
|
||||
{{#include ../../../candle-examples/src/lib.rs:book_hub_3}}
|
||||
```
|
||||
|
@ -1,193 +0,0 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Device, Tensor};
|
||||
use parquet::file::reader::SerializedFileReader;
|
||||
|
||||
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
|
||||
#[rustfmt::skip]
|
||||
#[tokio::test]
|
||||
async fn book_hub_1() {
|
||||
// ANCHOR: book_hub_1
|
||||
use candle::Device;
|
||||
use hf_hub::api::tokio::Api;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
let repo = api.model("bert-base-uncased".to_string());
|
||||
|
||||
let weights_filename = repo.get("model.safetensors").await.unwrap();
|
||||
|
||||
let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
|
||||
// ANCHOR_END: book_hub_1
|
||||
assert_eq!(weights.len(), 206);
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_hub_2() {
|
||||
// ANCHOR: book_hub_2
|
||||
use candle::Device;
|
||||
use hf_hub::api::sync::Api;
|
||||
use memmap2::Mmap;
|
||||
use std::fs;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
let repo = api.model("bert-base-uncased".to_string());
|
||||
let weights_filename = repo.get("model.safetensors").unwrap();
|
||||
|
||||
let file = fs::File::open(weights_filename).unwrap();
|
||||
let mmap = unsafe { Mmap::map(&file).unwrap() };
|
||||
let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
|
||||
// ANCHOR_END: book_hub_2
|
||||
assert_eq!(weights.len(), 206);
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_hub_3() {
|
||||
// ANCHOR: book_hub_3
|
||||
use candle::{DType, Device, Tensor};
|
||||
use hf_hub::api::sync::Api;
|
||||
use memmap2::Mmap;
|
||||
use safetensors::slice::IndexOp;
|
||||
use safetensors::SafeTensors;
|
||||
use std::fs;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
let repo = api.model("bert-base-uncased".to_string());
|
||||
let weights_filename = repo.get("model.safetensors").unwrap();
|
||||
|
||||
let file = fs::File::open(weights_filename).unwrap();
|
||||
let mmap = unsafe { Mmap::map(&file).unwrap() };
|
||||
|
||||
// Use safetensors directly
|
||||
let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
|
||||
let view = tensors
|
||||
.tensor("bert.encoder.layer.0.attention.self.query.weight")
|
||||
.unwrap();
|
||||
|
||||
// We're going to load shard with rank 1, within a world_size of 4
|
||||
// We're going to split along dimension 0 doing VIEW[start..stop, :]
|
||||
let rank = 1;
|
||||
let world_size = 4;
|
||||
let dim = 0;
|
||||
let dtype = view.dtype();
|
||||
let mut tp_shape = view.shape().to_vec();
|
||||
let size = tp_shape[0];
|
||||
|
||||
if size % world_size != 0 {
|
||||
panic!("The dimension is not divisble by `world_size`");
|
||||
}
|
||||
let block_size = size / world_size;
|
||||
let start = rank * block_size;
|
||||
let stop = (rank + 1) * block_size;
|
||||
|
||||
// Everything is expressed in tensor dimension
|
||||
// bytes offsets is handled automatically for safetensors.
|
||||
|
||||
let iterator = view.slice(start..stop).unwrap();
|
||||
|
||||
tp_shape[dim] = block_size;
|
||||
|
||||
// Convert safetensors Dtype to candle DType
|
||||
let dtype: DType = dtype.try_into().unwrap();
|
||||
|
||||
// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
|
||||
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
|
||||
let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
|
||||
// ANCHOR_END: book_hub_3
|
||||
assert_eq!(view.shape(), &[768, 768]);
|
||||
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_training_1() -> Result<()>{
|
||||
// ANCHOR: book_training_1
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
|
||||
let dataset_id = "mnist".to_string();
|
||||
|
||||
let api = Api::new()?;
|
||||
let repo = Repo::with_revision(
|
||||
dataset_id,
|
||||
RepoType::Dataset,
|
||||
"refs/convert/parquet".to_string(),
|
||||
);
|
||||
let repo = api.repo(repo);
|
||||
let test_parquet_filename = repo.get("mnist/test/0000.parquet")?;
|
||||
let train_parquet_filename = repo.get("mnist/train/0000.parquet")?;
|
||||
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
|
||||
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
|
||||
// ANCHOR_END: book_training_1
|
||||
// Ignore unused
|
||||
let _train = train_parquet;
|
||||
// ANCHOR: book_training_2
|
||||
for row in test_parquet {
|
||||
for (idx, (name, field)) in row?.get_column_iter().enumerate() {
|
||||
println!("Column id {idx}, name {name}, value {field}");
|
||||
}
|
||||
}
|
||||
// ANCHOR_END: book_training_2
|
||||
let test_parquet_filename = repo.get("mnist/test/0000.parquet")?;
|
||||
let train_parquet_filename = repo.get("mnist/train/0000.parquet")?;
|
||||
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
|
||||
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
|
||||
// ANCHOR: book_training_3
|
||||
|
||||
let test_samples = 10_000;
|
||||
let mut test_buffer_images: Vec<u8> = Vec::with_capacity(test_samples * 784);
|
||||
let mut test_buffer_labels: Vec<u8> = Vec::with_capacity(test_samples);
|
||||
for row in test_parquet{
|
||||
for (_name, field) in row?.get_column_iter() {
|
||||
if let parquet::record::Field::Group(subrow) = field {
|
||||
for (_name, field) in subrow.get_column_iter() {
|
||||
if let parquet::record::Field::Bytes(value) = field {
|
||||
let image = image::load_from_memory(value.data()).unwrap();
|
||||
test_buffer_images.extend(image.to_luma8().as_raw());
|
||||
}
|
||||
}
|
||||
}else if let parquet::record::Field::Long(label) = field {
|
||||
test_buffer_labels.push(*label as u8);
|
||||
}
|
||||
}
|
||||
}
|
||||
let test_images = (Tensor::from_vec(test_buffer_images, (test_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
|
||||
let test_labels = Tensor::from_vec(test_buffer_labels, (test_samples, ), &Device::Cpu)?;
|
||||
|
||||
let train_samples = 60_000;
|
||||
let mut train_buffer_images: Vec<u8> = Vec::with_capacity(train_samples * 784);
|
||||
let mut train_buffer_labels: Vec<u8> = Vec::with_capacity(train_samples);
|
||||
for row in train_parquet{
|
||||
for (_name, field) in row?.get_column_iter() {
|
||||
if let parquet::record::Field::Group(subrow) = field {
|
||||
for (_name, field) in subrow.get_column_iter() {
|
||||
if let parquet::record::Field::Bytes(value) = field {
|
||||
let image = image::load_from_memory(value.data()).unwrap();
|
||||
train_buffer_images.extend(image.to_luma8().as_raw());
|
||||
}
|
||||
}
|
||||
}else if let parquet::record::Field::Long(label) = field {
|
||||
train_buffer_labels.push(*label as u8);
|
||||
}
|
||||
}
|
||||
}
|
||||
let train_images = (Tensor::from_vec(train_buffer_images, (train_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
|
||||
let train_labels = Tensor::from_vec(train_buffer_labels, (train_samples, ), &Device::Cpu)?;
|
||||
|
||||
let mnist = candle_datasets::vision::Dataset {
|
||||
train_images,
|
||||
train_labels,
|
||||
test_images,
|
||||
test_labels,
|
||||
labels: 10,
|
||||
};
|
||||
|
||||
// ANCHOR_END: book_training_3
|
||||
assert_eq!(mnist.test_images.dims(), &[10_000, 784]);
|
||||
assert_eq!(mnist.test_labels.dims(), &[10_000]);
|
||||
assert_eq!(mnist.train_images.dims(), &[60_000, 784]);
|
||||
assert_eq!(mnist.train_labels.dims(), &[60_000]);
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -1,39 +1 @@
|
||||
# Training
|
||||
|
||||
|
||||
Training starts with data. We're going to use the huggingface hub and
|
||||
start with the Hello world dataset of machine learning, MNIST.
|
||||
|
||||
Let's start with downloading `MNIST` from [huggingface](https://huggingface.co/datasets/mnist).
|
||||
|
||||
This requires [`hf-hub`](https://github.com/huggingface/hf-hub).
|
||||
```bash
|
||||
cargo add hf-hub
|
||||
```
|
||||
|
||||
This is going to be very hands-on for now.
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../../../candle-examples/src/lib.rs:book_training_1}}
|
||||
```
|
||||
|
||||
This uses the standardized `parquet` files from the `refs/convert/parquet` branch on every dataset.
|
||||
Our handles are now [`parquet::file::serialized_reader::SerializedFileReader`].
|
||||
|
||||
We can inspect the content of the files with:
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../../../candle-examples/src/lib.rs:book_training_2}}
|
||||
```
|
||||
|
||||
You should see something like:
|
||||
|
||||
```bash
|
||||
Column id 1, name label, value 6
|
||||
Column id 0, name image, value {bytes: [137, ....]
|
||||
Column id 1, name label, value 8
|
||||
Column id 0, name image, value {bytes: [137, ....]
|
||||
```
|
||||
|
||||
So each row contains 2 columns (image, label) with image being saved as bytes.
|
||||
Let's put them into a useful struct.
|
||||
|
@ -1,10 +1 @@
|
||||
# MNIST
|
||||
|
||||
So we now have downloaded the MNIST parquet files, let's put them in a simple struct.
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../lib.rs:book_training_3}}
|
||||
```
|
||||
|
||||
The parsing of the file and putting it into single tensors requires the dataset to fit the entire memory.
|
||||
It is quite rudimentary, but simple enough for a small dataset like MNIST.
|
||||
|
@ -12,7 +12,7 @@ readme = "README.md"
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.2.1", optional = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.1.1", optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
gemm = { workspace = true }
|
||||
half = { workspace = true }
|
||||
|
@ -11,7 +11,7 @@ fn main() -> Result<()> {
|
||||
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
|
||||
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
|
||||
let start = std::time::Instant::now();
|
||||
let res = inp.conv2d(&w, 0, 1, 1, 1)?;
|
||||
let res = inp.conv2d(&w, 0, 1);
|
||||
println!("{:?}", start.elapsed());
|
||||
println!("{res:?}");
|
||||
Ok(())
|
||||
|
@ -5,10 +5,18 @@ extern crate intel_mkl_src;
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::quantized::GgmlType;
|
||||
use candle::{Device, Result, Tensor, D};
|
||||
use candle_core::{Device, Result, Tensor, D};
|
||||
use clap::{Parser, Subcommand};
|
||||
|
||||
fn softmax<D: candle_core::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> {
|
||||
let dim = dim.to_index(xs.shape(), "softmax")?;
|
||||
let max = xs.max_keepdim(dim)?;
|
||||
let diff = xs.broadcast_sub(&max)?;
|
||||
let num = diff.exp()?;
|
||||
let den = num.sum_keepdim(dim)?;
|
||||
num.broadcast_div(&den)
|
||||
}
|
||||
|
||||
trait Benchmark {
|
||||
type PreProcessData;
|
||||
type RunResult;
|
||||
@ -31,7 +39,7 @@ impl Benchmark for Conv1d {
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
d.0.conv1d(&d.1, 0, 1, 1, 1)
|
||||
d.0.conv1d(&d.1, 0, 1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 5;
|
||||
@ -50,7 +58,7 @@ impl Benchmark for Conv2d {
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
d.0.conv2d(&d.1, 0, 1, 1, 1)
|
||||
d.0.conv2d(&d.1, 0, 1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 1;
|
||||
@ -73,27 +81,6 @@ impl Benchmark for Matmul {
|
||||
const ITERS: usize = 100;
|
||||
}
|
||||
|
||||
// This benchmark is similar to:
|
||||
// https://github.com/ggerganov/llama.cpp/blob/master/examples/benchmark/benchmark-matmult.cpp
|
||||
struct QMatMul;
|
||||
impl Benchmark for QMatMul {
|
||||
type PreProcessData = (candle::quantized::QMatMul, Tensor);
|
||||
type RunResult = Tensor;
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
let zeros = vec![candle::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
|
||||
let mm = candle::quantized::QTensor::new(zeros, (4096, 11008))?;
|
||||
let mm = candle::quantized::QMatMul::from_qtensor(mm);
|
||||
let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?;
|
||||
Ok((mm, arg))
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
d.0.forward(&d.1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 100;
|
||||
}
|
||||
|
||||
struct Softmax;
|
||||
impl Benchmark for Softmax {
|
||||
type PreProcessData = Tensor;
|
||||
@ -105,24 +92,7 @@ impl Benchmark for Softmax {
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
candle_nn::ops::softmax(d, D::Minus1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 100;
|
||||
}
|
||||
|
||||
struct SoftmaxLastDim;
|
||||
impl Benchmark for SoftmaxLastDim {
|
||||
type PreProcessData = Tensor;
|
||||
type RunResult = Tensor;
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
// Typical whisper tiny size.
|
||||
let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?;
|
||||
Ok(x)
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
candle_nn::ops::softmax_last_dim(d)
|
||||
softmax(d, D::Minus1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 100;
|
||||
@ -146,9 +116,7 @@ enum Task {
|
||||
Conv1d,
|
||||
Conv2d,
|
||||
Matmul,
|
||||
Qmatmul,
|
||||
Softmax,
|
||||
SoftmaxLastDim,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -169,8 +137,6 @@ fn main() -> Result<()> {
|
||||
Task::Conv2d => run::<Conv2d>(args.iters)?,
|
||||
Task::Matmul => run::<Matmul>(args.iters)?,
|
||||
Task::Softmax => run::<Softmax>(args.iters)?,
|
||||
Task::SoftmaxLastDim => run::<SoftmaxLastDim>(args.iters)?,
|
||||
Task::Qmatmul => run::<QMatMul>(args.iters)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -9,21 +9,9 @@ use candle_core::{Device, Tensor};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
let in_t = Tensor::rand(-1f32, 1f32, (1, 3, 12, 7), &device)?;
|
||||
let k_t = Tensor::rand(-1f32, 1f32, (6, 3, 1, 1), &device)?;
|
||||
let out_t = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
|
||||
println!("{out_t}");
|
||||
let in_t = in_t.to_device(&Device::Cpu)?;
|
||||
let k_t = k_t.to_device(&Device::Cpu)?;
|
||||
let out_t2 = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
|
||||
let diff = (out_t.to_device(&Device::Cpu)? - out_t2)?
|
||||
.sqr()?
|
||||
.sum_all()?;
|
||||
println!("{diff}");
|
||||
|
||||
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
|
||||
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
|
||||
let res = t.conv2d(&w, 1, 1, 1, 1)?;
|
||||
let res = t.conv2d(&w, 1, 1)?;
|
||||
println!("{res:?}");
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,299 +0,0 @@
|
||||
use candle_core::quantized::{gguf_file, k_quants, QTensor};
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
use clap::{Parser, Subcommand, ValueEnum};
|
||||
use rayon::prelude::*;
|
||||
|
||||
#[derive(ValueEnum, Debug, Clone)]
|
||||
enum QuantizationMode {
|
||||
/// The default quantization includes all 2d tensors, except the output tensor which always
|
||||
/// uses Q6_K.
|
||||
Llama,
|
||||
}
|
||||
|
||||
impl QuantizationMode {
|
||||
fn quantize(
|
||||
&self,
|
||||
name: &str,
|
||||
tensor: QTensor,
|
||||
default: fn(&Tensor) -> Result<QTensor>,
|
||||
) -> Result<QTensor> {
|
||||
match self {
|
||||
Self::Llama => {
|
||||
// Same behavior as the llama.cpp quantization.
|
||||
let should_quantize = name.ends_with(".weight") && tensor.rank() == 2;
|
||||
if should_quantize {
|
||||
let tensor = tensor.dequantize(&Device::Cpu)?;
|
||||
if name == "output.weight" {
|
||||
QTensor::quantize::<k_quants::BlockQ6K>(&tensor)
|
||||
} else {
|
||||
default(&tensor)
|
||||
}
|
||||
} else {
|
||||
Ok(tensor)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(ValueEnum, Debug, Clone)]
|
||||
enum Quantization {
|
||||
#[value(name = "q4_0")]
|
||||
Q4_0,
|
||||
#[value(name = "q4_1")]
|
||||
Q4_1,
|
||||
#[value(name = "q5_0")]
|
||||
Q5_0,
|
||||
#[value(name = "q5_1")]
|
||||
Q5_1,
|
||||
#[value(name = "q8_0")]
|
||||
Q8_0,
|
||||
#[value(name = "q8_1")]
|
||||
Q8_1,
|
||||
Q2k,
|
||||
Q3k,
|
||||
Q4k,
|
||||
Q5k,
|
||||
Q6k,
|
||||
Q8k,
|
||||
F16,
|
||||
F32,
|
||||
}
|
||||
|
||||
#[derive(ValueEnum, Debug, Clone)]
|
||||
enum Format {
|
||||
Safetensors,
|
||||
Npz,
|
||||
Ggml,
|
||||
Gguf,
|
||||
Pth,
|
||||
Pickle,
|
||||
}
|
||||
|
||||
impl Format {
|
||||
fn infer<P: AsRef<std::path::Path>>(p: P) -> Option<Self> {
|
||||
p.as_ref()
|
||||
.extension()
|
||||
.and_then(|e| e.to_str())
|
||||
.and_then(|e| match e {
|
||||
// We don't infer any format for .bin as it can be used for ggml/gguf or pytorch.
|
||||
"safetensors" | "safetensor" => Some(Self::Safetensors),
|
||||
"npz" => Some(Self::Npz),
|
||||
"pth" | "pt" => Some(Self::Pth),
|
||||
"ggml" => Some(Self::Ggml),
|
||||
"gguf" => Some(Self::Gguf),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
enum Command {
|
||||
Ls {
|
||||
files: Vec<std::path::PathBuf>,
|
||||
|
||||
/// The file format to use, if unspecified infer from the file extension.
|
||||
#[arg(long, value_enum)]
|
||||
format: Option<Format>,
|
||||
|
||||
/// Enable verbose mode.
|
||||
#[arg(short, long)]
|
||||
verbose: bool,
|
||||
},
|
||||
|
||||
Quantize {
|
||||
/// The input file, in gguf format.
|
||||
in_file: std::path::PathBuf,
|
||||
/// The output file, in gguf format.
|
||||
out_file: std::path::PathBuf,
|
||||
|
||||
/// The quantization schema to apply.
|
||||
#[arg(long, value_enum)]
|
||||
quantization: Quantization,
|
||||
|
||||
/// Which tensor to quantize.
|
||||
#[arg(long, value_enum, default_value_t = QuantizationMode::Llama)]
|
||||
mode: QuantizationMode,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
struct Args {
|
||||
#[command(subcommand)]
|
||||
command: Command,
|
||||
}
|
||||
|
||||
fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> Result<()> {
|
||||
let format = match format {
|
||||
Some(format) => format,
|
||||
None => match Format::infer(file) {
|
||||
Some(format) => format,
|
||||
None => {
|
||||
println!(
|
||||
"{file:?}: cannot infer format from file extension, use the --format flag"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
},
|
||||
};
|
||||
match format {
|
||||
Format::Npz => {
|
||||
let tensors = candle_core::npy::NpzTensors::new(file)?;
|
||||
let mut names = tensors.names();
|
||||
names.sort();
|
||||
for name in names {
|
||||
let shape_dtype = match tensors.get_shape_and_dtype(name) {
|
||||
Ok((shape, dtype)) => format!("[{shape:?}; {dtype:?}]"),
|
||||
Err(err) => err.to_string(),
|
||||
};
|
||||
println!("{name}: {shape_dtype}")
|
||||
}
|
||||
}
|
||||
Format::Safetensors => {
|
||||
let tensors = unsafe { candle_core::safetensors::MmapedFile::new(file)? };
|
||||
let tensors = tensors.deserialize()?;
|
||||
let mut tensors = tensors.tensors();
|
||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
for (name, view) in tensors.iter() {
|
||||
let dtype = view.dtype();
|
||||
let dtype = match candle_core::DType::try_from(dtype) {
|
||||
Ok(dtype) => format!("{dtype:?}"),
|
||||
Err(_) => format!("{dtype:?}"),
|
||||
};
|
||||
let shape = view.shape();
|
||||
println!("{name}: [{shape:?}; {dtype}]")
|
||||
}
|
||||
}
|
||||
Format::Pth => {
|
||||
let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose)?;
|
||||
tensors.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
for tensor_info in tensors.iter() {
|
||||
println!(
|
||||
"{}: [{:?}; {:?}]",
|
||||
tensor_info.name,
|
||||
tensor_info.layout.shape(),
|
||||
tensor_info.dtype,
|
||||
);
|
||||
if verbose {
|
||||
println!(" {:?}", tensor_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
Format::Pickle => {
|
||||
let file = std::fs::File::open(file)?;
|
||||
let mut reader = std::io::BufReader::new(file);
|
||||
let mut stack = candle_core::pickle::Stack::empty();
|
||||
stack.read_loop(&mut reader)?;
|
||||
for (i, obj) in stack.stack().iter().enumerate() {
|
||||
println!("{i} {obj:?}");
|
||||
}
|
||||
}
|
||||
Format::Ggml => {
|
||||
let mut file = std::fs::File::open(file)?;
|
||||
let content = candle_core::quantized::ggml_file::Content::read(&mut file)?;
|
||||
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
|
||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
for (name, qtensor) in tensors.iter() {
|
||||
println!("{name}: [{:?}; {:?}]", qtensor.shape(), qtensor.dtype());
|
||||
}
|
||||
}
|
||||
Format::Gguf => {
|
||||
let mut file = std::fs::File::open(file)?;
|
||||
let content = gguf_file::Content::read(&mut file)?;
|
||||
if verbose {
|
||||
let mut metadata = content.metadata.into_iter().collect::<Vec<_>>();
|
||||
metadata.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
println!("metadata entries ({})", metadata.len());
|
||||
for (key, value) in metadata.iter() {
|
||||
println!(" {key}: {value:?}");
|
||||
}
|
||||
}
|
||||
let mut tensors = content.tensor_infos.into_iter().collect::<Vec<_>>();
|
||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
for (name, info) in tensors.iter() {
|
||||
println!("{name}: [{:?}; {:?}]", info.shape, info.ggml_dtype);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_quantize(
|
||||
in_file: std::path::PathBuf,
|
||||
out_file: std::path::PathBuf,
|
||||
q: Quantization,
|
||||
qmode: QuantizationMode,
|
||||
) -> Result<()> {
|
||||
// Open the out file early so as to fail directly on missing directories etc.
|
||||
let mut out_file = std::fs::File::create(out_file)?;
|
||||
let mut in_ = std::fs::File::open(&in_file)?;
|
||||
let content = gguf_file::Content::read(&mut in_)?;
|
||||
println!("tensors: {}", content.tensor_infos.len());
|
||||
|
||||
let quantize_fn = match q {
|
||||
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
|
||||
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
|
||||
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
|
||||
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
|
||||
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
|
||||
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
|
||||
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
|
||||
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
|
||||
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
|
||||
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
|
||||
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
|
||||
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
|
||||
Quantization::F16 => QTensor::quantize::<half::f16>,
|
||||
Quantization::F32 => QTensor::quantize::<f32>,
|
||||
};
|
||||
|
||||
let qtensors = content
|
||||
.tensor_infos
|
||||
.par_iter()
|
||||
.map(|(name, _)| {
|
||||
println!(" quantizing {name}");
|
||||
let mut in_file = std::fs::File::open(&in_file)?;
|
||||
let tensor = content.tensor(&mut in_file, name)?;
|
||||
let tensor = qmode.quantize(name, tensor, quantize_fn)?;
|
||||
Ok((name, tensor))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let qtensors = qtensors
|
||||
.iter()
|
||||
.map(|(k, v)| (k.as_str(), v))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let metadata = content
|
||||
.metadata
|
||||
.iter()
|
||||
.map(|(k, v)| (k.as_str(), v))
|
||||
.collect::<Vec<_>>();
|
||||
gguf_file::write(&mut out_file, metadata.as_slice(), &qtensors)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
match args.command {
|
||||
Command::Ls {
|
||||
files,
|
||||
format,
|
||||
verbose,
|
||||
} => {
|
||||
let multiple_files = files.len() > 1;
|
||||
for file in files.iter() {
|
||||
if multiple_files {
|
||||
println!("--- {file:?} ---");
|
||||
}
|
||||
run_ls(file, format.clone(), verbose)?
|
||||
}
|
||||
}
|
||||
Command::Quantize {
|
||||
in_file,
|
||||
out_file,
|
||||
quantization,
|
||||
mode,
|
||||
} => run_quantize(in_file, out_file, quantization, mode)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -50,8 +50,6 @@ mod ffi {
|
||||
pub fn vvcos(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||
pub fn vvlogf(dst: *mut c_float, src: *const c_float, len: *const c_int);
|
||||
pub fn vvlog(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||
pub fn vvtanhf(dst: *mut c_float, src: *const c_float, len: *const c_int);
|
||||
pub fn vvtanh(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||
|
||||
pub fn vDSP_vaddD(
|
||||
_: *const c_double,
|
||||
@ -125,42 +123,6 @@ mod ffi {
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
pub fn vDSP_vminD(
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *mut c_double,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
pub fn vDSP_vmin(
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *mut c_float,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
pub fn vDSP_vmaxD(
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *mut c_double,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
pub fn vDSP_vmax(
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *mut c_float,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -310,26 +272,6 @@ pub fn vd_cos(a: &[f64], y: &mut [f64]) {
|
||||
}
|
||||
unsafe { ffi::vvcos(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
#[inline]
|
||||
pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvtanhf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvtanh(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_ln(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
@ -406,7 +348,3 @@ binary_op!(vs_mul, f32, vDSP_vmul);
|
||||
binary_op!(vd_mul, f64, vDSP_vmulD);
|
||||
binary_op!(vs_div, f32, vDSP_vdiv);
|
||||
binary_op!(vd_div, f64, vDSP_vdivD);
|
||||
binary_op!(vs_max, f32, vDSP_vmax);
|
||||
binary_op!(vd_max, f64, vDSP_vmaxD);
|
||||
binary_op!(vs_min, f32, vDSP_vmin);
|
||||
binary_op!(vd_min, f64, vDSP_vminD);
|
||||
|
@ -15,8 +15,6 @@ pub trait BackendStorage: Sized {
|
||||
|
||||
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self>;
|
||||
|
||||
fn powf(&self, _: &Layout, _: f64) -> Result<Self>;
|
||||
|
||||
fn elu(&self, _: &Layout, _: f64) -> Result<Self>;
|
||||
|
||||
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self>;
|
||||
@ -47,14 +45,6 @@ pub trait BackendStorage: Sized {
|
||||
_params: &crate::conv::ParamsConv2D,
|
||||
) -> Result<Self>;
|
||||
|
||||
fn conv_transpose2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &crate::conv::ParamsConvTranspose2D,
|
||||
) -> Result<Self>;
|
||||
|
||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
|
||||
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
|
||||
|
@ -60,11 +60,6 @@ impl Tensor {
|
||||
kernel: rhs,
|
||||
..
|
||||
}
|
||||
| Op::ConvTranspose2D {
|
||||
arg: lhs,
|
||||
kernel: rhs,
|
||||
..
|
||||
}
|
||||
| Op::CustomOp2(lhs, rhs, _)
|
||||
| Op::Binary(lhs, rhs, _)
|
||||
| Op::Gather(lhs, rhs, _)
|
||||
@ -101,11 +96,9 @@ impl Tensor {
|
||||
| Op::ToDType(node)
|
||||
| Op::ToDevice(node)
|
||||
| Op::Transpose(node, _, _)
|
||||
| Op::Permute(node, _)
|
||||
| Op::Narrow(node, _, _, _)
|
||||
| Op::Unary(node, _)
|
||||
| Op::Elu(node, _)
|
||||
| Op::Powf(node, _)
|
||||
| Op::CustomOp1(node, _) => {
|
||||
let (tg, nodes) = walk(node, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
@ -168,21 +161,6 @@ impl Tensor {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?;
|
||||
}
|
||||
Op::Binary(lhs, rhs, BinaryOp::Minimum)
|
||||
| Op::Binary(lhs, rhs, BinaryOp::Maximum) => {
|
||||
let mask_lhs = node.eq(lhs)?.to_dtype(grad.dtype())?;
|
||||
let mask_rhs = node.eq(rhs)?.to_dtype(grad.dtype())?;
|
||||
|
||||
// If both masks are 1 one the same point, we want to scale the
|
||||
// gradient by 0.5 rather than 1.
|
||||
let lhs_grad = mask_lhs.mul(&grad)?.div(&(&mask_rhs + 1.)?)?;
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||
|
||||
let rhs_grad = mask_rhs.mul(&grad)?.div(&(&mask_lhs + 1.)?)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
}
|
||||
Op::WhereCond(pred, t, f) => {
|
||||
let zeros = grad.zeros_like()?;
|
||||
let t_sum_grad = grads.or_insert(t)?;
|
||||
@ -193,75 +171,9 @@ impl Tensor {
|
||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||
}
|
||||
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
||||
Op::Conv2D {
|
||||
arg,
|
||||
kernel,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
} => {
|
||||
// The output height for conv_transpose2d is:
|
||||
// (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1
|
||||
let grad_h = grad.dim(2)?;
|
||||
let k_h = kernel.dim(2)?;
|
||||
let out_size =
|
||||
(grad_h - 1) * stride + dilation * (k_h - 1) + 1 - 2 * padding;
|
||||
let out_padding = arg.dim(2)? - out_size;
|
||||
let grad_arg = grad.conv_transpose2d(
|
||||
kernel,
|
||||
*padding,
|
||||
out_padding,
|
||||
*stride,
|
||||
*dilation,
|
||||
)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
|
||||
let grad_kernel = arg
|
||||
.transpose(0, 1)?
|
||||
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
||||
.transpose(0, 1)?;
|
||||
let sum_grad = grads.or_insert(kernel)?;
|
||||
*sum_grad = sum_grad.add(&grad_kernel)?;
|
||||
}
|
||||
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "conv-transpose2d",
|
||||
})?,
|
||||
Op::AvgPool2D {
|
||||
arg,
|
||||
kernel_size,
|
||||
stride,
|
||||
} => {
|
||||
if kernel_size != stride {
|
||||
crate::bail!("backward not supported for avgpool2d if ksize {kernel_size:?} != stride {stride:?}")
|
||||
}
|
||||
let (_n, _c, h, w) = arg.dims4()?;
|
||||
let grad_arg = grad.upsample_nearest2d(h, w)?;
|
||||
let grad_arg =
|
||||
(grad_arg * (1f64 / (kernel_size.0 * kernel_size.1) as f64))?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
}
|
||||
Op::MaxPool2D {
|
||||
arg,
|
||||
kernel_size,
|
||||
stride,
|
||||
} => {
|
||||
if kernel_size != stride {
|
||||
crate::bail!("backward not supported for maxpool2d if ksize {kernel_size:?} != stride {stride:?}")
|
||||
}
|
||||
let (_n, _c, h, w) = arg.dims4()?;
|
||||
// For computing the max-pool gradient, we compute a mask where a 1 means
|
||||
// that the element is the maximum, then we apply this mask to the
|
||||
// upsampled gradient (taking into account that multiple max may exist so
|
||||
// we scale the gradient for this case).
|
||||
let node_upsampled = node.upsample_nearest2d(h, w)?;
|
||||
let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?;
|
||||
let avg = mask.avg_pool2d_with_stride(*kernel_size, *stride)?;
|
||||
let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
}
|
||||
Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?,
|
||||
Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?,
|
||||
Op::MaxPool2D { .. } => Err(Error::BackwardNotSupported { op: "max-pool2d" })?,
|
||||
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "upsample-nearest2d",
|
||||
})?,
|
||||
@ -379,11 +291,6 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Tanh) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let minus_dtanh = (node.sqr()? - 1.)?;
|
||||
*sum_grad = sum_grad.sub(&(&grad * &minus_dtanh)?)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Abs) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let ones = arg.ones_like()?;
|
||||
@ -443,11 +350,6 @@ impl Tensor {
|
||||
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
||||
}
|
||||
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
|
||||
Op::Powf(arg, e) => {
|
||||
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::CustomOp1(arg, c) => {
|
||||
if let Some(arg_grad) = c.bwd(arg, node, &grad)? {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
@ -501,15 +403,6 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Permute(arg, dims) => {
|
||||
let mut inv_dims = vec![0; dims.len()];
|
||||
for (i, &dim_idx) in dims.iter().enumerate() {
|
||||
inv_dims[dim_idx] = i
|
||||
}
|
||||
let arg_grad = grad.permute(inv_dims)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,3 @@
|
||||
use crate::{op::BackpropOp, op::Op, Error, Result, Tensor};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ParamsConv1D {
|
||||
pub(crate) b_size: usize,
|
||||
@ -11,12 +9,12 @@ pub struct ParamsConv1D {
|
||||
pub(crate) k_size: usize,
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
pub(crate) dilation: usize,
|
||||
}
|
||||
|
||||
impl ParamsConv1D {
|
||||
pub(crate) fn l_out(&self) -> usize {
|
||||
(self.l_in + 2 * self.padding - self.dilation * (self.k_size - 1) - 1) / self.stride + 1
|
||||
let dilation = 1;
|
||||
(self.l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1
|
||||
}
|
||||
|
||||
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
||||
@ -36,215 +34,20 @@ pub struct ParamsConv2D {
|
||||
pub(crate) c_in: usize,
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
pub(crate) dilation: usize,
|
||||
}
|
||||
|
||||
impl ParamsConv2D {
|
||||
pub(crate) fn out_h(&self) -> usize {
|
||||
(self.i_h + 2 * self.padding - self.dilation * (self.k_h - 1) - 1) / self.stride + 1
|
||||
let dilation = 1;
|
||||
(self.i_h + 2 * self.padding - dilation * (self.k_h - 1) - 1) / self.stride + 1
|
||||
}
|
||||
|
||||
pub(crate) fn out_w(&self) -> usize {
|
||||
(self.i_w + 2 * self.padding - self.dilation * (self.k_w - 1) - 1) / self.stride + 1
|
||||
let dilation = 1;
|
||||
(self.i_w + 2 * self.padding - dilation * (self.k_w - 1) - 1) / self.stride + 1
|
||||
}
|
||||
|
||||
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
||||
vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ParamsConvTranspose2D {
|
||||
pub(crate) b_size: usize,
|
||||
pub(crate) i_h: usize,
|
||||
pub(crate) i_w: usize,
|
||||
pub(crate) k_h: usize,
|
||||
pub(crate) k_w: usize,
|
||||
pub(crate) c_out: usize,
|
||||
pub(crate) c_in: usize,
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) output_padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
pub(crate) dilation: usize,
|
||||
}
|
||||
|
||||
impl ParamsConvTranspose2D {
|
||||
pub(crate) fn out_h(&self) -> usize {
|
||||
(self.i_h - 1) * self.stride + self.dilation * (self.k_h - 1) + self.output_padding + 1
|
||||
- 2 * self.padding
|
||||
}
|
||||
|
||||
pub(crate) fn out_w(&self) -> usize {
|
||||
(self.i_w - 1) * self.stride + self.dilation * (self.k_w - 1) + self.output_padding + 1
|
||||
- 2 * self.padding
|
||||
}
|
||||
|
||||
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
||||
vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
|
||||
}
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
fn conv1d_single_group(&self, kernel: &Self, params: &ParamsConv1D) -> Result<Self> {
|
||||
let storage =
|
||||
self.storage()
|
||||
.conv1d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
|
||||
arg,
|
||||
kernel,
|
||||
padding: params.padding,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
/// Applies a 1D convolution over the input tensor.
|
||||
pub fn conv1d(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
let (c_out, c_in_k, k_size) = kernel.dims3()?;
|
||||
let (b_size, c_in, l_in) = self.dims3()?;
|
||||
if c_in != c_in_k * groups {
|
||||
Err(Error::Conv1dInvalidArgs {
|
||||
inp_shape: self.shape().clone(),
|
||||
k_shape: kernel.shape().clone(),
|
||||
padding,
|
||||
stride,
|
||||
msg: "the number of in-channels on the input doesn't match the kernel size",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
|
||||
let params = ParamsConv1D {
|
||||
b_size,
|
||||
l_in,
|
||||
c_out: c_out / groups,
|
||||
c_in: c_in / groups,
|
||||
k_size,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv1d_single_group(kernel, ¶ms)
|
||||
} else {
|
||||
let blocks = self.chunk(groups, 1)?;
|
||||
let kernel = kernel.chunk(groups, 0)?;
|
||||
let blocks = blocks
|
||||
.iter()
|
||||
.zip(&kernel)
|
||||
.map(|(block, kernel)| block.conv1d_single_group(kernel, ¶ms))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Tensor::cat(&blocks, 1)
|
||||
}
|
||||
}
|
||||
|
||||
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
|
||||
let storage =
|
||||
self.storage()
|
||||
.conv2d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
|
||||
arg,
|
||||
kernel,
|
||||
padding: params.padding,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
/// Applies a 2D convolution over the input tensor.
|
||||
pub fn conv2d(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
||||
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
|
||||
if c_in != c_in_k * groups {
|
||||
crate::bail!(
|
||||
"in_channel mismatch between input ({c_in}, groups {groups}) and kernel ({c_in_k})"
|
||||
)
|
||||
}
|
||||
let params = ParamsConv2D {
|
||||
b_size,
|
||||
i_h,
|
||||
i_w,
|
||||
k_h,
|
||||
k_w,
|
||||
c_out: c_out / groups,
|
||||
c_in: c_in / groups,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv2d_single_group(kernel, ¶ms)
|
||||
} else {
|
||||
let blocks = self.chunk(groups, 1)?;
|
||||
let kernel = kernel.chunk(groups, 0)?;
|
||||
let blocks = blocks
|
||||
.iter()
|
||||
.zip(&kernel)
|
||||
.map(|(block, kernel)| block.conv2d_single_group(kernel, ¶ms))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Tensor::cat(&blocks, 1)
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies a 2D transposed convolution over the input tensor.
|
||||
pub fn conv_transpose2d(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
) -> Result<Self> {
|
||||
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
||||
let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?;
|
||||
if c_in != c_in_k {
|
||||
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||
}
|
||||
let params = ParamsConvTranspose2D {
|
||||
b_size,
|
||||
i_h,
|
||||
i_w,
|
||||
k_h,
|
||||
k_w,
|
||||
c_out,
|
||||
c_in,
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
dilation,
|
||||
};
|
||||
let storage = self.storage().conv_transpose2d(
|
||||
self.layout(),
|
||||
&kernel.storage(),
|
||||
kernel.layout(),
|
||||
¶ms,
|
||||
)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose2D {
|
||||
arg,
|
||||
kernel,
|
||||
padding: params.padding,
|
||||
output_padding: params.output_padding,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
}
|
||||
|
@ -92,7 +92,6 @@ from_tensor!(f64);
|
||||
from_tensor!(f32);
|
||||
from_tensor!(f16);
|
||||
from_tensor!(bf16);
|
||||
from_tensor!(i64);
|
||||
from_tensor!(u32);
|
||||
from_tensor!(u8);
|
||||
|
||||
@ -130,11 +129,6 @@ impl Tensor {
|
||||
f.write_u32::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::I64 => {
|
||||
for v in vs.to_vec1::<i64>()? {
|
||||
f.write_i64::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::U8 => {
|
||||
let vs = vs.to_vec1::<u8>()?;
|
||||
f.write_all(&vs)?;
|
||||
|
@ -103,7 +103,7 @@ impl CpuF16<ARR> for CurrentCpuF16 {
|
||||
for i in 0..8 {
|
||||
tmp[i] = (*mem_addr.add(i)).to_f32();
|
||||
}
|
||||
_mm256_loadu_ps(tmp.as_ptr())
|
||||
_mm_loadu_ps(tmp.as_ptr())
|
||||
}
|
||||
|
||||
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
|
||||
|
@ -1,7 +1,4 @@
|
||||
pub trait VecOps: num_traits::NumAssign + Copy {
|
||||
fn min(self, rhs: Self) -> Self;
|
||||
fn max(self, rhs: Self) -> Self;
|
||||
|
||||
/// Dot-product of two vectors.
|
||||
///
|
||||
/// # Safety
|
||||
@ -29,47 +26,9 @@ pub trait VecOps: num_traits::NumAssign + Copy {
|
||||
*res += *xs.add(i)
|
||||
}
|
||||
}
|
||||
|
||||
/// Maximum element in a non-empty vector.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
|
||||
/// element.
|
||||
#[inline(always)]
|
||||
unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) {
|
||||
*res = *xs;
|
||||
for i in 1..len {
|
||||
*res = (*res).max(*xs.add(i))
|
||||
}
|
||||
}
|
||||
|
||||
/// Minimum element in a non-empty vector.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
|
||||
/// element.
|
||||
#[inline(always)]
|
||||
unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) {
|
||||
*res = *xs;
|
||||
for i in 1..len {
|
||||
*res = (*res).min(*xs.add(i))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VecOps for f32 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
Self::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
Self::max(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
|
||||
super::vec_dot_f32(lhs, rhs, res, len)
|
||||
@ -82,16 +41,6 @@ impl VecOps for f32 {
|
||||
}
|
||||
|
||||
impl VecOps for half::f16 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
Self::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
Self::max(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
|
||||
let mut res_f32 = 0f32;
|
||||
@ -100,61 +49,10 @@ impl VecOps for half::f16 {
|
||||
}
|
||||
}
|
||||
|
||||
impl VecOps for f64 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
Self::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
Self::max(self, other)
|
||||
}
|
||||
}
|
||||
impl VecOps for half::bf16 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
Self::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
Self::max(self, other)
|
||||
}
|
||||
}
|
||||
impl VecOps for u8 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
<Self as Ord>::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
<Self as Ord>::max(self, other)
|
||||
}
|
||||
}
|
||||
impl VecOps for u32 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
<Self as Ord>::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
<Self as Ord>::max(self, other)
|
||||
}
|
||||
}
|
||||
impl VecOps for i64 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
<Self as Ord>::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
<Self as Ord>::max(self, other)
|
||||
}
|
||||
}
|
||||
impl VecOps for f64 {}
|
||||
impl VecOps for half::bf16 {}
|
||||
impl VecOps for u8 {}
|
||||
impl VecOps for u32 {}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
|
||||
|
@ -2,7 +2,6 @@ use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
|
||||
use half::{bf16, f16};
|
||||
use rayon::prelude::*;
|
||||
|
||||
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
||||
// intercept the oom errors to avoid panicking and provide a proper error.
|
||||
@ -10,7 +9,6 @@ use rayon::prelude::*;
|
||||
pub enum CpuStorage {
|
||||
U8(Vec<u8>),
|
||||
U32(Vec<u32>),
|
||||
I64(Vec<i64>),
|
||||
BF16(Vec<bf16>),
|
||||
F16(Vec<f16>),
|
||||
F32(Vec<f32>),
|
||||
@ -27,7 +25,6 @@ pub trait Map1 {
|
||||
match vs {
|
||||
CpuStorage::U8(vs) => Ok(CpuStorage::U8(self.f(vs, layout)?)),
|
||||
CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)),
|
||||
CpuStorage::I64(vs) => Ok(CpuStorage::I64(self.f(vs, layout)?)),
|
||||
CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)),
|
||||
CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)),
|
||||
CpuStorage::F32(vs) => Ok(CpuStorage::F32(self.f(vs, layout)?)),
|
||||
@ -48,7 +45,6 @@ pub trait Map1Any {
|
||||
match vs {
|
||||
CpuStorage::U8(vs) => Ok(self.f(vs, layout, CpuStorage::U8)?),
|
||||
CpuStorage::U32(vs) => Ok(self.f(vs, layout, CpuStorage::U32)?),
|
||||
CpuStorage::I64(vs) => Ok(self.f(vs, layout, CpuStorage::I64)?),
|
||||
CpuStorage::BF16(vs) => Ok(self.f(vs, layout, CpuStorage::BF16)?),
|
||||
CpuStorage::F16(vs) => Ok(self.f(vs, layout, CpuStorage::F16)?),
|
||||
CpuStorage::F32(vs) => Ok(self.f(vs, layout, CpuStorage::F32)?),
|
||||
@ -72,7 +68,6 @@ pub trait Map2 {
|
||||
match (v1, v2) {
|
||||
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
|
||||
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
|
||||
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
|
||||
@ -101,7 +96,6 @@ pub trait Map2U8 {
|
||||
match (v1, v2) {
|
||||
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
@ -292,9 +286,10 @@ struct ReduceSum<'a> {
|
||||
|
||||
impl<'a> ReduceSum<'a> {
|
||||
#[inline(always)]
|
||||
fn fold_impl<T>(&self, src: &[T], src_l: &Layout, start_elt: T) -> Result<Vec<T>>
|
||||
fn fold_impl<T, F>(&self, src: &[T], src_l: &Layout, start_elt: T, f: F) -> Result<Vec<T>>
|
||||
where
|
||||
T: WithDType,
|
||||
F: Fn(T, T) -> T,
|
||||
{
|
||||
let mut dst = vec![start_elt; self.dst_shape.elem_count()];
|
||||
match src_l.contiguous_offsets() {
|
||||
@ -335,7 +330,7 @@ impl<'a> ReduceSum<'a> {
|
||||
let (pre, post) = (dst_index / stride, dst_index % stride);
|
||||
dst_index = (pre / dim) * stride + post;
|
||||
}
|
||||
dst[dst_index] += src;
|
||||
dst[dst_index] = f(dst[dst_index], src);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
@ -347,7 +342,7 @@ impl<'a> ReduceSum<'a> {
|
||||
let (pre, post) = (dst_index / stride, dst_index % stride);
|
||||
dst_index = (pre / dim) * stride + post;
|
||||
}
|
||||
dst[dst_index] += src[src_index];
|
||||
dst[dst_index] = f(dst[dst_index], src[src_index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -358,7 +353,7 @@ impl<'a> ReduceSum<'a> {
|
||||
impl<'a> Map1 for ReduceSum<'a> {
|
||||
#[inline(always)]
|
||||
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
self.fold_impl(src, src_l, T::zero())
|
||||
self.fold_impl(src, src_l, T::zero(), |x, y| x + y)
|
||||
}
|
||||
}
|
||||
|
||||
@ -446,7 +441,7 @@ pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U
|
||||
}
|
||||
|
||||
// This function maps over two strided index sequences.
|
||||
pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
|
||||
fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
lhs: &[T],
|
||||
@ -526,7 +521,7 @@ pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
|
||||
}
|
||||
|
||||
// Similar to binary_map but with vectorized variants.
|
||||
pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
|
||||
fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
lhs: &[T],
|
||||
@ -1053,8 +1048,10 @@ impl<'a> Map2 for Conv1D<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
let num_threads = crate::utils::get_num_threads();
|
||||
|
||||
for offset in 0..p.k_size {
|
||||
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
|
||||
crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
|
||||
let dst_idx = dst_c_idx * l_out;
|
||||
let k_cont = (0..p.c_in)
|
||||
.map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
|
||||
@ -1063,7 +1060,7 @@ impl<'a> Map2 for Conv1D<'a> {
|
||||
let dst_idx = dst_idx + b_idx * p.c_out * l_out;
|
||||
for dst_l in 0..l_out {
|
||||
let dst_idx = dst_idx + dst_l;
|
||||
let src_l = p.stride * dst_l + offset * p.dilation;
|
||||
let src_l = p.stride * dst_l + offset;
|
||||
if src_l < p.padding || src_l >= p.padding + p.l_in {
|
||||
continue;
|
||||
}
|
||||
@ -1122,9 +1119,11 @@ impl<'a> Map2 for Conv2D<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
let num_threads = crate::utils::get_num_threads();
|
||||
|
||||
for offset_h in 0..p.k_h {
|
||||
for offset_w in 0..p.k_w {
|
||||
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
|
||||
crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
|
||||
let dst_idx = dst_c_idx * out_w * out_h;
|
||||
let k_cont = (0..p.c_in)
|
||||
.map(|c_in_idx| {
|
||||
@ -1138,14 +1137,14 @@ impl<'a> Map2 for Conv2D<'a> {
|
||||
let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w;
|
||||
for dst_h in 0..out_h {
|
||||
let dst_idx = dst_idx + dst_h * out_w;
|
||||
let src_h = p.stride * dst_h + offset_h * p.dilation;
|
||||
let src_h = p.stride * dst_h + offset_h;
|
||||
if src_h < p.padding || src_h >= p.i_h + p.padding {
|
||||
continue;
|
||||
}
|
||||
let src_h = src_h - p.padding;
|
||||
for dst_w in 0..out_w {
|
||||
let dst_idx = dst_idx + dst_w;
|
||||
let src_w = p.stride * dst_w + offset_w * p.dilation;
|
||||
let src_w = p.stride * dst_w + offset_w;
|
||||
if src_w < p.padding || src_w >= p.i_w + p.padding {
|
||||
continue;
|
||||
}
|
||||
@ -1177,96 +1176,6 @@ impl<'a> Map2 for Conv2D<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
|
||||
|
||||
impl<'a> Map2 for ConvTranspose2D<'a> {
|
||||
const OP: &'static str = "conv_transpose2d";
|
||||
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
||||
let p = self.0;
|
||||
let inp = &inp[inp_l.start_offset()..];
|
||||
let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
|
||||
let k = &k[k_l.start_offset()..];
|
||||
let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
|
||||
let (out_h, out_w) = (p.out_h(), p.out_w());
|
||||
|
||||
// Output shape: [b_size, c_out, out_h, out_w].
|
||||
let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
|
||||
let dst_s0 = p.c_out * out_h * out_w;
|
||||
let dst_s1 = out_h * out_w;
|
||||
let dst_s2 = out_w;
|
||||
let dst_s3 = 1;
|
||||
|
||||
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
|
||||
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
|
||||
let cont_s0 = p.i_h * p.i_w * p.c_in;
|
||||
let cont_s1 = p.i_w * p.c_in;
|
||||
let cont_s2 = p.c_in;
|
||||
for b_idx in 0..p.b_size {
|
||||
for h_idx in 0..p.i_h {
|
||||
for w_idx in 0..p.i_w {
|
||||
for c_idx in 0..p.c_in {
|
||||
let src_idx =
|
||||
b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
|
||||
let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
|
||||
inp_cont[dst_idx] = inp[src_idx]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for k_y in 0..p.k_h {
|
||||
for k_x in 0..p.k_w {
|
||||
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
|
||||
let k_cont = (0..p.c_in)
|
||||
.map(|c_in_idx| {
|
||||
k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3]
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
for b_idx in 0..p.b_size {
|
||||
for inp_y in 0..p.i_h {
|
||||
for inp_x in 0..p.i_w {
|
||||
let out_x = inp_x * p.stride + k_x * p.dilation;
|
||||
let out_y = inp_y * p.stride + k_y * p.dilation;
|
||||
if out_x < p.padding || out_y < p.padding {
|
||||
continue;
|
||||
}
|
||||
let out_x = out_x - p.padding;
|
||||
let out_y = out_y - p.padding;
|
||||
if out_x < out_w && out_y < out_h {
|
||||
let inp_cont = &inp_cont
|
||||
[b_idx * cont_s0 + inp_y * cont_s1 + inp_x * cont_s2..];
|
||||
let dst_idx = b_idx * dst_s0
|
||||
+ out_y * dst_s2
|
||||
+ out_x * dst_s3
|
||||
+ dst_c_idx * dst_s1;
|
||||
let mut d = T::zero();
|
||||
unsafe {
|
||||
T::vec_dot(
|
||||
inp_cont.as_ptr(),
|
||||
k_cont.as_ptr(),
|
||||
&mut d,
|
||||
p.c_in,
|
||||
)
|
||||
}
|
||||
let dst_p = dst.as_ptr();
|
||||
// Safety: dst_idx are uniques per dst_c_idx which is used to
|
||||
// parallelise the different tasks so no two threads can try to
|
||||
// write at the same location.
|
||||
unsafe {
|
||||
let ptr = dst_p.add(dst_idx) as *mut T;
|
||||
*ptr += d
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct MatMul((usize, usize, usize, usize));
|
||||
|
||||
impl MatMul {
|
||||
@ -1293,11 +1202,6 @@ impl Map2 for MatMul {
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Vec<T>> {
|
||||
use gemm::{gemm, Parallelism};
|
||||
|
||||
if T::DTYPE == DType::BF16 {
|
||||
return Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?;
|
||||
}
|
||||
|
||||
let (b, m, n, k) = self.0;
|
||||
let lhs = &lhs[lhs_l.start_offset()..];
|
||||
let rhs = &rhs[rhs_l.start_offset()..];
|
||||
@ -1614,90 +1518,6 @@ impl CpuStorage {
|
||||
pub fn as_slice<D: WithDType>(&self) -> Result<&[D]> {
|
||||
D::cpu_storage_as_slice(self)
|
||||
}
|
||||
|
||||
pub fn concat(storages: &[CpuStorage]) -> Result<CpuStorage> {
|
||||
let storage0 = &storages[0];
|
||||
let s = match storage0 {
|
||||
Self::U8(_) => {
|
||||
let storages = storages
|
||||
.iter()
|
||||
.map(|s| match s {
|
||||
Self::U8(s) => Ok(s.as_slice()),
|
||||
_ => crate::bail!("dtype mismatch"),
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
.concat();
|
||||
Self::U8(storages)
|
||||
}
|
||||
Self::U32(_) => {
|
||||
let storages = storages
|
||||
.iter()
|
||||
.map(|s| match s {
|
||||
Self::U32(s) => Ok(s.as_slice()),
|
||||
_ => crate::bail!("dtype mismatch"),
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
.concat();
|
||||
Self::U32(storages)
|
||||
}
|
||||
Self::I64(_) => {
|
||||
let storages = storages
|
||||
.iter()
|
||||
.map(|s| match s {
|
||||
Self::I64(s) => Ok(s.as_slice()),
|
||||
_ => crate::bail!("dtype mismatch"),
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
.concat();
|
||||
Self::I64(storages)
|
||||
}
|
||||
Self::BF16(_) => {
|
||||
let storages = storages
|
||||
.iter()
|
||||
.map(|s| match s {
|
||||
Self::BF16(s) => Ok(s.as_slice()),
|
||||
_ => crate::bail!("dtype mismatch"),
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
.concat();
|
||||
Self::BF16(storages)
|
||||
}
|
||||
Self::F16(_) => {
|
||||
let storages = storages
|
||||
.iter()
|
||||
.map(|s| match s {
|
||||
Self::F16(s) => Ok(s.as_slice()),
|
||||
_ => crate::bail!("dtype mismatch"),
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
.concat();
|
||||
Self::F16(storages)
|
||||
}
|
||||
Self::F32(_) => {
|
||||
let storages = storages
|
||||
.iter()
|
||||
.map(|s| match s {
|
||||
Self::F32(s) => Ok(s.as_slice()),
|
||||
_ => crate::bail!("dtype mismatch"),
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
.concat();
|
||||
Self::F32(storages)
|
||||
}
|
||||
Self::F64(_) => {
|
||||
let storages = storages
|
||||
.iter()
|
||||
.map(|s| match s {
|
||||
Self::F64(s) => Ok(s.as_slice()),
|
||||
_ => crate::bail!("dtype mismatch"),
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
.concat();
|
||||
Self::F64(storages)
|
||||
}
|
||||
};
|
||||
Ok(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendStorage for CpuStorage {
|
||||
@ -1707,7 +1527,6 @@ impl BackendStorage for CpuStorage {
|
||||
match self {
|
||||
Self::U8(_) => DType::U8,
|
||||
Self::U32(_) => DType::U32,
|
||||
Self::I64(_) => DType::I64,
|
||||
Self::BF16(_) => DType::BF16,
|
||||
Self::F16(_) => DType::F16,
|
||||
Self::F32(_) => DType::F32,
|
||||
@ -1726,10 +1545,6 @@ impl BackendStorage for CpuStorage {
|
||||
let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::I64(storage), DType::BF16) => {
|
||||
let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::BF16) => {
|
||||
let data = unary_map(storage, layout, |v| v);
|
||||
Ok(Self::BF16(data))
|
||||
@ -1754,10 +1569,6 @@ impl BackendStorage for CpuStorage {
|
||||
let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::I64(storage), DType::F16) => {
|
||||
let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::F16) => {
|
||||
let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
|
||||
Ok(Self::F16(data))
|
||||
@ -1782,10 +1593,6 @@ impl BackendStorage for CpuStorage {
|
||||
let data = unary_map(storage, layout, |v| v as f32);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::I64(storage), DType::F32) => {
|
||||
let data = unary_map(storage, layout, |v| v as f32);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::F32) => {
|
||||
let data = unary_map(storage, layout, |v| v.to_f32());
|
||||
Ok(Self::F32(data))
|
||||
@ -1822,26 +1629,18 @@ impl BackendStorage for CpuStorage {
|
||||
let data = unary_map(storage, layout, |v| v as u8);
|
||||
Ok(Self::U8(data))
|
||||
}
|
||||
(Self::U32(storage), DType::U8) => {
|
||||
let data = unary_map(storage, layout, |v| v as u8);
|
||||
Ok(Self::U8(data))
|
||||
}
|
||||
(Self::I64(storage), DType::U8) => {
|
||||
let data = unary_map(storage, layout, |v| v as u8);
|
||||
Ok(Self::U8(data))
|
||||
}
|
||||
(Self::U8(storage), DType::U32) => {
|
||||
let data = unary_map(storage, layout, |v| v as u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::U32(storage), DType::U8) => {
|
||||
let data = unary_map(storage, layout, |v| v as u8);
|
||||
Ok(Self::U8(data))
|
||||
}
|
||||
(Self::U32(storage), DType::U32) => {
|
||||
let data = unary_map(storage, layout, |v| v);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::I64(storage), DType::U32) => {
|
||||
let data = unary_map(storage, layout, |v| v as u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::U32) => {
|
||||
let data = unary_map(storage, layout, |v| v.to_f32() as u32);
|
||||
Ok(Self::U32(data))
|
||||
@ -1858,34 +1657,6 @@ impl BackendStorage for CpuStorage {
|
||||
let data = unary_map(storage, layout, |v| v as u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::U8(storage), DType::I64) => {
|
||||
let data = unary_map(storage, layout, |v| v as i64);
|
||||
Ok(Self::I64(data))
|
||||
}
|
||||
(Self::U32(storage), DType::I64) => {
|
||||
let data = unary_map(storage, layout, |v| v as i64);
|
||||
Ok(Self::I64(data))
|
||||
}
|
||||
(Self::I64(storage), DType::I64) => {
|
||||
let data = unary_map(storage, layout, |v| v);
|
||||
Ok(Self::I64(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::I64) => {
|
||||
let data = unary_map(storage, layout, |v| v.to_f32() as i64);
|
||||
Ok(Self::I64(data))
|
||||
}
|
||||
(Self::F16(storage), DType::I64) => {
|
||||
let data = unary_map(storage, layout, |v| v.to_f32() as i64);
|
||||
Ok(Self::I64(data))
|
||||
}
|
||||
(Self::F32(storage), DType::I64) => {
|
||||
let data = unary_map(storage, layout, |v| v as i64);
|
||||
Ok(Self::I64(data))
|
||||
}
|
||||
(Self::F64(storage), DType::I64) => {
|
||||
let data = unary_map(storage, layout, |v| v as i64);
|
||||
Ok(Self::I64(data))
|
||||
}
|
||||
(Self::U8(storage), DType::F64) => {
|
||||
let data = unary_map(storage, layout, |v| v as f64);
|
||||
Ok(Self::F64(data))
|
||||
@ -1894,10 +1665,6 @@ impl BackendStorage for CpuStorage {
|
||||
let data = unary_map(storage, layout, |v| v as f64);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::I64(storage), DType::F64) => {
|
||||
let data = unary_map(storage, layout, |v| v as f64);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::F64) => {
|
||||
let data = unary_map(storage, layout, |v| v.to_f64());
|
||||
Ok(Self::F64(data))
|
||||
@ -2003,32 +1770,6 @@ impl BackendStorage for CpuStorage {
|
||||
UpsampleNearest2D(h, w).map(self, layout)
|
||||
}
|
||||
|
||||
fn powf(&self, layout: &Layout, e: f64) -> Result<Self> {
|
||||
use num_traits::Float;
|
||||
// TODO: Have some generic map for functions that apply on num_traits::Float elements.
|
||||
match self {
|
||||
Self::BF16(storage) => {
|
||||
let data = unary_map(storage, layout, |v| v.powf(bf16::from_f64(e)));
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
Self::F16(storage) => {
|
||||
let data = unary_map(storage, layout, |v| v.powf(f16::from_f64(e)));
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
Self::F32(storage) => {
|
||||
let data = unary_map(storage, layout, |v| v.powf(e as f32));
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
Self::F64(storage) => {
|
||||
let data = unary_map(storage, layout, |v| v.powf(e));
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
|
||||
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
|
||||
Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
|
||||
}
|
||||
}
|
||||
|
||||
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
|
||||
// TODO: Have some generic map for functions that apply on num_traits::Float elements.
|
||||
match self {
|
||||
@ -2050,7 +1791,6 @@ impl BackendStorage for CpuStorage {
|
||||
}
|
||||
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
|
||||
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
|
||||
Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -2100,10 +1840,6 @@ impl BackendStorage for CpuStorage {
|
||||
let data = unary_map(storage, layout, B::u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
Self::I64(storage) => {
|
||||
let data = unary_map(storage, layout, B::i64);
|
||||
Ok(Self::I64(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -2154,14 +1890,6 @@ impl BackendStorage for CpuStorage {
|
||||
};
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::I64(lhs), Self::I64(rhs)) => {
|
||||
let data = if B::I64_VEC {
|
||||
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i64, B::i64_vec)
|
||||
} else {
|
||||
binary_map(lhs_l, rhs_l, lhs, rhs, B::i64)
|
||||
};
|
||||
Ok(Self::I64(data))
|
||||
}
|
||||
(Self::U8(lhs), Self::U8(rhs)) => {
|
||||
let data = if B::U8_VEC {
|
||||
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u8, B::u8_vec)
|
||||
@ -2186,7 +1914,6 @@ impl BackendStorage for CpuStorage {
|
||||
match (self, dst) {
|
||||
(Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
(Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
(Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
(Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
(Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
(Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
@ -2215,7 +1942,6 @@ impl BackendStorage for CpuStorage {
|
||||
match self {
|
||||
Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
|
||||
Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
|
||||
Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")),
|
||||
}
|
||||
}
|
||||
@ -2240,21 +1966,10 @@ impl BackendStorage for CpuStorage {
|
||||
Conv2D(params).map(self, l, kernel, kernel_l)
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConvTranspose2D,
|
||||
) -> Result<Self> {
|
||||
ConvTranspose2D(params).map(self, l, kernel, kernel_l)
|
||||
}
|
||||
|
||||
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
match ids {
|
||||
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||
Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||
Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select")),
|
||||
}
|
||||
}
|
||||
@ -2263,7 +1978,6 @@ impl BackendStorage for CpuStorage {
|
||||
match ids {
|
||||
Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||
Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||
Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather")),
|
||||
}
|
||||
}
|
||||
@ -2280,7 +1994,6 @@ impl BackendStorage for CpuStorage {
|
||||
match ids {
|
||||
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add")),
|
||||
}
|
||||
}
|
||||
@ -2309,13 +2022,6 @@ impl BackendStorage for CpuStorage {
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
Self::I64(ids) => {
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")),
|
||||
}
|
||||
}
|
||||
@ -2368,9 +2074,7 @@ impl BackendDevice for CpuDevice {
|
||||
let elem_count = shape.elem_count();
|
||||
let mut rng = rand::thread_rng();
|
||||
match dtype {
|
||||
DType::U8 | DType::U32 | DType::I64 => {
|
||||
Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt())
|
||||
}
|
||||
DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()),
|
||||
DType::BF16 => {
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let uniform =
|
||||
@ -2414,9 +2118,7 @@ impl BackendDevice for CpuDevice {
|
||||
let elem_count = shape.elem_count();
|
||||
let mut rng = rand::thread_rng();
|
||||
match dtype {
|
||||
DType::U8 | DType::U32 | DType::I64 => {
|
||||
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())
|
||||
}
|
||||
DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()),
|
||||
DType::BF16 => {
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std))
|
||||
@ -2460,7 +2162,6 @@ impl BackendDevice for CpuDevice {
|
||||
let storage = match dtype {
|
||||
DType::U8 => CpuStorage::U8(vec![1u8; elem_count]),
|
||||
DType::U32 => CpuStorage::U32(vec![1u32; elem_count]),
|
||||
DType::I64 => CpuStorage::I64(vec![1i64; elem_count]),
|
||||
DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]),
|
||||
DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]),
|
||||
DType::F32 => CpuStorage::F32(vec![1f32; elem_count]),
|
||||
@ -2474,7 +2175,6 @@ impl BackendDevice for CpuDevice {
|
||||
let storage = match dtype {
|
||||
DType::U8 => CpuStorage::U8(vec![0u8; elem_count]),
|
||||
DType::U32 => CpuStorage::U32(vec![0u32; elem_count]),
|
||||
DType::I64 => CpuStorage::I64(vec![0i64; elem_count]),
|
||||
DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]),
|
||||
DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]),
|
||||
DType::F32 => CpuStorage::F32(vec![0f32; elem_count]),
|
||||
|
@ -1,7 +1,7 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
|
||||
pub use candle_kernels as kernels;
|
||||
use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||
use cudarc::driver::{
|
||||
@ -139,14 +139,6 @@ impl CudaDevice {
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
|
||||
let params = (&data, v as i64, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
|
||||
@ -244,10 +236,6 @@ impl BackendDevice for CudaDevice {
|
||||
let data = self.alloc_zeros::<u32>(elem_count).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
let data = self.alloc_zeros::<i64>(elem_count).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
@ -277,13 +265,11 @@ impl BackendDevice for CudaDevice {
|
||||
let slice = match dtype {
|
||||
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
||||
// cudarc changes.
|
||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||
Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_uniform",
|
||||
})
|
||||
.w()?
|
||||
}
|
||||
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_uniform",
|
||||
})
|
||||
.w()?,
|
||||
DType::F32 => {
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||
curand.0.fill_with_uniform(&mut data).w()?;
|
||||
@ -295,12 +281,10 @@ impl BackendDevice for CudaDevice {
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
let slice = if lo == 0. && up == 1.0 {
|
||||
slice
|
||||
} else {
|
||||
if lo != 0.0 || up != 1.0 {
|
||||
let layout = Layout::contiguous(shape);
|
||||
Affine(up - lo, lo).map(&slice, self, &layout)?
|
||||
};
|
||||
Affine(up - lo, lo).map(&slice, self, &layout)?;
|
||||
}
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
@ -313,13 +297,11 @@ impl BackendDevice for CudaDevice {
|
||||
let elem_count = shape.elem_count();
|
||||
let curand = self.curand.lock().unwrap();
|
||||
let slice = match dtype {
|
||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||
Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_normal",
|
||||
})
|
||||
.w()?
|
||||
}
|
||||
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_normal",
|
||||
})
|
||||
.w()?,
|
||||
DType::F32 => {
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||
curand
|
||||
@ -354,10 +336,6 @@ impl BackendDevice for CudaDevice {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorage::I64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorage::BF16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
@ -383,10 +361,9 @@ impl BackendDevice for CudaDevice {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum CudaStorageSlice {
|
||||
enum CudaStorageSlice {
|
||||
U8(CudaSlice<u8>),
|
||||
U32(CudaSlice<u32>),
|
||||
I64(CudaSlice<i64>),
|
||||
BF16(CudaSlice<bf16>),
|
||||
F16(CudaSlice<f16>),
|
||||
F32(CudaSlice<f32>),
|
||||
@ -394,7 +371,7 @@ pub enum CudaStorageSlice {
|
||||
}
|
||||
type S = CudaStorageSlice;
|
||||
|
||||
pub trait Map1 {
|
||||
trait Map1 {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
@ -406,7 +383,6 @@ pub trait Map1 {
|
||||
let out = match s {
|
||||
S::U8(s) => S::U8(self.f(s, d, l)?),
|
||||
S::U32(s) => S::U32(self.f(s, d, l)?),
|
||||
S::I64(s) => S::I64(self.f(s, d, l)?),
|
||||
S::BF16(s) => S::BF16(self.f(s, d, l)?),
|
||||
S::F16(s) => S::F16(self.f(s, d, l)?),
|
||||
S::F32(s) => S::F32(self.f(s, d, l)?),
|
||||
@ -416,7 +392,7 @@ pub trait Map1 {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2 {
|
||||
trait Map2 {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
@ -430,7 +406,6 @@ pub trait Map2 {
|
||||
let out = match (s1, s2) {
|
||||
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
|
||||
@ -441,7 +416,7 @@ pub trait Map2 {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2InPlace {
|
||||
trait Map2InPlace {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
@ -462,7 +437,6 @@ pub trait Map2InPlace {
|
||||
match (dst, src) {
|
||||
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
@ -472,7 +446,7 @@ pub trait Map2InPlace {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map1Any {
|
||||
trait Map1Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
@ -485,7 +459,6 @@ pub trait Map1Any {
|
||||
let out = match s {
|
||||
S::U8(s) => self.f(s, d, l, S::U8)?,
|
||||
S::U32(s) => self.f(s, d, l, S::U32)?,
|
||||
S::I64(s) => self.f(s, d, l, S::I64)?,
|
||||
S::BF16(s) => self.f(s, d, l, S::BF16)?,
|
||||
S::F16(s) => self.f(s, d, l, S::F16)?,
|
||||
S::F32(s) => self.f(s, d, l, S::F32)?,
|
||||
@ -495,7 +468,7 @@ pub trait Map1Any {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2Any {
|
||||
trait Map2Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
@ -509,7 +482,6 @@ pub trait Map2Any {
|
||||
let out = match (s1, s2) {
|
||||
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
@ -532,7 +504,7 @@ impl Map1 for Clone {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn kernel_name<T: WithDType>(root: &str) -> String {
|
||||
fn kernel_name<T: WithDType>(root: &str) -> String {
|
||||
let dtype = T::DTYPE.as_str();
|
||||
format!("{root}_{dtype}")
|
||||
}
|
||||
@ -593,30 +565,6 @@ impl Map1 for Elu {
|
||||
}
|
||||
}
|
||||
|
||||
struct Powf(f64);
|
||||
impl Map1 for Powf {
|
||||
fn f<T: DeviceRepr + WithDType>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("upowf"), kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (el, dims.len(), &ds, T::from_f64(self.0), src, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct Sum<'a>(&'a [usize]);
|
||||
impl<'a> Map1 for Sum<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
@ -766,9 +714,6 @@ impl<'a> Map1 for IndexSelect<'a> {
|
||||
CudaStorageSlice::U8(slice) => {
|
||||
("is_u8", *slice.slice(ids_l.start_offset()..).device_ptr())
|
||||
}
|
||||
CudaStorageSlice::I64(slice) => {
|
||||
("is_i64", *slice.slice(ids_l.start_offset()..).device_ptr())
|
||||
}
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "index_select ids should be u8 or u32",
|
||||
expected: DType::U32,
|
||||
@ -828,11 +773,8 @@ impl<'a> Map1 for Gather<'a> {
|
||||
("gather_u32", *slice.slice(ids_o1..ids_o2).device_ptr())
|
||||
}
|
||||
CudaStorageSlice::U8(slice) => ("gather_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
CudaStorageSlice::I64(slice) => {
|
||||
("gather_i64", *slice.slice(ids_o1..ids_o2).device_ptr())
|
||||
}
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "gather ids should be u8/u32/i64",
|
||||
msg: "gather ids should be u8 or u32",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
@ -878,10 +820,9 @@ impl<'a> Map2InPlace for IndexAdd<'a> {
|
||||
};
|
||||
let (name, ids) = match &ids.slice {
|
||||
CudaStorageSlice::U32(slice) => ("ia_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
CudaStorageSlice::I64(slice) => ("ia_i64", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
CudaStorageSlice::U8(slice) => ("ia_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "index-add ids should be u8/u32/i64",
|
||||
msg: "index-add ids should be u8 or u32",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
@ -926,10 +867,9 @@ impl<'a> Map2InPlace for ScatterAdd<'a> {
|
||||
};
|
||||
let (name, ids) = match &ids.slice {
|
||||
CudaStorageSlice::U32(slice) => ("sa_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
CudaStorageSlice::I64(slice) => ("sa_i64", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
CudaStorageSlice::U8(slice) => ("sa_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "scatter-add ids should be u8/u32/i64",
|
||||
msg: "scatter-add ids should be u8 or u32",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
@ -981,12 +921,10 @@ impl<'a> Map2 for Conv1D<'a> {
|
||||
} else if dims.len() == 2 {
|
||||
[&[1], dims, &[1], inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
||||
} else {
|
||||
crate::bail!("unexpected input shape for conv1d {dims:?}")
|
||||
panic!("unexpected input shape for conv1d {dims:?}")
|
||||
};
|
||||
let ds = dev.htod_copy(ds).w()?;
|
||||
let params = (
|
||||
el, l_out, p.stride, p.padding, p.dilation, &ds, inp, k, &out,
|
||||
);
|
||||
let params = (el, l_out, p.stride, p.padding, &ds, inp, k, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
@ -1003,8 +941,8 @@ impl<'a> Map2 for Conv2D<'a> {
|
||||
k_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
// Kernel shape: (c_out, c_in_k, h_k, w_k)
|
||||
// Input shape: (b_size, c_in, h_in, w_in)
|
||||
// Kernel shape: (c_out, c_in_k, w_k, h_k)
|
||||
// Input shape: (b_size, c_in, w_in, c_in)
|
||||
let p = &self.0;
|
||||
let (out_w, out_h) = (p.out_w(), p.out_h());
|
||||
let dst_el = p.c_out * out_w * out_h * p.b_size;
|
||||
@ -1021,62 +959,10 @@ impl<'a> Map2 for Conv2D<'a> {
|
||||
let ds = if dims.len() == 4 {
|
||||
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
||||
} else {
|
||||
crate::bail!("unexpected input shape for conv2d {dims:?}")
|
||||
panic!("unexpected input shape for conv1d {dims:?}")
|
||||
};
|
||||
let ds = dev.htod_copy(ds).w()?;
|
||||
let params = (
|
||||
el, out_w, out_h, p.stride, p.padding, p.dilation, &ds, inp, k, &out,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
|
||||
impl<'a> Map2 for ConvTranspose2D<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
inp: &CudaSlice<T>,
|
||||
inp_l: &Layout,
|
||||
k: &CudaSlice<T>,
|
||||
k_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
// Kernel shape: (c_in_k, c_out, h_k, w_k)
|
||||
// Input shape: (b_size, c_in, h_in, w_in)
|
||||
let p = &self.0;
|
||||
let (out_w, out_h) = (p.out_w(), p.out_h());
|
||||
let dst_el = p.c_out * out_w * out_h * p.b_size;
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(k_l.start_offset()..);
|
||||
let shape = inp_l.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose2d"), kernels::CONV)?;
|
||||
let ds = if dims.len() == 4 {
|
||||
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
||||
} else {
|
||||
crate::bail!("unexpected input shape for conv_transpose2d {dims:?}")
|
||||
};
|
||||
let ds = dev.htod_copy(ds).w()?;
|
||||
let params = (
|
||||
el,
|
||||
out_w,
|
||||
out_h,
|
||||
p.stride,
|
||||
p.padding,
|
||||
p.output_padding,
|
||||
p.dilation,
|
||||
&ds,
|
||||
inp,
|
||||
k,
|
||||
&out,
|
||||
);
|
||||
let params = (el, out_w, out_h, p.stride, p.padding, &ds, inp, k, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
@ -1110,7 +996,7 @@ impl Map1 for Pool2D {
|
||||
let ds = if dims.len() == 4 {
|
||||
[dims, inp_l.stride()].concat()
|
||||
} else {
|
||||
crate::bail!("unexpected input shape for pool {dims:?}")
|
||||
panic!("unexpected input shape for conv1d {dims:?}")
|
||||
};
|
||||
let el = shape.elem_count();
|
||||
let out_w = (dims[2] - self.w_k) / self.w_stride + 1;
|
||||
@ -1156,7 +1042,7 @@ impl Map1 for UpsampleNearest2D {
|
||||
let ds = if dims.len() == 4 {
|
||||
[dims, inp_l.stride()].concat()
|
||||
} else {
|
||||
crate::bail!("unexpected input shape for upsample {dims:?}")
|
||||
panic!("unexpected input shape for conv1d {dims:?}")
|
||||
};
|
||||
let (out_w, out_h) = (self.0, self.1);
|
||||
let dst_el = out_w * out_h * dims[0] * dims[1];
|
||||
@ -1194,12 +1080,8 @@ impl<'a> Map2 for WhereCond<'a> {
|
||||
let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
|
||||
(ptr, "where_u32")
|
||||
}
|
||||
CudaStorageSlice::I64(slice) => {
|
||||
let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
|
||||
(ptr, "where_i64")
|
||||
}
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "where conditions should be u8/u32/i64",
|
||||
msg: "where conditions should be u8 or u32",
|
||||
expected: DType::U32,
|
||||
got: self.0.dtype(),
|
||||
})
|
||||
@ -1310,8 +1192,8 @@ fn slice_src_and_dst<'a, T>(
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CudaStorage {
|
||||
pub slice: CudaStorageSlice,
|
||||
pub device: CudaDevice,
|
||||
slice: CudaStorageSlice,
|
||||
device: CudaDevice,
|
||||
}
|
||||
|
||||
pub trait CudaDType: Sized {
|
||||
@ -1343,7 +1225,6 @@ macro_rules! cuda_dtype {
|
||||
}
|
||||
cuda_dtype!(u8, U8);
|
||||
cuda_dtype!(u32, U32);
|
||||
cuda_dtype!(i64, I64);
|
||||
cuda_dtype!(f16, F16);
|
||||
cuda_dtype!(bf16, BF16);
|
||||
cuda_dtype!(f32, F32);
|
||||
@ -1457,7 +1338,6 @@ impl BackendStorage for CudaStorage {
|
||||
match self.slice {
|
||||
CudaStorageSlice::U8(_) => DType::U8,
|
||||
CudaStorageSlice::U32(_) => DType::U32,
|
||||
CudaStorageSlice::I64(_) => DType::I64,
|
||||
CudaStorageSlice::BF16(_) => DType::BF16,
|
||||
CudaStorageSlice::F16(_) => DType::F16,
|
||||
CudaStorageSlice::F32(_) => DType::F32,
|
||||
@ -1483,7 +1363,6 @@ impl BackendStorage for CudaStorage {
|
||||
let inp = match &self.slice {
|
||||
CudaStorageSlice::U8(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::I64(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::F32(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
@ -1506,12 +1385,6 @@ impl BackendStorage for CudaStorage {
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::U32(out)
|
||||
}
|
||||
DType::I64 => {
|
||||
let out = unsafe { dev.alloc::<i64>(el) }.w()?;
|
||||
let params = (el, dims.len(), &ds, *inp, &out);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::I64(out)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let out = unsafe { dev.alloc::<bf16>(el) }.w()?;
|
||||
let params = (el, dims.len(), &ds, *inp, &out);
|
||||
@ -1549,12 +1422,6 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn powf(&self, layout: &Layout, e: f64) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = Powf(e).map(&self.slice, &device, layout)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = Elu(alpha).map(&self.slice, &device, layout)?;
|
||||
@ -1602,11 +1469,6 @@ impl BackendStorage for CudaStorage {
|
||||
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
|
||||
Ok(CpuStorage::U32(cpu_storage))
|
||||
}
|
||||
CudaStorageSlice::I64(slice) => {
|
||||
let dev = slice.device();
|
||||
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
|
||||
Ok(CpuStorage::I64(cpu_storage))
|
||||
}
|
||||
CudaStorageSlice::BF16(slice) => {
|
||||
let dev = slice.device();
|
||||
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
|
||||
@ -1708,6 +1570,7 @@ impl BackendStorage for CudaStorage {
|
||||
.map_err(crate::Error::wrap)?;
|
||||
S::F16(out)
|
||||
}
|
||||
|
||||
(S::F32(inp), S::F32(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
@ -1725,25 +1588,11 @@ impl BackendStorage for CudaStorage {
|
||||
S::F64(out)
|
||||
}
|
||||
(S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv2d does not support u32"))?,
|
||||
(S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv2d does not support i64"))?,
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in conv2d"))?,
|
||||
};
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConvTranspose2D,
|
||||
) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice =
|
||||
ConvTranspose2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn avg_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = Pool2D {
|
||||
@ -1889,9 +1738,6 @@ impl BackendStorage for CudaStorage {
|
||||
let src_shape = src_l.shape();
|
||||
let dims = src_shape.dims();
|
||||
let el_count = src_shape.elem_count();
|
||||
if el_count == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let dev = &self.device;
|
||||
let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?;
|
||||
@ -1956,18 +1802,6 @@ impl BackendStorage for CudaStorage {
|
||||
unsafe { func.launch(cfg, params) }.w()?
|
||||
}
|
||||
}
|
||||
(CudaStorageSlice::I64(src), CudaStorageSlice::I64(dst)) => {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||
if src_l.is_contiguous() {
|
||||
dev.dtod_copy(&src, &mut dst).w()?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_i64", kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?
|
||||
}
|
||||
}
|
||||
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||
if src_l.is_contiguous() {
|
||||
|
@ -48,14 +48,14 @@ pub(crate) fn launch_conv2d<
|
||||
let conv = cudnn.create_conv2d::<T>(
|
||||
/* pad */ [params.padding as i32, params.padding as i32],
|
||||
/* stride */ [params.stride as i32, params.stride as i32],
|
||||
/* dilation */ [params.dilation as i32, params.dilation as i32],
|
||||
/* dilation */ [1, 1],
|
||||
cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
|
||||
)?;
|
||||
let x_shape = [
|
||||
params.b_size as i32,
|
||||
params.c_in as i32,
|
||||
params.i_h as i32,
|
||||
params.i_w as i32,
|
||||
params.i_h as i32,
|
||||
];
|
||||
// Note that `src` already starts at the proper offset.
|
||||
let x = if src_l.is_contiguous() {
|
||||
@ -75,14 +75,14 @@ pub(crate) fn launch_conv2d<
|
||||
[
|
||||
params.c_out as i32,
|
||||
params.c_in as i32,
|
||||
params.k_h as i32,
|
||||
params.k_w as i32,
|
||||
params.k_h as i32,
|
||||
],
|
||||
)?;
|
||||
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
|
||||
let y = cudnn.create_4d_tensor(
|
||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||
[params.b_size as i32, params.c_out as i32, h_out, w_out],
|
||||
[params.b_size as i32, params.c_out as i32, w_out, h_out],
|
||||
)?;
|
||||
let conv2d = Conv2dForward {
|
||||
conv: &conv,
|
||||
|
@ -16,6 +16,7 @@ pub enum Device {
|
||||
Cuda(crate::CudaDevice),
|
||||
}
|
||||
|
||||
// TODO: Should we back the cpu implementation using the NdArray crate or similar?
|
||||
pub trait NdArray {
|
||||
fn shape(&self) -> Result<Shape>;
|
||||
|
||||
@ -80,49 +81,6 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4: usize> NdArray
|
||||
for &[[[[S; N4]; N3]; N2]; N1]
|
||||
{
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
Ok(Shape::from((N1, N2, N3, N4)))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
let mut vec = Vec::with_capacity(N1 * N2 * N3 * N4);
|
||||
for i1 in 0..N1 {
|
||||
for i2 in 0..N2 {
|
||||
for i3 in 0..N3 {
|
||||
vec.extend(self[i1][i2][i3])
|
||||
}
|
||||
}
|
||||
}
|
||||
S::to_cpu_storage_owned(vec)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: NdArray> NdArray for Vec<S> {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
if self.is_empty() {
|
||||
crate::bail!("empty array")
|
||||
}
|
||||
let shape0 = self[0].shape()?;
|
||||
let n = self.len();
|
||||
for v in self.iter() {
|
||||
let shape = v.shape()?;
|
||||
if shape != shape0 {
|
||||
crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
|
||||
}
|
||||
}
|
||||
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
// This allocates intermediary memory and shouldn't be necessary.
|
||||
let storages = self.iter().map(|v| v.to_cpu_storage()).collect::<Vec<_>>();
|
||||
CpuStorage::concat(storages.as_slice()).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl Device {
|
||||
pub fn new_cuda(ordinal: usize) -> Result<Self> {
|
||||
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
||||
|
@ -9,14 +9,11 @@ impl Tensor {
|
||||
&self,
|
||||
f: &mut std::fmt::Formatter,
|
||||
) -> std::fmt::Result {
|
||||
let device_str = match self.device().location() {
|
||||
crate::DeviceLocation::Cpu => "".to_owned(),
|
||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||
format!(", cuda:{}", gpu_id)
|
||||
}
|
||||
let prefix = match self.device() {
|
||||
crate::Device::Cpu => "Cpu",
|
||||
crate::Device::Cuda(_) => "Cuda",
|
||||
};
|
||||
|
||||
write!(f, "Tensor[")?;
|
||||
write!(f, "{prefix}Tensor[")?;
|
||||
match self.dims() {
|
||||
[] => {
|
||||
if let Ok(v) = self.to_scalar::<T>() {
|
||||
@ -43,7 +40,7 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
}
|
||||
write!(f, "; {}{}]", self.dtype().as_str(), device_str)
|
||||
write!(f, "; {}]", self.dtype().as_str())
|
||||
}
|
||||
}
|
||||
|
||||
@ -52,7 +49,6 @@ impl std::fmt::Debug for Tensor {
|
||||
match self.dtype() {
|
||||
DType::U8 => self.fmt_dt::<u8>(f),
|
||||
DType::U32 => self.fmt_dt::<u32>(f),
|
||||
DType::I64 => self.fmt_dt::<i64>(f),
|
||||
DType::BF16 => self.fmt_dt::<bf16>(f),
|
||||
DType::F16 => self.fmt_dt::<f16>(f),
|
||||
DType::F32 => self.fmt_dt::<f32>(f),
|
||||
@ -435,12 +431,6 @@ impl std::fmt::Display for Tensor {
|
||||
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
|
||||
writeln!(f)?;
|
||||
}
|
||||
DType::I64 => {
|
||||
let tf: IntFormatter<i64> = IntFormatter::new();
|
||||
let max_w = tf.max_width(&to_display);
|
||||
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
|
||||
writeln!(f)?;
|
||||
}
|
||||
DType::BF16 => {
|
||||
if let Ok(tf) = FloatFormatter::<bf16>::new(&to_display, &po) {
|
||||
let max_w = tf.max_width(&to_display);
|
||||
@ -470,20 +460,6 @@ impl std::fmt::Display for Tensor {
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let device_str = match self.device().location() {
|
||||
crate::DeviceLocation::Cpu => "".to_owned(),
|
||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||
format!(", cuda:{}", gpu_id)
|
||||
}
|
||||
};
|
||||
|
||||
write!(
|
||||
f,
|
||||
"Tensor[{:?}, {}{}]",
|
||||
self.dims(),
|
||||
self.dtype().as_str(),
|
||||
device_str
|
||||
)
|
||||
write!(f, "Tensor[{:?}, {}]", self.dims(), self.dtype().as_str())
|
||||
}
|
||||
}
|
||||
|
@ -1,24 +1,13 @@
|
||||
//! Types for elements that can be stored and manipulated using tensors.
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::{CpuStorage, Error, Result};
|
||||
|
||||
/// The different types of elements allowed in tensors.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum DType {
|
||||
// Unsigned 8 bits integer.
|
||||
U8,
|
||||
// Unsigned 32 bits integer.
|
||||
U32,
|
||||
// Signed 64 bits integer.
|
||||
I64,
|
||||
// Brain floating-point using half precision (16 bits).
|
||||
BF16,
|
||||
// Floating-point using half precision (16 bits).
|
||||
F16,
|
||||
// Floating-point using single precision (32 bits).
|
||||
F32,
|
||||
// Floating-point using double precision (64 bits).
|
||||
F64,
|
||||
}
|
||||
|
||||
@ -31,7 +20,6 @@ impl std::str::FromStr for DType {
|
||||
match s {
|
||||
"u8" => Ok(Self::U8),
|
||||
"u32" => Ok(Self::U32),
|
||||
"i64" => Ok(Self::I64),
|
||||
"bf16" => Ok(Self::BF16),
|
||||
"f16" => Ok(Self::F16),
|
||||
"f32" => Ok(Self::F32),
|
||||
@ -42,12 +30,10 @@ impl std::str::FromStr for DType {
|
||||
}
|
||||
|
||||
impl DType {
|
||||
/// String representation for dtypes.
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::U8 => "u8",
|
||||
Self::U32 => "u32",
|
||||
Self::I64 => "i64",
|
||||
Self::BF16 => "bf16",
|
||||
Self::F16 => "f16",
|
||||
Self::F32 => "f32",
|
||||
@ -55,12 +41,10 @@ impl DType {
|
||||
}
|
||||
}
|
||||
|
||||
/// The size used by each element in bytes, i.e. 1 for `U8`, 4 for `F32`.
|
||||
pub fn size_in_bytes(&self) -> usize {
|
||||
match self {
|
||||
Self::U8 => 1,
|
||||
Self::U32 => 4,
|
||||
Self::I64 => 8,
|
||||
Self::BF16 => 2,
|
||||
Self::F16 => 2,
|
||||
Self::F32 => 4,
|
||||
@ -141,7 +125,6 @@ use half::{bf16, f16};
|
||||
|
||||
with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64);
|
||||
with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64);
|
||||
with_dtype!(i64, I64, |v: f64| v as i64, |v: i64| v as f64);
|
||||
with_dtype!(f16, F16, f16::from_f64, f16::to_f64);
|
||||
with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
|
||||
with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
|
||||
@ -152,15 +135,6 @@ pub trait IntDType: WithDType {
|
||||
fn as_usize(&self) -> usize;
|
||||
}
|
||||
|
||||
impl IntDType for i64 {
|
||||
fn is_true(&self) -> bool {
|
||||
*self != 0
|
||||
}
|
||||
fn as_usize(&self) -> usize {
|
||||
*self as usize
|
||||
}
|
||||
}
|
||||
|
||||
impl IntDType for u32 {
|
||||
fn is_true(&self) -> bool {
|
||||
*self != 0
|
||||
|
@ -37,10 +37,6 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
@ -89,16 +85,6 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &crate::conv::ParamsConvTranspose2D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
@ -30,7 +30,7 @@ pub enum Error {
|
||||
UnsupportedDTypeForOp(DType, &'static str),
|
||||
|
||||
// === Dimension Index Errors ===
|
||||
#[error("{op}: dimension index {dim} out of range for shape {shape:?}")]
|
||||
#[error("{op}: dimension index {dim} out of range for {shape:?}")]
|
||||
DimOutOfRange {
|
||||
shape: Shape,
|
||||
dim: i32,
|
||||
@ -207,11 +207,7 @@ pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
impl Error {
|
||||
pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self {
|
||||
Self::Wrapped(Box::new(err)).bt()
|
||||
}
|
||||
|
||||
pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self {
|
||||
Self::Msg(err.to_string()).bt()
|
||||
Self::Wrapped(Box::new(err))
|
||||
}
|
||||
|
||||
pub fn bt(self) -> Self {
|
||||
|
@ -9,14 +9,6 @@ pub struct Layout {
|
||||
}
|
||||
|
||||
impl Layout {
|
||||
pub fn new(shape: Shape, stride: Vec<usize>, start_offset: usize) -> Self {
|
||||
Self {
|
||||
shape,
|
||||
stride,
|
||||
start_offset,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn contiguous_with_offset<S: Into<Shape>>(shape: S, start_offset: usize) -> Self {
|
||||
let shape = shape.into();
|
||||
let stride = shape.stride_contiguous();
|
||||
@ -120,31 +112,6 @@ impl Layout {
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn permute(&self, idxs: &[usize]) -> Result<Self> {
|
||||
let is_permutation =
|
||||
idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
|
||||
if !is_permutation {
|
||||
crate::bail!(
|
||||
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
|
||||
self.dims(),
|
||||
idxs
|
||||
)
|
||||
}
|
||||
let stride = self.stride();
|
||||
let dims = self.shape().dims();
|
||||
let mut perm_stride = stride.to_vec();
|
||||
let mut perm_dims = dims.to_vec();
|
||||
for (i, &idx) in idxs.iter().enumerate() {
|
||||
perm_stride[i] = stride[idx];
|
||||
perm_dims[i] = dims[idx];
|
||||
}
|
||||
Ok(Self {
|
||||
shape: Shape::from(perm_dims),
|
||||
stride: perm_stride,
|
||||
start_offset: self.start_offset,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
if shape.rank() < self.shape().rank() {
|
||||
|
@ -56,15 +56,12 @@ pub mod layout;
|
||||
mod mkl;
|
||||
pub mod npy;
|
||||
mod op;
|
||||
pub mod pickle;
|
||||
pub mod quantized;
|
||||
pub mod safetensors;
|
||||
pub mod scalar;
|
||||
pub mod shape;
|
||||
mod storage;
|
||||
mod strided_index;
|
||||
mod tensor;
|
||||
pub mod test_utils;
|
||||
pub mod utils;
|
||||
mod variable;
|
||||
|
||||
@ -89,39 +86,3 @@ pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
pub trait ToUsize2 {
|
||||
fn to_usize2(self) -> (usize, usize);
|
||||
}
|
||||
|
||||
impl ToUsize2 for usize {
|
||||
fn to_usize2(self) -> (usize, usize) {
|
||||
(self, self)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToUsize2 for (usize, usize) {
|
||||
fn to_usize2(self) -> (usize, usize) {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
// A simple trait defining a module with forward method using a single argument.
|
||||
pub trait Module: std::fmt::Debug {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
||||
|
||||
/// Change the module to use training mode vs eval mode.
|
||||
///
|
||||
/// The default implementation does nothing as this is only used for a couple modules such as
|
||||
/// dropout or batch-normalization.
|
||||
fn set_training(&mut self, _training: bool) {}
|
||||
}
|
||||
|
||||
impl Module for quantized::QMatMul {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.forward(xs)
|
||||
}
|
||||
}
|
||||
|
@ -25,10 +25,6 @@ mod ffi {
|
||||
pub fn vdMul(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
|
||||
pub fn vsDiv(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
|
||||
pub fn vdDiv(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
|
||||
pub fn vsFmax(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
|
||||
pub fn vdFmax(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
|
||||
pub fn vsFmin(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
|
||||
pub fn vdFmin(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
|
||||
|
||||
pub fn sgemm_(
|
||||
transa: *const c_char,
|
||||
@ -301,7 +297,7 @@ pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
|
||||
fn vs_tanh(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
@ -311,7 +307,7 @@ pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
|
||||
fn vd_tanh(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
@ -380,7 +376,3 @@ binary_op!(vs_mul, f32, vsMul);
|
||||
binary_op!(vd_mul, f64, vdMul);
|
||||
binary_op!(vs_div, f32, vsDiv);
|
||||
binary_op!(vd_div, f64, vdDiv);
|
||||
binary_op!(vs_max, f32, vsFmax);
|
||||
binary_op!(vd_max, f64, vdFmax);
|
||||
binary_op!(vs_min, f32, vsFmin);
|
||||
binary_op!(vd_min, f64, vdFmin);
|
||||
|
@ -85,7 +85,6 @@ impl Header {
|
||||
DType::F16 => "f2",
|
||||
DType::F32 => "f4",
|
||||
DType::F64 => "f8",
|
||||
DType::I64 => "i8",
|
||||
DType::U32 => "u4",
|
||||
DType::U8 => "u1",
|
||||
};
|
||||
@ -161,7 +160,7 @@ impl Header {
|
||||
"f" | "f4" => DType::F32,
|
||||
"d" | "f8" => DType::F64,
|
||||
// "i" | "i4" => DType::S32,
|
||||
"q" | "i8" => DType::I64,
|
||||
// "q" | "i8" => DType::S64,
|
||||
// "h" | "i2" => DType::S16,
|
||||
// "b" | "i1" => DType::S8,
|
||||
"B" | "u1" => DType::U8,
|
||||
@ -197,11 +196,7 @@ impl Header {
|
||||
|
||||
impl Tensor {
|
||||
// TODO: Add the possibility to read directly to a device?
|
||||
pub(crate) fn from_reader<R: std::io::Read>(
|
||||
shape: Shape,
|
||||
dtype: DType,
|
||||
reader: &mut R,
|
||||
) -> Result<Self> {
|
||||
fn from_reader<R: std::io::Read>(shape: Shape, dtype: DType, reader: &mut R) -> Result<Self> {
|
||||
let elem_count = shape.elem_count();
|
||||
match dtype {
|
||||
DType::BF16 => {
|
||||
@ -234,11 +229,6 @@ impl Tensor {
|
||||
reader.read_u32_into::<LittleEndian>(&mut data_t)?;
|
||||
Tensor::from_vec(data_t, shape, &Device::Cpu)
|
||||
}
|
||||
DType::I64 => {
|
||||
let mut data_t = vec![0i64; elem_count];
|
||||
reader.read_i64_into::<LittleEndian>(&mut data_t)?;
|
||||
Tensor::from_vec(data_t, shape, &Device::Cpu)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -371,25 +361,6 @@ impl NpzTensors {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn names(&self) -> Vec<&String> {
|
||||
self.index_per_name.keys().collect()
|
||||
}
|
||||
|
||||
/// This only returns the shape and dtype for a named tensor. Compared to `get`, this avoids
|
||||
/// reading the whole tensor data.
|
||||
pub fn get_shape_and_dtype(&self, name: &str) -> Result<(Shape, DType)> {
|
||||
let index = match self.index_per_name.get(name) {
|
||||
None => crate::bail!("cannot find tensor {name}"),
|
||||
Some(index) => *index,
|
||||
};
|
||||
let zip_reader = BufReader::new(File::open(&self.path)?);
|
||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||
let mut reader = zip.by_index(index)?;
|
||||
let header = read_header(&mut reader)?;
|
||||
let header = Header::parse(&header)?;
|
||||
Ok((header.shape(), header.descr))
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
|
||||
let index = match self.index_per_name.get(name) {
|
||||
None => return Ok(None),
|
||||
|
@ -1,4 +1,3 @@
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor};
|
||||
use half::{bf16, f16};
|
||||
use num_traits::float::Float;
|
||||
@ -41,8 +40,6 @@ pub enum BinaryOp {
|
||||
Mul,
|
||||
Sub,
|
||||
Div,
|
||||
Maximum,
|
||||
Minimum,
|
||||
}
|
||||
|
||||
// Unary ops with no argument
|
||||
@ -59,7 +56,6 @@ pub enum UnaryOp {
|
||||
Sqrt,
|
||||
Gelu,
|
||||
Relu,
|
||||
Tanh,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -82,7 +78,6 @@ pub enum Op {
|
||||
kernel: Tensor,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
},
|
||||
|
||||
#[allow(dead_code)]
|
||||
@ -91,17 +86,6 @@ pub enum Op {
|
||||
kernel: Tensor,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
},
|
||||
|
||||
#[allow(dead_code)]
|
||||
ConvTranspose2D {
|
||||
arg: Tensor,
|
||||
kernel: Tensor,
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
},
|
||||
|
||||
AvgPool2D {
|
||||
@ -133,25 +117,14 @@ pub enum Op {
|
||||
Reshape(Tensor),
|
||||
ToDevice(Tensor),
|
||||
Transpose(Tensor, usize, usize),
|
||||
Permute(Tensor, Vec<usize>),
|
||||
Elu(Tensor, f64),
|
||||
Powf(Tensor, f64),
|
||||
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>),
|
||||
CustomOp2(
|
||||
Tensor,
|
||||
Tensor,
|
||||
std::sync::Arc<Box<dyn CustomOp2 + Send + Sync>>,
|
||||
),
|
||||
CustomOp3(
|
||||
Tensor,
|
||||
Tensor,
|
||||
Tensor,
|
||||
std::sync::Arc<Box<dyn CustomOp3 + Send + Sync>>,
|
||||
),
|
||||
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1>>),
|
||||
CustomOp2(Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp2>>),
|
||||
CustomOp3(Tensor, Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp3>>),
|
||||
}
|
||||
|
||||
/// Unary ops that can be defined in user-land.
|
||||
pub trait CustomOp1 {
|
||||
pub trait CustomOp1: Send + Sync {
|
||||
// Box<dyn> does not support const yet, so use a function to get the name.
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
@ -175,7 +148,7 @@ pub trait CustomOp1 {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CustomOp2 {
|
||||
pub trait CustomOp2: Send + Sync {
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
@ -213,7 +186,7 @@ pub trait CustomOp2 {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CustomOp3 {
|
||||
pub trait CustomOp3: Send + Sync {
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
@ -266,7 +239,6 @@ pub trait UnaryOpT {
|
||||
fn f64(v1: f64) -> f64;
|
||||
fn u8(v1: u8) -> u8;
|
||||
fn u32(v1: u32) -> u32;
|
||||
fn i64(v1: i64) -> i64;
|
||||
|
||||
// There is no very good way to represent optional function in traits so we go for an explicit
|
||||
// boolean flag to mark the function as existing.
|
||||
@ -290,7 +262,6 @@ pub trait BinaryOpT {
|
||||
fn f64(v1: f64, v2: f64) -> f64;
|
||||
fn u8(v1: u8, v2: u8) -> u8;
|
||||
fn u32(v1: u32, v2: u32) -> u32;
|
||||
fn i64(v1: i64, v2: i64) -> i64;
|
||||
|
||||
const BF16_VEC: bool = false;
|
||||
fn bf16_vec(_xs1: &[bf16], _xs2: &[bf16], _ys: &mut [bf16]) {}
|
||||
@ -304,16 +275,12 @@ pub trait BinaryOpT {
|
||||
fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {}
|
||||
const U32_VEC: bool = false;
|
||||
fn u32_vec(_xs1: &[u32], _xs2: &[u32], _ys: &mut [u32]) {}
|
||||
const I64_VEC: bool = false;
|
||||
fn i64_vec(_xs1: &[i64], _xs2: &[i64], _ys: &mut [i64]) {}
|
||||
}
|
||||
|
||||
pub(crate) struct Add;
|
||||
pub(crate) struct Div;
|
||||
pub(crate) struct Mul;
|
||||
pub(crate) struct Sub;
|
||||
pub(crate) struct Maximum;
|
||||
pub(crate) struct Minimum;
|
||||
pub(crate) struct Exp;
|
||||
pub(crate) struct Log;
|
||||
pub(crate) struct Sin;
|
||||
@ -325,7 +292,6 @@ pub(crate) struct Sqr;
|
||||
pub(crate) struct Sqrt;
|
||||
pub(crate) struct Gelu;
|
||||
pub(crate) struct Relu;
|
||||
pub(crate) struct Tanh;
|
||||
|
||||
macro_rules! bin_op {
|
||||
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
||||
@ -357,10 +323,6 @@ macro_rules! bin_op {
|
||||
fn u32(v1: u32, v2: u32) -> u32 {
|
||||
$e(v1, v2)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v1: i64, v2: i64) -> i64 {
|
||||
$e(v1, v2)
|
||||
}
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const F32_VEC: bool = true;
|
||||
@ -399,22 +361,7 @@ bin_op!(Add, "add", |v1, v2| v1 + v2, vs_add, vd_add);
|
||||
bin_op!(Sub, "sub", |v1, v2| v1 - v2, vs_sub, vd_sub);
|
||||
bin_op!(Mul, "mul", |v1, v2| v1 * v2, vs_mul, vd_mul);
|
||||
bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div);
|
||||
bin_op!(
|
||||
Minimum,
|
||||
"minimum",
|
||||
|v1, v2| if v1 > v2 { v2 } else { v1 },
|
||||
vs_min,
|
||||
vd_min
|
||||
);
|
||||
bin_op!(
|
||||
Maximum,
|
||||
"maximum",
|
||||
|v1, v2| if v1 < v2 { v2 } else { v1 },
|
||||
vs_max,
|
||||
vd_max
|
||||
);
|
||||
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
macro_rules! unary_op {
|
||||
($op: ident, $name: literal, $a: ident, $e: expr) => {
|
||||
impl UnaryOpT for $op {
|
||||
@ -445,10 +392,6 @@ macro_rules! unary_op {
|
||||
fn u32(_: u32) -> u32 {
|
||||
todo!("no unary function for u32")
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
todo!("no unary function for i64")
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -481,10 +424,6 @@ macro_rules! unary_op {
|
||||
fn u32(_: u32) -> u32 {
|
||||
todo!("no unary function for u32")
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
todo!("no unary function for i64")
|
||||
}
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const F32_VEC: bool = true;
|
||||
@ -523,7 +462,6 @@ unary_op!(Exp, "exp", v, v.exp(), vs_exp, vd_exp);
|
||||
unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
|
||||
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
|
||||
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
|
||||
unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
|
||||
unary_op!(Abs, "abs", v, v.abs());
|
||||
unary_op!(Neg, "neg", v, -v);
|
||||
unary_op!(Recip, "recip", v, v.recip());
|
||||
@ -577,10 +515,6 @@ impl UnaryOpT for Gelu {
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
0
|
||||
}
|
||||
const KERNEL: &'static str = "ugelu";
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
@ -630,10 +564,6 @@ impl UnaryOpT for Relu {
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
/// `BackpropOp` is a wrapper around `Option<Op>`. The main goal is to ensure that dependencies are
|
||||
|
@ -1,725 +0,0 @@
|
||||
// Just enough pickle support to be able to read PyTorch checkpoints.
|
||||
// This hardcodes objects that are required for tensor reading, we may want to make this a bit more
|
||||
// composable/tensor agnostic at some point.
|
||||
use crate::{DType, Error as E, Layout, Result, Tensor};
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use std::collections::HashMap;
|
||||
use std::io::BufRead;
|
||||
|
||||
const VERBOSE: bool = false;
|
||||
|
||||
// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/
|
||||
#[repr(u8)]
|
||||
#[derive(Debug, Eq, PartialEq, Clone)]
|
||||
pub enum OpCode {
|
||||
// https://github.com/python/cpython/blob/ed25f097160b5cbb0c9a1f9a746d2f1bbc96515a/Lib/pickletools.py#L2123
|
||||
Proto = 0x80,
|
||||
Global = b'c',
|
||||
BinPut = b'q',
|
||||
LongBinPut = b'r',
|
||||
EmptyTuple = b')',
|
||||
Reduce = b'R',
|
||||
Mark = b'(',
|
||||
BinUnicode = b'X',
|
||||
BinInt = b'J',
|
||||
Tuple = b't',
|
||||
BinPersId = b'Q',
|
||||
BinInt1 = b'K',
|
||||
BinInt2 = b'M',
|
||||
Tuple1 = 0x85,
|
||||
Tuple2 = 0x86,
|
||||
Tuple3 = 0x87,
|
||||
NewTrue = 0x88,
|
||||
NewFalse = 0x89,
|
||||
None = b'N',
|
||||
BinGet = b'h',
|
||||
LongBinGet = b'j',
|
||||
SetItem = b's',
|
||||
SetItems = b'u',
|
||||
EmptyDict = b'}',
|
||||
Dict = b'd',
|
||||
Build = b'b',
|
||||
Stop = b'.',
|
||||
NewObj = 0x81,
|
||||
EmptyList = b']',
|
||||
BinFloat = b'g',
|
||||
Append = b'a',
|
||||
Appends = b'e',
|
||||
}
|
||||
|
||||
// Avoid using FromPrimitive so as not to drag another dependency.
|
||||
impl TryFrom<u8> for OpCode {
|
||||
type Error = u8;
|
||||
fn try_from(value: u8) -> std::result::Result<Self, Self::Error> {
|
||||
match value {
|
||||
0x80 => Ok(Self::Proto),
|
||||
b'c' => Ok(Self::Global),
|
||||
b'q' => Ok(Self::BinPut),
|
||||
b'r' => Ok(Self::LongBinPut),
|
||||
b')' => Ok(Self::EmptyTuple),
|
||||
b'R' => Ok(Self::Reduce),
|
||||
b'(' => Ok(Self::Mark),
|
||||
b'X' => Ok(Self::BinUnicode),
|
||||
b'J' => Ok(Self::BinInt),
|
||||
b't' => Ok(Self::Tuple),
|
||||
b'Q' => Ok(Self::BinPersId),
|
||||
b'K' => Ok(Self::BinInt1),
|
||||
b'M' => Ok(Self::BinInt2),
|
||||
b'N' => Ok(Self::None),
|
||||
0x85 => Ok(Self::Tuple1),
|
||||
0x86 => Ok(Self::Tuple2),
|
||||
0x87 => Ok(Self::Tuple3),
|
||||
0x88 => Ok(Self::NewTrue),
|
||||
0x89 => Ok(Self::NewFalse),
|
||||
b'h' => Ok(Self::BinGet),
|
||||
b'j' => Ok(Self::LongBinGet),
|
||||
b's' => Ok(Self::SetItem),
|
||||
b'u' => Ok(Self::SetItems),
|
||||
b'}' => Ok(Self::EmptyDict),
|
||||
b'd' => Ok(Self::EmptyDict),
|
||||
b'b' => Ok(Self::Build),
|
||||
b'.' => Ok(Self::Stop),
|
||||
0x81 => Ok(Self::NewObj),
|
||||
b']' => Ok(Self::EmptyList),
|
||||
b'G' => Ok(Self::BinFloat),
|
||||
b'a' => Ok(Self::Append),
|
||||
b'e' => Ok(Self::Appends),
|
||||
value => Err(value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn read_to_newline<R: BufRead>(r: &mut R) -> Result<Vec<u8>> {
|
||||
let mut data: Vec<u8> = Vec::with_capacity(32);
|
||||
r.read_until(b'\n', &mut data)?;
|
||||
data.pop();
|
||||
if data.last() == Some(&b'\r') {
|
||||
data.pop();
|
||||
}
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum Object {
|
||||
Class {
|
||||
module_name: String,
|
||||
class_name: String,
|
||||
},
|
||||
Int(i32),
|
||||
Float(f64),
|
||||
Unicode(String),
|
||||
Bool(bool),
|
||||
None,
|
||||
Tuple(Vec<Object>),
|
||||
List(Vec<Object>),
|
||||
Mark,
|
||||
Dict(Vec<(Object, Object)>),
|
||||
Reduce {
|
||||
callable: Box<Object>,
|
||||
args: Box<Object>,
|
||||
},
|
||||
Build {
|
||||
callable: Box<Object>,
|
||||
args: Box<Object>,
|
||||
},
|
||||
PersistentLoad(Box<Object>),
|
||||
}
|
||||
|
||||
type OResult<T> = std::result::Result<T, Object>;
|
||||
|
||||
impl Object {
|
||||
pub fn unicode(self) -> OResult<String> {
|
||||
match self {
|
||||
Self::Unicode(t) => Ok(t),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reduce(self) -> OResult<(Self, Self)> {
|
||||
match self {
|
||||
Self::Reduce { callable, args } => Ok((*callable, *args)),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn none(self) -> OResult<()> {
|
||||
match self {
|
||||
Self::None => Ok(()),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn persistent_load(self) -> OResult<Self> {
|
||||
match self {
|
||||
Self::PersistentLoad(t) => Ok(*t),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bool(self) -> OResult<bool> {
|
||||
match self {
|
||||
Self::Bool(t) => Ok(t),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn int(self) -> OResult<i32> {
|
||||
match self {
|
||||
Self::Int(t) => Ok(t),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tuple(self) -> OResult<Vec<Self>> {
|
||||
match self {
|
||||
Self::Tuple(t) => Ok(t),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dict(self) -> OResult<Vec<(Self, Self)>> {
|
||||
match self {
|
||||
Self::Dict(t) => Ok(t),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn class(self) -> OResult<(String, String)> {
|
||||
match self {
|
||||
Self::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} => Ok((module_name, class_name)),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Object> for String {
|
||||
type Error = Object;
|
||||
fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
|
||||
match value {
|
||||
Object::Unicode(s) => Ok(s),
|
||||
other => Err(other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Object> for usize {
|
||||
type Error = Object;
|
||||
fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
|
||||
match value {
|
||||
Object::Int(s) if s >= 0 => Ok(s as usize),
|
||||
other => Err(other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: TryFrom<Object, Error = Object>> TryFrom<Object> for Vec<T> {
|
||||
type Error = Object;
|
||||
fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
|
||||
match value {
|
||||
Object::Tuple(values) => {
|
||||
// This does not return the appropriate value in the error case but instead return
|
||||
// the object related to the first error.
|
||||
values
|
||||
.into_iter()
|
||||
.map(|v| T::try_from(v))
|
||||
.collect::<std::result::Result<Vec<T>, Self::Error>>()
|
||||
}
|
||||
other => Err(other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Stack {
|
||||
stack: Vec<Object>,
|
||||
memo: HashMap<u32, Object>,
|
||||
}
|
||||
|
||||
impl Stack {
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
stack: Vec::with_capacity(512),
|
||||
memo: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stack(&self) -> &[Object] {
|
||||
self.stack.as_slice()
|
||||
}
|
||||
|
||||
pub fn read_loop<R: BufRead>(&mut self, r: &mut R) -> Result<()> {
|
||||
loop {
|
||||
if self.read(r)? {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn finalize(mut self) -> Result<Object> {
|
||||
self.pop()
|
||||
}
|
||||
|
||||
fn push(&mut self, obj: Object) {
|
||||
self.stack.push(obj)
|
||||
}
|
||||
|
||||
fn pop(&mut self) -> Result<Object> {
|
||||
match self.stack.pop() {
|
||||
None => crate::bail!("unexpected empty stack"),
|
||||
Some(obj) => Ok(obj),
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/#Pickle.OpCodes.BUILD
|
||||
fn build(&mut self) -> Result<()> {
|
||||
let args = self.pop()?;
|
||||
let obj = self.pop()?;
|
||||
let obj = match (obj, args) {
|
||||
(Object::Dict(mut obj), Object::Dict(mut args)) => {
|
||||
obj.append(&mut args);
|
||||
Object::Dict(obj)
|
||||
}
|
||||
(obj, args) => Object::Build {
|
||||
callable: Box::new(obj),
|
||||
args: Box::new(args),
|
||||
},
|
||||
};
|
||||
self.push(obj);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn reduce(&mut self) -> Result<()> {
|
||||
let args = self.pop()?;
|
||||
let callable = self.pop()?;
|
||||
#[allow(clippy::single_match)]
|
||||
let reduced = match &callable {
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} => {
|
||||
if module_name == "collections" && class_name == "OrderedDict" {
|
||||
// TODO: have a separate ordered dict.
|
||||
Some(Object::Dict(vec![]))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
let reduced = reduced.unwrap_or_else(|| Object::Reduce {
|
||||
callable: Box::new(callable),
|
||||
args: Box::new(args),
|
||||
});
|
||||
self.push(reduced);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn last(&mut self) -> Result<&mut Object> {
|
||||
match self.stack.last_mut() {
|
||||
None => crate::bail!("unexpected empty stack"),
|
||||
Some(obj) => Ok(obj),
|
||||
}
|
||||
}
|
||||
|
||||
fn memo_get(&self, id: u32) -> Result<Object> {
|
||||
match self.memo.get(&id) {
|
||||
None => crate::bail!("missing object in memo {id}"),
|
||||
Some(obj) => {
|
||||
// Maybe we should use refcounting rather than doing potential large clones here.
|
||||
Ok(obj.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn memo_put(&mut self, id: u32) -> Result<()> {
|
||||
let obj = self.last()?.clone();
|
||||
self.memo.insert(id, obj);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn persistent_load(&self, id: Object) -> Result<Object> {
|
||||
Ok(Object::PersistentLoad(Box::new(id)))
|
||||
}
|
||||
|
||||
fn new_obj(&self, class: Object, args: Object) -> Result<Object> {
|
||||
Ok(Object::Reduce {
|
||||
callable: Box::new(class),
|
||||
args: Box::new(args),
|
||||
})
|
||||
}
|
||||
|
||||
fn pop_to_marker(&mut self) -> Result<Vec<Object>> {
|
||||
let mut mark_idx = None;
|
||||
for (idx, obj) in self.stack.iter().enumerate().rev() {
|
||||
if obj == &Object::Mark {
|
||||
mark_idx = Some(idx);
|
||||
break;
|
||||
}
|
||||
}
|
||||
match mark_idx {
|
||||
Some(mark_idx) => {
|
||||
let objs = self.stack.split_off(mark_idx + 1);
|
||||
self.stack.pop();
|
||||
Ok(objs)
|
||||
}
|
||||
None => {
|
||||
crate::bail!("marker object not found")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read<R: BufRead>(&mut self, r: &mut R) -> Result<bool> {
|
||||
let op_code = match OpCode::try_from(r.read_u8()?) {
|
||||
Ok(op_code) => op_code,
|
||||
Err(op_code) => {
|
||||
crate::bail!("unknown op-code {op_code}")
|
||||
}
|
||||
};
|
||||
// println!("op: {op_code:?}");
|
||||
// println!("{:?}", self.stack);
|
||||
match op_code {
|
||||
OpCode::Proto => {
|
||||
let version = r.read_u8()?;
|
||||
if VERBOSE {
|
||||
println!("proto {version}");
|
||||
}
|
||||
}
|
||||
OpCode::Global => {
|
||||
let module_name = read_to_newline(r)?;
|
||||
let class_name = read_to_newline(r)?;
|
||||
let module_name = String::from_utf8_lossy(&module_name).to_string();
|
||||
let class_name = String::from_utf8_lossy(&class_name).to_string();
|
||||
self.push(Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
})
|
||||
}
|
||||
OpCode::BinInt1 => {
|
||||
let arg = r.read_u8()?;
|
||||
self.push(Object::Int(arg as i32))
|
||||
}
|
||||
OpCode::BinInt2 => {
|
||||
let arg = r.read_u16::<LittleEndian>()?;
|
||||
self.push(Object::Int(arg as i32))
|
||||
}
|
||||
OpCode::BinInt => {
|
||||
let arg = r.read_i32::<LittleEndian>()?;
|
||||
self.push(Object::Int(arg))
|
||||
}
|
||||
OpCode::BinFloat => {
|
||||
let arg = r.read_f64::<LittleEndian>()?;
|
||||
self.push(Object::Float(arg))
|
||||
}
|
||||
OpCode::BinUnicode => {
|
||||
let len = r.read_u32::<LittleEndian>()?;
|
||||
let mut data = vec![0u8; len as usize];
|
||||
r.read_exact(&mut data)?;
|
||||
let data = String::from_utf8(data).map_err(E::wrap)?;
|
||||
self.push(Object::Unicode(data))
|
||||
}
|
||||
OpCode::BinPersId => {
|
||||
let id = self.pop()?;
|
||||
let obj = self.persistent_load(id)?;
|
||||
self.push(obj)
|
||||
}
|
||||
OpCode::Tuple => {
|
||||
let objs = self.pop_to_marker()?;
|
||||
self.push(Object::Tuple(objs))
|
||||
}
|
||||
OpCode::Tuple1 => {
|
||||
let obj = self.pop()?;
|
||||
self.push(Object::Tuple(vec![obj]))
|
||||
}
|
||||
OpCode::Tuple2 => {
|
||||
let obj2 = self.pop()?;
|
||||
let obj1 = self.pop()?;
|
||||
self.push(Object::Tuple(vec![obj1, obj2]))
|
||||
}
|
||||
OpCode::Tuple3 => {
|
||||
let obj3 = self.pop()?;
|
||||
let obj2 = self.pop()?;
|
||||
let obj1 = self.pop()?;
|
||||
self.push(Object::Tuple(vec![obj1, obj2, obj3]))
|
||||
}
|
||||
OpCode::NewTrue => self.push(Object::Bool(true)),
|
||||
OpCode::NewFalse => self.push(Object::Bool(false)),
|
||||
OpCode::Append => {
|
||||
let value = self.pop()?;
|
||||
let pylist = self.last()?;
|
||||
if let Object::List(d) = pylist {
|
||||
d.push(value)
|
||||
} else {
|
||||
crate::bail!("expected a list, got {pylist:?}")
|
||||
}
|
||||
}
|
||||
OpCode::Appends => {
|
||||
let objs = self.pop_to_marker()?;
|
||||
let pylist = self.last()?;
|
||||
if let Object::List(d) = pylist {
|
||||
d.extend(objs)
|
||||
} else {
|
||||
crate::bail!("expected a list, got {pylist:?}")
|
||||
}
|
||||
}
|
||||
OpCode::SetItem => {
|
||||
let value = self.pop()?;
|
||||
let key = self.pop()?;
|
||||
let pydict = self.last()?;
|
||||
if let Object::Dict(d) = pydict {
|
||||
d.push((key, value))
|
||||
} else {
|
||||
crate::bail!("expected a dict, got {pydict:?}")
|
||||
}
|
||||
}
|
||||
OpCode::SetItems => {
|
||||
let mut objs = self.pop_to_marker()?;
|
||||
let pydict = self.last()?;
|
||||
if let Object::Dict(d) = pydict {
|
||||
if objs.len() % 2 != 0 {
|
||||
crate::bail!("setitems: not an even number of objects")
|
||||
}
|
||||
while let Some(value) = objs.pop() {
|
||||
let key = objs.pop().unwrap();
|
||||
d.push((key, value))
|
||||
}
|
||||
} else {
|
||||
crate::bail!("expected a dict, got {pydict:?}")
|
||||
}
|
||||
}
|
||||
OpCode::None => self.push(Object::None),
|
||||
OpCode::Stop => {
|
||||
return Ok(true);
|
||||
}
|
||||
OpCode::Build => self.build()?,
|
||||
OpCode::EmptyDict => self.push(Object::Dict(vec![])),
|
||||
OpCode::Dict => {
|
||||
let mut objs = self.pop_to_marker()?;
|
||||
let mut pydict = vec![];
|
||||
if objs.len() % 2 != 0 {
|
||||
crate::bail!("setitems: not an even number of objects")
|
||||
}
|
||||
while let Some(value) = objs.pop() {
|
||||
let key = objs.pop().unwrap();
|
||||
pydict.push((key, value))
|
||||
}
|
||||
self.push(Object::Dict(pydict))
|
||||
}
|
||||
OpCode::Mark => self.push(Object::Mark),
|
||||
OpCode::Reduce => self.reduce()?,
|
||||
OpCode::EmptyTuple => self.push(Object::Tuple(vec![])),
|
||||
OpCode::EmptyList => self.push(Object::List(vec![])),
|
||||
OpCode::BinGet => {
|
||||
let arg = r.read_u8()?;
|
||||
let obj = self.memo_get(arg as u32)?;
|
||||
self.push(obj)
|
||||
}
|
||||
OpCode::LongBinGet => {
|
||||
let arg = r.read_u32::<LittleEndian>()?;
|
||||
let obj = self.memo_get(arg)?;
|
||||
self.push(obj)
|
||||
}
|
||||
OpCode::BinPut => {
|
||||
let arg = r.read_u8()?;
|
||||
self.memo_put(arg as u32)?
|
||||
}
|
||||
OpCode::LongBinPut => {
|
||||
let arg = r.read_u32::<LittleEndian>()?;
|
||||
self.memo_put(arg)?
|
||||
}
|
||||
OpCode::NewObj => {
|
||||
let args = self.pop()?;
|
||||
let class = self.pop()?;
|
||||
let obj = self.new_obj(class, args)?;
|
||||
self.push(obj)
|
||||
}
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Object> for E {
|
||||
fn from(value: Object) -> Self {
|
||||
E::Msg(format!("conversion error on {value:?}"))
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/pytorch/pytorch/blob/4eac43d046ded0f0a5a5fa8db03eb40f45bf656e/torch/_utils.py#L198
|
||||
// Arguments: storage, storage_offset, size, stride, requires_grad, backward_hooks
|
||||
fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
|
||||
let mut args = args.tuple()?;
|
||||
let stride = Vec::<usize>::try_from(args.remove(3))?;
|
||||
let size = Vec::<usize>::try_from(args.remove(2))?;
|
||||
let offset = args.remove(1).int()? as usize;
|
||||
let storage = args.remove(0).persistent_load()?;
|
||||
let mut storage = storage.tuple()?;
|
||||
let storage_size = storage.remove(4).int()? as usize;
|
||||
let path = storage.remove(2).unicode()?;
|
||||
let (_module_name, class_name) = storage.remove(1).class()?;
|
||||
let dtype = match class_name.as_str() {
|
||||
"FloatStorage" => DType::F32,
|
||||
"DoubleStorage" => DType::F64,
|
||||
"HalfStorage" => DType::F16,
|
||||
"BFloat16Storage" => DType::BF16,
|
||||
"ByteStorage" => DType::U8,
|
||||
other => {
|
||||
crate::bail!("unsupported storage type {other}")
|
||||
}
|
||||
};
|
||||
let layout = Layout::new(crate::Shape::from(size), stride, offset);
|
||||
Ok((layout, dtype, path, storage_size))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TensorInfo {
|
||||
pub name: String,
|
||||
pub dtype: DType,
|
||||
pub layout: Layout,
|
||||
pub path: String,
|
||||
pub storage_size: usize,
|
||||
}
|
||||
|
||||
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
||||
file: P,
|
||||
verbose: bool,
|
||||
) -> Result<Vec<TensorInfo>> {
|
||||
let file = std::fs::File::open(file)?;
|
||||
let zip_reader = std::io::BufReader::new(file);
|
||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||
let zip_file_names = zip
|
||||
.file_names()
|
||||
.map(|f| f.to_string())
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
let mut tensor_infos = vec![];
|
||||
for file_name in zip_file_names.iter() {
|
||||
if !file_name.ends_with("data.pkl") {
|
||||
continue;
|
||||
}
|
||||
let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").unwrap());
|
||||
let reader = zip.by_name(file_name)?;
|
||||
let mut reader = std::io::BufReader::new(reader);
|
||||
let mut stack = Stack::empty();
|
||||
stack.read_loop(&mut reader)?;
|
||||
let obj = stack.finalize()?;
|
||||
if VERBOSE || verbose {
|
||||
println!("{obj:?}");
|
||||
}
|
||||
let obj = match obj {
|
||||
Object::Build { callable, args } => match *callable {
|
||||
Object::Reduce { callable, args: _ } => match *callable {
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "__torch__" && class_name == "Module" => *args,
|
||||
_ => continue,
|
||||
},
|
||||
_ => continue,
|
||||
},
|
||||
obj => obj,
|
||||
};
|
||||
if let Object::Dict(key_values) = obj {
|
||||
for (name, value) in key_values.into_iter() {
|
||||
let name = match name.unicode() {
|
||||
Ok(name) => name,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let (callable, args) = match value.reduce() {
|
||||
Ok(callable_args) => callable_args,
|
||||
_ => continue,
|
||||
};
|
||||
let (callable, args) = match callable {
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "torch._tensor"
|
||||
&& class_name == "_rebuild_from_type_v2" =>
|
||||
{
|
||||
let mut args = args.tuple()?;
|
||||
let callable = args.remove(0);
|
||||
let args = args.remove(1);
|
||||
(callable, args)
|
||||
}
|
||||
_ => (callable, args),
|
||||
};
|
||||
match callable {
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
|
||||
_ => continue,
|
||||
};
|
||||
match rebuild_args(args) {
|
||||
Ok((layout, dtype, file_path, storage_size)) => {
|
||||
let mut path = dir_name.clone();
|
||||
path.push(file_path);
|
||||
tensor_infos.push(TensorInfo {
|
||||
name,
|
||||
dtype,
|
||||
layout,
|
||||
path: path.to_string_lossy().into_owned(),
|
||||
storage_size,
|
||||
})
|
||||
}
|
||||
Err(err) => {
|
||||
eprintln!("skipping {name}: {err:?}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(tensor_infos)
|
||||
}
|
||||
|
||||
/// Lazy tensor loader.
|
||||
pub struct PthTensors {
|
||||
tensor_infos: HashMap<String, TensorInfo>,
|
||||
path: std::path::PathBuf,
|
||||
// We do not store a zip reader as it needs mutable access to extract data. Instead we
|
||||
// re-create a zip reader for each tensor.
|
||||
}
|
||||
|
||||
impl PthTensors {
|
||||
pub fn new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
|
||||
let tensor_infos = read_pth_tensor_info(path.as_ref(), false)?;
|
||||
let tensor_infos = tensor_infos
|
||||
.into_iter()
|
||||
.map(|ti| (ti.name.to_string(), ti))
|
||||
.collect();
|
||||
let path = path.as_ref().to_owned();
|
||||
Ok(Self { tensor_infos, path })
|
||||
}
|
||||
|
||||
pub fn tensor_infos(&self) -> &HashMap<String, TensorInfo> {
|
||||
&self.tensor_infos
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
|
||||
let tensor_info = match self.tensor_infos.get(name) {
|
||||
None => return Ok(None),
|
||||
Some(tensor_info) => tensor_info,
|
||||
};
|
||||
// We hope that the file has not changed since first reading it.
|
||||
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
|
||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||
let mut reader = zip.by_name(&tensor_info.path)?;
|
||||
|
||||
// Reading the data is a bit tricky as it can be strided, use an offset, etc.
|
||||
// For now only support the basic case.
|
||||
if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() {
|
||||
crate::bail!(
|
||||
"cannot retrieve non-contiguous tensors {:?}",
|
||||
tensor_info.layout
|
||||
)
|
||||
}
|
||||
let tensor = Tensor::from_reader(
|
||||
tensor_info.layout.shape().clone(),
|
||||
tensor_info.dtype,
|
||||
&mut reader,
|
||||
)?;
|
||||
Ok(Some(tensor))
|
||||
}
|
||||
}
|
@ -1,640 +0,0 @@
|
||||
use super::k_quants::{
|
||||
BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K,
|
||||
};
|
||||
use crate::Result;
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
use half::f16;
|
||||
|
||||
#[cfg(target_arch = "x86")]
|
||||
use core::arch::x86::*;
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
use core::arch::x86_64::*;
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn sum_i16_pairs_float(x: __m256i) -> __m256 {
|
||||
let ones = _mm256_set1_epi16(1);
|
||||
let summed_pairs = _mm256_madd_epi16(ones, x);
|
||||
_mm256_cvtepi32_ps(summed_pairs)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn mul_sum_us8_pairs_float(ax: __m256i, sy: __m256i) -> __m256 {
|
||||
let dot = _mm256_maddubs_epi16(ax, sy);
|
||||
sum_i16_pairs_float(dot)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn hsum_float_8(x: __m256) -> f32 {
|
||||
let res = _mm256_extractf128_ps(x, 1);
|
||||
let res = _mm_add_ps(res, _mm256_castps256_ps128(x));
|
||||
let res = _mm_add_ps(res, _mm_movehl_ps(res, res));
|
||||
let res = _mm_add_ss(res, _mm_movehdup_ps(res));
|
||||
_mm_cvtss_f32(res)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn bytes_from_nibbles_32(rsi: *const u8) -> __m256i {
|
||||
let tmp = _mm_loadu_si128(rsi as *const __m128i);
|
||||
let bytes = _mm256_insertf128_si256::<1>(_mm256_castsi128_si256(tmp), _mm_srli_epi16(tmp, 4));
|
||||
let low_mask = _mm256_set1_epi8(0xF);
|
||||
_mm256_and_si256(low_mask, bytes)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 {
|
||||
let ax = _mm256_sign_epi8(x, x);
|
||||
let sy = _mm256_sign_epi8(y, x);
|
||||
mul_sum_us8_pairs_float(ax, sy)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
let nb = n / qk;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d));
|
||||
let bx = bytes_from_nibbles_32(x.qs.as_ptr());
|
||||
let off = _mm256_set1_epi8(8);
|
||||
let bx = _mm256_sub_epi8(bx, off);
|
||||
let by = _mm256_loadu_si256(y.qs.as_ptr() as *const __m256i);
|
||||
let q = mul_sum_i8_pairs_float(bx, by);
|
||||
acc = _mm256_fmadd_ps(d, q, acc);
|
||||
}
|
||||
Ok(hsum_float_8(acc))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
unsafe {
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d));
|
||||
let bx = _mm256_loadu_si256(x.qs.as_ptr() as *const __m256i);
|
||||
let by = _mm256_loadu_si256(y.qs.as_ptr() as *const __m256i);
|
||||
let q = mul_sum_i8_pairs_float(bx, by);
|
||||
acc = _mm256_fmadd_ps(d, q, acc);
|
||||
}
|
||||
Ok(hsum_float_8(acc))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn get_scale_shuffle(i: usize) -> __m128i {
|
||||
const K_SHUFFLE: [u8; 128] = [
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3,
|
||||
3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7,
|
||||
7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10,
|
||||
11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13,
|
||||
13, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15,
|
||||
];
|
||||
_mm_loadu_si128((K_SHUFFLE.as_ptr() as *const __m128i).add(i))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn get_scale_shuffle_k4(i: usize) -> __m256i {
|
||||
const K_SHUFFLE: [u8; 256] = [
|
||||
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
|
||||
0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
|
||||
2, 3, 2, 3, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
|
||||
4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
|
||||
6, 7, 6, 7, 6, 7, 6, 7, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
|
||||
8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10,
|
||||
11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 12, 13, 12, 13, 12, 13,
|
||||
12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12,
|
||||
13, 12, 13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
|
||||
14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
|
||||
];
|
||||
_mm256_loadu_si256((K_SHUFFLE.as_ptr() as *const __m256i).add(i))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn get_scale_shuffle_q3k(i: usize) -> __m256i {
|
||||
const K_SHUFFLE: [u8; 128] = [
|
||||
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
|
||||
2, 3, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
|
||||
6, 7, 6, 7, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10, 11, 10, 11, 10, 11, 10, 11,
|
||||
10, 11, 10, 11, 10, 11, 10, 11, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12,
|
||||
13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
|
||||
];
|
||||
_mm256_loadu_si256((K_SHUFFLE.as_ptr() as *const __m256i).add(i))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
let qk = QK_K;
|
||||
if n % qk != 0 {
|
||||
crate::bail!("vec_dot_q6k_8k: {n} is not divisible by {qk}")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let m4 = _mm256_set1_epi8(0xF);
|
||||
let m2 = _mm256_set1_epi8(3);
|
||||
let m32s = _mm256_set1_epi8(32);
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
let mut q4 = x.ql.as_ptr();
|
||||
let mut qh = x.qh.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
let scales = _mm_loadu_si128(x.scales.as_ptr() as *const __m128i);
|
||||
let mut sumi = _mm256_setzero_si256();
|
||||
|
||||
for j in 0..QK_K / 128 {
|
||||
let is = j * 4;
|
||||
let scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is));
|
||||
let scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
|
||||
let scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
|
||||
let scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
|
||||
|
||||
let q4bits1 = _mm256_loadu_si256(q4 as *const __m256i);
|
||||
q4 = q4.add(32);
|
||||
let q4bits2 = _mm256_loadu_si256(q4 as *const __m256i);
|
||||
q4 = q4.add(32);
|
||||
let q4bits_h = _mm256_loadu_si256(qh as *const __m256i);
|
||||
qh = qh.add(32);
|
||||
|
||||
let q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bits_h, m2), 4);
|
||||
let q4h_1 =
|
||||
_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bits_h, 2), m2), 4);
|
||||
let q4h_2 =
|
||||
_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bits_h, 4), m2), 4);
|
||||
let q4h_3 =
|
||||
_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bits_h, 6), m2), 4);
|
||||
|
||||
let q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
|
||||
let q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
|
||||
let q4_2 =
|
||||
_mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
|
||||
let q4_3 =
|
||||
_mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
|
||||
|
||||
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_2 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_3 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
|
||||
let q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
|
||||
let q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
|
||||
let q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
|
||||
let q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
|
||||
|
||||
let p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
|
||||
let p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
|
||||
let p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
|
||||
let p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
|
||||
|
||||
let p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
|
||||
let p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
|
||||
let p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
|
||||
let p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
|
||||
|
||||
let p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
|
||||
let p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
|
||||
let p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2);
|
||||
let p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3);
|
||||
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3));
|
||||
}
|
||||
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
|
||||
}
|
||||
Ok(hsum_float_8(acc))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn mm256_set_m128i(a: __m128i, b: __m128i) -> __m256i {
|
||||
_mm256_insertf128_si256(_mm256_castsi128_si256(b), a, 1)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let m3 = _mm256_set1_epi8(3);
|
||||
let m4 = _mm_set1_epi8(0xF);
|
||||
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
let dmin = -y.d * x.dmin.to_f32();
|
||||
|
||||
let mut q2 = x.qs.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
let mins_and_scales = _mm_loadu_si128(x.scales.as_ptr() as *const __m128i);
|
||||
let scales8 = _mm_and_si128(mins_and_scales, m4);
|
||||
let mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
|
||||
let mins = _mm256_cvtepi8_epi16(mins8);
|
||||
let prod =
|
||||
_mm256_madd_epi16(mins, _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i));
|
||||
|
||||
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc);
|
||||
|
||||
let all_scales = _mm256_cvtepi8_epi16(scales8);
|
||||
let l_scales = _mm256_extracti128_si256(all_scales, 0);
|
||||
let h_scales = _mm256_extracti128_si256(all_scales, 1);
|
||||
let scales = [
|
||||
mm256_set_m128i(l_scales, l_scales),
|
||||
mm256_set_m128i(h_scales, h_scales),
|
||||
];
|
||||
|
||||
let mut sumi = _mm256_setzero_si256();
|
||||
|
||||
for scale in scales {
|
||||
let q2bits = _mm256_loadu_si256(q2 as *const __m256i);
|
||||
q2 = q2.add(32);
|
||||
|
||||
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_2 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_3 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
|
||||
let q2_0 = _mm256_and_si256(q2bits, m3);
|
||||
let q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3);
|
||||
let q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3);
|
||||
let q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3);
|
||||
|
||||
let p0 = _mm256_maddubs_epi16(q2_0, q8_0);
|
||||
let p1 = _mm256_maddubs_epi16(q2_1, q8_1);
|
||||
let p2 = _mm256_maddubs_epi16(q2_2, q8_2);
|
||||
let p3 = _mm256_maddubs_epi16(q2_3, q8_3);
|
||||
|
||||
let p0 =
|
||||
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(0)), p0);
|
||||
let p1 =
|
||||
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(1)), p1);
|
||||
let p2 =
|
||||
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(2)), p2);
|
||||
let p3 =
|
||||
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(3)), p3);
|
||||
|
||||
let p0 = _mm256_add_epi32(p0, p1);
|
||||
let p2 = _mm256_add_epi32(p2, p3);
|
||||
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2));
|
||||
}
|
||||
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
|
||||
}
|
||||
|
||||
Ok(hsum_float_8(acc))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
|
||||
const KMASK1: u32 = 0x03030303;
|
||||
const KMASK2: u32 = 0x0f0f0f0f;
|
||||
|
||||
let mut aux = [0u32; 3];
|
||||
|
||||
unsafe {
|
||||
let m3 = _mm256_set1_epi8(3);
|
||||
let mone = _mm256_set1_epi8(1);
|
||||
let m32 = _mm_set1_epi8(32);
|
||||
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
|
||||
let mut q3 = x.qs.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
LittleEndian::read_u32_into(&x.scales, &mut aux);
|
||||
let scales128 = _mm_set_epi32(
|
||||
(((aux[1] >> 4) & KMASK2) | (((aux[2] >> 6) & KMASK1) << 4)) as i32,
|
||||
(((aux[0] >> 4) & KMASK2) | (((aux[2] >> 4) & KMASK1) << 4)) as i32,
|
||||
((aux[1] & KMASK2) | (((aux[2] >> 2) & KMASK1) << 4)) as i32,
|
||||
((aux[0] & KMASK2) | (((aux[2]) & KMASK1) << 4)) as i32,
|
||||
);
|
||||
let scales128 = _mm_sub_epi8(scales128, m32);
|
||||
let all_scales = _mm256_cvtepi8_epi16(scales128);
|
||||
let l_scales = _mm256_extracti128_si256(all_scales, 0);
|
||||
let h_scales = _mm256_extracti128_si256(all_scales, 1);
|
||||
let scales = [
|
||||
mm256_set_m128i(l_scales, l_scales),
|
||||
mm256_set_m128i(h_scales, h_scales),
|
||||
];
|
||||
|
||||
// high bit
|
||||
let hbits = _mm256_loadu_si256(x.hmask.as_ptr() as *const __m256i);
|
||||
|
||||
let mut sumi = _mm256_setzero_si256();
|
||||
|
||||
for (j, scale) in scales.iter().enumerate() {
|
||||
// load low 2 bits
|
||||
let q3bits = _mm256_loadu_si256(q3 as *const __m256i);
|
||||
q3 = q3.add(32);
|
||||
|
||||
// Prepare low and high bits
|
||||
// We hardcode the shifts here to avoid loading them into a seperate register
|
||||
let q3l_0 = _mm256_and_si256(q3bits, m3);
|
||||
let q3h_0 = if j == 0 {
|
||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), 0)
|
||||
} else {
|
||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 4)), 4)
|
||||
};
|
||||
let q3h_0 = _mm256_slli_epi16(q3h_0, 2);
|
||||
|
||||
let q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3);
|
||||
let q3h_1 = if j == 0 {
|
||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 1)), 1)
|
||||
} else {
|
||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 5)), 5)
|
||||
};
|
||||
let q3h_1 = _mm256_slli_epi16(q3h_1, 2);
|
||||
|
||||
let q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3);
|
||||
let q3h_2 = if j == 0 {
|
||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 2)), 2)
|
||||
} else {
|
||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 6)), 6)
|
||||
};
|
||||
let q3h_2 = _mm256_slli_epi16(q3h_2, 2);
|
||||
|
||||
let q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3);
|
||||
let q3h_3 = if j == 0 {
|
||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 3)), 3)
|
||||
} else {
|
||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 7)), 7)
|
||||
};
|
||||
let q3h_3 = _mm256_slli_epi16(q3h_3, 2);
|
||||
|
||||
// load Q8 quants
|
||||
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_2 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_3 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
|
||||
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we
|
||||
// can use _mm256_maddubs_epi16, and then subtract. The high bit part has the 2
|
||||
// already subtracted (and so, it is zero if the high bit was not set, and 2 if the
|
||||
// high bit was set)
|
||||
let q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
|
||||
let q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
|
||||
let q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2);
|
||||
let q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3);
|
||||
|
||||
let p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
|
||||
let p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
|
||||
let p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2);
|
||||
let p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3);
|
||||
|
||||
let p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
|
||||
let p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
|
||||
let p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
|
||||
let p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
|
||||
|
||||
// multiply with scales
|
||||
let p16_0 =
|
||||
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(0)), p16_0);
|
||||
let p16_1 =
|
||||
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(1)), p16_1);
|
||||
let p16_2 =
|
||||
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(2)), p16_2);
|
||||
let p16_3 =
|
||||
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(3)), p16_3);
|
||||
|
||||
// accumulate
|
||||
let p16_0 = _mm256_add_epi32(p16_0, p16_1);
|
||||
let p16_2 = _mm256_add_epi32(p16_2, p16_3);
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2));
|
||||
}
|
||||
|
||||
// multiply with block scale and accumulate
|
||||
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
|
||||
}
|
||||
Ok(hsum_float_8(acc))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut utmp = [0u32; 4];
|
||||
const KMASK1: u32 = 0x3f3f3f3f;
|
||||
const KMASK2: u32 = 0x0f0f0f0f;
|
||||
const KMASK3: u32 = 0x03030303;
|
||||
|
||||
unsafe {
|
||||
let m4 = _mm256_set1_epi8(0xF);
|
||||
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
let mut acc_m = _mm_setzero_ps();
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
let dmin = -y.d * x.dmin.to_f32();
|
||||
|
||||
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
|
||||
|
||||
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
|
||||
let uaux = utmp[1] & KMASK1;
|
||||
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
|
||||
utmp[2] = uaux;
|
||||
utmp[0] &= KMASK1;
|
||||
|
||||
let mut q4 = x.qs.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
let mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(
|
||||
utmp[3] as i32,
|
||||
utmp[2] as i32,
|
||||
utmp[1] as i32,
|
||||
utmp[0] as i32,
|
||||
));
|
||||
|
||||
let q8sums = _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i);
|
||||
let q8s = _mm_hadd_epi16(
|
||||
_mm256_extracti128_si256(q8sums, 0),
|
||||
_mm256_extracti128_si256(q8sums, 1),
|
||||
);
|
||||
let prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
|
||||
acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
|
||||
|
||||
let sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
|
||||
let scales = mm256_set_m128i(sc128, sc128);
|
||||
|
||||
let mut sumi = _mm256_setzero_si256();
|
||||
|
||||
for j in 0..QK_K / 64 {
|
||||
let scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j));
|
||||
let scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j + 1));
|
||||
|
||||
let q4bits = _mm256_loadu_si256(q4 as *const __m256i);
|
||||
q4 = q4.add(32);
|
||||
let q4l = _mm256_and_si256(q4bits, m4);
|
||||
let q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
|
||||
|
||||
let q8l = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let p16l = _mm256_maddubs_epi16(q4l, q8l);
|
||||
let p16l = _mm256_madd_epi16(scale_l, p16l);
|
||||
sumi = _mm256_add_epi32(sumi, p16l);
|
||||
|
||||
let q8h = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let p16h = _mm256_maddubs_epi16(q4h, q8h);
|
||||
let p16h = _mm256_madd_epi16(scale_h, p16h);
|
||||
sumi = _mm256_add_epi32(sumi, p16h);
|
||||
}
|
||||
|
||||
let vd = _mm256_set1_ps(d);
|
||||
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
|
||||
}
|
||||
|
||||
let acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
|
||||
let acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
|
||||
|
||||
Ok(hsum_float_8(acc) + _mm_cvtss_f32(acc_m))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut utmp = [0u32; 4];
|
||||
const KMASK1: u32 = 0x3f3f3f3f;
|
||||
const KMASK2: u32 = 0x0f0f0f0f;
|
||||
const KMASK3: u32 = 0x03030303;
|
||||
|
||||
unsafe {
|
||||
let m4 = _mm256_set1_epi8(0xF);
|
||||
let mzero = _mm_setzero_si128();
|
||||
let mone = _mm256_set1_epi8(1);
|
||||
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
let mut summs = 0.0;
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
let dmin = -y.d * x.dmin.to_f32();
|
||||
|
||||
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
|
||||
|
||||
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
|
||||
let uaux = utmp[1] & KMASK1;
|
||||
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
|
||||
utmp[2] = uaux;
|
||||
utmp[0] &= KMASK1;
|
||||
|
||||
let mut q5 = x.qs.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
let mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(
|
||||
utmp[3] as i32,
|
||||
utmp[2] as i32,
|
||||
utmp[1] as i32,
|
||||
utmp[0] as i32,
|
||||
));
|
||||
|
||||
let q8sums = _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i);
|
||||
let q8s = _mm_hadd_epi16(
|
||||
_mm256_extracti128_si256(q8sums, 0),
|
||||
_mm256_extracti128_si256(q8sums, 1),
|
||||
);
|
||||
let prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
|
||||
let hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
|
||||
summs += dmin * _mm_extract_epi32(hsum, 0) as f32;
|
||||
|
||||
let sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
|
||||
let scales = mm256_set_m128i(sc128, sc128);
|
||||
|
||||
let hbits = _mm256_loadu_si256(x.qh.as_ptr() as *const __m256i);
|
||||
let mut hmask = mone;
|
||||
|
||||
let mut sumi = _mm256_setzero_si256();
|
||||
|
||||
for j in 0..QK_K / 64 {
|
||||
let scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j));
|
||||
let scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j + 1));
|
||||
|
||||
let q5bits = _mm256_loadu_si256(q5 as *const __m256i);
|
||||
q5 = q5.add(32);
|
||||
|
||||
//Similar to q3k we hardcode the shifts here to avoid loading them into a seperate register
|
||||
let q5l_0 = _mm256_and_si256(q5bits, m4);
|
||||
let q5l_0_shift_input = _mm256_and_si256(hbits, hmask);
|
||||
let q5l_0_right_shift = match j {
|
||||
0 => _mm256_srli_epi16(q5l_0_shift_input, 0),
|
||||
1 => _mm256_srli_epi16(q5l_0_shift_input, 2),
|
||||
2 => _mm256_srli_epi16(q5l_0_shift_input, 4),
|
||||
3 => _mm256_srli_epi16(q5l_0_shift_input, 6),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let q5h_0 = _mm256_slli_epi16(q5l_0_right_shift, 4);
|
||||
let q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
|
||||
hmask = _mm256_slli_epi16(hmask, 1);
|
||||
|
||||
let q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
|
||||
let q5l_1_shift_input = _mm256_and_si256(hbits, hmask);
|
||||
let q5l_1_right_shift = match j {
|
||||
0 => _mm256_srli_epi16(q5l_1_shift_input, 1),
|
||||
1 => _mm256_srli_epi16(q5l_1_shift_input, 3),
|
||||
2 => _mm256_srli_epi16(q5l_1_shift_input, 5),
|
||||
3 => _mm256_srli_epi16(q5l_1_shift_input, 7),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let q5h_1 = _mm256_slli_epi16(q5l_1_right_shift, 4);
|
||||
let q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
|
||||
hmask = _mm256_slli_epi16(hmask, 1);
|
||||
|
||||
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
|
||||
let p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
|
||||
let p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
|
||||
|
||||
let p16_0 = _mm256_madd_epi16(scale_0, p16_0);
|
||||
let p16_1 = _mm256_madd_epi16(scale_1, p16_1);
|
||||
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
|
||||
}
|
||||
let vd = _mm256_set1_ps(d);
|
||||
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
|
||||
}
|
||||
Ok(hsum_float_8(acc) + summs)
|
||||
}
|
||||
}
|
@ -3,7 +3,6 @@
|
||||
use super::{k_quants, GgmlDType};
|
||||
use crate::Result;
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use std::collections::HashMap;
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@ -125,7 +124,7 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
|
||||
let raw_data_ptr = raw_data.as_ptr();
|
||||
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
|
||||
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
||||
super::QTensor::new(data.to_vec(), dims)
|
||||
Ok(super::QTensor::new(data.to_vec(), dims))
|
||||
}
|
||||
|
||||
/// Creates a [Tensor] from a raw GGML tensor.
|
||||
@ -164,9 +163,6 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||
let ggml_dtype = GgmlDType::from_u32(ggml_dtype)?;
|
||||
let mut dims = vec![0u32; n_dims as usize];
|
||||
reader.read_u32_into::<LittleEndian>(&mut dims)?;
|
||||
// The dimensions are stored in reverse order, see for example:
|
||||
// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/convert.py#L969
|
||||
dims.reverse();
|
||||
let mut name = vec![0u8; name_len as usize];
|
||||
reader.read_exact(&mut name)?;
|
||||
let name = String::from_utf8_lossy(&name).into_owned();
|
||||
@ -178,6 +174,7 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
|
||||
let tensor_elems = dims.iter().product::<usize>();
|
||||
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
|
||||
println!("{name} {ggml_dtype:?} {dims:?}");
|
||||
// TODO: Mmap version to avoid copying the data around?
|
||||
let mut raw_data = vec![0u8; size_in_bytes];
|
||||
reader.read_exact(&mut raw_data)?;
|
||||
@ -191,7 +188,7 @@ pub struct Content {
|
||||
pub magic: VersionedMagic,
|
||||
pub hparams: HParams,
|
||||
pub vocab: Vocab,
|
||||
pub tensors: HashMap<String, super::QTensor>,
|
||||
pub tensors: Vec<(String, super::QTensor)>,
|
||||
}
|
||||
|
||||
impl Content {
|
||||
@ -202,11 +199,11 @@ impl Content {
|
||||
let magic = VersionedMagic::read(reader)?;
|
||||
let hparams = HParams::read(reader)?;
|
||||
let vocab = Vocab::read(reader, hparams.n_vocab as usize)?;
|
||||
let mut tensors = HashMap::new();
|
||||
let mut tensors = vec![];
|
||||
|
||||
while reader.stream_position()? != last_position {
|
||||
let (name, tensor) = read_one_tensor(reader, magic)?;
|
||||
tensors.insert(name, tensor);
|
||||
tensors.push((name, tensor))
|
||||
}
|
||||
Ok(Self {
|
||||
magic,
|
||||
@ -215,11 +212,4 @@ impl Content {
|
||||
tensors,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, name: &str) -> Result<super::QTensor> {
|
||||
match self.tensors.remove(name) {
|
||||
None => crate::bail!("cannot find tensor with name '{name}'"),
|
||||
Some(tensor) => Ok(tensor),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,513 +0,0 @@
|
||||
//! Support for the GGUF file format.
|
||||
//!
|
||||
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
|
||||
|
||||
use super::{GgmlDType, QTensor};
|
||||
use crate::Result;
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub const DEFAULT_ALIGNMENT: u64 = 32;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum Magic {
|
||||
Gguf,
|
||||
}
|
||||
|
||||
impl TryFrom<u32> for Magic {
|
||||
type Error = crate::Error;
|
||||
fn try_from(value: u32) -> Result<Self> {
|
||||
let magic = match value {
|
||||
0x46554747 | 0x47475546 => Self::Gguf,
|
||||
_ => crate::bail!("unknown magic 0x{value:08x}"),
|
||||
};
|
||||
Ok(magic)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum VersionedMagic {
|
||||
GgufV1,
|
||||
GgufV2,
|
||||
}
|
||||
|
||||
impl VersionedMagic {
|
||||
fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
|
||||
let magic = reader.read_u32::<LittleEndian>()?;
|
||||
let magic = Magic::try_from(magic)?;
|
||||
let version = reader.read_u32::<LittleEndian>()?;
|
||||
let versioned_magic = match (magic, version) {
|
||||
(Magic::Gguf, 1) => Self::GgufV1,
|
||||
(Magic::Gguf, 2) => Self::GgufV2,
|
||||
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
|
||||
};
|
||||
Ok(versioned_magic)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TensorInfo {
|
||||
pub ggml_dtype: GgmlDType,
|
||||
pub shape: crate::Shape,
|
||||
pub offset: u64,
|
||||
}
|
||||
|
||||
impl TensorInfo {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(
|
||||
&self,
|
||||
reader: &mut R,
|
||||
tensor_data_offset: u64,
|
||||
) -> Result<QTensor> {
|
||||
let tensor_elems = self.shape.elem_count();
|
||||
let size_in_bytes =
|
||||
tensor_elems * self.ggml_dtype.type_size() / self.ggml_dtype.blck_size();
|
||||
let mut raw_data = vec![0u8; size_in_bytes];
|
||||
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
|
||||
reader.read_exact(&mut raw_data)?;
|
||||
super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Content {
|
||||
pub magic: VersionedMagic,
|
||||
pub metadata: HashMap<String, Value>,
|
||||
pub tensor_infos: HashMap<String, TensorInfo>,
|
||||
pub tensor_data_offset: u64,
|
||||
}
|
||||
|
||||
fn read_string<R: std::io::Read>(reader: &mut R, magic: &VersionedMagic) -> Result<String> {
|
||||
let len = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||
};
|
||||
let mut v = vec![0u8; len];
|
||||
reader.read_exact(&mut v)?;
|
||||
// GGUF strings are supposed to be non-null terminated but in practice this happens.
|
||||
while let Some(0) = v.last() {
|
||||
v.pop();
|
||||
}
|
||||
// GGUF strings are utf8 encoded but there are cases that don't seem to be valid.
|
||||
Ok(String::from_utf8_lossy(&v).into_owned())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum ValueType {
|
||||
// The value is a 8-bit unsigned integer.
|
||||
U8,
|
||||
// The value is a 8-bit signed integer.
|
||||
I8,
|
||||
// The value is a 16-bit unsigned little-endian integer.
|
||||
U16,
|
||||
// The value is a 16-bit signed little-endian integer.
|
||||
I16,
|
||||
// The value is a 32-bit unsigned little-endian integer.
|
||||
U32,
|
||||
// The value is a 32-bit signed little-endian integer.
|
||||
I32,
|
||||
// The value is a 64-bit unsigned little-endian integer.
|
||||
U64,
|
||||
// The value is a 64-bit signed little-endian integer.
|
||||
I64,
|
||||
// The value is a 32-bit IEEE754 floating point number.
|
||||
F32,
|
||||
// The value is a 64-bit IEEE754 floating point number.
|
||||
F64,
|
||||
// The value is a boolean.
|
||||
// 1-byte value where 0 is false and 1 is true.
|
||||
// Anything else is invalid, and should be treated as either the model being invalid or the reader being buggy.
|
||||
Bool,
|
||||
// The value is a UTF-8 non-null-terminated string, with length prepended.
|
||||
String,
|
||||
// The value is an array of other values, with the length and type prepended.
|
||||
///
|
||||
// Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes.
|
||||
Array,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Value {
|
||||
U8(u8),
|
||||
I8(i8),
|
||||
U16(u16),
|
||||
I16(i16),
|
||||
U32(u32),
|
||||
I32(i32),
|
||||
U64(u64),
|
||||
I64(i64),
|
||||
F32(f32),
|
||||
F64(f64),
|
||||
Bool(bool),
|
||||
String(String),
|
||||
Array(Vec<Value>),
|
||||
}
|
||||
|
||||
impl Value {
|
||||
pub fn value_type(&self) -> ValueType {
|
||||
match self {
|
||||
Self::U8(_) => ValueType::U8,
|
||||
Self::I8(_) => ValueType::I8,
|
||||
Self::U16(_) => ValueType::U16,
|
||||
Self::I16(_) => ValueType::I16,
|
||||
Self::U32(_) => ValueType::U32,
|
||||
Self::I32(_) => ValueType::I32,
|
||||
Self::U64(_) => ValueType::U64,
|
||||
Self::I64(_) => ValueType::I64,
|
||||
Self::F32(_) => ValueType::F32,
|
||||
Self::F64(_) => ValueType::F64,
|
||||
Self::Bool(_) => ValueType::Bool,
|
||||
Self::String(_) => ValueType::String,
|
||||
Self::Array(_) => ValueType::Array,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_u8(&self) -> Result<u8> {
|
||||
match self {
|
||||
Self::U8(v) => Ok(*v),
|
||||
v => crate::bail!("not a u8 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_i8(&self) -> Result<i8> {
|
||||
match self {
|
||||
Self::I8(v) => Ok(*v),
|
||||
v => crate::bail!("not a i8 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_u16(&self) -> Result<u16> {
|
||||
match self {
|
||||
Self::U16(v) => Ok(*v),
|
||||
v => crate::bail!("not a u16 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_i16(&self) -> Result<i16> {
|
||||
match self {
|
||||
Self::I16(v) => Ok(*v),
|
||||
v => crate::bail!("not a i16 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_u32(&self) -> Result<u32> {
|
||||
match self {
|
||||
Self::U32(v) => Ok(*v),
|
||||
v => crate::bail!("not a u32 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_i32(&self) -> Result<i32> {
|
||||
match self {
|
||||
Self::I32(v) => Ok(*v),
|
||||
v => crate::bail!("not a i32 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_u64(&self) -> Result<u64> {
|
||||
match self {
|
||||
Self::U64(v) => Ok(*v),
|
||||
v => crate::bail!("not a u64 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_i64(&self) -> Result<i64> {
|
||||
match self {
|
||||
Self::I64(v) => Ok(*v),
|
||||
v => crate::bail!("not a i64 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_f32(&self) -> Result<f32> {
|
||||
match self {
|
||||
Self::F32(v) => Ok(*v),
|
||||
v => crate::bail!("not a f32 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_f64(&self) -> Result<f64> {
|
||||
match self {
|
||||
Self::F64(v) => Ok(*v),
|
||||
v => crate::bail!("not a f64 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_bool(&self) -> Result<bool> {
|
||||
match self {
|
||||
Self::Bool(v) => Ok(*v),
|
||||
v => crate::bail!("not a bool {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_vec(&self) -> Result<&Vec<Value>> {
|
||||
match self {
|
||||
Self::Array(v) => Ok(v),
|
||||
v => crate::bail!("not a vec {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_string(&self) -> Result<&String> {
|
||||
match self {
|
||||
Self::String(v) => Ok(v),
|
||||
v => crate::bail!("not a string {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn read<R: std::io::Read>(
|
||||
reader: &mut R,
|
||||
value_type: ValueType,
|
||||
magic: &VersionedMagic,
|
||||
) -> Result<Self> {
|
||||
let v = match value_type {
|
||||
ValueType::U8 => Self::U8(reader.read_u8()?),
|
||||
ValueType::I8 => Self::I8(reader.read_i8()?),
|
||||
ValueType::U16 => Self::U16(reader.read_u16::<LittleEndian>()?),
|
||||
ValueType::I16 => Self::I16(reader.read_i16::<LittleEndian>()?),
|
||||
ValueType::U32 => Self::U32(reader.read_u32::<LittleEndian>()?),
|
||||
ValueType::I32 => Self::I32(reader.read_i32::<LittleEndian>()?),
|
||||
ValueType::U64 => Self::U64(reader.read_u64::<LittleEndian>()?),
|
||||
ValueType::I64 => Self::I64(reader.read_i64::<LittleEndian>()?),
|
||||
ValueType::F32 => Self::F32(reader.read_f32::<LittleEndian>()?),
|
||||
ValueType::F64 => Self::F64(reader.read_f64::<LittleEndian>()?),
|
||||
ValueType::Bool => match reader.read_u8()? {
|
||||
0 => Self::Bool(false),
|
||||
1 => Self::Bool(true),
|
||||
b => crate::bail!("unexpected bool value {b}"),
|
||||
},
|
||||
ValueType::String => Self::String(read_string(reader, magic)?),
|
||||
ValueType::Array => {
|
||||
let value_type = reader.read_u32::<LittleEndian>()?;
|
||||
let value_type = ValueType::from_u32(value_type)?;
|
||||
let len = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||
};
|
||||
let mut vs = Vec::with_capacity(len);
|
||||
for _ in 0..len {
|
||||
vs.push(Value::read(reader, value_type, magic)?)
|
||||
}
|
||||
Self::Array(vs)
|
||||
}
|
||||
};
|
||||
Ok(v)
|
||||
}
|
||||
|
||||
fn write<W: std::io::Write>(&self, w: &mut W) -> Result<()> {
|
||||
match self {
|
||||
&Self::U8(v) => w.write_u8(v)?,
|
||||
&Self::I8(v) => w.write_i8(v)?,
|
||||
&Self::U16(v) => w.write_u16::<LittleEndian>(v)?,
|
||||
&Self::I16(v) => w.write_i16::<LittleEndian>(v)?,
|
||||
&Self::U32(v) => w.write_u32::<LittleEndian>(v)?,
|
||||
&Self::I32(v) => w.write_i32::<LittleEndian>(v)?,
|
||||
&Self::U64(v) => w.write_u64::<LittleEndian>(v)?,
|
||||
&Self::I64(v) => w.write_i64::<LittleEndian>(v)?,
|
||||
&Self::F32(v) => w.write_f32::<LittleEndian>(v)?,
|
||||
&Self::F64(v) => w.write_f64::<LittleEndian>(v)?,
|
||||
&Self::Bool(v) => w.write_u8(u8::from(v))?,
|
||||
Self::String(v) => write_string(w, v.as_str())?,
|
||||
Self::Array(v) => {
|
||||
// The `Value` type does not enforce that all the values in an Array have the same
|
||||
// type.
|
||||
let value_type = if v.is_empty() {
|
||||
// Doesn't matter, the array is empty.
|
||||
ValueType::U32
|
||||
} else {
|
||||
let value_type: std::collections::HashSet<_> =
|
||||
v.iter().map(|elem| elem.value_type()).collect();
|
||||
if value_type.len() != 1 {
|
||||
crate::bail!("multiple value-types in the same array {value_type:?}")
|
||||
}
|
||||
value_type.into_iter().next().unwrap()
|
||||
};
|
||||
w.write_u32::<LittleEndian>(value_type.to_u32())?;
|
||||
w.write_u64::<LittleEndian>(v.len() as u64)?;
|
||||
for elem in v.iter() {
|
||||
elem.write(w)?
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl ValueType {
|
||||
fn from_u32(v: u32) -> Result<Self> {
|
||||
let v = match v {
|
||||
0 => Self::U8,
|
||||
1 => Self::I8,
|
||||
2 => Self::U16,
|
||||
3 => Self::I16,
|
||||
4 => Self::U32,
|
||||
5 => Self::I32,
|
||||
6 => Self::F32,
|
||||
7 => Self::Bool,
|
||||
8 => Self::String,
|
||||
9 => Self::Array,
|
||||
10 => Self::U64,
|
||||
11 => Self::I64,
|
||||
12 => Self::F64,
|
||||
v => crate::bail!("unrecognized value-type {v:#08x}"),
|
||||
};
|
||||
Ok(v)
|
||||
}
|
||||
|
||||
fn to_u32(self) -> u32 {
|
||||
match self {
|
||||
Self::U8 => 0,
|
||||
Self::I8 => 1,
|
||||
Self::U16 => 2,
|
||||
Self::I16 => 3,
|
||||
Self::U32 => 4,
|
||||
Self::I32 => 5,
|
||||
Self::F32 => 6,
|
||||
Self::Bool => 7,
|
||||
Self::String => 8,
|
||||
Self::Array => 9,
|
||||
Self::U64 => 10,
|
||||
Self::I64 => 11,
|
||||
Self::F64 => 12,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Content {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Self> {
|
||||
let magic = VersionedMagic::read(reader)?;
|
||||
|
||||
let tensor_count = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||
};
|
||||
let metadata_kv_count = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||
};
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
for _idx in 0..metadata_kv_count {
|
||||
let key = read_string(reader, &magic)?;
|
||||
let value_type = reader.read_u32::<LittleEndian>()?;
|
||||
let value_type = ValueType::from_u32(value_type)?;
|
||||
let value = Value::read(reader, value_type, &magic)?;
|
||||
metadata.insert(key, value);
|
||||
}
|
||||
let mut tensor_infos = HashMap::new();
|
||||
for _idx in 0..tensor_count {
|
||||
let tensor_name = read_string(reader, &magic)?;
|
||||
let n_dimensions = reader.read_u32::<LittleEndian>()?;
|
||||
|
||||
let mut dimensions: Vec<usize> = match magic {
|
||||
VersionedMagic::GgufV1 => {
|
||||
let mut dimensions = vec![0; n_dimensions as usize];
|
||||
reader.read_u32_into::<LittleEndian>(&mut dimensions)?;
|
||||
dimensions.into_iter().map(|c| c as usize).collect()
|
||||
}
|
||||
VersionedMagic::GgufV2 => {
|
||||
let mut dimensions = vec![0; n_dimensions as usize];
|
||||
reader.read_u64_into::<LittleEndian>(&mut dimensions)?;
|
||||
dimensions.into_iter().map(|c| c as usize).collect()
|
||||
}
|
||||
};
|
||||
|
||||
dimensions.reverse();
|
||||
let ggml_dtype = reader.read_u32::<LittleEndian>()?;
|
||||
let ggml_dtype = GgmlDType::from_u32(ggml_dtype)?;
|
||||
let offset = reader.read_u64::<LittleEndian>()?;
|
||||
tensor_infos.insert(
|
||||
tensor_name,
|
||||
TensorInfo {
|
||||
shape: crate::Shape::from(dimensions),
|
||||
offset,
|
||||
ggml_dtype,
|
||||
},
|
||||
);
|
||||
}
|
||||
let position = reader.stream_position()?;
|
||||
let alignment = match metadata.get("general.alignment") {
|
||||
Some(Value::U8(v)) => *v as u64,
|
||||
Some(Value::U16(v)) => *v as u64,
|
||||
Some(Value::U32(v)) => *v as u64,
|
||||
Some(Value::I8(v)) if *v >= 0 => *v as u64,
|
||||
Some(Value::I16(v)) if *v >= 0 => *v as u64,
|
||||
Some(Value::I32(v)) if *v >= 0 => *v as u64,
|
||||
_ => DEFAULT_ALIGNMENT,
|
||||
};
|
||||
let tensor_data_offset = (position + alignment - 1) / alignment * alignment;
|
||||
Ok(Self {
|
||||
magic,
|
||||
metadata,
|
||||
tensor_infos,
|
||||
tensor_data_offset,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn tensor<R: std::io::Seek + std::io::Read>(
|
||||
&self,
|
||||
reader: &mut R,
|
||||
name: &str,
|
||||
) -> Result<QTensor> {
|
||||
let tensor_info = match self.tensor_infos.get(name) {
|
||||
Some(tensor_info) => tensor_info,
|
||||
None => crate::bail!("cannot find tensor-infor for {name}"),
|
||||
};
|
||||
tensor_info.read(reader, self.tensor_data_offset)
|
||||
}
|
||||
}
|
||||
|
||||
fn write_string<W: std::io::Write>(w: &mut W, str: &str) -> Result<()> {
|
||||
let bytes = str.as_bytes();
|
||||
w.write_u64::<LittleEndian>(bytes.len() as u64)?;
|
||||
w.write_all(bytes)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write<W: std::io::Seek + std::io::Write>(
|
||||
w: &mut W,
|
||||
metadata: &[(&str, &Value)],
|
||||
tensors: &[(&str, &QTensor)],
|
||||
) -> Result<()> {
|
||||
w.write_u32::<LittleEndian>(0x46554747)?;
|
||||
w.write_u32::<LittleEndian>(2)?; // version 2.
|
||||
w.write_u64::<LittleEndian>(tensors.len() as u64)?;
|
||||
w.write_u64::<LittleEndian>(metadata.len() as u64)?;
|
||||
for (name, value) in metadata.iter() {
|
||||
write_string(w, name)?;
|
||||
w.write_u32::<LittleEndian>(value.value_type().to_u32())?;
|
||||
value.write(w)?;
|
||||
}
|
||||
let mut offset = 0usize;
|
||||
let mut offsets = Vec::with_capacity(tensors.len());
|
||||
for (name, tensor) in tensors.iter() {
|
||||
write_string(w, name)?;
|
||||
let dims = tensor.shape().dims();
|
||||
w.write_u32::<LittleEndian>(dims.len() as u32)?;
|
||||
for &dim in dims.iter().rev() {
|
||||
w.write_u64::<LittleEndian>(dim as u64)?;
|
||||
}
|
||||
w.write_u32::<LittleEndian>(tensor.dtype().to_u32())?;
|
||||
w.write_u64::<LittleEndian>(offset as u64)?;
|
||||
offsets.push(offset);
|
||||
let size_in_bytes = tensor.storage_size_in_bytes();
|
||||
let padding = 31 - (31 + size_in_bytes) % 32;
|
||||
offset += size_in_bytes + padding;
|
||||
}
|
||||
let pos = w.stream_position()? as usize;
|
||||
let padding = 31 - (31 + pos) % 32;
|
||||
w.write_all(&vec![0u8; padding])?;
|
||||
let tensor_start_pos = w.stream_position()? as usize;
|
||||
for (offset, (_name, tensor)) in offsets.iter().zip(tensors.iter()) {
|
||||
let pos = w.stream_position()? as usize;
|
||||
if tensor_start_pos + offset != pos {
|
||||
crate::bail!(
|
||||
"internal error, unexpected current position {tensor_start_pos} {offset} {pos}"
|
||||
)
|
||||
}
|
||||
let data_ptr = tensor.as_ptr();
|
||||
let size_in_bytes = tensor.storage_size_in_bytes();
|
||||
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
||||
w.write_all(data)?;
|
||||
let padding = 31 - (31 + size_in_bytes) % 32;
|
||||
w.write_all(&vec![0u8; padding])?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,13 +1,7 @@
|
||||
use crate::{Device, Result, Shape, Tensor};
|
||||
|
||||
#[cfg(target_feature = "avx")]
|
||||
pub mod avx;
|
||||
pub mod ggml_file;
|
||||
pub mod gguf_file;
|
||||
pub mod k_quants;
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub mod neon;
|
||||
pub mod utils;
|
||||
|
||||
pub use k_quants::GgmlType;
|
||||
|
||||
@ -16,7 +10,7 @@ pub struct QTensor {
|
||||
shape: Shape,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum GgmlDType {
|
||||
F32,
|
||||
F16,
|
||||
@ -56,27 +50,7 @@ impl GgmlDType {
|
||||
Ok(dtype)
|
||||
}
|
||||
|
||||
pub(crate) fn to_u32(self) -> u32 {
|
||||
match self {
|
||||
Self::F32 => 0,
|
||||
Self::F16 => 1,
|
||||
Self::Q4_0 => 2,
|
||||
Self::Q4_1 => 3,
|
||||
Self::Q5_0 => 6,
|
||||
Self::Q5_1 => 7,
|
||||
Self::Q8_0 => 8,
|
||||
Self::Q8_1 => 9,
|
||||
Self::Q2K => 10,
|
||||
Self::Q3K => 11,
|
||||
Self::Q4K => 12,
|
||||
Self::Q5K => 13,
|
||||
Self::Q6K => 14,
|
||||
Self::Q8K => 15,
|
||||
}
|
||||
}
|
||||
|
||||
/// The type size for blocks in bytes.
|
||||
pub fn type_size(&self) -> usize {
|
||||
fn type_size(&self) -> usize {
|
||||
use k_quants::*;
|
||||
match self {
|
||||
Self::F32 => 4,
|
||||
@ -97,8 +71,7 @@ impl GgmlDType {
|
||||
}
|
||||
}
|
||||
|
||||
/// The block size, i.e. the number of elements stored in each block.
|
||||
pub fn blck_size(&self) -> usize {
|
||||
fn blck_size(&self) -> usize {
|
||||
match self {
|
||||
Self::F32 => 1,
|
||||
Self::F16 => 1,
|
||||
@ -118,8 +91,6 @@ pub trait QuantizedType: Send + Sync {
|
||||
fn dtype(&self) -> GgmlDType;
|
||||
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
|
||||
fn to_float(&self, ys: &mut [f32]) -> Result<()>;
|
||||
fn storage_size_in_bytes(&self) -> usize;
|
||||
fn as_ptr(&self) -> *const u8;
|
||||
}
|
||||
|
||||
impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
|
||||
@ -134,14 +105,6 @@ impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
|
||||
fn to_float(&self, ys: &mut [f32]) -> Result<()> {
|
||||
T::to_float(self.as_slice(), ys)
|
||||
}
|
||||
|
||||
fn storage_size_in_bytes(&self) -> usize {
|
||||
self.len() * std::mem::size_of::<T>()
|
||||
}
|
||||
|
||||
fn as_ptr(&self) -> *const u8 {
|
||||
self.as_ptr() as *const u8
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for QTensor {
|
||||
@ -150,62 +113,21 @@ impl std::fmt::Debug for QTensor {
|
||||
}
|
||||
}
|
||||
|
||||
fn check_shape<T: k_quants::GgmlType>(shape: &Shape) -> Result<()> {
|
||||
let dims = shape.dims();
|
||||
if dims.is_empty() {
|
||||
crate::bail!("scalar tensor cannot be quantized {shape:?}")
|
||||
}
|
||||
if dims[dims.len() - 1] % T::BLCK_SIZE != 0 {
|
||||
crate::bail!(
|
||||
"quantized tensor must have their last dim divisible by block size {shape:?} {}",
|
||||
T::BLCK_SIZE
|
||||
)
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl QTensor {
|
||||
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
|
||||
data: Vec<T>,
|
||||
shape: S,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
check_shape::<T>(&shape)?;
|
||||
Ok(Self {
|
||||
) -> Self {
|
||||
Self {
|
||||
data: Box::new(data),
|
||||
shape,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
|
||||
let shape = src.shape();
|
||||
check_shape::<T>(shape)?;
|
||||
let src = src
|
||||
.to_dtype(crate::DType::F32)?
|
||||
.flatten_all()?
|
||||
.to_vec1::<f32>()?;
|
||||
if src.len() % T::BLCK_SIZE != 0 {
|
||||
crate::bail!(
|
||||
"tensor size ({shape:?}) is not divisible by block size {}",
|
||||
T::BLCK_SIZE
|
||||
)
|
||||
shape: shape.into(),
|
||||
}
|
||||
let mut data = vec![T::zeros(); src.len() / T::BLCK_SIZE];
|
||||
T::from_float(&src, &mut data)?;
|
||||
Ok(Self {
|
||||
data: Box::new(data),
|
||||
shape: shape.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.data.dtype()
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
self.shape.rank()
|
||||
}
|
||||
|
||||
pub fn shape(&self) -> &Shape {
|
||||
&self.shape
|
||||
}
|
||||
@ -219,34 +141,18 @@ impl QTensor {
|
||||
pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
|
||||
self.data.matmul_t(mkn, lhs, dst)
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
self.data.storage_size_in_bytes()
|
||||
}
|
||||
|
||||
pub fn as_ptr(&self) -> *const u8 {
|
||||
self.data.as_ptr()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QMatMul(std::sync::Arc<QTensor>);
|
||||
|
||||
impl QMatMul {
|
||||
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Self {
|
||||
pub fn new(qtensor: std::sync::Arc<QTensor>) -> Self {
|
||||
Self(qtensor)
|
||||
}
|
||||
|
||||
pub fn from_qtensor(qtensor: QTensor) -> Self {
|
||||
Self(std::sync::Arc::new(qtensor))
|
||||
}
|
||||
|
||||
pub fn inner(&self) -> &std::sync::Arc<QTensor> {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::CustomOp1 for QTensor {
|
||||
impl crate::CustomOp1 for QMatMul {
|
||||
fn name(&self) -> &'static str {
|
||||
"qmatmul"
|
||||
}
|
||||
@ -260,15 +166,17 @@ impl crate::CustomOp1 for QTensor {
|
||||
crate::bail!("input tensor is not contiguous {layout:?}")
|
||||
}
|
||||
let src_shape = layout.shape();
|
||||
// self is transposed so n is first then k.
|
||||
let (n, k) = self.shape.dims2()?;
|
||||
let (k, n) = self.0.shape.dims2()?;
|
||||
if src_shape.rank() < 2 {
|
||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||
}
|
||||
let mut dst_shape = src_shape.dims().to_vec();
|
||||
let last_k = dst_shape.pop().unwrap();
|
||||
if last_k != k {
|
||||
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
|
||||
crate::bail!(
|
||||
"input tensor {layout:?} incompatible with {:?}",
|
||||
self.0.shape
|
||||
)
|
||||
}
|
||||
dst_shape.push(n);
|
||||
let dst_shape = Shape::from(dst_shape);
|
||||
@ -276,7 +184,7 @@ impl crate::CustomOp1 for QTensor {
|
||||
let storage =
|
||||
&storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||
let mut dst_storage = vec![0f32; dst_shape.elem_count()];
|
||||
self.matmul_t(
|
||||
self.0.matmul_t(
|
||||
(dst_shape.elem_count() / n, k, n),
|
||||
storage,
|
||||
&mut dst_storage,
|
||||
@ -284,9 +192,3 @@ impl crate::CustomOp1 for QTensor {
|
||||
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
||||
}
|
||||
}
|
||||
|
||||
impl QMatMul {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply_op1_no_bwd(self.0.as_ref())
|
||||
}
|
||||
}
|
||||
|
@ -1,727 +0,0 @@
|
||||
use super::k_quants::{
|
||||
BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K,
|
||||
};
|
||||
use crate::Result;
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
|
||||
#[allow(unused_imports)]
|
||||
#[cfg(target_arch = "arm")]
|
||||
use core::arch::arm::*;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
use core::arch::aarch64::*;
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
let nb = n / qk;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut sumv0 = vdupq_n_f32(0.0f32);
|
||||
let mut sumv1 = vdupq_n_f32(0.0f32);
|
||||
for i in (0..nb).step_by(2) {
|
||||
let x0 = &xs[i];
|
||||
let x1 = &xs[i + 1];
|
||||
let y0 = &ys[i];
|
||||
let y1 = &ys[i + 1];
|
||||
|
||||
let m4b = vdupq_n_u8(0x0F);
|
||||
let s8b = vdupq_n_s8(0x8);
|
||||
|
||||
let v0_0 = vld1q_u8(x0.qs.as_ptr());
|
||||
let v0_1 = vld1q_u8(x1.qs.as_ptr());
|
||||
|
||||
// 4-bit -> 8-bit
|
||||
let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
|
||||
let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
||||
let v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
|
||||
let v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
||||
|
||||
// sub 8
|
||||
let v0_0ls = vsubq_s8(v0_0l, s8b);
|
||||
let v0_0hs = vsubq_s8(v0_0h, s8b);
|
||||
let v0_1ls = vsubq_s8(v0_1l, s8b);
|
||||
let v0_1hs = vsubq_s8(v0_1h, s8b);
|
||||
|
||||
// load y
|
||||
let v1_0l = vld1q_s8(y0.qs.as_ptr());
|
||||
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||
let v1_1l = vld1q_s8(y1.qs.as_ptr());
|
||||
let v1_1h = vld1q_s8(y1.qs.as_ptr().add(16));
|
||||
|
||||
// TODO: Support dotprod when it's available outside of nightly.
|
||||
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
|
||||
let pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
|
||||
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
|
||||
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
|
||||
|
||||
let pl1l = vmull_s8(vget_low_s8(v0_1ls), vget_low_s8(v1_1l));
|
||||
let pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
|
||||
let ph1l = vmull_s8(vget_low_s8(v0_1hs), vget_low_s8(v1_1h));
|
||||
let ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
|
||||
|
||||
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
||||
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
||||
let pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
||||
let ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
||||
|
||||
sumv0 = vmlaq_n_f32(
|
||||
sumv0,
|
||||
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
|
||||
x0.d.to_f32() * y0.d.to_f32(),
|
||||
);
|
||||
sumv1 = vmlaq_n_f32(
|
||||
sumv1,
|
||||
vcvtq_f32_s32(vaddq_s32(pl1, ph1)),
|
||||
x1.d.to_f32() * y1.d.to_f32(),
|
||||
);
|
||||
}
|
||||
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
let nb = n / QK8_0;
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
|
||||
}
|
||||
unsafe {
|
||||
let mut sumv0 = vdupq_n_f32(0.0f32);
|
||||
let mut sumv1 = vdupq_n_f32(0.0f32);
|
||||
for i in (0..nb).step_by(2) {
|
||||
let x0 = &xs[i];
|
||||
let x1 = &xs[i + 1];
|
||||
let y0 = &ys[i];
|
||||
let y1 = &ys[i + 1];
|
||||
|
||||
let x0_0 = vld1q_s8(x0.qs.as_ptr());
|
||||
let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16));
|
||||
let x1_0 = vld1q_s8(x1.qs.as_ptr());
|
||||
let x1_1 = vld1q_s8(x1.qs.as_ptr().add(16));
|
||||
|
||||
// load y
|
||||
let y0_0 = vld1q_s8(y0.qs.as_ptr());
|
||||
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||
let y1_0 = vld1q_s8(y1.qs.as_ptr());
|
||||
let y1_1 = vld1q_s8(y1.qs.as_ptr().add(16));
|
||||
|
||||
// TODO dotprod once this is the intrinsics are.
|
||||
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
|
||||
let p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
|
||||
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
|
||||
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
|
||||
|
||||
let p1_0 = vmull_s8(vget_low_s8(x1_0), vget_low_s8(y1_0));
|
||||
let p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
|
||||
let p1_2 = vmull_s8(vget_low_s8(x1_1), vget_low_s8(y1_1));
|
||||
let p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
|
||||
|
||||
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
|
||||
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
|
||||
let p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
|
||||
let p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
|
||||
|
||||
sumv0 = vmlaq_n_f32(
|
||||
sumv0,
|
||||
vcvtq_f32_s32(vaddq_s32(p0, p1)),
|
||||
x0.d.to_f32() * y0.d.to_f32(),
|
||||
);
|
||||
sumv1 = vmlaq_n_f32(
|
||||
sumv1,
|
||||
vcvtq_f32_s32(vaddq_s32(p2, p3)),
|
||||
x1.d.to_f32() * y1.d.to_f32(),
|
||||
);
|
||||
}
|
||||
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut sum = 0f32;
|
||||
unsafe {
|
||||
let m4b = vdupq_n_u8(0xF);
|
||||
|
||||
let mone = vdupq_n_u8(3);
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d_all = x.d.to_f32();
|
||||
|
||||
let mut q6 = x.ql.as_ptr();
|
||||
let mut qh = x.qh.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
let mut scale = x.scales.as_ptr();
|
||||
|
||||
let q8sums = vld1q_s16_x2(y.bsums.as_ptr());
|
||||
let scales = vld1q_s8(scale);
|
||||
let q6scales = int16x8x2_t(
|
||||
vmovl_s8(vget_low_s8(scales)),
|
||||
vmovl_s8(vget_high_s8(scales)),
|
||||
);
|
||||
|
||||
let prod = vaddq_s32(
|
||||
vaddq_s32(
|
||||
vmull_s16(vget_low_s16(q8sums.0), vget_low_s16(q6scales.0)),
|
||||
vmull_s16(vget_high_s16(q8sums.0), vget_high_s16(q6scales.0)),
|
||||
),
|
||||
vaddq_s32(
|
||||
vmull_s16(vget_low_s16(q8sums.1), vget_low_s16(q6scales.1)),
|
||||
vmull_s16(vget_high_s16(q8sums.1), vget_high_s16(q6scales.1)),
|
||||
),
|
||||
);
|
||||
let isum_mins = vaddvq_s32(prod);
|
||||
|
||||
let mut isum = 0i32;
|
||||
|
||||
for _j in 0..QK_K / 128 {
|
||||
let qhbits = vld1q_u8_x2(qh);
|
||||
qh = qh.add(32);
|
||||
let q6bits = vld1q_u8_x4(q6);
|
||||
q6 = q6.add(64);
|
||||
let q8bytes = vld1q_s8_x4(q8);
|
||||
q8 = q8.add(64);
|
||||
|
||||
let q6h_0 = vshlq_n_u8(vandq_u8(mone, qhbits.0), 4);
|
||||
let q6h_1 = vshlq_n_u8(vandq_u8(mone, qhbits.1), 4);
|
||||
let shifted = vshrq_n_u8(qhbits.0, 2);
|
||||
let q6h_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
||||
let shifted = vshrq_n_u8(qhbits.1, 2);
|
||||
let q6h_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
||||
|
||||
let q6bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.0, m4b), q6h_0));
|
||||
let q6bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.1, m4b), q6h_1));
|
||||
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2));
|
||||
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3));
|
||||
|
||||
// TODO: dotprod
|
||||
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
|
||||
scale = scale.add(2);
|
||||
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
|
||||
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
|
||||
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
|
||||
);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
|
||||
scale = scale.add(2);
|
||||
|
||||
let q8bytes = vld1q_s8_x4(q8);
|
||||
q8 = q8.add(64);
|
||||
|
||||
let shifted = vshrq_n_u8(qhbits.0, 4);
|
||||
let q6h_0 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
||||
let shifted = vshrq_n_u8(qhbits.1, 4);
|
||||
let q6h_1 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
||||
let shifted = vshrq_n_u8(qhbits.0, 6);
|
||||
let q6h_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
||||
let shifted = vshrq_n_u8(qhbits.1, 6);
|
||||
let q6h_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
||||
|
||||
let q6bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.0, 4), q6h_0));
|
||||
let q6bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.1, 4), q6h_1));
|
||||
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2));
|
||||
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3));
|
||||
|
||||
// TODO: dotprod case.
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
|
||||
scale = scale.add(2);
|
||||
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
|
||||
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
|
||||
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
|
||||
);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
|
||||
scale = scale.add(2);
|
||||
}
|
||||
sum += d_all * y.d * ((isum - 32 * isum_mins) as f32);
|
||||
}
|
||||
}
|
||||
Ok(sum)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut sumf = 0f32;
|
||||
let mut utmp = [0u32; 4];
|
||||
const KMASK1: u32 = 0x3f3f3f3f;
|
||||
const KMASK2: u32 = 0x0f0f0f0f;
|
||||
const KMASK3: u32 = 0x03030303;
|
||||
|
||||
unsafe {
|
||||
let m4b = vdupq_n_u8(0xF);
|
||||
let mone = vdupq_n_u8(1);
|
||||
let mtwo = vdupq_n_u8(2);
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
let dmin = y.d * x.dmin.to_f32();
|
||||
|
||||
let q8sums = vpaddq_s16(
|
||||
vld1q_s16(y.bsums.as_ptr()),
|
||||
vld1q_s16(y.bsums.as_ptr().add(8)),
|
||||
);
|
||||
|
||||
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
|
||||
|
||||
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
|
||||
let uaux = utmp[1] & KMASK1;
|
||||
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
|
||||
utmp[2] = uaux;
|
||||
utmp[0] &= KMASK1;
|
||||
|
||||
let mins8 = vld1_u8((utmp.as_ptr() as *const u8).add(8));
|
||||
let mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
|
||||
let prod = vaddq_s32(
|
||||
vmull_s16(vget_low_s16(q8sums), vget_low_s16(mins)),
|
||||
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)),
|
||||
);
|
||||
let sumi_mins = vaddvq_s32(prod);
|
||||
|
||||
let mut scales = utmp.as_ptr() as *const u8;
|
||||
|
||||
let mut q5 = x.qs.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
let mut qhbits = vld1q_u8_x2(x.qh.as_ptr());
|
||||
|
||||
let mut sumi = 0i32;
|
||||
|
||||
for _j in 0..QK_K / 64 {
|
||||
let q5bits = vld1q_u8_x2(q5);
|
||||
q5 = q5.add(32);
|
||||
let q8bytes = vld1q_s8_x4(q8);
|
||||
q8 = q8.add(64);
|
||||
|
||||
let q5h_0 = vshlq_n_u8(vandq_u8(mone, qhbits.0), 4);
|
||||
let q5h_1 = vshlq_n_u8(vandq_u8(mone, qhbits.1), 4);
|
||||
let q5h_2 = vshlq_n_u8(vandq_u8(mtwo, qhbits.0), 3);
|
||||
let q5h_3 = vshlq_n_u8(vandq_u8(mtwo, qhbits.1), 3);
|
||||
qhbits.0 = vshrq_n_u8(qhbits.0, 2);
|
||||
qhbits.1 = vshrq_n_u8(qhbits.1, 2);
|
||||
|
||||
let q5bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.0, m4b), q5h_0));
|
||||
let q5bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.1, m4b), q5h_1));
|
||||
let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
|
||||
let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
|
||||
|
||||
// TODO: dotprod
|
||||
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32;
|
||||
scales = scales.add(1);
|
||||
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)),
|
||||
vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)),
|
||||
vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)),
|
||||
);
|
||||
sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32;
|
||||
scales = scales.add(1);
|
||||
}
|
||||
sumf += d * sumi as f32 - dmin * sumi_mins as f32;
|
||||
}
|
||||
}
|
||||
Ok(sumf)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut sumf = 0f32;
|
||||
let mut utmp = [0u32; 4];
|
||||
let mut scales = [0u8; 16];
|
||||
const KMASK1: u32 = 0x3f3f3f3f;
|
||||
const KMASK2: u32 = 0x0f0f0f0f;
|
||||
const KMASK3: u32 = 0x03030303;
|
||||
|
||||
unsafe {
|
||||
let m4b = vdupq_n_u8(0xF);
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
let dmin = y.d * x.dmin.to_f32();
|
||||
|
||||
let q8sums = vpaddq_s16(
|
||||
vld1q_s16(y.bsums.as_ptr()),
|
||||
vld1q_s16(y.bsums.as_ptr().add(8)),
|
||||
);
|
||||
|
||||
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
|
||||
|
||||
let mins8 = vld1_u32(
|
||||
[
|
||||
utmp[1] & KMASK1,
|
||||
((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4),
|
||||
]
|
||||
.as_ptr(),
|
||||
);
|
||||
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
|
||||
utmp[0] &= KMASK1;
|
||||
|
||||
let mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
|
||||
let prod = vaddq_s32(
|
||||
vmull_s16(vget_low_s16(q8sums), vget_low_s16(mins)),
|
||||
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)),
|
||||
);
|
||||
sumf -= dmin * vaddvq_s32(prod) as f32;
|
||||
|
||||
LittleEndian::write_u32_into(&utmp, &mut scales);
|
||||
|
||||
let mut q4 = x.qs.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
let mut sumi1 = 0i32;
|
||||
let mut sumi2 = 0i32;
|
||||
|
||||
for j in 0..QK_K / 64 {
|
||||
let q4bits = vld1q_u8_x2(q4);
|
||||
q4 = q4.add(32);
|
||||
// TODO: dotprod
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
let q4bytes = int8x16x2_t(
|
||||
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
|
||||
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
|
||||
);
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32;
|
||||
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
let q4bytes = int8x16x2_t(
|
||||
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
|
||||
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
|
||||
);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32;
|
||||
}
|
||||
sumf += d * (sumi1 + sumi2) as f32;
|
||||
}
|
||||
}
|
||||
Ok(sumf)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut sumf = 0f32;
|
||||
let mut utmp = [0u32; 4];
|
||||
let mut aux = [0u32; 3];
|
||||
const KMASK1: u32 = 0x03030303;
|
||||
const KMASK2: u32 = 0x0f0f0f0f;
|
||||
|
||||
unsafe {
|
||||
let m3b = vdupq_n_u8(0x3);
|
||||
let m0 = vdupq_n_u8(1);
|
||||
let m1 = vshlq_n_u8(m0, 1);
|
||||
let m2 = vshlq_n_u8(m0, 2);
|
||||
let m3 = vshlq_n_u8(m0, 3);
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
let mut q3 = x.qs.as_ptr();
|
||||
let qh = x.hmask.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
let mut qhbits = vld1q_u8_x2(qh);
|
||||
|
||||
let mut isum = 0i32;
|
||||
|
||||
// Set up scales
|
||||
LittleEndian::read_u32_into(&x.scales, &mut aux);
|
||||
|
||||
utmp[3] = ((aux[1] >> 4) & KMASK2) | (((aux[2] >> 6) & KMASK1) << 4);
|
||||
utmp[2] = ((aux[0] >> 4) & KMASK2) | (((aux[2] >> 4) & KMASK1) << 4);
|
||||
utmp[1] = (aux[1] & KMASK2) | (((aux[2] >> 2) & KMASK1) << 4);
|
||||
utmp[0] = (aux[0] & KMASK2) | ((aux[2] & KMASK1) << 4);
|
||||
|
||||
let mut scale = utmp.as_mut_ptr() as *mut i8;
|
||||
for j in 0..16 {
|
||||
*scale.add(j) -= 32i8
|
||||
}
|
||||
|
||||
for j in 0..QK_K / 128 {
|
||||
let q3bits = vld1q_u8_x2(q3);
|
||||
q3 = q3.add(32);
|
||||
let q8bytes_1 = vld1q_s8_x4(q8);
|
||||
q8 = q8.add(64);
|
||||
let q8bytes_2 = vld1q_s8_x4(q8);
|
||||
q8 = q8.add(64);
|
||||
|
||||
let q3h_0 = vshlq_n_u8(vbicq_u8(m0, qhbits.0), 2);
|
||||
let q3h_1 = vshlq_n_u8(vbicq_u8(m0, qhbits.1), 2);
|
||||
let q3h_2 = vshlq_n_u8(vbicq_u8(m1, qhbits.0), 1);
|
||||
let q3h_3 = vshlq_n_u8(vbicq_u8(m1, qhbits.1), 1);
|
||||
|
||||
let q3bytes_0 = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vandq_u8(q3bits.0, m3b)),
|
||||
vreinterpretq_s8_u8(q3h_0),
|
||||
);
|
||||
let q3bytes_1 = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vandq_u8(q3bits.1, m3b)),
|
||||
vreinterpretq_s8_u8(q3h_1),
|
||||
);
|
||||
let q3bytes_2 = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 2), m3b)),
|
||||
vreinterpretq_s8_u8(q3h_2),
|
||||
);
|
||||
let q3bytes_3 = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 2), m3b)),
|
||||
vreinterpretq_s8_u8(q3h_3),
|
||||
);
|
||||
|
||||
// TODO: dotprod
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)),
|
||||
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)),
|
||||
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)),
|
||||
);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)),
|
||||
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)),
|
||||
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)),
|
||||
);
|
||||
isum += vaddvq_s16(p0) as i32 * *scale as i32
|
||||
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
|
||||
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
|
||||
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
|
||||
scale = scale.add(4);
|
||||
|
||||
let q3h_0 = vbicq_u8(m2, qhbits.0);
|
||||
let q3h_1 = vbicq_u8(m2, qhbits.1);
|
||||
let q3h_2 = vshrq_n_u8(vbicq_u8(m3, qhbits.0), 1);
|
||||
let q3h_3 = vshrq_n_u8(vbicq_u8(m3, qhbits.1), 1);
|
||||
|
||||
let q3bytes_0 = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 4), m3b)),
|
||||
vreinterpretq_s8_u8(q3h_0),
|
||||
);
|
||||
let q3bytes_1 = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 4), m3b)),
|
||||
vreinterpretq_s8_u8(q3h_1),
|
||||
);
|
||||
let q3bytes_2 = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 6), m3b)),
|
||||
vreinterpretq_s8_u8(q3h_2),
|
||||
);
|
||||
let q3bytes_3 = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 6), m3b)),
|
||||
vreinterpretq_s8_u8(q3h_3),
|
||||
);
|
||||
|
||||
// TODO: dotprod
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)),
|
||||
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)),
|
||||
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)),
|
||||
);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)),
|
||||
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)),
|
||||
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)),
|
||||
);
|
||||
isum += vaddvq_s16(p0) as i32 * *scale as i32
|
||||
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
|
||||
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
|
||||
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
|
||||
scale = scale.add(4);
|
||||
|
||||
if j == 0 {
|
||||
qhbits.0 = vshrq_n_u8(qhbits.0, 4);
|
||||
qhbits.1 = vshrq_n_u8(qhbits.1, 4);
|
||||
}
|
||||
}
|
||||
sumf += d * isum as f32;
|
||||
}
|
||||
}
|
||||
Ok(sumf)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut sumf = 0f32;
|
||||
let mut aux = [0u8; 16];
|
||||
|
||||
unsafe {
|
||||
let m3 = vdupq_n_u8(0x3);
|
||||
let m4 = vdupq_n_u8(0xF);
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
let dmin = -y.d * x.dmin.to_f32();
|
||||
|
||||
let mut q2 = x.qs.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
let sc = x.scales.as_ptr();
|
||||
|
||||
let mins_and_scales = vld1q_u8(sc);
|
||||
let scales = vandq_u8(mins_and_scales, m4);
|
||||
vst1q_u8(aux.as_mut_ptr(), scales);
|
||||
|
||||
let mins = vshrq_n_u8(mins_and_scales, 4);
|
||||
let q8sums = vld1q_s16_x2(y.bsums.as_ptr());
|
||||
let mins16 = int16x8x2_t(
|
||||
vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))),
|
||||
vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins))),
|
||||
);
|
||||
let s0 = vaddq_s32(
|
||||
vmull_s16(vget_low_s16(mins16.0), vget_low_s16(q8sums.0)),
|
||||
vmull_s16(vget_high_s16(mins16.0), vget_high_s16(q8sums.0)),
|
||||
);
|
||||
let s1 = vaddq_s32(
|
||||
vmull_s16(vget_low_s16(mins16.1), vget_low_s16(q8sums.1)),
|
||||
vmull_s16(vget_high_s16(mins16.1), vget_high_s16(q8sums.1)),
|
||||
);
|
||||
sumf += dmin * vaddvq_s32(vaddq_s32(s0, s1)) as f32;
|
||||
|
||||
let mut isum = 0i32;
|
||||
let mut is = 0usize;
|
||||
|
||||
// TODO: dotprod
|
||||
|
||||
for _j in 0..QK_K / 128 {
|
||||
let q2bits = vld1q_u8_x2(q2);
|
||||
q2 = q2.add(32);
|
||||
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
let mut q2bytes = int8x16x2_t(
|
||||
vreinterpretq_s8_u8(vandq_u8(q2bits.0, m3)),
|
||||
vreinterpretq_s8_u8(vandq_u8(q2bits.1, m3)),
|
||||
);
|
||||
isum += multiply_accum_with_scale(&aux, is, 0, q2bytes, q8bytes);
|
||||
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 2), m3));
|
||||
q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 2), m3));
|
||||
isum += multiply_accum_with_scale(&aux, is, 2, q2bytes, q8bytes);
|
||||
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 4), m3));
|
||||
q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 4), m3));
|
||||
isum += multiply_accum_with_scale(&aux, is, 4, q2bytes, q8bytes);
|
||||
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 6), m3));
|
||||
q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 6), m3));
|
||||
isum += multiply_accum_with_scale(&aux, is, 6, q2bytes, q8bytes);
|
||||
|
||||
is += 8;
|
||||
}
|
||||
sumf += d * isum as f32;
|
||||
}
|
||||
}
|
||||
Ok(sumf)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn multiply_accum_with_scale(
|
||||
aux: &[u8; 16],
|
||||
is: usize,
|
||||
index: usize,
|
||||
q2bytes: int8x16x2_t,
|
||||
q8bytes: int8x16x2_t,
|
||||
) -> i32 {
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
vaddvq_s16(p1) as i32 * aux[is + index] as i32
|
||||
+ vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32
|
||||
}
|
@ -1,326 +0,0 @@
|
||||
use crate::Result;
|
||||
|
||||
pub(super) fn nearest_int(v: f32) -> i32 {
|
||||
v.round() as i32
|
||||
}
|
||||
|
||||
/// Validates that the input and output are the right size and returns an iterator which maps each
|
||||
/// input region `xs` to its corresponding output block in `ys`. Each output region is guaranteed
|
||||
/// to be `T::BLCK_SIZE` long.
|
||||
pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
|
||||
xs: &'b [f32],
|
||||
ys: &'a mut [T],
|
||||
) -> Result<Vec<(&'a mut T, &'b [f32])>> {
|
||||
let block_size = T::BLCK_SIZE;
|
||||
let dtype = T::DTYPE;
|
||||
|
||||
let expected_blocks = xs.len() / block_size;
|
||||
let actual_blocks = ys.len();
|
||||
|
||||
//validate that the input is the right size
|
||||
if expected_blocks != actual_blocks {
|
||||
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
|
||||
}
|
||||
|
||||
Ok(ys.iter_mut().zip(xs.chunks_exact(block_size)).collect())
|
||||
}
|
||||
|
||||
/// Validates that the input and output are the right size and returns an iterator which maps each
|
||||
/// input block `xs` to its corresponding output region in `ys`. Each output region is guaranteed
|
||||
/// to be `T::BLCK_SIZE` long.
|
||||
pub(super) fn group_for_dequantization<'a, 'b, T: super::k_quants::GgmlType>(
|
||||
xs: &'a [T],
|
||||
ys: &'b mut [f32],
|
||||
) -> Result<Vec<(&'a T, &'b mut [f32])>> {
|
||||
let block_size = T::BLCK_SIZE;
|
||||
let dtype = T::DTYPE;
|
||||
|
||||
let actual_output_len = ys.len();
|
||||
let expected_output_len = xs.len() * block_size;
|
||||
//validate that the output is the right size
|
||||
if expected_output_len != actual_output_len {
|
||||
crate::bail!("dequantize {dtype:?}: ys (len = {actual_output_len}) does not match the expected length of {expected_output_len}!")
|
||||
}
|
||||
|
||||
//zip the blocks and outputs together
|
||||
Ok(xs.iter().zip(ys.chunks_exact_mut(block_size)).collect())
|
||||
}
|
||||
|
||||
pub(super) fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
|
||||
if j < 4 {
|
||||
let d = q[j] & 63;
|
||||
let m = q[j + 4] & 63;
|
||||
(d, m)
|
||||
} else {
|
||||
let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
|
||||
let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
|
||||
(d, m)
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) unsafe fn make_qx_quants(
|
||||
n: usize,
|
||||
nmax: i32,
|
||||
x: *const f32,
|
||||
ls: *mut i8,
|
||||
rmse_type: i32,
|
||||
) -> f32 {
|
||||
let mut max = 0f32;
|
||||
let mut amax = 0f32;
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let ax = x.abs();
|
||||
if ax > amax {
|
||||
amax = ax;
|
||||
max = x;
|
||||
}
|
||||
}
|
||||
if amax == 0. {
|
||||
// all zero
|
||||
for i in 0..n {
|
||||
*ls.add(i) = 0;
|
||||
}
|
||||
return 0.;
|
||||
}
|
||||
let mut iscale = -(nmax as f32) / max;
|
||||
if rmse_type == 0 {
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let l = nearest_int(iscale * x);
|
||||
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
|
||||
}
|
||||
return 1.0 / iscale;
|
||||
}
|
||||
let weight_type = rmse_type % 2;
|
||||
let mut sumlx = 0f32;
|
||||
let mut suml2 = 0f32;
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let l = nearest_int(iscale * x);
|
||||
let l = l.clamp(-nmax, nmax - 1);
|
||||
*ls.add(i) = (l + nmax) as i8;
|
||||
let w = if weight_type == 1 { x * x } else { 1.0 };
|
||||
let l = l as f32;
|
||||
sumlx += w * x * l;
|
||||
suml2 += w * l * l;
|
||||
}
|
||||
let mut scale = sumlx / suml2;
|
||||
let mut best = scale * sumlx;
|
||||
for _itry in 0..3 {
|
||||
let iscale = 1.0 / scale;
|
||||
let mut slx = 0f32;
|
||||
let mut sl2 = 0f32;
|
||||
let mut changed = false;
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let l = nearest_int(iscale * x);
|
||||
let l = l.clamp(-nmax, nmax - 1);
|
||||
if l + nmax != *ls.add(i) as i32 {
|
||||
changed = true;
|
||||
}
|
||||
let w = if weight_type == 1 { x * x } else { 1f32 };
|
||||
let l = l as f32;
|
||||
slx += w * x * l;
|
||||
sl2 += w * l * l;
|
||||
}
|
||||
if !changed || sl2 == 0.0 || slx * slx <= best * sl2 {
|
||||
break;
|
||||
}
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let l = nearest_int(iscale * x);
|
||||
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
|
||||
}
|
||||
sumlx = slx;
|
||||
suml2 = sl2;
|
||||
scale = sumlx / suml2;
|
||||
best = scale * sumlx;
|
||||
}
|
||||
for _itry in 0..5 {
|
||||
let mut n_changed = 0;
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let w = if weight_type == 1 { x * x } else { 1. };
|
||||
let l = *ls.add(i) as i32 - nmax;
|
||||
let mut slx = sumlx - w * x * l as f32;
|
||||
if slx > 0. {
|
||||
let mut sl2 = suml2 - w * l as f32 * l as f32;
|
||||
let new_l = nearest_int(x * sl2 / slx);
|
||||
let new_l = new_l.clamp(-nmax, nmax - 1);
|
||||
if new_l != l {
|
||||
slx += w * x * new_l as f32;
|
||||
sl2 += w * new_l as f32 * new_l as f32;
|
||||
if sl2 > 0. && slx * slx * suml2 > sumlx * sumlx * sl2 {
|
||||
*ls.add(i) = (nmax + new_l) as i8;
|
||||
sumlx = slx;
|
||||
suml2 = sl2;
|
||||
scale = sumlx / suml2;
|
||||
best = scale * sumlx;
|
||||
n_changed += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if n_changed == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if rmse_type < 3 {
|
||||
return scale;
|
||||
}
|
||||
for is in -4..4 {
|
||||
if is == 0 {
|
||||
continue;
|
||||
}
|
||||
iscale = -(nmax as f32 + 0.1f32 * is as f32) / max;
|
||||
let mut sumlx = 0.;
|
||||
let mut suml2 = 0.;
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let l = nearest_int(iscale * x);
|
||||
let l = l.clamp(-nmax, nmax - 1);
|
||||
let w = if weight_type == 1 { x * x } else { 1. };
|
||||
let l = l as f32;
|
||||
sumlx += w * x * l;
|
||||
suml2 += w * l * l;
|
||||
}
|
||||
if suml2 > 0. && sumlx * sumlx > best * suml2 {
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let l = nearest_int(iscale * x);
|
||||
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
|
||||
}
|
||||
scale = sumlx / suml2;
|
||||
best = scale * sumlx;
|
||||
}
|
||||
}
|
||||
scale
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L224
|
||||
pub(super) fn make_qkx1_quants(nmax: i32, ntry: usize, x: &[f32]) -> (f32, f32) {
|
||||
let n = x.len();
|
||||
let mut l = vec![0; n];
|
||||
// Get min/max
|
||||
let min = *x
|
||||
.iter()
|
||||
.take(n)
|
||||
.min_by(|a, b| a.total_cmp(b))
|
||||
.unwrap_or(&x[0]);
|
||||
let max = *x.iter().max_by(|a, b| a.total_cmp(b)).unwrap_or(&x[0]);
|
||||
|
||||
// If min == max, all values are the same => nothing to do here
|
||||
if max == min {
|
||||
return (0.0, 0.0);
|
||||
}
|
||||
|
||||
// Ensure min <= 0.0
|
||||
let mut min = min.min(0.);
|
||||
|
||||
// Compute scale and inverse scale
|
||||
let mut iscale = nmax as f32 / (max - min);
|
||||
let mut scale = 1.0 / iscale;
|
||||
|
||||
for _ in 0..ntry {
|
||||
let mut sumlx = 0.0;
|
||||
let mut suml2 = 0;
|
||||
let mut did_change = false;
|
||||
|
||||
for (i, value) in x.iter().enumerate().take(n) {
|
||||
let li = nearest_int(iscale * (value - min)).clamp(0, nmax);
|
||||
let clamped_li = li as u8;
|
||||
if clamped_li != l[i] {
|
||||
l[i] = clamped_li;
|
||||
did_change = true;
|
||||
}
|
||||
sumlx += (value - min) * li as f32;
|
||||
suml2 += li * li;
|
||||
}
|
||||
scale = sumlx / suml2 as f32;
|
||||
|
||||
let sum: f32 = x
|
||||
.iter()
|
||||
.take(n)
|
||||
.zip(l.iter().take(n))
|
||||
.map(|(xi, &li)| xi - scale * li as f32)
|
||||
.sum();
|
||||
|
||||
min = sum / n as f32;
|
||||
if min > 0.0 {
|
||||
min = 0.0;
|
||||
}
|
||||
iscale = 1.0 / scale;
|
||||
if !did_change {
|
||||
break;
|
||||
}
|
||||
}
|
||||
(scale, -min)
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L165
|
||||
pub(super) fn make_q3_quants(x: &[f32], nmax: i32, do_rmse: bool) -> f32 {
|
||||
let n = x.len();
|
||||
let mut l = vec![0i8; n];
|
||||
|
||||
let mut max = 0.0;
|
||||
let mut amax = 0.0;
|
||||
for &xi in x.iter().take(n) {
|
||||
let ax = xi.abs();
|
||||
if ax > amax {
|
||||
amax = ax;
|
||||
max = xi;
|
||||
}
|
||||
}
|
||||
|
||||
if amax == 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let iscale = -(nmax as f32) / max;
|
||||
if do_rmse {
|
||||
let mut sumlx = 0.0;
|
||||
let mut suml2 = 0.0;
|
||||
for i in 0..n {
|
||||
let li = (iscale * x[i]).round() as i32;
|
||||
let li = li.clamp(-nmax, nmax - 1);
|
||||
l[i] = li as i8;
|
||||
let w = x[i] * x[i];
|
||||
sumlx += w * x[i] * li as f32;
|
||||
suml2 += w * (li * li) as f32;
|
||||
}
|
||||
for _ in 0..5 {
|
||||
let mut n_changed = 0;
|
||||
for i in 0..n {
|
||||
let w = x[i] * x[i];
|
||||
let mut slx = sumlx - w * x[i] * l[i] as f32;
|
||||
if slx > 0.0 {
|
||||
let mut sl2 = suml2 - w * (l[i] as i32 * l[i] as i32) as f32;
|
||||
let mut new_l = (x[i] * sl2 / slx).round() as i32;
|
||||
new_l = new_l.clamp(-nmax, nmax - 1);
|
||||
if new_l != l[i] as i32 {
|
||||
slx += w * x[i] * new_l as f32;
|
||||
sl2 += w * (new_l * new_l) as f32;
|
||||
if sl2 > 0.0 && slx * slx * suml2 > sumlx * sumlx * sl2 {
|
||||
l[i] = new_l as i8;
|
||||
sumlx = slx;
|
||||
suml2 = sl2;
|
||||
n_changed += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if n_changed == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
for li in l.iter_mut() {
|
||||
*li += nmax as i8;
|
||||
}
|
||||
return sumlx / suml2;
|
||||
}
|
||||
for i in 0..n {
|
||||
let li = (iscale * x[i]).round() as i32;
|
||||
l[i] = (li.clamp(-nmax, nmax - 1) + nmax) as i8;
|
||||
}
|
||||
1.0 / iscale
|
||||
}
|
@ -10,7 +10,6 @@ impl From<DType> for st::Dtype {
|
||||
match value {
|
||||
DType::U8 => st::Dtype::U8,
|
||||
DType::U32 => st::Dtype::U32,
|
||||
DType::I64 => st::Dtype::I64,
|
||||
DType::BF16 => st::Dtype::BF16,
|
||||
DType::F16 => st::Dtype::F16,
|
||||
DType::F32 => st::Dtype::F32,
|
||||
@ -25,7 +24,6 @@ impl TryFrom<st::Dtype> for DType {
|
||||
match value {
|
||||
st::Dtype::U8 => Ok(DType::U8),
|
||||
st::Dtype::U32 => Ok(DType::U32),
|
||||
st::Dtype::I64 => Ok(DType::I64),
|
||||
st::Dtype::BF16 => Ok(DType::BF16),
|
||||
st::Dtype::F16 => Ok(DType::F16),
|
||||
st::Dtype::F32 => Ok(DType::F32),
|
||||
@ -191,7 +189,6 @@ impl Tensor {
|
||||
match dtype {
|
||||
DType::U8 => convert_slice::<u8>(data, shape, device),
|
||||
DType::U32 => convert_slice::<u32>(data, shape, device),
|
||||
DType::I64 => convert_slice::<i64>(data, shape, device),
|
||||
DType::BF16 => convert_slice::<half::bf16>(data, shape, device),
|
||||
DType::F16 => convert_slice::<half::f16>(data, shape, device),
|
||||
DType::F32 => convert_slice::<f32>(data, shape, device),
|
||||
@ -208,15 +205,24 @@ fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
convert_with_cast_::<u16, u32, _>(view, device, conv)
|
||||
}
|
||||
st::Dtype::U32 => convert_::<u32>(view, device),
|
||||
st::Dtype::I32 => {
|
||||
let conv = |x| Ok(i64::from(x));
|
||||
convert_with_cast_::<i32, i64, _>(view, device, conv)
|
||||
}
|
||||
st::Dtype::I64 => convert_::<i64>(view, device),
|
||||
st::Dtype::BF16 => convert_::<half::bf16>(view, device),
|
||||
st::Dtype::F16 => convert_::<half::f16>(view, device),
|
||||
st::Dtype::F32 => convert_::<f32>(view, device),
|
||||
st::Dtype::F64 => convert_::<f64>(view, device),
|
||||
st::Dtype::I32 => {
|
||||
let conv = |x| {
|
||||
u32::try_from(x)
|
||||
.map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}")))
|
||||
};
|
||||
convert_with_cast_::<i32, u32, _>(view, device, conv)
|
||||
}
|
||||
st::Dtype::I64 => {
|
||||
let conv = |x| {
|
||||
u32::try_from(x)
|
||||
.map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}")))
|
||||
};
|
||||
convert_with_cast_::<i64, u32, _>(view, device, conv)
|
||||
}
|
||||
dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
|
||||
}
|
||||
}
|
||||
@ -227,7 +233,6 @@ fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
|
||||
match tensor.dtype() {
|
||||
DType::U8 => Ok(convert_back_::<u8>(tensor.to_vec1()?)),
|
||||
DType::U32 => Ok(convert_back_::<u32>(tensor.to_vec1()?)),
|
||||
DType::I64 => Ok(convert_back_::<i64>(tensor.to_vec1()?)),
|
||||
DType::F16 => Ok(convert_back_::<half::f16>(tensor.to_vec1()?)),
|
||||
DType::BF16 => Ok(convert_back_::<half::bf16>(tensor.to_vec1()?)),
|
||||
DType::F32 => Ok(convert_back_::<f32>(tensor.to_vec1()?)),
|
||||
|
@ -1,23 +0,0 @@
|
||||
use crate::{Result, Tensor, WithDType};
|
||||
|
||||
pub enum TensorScalar {
|
||||
Tensor(Tensor),
|
||||
Scalar(Tensor),
|
||||
}
|
||||
|
||||
pub trait TensorOrScalar {
|
||||
fn to_tensor_scalar(self) -> Result<TensorScalar>;
|
||||
}
|
||||
|
||||
impl TensorOrScalar for &Tensor {
|
||||
fn to_tensor_scalar(self) -> Result<TensorScalar> {
|
||||
Ok(TensorScalar::Tensor(self.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: WithDType> TensorOrScalar for T {
|
||||
fn to_tensor_scalar(self) -> Result<TensorScalar> {
|
||||
let scalar = Tensor::new(self, &crate::Device::Cpu)?;
|
||||
Ok(TensorScalar::Scalar(scalar))
|
||||
}
|
||||
}
|
@ -1,5 +1,3 @@
|
||||
//! The shape of a tensor is a tuple with the size of each of its dimensions.
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::{Error, Result};
|
||||
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
@ -73,14 +71,6 @@ impl From<(usize, usize, usize, usize, usize)> for Shape {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(usize, usize, usize, usize, usize, usize)> for Shape {
|
||||
fn from(d123456: (usize, usize, usize, usize, usize, usize)) -> Self {
|
||||
Self(vec![
|
||||
d123456.0, d123456.1, d123456.2, d123456.3, d123456.4, d123456.5,
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<usize>> for Shape {
|
||||
fn from(dims: Vec<usize>) -> Self {
|
||||
Self(dims)
|
||||
@ -128,7 +118,6 @@ impl Shape {
|
||||
Self(dims.to_vec())
|
||||
}
|
||||
|
||||
/// The rank is the number of dimensions, 0 for a scalar value, 1 for a vector, etc.
|
||||
pub fn rank(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
@ -137,12 +126,10 @@ impl Shape {
|
||||
self.0
|
||||
}
|
||||
|
||||
/// The dimensions as a slice of `usize`.
|
||||
pub fn dims(&self) -> &[usize] {
|
||||
&self.0
|
||||
}
|
||||
|
||||
/// The total number of elements, this is the product of all dimension sizes.
|
||||
pub fn elem_count(&self) -> usize {
|
||||
self.0.iter().product()
|
||||
}
|
||||
@ -194,75 +181,10 @@ impl Shape {
|
||||
true
|
||||
}
|
||||
|
||||
/// Modifies the shape by adding a list of additional dimensions at the end of the existing
|
||||
/// dimensions.
|
||||
pub fn extend(mut self, additional_dims: &[usize]) -> Self {
|
||||
self.0.extend(additional_dims);
|
||||
self
|
||||
}
|
||||
|
||||
/// Check whether the two shapes are compatible for broadcast, and if it is the case return the
|
||||
/// broadcasted shape. This is to be used for binary pointwise ops.
|
||||
pub(crate) fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
|
||||
let lhs = self;
|
||||
let lhs_dims = lhs.dims();
|
||||
let rhs_dims = rhs.dims();
|
||||
let lhs_ndims = lhs_dims.len();
|
||||
let rhs_ndims = rhs_dims.len();
|
||||
let bcast_ndims = usize::max(lhs_ndims, rhs_ndims);
|
||||
let mut bcast_dims = vec![0; bcast_ndims];
|
||||
for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() {
|
||||
let rev_idx = bcast_ndims - idx;
|
||||
let l_value = if lhs_ndims < rev_idx {
|
||||
1
|
||||
} else {
|
||||
lhs_dims[lhs_ndims - rev_idx]
|
||||
};
|
||||
let r_value = if rhs_ndims < rev_idx {
|
||||
1
|
||||
} else {
|
||||
rhs_dims[rhs_ndims - rev_idx]
|
||||
};
|
||||
*bcast_value = if l_value == r_value {
|
||||
l_value
|
||||
} else if l_value == 1 {
|
||||
r_value
|
||||
} else if r_value == 1 {
|
||||
l_value
|
||||
} else {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: lhs.clone(),
|
||||
rhs: rhs.clone(),
|
||||
op,
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
Ok(Shape::from(bcast_dims))
|
||||
}
|
||||
|
||||
pub(crate) fn broadcast_shape_matmul(&self, rhs: &Self) -> Result<(Shape, Shape)> {
|
||||
let lhs = self;
|
||||
let lhs_dims = lhs.dims();
|
||||
let rhs_dims = rhs.dims();
|
||||
if lhs_dims.len() < 2 || rhs_dims.len() < 2 {
|
||||
crate::bail!("only 2d matrixes are supported {lhs:?} {rhs:?}")
|
||||
}
|
||||
let (m, lhs_k) = (lhs_dims[lhs_dims.len() - 2], lhs_dims[lhs_dims.len() - 1]);
|
||||
let (rhs_k, n) = (rhs_dims[rhs_dims.len() - 2], rhs_dims[rhs_dims.len() - 1]);
|
||||
if lhs_k != rhs_k {
|
||||
crate::bail!("different inner dimensions in broadcast matmul {lhs:?} {rhs:?}")
|
||||
}
|
||||
|
||||
let lhs_b = Self::from(&lhs_dims[..lhs_dims.len() - 2]);
|
||||
let rhs_b = Self::from(&rhs_dims[..rhs_dims.len() - 2]);
|
||||
let bcast = lhs_b.broadcast_shape_binary_op(&rhs_b, "broadcast_matmul")?;
|
||||
let bcast_dims = bcast.dims();
|
||||
|
||||
let bcast_lhs = [bcast_dims, &[m, lhs_k]].concat();
|
||||
let bcast_rhs = [bcast_dims, &[rhs_k, n]].concat();
|
||||
Ok((Shape::from(bcast_lhs), Shape::from(bcast_rhs)))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Dim {
|
||||
@ -423,27 +345,6 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
|
||||
}
|
||||
}
|
||||
|
||||
impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim> Dims for (D1, D2, D3, D4) {
|
||||
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||
let d0 = self.0.to_index(shape, op)?;
|
||||
let d1 = self.1.to_index(shape, op)?;
|
||||
let d2 = self.2.to_index(shape, op)?;
|
||||
let d3 = self.3.to_index(shape, op)?;
|
||||
Ok(vec![d0, d1, d2, d3])
|
||||
}
|
||||
}
|
||||
|
||||
impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5) {
|
||||
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||
let d0 = self.0.to_index(shape, op)?;
|
||||
let d1 = self.1.to_index(shape, op)?;
|
||||
let d2 = self.2.to_index(shape, op)?;
|
||||
let d3 = self.3.to_index(shape, op)?;
|
||||
let d4 = self.4.to_index(shape, op)?;
|
||||
Ok(vec![d0, d1, d2, d3, d4])
|
||||
}
|
||||
}
|
||||
|
||||
extract_dims!(dims0, 0, |_: &[usize]| (), ());
|
||||
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
|
||||
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
|
||||
@ -482,171 +383,3 @@ mod tests {
|
||||
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ShapeWithOneHole {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape>;
|
||||
}
|
||||
|
||||
impl<S: Into<Shape>> ShapeWithOneHole for S {
|
||||
fn into_shape(self, _el_count: usize) -> Result<Shape> {
|
||||
Ok(self.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((),) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
Ok(el_count.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1) = self;
|
||||
if el_count % d1 != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
|
||||
}
|
||||
Ok((el_count / d1, d1).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, ()) = self;
|
||||
if el_count % d1 != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
|
||||
}
|
||||
Ok((d1, el_count / d1).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1, d2) = self;
|
||||
let d = d1 * d2;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((el_count / d, d1, d2).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, (), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, (), d2) = self;
|
||||
let d = d1 * d2;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, el_count / d, d2).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, ()) = self;
|
||||
let d = d1 * d2;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, el_count / d).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize, usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1, d2, d3) = self;
|
||||
let d = d1 * d2 * d3;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((el_count / d, d1, d2, d3).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, (), usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, (), d2, d3) = self;
|
||||
let d = d1 * d2 * d3;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, el_count / d, d2, d3).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, (), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, (), d3) = self;
|
||||
let d = d1 * d2 * d3;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, el_count / d, d3).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, d3, ()) = self;
|
||||
let d = d1 * d2 * d3;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, d3, el_count / d).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1, d2, d3, d4) = self;
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((el_count / d, d1, d2, d3, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, (), d2, d3, d4) = self;
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, el_count / d, d2, d3, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, (), d3, d4) = self;
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, el_count / d, d3, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, d3, (), d4) = self;
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, d3, el_count / d, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, d3, d4, ()) = self;
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, d3, d4, el_count / d).into())
|
||||
}
|
||||
}
|
||||
|
@ -68,19 +68,6 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn powf(&self, layout: &Layout, alpha: f64) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.powf(layout, alpha)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.powf(layout, alpha)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
@ -151,7 +138,7 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn apply_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
|
||||
pub(crate) fn custom_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
|
||||
match self {
|
||||
Self::Cpu(storage) => {
|
||||
let (storage, shape) = c.cpu_fwd(storage, l)?;
|
||||
@ -164,7 +151,7 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn apply_op2(
|
||||
pub(crate) fn custom_op2(
|
||||
&self,
|
||||
l1: &Layout,
|
||||
t2: &Self,
|
||||
@ -185,7 +172,7 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn apply_op3(
|
||||
pub(crate) fn custom_op3(
|
||||
&self,
|
||||
l1: &Layout,
|
||||
t2: &Self,
|
||||
@ -306,33 +293,6 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn conv_transpose2d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConvTranspose2D,
|
||||
) -> Result<Self> {
|
||||
self.same_device(kernel, "conv_transpose2d")?;
|
||||
self.same_dtype(kernel, "conv_transpose2d")?;
|
||||
match (self, &kernel) {
|
||||
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
|
||||
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cpu(s))
|
||||
}
|
||||
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
|
||||
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "conv_transpose2d",
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn avg_pool2d(
|
||||
&self,
|
||||
layout: &Layout,
|
||||
|
@ -1,10 +1,7 @@
|
||||
//! Tensors are N-dimenional matrixes of elements using a single data type.
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{
|
||||
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
|
||||
};
|
||||
use crate::scalar::TensorOrScalar;
|
||||
use crate::shape::{Dim, Dims};
|
||||
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::{Arc, RwLock};
|
||||
@ -109,9 +106,7 @@ macro_rules! broadcast_binary_op {
|
||||
($fn_name:ident, $inner_fn_name:ident) => {
|
||||
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
||||
let lhs = self;
|
||||
let shape = lhs
|
||||
.shape()
|
||||
.broadcast_shape_binary_op(rhs.shape(), stringify!($fn_name))?;
|
||||
let shape = lhs.broadcast_shape_binary_op(rhs, stringify!($fn_name))?;
|
||||
let l_broadcast = shape != *lhs.shape();
|
||||
let r_broadcast = shape != *rhs.shape();
|
||||
match (l_broadcast, r_broadcast) {
|
||||
@ -127,7 +122,7 @@ macro_rules! broadcast_binary_op {
|
||||
}
|
||||
|
||||
/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides.
|
||||
pub(crate) fn from_storage<S: Into<Shape>>(
|
||||
fn from_storage<S: Into<Shape>>(
|
||||
storage: Storage,
|
||||
shape: S,
|
||||
op: BackpropOp,
|
||||
@ -420,6 +415,48 @@ impl Tensor {
|
||||
Self::new_impl(array, shape.into(), device, false)
|
||||
}
|
||||
|
||||
pub(crate) fn broadcast_shape_binary_op<'a>(
|
||||
&'a self,
|
||||
rhs: &'a Self,
|
||||
op: &'static str,
|
||||
) -> Result<Shape> {
|
||||
let lhs = self;
|
||||
let lhs_dims = lhs.shape().dims();
|
||||
let rhs_dims = rhs.shape().dims();
|
||||
let lhs_ndims = lhs_dims.len();
|
||||
let rhs_ndims = rhs_dims.len();
|
||||
let bcast_ndims = usize::max(lhs_ndims, rhs_ndims);
|
||||
let mut bcast_dims = vec![0; bcast_ndims];
|
||||
for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() {
|
||||
let rev_idx = bcast_ndims - idx;
|
||||
let l_value = if lhs_ndims < rev_idx {
|
||||
1
|
||||
} else {
|
||||
lhs_dims[lhs_ndims - rev_idx]
|
||||
};
|
||||
let r_value = if rhs_ndims < rev_idx {
|
||||
1
|
||||
} else {
|
||||
rhs_dims[rhs_ndims - rev_idx]
|
||||
};
|
||||
*bcast_value = if l_value == r_value {
|
||||
l_value
|
||||
} else if l_value == 1 {
|
||||
r_value
|
||||
} else if r_value == 1 {
|
||||
l_value
|
||||
} else {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: self.shape().clone(),
|
||||
rhs: rhs.shape().clone(),
|
||||
op,
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
Ok(Shape::from(bcast_dims))
|
||||
}
|
||||
|
||||
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
||||
let lhs = self.shape();
|
||||
let rhs = rhs.shape();
|
||||
@ -447,14 +484,10 @@ impl Tensor {
|
||||
binary_op!(mul, Mul);
|
||||
binary_op!(sub, Sub);
|
||||
binary_op!(div, Div);
|
||||
binary_op!(maximum, Maximum);
|
||||
binary_op!(minimum, Minimum);
|
||||
broadcast_binary_op!(broadcast_add, add);
|
||||
broadcast_binary_op!(broadcast_mul, mul);
|
||||
broadcast_binary_op!(broadcast_sub, sub);
|
||||
broadcast_binary_op!(broadcast_div, div);
|
||||
broadcast_binary_op!(broadcast_maximum, maximum);
|
||||
broadcast_binary_op!(broadcast_minimum, minimum);
|
||||
|
||||
unary_op!(recip, Recip);
|
||||
unary_op!(neg, Neg);
|
||||
@ -462,7 +495,6 @@ impl Tensor {
|
||||
unary_op!(log, Log);
|
||||
unary_op!(sin, Sin);
|
||||
unary_op!(cos, Cos);
|
||||
unary_op!(tanh, Tanh);
|
||||
unary_op!(abs, Abs);
|
||||
unary_op!(sqr, Sqr);
|
||||
unary_op!(sqrt, Sqrt);
|
||||
@ -495,25 +527,6 @@ impl Tensor {
|
||||
self.to_scalar::<S>()
|
||||
}
|
||||
|
||||
/// Repeat this tensor along the specified dimensions.
|
||||
pub fn repeat<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
|
||||
// Similar to PyTorch, we extend the number of dimensions of self if needed.
|
||||
let repeats = shape.into();
|
||||
let repeats = repeats.dims();
|
||||
let mut inp = if self.rank() < repeats.len() {
|
||||
let shape = [vec![1; repeats.len() - self.rank()], self.dims().to_vec()].concat();
|
||||
self.reshape(shape)?
|
||||
} else {
|
||||
self.clone()
|
||||
};
|
||||
for (idx, &repeat) in repeats.iter().enumerate() {
|
||||
if repeat > 1 {
|
||||
inp = Tensor::cat(&vec![&inp; repeat], idx)?
|
||||
}
|
||||
}
|
||||
Ok(inp)
|
||||
}
|
||||
|
||||
/// This operation multiplies the input tensor by `mul` then adds `add` and return the result.
|
||||
/// The input values `mul` and `add` are casted to the appropriate type so some rounding might
|
||||
/// be performed.
|
||||
@ -538,13 +551,6 @@ impl Tensor {
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
/// Raise the tensor to some float exponent `e`.
|
||||
pub fn powf(&self, e: f64) -> Result<Self> {
|
||||
let storage = self.storage().powf(self.layout(), e)?;
|
||||
let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
|
||||
if dim >= self.dims().len() {
|
||||
Err(Error::DimOutOfRange {
|
||||
@ -699,58 +705,18 @@ impl Tensor {
|
||||
self.sum_impl(sum_dims, false)
|
||||
}
|
||||
|
||||
/// Returns the mean of all elements in the input tensor. The mean is performed over all the
|
||||
/// input dimensions.
|
||||
///
|
||||
/// The resulting tensor has a shape that is similar to the shape of the input tensor, except
|
||||
/// that the number of elements for each dimension index in `mean_dims` is 1.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
|
||||
/// let s = a.mean_keepdim(0)?;
|
||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[1., 2.]]);
|
||||
/// let s = a.mean_keepdim(1)?;
|
||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[0.5], [2.5]]);
|
||||
/// let s = a.mean_keepdim((0, 1))?;
|
||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[1.5]]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn mean_keepdim<D: Dims>(&self, mean_dims: D) -> Result<Self> {
|
||||
let mean_dims = mean_dims.to_indexes(self.shape(), "mean-keepdim")?;
|
||||
let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
|
||||
let scale = 1f64 / (reduced_dim as f64);
|
||||
self.sum_impl(mean_dims, true)? * scale
|
||||
}
|
||||
|
||||
/// Returns the mean of all elements in the input tensor. The mean is performed over all the
|
||||
/// input dimensions and compared to `mean_keepdim` these dimensions are squeezed rather than
|
||||
/// kept.
|
||||
pub fn mean<D: Dims>(&self, mean_dims: D) -> Result<Self> {
|
||||
let mean_dims = mean_dims.to_indexes(self.shape(), "mean")?;
|
||||
let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
|
||||
let scale = 1f64 / (reduced_dim as f64);
|
||||
self.sum_impl(mean_dims, false)? * scale
|
||||
}
|
||||
|
||||
/// Gathers the maximum value across the selected dimension. The resulting shape has the same
|
||||
/// number of dimensions as the original tensor and the select dimension has a single element.
|
||||
pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
self.reduce_impl(dim, true, ReduceOp::Max)
|
||||
}
|
||||
|
||||
/// Similar to `max_keepdim` but the target dimension is squeezed.
|
||||
pub fn max<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
self.reduce_impl(dim, false, ReduceOp::Max)
|
||||
}
|
||||
|
||||
/// Gathers the minimum value across the selected dimension. The resulting shape has the same
|
||||
/// number of dimensions as the original tensor and the select dimension has a single element.
|
||||
pub fn min_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
self.reduce_impl(dim, true, ReduceOp::Min)
|
||||
}
|
||||
|
||||
/// Similar to `min_keepdim` but the target dimension is squeezed.
|
||||
pub fn min<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
self.reduce_impl(dim, false, ReduceOp::Min)
|
||||
}
|
||||
@ -759,7 +725,6 @@ impl Tensor {
|
||||
self.reduce_impl(dim, true, ReduceOp::ArgMax)
|
||||
}
|
||||
|
||||
/// Similar to `argmax_keepdim` but the target dimension is squeezed.
|
||||
pub fn argmax<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
self.reduce_impl(dim, false, ReduceOp::ArgMax)
|
||||
}
|
||||
@ -768,24 +733,12 @@ impl Tensor {
|
||||
self.reduce_impl(dim, true, ReduceOp::ArgMin)
|
||||
}
|
||||
|
||||
/// Similar to `argmin_keepdim` but the target dimension is squeezed.
|
||||
pub fn argmin<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
self.reduce_impl(dim, false, ReduceOp::ArgMin)
|
||||
}
|
||||
|
||||
/// Element-wise comparison between two tensors, e.g. equality, greater than, ... The actual
|
||||
/// comparison operation is specified by the `op` argument.
|
||||
///
|
||||
/// The returned tensor has the same shape as the original tensors and uses `u8` elements.
|
||||
pub fn cmp<T: TensorOrScalar>(&self, rhs: T, op: CmpOp) -> Result<Self> {
|
||||
let rhs = match rhs.to_tensor_scalar()? {
|
||||
crate::scalar::TensorScalar::Tensor(rhs) => rhs,
|
||||
crate::scalar::TensorScalar::Scalar(rhs) => rhs
|
||||
.to_dtype(self.dtype())?
|
||||
.to_device(self.device())?
|
||||
.broadcast_as(self.shape())?,
|
||||
};
|
||||
let shape = self.same_shape_binary_op(&rhs, "cmp")?;
|
||||
pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result<Self> {
|
||||
let shape = self.same_shape_binary_op(rhs, "cmp")?;
|
||||
let storage = self
|
||||
.storage()
|
||||
.cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
|
||||
@ -793,45 +746,96 @@ impl Tensor {
|
||||
Ok(from_storage(storage, shape.dims(), op, false))
|
||||
}
|
||||
|
||||
/// Element-wise equality.
|
||||
pub fn eq<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||
pub fn eq(&self, rhs: &Self) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Eq)
|
||||
}
|
||||
|
||||
/// Element-wise non-equality.
|
||||
pub fn ne<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||
pub fn ne(&self, rhs: &Self) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Ne)
|
||||
}
|
||||
|
||||
/// Element-wise comparison with lower-than, the returned tensor uses value 1 where `self <
|
||||
/// rhs` and 0 otherwise.
|
||||
pub fn lt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||
pub fn lt(&self, rhs: &Self) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Lt)
|
||||
}
|
||||
|
||||
/// Element-wise comparison with greater-than, the returned tensor uses value 1 where `self >
|
||||
/// rhs` and 0 otherwise.
|
||||
pub fn gt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||
pub fn gt(&self, rhs: &Self) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Gt)
|
||||
}
|
||||
|
||||
/// Element-wise comparison with greater-equal, the returned tensor uses value 1 where `self >=
|
||||
/// rhs` and 0 otherwise.
|
||||
pub fn ge<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||
pub fn ge(&self, rhs: &Self) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Ge)
|
||||
}
|
||||
|
||||
/// Element-wise comparison with lower-equal, the returned tensor uses value 1 where `self <=
|
||||
/// rhs` and 0 otherwise.
|
||||
pub fn le<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||
pub fn le(&self, rhs: &Self) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Le)
|
||||
}
|
||||
|
||||
/// Upsample the input tensor to the `(target_h, target_w)` size, taking the value of the
|
||||
/// nearest element.
|
||||
///
|
||||
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
|
||||
/// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
|
||||
/// Applies a 1D convolution over the input tensor.
|
||||
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
||||
let (c_out, c_in_k, k_size) = kernel.dims3()?;
|
||||
let (b_size, c_in, l_in) = self.dims3()?;
|
||||
if c_in != c_in_k {
|
||||
Err(Error::Conv1dInvalidArgs {
|
||||
inp_shape: self.shape().clone(),
|
||||
k_shape: kernel.shape().clone(),
|
||||
padding,
|
||||
stride,
|
||||
msg: "the number of in-channels on the input doesn't match the kernel size",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let params = crate::conv::ParamsConv1D {
|
||||
b_size,
|
||||
l_in,
|
||||
c_out,
|
||||
c_in,
|
||||
k_size,
|
||||
padding,
|
||||
stride,
|
||||
};
|
||||
let storage =
|
||||
self.storage()
|
||||
.conv1d(self.layout(), &kernel.storage(), kernel.layout(), ¶ms)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
|
||||
arg,
|
||||
kernel,
|
||||
padding,
|
||||
stride,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
pub fn conv2d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
||||
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
||||
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
|
||||
if c_in != c_in_k {
|
||||
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||
}
|
||||
let params = crate::conv::ParamsConv2D {
|
||||
b_size,
|
||||
i_h,
|
||||
i_w,
|
||||
k_h,
|
||||
k_w,
|
||||
c_out,
|
||||
c_in,
|
||||
padding,
|
||||
stride,
|
||||
};
|
||||
let storage =
|
||||
self.storage()
|
||||
.conv2d(self.layout(), &kernel.storage(), kernel.layout(), ¶ms)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
|
||||
arg,
|
||||
kernel,
|
||||
padding,
|
||||
stride,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
|
||||
let (n, c, _h, _w) = self.dims4()?;
|
||||
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
|
||||
@ -841,26 +845,7 @@ impl Tensor {
|
||||
Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
|
||||
}
|
||||
|
||||
/// 2D average pooling over an input tensor with multiple channels.
|
||||
///
|
||||
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
|
||||
/// tensor also has four dimensions, `(batch, channels, h', w')`. The pooling is performed on
|
||||
/// the two last dimensions using a kernel of size `sz`. The returned element is the average
|
||||
/// value over the kernel window.
|
||||
pub fn avg_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
|
||||
let sz = sz.to_usize2();
|
||||
self.avg_pool2d_with_stride(sz, sz)
|
||||
}
|
||||
|
||||
/// Same as `avg_pool2d` but with a `stride` that can be set to a value different from the
|
||||
/// kernel size.
|
||||
pub fn avg_pool2d_with_stride<T: crate::ToUsize2>(
|
||||
&self,
|
||||
kernel_size: T,
|
||||
stride: T,
|
||||
) -> Result<Self> {
|
||||
let kernel_size = kernel_size.to_usize2();
|
||||
let stride = stride.to_usize2();
|
||||
pub fn avg_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> {
|
||||
let (n, c, h, w) = self.dims4()?;
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
|
||||
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
||||
@ -876,26 +861,7 @@ impl Tensor {
|
||||
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
|
||||
}
|
||||
|
||||
/// 2D max pooling over an input tensor with multiple channels.
|
||||
///
|
||||
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
|
||||
/// tensor also has four dimensions, `(batch, channels, h', w')`. The pooling is performed on
|
||||
/// the two last dimensions using a kernel of size `sz`, the returned element is the maximum
|
||||
/// value over the kernel window.
|
||||
pub fn max_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
|
||||
let sz = sz.to_usize2();
|
||||
self.max_pool2d_with_stride(sz, sz)
|
||||
}
|
||||
|
||||
/// Same as `max_pool2d` but with a `stride` that can be set to a value different from the
|
||||
/// kernel size.
|
||||
pub fn max_pool2d_with_stride<T: crate::ToUsize2>(
|
||||
&self,
|
||||
kernel_size: T,
|
||||
stride: T,
|
||||
) -> Result<Self> {
|
||||
let kernel_size = kernel_size.to_usize2();
|
||||
let stride = stride.to_usize2();
|
||||
pub fn max_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> {
|
||||
let (n, c, h, w) = self.dims4()?;
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
|
||||
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
||||
@ -961,28 +927,6 @@ impl Tensor {
|
||||
Ok(from_storage(storage, c_shape, op, false))
|
||||
}
|
||||
|
||||
/// Matrix-multiplication with broadcasting support.
|
||||
///
|
||||
/// Compared to `matmul` the two matrixes are allowed to have different dimensions as long as
|
||||
/// they are compatible for broadcast. E.g. if `self` has shape `(j, 1, n, k)` and `rhs` has
|
||||
/// shape `(l, k, m)`, the output will have shape `(j, l, n, m)`.
|
||||
pub fn broadcast_matmul(&self, rhs: &Self) -> Result<Self> {
|
||||
let lhs = self;
|
||||
let (l_shape, r_shape) = lhs.shape().broadcast_shape_matmul(rhs.shape())?;
|
||||
let l_broadcast = l_shape != *lhs.shape();
|
||||
let r_broadcast = r_shape != *rhs.shape();
|
||||
// TODO: Avoid concretising the broadcasted matrixes via contiguous.
|
||||
match (l_broadcast, r_broadcast) {
|
||||
(true, true) => lhs
|
||||
.broadcast_as(&l_shape)?
|
||||
.contiguous()?
|
||||
.matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
|
||||
(false, true) => lhs.matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
|
||||
(true, false) => lhs.broadcast_as(&l_shape)?.contiguous()?.matmul(rhs),
|
||||
(false, false) => lhs.matmul(rhs),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a tensor with the same shape as the input tensor, the values are taken from
|
||||
/// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the
|
||||
/// input tensor is equal to zero.
|
||||
@ -1075,7 +1019,6 @@ impl Tensor {
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
/// Accumulate element from `source` at indexes `indexes` and add them to `self`.
|
||||
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "index-add")?;
|
||||
let source_dims = source.dims();
|
||||
@ -1124,17 +1067,6 @@ impl Tensor {
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
/// Gather values across the target dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `self` - The input tensor.
|
||||
/// * `indexes` - The indices of elements to gather, this should have the same shape as `self`
|
||||
/// but can have a different number of elements on the target dimension.
|
||||
/// * `dim` - the target dimension.
|
||||
///
|
||||
/// The resulting tensor has the same shape as `indexes` and use values from `self` indexed on
|
||||
/// dimension `dim` by the values in `indexes`.
|
||||
pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "gather")?;
|
||||
let self_dims = self.dims();
|
||||
@ -1165,13 +1097,6 @@ impl Tensor {
|
||||
Ok(from_storage(storage, indexes.shape(), op, false))
|
||||
}
|
||||
|
||||
/// Select values for the input tensor at the target indexes across the specified dimension.
|
||||
///
|
||||
/// The `indexes` is argument is an int tensor with a single dimension.
|
||||
/// The output has the same number of dimension as the `self` input. The target dimension of
|
||||
/// the output has length the length of `indexes` and the values are taken from `self` using
|
||||
/// the index from `indexes`. Other dimensions have the same number of elements as the input
|
||||
/// tensor.
|
||||
pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "index-select")?;
|
||||
let indexes_len = match indexes.dims() {
|
||||
@ -1379,10 +1304,6 @@ impl Tensor {
|
||||
self.sum(dims)
|
||||
}
|
||||
|
||||
pub fn mean_all(&self) -> Result<Tensor> {
|
||||
self.sum_all()? / self.elem_count() as f64
|
||||
}
|
||||
|
||||
fn flatten_<D1: Dim, D2: Dim>(
|
||||
&self,
|
||||
start_dim: Option<D1>,
|
||||
@ -1504,42 +1425,6 @@ impl Tensor {
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
/// Returns a tensor with the same data as the input where the dimensions have been permuted.
|
||||
/// dims must be a permutation, i.e. include each dimension index exactly once.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let tensor = Tensor::arange(0u32, 120u32, &Device::Cpu)?.reshape((2, 3, 4, 5))?;
|
||||
/// assert_eq!(tensor.dims(), &[2, 3, 4, 5]);
|
||||
/// let tensor = tensor.permute((2, 3, 1, 0))?;
|
||||
/// assert_eq!(tensor.dims(), &[4, 5, 3, 2]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn permute<D: Dims>(&self, dims: D) -> Result<Tensor> {
|
||||
let dims = dims.to_indexes(self.shape(), "permute")?;
|
||||
// O(n^2) permutation check but these arrays are small.
|
||||
let is_permutation =
|
||||
dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
|
||||
if !is_permutation {
|
||||
crate::bail!(
|
||||
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
|
||||
self.dims(),
|
||||
dims
|
||||
)
|
||||
}
|
||||
let op = BackpropOp::new1(self, |t| Op::Permute(t, dims.clone()));
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
layout: self.layout.permute(&dims)?,
|
||||
op,
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
/// Returns true if the data is stored in a C contiguous (aka row major) way.
|
||||
pub fn is_contiguous(&self) -> bool {
|
||||
self.layout.is_contiguous()
|
||||
@ -1693,15 +1578,12 @@ impl Tensor {
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), true))
|
||||
}
|
||||
|
||||
// TODO: Do we want to allow target shape using -1 on some dimensions?
|
||||
/// Reshape returns a tensor with the target shape provided that the number of elements of the
|
||||
/// original tensor is the same.
|
||||
/// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses
|
||||
/// a new storage and copies the data over, the returned tensor is always contiguous.
|
||||
///
|
||||
/// The shape can be specified using a tuple of `usize` and at most one `()` in which case
|
||||
/// the behavior is the same as when using `-1` in PyTorch: this dimension size is adjusted so
|
||||
/// as to match the number of elements in the tensor.
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle_core::{Tensor, DType, Device, D};
|
||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
@ -1711,14 +1593,10 @@ impl Tensor {
|
||||
///
|
||||
/// let c = a.reshape((3, 2))?;
|
||||
/// assert_eq!(c.shape().dims(), &[3, 2]);
|
||||
///
|
||||
/// let c = a.reshape((2, (), 1))?;
|
||||
/// assert_eq!(c.shape().dims(), &[2, 3, 1]);
|
||||
///
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn reshape<S: crate::shape::ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
|
||||
let shape = s.into_shape(self.elem_count())?;
|
||||
pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
|
||||
let shape = shape.into();
|
||||
if shape.elem_count() != self.elem_count() {
|
||||
return Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: self.shape().clone(),
|
||||
@ -1941,8 +1819,6 @@ impl Tensor {
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
/// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the
|
||||
/// input tensor values and `right` elements after.
|
||||
pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
|
||||
if left == 0 && right == 0 {
|
||||
Ok(self.clone())
|
||||
@ -1969,12 +1845,7 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the `forward` method of `m` on `self`.
|
||||
pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
|
||||
m.forward(self)
|
||||
}
|
||||
|
||||
pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
|
||||
fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
|
||||
self.storage.read().unwrap()
|
||||
}
|
||||
|
||||
@ -1999,53 +1870,22 @@ impl Tensor {
|
||||
std::ptr::eq(lhs, rhs)
|
||||
}
|
||||
|
||||
/// Applies a unary custom op without backward support
|
||||
pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||
}
|
||||
|
||||
/// Applies a binary custom op without backward support
|
||||
pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
|
||||
let (storage, shape) =
|
||||
self.storage()
|
||||
.apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||
}
|
||||
|
||||
/// Applies a ternary custom op without backward support
|
||||
pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op3(
|
||||
self.layout(),
|
||||
&t2.storage(),
|
||||
t2.layout(),
|
||||
&t3.storage(),
|
||||
t3.layout(),
|
||||
c,
|
||||
)?;
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||
}
|
||||
|
||||
/// Applies a unary custom op.
|
||||
pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
|
||||
pub fn custom_op1_arc(&self, c: Arc<Box<dyn CustomOp1>>) -> Result<Self> {
|
||||
let (storage, shape) = self
|
||||
.storage()
|
||||
.apply_op1(self.layout(), c.as_ref().as_ref())?;
|
||||
.custom_op1(self.layout(), c.as_ref().as_ref())?;
|
||||
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
|
||||
self.apply_op1_arc(Arc::new(Box::new(c)))
|
||||
pub fn custom_op1<C: 'static + CustomOp1>(&self, c: C) -> Result<Self> {
|
||||
self.custom_op1_arc(Arc::new(Box::new(c)))
|
||||
}
|
||||
|
||||
/// Applies a binary custom op.
|
||||
pub fn apply_op2_arc(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
|
||||
) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op2(
|
||||
pub fn custom_op2_arc(&self, rhs: &Self, c: Arc<Box<dyn CustomOp2>>) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().custom_op2(
|
||||
self.layout(),
|
||||
&rhs.storage(),
|
||||
rhs.layout(),
|
||||
@ -2055,18 +1895,13 @@ impl Tensor {
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
|
||||
self.apply_op2_arc(r, Arc::new(Box::new(c)))
|
||||
pub fn custom_op2<C: 'static + CustomOp2>(&self, r: &Self, c: C) -> Result<Self> {
|
||||
self.custom_op2_arc(r, Arc::new(Box::new(c)))
|
||||
}
|
||||
|
||||
/// Applies a ternary custom op.
|
||||
pub fn apply_op3_arc(
|
||||
&self,
|
||||
t2: &Self,
|
||||
t3: &Self,
|
||||
c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
|
||||
) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op3(
|
||||
pub fn custom_op3_arc(&self, t2: &Self, t3: &Self, c: Arc<Box<dyn CustomOp3>>) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().custom_op3(
|
||||
self.layout(),
|
||||
&t2.storage(),
|
||||
t2.layout(),
|
||||
@ -2080,13 +1915,8 @@ impl Tensor {
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
|
||||
&self,
|
||||
t2: &Self,
|
||||
t3: &Self,
|
||||
c: C,
|
||||
) -> Result<Self> {
|
||||
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
||||
pub fn custom_op3<C: 'static + CustomOp3>(&self, t2: &Self, t3: &Self, c: C) -> Result<Self> {
|
||||
self.custom_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
||||
}
|
||||
}
|
||||
|
||||
@ -2108,22 +1938,6 @@ macro_rules! bin_trait {
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Tensor> for Result<B> {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
fn $fn1(self, rhs: Tensor) -> Self::Output {
|
||||
Tensor::$fn1(self?.borrow(), &rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<&Tensor> for Result<B> {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
fn $fn1(self, rhs: &Tensor) -> Self::Output {
|
||||
Tensor::$fn1(self?.borrow(), rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<Result<B>> for Tensor {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
@ -2162,69 +1976,3 @@ bin_trait!(Add, add, |_| 1., |v| v);
|
||||
bin_trait!(Sub, sub, |_| 1., |v: f64| -v);
|
||||
bin_trait!(Mul, mul, |v| v, |_| 0.);
|
||||
bin_trait!(Div, div, |v| 1. / v, |_| 0.);
|
||||
|
||||
impl std::ops::Add<Tensor> for f64 {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
fn add(self, rhs: Tensor) -> Self::Output {
|
||||
rhs + self
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Add<&Tensor> for f64 {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
fn add(self, rhs: &Tensor) -> Self::Output {
|
||||
rhs + self
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Mul<Tensor> for f64 {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
fn mul(self, rhs: Tensor) -> Self::Output {
|
||||
rhs * self
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Mul<&Tensor> for f64 {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
fn mul(self, rhs: &Tensor) -> Self::Output {
|
||||
rhs * self
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Sub<Tensor> for f64 {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
fn sub(self, rhs: Tensor) -> Self::Output {
|
||||
rhs.affine(-1., self)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Sub<&Tensor> for f64 {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
fn sub(self, rhs: &Tensor) -> Self::Output {
|
||||
rhs.affine(-1., self)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Div<Tensor> for f64 {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
#[allow(clippy::suspicious_arithmetic_impl)]
|
||||
fn div(self, rhs: Tensor) -> Self::Output {
|
||||
rhs.recip()? * self
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Div<&Tensor> for f64 {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
#[allow(clippy::suspicious_arithmetic_impl)]
|
||||
fn div(self, rhs: &Tensor) -> Self::Output {
|
||||
rhs.recip()? * self
|
||||
}
|
||||
}
|
||||
|
@ -22,19 +22,3 @@ pub fn has_mkl() -> bool {
|
||||
pub fn cuda_is_available() -> bool {
|
||||
cfg!(feature = "cuda")
|
||||
}
|
||||
|
||||
pub fn with_avx() -> bool {
|
||||
cfg!(target_feature = "avx")
|
||||
}
|
||||
|
||||
pub fn with_neon() -> bool {
|
||||
cfg!(target_feature = "neon")
|
||||
}
|
||||
|
||||
pub fn with_simd128() -> bool {
|
||||
cfg!(target_feature = "simd128")
|
||||
}
|
||||
|
||||
pub fn with_f16c() -> bool {
|
||||
cfg!(target_feature = "f16c")
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
mod test_utils;
|
||||
use anyhow::Result;
|
||||
use candle_core::{test_device, test_utils, Device, IndexOp, Tensor};
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
/* This test is based on the following script.
|
||||
import torch
|
||||
@ -32,13 +33,13 @@ fn conv1d(dev: &Device) -> Result<()> {
|
||||
dev,
|
||||
)?
|
||||
.reshape((2, 4, 3))?;
|
||||
let res = t.conv1d(&w, 0, 1, 1, 1)?;
|
||||
let res = t.conv1d(&w, 0, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.6357, -1.3336, 4.1393, -1.1784, 3.5675, 0.5069]
|
||||
);
|
||||
let res = t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?;
|
||||
let res = t.conv1d(&w, /*padding*/ 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 5]);
|
||||
// Same as pytorch default padding: use zeros.
|
||||
assert_eq!(
|
||||
@ -51,13 +52,13 @@ fn conv1d(dev: &Device) -> Result<()> {
|
||||
fn conv1d_small(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(&[0.4056f32, -0.8689, -0.0773, -1.5630], dev)?.reshape((1, 1, 4))?;
|
||||
let w = Tensor::new(&[1f32, 0., 0.], dev)?.reshape((1, 1, 3))?;
|
||||
let res = t.conv1d(&w, 0, 1, 1, 1)?;
|
||||
let res = t.conv1d(&w, 0, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 2]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[0.4056, -0.8689]
|
||||
);
|
||||
let res = t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?;
|
||||
let res = t.conv1d(&w, /*padding*/ 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 4]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
@ -76,19 +77,6 @@ print(t.flatten())
|
||||
print(w.flatten())
|
||||
res = torch.nn.functional.conv2d(t, w)
|
||||
print(res.flatten())
|
||||
|
||||
w_t = w.transpose(0, 1)
|
||||
res = torch.nn.functional.conv_transpose2d(t, w_t)
|
||||
print(res.shape)
|
||||
print(res)
|
||||
|
||||
res = torch.nn.functional.conv2d(t, w, dilation=2)
|
||||
print(res.shape)
|
||||
print(res[0])
|
||||
|
||||
res = torch.nn.functional.conv_transpose2d(t, w_t, dilation=2)
|
||||
print(res.shape)
|
||||
print(res)
|
||||
*/
|
||||
fn conv2d(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
@ -121,7 +109,7 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
)?;
|
||||
let t = t.reshape((1, 4, 5, 5))?;
|
||||
let w = w.reshape((2, 4, 3, 3))?;
|
||||
let res = t.conv2d(&w, 0, 1, 1, 1)?;
|
||||
let res = t.conv2d(&w, 0, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
@ -130,69 +118,6 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
||||
]
|
||||
);
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
||||
[
|
||||
[
|
||||
[-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277],
|
||||
[1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375],
|
||||
[0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889],
|
||||
[0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632],
|
||||
[-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985],
|
||||
[2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114],
|
||||
[5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579]
|
||||
],
|
||||
[
|
||||
[1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211],
|
||||
[-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131],
|
||||
[1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621],
|
||||
[-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142],
|
||||
[7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059],
|
||||
[-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516],
|
||||
[-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171]
|
||||
]
|
||||
]
|
||||
);
|
||||
// Dilations.
|
||||
let res = t.conv2d(&w, 0, 1, 2, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 1, 1]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.45, -2.3504],
|
||||
);
|
||||
|
||||
// Transpose and dilations.
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?;
|
||||
assert_eq!(res.dims(), [1, 2, 9, 9]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
||||
[
|
||||
[
|
||||
[-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277],
|
||||
[2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499],
|
||||
[-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376],
|
||||
[-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141],
|
||||
[-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822],
|
||||
[0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03],
|
||||
[-2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51, -3.5024],
|
||||
[4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787],
|
||||
[5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579]
|
||||
],
|
||||
[
|
||||
[1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211],
|
||||
[-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278],
|
||||
[1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861],
|
||||
[1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185],
|
||||
[1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642],
|
||||
[3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957],
|
||||
[5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856],
|
||||
[-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908],
|
||||
[-5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827, 1.0171]
|
||||
]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -206,16 +131,6 @@ print(t.flatten())
|
||||
print(w.flatten())
|
||||
res = torch.nn.functional.conv2d(t, w)
|
||||
print(res.flatten())
|
||||
|
||||
w_t = w.transpose(0, 1)
|
||||
res = torch.nn.functional.conv_transpose2d(t, w_t)
|
||||
print(res.shape)
|
||||
print(res.flatten())
|
||||
|
||||
t_t = w.transpose(0, 1)
|
||||
res = torch.nn.functional.conv_transpose2d(t_t, w)
|
||||
print(res.shape)
|
||||
print(res.flatten())
|
||||
*/
|
||||
fn conv2d_small(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
@ -228,41 +143,12 @@ fn conv2d_small(dev: &Device) -> Result<()> {
|
||||
let w = Tensor::new(&[-0.9259f32, 1.3017], dev)?;
|
||||
let t = t.reshape((1, 2, 3, 3))?;
|
||||
let w = w.reshape((1, 2, 1, 1))?;
|
||||
let res = t.conv2d(&w, 0, 1, 1, 1)?;
|
||||
let res = t.conv2d(&w, 0, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539]
|
||||
);
|
||||
let res = t.conv2d(&w, 2, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 7, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[
|
||||
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
||||
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1640, -0.0111, -0.1742, 0.0000, 0.0000,
|
||||
0.0000, 0.0000, 2.6437, -2.0268, 1.1823, 0.0000, 0.0000, 0.0000, 0.0000, 3.2855,
|
||||
-1.0324, 0.2539, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
||||
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
|
||||
]
|
||||
);
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539],
|
||||
);
|
||||
let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [2, 2, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[
|
||||
-0.3755, 0.8045, -0.6336, -0.2218, -1.1369, 0.8599, 1.5768, -0.1268, -0.1728, 0.528,
|
||||
-1.131, 0.8908, 0.3118, 1.5984, -1.2089, -2.2168, 0.1783, 0.2429, -0.3838, 0.5802,
|
||||
-0.3268, -2.0382, 0.6329, -0.2293, -1.2154, 0.6441, -0.3035, 0.5396, -0.8156, 0.4594,
|
||||
2.8654, -0.8898, 0.3224, 1.7087, -0.9056, 0.4267
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -276,7 +162,7 @@ fn conv2d_smaller(dev: &Device) -> Result<()> {
|
||||
let w = Tensor::new(&[1f32, 1., 1., 1., 1., 1., 1., 1., 1.], dev)?;
|
||||
let t = t.reshape((1, 1, 3, 3))?;
|
||||
let w = w.reshape((1, 1, 3, 3))?;
|
||||
let res = t.conv2d(&w, 0, 1, 1, 1)?;
|
||||
let res = t.conv2d(&w, 0, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 1, 1]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
@ -285,211 +171,8 @@ fn conv2d_smaller(dev: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/* This test is based on the following script.
|
||||
import torch
|
||||
torch.manual_seed(4242)
|
||||
|
||||
t = torch.randn((1, 2, 4, 2))
|
||||
w = torch.randn((1, 2, 1, 1))
|
||||
print(t.flatten())
|
||||
print(w.flatten())
|
||||
res = torch.nn.functional.conv2d(t, w)
|
||||
print(res.flatten())
|
||||
*/
|
||||
fn conv2d_non_square(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
&[
|
||||
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
||||
1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699,
|
||||
],
|
||||
dev,
|
||||
)?;
|
||||
let w = Tensor::new(&[-1.1351f32, 1.3841], dev)?;
|
||||
let t = t.reshape((1, 2, 4, 2))?;
|
||||
let w = w.reshape((1, 2, 1, 1))?;
|
||||
let res = t.conv2d(&w, 0, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 4, 2]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[0.2312, 5.2238, 2.3772, 1.9076, 2.0256, -0.5776, -1.6028, -1.467]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/*
|
||||
import torch
|
||||
torch.manual_seed(4242)
|
||||
|
||||
t = torch.randn((1, 4, 5, 5), requires_grad=True)
|
||||
w = torch.randn((2, 4, 3, 3), requires_grad=True)
|
||||
print(t.flatten())
|
||||
print(w.flatten())
|
||||
res = torch.nn.functional.conv2d(t, w)
|
||||
print(res.flatten())
|
||||
loss = (res ** 2).sum()
|
||||
print(loss)
|
||||
loss.backward()
|
||||
print(t.grad.shape)
|
||||
print(t.grad.flatten())
|
||||
print(w.grad.shape)
|
||||
print(w.grad.flatten())
|
||||
|
||||
t.grad.zero_()
|
||||
w.grad.zero_()
|
||||
res = torch.nn.functional.conv2d(t, w, stride=2)
|
||||
print(res.flatten())
|
||||
loss = (res ** 2).sum()
|
||||
print(loss)
|
||||
loss.backward()
|
||||
print(t.grad.shape)
|
||||
print(t.grad[0])
|
||||
print(w.grad.shape)
|
||||
print(w.grad[0])
|
||||
*/
|
||||
fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||
use candle_core::Var;
|
||||
let t = Var::from_slice(
|
||||
&[
|
||||
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
||||
1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395,
|
||||
1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836,
|
||||
0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130, 1.3123,
|
||||
1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071, 1.1586,
|
||||
0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090, 0.2049,
|
||||
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
|
||||
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
|
||||
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
|
||||
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
||||
],
|
||||
(1, 4, 5, 5),
|
||||
dev,
|
||||
)?;
|
||||
let w = Var::from_slice(
|
||||
&[
|
||||
-0.9325f32, 0.6451, -0.8537, 0.2378, 0.8764, -0.1832, 0.2987, -0.6488, -0.2273,
|
||||
-2.4184, -0.1192, -0.4821, -0.5079, -0.5766, -2.4729, 1.6734, 0.4558, 0.2851, 1.1514,
|
||||
-0.9013, 1.0662, -0.1817, -0.0259, 0.1709, 0.5367, 0.7513, 0.8086, -2.2586, -0.5027,
|
||||
0.9141, -1.3086, -1.3343, -1.5669, -0.1657, 0.7958, 0.1432, 0.3896, -0.4501, 0.1667,
|
||||
0.0714, -0.0952, 1.2970, -0.1674, -0.3178, 1.0677, 0.3060, 0.7080, 0.1914, 1.1679,
|
||||
-0.3602, 1.9265, -1.8626, -0.5112, -0.0982, 0.2621, 0.6565, 0.5908, 1.0089, -0.1646,
|
||||
1.8032, -0.6286, 0.2016, -0.3370, 1.2555, 0.8009, -0.6488, -0.4652, -1.5685, 1.5860,
|
||||
0.5583, 0.4623, 0.6026,
|
||||
],
|
||||
(2, 4, 3, 3),
|
||||
dev,
|
||||
)?;
|
||||
let res = t.conv2d(&w, 0, 1, 1, 1)?;
|
||||
let loss = res.sqr()?.sum_all()?;
|
||||
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 741.12f32);
|
||||
let grads = loss.backward()?;
|
||||
let grad_t = grads.get(&t).unwrap();
|
||||
let grad_w = grads.get(&w).unwrap();
|
||||
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
|
||||
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&grad_t.flatten_all()?, 2)?,
|
||||
[
|
||||
9.29, -2.84, -5.71, 3.38, -7.71, -19.15, 7.02, 29.1, 9.34, 34.73, -22.87, 24.35,
|
||||
-39.88, -14.01, 21.08, 9.94, 13.63, -34.68, 11.21, -6.26, 7.72, -6.32, -16.64, -1.08,
|
||||
-20.22, 21.73, -0.37, -4.06, 5.82, -3.65, -30.73, 14.55, 87.7, 31.6, 4.53, -89.78,
|
||||
-75.37, -57.43, -7.56, 92.96, 18.79, -4.63, -159.75, -42.47, -47.26, 52.88, 37.32,
|
||||
49.0, 12.82, 2.01, -8.98, 20.18, 16.62, 12.06, 15.38, 20.0, 2.57, -15.22, 72.62,
|
||||
-10.75, 2.25, -31.2, 3.75, -0.2, 9.76, -0.68, 5.21, -40.44, -22.59, -61.61, 17.28,
|
||||
20.41, 37.55, 5.23, 6.81, 23.54, 23.62, -9.99, -9.13, 4.87, -35.06, -26.1, 63.48,
|
||||
25.81, -39.21, -70.68, -46.96, 2.33, 41.81, 82.42, -28.63, -11.78, -35.33, -10.28,
|
||||
-28.57, -9.13, 7.21, -9.05, -9.62, -11.25
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&grad_w.flatten_all()?, 2)?,
|
||||
[
|
||||
-28.92, -22.88, -141.23, 73.35, 61.07, 47.81, -20.0, -73.71, -41.82, -13.59, 21.5,
|
||||
28.72, 28.57, -46.85, -90.19, 143.61, 16.68, 7.43, 18.88, -90.81, -20.29, 54.79, 82.63,
|
||||
22.94, 77.81, -16.39, -13.2, 9.34, -40.39, -26.62, 5.33, -60.91, 9.09, -59.37, 7.08,
|
||||
58.64, 5.55, 20.52, 2.5, -17.25, -6.8, 22.21, 30.15, -7.52, -37.46, 5.67, 22.58, 9.03,
|
||||
47.05, 17.61, 37.31, -98.13, -14.61, -4.8, -6.36, 44.69, 23.34, 8.37, -13.52, 80.05,
|
||||
-34.24, -16.36, -12.31, 1.92, -33.62, -14.1, -49.23, -7.39, 11.5, -9.98, 9.66, 29.6
|
||||
]
|
||||
);
|
||||
|
||||
// Same as before but with stride.
|
||||
let res = t.conv2d(&w, 0, 2, 1, 1)?;
|
||||
let loss = res.sqr()?.sum_all()?;
|
||||
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 277.16f32);
|
||||
let grads = loss.backward()?;
|
||||
let grad_t = grads.get(&t).unwrap();
|
||||
let grad_w = grads.get(&w).unwrap();
|
||||
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
|
||||
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_t.i(0)?, 2)?,
|
||||
[
|
||||
[
|
||||
[9.29, -7.03, 0.94, 3.49, -7.71],
|
||||
[-1.8, -7.82, 8.9, 8.46, 7.43],
|
||||
[-25.84, 22.09, -19.27, -0.22, 1.69],
|
||||
[4.02, 18.53, -18.37, 2.3, -24.51],
|
||||
[7.72, -9.68, -12.34, 5.6, -20.22]
|
||||
],
|
||||
[
|
||||
[21.73, 3.39, -18.27, 3.86, -3.65],
|
||||
[8.25, 3.73, 30.73, -8.61, -11.93],
|
||||
[-72.15, -15.36, -17.53, -12.32, -1.61],
|
||||
[-22.32, -7.79, -91.82, 6.44, -37.69],
|
||||
[52.88, 14.44, 42.75, 9.88, 2.01]
|
||||
],
|
||||
[
|
||||
[-8.98, 9.91, 6.75, -4.68, 15.38],
|
||||
[4.93, -0.33, 9.94, -1.46, 14.78],
|
||||
[13.62, -30.63, 3.96, -3.58, -4.48],
|
||||
[-14.13, 1.19, -34.43, 3.08, -33.83],
|
||||
[17.28, 12.94, 31.83, -3.35, 6.81]
|
||||
],
|
||||
[
|
||||
[23.54, 6.98, -24.52, 0.52, 4.87],
|
||||
[9.65, 6.18, 1.71, -25.23, -4.93],
|
||||
[-54.99, -23.66, 3.19, -3.73, 18.58],
|
||||
[-21.35, -10.39, -39.88, 28.73, -30.76],
|
||||
[-9.13, 11.12, -14.0, -8.23, -11.25]
|
||||
]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_w.i(0)?, 2)?,
|
||||
[
|
||||
[
|
||||
[28.34, -7.91, -45.75],
|
||||
[21.03, 3.86, 29.86],
|
||||
[0.72, -36.58, -35.28]
|
||||
],
|
||||
[
|
||||
[-16.04, 11.53, -16.38],
|
||||
[29.62, -16.32, -48.35],
|
||||
[57.5, 28.29, 25.81]
|
||||
],
|
||||
[
|
||||
[2.93, -19.6, 1.57],
|
||||
[27.15, 53.88, -24.64],
|
||||
[12.74, -22.6, -26.2]
|
||||
],
|
||||
[
|
||||
[-0.18, -14.86, -6.82],
|
||||
[-19.55, -2.72, 45.9],
|
||||
[-2.54, 36.97, 27.11]
|
||||
]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
|
||||
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
|
||||
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
|
||||
test_device!(
|
||||
conv2d_non_square,
|
||||
conv2d_non_square_cpu,
|
||||
conv2d_non_square_gpu
|
||||
);
|
||||
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
|
||||
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
|
||||
test_device!(conv2d_grad, conv2d_grad_cpu, conv2d_grad_gpu);
|
||||
|
@ -1,8 +1,10 @@
|
||||
use candle_core::backend::BackendStorage;
|
||||
use candle_core::cpu_backend;
|
||||
use candle_core::test_utils::to_vec1_round;
|
||||
use candle_core::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor};
|
||||
|
||||
mod test_utils;
|
||||
use test_utils::to_vec1_round;
|
||||
|
||||
fn fwd<T: num_traits::Float>(v: T, alpha: f64) -> T {
|
||||
if v.is_sign_positive() {
|
||||
v
|
||||
@ -37,7 +39,7 @@ fn custom_op1_no_backward() -> Result<()> {
|
||||
let cpu = &Device::Cpu;
|
||||
let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?;
|
||||
let t = (t - 5.)?;
|
||||
let elu_t = t.apply_op1_no_bwd(&Elu { alpha: 1. })?;
|
||||
let elu_t = t.custom_op1(Elu { alpha: 1. })?;
|
||||
assert_eq!(
|
||||
to_vec1_round(&elu_t, 4)?,
|
||||
&[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
@ -94,7 +96,7 @@ impl CustomOp1 for EluWithBackward {
|
||||
|
||||
fn bwd(&self, arg: &Tensor, _res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
|
||||
let alpha = self.0.alpha;
|
||||
let bwd = arg.apply_op1(EluBackward { alpha })?;
|
||||
let bwd = arg.custom_op1(EluBackward { alpha })?;
|
||||
Ok(Some(grad_res.mul(&bwd)?))
|
||||
}
|
||||
}
|
||||
@ -103,7 +105,7 @@ impl CustomOp1 for EluWithBackward {
|
||||
fn custom_op1_with_backward() -> Result<()> {
|
||||
let cpu = &Device::Cpu;
|
||||
let t = candle_core::Var::new(&[-2f32, 0f32, 2f32], cpu)?;
|
||||
let elu_t = t.apply_op1(EluWithBackward::new(2.))?;
|
||||
let elu_t = t.custom_op1(EluWithBackward::new(2.))?;
|
||||
assert_eq!(to_vec1_round(&elu_t, 4)?, &[-1.7293, 0.0, 2.0]);
|
||||
|
||||
let grads = elu_t.backward()?;
|
||||
|
@ -1,5 +1,6 @@
|
||||
use anyhow::{Context, Result};
|
||||
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
|
||||
use candle_core::{Device, Shape, Tensor, Var};
|
||||
mod test_utils;
|
||||
|
||||
fn simple_grad(device: &Device) -> Result<()> {
|
||||
let x = Var::new(&[3f32, 1., 4.], device)?;
|
||||
@ -173,51 +174,6 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [6., 2., 8., 0.3]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [2., 2., 2., 2.]);
|
||||
|
||||
let x = Var::new(&[3f32, 1., 4., 0.15], device)?;
|
||||
let y = x.powf(2.5)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(test_utils::to_vec1_round(&y, 2)?, [15.59, 1.0, 32.0, 0.01]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 2)?,
|
||||
[12.99, 2.5, 20.0, 0.15]
|
||||
);
|
||||
|
||||
let y = x.tanh()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(test_utils::to_vec1_round(&y, 2)?, [1.0, 0.76, 1.0, 0.15]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 2)?,
|
||||
[0.01, 0.42, 0.0, 0.98],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn binary_grad(device: &Device) -> Result<()> {
|
||||
let x = Var::new(&[3f32, 1., -4., -1.], device)?;
|
||||
let x = x.as_tensor();
|
||||
// leaky relu
|
||||
let y = x.maximum(&(x * 0.1)?)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(x.to_vec1::<f32>()?, [3., 1., -4., -1.]);
|
||||
assert_eq!(y.to_vec1::<f32>()?, [3., 1., -0.4, -0.1]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [1., 1., 0.1, 0.1]);
|
||||
|
||||
let y = x.minimum(&(x * 0.1)?)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [0.3, 0.1, -4., -1.]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [0.1, 0.1, 1., 1.]);
|
||||
|
||||
// This one is easy to mess up, we want the gradient to be one as it is the identity function.
|
||||
let y = x.minimum(x)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [3., 1., -4., -1.]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [1., 1., 1., 1.]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -226,4 +182,3 @@ test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
|
||||
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
|
||||
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
|
||||
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu);
|
||||
test_device!(binary_grad, binary_grad_cpu, binary_grad_gpu);
|
||||
|
@ -1,6 +1,8 @@
|
||||
use anyhow::Result;
|
||||
use candle_core::{Device, IndexOp, Tensor};
|
||||
|
||||
mod test_utils;
|
||||
|
||||
#[test]
|
||||
fn integer_index() -> Result<()> {
|
||||
let dev = Device::Cpu;
|
||||
|
@ -1,4 +1,5 @@
|
||||
use candle::{test_device, Device, IndexOp, Result, Tensor};
|
||||
mod test_utils;
|
||||
use candle::{Device, IndexOp, Result, Tensor};
|
||||
use candle_core as candle;
|
||||
|
||||
fn contiguous(device: &Device) -> Result<()> {
|
||||
|
@ -1,4 +1,5 @@
|
||||
use candle_core::{test_device, test_utils, Device, IndexOp, Result, Tensor};
|
||||
mod test_utils;
|
||||
use candle_core::{Device, IndexOp, Result, Tensor};
|
||||
|
||||
// https://github.com/huggingface/candle/issues/364
|
||||
fn avg_pool2d(dev: &Device) -> Result<()> {
|
||||
@ -6,15 +7,8 @@ fn avg_pool2d(dev: &Device) -> Result<()> {
|
||||
1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?;
|
||||
let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?;
|
||||
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(pool.to_vec2::<f32>()?, [[0.5f32, 1.], [1., 1.]]);
|
||||
|
||||
let data: Vec<f32> = vec![
|
||||
1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1.,
|
||||
];
|
||||
let t = Tensor::from_vec(data, (1, 1, 2, 8), dev)?;
|
||||
let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(pool.to_vec2::<f32>()?, [[5. / 4., 6. / 4., 6. / 4., 1.]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -24,12 +18,8 @@ fn max_pool2d(dev: &Device) -> Result<()> {
|
||||
];
|
||||
let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?;
|
||||
|
||||
let pool = t.max_pool2d(2)?.squeeze(0)?.squeeze(0)?;
|
||||
let pool = t.max_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(pool.to_vec2::<f32>()?, [[2f32, 3.], [5., 1.]]);
|
||||
|
||||
let t = t.reshape((1, 1, 2, 8))?;
|
||||
let pool = t.max_pool2d(2)?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(pool.to_vec2::<f32>()?, [[2.0, 3.0, 5.0, 1.0]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -53,29 +43,16 @@ fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
|
||||
dev,
|
||||
)?
|
||||
.reshape((1, 2, 4, 4))?;
|
||||
let pool = t.avg_pool2d(2)?.squeeze(0)?;
|
||||
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&pool, 4)?,
|
||||
test_utils::to_vec3_round(pool, 4)?,
|
||||
[
|
||||
[[-1.1926, -0.0395], [0.2688, 0.1871]],
|
||||
[[0.1835, -0.1606], [0.6249, 0.3217]]
|
||||
]
|
||||
);
|
||||
let pool = t.avg_pool2d(3)?.squeeze(0)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&pool, 4)?,
|
||||
[[[0.085]], [[0.0078]]]
|
||||
);
|
||||
|
||||
let t = t.reshape((1, 1, 4, 8))?;
|
||||
let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&pool, 4)?,
|
||||
[
|
||||
[0.7745, 0.0276, -1.6983, 0.12],
|
||||
[0.3542, 0.1625, 0.4542, -0.0014]
|
||||
]
|
||||
);
|
||||
let pool = t.avg_pool2d((3, 3), (3, 3))?.squeeze(0)?;
|
||||
assert_eq!(test_utils::to_vec3_round(pool, 4)?, [[[0.085]], [[0.0078]]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -1,17 +1,5 @@
|
||||
use candle_core::{
|
||||
quantized::{self, GgmlDType},
|
||||
test_utils::to_vec2_round,
|
||||
Device, Result, Tensor,
|
||||
};
|
||||
use candle_core::{quantized, Device, Result, Tensor};
|
||||
use quantized::{k_quants, GgmlType};
|
||||
use rand::prelude::*;
|
||||
|
||||
const GGML_TEST_SIZE: usize = 32 * 128;
|
||||
|
||||
const GGML_MAX_QUANTIZATION_TOTAL_ERROR: f32 = 0.002;
|
||||
const GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS: f32 = 0.0075;
|
||||
const GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS: f32 = 0.0040;
|
||||
const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02;
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul() -> Result<()> {
|
||||
@ -26,10 +14,10 @@ fn quantized_matmul() -> Result<()> {
|
||||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
||||
assert_eq!(
|
||||
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
|
||||
dst,
|
||||
&[
|
||||
85120.0, 214562.0, 345455.0, 474748.0, 213475.0, 604465.0, 1000686.0, 1388317.0,
|
||||
341876.0, 994283.0, 1655709.0, 2301518.0
|
||||
85120.43, 214561.61, 345454.9, 474748.1, 213474.94, 604465.25, 1000686.4, 1388317.3,
|
||||
341875.88, 994283.0, 1655708.8, 2301518.3
|
||||
]
|
||||
);
|
||||
let mm = tensor_lhs.matmul(&tensor_rhs)?;
|
||||
@ -42,648 +30,17 @@ fn quantized_matmul() -> Result<()> {
|
||||
]
|
||||
);
|
||||
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (64, 4));
|
||||
let op = quantized::QMatMul::new(std::sync::Arc::new(qtensor));
|
||||
let res = tensor_lhs.custom_op1(op)?;
|
||||
assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
res.to_vec2::<f32>()?,
|
||||
&[
|
||||
[85120.0, 214562.0, 345455.0, 474748.0],
|
||||
[213475.0, 604465.0, 1000686.0, 1388317.0],
|
||||
[341876.0, 994283.0, 1655709.0, 2301518.0]
|
||||
[85120.43, 214561.61, 345454.9, 474748.1],
|
||||
[213474.94, 604465.25, 1000686.4, 1388317.3],
|
||||
[341875.88, 994283.0, 1655708.8, 2301518.3]
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_neg() -> Result<()> {
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (3, 64, 4);
|
||||
let lhs = (0..(m * k))
|
||||
.map(|v| v as f32 - (m * k) as f32 / 2.0)
|
||||
.collect::<Vec<_>>();
|
||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
|
||||
let mut dst = vec![42.; 3 * 4];
|
||||
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
||||
let rhs = (0..k * n)
|
||||
.map(|v| v as f32 - (k * n) as f32 / 3.0)
|
||||
.collect::<Vec<_>>();
|
||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
|
||||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
||||
assert_eq!(
|
||||
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
|
||||
&[
|
||||
243524.0, -19596.0, -285051.0, -549815.0, 23777.0, 21651.0, 19398.0, 18367.0,
|
||||
-196472.0, 63012.0, 324585.0, 587902.0
|
||||
]
|
||||
);
|
||||
let mm = tensor_lhs.matmul(&tensor_rhs)?;
|
||||
assert_eq!(
|
||||
to_vec2_round(&mm, 0)?,
|
||||
&[
|
||||
[244064.0, -20128.0, -284320.0, -548512.0],
|
||||
[23563.0, 21515.0, 19467.0, 17419.0],
|
||||
[-196939.0, 63157.0, 323253.0, 583349.0]
|
||||
]
|
||||
);
|
||||
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[243524.0, -19596.0, -285051.0, -549815.0],
|
||||
[23777.0, 21651.0, 19398.0, 18367.0],
|
||||
[-196472.0, 63012.0, 324585.0, 587902.0]
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q4_0() -> Result<()> {
|
||||
use k_quants::BlockQ4_0;
|
||||
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let mut dst = vec![0f32; 32 * 4];
|
||||
let mut quant = vec![BlockQ4_0::zeros(); 4];
|
||||
BlockQ4_0::from_float(&src, &mut quant)?;
|
||||
BlockQ4_0::to_float(&quant, dst.as_mut_slice())?;
|
||||
assert_eq!(
|
||||
dst,
|
||||
&[
|
||||
-0.0, -0.0, 3.875, 3.875, 3.875, 3.875, 7.75, 7.75, 7.75, 7.75, 11.625, 11.625, 11.625,
|
||||
11.625, 15.5, 15.5, 15.5, 15.5, 19.375, 19.375, 19.375, 19.375, 23.25, 23.25, 23.25,
|
||||
23.25, 27.125, 27.125, 27.125, 27.125, 31.0, 31.0, 31.5, 31.5, 31.5, 31.5, 39.375,
|
||||
39.375, 39.375, 39.375, 39.375, 39.375, 39.375, 39.375, 47.25, 47.25, 47.25, 47.25,
|
||||
47.25, 47.25, 47.25, 47.25, 55.125, 55.125, 55.125, 55.125, 55.125, 55.125, 55.125,
|
||||
55.125, 63.0, 63.0, 63.0, 63.0, 59.375, 59.375, 71.25, 71.25, 71.25, 71.25, 71.25,
|
||||
71.25, 71.25, 71.25, 71.25, 71.25, 71.25, 71.25, 83.125, 83.125, 83.125, 83.125,
|
||||
83.125, 83.125, 83.125, 83.125, 83.125, 83.125, 83.125, 83.125, 95.0, 95.0, 95.0, 95.0,
|
||||
95.0, 95.0, 95.25, 95.25, 95.25, 95.25, 95.25, 95.25, 95.25, 95.25, 111.125, 111.125,
|
||||
111.125, 111.125, 111.125, 111.125, 111.125, 111.125, 111.125, 111.125, 111.125,
|
||||
111.125, 111.125, 111.125, 111.125, 111.125, 127.0, 127.0, 127.0, 127.0, 127.0, 127.0,
|
||||
127.0, 127.0
|
||||
]
|
||||
);
|
||||
ggml_quantization_error_test::<BlockQ4_0>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q4_1() -> Result<()> {
|
||||
use k_quants::BlockQ4_1;
|
||||
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let mut dst = vec![0f32; 32 * 4];
|
||||
let mut quant = vec![BlockQ4_1::zeros(); 4];
|
||||
BlockQ4_1::from_float(&src, &mut quant)?;
|
||||
BlockQ4_1::to_float(&quant, dst.as_mut_slice())?;
|
||||
assert_eq!(
|
||||
round_vector(&dst),
|
||||
&[
|
||||
0.0, 0.0, 2.066, 2.066, 4.133, 4.133, 6.199, 6.199, 8.266, 8.266, 10.332, 10.332,
|
||||
12.398, 12.398, 14.465, 14.465, 16.531, 16.531, 18.598, 18.598, 20.664, 20.664, 22.73,
|
||||
22.73, 24.797, 24.797, 26.863, 26.863, 28.93, 28.93, 30.996, 30.996, 32.0, 32.0,
|
||||
34.066, 34.066, 36.133, 36.133, 38.199, 38.199, 40.266, 40.266, 42.332, 42.332, 44.398,
|
||||
44.398, 46.465, 46.465, 48.531, 48.531, 50.598, 50.598, 52.664, 52.664, 54.73, 54.73,
|
||||
56.797, 56.797, 58.863, 58.863, 60.93, 60.93, 62.996, 62.996, 64.0, 64.0, 66.066,
|
||||
66.066, 68.133, 68.133, 70.199, 70.199, 72.266, 72.266, 74.332, 74.332, 76.398, 76.398,
|
||||
78.465, 78.465, 80.531, 80.531, 82.598, 82.598, 84.664, 84.664, 86.73, 86.73, 88.797,
|
||||
88.797, 90.863, 90.863, 92.93, 92.93, 94.996, 94.996, 96.0, 96.0, 98.066, 98.066,
|
||||
100.133, 100.133, 102.199, 102.199, 104.266, 104.266, 106.332, 106.332, 108.398,
|
||||
108.398, 110.465, 110.465, 112.531, 112.531, 114.598, 114.598, 116.664, 116.664,
|
||||
118.73, 118.73, 120.797, 120.797, 122.863, 122.863, 124.93, 124.93, 126.996, 126.996
|
||||
]
|
||||
);
|
||||
ggml_quantization_error_test::<BlockQ4_1>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q5_0() -> Result<()> {
|
||||
use k_quants::BlockQ5_0;
|
||||
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let mut dst = vec![0f32; 32 * 4];
|
||||
let mut quant = vec![BlockQ5_0::zeros(); 4];
|
||||
BlockQ5_0::from_float(&src, &mut quant)?;
|
||||
BlockQ5_0::to_float(&quant, dst.as_mut_slice())?;
|
||||
assert_eq!(
|
||||
round_vector(&dst),
|
||||
&[
|
||||
-0.0, 1.938, 1.938, 3.875, 3.875, 5.813, 5.813, 7.75, 7.75, 9.688, 9.688, 11.625,
|
||||
11.625, 13.563, 13.563, 15.5, 15.5, 17.438, 17.438, 19.375, 19.375, 21.313, 21.313,
|
||||
23.25, 23.25, 25.188, 25.188, 27.125, 27.125, 29.063, 29.063, 31.0, 31.5, 31.5, 35.438,
|
||||
35.438, 35.438, 35.438, 39.375, 39.375, 39.375, 39.375, 43.313, 43.313, 43.313, 43.313,
|
||||
47.25, 47.25, 47.25, 47.25, 51.188, 51.188, 51.188, 51.188, 55.125, 55.125, 55.125,
|
||||
55.125, 59.063, 59.063, 59.063, 59.063, 63.0, 63.0, 65.313, 65.313, 65.313, 65.313,
|
||||
65.313, 71.25, 71.25, 71.25, 71.25, 71.25, 71.25, 77.188, 77.188, 77.188, 77.188,
|
||||
77.188, 77.188, 83.125, 83.125, 83.125, 83.125, 83.125, 83.125, 89.063, 89.063, 89.063,
|
||||
89.063, 89.063, 89.063, 95.0, 95.0, 95.0, 95.25, 95.25, 95.25, 95.25, 103.188, 103.188,
|
||||
103.188, 103.188, 103.188, 103.188, 103.188, 103.188, 111.125, 111.125, 111.125,
|
||||
111.125, 111.125, 111.125, 111.125, 111.125, 119.063, 119.063, 119.063, 119.063,
|
||||
119.063, 119.063, 119.063, 119.063, 127.0, 127.0, 127.0, 127.0
|
||||
]
|
||||
);
|
||||
ggml_quantization_error_test::<BlockQ5_0>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q5_1() -> Result<()> {
|
||||
use k_quants::BlockQ5_1;
|
||||
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let mut dst = vec![0f32; 32 * 4];
|
||||
let mut quant = vec![BlockQ5_1::zeros(); 4];
|
||||
BlockQ5_1::from_float(&src, &mut quant)?;
|
||||
BlockQ5_1::to_float(&quant, dst.as_mut_slice())?;
|
||||
assert_eq!(
|
||||
dst,
|
||||
&[
|
||||
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
|
||||
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0,
|
||||
30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0,
|
||||
44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0,
|
||||
58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0,
|
||||
72.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0, 80.0, 81.0, 82.0, 83.0, 84.0, 85.0,
|
||||
86.0, 87.0, 88.0, 89.0, 90.0, 91.0, 92.0, 93.0, 94.0, 95.0, 96.0, 97.0, 98.0, 99.0,
|
||||
100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0,
|
||||
112.0, 113.0, 114.0, 115.0, 116.0, 117.0, 118.0, 119.0, 120.0, 121.0, 122.0, 123.0,
|
||||
124.0, 125.0, 126.0, 127.0
|
||||
]
|
||||
);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ5_1>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Generates a small test vector ranging from -`bound` to `bound` with `size` steps
|
||||
fn get_test_vector(bound: f32, size: usize) -> (Vec<f32>, Vec<f32>) {
|
||||
assert!(
|
||||
size % crate::quantized::k_quants::QK_K == 0,
|
||||
"size must be a multiple of {}",
|
||||
crate::quantized::k_quants::QK_K
|
||||
);
|
||||
|
||||
let src = (0..size)
|
||||
.map(|v| (v as f32 - size as f32 / 2.) * bound / (size as f32 / 2.))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let dst = vec![0f32; size];
|
||||
assert_eq!([src[0], src[size / 2]], [-bound, 0.0]);
|
||||
(src, dst)
|
||||
}
|
||||
|
||||
/// Round a vector
|
||||
fn round_vector(values: &[f32]) -> Vec<f32> {
|
||||
values
|
||||
.iter()
|
||||
.map(|x| (1000. * x).round() / 1000.)
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn compare_with_error(values: &[f32], expected: &[f32], tolerance: f32) {
|
||||
for (i, (value, expected_value)) in values.iter().zip(expected.iter()).enumerate() {
|
||||
let difference = (value - expected_value).abs();
|
||||
|
||||
assert!(
|
||||
difference < tolerance,
|
||||
"Error at index {}: value = {}, expected = {}. Difference = {} exceeds tolerance = {}.",
|
||||
i,
|
||||
value,
|
||||
expected_value,
|
||||
difference,
|
||||
tolerance
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
|
||||
fn create_ggml_like_vector(offset: f32) -> Vec<f32> {
|
||||
(0..GGML_TEST_SIZE)
|
||||
.map(|i| 0.1 + 2.0 * (i as f32 + offset).cos())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Calculates the root mean square error between two vectors
|
||||
fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 {
|
||||
assert_eq!(a.len(), b.len());
|
||||
let sum = a
|
||||
.iter()
|
||||
.zip(b)
|
||||
.map(|(a, b)| (a - b).powi(2))
|
||||
.sum::<f32>()
|
||||
.sqrt();
|
||||
sum / a.len() as f32
|
||||
}
|
||||
|
||||
/// Mirrores the GGML quanitzation unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
|
||||
fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> {
|
||||
let src = create_ggml_like_vector(0.0);
|
||||
let mut dst = vec![0.0; GGML_TEST_SIZE];
|
||||
let _quant = quantize_roundtrip::<T>(src.as_slice(), dst.as_mut_slice())?;
|
||||
let error = calculate_rmse(src.as_slice(), dst.as_slice());
|
||||
if error > max_error {
|
||||
candle_core::bail!(
|
||||
"Quantization error {} exceeds max error {}",
|
||||
error,
|
||||
max_error
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn quantize_roundtrip<T: GgmlType>(src: &[f32], dst: &mut [f32]) -> Result<Vec<T>> {
|
||||
let mut quant = vec![T::zeros(); src.len() / T::BLCK_SIZE];
|
||||
T::from_float(src, &mut quant)?;
|
||||
T::to_float(&quant, dst)?;
|
||||
Ok(quant)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q2k() -> Result<()> {
|
||||
use k_quants::BlockQ2K;
|
||||
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ2K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.1);
|
||||
|
||||
// Test some specific values
|
||||
assert_eq!(
|
||||
[src[0], src[128], src[256], src[512], src[800], src[1023]],
|
||||
[-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344]
|
||||
);
|
||||
let dst = round_vector(&dst);
|
||||
assert_eq!(
|
||||
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
|
||||
[-0.499, -0.366, -0.249, 0.0, 0.295, 0.492]
|
||||
);
|
||||
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ2K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 6.0);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ2K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q3k() -> Result<()> {
|
||||
use k_quants::BlockQ3K;
|
||||
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ3K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.03);
|
||||
|
||||
// Test some specific values
|
||||
assert_eq!(
|
||||
[src[0], src[128], src[256], src[512], src[800], src[1023]],
|
||||
[-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344]
|
||||
);
|
||||
let dst = round_vector(&dst);
|
||||
assert_eq!(
|
||||
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
|
||||
[-0.493, -0.37, -0.243, -0.0, 0.292, 0.492]
|
||||
);
|
||||
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ3K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 3.5);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ3K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q4k() -> Result<()> {
|
||||
use k_quants::BlockQ4K;
|
||||
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ4K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.017);
|
||||
|
||||
// Test some specific values
|
||||
assert_eq!(
|
||||
[src[0], src[128], src[256], src[512], src[800], src[1023]],
|
||||
[-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344]
|
||||
);
|
||||
let dst = round_vector(&dst);
|
||||
assert_eq!(
|
||||
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
|
||||
[-0.5, -0.373, -0.25, 0.0, 0.288, 0.498]
|
||||
);
|
||||
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ4K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 4.5);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ4K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q5k() -> Result<()> {
|
||||
use k_quants::BlockQ5K;
|
||||
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ5K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
|
||||
|
||||
// Test some specific values
|
||||
assert_eq!(
|
||||
[src[0], src[128], src[256], src[512], src[800], src[1023]],
|
||||
[-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344]
|
||||
);
|
||||
let dst = round_vector(&dst);
|
||||
assert_eq!(
|
||||
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
|
||||
[-0.499, -0.372, -0.249, 0.001, 0.279, 0.499]
|
||||
);
|
||||
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ5K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.5);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ5K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q6k() -> Result<()> {
|
||||
use k_quants::BlockQ6K;
|
||||
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ6K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
|
||||
|
||||
// Test some specific values
|
||||
assert_eq!(
|
||||
[src[0], src[128], src[256], src[512], src[800], src[1023]],
|
||||
[-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344]
|
||||
);
|
||||
let dst = round_vector(&dst);
|
||||
assert_eq!(
|
||||
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
|
||||
[-0.497, -0.372, -0.25, -0.0, 0.284, 0.5]
|
||||
);
|
||||
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ6K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.0);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ6K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q8k() -> Result<()> {
|
||||
use k_quants::BlockQ8K;
|
||||
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ8K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.003);
|
||||
|
||||
// Test some specific values
|
||||
assert_eq!(
|
||||
[src[0], src[128], src[256], src[512], src[800], src[1023]],
|
||||
[-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344]
|
||||
);
|
||||
let dst = round_vector(&dst);
|
||||
assert_eq!(
|
||||
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
|
||||
[-0.5, -0.375, -0.25, -0.0, 0.281, 0.499]
|
||||
);
|
||||
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ8K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 0.6);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ8K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Very simple dot product implementation
|
||||
fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b).map(|(a, b)| a * b).sum()
|
||||
}
|
||||
|
||||
/// Returns the error achieved by the GGML matmul unit test.
|
||||
fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
|
||||
let err = match dtype {
|
||||
GgmlDType::F16 => 0.000010,
|
||||
GgmlDType::Q2K => 0.004086,
|
||||
GgmlDType::Q3K => 0.016148,
|
||||
GgmlDType::Q4K => 0.002425,
|
||||
GgmlDType::Q5K => 0.000740,
|
||||
GgmlDType::Q6K => 0.000952,
|
||||
GgmlDType::Q4_0 => 0.001143,
|
||||
GgmlDType::Q4_1 => 0.007784,
|
||||
GgmlDType::Q5_0 => 0.001353,
|
||||
GgmlDType::Q5_1 => 0.001363,
|
||||
GgmlDType::Q8_0 => 0.000092,
|
||||
_ => candle_core::bail!("No GGML results for quantization type {dtype:?}",),
|
||||
};
|
||||
Ok(err)
|
||||
}
|
||||
|
||||
/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
|
||||
fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
|
||||
let a = create_ggml_like_vector(0.0);
|
||||
let b = create_ggml_like_vector(1.0);
|
||||
let length = a.len();
|
||||
|
||||
let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE];
|
||||
let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE];
|
||||
T::from_float(&a, &mut a_quant)?;
|
||||
T::VecDotType::from_float(&b, &mut b_quant)?;
|
||||
|
||||
let result = T::vec_dot(length, &a_quant, &b_quant)?;
|
||||
let reference_result = vec_dot_reference(&a, &b);
|
||||
|
||||
let error = (result - reference_result).abs() / length as f32;
|
||||
|
||||
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;
|
||||
|
||||
if error > GGML_MAX_DOT_PRODUCT_ERROR {
|
||||
candle_core::bail!(
|
||||
"Dot product error {} exceeds max error {}",
|
||||
error,
|
||||
GGML_MAX_DOT_PRODUCT_ERROR
|
||||
);
|
||||
}
|
||||
|
||||
// We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML
|
||||
// => we use a slightly higher error threshold
|
||||
const ERROR_LENIENCY: f32 = 0.00001;
|
||||
if error - ERROR_LENIENCY > ggml_error {
|
||||
candle_core::bail!(
|
||||
"Dot product error {} exceeds ggml reference error {}",
|
||||
error,
|
||||
ggml_error
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result.
|
||||
fn get_random_tensors(
|
||||
m: usize,
|
||||
k: usize,
|
||||
n: usize,
|
||||
device: &Device,
|
||||
) -> Result<(Tensor, Tensor, Tensor)> {
|
||||
let mut rng = StdRng::seed_from_u64(314159265358979);
|
||||
|
||||
let lhs = (0..m * k)
|
||||
.map(|_| rng.gen::<f32>() - 0.5)
|
||||
.collect::<Vec<_>>();
|
||||
let rhs = (0..n * k)
|
||||
.map(|_| rng.gen::<f32>() - 0.5)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let lhs = Tensor::from_vec(lhs, (m, k), device)?;
|
||||
let rhs = Tensor::from_vec(rhs, (n, k), device)?;
|
||||
|
||||
let mm = lhs.matmul(&rhs.t()?)?;
|
||||
Ok((lhs, rhs, mm))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_q2k() -> Result<()> {
|
||||
use k_quants::BlockQ2K;
|
||||
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (11, 512, 21);
|
||||
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [0.916, 0.422, 0.215, 1.668]);
|
||||
|
||||
ggml_matmul_error_test::<BlockQ2K>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_q3k() -> Result<()> {
|
||||
use k_quants::BlockQ3K;
|
||||
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (11, 512, 21);
|
||||
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.029, 1.418, -0.314, 1.495]);
|
||||
|
||||
ggml_matmul_error_test::<BlockQ3K>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_q4k() -> Result<()> {
|
||||
use k_quants::BlockQ4K;
|
||||
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (11, 512, 21);
|
||||
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.125, 1.435, -0.201, 1.589]);
|
||||
|
||||
ggml_matmul_error_test::<BlockQ4K>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_q5k() -> Result<()> {
|
||||
use k_quants::BlockQ5K;
|
||||
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (11, 512, 21);
|
||||
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.192, 1.491, -0.18, 1.743]);
|
||||
|
||||
//Expected: 0.000740408897
|
||||
ggml_matmul_error_test::<BlockQ5K>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_q6k() -> Result<()> {
|
||||
use k_quants::BlockQ6K;
|
||||
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (11, 512, 21);
|
||||
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.324, 1.49, -0.164, 1.741]);
|
||||
|
||||
ggml_matmul_error_test::<BlockQ6K>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,4 +1,5 @@
|
||||
use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor};
|
||||
mod test_utils;
|
||||
use candle_core::{DType, Device, IndexOp, Result, Tensor};
|
||||
|
||||
fn zeros(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
|
||||
@ -35,10 +36,10 @@ fn tensor_2d(device: &Device) -> Result<()> {
|
||||
|
||||
fn binary_op(device: &Device) -> Result<()> {
|
||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||
let tensor1 = Tensor::new(data, device)?;
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
let data2 = &[[5f32, 5., 5., 5., 5.], [2., 1., 7., 8., 2.]];
|
||||
let tensor2 = Tensor::new(data2, device)?;
|
||||
let tensor = (&tensor1 + (&tensor1 * &tensor1)? / (&tensor1 + &tensor2))?;
|
||||
let tensor = (&tensor + (&tensor * &tensor)? / (&tensor + &tensor2))?;
|
||||
let dims = tensor.dims2()?;
|
||||
assert_eq!(dims, (2, 5));
|
||||
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
|
||||
@ -48,17 +49,6 @@ fn binary_op(device: &Device) -> Result<()> {
|
||||
let tensor = (&tensor - &tensor)?;
|
||||
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
|
||||
assert_eq!(content[0], [0., 0., 0., 0., 0.]);
|
||||
|
||||
let min = tensor1.minimum(&(&tensor2 * 0.5)?)?;
|
||||
let max = tensor1.maximum(&(&tensor2 * 0.5)?)?;
|
||||
assert_eq!(
|
||||
min.to_vec2::<f32>()?,
|
||||
[[2.5, 1.0, 2.5, 1.0, 2.5], [1.0, 0.5, 3.5, 4.0, 1.0]],
|
||||
);
|
||||
assert_eq!(
|
||||
max.to_vec2::<f32>()?,
|
||||
[[3.0, 2.5, 4.0, 2.5, 5.0], [2.0, 1.0, 7.0, 8.0, 2.0]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -757,25 +747,6 @@ fn matmul(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn broadcast_matmul(device: &Device) -> Result<()> {
|
||||
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
|
||||
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
|
||||
let out = lhs.broadcast_matmul(&rhs)?;
|
||||
assert_eq!(out.dims(), &[3, 6, 4, 2]);
|
||||
for idx1 in 0..3 {
|
||||
for idx2 in 0..6 {
|
||||
let out = out.i((idx1, idx2))?;
|
||||
let lhs = lhs.i((idx1, 0))?;
|
||||
let rhs = rhs.i(idx2)?;
|
||||
let out2 = lhs.matmul(&rhs);
|
||||
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
|
||||
// With cuda, we see errors of up to ~1e-12.
|
||||
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn broadcasting(device: &Device) -> Result<()> {
|
||||
let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?;
|
||||
let t2 = Tensor::new(&[100f32, 200f32], device)?;
|
||||
@ -893,7 +864,6 @@ test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
||||
test_device!(cmp, cmp_cpu, cmp_gpu);
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu);
|
||||
test_device!(broadcast_matmul, broadcast_matmul_cpu, broadcast_matmul_gpu);
|
||||
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
|
||||
test_device!(index_select, index_select_cpu, index_select_gpu);
|
||||
test_device!(index_add, index_add_cpu, index_add_gpu);
|
||||
|
@ -1,4 +1,9 @@
|
||||
use crate::{Result, Tensor};
|
||||
#![allow(dead_code)]
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle_core::{Result, Tensor};
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! test_device {
|
||||
@ -18,12 +23,6 @@ macro_rules! test_device {
|
||||
};
|
||||
}
|
||||
|
||||
pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
|
||||
let b = 10f32.powi(digits);
|
||||
let t = t.to_vec0::<f32>()?;
|
||||
Ok(f32::round(t * b) / b)
|
||||
}
|
||||
|
||||
pub fn to_vec1_round(t: &Tensor, digits: i32) -> Result<Vec<f32>> {
|
||||
let b = 10f32.powi(digits);
|
||||
let t = t.to_vec1::<f32>()?;
|
||||
@ -41,7 +40,7 @@ pub fn to_vec2_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<f32>>> {
|
||||
Ok(t)
|
||||
}
|
||||
|
||||
pub fn to_vec3_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
|
||||
pub fn to_vec3_round(t: Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
|
||||
let b = 10f32.powi(digits);
|
||||
let t = t.to_vec3::<f32>()?;
|
||||
let t = t
|
@ -11,13 +11,10 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
byteorder = { workspace = true }
|
||||
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.1" }
|
||||
candle = { path = "../candle-core", version = "0.1.1", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.1.1" }
|
||||
hf-hub = { workspace = true}
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
memmap2 = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
rand = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
parquet = { workspace = true}
|
||||
image = { workspace = true }
|
||||
|
@ -1,73 +0,0 @@
|
||||
use hf_hub::{
|
||||
api::sync::{Api, ApiRepo},
|
||||
Repo, RepoType,
|
||||
};
|
||||
use parquet::file::reader::SerializedFileReader;
|
||||
use std::fs::File;
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("ApiError : {0}")]
|
||||
ApiError(#[from] hf_hub::api::sync::ApiError),
|
||||
|
||||
#[error("IoError : {0}")]
|
||||
IoError(#[from] std::io::Error),
|
||||
|
||||
#[error("ParquetError : {0}")]
|
||||
ParquetError(#[from] parquet::errors::ParquetError),
|
||||
}
|
||||
|
||||
fn sibling_to_parquet(
|
||||
rfilename: &str,
|
||||
repo: &ApiRepo,
|
||||
) -> Result<SerializedFileReader<File>, Error> {
|
||||
let local = repo.get(rfilename)?;
|
||||
let file = File::open(local)?;
|
||||
let reader = SerializedFileReader::new(file)?;
|
||||
Ok(reader)
|
||||
}
|
||||
|
||||
pub fn from_hub(api: &Api, dataset_id: String) -> Result<Vec<SerializedFileReader<File>>, Error> {
|
||||
let repo = Repo::with_revision(
|
||||
dataset_id,
|
||||
RepoType::Dataset,
|
||||
"refs/convert/parquet".to_string(),
|
||||
);
|
||||
let repo = api.repo(repo);
|
||||
let info = repo.info()?;
|
||||
|
||||
let files: Result<Vec<_>, _> = info
|
||||
.siblings
|
||||
.into_iter()
|
||||
.filter_map(|s| -> Option<Result<_, _>> {
|
||||
let filename = s.rfilename;
|
||||
if filename.ends_with(".parquet") {
|
||||
let reader_result = sibling_to_parquet(&filename, &repo);
|
||||
Some(reader_result)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let files = files?;
|
||||
|
||||
Ok(files)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use parquet::file::reader::FileReader;
|
||||
|
||||
#[test]
|
||||
fn test_dataset() {
|
||||
let api = Api::new().unwrap();
|
||||
let files = from_hub(
|
||||
&api,
|
||||
"hf-internal-testing/dummy_image_text_data".to_string(),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(files.len(), 1);
|
||||
assert_eq!(files[0].metadata().file_metadata().num_rows(), 20);
|
||||
}
|
||||
}
|
@ -1,6 +1,5 @@
|
||||
//! Datasets & Dataloaders for Candle
|
||||
pub mod batcher;
|
||||
pub mod hub;
|
||||
pub mod nlp;
|
||||
pub mod vision;
|
||||
|
||||
|
@ -2,9 +2,7 @@
|
||||
//!
|
||||
//! The files can be obtained from the following link:
|
||||
//! <http://yann.lecun.com/exdb/mnist/>
|
||||
use candle::{DType, Device, Error, Result, Tensor};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use parquet::file::reader::{FileReader, SerializedFileReader};
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use std::fs::File;
|
||||
use std::io::{self, BufReader, Read};
|
||||
|
||||
@ -65,58 +63,3 @@ pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<crate::vision::Data
|
||||
labels: 10,
|
||||
})
|
||||
}
|
||||
|
||||
fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor, Tensor)> {
|
||||
let samples = parquet.metadata().file_metadata().num_rows() as usize;
|
||||
let mut buffer_images: Vec<u8> = Vec::with_capacity(samples * 784);
|
||||
let mut buffer_labels: Vec<u8> = Vec::with_capacity(samples);
|
||||
for row in parquet.into_iter().flatten() {
|
||||
for (_name, field) in row.get_column_iter() {
|
||||
if let parquet::record::Field::Group(subrow) = field {
|
||||
for (_name, field) in subrow.get_column_iter() {
|
||||
if let parquet::record::Field::Bytes(value) = field {
|
||||
let image = image::load_from_memory(value.data()).unwrap();
|
||||
buffer_images.extend(image.to_luma8().as_raw());
|
||||
}
|
||||
}
|
||||
} else if let parquet::record::Field::Long(label) = field {
|
||||
buffer_labels.push(*label as u8);
|
||||
}
|
||||
}
|
||||
}
|
||||
let images = (Tensor::from_vec(buffer_images, (samples, 784), &Device::Cpu)?
|
||||
.to_dtype(DType::F32)?
|
||||
/ 255.)?;
|
||||
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
|
||||
Ok((images, labels))
|
||||
}
|
||||
|
||||
pub fn load() -> Result<crate::vision::Dataset> {
|
||||
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||
let dataset_id = "mnist".to_string();
|
||||
let repo = Repo::with_revision(
|
||||
dataset_id,
|
||||
RepoType::Dataset,
|
||||
"refs/convert/parquet".to_string(),
|
||||
);
|
||||
let repo = api.repo(repo);
|
||||
let test_parquet_filename = repo
|
||||
.get("mnist/test/0000.parquet")
|
||||
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||
let train_parquet_filename = repo
|
||||
.get("mnist/train/0000.parquet")
|
||||
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)
|
||||
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
|
||||
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)
|
||||
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
|
||||
let (test_images, test_labels) = load_parquet(test_parquet)?;
|
||||
let (train_images, train_labels) = load_parquet(train_parquet)?;
|
||||
Ok(crate::vision::Dataset {
|
||||
train_images,
|
||||
train_labels,
|
||||
test_images,
|
||||
test_labels,
|
||||
labels: 10,
|
||||
})
|
||||
}
|
||||
|
@ -11,11 +11,11 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.2.1" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.1" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.2.1" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
|
||||
candle = { path = "../candle-core", version = "0.1.1", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.1.1" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.1.1" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.1.1" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.1.1", optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
@ -23,17 +23,15 @@ num-traits = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
image = { workspace = true }
|
||||
image = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
byteorder = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
hf-hub = { workspace = true, features=["tokio"]}
|
||||
imageproc = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
memmap2 = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
rusttype = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-chrome = { workspace = true }
|
||||
@ -57,3 +55,7 @@ nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||
[[example]]
|
||||
name = "llama_multiprocess"
|
||||
required-features = ["cuda", "nccl", "flash-attn"]
|
||||
|
||||
[[example]]
|
||||
name = "stable-diffusion"
|
||||
required-features = ["image"]
|
||||
|
@ -1,8 +1,5 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
mod model;
|
||||
|
||||
use anyhow::{anyhow, Error as E, Result};
|
||||
@ -62,16 +59,16 @@ impl Args {
|
||||
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
|
||||
let cache = Cache::default().repo(repo);
|
||||
let cache = Cache::default();
|
||||
(
|
||||
cache
|
||||
.get("config.json")
|
||||
.get(&repo, "config.json")
|
||||
.ok_or(anyhow!("Missing config file in cache"))?,
|
||||
cache
|
||||
.get("tokenizer.json")
|
||||
.get(&repo, "tokenizer.json")
|
||||
.ok_or(anyhow!("Missing tokenizer file in cache"))?,
|
||||
cache
|
||||
.get("model.safetensors")
|
||||
.get(&repo, "model.safetensors")
|
||||
.ok_or(anyhow!("Missing weights file in cache"))?,
|
||||
)
|
||||
} else {
|
||||
|
@ -1,5 +1,5 @@
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle_nn::{Embedding, Module, VarBuilder};
|
||||
use candle_nn::{Embedding, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
pub const DTYPE: DType = DType::F32;
|
||||
|
@ -1,9 +1,6 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
use candle_nn::{Embedding, LayerNorm, Linear, VarBuilder};
|
||||
|
||||
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
|
@ -2,16 +2,19 @@
|
||||
// own forward pass (CPU and GPU versions) as well as their backward pass.
|
||||
//
|
||||
// In this example we add the RMS normalization operation and implement it for f32.
|
||||
#![allow(dead_code)]
|
||||
#![allow(unused)]
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[allow(unused)]
|
||||
mod cuda_kernels;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{CpuStorage, CustomOp1, Layout, Result, Shape, Tensor};
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cpu_backend;
|
||||
use candle::{CpuStorage, CustomOp1, DType, Device, Layout, Result, Shape, Tensor};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
@ -54,9 +57,8 @@ impl CustomOp1 for LayerNorm {
|
||||
storage: &candle::CudaStorage,
|
||||
layout: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cuda_backend::cudarc::driver::{LaunchAsync, LaunchConfig};
|
||||
use candle::cuda_backend::WrapErr;
|
||||
use candle::cuda_backend::{cudarc, WrapErr};
|
||||
use cudarc::driver::{LaunchAsync, LaunchConfig};
|
||||
let (d1, d2) = layout.shape().dims2()?;
|
||||
let d1 = d1 as u32;
|
||||
let d2 = d2 as u32;
|
||||
@ -87,7 +89,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let t = Tensor::arange(0f32, 14f32, &device)?.reshape((2, 7))?;
|
||||
println!("{t}");
|
||||
let t = t.apply_op1(LayerNorm { eps: 1e-5 })?;
|
||||
let t = t.custom_op1(LayerNorm { eps: 1e-5 })?;
|
||||
println!("{t}");
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,339 +0,0 @@
|
||||
//! DINOv2: Learning Robust Visual Features without Supervision
|
||||
//! https://github.com/facebookresearch/dinov2
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
const IMG_SIZE: usize = 518;
|
||||
const PATCH_SIZE: usize = 14;
|
||||
const NUM_CLASSES: usize = 1000;
|
||||
|
||||
fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
|
||||
if bias {
|
||||
candle_nn::linear(in_dim, out_dim, vb)
|
||||
} else {
|
||||
candle_nn::linear_no_bias(in_dim, out_dim, vb)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Attention {
|
||||
qkv: Linear,
|
||||
proj: Linear,
|
||||
num_heads: usize,
|
||||
scale: f64,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(
|
||||
vb: VarBuilder,
|
||||
dim: usize,
|
||||
num_heads: usize,
|
||||
qkv_bias: bool,
|
||||
proj_bias: bool,
|
||||
) -> Result<Self> {
|
||||
let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
|
||||
let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
|
||||
let scale = 1. / ((dim / num_heads) as f64).sqrt();
|
||||
Ok(Self {
|
||||
qkv,
|
||||
proj,
|
||||
num_heads,
|
||||
scale,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Attention {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (b, n, c) = xs.dims3()?;
|
||||
let qkv = self
|
||||
.qkv
|
||||
.forward(xs)?
|
||||
.reshape((b, n, 3, self.num_heads, c / self.num_heads))?
|
||||
.transpose(1, 2)? // 02134
|
||||
.transpose(0, 1)? // 20134
|
||||
.transpose(2, 3)?; // 20314
|
||||
let q = (qkv.i(0)? * self.scale)?;
|
||||
let k = qkv.i(1)?;
|
||||
let v = qkv.i(2)?;
|
||||
let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?;
|
||||
let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
|
||||
self.proj.forward(&attn)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct LayerScale {
|
||||
gamma: Tensor,
|
||||
}
|
||||
|
||||
impl LayerScale {
|
||||
fn new(vb: VarBuilder, dim: usize) -> Result<Self> {
|
||||
let gamma = vb.get(dim, "gamma")?;
|
||||
Ok(Self { gamma })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for LayerScale {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.broadcast_mul(&self.gamma)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Mlp {
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> {
|
||||
let out_features = in_features;
|
||||
let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?;
|
||||
let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?;
|
||||
Ok(Self { fc1, fc2 })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.fc1.forward(xs)?.gelu()?;
|
||||
self.fc2.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Block {
|
||||
norm1: LayerNorm,
|
||||
attn: Attention,
|
||||
ls1: LayerScale,
|
||||
norm2: LayerNorm,
|
||||
mlp: Mlp,
|
||||
ls2: LayerScale,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
|
||||
let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
|
||||
let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?;
|
||||
let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
|
||||
let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
|
||||
let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
|
||||
let ls2 = LayerScale::new(vb.pp("ls2"), dim)?;
|
||||
Ok(Self {
|
||||
norm1,
|
||||
attn,
|
||||
ls1,
|
||||
norm2,
|
||||
mlp,
|
||||
ls2,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Block {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self
|
||||
.ls1
|
||||
.forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?;
|
||||
let xs = (xs + residual)?;
|
||||
let residual = &xs;
|
||||
let xs = self
|
||||
.ls2
|
||||
.forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?;
|
||||
xs + residual
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PatchEmbed {
|
||||
proj: candle_nn::Conv2d,
|
||||
patch_size: (usize, usize),
|
||||
num_patches: usize,
|
||||
}
|
||||
|
||||
impl PatchEmbed {
|
||||
fn new(
|
||||
vb: VarBuilder,
|
||||
img_size: usize,
|
||||
patch_size: usize,
|
||||
in_chans: usize,
|
||||
embed_dim: usize,
|
||||
) -> Result<Self> {
|
||||
let config = candle_nn::Conv2dConfig {
|
||||
stride: patch_size,
|
||||
..Default::default()
|
||||
};
|
||||
let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?;
|
||||
let num_patches = (img_size / patch_size) * (img_size / patch_size);
|
||||
Ok(Self {
|
||||
proj,
|
||||
patch_size: (patch_size, patch_size),
|
||||
num_patches,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for PatchEmbed {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (_b, _c, h, w) = xs.dims4()?;
|
||||
let (patch_h, patch_w) = self.patch_size;
|
||||
if (h % patch_h) != 0 {
|
||||
candle::bail!("image height {h} is not a multiple of patch height {patch_h}")
|
||||
}
|
||||
if (w % patch_w) != 0 {
|
||||
candle::bail!("image width {w} is not a multiple of patch width {patch_w}")
|
||||
}
|
||||
let xs = self.proj.forward(xs)?;
|
||||
let (b, c, h, w) = xs.dims4()?;
|
||||
// flatten embeddings.
|
||||
xs.reshape((b, c, h * w))?.transpose(1, 2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DinoVisionTransformer {
|
||||
patch_embed: PatchEmbed,
|
||||
cls_token: Tensor,
|
||||
pos_embed: Tensor,
|
||||
blocks: Vec<Block>,
|
||||
norm: LayerNorm,
|
||||
head: Linear,
|
||||
}
|
||||
|
||||
impl DinoVisionTransformer {
|
||||
pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> {
|
||||
let patch_embed =
|
||||
PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?;
|
||||
let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
|
||||
let num_tokens = 1;
|
||||
let pos_embed = vb.get(
|
||||
(1, patch_embed.num_patches + num_tokens, embed_dim),
|
||||
"pos_embed",
|
||||
)?;
|
||||
let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?;
|
||||
let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?;
|
||||
let vb_b = vb.pp("blocks");
|
||||
let blocks = (0..depth)
|
||||
.map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self {
|
||||
patch_embed,
|
||||
cls_token,
|
||||
pos_embed,
|
||||
blocks,
|
||||
norm,
|
||||
head,
|
||||
})
|
||||
}
|
||||
|
||||
fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
|
||||
let npatch = xs.dim(1)? - 1;
|
||||
let n = self.pos_embed.dim(1)? - 1;
|
||||
let sqrt_n = (n as f64).sqrt();
|
||||
if npatch == n && w == h {
|
||||
return Ok(xs.clone());
|
||||
}
|
||||
let class_pos_embed = self.pos_embed.i((.., ..1))?;
|
||||
let patch_pos_embed = self.pos_embed.i((.., 1..))?;
|
||||
let dim = xs.dim(D::Minus1)?;
|
||||
let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
|
||||
let patch_pos_embed = patch_pos_embed
|
||||
.reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
|
||||
.transpose(2, 3)?
|
||||
.transpose(1, 2)?;
|
||||
// This uses bicubic interpolation in the original implementation.
|
||||
let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
|
||||
let el_count = patch_pos_embed.shape().elem_count();
|
||||
let patch_pos_embed =
|
||||
patch_pos_embed
|
||||
.transpose(1, 2)?
|
||||
.transpose(2, 3)?
|
||||
.reshape((1, el_count / dim, dim))?;
|
||||
Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1)
|
||||
}
|
||||
|
||||
fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (_b, _nc, w, h) = xs.dims4()?;
|
||||
let xs = self.patch_embed.forward(xs)?;
|
||||
let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
|
||||
&xs + &self.interpolate_pos_encoding(&xs, w, h)?
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for DinoVisionTransformer {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = self.prepare_tokens_with_mask(xs)?;
|
||||
for blk in self.blocks.iter() {
|
||||
xs = blk.forward(&xs)?
|
||||
}
|
||||
let xs = self.norm.forward(&xs)?;
|
||||
let xs_norm_clstoken = xs.i((.., 0))?;
|
||||
let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?;
|
||||
let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?;
|
||||
self.head.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> {
|
||||
DinoVisionTransformer::new(vb, 12, 384, 6)
|
||||
}
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("lmz/candle-dino-v2".into());
|
||||
api.get("dinov2_vits14.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let model = vit_small(vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -1,429 +0,0 @@
|
||||
//! EfficientNet implementation.
|
||||
//!
|
||||
//! https://arxiv.org/abs/1905.11946
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
use nn::{Module, VarBuilder};
|
||||
|
||||
// Based on the Python version from torchvision.
|
||||
// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct MBConvConfig {
|
||||
expand_ratio: f64,
|
||||
kernel: usize,
|
||||
stride: usize,
|
||||
input_channels: usize,
|
||||
out_channels: usize,
|
||||
num_layers: usize,
|
||||
}
|
||||
|
||||
fn make_divisible(v: f64, divisor: usize) -> usize {
|
||||
let min_value = divisor;
|
||||
let new_v = usize::max(
|
||||
min_value,
|
||||
(v + divisor as f64 * 0.5) as usize / divisor * divisor,
|
||||
);
|
||||
if (new_v as f64) < 0.9 * v {
|
||||
new_v + divisor
|
||||
} else {
|
||||
new_v
|
||||
}
|
||||
}
|
||||
|
||||
fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> {
|
||||
let bneck_conf = |e, k, s, i, o, n| {
|
||||
let input_channels = make_divisible(i as f64 * width_mult, 8);
|
||||
let out_channels = make_divisible(o as f64 * width_mult, 8);
|
||||
let num_layers = (n as f64 * depth_mult).ceil() as usize;
|
||||
MBConvConfig {
|
||||
expand_ratio: e,
|
||||
kernel: k,
|
||||
stride: s,
|
||||
input_channels,
|
||||
out_channels,
|
||||
num_layers,
|
||||
}
|
||||
};
|
||||
vec![
|
||||
bneck_conf(1., 3, 1, 32, 16, 1),
|
||||
bneck_conf(6., 3, 2, 16, 24, 2),
|
||||
bneck_conf(6., 5, 2, 24, 40, 2),
|
||||
bneck_conf(6., 3, 2, 40, 80, 3),
|
||||
bneck_conf(6., 5, 1, 80, 112, 3),
|
||||
bneck_conf(6., 5, 2, 112, 192, 4),
|
||||
bneck_conf(6., 3, 1, 192, 320, 1),
|
||||
]
|
||||
}
|
||||
|
||||
impl MBConvConfig {
|
||||
fn b0() -> Vec<Self> {
|
||||
bneck_confs(1.0, 1.0)
|
||||
}
|
||||
fn b1() -> Vec<Self> {
|
||||
bneck_confs(1.0, 1.1)
|
||||
}
|
||||
fn b2() -> Vec<Self> {
|
||||
bneck_confs(1.1, 1.2)
|
||||
}
|
||||
fn b3() -> Vec<Self> {
|
||||
bneck_confs(1.2, 1.4)
|
||||
}
|
||||
fn b4() -> Vec<Self> {
|
||||
bneck_confs(1.4, 1.8)
|
||||
}
|
||||
fn b5() -> Vec<Self> {
|
||||
bneck_confs(1.6, 2.2)
|
||||
}
|
||||
fn b6() -> Vec<Self> {
|
||||
bneck_confs(1.8, 2.6)
|
||||
}
|
||||
fn b7() -> Vec<Self> {
|
||||
bneck_confs(2.0, 3.1)
|
||||
}
|
||||
}
|
||||
|
||||
/// Conv2D with same padding.
|
||||
#[derive(Debug)]
|
||||
struct Conv2DSame {
|
||||
conv2d: nn::Conv2d,
|
||||
s: usize,
|
||||
k: usize,
|
||||
}
|
||||
|
||||
impl Conv2DSame {
|
||||
fn new(
|
||||
vb: VarBuilder,
|
||||
i: usize,
|
||||
o: usize,
|
||||
k: usize,
|
||||
stride: usize,
|
||||
groups: usize,
|
||||
bias: bool,
|
||||
) -> Result<Self> {
|
||||
let conv_config = nn::Conv2dConfig {
|
||||
stride,
|
||||
groups,
|
||||
..Default::default()
|
||||
};
|
||||
let conv2d = if bias {
|
||||
nn::conv2d(i, o, k, conv_config, vb)?
|
||||
} else {
|
||||
nn::conv2d_no_bias(i, o, k, conv_config, vb)?
|
||||
};
|
||||
Ok(Self {
|
||||
conv2d,
|
||||
s: stride,
|
||||
k,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Conv2DSame {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let s = self.s;
|
||||
let k = self.k;
|
||||
let (_, _, ih, iw) = xs.dims4()?;
|
||||
let oh = (ih + s - 1) / s;
|
||||
let ow = (iw + s - 1) / s;
|
||||
let pad_h = usize::max((oh - 1) * s + k - ih, 0);
|
||||
let pad_w = usize::max((ow - 1) * s + k - iw, 0);
|
||||
if pad_h > 0 || pad_w > 0 {
|
||||
let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?;
|
||||
let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?;
|
||||
self.conv2d.forward(&xs)
|
||||
} else {
|
||||
self.conv2d.forward(xs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ConvNormActivation {
|
||||
conv2d: Conv2DSame,
|
||||
bn2d: nn::BatchNorm,
|
||||
activation: bool,
|
||||
}
|
||||
|
||||
impl ConvNormActivation {
|
||||
fn new(
|
||||
vb: VarBuilder,
|
||||
i: usize,
|
||||
o: usize,
|
||||
k: usize,
|
||||
stride: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?;
|
||||
let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?;
|
||||
Ok(Self {
|
||||
conv2d,
|
||||
bn2d,
|
||||
activation: true,
|
||||
})
|
||||
}
|
||||
|
||||
fn no_activation(self) -> Self {
|
||||
Self {
|
||||
activation: false,
|
||||
..self
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ConvNormActivation {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.conv2d.forward(xs)?;
|
||||
let xs = self.bn2d.forward(&xs)?;
|
||||
if self.activation {
|
||||
swish(&xs)
|
||||
} else {
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct SqueezeExcitation {
|
||||
fc1: Conv2DSame,
|
||||
fc2: Conv2DSame,
|
||||
}
|
||||
|
||||
impl SqueezeExcitation {
|
||||
fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> {
|
||||
let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?;
|
||||
let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?;
|
||||
Ok(Self { fc1, fc2 })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SqueezeExcitation {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
// equivalent to adaptive_avg_pool2d([1, 1])
|
||||
let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
|
||||
let xs = self.fc1.forward(&xs)?;
|
||||
let xs = swish(&xs)?;
|
||||
let xs = self.fc2.forward(&xs)?;
|
||||
let xs = nn::ops::sigmoid(&xs)?;
|
||||
residual.broadcast_mul(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MBConv {
|
||||
expand_cna: Option<ConvNormActivation>,
|
||||
depthwise_cna: ConvNormActivation,
|
||||
squeeze_excitation: SqueezeExcitation,
|
||||
project_cna: ConvNormActivation,
|
||||
config: MBConvConfig,
|
||||
}
|
||||
|
||||
impl MBConv {
|
||||
fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> {
|
||||
let vb = vb.pp("block");
|
||||
let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8);
|
||||
let expand_cna = if exp != c.input_channels {
|
||||
Some(ConvNormActivation::new(
|
||||
vb.pp("0"),
|
||||
c.input_channels,
|
||||
exp,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let start_index = if expand_cna.is_some() { 1 } else { 0 };
|
||||
let depthwise_cna =
|
||||
ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?;
|
||||
let squeeze_channels = usize::max(1, c.input_channels / 4);
|
||||
let squeeze_excitation =
|
||||
SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?;
|
||||
let project_cna =
|
||||
ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)?
|
||||
.no_activation();
|
||||
Ok(Self {
|
||||
expand_cna,
|
||||
depthwise_cna,
|
||||
squeeze_excitation,
|
||||
project_cna,
|
||||
config: c,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MBConv {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let use_res_connect =
|
||||
self.config.stride == 1 && self.config.input_channels == self.config.out_channels;
|
||||
let ys = match &self.expand_cna {
|
||||
Some(expand_cna) => expand_cna.forward(xs)?,
|
||||
None => xs.clone(),
|
||||
};
|
||||
let ys = self.depthwise_cna.forward(&ys)?;
|
||||
let ys = self.squeeze_excitation.forward(&ys)?;
|
||||
let ys = self.project_cna.forward(&ys)?;
|
||||
if use_res_connect {
|
||||
ys + xs
|
||||
} else {
|
||||
Ok(ys)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn swish(s: &Tensor) -> Result<Tensor> {
|
||||
s * nn::ops::sigmoid(s)?
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EfficientNet {
|
||||
init_cna: ConvNormActivation,
|
||||
blocks: Vec<MBConv>,
|
||||
final_cna: ConvNormActivation,
|
||||
classifier: nn::Linear,
|
||||
}
|
||||
|
||||
impl EfficientNet {
|
||||
fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> {
|
||||
let f_p = p.pp("features");
|
||||
let first_in_c = configs[0].input_channels;
|
||||
let last_out_c = configs.last().unwrap().out_channels;
|
||||
let final_out_c = 4 * last_out_c;
|
||||
let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
|
||||
let nconfigs = configs.len();
|
||||
let mut blocks = vec![];
|
||||
for (index, cnf) in configs.into_iter().enumerate() {
|
||||
let f_p = f_p.pp(index + 1);
|
||||
for r_index in 0..cnf.num_layers {
|
||||
let cnf = if r_index == 0 {
|
||||
cnf
|
||||
} else {
|
||||
MBConvConfig {
|
||||
input_channels: cnf.out_channels,
|
||||
stride: 1,
|
||||
..cnf
|
||||
}
|
||||
};
|
||||
blocks.push(MBConv::new(f_p.pp(r_index), cnf)?)
|
||||
}
|
||||
}
|
||||
let final_cna =
|
||||
ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?;
|
||||
let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?;
|
||||
Ok(Self {
|
||||
init_cna,
|
||||
blocks,
|
||||
final_cna,
|
||||
classifier,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EfficientNet {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = self.init_cna.forward(xs)?;
|
||||
for block in self.blocks.iter() {
|
||||
xs = block.forward(&xs)?
|
||||
}
|
||||
let xs = self.final_cna.forward(&xs)?;
|
||||
// Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1)
|
||||
let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?;
|
||||
self.classifier.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
B0,
|
||||
B1,
|
||||
B2,
|
||||
B3,
|
||||
B4,
|
||||
B5,
|
||||
B6,
|
||||
B7,
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Variant of the model to use.
|
||||
#[arg(value_enum, long, default_value_t = Which::B2)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("lmz/candle-efficientnet".into());
|
||||
let filename = match args.which {
|
||||
Which::B0 => "efficientnet-b0.safetensors",
|
||||
Which::B1 => "efficientnet-b1.safetensors",
|
||||
Which::B2 => "efficientnet-b2.safetensors",
|
||||
Which::B3 => "efficientnet-b3.safetensors",
|
||||
Which::B4 => "efficientnet-b4.safetensors",
|
||||
Which::B5 => "efficientnet-b5.safetensors",
|
||||
Which::B6 => "efficientnet-b6.safetensors",
|
||||
Which::B7 => "efficientnet-b7.safetensors",
|
||||
};
|
||||
api.get(filename)?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let cfg = match args.which {
|
||||
Which::B0 => MBConvConfig::b0(),
|
||||
Which::B1 => MBConvConfig::b1(),
|
||||
Which::B2 => MBConvConfig::b2(),
|
||||
Which::B3 => MBConvConfig::b3(),
|
||||
Which::B4 => MBConvConfig::b4(),
|
||||
Which::B5 => MBConvConfig::b5(),
|
||||
Which::B6 => MBConvConfig::b6(),
|
||||
Which::B7 => MBConvConfig::b7(),
|
||||
};
|
||||
let model = EfficientNet::new(vb, cfg, candle_examples::imagenet::CLASS_COUNT as usize)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -22,8 +22,6 @@ struct TextGeneration {
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
@ -33,8 +31,6 @@ impl TextGeneration {
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
device: &Device,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp);
|
||||
Self {
|
||||
@ -42,8 +38,6 @@ impl TextGeneration {
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
device: device.clone(),
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
}
|
||||
}
|
||||
|
||||
@ -69,16 +63,6 @@ impl TextGeneration {
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input)?;
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
@ -132,14 +116,6 @@ struct Args {
|
||||
|
||||
#[arg(long, default_value = "refs/pr/43")]
|
||||
revision: String,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.0)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -186,15 +162,7 @@ fn main() -> Result<()> {
|
||||
let model = Falcon::load(vb, config)?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
&device,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
);
|
||||
let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Device, Tensor, D};
|
||||
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
use candle_nn::{Embedding, LayerNorm, Linear, VarBuilder};
|
||||
|
||||
const MAX_SEQ_LEN: usize = 5000;
|
||||
|
||||
|
28
candle-examples/examples/ggml/main.rs
Normal file
28
candle-examples/examples/ggml/main.rs
Normal file
@ -0,0 +1,28 @@
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use std::fs::File;
|
||||
|
||||
use candle::quantized::ggml_file::Content;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// GGML file to load, typically a .bin file generated by the quantize command from llama.cpp
|
||||
#[arg(long)]
|
||||
model: String,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let mut file = File::open(args.model)?;
|
||||
let start = std::time::Instant::now();
|
||||
let model = Content::read(&mut file)?;
|
||||
|
||||
println!(
|
||||
"Loaded {:?} tensors in {:?}",
|
||||
model.tensors.len(),
|
||||
start.elapsed()
|
||||
);
|
||||
Ok(())
|
||||
}
|
@ -12,17 +12,17 @@ extern crate accelerate_src;
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{bail, Error as E, Result};
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use hf_hub::api::sync::Api;
|
||||
use std::io::Write;
|
||||
|
||||
mod model;
|
||||
use model::{Config, Llama, LlamaConfig};
|
||||
use model::{Config, Llama};
|
||||
|
||||
const EOS_TOKEN: &str = "</s>";
|
||||
const MAX_SEQ_LEN: usize = 4096;
|
||||
@ -59,9 +59,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// Use different dtype than f16
|
||||
/// Use f32 computations rather than f16.
|
||||
#[arg(long)]
|
||||
dtype: Option<String>,
|
||||
use_f32: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
@ -70,9 +70,6 @@ struct Args {
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
v1: bool,
|
||||
|
||||
@ -83,14 +80,6 @@ struct Args {
|
||||
/// (same structure as huggingface online)
|
||||
#[arg(long)]
|
||||
local_weights: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.0)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -100,6 +89,7 @@ fn main() -> Result<()> {
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
println!("tracing...");
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
@ -108,24 +98,18 @@ fn main() -> Result<()> {
|
||||
};
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = match args.dtype.as_deref() {
|
||||
Some("f16") => DType::F16,
|
||||
Some("bf16") => DType::BF16,
|
||||
Some("f32") => DType::F32,
|
||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||
None => DType::F16,
|
||||
let config = if args.v1 {
|
||||
Config::config_7b_v1(args.use_flash_attn)
|
||||
} else {
|
||||
Config::config_7b_v2(args.use_flash_attn)
|
||||
};
|
||||
let (llama, tokenizer_filename, cache) = match args.npy {
|
||||
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
let (llama, tokenizer_filename) = match args.npy {
|
||||
Some(filename) => {
|
||||
let config = if args.v1 {
|
||||
Config::config_7b_v1(args.use_flash_attn)
|
||||
} else {
|
||||
Config::config_7b_v2(args.use_flash_attn)
|
||||
};
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
let vb = VarBuilder::from_npz(filename, dtype, &device)?;
|
||||
let tokenizer = std::path::PathBuf::from("llama-tokenizer.json");
|
||||
(Llama::load(vb, &cache, &config)?, tokenizer, cache)
|
||||
(Llama::load(vb, &cache, &config)?, tokenizer)
|
||||
}
|
||||
None => {
|
||||
let api = Api::new()?;
|
||||
@ -137,21 +121,13 @@ fn main() -> Result<()> {
|
||||
}
|
||||
});
|
||||
println!("loading the model weights from {model_id}");
|
||||
let revision = args.revision.unwrap_or("main".to_string());
|
||||
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
let api = api.model(model_id);
|
||||
|
||||
let tokenizer_filename = match &args.local_weights {
|
||||
Some(path) => (path.to_owned() + "tokenizer.json").into(),
|
||||
_ => api.get("tokenizer.json")?,
|
||||
};
|
||||
|
||||
let config_filename = match &args.local_weights {
|
||||
Some(path) => (path.to_owned() + "config.json").into(),
|
||||
_ => api.get("config.json")?,
|
||||
};
|
||||
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let config = config.into_config(args.use_flash_attn);
|
||||
|
||||
let mut filenames = vec![];
|
||||
for rfilename in [
|
||||
"model-00001-of-00002.safetensors",
|
||||
@ -177,10 +153,9 @@ fn main() -> Result<()> {
|
||||
.iter()
|
||||
.map(|h| Ok(h.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
|
||||
let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
|
||||
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
|
||||
(Llama::load(vb, &cache, &config)?, tokenizer_filename)
|
||||
}
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
@ -208,16 +183,6 @@ fn main() -> Result<()> {
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
let logits = llama.forward(&input, index_pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
index_pos += ctxt.len();
|
||||
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
|
@ -1,54 +1,20 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
use candle_nn::{Embedding, VarBuilder};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::MAX_SEQ_LEN;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct LlamaConfig {
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub vocab_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: Option<usize>,
|
||||
pub rms_norm_eps: f64,
|
||||
#[serde(default = "default_rope")]
|
||||
pub rope_theta: f32,
|
||||
}
|
||||
|
||||
fn default_rope() -> f32 {
|
||||
10_000.0
|
||||
}
|
||||
|
||||
impl LlamaConfig {
|
||||
pub fn into_config(self, use_flash_attn: bool) -> Config {
|
||||
Config {
|
||||
hidden_size: self.hidden_size,
|
||||
intermediate_size: self.intermediate_size,
|
||||
vocab_size: self.vocab_size,
|
||||
num_hidden_layers: self.num_hidden_layers,
|
||||
num_attention_heads: self.num_attention_heads,
|
||||
num_key_value_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads),
|
||||
rms_norm_eps: self.rms_norm_eps,
|
||||
rope_theta: self.rope_theta,
|
||||
use_flash_attn,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Config {
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub vocab_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub n_layer: usize,
|
||||
pub n_head: usize,
|
||||
pub n_embd: usize,
|
||||
pub n_key_value_head: usize,
|
||||
pub use_flash_attn: bool,
|
||||
pub rms_norm_eps: f64,
|
||||
pub rope_theta: f32,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@ -57,12 +23,12 @@ impl Config {
|
||||
hidden_size: 4096,
|
||||
intermediate_size: 11008,
|
||||
vocab_size: 32000,
|
||||
num_hidden_layers: 32,
|
||||
num_attention_heads: 32,
|
||||
num_key_value_heads: 32,
|
||||
n_layer: 32,
|
||||
n_head: 32,
|
||||
n_embd: 4096,
|
||||
n_key_value_head: 32,
|
||||
use_flash_attn,
|
||||
rms_norm_eps: 1e-6,
|
||||
rope_theta: 10_000.0,
|
||||
}
|
||||
}
|
||||
|
||||
@ -71,12 +37,12 @@ impl Config {
|
||||
hidden_size: 4096,
|
||||
intermediate_size: 11008,
|
||||
vocab_size: 32000,
|
||||
num_hidden_layers: 32,
|
||||
num_attention_heads: 32,
|
||||
num_key_value_heads: 32,
|
||||
n_layer: 32,
|
||||
n_head: 32,
|
||||
n_embd: 4096,
|
||||
n_key_value_head: 32,
|
||||
use_flash_attn,
|
||||
rms_norm_eps: 1e-5,
|
||||
rope_theta: 10_000.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -110,10 +76,10 @@ pub struct Cache {
|
||||
impl Cache {
|
||||
pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> {
|
||||
// precompute freqs_cis
|
||||
let n_elem = config.hidden_size / config.num_attention_heads;
|
||||
let n_elem = config.n_embd / config.n_head;
|
||||
let theta: Vec<_> = (0..n_elem)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32))
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||
@ -128,7 +94,7 @@ impl Cache {
|
||||
Ok(Self {
|
||||
masks: Arc::new(Mutex::new(HashMap::new())),
|
||||
use_kv_cache,
|
||||
kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])),
|
||||
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
|
||||
device: device.clone(),
|
||||
cos,
|
||||
sin,
|
||||
@ -150,6 +116,10 @@ impl Cache {
|
||||
}
|
||||
}
|
||||
|
||||
fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||
xs / (xs.neg()?.exp()? + 1.0)?
|
||||
}
|
||||
|
||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
|
||||
@ -162,20 +132,35 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
||||
}
|
||||
|
||||
struct RmsNorm {
|
||||
inner: candle_nn::RmsNorm,
|
||||
scale: Tensor,
|
||||
eps: f64,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
||||
let inner = candle_nn::rms_norm(size, eps, vb)?;
|
||||
Ok(Self { inner, span })
|
||||
let scale = vb.get(size, "weight")?;
|
||||
Ok(Self { scale, eps, span })
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
let in_dtype = x.dtype();
|
||||
// This is a no-op if x's dtype is already f32.
|
||||
let x = x.to_dtype(DType::F32)?;
|
||||
let (b_sz, seq_len, hidden_size) = x.dims3()?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
||||
let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
|
||||
let size = self.scale.dims1()?;
|
||||
let scale = self
|
||||
.scale
|
||||
.to_dtype(DType::F32)?
|
||||
.broadcast_as((b_sz, seq_len, size))?;
|
||||
let x = (scale * x_normed)?;
|
||||
let x = x.to_dtype(in_dtype)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
@ -184,8 +169,8 @@ struct CausalSelfAttention {
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
o_proj: Linear,
|
||||
num_attention_heads: usize,
|
||||
num_key_value_heads: usize,
|
||||
n_head: usize,
|
||||
n_key_value_head: usize,
|
||||
head_dim: usize,
|
||||
cache: Cache,
|
||||
use_flash_attn: bool,
|
||||
@ -212,13 +197,13 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
|
||||
impl CausalSelfAttention {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let _enter = self.span_rot.enter();
|
||||
let (b_sz, _, seq_len, hidden_size) = x.dims4()?;
|
||||
let (b_sz, _, seq_len, n_embd) = x.dims4()?;
|
||||
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
||||
let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
|
||||
let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?;
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?;
|
||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd))?;
|
||||
let x1 = x.narrow(D::Minus1, 0, n_embd / 2)?;
|
||||
let x2 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?;
|
||||
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
|
||||
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
|
||||
Ok(rope)
|
||||
@ -226,19 +211,19 @@ impl CausalSelfAttention {
|
||||
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b_sz, seq_len, hidden_size) = x.dims3()?;
|
||||
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
||||
let q = self.q_proj.forward(x)?;
|
||||
let k = self.k_proj.forward(x)?;
|
||||
let v = self.v_proj.forward(x)?;
|
||||
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
|
||||
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let k = k
|
||||
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let mut v = v
|
||||
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||
@ -287,13 +272,13 @@ impl CausalSelfAttention {
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
|
||||
};
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||
let y = self.o_proj.forward(&y)?;
|
||||
Ok(y)
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_attention_heads / self.num_key_value_heads;
|
||||
let n_rep = self.n_head / self.n_key_value_head;
|
||||
if n_rep == 1 {
|
||||
Ok(x)
|
||||
} else {
|
||||
@ -301,7 +286,7 @@ impl CausalSelfAttention {
|
||||
let x = x
|
||||
.unsqueeze(2)?
|
||||
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
|
||||
.reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
@ -310,8 +295,8 @@ impl CausalSelfAttention {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||
let size_in = cfg.hidden_size;
|
||||
let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
|
||||
let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
|
||||
let size_q = (cfg.hidden_size / cfg.n_head) * cfg.n_head;
|
||||
let size_kv = (cfg.hidden_size / cfg.n_head) * cfg.n_key_value_head;
|
||||
let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?;
|
||||
let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?;
|
||||
@ -321,9 +306,9 @@ impl CausalSelfAttention {
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
num_attention_heads: cfg.num_attention_heads,
|
||||
num_key_value_heads: cfg.num_key_value_heads,
|
||||
head_dim: cfg.hidden_size / cfg.num_attention_heads,
|
||||
n_head: cfg.n_head,
|
||||
n_key_value_head: cfg.n_key_value_head,
|
||||
head_dim: cfg.hidden_size / cfg.n_head,
|
||||
cache: cache.clone(),
|
||||
use_flash_attn: cfg.use_flash_attn,
|
||||
span,
|
||||
@ -349,7 +334,7 @@ struct Mlp {
|
||||
impl Mlp {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let x = (candle_nn::ops::silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
||||
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
||||
self.c_proj.forward(&x)
|
||||
}
|
||||
|
||||
@ -432,7 +417,7 @@ impl Llama {
|
||||
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
|
||||
let blocks: Vec<_> = (0..cfg.n_layer)
|
||||
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap())
|
||||
.collect();
|
||||
|
||||
|
@ -103,14 +103,6 @@ pub struct Args {
|
||||
/// Tokenizer config file.
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
@ -276,16 +268,6 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, index_pos)?;
|
||||
let logits = logits.i((0, logits.dim(1)? - 1))?;
|
||||
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(common_args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
common_args.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
index_pos += ctxt.len();
|
||||
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
|
@ -1,6 +1,6 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::linear_no_bias as linear;
|
||||
use candle_nn::{embedding, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
|
||||
use candle_nn::{embedding, Embedding, Linear, VarBuilder};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
@ -94,6 +94,32 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||
xs / (xs.neg()?.exp()? + 1.0)?
|
||||
}
|
||||
|
||||
struct RmsNorm {
|
||||
scale: Tensor,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let scale = vb.get_or_init(size, "weight", candle_nn::Init::Const(1.))?;
|
||||
Ok(Self { scale, eps })
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (b_sz, seq_len, hidden_size) = x.dims3()?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
||||
let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
|
||||
let size = self.scale.dims1()?;
|
||||
let scale = self
|
||||
.scale
|
||||
.to_dtype(DType::F32)?
|
||||
.broadcast_as((b_sz, seq_len, size))?;
|
||||
let x = (scale * x_normed)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
struct CausalSelfAttention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
@ -264,9 +290,9 @@ impl Block {
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
|
||||
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
||||
let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
|
||||
let input_layernorm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
|
||||
let post_attention_layernorm =
|
||||
rms_norm(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?;
|
||||
RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?;
|
||||
Ok(Self::new(
|
||||
input_layernorm,
|
||||
attn,
|
||||
@ -299,7 +325,7 @@ impl Llama {
|
||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
|
||||
let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let ln_f = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
|
||||
let ln_f = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.n_layers)
|
||||
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, &cfg).unwrap())
|
||||
.collect();
|
||||
|
@ -1,7 +1,6 @@
|
||||
use crate::model::{Cache, Config, Llama};
|
||||
use candle::{DType, Device, Result};
|
||||
use candle_datasets::nlp::tinystories::{Dataset, DatasetRandomIter};
|
||||
use candle_nn::Optimizer;
|
||||
|
||||
fn valid_loss(
|
||||
dataset: &Dataset,
|
||||
|
@ -9,14 +9,15 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{bail, Error as E, Result};
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use cudarc::driver::safe::CudaDevice;
|
||||
use cudarc::nccl::safe::{Comm, Id};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use hf_hub::api::sync::Api;
|
||||
use std::io::Write;
|
||||
use std::rc::Rc;
|
||||
|
||||
@ -107,12 +108,6 @@ struct Args {
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
dtype: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -120,13 +115,8 @@ fn main() -> Result<()> {
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let dtype = match args.dtype.as_deref() {
|
||||
Some("f16") => DType::F16,
|
||||
Some("bf16") => DType::BF16,
|
||||
Some("f32") => DType::F32,
|
||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||
None => DType::F16,
|
||||
};
|
||||
let config = Config::config_7b();
|
||||
let dtype = DType::F16;
|
||||
|
||||
let api = Api::new()?;
|
||||
|
||||
@ -134,10 +124,7 @@ fn main() -> Result<()> {
|
||||
.model_id
|
||||
.unwrap_or_else(|| "meta-llama/Llama-2-7b-hf".to_string());
|
||||
println!("loading the model weights from {model_id}");
|
||||
let revision = args.revision.unwrap_or("main".to_string());
|
||||
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
let config_filename = api.get("config.json")?;
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let api = api.model(model_id);
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let mut filenames = vec![];
|
||||
for rfilename in [
|
||||
@ -198,7 +185,7 @@ fn main() -> Result<()> {
|
||||
println!("Rank {rank:?} spawned");
|
||||
|
||||
let device = Device::new_cuda(i)?;
|
||||
let cache = model::Cache::new(dtype, &config, &device)?;
|
||||
let cache = model::Cache::new(&config, &device)?;
|
||||
|
||||
println!("building the model");
|
||||
let handles = filenames
|
||||
@ -210,7 +197,7 @@ fn main() -> Result<()> {
|
||||
.map(|h| Ok(h.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let vb = candle_nn::var_builder::ShardedSafeTensors::var_builder(tensors, dtype, &device);
|
||||
let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
|
||||
let llama = Llama::load(vb, &cache, &config, comm)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
@ -244,7 +231,7 @@ fn main() -> Result<()> {
|
||||
"{} token: {} '{}'",
|
||||
index + 1,
|
||||
next_token,
|
||||
tokenizer.decode(&[next_token], true).map_err(E::msg)?
|
||||
tokenizer.decode(vec![next_token], true).map_err(E::msg)?
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -254,9 +241,7 @@ fn main() -> Result<()> {
|
||||
"{} tokens generated ({} token/s)\n----\n{}\n----",
|
||||
args.sample_len,
|
||||
args.sample_len as f64 / dt.as_secs_f64(),
|
||||
tokenizer
|
||||
.decode(new_tokens.as_slice(), true)
|
||||
.map_err(E::msg)?
|
||||
tokenizer.decode(new_tokens, true).map_err(E::msg)?
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
|
@ -1,16 +1,13 @@
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
|
||||
use candle_nn::{Embedding, Linear, Module, RmsNorm};
|
||||
use candle_nn::{Embedding, Linear, VarBuilder};
|
||||
use cudarc::nccl::safe::{Comm, ReduceOp};
|
||||
use half::f16;
|
||||
use serde::Deserialize;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::MAX_SEQ_LEN;
|
||||
|
||||
use candle_nn::var_builder::ShardedVarBuilder as VarBuilder;
|
||||
|
||||
struct TensorParallelColumnLinear {
|
||||
linear: Linear,
|
||||
}
|
||||
@ -71,7 +68,7 @@ impl CustomOp1 for AllReduce {
|
||||
}
|
||||
|
||||
fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
|
||||
x.apply_op1(AllReduce { comm: comm.clone() })
|
||||
x.custom_op1(AllReduce { comm: comm.clone() })
|
||||
}
|
||||
|
||||
impl TensorParallelRowLinear {
|
||||
@ -84,19 +81,11 @@ impl TensorParallelRowLinear {
|
||||
}
|
||||
}
|
||||
|
||||
fn shard(dim: usize, rank: usize, world_size: usize) -> candle_nn::var_builder::Shard {
|
||||
candle_nn::var_builder::Shard {
|
||||
dim,
|
||||
rank,
|
||||
world_size,
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorParallelColumnLinear {
|
||||
fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
|
||||
let rank = comm.rank();
|
||||
let size = comm.world_size();
|
||||
let weight = vb.get_with_hints((), "weight", shard(0, rank, size))?;
|
||||
let weight = vb.get_sharded("weight", 0, rank, size)?;
|
||||
Ok(Self::new(Linear::new(weight, None)))
|
||||
}
|
||||
|
||||
@ -105,8 +94,8 @@ impl TensorParallelColumnLinear {
|
||||
let size = comm.world_size();
|
||||
let weights: Vec<_> = prefixes
|
||||
.iter()
|
||||
.map(|p| vb.pp(p).get_with_hints((), "weight", shard(0, rank, size)))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
.map(|p| vb.pp(p).get_sharded("weight", 0, rank, size).unwrap())
|
||||
.collect();
|
||||
let weight = Tensor::cat(&weights, 0)?;
|
||||
Ok(Self::new(Linear::new(weight, None)))
|
||||
}
|
||||
@ -116,26 +105,33 @@ impl TensorParallelRowLinear {
|
||||
fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
|
||||
let rank = comm.rank();
|
||||
let size = comm.world_size();
|
||||
let weight = vb.get_with_hints((), "weight", shard(1, rank, size))?;
|
||||
let weight = vb.get_sharded("weight", 1, rank, size)?;
|
||||
Ok(Self::new(Linear::new(weight, None), comm))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct Config {
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub vocab_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub rms_norm_eps: f64,
|
||||
#[serde(default = "default_rope")]
|
||||
pub rope_theta: f32,
|
||||
pub n_layer: usize,
|
||||
pub n_head: usize,
|
||||
pub n_embd: usize,
|
||||
pub n_key_value_head: usize,
|
||||
}
|
||||
|
||||
fn default_rope() -> f32 {
|
||||
10_000.0
|
||||
impl Config {
|
||||
pub fn config_7b() -> Self {
|
||||
Self {
|
||||
hidden_size: 4096,
|
||||
intermediate_size: 11008,
|
||||
vocab_size: 32000,
|
||||
n_layer: 32,
|
||||
n_head: 32,
|
||||
n_embd: 4096,
|
||||
n_key_value_head: 32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -147,12 +143,12 @@ pub struct Cache {
|
||||
}
|
||||
|
||||
impl Cache {
|
||||
pub fn new(dtype: DType, config: &Config, device: &Device) -> Result<Self> {
|
||||
pub fn new(config: &Config, device: &Device) -> Result<Self> {
|
||||
// precompute freqs_cis
|
||||
let n_elem = config.hidden_size / config.num_attention_heads;
|
||||
let n_elem = config.n_embd / config.n_head;
|
||||
let theta: Vec<_> = (0..n_elem)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32))
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||
@ -162,10 +158,10 @@ impl Cache {
|
||||
// This is different from the paper, see:
|
||||
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
||||
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
|
||||
let cos = idx_theta.cos()?.to_dtype(dtype)?;
|
||||
let sin = idx_theta.sin()?.to_dtype(dtype)?;
|
||||
let cos = idx_theta.cos()?.to_dtype(DType::F16)?;
|
||||
let sin = idx_theta.sin()?.to_dtype(DType::F16)?;
|
||||
Ok(Self {
|
||||
kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])),
|
||||
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
|
||||
cos,
|
||||
sin,
|
||||
})
|
||||
@ -186,24 +182,57 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
||||
Ok(Embedding::new(embeddings, cfg.hidden_size))
|
||||
}
|
||||
|
||||
struct RmsNorm {
|
||||
scale: Tensor,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
fn load(size: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let scale = vb.get(size, "weight")?;
|
||||
Ok(Self::new(scale))
|
||||
}
|
||||
|
||||
fn new(scale: Tensor) -> Self {
|
||||
Self { scale }
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let in_dtype = x.dtype();
|
||||
// This is a no-op if x's dtype is already f32.
|
||||
let x = x.to_dtype(DType::F32)?;
|
||||
let (b_sz, seq_len, hidden_size) = x.shape().dims3()?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
||||
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
||||
let size = self.scale.shape().dims1()?;
|
||||
let scale = self
|
||||
.scale
|
||||
.to_dtype(DType::F32)?
|
||||
.broadcast_as((b_sz, seq_len, size))?;
|
||||
let x = (scale * x_normed)?;
|
||||
let x = x.to_dtype(in_dtype)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
struct CausalSelfAttention {
|
||||
qkv_proj: TensorParallelColumnLinear,
|
||||
o_proj: TensorParallelRowLinear,
|
||||
num_attention_heads: usize,
|
||||
num_key_value_heads: usize,
|
||||
n_head: usize,
|
||||
n_key_value_head: usize,
|
||||
head_dim: usize,
|
||||
cache: Cache,
|
||||
}
|
||||
|
||||
impl CausalSelfAttention {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (b_sz, _, seq_len, hidden_size) = x.shape().dims4()?;
|
||||
let (b_sz, _, seq_len, n_embd) = x.shape().dims4()?;
|
||||
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
||||
let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
|
||||
let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?;
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?;
|
||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd))?;
|
||||
let x1 = x.narrow(D::Minus1, 0, n_embd / 2)?;
|
||||
let x2 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?;
|
||||
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
|
||||
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
|
||||
Ok(rope)
|
||||
@ -213,31 +242,30 @@ impl CausalSelfAttention {
|
||||
let (b_sz, seq_len, _) = x.shape().dims3()?;
|
||||
|
||||
let qkv = self.qkv_proj.forward(x)?;
|
||||
let hidden_size = self.num_attention_heads * self.head_dim;
|
||||
let n_embd = self.n_head * self.head_dim;
|
||||
|
||||
let q = qkv.i((.., .., ..self.num_attention_heads * self.head_dim))?;
|
||||
let q = qkv.i((.., .., ..self.n_head * self.head_dim))?;
|
||||
let k = qkv.i((
|
||||
..,
|
||||
..,
|
||||
self.num_attention_heads * self.head_dim
|
||||
..self.num_attention_heads * self.head_dim
|
||||
+ self.num_key_value_heads * self.head_dim,
|
||||
self.n_head * self.head_dim
|
||||
..self.n_head * self.head_dim + self.n_key_value_head * self.head_dim,
|
||||
))?;
|
||||
let v = qkv.i((
|
||||
..,
|
||||
..,
|
||||
self.num_attention_heads * self.head_dim + self.num_key_value_heads * self.head_dim..,
|
||||
self.n_head * self.head_dim + self.n_key_value_head * self.head_dim..,
|
||||
))?;
|
||||
// todo!("Q {:?} K {:?} V {:?} - x {:?}", q.shape(), k.shape(), v.shape(), x.shape());
|
||||
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
|
||||
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let k = k
|
||||
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let mut v = v
|
||||
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||
@ -271,13 +299,13 @@ impl CausalSelfAttention {
|
||||
let y = candle_flash_attn::flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?
|
||||
.transpose(1, 2)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||
let y = self.o_proj.forward(&y)?;
|
||||
Ok(y)
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_attention_heads / self.num_key_value_heads;
|
||||
let n_rep = self.n_head / self.n_key_value_head;
|
||||
if n_rep == 1 {
|
||||
Ok(x)
|
||||
} else {
|
||||
@ -300,9 +328,9 @@ impl CausalSelfAttention {
|
||||
Ok(Self {
|
||||
qkv_proj,
|
||||
o_proj,
|
||||
num_attention_heads: cfg.num_attention_heads / comm.world_size(),
|
||||
num_key_value_heads: cfg.num_key_value_heads / comm.world_size(),
|
||||
head_dim: cfg.hidden_size / cfg.num_attention_heads,
|
||||
n_head: cfg.n_head / comm.world_size(),
|
||||
n_key_value_head: cfg.n_key_value_head / comm.world_size(),
|
||||
head_dim: cfg.hidden_size / cfg.n_head,
|
||||
cache: cache.clone(),
|
||||
})
|
||||
}
|
||||
@ -347,11 +375,6 @@ struct Block {
|
||||
mlp: Mlp,
|
||||
}
|
||||
|
||||
fn rms_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<RmsNorm> {
|
||||
let weight = vb.get_with_hints(size, "weight", shard(0, 0, 1))?;
|
||||
Ok(RmsNorm::new(weight, eps))
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
|
||||
Self {
|
||||
@ -374,9 +397,9 @@ impl Block {
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?;
|
||||
let mlp = Mlp::load(vb.pp("mlp"), cfg, comm)?;
|
||||
let input_layernorm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("input_layernorm"))?;
|
||||
let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?;
|
||||
let post_attention_layernorm =
|
||||
rms_norm(cfg.hidden_size, 1e-5, vb.pp("post_attention_layernorm"))?;
|
||||
RmsNorm::load(cfg.hidden_size, vb.pp("post_attention_layernorm"))?;
|
||||
Ok(Self::new(
|
||||
input_layernorm,
|
||||
attn,
|
||||
@ -418,8 +441,8 @@ impl Llama {
|
||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
||||
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
|
||||
let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.n_layer)
|
||||
.map(|i| {
|
||||
Block::load(
|
||||
vb.pp(&format!("model.layers.{i}")),
|
||||
|
@ -2,21 +2,17 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
use rand::prelude::*;
|
||||
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
use candle_nn::{loss, ops, Conv2d, Linear, Module, Optimizer, VarBuilder, VarMap};
|
||||
use candle_nn::{loss, ops, Linear, VarBuilder, VarMap};
|
||||
|
||||
const IMAGE_DIM: usize = 784;
|
||||
const LABELS: usize = 10;
|
||||
|
||||
fn linear_z(in_dim: usize, out_dim: usize, vs: VarBuilder) -> Result<Linear> {
|
||||
let ws = vs.get_with_hints((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
|
||||
let bs = vs.get_with_hints(out_dim, "bias", candle_nn::init::ZERO)?;
|
||||
let ws = vs.get_or_init((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
|
||||
let bs = vs.get_or_init(out_dim, "bias", candle_nn::init::ZERO)?;
|
||||
Ok(Linear::new(ws, Some(bs)))
|
||||
}
|
||||
|
||||
@ -59,46 +55,6 @@ impl Model for Mlp {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ConvNet {
|
||||
conv1: Conv2d,
|
||||
conv2: Conv2d,
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
dropout: candle_nn::Dropout,
|
||||
}
|
||||
|
||||
impl ConvNet {
|
||||
fn new(vs: VarBuilder) -> Result<Self> {
|
||||
let conv1 = candle_nn::conv2d(1, 32, 5, Default::default(), vs.pp("c1"))?;
|
||||
let conv2 = candle_nn::conv2d(32, 64, 5, Default::default(), vs.pp("c2"))?;
|
||||
let fc1 = candle_nn::linear(1024, 1024, vs.pp("fc1"))?;
|
||||
let fc2 = candle_nn::linear(1024, LABELS, vs.pp("fc2"))?;
|
||||
let dropout = candle_nn::Dropout::new(0.5);
|
||||
Ok(Self {
|
||||
conv1,
|
||||
conv2,
|
||||
fc1,
|
||||
fc2,
|
||||
dropout,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
|
||||
let (b_sz, _img_dim) = xs.dims2()?;
|
||||
let xs = xs
|
||||
.reshape((b_sz, 1, 28, 28))?
|
||||
.apply(&self.conv1)?
|
||||
.max_pool2d(2)?
|
||||
.apply(&self.conv2)?
|
||||
.max_pool2d(2)?
|
||||
.flatten_from(1)?
|
||||
.apply(&self.fc1)?
|
||||
.relu()?;
|
||||
self.dropout.forward(&xs, train)?.apply(&self.fc2)
|
||||
}
|
||||
}
|
||||
|
||||
struct TrainingArgs {
|
||||
learning_rate: f64,
|
||||
load: Option<String>,
|
||||
@ -106,71 +62,6 @@ struct TrainingArgs {
|
||||
epochs: usize,
|
||||
}
|
||||
|
||||
fn training_loop_cnn(
|
||||
m: candle_datasets::vision::Dataset,
|
||||
args: &TrainingArgs,
|
||||
) -> anyhow::Result<()> {
|
||||
const BSIZE: usize = 64;
|
||||
|
||||
let dev = candle::Device::cuda_if_available(0)?;
|
||||
|
||||
let train_labels = m.train_labels;
|
||||
let train_images = m.train_images.to_device(&dev)?;
|
||||
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
let mut varmap = VarMap::new();
|
||||
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
|
||||
let model = ConvNet::new(vs.clone())?;
|
||||
|
||||
if let Some(load) = &args.load {
|
||||
println!("loading weights from {load}");
|
||||
varmap.load(load)?
|
||||
}
|
||||
|
||||
let adamw_params = candle_nn::ParamsAdamW {
|
||||
lr: args.learning_rate,
|
||||
..Default::default()
|
||||
};
|
||||
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), adamw_params)?;
|
||||
let test_images = m.test_images.to_device(&dev)?;
|
||||
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
let n_batches = train_images.dim(0)? / BSIZE;
|
||||
let mut batch_idxs = (0..n_batches).collect::<Vec<usize>>();
|
||||
for epoch in 1..args.epochs {
|
||||
let mut sum_loss = 0f32;
|
||||
batch_idxs.shuffle(&mut thread_rng());
|
||||
for batch_idx in batch_idxs.iter() {
|
||||
let train_images = train_images.narrow(0, batch_idx * BSIZE, BSIZE)?;
|
||||
let train_labels = train_labels.narrow(0, batch_idx * BSIZE, BSIZE)?;
|
||||
let logits = model.forward(&train_images, true)?;
|
||||
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
||||
let loss = loss::nll(&log_sm, &train_labels)?;
|
||||
opt.backward_step(&loss)?;
|
||||
sum_loss += loss.to_vec0::<f32>()?;
|
||||
}
|
||||
let avg_loss = sum_loss / n_batches as f32;
|
||||
|
||||
let test_logits = model.forward(&test_images, false)?;
|
||||
let sum_ok = test_logits
|
||||
.argmax(D::Minus1)?
|
||||
.eq(&test_labels)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_scalar::<f32>()?;
|
||||
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
|
||||
println!(
|
||||
"{epoch:4} train loss {:8.5} test acc: {:5.2}%",
|
||||
avg_loss,
|
||||
100. * test_accuracy
|
||||
);
|
||||
}
|
||||
if let Some(save) = &args.save {
|
||||
println!("saving trained weights in {save}");
|
||||
varmap.save(save)?
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn training_loop<M: Model>(
|
||||
m: candle_datasets::vision::Dataset,
|
||||
args: &TrainingArgs,
|
||||
@ -190,7 +81,7 @@ fn training_loop<M: Model>(
|
||||
varmap.load(load)?
|
||||
}
|
||||
|
||||
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate)?;
|
||||
let sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate);
|
||||
let test_images = m.test_images.to_device(&dev)?;
|
||||
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
for epoch in 1..args.epochs {
|
||||
@ -224,7 +115,6 @@ fn training_loop<M: Model>(
|
||||
enum WhichModel {
|
||||
Linear,
|
||||
Mlp,
|
||||
Cnn,
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
@ -245,20 +135,12 @@ struct Args {
|
||||
/// The file where to load the trained weights from, in safetensors format.
|
||||
#[arg(long)]
|
||||
load: Option<String>,
|
||||
|
||||
/// The directory where to load the dataset from, in ubyte format.
|
||||
#[arg(long)]
|
||||
local_mnist: Option<String>,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
// Load the dataset
|
||||
let m = if let Some(directory) = args.local_mnist {
|
||||
candle_datasets::vision::mnist::load_dir(directory)?
|
||||
} else {
|
||||
candle_datasets::vision::mnist::load()?
|
||||
};
|
||||
let m = candle_datasets::vision::mnist::load_dir("data")?;
|
||||
println!("train-images: {:?}", m.train_images.shape());
|
||||
println!("train-labels: {:?}", m.train_labels.shape());
|
||||
println!("test-images: {:?}", m.test_images.shape());
|
||||
@ -267,7 +149,6 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let default_learning_rate = match args.model {
|
||||
WhichModel::Linear => 1.,
|
||||
WhichModel::Mlp => 0.05,
|
||||
WhichModel::Cnn => 0.001,
|
||||
};
|
||||
let training_args = TrainingArgs {
|
||||
epochs: args.epochs,
|
||||
@ -278,6 +159,5 @@ pub fn main() -> anyhow::Result<()> {
|
||||
match args.model {
|
||||
WhichModel::Linear => training_loop::<LinearModel>(m, &training_args),
|
||||
WhichModel::Mlp => training_loop::<Mlp>(m, &training_args),
|
||||
WhichModel::Cnn => training_loop_cnn(m, &training_args),
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::nn::conv1d_weight_norm;
|
||||
use candle::{DType, IndexOp, Result, Tensor};
|
||||
use candle_nn::{conv1d, Conv1d, Conv1dConfig, Module, VarBuilder};
|
||||
use crate::nn::{conv1d, conv1d_weight_norm, Conv1d, Conv1dConfig, VarBuilder};
|
||||
use anyhow::Result;
|
||||
use candle::{DType, IndexOp, Tensor};
|
||||
|
||||
// Encodec Model
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
|
||||
@ -182,7 +182,7 @@ impl EncodecResidualVectorQuantizer {
|
||||
fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
||||
let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?;
|
||||
if codes.dim(0)? != self.layers.len() {
|
||||
candle::bail!(
|
||||
anyhow::bail!(
|
||||
"codes shape {:?} does not match the number of quantization layers {}",
|
||||
codes.shape(),
|
||||
self.layers.len()
|
||||
@ -273,24 +273,14 @@ impl EncodecConv1d {
|
||||
in_c,
|
||||
out_c,
|
||||
kernel_size,
|
||||
Conv1dConfig {
|
||||
padding: 0,
|
||||
stride,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
},
|
||||
Conv1dConfig { padding: 0, stride },
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
NormType::None => conv1d(
|
||||
in_c,
|
||||
out_c,
|
||||
kernel_size,
|
||||
Conv1dConfig {
|
||||
padding: 0,
|
||||
stride,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
},
|
||||
Conv1dConfig { padding: 0, stride },
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
};
|
||||
@ -320,7 +310,7 @@ impl EncodecResnetBlock {
|
||||
let h = dim / cfg.compress;
|
||||
let mut layer = Layer::new(vb.pp("block"));
|
||||
if dilations.len() != 2 {
|
||||
candle::bail!("expected dilations of size 2")
|
||||
anyhow::bail!("expected dilations of size 2")
|
||||
}
|
||||
// TODO: Apply dilations!
|
||||
layer.inc();
|
||||
@ -369,7 +359,7 @@ impl<'a> Layer<'a> {
|
||||
self.cnt += 1;
|
||||
}
|
||||
|
||||
fn next(&mut self) -> VarBuilder {
|
||||
fn next(&mut self) -> VarBuilder<'a> {
|
||||
let vb = self.vb.pp(&self.cnt.to_string());
|
||||
self.cnt += 1;
|
||||
vb
|
||||
|
@ -7,21 +7,17 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
mod encodec_model;
|
||||
mod musicgen_model;
|
||||
mod nn;
|
||||
mod t5_model;
|
||||
|
||||
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
|
||||
use nn::VarBuilder;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle::DType;
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
|
||||
const DTYPE: DType = DType::F32;
|
||||
|
||||
@ -34,17 +30,11 @@ struct Args {
|
||||
|
||||
/// The model weight file, in safetensor format.
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
model: String,
|
||||
|
||||
/// The tokenizer config.
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "90s rock song with loud guitars and heavy drums"
|
||||
)]
|
||||
prompt: String,
|
||||
tokenizer: String,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -52,44 +42,13 @@ fn main() -> Result<()> {
|
||||
|
||||
let args = Args::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let tokenizer = match args.tokenizer {
|
||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||
None => Api::new()?
|
||||
.model("facebook/musicgen-small".to_string())
|
||||
.get("tokenizer.json")?,
|
||||
};
|
||||
let mut tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
let mut tokenizer = Tokenizer::from_file(args.tokenizer).map_err(E::msg)?;
|
||||
let _tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
||||
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => Api::new()?
|
||||
.repo(Repo::with_revision(
|
||||
"facebook/musicgen-small".to_string(),
|
||||
RepoType::Model,
|
||||
"refs/pr/13".to_string(),
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
};
|
||||
let model = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
||||
let model = unsafe { candle::safetensors::MmapedFile::new(args.model)? };
|
||||
let model = model.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
|
||||
let config = GenConfig::small();
|
||||
let model = MusicgenForConditionalGeneration::load(vb, config)?;
|
||||
|
||||
let tokens = tokenizer
|
||||
.encode(args.prompt.as_str(), true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
println!("tokens: {tokens:?}");
|
||||
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
println!("{tokens:?}");
|
||||
let embeds = model.text_encoder.forward(&tokens)?;
|
||||
println!("{embeds}");
|
||||
|
||||
let _model = MusicgenForConditionalGeneration::load(vb, config)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,9 +1,9 @@
|
||||
use crate::{encodec_model, t5_model};
|
||||
use candle::{DType, Device, Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
|
||||
VarBuilder,
|
||||
use crate::nn::{
|
||||
embedding, layer_norm, linear, Embedding, HiddenAct, LayerNorm, Linear, VarBuilder,
|
||||
};
|
||||
use crate::{encodec_model, t5_model};
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Device, Tensor, D};
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@ -15,7 +15,7 @@ pub struct Config {
|
||||
num_attention_heads: usize,
|
||||
layerdrop: f64,
|
||||
use_cache: bool,
|
||||
activation_function: Activation,
|
||||
activation_function: HiddenAct,
|
||||
hidden_size: usize,
|
||||
dropout: f64,
|
||||
attention_dropout: f64,
|
||||
@ -39,7 +39,7 @@ impl Default for Config {
|
||||
num_attention_heads: 16,
|
||||
layerdrop: 0.0,
|
||||
use_cache: true,
|
||||
activation_function: Activation::Gelu, // TODO: Handle old style gelu.
|
||||
activation_function: HiddenAct::Gelu, // TODO: Handle old style gelu.
|
||||
hidden_size: 1024,
|
||||
dropout: 0.1,
|
||||
attention_dropout: 0.0,
|
||||
@ -65,7 +65,7 @@ impl Config {
|
||||
num_attention_heads: 16,
|
||||
layerdrop: 0.0,
|
||||
use_cache: true,
|
||||
activation_function: Activation::Gelu, // TODO: Handle old style gelu.
|
||||
activation_function: HiddenAct::Gelu, // TODO: Handle old style gelu.
|
||||
hidden_size: 1024,
|
||||
dropout: 0.1,
|
||||
attention_dropout: 0.0,
|
||||
@ -127,7 +127,7 @@ impl MusicgenSinusoidalPositionalEmbedding {
|
||||
if seq_len > self.weights.dim(0)? {
|
||||
self.weights = get_embedding(seq_len, self.embedding_dim)?
|
||||
}
|
||||
self.weights.narrow(0, 0, seq_len)
|
||||
Ok(self.weights.narrow(0, 0, seq_len)?)
|
||||
}
|
||||
}
|
||||
|
||||
@ -148,10 +148,10 @@ impl MusicgenAttention {
|
||||
let h = cfg.hidden_size;
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let head_dim = h / num_heads;
|
||||
let k_proj = linear_no_bias(h, h, vb.pp("k_proj"))?;
|
||||
let v_proj = linear_no_bias(h, h, vb.pp("v_proj"))?;
|
||||
let q_proj = linear_no_bias(h, h, vb.pp("q_proj"))?;
|
||||
let out_proj = linear_no_bias(h, h, vb.pp("out_proj"))?;
|
||||
let k_proj = linear(h, h, false, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(h, h, false, vb.pp("v_proj"))?;
|
||||
let q_proj = linear(h, h, false, vb.pp("q_proj"))?;
|
||||
let out_proj = linear(h, h, false, vb.pp("out_proj"))?;
|
||||
Ok(Self {
|
||||
scaling: 1. / (head_dim as f64).sqrt(),
|
||||
is_decoder: true,
|
||||
@ -208,7 +208,7 @@ struct MusicgenDecoderLayer {
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
final_layer_norm: LayerNorm,
|
||||
activation_fn: Activation,
|
||||
activation_fn: HiddenAct,
|
||||
}
|
||||
|
||||
impl MusicgenDecoderLayer {
|
||||
@ -218,8 +218,8 @@ impl MusicgenDecoderLayer {
|
||||
let self_attn_layer_norm = layer_norm(h, 1e-5, vb.pp("self_attn_layer_norm"))?;
|
||||
let encoder_attn = MusicgenAttention::load(vb.pp("encoder_attn"), cfg)?;
|
||||
let encoder_attn_layer_norm = layer_norm(h, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
|
||||
let fc1 = linear_no_bias(h, cfg.ffn_dim, vb.pp("fc1"))?;
|
||||
let fc2 = linear_no_bias(cfg.ffn_dim, h, vb.pp("fc2"))?;
|
||||
let fc1 = linear(h, cfg.ffn_dim, false, vb.pp("fc1"))?;
|
||||
let fc2 = linear(cfg.ffn_dim, h, false, vb.pp("fc2"))?;
|
||||
let final_layer_norm = layer_norm(h, 1e-5, vb.pp("final_layer_norm"))?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
@ -341,7 +341,7 @@ impl MusicgenForCausalLM {
|
||||
let h = cfg.hidden_size;
|
||||
let decoder = MusicgenDecoder::load(vb.pp("model.decoder"), cfg)?;
|
||||
let lm_heads = (0..cfg.num_codebooks)
|
||||
.map(|i| linear_no_bias(h, cfg.vocab_size, vb.pp(&format!("lm_heads.{i}"))))
|
||||
.map(|i| linear(h, cfg.vocab_size, false, vb.pp(&format!("lm_heads.{i}"))))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self {
|
||||
decoder,
|
||||
@ -357,7 +357,7 @@ impl MusicgenForCausalLM {
|
||||
let lm_logits = self
|
||||
.lm_heads
|
||||
.iter()
|
||||
.map(|h| h.forward(&hidden_states))
|
||||
.map(|h| Ok(h.forward(&hidden_states)?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let lm_logits = Tensor::stack(&lm_logits, 1)?.reshape((
|
||||
b_sz * self.num_codebooks,
|
||||
@ -370,9 +370,9 @@ impl MusicgenForCausalLM {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MusicgenForConditionalGeneration {
|
||||
pub text_encoder: crate::t5_model::T5EncoderModel,
|
||||
pub audio_encoder: crate::encodec_model::EncodecModel,
|
||||
pub decoder: MusicgenForCausalLM,
|
||||
text_encoder: crate::t5_model::T5EncoderModel,
|
||||
audio_encoder: crate::encodec_model::EncodecModel,
|
||||
decoder: MusicgenForCausalLM,
|
||||
cfg: GenConfig,
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,62 @@
|
||||
use candle::Result;
|
||||
use candle_nn::{Conv1d, Conv1dConfig, VarBuilder};
|
||||
use anyhow::Result;
|
||||
use candle::Tensor;
|
||||
|
||||
const MAX_SEQ_LEN: usize = 5000;
|
||||
|
||||
pub type VarBuilder<'a> = candle_nn::VarBuilder<'a>;
|
||||
pub type Linear = candle_nn::Linear;
|
||||
|
||||
pub fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
let bias = if bias {
|
||||
Some(vb.get(size2, "bias")?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Linear::new(weight, bias))
|
||||
}
|
||||
|
||||
pub type LayerNorm = candle_nn::LayerNorm;
|
||||
|
||||
pub fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
|
||||
(Ok(weight), Ok(bias)) => (weight, bias),
|
||||
(Err(err), _) | (_, Err(err)) => {
|
||||
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
|
||||
(weight, bias)
|
||||
} else {
|
||||
return Err(err.into());
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(LayerNorm::new(weight, bias, eps))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Dropout {
|
||||
pr: f64,
|
||||
}
|
||||
|
||||
impl Dropout {
|
||||
pub fn new(pr: f64) -> Self {
|
||||
Self { pr }
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
// TODO
|
||||
Ok(x.clone())
|
||||
}
|
||||
}
|
||||
|
||||
pub type Embedding = candle_nn::Embedding;
|
||||
|
||||
pub fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
pub type Conv1d = candle_nn::Conv1d;
|
||||
pub type Conv1dConfig = candle_nn::Conv1dConfig;
|
||||
|
||||
// Applies weight norm for inference by recomputing the weight tensor. This
|
||||
// does not apply to training.
|
||||
@ -18,3 +75,17 @@ pub fn conv1d_weight_norm(
|
||||
let bias = vb.get(out_c, "bias")?;
|
||||
Ok(Conv1d::new(weight, Some(bias), config))
|
||||
}
|
||||
|
||||
pub fn conv1d(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
config: Conv1dConfig,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let weight = vb.get((out_c, in_c, kernel_size), "weight")?;
|
||||
let bias = vb.get(out_c, "bias")?;
|
||||
Ok(Conv1d::new(weight, Some(bias), config))
|
||||
}
|
||||
|
||||
pub type HiddenAct = candle_nn::Activation;
|
||||
|
@ -1,8 +1,9 @@
|
||||
// T5 Text Encoder
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, Module, VarBuilder};
|
||||
use crate::nn::{embedding, linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder};
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Tensor, D};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@ -19,7 +20,7 @@ pub struct Config {
|
||||
dropout_rate: f64,
|
||||
layer_norm_epsilon: f64,
|
||||
initializer_factor: f64,
|
||||
feed_forward_proj: Activation,
|
||||
feed_forward_proj: HiddenAct,
|
||||
is_decoder: bool,
|
||||
is_encoder_decoder: bool,
|
||||
use_cache: bool,
|
||||
@ -42,7 +43,7 @@ impl Default for Config {
|
||||
dropout_rate: 0.1,
|
||||
layer_norm_epsilon: 1e-6,
|
||||
initializer_factor: 1.0,
|
||||
feed_forward_proj: Activation::Relu,
|
||||
feed_forward_proj: HiddenAct::Relu,
|
||||
is_decoder: false,
|
||||
is_encoder_decoder: true,
|
||||
use_cache: true,
|
||||
@ -61,7 +62,7 @@ impl Config {
|
||||
d_model: 768,
|
||||
dropout_rate: 0.1,
|
||||
eos_token_id: 1,
|
||||
feed_forward_proj: Activation::Relu,
|
||||
feed_forward_proj: HiddenAct::Relu,
|
||||
initializer_factor: 1.0,
|
||||
is_decoder: false,
|
||||
is_encoder_decoder: true,
|
||||
@ -96,9 +97,10 @@ impl T5LayerNorm {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let dtype = xs.dtype();
|
||||
let xs_f32 = xs.to_dtype(DType::F32)?;
|
||||
// variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
|
||||
let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
|
||||
let xs2_f32 = (&xs_f32 * &xs_f32)?;
|
||||
let sum_xs2_f32 = xs2_f32.sum_keepdim(D::Minus1)?;
|
||||
let variance = xs2_f32.broadcast_div(&sum_xs2_f32)?;
|
||||
let xs = (xs / (variance + self.variance_epsilon)?.sqrt()?)?;
|
||||
let xs = xs.to_dtype(dtype)?;
|
||||
let xs = xs.broadcast_mul(&self.weight)?;
|
||||
Ok(xs)
|
||||
@ -109,23 +111,27 @@ impl T5LayerNorm {
|
||||
struct T5DenseActDense {
|
||||
wi: Linear,
|
||||
wo: Linear,
|
||||
act: Activation,
|
||||
dropout: Dropout,
|
||||
act: HiddenAct,
|
||||
}
|
||||
|
||||
impl T5DenseActDense {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let wi = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
|
||||
let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
|
||||
let wi = linear(cfg.d_model, cfg.d_ff, false, vb.pp("wi"))?;
|
||||
let wo = linear(cfg.d_ff, cfg.d_model, false, vb.pp("wo"))?;
|
||||
let dropout = Dropout::new(cfg.dropout_rate);
|
||||
Ok(Self {
|
||||
wi,
|
||||
wo,
|
||||
act: Activation::Relu,
|
||||
dropout,
|
||||
act: HiddenAct::Relu,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.wi.forward(xs)?;
|
||||
let xs = self.act.forward(&xs)?;
|
||||
let xs = self.dropout.forward(&xs)?;
|
||||
let xs = self.wo.forward(&xs)?;
|
||||
Ok(xs)
|
||||
}
|
||||
@ -135,6 +141,7 @@ impl T5DenseActDense {
|
||||
struct T5LayerFF {
|
||||
dense_relu_dense: T5DenseActDense,
|
||||
layer_norm: T5LayerNorm,
|
||||
dropout: Dropout,
|
||||
}
|
||||
|
||||
impl T5LayerFF {
|
||||
@ -143,16 +150,18 @@ impl T5LayerFF {
|
||||
let dense_relu_dense = T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?;
|
||||
let layer_norm =
|
||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||
let dropout = Dropout::new(cfg.dropout_rate);
|
||||
Ok(Self {
|
||||
dense_relu_dense,
|
||||
layer_norm,
|
||||
dropout,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let ys = self.layer_norm.forward(xs)?;
|
||||
let ys = self.dense_relu_dense.forward(&ys)?;
|
||||
let xs = (xs + ys)?;
|
||||
let xs = (xs + self.dropout.forward(&ys)?)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
@ -166,18 +175,15 @@ struct T5Attention {
|
||||
n_heads: usize,
|
||||
d_kv: usize,
|
||||
relative_attention_bias: Option<Embedding>,
|
||||
relative_attention_num_buckets: usize,
|
||||
relative_attention_max_distance: usize,
|
||||
inner_dim: usize,
|
||||
}
|
||||
|
||||
impl T5Attention {
|
||||
fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let inner_dim = cfg.num_heads * cfg.d_kv;
|
||||
let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?;
|
||||
let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?;
|
||||
let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?;
|
||||
let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?;
|
||||
let q = linear(cfg.d_model, inner_dim, false, vb.pp("q"))?;
|
||||
let k = linear(cfg.d_model, inner_dim, false, vb.pp("k"))?;
|
||||
let v = linear(cfg.d_model, inner_dim, false, vb.pp("v"))?;
|
||||
let o = linear(inner_dim, cfg.d_model, false, vb.pp("o"))?;
|
||||
let relative_attention_bias = if h {
|
||||
let emb = embedding(
|
||||
cfg.relative_attention_num_buckets,
|
||||
@ -196,17 +202,10 @@ impl T5Attention {
|
||||
n_heads: cfg.num_heads,
|
||||
d_kv: cfg.d_kv,
|
||||
relative_attention_bias,
|
||||
relative_attention_num_buckets: cfg.relative_attention_num_buckets,
|
||||
relative_attention_max_distance: cfg.relative_attention_max_distance,
|
||||
inner_dim,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
position_bias: Option<&Tensor>,
|
||||
) -> Result<(Tensor, Option<Tensor>)> {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
// TODO: Apply the mask(s)?
|
||||
// TODO: kv caching.
|
||||
let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?);
|
||||
@ -215,78 +214,19 @@ impl T5Attention {
|
||||
let v = self.v.forward(xs)?;
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
.transpose(1, 2)?;
|
||||
let k = k
|
||||
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
.transpose(1, 2)?;
|
||||
let v = v
|
||||
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
.transpose(1, 2)?;
|
||||
let scores = q.matmul(&k.t()?)?;
|
||||
|
||||
let (scores, position_bias) = match position_bias {
|
||||
Some(position_bias) => ((scores + position_bias)?, Some(position_bias.clone())),
|
||||
None => match &self.relative_attention_bias {
|
||||
None => (scores, None),
|
||||
Some(relative_attention_bias) => {
|
||||
let query_length = seq_len;
|
||||
let key_length = seq_len;
|
||||
// This only handles the bidirectional case.
|
||||
let num_buckets = self.relative_attention_num_buckets as u32 / 2;
|
||||
let max_exact = num_buckets / 2;
|
||||
let relative_position = (0..query_length as u32)
|
||||
.map(|i| {
|
||||
(0..key_length as u32)
|
||||
.map(|j| {
|
||||
if i < j {
|
||||
if j - i < max_exact {
|
||||
j - i + num_buckets
|
||||
} else {
|
||||
let b = f32::log(
|
||||
(j - i) as f32 / max_exact as f32,
|
||||
self.relative_attention_max_distance as f32
|
||||
/ max_exact as f32,
|
||||
) * (num_buckets - max_exact) as f32;
|
||||
u32::min(
|
||||
max_exact + num_buckets + b as u32,
|
||||
self.relative_attention_num_buckets as u32 - 1,
|
||||
)
|
||||
}
|
||||
} else if i - j < max_exact {
|
||||
i - j
|
||||
} else {
|
||||
let b = f32::log(
|
||||
(i - j) as f32 / max_exact as f32,
|
||||
self.relative_attention_max_distance as f32
|
||||
/ max_exact as f32,
|
||||
) * (num_buckets - max_exact) as f32;
|
||||
max_exact + b as u32
|
||||
}
|
||||
})
|
||||
.collect::<Vec<u32>>()
|
||||
})
|
||||
.collect::<Vec<Vec<_>>>();
|
||||
let relative_buckets = Tensor::new(relative_position, q.device())?;
|
||||
let position_bias = relative_attention_bias
|
||||
.forward(&relative_buckets)?
|
||||
.permute((2, 0, 1))?
|
||||
.unsqueeze(0)?;
|
||||
((scores + &position_bias)?, Some(position_bias))
|
||||
// TODO: position_bias_masked?
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
// TODO: position_bias_masked
|
||||
let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
|
||||
let attn_output = attn_weights.matmul(&v)?;
|
||||
let attn_output = attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, seq_len, self.inner_dim))?;
|
||||
let attn_output = self.o.forward(&attn_output)?;
|
||||
Ok((attn_output, position_bias))
|
||||
Ok(attn_output)
|
||||
}
|
||||
}
|
||||
|
||||
@ -294,6 +234,7 @@ impl T5Attention {
|
||||
struct T5LayerSelfAttention {
|
||||
self_attention: T5Attention,
|
||||
layer_norm: T5LayerNorm,
|
||||
dropout: Dropout,
|
||||
}
|
||||
|
||||
impl T5LayerSelfAttention {
|
||||
@ -301,21 +242,19 @@ impl T5LayerSelfAttention {
|
||||
let self_attention = T5Attention::load(h, vb.pp("SelfAttention"), cfg)?;
|
||||
let layer_norm =
|
||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||
let dropout = Dropout::new(cfg.dropout_rate);
|
||||
Ok(Self {
|
||||
self_attention,
|
||||
layer_norm,
|
||||
dropout,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
position_bias: Option<&Tensor>,
|
||||
) -> Result<(Tensor, Option<Tensor>)> {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let normed_xs = self.layer_norm.forward(xs)?;
|
||||
let (ys, position_bias) = self.self_attention.forward(&normed_xs, position_bias)?;
|
||||
let ys = self.self_attention.forward(&normed_xs)?;
|
||||
let ys = (xs + ys)?;
|
||||
Ok((ys, position_bias))
|
||||
Ok(ys)
|
||||
}
|
||||
}
|
||||
|
||||
@ -357,12 +296,8 @@ impl T5Block {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
position_bias: Option<&Tensor>,
|
||||
) -> Result<(Tensor, Option<Tensor>)> {
|
||||
let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias)?;
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = self.self_attn.forward(xs)?;
|
||||
// TODO: clamp for f16?
|
||||
if let Some(cross_attn) = &self.cross_attn {
|
||||
xs = cross_attn.forward(&xs)?;
|
||||
@ -370,7 +305,7 @@ impl T5Block {
|
||||
}
|
||||
let xs = self.ff.forward(&xs)?;
|
||||
// TODO: clamp for f16?
|
||||
Ok((xs, position_bias))
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
@ -379,6 +314,7 @@ struct T5Stack {
|
||||
block: Vec<T5Block>,
|
||||
shared: Arc<Embedding>,
|
||||
final_layer_norm: T5LayerNorm,
|
||||
dropout: Dropout,
|
||||
}
|
||||
|
||||
impl T5Stack {
|
||||
@ -391,24 +327,25 @@ impl T5Stack {
|
||||
cfg.layer_norm_epsilon,
|
||||
vb.pp("final_layer_norm"),
|
||||
)?;
|
||||
let dropout = Dropout::new(cfg.dropout_rate);
|
||||
Ok(Self {
|
||||
block,
|
||||
shared: shared.clone(),
|
||||
final_layer_norm,
|
||||
dropout,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let input_embeds = self.shared.as_ref().forward(input_ids)?;
|
||||
let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?);
|
||||
let (_b_sz, _seq_len) = input_embeds.dims2()?;
|
||||
|
||||
let mut hidden_states = input_embeds;
|
||||
let mut position_bias = None;
|
||||
let mut hidden_states = self.dropout.forward(&input_embeds)?;
|
||||
for block in self.block.iter() {
|
||||
(hidden_states, position_bias) =
|
||||
block.forward(&hidden_states, position_bias.as_ref())?
|
||||
hidden_states = block.forward(&hidden_states)?
|
||||
}
|
||||
let hidden_states = self.final_layer_norm.forward(&hidden_states)?;
|
||||
let hidden_states = self.dropout.forward(&hidden_states)?;
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
@ -1,367 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use candle::quantized::{ggml_file, gguf_file};
|
||||
use candle::{Device, Tensor};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
|
||||
mod model;
|
||||
use model::ModelWeights;
|
||||
|
||||
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Prompt {
|
||||
Interactive,
|
||||
Chat,
|
||||
One(String),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "7b")]
|
||||
L7b,
|
||||
#[value(name = "13b")]
|
||||
L13b,
|
||||
#[value(name = "70b")]
|
||||
L70b,
|
||||
#[value(name = "7b-chat")]
|
||||
L7bChat,
|
||||
#[value(name = "13b-chat")]
|
||||
L13bChat,
|
||||
#[value(name = "70b-chat")]
|
||||
L70bChat,
|
||||
#[value(name = "7b-code")]
|
||||
L7bCode,
|
||||
#[value(name = "13b-code")]
|
||||
L13bCode,
|
||||
#[value(name = "32b-code")]
|
||||
L34bCode,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// GGML file to load, typically a .bin file generated by the quantize command from llama.cpp
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
|
||||
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
|
||||
/// is preserved.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(short = 'n', long, default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
/// The tokenizer config in json format.
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
/// The temperature used to generate samples, use 0 for greedy sampling.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Display the token for the specified prompt.
|
||||
#[arg(long)]
|
||||
verbose_prompt: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "7b")]
|
||||
which: Which,
|
||||
|
||||
/// Group-Query Attention, use 8 for the 70B version of LLaMAv2.
|
||||
#[arg(long)]
|
||||
gqa: Option<usize>,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
|
||||
let tokenizer_path = match &self.tokenizer {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("hf-internal-testing/llama-tokenizer".to_string());
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
};
|
||||
Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
|
||||
}
|
||||
|
||||
fn model(&self) -> anyhow::Result<std::path::PathBuf> {
|
||||
let model_path = match &self.model {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let (repo, filename) = match self.which {
|
||||
Which::L7b => ("TheBloke/Llama-2-7B-GGML", "llama-2-7b.ggmlv3.q4_0.bin"),
|
||||
Which::L13b => ("TheBloke/Llama-2-13B-GGML", "llama-2-13b.ggmlv3.q4_0.bin"),
|
||||
Which::L70b => ("TheBloke/Llama-2-70B-GGML", "llama-2-70b.ggmlv3.q4_0.bin"),
|
||||
Which::L7bChat => (
|
||||
"TheBloke/Llama-2-7B-Chat-GGML",
|
||||
"llama-2-7b-chat.ggmlv3.q4_0.bin",
|
||||
),
|
||||
Which::L13bChat => (
|
||||
"TheBloke/Llama-2-13B-Chat-GGML",
|
||||
"llama-2-13b-chat.ggmlv3.q4_0.bin",
|
||||
),
|
||||
Which::L70bChat => (
|
||||
"TheBloke/Llama-2-70B-Chat-GGML",
|
||||
"llama-2-70b-chat.ggmlv3.q4_0.bin",
|
||||
),
|
||||
Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"),
|
||||
Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"),
|
||||
Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"),
|
||||
};
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(repo.to_string());
|
||||
api.get(filename)?
|
||||
}
|
||||
};
|
||||
Ok(model_path)
|
||||
}
|
||||
}
|
||||
|
||||
fn print_token(next_token: u32, tokenizer: &Tokenizer) {
|
||||
// Extracting the last token as a string is complicated, here we just apply some simple
|
||||
// heuristics as it seems to work well enough for this example. See the following for more
|
||||
// details:
|
||||
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
||||
if let Some(text) = tokenizer.id_to_token(next_token) {
|
||||
let text = text.replace('▁', " ");
|
||||
let ascii = text
|
||||
.strip_prefix("<0x")
|
||||
.and_then(|t| t.strip_suffix('>'))
|
||||
.and_then(|t| u8::from_str_radix(t, 16).ok());
|
||||
match ascii {
|
||||
None => print!("{text}"),
|
||||
Some(ascii) => {
|
||||
if let Some(chr) = char::from_u32(ascii as u32) {
|
||||
if chr.is_ascii() {
|
||||
print!("{chr}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let _ = std::io::stdout().flush();
|
||||
}
|
||||
}
|
||||
|
||||
fn format_size(size_in_bytes: usize) -> String {
|
||||
if size_in_bytes < 1_000 {
|
||||
format!("{}B", size_in_bytes)
|
||||
} else if size_in_bytes < 1_000_000 {
|
||||
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
|
||||
} else if size_in_bytes < 1_000_000_000 {
|
||||
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
|
||||
} else {
|
||||
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let temperature = if args.temperature == 0. {
|
||||
None
|
||||
} else {
|
||||
Some(args.temperature)
|
||||
};
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature, args.repeat_penalty, args.repeat_last_n
|
||||
);
|
||||
|
||||
let model_path = args.model()?;
|
||||
let mut file = std::fs::File::open(&model_path)?;
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let mut model = match model_path.extension().and_then(|v| v.to_str()) {
|
||||
Some("gguf") => {
|
||||
let model = gguf_file::Content::read(&mut file)?;
|
||||
let mut total_size_in_bytes = 0;
|
||||
for (_, tensor) in model.tensor_infos.iter() {
|
||||
let elem_count = tensor.shape.elem_count();
|
||||
total_size_in_bytes +=
|
||||
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size();
|
||||
}
|
||||
println!(
|
||||
"loaded {:?} tensors ({}) in {:.2}s",
|
||||
model.tensor_infos.len(),
|
||||
&format_size(total_size_in_bytes),
|
||||
start.elapsed().as_secs_f32(),
|
||||
);
|
||||
ModelWeights::from_gguf(model, &mut file)?
|
||||
}
|
||||
Some("ggml" | "bin") | Some(_) | None => {
|
||||
let model = ggml_file::Content::read(&mut file)?;
|
||||
let mut total_size_in_bytes = 0;
|
||||
for (_, tensor) in model.tensors.iter() {
|
||||
let elem_count = tensor.shape().elem_count();
|
||||
total_size_in_bytes +=
|
||||
elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size();
|
||||
}
|
||||
println!(
|
||||
"loaded {:?} tensors ({}) in {:.2}s",
|
||||
model.tensors.len(),
|
||||
&format_size(total_size_in_bytes),
|
||||
start.elapsed().as_secs_f32(),
|
||||
);
|
||||
println!("params: {:?}", model.hparams);
|
||||
let default_gqa = match args.which {
|
||||
Which::L7b
|
||||
| Which::L13b
|
||||
| Which::L7bChat
|
||||
| Which::L13bChat
|
||||
| Which::L7bCode
|
||||
| Which::L13bCode
|
||||
| Which::L34bCode => 1,
|
||||
Which::L70b | Which::L70bChat => 8,
|
||||
};
|
||||
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
|
||||
}
|
||||
};
|
||||
println!("model built");
|
||||
|
||||
let tokenizer = args.tokenizer()?;
|
||||
let prompt = match args.prompt.as_deref() {
|
||||
Some("chat") => Prompt::Chat,
|
||||
Some("interactive") => Prompt::Interactive,
|
||||
Some(s) => Prompt::One(s.to_string()),
|
||||
None => Prompt::One(DEFAULT_PROMPT.to_string()),
|
||||
};
|
||||
|
||||
let mut pre_prompt_tokens = vec![];
|
||||
loop {
|
||||
let prompt_str = match &prompt {
|
||||
Prompt::One(prompt) => prompt.clone(),
|
||||
Prompt::Interactive | Prompt::Chat => {
|
||||
print!("> ");
|
||||
std::io::stdout().flush()?;
|
||||
let mut prompt = String::new();
|
||||
std::io::stdin().read_line(&mut prompt)?;
|
||||
if prompt.ends_with('\n') {
|
||||
prompt.pop();
|
||||
if prompt.ends_with('\r') {
|
||||
prompt.pop();
|
||||
}
|
||||
}
|
||||
prompt
|
||||
}
|
||||
};
|
||||
print!("{}", &prompt_str);
|
||||
let tokens = tokenizer
|
||||
.encode(prompt_str, true)
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
if args.verbose_prompt {
|
||||
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
||||
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
println!("{id:7} -> '{token}'");
|
||||
}
|
||||
}
|
||||
|
||||
let prompt_tokens = [&pre_prompt_tokens, tokens.get_ids()].concat();
|
||||
let to_sample = args.sample_len.saturating_sub(1);
|
||||
let prompt_tokens = if prompt_tokens.len() + to_sample > model::MAX_SEQ_LEN - 10 {
|
||||
let to_remove = prompt_tokens.len() + to_sample + 10 - model::MAX_SEQ_LEN;
|
||||
prompt_tokens[prompt_tokens.len().saturating_sub(to_remove)..].to_vec()
|
||||
} else {
|
||||
prompt_tokens
|
||||
};
|
||||
let mut all_tokens = vec![];
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, temperature);
|
||||
|
||||
let start_prompt_processing = std::time::Instant::now();
|
||||
let mut next_token = {
|
||||
let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
logits_processor.sample(&logits)?
|
||||
};
|
||||
let prompt_dt = start_prompt_processing.elapsed();
|
||||
all_tokens.push(next_token);
|
||||
print_token(next_token, &tokenizer);
|
||||
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
for index in 0..to_sample {
|
||||
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&all_tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
next_token = logits_processor.sample(&logits)?;
|
||||
all_tokens.push(next_token);
|
||||
print_token(next_token, &tokenizer);
|
||||
}
|
||||
let dt = start_post_prompt.elapsed();
|
||||
println!(
|
||||
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
||||
prompt_tokens.len(),
|
||||
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
||||
);
|
||||
println!(
|
||||
"{:4} tokens generated: {:.2} token/s",
|
||||
to_sample,
|
||||
to_sample as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
|
||||
match prompt {
|
||||
Prompt::One(_) => break,
|
||||
Prompt::Interactive => {}
|
||||
Prompt::Chat => {
|
||||
pre_prompt_tokens = [prompt_tokens.as_slice(), all_tokens.as_slice()].concat()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
@ -1,371 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use candle::quantized::QTensor;
|
||||
use candle::quantized::{ggml_file, gguf_file};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, Module};
|
||||
|
||||
pub const MAX_SEQ_LEN: usize = 4096;
|
||||
|
||||
struct RmsNorm {
|
||||
inner: candle_nn::LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
fn new(scale: QTensor, eps: f32) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
||||
let scale = scale.dequantize(&Device::Cpu)?;
|
||||
let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64);
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
// QMatMul wrapper adding some tracing.
|
||||
struct QMatMul {
|
||||
inner: candle::quantized::QMatMul,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl QMatMul {
|
||||
fn from_qtensor(qtensor: QTensor) -> Self {
|
||||
let inner = candle::quantized::QMatMul::from_qtensor(qtensor);
|
||||
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
||||
Self { inner, span }
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(xs)
|
||||
}
|
||||
}
|
||||
|
||||
struct LayerWeights {
|
||||
attention_wq: QMatMul,
|
||||
attention_wk: QMatMul,
|
||||
attention_wv: QMatMul,
|
||||
attention_wo: QMatMul,
|
||||
attention_norm: RmsNorm,
|
||||
feed_forward_w1: QMatMul,
|
||||
feed_forward_w2: QMatMul,
|
||||
feed_forward_w3: QMatMul,
|
||||
ffn_norm: RmsNorm,
|
||||
n_head: usize,
|
||||
n_kv_head: usize,
|
||||
head_dim: usize,
|
||||
cos: Tensor,
|
||||
sin: Tensor,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
span_attn: tracing::Span,
|
||||
span_rot: tracing::Span,
|
||||
span_mlp: tracing::Span,
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
impl LayerWeights {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let _enter = self.span_rot.enter();
|
||||
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
|
||||
let cos = self
|
||||
.cos
|
||||
.narrow(0, index_pos, seq_len)?
|
||||
.reshape((seq_len, n_embd / 2, 1))?;
|
||||
let sin = self
|
||||
.sin
|
||||
.narrow(0, index_pos, seq_len)?
|
||||
.reshape((seq_len, n_embd / 2, 1))?;
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
|
||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
|
||||
// This mimics the llama.cpp behavior.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105
|
||||
// The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension.
|
||||
// The resulting y0 and y1 are also interleaved with:
|
||||
// y0 = x0*cos - x1*sin
|
||||
// y1 = x0*sin + x1*cos
|
||||
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
|
||||
let x0 = x.narrow(D::Minus1, 0, 1)?;
|
||||
let x1 = x.narrow(D::Minus1, 1, 1)?;
|
||||
let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
|
||||
let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
|
||||
let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
|
||||
let rope = rope.flatten_from(D::Minus2)?;
|
||||
Ok(rope)
|
||||
}
|
||||
|
||||
fn forward_attn(&mut self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let _enter = self.span_attn.enter();
|
||||
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
||||
let q = self.attention_wq.forward(x)?;
|
||||
let k = self.attention_wk.forward(x)?;
|
||||
let v = self.attention_wv.forward(x)?;
|
||||
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let k = k
|
||||
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let v = v
|
||||
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||
let k = self.apply_rotary_emb(&k, index_pos)?;
|
||||
|
||||
let (k, v) = match &self.kv_cache {
|
||||
None => (k, v),
|
||||
Some((k_cache, v_cache)) => {
|
||||
if index_pos == 0 {
|
||||
(k, v)
|
||||
} else {
|
||||
let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?;
|
||||
let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?;
|
||||
(k, v)
|
||||
}
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((k.clone(), v.clone()));
|
||||
|
||||
// Support for MQA, useful for 70B models.
|
||||
let k = self.repeat_kv(k)?;
|
||||
let v = self.repeat_kv(v)?;
|
||||
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let mask = mask.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||
let y = self.attention_wo.forward(&y)?;
|
||||
Ok(y)
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.n_head / self.n_kv_head;
|
||||
if n_rep == 1 {
|
||||
Ok(x)
|
||||
} else {
|
||||
let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
|
||||
let x = x
|
||||
.unsqueeze(2)?
|
||||
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ModelWeights {
|
||||
tok_embeddings: Embedding,
|
||||
layers: Vec<LayerWeights>,
|
||||
norm: RmsNorm,
|
||||
output: QMatMul,
|
||||
masks: HashMap<usize, Tensor>,
|
||||
span: tracing::Span,
|
||||
span_output: tracing::Span,
|
||||
}
|
||||
|
||||
fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> {
|
||||
let theta: Vec<_> = (0..head_dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((MAX_SEQ_LEN, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
let cos = idx_theta.cos()?;
|
||||
let sin = idx_theta.sin()?;
|
||||
Ok((cos, sin))
|
||||
}
|
||||
|
||||
impl ModelWeights {
|
||||
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
|
||||
let cpu = &Device::Cpu;
|
||||
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
|
||||
let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?;
|
||||
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
|
||||
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
|
||||
let output = ct.remove("output.weight")?;
|
||||
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
|
||||
for layer_idx in 0..ct.hparams.n_layer {
|
||||
let prefix = format!("layers.{layer_idx}");
|
||||
let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?;
|
||||
let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
|
||||
let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
|
||||
let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
|
||||
let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
|
||||
let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
|
||||
let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
|
||||
let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?;
|
||||
let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?;
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
||||
layers.push(LayerWeights {
|
||||
attention_wq: QMatMul::from_qtensor(attention_wq),
|
||||
attention_wk: QMatMul::from_qtensor(attention_wk),
|
||||
attention_wv: QMatMul::from_qtensor(attention_wv),
|
||||
attention_wo: QMatMul::from_qtensor(attention_wo),
|
||||
attention_norm: RmsNorm::new(attention_norm, 1e-5)?,
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1),
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2),
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
|
||||
ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?,
|
||||
n_head: ct.hparams.n_head as usize,
|
||||
n_kv_head: ct.hparams.n_head as usize / gqa,
|
||||
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
|
||||
cos: cos.clone(),
|
||||
sin: sin.clone(),
|
||||
kv_cache: None,
|
||||
span_attn,
|
||||
span_rot,
|
||||
span_mlp,
|
||||
})
|
||||
}
|
||||
let span = tracing::span!(tracing::Level::TRACE, "model");
|
||||
let span_output = tracing::span!(tracing::Level::TRACE, "output");
|
||||
Ok(Self {
|
||||
tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
|
||||
layers,
|
||||
norm,
|
||||
output: QMatMul::from_qtensor(output),
|
||||
masks: HashMap::new(),
|
||||
span,
|
||||
span_output,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
||||
ct: gguf_file::Content,
|
||||
reader: &mut R,
|
||||
) -> Result<Self> {
|
||||
let cpu = &Device::Cpu;
|
||||
let md_get = |s: &str| match ct.metadata.get(s) {
|
||||
None => candle::bail!("cannot find {s} in metadata"),
|
||||
Some(v) => Ok(v),
|
||||
};
|
||||
|
||||
// Parameter extraction from metadata.
|
||||
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
|
||||
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
|
||||
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
|
||||
let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
|
||||
let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
|
||||
// Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
|
||||
let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()?;
|
||||
|
||||
let rope_freq_base = md_get("llama.rope.freq_base")
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(10000f32);
|
||||
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?;
|
||||
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight")?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
|
||||
let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?;
|
||||
let output = ct.tensor(reader, "output.weight")?;
|
||||
let mut layers = Vec::with_capacity(block_count);
|
||||
for layer_idx in 0..block_count {
|
||||
let prefix = format!("blk.{layer_idx}");
|
||||
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?;
|
||||
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?;
|
||||
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?;
|
||||
let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?;
|
||||
let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?;
|
||||
let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?;
|
||||
let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?;
|
||||
let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?;
|
||||
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?;
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
||||
layers.push(LayerWeights {
|
||||
attention_wq: QMatMul::from_qtensor(attention_wq),
|
||||
attention_wk: QMatMul::from_qtensor(attention_wk),
|
||||
attention_wv: QMatMul::from_qtensor(attention_wv),
|
||||
attention_wo: QMatMul::from_qtensor(attention_wo),
|
||||
attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?,
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1),
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2),
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
|
||||
ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?,
|
||||
n_head: head_count,
|
||||
n_kv_head: head_count_kv,
|
||||
head_dim: embedding_length / head_count,
|
||||
cos: cos.clone(),
|
||||
sin: sin.clone(),
|
||||
kv_cache: None,
|
||||
span_attn,
|
||||
span_rot,
|
||||
span_mlp,
|
||||
})
|
||||
}
|
||||
let span = tracing::span!(tracing::Level::TRACE, "model");
|
||||
let span_output = tracing::span!(tracing::Level::TRACE, "output");
|
||||
Ok(Self {
|
||||
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
|
||||
layers,
|
||||
norm,
|
||||
output: QMatMul::from_qtensor(output),
|
||||
masks: HashMap::new(),
|
||||
span,
|
||||
span_output,
|
||||
})
|
||||
}
|
||||
|
||||
fn mask(&mut self, t: usize) -> Result<Tensor> {
|
||||
if let Some(mask) = self.masks.get(&t) {
|
||||
Ok(mask.clone())
|
||||
} else {
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
|
||||
self.masks.insert(t, mask.clone());
|
||||
Ok(mask)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len) = x.dims2()?;
|
||||
let mask = self.mask(seq_len)?;
|
||||
let _enter = self.span.enter();
|
||||
let mut layer_in = self.tok_embeddings.forward(x)?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
let x = layer_in;
|
||||
let residual = &x;
|
||||
let x = layer.attention_norm.forward(&x)?;
|
||||
let attn = layer.forward_attn(&x, &mask, index_pos)?;
|
||||
let x = (attn + residual)?;
|
||||
|
||||
// MLP
|
||||
let _enter = layer.span_mlp.enter();
|
||||
let residual = &x;
|
||||
let x = layer.ffn_norm.forward(&x)?;
|
||||
let w1 = layer.feed_forward_w1.forward(&x)?;
|
||||
let w3 = layer.feed_forward_w3.forward(&x)?;
|
||||
let mlp = layer
|
||||
.feed_forward_w2
|
||||
.forward(&(candle_nn::ops::silu(&w1)? * w3)?)?;
|
||||
layer_in = (mlp + residual)?;
|
||||
}
|
||||
let x = self.norm.forward(&layer_in)?;
|
||||
let x = x.i((.., seq_len - 1, ..))?;
|
||||
let _enter = self.span_output.enter();
|
||||
self.output.forward(&x)
|
||||
}
|
||||
}
|
@ -1,273 +0,0 @@
|
||||
//! SAM: Segment Anything Model
|
||||
//! https://github.com/facebookresearch/segment-anything
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
pub mod model_image_encoder;
|
||||
pub mod model_mask_decoder;
|
||||
pub mod model_prompt_encoder;
|
||||
pub mod model_sam;
|
||||
pub mod model_transformer;
|
||||
|
||||
use candle::{DType, Result, Tensor};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use clap::Parser;
|
||||
|
||||
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
|
||||
let inner = if bias {
|
||||
candle_nn::linear(in_dim, out_dim, vb)?
|
||||
} else {
|
||||
candle_nn::linear_no_bias(in_dim, out_dim, vb)?
|
||||
};
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
Ok(Linear { inner, span })
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LayerNorm2d {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
num_channels: usize,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl LayerNorm2d {
|
||||
pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let weight = vb.get(num_channels, "weight")?;
|
||||
let bias = vb.get(num_channels, "bias")?;
|
||||
Ok(Self {
|
||||
weight,
|
||||
bias,
|
||||
num_channels,
|
||||
eps,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for LayerNorm2d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let u = xs.mean_keepdim(1)?;
|
||||
let xs = xs.broadcast_sub(&u)?;
|
||||
let s = xs.sqr()?.mean_keepdim(1)?;
|
||||
let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?;
|
||||
xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)?
|
||||
.broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MlpBlock {
|
||||
lin1: Linear,
|
||||
lin2: Linear,
|
||||
activation: candle_nn::Activation,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl MlpBlock {
|
||||
pub fn new(
|
||||
embedding_dim: usize,
|
||||
mlp_dim: usize,
|
||||
activation: candle_nn::Activation,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?;
|
||||
let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "mlp-block");
|
||||
Ok(Self {
|
||||
lin1,
|
||||
lin2,
|
||||
activation,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MlpBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
xs.apply(&self.lin1)?
|
||||
.apply(&self.activation)?
|
||||
.apply(&self.lin2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Linear {
|
||||
inner: candle_nn::Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Module for Linear {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(long)]
|
||||
generate_masks: bool,
|
||||
|
||||
#[arg(long, default_value_t = 0.5)]
|
||||
point_x: f64,
|
||||
|
||||
#[arg(long, default_value_t = 0.5)]
|
||||
point_y: f64,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let (image, initial_h, initial_w) = if args.image.ends_with(".safetensors") {
|
||||
let mut tensors = candle::safetensors::load(&args.image, &device)?;
|
||||
let image = match tensors.remove("image") {
|
||||
Some(image) => image,
|
||||
None => {
|
||||
if tensors.len() != 1 {
|
||||
anyhow::bail!("multiple tensors in '{}'", args.image)
|
||||
}
|
||||
tensors.into_values().next().unwrap()
|
||||
}
|
||||
};
|
||||
let image = if image.rank() == 4 {
|
||||
image.get(0)?
|
||||
} else {
|
||||
image
|
||||
};
|
||||
let (_c, h, w) = image.dims3()?;
|
||||
(image, h, w)
|
||||
} else {
|
||||
let (image, h, w) = candle_examples::load_image(&args.image, Some(model_sam::IMAGE_SIZE))?;
|
||||
(image.to_device(&device)?, h, w)
|
||||
};
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("lmz/candle-sam".to_string());
|
||||
api.get("sam_vit_b_01ec64.safetensors")?
|
||||
}
|
||||
};
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
|
||||
|
||||
if args.generate_masks {
|
||||
// Default options similar to the Python version.
|
||||
let bboxes = sam.generate_masks(
|
||||
&image,
|
||||
/* points_per_side */ 32,
|
||||
/* crop_n_layer */ 0,
|
||||
/* crop_overlap_ratio */ 512. / 1500.,
|
||||
/* crop_n_points_downscale_factor */ 1,
|
||||
)?;
|
||||
for (idx, bbox) in bboxes.iter().enumerate() {
|
||||
println!("{bbox:?}");
|
||||
let mask = (&bbox.data.to_dtype(DType::U8)? * 255.)?;
|
||||
let (h, w) = mask.dims2()?;
|
||||
let mask = mask.broadcast_as((3, h, w))?;
|
||||
candle_examples::save_image_resize(
|
||||
&mask,
|
||||
format!("sam_mask{idx}.png"),
|
||||
initial_h,
|
||||
initial_w,
|
||||
)?;
|
||||
}
|
||||
} else {
|
||||
let point = Some((args.point_x, args.point_y));
|
||||
let start_time = std::time::Instant::now();
|
||||
let (mask, iou_predictions) = sam.forward(&image, point, false)?;
|
||||
println!(
|
||||
"mask generated in {:.2}s",
|
||||
start_time.elapsed().as_secs_f32()
|
||||
);
|
||||
println!("mask:\n{mask}");
|
||||
println!("iou_predictions: {iou_predictions:?}");
|
||||
|
||||
// Save the mask as an image.
|
||||
let mask = (mask.ge(0f32)? * 255.)?;
|
||||
let (_one, h, w) = mask.dims3()?;
|
||||
let mask = mask.expand((3, h, w))?;
|
||||
candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?;
|
||||
|
||||
if !args.image.ends_with(".safetensors") {
|
||||
let mut img = image::io::Reader::open(&args.image)?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?;
|
||||
let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
|
||||
let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
||||
match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) {
|
||||
Some(image) => image,
|
||||
None => anyhow::bail!("error saving merged image"),
|
||||
};
|
||||
let mask_img = image::DynamicImage::from(mask_img).resize_to_fill(
|
||||
img.width(),
|
||||
img.height(),
|
||||
image::imageops::FilterType::CatmullRom,
|
||||
);
|
||||
for x in 0..img.width() {
|
||||
for y in 0..img.height() {
|
||||
let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y);
|
||||
if mask_p.0[0] > 100 {
|
||||
let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y);
|
||||
img_p.0[2] = 255 - (255 - img_p.0[2]) / 2;
|
||||
img_p.0[1] /= 2;
|
||||
img_p.0[0] /= 2;
|
||||
imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p)
|
||||
}
|
||||
}
|
||||
}
|
||||
match point {
|
||||
Some((x, y)) => {
|
||||
let (x, y) = (
|
||||
(x * img.width() as f64) as i32,
|
||||
(y * img.height() as f64) as i32,
|
||||
);
|
||||
imageproc::drawing::draw_filled_circle(
|
||||
&img,
|
||||
(x, y),
|
||||
3,
|
||||
image::Rgba([255, 0, 0, 200]),
|
||||
)
|
||||
.save("sam_merged.jpg")?
|
||||
}
|
||||
None => img.save("sam_merged.jpg")?,
|
||||
};
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -1,516 +0,0 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor};
|
||||
use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PatchEmbed {
|
||||
proj: candle_nn::Conv2d,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl PatchEmbed {
|
||||
fn new(
|
||||
in_chans: usize,
|
||||
embed_dim: usize,
|
||||
k_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let cfg = candle_nn::Conv2dConfig {
|
||||
stride,
|
||||
padding,
|
||||
..Default::default()
|
||||
};
|
||||
let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
|
||||
Ok(Self { proj, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for PatchEmbed {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
xs.apply(&self.proj)?.permute((0, 2, 3, 1))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Attention {
|
||||
qkv: crate::Linear,
|
||||
proj: crate::Linear,
|
||||
num_heads: usize,
|
||||
scale: f64,
|
||||
rel_pos_hw: Option<(Tensor, Tensor)>,
|
||||
span: tracing::Span,
|
||||
span_rel_pos: tracing::Span,
|
||||
span_softmax: tracing::Span,
|
||||
}
|
||||
|
||||
// rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
||||
struct EinSum1;
|
||||
impl candle::CustomOp2 for EinSum1 {
|
||||
fn name(&self) -> &'static str {
|
||||
"einsum1"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &candle::CpuStorage,
|
||||
l1: &candle::Layout,
|
||||
s2: &candle::CpuStorage,
|
||||
l2: &candle::Layout,
|
||||
) -> Result<(candle::CpuStorage, candle::Shape)> {
|
||||
use candle::cpu::kernels::VecOps;
|
||||
|
||||
let (b, h, w, c) = l1.shape().dims4()?;
|
||||
let (h2, k, c2) = l2.shape().dims3()?;
|
||||
if c != c2 || h != h2 {
|
||||
candle::bail!("shape mismatch {l1:?} {l2:?}")
|
||||
}
|
||||
let s1 = s1.as_slice::<f32>()?;
|
||||
let s1 = match l1.contiguous_offsets() {
|
||||
None => candle::bail!("input1 has to be contiguous"),
|
||||
Some((o1, o2)) => &s1[o1..o2],
|
||||
};
|
||||
let s2 = s2.as_slice::<f32>()?;
|
||||
let s2 = match l2.contiguous_offsets() {
|
||||
None => candle::bail!("input2 has to be contiguous"),
|
||||
Some((o1, o2)) => &s2[o1..o2],
|
||||
};
|
||||
let mut dst = vec![0f32; b * h * w * k];
|
||||
for b_idx in 0..b {
|
||||
let lhs_idx = b_idx * h * w * c;
|
||||
let dst_idx = b_idx * h * w * k;
|
||||
for h_idx in 0..h {
|
||||
let lhs_idx = lhs_idx + h_idx * w * c;
|
||||
let rhs_idx = h_idx * k * c;
|
||||
let dst_idx = dst_idx + h_idx * w * k;
|
||||
for w_idx in 0..w {
|
||||
let lhs_idx = lhs_idx + w_idx * c;
|
||||
let dst_idx = dst_idx + w_idx * k;
|
||||
let lhs = &s1[lhs_idx..lhs_idx + c];
|
||||
for k_idx in 0..k {
|
||||
let rhs_idx = rhs_idx + k_idx * c;
|
||||
let rhs = &s2[rhs_idx..rhs_idx + c];
|
||||
let mut d = 0f32;
|
||||
unsafe { f32::vec_dot(lhs.as_ptr(), rhs.as_ptr(), &mut d, c) };
|
||||
dst[dst_idx + k_idx] += d;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||
Ok((storage, (b, h, w, k).into()))
|
||||
}
|
||||
}
|
||||
|
||||
// rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
||||
struct EinSum2;
|
||||
impl candle::CustomOp2 for EinSum2 {
|
||||
fn name(&self) -> &'static str {
|
||||
"einsum2"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &candle::CpuStorage,
|
||||
l1: &candle::Layout,
|
||||
s2: &candle::CpuStorage,
|
||||
l2: &candle::Layout,
|
||||
) -> Result<(candle::CpuStorage, candle::Shape)> {
|
||||
use candle::cpu::kernels::VecOps;
|
||||
|
||||
let (b, h, w, c) = l1.shape().dims4()?;
|
||||
let (w2, k, c2) = l2.shape().dims3()?;
|
||||
if c != c2 || w != w2 {
|
||||
candle::bail!("shape mismatch {l1:?} {l2:?}")
|
||||
}
|
||||
let s1 = s1.as_slice::<f32>()?;
|
||||
let s1 = match l1.contiguous_offsets() {
|
||||
None => candle::bail!("input1 has to be contiguous"),
|
||||
Some((o1, o2)) => &s1[o1..o2],
|
||||
};
|
||||
let s2 = s2.as_slice::<f32>()?;
|
||||
let s2 = match l2.contiguous_offsets() {
|
||||
None => candle::bail!("input2 has to be contiguous"),
|
||||
Some((o1, o2)) => &s2[o1..o2],
|
||||
};
|
||||
let mut dst = vec![0f32; b * h * w * k];
|
||||
for b_idx in 0..b {
|
||||
let lhs_idx = b_idx * h * w * c;
|
||||
let dst_idx = b_idx * h * w * k;
|
||||
for h_idx in 0..h {
|
||||
let lhs_idx = lhs_idx + h_idx * w * c;
|
||||
let dst_idx = dst_idx + h_idx * w * k;
|
||||
for w_idx in 0..w {
|
||||
let lhs_idx = lhs_idx + w_idx * c;
|
||||
let rhs_idx = w_idx * k * c;
|
||||
let dst_idx = dst_idx + w_idx * k;
|
||||
let lhs = &s1[lhs_idx..lhs_idx + c];
|
||||
for k_idx in 0..k {
|
||||
let rhs_idx = rhs_idx + k_idx * c;
|
||||
let rhs = &s2[rhs_idx..rhs_idx + c];
|
||||
let mut d = 0f32;
|
||||
unsafe { f32::vec_dot(lhs.as_ptr(), rhs.as_ptr(), &mut d, c) };
|
||||
dst[dst_idx + k_idx] += d;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||
Ok((storage, (b, h, w, k).into()))
|
||||
}
|
||||
}
|
||||
impl Attention {
|
||||
fn new(
|
||||
dim: usize,
|
||||
num_heads: usize,
|
||||
qkv_bias: bool,
|
||||
use_rel_pos: bool,
|
||||
input_size: (usize, usize),
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attention");
|
||||
let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos");
|
||||
let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
|
||||
let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
|
||||
let proj = crate::linear(vb.pp("proj"), dim, dim, true)?;
|
||||
let head_dim = dim / num_heads;
|
||||
let scale = 1. / (head_dim as f64).sqrt();
|
||||
let rel_pos_hw = if use_rel_pos {
|
||||
let h = vb.get((2 * input_size.0 - 1, head_dim), "rel_pos_h")?;
|
||||
let w = vb.get((2 * input_size.1 - 1, head_dim), "rel_pos_w")?;
|
||||
Some((h, w))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
qkv,
|
||||
proj,
|
||||
num_heads,
|
||||
scale,
|
||||
rel_pos_hw,
|
||||
span,
|
||||
span_rel_pos,
|
||||
span_softmax,
|
||||
})
|
||||
}
|
||||
|
||||
fn add_decomposed_rel_pos(
|
||||
&self,
|
||||
attn: Tensor,
|
||||
q: &Tensor,
|
||||
(q_h, q_w): (usize, usize),
|
||||
(k_h, k_w): (usize, usize),
|
||||
) -> Result<Tensor> {
|
||||
match &self.rel_pos_hw {
|
||||
Some((rel_pos_h, rel_pos_w)) => {
|
||||
println!("{:?} {:?}", attn.layout(), q.layout());
|
||||
let r_h = get_rel_pos(q_h, k_h, rel_pos_h)?;
|
||||
let r_w = get_rel_pos(q_w, k_w, rel_pos_w)?;
|
||||
let (b, _, dim) = q.dims3()?;
|
||||
let r_q = q.reshape((b, q_h, q_w, dim))?;
|
||||
// rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
||||
let rel_h = r_q.apply_op2_no_bwd(&r_h, &EinSum1)?;
|
||||
// rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
||||
let rel_w = r_q.apply_op2_no_bwd(&r_w, &EinSum2)?;
|
||||
(attn.reshape((b, q_h, q_w, k_h, k_w))?
|
||||
+ rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
|
||||
.reshape((b, q_h * q_w, k_h * k_w))
|
||||
}
|
||||
None => Ok(attn),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_rel_pos(q_size: usize, k_size: usize, rel_pos: &Tensor) -> Result<Tensor> {
|
||||
let max_rel_dist = 2 * usize::max(q_size, k_size) - 1;
|
||||
let dev = rel_pos.device();
|
||||
let rel_pos_resized = if rel_pos.dim(0)? != max_rel_dist {
|
||||
todo!("interpolation")
|
||||
} else {
|
||||
rel_pos
|
||||
};
|
||||
let q_coords = Tensor::arange(0u32, q_size as u32, dev)?
|
||||
.reshape((q_size, 1))?
|
||||
.to_dtype(DType::F32)?;
|
||||
let k_coords = Tensor::arange(0u32, k_size as u32, dev)?
|
||||
.reshape((1, k_size))?
|
||||
.to_dtype(DType::F32)?;
|
||||
let q_coords = (q_coords * f64::max(1f64, k_size as f64 / q_size as f64))?;
|
||||
let k_coords = (k_coords * f64::max(1f64, q_size as f64 / k_size as f64))?;
|
||||
let relative_coords = (q_coords.broadcast_sub(&k_coords)?
|
||||
+ (k_size as f64 - 1.) * f64::max(1f64, q_size as f64 / k_size as f64))?;
|
||||
let (d1, d2) = relative_coords.dims2()?;
|
||||
let relative_coords = relative_coords.to_dtype(DType::U32)?;
|
||||
rel_pos_resized
|
||||
.index_select(&relative_coords.reshape(d1 * d2)?, 0)?
|
||||
.reshape((d1, d2, ()))
|
||||
}
|
||||
|
||||
impl Module for Attention {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b, h, w, c) = xs.dims4()?;
|
||||
let qkv = self
|
||||
.qkv
|
||||
.forward(&xs.flatten_to(1)?)?
|
||||
.reshape((b, h * w, 3, self.num_heads, c / self.num_heads))?
|
||||
.permute((2, 0, 3, 1, 4))?
|
||||
.reshape((3, b * self.num_heads, h * w, c / self.num_heads))?;
|
||||
let q = qkv.i(0)?;
|
||||
let k = qkv.i(1)?;
|
||||
let v = qkv.i(2)?;
|
||||
let attn = (&q * self.scale)?.matmul(&k.t()?)?;
|
||||
let attn = {
|
||||
let _enter = self.span_rel_pos.enter();
|
||||
self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))?
|
||||
};
|
||||
let attn = {
|
||||
let _enter = self.span_softmax.enter();
|
||||
candle_nn::ops::softmax_last_dim(&attn)?
|
||||
};
|
||||
let attn = attn.matmul(&v)?;
|
||||
let attn = attn
|
||||
.reshape((b, self.num_heads, h, w, c / self.num_heads))?
|
||||
.permute((0, 2, 3, 1, 4))?
|
||||
.reshape((b, h * w, c))?;
|
||||
self.proj.forward(&attn)?.reshape((b, h, w, c))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Block {
|
||||
norm1: LayerNorm,
|
||||
attn: Attention,
|
||||
norm2: LayerNorm,
|
||||
mlp: crate::MlpBlock,
|
||||
window_size: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(
|
||||
dim: usize,
|
||||
num_heads: usize,
|
||||
qkv_bias: bool,
|
||||
use_rel_pos: bool,
|
||||
window_size: usize,
|
||||
input_size: (usize, usize),
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?;
|
||||
let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?;
|
||||
let input_size_attn = if window_size == 0 {
|
||||
input_size
|
||||
} else {
|
||||
(window_size, window_size)
|
||||
};
|
||||
let attn = Attention::new(
|
||||
dim,
|
||||
num_heads,
|
||||
qkv_bias,
|
||||
use_rel_pos,
|
||||
input_size_attn,
|
||||
vb.pp("attn"),
|
||||
)?;
|
||||
let mlp = crate::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "ie-block");
|
||||
Ok(Self {
|
||||
norm1,
|
||||
attn,
|
||||
norm2,
|
||||
mlp,
|
||||
window_size,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn window_partition(xs: Tensor, window_size: usize) -> Result<(Tensor, (usize, usize))> {
|
||||
let (b, h, w, c) = xs.dims4()?;
|
||||
let pad_h = (window_size - h % window_size) % window_size;
|
||||
let pad_w = (window_size - w % window_size) % window_size;
|
||||
let xs = if pad_h > 0 {
|
||||
xs.pad_with_zeros(1, 0, pad_h)?
|
||||
} else {
|
||||
xs
|
||||
};
|
||||
let xs = if pad_w > 0 {
|
||||
xs.pad_with_zeros(2, 0, pad_w)?
|
||||
} else {
|
||||
xs
|
||||
};
|
||||
let (h_p, w_p) = (h + pad_h, w + pad_w);
|
||||
let windows = xs
|
||||
.reshape((
|
||||
b,
|
||||
h_p / window_size,
|
||||
window_size,
|
||||
w_p / window_size,
|
||||
window_size,
|
||||
c,
|
||||
))?
|
||||
.transpose(2, 3)?
|
||||
.contiguous()?
|
||||
.flatten_to(2)?;
|
||||
Ok((windows, (h_p, w_p)))
|
||||
}
|
||||
|
||||
fn window_unpartition(
|
||||
windows: Tensor,
|
||||
window_size: usize,
|
||||
(h_p, w_p): (usize, usize),
|
||||
(h, w): (usize, usize),
|
||||
) -> Result<Tensor> {
|
||||
let b = windows.dim(0)? / (h_p * w_p / window_size / window_size);
|
||||
let xs = windows
|
||||
.reshape((
|
||||
b,
|
||||
h_p / window_size,
|
||||
w_p / window_size,
|
||||
window_size,
|
||||
window_size,
|
||||
windows.elem_count() / b / h_p / w_p,
|
||||
))?
|
||||
.transpose(2, 3)?
|
||||
.contiguous()?
|
||||
.reshape((b, h_p, w_p, ()))?;
|
||||
let xs = if h_p > h { xs.narrow(1, 0, h)? } else { xs };
|
||||
let xs = if w_p > w { xs.narrow(2, 0, w)? } else { xs };
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
impl Module for Block {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let shortcut = xs;
|
||||
let xs = self.norm1.forward(xs)?;
|
||||
let hw = (xs.dim(1)?, xs.dim(2)?);
|
||||
let (xs, pad_hw) = if self.window_size > 0 {
|
||||
window_partition(xs, self.window_size)?
|
||||
} else {
|
||||
(xs, (0, 0))
|
||||
};
|
||||
let xs = self.attn.forward(&xs)?;
|
||||
let xs = if self.window_size > 0 {
|
||||
window_unpartition(xs, self.window_size, pad_hw, hw)?
|
||||
} else {
|
||||
xs
|
||||
};
|
||||
let xs = (xs + shortcut)?;
|
||||
&xs + xs.apply(&self.norm2)?.apply(&self.mlp)?
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ImageEncoderViT {
|
||||
patch_embed: PatchEmbed,
|
||||
blocks: Vec<Block>,
|
||||
neck_conv1: candle_nn::Conv2d,
|
||||
neck_ln1: crate::LayerNorm2d,
|
||||
neck_conv2: candle_nn::Conv2d,
|
||||
neck_ln2: crate::LayerNorm2d,
|
||||
pos_embed: Option<Tensor>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl ImageEncoderViT {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
img_size: usize,
|
||||
patch_size: usize,
|
||||
in_chans: usize,
|
||||
embed_dim: usize,
|
||||
depth: usize,
|
||||
num_heads: usize,
|
||||
out_chans: usize,
|
||||
qkv_bias: bool,
|
||||
use_rel_pos: bool,
|
||||
use_abs_pos: bool,
|
||||
window_size: usize,
|
||||
global_attn_indexes: &[usize],
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let patch_embed = PatchEmbed::new(
|
||||
in_chans,
|
||||
embed_dim,
|
||||
patch_size,
|
||||
patch_size,
|
||||
0,
|
||||
vb.pp("patch_embed"),
|
||||
)?;
|
||||
let mut blocks = Vec::with_capacity(depth);
|
||||
let vb_b = vb.pp("blocks");
|
||||
for i in 0..depth {
|
||||
let window_size = if global_attn_indexes.contains(&i) {
|
||||
0
|
||||
} else {
|
||||
window_size
|
||||
};
|
||||
let block = Block::new(
|
||||
embed_dim,
|
||||
num_heads,
|
||||
qkv_bias,
|
||||
use_rel_pos,
|
||||
window_size,
|
||||
(img_size / patch_size, img_size / patch_size),
|
||||
vb_b.pp(i),
|
||||
)?;
|
||||
blocks.push(block)
|
||||
}
|
||||
let neck_conv1 = candle_nn::conv2d_no_bias(
|
||||
embed_dim,
|
||||
out_chans,
|
||||
1,
|
||||
Default::default(),
|
||||
vb.pp("neck.0"),
|
||||
)?;
|
||||
let neck_ln1 = crate::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.1"))?;
|
||||
let cfg = candle_nn::Conv2dConfig {
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?;
|
||||
let neck_ln2 = crate::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.3"))?;
|
||||
let pos_embed = if use_abs_pos {
|
||||
let p = vb.get(
|
||||
(1, img_size / patch_size, img_size / patch_size, embed_dim),
|
||||
"pos_embed",
|
||||
)?;
|
||||
Some(p)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let span = tracing::span!(tracing::Level::TRACE, "image-encoder-vit");
|
||||
Ok(Self {
|
||||
patch_embed,
|
||||
blocks,
|
||||
neck_conv1,
|
||||
neck_ln1,
|
||||
neck_conv2,
|
||||
neck_ln2,
|
||||
pos_embed,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ImageEncoderViT {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = self.patch_embed.forward(xs)?;
|
||||
let mut xs = match &self.pos_embed {
|
||||
Some(pos_embed) => (xs + pos_embed)?,
|
||||
None => xs,
|
||||
};
|
||||
for block in self.blocks.iter() {
|
||||
xs = block.forward(&xs)?
|
||||
}
|
||||
xs.permute((0, 3, 1, 2))?
|
||||
.apply(&self.neck_conv1)?
|
||||
.apply(&self.neck_ln1)?
|
||||
.apply(&self.neck_conv2)?
|
||||
.apply(&self.neck_ln2)
|
||||
}
|
||||
}
|
@ -1,239 +0,0 @@
|
||||
use candle::{IndexOp, Result, Tensor};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
|
||||
use crate::model_transformer::TwoWayTransformer;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MlpMaskDecoder {
|
||||
layers: Vec<crate::Linear>,
|
||||
sigmoid_output: bool,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl MlpMaskDecoder {
|
||||
fn new(
|
||||
input_dim: usize,
|
||||
hidden_dim: usize,
|
||||
output_dim: usize,
|
||||
num_layers: usize,
|
||||
sigmoid_output: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let mut layers = Vec::with_capacity(num_layers);
|
||||
let vb = vb.pp("layers");
|
||||
for i in 0..num_layers {
|
||||
let in_dim = if i == 0 { input_dim } else { hidden_dim };
|
||||
let out_dim = if i + 1 == num_layers {
|
||||
output_dim
|
||||
} else {
|
||||
hidden_dim
|
||||
};
|
||||
let layer = crate::linear(vb.pp(i), in_dim, out_dim, true)?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let span = tracing::span!(tracing::Level::TRACE, "mlp-mask-decoder");
|
||||
Ok(Self {
|
||||
layers,
|
||||
sigmoid_output,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MlpMaskDecoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = xs.clone();
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
xs = layer.forward(&xs)?;
|
||||
if i + 1 < self.layers.len() {
|
||||
xs = xs.relu()?
|
||||
}
|
||||
}
|
||||
if self.sigmoid_output {
|
||||
candle_nn::ops::sigmoid(&xs)
|
||||
} else {
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MaskDecoder {
|
||||
iou_token: candle_nn::Embedding,
|
||||
mask_tokens: candle_nn::Embedding,
|
||||
iou_prediction_head: MlpMaskDecoder,
|
||||
output_upscaling_conv1: candle_nn::ConvTranspose2d,
|
||||
output_upscaling_ln: crate::LayerNorm2d,
|
||||
output_upscaling_conv2: candle_nn::ConvTranspose2d,
|
||||
num_mask_tokens: usize,
|
||||
output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
|
||||
transformer: TwoWayTransformer,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl MaskDecoder {
|
||||
pub fn new(
|
||||
transformer_dim: usize,
|
||||
num_multimask_outputs: usize,
|
||||
iou_head_depth: usize,
|
||||
iou_head_hidden_dim: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let num_mask_tokens = num_multimask_outputs + 1;
|
||||
let iou_prediction_head = MlpMaskDecoder::new(
|
||||
transformer_dim,
|
||||
iou_head_hidden_dim,
|
||||
num_mask_tokens,
|
||||
iou_head_depth,
|
||||
false,
|
||||
vb.pp("iou_prediction_head"),
|
||||
)?;
|
||||
let iou_token = candle_nn::embedding(1, transformer_dim, vb.pp("iou_token"))?;
|
||||
let mask_tokens =
|
||||
candle_nn::embedding(num_mask_tokens, transformer_dim, vb.pp("mask_tokens"))?;
|
||||
let cfg = candle_nn::ConvTranspose2dConfig {
|
||||
stride: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let output_upscaling_conv1 = candle_nn::conv_transpose2d(
|
||||
transformer_dim,
|
||||
transformer_dim / 4,
|
||||
2,
|
||||
cfg,
|
||||
vb.pp("output_upscaling.0"),
|
||||
)?;
|
||||
let output_upscaling_ln =
|
||||
crate::LayerNorm2d::new(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?;
|
||||
let output_upscaling_conv2 = candle_nn::conv_transpose2d(
|
||||
transformer_dim / 4,
|
||||
transformer_dim / 8,
|
||||
2,
|
||||
cfg,
|
||||
vb.pp("output_upscaling.3"),
|
||||
)?;
|
||||
let mut output_hypernetworks_mlps = Vec::with_capacity(num_mask_tokens);
|
||||
let vb_o = vb.pp("output_hypernetworks_mlps");
|
||||
for i in 0..num_mask_tokens {
|
||||
let mlp = MlpMaskDecoder::new(
|
||||
transformer_dim,
|
||||
transformer_dim,
|
||||
transformer_dim / 8,
|
||||
3,
|
||||
false,
|
||||
vb_o.pp(i),
|
||||
)?;
|
||||
output_hypernetworks_mlps.push(mlp)
|
||||
}
|
||||
let transformer = TwoWayTransformer::new(
|
||||
/* depth */ 2,
|
||||
/* embedding_dim */ transformer_dim,
|
||||
/* num_heads */ 8,
|
||||
/* mlp_dim */ 2048,
|
||||
vb.pp("transformer"),
|
||||
)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "mask-decoder");
|
||||
Ok(Self {
|
||||
iou_token,
|
||||
mask_tokens,
|
||||
iou_prediction_head,
|
||||
output_upscaling_conv1,
|
||||
output_upscaling_ln,
|
||||
output_upscaling_conv2,
|
||||
num_mask_tokens,
|
||||
output_hypernetworks_mlps,
|
||||
transformer,
|
||||
span,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
image_embeddings: &Tensor,
|
||||
image_pe: &Tensor,
|
||||
sparse_prompt_embeddings: &Tensor,
|
||||
dense_prompt_embeddings: &Tensor,
|
||||
multimask_output: bool,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let _enter = self.span.enter();
|
||||
let (masks, iou_pred) = self.predict_masks(
|
||||
image_embeddings,
|
||||
image_pe,
|
||||
sparse_prompt_embeddings,
|
||||
dense_prompt_embeddings,
|
||||
)?;
|
||||
let masks = if multimask_output {
|
||||
masks.i((.., 1..))?
|
||||
} else {
|
||||
masks.i((.., 0..1))?
|
||||
};
|
||||
let iou_pred = if multimask_output {
|
||||
iou_pred.i((.., 1..))?
|
||||
} else {
|
||||
iou_pred.i((.., 0..1))?
|
||||
};
|
||||
Ok((masks, iou_pred))
|
||||
}
|
||||
|
||||
fn predict_masks(
|
||||
&self,
|
||||
image_embeddings: &Tensor,
|
||||
image_pe: &Tensor,
|
||||
sparse_prompt_embeddings: &Tensor,
|
||||
dense_prompt_embeddings: &Tensor,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
// Concatenate ouput tokens.
|
||||
let output_tokens = Tensor::cat(
|
||||
&[self.iou_token.embeddings(), self.mask_tokens.embeddings()],
|
||||
0,
|
||||
)?;
|
||||
let (d1, d2) = output_tokens.dims2()?;
|
||||
let output_tokens =
|
||||
output_tokens
|
||||
.unsqueeze(0)?
|
||||
.expand((sparse_prompt_embeddings.dim(0)?, d1, d2))?;
|
||||
let tokens = Tensor::cat(&[&output_tokens, sparse_prompt_embeddings], 1)?;
|
||||
|
||||
// Expand per-image data in batch direction to be per mask
|
||||
let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?;
|
||||
let src = src.broadcast_add(dense_prompt_embeddings)?;
|
||||
let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?;
|
||||
let (b, c, h, w) = src.dims4()?;
|
||||
|
||||
// Run the transformer
|
||||
let (hs, src) = self.transformer.forward(&src, &pos_src, &tokens)?;
|
||||
let iou_token_out = hs.i((.., 0))?;
|
||||
let mask_tokens_out = hs.i((.., 1..1 + self.num_mask_tokens))?;
|
||||
|
||||
// Upscale mask embeddings and predict masks using the masks tokens.
|
||||
let src = src.transpose(1, 2)?.reshape((b, c, h, w))?;
|
||||
let upscaled_embedding = self
|
||||
.output_upscaling_conv1
|
||||
.forward(&src)?
|
||||
.apply(&self.output_upscaling_ln)?
|
||||
.gelu()?
|
||||
.apply(&self.output_upscaling_conv2)?
|
||||
.gelu()?;
|
||||
let mut hyper_in_list = Vec::with_capacity(self.num_mask_tokens);
|
||||
for (i, mlp) in self.output_hypernetworks_mlps.iter().enumerate() {
|
||||
let h = mlp.forward(&mask_tokens_out.i((.., i))?)?;
|
||||
hyper_in_list.push(h)
|
||||
}
|
||||
let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?.contiguous()?;
|
||||
let (b, c, h, w) = upscaled_embedding.dims4()?;
|
||||
let masks = hyper_in.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?;
|
||||
let masks = masks.reshape((b, (), h, w))?;
|
||||
|
||||
// Generate mask quality predictions.
|
||||
let iou_pred = self.iou_prediction_head.forward(&iou_token_out)?;
|
||||
Ok((masks, iou_pred))
|
||||
}
|
||||
}
|
||||
|
||||
// Equivalent to torch.repeat_interleave
|
||||
fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> {
|
||||
let img = img.unsqueeze(dim + 1)?;
|
||||
let mut dims = img.dims().to_vec();
|
||||
dims[dim + 1] = repeats;
|
||||
img.broadcast_as(dims)?.flatten(dim, dim + 1)
|
||||
}
|
@ -1,239 +0,0 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PostionEmbeddingRandom {
|
||||
positional_encoding_gaussian_matrix: Tensor,
|
||||
}
|
||||
|
||||
impl PostionEmbeddingRandom {
|
||||
fn new(num_pos_feats: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let positional_encoding_gaussian_matrix =
|
||||
vb.get((2, num_pos_feats), "positional_encoding_gaussian_matrix")?;
|
||||
Ok(Self {
|
||||
positional_encoding_gaussian_matrix,
|
||||
})
|
||||
}
|
||||
|
||||
fn pe_encoding(&self, coords: &Tensor) -> Result<Tensor> {
|
||||
let coords = coords.affine(2., -1.)?;
|
||||
let coords = coords.broadcast_matmul(&self.positional_encoding_gaussian_matrix)?;
|
||||
let coords = (coords * (2. * std::f64::consts::PI))?;
|
||||
Tensor::cat(&[coords.sin()?, coords.cos()?], D::Minus1)
|
||||
}
|
||||
|
||||
fn forward(&self, h: usize, w: usize) -> Result<Tensor> {
|
||||
let device = self.positional_encoding_gaussian_matrix.device();
|
||||
let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
|
||||
let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
|
||||
let x_embed = (x_embed / w as f64)?
|
||||
.reshape((1, ()))?
|
||||
.broadcast_as((h, w))?;
|
||||
let y_embed = (y_embed / h as f64)?
|
||||
.reshape(((), 1))?
|
||||
.broadcast_as((h, w))?;
|
||||
let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?;
|
||||
self.pe_encoding(&coords)?.permute((2, 0, 1))
|
||||
}
|
||||
|
||||
fn forward_with_coords(
|
||||
&self,
|
||||
coords_input: &Tensor,
|
||||
image_size: (usize, usize),
|
||||
) -> Result<Tensor> {
|
||||
let coords0 = (coords_input.narrow(D::Minus1, 0, 1)? / image_size.1 as f64)?;
|
||||
let coords1 = (coords_input.narrow(D::Minus1, 1, 1)? / image_size.0 as f64)?;
|
||||
let c = coords_input.dim(D::Minus1)?;
|
||||
let coords_rest = coords_input.narrow(D::Minus1, 2, c - 2)?;
|
||||
let coords = Tensor::cat(&[&coords0, &coords1, &coords_rest], D::Minus1)?;
|
||||
self.pe_encoding(&coords)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PromptEncoder {
|
||||
pe_layer: PostionEmbeddingRandom,
|
||||
point_embeddings: Vec<candle_nn::Embedding>,
|
||||
not_a_point_embed: candle_nn::Embedding,
|
||||
mask_downscaling_conv1: candle_nn::Conv2d,
|
||||
mask_downscaling_ln1: crate::LayerNorm2d,
|
||||
mask_downscaling_conv2: candle_nn::Conv2d,
|
||||
mask_downscaling_ln2: crate::LayerNorm2d,
|
||||
mask_downscaling_conv3: candle_nn::Conv2d,
|
||||
no_mask_embed: candle_nn::Embedding,
|
||||
image_embedding_size: (usize, usize),
|
||||
input_image_size: (usize, usize),
|
||||
embed_dim: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl PromptEncoder {
|
||||
pub fn new(
|
||||
embed_dim: usize,
|
||||
image_embedding_size: (usize, usize),
|
||||
input_image_size: (usize, usize),
|
||||
mask_in_chans: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let num_points_embeddings = 4;
|
||||
let pe_layer = PostionEmbeddingRandom::new(embed_dim / 2, vb.pp("pe_layer"))?;
|
||||
let not_a_point_embed = candle_nn::embedding(1, embed_dim, vb.pp("not_a_point_embed"))?;
|
||||
let no_mask_embed = candle_nn::embedding(1, embed_dim, vb.pp("no_mask_embed"))?;
|
||||
let cfg = candle_nn::Conv2dConfig {
|
||||
stride: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let mask_downscaling_conv1 =
|
||||
candle_nn::conv2d(1, mask_in_chans / 4, 2, cfg, vb.pp("mask_downscaling.0"))?;
|
||||
let mask_downscaling_conv2 = candle_nn::conv2d(
|
||||
mask_in_chans / 4,
|
||||
mask_in_chans,
|
||||
2,
|
||||
cfg,
|
||||
vb.pp("mask_downscaling.3"),
|
||||
)?;
|
||||
let mask_downscaling_conv3 = candle_nn::conv2d(
|
||||
mask_in_chans,
|
||||
embed_dim,
|
||||
1,
|
||||
Default::default(),
|
||||
vb.pp("mask_downscaling.6"),
|
||||
)?;
|
||||
let mask_downscaling_ln1 =
|
||||
crate::LayerNorm2d::new(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?;
|
||||
let mask_downscaling_ln2 =
|
||||
crate::LayerNorm2d::new(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?;
|
||||
let mut point_embeddings = Vec::with_capacity(num_points_embeddings);
|
||||
let vb_e = vb.pp("point_embeddings");
|
||||
for i in 0..num_points_embeddings {
|
||||
let emb = candle_nn::embedding(1, embed_dim, vb_e.pp(i))?;
|
||||
point_embeddings.push(emb)
|
||||
}
|
||||
let span = tracing::span!(tracing::Level::TRACE, "prompt-encoder");
|
||||
Ok(Self {
|
||||
pe_layer,
|
||||
point_embeddings,
|
||||
not_a_point_embed,
|
||||
mask_downscaling_conv1,
|
||||
mask_downscaling_ln1,
|
||||
mask_downscaling_conv2,
|
||||
mask_downscaling_ln2,
|
||||
mask_downscaling_conv3,
|
||||
no_mask_embed,
|
||||
image_embedding_size,
|
||||
input_image_size,
|
||||
embed_dim,
|
||||
span,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_dense_pe(&self) -> Result<Tensor> {
|
||||
self.pe_layer
|
||||
.forward(self.image_embedding_size.0, self.image_embedding_size.1)?
|
||||
.unsqueeze(0)
|
||||
}
|
||||
|
||||
fn embed_masks(&self, masks: &Tensor) -> Result<Tensor> {
|
||||
masks
|
||||
.apply(&self.mask_downscaling_conv1)?
|
||||
.apply(&self.mask_downscaling_ln1)?
|
||||
.gelu()?
|
||||
.apply(&self.mask_downscaling_conv2)?
|
||||
.apply(&self.mask_downscaling_ln2)?
|
||||
.gelu()?
|
||||
.apply(&self.mask_downscaling_conv3)
|
||||
}
|
||||
|
||||
fn embed_points(&self, points: &Tensor, labels: &Tensor, pad: bool) -> Result<Tensor> {
|
||||
let points = (points + 0.5)?;
|
||||
let dev = points.device();
|
||||
let (points, labels) = if pad {
|
||||
let padding_point = Tensor::zeros((points.dim(0)?, 1, 2), DType::F32, dev)?;
|
||||
let padding_label = (Tensor::ones((labels.dim(0)?, 1), DType::F32, dev)? * (-1f64))?;
|
||||
let points = Tensor::cat(&[&points, &padding_point], 1)?;
|
||||
let labels = Tensor::cat(&[labels, &padding_label], 1)?;
|
||||
(points, labels)
|
||||
} else {
|
||||
(points, labels.clone())
|
||||
};
|
||||
let point_embedding = self
|
||||
.pe_layer
|
||||
.forward_with_coords(&points, self.input_image_size)?;
|
||||
let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?;
|
||||
let zeros = point_embedding.zeros_like()?;
|
||||
let point_embedding = labels.lt(0f32)?.where_cond(
|
||||
&self
|
||||
.not_a_point_embed
|
||||
.embeddings()
|
||||
.broadcast_as(zeros.shape())?,
|
||||
&point_embedding,
|
||||
)?;
|
||||
let labels0 = labels.eq(0f32)?.where_cond(
|
||||
&self.point_embeddings[0]
|
||||
.embeddings()
|
||||
.broadcast_as(zeros.shape())?,
|
||||
&zeros,
|
||||
)?;
|
||||
let point_embedding = (point_embedding + labels0)?;
|
||||
let labels1 = labels.eq(1f32)?.where_cond(
|
||||
&self.point_embeddings[1]
|
||||
.embeddings()
|
||||
.broadcast_as(zeros.shape())?,
|
||||
&zeros,
|
||||
)?;
|
||||
let point_embedding = (point_embedding + labels1)?;
|
||||
Ok(point_embedding)
|
||||
}
|
||||
|
||||
fn embed_boxes(&self, boxes: &Tensor) -> Result<Tensor> {
|
||||
let boxes = (boxes + 0.5)?;
|
||||
let coords = boxes.reshape(((), 2, 2))?;
|
||||
let corner_embedding = self
|
||||
.pe_layer
|
||||
.forward_with_coords(&coords, self.input_image_size)?;
|
||||
let ce1 = corner_embedding.i((.., 0))?;
|
||||
let ce2 = corner_embedding.i((.., 1))?;
|
||||
let ce1 = (ce1 + self.point_embeddings[2].embeddings())?;
|
||||
let ce2 = (ce2 + self.point_embeddings[3].embeddings())?;
|
||||
Tensor::cat(&[&ce1, &ce2], 1)
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
points: Option<(&Tensor, &Tensor)>,
|
||||
boxes: Option<&Tensor>,
|
||||
masks: Option<&Tensor>,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let _enter = self.span.enter();
|
||||
let se_points = match points {
|
||||
Some((coords, labels)) => Some(self.embed_points(coords, labels, boxes.is_none())?),
|
||||
None => None,
|
||||
};
|
||||
let se_boxes = match boxes {
|
||||
Some(boxes) => Some(self.embed_boxes(boxes)?),
|
||||
None => None,
|
||||
};
|
||||
let sparse_embeddings = match (se_points, se_boxes) {
|
||||
(Some(se_points), Some(se_boxes)) => Tensor::cat(&[se_points, se_boxes], 1)?,
|
||||
(Some(se_points), None) => se_points,
|
||||
(None, Some(se_boxes)) => se_boxes,
|
||||
(None, None) => {
|
||||
Tensor::zeros((1, 0, self.embed_dim), DType::F32, &candle::Device::Cpu)?
|
||||
}
|
||||
};
|
||||
|
||||
let dense_embeddings = match masks {
|
||||
None => {
|
||||
let emb = self.no_mask_embed.embeddings();
|
||||
emb.reshape((1, (), 1, 1))?.expand((
|
||||
1,
|
||||
emb.elem_count(),
|
||||
self.image_embedding_size.0,
|
||||
self.image_embedding_size.1,
|
||||
))?
|
||||
}
|
||||
Some(masks) => self.embed_masks(masks)?,
|
||||
};
|
||||
Ok((sparse_embeddings, dense_embeddings))
|
||||
}
|
||||
}
|
@ -1,364 +0,0 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
|
||||
use crate::model_image_encoder::ImageEncoderViT;
|
||||
use crate::model_mask_decoder::MaskDecoder;
|
||||
use crate::model_prompt_encoder::PromptEncoder;
|
||||
|
||||
const PROMPT_EMBED_DIM: usize = 256;
|
||||
pub const IMAGE_SIZE: usize = 1024;
|
||||
const VIT_PATCH_SIZE: usize = 16;
|
||||
const PRED_IOU_THRESH: f32 = 0.88;
|
||||
const STABILITY_SCORE_OFFSET: f32 = 1.0;
|
||||
const STABILITY_SCORE_THRESHOLD: f32 = 0.95;
|
||||
const MODEL_MASK_THRESHOLD: f32 = 0.0;
|
||||
const CROP_NMS_THRESH: f32 = 0.7;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Sam {
|
||||
image_encoder: ImageEncoderViT,
|
||||
prompt_encoder: PromptEncoder,
|
||||
mask_decoder: MaskDecoder,
|
||||
pixel_mean: Tensor,
|
||||
pixel_std: Tensor,
|
||||
}
|
||||
|
||||
impl Sam {
|
||||
pub fn new(
|
||||
encoder_embed_dim: usize,
|
||||
encoder_depth: usize,
|
||||
encoder_num_heads: usize,
|
||||
encoder_global_attn_indexes: &[usize],
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE;
|
||||
|
||||
let image_encoder = ImageEncoderViT::new(
|
||||
IMAGE_SIZE,
|
||||
VIT_PATCH_SIZE,
|
||||
3,
|
||||
encoder_embed_dim,
|
||||
encoder_depth,
|
||||
encoder_num_heads,
|
||||
PROMPT_EMBED_DIM,
|
||||
/* qkv_bias */ true,
|
||||
/* use_rel_pos */ true,
|
||||
/* use_abs_pos */ true,
|
||||
/* window_size */ 14,
|
||||
/* global_attn_indexes */ encoder_global_attn_indexes,
|
||||
vb.pp("image_encoder"),
|
||||
)?;
|
||||
let prompt_encoder = PromptEncoder::new(
|
||||
PROMPT_EMBED_DIM,
|
||||
(image_embedding_size, image_embedding_size),
|
||||
(IMAGE_SIZE, IMAGE_SIZE),
|
||||
16,
|
||||
vb.pp("prompt_encoder"),
|
||||
)?;
|
||||
let mask_decoder = MaskDecoder::new(
|
||||
PROMPT_EMBED_DIM,
|
||||
/* num_multitask_outputs */ 3,
|
||||
/* iou_head_depth */ 3,
|
||||
/* iou_head_hidden_dim */ 256,
|
||||
vb.pp("mask_decoder"),
|
||||
)?;
|
||||
let pixel_mean =
|
||||
Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?;
|
||||
let pixel_std =
|
||||
Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
|
||||
Ok(Self {
|
||||
image_encoder,
|
||||
prompt_encoder,
|
||||
mask_decoder,
|
||||
pixel_std,
|
||||
pixel_mean,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
img: &Tensor,
|
||||
point: Option<(f64, f64)>,
|
||||
multimask_output: bool,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_c, original_h, original_w) = img.dims3()?;
|
||||
let img = self.preprocess(img)?.unsqueeze(0)?;
|
||||
let img_embeddings = self.image_encoder.forward(&img)?;
|
||||
let image_pe = self.prompt_encoder.get_dense_pe()?;
|
||||
let points = match point {
|
||||
None => None,
|
||||
Some((x, y)) => {
|
||||
let points = Tensor::new(
|
||||
&[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]],
|
||||
img.device(),
|
||||
)?;
|
||||
let labels = Tensor::ones((1, 1), DType::F32, img.device())?;
|
||||
Some((points, labels))
|
||||
}
|
||||
};
|
||||
let points = points.as_ref().map(|(x, y)| (x, y));
|
||||
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
|
||||
self.prompt_encoder.forward(points, None, None)?;
|
||||
let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
|
||||
&img_embeddings,
|
||||
&image_pe,
|
||||
&sparse_prompt_embeddings,
|
||||
&dense_prompt_embeddings,
|
||||
multimask_output,
|
||||
)?;
|
||||
let mask = low_res_mask
|
||||
.upsample_nearest2d(IMAGE_SIZE, IMAGE_SIZE)?
|
||||
.get(0)?
|
||||
.i((.., ..original_h, ..original_w))?;
|
||||
Ok((mask, iou_predictions))
|
||||
}
|
||||
|
||||
pub fn unpreprocess(&self, img: &Tensor) -> Result<Tensor> {
|
||||
let img = img
|
||||
.broadcast_mul(&self.pixel_std)?
|
||||
.broadcast_add(&self.pixel_mean)?;
|
||||
img.maximum(&img.zeros_like()?)?
|
||||
.minimum(&(img.ones_like()? * 255.)?)
|
||||
}
|
||||
|
||||
pub fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
|
||||
let (_c, h, w) = img.dims3()?;
|
||||
let img = img
|
||||
.to_dtype(DType::F32)?
|
||||
.broadcast_sub(&self.pixel_mean)?
|
||||
.broadcast_div(&self.pixel_std)?;
|
||||
if h > IMAGE_SIZE || w > IMAGE_SIZE {
|
||||
candle::bail!("image is too large ({w}, {h}), maximum size {IMAGE_SIZE}")
|
||||
}
|
||||
let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?;
|
||||
img.pad_with_zeros(2, 0, IMAGE_SIZE - w)
|
||||
}
|
||||
|
||||
fn process_crop(
|
||||
&self,
|
||||
img: &Tensor,
|
||||
cb: CropBox,
|
||||
point_grids: &[(f64, f64)],
|
||||
) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> {
|
||||
// Crop the image and calculate embeddings.
|
||||
let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?;
|
||||
let img = self.preprocess(&img)?.unsqueeze(0)?;
|
||||
let img_embeddings = self.image_encoder.forward(&img)?;
|
||||
|
||||
let crop_w = cb.x1 - cb.x0;
|
||||
let crop_h = cb.y1 - cb.y0;
|
||||
|
||||
// Generate masks for this crop.
|
||||
let image_pe = self.prompt_encoder.get_dense_pe()?;
|
||||
let points = point_grids
|
||||
.iter()
|
||||
.map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32])
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut bboxes = Vec::new();
|
||||
for points in points.chunks(64) {
|
||||
// Run the model on this batch.
|
||||
let points_len = points.len();
|
||||
let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?;
|
||||
let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?;
|
||||
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
|
||||
self.prompt_encoder
|
||||
.forward(Some((&in_points, &in_labels)), None, None)?;
|
||||
|
||||
let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
|
||||
&img_embeddings,
|
||||
&image_pe,
|
||||
&sparse_prompt_embeddings,
|
||||
&dense_prompt_embeddings,
|
||||
/* multimask_output */ true,
|
||||
)?;
|
||||
let low_res_mask = low_res_mask.flatten(0, 1)?;
|
||||
let iou_predictions = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?;
|
||||
let dev = low_res_mask.device();
|
||||
|
||||
for (i, iou) in iou_predictions.iter().enumerate() {
|
||||
// Filter by predicted IoU.
|
||||
if *iou < PRED_IOU_THRESH {
|
||||
continue;
|
||||
}
|
||||
let low_res_mask = low_res_mask.get(i)?;
|
||||
|
||||
// Calculate stability score.
|
||||
let bound = Tensor::new(MODEL_MASK_THRESHOLD + STABILITY_SCORE_OFFSET, dev)?
|
||||
.broadcast_as(low_res_mask.shape())?;
|
||||
let intersections = low_res_mask
|
||||
.ge(&bound)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
let bound = Tensor::new(MODEL_MASK_THRESHOLD - STABILITY_SCORE_OFFSET, dev)?
|
||||
.broadcast_as(low_res_mask.shape())?;
|
||||
let unions = low_res_mask
|
||||
.ge(&bound)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
let stability_score = intersections / unions;
|
||||
if stability_score < STABILITY_SCORE_THRESHOLD {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Threshold masks and calculate boxes.
|
||||
let low_res_mask = low_res_mask
|
||||
.ge(&Tensor::new(0f32, dev)?.broadcast_as(low_res_mask.shape())?)?
|
||||
.to_dtype(DType::U32)?;
|
||||
let low_res_mask_per_x = low_res_mask.sum(0)?.to_vec1::<u32>()?;
|
||||
let low_res_mask_per_y = low_res_mask.sum(1)?.to_vec1::<u32>()?;
|
||||
let min_max_x = min_max_indexes(&low_res_mask_per_x);
|
||||
let min_max_y = min_max_indexes(&low_res_mask_per_y);
|
||||
if let Some(((x0, x1), (y0, y1))) = min_max_x.zip(min_max_y) {
|
||||
let bbox = candle_examples::object_detection::Bbox {
|
||||
xmin: x0 as f32,
|
||||
ymin: y0 as f32,
|
||||
xmax: x1 as f32,
|
||||
ymax: y1 as f32,
|
||||
confidence: *iou,
|
||||
data: low_res_mask,
|
||||
};
|
||||
bboxes.push(bbox);
|
||||
}
|
||||
// TODO:
|
||||
// Filter boxes that touch crop boundaries
|
||||
// Compress to RLE.
|
||||
}
|
||||
}
|
||||
|
||||
let mut bboxes = vec![bboxes];
|
||||
// Remove duplicates within this crop.
|
||||
candle_examples::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH);
|
||||
|
||||
// TODO: Return to the original image frame.
|
||||
Ok(bboxes.remove(0))
|
||||
}
|
||||
|
||||
pub fn generate_masks(
|
||||
&self,
|
||||
img: &Tensor,
|
||||
points_per_side: usize,
|
||||
crop_n_layer: usize,
|
||||
crop_overlap_ratio: f64,
|
||||
crop_n_points_downscale_factor: usize,
|
||||
) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> {
|
||||
let (_c, h, w) = img.dims3()?;
|
||||
let point_grids = build_all_layer_point_grids(
|
||||
points_per_side,
|
||||
crop_n_layer,
|
||||
crop_n_points_downscale_factor,
|
||||
);
|
||||
let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio);
|
||||
let mut bboxes = Vec::new();
|
||||
for crop_box in crop_boxes.into_iter() {
|
||||
let layer_idx = crop_box.layer_idx;
|
||||
let b = self.process_crop(img, crop_box, &point_grids[layer_idx])?;
|
||||
bboxes.extend(b)
|
||||
}
|
||||
// TODO: remove duplicates
|
||||
Ok(bboxes)
|
||||
}
|
||||
}
|
||||
|
||||
// Return the first and last indexes i for which values[i] > 0
|
||||
fn min_max_indexes(values: &[u32]) -> Option<(usize, usize)> {
|
||||
let (mut min_i, mut max_i) = (usize::MAX, usize::MIN);
|
||||
for (i, &s) in values.iter().enumerate() {
|
||||
if s == 0 {
|
||||
continue;
|
||||
}
|
||||
min_i = usize::min(i, min_i);
|
||||
max_i = usize::max(i, max_i);
|
||||
}
|
||||
if max_i < min_i {
|
||||
None
|
||||
} else {
|
||||
Some((min_i, max_i))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CropBox {
|
||||
x0: usize,
|
||||
y0: usize,
|
||||
x1: usize,
|
||||
y1: usize,
|
||||
layer_idx: usize,
|
||||
}
|
||||
|
||||
impl CropBox {
|
||||
fn new(x0: usize, y0: usize, x1: usize, y1: usize, layer_idx: usize) -> Self {
|
||||
Self {
|
||||
x0,
|
||||
y0,
|
||||
x1,
|
||||
y1,
|
||||
layer_idx,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_crop_boxes(
|
||||
(im_h, im_w): (usize, usize),
|
||||
n_layers: usize,
|
||||
overlap_ratio: f64,
|
||||
) -> Vec<CropBox> {
|
||||
fn crop_len(orig_len: usize, n_crops: usize, overlap: usize) -> usize {
|
||||
f64::ceil((overlap * (n_crops - 1) + orig_len) as f64 / n_crops as f64) as usize
|
||||
}
|
||||
|
||||
let short_side = usize::min(im_h, im_w);
|
||||
|
||||
let mut crop_boxes = Vec::new();
|
||||
|
||||
// Original image.
|
||||
crop_boxes.push(CropBox::new(0, 0, im_w, im_h, 0));
|
||||
|
||||
for layer_idx in 1..=n_layers {
|
||||
let n_crops_per_side = 1 << layer_idx;
|
||||
let overlap = (overlap_ratio * short_side as f64 * 2. / n_crops_per_side as f64) as usize;
|
||||
let crop_w = crop_len(im_w, n_crops_per_side, overlap);
|
||||
let crop_h = crop_len(im_w, n_crops_per_side, overlap);
|
||||
|
||||
for i_x in 0..n_crops_per_side {
|
||||
let x0 = (crop_w - overlap) * i_x;
|
||||
for i_y in 0..n_crops_per_side {
|
||||
let y0 = (crop_h - overlap) * i_y;
|
||||
let x1 = usize::min(im_w, x0 + crop_w);
|
||||
let y1 = usize::min(im_h, y0 + crop_h);
|
||||
crop_boxes.push(CropBox::new(x0, y0, x1, y1, layer_idx));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
crop_boxes
|
||||
}
|
||||
|
||||
// Generates a 2D grid of points evenly spaced in [0,1]x[0,1].
|
||||
fn build_point_grid(n_per_side: usize) -> Vec<(f64, f64)> {
|
||||
let offset = 1f64 / (2 * n_per_side) as f64;
|
||||
let mut points = Vec::with_capacity(n_per_side * n_per_side);
|
||||
for i_x in 0..n_per_side {
|
||||
let x = offset + i_x as f64 / n_per_side as f64;
|
||||
for i_y in 0..n_per_side {
|
||||
let y = offset + i_y as f64 / n_per_side as f64;
|
||||
points.push((x, y))
|
||||
}
|
||||
}
|
||||
points
|
||||
}
|
||||
|
||||
fn build_all_layer_point_grids(
|
||||
n_per_side: usize,
|
||||
n_layers: usize,
|
||||
scale_per_layer: usize,
|
||||
) -> Vec<Vec<(f64, f64)>> {
|
||||
let mut points_by_layer = Vec::with_capacity(n_layers + 1);
|
||||
for i in 0..=n_layers {
|
||||
let n_points = n_per_side / scale_per_layer.pow(i as u32);
|
||||
points_by_layer.push(build_point_grid(n_points))
|
||||
}
|
||||
points_by_layer
|
||||
}
|
@ -1,221 +0,0 @@
|
||||
use candle::{Result, Tensor};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
out_proj: Linear,
|
||||
num_heads: usize,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(
|
||||
embedding_dim: usize,
|
||||
num_heads: usize,
|
||||
downsample_rate: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let internal_dim = embedding_dim / downsample_rate;
|
||||
let q_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("q_proj"))?;
|
||||
let k_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("k_proj"))?;
|
||||
let v_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("v_proj"))?;
|
||||
let out_proj = candle_nn::linear(internal_dim, embedding_dim, vb.pp("out_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
out_proj,
|
||||
num_heads,
|
||||
})
|
||||
}
|
||||
|
||||
fn separate_heads(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (b, n, c) = x.dims3()?;
|
||||
x.reshape((b, n, self.num_heads, c / self.num_heads))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()
|
||||
}
|
||||
|
||||
fn recombine_heads(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (b, n_heads, n_tokens, c_per_head) = x.dims4()?;
|
||||
x.transpose(1, 2)?
|
||||
.reshape((b, n_tokens, n_heads * c_per_head))
|
||||
}
|
||||
|
||||
fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
||||
let q = self.q_proj.forward(&q.contiguous()?)?;
|
||||
let k = self.k_proj.forward(&k.contiguous()?)?;
|
||||
let v = self.v_proj.forward(&v.contiguous()?)?;
|
||||
|
||||
let q = self.separate_heads(&q)?;
|
||||
let k = self.separate_heads(&k)?;
|
||||
let v = self.separate_heads(&v)?;
|
||||
|
||||
let (_, _, _, c_per_head) = q.dims4()?;
|
||||
let attn = (q.matmul(&k.t()?)? / (c_per_head as f64).sqrt())?;
|
||||
let attn = candle_nn::ops::softmax_last_dim(&attn)?;
|
||||
|
||||
let out = attn.matmul(&v)?;
|
||||
self.recombine_heads(&out)?.apply(&self.out_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TwoWayAttentionBlock {
|
||||
self_attn: Attention,
|
||||
norm1: LayerNorm,
|
||||
cross_attn_token_to_image: Attention,
|
||||
norm2: LayerNorm,
|
||||
mlp: crate::MlpBlock,
|
||||
norm3: LayerNorm,
|
||||
norm4: LayerNorm,
|
||||
cross_attn_image_to_token: Attention,
|
||||
skip_first_layer_pe: bool,
|
||||
}
|
||||
|
||||
impl TwoWayAttentionBlock {
|
||||
fn new(
|
||||
embedding_dim: usize,
|
||||
num_heads: usize,
|
||||
mlp_dim: usize,
|
||||
skip_first_layer_pe: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?;
|
||||
let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?;
|
||||
let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?;
|
||||
let norm4 = layer_norm(embedding_dim, 1e-5, vb.pp("norm4"))?;
|
||||
let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?;
|
||||
let cross_attn_token_to_image = Attention::new(
|
||||
embedding_dim,
|
||||
num_heads,
|
||||
2,
|
||||
vb.pp("cross_attn_token_to_image"),
|
||||
)?;
|
||||
let cross_attn_image_to_token = Attention::new(
|
||||
embedding_dim,
|
||||
num_heads,
|
||||
2,
|
||||
vb.pp("cross_attn_image_to_token"),
|
||||
)?;
|
||||
let mlp = crate::MlpBlock::new(
|
||||
embedding_dim,
|
||||
mlp_dim,
|
||||
candle_nn::Activation::Relu,
|
||||
vb.pp("mlp"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
norm1,
|
||||
cross_attn_image_to_token,
|
||||
norm2,
|
||||
mlp,
|
||||
norm3,
|
||||
norm4,
|
||||
cross_attn_token_to_image,
|
||||
skip_first_layer_pe,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
queries: &Tensor,
|
||||
keys: &Tensor,
|
||||
query_pe: &Tensor,
|
||||
key_pe: &Tensor,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
// Self attention block
|
||||
let queries = if self.skip_first_layer_pe {
|
||||
self.self_attn.forward(queries, queries, queries)?
|
||||
} else {
|
||||
let q = (queries + query_pe)?;
|
||||
let attn_out = self.self_attn.forward(&q, &q, queries)?;
|
||||
(queries + attn_out)?
|
||||
};
|
||||
let queries = self.norm1.forward(&queries)?;
|
||||
|
||||
// Cross attention block, tokens attending to image embedding
|
||||
let q = (&queries + query_pe)?;
|
||||
let k = (keys + key_pe)?;
|
||||
let attn_out = self.cross_attn_token_to_image.forward(&q, &k, keys)?;
|
||||
let queries = (&queries + attn_out)?;
|
||||
let queries = self.norm2.forward(&queries)?;
|
||||
|
||||
// MLP block
|
||||
let mlp_out = self.mlp.forward(&queries);
|
||||
let queries = (queries + mlp_out)?;
|
||||
let queries = self.norm3.forward(&queries)?;
|
||||
|
||||
// Cross attention block, image embedding attending to tokens
|
||||
let q = (&queries + query_pe)?;
|
||||
let k = (keys + key_pe)?;
|
||||
let attn_out = self.cross_attn_image_to_token.forward(&k, &q, &queries)?;
|
||||
let keys = (keys + attn_out)?;
|
||||
let keys = self.norm4.forward(&keys)?;
|
||||
|
||||
Ok((queries, keys))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TwoWayTransformer {
|
||||
layers: Vec<TwoWayAttentionBlock>,
|
||||
final_attn_token_to_image: Attention,
|
||||
norm_final_attn: LayerNorm,
|
||||
}
|
||||
|
||||
impl TwoWayTransformer {
|
||||
pub fn new(
|
||||
depth: usize,
|
||||
embedding_dim: usize,
|
||||
num_heads: usize,
|
||||
mlp_dim: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let vb_l = vb.pp("layers");
|
||||
let mut layers = Vec::with_capacity(depth);
|
||||
for i in 0..depth {
|
||||
let layer =
|
||||
TwoWayAttentionBlock::new(embedding_dim, num_heads, mlp_dim, i == 0, vb_l.pp(i))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let final_attn_token_to_image = Attention::new(
|
||||
embedding_dim,
|
||||
num_heads,
|
||||
2,
|
||||
vb.pp("final_attn_token_to_image"),
|
||||
)?;
|
||||
let norm_final_attn = layer_norm(embedding_dim, 1e-5, vb.pp("norm_final_attn"))?;
|
||||
Ok(Self {
|
||||
layers,
|
||||
final_attn_token_to_image,
|
||||
norm_final_attn,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
image_embedding: &Tensor,
|
||||
image_pe: &Tensor,
|
||||
point_embedding: &Tensor,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?;
|
||||
let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?;
|
||||
|
||||
let mut queries = point_embedding.clone();
|
||||
let mut keys = image_embedding;
|
||||
|
||||
for layer in self.layers.iter() {
|
||||
(queries, keys) = layer.forward(&queries, &keys, point_embedding, &image_pe)?
|
||||
}
|
||||
|
||||
let q = (&queries + point_embedding)?;
|
||||
let k = (&keys + image_pe)?;
|
||||
let attn_out = self.final_attn_token_to_image.forward(&q, &k, &keys)?;
|
||||
let queries = (queries + attn_out)?.apply(&self.norm_final_attn)?;
|
||||
|
||||
Ok((queries, keys))
|
||||
}
|
||||
}
|
@ -1,7 +1,7 @@
|
||||
#![allow(dead_code)]
|
||||
//! Attention Based Building Blocks
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle::{IndexOp, Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
use candle_nn::Module;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct GeGlu {
|
||||
@ -61,22 +61,6 @@ impl FeedForward {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "flash-attn")]
|
||||
fn flash_attn(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "flash-attn"))]
|
||||
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||
unimplemented!("compile with '--features flash-attn'")
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CrossAttention {
|
||||
to_q: nn::Linear,
|
||||
@ -88,8 +72,6 @@ struct CrossAttention {
|
||||
slice_size: Option<usize>,
|
||||
span: tracing::Span,
|
||||
span_attn: tracing::Span,
|
||||
span_softmax: tracing::Span,
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
impl CrossAttention {
|
||||
@ -101,7 +83,6 @@ impl CrossAttention {
|
||||
heads: usize,
|
||||
dim_head: usize,
|
||||
slice_size: Option<usize>,
|
||||
use_flash_attn: bool,
|
||||
) -> Result<Self> {
|
||||
let inner_dim = dim_head * heads;
|
||||
let context_dim = context_dim.unwrap_or(query_dim);
|
||||
@ -112,7 +93,6 @@ impl CrossAttention {
|
||||
let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "xa");
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "xa-attn");
|
||||
let span_softmax = tracing::span!(tracing::Level::TRACE, "xa-softmax");
|
||||
Ok(Self {
|
||||
to_q,
|
||||
to_k,
|
||||
@ -123,8 +103,6 @@ impl CrossAttention {
|
||||
slice_size,
|
||||
span,
|
||||
span_attn,
|
||||
span_softmax,
|
||||
use_flash_attn,
|
||||
})
|
||||
}
|
||||
|
||||
@ -151,10 +129,6 @@ impl CrossAttention {
|
||||
) -> Result<Tensor> {
|
||||
let batch_size_attention = query.dim(0)?;
|
||||
let mut hidden_states = Vec::with_capacity(batch_size_attention / slice_size);
|
||||
let in_dtype = query.dtype();
|
||||
let query = query.to_dtype(DType::F32)?;
|
||||
let key = key.to_dtype(DType::F32)?;
|
||||
let value = value.to_dtype(DType::F32)?;
|
||||
|
||||
for i in 0..batch_size_attention / slice_size {
|
||||
let start_idx = i * slice_size;
|
||||
@ -166,65 +140,35 @@ impl CrossAttention {
|
||||
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(&value.i(start_idx..end_idx)?)?;
|
||||
hidden_states.push(xs)
|
||||
}
|
||||
let hidden_states = Tensor::stack(&hidden_states, 0)?.to_dtype(in_dtype)?;
|
||||
let hidden_states = Tensor::stack(&hidden_states, 0)?;
|
||||
self.reshape_batch_dim_to_heads(&hidden_states)
|
||||
}
|
||||
|
||||
fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span_attn.enter();
|
||||
let xs = if self.use_flash_attn {
|
||||
let init_dtype = query.dtype();
|
||||
let q = query
|
||||
.to_dtype(candle::DType::F16)?
|
||||
.unsqueeze(0)?
|
||||
.transpose(1, 2)?;
|
||||
let k = key
|
||||
.to_dtype(candle::DType::F16)?
|
||||
.unsqueeze(0)?
|
||||
.transpose(1, 2)?;
|
||||
let v = value
|
||||
.to_dtype(candle::DType::F16)?
|
||||
.unsqueeze(0)?
|
||||
.transpose(1, 2)?;
|
||||
flash_attn(&q, &k, &v, self.scale as f32, false)?
|
||||
.transpose(1, 2)?
|
||||
.squeeze(0)?
|
||||
.to_dtype(init_dtype)?
|
||||
} else {
|
||||
let in_dtype = query.dtype();
|
||||
let query = query.to_dtype(DType::F32)?;
|
||||
let key = key.to_dtype(DType::F32)?;
|
||||
let value = value.to_dtype(DType::F32)?;
|
||||
let xs = query.matmul(&(key.t()? * self.scale)?)?;
|
||||
let xs = {
|
||||
let _enter = self.span_softmax.enter();
|
||||
nn::ops::softmax_last_dim(&xs)?
|
||||
};
|
||||
xs.matmul(&value)?.to_dtype(in_dtype)?
|
||||
};
|
||||
let xs = query.matmul(&(key.transpose(D::Minus1, D::Minus2)? * self.scale)?)?;
|
||||
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?;
|
||||
self.reshape_batch_dim_to_heads(&xs)
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let query = self.to_q.forward(xs)?;
|
||||
let context = context.unwrap_or(xs).contiguous()?;
|
||||
let key = self.to_k.forward(&context)?;
|
||||
let value = self.to_v.forward(&context)?;
|
||||
let context = context.unwrap_or(xs);
|
||||
let key = self.to_k.forward(context)?;
|
||||
let value = self.to_v.forward(context)?;
|
||||
let query = self.reshape_heads_to_batch_dim(&query)?;
|
||||
let key = self.reshape_heads_to_batch_dim(&key)?;
|
||||
let value = self.reshape_heads_to_batch_dim(&value)?;
|
||||
let dim0 = query.dim(0)?;
|
||||
let slice_size = self.slice_size.and_then(|slice_size| {
|
||||
if dim0 < slice_size {
|
||||
None
|
||||
} else {
|
||||
Some(slice_size)
|
||||
}
|
||||
});
|
||||
let xs = match slice_size {
|
||||
let xs = match self.slice_size {
|
||||
None => self.attention(&query, &key, &value)?,
|
||||
Some(slice_size) => self.sliced_attention(&query, &key, &value, slice_size)?,
|
||||
Some(slice_size) => {
|
||||
if query.dim(0)? / slice_size <= 1 {
|
||||
self.attention(&query, &key, &value)?
|
||||
} else {
|
||||
self.sliced_attention(&query, &key, &value, slice_size)?
|
||||
}
|
||||
}
|
||||
};
|
||||
self.to_out.forward(&xs)
|
||||
}
|
||||
@ -250,7 +194,6 @@ impl BasicTransformerBlock {
|
||||
d_head: usize,
|
||||
context_dim: Option<usize>,
|
||||
sliced_attention_size: Option<usize>,
|
||||
use_flash_attn: bool,
|
||||
) -> Result<Self> {
|
||||
let attn1 = CrossAttention::new(
|
||||
vs.pp("attn1"),
|
||||
@ -259,7 +202,6 @@ impl BasicTransformerBlock {
|
||||
n_heads,
|
||||
d_head,
|
||||
sliced_attention_size,
|
||||
use_flash_attn,
|
||||
)?;
|
||||
let ff = FeedForward::new(vs.pp("ff"), dim, None, 4)?;
|
||||
let attn2 = CrossAttention::new(
|
||||
@ -269,7 +211,6 @@ impl BasicTransformerBlock {
|
||||
n_heads,
|
||||
d_head,
|
||||
sliced_attention_size,
|
||||
use_flash_attn,
|
||||
)?;
|
||||
let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?;
|
||||
let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?;
|
||||
@ -338,7 +279,6 @@ impl SpatialTransformer {
|
||||
in_channels: usize,
|
||||
n_heads: usize,
|
||||
d_head: usize,
|
||||
use_flash_attn: bool,
|
||||
config: SpatialTransformerConfig,
|
||||
) -> Result<Self> {
|
||||
let inner_dim = n_heads * d_head;
|
||||
@ -364,7 +304,6 @@ impl SpatialTransformer {
|
||||
d_head,
|
||||
config.context_dim,
|
||||
config.sliced_attention_size,
|
||||
use_flash_attn,
|
||||
)?;
|
||||
transformer_blocks.push(tb)
|
||||
}
|
||||
@ -473,15 +412,10 @@ impl AttentionBlock {
|
||||
let num_heads = channels / num_head_channels;
|
||||
let group_norm =
|
||||
nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?;
|
||||
let (q_path, k_path, v_path, out_path) = if vs.contains_tensor("to_q.weight") {
|
||||
("to_q", "to_k", "to_v", "to_out.0")
|
||||
} else {
|
||||
("query", "key", "value", "proj_attn")
|
||||
};
|
||||
let query = nn::linear(channels, channels, vs.pp(q_path))?;
|
||||
let key = nn::linear(channels, channels, vs.pp(k_path))?;
|
||||
let value = nn::linear(channels, channels, vs.pp(v_path))?;
|
||||
let proj_attn = nn::linear(channels, channels, vs.pp(out_path))?;
|
||||
let query = nn::linear(channels, channels, vs.pp("query"))?;
|
||||
let key = nn::linear(channels, channels, vs.pp("key"))?;
|
||||
let value = nn::linear(channels, channels, vs.pp("value"))?;
|
||||
let proj_attn = nn::linear(channels, channels, vs.pp("proj_attn"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attn-block");
|
||||
Ok(Self {
|
||||
group_norm,
|
||||
@ -504,7 +438,6 @@ impl AttentionBlock {
|
||||
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let in_dtype = xs.dtype();
|
||||
let residual = xs;
|
||||
let (batch, channel, height, width) = xs.dims4()?;
|
||||
let xs = self
|
||||
@ -517,13 +450,9 @@ impl AttentionBlock {
|
||||
let key_proj = self.key.forward(&xs)?;
|
||||
let value_proj = self.value.forward(&xs)?;
|
||||
|
||||
let query_states = self
|
||||
.transpose_for_scores(query_proj)?
|
||||
.to_dtype(DType::F32)?;
|
||||
let key_states = self.transpose_for_scores(key_proj)?.to_dtype(DType::F32)?;
|
||||
let value_states = self
|
||||
.transpose_for_scores(value_proj)?
|
||||
.to_dtype(DType::F32)?;
|
||||
let query_states = self.transpose_for_scores(query_proj)?;
|
||||
let key_states = self.transpose_for_scores(key_proj)?;
|
||||
let value_states = self.transpose_for_scores(value_proj)?;
|
||||
|
||||
let scale = f64::powf((self.channels as f64) / (self.num_heads as f64), -0.25);
|
||||
let attention_scores =
|
||||
@ -532,7 +461,6 @@ impl AttentionBlock {
|
||||
let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?;
|
||||
|
||||
let xs = attention_probs.matmul(&value_states.contiguous()?)?;
|
||||
let xs = xs.to_dtype(in_dtype)?;
|
||||
let xs = xs.transpose(1, 2)?.contiguous()?;
|
||||
let xs = xs.flatten_from(D::Minus2)?;
|
||||
let xs = self
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user