Compare commits

..

1 Commits

Author SHA1 Message Date
0bb344f798 [RFC] Start removal of VarBuilder.
- Uses `Initializer` trait instead.
- Allows more decoupling between init and load, which are very different
  ops.
- Allows more decoupling between backends (safetensors, npy, ggml,
  etc...)

This is a minimum viable change.

There are 3 kind of objects with various relations.

The `Model`:
    This is `Llama`, `Linear`, `Rms` ...
    They contain tensors (and possibly other things).  and are used to
    call `forward` basically.
    They should have no ownership of any internals like Rng state or
    actual shapes of the tensors (the tensors already own those)

The `Initializer`:
    This is a struct containing necessary information to generate new
    random tensors. Typically they should own a random generator, and
    generate different kind of random tensors based on what kind of
    `Model` they are initializing.
    This do not own any information about the `Model` itself.
    Default init stores the `Vec<Var>` for now, in order to send to the
    optimizer.

Ths `Config`:
    This is the necessary information to link between the `Model` and
    the `Initializer`. This is another struct which is a companion of
    the implementation of the initalization.
    Typical information is the shape of the tensors for simple `Model`,
    the `eps` for RMS, the `use_bias` boolean to know whether we should
    have a bias in the linear layer.

This should remove all need for `VarBuilder` during intialization, and
allow removing every initialization bit within `VarBuilder`.

Modifying `llama2-c` to follow that initialization is left on purpose
for a follow-up to keep the current PR rather small.
2023-08-16 14:39:36 +02:00
545 changed files with 5778 additions and 87997 deletions

View File

@ -1,8 +1,8 @@
[build]
[target.x86_64-unknown-linux-gnu]
rustflags = ["-C", "target-cpu=native"]
[target.aarch64-apple-darwin]
rustflags = ["-C", "target-cpu=native"]
[target.wasm32-unknown-unknown]
rustflags = ["-C", "target-feature=+simd128"]
[target.x86_64-apple-darwin]
rustflags = ["-C", "target-feature=-avx,-avx2"]

View File

@ -1,7 +0,0 @@
version: 2
updates:
- package-ecosystem: "cargo"
directory: "/"
schedule:
interval: "weekly"
open-pull-requests-limit: 5

View File

@ -5,15 +5,47 @@ on:
pull_request:
jobs:
start-runner:
name: Start self-hosted EC2 runner
runs-on: ubuntu-latest
env:
AWS_REGION: us-east-1
EC2_AMI_ID: ami-03cfed9ea28f4b002
EC2_INSTANCE_TYPE: g5.xlarge
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
EC2_SECURITY_GROUP: sg-030175c435ac141d6
outputs:
label: ${{ steps.start-ec2-runner.outputs.label }}
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
steps:
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v1
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ env.AWS_REGION }}
- name: Start EC2 runner
id: start-ec2-runner
uses: philschmid/philschmid-ec2-github-runner@main
with:
mode: start
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
ec2-image-id: ${{ env.EC2_AMI_ID }}
ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }}
subnet-id: ${{ env.EC2_SUBNET_ID }}
security-group-id: ${{ env.EC2_SECURITY_GROUP }}
aws-resource-tags: > # optional, requires additional permissions
[
{"Key": "Name", "Value": "ec2-tgi-github-runner"},
{"Key": "GitHubRepository", "Value": "${{ github.repository }}"}
]
test-cuda:
concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
runs-on: [single-gpu, nvidia-gpu, t4, ci]
container:
image: nvidia/cuda:12.3.1-devel-ubuntu22.04
options: --gpus 0
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
needs: start-runner # required to start the main job when the runner is ready
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
permissions:
contents: write
packages: write
@ -24,10 +56,32 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v3
- name: Install dependencies
run: apt-get update && apt install curl build-essential libssl-dev protobuf-compiler pkg-config -y
- name: Install Rust Stable
uses: actions-rust-lang/setup-rust-toolchain@v1
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
- uses: Swatinem/rust-cache@v2
- run: apt update -y && apt install libssl-dev -y
- name: Test (cuda)
run: cargo test --features cuda
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
stop-runner:
name: Stop self-hosted EC2 runner
needs:
- start-runner
- test-cuda
runs-on: ubuntu-latest
env:
AWS_REGION: us-east-1
if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs
steps:
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v1
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ env.AWS_REGION }}
- name: Stop EC2 runner
uses: philschmid/philschmid-ec2-github-runner@main
with:
mode: stop
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
label: ${{ needs.start-runner.outputs.label }}
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}

Binary file not shown.

View File

@ -1,68 +0,0 @@
name: PyO3-CI
on:
workflow_dispatch:
push:
branches:
- main
paths:
- candle-pyo3/**
pull_request:
paths:
- candle-pyo3/**
jobs:
build_and_test:
name: Check everything builds & tests
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest] # For now, only test on Linux
steps:
- name: Checkout repository
uses: actions/checkout@v2
- name: Install Rust
uses: actions-rs/toolchain@v1
with:
toolchain: stable
- name: Install Python
uses: actions/setup-python@v4
with:
python-version: 3.11
architecture: "x64"
- name: Cache Cargo Registry
uses: actions/cache@v1
with:
path: ~/.cargo/registry
key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}
- name: Install Protoc
uses: arduino/setup-protoc@v2
with:
version: "25.0"
repo-token: ${{ secrets.GITHUB_TOKEN }}
- name: Install
working-directory: ./candle-pyo3
run: |
python -m venv .env
source .env/bin/activate
pip install -U pip
pip install pytest maturin black
python -m maturin develop -r --features onnx
- name: Check style
working-directory: ./candle-pyo3
run: |
source .env/bin/activate
python stub.py --check
black --check .
- name: Run tests
working-directory: ./candle-pyo3
run: |
source .env/bin/activate
python -m pytest -s -v tests

11
.gitignore vendored
View File

@ -23,16 +23,9 @@ flamegraph.svg
*.dylib
*.so
*.swp
*.swo
trace-*.json
candle-wasm-examples/*/build
candle-wasm-examples/*/*.bin
candle-wasm-examples/*/*.jpeg
candle-wasm-examples/*/audios/*.wav
candle-wasm-examples/**/*.safetensors
candle-wasm-examples/**/*.gguf
candle-wasm-examples/*/*.wav
candle-wasm-examples/*/*.safetensors
candle-wasm-examples/*/package-lock.json
candle-wasm-examples/**/config*.json
.DS_Store
.idea/*

11
.vscode/settings.json vendored
View File

@ -1,11 +0,0 @@
{
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter"
},
"python.formatting.provider": "none",
"python.testing.pytestArgs": [
"candle-pyo3"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}

View File

@ -1,113 +0,0 @@
# Changelog
This documents the main changes to the `candle` crate.
## v0.3.1 - Unreleased
### Added
### Modified
## v0.3.0 - 2023-10-01
### Added
- Added the Mistral 7b v0.1 model
[983](https://github.com/huggingface/candle/pull/983).
- Quantized version of the Mistral model
[1009](https://github.com/huggingface/candle/pull/1009).
- Add the gelu-erf op and activation function
[969](https://github.com/huggingface/candle/pull/969).
- Add the mixformer/phi-v1.5 model
[930](https://github.com/huggingface/candle/pull/930).
- Add the sclice-scatter op
[927](https://github.com/huggingface/candle/pull/927).
- Add the Wuerstchen diffusion model
[911](https://github.com/huggingface/candle/pull/911).
### Modified
- Support for simd128 intrinsics in some quantized vecdots
[982](https://github.com/huggingface/candle/pull/982).
- Optimize the index-select cuda kernel
[976](https://github.com/huggingface/candle/pull/976).
- Self-contained safetensor wrappers
[946](https://github.com/huggingface/candle/pull/946).
## v0.2.2 - 2023-09-18
### Added
- Support for `top_p` sampling
[819](https://github.com/huggingface/candle/pull/819).
- T5 model including decoding
[864](https://github.com/huggingface/candle/pull/864).
- 1-d upsampling
[839](https://github.com/huggingface/candle/pull/839).
### Modified
- Bugfix for conv2d
[820](https://github.com/huggingface/candle/pull/820).
- Support tensor based indexing using `.i`
[842](https://github.com/huggingface/candle/pull/842).
## v0.2.1 - 2023-09-11
### Added
- Add some RNNs (GRU and LSTM) in `candle-nn`
[674](https://github.com/huggingface/candle/pull/674),
[688](https://github.com/huggingface/candle/pull/688).
- gguf v2 support
[725](https://github.com/huggingface/candle/pull/725).
- Quantized llama example in Python using the pyo3 api
[716](https://github.com/huggingface/candle/pull/716).
- `candle-nn` layer for conv2d-transposed
[760](https://github.com/huggingface/candle/pull/760).
- Add the Segment-Anything Model (SAM) as an example
[773](https://github.com/huggingface/candle/pull/773).
- TinyViT backbone for the segment anything example
[787](https://github.com/huggingface/candle/pull/787).
- Shape with holes support
[770](https://github.com/huggingface/candle/pull/770).
### Modified
- Dilations are now supported in conv-transpose2d.
[671](https://github.com/huggingface/candle/pull/671).
- Interactive mode for the quantized model
[690](https://github.com/huggingface/candle/pull/690).
- Faster softmax operation
[747](https://github.com/huggingface/candle/pull/747).
- Faster convolution operations on CPU and CUDA via im2col
[802](https://github.com/huggingface/candle/pull/802).
- Moving some models to a more central location
[796](https://github.com/huggingface/candle/pull/796).
## v0.2.0 - 2023-08-30
### Added
- Add the powf op
[664](https://github.com/huggingface/candle/pull/664).
- Stable Diffusion XL support
[647](https://github.com/huggingface/candle/pull/647).
- Add the conv-transpose2d op
[635](https://github.com/huggingface/candle/pull/635).
- Refactor the VarBuilder api
[627](https://github.com/huggingface/candle/pull/627).
- Add some quantization command
[625](https://github.com/huggingface/candle/pull/625).
- Support more quantized types, e.g. Q2K, Q4K, Q5K...
[586](https://github.com/huggingface/candle/pull/586).
- Add pose estimation to the yolo example
[589](https://github.com/huggingface/candle/pull/589).
- Api to write GGUF files
[585](https://github.com/huggingface/candle/pull/585).
- Support more quantization types
[580](https://github.com/huggingface/candle/pull/580).
- Add EfficientNet as an example Computer Vision model
[572](https://github.com/huggingface/candle/pull/572).
- Add a group parameter to convolutions
[566](https://github.com/huggingface/candle/pull/566).
- New dtype: int64
[563](https://github.com/huggingface/candle/pull/563).
- Handling of the GGUF file format.
[559](https://github.com/huggingface/candle/pull/559).
## v0.1.2 - 2023-08-21

View File

@ -3,23 +3,19 @@ members = [
"candle-core",
"candle-datasets",
"candle-examples",
"candle-book",
"candle-nn",
"candle-pyo3",
"candle-transformers",
"candle-wasm-examples/*",
"candle-wasm-tests",
"candle-wasm-examples/llama2-c",
"candle-wasm-examples/whisper",
]
exclude = [
"candle-flash-attn",
"candle-kernels",
"candle-metal-kernels",
"candle-onnx",
"candle-flash-attn",
"candle-kernels",
]
resolver = "2"
[workspace.package]
version = "0.4.0"
version = "0.1.1"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -31,46 +27,32 @@ license = "MIT OR Apache-2.0"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.4.0" }
candle-datasets = { path = "./candle-datasets", version = "0.4.0" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.4.0" }
candle-kernels = { path = "./candle-kernels", version = "0.4.0" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.4.0" }
candle-nn = { path = "./candle-nn", version = "0.4.0" }
candle-onnx = { path = "./candle-onnx", version = "0.4.0" }
candle-transformers = { path = "./candle-transformers", version = "0.4.0" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.10.0", features = ["f16"] }
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.3.0"
cudarc = { version = "0.9.14", features = ["f16"] }
# TODO: Switch back to the official gemm implementation once it has caught up.
gemm = { version = "0.15.6", package = "candle-gemm" }
hf-hub = "0.2.0"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
imageproc = { version = "0.23.0", default-features = false }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
libc = { version = "0.2.147" }
log = "0.4"
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
memmap2 = "0.7.1"
num_cpus = "1.15.0"
num-traits = "0.2.15"
parquet = { version = "50.0.0" }
rand = "0.8.5"
rand_distr = "0.4.3"
rayon = "1.7.0"
rusttype = { version = "0.9", default-features = false }
safetensors = "0.4.1"
safetensors = "0.3.1"
serde = { version = "1.0.171", features = ["derive"] }
serde_plain = "1.0.2"
serde_json = "1.0.99"
thiserror = "1"
tokenizers = { version = "0.15.0", default-features = false }
tokenizers = { version = "0.13.4", default-features = false }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false }
metal = { version = "0.27.0", features = ["mps"]}
[profile.release-with-debug]
inherits = "release"

View File

@ -1,5 +1,3 @@
.PHONY: clean-ptx clean test
clean-ptx:
find target -name "*.ptx" -type f -delete
echo "" > candle-kernels/src/lib.rs
@ -13,4 +11,8 @@ clean:
test:
cargo test
pyo3-test:
cargo build --profile=release-with-debug --package candle-pyo3
python3 candle-pyo3/test.py
all: test

281
README.md
View File

@ -1,5 +1,5 @@
# candle
[![discord server](https://dcbadge.vercel.app/api/server/hugging-face-879548962464493619)](https://discord.gg/hugging-face-879548962464493619)
[![discord server](https://dcbadge.vercel.app/api/server/hugging-face-879548962464493619)](https://discord.com/channels/879548962464493619/1136218819447238726)
[![Latest version](https://img.shields.io/crates/v/candle-core.svg)](https://crates.io/crates/candle-core)
[![Documentation](https://docs.rs/candle-core/badge.svg)](https://docs.rs/candle-core)
![License](https://img.shields.io/crates/l/candle-core.svg)
@ -7,167 +7,57 @@
Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support)
and ease of use. Try our online demos:
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
[LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2),
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
[yolo](https://huggingface.co/spaces/lmz/candle-yolo),
[Segment
Anything](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
[llama2](https://huggingface.co/spaces/lmz/candle-llama2).
## Get started
Make sure that you have [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) correctly installed as described in [**Installation**](https://huggingface.github.io/candle/guide/installation.html).
Let's see how to run a simple matrix multiplication.
Write the following to your `myapp/src/main.rs` file:
```rust
use candle_core::{Device, Tensor};
let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let device = Device::Cpu;
let a = Tensor::randn(0f32, 1., (2, 3), &device)?;
let b = Tensor::randn(0f32, 1., (3, 4), &device)?;
let c = a.matmul(&b)?;
println!("{c}");
Ok(())
}
let c = a.matmul(&b)?;
println!("{c}");
```
`cargo run` should display a tensor of shape `Tensor[[2, 4], f32]`.
Having installed `candle` with Cuda support, simply define the `device` to be on GPU:
```diff
- let device = Device::Cpu;
+ let device = Device::new_cuda(0)?;
```
For more advanced examples, please have a look at the following section.
## Check out our examples
These online demos run entirely in your browser:
- [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and
object recognition.
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): speech recognition.
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
- [Phi-1.5, and Phi-2](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
- [BLIP](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning): image captioning.
We also provide a some command line based examples using state of the art models:
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes
the SOLAR-10.7B variant.
- [Falcon](./candle-examples/examples/falcon/): general LLM.
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
pre-trained on 1T tokens of English and code datasets. Also supports
StableLM-2, a 1.6b LLM trained on 2T tokens, as well as the code variants.
- [Mamba](./candle-examples/examples/mamba/): an inference only
implementation of the Mamba state space model.
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
better performance than all publicly available 13b models as of 2023-09-28.
- [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of
experts 8x7b general LLM with better performance than a Llama 2 70B model with
much faster inference.
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
(English/Chinese) general LLMs with 6b and 34b parameters.
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
the LLaMA model using the same quantization techniques as
[llama.cpp](https://github.com/ggerganov/llama.cpp).
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/quantized/assets/aoc.gif" width="600">
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
image generative model, support for the 1.5, 2.1, SDXL 1.0 and Turbo versions.
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200">
- [Wuerstchen](./candle-examples/examples/wuerstchen/): another text to
image generative model.
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/wuerstchen/assets/cat.jpg" width="200">
- [yolo-v3](./candle-examples/examples/yolo-v3/) and
[yolo-v8](./candle-examples/examples/yolo-v8/): object detection and pose
estimation models.
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.od.jpg" width="200"><img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.pose.jpg" width="200">
- [segment-anything](./candle-examples/examples/segment-anything/): image
segmentation model with prompt.
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
Check out our [examples](./candle-examples/examples/):
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
using self-supervision (can be used for imagenet classification, depth
evaluation, segmentation).
- [VGG](./candle-examples/examples/vgg/),
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
generate captions for an image.
- [TrOCR](./candle-examples/examples/trocr/): a transformer OCR model, with
dedicated submodels for hand-writing and printed recognition.
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
model, generates the translated text from the input text.
- [Llama and Llama-v2](./candle-examples/examples/llama/): general LLM.
- [Falcon](./candle-examples/examples/falcon/): general LLM.
- [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code
generation.
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
image generative model, yet to be optimized.
Run them using commands like:
Run them using the following commands:
```
cargo run --example quantized --release
cargo run --example whisper --release
cargo run --example llama --release
cargo run --example falcon --release
cargo run --example bert --release
cargo run --example bigcode --release
cargo run --example stable-diffusion --release --features image -- --prompt "a rusty robot holding a fire torch"
```
In order to use **CUDA** add `--features cuda` to the example command line. If
you have cuDNN installed, use `--features cudnn` for even more speedups.
In order to use **CUDA** add `--features cuda` to the example command line.
There are also some wasm examples for whisper and
[llama2.c](https://github.com/karpathy/llama2.c). You can either build them with
`trunk` or try them online:
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
[llama2](https://huggingface.co/spaces/lmz/candle-llama2),
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
[Phi-1.5, and Phi-2](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm),
[Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
[llama2](https://huggingface.co/spaces/lmz/candle-llama2).
For LLaMA2, run the following command to retrieve the weight files and start a
For llama2, run the following command to retrieve the weight files and start a
test server:
```bash
cd candle-wasm-examples/llama2-c
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/model.bin
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/tokenizer.json
trunk serve --release --port 8081
trunk serve --release --public-url /candle-llama2/ --port 8081
```
And then head over to
[http://localhost:8081/](http://localhost:8081/).
<!--- ANCHOR: useful_libraries --->
## Useful External Resources
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): A
very detailed tutorial showing how to convert a PyTorch model to Candle.
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): Efficient and
ergonomic LoRA implementation for Candle. `candle-lora` has
out-of-the-box LoRA support for many models from Candle, which can be found
[here](https://github.com/EricLBuehler/candle-lora/tree/master/candle-lora-transformers/examples).
- [`optimisers`](https://github.com/KGrewal1/optimisers): A collection of optimisers
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
serving local LLMs including an OpenAI compatible API server.
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
If you have an addition to this list, please submit a pull request.
<!--- ANCHOR_END: useful_libraries --->
[http://localhost:8081/candle-llama2](http://localhost:8081/candle-llama2).
<!--- ANCHOR: features --->
@ -180,42 +70,11 @@ If you have an addition to this list, please submit a pull request.
- Optimized CPU backend with optional MKL support for x86 and Accelerate for macs.
- CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.
- WASM support, run your models in a browser.
- Included models.
- Language Models.
- LLaMA v1 and v2 with variants such as SOLAR-10.7B.
- Falcon.
- StarCoder.
- Phi 1, 1.5, and 2.
- Mamba, Minimal Mamba
- Mistral 7b v0.1.
- Mixtral 8x7b v0.1.
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
- Replit-code-v1.5-3B.
- Bert.
- Yi-6B and Yi-34B.
- Quantized LLMs.
- Llama 7b, 13b, 70b, as well as the chat and code variants.
- Mistral 7b, and 7b instruct.
- Mixtral 8x7b.
- Zephyr 7b a and b (Mistral-7b based).
- OpenChat 3.5 (Mistral-7b based).
- Text to text.
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
- Marian MT (Machine Translation).
- Whisper (multi-lingual support).
- Text to image.
- Stable Diffusion v1.5, v2.1, XL v1.0.
- Wurstchen v2.
- Image to text.
- BLIP.
- TrOCR.
- Computer Vision Models.
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT.
- yolo-v3, yolo-v8.
- Segment-Anything Model (SAM).
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
- Model support out of the box.
- LLMs: Llama v1 and v2, Falcon, StarCoder.
- Whisper.
- Stable Diffusion.
- Serverless (on CPU), small and fast deployments.
- Quantization support using the llama.cpp quantized types.
<!--- ANCHOR_END: features --->
@ -232,7 +91,7 @@ Cheatsheet:
| Operations | `tensor.view((2, 2))` | `tensor.reshape((2, 2))?` |
| Operations | `a.matmul(b)` | `a.matmul(&b)?` |
| Arithmetic | `a + b` | `&a + &b` |
| Device | `tensor.to(device="cuda")` | `tensor.to_device(&Device::new_cuda(0)?)?` |
| Device | `tensor.to(device="cuda")` | `tensor.to_device(&Device::Cuda(0))?` |
| Dtype | `tensor.to(dtype=torch.float16)` | `tensor.to_dtype(&DType::F16)?` |
| Saving | `torch.save({"A": A}, "model.bin")` | `candle::safetensors::save(&HashMap::from([("A", A)]), "model.safetensors")?` |
| Loading | `weights = torch.load("model.bin")` | `candle::safetensors::load("model.safetensors", &device)` |
@ -249,7 +108,6 @@ Cheatsheet:
- [candle-datasets](./candle-datasets/): Datasets and data loaders.
- [candle-transformers](./candle-transformers): transformers-related utilities.
- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer.
- [candle-onnx](./candle-onnx/): ONNX model evaluation.
## FAQ
@ -286,97 +144,34 @@ Finally, Rust is cool! A lot of the HF ecosystem already has Rust crates, like [
#### Missing symbols when compiling with the mkl feature.
If you get some missing symbols when compiling binaries/tests using the mkl
or accelerate features, e.g. for mkl you get:
features, e.g.:
```
= note: /usr/bin/ld: (....o): in function `blas::sgemm':
.../blas-0.22.0/src/lib.rs:1944: undefined reference to `sgemm_' collect2: error: ld returned 1 exit status
= note: some `extern` functions couldn't be found; some native libraries may need to be installed or have their path specified
= note: use the `-l` flag to specify native libraries to link
= note: use the `cargo:rustc-link-lib` directive to specify the native libraries to link with Cargo
```
or for accelerate:
```
Undefined symbols for architecture arm64:
"_dgemm_", referenced from:
candle_core::accelerate::dgemm::h1b71a038552bcabe in libcandle_core...
"_sgemm_", referenced from:
candle_core::accelerate::sgemm::h2cf21c592cba3c47 in libcandle_core...
ld: symbol(s) not found for architecture arm64
= note: use the `cargo:rustc-link-lib` directive to specify the native libraries to link with Cargo (see https://doc.rust-lang.org/cargo/reference/build-scripts.html#cargorustc-link-libkindname)
```
This is likely due to a missing linker flag that was needed to enable the mkl library. You
can try adding the following for mkl at the top of your binary:
```rust
can try adding the following at the top of your binary:
```
extern crate intel_mkl_src;
```
or for accelerate:
```rust
extern crate accelerate_src;
```
#### Cannot run the LLaMA examples: access to source requires login credentials
#### Cannot run llama example : access to source requires login credentials
```
Error: request error: https://huggingface.co/meta-llama/Llama-2-7b-hf/resolve/main/tokenizer.json: status code 401
```
This is likely because you're not permissioned for the LLaMA-v2 model. To fix
this, you have to register on the huggingface-hub, accept the [LLaMA-v2 model
This is likely because you're not permissioned for the llama-v2 model. To fix
this, you have to register on the huggingface-hub, accept the [llama-v2 model
conditions](https://huggingface.co/meta-llama/Llama-2-7b-hf), and set up your
authentication token. See issue
[#350](https://github.com/huggingface/candle/issues/350) for more details.
#### Missing cute/cutlass headers when compiling flash-attn
```
In file included from kernels/flash_fwd_launch_template.h:11:0,
from kernels/flash_fwd_hdim224_fp16_sm80.cu:5:
kernels/flash_fwd_kernel.h:8:10: fatal error: cute/algorithm/copy.hpp: No such file or directory
#include <cute/algorithm/copy.hpp>
^~~~~~~~~~~~~~~~~~~~~~~~~
compilation terminated.
Error: nvcc error while compiling:
```
[cutlass](https://github.com/NVIDIA/cutlass) is provided as a git submodule so you may want to run the following command to check it in properly.
```bash
git submodule update --init
```
#### Compiling with flash-attention fails
```
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ...:
```
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the CANDLE_NVCC_CCBIN environment variable.
```
env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
```
#### Linking error on windows when running rustdoc or mdbook tests
```
Couldn't compile the test.
---- .\candle-book\src\inference\hub.md - Using_the_hub::Using_in_a_real_model_ (line 50) stdout ----
error: linking with `link.exe` failed: exit code: 1181
//very long chain of linking
= note: LINK : fatal error LNK1181: cannot open input file 'windows.0.48.5.lib'
```
Make sure you link all native libraries that might be located outside a project target, e.g., to run mdbook tests, you should run:
```
mdbook test candle-book -L .\target\debug\deps\ `
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.42.2\lib `
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.48.5\lib
```
#### Extremely slow model load time with WSL
This may be caused by the models being loaded from `/mnt/c`, more details on
[stackoverflow](https://stackoverflow.com/questions/68972448/why-is-wsl-extremely-slow-when-compared-with-native-windows-npm-yarn-processing).
#### Tracking down errors
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle

View File

@ -1,49 +0,0 @@
[package]
name = "candle-book"
version.workspace = true
edition.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true
readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { workspace = true }
candle-datasets = { workspace = true }
candle-nn = { workspace = true }
candle-transformers = { workspace = true }
candle-flash-attn = { workspace = true, optional = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
num-traits = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
image = { workspace = true, optional = true }
anyhow = { workspace = true }
tokio = "1.29.1"
[dev-dependencies]
byteorder = { workspace = true }
hf-hub = { workspace = true, features=["tokio"]}
clap = { workspace = true }
memmap2 = { workspace = true }
rand = { workspace = true }
tokenizers = { workspace = true, features = ["onig"] }
tracing = { workspace = true }
tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true }
wav = { workspace = true }
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
parquet = { workspace = true }
image = { workspace = true }
[build-dependencies]
anyhow = { workspace = true }
[features]
default = []

View File

@ -10,14 +10,9 @@
# Reference Guide
- [Running a model](inference/inference.md)
- [Running a model](inference/README.md)
- [Using the hub](inference/hub.md)
- [Error management](error_manage.md)
- [Training](training/training.md)
- [Simplified](training/simplified.md)
- [MNIST](training/mnist.md)
- [Fine-tuning]()
- [Serialization]()
- [Error management]()
- [Advanced Cuda usage]()
- [Writing a custom kernel]()
- [Porting a custom kernel]()
@ -26,3 +21,7 @@
- [Creating a WASM app]()
- [Creating a REST api webserver]()
- [Creating a desktop Tauri app]()
- [Training]()
- [MNIST]()
- [Fine-tuning]()
- [Serialization]()

View File

@ -29,7 +29,7 @@ After adding `RUST_BACKTRACE=1`:
Error: WithBacktrace { inner: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }, backtrace: Backtrace [{ fn: "candle::error::Error::bt", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/error.rs", line: 200 }, { fn: "candle::tensor::Tensor::matmul", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/tensor.rs", line: 816 }, { fn: "myapp::main", file: "./src/main.rs", line: 29 }, { fn: "core::ops::function::FnOnce::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 250 }, { fn: "std::sys_common::backtrace::__rust_begin_short_backtrace", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/sys_common/backtrace.rs", line: 135 }, { fn: "std::rt::lang_start::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 166 }, { fn: "core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 284 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::rt::lang_start", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 165 }, { fn: "main" }, { fn: "__libc_start_main" }, { fn: "_start" }] }
```
Not super pretty at the moment, but we can see error occurred on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }`
Not super pretty at the moment, but we can see error occured on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }`
Another thing to note, is that since Rust is compiled it is not necessarily as easy to recover proper stacktraces

View File

@ -6,7 +6,7 @@ Open `src/main.rs` and fill in this content:
```rust
# extern crate candle_core;
use candle_core::{Device, Result, Tensor};
use candle_core::{DType, Device, Result, Tensor};
struct Model {
first: Tensor,
@ -25,11 +25,11 @@ fn main() -> Result<()> {
// Use Device::new_cuda(0)?; to use the GPU.
let device = Device::Cpu;
let first = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
let second = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
let first = Tensor::zeros((784, 100), DType::F32, &device)?;
let second = Tensor::zeros((100, 10), DType::F32, &device)?;
let model = Model { first, second };
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
let digit = model.forward(&dummy_image)?;
println!("Digit {digit:?} digit");
@ -50,7 +50,7 @@ the classical `Linear` layer. We can do as such
```rust
# extern crate candle_core;
# use candle_core::{Device, Result, Tensor};
# use candle_core::{DType, Device, Result, Tensor};
struct Linear{
weight: Tensor,
bias: Tensor,
@ -80,7 +80,7 @@ This will change the model running code into a new function
```rust
# extern crate candle_core;
# use candle_core::{Device, Result, Tensor};
# use candle_core::{DType, Device, Result, Tensor};
# struct Linear{
# weight: Tensor,
# bias: Tensor,
@ -110,15 +110,15 @@ fn main() -> Result<()> {
let device = Device::cuda_if_available(0)?;
// Creating a dummy model
let weight = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
let weight = Tensor::zeros((784, 100), DType::F32, &device)?;
let bias = Tensor::zeros((100, ), DType::F32, &device)?;
let first = Linear{weight, bias};
let weight = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
let weight = Tensor::zeros((100, 10), DType::F32, &device)?;
let bias = Tensor::zeros((10, ), DType::F32, &device)?;
let second = Linear{weight, bias};
let model = Model { first, second };
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
// Inference on the model
let digit = model.forward(&dummy_image)?;
@ -146,8 +146,8 @@ And rewrite our examples using it
```rust
# extern crate candle_core;
# extern crate candle_nn;
use candle_core::{Device, Result, Tensor};
use candle_nn::{Linear, Module};
use candle_core::{DType, Device, Result, Tensor};
use candle_nn::Linear;
struct Model {
first: Linear,
@ -167,15 +167,15 @@ fn main() -> Result<()> {
let device = Device::Cpu;
// This has changed (784, 100) -> (100, 784) !
let weight = Tensor::randn(0f32, 1.0, (100, 784), &device)?;
let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
let weight = Tensor::zeros((100, 784), DType::F32, &device)?;
let bias = Tensor::zeros((100, ), DType::F32, &device)?;
let first = Linear::new(weight, Some(bias));
let weight = Tensor::randn(0f32, 1.0, (10, 100), &device)?;
let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
let weight = Tensor::zeros((10, 100), DType::F32, &device)?;
let bias = Tensor::zeros((10, ), DType::F32, &device)?;
let second = Linear::new(weight, Some(bias));
let model = Model { first, second };
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
let digit = model.forward(&dummy_image)?;
println!("Digit {digit:?} digit");
@ -188,8 +188,8 @@ Feel free to modify this example to use `Conv2d` to create a classical convnet i
Now that we have the running dummy code we can get to more advanced topics:
- [For PyTorch users](../guide/cheatsheet.md)
- [Running existing models](../inference/inference.md)
- [Training models](../training/training.md)
- [For PyTorch users](./guide/cheatsheet.md)
- [Running existing models](./inference/README.md)
- [Training models](./training/README.md)

View File

@ -1,46 +1,6 @@
# Installation
**With Cuda support**:
1. First, make sure that Cuda is correctly installed.
- `nvcc --version` should print information about your Cuda compiler driver.
- `nvidia-smi --query-gpu=compute_cap --format=csv` should print your GPUs compute capability, e.g. something
like:
```bash
compute_cap
8.9
```
You can also compile the Cuda kernels for a specific compute cap using the
`CUDA_COMPUTE_CAP=<compute cap>` environment variable.
If any of the above commands errors out, please make sure to update your Cuda version.
2. Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) with Cuda support.
Start by creating a new cargo:
```bash
cargo new myapp
cd myapp
```
Make sure to add the `candle-core` crate with the cuda feature:
```bash
cargo add --git https://github.com/huggingface/candle.git candle-core --features "cuda"
```
Run `cargo build` to make sure everything can be correctly built.
```bash
cargo build
```
**Without Cuda support**:
Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) as follows:
Start by creating a new app:
```bash
cargo new myapp
@ -48,12 +8,17 @@ cd myapp
cargo add --git https://github.com/huggingface/candle.git candle-core
```
Finally, run `cargo build` to make sure everything can be correctly built.
At this point, candle will be built **without** CUDA support.
To get CUDA support use the `cuda` feature
```bash
cargo add --git https://github.com/huggingface/candle.git candle-core --features cuda
```
You can check everything works properly:
```bash
cargo build
```
**With mkl support**
You can also see the `mkl` feature which could be interesting to get faster inference on CPU. [Using mkl](./advanced/mkl.md)

View File

@ -39,7 +39,7 @@ cargo add hf-hub --features tokio
```rust,ignore
# This is tested directly in examples crate because it needs external dependencies unfortunately:
# See [this](https://github.com/rust-lang/mdBook/issues/706)
{{#include ../lib.rs:book_hub_1}}
{{#include ../../../candle-examples/src/lib.rs:book_hub_1}}
```
@ -58,7 +58,7 @@ Now that we have our weights, we can use them in our bert architecture:
#
# let weights = repo.get("model.safetensors").unwrap();
use candle_core::{Device, Tensor, DType};
use candle_nn::{Linear, Module};
use candle_nn::Linear;
let weights = candle_core::safetensors::load(weights, &Device::Cpu).unwrap();
@ -81,7 +81,7 @@ For more efficient loading, instead of reading the file, you could use [`memmap2
and will definitely be slower on network mounted disk, because it will issue more read calls.
```rust,ignore
{{#include ../lib.rs:book_hub_2}}
{{#include ../../../candle-examples/src/lib.rs:book_hub_2}}
```
**Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety).
@ -100,5 +100,5 @@ cargo add safetensors
```rust,ignore
{{#include ../lib.rs:book_hub_3}}
{{#include ../../../candle-examples/src/lib.rs:book_hub_3}}
```

View File

@ -1,199 +0,0 @@
#[cfg(test)]
pub mod simplified;
#[cfg(test)]
mod tests {
use anyhow::Result;
use candle::{DType, Device, Tensor};
use parquet::file::reader::SerializedFileReader;
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
#[rustfmt::skip]
#[tokio::test]
async fn book_hub_1() {
// ANCHOR: book_hub_1
use candle::Device;
use hf_hub::api::tokio::Api;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").await.unwrap();
let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
// ANCHOR_END: book_hub_1
assert_eq!(weights.len(), 206);
}
#[rustfmt::skip]
#[test]
fn book_hub_2() {
{
// ANCHOR: book_hub_2
use candle::Device;
use hf_hub::api::sync::Api;
use memmap2::Mmap;
use std::fs;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").unwrap();
let file = fs::File::open(weights_filename).unwrap();
let mmap = unsafe { Mmap::map(&file).unwrap() };
let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
// ANCHOR_END: book_hub_2
assert_eq!(weights.len(), 206);
}
// #[rustfmt::skip]
// #[test]
// fn book_hub_3() {
{
// ANCHOR: book_hub_3
use candle::{DType, Device, Tensor};
use hf_hub::api::sync::Api;
use memmap2::Mmap;
use safetensors::slice::IndexOp;
use safetensors::SafeTensors;
use std::fs;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").unwrap();
let file = fs::File::open(weights_filename).unwrap();
let mmap = unsafe { Mmap::map(&file).unwrap() };
// Use safetensors directly
let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
let view = tensors
.tensor("bert.encoder.layer.0.attention.self.query.weight")
.unwrap();
// We're going to load shard with rank 1, within a world_size of 4
// We're going to split along dimension 0 doing VIEW[start..stop, :]
let rank = 1;
let world_size = 4;
let dim = 0;
let dtype = view.dtype();
let mut tp_shape = view.shape().to_vec();
let size = tp_shape[0];
if size % world_size != 0 {
panic!("The dimension is not divisble by `world_size`");
}
let block_size = size / world_size;
let start = rank * block_size;
let stop = (rank + 1) * block_size;
// Everything is expressed in tensor dimension
// bytes offsets is handled automatically for safetensors.
let iterator = view.slice(start..stop).unwrap();
tp_shape[dim] = block_size;
// Convert safetensors Dtype to candle DType
let dtype: DType = dtype.try_into().unwrap();
// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
// ANCHOR_END: book_hub_3
assert_eq!(view.shape(), &[768, 768]);
assert_eq!(tp_tensor.dims(), &[192, 768]);
}
}
#[rustfmt::skip]
#[test]
fn book_training_1() -> Result<()>{
// ANCHOR: book_training_1
use hf_hub::{api::sync::Api, Repo, RepoType};
let dataset_id = "mnist".to_string();
let api = Api::new()?;
let repo = Repo::with_revision(
dataset_id,
RepoType::Dataset,
"refs/convert/parquet".to_string(),
);
let repo = api.repo(repo);
let test_parquet_filename = repo.get("mnist/test/0000.parquet")?;
let train_parquet_filename = repo.get("mnist/train/0000.parquet")?;
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
// ANCHOR_END: book_training_1
// Ignore unused
let _train = train_parquet;
// ANCHOR: book_training_2
for row in test_parquet {
for (idx, (name, field)) in row?.get_column_iter().enumerate() {
println!("Column id {idx}, name {name}, value {field}");
}
}
// ANCHOR_END: book_training_2
let test_parquet_filename = repo.get("mnist/test/0000.parquet")?;
let train_parquet_filename = repo.get("mnist/train/0000.parquet")?;
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
// ANCHOR: book_training_3
let test_samples = 10_000;
let mut test_buffer_images: Vec<u8> = Vec::with_capacity(test_samples * 784);
let mut test_buffer_labels: Vec<u8> = Vec::with_capacity(test_samples);
for row in test_parquet{
for (_name, field) in row?.get_column_iter() {
if let parquet::record::Field::Group(subrow) = field {
for (_name, field) in subrow.get_column_iter() {
if let parquet::record::Field::Bytes(value) = field {
let image = image::load_from_memory(value.data()).unwrap();
test_buffer_images.extend(image.to_luma8().as_raw());
}
}
}else if let parquet::record::Field::Long(label) = field {
test_buffer_labels.push(*label as u8);
}
}
}
let test_images = (Tensor::from_vec(test_buffer_images, (test_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
let test_labels = Tensor::from_vec(test_buffer_labels, (test_samples, ), &Device::Cpu)?;
let train_samples = 60_000;
let mut train_buffer_images: Vec<u8> = Vec::with_capacity(train_samples * 784);
let mut train_buffer_labels: Vec<u8> = Vec::with_capacity(train_samples);
for row in train_parquet{
for (_name, field) in row?.get_column_iter() {
if let parquet::record::Field::Group(subrow) = field {
for (_name, field) in subrow.get_column_iter() {
if let parquet::record::Field::Bytes(value) = field {
let image = image::load_from_memory(value.data()).unwrap();
train_buffer_images.extend(image.to_luma8().as_raw());
}
}
}else if let parquet::record::Field::Long(label) = field {
train_buffer_labels.push(*label as u8);
}
}
}
let train_images = (Tensor::from_vec(train_buffer_images, (train_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
let train_labels = Tensor::from_vec(train_buffer_labels, (train_samples, ), &Device::Cpu)?;
let mnist = candle_datasets::vision::Dataset {
train_images,
train_labels,
test_images,
test_labels,
labels: 10,
};
// ANCHOR_END: book_training_3
assert_eq!(mnist.test_images.dims(), &[10_000, 784]);
assert_eq!(mnist.test_labels.dims(), &[10_000]);
assert_eq!(mnist.train_images.dims(), &[60_000, 784]);
assert_eq!(mnist.train_labels.dims(), &[60_000]);
Ok(())
}
}

View File

@ -1,196 +0,0 @@
//! #A simplified example in Rust of training a neural network and then using it based on the Candle Framework by Hugging Face.
//! Author: Evgeny Igumnov 2023 igumnovnsk@gmail.com
//! This program implements a neural network to predict the winner of the second round of elections based on the results of the first round.
//!
//! ##Basic moments:
//!
//! A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons.
//! The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage.
//! The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose.
//! For training, samples with real data on the results of the first and second stages of different elections are used.
//! The model is trained by backpropagation using gradient descent and the cross-entropy loss function.
//! Model parameters (weights of neurons) are initialized randomly, then optimized during training.
//! After training, the model is tested on a deferred sample to evaluate the accuracy.
//! If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated.
//! Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data.
#[rustfmt::skip]
mod tests {
use candle::{DType, Result, Tensor, D, Device};
use candle_nn::{loss, ops, Linear, Module, VarBuilder, VarMap, Optimizer};
// ANCHOR: book_training_simplified1
const VOTE_DIM: usize = 2;
const RESULTS: usize = 1;
const EPOCHS: usize = 10;
const LAYER1_OUT_SIZE: usize = 4;
const LAYER2_OUT_SIZE: usize = 2;
const LEARNING_RATE: f64 = 0.05;
#[derive(Clone)]
pub struct Dataset {
pub train_votes: Tensor,
pub train_results: Tensor,
pub test_votes: Tensor,
pub test_results: Tensor,
}
struct MultiLevelPerceptron {
ln1: Linear,
ln2: Linear,
ln3: Linear,
}
impl MultiLevelPerceptron {
fn new(vs: VarBuilder) -> Result<Self> {
let ln1 = candle_nn::linear(VOTE_DIM, LAYER1_OUT_SIZE, vs.pp("ln1"))?;
let ln2 = candle_nn::linear(LAYER1_OUT_SIZE, LAYER2_OUT_SIZE, vs.pp("ln2"))?;
let ln3 = candle_nn::linear(LAYER2_OUT_SIZE, RESULTS + 1, vs.pp("ln3"))?;
Ok(Self { ln1, ln2, ln3 })
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.ln1.forward(xs)?;
let xs = xs.relu()?;
let xs = self.ln2.forward(&xs)?;
let xs = xs.relu()?;
self.ln3.forward(&xs)
}
}
// ANCHOR_END: book_training_simplified1
// ANCHOR: book_training_simplified3
#[tokio::test]
async fn simplified() -> anyhow::Result<()> {
let dev = Device::cuda_if_available(0)?;
let train_votes_vec: Vec<u32> = vec![
15, 10,
10, 15,
5, 12,
30, 20,
16, 12,
13, 25,
6, 14,
31, 21,
];
let train_votes_tensor = Tensor::from_vec(train_votes_vec.clone(), (train_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
let train_results_vec: Vec<u32> = vec![
1,
0,
0,
1,
1,
0,
0,
1,
];
let train_results_tensor = Tensor::from_vec(train_results_vec, train_votes_vec.len() / VOTE_DIM, &dev)?;
let test_votes_vec: Vec<u32> = vec![
13, 9,
8, 14,
3, 10,
];
let test_votes_tensor = Tensor::from_vec(test_votes_vec.clone(), (test_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
let test_results_vec: Vec<u32> = vec![
1,
0,
0,
];
let test_results_tensor = Tensor::from_vec(test_results_vec.clone(), test_results_vec.len(), &dev)?;
let m = Dataset {
train_votes: train_votes_tensor,
train_results: train_results_tensor,
test_votes: test_votes_tensor,
test_results: test_results_tensor,
};
let trained_model: MultiLevelPerceptron;
loop {
println!("Trying to train neural network.");
match train(m.clone(), &dev) {
Ok(model) => {
trained_model = model;
break;
},
Err(e) => {
println!("Error: {}", e);
continue;
}
}
}
let real_world_votes: Vec<u32> = vec![
13, 22,
];
let tensor_test_votes = Tensor::from_vec(real_world_votes.clone(), (1, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
let final_result = trained_model.forward(&tensor_test_votes)?;
let result = final_result
.argmax(D::Minus1)?
.to_dtype(DType::F32)?
.get(0).map(|x| x.to_scalar::<f32>())??;
println!("real_life_votes: {:?}", real_world_votes);
println!("neural_network_prediction_result: {:?}", result);
Ok(())
}
// ANCHOR_END: book_training_simplified3
// ANCHOR: book_training_simplified2
fn train(m: Dataset, dev: &Device) -> anyhow::Result<MultiLevelPerceptron> {
let train_results = m.train_results.to_device(dev)?;
let train_votes = m.train_votes.to_device(dev)?;
let varmap = VarMap::new();
let vs = VarBuilder::from_varmap(&varmap, DType::F32, dev);
let model = MultiLevelPerceptron::new(vs.clone())?;
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), LEARNING_RATE)?;
let test_votes = m.test_votes.to_device(dev)?;
let test_results = m.test_results.to_device(dev)?;
let mut final_accuracy: f32 = 0.0;
for epoch in 1..EPOCHS + 1 {
let logits = model.forward(&train_votes)?;
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
let loss = loss::nll(&log_sm, &train_results)?;
sgd.backward_step(&loss)?;
let test_logits = model.forward(&test_votes)?;
let sum_ok = test_logits
.argmax(D::Minus1)?
.eq(&test_results)?
.to_dtype(DType::F32)?
.sum_all()?
.to_scalar::<f32>()?;
let test_accuracy = sum_ok / test_results.dims1()? as f32;
final_accuracy = 100. * test_accuracy;
println!("Epoch: {epoch:3} Train loss: {:8.5} Test accuracy: {:5.2}%",
loss.to_scalar::<f32>()?,
final_accuracy
);
if final_accuracy == 100.0 {
break;
}
}
if final_accuracy < 100.0 {
Err(anyhow::Error::msg("The model is not trained well enough."))
} else {
Ok(model)
}
}
// ANCHOR_END: book_training_simplified2
}

View File

@ -0,0 +1 @@
# Training

View File

@ -1,10 +1 @@
# MNIST
So we now have downloaded the MNIST parquet files, let's put them in a simple struct.
```rust,ignore
{{#include ../lib.rs:book_training_3}}
```
The parsing of the file and putting it into single tensors requires the dataset to fit the entire memory.
It is quite rudimentary, but simple enough for a small dataset like MNIST.

View File

@ -1,45 +0,0 @@
# Simplified
## How its works
This program implements a neural network to predict the winner of the second round of elections based on the results of the first round.
Basic moments:
1. A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons.
2. The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage.
3. The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose.
4. For training, samples with real data on the results of the first and second stages of different elections are used.
5. The model is trained by backpropagation using gradient descent and the cross-entropy loss function.
6. Model parameters (weights of neurons) are initialized randomly, then optimized during training.
7. After training, the model is tested on a deferred sample to evaluate the accuracy.
8. If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated.
Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data.
```rust,ignore
{{#include ../simplified.rs:book_training_simplified1}}
```
```rust,ignore
{{#include ../simplified.rs:book_training_simplified2}}
```
```rust,ignore
{{#include ../simplified.rs:book_training_simplified3}}
```
## Example output
```bash
Trying to train neural network.
Epoch: 1 Train loss: 4.42555 Test accuracy: 0.00%
Epoch: 2 Train loss: 0.84677 Test accuracy: 33.33%
Epoch: 3 Train loss: 2.54335 Test accuracy: 33.33%
Epoch: 4 Train loss: 0.37806 Test accuracy: 33.33%
Epoch: 5 Train loss: 0.36647 Test accuracy: 100.00%
real_life_votes: [13, 22]
neural_network_prediction_result: 0.0
```

View File

@ -1,39 +0,0 @@
# Training
Training starts with data. We're going to use the huggingface hub and
start with the Hello world dataset of machine learning, MNIST.
Let's start with downloading `MNIST` from [huggingface](https://huggingface.co/datasets/mnist).
This requires [`hf-hub`](https://github.com/huggingface/hf-hub).
```bash
cargo add hf-hub
```
This is going to be very hands-on for now.
```rust,ignore
{{#include ../../../candle-examples/src/lib.rs:book_training_1}}
```
This uses the standardized `parquet` files from the `refs/convert/parquet` branch on every dataset.
Our handles are now [`parquet::file::serialized_reader::SerializedFileReader`].
We can inspect the content of the files with:
```rust,ignore
{{#include ../../../candle-examples/src/lib.rs:book_training_2}}
```
You should see something like:
```bash
Column id 1, name label, value 6
Column id 0, name image, value {bytes: [137, ....]
Column id 1, name label, value 8
Column id 0, name image, value {bytes: [137, ....]
```
So each row contains 2 columns (image, label) with image being saved as bytes.
Let's put them into a useful struct.

View File

@ -12,9 +12,7 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle-kernels = { workspace = true, optional = true }
candle-metal-kernels = { workspace = true, optional = true }
metal = { workspace = true, optional = true}
candle-kernels = { path = "../candle-kernels", version = "0.1.1", optional = true }
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
half = { workspace = true }
@ -28,14 +26,11 @@ rand_distr = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
thiserror = { workspace = true }
yoke = { workspace = true }
zip = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }
clap = { workspace = true }
criterion = { workspace = true }
[features]
default = []
@ -43,8 +38,3 @@ cuda = ["cudarc", "dep:candle-kernels"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels"]
[[bench]]
name = "bench_main"
harness = false

View File

@ -1,9 +0,0 @@
mod benchmarks;
use criterion::criterion_main;
criterion_main!(
benchmarks::affine::benches,
benchmarks::matmul::benches,
benchmarks::random::benches,
benchmarks::where_cond::benches
);

View File

@ -1,43 +0,0 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(a: &Tensor) {
a.affine(12.34, 56.78).unwrap();
}
fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let b = 1;
let m = 1024;
let k = 1024;
let tensor = Tensor::zeros((b, m, k), dtype, &device).unwrap();
let flops = b * m * k * dtype.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&tensor));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_affine_benchmark(c, &device, DType::F32, "affine_f32");
run_affine_benchmark(c, &device, DType::F16, "affine_f16");
run_affine_benchmark(c, &device, DType::BF16, "affine_bf16");
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -1,44 +0,0 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(a: &Tensor, b: &Tensor) {
a.matmul(&b.t().unwrap()).unwrap();
}
fn run_bench(c: &mut Criterion, device: &Device) {
let b = 1;
let m = 1;
let n = 2048;
let k = 2048;
let dtype = DType::F32;
let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap();
let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap();
let flops = b * m * n * k;
let mut group = c.benchmark_group(device.bench_name("matmul"));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&lhs), black_box(&rhs));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_bench(c, &device);
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -1,66 +0,0 @@
pub(crate) mod affine;
pub(crate) mod matmul;
pub(crate) mod random;
pub(crate) mod where_cond;
use candle_core::{Device, Result};
pub(crate) trait BenchDevice {
fn sync(&self) -> Result<()>;
fn bench_name<S: Into<String>>(&self, name: S) -> String;
}
impl BenchDevice for Device {
fn sync(&self) -> Result<()> {
match self {
Device::Cpu => Ok(()),
Device::Cuda(device) => {
#[cfg(feature = "cuda")]
return Ok(device.synchronize()?);
#[cfg(not(feature = "cuda"))]
panic!("Cuda device without cuda feature enabled: {:?}", device)
}
Device::Metal(device) => {
#[cfg(feature = "metal")]
return Ok(device.wait_until_completed()?);
#[cfg(not(feature = "metal"))]
panic!("Metal device without metal feature enabled: {:?}", device)
}
}
}
fn bench_name<S: Into<String>>(&self, name: S) -> String {
match self {
Device::Cpu => {
let cpu_type = if cfg!(feature = "accelerate") {
"accelerate"
} else if cfg!(feature = "mkl") {
"mkl"
} else {
"cpu"
};
format!("{}_{}", cpu_type, name.into())
}
Device::Cuda(_) => format!("cuda_{}", name.into()),
Device::Metal(_) => format!("metal_{}", name.into()),
}
}
}
struct BenchDeviceHandler {
devices: Vec<Device>,
}
impl BenchDeviceHandler {
pub fn new() -> Result<Self> {
let mut devices = Vec::new();
if cfg!(feature = "metal") {
devices.push(Device::new_metal(0)?);
} else if cfg!(feature = "cuda") {
devices.push(Device::new_cuda(0)?);
}
devices.push(Device::Cpu);
Ok(Self { devices })
}
}

View File

@ -1,63 +0,0 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn rand_uniform(a: &Tensor) {
a.rand_like(-1.0, 123.0).unwrap();
}
fn rand_normal(a: &Tensor) {
a.randn_like(100.0, 15.0).unwrap();
}
fn run_random_bench(c: &mut Criterion, device: &Device) {
let b = 1;
let rows = 2048;
let cols = 2048;
let dtype = DType::F32;
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
let flops = b * rows * cols * dtype.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name("random_uniform"));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |benches| {
benches.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
rand_uniform(black_box(&tensor));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
let mut group = c.benchmark_group(device.bench_name("random_normal"));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |benches| {
benches.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
rand_normal(black_box(&tensor));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_random_bench(c, &device);
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -1,64 +0,0 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(a: &Tensor, b: &Tensor, c: &Tensor) {
a.where_cond(b, c).unwrap();
}
const fn create_cond_arr<const N: usize>() -> [u8; N] {
let mut arr = [0u8; N];
let mut i = 0;
while i < N {
arr[i] = (i % 2) as u8;
i += 1;
}
arr
}
const B: usize = 1;
const M: usize = 1024;
const K: usize = 1024;
const SIZE: usize = B * M * K;
const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), &device).unwrap();
let on_true = Tensor::ones((B, M, K), dtype, &device).unwrap();
let on_false = Tensor::zeros((B, M, K), dtype, &device).unwrap();
let elements = B * M * K;
// E.g. 2 f32 tensors + 1 u8 tensor
let flops = (2 * elements * dtype.size_in_bytes()) + elements;
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(
black_box(&tensor),
black_box(&on_true),
black_box(&on_false),
);
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let device = BenchDeviceHandler::new().unwrap();
for d in device.devices {
run_where_cond_benchmark(c, &d, DType::F32, "where_cond_f32");
run_where_cond_benchmark(c, &d, DType::BF16, "where_cond_bf16");
run_where_cond_benchmark(c, &d, DType::F16, "where_cond_f16");
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -8,10 +8,11 @@ use anyhow::Result;
use candle_core::{Device, Tensor};
fn main() -> Result<()> {
let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?;
let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?;
let new_a = a.slice_scatter(&b, 1, 2)?;
assert_eq!(a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
assert_eq!(new_a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
let start = std::time::Instant::now();
let res = inp.conv2d(&w, 0, 1);
println!("{:?}", start.elapsed());
println!("{res:?}");
Ok(())
}

View File

@ -0,0 +1,142 @@
/// This example contains some simple benchmarks so that it's easy to run them in perf etc.
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_core::{Device, Result, Tensor, D};
use clap::{Parser, Subcommand};
fn softmax<D: candle_core::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> {
let dim = dim.to_index(xs.shape(), "softmax")?;
let max = xs.max_keepdim(dim)?;
let diff = xs.broadcast_sub(&max)?;
let num = diff.exp()?;
let den = num.sum_keepdim(dim)?;
num.broadcast_div(&den)
}
trait Benchmark {
type PreProcessData;
type RunResult;
fn preprocess() -> Result<Self::PreProcessData>;
fn run_one(_: &Self::PreProcessData) -> Result<Self::RunResult>;
const ITERS: usize;
}
// Conv1d example as used in whisper.
struct Conv1d;
impl Benchmark for Conv1d {
type PreProcessData = (Tensor, Tensor);
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
let inp = Tensor::randn(0f32, 1., (1, 384, 3000), &Device::Cpu)?;
let w = Tensor::randn(0f32, 1., (384, 384, 3), &Device::Cpu)?;
Ok((inp, w))
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
d.0.conv1d(&d.1, 0, 1)
}
const ITERS: usize = 5;
}
// Conv2d example as used in stable-diffusion.
struct Conv2d;
impl Benchmark for Conv2d {
type PreProcessData = (Tensor, Tensor);
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
Ok((inp, w))
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
d.0.conv2d(&d.1, 0, 1)
}
const ITERS: usize = 1;
}
struct Matmul;
impl Benchmark for Matmul {
type PreProcessData = (Tensor, Tensor);
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
let lhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
let rhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
Ok((lhs, rhs))
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
d.0.matmul(&d.1)
}
const ITERS: usize = 100;
}
struct Softmax;
impl Benchmark for Softmax {
type PreProcessData = Tensor;
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
// Typical whisper tiny size.
let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?;
Ok(x)
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
softmax(d, D::Minus1)
}
const ITERS: usize = 100;
}
fn run<B: Benchmark>(iters: Option<usize>) -> Result<()> {
use std::hint::black_box;
let iters = iters.unwrap_or(B::ITERS);
let d = B::preprocess()?;
let start = std::time::Instant::now();
for _iter in 0..iters {
let _res = black_box(B::run_one(black_box(&d))?);
}
println!("{:?}", start.elapsed() / iters as u32);
Ok(())
}
#[derive(Subcommand, Debug, Clone)]
enum Task {
Conv1d,
Conv2d,
Matmul,
Softmax,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
/// The benchmark to be run.
#[command(subcommand)]
task: Task,
#[arg(long)]
iters: Option<usize>,
}
fn main() -> Result<()> {
let args = Args::parse();
match args.task {
Task::Conv1d => run::<Conv1d>(args.iters)?,
Task::Conv2d => run::<Conv2d>(args.iters)?,
Task::Matmul => run::<Matmul>(args.iters)?,
Task::Softmax => run::<Softmax>(args.iters)?,
}
Ok(())
}

View File

@ -9,21 +9,9 @@ use candle_core::{Device, Tensor};
fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let in_t = Tensor::rand(-1f32, 1f32, (1, 3, 12, 7), &device)?;
let k_t = Tensor::rand(-1f32, 1f32, (6, 3, 1, 1), &device)?;
let out_t = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
println!("{out_t}");
let in_t = in_t.to_device(&Device::Cpu)?;
let k_t = k_t.to_device(&Device::Cpu)?;
let out_t2 = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
let diff = (out_t.to_device(&Device::Cpu)? - out_t2)?
.sqr()?
.sum_all()?;
println!("{diff}");
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
let res = t.conv2d(&w, 1, 1, 1, 1)?;
let res = t.conv2d(&w, 1, 1)?;
println!("{res:?}");
Ok(())
}

View File

@ -1,389 +0,0 @@
use candle_core::quantized::{gguf_file, GgmlDType, QTensor};
use candle_core::{Device, Result};
use clap::{Parser, Subcommand, ValueEnum};
use rayon::prelude::*;
#[derive(ValueEnum, Debug, Clone)]
enum QuantizationMode {
/// The default quantization includes all 2d tensors, except the output tensor which always
/// uses Q6_K.
Llama,
}
impl QuantizationMode {
fn quantize(&self, name: &str, tensor: QTensor, dtype: GgmlDType) -> Result<QTensor> {
match self {
Self::Llama => {
// Same behavior as the llama.cpp quantization.
let should_quantize = name.ends_with(".weight") && tensor.rank() == 2;
if should_quantize {
let tensor = tensor.dequantize(&Device::Cpu)?;
if name == "output.weight" {
QTensor::quantize(&tensor, GgmlDType::Q6K)
} else {
QTensor::quantize(&tensor, dtype)
}
} else {
Ok(tensor)
}
}
}
}
}
#[derive(ValueEnum, Debug, Clone)]
enum Quantization {
#[value(name = "q4_0")]
Q4_0,
#[value(name = "q4_1")]
Q4_1,
#[value(name = "q5_0")]
Q5_0,
#[value(name = "q5_1")]
Q5_1,
#[value(name = "q8_0")]
Q8_0,
#[value(name = "q8_1")]
Q8_1,
Q2k,
Q3k,
Q4k,
Q5k,
Q6k,
Q8k,
F16,
F32,
}
impl Quantization {
fn dtype(&self) -> GgmlDType {
match self {
Quantization::Q4_0 => GgmlDType::Q4_0,
Quantization::Q4_1 => GgmlDType::Q4_1,
Quantization::Q5_0 => GgmlDType::Q5_0,
Quantization::Q5_1 => GgmlDType::Q5_1,
Quantization::Q8_0 => GgmlDType::Q8_0,
Quantization::Q8_1 => GgmlDType::Q8_1,
Quantization::Q2k => GgmlDType::Q2K,
Quantization::Q3k => GgmlDType::Q3K,
Quantization::Q4k => GgmlDType::Q4K,
Quantization::Q5k => GgmlDType::Q5K,
Quantization::Q6k => GgmlDType::Q6K,
Quantization::Q8k => GgmlDType::Q8K,
Quantization::F16 => GgmlDType::F16,
Quantization::F32 => GgmlDType::F32,
}
}
}
#[derive(ValueEnum, Debug, Clone)]
enum Format {
Safetensors,
Npz,
Ggml,
Gguf,
Pth,
Pickle,
}
impl Format {
fn infer<P: AsRef<std::path::Path>>(p: P) -> Option<Self> {
p.as_ref()
.extension()
.and_then(|e| e.to_str())
.and_then(|e| match e {
// We don't infer any format for .bin as it can be used for ggml/gguf or pytorch.
"safetensors" | "safetensor" => Some(Self::Safetensors),
"npz" => Some(Self::Npz),
"pth" | "pt" => Some(Self::Pth),
"ggml" => Some(Self::Ggml),
"gguf" => Some(Self::Gguf),
_ => None,
})
}
}
#[derive(Subcommand, Debug, Clone)]
enum Command {
Ls {
files: Vec<std::path::PathBuf>,
/// The file format to use, if unspecified infer from the file extension.
#[arg(long, value_enum)]
format: Option<Format>,
/// Enable verbose mode.
#[arg(short, long)]
verbose: bool,
},
Quantize {
/// The input file(s), in safetensors format.
in_file: Vec<std::path::PathBuf>,
/// The output file, in gguf format.
#[arg(long)]
out_file: std::path::PathBuf,
/// The quantization schema to apply.
#[arg(long, value_enum)]
quantization: Quantization,
/// Which tensor to quantize.
#[arg(long, value_enum, default_value_t = QuantizationMode::Llama)]
mode: QuantizationMode,
},
Dequantize {
/// The input file, in gguf format.
in_file: std::path::PathBuf,
/// The output file, in safetensors format.
#[arg(long)]
out_file: std::path::PathBuf,
},
}
#[derive(Parser, Debug, Clone)]
struct Args {
#[command(subcommand)]
command: Command,
}
fn run_ls(
file: &std::path::PathBuf,
format: Option<Format>,
verbose: bool,
device: &Device,
) -> Result<()> {
let format = match format {
Some(format) => format,
None => match Format::infer(file) {
Some(format) => format,
None => {
println!(
"{file:?}: cannot infer format from file extension, use the --format flag"
);
return Ok(());
}
},
};
match format {
Format::Npz => {
let tensors = candle_core::npy::NpzTensors::new(file)?;
let mut names = tensors.names();
names.sort();
for name in names {
let shape_dtype = match tensors.get_shape_and_dtype(name) {
Ok((shape, dtype)) => format!("[{shape:?}; {dtype:?}]"),
Err(err) => err.to_string(),
};
println!("{name}: {shape_dtype}")
}
}
Format::Safetensors => {
let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? };
let mut tensors = tensors.tensors();
tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, view) in tensors.iter() {
let dtype = view.dtype();
let dtype = match candle_core::DType::try_from(dtype) {
Ok(dtype) => format!("{dtype:?}"),
Err(_) => format!("{dtype:?}"),
};
let shape = view.shape();
println!("{name}: [{shape:?}; {dtype}]")
}
}
Format::Pth => {
let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose, None)?;
tensors.sort_by(|a, b| a.name.cmp(&b.name));
for tensor_info in tensors.iter() {
println!(
"{}: [{:?}; {:?}]",
tensor_info.name,
tensor_info.layout.shape(),
tensor_info.dtype,
);
if verbose {
println!(" {:?}", tensor_info);
}
}
}
Format::Pickle => {
let file = std::fs::File::open(file)?;
let mut reader = std::io::BufReader::new(file);
let mut stack = candle_core::pickle::Stack::empty();
stack.read_loop(&mut reader)?;
for (i, obj) in stack.stack().iter().enumerate() {
println!("{i} {obj:?}");
}
}
Format::Ggml => {
let mut file = std::fs::File::open(file)?;
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?;
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, qtensor) in tensors.iter() {
println!("{name}: [{:?}; {:?}]", qtensor.shape(), qtensor.dtype());
}
}
Format::Gguf => {
let mut file = std::fs::File::open(file)?;
let content = gguf_file::Content::read(&mut file)?;
if verbose {
let mut metadata = content.metadata.into_iter().collect::<Vec<_>>();
metadata.sort_by(|a, b| a.0.cmp(&b.0));
println!("metadata entries ({})", metadata.len());
for (key, value) in metadata.iter() {
println!(" {key}: {value:?}");
}
}
let mut tensors = content.tensor_infos.into_iter().collect::<Vec<_>>();
tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, info) in tensors.iter() {
println!("{name}: [{:?}; {:?}]", info.shape, info.ggml_dtype);
}
}
}
Ok(())
}
fn run_quantize_safetensors(
in_files: &[std::path::PathBuf],
out_file: std::path::PathBuf,
q: Quantization,
) -> Result<()> {
let mut out_file = std::fs::File::create(out_file)?;
let mut tensors = std::collections::HashMap::new();
for in_file in in_files.iter() {
let in_tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?;
tensors.extend(in_tensors)
}
println!("tensors: {}", tensors.len());
let dtype = q.dtype();
let block_size = dtype.block_size();
let qtensors = tensors
.into_par_iter()
.map(|(name, tensor)| {
let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0;
println!(" quantizing {name} {tensor:?} {should_quantize}");
let tensor = if should_quantize {
QTensor::quantize(&tensor, dtype)?
} else {
QTensor::quantize(&tensor, GgmlDType::F32)?
};
Ok((name, tensor))
})
.collect::<Result<Vec<_>>>()?;
let qtensors = qtensors
.iter()
.map(|(k, v)| (k.as_str(), v))
.collect::<Vec<_>>();
gguf_file::write(&mut out_file, &[], &qtensors)?;
Ok(())
}
fn run_dequantize(
in_file: std::path::PathBuf,
out_file: std::path::PathBuf,
device: &Device,
) -> Result<()> {
let mut in_file = std::fs::File::open(in_file)?;
let content = gguf_file::Content::read(&mut in_file)?;
let mut tensors = std::collections::HashMap::new();
for (tensor_name, _) in content.tensor_infos.iter() {
let tensor = content.tensor(&mut in_file, tensor_name, device)?;
let tensor = tensor.dequantize(device)?;
tensors.insert(tensor_name.to_string(), tensor);
}
candle_core::safetensors::save(&tensors, out_file)?;
Ok(())
}
fn run_quantize(
in_files: &[std::path::PathBuf],
out_file: std::path::PathBuf,
q: Quantization,
qmode: QuantizationMode,
device: &Device,
) -> Result<()> {
if in_files.is_empty() {
candle_core::bail!("no specified input files")
}
if let Some(extension) = out_file.extension() {
if extension == "safetensors" {
candle_core::bail!("the generated file cannot use the safetensors extension")
}
}
if let Some(extension) = in_files[0].extension() {
if extension == "safetensors" {
return run_quantize_safetensors(in_files, out_file, q);
}
}
if in_files.len() != 1 {
candle_core::bail!("only a single in-file can be used when quantizing gguf files")
}
// Open the out file early so as to fail directly on missing directories etc.
let mut out_file = std::fs::File::create(out_file)?;
let mut in_ = std::fs::File::open(&in_files[0])?;
let content = gguf_file::Content::read(&mut in_)?;
println!("tensors: {}", content.tensor_infos.len());
let dtype = q.dtype();
let qtensors = content
.tensor_infos
.par_iter()
.map(|(name, _)| {
println!(" quantizing {name}");
let mut in_file = std::fs::File::open(&in_files[0])?;
let tensor = content.tensor(&mut in_file, name, device)?;
let tensor = qmode.quantize(name, tensor, dtype)?;
Ok((name, tensor))
})
.collect::<Result<Vec<_>>>()?;
let qtensors = qtensors
.iter()
.map(|(k, v)| (k.as_str(), v))
.collect::<Vec<_>>();
let metadata = content
.metadata
.iter()
.map(|(k, v)| (k.as_str(), v))
.collect::<Vec<_>>();
gguf_file::write(&mut out_file, metadata.as_slice(), &qtensors)?;
Ok(())
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = Device::Cpu;
match args.command {
Command::Ls {
files,
format,
verbose,
} => {
let multiple_files = files.len() > 1;
for file in files.iter() {
if multiple_files {
println!("--- {file:?} ---");
}
run_ls(file, format.clone(), verbose, &device)?
}
}
Command::Quantize {
in_file,
out_file,
quantization,
mode,
} => run_quantize(&in_file, out_file, quantization, mode, &device)?,
Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file, &device)?,
}
Ok(())
}

View File

@ -50,8 +50,6 @@ mod ffi {
pub fn vvcos(dst: *mut c_double, src: *const c_double, len: *const c_int);
pub fn vvlogf(dst: *mut c_float, src: *const c_float, len: *const c_int);
pub fn vvlog(dst: *mut c_double, src: *const c_double, len: *const c_int);
pub fn vvtanhf(dst: *mut c_float, src: *const c_float, len: *const c_int);
pub fn vvtanh(dst: *mut c_double, src: *const c_double, len: *const c_int);
pub fn vDSP_vaddD(
_: *const c_double,
@ -125,42 +123,6 @@ mod ffi {
_: c_long,
_: c_ulong,
);
pub fn vDSP_vminD(
_: *const c_double,
_: c_long,
_: *const c_double,
_: c_long,
_: *mut c_double,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vmin(
_: *const c_float,
_: c_long,
_: *const c_float,
_: c_long,
_: *mut c_float,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vmaxD(
_: *const c_double,
_: c_long,
_: *const c_double,
_: c_long,
_: *mut c_double,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vmax(
_: *const c_float,
_: c_long,
_: *const c_float,
_: c_long,
_: *mut c_float,
_: c_long,
_: c_ulong,
);
}
}
@ -310,26 +272,6 @@ pub fn vd_cos(a: &[f64], y: &mut [f64]) {
}
unsafe { ffi::vvcos(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vvtanhf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vvtanh(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vs_ln(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
@ -370,38 +312,6 @@ pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a)
}
#[inline]
pub fn vs_tanh_inplace(y: &mut [f32]) {
unsafe { ffi::vvtanhf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
}
#[inline]
pub fn vd_tanh_inplace(y: &mut [f64]) {
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
}
#[inline]
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
}
vs_tanh_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = 0.5 * v * (1.0 + *y)
}
}
#[inline]
pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
}
vd_tanh_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = 0.5 * v * (1.0 + *y)
}
}
macro_rules! binary_op {
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
#[inline]
@ -438,7 +348,3 @@ binary_op!(vs_mul, f32, vDSP_vmul);
binary_op!(vd_mul, f64, vDSP_vmulD);
binary_op!(vs_div, f32, vDSP_vdiv);
binary_op!(vd_div, f64, vDSP_vdivD);
binary_op!(vs_max, f32, vDSP_vmax);
binary_op!(vd_max, f64, vDSP_vmaxD);
binary_op!(vs_min, f32, vDSP_vmin);
binary_op!(vd_min, f64, vDSP_vminD);

View File

@ -15,8 +15,6 @@ pub trait BackendStorage: Sized {
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self>;
fn powf(&self, _: &Layout, _: f64) -> Result<Self>;
fn elu(&self, _: &Layout, _: f64) -> Result<Self>;
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self>;
@ -39,14 +37,6 @@ pub trait BackendStorage: Sized {
_params: &crate::conv::ParamsConv1D,
) -> Result<Self>;
fn conv_transpose1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self>;
fn conv2d(
&self,
_l: &Layout,
@ -55,17 +45,8 @@ pub trait BackendStorage: Sized {
_params: &crate::conv::ParamsConv2D,
) -> Result<Self>;
fn conv_transpose2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose2D,
) -> Result<Self>;
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self>;
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
@ -119,6 +100,4 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
fn set_seed(&self, _: u64) -> Result<()>;
}

View File

@ -15,17 +15,6 @@ fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result
}
}
thread_local! {
static CANDLE_GRAD_DO_NOT_DETACH: bool = {
match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") {
Ok(s) => {
!s.is_empty() && s != "0"
},
Err(_) => false,
}
}
}
impl Tensor {
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
/// elements having dependencies on the latter ones, e.g. the first element if any is the
@ -47,8 +36,6 @@ impl Tensor {
// Do not call recursively on the "leaf" nodes.
track_grad = true;
nodes
} else if node.dtype().is_int() {
nodes
} else if let Some(op) = node.op() {
match op {
Op::IndexAdd(t1, t2, t3, _)
@ -68,27 +55,16 @@ impl Tensor {
kernel: rhs,
..
}
| Op::ConvTranspose1D {
arg: lhs,
kernel: rhs,
..
}
| Op::Conv2D {
arg: lhs,
kernel: rhs,
..
}
| Op::ConvTranspose2D {
arg: lhs,
kernel: rhs,
..
}
| Op::CustomOp2(lhs, rhs, _)
| Op::Binary(lhs, rhs, _)
| Op::Gather(lhs, rhs, _)
| Op::IndexSelect(lhs, rhs, _)
| Op::Matmul(lhs, rhs)
| Op::SliceScatter0(lhs, rhs, _) => {
| Op::Matmul(lhs, rhs) => {
let (tg, nodes) = walk(lhs, nodes, already_seen);
track_grad |= tg;
let (tg, nodes) = walk(rhs, nodes, already_seen);
@ -109,40 +85,25 @@ impl Tensor {
nodes
}
}
Op::Unary(_node, UnaryOp::Ceil)
| Op::Unary(_node, UnaryOp::Floor)
| Op::Unary(_node, UnaryOp::Round) => nodes,
Op::Reshape(node)
| Op::UpsampleNearest1D(node)
| Op::UpsampleNearest2D { arg: node, .. }
| Op::UpsampleNearest2D(node)
| Op::AvgPool2D { arg: node, .. }
| Op::MaxPool2D { arg: node, .. }
| Op::Copy(node)
| Op::Broadcast(node)
| Op::Cmp(node, _)
| Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
| Op::Reduce(node, _, _)
| Op::ToDType(node)
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
| Op::Permute(node, _)
| Op::Narrow(node, _, _, _)
| Op::Unary(node, _)
| Op::Elu(node, _)
| Op::Powf(node, _)
| Op::CustomOp1(node, _) => {
let (tg, nodes) = walk(node, nodes, already_seen);
track_grad |= tg;
nodes
}
Op::ToDType(node) => {
if node.dtype().is_float() {
let (tg, nodes) = walk(node, nodes, already_seen);
track_grad |= tg;
nodes
} else {
nodes
}
}
Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
}
} else {
nodes
@ -166,16 +127,10 @@ impl Tensor {
if node.is_variable() {
continue;
}
let grad = grads
.remove(node)
.expect("candle internal error - grad not populated");
// https://github.com/huggingface/candle/issues/1241
// Ideally, we would make these operations in place where possible to ensure that we
// do not have to allocate too often. Here we just call `.detach` to avoid computing
// the backprop graph of the backprop itself. This would be an issue for second order
// derivatives but these are out of scope at the moment.
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
let grad = if do_not_detach { grad } else { grad.detach() };
let grad = grads.remove(node).unwrap();
// TODO: We should perform all these operations in place (or at least not track the
// whole graph). The only drawback would be if we wanted to support grad of grad but
// this is out of scope.
if let Some(op) = node.op() {
match op {
Op::Binary(lhs, rhs, BinaryOp::Add) => {
@ -206,21 +161,6 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?;
}
Op::Binary(lhs, rhs, BinaryOp::Minimum)
| Op::Binary(lhs, rhs, BinaryOp::Maximum) => {
let mask_lhs = node.eq(lhs)?.to_dtype(grad.dtype())?;
let mask_rhs = node.eq(rhs)?.to_dtype(grad.dtype())?;
// If both masks are 1 one the same point, we want to scale the
// gradient by 0.5 rather than 1.
let lhs_grad = mask_lhs.mul(&grad)?.div(&(&mask_rhs + 1.)?)?;
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
let rhs_grad = mask_rhs.mul(&grad)?.div(&(&mask_lhs + 1.)?)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::WhereCond(pred, t, f) => {
let zeros = grad.zeros_like()?;
let t_sum_grad = grads.or_insert(t)?;
@ -230,156 +170,13 @@ impl Tensor {
let f_grad = pred.where_cond(&zeros, &grad)?;
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D {
arg,
kernel,
padding,
stride,
dilation,
} => {
// The output height for conv_transpose1d is:
// (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1
let grad_l_in = grad.dim(2)?;
let k_size = kernel.dim(2)?;
let out_size =
(grad_l_in - 1) * stride + dilation * (k_size - 1) + 1 - 2 * padding;
let out_padding = arg.dim(2)? - out_size;
let grad_arg = grad.conv_transpose1d(
kernel,
*padding,
out_padding,
*stride,
*dilation,
)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
let grad_kernel = arg
.transpose(0, 1)?
.conv1d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
let (_, _, k0) = kernel.dims3()?;
let (_, _, g_k0) = grad_kernel.dims3()?;
let grad_kernel = if g_k0 != k0 {
grad_kernel.narrow(2, 0, k0)?
} else {
grad_kernel
};
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::Conv2D {
arg,
kernel,
padding,
stride,
dilation,
} => {
// The output height for conv_transpose2d is:
// (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1
let grad_h = grad.dim(2)?;
let k_h = kernel.dim(2)?;
let out_size =
(grad_h - 1) * stride + dilation * (k_h - 1) + 1 - 2 * padding;
let out_padding = arg.dim(2)? - out_size;
let grad_arg = grad.conv_transpose2d(
kernel,
*padding,
out_padding,
*stride,
*dilation,
)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
let grad_kernel = arg
.transpose(0, 1)?
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
let (_, _, k0, k1) = kernel.dims4()?;
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
} else {
grad_kernel
};
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
op: "conv-transpose1d",
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?,
Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?,
Op::MaxPool2D { .. } => Err(Error::BackwardNotSupported { op: "max-pool2d" })?,
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest2d",
})?,
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
op: "conv-transpose2d",
})?,
Op::AvgPool2D {
arg,
kernel_size,
stride,
} => {
if kernel_size != stride {
crate::bail!("backward not supported for avgpool2d if ksize {kernel_size:?} != stride {stride:?}")
}
let (_n, _c, h, w) = arg.dims4()?;
let grad_arg = grad.upsample_nearest2d(h, w)?;
let grad_arg =
(grad_arg * (1f64 / (kernel_size.0 * kernel_size.1) as f64))?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
}
Op::MaxPool2D {
arg,
kernel_size,
stride,
} => {
if kernel_size != stride {
crate::bail!("backward not supported for maxpool2d if ksize {kernel_size:?} != stride {stride:?}")
}
let (_n, _c, h, w) = arg.dims4()?;
// For computing the max-pool gradient, we compute a mask where a 1 means
// that the element is the maximum, then we apply this mask to the
// upsampled gradient (taking into account that multiple max may exist so
// we scale the gradient for this case).
let node_upsampled = node.upsample_nearest2d(h, w)?;
let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?;
let avg = mask.avg_pool2d_with_stride(*kernel_size, *stride)?;
let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
}
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest1d",
})?,
Op::UpsampleNearest2D {
arg,
target_h,
target_w,
} => {
let (_n, c, h, w) = arg.dims4()?;
if target_h % h != 0 || target_w % w != 0 {
crate::bail!("backward not supported for non integer upscaling factors")
}
let scale_h = target_h / h;
let scale_w = target_w / w;
if scale_h != scale_w {
crate::bail!("backward not supported for non uniform upscaling factors")
};
let kernel =
Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?;
let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = conv_sum;
}
Op::SliceScatter0(lhs, rhs, start_rhs) => {
let rhs_sum_grad = grads.or_insert(rhs)?;
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
let lhs_sum_grad = grads.or_insert(lhs)?;
let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?
}
Op::Gather(arg, indexes, dim) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
@ -471,7 +268,7 @@ impl Tensor {
}
Op::ToDType(arg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad.to_dtype(arg.dtype())?)?
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
}
Op::Copy(arg) => {
let sum_grad = grads.or_insert(arg)?;
@ -494,11 +291,6 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
}
Op::Unary(arg, UnaryOp::Tanh) => {
let sum_grad = grads.or_insert(arg)?;
let minus_dtanh = (node.sqr()? - 1.)?;
*sum_grad = sum_grad.sub(&(&grad * &minus_dtanh)?)?
}
Op::Unary(arg, UnaryOp::Abs) => {
let sum_grad = grads.or_insert(arg)?;
let ones = arg.ones_like()?;
@ -551,59 +343,13 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
Op::Unary(_, UnaryOp::Floor) => {
Err(Error::BackwardNotSupported { op: "floor" })?
}
Op::Unary(_, UnaryOp::Round) => {
Err(Error::BackwardNotSupported { op: "round" })?
}
Op::Unary(arg, UnaryOp::Gelu) => {
let sum_grad = grads.or_insert(arg)?;
let cube = arg.powf(3.)?;
let tanh = (0.0356774 * &cube + (0.797885 * arg)?)?.tanh()?;
let gelu_grad = (((0.5 * &tanh)?
+ (0.0535161 * cube + (0.398942 * arg)?)? * (1. - tanh.powf(2.)?))?
+ 0.5)?;
*sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
}
Op::Unary(arg, UnaryOp::Erf) => {
let sum_grad = grads.or_insert(arg)?;
// d/dx erf(x) = 2/sqrt(pi) * e^(-x^2)
let erf_grad =
(2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?;
*sum_grad = sum_grad.add(&(&grad * erf_grad)?)?
}
Op::Unary(arg, UnaryOp::GeluErf) => {
let sum_grad = grads.or_insert(arg)?;
// d/dx gelu_erf(x) = 0.5 + 0.398942 e^(-x^2/2) x + 0.5 erf(x/sqrt(2))
let neg_half_square = (arg.sqr()?.neg()? / 2.)?;
let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?;
let arg_scaled_sqrt = (arg / 2f64.sqrt())?;
let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?;
let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?;
*sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?;
}
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
Op::Unary(arg, UnaryOp::Relu) => {
let sum_grad = grads.or_insert(arg)?;
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
}
Op::Elu(arg, alpha) => {
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
let sum_grad = grads.or_insert(arg)?;
let zeros = arg.zeros_like()?;
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?;
let combined_mask = (positive_mask + negative_exp_mask)?;
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
}
Op::Powf(arg, e) => {
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
Op::CustomOp1(arg, c) => {
if let Some(arg_grad) = c.bwd(arg, node, &grad)? {
let sum_grad = grads.or_insert(arg)?;
@ -657,15 +403,6 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Permute(arg, dims) => {
let mut inv_dims = vec![0; dims.len()];
for (i, &dim_idx) in dims.iter().enumerate() {
inv_dims[dim_idx] = i
}
let arg_grad = grad.permute(inv_dims)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
};
}
}
@ -673,7 +410,6 @@ impl Tensor {
}
}
#[derive(Debug)]
pub struct GradStore(HashMap<TensorId, Tensor>);
impl GradStore {

View File

@ -1,5 +1,3 @@
use crate::{op::BackpropOp, op::Op, Error, Result, Tensor};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConv1D {
pub(crate) b_size: usize,
@ -11,12 +9,12 @@ pub struct ParamsConv1D {
pub(crate) k_size: usize,
pub(crate) padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
}
impl ParamsConv1D {
pub(crate) fn l_out(&self) -> usize {
(self.l_in + 2 * self.padding - self.dilation * (self.k_size - 1) - 1) / self.stride + 1
let dilation = 1;
(self.l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
@ -25,46 +23,6 @@ impl ParamsConv1D {
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConvTranspose1D {
pub(crate) b_size: usize,
pub(crate) l_in: usize,
pub(crate) c_out: usize,
pub(crate) c_in: usize,
pub(crate) k_size: usize,
pub(crate) padding: usize,
pub(crate) output_padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
}
impl ParamsConvTranspose1D {
pub(crate) fn l_out(&self) -> usize {
(self.l_in - 1) * self.stride - 2 * self.padding
+ self.dilation * (self.k_size - 1)
+ self.output_padding
+ 1
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
let l_out = self.l_out();
vec![self.b_size, self.c_out, l_out]
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum CudnnFwdAlgo {
ImplicitGemm,
ImplicitPrecompGemm,
Gemm,
Direct,
Fft,
FftTiling,
Winograd,
WinogradNonFused,
Count,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConv2D {
pub(crate) b_size: usize,
@ -76,260 +34,20 @@ pub struct ParamsConv2D {
pub(crate) c_in: usize,
pub(crate) padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
}
impl ParamsConv2D {
pub(crate) fn out_h(&self) -> usize {
(self.i_h + 2 * self.padding - self.dilation * (self.k_h - 1) - 1) / self.stride + 1
let dilation = 1;
(self.i_h + 2 * self.padding - dilation * (self.k_h - 1) - 1) / self.stride + 1
}
pub(crate) fn out_w(&self) -> usize {
(self.i_w + 2 * self.padding - self.dilation * (self.k_w - 1) - 1) / self.stride + 1
let dilation = 1;
(self.i_w + 2 * self.padding - dilation * (self.k_w - 1) - 1) / self.stride + 1
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConvTranspose2D {
pub(crate) b_size: usize,
pub(crate) i_h: usize,
pub(crate) i_w: usize,
pub(crate) k_h: usize,
pub(crate) k_w: usize,
pub(crate) c_out: usize,
pub(crate) c_in: usize,
pub(crate) padding: usize,
pub(crate) output_padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
}
impl ParamsConvTranspose2D {
pub(crate) fn out_h(&self) -> usize {
(self.i_h - 1) * self.stride + self.dilation * (self.k_h - 1) + self.output_padding + 1
- 2 * self.padding
}
pub(crate) fn out_w(&self) -> usize {
(self.i_w - 1) * self.stride + self.dilation * (self.k_w - 1) + self.output_padding + 1
- 2 * self.padding
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
}
}
impl Tensor {
fn conv1d_single_group(&self, kernel: &Self, params: &ParamsConv1D) -> Result<Self> {
let storage =
self.storage()
.conv1d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
arg,
kernel,
padding: params.padding,
stride: params.stride,
dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
}
/// Applies a 1D convolution over the input tensor.
pub fn conv1d(
&self,
kernel: &Self,
padding: usize,
stride: usize,
dilation: usize,
groups: usize,
) -> Result<Self> {
let (c_out, c_in_k, k_size) = kernel.dims3()?;
let (b_size, c_in, l_in) = self.dims3()?;
if c_in != c_in_k * groups {
Err(Error::Conv1dInvalidArgs {
inp_shape: self.shape().clone(),
k_shape: kernel.shape().clone(),
padding,
stride,
msg: "the number of in-channels on the input doesn't match the kernel size",
}
.bt())?
}
let params = ParamsConv1D {
b_size,
l_in,
c_out: c_out / groups,
c_in: c_in / groups,
k_size,
padding,
stride,
dilation,
};
if groups == 1 {
self.conv1d_single_group(kernel, &params)
} else {
let blocks = self.chunk(groups, 1)?;
let kernel = kernel.chunk(groups, 0)?;
let blocks = blocks
.iter()
.zip(&kernel)
.map(|(block, kernel)| block.conv1d_single_group(kernel, &params))
.collect::<Result<Vec<_>>>()?;
Tensor::cat(&blocks, 1)
}
}
/// Applies a 1D transposed convolution over the input tensor.
pub fn conv_transpose1d(
&self,
kernel: &Self,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
) -> Result<Self> {
let (b_size, c_in, l_in) = self.dims3()?;
let (c_in_k, c_out, k_size) = kernel.dims3()?;
if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
}
let params = ParamsConvTranspose1D {
b_size,
l_in,
k_size,
c_out,
c_in,
padding,
output_padding,
stride,
dilation,
};
let storage = self.storage().conv_transpose1d(
self.layout(),
&kernel.storage(),
kernel.layout(),
&params,
)?;
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
arg,
kernel,
padding: params.padding,
output_padding: params.output_padding,
stride: params.stride,
dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
}
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
let storage =
self.storage()
.conv2d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
arg,
kernel,
padding: params.padding,
stride: params.stride,
dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
}
/// Applies a 2D convolution over the input tensor.
pub fn conv2d(
&self,
kernel: &Self,
padding: usize,
stride: usize,
dilation: usize,
groups: usize,
) -> Result<Self> {
let (b_size, c_in, i_h, i_w) = self.dims4()?;
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
if c_in != c_in_k * groups {
crate::bail!(
"in_channel mismatch between input ({c_in}, groups {groups}) and kernel ({c_in_k})"
)
}
let params = ParamsConv2D {
b_size,
i_h,
i_w,
k_h,
k_w,
c_out: c_out / groups,
c_in: c_in / groups,
padding,
stride,
dilation,
cudnn_fwd_algo: None,
};
if groups == 1 {
self.conv2d_single_group(kernel, &params)
} else {
let blocks = self.chunk(groups, 1)?;
let kernel = kernel.chunk(groups, 0)?;
let blocks = blocks
.iter()
.zip(&kernel)
.map(|(block, kernel)| block.conv2d_single_group(kernel, &params))
.collect::<Result<Vec<_>>>()?;
Tensor::cat(&blocks, 1)
}
}
/// Applies a 2D transposed convolution over the input tensor.
pub fn conv_transpose2d(
&self,
kernel: &Self,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
) -> Result<Self> {
let (b_size, c_in, i_h, i_w) = self.dims4()?;
let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?;
if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
}
let params = ParamsConvTranspose2D {
b_size,
i_h,
i_w,
k_h,
k_w,
c_out,
c_in,
padding,
output_padding,
stride,
dilation,
};
let storage = self.storage().conv_transpose2d(
self.layout(),
&kernel.storage(),
kernel.layout(),
&params,
)?;
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose2D {
arg,
kernel,
padding: params.padding,
output_padding: params.output_padding,
stride: params.stride,
dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
}
}

View File

@ -92,7 +92,6 @@ from_tensor!(f64);
from_tensor!(f32);
from_tensor!(f16);
from_tensor!(bf16);
from_tensor!(i64);
from_tensor!(u32);
from_tensor!(u8);
@ -130,11 +129,6 @@ impl Tensor {
f.write_u32::<LittleEndian>(v)?
}
}
DType::I64 => {
for v in vs.to_vec1::<i64>()? {
f.write_i64::<LittleEndian>(v)?
}
}
DType::U8 => {
let vs = vs.to_vec1::<u8>()?;
f.write_all(&vs)?;

View File

@ -103,7 +103,7 @@ impl CpuF16<ARR> for CurrentCpuF16 {
for i in 0..8 {
tmp[i] = (*mem_addr.add(i)).to_f32();
}
_mm256_loadu_ps(tmp.as_ptr())
_mm_loadu_ps(tmp.as_ptr())
}
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {

View File

@ -1,763 +0,0 @@
#![allow(clippy::excessive_precision)]
// Code taken from https://github.com/statrs-dev/statrs
//! Provides the [error](https://en.wikipedia.org/wiki/Error_function) and
//! related functions
mod evaluate {
//! Provides functions that don't have a numerical solution and must
//! be solved computationally (e.g. evaluation of a polynomial)
/// evaluates a polynomial at `z` where `coeff` are the coeffecients
/// to a polynomial of order `k` where `k` is the length of `coeff` and the
/// coeffecient
/// to the `k`th power is the `k`th element in coeff. E.g. [3,-1,2] equates to
/// `2z^2 - z + 3`
///
/// # Remarks
///
/// Returns 0 for a 0 length coefficient slice
pub fn polynomial(z: f64, coeff: &[f64]) -> f64 {
let n = coeff.len();
if n == 0 {
return 0.0;
}
let mut sum = *coeff.last().unwrap();
for c in coeff[0..n - 1].iter().rev() {
sum = *c + z * sum;
}
sum
}
}
use std::f64;
/// `erf` calculates the error function at `x`.
pub fn erf(x: f64) -> f64 {
if x.is_nan() {
f64::NAN
} else if x >= 0.0 && x.is_infinite() {
1.0
} else if x <= 0.0 && x.is_infinite() {
-1.0
} else if x == 0. {
0.0
} else {
erf_impl(x, false)
}
}
/// `erf_inv` calculates the inverse error function
/// at `x`.
pub fn erf_inv(x: f64) -> f64 {
if x == 0.0 {
0.0
} else if x >= 1.0 {
f64::INFINITY
} else if x <= -1.0 {
f64::NEG_INFINITY
} else if x < 0.0 {
erf_inv_impl(-x, 1.0 + x, -1.0)
} else {
erf_inv_impl(x, 1.0 - x, 1.0)
}
}
/// `erfc` calculates the complementary error function
/// at `x`.
pub fn erfc(x: f64) -> f64 {
if x.is_nan() {
f64::NAN
} else if x == f64::INFINITY {
0.0
} else if x == f64::NEG_INFINITY {
2.0
} else {
erf_impl(x, true)
}
}
/// `erfc_inv` calculates the complementary inverse
/// error function at `x`.
pub fn erfc_inv(x: f64) -> f64 {
if x <= 0.0 {
f64::INFINITY
} else if x >= 2.0 {
f64::NEG_INFINITY
} else if x > 1.0 {
erf_inv_impl(-1.0 + x, 2.0 - x, -1.0)
} else {
erf_inv_impl(1.0 - x, x, 1.0)
}
}
// **********************************************************
// ********** Coefficients for erf_impl polynomial **********
// **********************************************************
/// Polynomial coefficients for a numerator of `erf_impl`
/// in the interval [1e-10, 0.5].
const ERF_IMPL_AN: &[f64] = &[
0.00337916709551257388990745,
-0.00073695653048167948530905,
-0.374732337392919607868241,
0.0817442448733587196071743,
-0.0421089319936548595203468,
0.0070165709512095756344528,
-0.00495091255982435110337458,
0.000871646599037922480317225,
];
/// Polynomial coefficients for a denominator of `erf_impl`
/// in the interval [1e-10, 0.5]
const ERF_IMPL_AD: &[f64] = &[
1.0,
-0.218088218087924645390535,
0.412542972725442099083918,
-0.0841891147873106755410271,
0.0655338856400241519690695,
-0.0120019604454941768171266,
0.00408165558926174048329689,
-0.000615900721557769691924509,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [0.5, 0.75].
const ERF_IMPL_BN: &[f64] = &[
-0.0361790390718262471360258,
0.292251883444882683221149,
0.281447041797604512774415,
0.125610208862766947294894,
0.0274135028268930549240776,
0.00250839672168065762786937,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [0.5, 0.75].
const ERF_IMPL_BD: &[f64] = &[
1.0,
1.8545005897903486499845,
1.43575803037831418074962,
0.582827658753036572454135,
0.124810476932949746447682,
0.0113724176546353285778481,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [0.75, 1.25].
const ERF_IMPL_CN: &[f64] = &[
-0.0397876892611136856954425,
0.153165212467878293257683,
0.191260295600936245503129,
0.10276327061989304213645,
0.029637090615738836726027,
0.0046093486780275489468812,
0.000307607820348680180548455,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [0.75, 1.25].
const ERF_IMPL_CD: &[f64] = &[
1.0,
1.95520072987627704987886,
1.64762317199384860109595,
0.768238607022126250082483,
0.209793185936509782784315,
0.0319569316899913392596356,
0.00213363160895785378615014,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [1.25, 2.25].
const ERF_IMPL_DN: &[f64] = &[
-0.0300838560557949717328341,
0.0538578829844454508530552,
0.0726211541651914182692959,
0.0367628469888049348429018,
0.00964629015572527529605267,
0.00133453480075291076745275,
0.778087599782504251917881e-4,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [1.25, 2.25].
const ERF_IMPL_DD: &[f64] = &[
1.0,
1.75967098147167528287343,
1.32883571437961120556307,
0.552528596508757581287907,
0.133793056941332861912279,
0.0179509645176280768640766,
0.00104712440019937356634038,
-0.106640381820357337177643e-7,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [2.25, 3.5].
const ERF_IMPL_EN: &[f64] = &[
-0.0117907570137227847827732,
0.014262132090538809896674,
0.0202234435902960820020765,
0.00930668299990432009042239,
0.00213357802422065994322516,
0.00025022987386460102395382,
0.120534912219588189822126e-4,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [2.25, 3.5].
const ERF_IMPL_ED: &[f64] = &[
1.0,
1.50376225203620482047419,
0.965397786204462896346934,
0.339265230476796681555511,
0.0689740649541569716897427,
0.00771060262491768307365526,
0.000371421101531069302990367,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [3.5, 5.25].
const ERF_IMPL_FN: &[f64] = &[
-0.00546954795538729307482955,
0.00404190278731707110245394,
0.0054963369553161170521356,
0.00212616472603945399437862,
0.000394984014495083900689956,
0.365565477064442377259271e-4,
0.135485897109932323253786e-5,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [3.5, 5.25].
const ERF_IMPL_FD: &[f64] = &[
1.0,
1.21019697773630784832251,
0.620914668221143886601045,
0.173038430661142762569515,
0.0276550813773432047594539,
0.00240625974424309709745382,
0.891811817251336577241006e-4,
-0.465528836283382684461025e-11,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [5.25, 8].
const ERF_IMPL_GN: &[f64] = &[
-0.00270722535905778347999196,
0.0013187563425029400461378,
0.00119925933261002333923989,
0.00027849619811344664248235,
0.267822988218331849989363e-4,
0.923043672315028197865066e-6,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [5.25, 8].
const ERF_IMPL_GD: &[f64] = &[
1.0,
0.814632808543141591118279,
0.268901665856299542168425,
0.0449877216103041118694989,
0.00381759663320248459168994,
0.000131571897888596914350697,
0.404815359675764138445257e-11,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [8, 11.5].
const ERF_IMPL_HN: &[f64] = &[
-0.00109946720691742196814323,
0.000406425442750422675169153,
0.000274499489416900707787024,
0.465293770646659383436343e-4,
0.320955425395767463401993e-5,
0.778286018145020892261936e-7,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [8, 11.5].
const ERF_IMPL_HD: &[f64] = &[
1.0,
0.588173710611846046373373,
0.139363331289409746077541,
0.0166329340417083678763028,
0.00100023921310234908642639,
0.24254837521587225125068e-4,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [11.5, 17].
const ERF_IMPL_IN: &[f64] = &[
-0.00056907993601094962855594,
0.000169498540373762264416984,
0.518472354581100890120501e-4,
0.382819312231928859704678e-5,
0.824989931281894431781794e-7,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [11.5, 17].
const ERF_IMPL_ID: &[f64] = &[
1.0,
0.339637250051139347430323,
0.043472647870310663055044,
0.00248549335224637114641629,
0.535633305337152900549536e-4,
-0.117490944405459578783846e-12,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [17, 24].
const ERF_IMPL_JN: &[f64] = &[
-0.000241313599483991337479091,
0.574224975202501512365975e-4,
0.115998962927383778460557e-4,
0.581762134402593739370875e-6,
0.853971555085673614607418e-8,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [17, 24].
const ERF_IMPL_JD: &[f64] = &[
1.0,
0.233044138299687841018015,
0.0204186940546440312625597,
0.000797185647564398289151125,
0.117019281670172327758019e-4,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [24, 38].
const ERF_IMPL_KN: &[f64] = &[
-0.000146674699277760365803642,
0.162666552112280519955647e-4,
0.269116248509165239294897e-5,
0.979584479468091935086972e-7,
0.101994647625723465722285e-8,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [24, 38].
const ERF_IMPL_KD: &[f64] = &[
1.0,
0.165907812944847226546036,
0.0103361716191505884359634,
0.000286593026373868366935721,
0.298401570840900340874568e-5,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [38, 60].
const ERF_IMPL_LN: &[f64] = &[
-0.583905797629771786720406e-4,
0.412510325105496173512992e-5,
0.431790922420250949096906e-6,
0.993365155590013193345569e-8,
0.653480510020104699270084e-10,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [38, 60].
const ERF_IMPL_LD: &[f64] = &[
1.0,
0.105077086072039915406159,
0.00414278428675475620830226,
0.726338754644523769144108e-4,
0.477818471047398785369849e-6,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [60, 85].
const ERF_IMPL_MN: &[f64] = &[
-0.196457797609229579459841e-4,
0.157243887666800692441195e-5,
0.543902511192700878690335e-7,
0.317472492369117710852685e-9,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [60, 85].
const ERF_IMPL_MD: &[f64] = &[
1.0,
0.052803989240957632204885,
0.000926876069151753290378112,
0.541011723226630257077328e-5,
0.535093845803642394908747e-15,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [85, 110].
const ERF_IMPL_NN: &[f64] = &[
-0.789224703978722689089794e-5,
0.622088451660986955124162e-6,
0.145728445676882396797184e-7,
0.603715505542715364529243e-10,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [85, 110].
const ERF_IMPL_ND: &[f64] = &[
1.0,
0.0375328846356293715248719,
0.000467919535974625308126054,
0.193847039275845656900547e-5,
];
// **********************************************************
// ********** Coefficients for erf_inv_impl polynomial ******
// **********************************************************
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0, 0.5].
const ERF_INV_IMPL_AN: &[f64] = &[
-0.000508781949658280665617,
-0.00836874819741736770379,
0.0334806625409744615033,
-0.0126926147662974029034,
-0.0365637971411762664006,
0.0219878681111168899165,
0.00822687874676915743155,
-0.00538772965071242932965,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0, 0.5].
const ERF_INV_IMPL_AD: &[f64] = &[
1.0,
-0.970005043303290640362,
-1.56574558234175846809,
1.56221558398423026363,
0.662328840472002992063,
-0.71228902341542847553,
-0.0527396382340099713954,
0.0795283687341571680018,
-0.00233393759374190016776,
0.000886216390456424707504,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.5, 0.75].
const ERF_INV_IMPL_BN: &[f64] = &[
-0.202433508355938759655,
0.105264680699391713268,
8.37050328343119927838,
17.6447298408374015486,
-18.8510648058714251895,
-44.6382324441786960818,
17.445385985570866523,
21.1294655448340526258,
-3.67192254707729348546,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.5, 0.75].
const ERF_INV_IMPL_BD: &[f64] = &[
1.0,
6.24264124854247537712,
3.9713437953343869095,
-28.6608180499800029974,
-20.1432634680485188801,
48.5609213108739935468,
10.8268667355460159008,
-22.6436933413139721736,
1.72114765761200282724,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x less than 3.
const ERF_INV_IMPL_CN: &[f64] = &[
-0.131102781679951906451,
-0.163794047193317060787,
0.117030156341995252019,
0.387079738972604337464,
0.337785538912035898924,
0.142869534408157156766,
0.0290157910005329060432,
0.00214558995388805277169,
-0.679465575181126350155e-6,
0.285225331782217055858e-7,
-0.681149956853776992068e-9,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x less than 3.
const ERF_INV_IMPL_CD: &[f64] = &[
1.0,
3.46625407242567245975,
5.38168345707006855425,
4.77846592945843778382,
2.59301921623620271374,
0.848854343457902036425,
0.152264338295331783612,
0.01105924229346489121,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 3 and 6.
const ERF_INV_IMPL_DN: &[f64] = &[
-0.0350353787183177984712,
-0.00222426529213447927281,
0.0185573306514231072324,
0.00950804701325919603619,
0.00187123492819559223345,
0.000157544617424960554631,
0.460469890584317994083e-5,
-0.230404776911882601748e-9,
0.266339227425782031962e-11,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 3 and 6.
const ERF_INV_IMPL_DD: &[f64] = &[
1.0,
1.3653349817554063097,
0.762059164553623404043,
0.220091105764131249824,
0.0341589143670947727934,
0.00263861676657015992959,
0.764675292302794483503e-4,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 6 and 18.
const ERF_INV_IMPL_EN: &[f64] = &[
-0.0167431005076633737133,
-0.00112951438745580278863,
0.00105628862152492910091,
0.000209386317487588078668,
0.149624783758342370182e-4,
0.449696789927706453732e-6,
0.462596163522878599135e-8,
-0.281128735628831791805e-13,
0.99055709973310326855e-16,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 6 and 18.
const ERF_INV_IMPL_ED: &[f64] = &[
1.0,
0.591429344886417493481,
0.138151865749083321638,
0.0160746087093676504695,
0.000964011807005165528527,
0.275335474764726041141e-4,
0.282243172016108031869e-6,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 18 and 44.
const ERF_INV_IMPL_FN: &[f64] = &[
-0.0024978212791898131227,
-0.779190719229053954292e-5,
0.254723037413027451751e-4,
0.162397777342510920873e-5,
0.396341011304801168516e-7,
0.411632831190944208473e-9,
0.145596286718675035587e-11,
-0.116765012397184275695e-17,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 18 and 44.
const ERF_INV_IMPL_FD: &[f64] = &[
1.0,
0.207123112214422517181,
0.0169410838120975906478,
0.000690538265622684595676,
0.145007359818232637924e-4,
0.144437756628144157666e-6,
0.509761276599778486139e-9,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x greater than 44.
const ERF_INV_IMPL_GN: &[f64] = &[
-0.000539042911019078575891,
-0.28398759004727721098e-6,
0.899465114892291446442e-6,
0.229345859265920864296e-7,
0.225561444863500149219e-9,
0.947846627503022684216e-12,
0.135880130108924861008e-14,
-0.348890393399948882918e-21,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x greater than 44.
const ERF_INV_IMPL_GD: &[f64] = &[
1.0,
0.0845746234001899436914,
0.00282092984726264681981,
0.468292921940894236786e-4,
0.399968812193862100054e-6,
0.161809290887904476097e-8,
0.231558608310259605225e-11,
];
/// `erf_impl` computes the error function at `z`.
/// If `inv` is true, `1 - erf` is calculated as opposed to `erf`
fn erf_impl(z: f64, inv: bool) -> f64 {
if z < 0.0 {
if !inv {
return -erf_impl(-z, false);
}
if z < -0.5 {
return 2.0 - erf_impl(-z, true);
}
return 1.0 + erf_impl(-z, false);
}
let result = if z < 0.5 {
if z < 1e-10 {
z * 1.125 + z * 0.003379167095512573896158903121545171688
} else {
z * 1.125
+ z * evaluate::polynomial(z, ERF_IMPL_AN) / evaluate::polynomial(z, ERF_IMPL_AD)
}
} else if z < 110.0 {
let (r, b) = if z < 0.75 {
(
evaluate::polynomial(z - 0.5, ERF_IMPL_BN)
/ evaluate::polynomial(z - 0.5, ERF_IMPL_BD),
0.3440242112,
)
} else if z < 1.25 {
(
evaluate::polynomial(z - 0.75, ERF_IMPL_CN)
/ evaluate::polynomial(z - 0.75, ERF_IMPL_CD),
0.419990927,
)
} else if z < 2.25 {
(
evaluate::polynomial(z - 1.25, ERF_IMPL_DN)
/ evaluate::polynomial(z - 1.25, ERF_IMPL_DD),
0.4898625016,
)
} else if z < 3.5 {
(
evaluate::polynomial(z - 2.25, ERF_IMPL_EN)
/ evaluate::polynomial(z - 2.25, ERF_IMPL_ED),
0.5317370892,
)
} else if z < 5.25 {
(
evaluate::polynomial(z - 3.5, ERF_IMPL_FN)
/ evaluate::polynomial(z - 3.5, ERF_IMPL_FD),
0.5489973426,
)
} else if z < 8.0 {
(
evaluate::polynomial(z - 5.25, ERF_IMPL_GN)
/ evaluate::polynomial(z - 5.25, ERF_IMPL_GD),
0.5571740866,
)
} else if z < 11.5 {
(
evaluate::polynomial(z - 8.0, ERF_IMPL_HN)
/ evaluate::polynomial(z - 8.0, ERF_IMPL_HD),
0.5609807968,
)
} else if z < 17.0 {
(
evaluate::polynomial(z - 11.5, ERF_IMPL_IN)
/ evaluate::polynomial(z - 11.5, ERF_IMPL_ID),
0.5626493692,
)
} else if z < 24.0 {
(
evaluate::polynomial(z - 17.0, ERF_IMPL_JN)
/ evaluate::polynomial(z - 17.0, ERF_IMPL_JD),
0.5634598136,
)
} else if z < 38.0 {
(
evaluate::polynomial(z - 24.0, ERF_IMPL_KN)
/ evaluate::polynomial(z - 24.0, ERF_IMPL_KD),
0.5638477802,
)
} else if z < 60.0 {
(
evaluate::polynomial(z - 38.0, ERF_IMPL_LN)
/ evaluate::polynomial(z - 38.0, ERF_IMPL_LD),
0.5640528202,
)
} else if z < 85.0 {
(
evaluate::polynomial(z - 60.0, ERF_IMPL_MN)
/ evaluate::polynomial(z - 60.0, ERF_IMPL_MD),
0.5641309023,
)
} else {
(
evaluate::polynomial(z - 85.0, ERF_IMPL_NN)
/ evaluate::polynomial(z - 85.0, ERF_IMPL_ND),
0.5641584396,
)
};
let g = (-z * z).exp() / z;
g * b + g * r
} else {
0.0
};
if inv && z >= 0.5 {
result
} else if z >= 0.5 || inv {
1.0 - result
} else {
result
}
}
// `erf_inv_impl` computes the inverse error function where
// `p`,`q`, and `s` are the first, second, and third intermediate
// parameters respectively
fn erf_inv_impl(p: f64, q: f64, s: f64) -> f64 {
let result = if p <= 0.5 {
let y = 0.0891314744949340820313;
let g = p * (p + 10.0);
let r = evaluate::polynomial(p, ERF_INV_IMPL_AN) / evaluate::polynomial(p, ERF_INV_IMPL_AD);
g * y + g * r
} else if q >= 0.25 {
let y = 2.249481201171875;
let g = (-2.0 * q.ln()).sqrt();
let xs = q - 0.25;
let r =
evaluate::polynomial(xs, ERF_INV_IMPL_BN) / evaluate::polynomial(xs, ERF_INV_IMPL_BD);
g / (y + r)
} else {
let x = (-q.ln()).sqrt();
if x < 3.0 {
let y = 0.807220458984375;
let xs = x - 1.125;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_CN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_CD);
y * x + r * x
} else if x < 6.0 {
let y = 0.93995571136474609375;
let xs = x - 3.0;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_DN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_DD);
y * x + r * x
} else if x < 18.0 {
let y = 0.98362827301025390625;
let xs = x - 6.0;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_EN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_ED);
y * x + r * x
} else if x < 44.0 {
let y = 0.99714565277099609375;
let xs = x - 18.0;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_FN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_FD);
y * x + r * x
} else {
let y = 0.99941349029541015625;
let xs = x - 44.0;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_GN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_GD);
y * x + r * x
}
};
s * result
}

View File

@ -1,7 +1,4 @@
pub trait VecOps: num_traits::NumAssign + Copy {
fn min(self, rhs: Self) -> Self;
fn max(self, rhs: Self) -> Self;
/// Dot-product of two vectors.
///
/// # Safety
@ -29,47 +26,9 @@ pub trait VecOps: num_traits::NumAssign + Copy {
*res += *xs.add(i)
}
}
/// Maximum element in a non-empty vector.
///
/// # Safety
///
/// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
/// element.
#[inline(always)]
unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) {
*res = *xs;
for i in 1..len {
*res = (*res).max(*xs.add(i))
}
}
/// Minimum element in a non-empty vector.
///
/// # Safety
///
/// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
/// element.
#[inline(always)]
unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) {
*res = *xs;
for i in 1..len {
*res = (*res).min(*xs.add(i))
}
}
}
impl VecOps for f32 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Self::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
Self::max(self, other)
}
#[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
super::vec_dot_f32(lhs, rhs, res, len)
@ -82,16 +41,6 @@ impl VecOps for f32 {
}
impl VecOps for half::f16 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Self::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
Self::max(self, other)
}
#[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
let mut res_f32 = 0f32;
@ -100,61 +49,10 @@ impl VecOps for half::f16 {
}
}
impl VecOps for f64 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Self::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
Self::max(self, other)
}
}
impl VecOps for half::bf16 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Self::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
Self::max(self, other)
}
}
impl VecOps for u8 {
#[inline(always)]
fn min(self, other: Self) -> Self {
<Self as Ord>::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
<Self as Ord>::max(self, other)
}
}
impl VecOps for u32 {
#[inline(always)]
fn min(self, other: Self) -> Self {
<Self as Ord>::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
<Self as Ord>::max(self, other)
}
}
impl VecOps for i64 {
#[inline(always)]
fn min(self, other: Self) -> Self {
<Self as Ord>::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
<Self as Ord>::max(self, other)
}
}
impl VecOps for f64 {}
impl VecOps for half::bf16 {}
impl VecOps for u8 {}
impl VecOps for u32 {}
#[inline(always)]
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {

View File

@ -1,4 +1,3 @@
pub mod erf;
pub mod kernels;
trait Cpu<const ARR: usize> {

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,7 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
pub use candle_kernels as kernels;
use candle_kernels as kernels;
pub use cudarc;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
use cudarc::driver::{
@ -139,14 +139,6 @@ impl CudaDevice {
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
let params = (&data, v as i64, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
@ -223,14 +215,6 @@ impl BackendDevice for CudaDevice {
})
}
fn set_seed(&self, seed: u64) -> Result<()> {
// We do not call set_seed but instead create a new curand object. This ensures that the
// state will be identical and the same random numbers will be generated.
let mut curand = self.curand.lock().unwrap();
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
Ok(())
}
fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Cuda {
gpu_id: self.device.ordinal(),
@ -252,10 +236,6 @@ impl BackendDevice for CudaDevice {
let data = self.alloc_zeros::<u32>(elem_count).w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
let data = self.alloc_zeros::<i64>(elem_count).w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
CudaStorageSlice::BF16(data)
@ -285,13 +265,11 @@ impl BackendDevice for CudaDevice {
let slice = match dtype {
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
// cudarc changes.
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_uniform",
})
.w()?
}
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_uniform",
})
.w()?,
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
curand.0.fill_with_uniform(&mut data).w()?;
@ -303,12 +281,10 @@ impl BackendDevice for CudaDevice {
CudaStorageSlice::F64(data)
}
};
let slice = if lo == 0. && up == 1.0 {
slice
} else {
if lo != 0.0 || up != 1.0 {
let layout = Layout::contiguous(shape);
Affine(up - lo, lo).map(&slice, self, &layout)?
};
Affine(up - lo, lo).map(&slice, self, &layout)?;
}
Ok(CudaStorage {
slice,
device: self.clone(),
@ -320,23 +296,14 @@ impl BackendDevice for CudaDevice {
// cudarc changes.
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
// curand can only generate an odd number of values.
// https://github.com/huggingface/candle/issues/734
let elem_count_round = if elem_count % 2 == 1 {
elem_count + 1
} else {
elem_count
};
let slice = match dtype {
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_normal",
})
.w()?
}
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_normal",
})
.w()?,
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
curand
.0
.fill_with_normal(&mut data, mean as f32, std as f32)
@ -344,7 +311,7 @@ impl BackendDevice for CudaDevice {
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
curand.0.fill_with_normal(&mut data, mean, std).w()?;
CudaStorageSlice::F64(data)
}
@ -369,10 +336,6 @@ impl BackendDevice for CudaDevice {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorage::I64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::I64(data)
}
CpuStorage::BF16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::BF16(data)
@ -398,10 +361,9 @@ impl BackendDevice for CudaDevice {
}
#[derive(Debug)]
pub enum CudaStorageSlice {
enum CudaStorageSlice {
U8(CudaSlice<u8>),
U32(CudaSlice<u32>),
I64(CudaSlice<i64>),
BF16(CudaSlice<bf16>),
F16(CudaSlice<f16>),
F32(CudaSlice<f32>),
@ -409,7 +371,7 @@ pub enum CudaStorageSlice {
}
type S = CudaStorageSlice;
pub trait Map1 {
trait Map1 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
@ -421,7 +383,6 @@ pub trait Map1 {
let out = match s {
S::U8(s) => S::U8(self.f(s, d, l)?),
S::U32(s) => S::U32(self.f(s, d, l)?),
S::I64(s) => S::I64(self.f(s, d, l)?),
S::BF16(s) => S::BF16(self.f(s, d, l)?),
S::F16(s) => S::F16(self.f(s, d, l)?),
S::F32(s) => S::F32(self.f(s, d, l)?),
@ -431,7 +392,7 @@ pub trait Map1 {
}
}
pub trait Map2 {
trait Map2 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
@ -445,7 +406,6 @@ pub trait Map2 {
let out = match (s1, s2) {
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
(S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?),
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
@ -456,7 +416,7 @@ pub trait Map2 {
}
}
pub trait Map2InPlace {
trait Map2InPlace {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
@ -477,7 +437,6 @@ pub trait Map2InPlace {
match (dst, src) {
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
@ -487,7 +446,7 @@ pub trait Map2InPlace {
}
}
pub trait Map1Any {
trait Map1Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
src: &CudaSlice<T>,
@ -500,7 +459,6 @@ pub trait Map1Any {
let out = match s {
S::U8(s) => self.f(s, d, l, S::U8)?,
S::U32(s) => self.f(s, d, l, S::U32)?,
S::I64(s) => self.f(s, d, l, S::I64)?,
S::BF16(s) => self.f(s, d, l, S::BF16)?,
S::F16(s) => self.f(s, d, l, S::F16)?,
S::F32(s) => self.f(s, d, l, S::F32)?,
@ -510,7 +468,7 @@ pub trait Map1Any {
}
}
pub trait Map2Any {
trait Map2Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
@ -524,7 +482,6 @@ pub trait Map2Any {
let out = match (s1, s2) {
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
@ -547,7 +504,7 @@ impl Map1 for Clone {
}
}
pub fn kernel_name<T: WithDType>(root: &str) -> String {
fn kernel_name<T: WithDType>(root: &str) -> String {
let dtype = T::DTYPE.as_str();
format!("{root}_{dtype}")
}
@ -608,129 +565,6 @@ impl Map1 for Elu {
}
}
struct Im2Col1D {
l_k: usize,
stride: usize,
dilation: usize,
padding: usize,
}
impl Im2Col1D {
fn l_out(&self, l: usize) -> usize {
(l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
}
}
impl Map1 for Im2Col1D {
fn f<T: DeviceRepr + WithDType>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
let shape = layout.shape();
let dims = shape.dims();
let l_out = self.l_out(dims[2]);
let dst_el = dims[0] * l_out * dims[1] * self.l_k;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("im2col1d"), kernels::CONV)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let params = (
dst_el,
l_out,
self.l_k,
self.stride,
self.padding,
self.dilation,
&ds,
src,
&dst,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
}
}
struct Im2Col {
h_k: usize,
w_k: usize,
stride: usize,
dilation: usize,
padding: usize,
}
impl Im2Col {
fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
(h_out, w_out)
}
}
impl Map1 for Im2Col {
fn f<T: DeviceRepr + WithDType>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
let shape = layout.shape();
let dims = shape.dims();
let (h_out, w_out) = self.hw_out(dims[2], dims[3]);
let dst_el = dims[0] * h_out * w_out * dims[1] * self.h_k * self.w_k;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("im2col"), kernels::CONV)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let params = (
dst_el,
h_out,
w_out,
self.h_k,
self.w_k,
self.stride,
self.padding,
self.dilation,
&ds,
src,
&dst,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
}
}
struct Powf(f64);
impl Map1 for Powf {
fn f<T: DeviceRepr + WithDType>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
let shape = layout.shape();
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("upowf"), kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (el, dims.len(), &ds, T::from_f64(self.0), src, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
struct Sum<'a>(&'a [usize]);
impl<'a> Map1 for Sum<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
@ -880,9 +714,6 @@ impl<'a> Map1 for IndexSelect<'a> {
CudaStorageSlice::U8(slice) => {
("is_u8", *slice.slice(ids_l.start_offset()..).device_ptr())
}
CudaStorageSlice::I64(slice) => {
("is_i64", *slice.slice(ids_l.start_offset()..).device_ptr())
}
_ => Err(CudaError::UnexpectedDType {
msg: "index_select ids should be u8 or u32",
expected: DType::U32,
@ -892,6 +723,8 @@ impl<'a> Map1 for IndexSelect<'a> {
};
let ids_shape = ids_l.shape();
let ids_dims = ids_shape.dims();
let ids_el = ids_shape.elem_count();
let cfg = LaunchConfig::for_num_elems(ids_el as u32);
let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?;
let src = match src_l.contiguous_offsets() {
Some((o1, o2)) => src.slice(o1..o2),
@ -899,23 +732,19 @@ impl<'a> Map1 for IndexSelect<'a> {
};
let left_size: usize = src_l.dims()[..self.2].iter().product();
let right_size: usize = src_l.dims()[self.2 + 1..].iter().product();
let src_dim_size = src_l.dims()[self.2];
let ids_dim_size = ids_shape.elem_count();
let dst_el = ids_shape.elem_count() * left_size * right_size;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let dim_size = src_l.dims()[self.2];
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let out = unsafe { dev.alloc::<T>(ids_el * left_size * right_size) }.w()?;
let params = (
dst_el,
ids_el,
ids_dims.len(),
&ds,
ids,
&src,
&out,
left_size,
src_dim_size,
ids_dim_size,
dim_size,
right_size,
);
// SAFETY: ffi.
@ -944,11 +773,8 @@ impl<'a> Map1 for Gather<'a> {
("gather_u32", *slice.slice(ids_o1..ids_o2).device_ptr())
}
CudaStorageSlice::U8(slice) => ("gather_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::I64(slice) => {
("gather_i64", *slice.slice(ids_o1..ids_o2).device_ptr())
}
_ => Err(CudaError::UnexpectedDType {
msg: "gather ids should be u8/u32/i64",
msg: "gather ids should be u8 or u32",
expected: DType::U32,
got: ids.dtype(),
})?,
@ -994,10 +820,9 @@ impl<'a> Map2InPlace for IndexAdd<'a> {
};
let (name, ids) = match &ids.slice {
CudaStorageSlice::U32(slice) => ("ia_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::I64(slice) => ("ia_i64", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::U8(slice) => ("ia_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
_ => Err(CudaError::UnexpectedDType {
msg: "index-add ids should be u8/u32/i64",
msg: "index-add ids should be u8 or u32",
expected: DType::U32,
got: ids.dtype(),
})?,
@ -1042,10 +867,9 @@ impl<'a> Map2InPlace for ScatterAdd<'a> {
};
let (name, ids) = match &ids.slice {
CudaStorageSlice::U32(slice) => ("sa_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::I64(slice) => ("sa_i64", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::U8(slice) => ("sa_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
_ => Err(CudaError::UnexpectedDType {
msg: "scatter-add ids should be u8/u32/i64",
msg: "scatter-add ids should be u8 or u32",
expected: DType::U32,
got: ids.dtype(),
})?,
@ -1097,12 +921,10 @@ impl<'a> Map2 for Conv1D<'a> {
} else if dims.len() == 2 {
[&[1], dims, &[1], inp_l.stride(), k_l.dims(), k_l.stride()].concat()
} else {
crate::bail!("unexpected input shape for conv1d {dims:?}")
panic!("unexpected input shape for conv1d {dims:?}")
};
let ds = dev.htod_copy(ds).w()?;
let params = (
el, l_out, p.stride, p.padding, p.dilation, &ds, inp, k, &out,
);
let params = (el, l_out, p.stride, p.padding, &ds, inp, k, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
@ -1119,8 +941,8 @@ impl<'a> Map2 for Conv2D<'a> {
k_l: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
// Kernel shape: (c_out, c_in_k, h_k, w_k)
// Input shape: (b_size, c_in, h_in, w_in)
// Kernel shape: (c_out, c_in_k, w_k, h_k)
// Input shape: (b_size, c_in, w_in, c_in)
let p = &self.0;
let (out_w, out_h) = (p.out_w(), p.out_h());
let dst_el = p.c_out * out_w * out_h * p.b_size;
@ -1137,111 +959,10 @@ impl<'a> Map2 for Conv2D<'a> {
let ds = if dims.len() == 4 {
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
} else {
crate::bail!("unexpected input shape for conv2d {dims:?}")
panic!("unexpected input shape for conv1d {dims:?}")
};
let ds = dev.htod_copy(ds).w()?;
let params = (
el, out_w, out_h, p.stride, p.padding, p.dilation, &ds, inp, k, &out,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
impl<'a> Map2 for ConvTranspose1D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
inp: &CudaSlice<T>,
inp_l: &Layout,
k: &CudaSlice<T>,
k_l: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
// Kernel shape: (c_in_k, c_out, l_k)
// Input shape: (b_size, c_in, l_in)
let p = &self.0;
let l_out = p.l_out();
let dst_el = p.c_out * l_out * p.b_size;
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(k_l.start_offset()..);
let shape = inp_l.shape();
let dims = shape.dims();
let el = shape.elem_count();
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose1d"), kernels::CONV)?;
let ds = if dims.len() == 3 {
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
} else {
crate::bail!("unexpected input shape for conv_transpose1d {dims:?}")
};
let ds = dev.htod_copy(ds).w()?;
let params = (
el,
l_out,
p.stride,
p.padding,
p.output_padding,
p.dilation,
&ds,
inp,
k,
&out,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
impl<'a> Map2 for ConvTranspose2D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
inp: &CudaSlice<T>,
inp_l: &Layout,
k: &CudaSlice<T>,
k_l: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
// Kernel shape: (c_in_k, c_out, h_k, w_k)
// Input shape: (b_size, c_in, h_in, w_in)
let p = &self.0;
let (out_w, out_h) = (p.out_w(), p.out_h());
let dst_el = p.c_out * out_w * out_h * p.b_size;
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(k_l.start_offset()..);
let shape = inp_l.shape();
let dims = shape.dims();
let el = shape.elem_count();
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose2d"), kernels::CONV)?;
let ds = if dims.len() == 4 {
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
} else {
crate::bail!("unexpected input shape for conv_transpose2d {dims:?}")
};
let ds = dev.htod_copy(ds).w()?;
let params = (
el,
out_w,
out_h,
p.stride,
p.padding,
p.output_padding,
p.dilation,
&ds,
inp,
k,
&out,
);
let params = (el, out_w, out_h, p.stride, p.padding, &ds, inp, k, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
@ -1275,7 +996,7 @@ impl Map1 for Pool2D {
let ds = if dims.len() == 4 {
[dims, inp_l.stride()].concat()
} else {
crate::bail!("unexpected input shape for pool {dims:?}")
panic!("unexpected input shape for conv1d {dims:?}")
};
let el = shape.elem_count();
let out_w = (dims[2] - self.w_k) / self.w_stride + 1;
@ -1321,7 +1042,7 @@ impl Map1 for UpsampleNearest2D {
let ds = if dims.len() == 4 {
[dims, inp_l.stride()].concat()
} else {
crate::bail!("unexpected input shape for upsample {dims:?}")
panic!("unexpected input shape for conv1d {dims:?}")
};
let (out_w, out_h) = (self.0, self.1);
let dst_el = out_w * out_h * dims[0] * dims[1];
@ -1359,12 +1080,8 @@ impl<'a> Map2 for WhereCond<'a> {
let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
(ptr, "where_u32")
}
CudaStorageSlice::I64(slice) => {
let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
(ptr, "where_i64")
}
_ => Err(CudaError::UnexpectedDType {
msg: "where conditions should be u8/u32/i64",
msg: "where conditions should be u8 or u32",
expected: DType::U32,
got: self.0.dtype(),
})
@ -1475,8 +1192,8 @@ fn slice_src_and_dst<'a, T>(
#[derive(Debug)]
pub struct CudaStorage {
pub slice: CudaStorageSlice,
pub device: CudaDevice,
slice: CudaStorageSlice,
device: CudaDevice,
}
pub trait CudaDType: Sized {
@ -1508,7 +1225,6 @@ macro_rules! cuda_dtype {
}
cuda_dtype!(u8, U8);
cuda_dtype!(u32, U32);
cuda_dtype!(i64, I64);
cuda_dtype!(f16, F16);
cuda_dtype!(bf16, BF16);
cuda_dtype!(f32, F32);
@ -1622,7 +1338,6 @@ impl BackendStorage for CudaStorage {
match self.slice {
CudaStorageSlice::U8(_) => DType::U8,
CudaStorageSlice::U32(_) => DType::U32,
CudaStorageSlice::I64(_) => DType::I64,
CudaStorageSlice::BF16(_) => DType::BF16,
CudaStorageSlice::F16(_) => DType::F16,
CudaStorageSlice::F32(_) => DType::F32,
@ -1648,7 +1363,6 @@ impl BackendStorage for CudaStorage {
let inp = match &self.slice {
CudaStorageSlice::U8(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::I64(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::F32(inp) => *inp.slice(start_o..).device_ptr(),
@ -1671,12 +1385,6 @@ impl BackendStorage for CudaStorage {
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U32(out)
}
DType::I64 => {
let out = unsafe { dev.alloc::<i64>(el) }.w()?;
let params = (el, dims.len(), &ds, *inp, &out);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::I64(out)
}
DType::BF16 => {
let out = unsafe { dev.alloc::<bf16>(el) }.w()?;
let params = (el, dims.len(), &ds, *inp, &out);
@ -1714,12 +1422,6 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
fn powf(&self, layout: &Layout, e: f64) -> Result<Self> {
let device = self.device().clone();
let slice = Powf(e).map(&self.slice, &device, layout)?;
Ok(Self { slice, device })
}
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
let device = self.device().clone();
let slice = Elu(alpha).map(&self.slice, &device, layout)?;
@ -1767,11 +1469,6 @@ impl BackendStorage for CudaStorage {
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
Ok(CpuStorage::U32(cpu_storage))
}
CudaStorageSlice::I64(slice) => {
let dev = slice.device();
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
Ok(CpuStorage::I64(cpu_storage))
}
CudaStorageSlice::BF16(slice) => {
let dev = slice.device();
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
@ -1815,58 +1512,8 @@ impl BackendStorage for CudaStorage {
kernel_l: &Layout,
params: &crate::conv::ParamsConv1D,
) -> Result<Self> {
const USE_IM2COL_CONV1D: bool = true;
let device = self.device().clone();
if !USE_IM2COL_CONV1D {
let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
return Ok(Self { slice, device });
}
let col = Im2Col1D {
l_k: params.k_size,
stride: params.stride,
dilation: params.dilation,
padding: params.padding,
}
.map(&self.slice, &device, l)?;
let col = Self { slice: col, device };
let l_out = params.l_out();
let b = params.b_size;
let n = params.c_out;
let k = params.k_size * params.c_in;
let m = l_out;
let col_l = Layout::contiguous((b, m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
}
fn conv_transpose1d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
let device = self.device().clone();
let slice =
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
Ok(Self { slice, device })
}
@ -1878,50 +1525,9 @@ impl BackendStorage for CudaStorage {
kernel_l: &Layout,
params: &crate::conv::ParamsConv2D,
) -> Result<Self> {
const USE_IM2COL_CONV2D: bool = true;
let device = self.device().clone();
if !USE_IM2COL_CONV2D {
let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
return Ok(Self { slice, device });
}
let col = Im2Col {
h_k: params.k_h,
w_k: params.k_w,
stride: params.stride,
dilation: params.dilation,
padding: params.padding,
}
.map(&self.slice, &device, l)?;
let col = Self { slice: col, device };
let h_out = params.out_h();
let w_out = params.out_w();
let b = params.b_size;
let n = params.c_out;
let k = params.k_h * params.k_w * params.c_in;
let m = h_out * w_out;
let col_l = Layout::contiguous((b, m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, h_out, w_out, n))
.transpose(1, 2)?
.transpose(1, 3)?;
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
Ok(Self { slice, device })
}
#[cfg(feature = "cudnn")]
@ -1964,6 +1570,7 @@ impl BackendStorage for CudaStorage {
.map_err(crate::Error::wrap)?;
S::F16(out)
}
(S::F32(inp), S::F32(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
@ -1981,25 +1588,11 @@ impl BackendStorage for CudaStorage {
S::F64(out)
}
(S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv2d does not support u32"))?,
(S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv2d does not support i64"))?,
_ => Err(CudaError::InternalError("dtype mismatch in conv2d"))?,
};
Ok(Self { slice, device })
}
fn conv_transpose2d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose2D,
) -> Result<Self> {
let device = self.device().clone();
let slice =
ConvTranspose2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
Ok(Self { slice, device })
}
fn avg_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result<Self> {
let device = self.device().clone();
let slice = Pool2D {
@ -2026,10 +1619,6 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
fn upsample_nearest1d(&self, _: &Layout, _out_sz: usize) -> Result<Self> {
crate::bail!("upsample-nearest1d is not supported on cuda")
}
fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result<Self> {
let device = self.device().clone();
let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?;
@ -2149,9 +1738,6 @@ impl BackendStorage for CudaStorage {
let src_shape = src_l.shape();
let dims = src_shape.dims();
let el_count = src_shape.elem_count();
if el_count == 0 {
return Ok(());
}
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let dev = &self.device;
let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?;
@ -2216,24 +1802,12 @@ impl BackendStorage for CudaStorage {
unsafe { func.launch(cfg, params) }.w()?
}
}
(CudaStorageSlice::I64(src), CudaStorageSlice::I64(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
dev.dtod_copy(&src, &mut dst).w()?
} else {
let func = dev.get_or_load_func("ucopy_i64", kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, &src, &mut dst);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?
}
}
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
dev.dtod_copy(&src, &mut dst).w()?
} else {
let func = dev.get_or_load_func("ucopy_f64", kernels::UNARY)?;
let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, &src, &mut dst);
// SAFETY: ffi.

View File

@ -34,9 +34,6 @@ pub(crate) fn launch_conv2d<
params: &crate::conv::ParamsConv2D,
dev: &crate::cuda_backend::CudaDevice,
) -> crate::Result<()> {
use crate::conv::CudnnFwdAlgo as CandleAlgo;
use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
let device_id = dev.id();
let cudnn = CUDNN.with(|cudnn| {
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
@ -51,14 +48,14 @@ pub(crate) fn launch_conv2d<
let conv = cudnn.create_conv2d::<T>(
/* pad */ [params.padding as i32, params.padding as i32],
/* stride */ [params.stride as i32, params.stride as i32],
/* dilation */ [params.dilation as i32, params.dilation as i32],
/* dilation */ [1, 1],
cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
)?;
let x_shape = [
params.b_size as i32,
params.c_in as i32,
params.i_h as i32,
params.i_w as i32,
params.i_h as i32,
];
// Note that `src` already starts at the proper offset.
let x = if src_l.is_contiguous() {
@ -78,14 +75,14 @@ pub(crate) fn launch_conv2d<
[
params.c_out as i32,
params.c_in as i32,
params.k_h as i32,
params.k_w as i32,
params.k_h as i32,
],
)?;
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
let y = cudnn.create_4d_tensor(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[params.b_size as i32, params.c_out as i32, h_out, w_out],
[params.b_size as i32, params.c_out as i32, w_out, h_out],
)?;
let conv2d = Conv2dForward {
conv: &conv,
@ -93,20 +90,7 @@ pub(crate) fn launch_conv2d<
w: &w,
y: &y,
};
let alg = match params.cudnn_fwd_algo {
None => conv2d.pick_algorithm()?,
Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
Some(CandleAlgo::ImplicitPrecompGemm) => {
A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
}
Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
};
let alg = conv2d.pick_algorithm()?;
let workspace_size = conv2d.get_workspace_size(alg)?;
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
unsafe {

View File

@ -8,16 +8,15 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
pub enum DeviceLocation {
Cpu,
Cuda { gpu_id: usize },
Metal { gpu_id: usize },
}
#[derive(Debug, Clone)]
pub enum Device {
Cpu,
Cuda(crate::CudaDevice),
Metal(crate::MetalDevice),
}
// TODO: Should we back the cpu implementation using the NdArray crate or similar?
pub trait NdArray {
fn shape(&self) -> Result<Shape>;
@ -82,71 +81,15 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
}
}
impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4: usize> NdArray
for &[[[[S; N4]; N3]; N2]; N1]
{
fn shape(&self) -> Result<Shape> {
Ok(Shape::from((N1, N2, N3, N4)))
}
fn to_cpu_storage(&self) -> CpuStorage {
let mut vec = Vec::with_capacity(N1 * N2 * N3 * N4);
for i1 in 0..N1 {
for i2 in 0..N2 {
for i3 in 0..N3 {
vec.extend(self[i1][i2][i3])
}
}
}
S::to_cpu_storage_owned(vec)
}
}
impl<S: NdArray> NdArray for Vec<S> {
fn shape(&self) -> Result<Shape> {
if self.is_empty() {
crate::bail!("empty array")
}
let shape0 = self[0].shape()?;
let n = self.len();
for v in self.iter() {
let shape = v.shape()?;
if shape != shape0 {
crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
}
}
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
}
fn to_cpu_storage(&self) -> CpuStorage {
// This allocates intermediary memory and shouldn't be necessary.
let storages = self.iter().map(|v| v.to_cpu_storage()).collect::<Vec<_>>();
CpuStorage::concat(storages.as_slice()).unwrap()
}
}
impl Device {
pub fn new_cuda(ordinal: usize) -> Result<Self> {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
}
pub fn new_metal(ordinal: usize) -> Result<Self> {
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
}
pub fn set_seed(&self, seed: u64) -> Result<()> {
match self {
Self::Cpu => CpuDevice.set_seed(seed),
Self::Cuda(c) => c.set_seed(seed),
Self::Metal(m) => m.set_seed(seed),
}
}
pub fn same_device(&self, rhs: &Self) -> bool {
match (self, rhs) {
(Self::Cpu, Self::Cpu) => true,
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
(Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
_ => false,
}
}
@ -155,20 +98,21 @@ impl Device {
match self {
Self::Cpu => DeviceLocation::Cpu,
Self::Cuda(device) => device.location(),
Device::Metal(device) => device.location(),
}
}
pub fn is_cpu(&self) -> bool {
matches!(self, Self::Cpu)
match self {
Self::Cpu => true,
Self::Cuda(_) => false,
}
}
pub fn is_cuda(&self) -> bool {
matches!(self, Self::Cuda(_))
}
pub fn is_metal(&self) -> bool {
matches!(self, Self::Metal(_))
match self {
Self::Cpu => false,
Self::Cuda(_) => true,
}
}
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
@ -192,18 +136,8 @@ impl Device {
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
if dtype == DType::F16 || dtype == DType::BF16 {
let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
} else {
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cuda(storage))
}
}
Device::Metal(device) => {
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Metal(storage))
Ok(Storage::Cuda(storage))
}
}
}
@ -230,18 +164,8 @@ impl Device {
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
if dtype == DType::F16 || dtype == DType::BF16 {
let storage = device.rand_normal(shape, DType::F32, mean, std)?;
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
} else {
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage))
}
}
Device::Metal(device) => {
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Metal(storage))
Ok(Storage::Cuda(storage))
}
}
}
@ -265,10 +189,6 @@ impl Device {
let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Metal(storage))
}
}
}
@ -282,10 +202,6 @@ impl Device {
let storage = device.zeros_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.zeros_impl(shape, dtype)?;
Ok(Storage::Metal(storage))
}
}
}
@ -297,11 +213,6 @@ impl Device {
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = array.to_cpu_storage();
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Metal(storage))
}
}
}
@ -313,11 +224,6 @@ impl Device {
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = S::to_cpu_storage_owned(data);
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Metal(storage))
}
}
}
}

View File

@ -9,17 +9,11 @@ impl Tensor {
&self,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
let device_str = match self.device().location() {
crate::DeviceLocation::Cpu => "".to_owned(),
crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id)
}
crate::DeviceLocation::Metal { gpu_id } => {
format!(", metal:{}", gpu_id)
}
let prefix = match self.device() {
crate::Device::Cpu => "Cpu",
crate::Device::Cuda(_) => "Cuda",
};
write!(f, "Tensor[")?;
write!(f, "{prefix}Tensor[")?;
match self.dims() {
[] => {
if let Ok(v) = self.to_scalar::<T>() {
@ -46,7 +40,7 @@ impl Tensor {
}
}
}
write!(f, "; {}{}]", self.dtype().as_str(), device_str)
write!(f, "; {}]", self.dtype().as_str())
}
}
@ -55,7 +49,6 @@ impl std::fmt::Debug for Tensor {
match self.dtype() {
DType::U8 => self.fmt_dt::<u8>(f),
DType::U32 => self.fmt_dt::<u32>(f),
DType::I64 => self.fmt_dt::<i64>(f),
DType::BF16 => self.fmt_dt::<bf16>(f),
DType::F16 => self.fmt_dt::<f16>(f),
DType::F32 => self.fmt_dt::<f32>(f),
@ -438,12 +431,6 @@ impl std::fmt::Display for Tensor {
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
DType::I64 => {
let tf: IntFormatter<i64> = IntFormatter::new();
let max_w = tf.max_width(&to_display);
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
DType::BF16 => {
if let Ok(tf) = FloatFormatter::<bf16>::new(&to_display, &po) {
let max_w = tf.max_width(&to_display);
@ -473,23 +460,6 @@ impl std::fmt::Display for Tensor {
}
}
};
let device_str = match self.device().location() {
crate::DeviceLocation::Cpu => "".to_owned(),
crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id)
}
crate::DeviceLocation::Metal { gpu_id } => {
format!(", metal:{}", gpu_id)
}
};
write!(
f,
"Tensor[{:?}, {}{}]",
self.dims(),
self.dtype().as_str(),
device_str
)
write!(f, "Tensor[{:?}, {}]", self.dims(), self.dtype().as_str())
}
}

View File

@ -1,24 +1,13 @@
//! Types for elements that can be stored and manipulated using tensors.
#![allow(clippy::redundant_closure_call)]
use crate::backend::BackendStorage;
use crate::{CpuStorage, Error, Result};
/// The different types of elements allowed in tensors.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum DType {
// Unsigned 8 bits integer.
U8,
// Unsigned 32 bits integer.
U32,
// Signed 64 bits integer.
I64,
// Brain floating-point using half precision (16 bits).
BF16,
// Floating-point using half precision (16 bits).
F16,
// Floating-point using single precision (32 bits).
F32,
// Floating-point using double precision (64 bits).
F64,
}
@ -31,7 +20,6 @@ impl std::str::FromStr for DType {
match s {
"u8" => Ok(Self::U8),
"u32" => Ok(Self::U32),
"i64" => Ok(Self::I64),
"bf16" => Ok(Self::BF16),
"f16" => Ok(Self::F16),
"f32" => Ok(Self::F32),
@ -42,12 +30,10 @@ impl std::str::FromStr for DType {
}
impl DType {
/// String representation for dtypes.
pub fn as_str(&self) -> &'static str {
match self {
Self::U8 => "u8",
Self::U32 => "u32",
Self::I64 => "i64",
Self::BF16 => "bf16",
Self::F16 => "f16",
Self::F32 => "f32",
@ -55,32 +41,16 @@ impl DType {
}
}
/// The size used by each element in bytes, i.e. 1 for `U8`, 4 for `F32`.
pub fn size_in_bytes(&self) -> usize {
match self {
Self::U8 => 1,
Self::U32 => 4,
Self::I64 => 8,
Self::BF16 => 2,
Self::F16 => 2,
Self::F32 => 4,
Self::F64 => 8,
}
}
pub fn is_int(&self) -> bool {
match self {
Self::U8 | Self::U32 | Self::I64 => true,
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => false,
}
}
pub fn is_float(&self) -> bool {
match self {
Self::U8 | Self::U32 | Self::I64 => false,
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true,
}
}
}
pub trait WithDType:
@ -155,7 +125,6 @@ use half::{bf16, f16};
with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64);
with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64);
with_dtype!(i64, I64, |v: f64| v as i64, |v: i64| v as f64);
with_dtype!(f16, F16, f16::from_f64, f16::to_f64);
with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
@ -166,15 +135,6 @@ pub trait IntDType: WithDType {
fn as_usize(&self) -> usize;
}
impl IntDType for i64 {
fn is_true(&self) -> bool {
*self != 0
}
fn as_usize(&self) -> usize {
*self as usize
}
}
impl IntDType for u32 {
fn is_true(&self) -> bool {
*self != 0

View File

@ -37,10 +37,6 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
@ -79,16 +75,6 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
fn conv_transpose1d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn conv2d(
&self,
_: &Layout,
@ -99,16 +85,6 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
fn conv_transpose2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose2D,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
@ -162,10 +138,6 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
@ -177,10 +149,6 @@ impl crate::backend::BackendDevice for CudaDevice {
Err(Error::NotCompiledWithCudaSupport)
}
fn set_seed(&self, _: u64) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
fn location(&self) -> crate::DeviceLocation {
fail!()
}

View File

@ -1,223 +0,0 @@
#![allow(dead_code)]
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
#[derive(Debug, Clone)]
pub struct MetalDevice;
#[derive(Debug)]
pub struct MetalStorage;
#[derive(thiserror::Error, Debug)]
pub enum MetalError {
#[error("{0}")]
Message(String),
}
impl From<String> for MetalError {
fn from(e: String) -> Self {
MetalError::Message(e)
}
}
macro_rules! fail {
() => {
unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
};
}
impl crate::backend::BackendStorage for MetalStorage {
type Device = MetalDevice;
fn try_clone(&self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn dtype(&self) -> DType {
fail!()
}
fn device(&self) -> &Self::Device {
fail!()
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv1d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConv1D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv_transpose1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv2d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConv2D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv_transpose2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose2D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn scatter_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn index_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn matmul(
&self,
_: &Self,
_: (usize, usize, usize, usize),
_: &Layout,
_: &Layout,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
}
impl crate::backend::BackendDevice for MetalDevice {
type Storage = MetalStorage;
fn new(_: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn set_seed(&self, _: u64) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
fn location(&self) -> crate::DeviceLocation {
fail!()
}
fn same_device(&self, _: &Self) -> bool {
fail!()
}
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
}

View File

@ -1,4 +1,4 @@
use crate::{DType, DeviceLocation, Layout, MetalError, Shape};
use crate::{DType, DeviceLocation, Layout, Shape};
#[derive(Debug, Clone)]
pub struct MatMulUnexpectedStriding {
@ -30,7 +30,7 @@ pub enum Error {
UnsupportedDTypeForOp(DType, &'static str),
// === Dimension Index Errors ===
#[error("{op}: dimension index {dim} out of range for shape {shape:?}")]
#[error("{op}: dimension index {dim} out of range for {shape:?}")]
DimOutOfRange {
shape: Shape,
dim: i32,
@ -142,9 +142,6 @@ pub enum Error {
#[error("{op} expects at least one tensor")]
OpRequiresAtLeastOneTensor { op: &'static str },
#[error("{op} expects at least two tensors")]
OpRequiresAtLeastTwoTensors { op: &'static str },
#[error("backward is not supported for {op}")]
BackwardNotSupported { op: &'static str },
@ -152,9 +149,6 @@ pub enum Error {
#[error("the candle crate has not been built with cuda support")]
NotCompiledWithCudaSupport,
#[error("the candle crate has not been built with metal support")]
NotCompiledWithMetalSupport,
#[error("cannot find tensor {path}")]
CannotFindTensor { path: String },
@ -162,9 +156,6 @@ pub enum Error {
#[error(transparent)]
Cuda(Box<dyn std::error::Error + Send + Sync>),
#[error("Metal error {0}")]
Metal(#[from] MetalError),
#[error(transparent)]
TryFromIntError(#[from] core::num::TryFromIntError),
@ -216,11 +207,7 @@ pub type Result<T> = std::result::Result<T, Error>;
impl Error {
pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self {
Self::Wrapped(Box::new(err)).bt()
}
pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self {
Self::Msg(err.to_string()).bt()
Self::Wrapped(Box::new(err))
}
pub fn bt(self) -> Self {

View File

@ -46,31 +46,19 @@ impl Tensor {
current_dim += 1;
out
}
TensorIndexer::IndexSelect(indexes) => {
if indexes.rank() != 1 {
crate::bail!("multi-dimensional tensor indexing is not supported")
}
let out = x.index_select(&indexes.to_device(x.device())?, current_dim)?;
current_dim += 1;
out
}
TensorIndexer::Err(e) => crate::bail!("indexing error {e:?}"),
};
}
Ok(x)
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
/// Generic structure used to index a slice of the tensor
pub enum TensorIndexer {
/// This selects the elements for which an index has some specific value.
/// This selects the elemnts for which an index has some specific value.
Select(usize),
/// This is a regular slice, purely indexing a chunk of the tensor
Narrow(Bound<usize>, Bound<usize>),
/// Indexing via a 1d tensor
IndexSelect(Tensor),
Err(Error),
}
impl From<usize> for TensorIndexer {
@ -79,55 +67,36 @@ impl From<usize> for TensorIndexer {
}
}
impl From<&[u32]> for TensorIndexer {
fn from(index: &[u32]) -> Self {
match Tensor::new(index, &crate::Device::Cpu) {
Ok(tensor) => TensorIndexer::IndexSelect(tensor),
Err(e) => TensorIndexer::Err(e),
macro_rules! impl_from_range {
($range_type:ty) => {
impl From<$range_type> for TensorIndexer {
fn from(range: $range_type) -> Self {
use std::ops::Bound::*;
let start = match range.start_bound() {
Included(idx) => Included(*idx),
Excluded(idx) => Excluded(*idx),
Unbounded => Unbounded,
};
let end = match range.end_bound() {
Included(idx) => Included(*idx),
Excluded(idx) => Excluded(*idx),
Unbounded => Unbounded,
};
TensorIndexer::Narrow(start, end)
}
}
}
};
}
impl From<Vec<u32>> for TensorIndexer {
fn from(index: Vec<u32>) -> Self {
let len = index.len();
match Tensor::from_vec(index, len, &crate::Device::Cpu) {
Ok(tensor) => TensorIndexer::IndexSelect(tensor),
Err(e) => TensorIndexer::Err(e),
}
}
}
impl From<&Tensor> for TensorIndexer {
fn from(tensor: &Tensor) -> Self {
TensorIndexer::IndexSelect(tensor.clone())
}
}
trait RB: RangeBounds<usize> {}
impl RB for Range<usize> {}
impl RB for RangeFrom<usize> {}
impl RB for RangeFull {}
impl RB for RangeInclusive<usize> {}
impl RB for RangeTo<usize> {}
impl RB for RangeToInclusive<usize> {}
impl<T: RB> From<T> for TensorIndexer {
fn from(range: T) -> Self {
use std::ops::Bound::*;
let start = match range.start_bound() {
Included(idx) => Included(*idx),
Excluded(idx) => Excluded(*idx),
Unbounded => Unbounded,
};
let end = match range.end_bound() {
Included(idx) => Included(*idx),
Excluded(idx) => Excluded(*idx),
Unbounded => Unbounded,
};
TensorIndexer::Narrow(start, end)
}
}
impl_from_range!(Range<usize>);
impl_from_range!(RangeFrom<usize>);
impl_from_range!(RangeFull);
impl_from_range!(RangeInclusive<usize>);
impl_from_range!(RangeTo<usize>);
impl_from_range!(RangeToInclusive<usize>);
/// Trait used to implement multiple signatures for ease of use of the slicing
/// of a tensor

View File

@ -9,14 +9,6 @@ pub struct Layout {
}
impl Layout {
pub fn new(shape: Shape, stride: Vec<usize>, start_offset: usize) -> Self {
Self {
shape,
stride,
start_offset,
}
}
pub fn contiguous_with_offset<S: Into<Shape>>(shape: S, start_offset: usize) -> Self {
let shape = shape.into();
let stride = shape.stride_contiguous();
@ -120,31 +112,6 @@ impl Layout {
})
}
pub(crate) fn permute(&self, idxs: &[usize]) -> Result<Self> {
let is_permutation =
idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
if !is_permutation {
crate::bail!(
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
self.dims(),
idxs
)
}
let stride = self.stride();
let dims = self.shape().dims();
let mut perm_stride = stride.to_vec();
let mut perm_dims = dims.to_vec();
for (i, &idx) in idxs.iter().enumerate() {
perm_stride[i] = stride[idx];
perm_dims[i] = dims[idx];
}
Ok(Self {
shape: Shape::from(perm_dims),
stride: perm_stride,
start_offset: self.start_offset,
})
}
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
let shape = shape.into();
if shape.rank() < self.shape().rank() {

View File

@ -49,30 +49,24 @@ mod device;
pub mod display;
mod dtype;
mod dummy_cuda_backend;
mod dummy_metal_backend;
pub mod error;
mod indexer;
pub mod layout;
#[cfg(feature = "metal")]
pub mod metal_backend;
#[cfg(feature = "mkl")]
mod mkl;
pub mod npy;
mod op;
pub mod pickle;
pub mod quantized;
pub mod safetensors;
pub mod scalar;
pub mod shape;
mod storage;
mod strided_index;
mod tensor;
pub mod test_utils;
pub mod utils;
mod variable;
pub use cpu_backend::CpuStorage;
pub use device::{Device, DeviceLocation, NdArray};
pub use device::{Device, DeviceLocation};
pub use dtype::{DType, FloatDType, IntDType, WithDType};
pub use error::{Error, Result};
pub use indexer::IndexOp;
@ -90,53 +84,5 @@ pub use cuda_backend::{CudaDevice, CudaStorage};
#[cfg(not(feature = "cuda"))]
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
#[cfg(feature = "metal")]
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
#[cfg(not(feature = "metal"))]
pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage};
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
pub trait ToUsize2 {
fn to_usize2(self) -> (usize, usize);
}
impl ToUsize2 for usize {
fn to_usize2(self) -> (usize, usize) {
(self, self)
}
}
impl ToUsize2 for (usize, usize) {
fn to_usize2(self) -> (usize, usize) {
self
}
}
// A simple trait defining a module with forward method using a single argument.
pub trait Module {
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
}
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self(xs)
}
}
// A trait defining a module with forward method using a single tensor argument and a flag to
// separate the training and evaluation behaviors.
pub trait ModuleT {
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
}
impl<M: Module> ModuleT for M {
fn forward_t(&self, xs: &Tensor, _train: bool) -> Result<Tensor> {
self.forward(xs)
}
}

File diff suppressed because it is too large Load Diff

View File

@ -25,10 +25,6 @@ mod ffi {
pub fn vdMul(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn vsDiv(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdDiv(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn vsFmax(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdFmax(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn vsFmin(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdFmin(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn sgemm_(
transa: *const c_char,
@ -301,7 +297,7 @@ pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
}
#[inline]
pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
fn vs_tanh(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
@ -311,7 +307,7 @@ pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
}
#[inline]
pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
fn vd_tanh(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
@ -380,7 +376,3 @@ binary_op!(vs_mul, f32, vsMul);
binary_op!(vd_mul, f64, vdMul);
binary_op!(vs_div, f32, vsDiv);
binary_op!(vd_div, f64, vdDiv);
binary_op!(vs_max, f32, vsFmax);
binary_op!(vd_max, f64, vdFmax);
binary_op!(vs_min, f32, vsFmin);
binary_op!(vd_min, f64, vdFmin);

View File

@ -85,7 +85,6 @@ impl Header {
DType::F16 => "f2",
DType::F32 => "f4",
DType::F64 => "f8",
DType::I64 => "i8",
DType::U32 => "u4",
DType::U8 => "u1",
};
@ -161,7 +160,7 @@ impl Header {
"f" | "f4" => DType::F32,
"d" | "f8" => DType::F64,
// "i" | "i4" => DType::S32,
"q" | "i8" => DType::I64,
// "q" | "i8" => DType::S64,
// "h" | "i2" => DType::S16,
// "b" | "i1" => DType::S8,
"B" | "u1" => DType::U8,
@ -197,11 +196,7 @@ impl Header {
impl Tensor {
// TODO: Add the possibility to read directly to a device?
pub(crate) fn from_reader<R: std::io::Read>(
shape: Shape,
dtype: DType,
reader: &mut R,
) -> Result<Self> {
fn from_reader<R: std::io::Read>(shape: Shape, dtype: DType, reader: &mut R) -> Result<Self> {
let elem_count = shape.elem_count();
match dtype {
DType::BF16 => {
@ -234,11 +229,6 @@ impl Tensor {
reader.read_u32_into::<LittleEndian>(&mut data_t)?;
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
DType::I64 => {
let mut data_t = vec![0i64; elem_count];
reader.read_i64_into::<LittleEndian>(&mut data_t)?;
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
}
}
@ -250,6 +240,8 @@ impl Tensor {
if header.fortran_order {
return Err(Error::Npy("fortran order not supported".to_string()));
}
let mut data: Vec<u8> = vec![];
reader.read_to_end(&mut data)?;
Self::from_reader(header.shape(), header.descr, &mut reader)
}
@ -369,25 +361,6 @@ impl NpzTensors {
})
}
pub fn names(&self) -> Vec<&String> {
self.index_per_name.keys().collect()
}
/// This only returns the shape and dtype for a named tensor. Compared to `get`, this avoids
/// reading the whole tensor data.
pub fn get_shape_and_dtype(&self, name: &str) -> Result<(Shape, DType)> {
let index = match self.index_per_name.get(name) {
None => crate::bail!("cannot find tensor {name}"),
Some(index) => *index,
};
let zip_reader = BufReader::new(File::open(&self.path)?);
let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut reader = zip.by_index(index)?;
let header = read_header(&mut reader)?;
let header = Header::parse(&header)?;
Ok((header.shape(), header.descr))
}
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
let index = match self.index_per_name.get(name) {
None => return Ok(None),

View File

@ -1,5 +1,4 @@
#![allow(clippy::redundant_closure_call)]
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor};
use half::{bf16, f16};
use num_traits::float::Float;
@ -41,8 +40,6 @@ pub enum BinaryOp {
Mul,
Sub,
Div,
Maximum,
Minimum,
}
// Unary ops with no argument
@ -58,13 +55,7 @@ pub enum UnaryOp {
Sqr,
Sqrt,
Gelu,
GeluErf,
Erf,
Relu,
Tanh,
Floor,
Ceil,
Round,
}
#[derive(Clone)]
@ -87,17 +78,6 @@ pub enum Op {
kernel: Tensor,
padding: usize,
stride: usize,
dilation: usize,
},
#[allow(dead_code)]
ConvTranspose1D {
arg: Tensor,
kernel: Tensor,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
},
#[allow(dead_code)]
@ -106,17 +86,6 @@ pub enum Op {
kernel: Tensor,
padding: usize,
stride: usize,
dilation: usize,
},
#[allow(dead_code)]
ConvTranspose2D {
arg: Tensor,
kernel: Tensor,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
},
AvgPool2D {
@ -131,12 +100,7 @@ pub enum Op {
stride: (usize, usize),
},
UpsampleNearest1D(Tensor),
UpsampleNearest2D {
arg: Tensor,
target_h: usize,
target_w: usize,
},
UpsampleNearest2D(Tensor),
Cat(Vec<Tensor>, usize),
@ -150,29 +114,17 @@ pub enum Op {
Copy(Tensor),
Broadcast(Tensor),
Narrow(Tensor, usize, usize, usize),
SliceScatter0(Tensor, Tensor, usize),
Reshape(Tensor),
ToDevice(Tensor),
Transpose(Tensor, usize, usize),
Permute(Tensor, Vec<usize>),
Elu(Tensor, f64),
Powf(Tensor, f64),
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>),
CustomOp2(
Tensor,
Tensor,
std::sync::Arc<Box<dyn CustomOp2 + Send + Sync>>,
),
CustomOp3(
Tensor,
Tensor,
Tensor,
std::sync::Arc<Box<dyn CustomOp3 + Send + Sync>>,
),
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1>>),
CustomOp2(Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp2>>),
CustomOp3(Tensor, Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp3>>),
}
/// Unary ops that can be defined in user-land.
pub trait CustomOp1 {
pub trait CustomOp1: Send + Sync {
// Box<dyn> does not support const yet, so use a function to get the name.
fn name(&self) -> &'static str;
@ -188,18 +140,6 @@ pub trait CustomOp1 {
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_storage: &MetalStorage,
_layout: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
/// This function takes as argument the argument `arg` used in the forward pass, the result
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
/// The function should return the gradient of the argument.
@ -208,7 +148,7 @@ pub trait CustomOp1 {
}
}
pub trait CustomOp2 {
pub trait CustomOp2: Send + Sync {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
@ -235,20 +175,6 @@ pub trait CustomOp2 {
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
@ -260,7 +186,7 @@ pub trait CustomOp2 {
}
}
pub trait CustomOp3 {
pub trait CustomOp3: Send + Sync {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
@ -291,22 +217,6 @@ pub trait CustomOp3 {
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
@ -329,7 +239,6 @@ pub trait UnaryOpT {
fn f64(v1: f64) -> f64;
fn u8(v1: u8) -> u8;
fn u32(v1: u32) -> u32;
fn i64(v1: i64) -> i64;
// There is no very good way to represent optional function in traits so we go for an explicit
// boolean flag to mark the function as existing.
@ -353,7 +262,6 @@ pub trait BinaryOpT {
fn f64(v1: f64, v2: f64) -> f64;
fn u8(v1: u8, v2: u8) -> u8;
fn u32(v1: u32, v2: u32) -> u32;
fn i64(v1: i64, v2: i64) -> i64;
const BF16_VEC: bool = false;
fn bf16_vec(_xs1: &[bf16], _xs2: &[bf16], _ys: &mut [bf16]) {}
@ -367,16 +275,12 @@ pub trait BinaryOpT {
fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {}
const U32_VEC: bool = false;
fn u32_vec(_xs1: &[u32], _xs2: &[u32], _ys: &mut [u32]) {}
const I64_VEC: bool = false;
fn i64_vec(_xs1: &[i64], _xs2: &[i64], _ys: &mut [i64]) {}
}
pub(crate) struct Add;
pub(crate) struct Div;
pub(crate) struct Mul;
pub(crate) struct Sub;
pub(crate) struct Maximum;
pub(crate) struct Minimum;
pub(crate) struct Exp;
pub(crate) struct Log;
pub(crate) struct Sin;
@ -387,13 +291,7 @@ pub(crate) struct Recip;
pub(crate) struct Sqr;
pub(crate) struct Sqrt;
pub(crate) struct Gelu;
pub(crate) struct GeluErf;
pub(crate) struct Erf;
pub(crate) struct Relu;
pub(crate) struct Tanh;
pub(crate) struct Floor;
pub(crate) struct Ceil;
pub(crate) struct Round;
macro_rules! bin_op {
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
@ -425,10 +323,6 @@ macro_rules! bin_op {
fn u32(v1: u32, v2: u32) -> u32 {
$e(v1, v2)
}
#[inline(always)]
fn i64(v1: i64, v2: i64) -> i64 {
$e(v1, v2)
}
#[cfg(feature = "mkl")]
const F32_VEC: bool = true;
@ -467,22 +361,7 @@ bin_op!(Add, "add", |v1, v2| v1 + v2, vs_add, vd_add);
bin_op!(Sub, "sub", |v1, v2| v1 - v2, vs_sub, vd_sub);
bin_op!(Mul, "mul", |v1, v2| v1 * v2, vs_mul, vd_mul);
bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div);
bin_op!(
Minimum,
"minimum",
|v1, v2| if v1 > v2 { v2 } else { v1 },
vs_min,
vd_min
);
bin_op!(
Maximum,
"maximum",
|v1, v2| if v1 < v2 { v2 } else { v1 },
vs_max,
vd_max
);
#[allow(clippy::redundant_closure_call)]
macro_rules! unary_op {
($op: ident, $name: literal, $a: ident, $e: expr) => {
impl UnaryOpT for $op {
@ -513,10 +392,6 @@ macro_rules! unary_op {
fn u32(_: u32) -> u32 {
todo!("no unary function for u32")
}
#[inline(always)]
fn i64(_: i64) -> i64 {
todo!("no unary function for i64")
}
}
};
@ -549,10 +424,6 @@ macro_rules! unary_op {
fn u32(_: u32) -> u32 {
todo!("no unary function for u32")
}
#[inline(always)]
fn i64(_: i64) -> i64 {
todo!("no unary function for i64")
}
#[cfg(feature = "mkl")]
const F32_VEC: bool = true;
@ -591,14 +462,13 @@ unary_op!(Exp, "exp", v, v.exp(), vs_exp, vd_exp);
unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
unary_op!(Abs, "abs", v, v.abs());
unary_op!(Neg, "neg", v, -v);
unary_op!(Recip, "recip", v, v.recip());
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
/// Tanh based approximation of the `gelu` operation
/// GeluErf is the more precise one.
/// `gelu` operation
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
impl UnaryOpT for Gelu {
const NAME: &'static str = "gelu";
@ -645,10 +515,6 @@ impl UnaryOpT for Gelu {
fn u32(_: u32) -> u32 {
0
}
#[inline(always)]
fn i64(_: i64) -> i64 {
0
}
const KERNEL: &'static str = "ugelu";
#[cfg(feature = "mkl")]
@ -668,230 +534,6 @@ impl UnaryOpT for Gelu {
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::vd_gelu(xs, ys)
}
#[cfg(feature = "accelerate")]
const F32_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::accelerate::vs_gelu(xs, ys)
}
#[cfg(feature = "accelerate")]
const F64_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::accelerate::vd_gelu(xs, ys)
}
}
/// `erf` operation
/// <https://en.wikipedia.org/wiki/Error_function>
impl UnaryOpT for Erf {
const NAME: &'static str = "erf";
const KERNEL: &'static str = "uerf";
const V: Self = Erf;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
bf16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f16(v: f16) -> f16 {
f16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f32(v: f32) -> f32 {
Self::f64(v as f64) as f32
}
#[inline(always)]
fn f64(v: f64) -> f64 {
crate::cpu::erf::erf(v)
}
#[inline(always)]
fn u8(_: u8) -> u8 {
0
}
#[inline(always)]
fn u32(_: u32) -> u32 {
0
}
#[inline(always)]
fn i64(_: i64) -> i64 {
0
}
}
impl UnaryOpT for Abs {
const NAME: &'static str = "abs";
const KERNEL: &'static str = "uabs";
const V: Self = Abs;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v.abs()
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v.abs()
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v.abs()
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v.abs()
}
#[inline(always)]
fn u8(v: u8) -> u8 {
v
}
#[inline(always)]
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v.abs()
}
}
impl UnaryOpT for Ceil {
const NAME: &'static str = "ceil";
const KERNEL: &'static str = "uceil";
const V: Self = Ceil;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v.ceil()
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v.ceil()
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v.ceil()
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v.ceil()
}
#[inline(always)]
fn u8(v: u8) -> u8 {
v
}
#[inline(always)]
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v
}
}
impl UnaryOpT for Floor {
const NAME: &'static str = "floor";
const KERNEL: &'static str = "ufloor";
const V: Self = Floor;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v.floor()
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v.floor()
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v.floor()
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v.floor()
}
#[inline(always)]
fn u8(v: u8) -> u8 {
v
}
#[inline(always)]
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v
}
}
impl UnaryOpT for Round {
const NAME: &'static str = "round";
const KERNEL: &'static str = "uround";
const V: Self = Round;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v.round()
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v.round()
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v.round()
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v.round()
}
#[inline(always)]
fn u8(v: u8) -> u8 {
v
}
#[inline(always)]
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v
}
}
impl UnaryOpT for GeluErf {
const NAME: &'static str = "gelu_erf";
const KERNEL: &'static str = "ugelu_erf";
const V: Self = GeluErf;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
bf16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f16(v: f16) -> f16 {
f16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f32(v: f32) -> f32 {
Self::f64(v as f64) as f32
}
#[inline(always)]
fn f64(v: f64) -> f64 {
(crate::cpu::erf::erf(v / 2f64.sqrt()) + 1.) * 0.5 * v
}
#[inline(always)]
fn u8(_: u8) -> u8 {
0
}
#[inline(always)]
fn u32(_: u32) -> u32 {
0
}
#[inline(always)]
fn i64(_: i64) -> i64 {
0
}
}
impl UnaryOpT for Relu {
@ -922,10 +564,6 @@ impl UnaryOpT for Relu {
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v
}
}
/// `BackpropOp` is a wrapper around `Option<Op>`. The main goal is to ensure that dependencies are
@ -979,10 +617,6 @@ impl BackpropOp {
};
Self(op)
}
pub(crate) fn is_none(&self) -> bool {
self.0.is_none()
}
}
impl std::ops::Deref for BackpropOp {

View File

@ -1,814 +0,0 @@
// Just enough pickle support to be able to read PyTorch checkpoints.
// This hardcodes objects that are required for tensor reading, we may want to make this a bit more
// composable/tensor agnostic at some point.
use crate::{DType, Error as E, Layout, Result, Tensor};
use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap;
use std::io::BufRead;
const VERBOSE: bool = false;
// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/
#[repr(u8)]
#[derive(Debug, Eq, PartialEq, Clone)]
pub enum OpCode {
// https://github.com/python/cpython/blob/ed25f097160b5cbb0c9a1f9a746d2f1bbc96515a/Lib/pickletools.py#L2123
Proto = 0x80,
Global = b'c',
BinPut = b'q',
LongBinPut = b'r',
EmptyTuple = b')',
Reduce = b'R',
Mark = b'(',
BinUnicode = b'X',
BinInt = b'J',
Tuple = b't',
BinPersId = b'Q',
BinInt1 = b'K',
BinInt2 = b'M',
Tuple1 = 0x85,
Tuple2 = 0x86,
Tuple3 = 0x87,
NewTrue = 0x88,
NewFalse = 0x89,
None = b'N',
BinGet = b'h',
LongBinGet = b'j',
SetItem = b's',
SetItems = b'u',
EmptyDict = b'}',
Dict = b'd',
Build = b'b',
Stop = b'.',
NewObj = 0x81,
EmptyList = b']',
BinFloat = b'g',
Append = b'a',
Appends = b'e',
}
// Avoid using FromPrimitive so as not to drag another dependency.
impl TryFrom<u8> for OpCode {
type Error = u8;
fn try_from(value: u8) -> std::result::Result<Self, Self::Error> {
match value {
0x80 => Ok(Self::Proto),
b'c' => Ok(Self::Global),
b'q' => Ok(Self::BinPut),
b'r' => Ok(Self::LongBinPut),
b')' => Ok(Self::EmptyTuple),
b'R' => Ok(Self::Reduce),
b'(' => Ok(Self::Mark),
b'X' => Ok(Self::BinUnicode),
b'J' => Ok(Self::BinInt),
b't' => Ok(Self::Tuple),
b'Q' => Ok(Self::BinPersId),
b'K' => Ok(Self::BinInt1),
b'M' => Ok(Self::BinInt2),
b'N' => Ok(Self::None),
0x85 => Ok(Self::Tuple1),
0x86 => Ok(Self::Tuple2),
0x87 => Ok(Self::Tuple3),
0x88 => Ok(Self::NewTrue),
0x89 => Ok(Self::NewFalse),
b'h' => Ok(Self::BinGet),
b'j' => Ok(Self::LongBinGet),
b's' => Ok(Self::SetItem),
b'u' => Ok(Self::SetItems),
b'}' => Ok(Self::EmptyDict),
b'd' => Ok(Self::EmptyDict),
b'b' => Ok(Self::Build),
b'.' => Ok(Self::Stop),
0x81 => Ok(Self::NewObj),
b']' => Ok(Self::EmptyList),
b'G' => Ok(Self::BinFloat),
b'a' => Ok(Self::Append),
b'e' => Ok(Self::Appends),
value => Err(value),
}
}
}
fn read_to_newline<R: BufRead>(r: &mut R) -> Result<Vec<u8>> {
let mut data: Vec<u8> = Vec::with_capacity(32);
r.read_until(b'\n', &mut data)?;
data.pop();
if data.last() == Some(&b'\r') {
data.pop();
}
Ok(data)
}
#[derive(Debug, Clone, PartialEq)]
pub enum Object {
Class {
module_name: String,
class_name: String,
},
Int(i32),
Float(f64),
Unicode(String),
Bool(bool),
None,
Tuple(Vec<Object>),
List(Vec<Object>),
Mark,
Dict(Vec<(Object, Object)>),
Reduce {
callable: Box<Object>,
args: Box<Object>,
},
Build {
callable: Box<Object>,
args: Box<Object>,
},
PersistentLoad(Box<Object>),
}
type OResult<T> = std::result::Result<T, Object>;
impl Object {
pub fn unicode(self) -> OResult<String> {
match self {
Self::Unicode(t) => Ok(t),
_ => Err(self),
}
}
pub fn reduce(self) -> OResult<(Self, Self)> {
match self {
Self::Reduce { callable, args } => Ok((*callable, *args)),
_ => Err(self),
}
}
pub fn none(self) -> OResult<()> {
match self {
Self::None => Ok(()),
_ => Err(self),
}
}
pub fn persistent_load(self) -> OResult<Self> {
match self {
Self::PersistentLoad(t) => Ok(*t),
_ => Err(self),
}
}
pub fn bool(self) -> OResult<bool> {
match self {
Self::Bool(t) => Ok(t),
_ => Err(self),
}
}
pub fn int(self) -> OResult<i32> {
match self {
Self::Int(t) => Ok(t),
_ => Err(self),
}
}
pub fn tuple(self) -> OResult<Vec<Self>> {
match self {
Self::Tuple(t) => Ok(t),
_ => Err(self),
}
}
pub fn dict(self) -> OResult<Vec<(Self, Self)>> {
match self {
Self::Dict(t) => Ok(t),
_ => Err(self),
}
}
pub fn class(self) -> OResult<(String, String)> {
match self {
Self::Class {
module_name,
class_name,
} => Ok((module_name, class_name)),
_ => Err(self),
}
}
pub fn into_tensor_info(
self,
name: Self,
dir_name: &std::path::Path,
) -> Result<Option<TensorInfo>> {
let name = match name.unicode() {
Ok(name) => name,
Err(_) => return Ok(None),
};
let (callable, args) = match self.reduce() {
Ok(callable_args) => callable_args,
_ => return Ok(None),
};
let (callable, args) = match callable {
Object::Class {
module_name,
class_name,
} if module_name == "torch._tensor" && class_name == "_rebuild_from_type_v2" => {
let mut args = args.tuple()?;
let callable = args.remove(0);
let args = args.remove(1);
(callable, args)
}
Object::Class {
module_name,
class_name,
} if module_name == "torch._utils" && class_name == "_rebuild_parameter" => {
let mut args = args.tuple()?;
args.remove(0).reduce()?
}
_ => (callable, args),
};
match callable {
Object::Class {
module_name,
class_name,
} if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
_ => return Ok(None),
};
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
Ok(Some(TensorInfo {
name,
dtype,
layout,
path: format!("{}/{}", dir_name.to_string_lossy(), file_path),
storage_size,
}))
}
}
impl TryFrom<Object> for String {
type Error = Object;
fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
match value {
Object::Unicode(s) => Ok(s),
other => Err(other),
}
}
}
impl TryFrom<Object> for usize {
type Error = Object;
fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
match value {
Object::Int(s) if s >= 0 => Ok(s as usize),
other => Err(other),
}
}
}
impl<T: TryFrom<Object, Error = Object>> TryFrom<Object> for Vec<T> {
type Error = Object;
fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
match value {
Object::Tuple(values) => {
// This does not return the appropriate value in the error case but instead return
// the object related to the first error.
values
.into_iter()
.map(|v| T::try_from(v))
.collect::<std::result::Result<Vec<T>, Self::Error>>()
}
other => Err(other),
}
}
}
#[derive(Debug)]
pub struct Stack {
stack: Vec<Object>,
memo: HashMap<u32, Object>,
}
impl Stack {
pub fn empty() -> Self {
Self {
stack: Vec::with_capacity(512),
memo: HashMap::new(),
}
}
pub fn stack(&self) -> &[Object] {
self.stack.as_slice()
}
pub fn read_loop<R: BufRead>(&mut self, r: &mut R) -> Result<()> {
loop {
if self.read(r)? {
break;
}
}
Ok(())
}
pub fn finalize(mut self) -> Result<Object> {
self.pop()
}
fn push(&mut self, obj: Object) {
self.stack.push(obj)
}
fn pop(&mut self) -> Result<Object> {
match self.stack.pop() {
None => crate::bail!("unexpected empty stack"),
Some(obj) => Ok(obj),
}
}
// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/#Pickle.OpCodes.BUILD
fn build(&mut self) -> Result<()> {
let args = self.pop()?;
let obj = self.pop()?;
let obj = match (obj, args) {
(Object::Dict(mut obj), Object::Dict(mut args)) => {
obj.append(&mut args);
Object::Dict(obj)
}
(obj, args) => Object::Build {
callable: Box::new(obj),
args: Box::new(args),
},
};
self.push(obj);
Ok(())
}
fn reduce(&mut self) -> Result<()> {
let args = self.pop()?;
let callable = self.pop()?;
#[allow(clippy::single_match)]
let reduced = match &callable {
Object::Class {
module_name,
class_name,
} => {
if module_name == "collections"
&& (class_name == "OrderedDict" || class_name == "defaultdict")
{
// TODO: have a separate ordered dict and a separate default dict.
Some(Object::Dict(vec![]))
} else {
None
}
}
_ => None,
};
let reduced = reduced.unwrap_or_else(|| Object::Reduce {
callable: Box::new(callable),
args: Box::new(args),
});
self.push(reduced);
Ok(())
}
fn last(&mut self) -> Result<&mut Object> {
match self.stack.last_mut() {
None => crate::bail!("unexpected empty stack"),
Some(obj) => Ok(obj),
}
}
fn memo_get(&self, id: u32) -> Result<Object> {
match self.memo.get(&id) {
None => crate::bail!("missing object in memo {id}"),
Some(obj) => {
// Maybe we should use refcounting rather than doing potential large clones here.
Ok(obj.clone())
}
}
}
fn memo_put(&mut self, id: u32) -> Result<()> {
let obj = self.last()?.clone();
self.memo.insert(id, obj);
Ok(())
}
fn persistent_load(&self, id: Object) -> Result<Object> {
Ok(Object::PersistentLoad(Box::new(id)))
}
fn new_obj(&self, class: Object, args: Object) -> Result<Object> {
Ok(Object::Reduce {
callable: Box::new(class),
args: Box::new(args),
})
}
fn pop_to_marker(&mut self) -> Result<Vec<Object>> {
let mut mark_idx = None;
for (idx, obj) in self.stack.iter().enumerate().rev() {
if obj == &Object::Mark {
mark_idx = Some(idx);
break;
}
}
match mark_idx {
Some(mark_idx) => {
let objs = self.stack.split_off(mark_idx + 1);
self.stack.pop();
Ok(objs)
}
None => {
crate::bail!("marker object not found")
}
}
}
pub fn read<R: BufRead>(&mut self, r: &mut R) -> Result<bool> {
let op_code = match OpCode::try_from(r.read_u8()?) {
Ok(op_code) => op_code,
Err(op_code) => {
crate::bail!("unknown op-code {op_code}")
}
};
// println!("op: {op_code:?}");
// println!("{:?}", self.stack);
match op_code {
OpCode::Proto => {
let version = r.read_u8()?;
if VERBOSE {
println!("proto {version}");
}
}
OpCode::Global => {
let module_name = read_to_newline(r)?;
let class_name = read_to_newline(r)?;
let module_name = String::from_utf8_lossy(&module_name).to_string();
let class_name = String::from_utf8_lossy(&class_name).to_string();
self.push(Object::Class {
module_name,
class_name,
})
}
OpCode::BinInt1 => {
let arg = r.read_u8()?;
self.push(Object::Int(arg as i32))
}
OpCode::BinInt2 => {
let arg = r.read_u16::<LittleEndian>()?;
self.push(Object::Int(arg as i32))
}
OpCode::BinInt => {
let arg = r.read_i32::<LittleEndian>()?;
self.push(Object::Int(arg))
}
OpCode::BinFloat => {
let arg = r.read_f64::<LittleEndian>()?;
self.push(Object::Float(arg))
}
OpCode::BinUnicode => {
let len = r.read_u32::<LittleEndian>()?;
let mut data = vec![0u8; len as usize];
r.read_exact(&mut data)?;
let data = String::from_utf8(data).map_err(E::wrap)?;
self.push(Object::Unicode(data))
}
OpCode::BinPersId => {
let id = self.pop()?;
let obj = self.persistent_load(id)?;
self.push(obj)
}
OpCode::Tuple => {
let objs = self.pop_to_marker()?;
self.push(Object::Tuple(objs))
}
OpCode::Tuple1 => {
let obj = self.pop()?;
self.push(Object::Tuple(vec![obj]))
}
OpCode::Tuple2 => {
let obj2 = self.pop()?;
let obj1 = self.pop()?;
self.push(Object::Tuple(vec![obj1, obj2]))
}
OpCode::Tuple3 => {
let obj3 = self.pop()?;
let obj2 = self.pop()?;
let obj1 = self.pop()?;
self.push(Object::Tuple(vec![obj1, obj2, obj3]))
}
OpCode::NewTrue => self.push(Object::Bool(true)),
OpCode::NewFalse => self.push(Object::Bool(false)),
OpCode::Append => {
let value = self.pop()?;
let pylist = self.last()?;
if let Object::List(d) = pylist {
d.push(value)
} else {
crate::bail!("expected a list, got {pylist:?}")
}
}
OpCode::Appends => {
let objs = self.pop_to_marker()?;
let pylist = self.last()?;
if let Object::List(d) = pylist {
d.extend(objs)
} else {
crate::bail!("expected a list, got {pylist:?}")
}
}
OpCode::SetItem => {
let value = self.pop()?;
let key = self.pop()?;
let pydict = self.last()?;
if let Object::Dict(d) = pydict {
d.push((key, value))
} else {
crate::bail!("expected a dict, got {pydict:?}")
}
}
OpCode::SetItems => {
let mut objs = self.pop_to_marker()?;
let pydict = self.last()?;
if let Object::Dict(d) = pydict {
if objs.len() % 2 != 0 {
crate::bail!("setitems: not an even number of objects")
}
while let Some(value) = objs.pop() {
let key = objs.pop().unwrap();
d.push((key, value))
}
} else {
crate::bail!("expected a dict, got {pydict:?}")
}
}
OpCode::None => self.push(Object::None),
OpCode::Stop => {
return Ok(true);
}
OpCode::Build => self.build()?,
OpCode::EmptyDict => self.push(Object::Dict(vec![])),
OpCode::Dict => {
let mut objs = self.pop_to_marker()?;
let mut pydict = vec![];
if objs.len() % 2 != 0 {
crate::bail!("setitems: not an even number of objects")
}
while let Some(value) = objs.pop() {
let key = objs.pop().unwrap();
pydict.push((key, value))
}
self.push(Object::Dict(pydict))
}
OpCode::Mark => self.push(Object::Mark),
OpCode::Reduce => self.reduce()?,
OpCode::EmptyTuple => self.push(Object::Tuple(vec![])),
OpCode::EmptyList => self.push(Object::List(vec![])),
OpCode::BinGet => {
let arg = r.read_u8()?;
let obj = self.memo_get(arg as u32)?;
self.push(obj)
}
OpCode::LongBinGet => {
let arg = r.read_u32::<LittleEndian>()?;
let obj = self.memo_get(arg)?;
self.push(obj)
}
OpCode::BinPut => {
let arg = r.read_u8()?;
self.memo_put(arg as u32)?
}
OpCode::LongBinPut => {
let arg = r.read_u32::<LittleEndian>()?;
self.memo_put(arg)?
}
OpCode::NewObj => {
let args = self.pop()?;
let class = self.pop()?;
let obj = self.new_obj(class, args)?;
self.push(obj)
}
}
Ok(false)
}
}
impl From<Object> for E {
fn from(value: Object) -> Self {
E::Msg(format!("conversion error on {value:?}"))
}
}
// https://github.com/pytorch/pytorch/blob/4eac43d046ded0f0a5a5fa8db03eb40f45bf656e/torch/_utils.py#L198
// Arguments: storage, storage_offset, size, stride, requires_grad, backward_hooks
fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
let mut args = args.tuple()?;
let stride = Vec::<usize>::try_from(args.remove(3))?;
let size = Vec::<usize>::try_from(args.remove(2))?;
let offset = args.remove(1).int()? as usize;
let storage = args.remove(0).persistent_load()?;
let mut storage = storage.tuple()?;
let storage_size = storage.remove(4).int()? as usize;
let path = storage.remove(2).unicode()?;
let (_module_name, class_name) = storage.remove(1).class()?;
let dtype = match class_name.as_str() {
"FloatStorage" => DType::F32,
"DoubleStorage" => DType::F64,
"HalfStorage" => DType::F16,
"BFloat16Storage" => DType::BF16,
"ByteStorage" => DType::U8,
"LongStorage" => DType::I64,
other => {
crate::bail!("unsupported storage type {other}")
}
};
let layout = Layout::new(crate::Shape::from(size), stride, offset);
Ok((layout, dtype, path, storage_size))
}
#[derive(Debug, Clone)]
pub struct TensorInfo {
pub name: String,
pub dtype: DType,
pub layout: Layout,
pub path: String,
pub storage_size: usize,
}
/// Read the tensor info from a .pth file.
///
/// # Arguments
/// * `file` - The path to the .pth file.
/// * `verbose` - Whether to print debug information.
/// * `key` - Optional key to retrieve `state_dict` from the pth file.
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
file: P,
verbose: bool,
key: Option<&str>,
) -> Result<Vec<TensorInfo>> {
let file = std::fs::File::open(file)?;
let zip_reader = std::io::BufReader::new(file);
let mut zip = zip::ZipArchive::new(zip_reader)?;
let zip_file_names = zip
.file_names()
.map(|f| f.to_string())
.collect::<Vec<String>>();
let mut tensor_infos = vec![];
for file_name in zip_file_names.iter() {
if !file_name.ends_with("data.pkl") {
continue;
}
let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").unwrap());
let reader = zip.by_name(file_name)?;
let mut reader = std::io::BufReader::new(reader);
let mut stack = Stack::empty();
stack.read_loop(&mut reader)?;
let obj = stack.finalize()?;
if VERBOSE || verbose {
println!("{obj:#?}");
}
let obj = match obj {
Object::Build { callable, args } => match *callable {
Object::Reduce { callable, args: _ } => match *callable {
Object::Class {
module_name,
class_name,
} if module_name == "__torch__" && class_name == "Module" => *args,
_ => continue,
},
_ => continue,
},
obj => obj,
};
// If key is provided, then we need to extract the state_dict from the object.
let obj = if let Some(key) = key {
if let Object::Dict(key_values) = obj {
key_values
.into_iter()
.find(|(k, _)| *k == Object::Unicode(key.to_owned()))
.map(|(_, v)| v)
.ok_or_else(|| E::Msg(format!("key {key} not found")))?
} else {
obj
}
} else {
obj
};
// If the object is a dict, then we can extract the tensor info from it.
// NOTE: We are assuming that the `obj` is state_dict by this stage.
if let Object::Dict(key_values) = obj {
for (name, value) in key_values.into_iter() {
match value.into_tensor_info(name, &dir_name) {
Ok(Some(tensor_info)) => tensor_infos.push(tensor_info),
Ok(None) => {}
Err(err) => eprintln!("skipping: {err:?}"),
}
}
}
}
Ok(tensor_infos)
}
/// Lazy tensor loader.
pub struct PthTensors {
tensor_infos: HashMap<String, TensorInfo>,
path: std::path::PathBuf,
// We do not store a zip reader as it needs mutable access to extract data. Instead we
// re-create a zip reader for each tensor.
}
impl PthTensors {
pub fn new<P: AsRef<std::path::Path>>(path: P, key: Option<&str>) -> Result<Self> {
let tensor_infos = read_pth_tensor_info(path.as_ref(), false, key)?;
let tensor_infos = tensor_infos
.into_iter()
.map(|ti| (ti.name.to_string(), ti))
.collect();
let path = path.as_ref().to_owned();
Ok(Self { tensor_infos, path })
}
pub fn tensor_infos(&self) -> &HashMap<String, TensorInfo> {
&self.tensor_infos
}
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
use std::io::Read;
let tensor_info = match self.tensor_infos.get(name) {
None => return Ok(None),
Some(tensor_info) => tensor_info,
};
// We hope that the file has not changed since first reading it.
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut reader = zip.by_name(&tensor_info.path)?;
let is_fortran_contiguous = tensor_info.layout.is_fortran_contiguous();
let rank = tensor_info.layout.shape().rank();
// Reading the data is a bit tricky as it can be strided, for now only support the basic
// case and when the tensor is fortran contiguous.
if !tensor_info.layout.is_contiguous() && !is_fortran_contiguous {
crate::bail!(
"cannot retrieve non-contiguous tensors {:?}",
tensor_info.layout
)
}
let start_offset = tensor_info.layout.start_offset();
if start_offset > 0 {
std::io::copy(
&mut reader.by_ref().take(start_offset as u64),
&mut std::io::sink(),
)?;
}
let tensor = Tensor::from_reader(
tensor_info.layout.shape().clone(),
tensor_info.dtype,
&mut reader,
)?;
if rank > 1 && is_fortran_contiguous {
// Reverse the shape, e.g. Shape(2, 3, 4) -> Shape(4, 3, 2)
let shape_reversed: Vec<_> = tensor_info.layout.dims().iter().rev().cloned().collect();
let tensor = tensor.reshape(shape_reversed)?;
// Permute (transpose) the dimensions, e.g. Shape(4, 3, 2) -> Shape(2, 3, 4)
let dim_indeces_reversed: Vec<_> = (0..rank).rev().collect();
let tensor = tensor.permute(dim_indeces_reversed)?;
Ok(Some(tensor))
} else {
Ok(Some(tensor))
}
}
}
/// Read all the tensors from a PyTorch pth file with a given key.
///
/// # Arguments
/// * `path` - Path to the pth file.
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
/// contains multiple objects and the state_dict is the one we are interested in.
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
path: P,
key: Option<&str>,
) -> Result<Vec<(String, Tensor)>> {
let pth = PthTensors::new(path, key)?;
let tensor_names = pth.tensor_infos.keys();
let mut tensors = Vec::with_capacity(tensor_names.len());
for name in tensor_names {
if let Some(tensor) = pth.get(name)? {
tensors.push((name.to_string(), tensor))
}
}
Ok(tensors)
}
/// Read all the tensors from a PyTorch pth file.
///
/// # Arguments
/// * `path` - Path to the pth file.
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
read_all_with_key(path, None)
}

View File

@ -1,667 +0,0 @@
use super::k_quants::{
BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K,
};
use crate::Result;
use byteorder::{ByteOrder, LittleEndian};
use half::f16;
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
#[inline(always)]
pub(crate) unsafe fn sum_i16_pairs_float(x: __m256i) -> __m256 {
let ones = _mm256_set1_epi16(1);
let summed_pairs = _mm256_madd_epi16(ones, x);
_mm256_cvtepi32_ps(summed_pairs)
}
#[inline(always)]
pub(crate) unsafe fn mul_sum_us8_pairs_float(ax: __m256i, sy: __m256i) -> __m256 {
let dot = _mm256_maddubs_epi16(ax, sy);
sum_i16_pairs_float(dot)
}
#[inline(always)]
pub(crate) unsafe fn hsum_float_8(x: __m256) -> f32 {
let res = _mm256_extractf128_ps(x, 1);
let res = _mm_add_ps(res, _mm256_castps256_ps128(x));
let res = _mm_add_ps(res, _mm_movehl_ps(res, res));
let res = _mm_add_ss(res, _mm_movehdup_ps(res));
_mm_cvtss_f32(res)
}
#[inline(always)]
pub(crate) unsafe fn bytes_from_nibbles_32(rsi: *const u8) -> __m256i {
let tmp = _mm_loadu_si128(rsi as *const __m128i);
let bytes = _mm256_insertf128_si256::<1>(_mm256_castsi128_si256(tmp), _mm_srli_epi16(tmp, 4));
let low_mask = _mm256_set1_epi8(0xF);
_mm256_and_si256(low_mask, bytes)
}
#[inline(always)]
pub(crate) unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 {
let ax = _mm256_sign_epi8(x, x);
let sy = _mm256_sign_epi8(y, x);
mul_sum_us8_pairs_float(ax, sy)
}
#[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
unsafe {
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d));
let bx = bytes_from_nibbles_32(x.qs.as_ptr());
let off = _mm256_set1_epi8(8);
let bx = _mm256_sub_epi8(bx, off);
let by = _mm256_loadu_si256(y.qs.as_ptr() as *const __m256i);
let q = mul_sum_i8_pairs_float(bx, by);
acc = _mm256_fmadd_ps(d, q, acc);
}
Ok(hsum_float_8(acc))
}
}
#[inline(always)]
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
}
unsafe {
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d));
let bx = _mm256_loadu_si256(x.qs.as_ptr() as *const __m256i);
let by = _mm256_loadu_si256(y.qs.as_ptr() as *const __m256i);
let q = mul_sum_i8_pairs_float(bx, by);
acc = _mm256_fmadd_ps(d, q, acc);
}
Ok(hsum_float_8(acc))
}
}
#[inline(always)]
unsafe fn get_scale_shuffle(i: usize) -> __m128i {
const K_SHUFFLE: [u8; 128] = [
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3,
3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7,
7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10,
11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13,
13, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15,
];
_mm_loadu_si128((K_SHUFFLE.as_ptr() as *const __m128i).add(i))
}
#[inline(always)]
unsafe fn get_scale_shuffle_k4(i: usize) -> __m256i {
const K_SHUFFLE: [u8; 256] = [
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
2, 3, 2, 3, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
6, 7, 6, 7, 6, 7, 6, 7, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10,
11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 12, 13, 12, 13, 12, 13,
12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12,
13, 12, 13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
];
_mm256_loadu_si256((K_SHUFFLE.as_ptr() as *const __m256i).add(i))
}
#[inline(always)]
unsafe fn get_scale_shuffle_q3k(i: usize) -> __m256i {
const K_SHUFFLE: [u8; 128] = [
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
2, 3, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
6, 7, 6, 7, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10, 11, 10, 11, 10, 11, 10, 11,
10, 11, 10, 11, 10, 11, 10, 11, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12,
13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
];
_mm256_loadu_si256((K_SHUFFLE.as_ptr() as *const __m256i).add(i))
}
#[inline(always)]
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
let qk = QK_K;
if n % qk != 0 {
crate::bail!("vec_dot_q6k_8k: {n} is not divisible by {qk}")
}
unsafe {
let m4 = _mm256_set1_epi8(0xF);
let m2 = _mm256_set1_epi8(3);
let m32s = _mm256_set1_epi8(32);
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let mut q4 = x.ql.as_ptr();
let mut qh = x.qh.as_ptr();
let mut q8 = y.qs.as_ptr();
let scales = _mm_loadu_si128(x.scales.as_ptr() as *const __m128i);
let mut sumi = _mm256_setzero_si256();
for j in 0..QK_K / 128 {
let is = j * 4;
let scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is));
let scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
let scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
let scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
let q4bits1 = _mm256_loadu_si256(q4 as *const __m256i);
q4 = q4.add(32);
let q4bits2 = _mm256_loadu_si256(q4 as *const __m256i);
q4 = q4.add(32);
let q4bits_h = _mm256_loadu_si256(qh as *const __m256i);
qh = qh.add(32);
let q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bits_h, m2), 4);
let q4h_1 =
_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bits_h, 2), m2), 4);
let q4h_2 =
_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bits_h, 4), m2), 4);
let q4h_3 =
_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bits_h, 6), m2), 4);
let q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
let q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
let q4_2 =
_mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
let q4_3 =
_mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_2 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_3 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
let q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
let q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
let q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
let p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
let p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
let p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
let p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
let p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
let p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
let p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
let p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
let p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
let p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
let p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2);
let p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3);
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3));
}
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
}
Ok(hsum_float_8(acc))
}
}
#[inline(always)]
unsafe fn mm256_set_m128i(a: __m128i, b: __m128i) -> __m256i {
_mm256_insertf128_si256(_mm256_castsi128_si256(b), a, 1)
}
#[inline(always)]
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
}
unsafe {
let m3 = _mm256_set1_epi8(3);
let m4 = _mm_set1_epi8(0xF);
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let dmin = -y.d * x.dmin.to_f32();
let mut q2 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();
let mins_and_scales = _mm_loadu_si128(x.scales.as_ptr() as *const __m128i);
let scales8 = _mm_and_si128(mins_and_scales, m4);
let mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
let mins = _mm256_cvtepi8_epi16(mins8);
let prod =
_mm256_madd_epi16(mins, _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i));
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc);
let all_scales = _mm256_cvtepi8_epi16(scales8);
let l_scales = _mm256_extracti128_si256(all_scales, 0);
let h_scales = _mm256_extracti128_si256(all_scales, 1);
let scales = [
mm256_set_m128i(l_scales, l_scales),
mm256_set_m128i(h_scales, h_scales),
];
let mut sumi = _mm256_setzero_si256();
for scale in scales {
let q2bits = _mm256_loadu_si256(q2 as *const __m256i);
q2 = q2.add(32);
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_2 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_3 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q2_0 = _mm256_and_si256(q2bits, m3);
let q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3);
let q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3);
let q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3);
let p0 = _mm256_maddubs_epi16(q2_0, q8_0);
let p1 = _mm256_maddubs_epi16(q2_1, q8_1);
let p2 = _mm256_maddubs_epi16(q2_2, q8_2);
let p3 = _mm256_maddubs_epi16(q2_3, q8_3);
let p0 =
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(0)), p0);
let p1 =
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(1)), p1);
let p2 =
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(2)), p2);
let p3 =
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(3)), p3);
let p0 = _mm256_add_epi32(p0, p1);
let p2 = _mm256_add_epi32(p2, p3);
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2));
}
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
}
Ok(hsum_float_8(acc))
}
}
#[inline(always)]
pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
}
const KMASK1: u32 = 0x03030303;
const KMASK2: u32 = 0x0f0f0f0f;
let mut aux = [0u32; 3];
unsafe {
let m3 = _mm256_set1_epi8(3);
let mone = _mm256_set1_epi8(1);
let m32 = _mm_set1_epi8(32);
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let mut q3 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();
LittleEndian::read_u32_into(&x.scales, &mut aux);
let scales128 = _mm_set_epi32(
(((aux[1] >> 4) & KMASK2) | (((aux[2] >> 6) & KMASK1) << 4)) as i32,
(((aux[0] >> 4) & KMASK2) | (((aux[2] >> 4) & KMASK1) << 4)) as i32,
((aux[1] & KMASK2) | (((aux[2] >> 2) & KMASK1) << 4)) as i32,
((aux[0] & KMASK2) | (((aux[2]) & KMASK1) << 4)) as i32,
);
let scales128 = _mm_sub_epi8(scales128, m32);
let all_scales = _mm256_cvtepi8_epi16(scales128);
let l_scales = _mm256_extracti128_si256(all_scales, 0);
let h_scales = _mm256_extracti128_si256(all_scales, 1);
let scales = [
mm256_set_m128i(l_scales, l_scales),
mm256_set_m128i(h_scales, h_scales),
];
// high bit
let hbits = _mm256_loadu_si256(x.hmask.as_ptr() as *const __m256i);
let mut sumi = _mm256_setzero_si256();
for (j, scale) in scales.iter().enumerate() {
// load low 2 bits
let q3bits = _mm256_loadu_si256(q3 as *const __m256i);
q3 = q3.add(32);
// Prepare low and high bits
// We hardcode the shifts here to avoid loading them into a separate register
let q3l_0 = _mm256_and_si256(q3bits, m3);
let q3h_0 = if j == 0 {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), 0)
} else {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 4)), 4)
};
let q3h_0 = _mm256_slli_epi16(q3h_0, 2);
let q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3);
let q3h_1 = if j == 0 {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 1)), 1)
} else {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 5)), 5)
};
let q3h_1 = _mm256_slli_epi16(q3h_1, 2);
let q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3);
let q3h_2 = if j == 0 {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 2)), 2)
} else {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 6)), 6)
};
let q3h_2 = _mm256_slli_epi16(q3h_2, 2);
let q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3);
let q3h_3 = if j == 0 {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 3)), 3)
} else {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 7)), 7)
};
let q3h_3 = _mm256_slli_epi16(q3h_3, 2);
// load Q8 quants
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_2 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_3 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we
// can use _mm256_maddubs_epi16, and then subtract. The high bit part has the 2
// already subtracted (and so, it is zero if the high bit was not set, and 2 if the
// high bit was set)
let q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
let q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
let q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2);
let q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3);
let p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
let p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
let p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2);
let p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3);
let p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
let p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
let p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
let p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
// multiply with scales
let p16_0 =
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(0)), p16_0);
let p16_1 =
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(1)), p16_1);
let p16_2 =
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(2)), p16_2);
let p16_3 =
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(3)), p16_3);
// accumulate
let p16_0 = _mm256_add_epi32(p16_0, p16_1);
let p16_2 = _mm256_add_epi32(p16_2, p16_3);
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2));
}
// multiply with block scale and accumulate
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
}
Ok(hsum_float_8(acc))
}
}
#[inline(always)]
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
}
let mut utmp = [0u32; 4];
const KMASK1: u32 = 0x3f3f3f3f;
const KMASK2: u32 = 0x0f0f0f0f;
const KMASK3: u32 = 0x03030303;
unsafe {
let m4 = _mm256_set1_epi8(0xF);
let mut acc = _mm256_setzero_ps();
let mut acc_m = _mm_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let dmin = -y.d * x.dmin.to_f32();
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
let uaux = utmp[1] & KMASK1;
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
utmp[2] = uaux;
utmp[0] &= KMASK1;
let mut q4 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();
let mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(
utmp[3] as i32,
utmp[2] as i32,
utmp[1] as i32,
utmp[0] as i32,
));
let q8sums = _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i);
let q8s = _mm_hadd_epi16(
_mm256_extracti128_si256(q8sums, 0),
_mm256_extracti128_si256(q8sums, 1),
);
let prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
let sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
let scales = mm256_set_m128i(sc128, sc128);
let mut sumi = _mm256_setzero_si256();
for j in 0..QK_K / 64 {
let scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j));
let scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j + 1));
let q4bits = _mm256_loadu_si256(q4 as *const __m256i);
q4 = q4.add(32);
let q4l = _mm256_and_si256(q4bits, m4);
let q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
let q8l = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let p16l = _mm256_maddubs_epi16(q4l, q8l);
let p16l = _mm256_madd_epi16(scale_l, p16l);
sumi = _mm256_add_epi32(sumi, p16l);
let q8h = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let p16h = _mm256_maddubs_epi16(q4h, q8h);
let p16h = _mm256_madd_epi16(scale_h, p16h);
sumi = _mm256_add_epi32(sumi, p16h);
}
let vd = _mm256_set1_ps(d);
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
}
let acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
let acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
Ok(hsum_float_8(acc) + _mm_cvtss_f32(acc_m))
}
}
#[inline(always)]
pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
}
let mut utmp = [0u32; 4];
const KMASK1: u32 = 0x3f3f3f3f;
const KMASK2: u32 = 0x0f0f0f0f;
const KMASK3: u32 = 0x03030303;
unsafe {
let m4 = _mm256_set1_epi8(0xF);
let mzero = _mm_setzero_si128();
let mone = _mm256_set1_epi8(1);
let mut acc = _mm256_setzero_ps();
let mut summs = 0.0;
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let dmin = -y.d * x.dmin.to_f32();
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
let uaux = utmp[1] & KMASK1;
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
utmp[2] = uaux;
utmp[0] &= KMASK1;
let mut q5 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();
let mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(
utmp[3] as i32,
utmp[2] as i32,
utmp[1] as i32,
utmp[0] as i32,
));
let q8sums = _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i);
let q8s = _mm_hadd_epi16(
_mm256_extracti128_si256(q8sums, 0),
_mm256_extracti128_si256(q8sums, 1),
);
let prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
let hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
summs += dmin * _mm_extract_epi32(hsum, 0) as f32;
let sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
let scales = mm256_set_m128i(sc128, sc128);
let hbits = _mm256_loadu_si256(x.qh.as_ptr() as *const __m256i);
let mut hmask = mone;
let mut sumi = _mm256_setzero_si256();
for j in 0..QK_K / 64 {
let scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j));
let scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j + 1));
let q5bits = _mm256_loadu_si256(q5 as *const __m256i);
q5 = q5.add(32);
//Similar to q3k we hardcode the shifts here to avoid loading them into a separate register
let q5l_0 = _mm256_and_si256(q5bits, m4);
let q5l_0_shift_input = _mm256_and_si256(hbits, hmask);
let q5l_0_right_shift = match j {
0 => _mm256_srli_epi16(q5l_0_shift_input, 0),
1 => _mm256_srli_epi16(q5l_0_shift_input, 2),
2 => _mm256_srli_epi16(q5l_0_shift_input, 4),
3 => _mm256_srli_epi16(q5l_0_shift_input, 6),
_ => unreachable!(),
};
let q5h_0 = _mm256_slli_epi16(q5l_0_right_shift, 4);
let q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
hmask = _mm256_slli_epi16(hmask, 1);
let q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
let q5l_1_shift_input = _mm256_and_si256(hbits, hmask);
let q5l_1_right_shift = match j {
0 => _mm256_srli_epi16(q5l_1_shift_input, 1),
1 => _mm256_srli_epi16(q5l_1_shift_input, 3),
2 => _mm256_srli_epi16(q5l_1_shift_input, 5),
3 => _mm256_srli_epi16(q5l_1_shift_input, 7),
_ => unreachable!(),
};
let q5h_1 = _mm256_slli_epi16(q5l_1_right_shift, 4);
let q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
hmask = _mm256_slli_epi16(hmask, 1);
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
let p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
let p16_0 = _mm256_madd_epi16(scale_0, p16_0);
let p16_1 = _mm256_madd_epi16(scale_1, p16_1);
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
}
let vd = _mm256_set1_ps(d);
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
}
Ok(hsum_float_8(acc) + summs)
}
}
#[inline(always)]
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
let qk = QK_K;
if n % qk != 0 {
crate::bail!("vec_dot_q8k_8k: {n} is not divisible by {qk}")
}
unsafe {
let mut acc = _mm256_setzero_ps();
for (xs, ys) in xs.iter().zip(ys.iter()) {
let mut sumi = _mm256_setzero_si256();
let x_qs = xs.qs.as_ptr();
let y_qs = ys.qs.as_ptr();
for j in (0..QK_K).step_by(32) {
let xs = _mm256_loadu_si256(x_qs.add(j) as *const __m256i);
let ys = _mm256_loadu_si256(y_qs.add(j) as *const __m256i);
let xs0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xs, 0));
let ys0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(ys, 0));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(xs0, ys0));
let xs1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xs, 1));
let ys1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(ys, 1));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(xs1, ys1));
}
let d = _mm256_set1_ps(xs.d * ys.d);
acc = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi), acc);
}
Ok(hsum_float_8(acc))
}
}

View File

@ -1,43 +0,0 @@
#![allow(unused)]
use super::GgmlDType;
use crate::{Error, MetalDevice, MetalStorage, Result};
pub struct QMetalStorage {
dtype: GgmlDType,
device: MetalDevice,
}
impl QMetalStorage {
pub fn zeros(_: &MetalDevice, _: usize, _: GgmlDType) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
pub fn dtype(&self) -> GgmlDType {
self.dtype
}
pub fn device(&self) -> &MetalDevice {
&self.device
}
pub fn dequantize(&self, _elem_count: usize) -> Result<MetalStorage> {
Err(Error::NotCompiledWithMetalSupport)
}
pub fn quantize(&mut self, _src: &MetalStorage) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
pub fn storage_size_in_bytes(&self) -> usize {
0
}
pub fn fwd(
&self,
_self_shape: &crate::Shape,
_storage: &MetalStorage,
_layout: &crate::Layout,
) -> Result<(MetalStorage, crate::Shape)> {
Err(Error::NotCompiledWithMetalSupport)
}
}

View File

@ -1,11 +1,8 @@
//! Support for the GGML file format.
#[cfg(feature = "metal")]
use super::metal::load_quantized_metal;
use super::{k_quants, GgmlDType, QStorage};
use crate::{Device, Result};
use super::{k_quants, GgmlDType};
use crate::Result;
use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap;
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -123,22 +120,11 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
raw_data: &[u8],
size_in_bytes: usize,
dims: Vec<usize>,
device: &Device,
) -> Result<super::QTensor> {
let raw_data_ptr = raw_data.as_ptr();
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
let data: QStorage = match device {
Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
#[cfg(feature = "metal")]
Device::Metal(metal) => load_quantized_metal(metal, data)?,
#[cfg(not(feature = "metal"))]
Device::Metal(_metal) => {
crate::bail!("Metal backend requires `metal` feature")
}
device => unimplemented!("Implement quantized tensor for device {device:?}"),
};
super::QTensor::new(data, dims)
Ok(super::QTensor::new(data.to_vec(), dims))
}
/// Creates a [Tensor] from a raw GGML tensor.
@ -146,50 +132,23 @@ pub fn qtensor_from_ggml(
ggml_dtype: GgmlDType,
raw_data: &[u8],
dims: Vec<usize>,
device: &Device,
) -> Result<super::QTensor> {
let tensor_elems = dims.iter().product::<usize>();
let block_size = ggml_dtype.block_size();
if tensor_elems % block_size != 0 {
crate::bail!(
"the number of elements {tensor_elems} is not divisible by the block size {block_size}"
)
}
let size_in_bytes = tensor_elems / block_size * ggml_dtype.type_size();
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
match ggml_dtype {
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q4_0 => {
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q4_1 => {
from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5_0 => {
from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5_1 => {
from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q8_0 => {
from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q2K => {
from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q3K => {
from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q4K => {
from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5K => {
from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q6K => {
from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
}
}
@ -197,7 +156,6 @@ pub fn qtensor_from_ggml(
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
reader: &mut R,
magic: VersionedMagic,
device: &Device,
) -> Result<(String, super::QTensor)> {
let n_dims = reader.read_u32::<LittleEndian>()?;
let name_len = reader.read_u32::<LittleEndian>()?;
@ -205,9 +163,6 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
let ggml_dtype = GgmlDType::from_u32(ggml_dtype)?;
let mut dims = vec![0u32; n_dims as usize];
reader.read_u32_into::<LittleEndian>(&mut dims)?;
// The dimensions are stored in reverse order, see for example:
// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/convert.py#L969
dims.reverse();
let mut name = vec![0u8; name_len as usize];
reader.read_exact(&mut name)?;
let name = String::from_utf8_lossy(&name).into_owned();
@ -218,11 +173,12 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
}
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
let tensor_elems = dims.iter().product::<usize>();
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.block_size();
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
println!("{name} {ggml_dtype:?} {dims:?}");
// TODO: Mmap version to avoid copying the data around?
let mut raw_data = vec![0u8; size_in_bytes];
reader.read_exact(&mut raw_data)?;
match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) {
match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
Ok(tensor) => Ok((name, tensor)),
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
}
@ -232,41 +188,28 @@ pub struct Content {
pub magic: VersionedMagic,
pub hparams: HParams,
pub vocab: Vocab,
pub tensors: HashMap<String, super::QTensor>,
pub device: Device,
pub tensors: Vec<(String, super::QTensor)>,
}
impl Content {
pub fn read<R: std::io::Seek + std::io::Read>(
reader: &mut R,
device: &Device,
) -> Result<Content> {
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
reader.seek(std::io::SeekFrom::Start(0))?;
let magic = VersionedMagic::read(reader)?;
let hparams = HParams::read(reader)?;
let vocab = Vocab::read(reader, hparams.n_vocab as usize)?;
let mut tensors = HashMap::new();
let mut tensors = vec![];
while reader.stream_position()? != last_position {
let (name, tensor) = read_one_tensor(reader, magic, device)?;
tensors.insert(name, tensor);
let (name, tensor) = read_one_tensor(reader, magic)?;
tensors.push((name, tensor))
}
let device = device.clone();
Ok(Self {
magic,
hparams,
vocab,
tensors,
device,
})
}
pub fn remove(&mut self, name: &str) -> Result<super::QTensor> {
match self.tensors.remove(name) {
None => crate::bail!("cannot find tensor with name '{name}'"),
Some(tensor) => Ok(tensor),
}
}
}

View File

@ -1,534 +0,0 @@
//! Support for the GGUF file format.
//!
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
use super::{GgmlDType, QTensor};
use crate::{Device, Result};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::collections::HashMap;
pub const DEFAULT_ALIGNMENT: u64 = 32;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Magic {
Gguf,
}
impl TryFrom<u32> for Magic {
type Error = crate::Error;
fn try_from(value: u32) -> Result<Self> {
let magic = match value {
0x46554747 | 0x47475546 => Self::Gguf,
_ => crate::bail!("unknown magic 0x{value:08x}"),
};
Ok(magic)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VersionedMagic {
GgufV1,
GgufV2,
GgufV3,
}
impl VersionedMagic {
fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
let magic = reader.read_u32::<LittleEndian>()?;
let magic = Magic::try_from(magic)?;
let version = reader.read_u32::<LittleEndian>()?;
let versioned_magic = match (magic, version) {
(Magic::Gguf, 1) => Self::GgufV1,
(Magic::Gguf, 2) => Self::GgufV2,
(Magic::Gguf, 3) => Self::GgufV3,
_ => crate::bail!("gguf: unsupported magic/version {magic:?}/{version}"),
};
Ok(versioned_magic)
}
}
#[derive(Debug)]
pub struct TensorInfo {
pub ggml_dtype: GgmlDType,
pub shape: crate::Shape,
pub offset: u64,
}
impl TensorInfo {
pub fn read<R: std::io::Seek + std::io::Read>(
&self,
reader: &mut R,
tensor_data_offset: u64,
device: &Device,
) -> Result<QTensor> {
let tensor_elems = self.shape.elem_count();
let block_size = self.ggml_dtype.block_size();
if tensor_elems % block_size != 0 {
crate::bail!(
"the number of elements {tensor_elems} is not divisible by the block size {block_size}"
)
}
let size_in_bytes = tensor_elems / block_size * self.ggml_dtype.type_size();
let mut raw_data = vec![0u8; size_in_bytes];
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
reader.read_exact(&mut raw_data)?;
super::ggml_file::qtensor_from_ggml(
self.ggml_dtype,
&raw_data,
self.shape.dims().to_vec(),
device,
)
}
}
#[derive(Debug)]
pub struct Content {
pub magic: VersionedMagic,
pub metadata: HashMap<String, Value>,
pub tensor_infos: HashMap<String, TensorInfo>,
pub tensor_data_offset: u64,
}
fn read_string<R: std::io::Read>(reader: &mut R, magic: &VersionedMagic) -> Result<String> {
let len = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
reader.read_u64::<LittleEndian>()? as usize
}
};
let mut v = vec![0u8; len];
reader.read_exact(&mut v)?;
// GGUF strings are supposed to be non-null terminated but in practice this happens.
while let Some(0) = v.last() {
v.pop();
}
// GGUF strings are utf8 encoded but there are cases that don't seem to be valid.
Ok(String::from_utf8_lossy(&v).into_owned())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ValueType {
// The value is a 8-bit unsigned integer.
U8,
// The value is a 8-bit signed integer.
I8,
// The value is a 16-bit unsigned little-endian integer.
U16,
// The value is a 16-bit signed little-endian integer.
I16,
// The value is a 32-bit unsigned little-endian integer.
U32,
// The value is a 32-bit signed little-endian integer.
I32,
// The value is a 64-bit unsigned little-endian integer.
U64,
// The value is a 64-bit signed little-endian integer.
I64,
// The value is a 32-bit IEEE754 floating point number.
F32,
// The value is a 64-bit IEEE754 floating point number.
F64,
// The value is a boolean.
// 1-byte value where 0 is false and 1 is true.
// Anything else is invalid, and should be treated as either the model being invalid or the reader being buggy.
Bool,
// The value is a UTF-8 non-null-terminated string, with length prepended.
String,
// The value is an array of other values, with the length and type prepended.
///
// Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes.
Array,
}
#[derive(Debug, Clone)]
pub enum Value {
U8(u8),
I8(i8),
U16(u16),
I16(i16),
U32(u32),
I32(i32),
U64(u64),
I64(i64),
F32(f32),
F64(f64),
Bool(bool),
String(String),
Array(Vec<Value>),
}
impl Value {
pub fn value_type(&self) -> ValueType {
match self {
Self::U8(_) => ValueType::U8,
Self::I8(_) => ValueType::I8,
Self::U16(_) => ValueType::U16,
Self::I16(_) => ValueType::I16,
Self::U32(_) => ValueType::U32,
Self::I32(_) => ValueType::I32,
Self::U64(_) => ValueType::U64,
Self::I64(_) => ValueType::I64,
Self::F32(_) => ValueType::F32,
Self::F64(_) => ValueType::F64,
Self::Bool(_) => ValueType::Bool,
Self::String(_) => ValueType::String,
Self::Array(_) => ValueType::Array,
}
}
pub fn to_u8(&self) -> Result<u8> {
match self {
Self::U8(v) => Ok(*v),
v => crate::bail!("not a u8 {v:?}"),
}
}
pub fn to_i8(&self) -> Result<i8> {
match self {
Self::I8(v) => Ok(*v),
v => crate::bail!("not a i8 {v:?}"),
}
}
pub fn to_u16(&self) -> Result<u16> {
match self {
Self::U16(v) => Ok(*v),
v => crate::bail!("not a u16 {v:?}"),
}
}
pub fn to_i16(&self) -> Result<i16> {
match self {
Self::I16(v) => Ok(*v),
v => crate::bail!("not a i16 {v:?}"),
}
}
pub fn to_u32(&self) -> Result<u32> {
match self {
Self::U32(v) => Ok(*v),
v => crate::bail!("not a u32 {v:?}"),
}
}
pub fn to_i32(&self) -> Result<i32> {
match self {
Self::I32(v) => Ok(*v),
v => crate::bail!("not a i32 {v:?}"),
}
}
pub fn to_u64(&self) -> Result<u64> {
match self {
Self::U64(v) => Ok(*v),
v => crate::bail!("not a u64 {v:?}"),
}
}
pub fn to_i64(&self) -> Result<i64> {
match self {
Self::I64(v) => Ok(*v),
v => crate::bail!("not a i64 {v:?}"),
}
}
pub fn to_f32(&self) -> Result<f32> {
match self {
Self::F32(v) => Ok(*v),
v => crate::bail!("not a f32 {v:?}"),
}
}
pub fn to_f64(&self) -> Result<f64> {
match self {
Self::F64(v) => Ok(*v),
v => crate::bail!("not a f64 {v:?}"),
}
}
pub fn to_bool(&self) -> Result<bool> {
match self {
Self::Bool(v) => Ok(*v),
v => crate::bail!("not a bool {v:?}"),
}
}
pub fn to_vec(&self) -> Result<&Vec<Value>> {
match self {
Self::Array(v) => Ok(v),
v => crate::bail!("not a vec {v:?}"),
}
}
pub fn to_string(&self) -> Result<&String> {
match self {
Self::String(v) => Ok(v),
v => crate::bail!("not a string {v:?}"),
}
}
fn read<R: std::io::Read>(
reader: &mut R,
value_type: ValueType,
magic: &VersionedMagic,
) -> Result<Self> {
let v = match value_type {
ValueType::U8 => Self::U8(reader.read_u8()?),
ValueType::I8 => Self::I8(reader.read_i8()?),
ValueType::U16 => Self::U16(reader.read_u16::<LittleEndian>()?),
ValueType::I16 => Self::I16(reader.read_i16::<LittleEndian>()?),
ValueType::U32 => Self::U32(reader.read_u32::<LittleEndian>()?),
ValueType::I32 => Self::I32(reader.read_i32::<LittleEndian>()?),
ValueType::U64 => Self::U64(reader.read_u64::<LittleEndian>()?),
ValueType::I64 => Self::I64(reader.read_i64::<LittleEndian>()?),
ValueType::F32 => Self::F32(reader.read_f32::<LittleEndian>()?),
ValueType::F64 => Self::F64(reader.read_f64::<LittleEndian>()?),
ValueType::Bool => match reader.read_u8()? {
0 => Self::Bool(false),
1 => Self::Bool(true),
b => crate::bail!("unexpected bool value {b}"),
},
ValueType::String => Self::String(read_string(reader, magic)?),
ValueType::Array => {
let value_type = reader.read_u32::<LittleEndian>()?;
let value_type = ValueType::from_u32(value_type)?;
let len = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
reader.read_u64::<LittleEndian>()? as usize
}
};
let mut vs = Vec::with_capacity(len);
for _ in 0..len {
vs.push(Value::read(reader, value_type, magic)?)
}
Self::Array(vs)
}
};
Ok(v)
}
fn write<W: std::io::Write>(&self, w: &mut W) -> Result<()> {
match self {
&Self::U8(v) => w.write_u8(v)?,
&Self::I8(v) => w.write_i8(v)?,
&Self::U16(v) => w.write_u16::<LittleEndian>(v)?,
&Self::I16(v) => w.write_i16::<LittleEndian>(v)?,
&Self::U32(v) => w.write_u32::<LittleEndian>(v)?,
&Self::I32(v) => w.write_i32::<LittleEndian>(v)?,
&Self::U64(v) => w.write_u64::<LittleEndian>(v)?,
&Self::I64(v) => w.write_i64::<LittleEndian>(v)?,
&Self::F32(v) => w.write_f32::<LittleEndian>(v)?,
&Self::F64(v) => w.write_f64::<LittleEndian>(v)?,
&Self::Bool(v) => w.write_u8(u8::from(v))?,
Self::String(v) => write_string(w, v.as_str())?,
Self::Array(v) => {
// The `Value` type does not enforce that all the values in an Array have the same
// type.
let value_type = if v.is_empty() {
// Doesn't matter, the array is empty.
ValueType::U32
} else {
let value_type: std::collections::HashSet<_> =
v.iter().map(|elem| elem.value_type()).collect();
if value_type.len() != 1 {
crate::bail!("multiple value-types in the same array {value_type:?}")
}
value_type.into_iter().next().unwrap()
};
w.write_u32::<LittleEndian>(value_type.to_u32())?;
w.write_u64::<LittleEndian>(v.len() as u64)?;
for elem in v.iter() {
elem.write(w)?
}
}
}
Ok(())
}
}
impl ValueType {
fn from_u32(v: u32) -> Result<Self> {
let v = match v {
0 => Self::U8,
1 => Self::I8,
2 => Self::U16,
3 => Self::I16,
4 => Self::U32,
5 => Self::I32,
6 => Self::F32,
7 => Self::Bool,
8 => Self::String,
9 => Self::Array,
10 => Self::U64,
11 => Self::I64,
12 => Self::F64,
v => crate::bail!("unrecognized value-type {v:#08x}"),
};
Ok(v)
}
fn to_u32(self) -> u32 {
match self {
Self::U8 => 0,
Self::I8 => 1,
Self::U16 => 2,
Self::I16 => 3,
Self::U32 => 4,
Self::I32 => 5,
Self::F32 => 6,
Self::Bool => 7,
Self::String => 8,
Self::Array => 9,
Self::U64 => 10,
Self::I64 => 11,
Self::F64 => 12,
}
}
}
impl Content {
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Self> {
let magic = VersionedMagic::read(reader)?;
let tensor_count = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
reader.read_u64::<LittleEndian>()? as usize
}
};
let metadata_kv_count = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
reader.read_u64::<LittleEndian>()? as usize
}
};
let mut metadata = HashMap::new();
for _idx in 0..metadata_kv_count {
let key = read_string(reader, &magic)?;
let value_type = reader.read_u32::<LittleEndian>()?;
let value_type = ValueType::from_u32(value_type)?;
let value = Value::read(reader, value_type, &magic)?;
metadata.insert(key, value);
}
let mut tensor_infos = HashMap::new();
for _idx in 0..tensor_count {
let tensor_name = read_string(reader, &magic)?;
let n_dimensions = reader.read_u32::<LittleEndian>()?;
let mut dimensions: Vec<usize> = match magic {
VersionedMagic::GgufV1 => {
let mut dimensions = vec![0; n_dimensions as usize];
reader.read_u32_into::<LittleEndian>(&mut dimensions)?;
dimensions.into_iter().map(|c| c as usize).collect()
}
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
let mut dimensions = vec![0; n_dimensions as usize];
reader.read_u64_into::<LittleEndian>(&mut dimensions)?;
dimensions.into_iter().map(|c| c as usize).collect()
}
};
dimensions.reverse();
let ggml_dtype = reader.read_u32::<LittleEndian>()?;
let ggml_dtype = GgmlDType::from_u32(ggml_dtype)?;
let offset = reader.read_u64::<LittleEndian>()?;
tensor_infos.insert(
tensor_name,
TensorInfo {
shape: crate::Shape::from(dimensions),
offset,
ggml_dtype,
},
);
}
let position = reader.stream_position()?;
let alignment = match metadata.get("general.alignment") {
Some(Value::U8(v)) => *v as u64,
Some(Value::U16(v)) => *v as u64,
Some(Value::U32(v)) => *v as u64,
Some(Value::I8(v)) if *v >= 0 => *v as u64,
Some(Value::I16(v)) if *v >= 0 => *v as u64,
Some(Value::I32(v)) if *v >= 0 => *v as u64,
_ => DEFAULT_ALIGNMENT,
};
let tensor_data_offset = (position + alignment - 1) / alignment * alignment;
Ok(Self {
magic,
metadata,
tensor_infos,
tensor_data_offset,
})
}
pub fn tensor<R: std::io::Seek + std::io::Read>(
&self,
reader: &mut R,
name: &str,
device: &Device,
) -> Result<QTensor> {
let tensor_info = match self.tensor_infos.get(name) {
Some(tensor_info) => tensor_info,
None => crate::bail!("cannot find tensor info for {name}"),
};
tensor_info.read(reader, self.tensor_data_offset, device)
}
}
fn write_string<W: std::io::Write>(w: &mut W, str: &str) -> Result<()> {
let bytes = str.as_bytes();
w.write_u64::<LittleEndian>(bytes.len() as u64)?;
w.write_all(bytes)?;
Ok(())
}
pub fn write<W: std::io::Seek + std::io::Write>(
w: &mut W,
metadata: &[(&str, &Value)],
tensors: &[(&str, &QTensor)],
) -> Result<()> {
w.write_u32::<LittleEndian>(0x46554747)?;
w.write_u32::<LittleEndian>(2)?; // version 2.
w.write_u64::<LittleEndian>(tensors.len() as u64)?;
w.write_u64::<LittleEndian>(metadata.len() as u64)?;
for (name, value) in metadata.iter() {
write_string(w, name)?;
w.write_u32::<LittleEndian>(value.value_type().to_u32())?;
value.write(w)?;
}
let mut offset = 0usize;
let mut offsets = Vec::with_capacity(tensors.len());
for (name, tensor) in tensors.iter() {
write_string(w, name)?;
let dims = tensor.shape().dims();
w.write_u32::<LittleEndian>(dims.len() as u32)?;
for &dim in dims.iter().rev() {
w.write_u64::<LittleEndian>(dim as u64)?;
}
w.write_u32::<LittleEndian>(tensor.dtype().to_u32())?;
w.write_u64::<LittleEndian>(offset as u64)?;
offsets.push(offset);
let size_in_bytes = tensor.storage_size_in_bytes();
let padding = 31 - (31 + size_in_bytes) % 32;
offset += size_in_bytes + padding;
}
let pos = w.stream_position()? as usize;
let padding = 31 - (31 + pos) % 32;
w.write_all(&vec![0u8; padding])?;
let tensor_start_pos = w.stream_position()? as usize;
for (offset, (_name, tensor)) in offsets.iter().zip(tensors.iter()) {
let pos = w.stream_position()? as usize;
if tensor_start_pos + offset != pos {
crate::bail!(
"internal error, unexpected current position {tensor_start_pos} {offset} {pos}"
)
}
let data = tensor.data()?;
let size_in_bytes = data.len();
w.write_all(&data)?;
let padding = 31 - (31 + size_in_bytes) % 32;
w.write_all(&vec![0u8; padding])?;
}
Ok(())
}

File diff suppressed because it is too large Load Diff

View File

@ -1,234 +0,0 @@
use super::{GgmlDType, QStorage};
use crate::backend::BackendStorage;
use crate::{DType, MetalDevice, MetalStorage, Result, Shape};
use metal::Buffer;
use std::sync::Arc;
pub struct QMetalStorage {
dtype: GgmlDType,
device: MetalDevice,
buffer: Arc<Buffer>,
}
impl QMetalStorage {
pub fn zeros(device: &MetalDevice, elem_count: usize, dtype: GgmlDType) -> Result<Self> {
let size = elem_count * dtype.type_size() / dtype.block_size();
let buffer = device.allocate_zeros(size)?;
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
}
pub fn dtype(&self) -> GgmlDType {
self.dtype
}
pub fn device(&self) -> &MetalDevice {
&self.device
}
pub fn buffer(&self) -> &Buffer {
&self.buffer
}
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu");
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
blit.end_encoding();
self.device.wait_until_completed()?;
let mut out = vec![0.0; elem_count];
match self.dtype {
GgmlDType::F32 => {
let vec: Vec<f32> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
f32::to_float(&vec, &mut out)?;
}
GgmlDType::F16 => {
let vec: Vec<half::f16> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
half::f16::to_float(&vec, &mut out)?;
}
GgmlDType::Q4_0 => {
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q4_1 => {
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q5_0 => {
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q5_1 => {
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q8_0 => {
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q8_1 => {
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q2K => {
let vec: Vec<crate::quantized::BlockQ2K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
}
GgmlDType::Q3K => {
let vec: Vec<crate::quantized::BlockQ3K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
}
GgmlDType::Q4K => {
let vec: Vec<crate::quantized::BlockQ4K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
}
GgmlDType::Q5K => {
let vec: Vec<crate::quantized::BlockQ5K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
}
GgmlDType::Q6K => {
let vec: Vec<crate::quantized::BlockQ6K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
}
GgmlDType::Q8K => {
let vec: Vec<crate::quantized::BlockQ8K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
}
}
let buffer = self.device.new_buffer_with_data(&out)?;
Ok(MetalStorage::new(buffer, self.device.clone(), DType::F32))
}
pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> {
// Quantization only happens on CPU for now.
let src = src.to_cpu::<f32>()?;
let elem_count = src.len();
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?;
qcpu_storage.quantize(&src)?;
let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?;
self.buffer = buffer;
Ok(())
}
pub fn storage_size_in_bytes(&self) -> usize {
self.buffer.length() as usize
}
pub fn fwd(
&self,
self_shape: &Shape,
storage: &MetalStorage,
layout: &crate::Layout,
) -> Result<(MetalStorage, Shape)> {
use crate::MetalError;
if !layout.is_contiguous() {
crate::bail!("input tensor is not contiguous {layout:?}")
}
let src_shape = layout.shape();
// self is transposed so n is first then k.
if src_shape.rank() < 2 {
crate::bail!("input tensor has only one dimension {layout:?}")
}
let (n, k) = self_shape.dims2()?;
let mut dst_shape = src_shape.dims().to_vec();
let (b, m) = match dst_shape.len() {
3 => (dst_shape[0], dst_shape[1]),
2 => (1, dst_shape[0]),
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
};
let last_k = dst_shape.pop().unwrap();
if last_k != k {
crate::bail!("input tensor {layout:?} incompatible with {:?}", self_shape)
}
dst_shape.push(n);
let dst_shape = Shape::from(dst_shape);
let device = storage.device().clone();
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
let command_buffer = device.command_buffer()?;
candle_metal_kernels::call_quantized_matmul_t(
device.device(),
&command_buffer,
device.kernels(),
self.dtype.into(),
(b, m, n, k),
storage.buffer(),
layout.start_offset() * storage.dtype().size_in_bytes(),
&self.buffer,
&dst,
)
.map_err(MetalError::from)?;
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
Ok((dst_storage, dst_shape))
}
}
pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>(
device: &MetalDevice,
data: &[T],
) -> Result<QStorage> {
let buffer = device.new_buffer_with_data(data)?;
let device = device.clone();
Ok(QStorage::Metal(QMetalStorage {
dtype: T::DTYPE,
device,
buffer,
}))
}
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T;
assert!(!ptr.is_null());
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
slice.to_vec()
}
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
fn from(value: GgmlDType) -> Self {
match value {
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
}
}
}

View File

@ -1,119 +1,16 @@
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
use k_quants::*;
use std::borrow::Cow;
use crate::{Device, Result, Shape, Tensor};
#[cfg(target_feature = "avx")]
pub mod avx;
mod dummy_metal;
pub mod ggml_file;
pub mod gguf_file;
pub mod k_quants;
#[cfg(feature = "metal")]
pub mod metal;
#[cfg(not(feature = "metal"))]
mod metal {
pub use super::dummy_metal::*;
}
#[cfg(target_feature = "neon")]
pub mod neon;
#[cfg(target_feature = "simd128")]
pub mod simd128;
pub mod utils;
use half::f16;
pub use k_quants::GgmlType;
pub struct QTensor {
storage: QStorage,
data: Box<dyn QuantizedType>,
shape: Shape,
}
impl Device {
fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> {
match self {
Device::Cpu => {
let storage = dtype.cpu_zeros(elem_count);
Ok(QStorage::Cpu(storage))
}
Device::Metal(metal) => {
let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
Ok(QStorage::Metal(storage))
}
Device::Cuda(_cuda) => {
crate::bail!("Cuda ggml quantization not supported");
}
}
}
}
pub enum QStorage {
Cpu(Box<dyn QuantizedType>),
Metal(metal::QMetalStorage),
}
impl QStorage {
fn block_size(&self) -> usize {
match self {
QStorage::Cpu(storage) => storage.block_size(),
QStorage::Metal(storage) => storage.dtype().block_size(),
}
}
fn dtype(&self) -> GgmlDType {
match self {
QStorage::Cpu(storage) => storage.dtype(),
QStorage::Metal(storage) => storage.dtype(),
}
}
fn device(&self) -> Device {
match self {
QStorage::Cpu(_storage) => Device::Cpu,
QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
}
}
fn size_in_bytes(&self) -> usize {
match self {
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
QStorage::Metal(storage) => storage.storage_size_in_bytes(),
}
}
fn quantize(&mut self, src: &Storage) -> Result<()> {
match (self, src) {
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
storage.from_float(src.as_slice::<f32>()?)?;
}
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
_ => crate::bail!("Invalid dequantize storage locations do not match"),
}
Ok(())
}
fn dequantize(&self, elem_count: usize) -> Result<Storage> {
match self {
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
}
}
fn data(&self) -> Result<Cow<[u8]>> {
match self {
QStorage::Cpu(storage) => {
let data_ptr = storage.as_ptr();
let size_in_bytes = storage.storage_size_in_bytes();
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
Ok(Cow::from(data))
}
QStorage::Metal(_storage) => {
crate::bail!("not implemented");
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GgmlDType {
F32,
F16,
@ -153,46 +50,7 @@ impl GgmlDType {
Ok(dtype)
}
pub(crate) fn to_u32(self) -> u32 {
match self {
Self::F32 => 0,
Self::F16 => 1,
Self::Q4_0 => 2,
Self::Q4_1 => 3,
Self::Q5_0 => 6,
Self::Q5_1 => 7,
Self::Q8_0 => 8,
Self::Q8_1 => 9,
Self::Q2K => 10,
Self::Q3K => 11,
Self::Q4K => 12,
Self::Q5K => 13,
Self::Q6K => 14,
Self::Q8K => 15,
}
}
/// The block dtype
pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> {
match self {
Self::F32 => Box::new(vec![f32::zeros(); elem_count]),
Self::F16 => Box::new(vec![f16::zeros(); elem_count]),
Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]),
Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]),
Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]),
Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]),
Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]),
Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]),
Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]),
Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]),
Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]),
Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
}
}
/// The type size for blocks in bytes.
pub fn type_size(&self) -> usize {
fn type_size(&self) -> usize {
use k_quants::*;
match self {
Self::F32 => 4,
@ -213,8 +71,7 @@ impl GgmlDType {
}
}
/// The block size, i.e. the number of elements stored in each block.
pub fn block_size(&self) -> usize {
fn blck_size(&self) -> usize {
match self {
Self::F32 => 1,
Self::F16 => 1,
@ -233,13 +90,7 @@ impl GgmlDType {
pub trait QuantizedType: Send + Sync {
fn dtype(&self) -> GgmlDType;
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>;
fn storage_size_in_bytes(&self) -> usize;
fn as_ptr(&self) -> *const u8;
fn block_size(&self) -> usize;
#[allow(clippy::wrong_self_convention)]
fn from_float(&mut self, xs: &[f32]) -> Result<()>;
fn size(&self) -> usize;
fn to_float(&self, ys: &mut [f32]) -> Result<()>;
}
impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
@ -247,34 +98,12 @@ impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
k_quants::matmul(mkn, lhs, self.as_slice(), dst)
}
fn size(&self) -> usize {
self.len() * core::mem::size_of::<T>()
}
fn from_float(&mut self, xs: &[f32]) -> Result<()> {
T::from_float(xs, self)
}
fn dtype(&self) -> GgmlDType {
T::DTYPE
}
fn block_size(&self) -> usize {
T::BLCK_SIZE
}
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> {
let mut ys = vec![0.0f32; elem_count];
T::to_float(self.as_slice(), &mut ys)?;
Ok(CpuStorage::F32(ys))
}
fn storage_size_in_bytes(&self) -> usize {
self.len() * std::mem::size_of::<T>()
}
fn as_ptr(&self) -> *const u8 {
self.as_ptr() as *const u8
fn to_float(&self, ys: &mut [f32]) -> Result<()> {
T::to_float(self.as_slice(), ys)
}
}
@ -284,57 +113,19 @@ impl std::fmt::Debug for QTensor {
}
}
fn check_shape(shape: &Shape, block_size: usize) -> Result<()> {
let dims = shape.dims();
if dims.is_empty() {
crate::bail!("scalar tensor cannot be quantized {shape:?}")
}
if dims[dims.len() - 1] % block_size != 0 {
crate::bail!(
"quantized tensor must have their last dim divisible by block size {shape:?} {}",
block_size
)
}
Ok(())
}
impl QTensor {
pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> {
let shape = shape.into();
check_shape(&shape, storage.block_size())?;
Ok(Self { storage, shape })
}
pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> {
let shape = src.shape();
let block_size = dtype.block_size();
check_shape(shape, block_size)?;
let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
let elem_count = shape.elem_count();
if elem_count % block_size != 0 {
crate::bail!(
"tensor size ({shape:?}) is not divisible by block size {}",
block_size
)
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
data: Vec<T>,
shape: S,
) -> Self {
Self {
data: Box::new(data),
shape: shape.into(),
}
let mut storage = src.device().qzeros(elem_count, dtype)?;
storage.quantize(&src.storage())?;
Ok(Self {
storage,
shape: shape.clone(),
})
}
pub fn dtype(&self) -> GgmlDType {
self.storage.dtype()
}
pub fn device(&self) -> Device {
self.storage.device()
}
pub fn rank(&self) -> usize {
self.shape.rank()
self.data.dtype()
}
pub fn shape(&self) -> &Shape {
@ -342,60 +133,26 @@ impl QTensor {
}
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
let storage = self.storage.dequantize(self.shape.elem_count())?;
let none = crate::op::BackpropOp::none();
let is_variable = false;
crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable)
.to_device(device)
let mut f32_data = vec![0f32; self.shape.elem_count()];
self.data.to_float(&mut f32_data)?;
Tensor::from_vec(f32_data, &self.shape, device)
}
pub fn storage_size_in_bytes(&self) -> usize {
self.storage.size_in_bytes()
}
pub fn data(&self) -> Result<Cow<'_, [u8]>> {
self.storage.data()
pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
self.data.matmul_t(mkn, lhs, dst)
}
}
#[derive(Clone, Debug)]
pub enum QMatMul {
QTensor(std::sync::Arc<QTensor>),
Tensor(Tensor),
}
thread_local! {
static DEQUANTIZE_ALL: bool = {
match std::env::var("CANDLE_DEQUANTIZE_ALL") {
Ok(s) => {
!s.is_empty() && s != "0"
},
Err(_) => false,
}
}
}
#[derive(Debug, Clone)]
pub struct QMatMul(std::sync::Arc<QTensor>);
impl QMatMul {
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
let dequantize = match qtensor.dtype() {
GgmlDType::F32 | GgmlDType::F16 => true,
_ => DEQUANTIZE_ALL.with(|b| *b),
};
let t = if dequantize {
let tensor = qtensor.dequantize(&Device::Cpu)?;
Self::Tensor(tensor)
} else {
Self::QTensor(qtensor)
};
Ok(t)
}
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
Self::from_arc(std::sync::Arc::new(qtensor))
pub fn new(qtensor: std::sync::Arc<QTensor>) -> Self {
Self(qtensor)
}
}
impl crate::CustomOp1 for QTensor {
impl crate::CustomOp1 for QMatMul {
fn name(&self) -> &'static str {
"qmatmul"
}
@ -409,55 +166,29 @@ impl crate::CustomOp1 for QTensor {
crate::bail!("input tensor is not contiguous {layout:?}")
}
let src_shape = layout.shape();
// self is transposed so n is first then k.
let (n, k) = self.shape.dims2()?;
let (k, n) = self.0.shape.dims2()?;
if src_shape.rank() < 2 {
crate::bail!("input tensor has only one dimension {layout:?}")
}
let mut dst_shape = src_shape.dims().to_vec();
let last_k = dst_shape.pop().unwrap();
if last_k != k {
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
crate::bail!(
"input tensor {layout:?} incompatible with {:?}",
self.0.shape
)
}
dst_shape.push(n);
let dst_shape = Shape::from(dst_shape);
#[allow(clippy::infallible_destructuring_match)]
let self_storage = match &self.storage {
QStorage::Cpu(storage) => storage,
QStorage::Metal(_) => crate::bail!("Invalid storage"),
};
let slice = storage.as_slice::<f32>()?;
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
let storage = storage.as_slice::<f32>()?;
let storage =
&storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
let mut dst_storage = vec![0f32; dst_shape.elem_count()];
self_storage.matmul_t((dst_shape.elem_count() / n, k, n), slice, &mut dst_storage)?;
self.0.matmul_t(
(dst_shape.elem_count() / n, k, n),
storage,
&mut dst_storage,
)?;
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
}
fn metal_fwd(
&self,
storage: &crate::MetalStorage,
layout: &crate::Layout,
) -> Result<(crate::MetalStorage, Shape)> {
let self_storage = match &self.storage {
QStorage::Metal(metal) => metal,
_ => unreachable!("Cannot call metal matmul on non metal QTensor"),
};
self_storage.fwd(&self.shape, storage, layout)
}
}
impl crate::Module for QMatMul {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
Self::Tensor(w) => {
let w = match *xs.dims() {
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
_ => w.t()?,
};
xs.matmul(&w)
}
}
}
}

View File

@ -1,613 +0,0 @@
use super::k_quants::{
BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K,
};
use crate::Result;
use byteorder::{ByteOrder, LittleEndian};
#[allow(unused_imports)]
#[cfg(target_arch = "arm")]
use core::arch::arm::*;
#[allow(unused_imports)]
#[cfg(target_arch = "aarch64")]
use core::arch::aarch64::*;
#[inline(always)]
unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t {
// TODO: dotprod
let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b));
let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))
}
#[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
let nb = n / qk;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
unsafe {
let mut sumv0 = vdupq_n_f32(0.0f32);
for i in 0..nb {
let x0 = &xs[i];
let y0 = &ys[i];
let m4b = vdupq_n_u8(0x0F);
let s8b = vdupq_n_s8(0x8);
let v0_0 = vld1q_u8(x0.qs.as_ptr());
// 4-bit -> 8-bit
let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
// sub 8
let v0_0ls = vsubq_s8(v0_0l, s8b);
let v0_0hs = vsubq_s8(v0_0h, s8b);
// load y
let v1_0l = vld1q_s8(y0.qs.as_ptr());
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
let pl0 = vdotq_s32(v0_0ls, v1_0l);
let ph0 = vdotq_s32(v0_0hs, v1_0h);
sumv0 = vmlaq_n_f32(
sumv0,
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
x0.d.to_f32() * y0.d.to_f32(),
);
}
Ok(vaddvq_f32(sumv0))
}
}
#[inline(always)]
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
}
let nb = n / QK8_0;
unsafe {
let mut sumv0 = vdupq_n_f32(0.0f32);
for i in 0..nb {
let x0 = &xs[i];
let y0 = &ys[i];
let x0_0 = vld1q_s8(x0.qs.as_ptr());
let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16));
// load y
let y0_0 = vld1q_s8(y0.qs.as_ptr());
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
let p0 = vdotq_s32(x0_0, y0_0);
let p1 = vdotq_s32(x0_1, y0_1);
sumv0 = vmlaq_n_f32(
sumv0,
vcvtq_f32_s32(vaddq_s32(p0, p1)),
x0.d.to_f32() * y0.d.to_f32(),
);
}
Ok(vaddvq_f32(sumv0))
}
}
#[inline(always)]
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
let qk = QK_K;
if n % QK_K != 0 {
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
}
let mut sumf = 0f32;
for (xs, ys) in xs.iter().zip(ys.iter()) {
unsafe {
let mut sum_i = vdupq_n_s32(0);
let scale = xs.d * ys.d;
let xs = xs.qs.as_ptr();
let ys = ys.qs.as_ptr();
for i in (0..QK_K).step_by(16) {
let xs = vld1q_s8(xs.add(i));
let ys = vld1q_s8(ys.add(i));
let xy = vdotq_s32(xs, ys);
sum_i = vaddq_s32(sum_i, xy)
}
sumf += vaddvq_s32(sum_i) as f32 * scale
}
}
Ok(sumf)
}
#[inline(always)]
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
}
let mut sum = 0f32;
unsafe {
let m4b = vdupq_n_u8(0xF);
let mone = vdupq_n_u8(3);
for (x, y) in xs.iter().zip(ys.iter()) {
let d_all = x.d.to_f32();
let mut q6 = x.ql.as_ptr();
let mut qh = x.qh.as_ptr();
let mut q8 = y.qs.as_ptr();
let mut scale = x.scales.as_ptr();
let q8sums = vld1q_s16_x2(y.bsums.as_ptr());
let scales = vld1q_s8(scale);
let q6scales = int16x8x2_t(
vmovl_s8(vget_low_s8(scales)),
vmovl_s8(vget_high_s8(scales)),
);
let prod = vaddq_s32(
vaddq_s32(
vmull_s16(vget_low_s16(q8sums.0), vget_low_s16(q6scales.0)),
vmull_s16(vget_high_s16(q8sums.0), vget_high_s16(q6scales.0)),
),
vaddq_s32(
vmull_s16(vget_low_s16(q8sums.1), vget_low_s16(q6scales.1)),
vmull_s16(vget_high_s16(q8sums.1), vget_high_s16(q6scales.1)),
),
);
let isum_mins = vaddvq_s32(prod);
let mut isum = 0i32;
for _j in 0..QK_K / 128 {
let qhbits = vld1q_u8_x2(qh);
qh = qh.add(32);
let q6bits = vld1q_u8_x4(q6);
q6 = q6.add(64);
let q8bytes = vld1q_s8_x4(q8);
q8 = q8.add(64);
let q6h_0 = vshlq_n_u8(vandq_u8(mone, qhbits.0), 4);
let q6h_1 = vshlq_n_u8(vandq_u8(mone, qhbits.1), 4);
let shifted = vshrq_n_u8(qhbits.0, 2);
let q6h_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
let shifted = vshrq_n_u8(qhbits.1, 2);
let q6h_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
let q6bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.0, m4b), q6h_0));
let q6bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.1, m4b), q6h_1));
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2));
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3));
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
scale = scale.add(2);
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
scale = scale.add(2);
let q8bytes = vld1q_s8_x4(q8);
q8 = q8.add(64);
let shifted = vshrq_n_u8(qhbits.0, 4);
let q6h_0 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
let shifted = vshrq_n_u8(qhbits.1, 4);
let q6h_1 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
let shifted = vshrq_n_u8(qhbits.0, 6);
let q6h_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
let shifted = vshrq_n_u8(qhbits.1, 6);
let q6h_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
let q6bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.0, 4), q6h_0));
let q6bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.1, 4), q6h_1));
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2));
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3));
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
scale = scale.add(2);
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
scale = scale.add(2);
}
sum += d_all * y.d * ((isum - 32 * isum_mins) as f32);
}
}
Ok(sum)
}
#[inline(always)]
pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
}
let mut sumf = 0f32;
let mut utmp = [0u32; 4];
const KMASK1: u32 = 0x3f3f3f3f;
const KMASK2: u32 = 0x0f0f0f0f;
const KMASK3: u32 = 0x03030303;
unsafe {
let m4b = vdupq_n_u8(0xF);
let mone = vdupq_n_u8(1);
let mtwo = vdupq_n_u8(2);
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let dmin = y.d * x.dmin.to_f32();
let q8sums = vpaddq_s16(
vld1q_s16(y.bsums.as_ptr()),
vld1q_s16(y.bsums.as_ptr().add(8)),
);
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
let uaux = utmp[1] & KMASK1;
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
utmp[2] = uaux;
utmp[0] &= KMASK1;
let mins8 = vld1_u8((utmp.as_ptr() as *const u8).add(8));
let mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
let prod = vaddq_s32(
vmull_s16(vget_low_s16(q8sums), vget_low_s16(mins)),
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)),
);
let sumi_mins = vaddvq_s32(prod);
let mut scales = utmp.as_ptr() as *const u8;
let mut q5 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();
let mut qhbits = vld1q_u8_x2(x.qh.as_ptr());
let mut sumi = 0i32;
for _j in 0..QK_K / 64 {
let q5bits = vld1q_u8_x2(q5);
q5 = q5.add(32);
let q8bytes = vld1q_s8_x4(q8);
q8 = q8.add(64);
let q5h_0 = vshlq_n_u8(vandq_u8(mone, qhbits.0), 4);
let q5h_1 = vshlq_n_u8(vandq_u8(mone, qhbits.1), 4);
let q5h_2 = vshlq_n_u8(vandq_u8(mtwo, qhbits.0), 3);
let q5h_3 = vshlq_n_u8(vandq_u8(mtwo, qhbits.1), 3);
qhbits.0 = vshrq_n_u8(qhbits.0, 2);
qhbits.1 = vshrq_n_u8(qhbits.1, 2);
let q5bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.0, m4b), q5h_0));
let q5bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.1, m4b), q5h_1));
let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
let p0 = vdotq_s32(q5bytes_0, q8bytes.0);
let p1 = vdotq_s32(q5bytes_1, q8bytes.1);
sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32;
scales = scales.add(1);
let p2 = vdotq_s32(q5bytes_2, q8bytes.2);
let p3 = vdotq_s32(q5bytes_3, q8bytes.3);
sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32;
scales = scales.add(1);
}
sumf += d * sumi as f32 - dmin * sumi_mins as f32;
}
}
Ok(sumf)
}
#[inline(always)]
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
}
let mut sumf = 0f32;
let mut utmp = [0u32; 4];
let mut scales = [0u8; 16];
const KMASK1: u32 = 0x3f3f3f3f;
const KMASK2: u32 = 0x0f0f0f0f;
const KMASK3: u32 = 0x03030303;
unsafe {
let m4b = vdupq_n_u8(0xF);
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let dmin = y.d * x.dmin.to_f32();
let q8sums = vpaddq_s16(
vld1q_s16(y.bsums.as_ptr()),
vld1q_s16(y.bsums.as_ptr().add(8)),
);
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
let mins8 = vld1_u32(
[
utmp[1] & KMASK1,
((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4),
]
.as_ptr(),
);
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
utmp[0] &= KMASK1;
let mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
let prod = vaddq_s32(
vmull_s16(vget_low_s16(q8sums), vget_low_s16(mins)),
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)),
);
sumf -= dmin * vaddvq_s32(prod) as f32;
LittleEndian::write_u32_into(&utmp, &mut scales);
let mut q4 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();
let mut sumi1 = 0i32;
let mut sumi2 = 0i32;
for j in 0..QK_K / 64 {
let q4bits = vld1q_u8_x2(q4);
q4 = q4.add(32);
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
let q4bytes = int8x16x2_t(
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
);
let p0 = vdotq_s32(q4bytes.0, q8bytes.0);
let p1 = vdotq_s32(q4bytes.1, q8bytes.1);
sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32;
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
let q4bytes = int8x16x2_t(
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
);
let p2 = vdotq_s32(q4bytes.0, q8bytes.0);
let p3 = vdotq_s32(q4bytes.1, q8bytes.1);
sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32;
}
sumf += d * (sumi1 + sumi2) as f32;
}
}
Ok(sumf)
}
#[inline(always)]
pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
}
let mut sumf = 0f32;
let mut utmp = [0u32; 4];
let mut aux = [0u32; 3];
const KMASK1: u32 = 0x03030303;
const KMASK2: u32 = 0x0f0f0f0f;
unsafe {
let m3b = vdupq_n_u8(0x3);
let m0 = vdupq_n_u8(1);
let m1 = vshlq_n_u8(m0, 1);
let m2 = vshlq_n_u8(m0, 2);
let m3 = vshlq_n_u8(m0, 3);
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let mut q3 = x.qs.as_ptr();
let qh = x.hmask.as_ptr();
let mut q8 = y.qs.as_ptr();
let mut qhbits = vld1q_u8_x2(qh);
let mut isum = 0i32;
// Set up scales
LittleEndian::read_u32_into(&x.scales, &mut aux);
utmp[3] = ((aux[1] >> 4) & KMASK2) | (((aux[2] >> 6) & KMASK1) << 4);
utmp[2] = ((aux[0] >> 4) & KMASK2) | (((aux[2] >> 4) & KMASK1) << 4);
utmp[1] = (aux[1] & KMASK2) | (((aux[2] >> 2) & KMASK1) << 4);
utmp[0] = (aux[0] & KMASK2) | ((aux[2] & KMASK1) << 4);
let mut scale = utmp.as_mut_ptr() as *mut i8;
for j in 0..16 {
*scale.add(j) -= 32i8
}
for j in 0..QK_K / 128 {
let q3bits = vld1q_u8_x2(q3);
q3 = q3.add(32);
let q8bytes_1 = vld1q_s8_x4(q8);
q8 = q8.add(64);
let q8bytes_2 = vld1q_s8_x4(q8);
q8 = q8.add(64);
let q3h_0 = vshlq_n_u8(vbicq_u8(m0, qhbits.0), 2);
let q3h_1 = vshlq_n_u8(vbicq_u8(m0, qhbits.1), 2);
let q3h_2 = vshlq_n_u8(vbicq_u8(m1, qhbits.0), 1);
let q3h_3 = vshlq_n_u8(vbicq_u8(m1, qhbits.1), 1);
let q3bytes_0 = vsubq_s8(
vreinterpretq_s8_u8(vandq_u8(q3bits.0, m3b)),
vreinterpretq_s8_u8(q3h_0),
);
let q3bytes_1 = vsubq_s8(
vreinterpretq_s8_u8(vandq_u8(q3bits.1, m3b)),
vreinterpretq_s8_u8(q3h_1),
);
let q3bytes_2 = vsubq_s8(
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 2), m3b)),
vreinterpretq_s8_u8(q3h_2),
);
let q3bytes_3 = vsubq_s8(
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 2), m3b)),
vreinterpretq_s8_u8(q3h_3),
);
let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0);
let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1);
let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2);
let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3);
isum += vaddvq_s32(p0) * *scale as i32
+ vaddvq_s32(p1) * *scale.add(1) as i32
+ vaddvq_s32(p2) * *scale.add(2) as i32
+ vaddvq_s32(p3) * *scale.add(3) as i32;
scale = scale.add(4);
let q3h_0 = vbicq_u8(m2, qhbits.0);
let q3h_1 = vbicq_u8(m2, qhbits.1);
let q3h_2 = vshrq_n_u8(vbicq_u8(m3, qhbits.0), 1);
let q3h_3 = vshrq_n_u8(vbicq_u8(m3, qhbits.1), 1);
let q3bytes_0 = vsubq_s8(
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 4), m3b)),
vreinterpretq_s8_u8(q3h_0),
);
let q3bytes_1 = vsubq_s8(
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 4), m3b)),
vreinterpretq_s8_u8(q3h_1),
);
let q3bytes_2 = vsubq_s8(
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 6), m3b)),
vreinterpretq_s8_u8(q3h_2),
);
let q3bytes_3 = vsubq_s8(
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 6), m3b)),
vreinterpretq_s8_u8(q3h_3),
);
let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0);
let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1);
let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2);
let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3);
isum += vaddvq_s32(p0) * *scale as i32
+ vaddvq_s32(p1) * *scale.add(1) as i32
+ vaddvq_s32(p2) * *scale.add(2) as i32
+ vaddvq_s32(p3) * *scale.add(3) as i32;
scale = scale.add(4);
if j == 0 {
qhbits.0 = vshrq_n_u8(qhbits.0, 4);
qhbits.1 = vshrq_n_u8(qhbits.1, 4);
}
}
sumf += d * isum as f32;
}
}
Ok(sumf)
}
#[inline(always)]
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
}
let mut sumf = 0f32;
let mut aux = [0u8; 16];
unsafe {
let m3 = vdupq_n_u8(0x3);
let m4 = vdupq_n_u8(0xF);
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let dmin = -y.d * x.dmin.to_f32();
let mut q2 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();
let sc = x.scales.as_ptr();
let mins_and_scales = vld1q_u8(sc);
let scales = vandq_u8(mins_and_scales, m4);
vst1q_u8(aux.as_mut_ptr(), scales);
let mins = vshrq_n_u8(mins_and_scales, 4);
let q8sums = vld1q_s16_x2(y.bsums.as_ptr());
let mins16 = int16x8x2_t(
vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))),
vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins))),
);
let s0 = vaddq_s32(
vmull_s16(vget_low_s16(mins16.0), vget_low_s16(q8sums.0)),
vmull_s16(vget_high_s16(mins16.0), vget_high_s16(q8sums.0)),
);
let s1 = vaddq_s32(
vmull_s16(vget_low_s16(mins16.1), vget_low_s16(q8sums.1)),
vmull_s16(vget_high_s16(mins16.1), vget_high_s16(q8sums.1)),
);
sumf += dmin * vaddvq_s32(vaddq_s32(s0, s1)) as f32;
let mut isum = 0i32;
let mut is = 0usize;
// TODO: dotprod
for _j in 0..QK_K / 128 {
let q2bits = vld1q_u8_x2(q2);
q2 = q2.add(32);
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
let mut q2bytes = int8x16x2_t(
vreinterpretq_s8_u8(vandq_u8(q2bits.0, m3)),
vreinterpretq_s8_u8(vandq_u8(q2bits.1, m3)),
);
isum += multiply_accum_with_scale(&aux, is, 0, q2bytes, q8bytes);
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 2), m3));
q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 2), m3));
isum += multiply_accum_with_scale(&aux, is, 2, q2bytes, q8bytes);
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 4), m3));
q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 4), m3));
isum += multiply_accum_with_scale(&aux, is, 4, q2bytes, q8bytes);
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 6), m3));
q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 6), m3));
isum += multiply_accum_with_scale(&aux, is, 6, q2bytes, q8bytes);
is += 8;
}
sumf += d * isum as f32;
}
}
Ok(sumf)
}
#[inline(always)]
unsafe fn multiply_accum_with_scale(
aux: &[u8; 16],
is: usize,
index: usize,
q2bytes: int8x16x2_t,
q8bytes: int8x16x2_t,
) -> i32 {
let p1 = vdotq_s32(q2bytes.0, q8bytes.0);
let p2 = vdotq_s32(q2bytes.1, q8bytes.1);
vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32
}

View File

@ -1,419 +0,0 @@
use super::k_quants::{BlockQ2K, BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K};
use crate::Result;
use byteorder::{ByteOrder, LittleEndian};
use half::f16;
use core::arch::wasm32::*;
#[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
unsafe {
let mut acc = f32x4_splat(0.0f32);
for (x, y) in xs.iter().zip(ys.iter()) {
let x1234 = v128_load(x.qs.as_ptr() as *const v128);
let x12 = v128_and(x1234, u8x16_splat(0x0F));
let x12 = i8x16_sub(x12, i8x16_splat(8));
let x34 = u8x16_shr(x1234, 4);
let x34 = i8x16_sub(x34, i8x16_splat(8));
let x1 = i16x8_extend_low_i8x16(x12);
let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr());
let sum_xy = i32x4_dot_i16x8(x1, y1);
let x2 = i16x8_extend_high_i8x16(x12);
let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2));
let x3 = i16x8_extend_low_i8x16(x34);
let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3));
let x4 = i16x8_extend_high_i8x16(x34);
let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4));
let sum_xy = f32x4_convert_i32x4(sum_xy);
// f32x4_relaxed_madd is nightly only.
let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d));
let scaled = f32x4_mul(sum_xy, d);
acc = f32x4_add(acc, scaled)
}
let res = f32x4_extract_lane::<0>(acc)
+ f32x4_extract_lane::<1>(acc)
+ f32x4_extract_lane::<2>(acc)
+ f32x4_extract_lane::<3>(acc);
Ok(res)
}
}
#[inline(always)]
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
}
unsafe {
let mut acc = f32x4_splat(0.0f32);
for (x, y) in xs.iter().zip(ys.iter()) {
let x1 = i16x8_load_extend_i8x8(x.qs.as_ptr());
let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr());
let sum_xy = i32x4_dot_i16x8(x1, y1);
let x2 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(8));
let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2));
let x3 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(16));
let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3));
let x4 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(24));
let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4));
let sum_xy = f32x4_convert_i32x4(sum_xy);
// f32x4_relaxed_madd is nightly only.
let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d));
let scaled = f32x4_mul(sum_xy, d);
acc = f32x4_add(acc, scaled)
}
let res = f32x4_extract_lane::<0>(acc)
+ f32x4_extract_lane::<1>(acc)
+ f32x4_extract_lane::<2>(acc)
+ f32x4_extract_lane::<3>(acc);
Ok(res)
}
}
#[inline(always)]
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
}
unsafe {
let mut sumf = f32x4_splat(0f32);
for (x, y) in xs.iter().zip(ys.iter()) {
let mut q2: &[_] = &x.qs;
let mut q8: &[_] = &y.qs;
let sc = &x.scales;
let mut summs = i32x4_splat(0);
for i in (0..(QK_K / 16)).step_by(4) {
let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(i));
let scales = i32x4_shr(
i32x4(
sc[i] as i32,
sc[i + 1] as i32,
sc[i + 2] as i32,
sc[i + 3] as i32,
),
4,
);
summs = i32x4_add(summs, i32x4_mul(bsums, scales))
}
let summs = f32x4_convert_i32x4(summs);
let dall = y.d * x.d.to_f32();
let dmin = y.d * x.dmin.to_f32();
let mut isum = i32x4_splat(0);
let mut is = 0;
for _ in 0..(QK_K / 128) {
let mut shift = 0;
for _ in 0..4 {
let d = (sc[is] & 0xF) as i32;
is += 1;
let mut isuml = i16x8_splat(0);
for l in (0..16).step_by(8) {
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
}
let dd = i32x4_splat(d);
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
let d = (sc[is] & 0xF) as i32;
is += 1;
let mut isuml = i16x8_splat(0);
for l in (16..32).step_by(8) {
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
}
let dd = i32x4_splat(d);
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
shift += 2;
// adjust the indexing
q8 = &q8[32..];
}
// adjust the indexing
q2 = &q2[32..];
}
let isum = f32x4_convert_i32x4(isum);
sumf = f32x4_add(
sumf,
f32x4_sub(
f32x4_mul(isum, f32x4_splat(dall)),
f32x4_mul(summs, f32x4_splat(dmin)),
),
);
}
let sumf = f32x4_extract_lane::<0>(sumf)
+ f32x4_extract_lane::<1>(sumf)
+ f32x4_extract_lane::<2>(sumf)
+ f32x4_extract_lane::<3>(sumf);
Ok(sumf)
}
}
#[inline(always)]
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
}
const KMASK1: u32 = 0x3f3f3f3f;
const KMASK2: u32 = 0x0f0f0f0f;
const KMASK3: u32 = 0x03030303;
let mut utmp: [u32; 4] = [0; 4];
let mut scales: [u8; 8] = [0; 8];
let mut mins: [u8; 8] = [0; 8];
let mut aux8: [u8; QK_K] = [0; QK_K];
let mut sums = f32x4_splat(0f32);
unsafe {
for (y, x) in ys.iter().zip(xs.iter()) {
let q4 = &x.qs;
let q8 = &y.qs;
for j in 0..QK_K / 64 {
let q4_1 = v128_load(q4.as_ptr().add(32 * j) as *const v128);
let q4_2 = v128_load(q4.as_ptr().add(32 * j + 16) as *const v128);
v128_store(
aux8.as_mut_ptr().add(64 * j) as *mut v128,
v128_and(q4_1, u8x16_splat(0x0F)),
);
v128_store(
aux8.as_mut_ptr().add(64 * j + 16) as *mut v128,
v128_and(q4_2, u8x16_splat(0x0F)),
);
v128_store(
aux8.as_mut_ptr().add(64 * j + 32) as *mut v128,
u8x16_shr(q4_1, 4),
);
v128_store(
aux8.as_mut_ptr().add(64 * j + 48) as *mut v128,
u8x16_shr(q4_2, 4),
);
}
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
let uaux = utmp[1] & KMASK1;
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
utmp[2] = uaux;
utmp[0] &= KMASK1;
//extract scales and mins
LittleEndian::write_u32_into(&utmp[0..2], &mut scales);
LittleEndian::write_u32_into(&utmp[2..4], &mut mins);
let mut sumi = i32x4_splat(0);
for j in (0..QK_K / 16).step_by(4) {
let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(j));
let (m1, m2) = (mins[j / 2] as i32, mins[j / 2 + 1] as i32);
let mins = i32x4(m1, m1, m2, m2);
sumi = i32x4_add(sumi, i32x4_mul(bsums, mins));
}
let mut aux32 = i32x4_splat(0i32);
for (scale_i, scale) in scales.iter().enumerate() {
let scale = i32x4_splat(*scale as i32);
for j in 0..4 {
let i = 32 * scale_i + 8 * j;
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(i));
let aux8 = i16x8_load_extend_u8x8(aux8.as_ptr().add(i));
let aux16 = i16x8_mul(q8, aux8);
aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_low_i16x8(aux16)));
aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_high_i16x8(aux16)));
}
}
let aux32 = f32x4_convert_i32x4(aux32);
let d = f32x4_splat(x.d.to_f32() * y.d);
sums = f32x4_add(sums, f32x4_mul(aux32, d));
let dmin = x.dmin.to_f32() * y.d;
let dmin = f32x4_splat(dmin);
let sumi = f32x4_convert_i32x4(sumi);
sums = f32x4_sub(sums, f32x4_mul(sumi, dmin));
}
let sums = f32x4_extract_lane::<0>(sums)
+ f32x4_extract_lane::<1>(sums)
+ f32x4_extract_lane::<2>(sums)
+ f32x4_extract_lane::<3>(sums);
Ok(sums)
}
}
#[inline(always)]
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
}
let mut aux8 = [0i8; QK_K];
unsafe {
let mut sums = f32x4_splat(0f32);
for (x, y) in xs.iter().zip(ys.iter()) {
let q4 = &x.ql;
let qh = &x.qh;
let q8 = &y.qs;
let mut aux32 = f32x4_splat(0f32);
for j in (0..QK_K).step_by(128) {
let aux8 = aux8.as_mut_ptr().add(j);
let q4 = &q4.as_ptr().add(j / 2);
let qh = &qh.as_ptr().add(j / 4);
for l in (0..32).step_by(16) {
// aux8[l] = (((q4[l] & 0xF) | ((qh[l] & 3) << 4)) as i32 - 32) as i8;
let a8 = v128_or(
v128_and(v128_load(q4.add(l) as *const v128), u8x16_splat(0xF)),
u8x16_shl(
v128_and(v128_load(qh.add(l) as *const v128), u8x16_splat(3)),
4,
),
);
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
v128_store(
aux8.add(l) as *mut v128,
i8x16_narrow_i16x8(a8_low, a8_high),
);
// aux8[l + 32] =
// (((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i32 - 32) as i8;
let a8 = v128_or(
v128_and(v128_load(q4.add(l + 32) as *const v128), u8x16_splat(0xF)),
u8x16_shl(
v128_and(
u8x16_shr(v128_load(qh.add(l) as *const v128), 2),
u8x16_splat(3),
),
4,
),
);
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
v128_store(
aux8.add(l + 32) as *mut v128,
i8x16_narrow_i16x8(a8_low, a8_high),
);
// aux8[l + 64] = (((q4[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i32 - 32) as i8;
let a8 = v128_or(
u8x16_shr(v128_load(q4.add(l) as *const v128), 4),
u8x16_shl(
v128_and(
u8x16_shr(v128_load(qh.add(l) as *const v128), 4),
u8x16_splat(3),
),
4,
),
);
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
v128_store(
aux8.add(l + 64) as *mut v128,
i8x16_narrow_i16x8(a8_low, a8_high),
);
// aux8[l + 96] =
// (((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i32 - 32) as i8;
let a8 = v128_or(
u8x16_shr(v128_load(q4.add(l + 32) as *const v128), 4),
u8x16_shl(
v128_and(
u8x16_shr(v128_load(qh.add(l) as *const v128), 6),
u8x16_splat(3),
),
4,
),
);
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
v128_store(
aux8.add(l + 96) as *mut v128,
i8x16_narrow_i16x8(a8_low, a8_high),
);
}
}
for (j, &scale) in x.scales.iter().enumerate() {
let scale = f32x4_splat(scale as f32);
for offset in [0, 8] {
let aux16 = i16x8_mul(
i16x8_load_extend_i8x8(q8.as_ptr().add(16 * j + offset)),
i16x8_load_extend_i8x8(aux8.as_ptr().add(16 * j + offset)),
);
aux32 = f32x4_add(
aux32,
f32x4_mul(f32x4_convert_i32x4(i32x4_extend_low_i16x8(aux16)), scale),
);
aux32 = f32x4_add(
aux32,
f32x4_mul(f32x4_convert_i32x4(i32x4_extend_high_i16x8(aux16)), scale),
);
}
}
let d = f32x4_splat(x.d.to_f32() * y.d);
sums = f32x4_add(sums, f32x4_mul(aux32, d));
}
let sums = f32x4_extract_lane::<0>(sums)
+ f32x4_extract_lane::<1>(sums)
+ f32x4_extract_lane::<2>(sums)
+ f32x4_extract_lane::<3>(sums);
Ok(sums)
}
}
#[inline(always)]
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
let qk = QK_K;
if n % QK_K != 0 {
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
}
unsafe {
let mut acc = f32x4_splat(0.0f32);
for (xs, ys) in xs.iter().zip(ys.iter()) {
let x_qs = xs.qs.as_ptr();
let y_qs = ys.qs.as_ptr();
let mut sumi = i32x4_splat(0);
for j in (0..QK_K).step_by(8) {
let xs = i16x8_load_extend_i8x8(x_qs.add(j));
let ys = i16x8_load_extend_i8x8(y_qs.add(j));
let sum_xy = i32x4_dot_i16x8(xs, ys);
sumi = i32x4_add(sumi, sum_xy)
}
let d = f32x4_splat(xs.d * ys.d);
acc = f32x4_add(acc, f32x4_mul(f32x4_convert_i32x4(sumi), d))
}
let res = f32x4_extract_lane::<0>(acc)
+ f32x4_extract_lane::<1>(acc)
+ f32x4_extract_lane::<2>(acc)
+ f32x4_extract_lane::<3>(acc);
Ok(res)
}
}

View File

@ -1,326 +0,0 @@
use crate::Result;
pub(super) fn nearest_int(v: f32) -> i32 {
v.round() as i32
}
/// Validates that the input and output are the right size and returns an iterator which maps each
/// input region `xs` to its corresponding output block in `ys`. Each output region is guaranteed
/// to be `T::BLCK_SIZE` long.
pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
xs: &'b [f32],
ys: &'a mut [T],
) -> Result<Vec<(&'a mut T, &'b [f32])>> {
let block_size = T::BLCK_SIZE;
let dtype = T::DTYPE;
let expected_blocks = xs.len() / block_size;
let actual_blocks = ys.len();
// Validate that the input is the right size
if expected_blocks != actual_blocks {
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
}
Ok(ys.iter_mut().zip(xs.chunks_exact(block_size)).collect())
}
/// Validates that the input and output are the right size and returns an iterator which maps each
/// input block `xs` to its corresponding output region in `ys`. Each output region is guaranteed
/// to be `T::BLCK_SIZE` long.
pub(super) fn group_for_dequantization<'a, 'b, T: super::k_quants::GgmlType>(
xs: &'a [T],
ys: &'b mut [f32],
) -> Result<Vec<(&'a T, &'b mut [f32])>> {
let block_size = T::BLCK_SIZE;
let dtype = T::DTYPE;
let actual_output_len = ys.len();
let expected_output_len = xs.len() * block_size;
// Validate that the output is the right size
if expected_output_len != actual_output_len {
crate::bail!("dequantize {dtype:?}: ys (len = {actual_output_len}) does not match the expected length of {expected_output_len}!")
}
// Zip the blocks and outputs together
Ok(xs.iter().zip(ys.chunks_exact_mut(block_size)).collect())
}
pub(super) fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
if j < 4 {
let d = q[j] & 63;
let m = q[j + 4] & 63;
(d, m)
} else {
let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
(d, m)
}
}
pub(super) unsafe fn make_qx_quants(
n: usize,
nmax: i32,
x: *const f32,
ls: *mut i8,
rmse_type: i32,
) -> f32 {
let mut max = 0f32;
let mut amax = 0f32;
for i in 0..n {
let x = *x.add(i);
let ax = x.abs();
if ax > amax {
amax = ax;
max = x;
}
}
if amax == 0. {
// all zero
for i in 0..n {
*ls.add(i) = 0;
}
return 0.;
}
let mut iscale = -(nmax as f32) / max;
if rmse_type == 0 {
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
}
return 1.0 / iscale;
}
let weight_type = rmse_type % 2;
let mut sumlx = 0f32;
let mut suml2 = 0f32;
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
let l = l.clamp(-nmax, nmax - 1);
*ls.add(i) = (l + nmax) as i8;
let w = if weight_type == 1 { x * x } else { 1.0 };
let l = l as f32;
sumlx += w * x * l;
suml2 += w * l * l;
}
let mut scale = sumlx / suml2;
let mut best = scale * sumlx;
for _itry in 0..3 {
let iscale = 1.0 / scale;
let mut slx = 0f32;
let mut sl2 = 0f32;
let mut changed = false;
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
let l = l.clamp(-nmax, nmax - 1);
if l + nmax != *ls.add(i) as i32 {
changed = true;
}
let w = if weight_type == 1 { x * x } else { 1f32 };
let l = l as f32;
slx += w * x * l;
sl2 += w * l * l;
}
if !changed || sl2 == 0.0 || slx * slx <= best * sl2 {
break;
}
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
}
sumlx = slx;
suml2 = sl2;
scale = sumlx / suml2;
best = scale * sumlx;
}
for _itry in 0..5 {
let mut n_changed = 0;
for i in 0..n {
let x = *x.add(i);
let w = if weight_type == 1 { x * x } else { 1. };
let l = *ls.add(i) as i32 - nmax;
let mut slx = sumlx - w * x * l as f32;
if slx > 0. {
let mut sl2 = suml2 - w * l as f32 * l as f32;
let new_l = nearest_int(x * sl2 / slx);
let new_l = new_l.clamp(-nmax, nmax - 1);
if new_l != l {
slx += w * x * new_l as f32;
sl2 += w * new_l as f32 * new_l as f32;
if sl2 > 0. && slx * slx * suml2 > sumlx * sumlx * sl2 {
*ls.add(i) = (nmax + new_l) as i8;
sumlx = slx;
suml2 = sl2;
scale = sumlx / suml2;
best = scale * sumlx;
n_changed += 1;
}
}
}
}
if n_changed == 0 {
break;
}
}
if rmse_type < 3 {
return scale;
}
for is in -4..4 {
if is == 0 {
continue;
}
iscale = -(nmax as f32 + 0.1f32 * is as f32) / max;
let mut sumlx = 0.;
let mut suml2 = 0.;
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
let l = l.clamp(-nmax, nmax - 1);
let w = if weight_type == 1 { x * x } else { 1. };
let l = l as f32;
sumlx += w * x * l;
suml2 += w * l * l;
}
if suml2 > 0. && sumlx * sumlx > best * suml2 {
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
}
scale = sumlx / suml2;
best = scale * sumlx;
}
}
scale
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L224
pub(super) fn make_qkx1_quants(nmax: i32, ntry: usize, x: &[f32]) -> (f32, f32) {
let n = x.len();
let mut l = vec![0; n];
// Get min/max
let min = *x
.iter()
.take(n)
.min_by(|a, b| a.total_cmp(b))
.unwrap_or(&x[0]);
let max = *x.iter().max_by(|a, b| a.total_cmp(b)).unwrap_or(&x[0]);
// If min == max, all values are the same => nothing to do here
if max == min {
return (0.0, 0.0);
}
// Ensure min <= 0.0
let mut min = min.min(0.);
// Compute scale and inverse scale
let mut iscale = nmax as f32 / (max - min);
let mut scale = 1.0 / iscale;
for _ in 0..ntry {
let mut sumlx = 0.0;
let mut suml2 = 0;
let mut did_change = false;
for (i, value) in x.iter().enumerate().take(n) {
let li = nearest_int(iscale * (value - min)).clamp(0, nmax);
let clamped_li = li as u8;
if clamped_li != l[i] {
l[i] = clamped_li;
did_change = true;
}
sumlx += (value - min) * li as f32;
suml2 += li * li;
}
scale = sumlx / suml2 as f32;
let sum: f32 = x
.iter()
.take(n)
.zip(l.iter().take(n))
.map(|(xi, &li)| xi - scale * li as f32)
.sum();
min = sum / n as f32;
if min > 0.0 {
min = 0.0;
}
iscale = 1.0 / scale;
if !did_change {
break;
}
}
(scale, -min)
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L165
pub(super) fn make_q3_quants(x: &[f32], nmax: i32, do_rmse: bool) -> f32 {
let n = x.len();
let mut l = vec![0i8; n];
let mut max = 0.0;
let mut amax = 0.0;
for &xi in x.iter().take(n) {
let ax = xi.abs();
if ax > amax {
amax = ax;
max = xi;
}
}
if amax == 0.0 {
return 0.0;
}
let iscale = -(nmax as f32) / max;
if do_rmse {
let mut sumlx = 0.0;
let mut suml2 = 0.0;
for i in 0..n {
let li = (iscale * x[i]).round() as i32;
let li = li.clamp(-nmax, nmax - 1);
l[i] = li as i8;
let w = x[i] * x[i];
sumlx += w * x[i] * li as f32;
suml2 += w * (li * li) as f32;
}
for _ in 0..5 {
let mut n_changed = 0;
for i in 0..n {
let w = x[i] * x[i];
let mut slx = sumlx - w * x[i] * l[i] as f32;
if slx > 0.0 {
let mut sl2 = suml2 - w * (l[i] as i32 * l[i] as i32) as f32;
let mut new_l = (x[i] * sl2 / slx).round() as i32;
new_l = new_l.clamp(-nmax, nmax - 1);
if new_l != l[i] as i32 {
slx += w * x[i] * new_l as f32;
sl2 += w * (new_l * new_l) as f32;
if sl2 > 0.0 && slx * slx * suml2 > sumlx * sumlx * sl2 {
l[i] = new_l as i8;
sumlx = slx;
suml2 = sl2;
n_changed += 1;
}
}
}
}
if n_changed == 0 {
break;
}
}
for li in l.iter_mut() {
*li += nmax as i8;
}
return sumlx / suml2;
}
for i in 0..n {
let li = (iscale * x[i]).round() as i32;
l[i] = (li.clamp(-nmax, nmax - 1) + nmax) as i8;
}
1.0 / iscale
}

View File

@ -10,7 +10,6 @@ impl From<DType> for st::Dtype {
match value {
DType::U8 => st::Dtype::U8,
DType::U32 => st::Dtype::U32,
DType::I64 => st::Dtype::I64,
DType::BF16 => st::Dtype::BF16,
DType::F16 => st::Dtype::F16,
DType::F32 => st::Dtype::F32,
@ -25,7 +24,6 @@ impl TryFrom<st::Dtype> for DType {
match value {
st::Dtype::U8 => Ok(DType::U8),
st::Dtype::U32 => Ok(DType::U32),
st::Dtype::I64 => Ok(DType::I64),
st::Dtype::BF16 => Ok(DType::BF16),
st::Dtype::F16 => Ok(DType::F16),
st::Dtype::F32 => Ok(DType::F32),
@ -78,7 +76,11 @@ impl st::View for &Tensor {
}
impl Tensor {
pub fn save_safetensors<P: AsRef<Path>>(&self, name: &str, filename: P) -> Result<()> {
pub fn save_safetensors<P: AsRef<std::path::Path>>(
&self,
name: &str,
filename: P,
) -> Result<()> {
let data = [(name, self.clone())];
Ok(st::serialize_to_file(data, &None, filename.as_ref())?)
}
@ -187,7 +189,6 @@ impl Tensor {
match dtype {
DType::U8 => convert_slice::<u8>(data, shape, device),
DType::U32 => convert_slice::<u32>(data, shape, device),
DType::I64 => convert_slice::<i64>(data, shape, device),
DType::BF16 => convert_slice::<half::bf16>(data, shape, device),
DType::F16 => convert_slice::<half::f16>(data, shape, device),
DType::F32 => convert_slice::<f32>(data, shape, device),
@ -204,15 +205,24 @@ fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
convert_with_cast_::<u16, u32, _>(view, device, conv)
}
st::Dtype::U32 => convert_::<u32>(view, device),
st::Dtype::I32 => {
let conv = |x| Ok(i64::from(x));
convert_with_cast_::<i32, i64, _>(view, device, conv)
}
st::Dtype::I64 => convert_::<i64>(view, device),
st::Dtype::BF16 => convert_::<half::bf16>(view, device),
st::Dtype::F16 => convert_::<half::f16>(view, device),
st::Dtype::F32 => convert_::<f32>(view, device),
st::Dtype::F64 => convert_::<f64>(view, device),
st::Dtype::I32 => {
let conv = |x| {
u32::try_from(x)
.map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}")))
};
convert_with_cast_::<i32, u32, _>(view, device, conv)
}
st::Dtype::I64 => {
let conv = |x| {
u32::try_from(x)
.map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}")))
};
convert_with_cast_::<i64, u32, _>(view, device, conv)
}
dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
}
}
@ -223,7 +233,6 @@ fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
match tensor.dtype() {
DType::U8 => Ok(convert_back_::<u8>(tensor.to_vec1()?)),
DType::U32 => Ok(convert_back_::<u32>(tensor.to_vec1()?)),
DType::I64 => Ok(convert_back_::<i64>(tensor.to_vec1()?)),
DType::F16 => Ok(convert_back_::<half::f16>(tensor.to_vec1()?)),
DType::BF16 => Ok(convert_back_::<half::bf16>(tensor.to_vec1()?)),
DType::F32 => Ok(convert_back_::<f32>(tensor.to_vec1()?)),
@ -251,134 +260,6 @@ pub fn save<K: AsRef<str> + Ord + std::fmt::Display, P: AsRef<Path>>(
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
}
#[derive(yoke::Yokeable)]
struct SafeTensors_<'a>(SafeTensors<'a>);
pub struct MmapedSafetensors {
safetensors: Vec<yoke::Yoke<SafeTensors_<'static>, memmap2::Mmap>>,
routing: Option<HashMap<String, usize>>,
}
impl MmapedSafetensors {
/// Creates a wrapper around a memory mapped file and deserialize the safetensors header.
///
/// # Safety
///
/// The unsafe is inherited from [`memmap2::MmapOptions`].
pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
let p = p.as_ref();
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
let file = memmap2::MmapOptions::new()
.map(&file)
.map_err(|e| Error::from(e).with_path(p))?;
let safetensors = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
file,
|data: &[u8]| {
let st = safetensors::SafeTensors::deserialize(data)
.map_err(|e| Error::from(e).with_path(p))?;
Ok::<_, Error>(SafeTensors_(st))
},
)?;
Ok(Self {
safetensors: vec![safetensors],
routing: None,
})
}
/// Creates a wrapper around multiple memory mapped file and deserialize the safetensors headers.
///
/// If a tensor name appears in multiple files, the last entry is returned.
///
/// # Safety
///
/// The unsafe is inherited from [`memmap2::MmapOptions`].
pub unsafe fn multi<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
let mut routing = HashMap::new();
let mut safetensors = vec![];
for (index, p) in paths.iter().enumerate() {
let p = p.as_ref();
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
let file = memmap2::MmapOptions::new()
.map(&file)
.map_err(|e| Error::from(e).with_path(p))?;
let data = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
file,
|data: &[u8]| {
let st = safetensors::SafeTensors::deserialize(data)
.map_err(|e| Error::from(e).with_path(p))?;
Ok::<_, Error>(SafeTensors_(st))
},
)?;
for k in data.get().0.names() {
routing.insert(k.to_string(), index);
}
safetensors.push(data)
}
Ok(Self {
safetensors,
routing: Some(routing),
})
}
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
self.get(name)?.load(dev)
}
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
let mut tensors = vec![];
for safetensors in self.safetensors.iter() {
tensors.push(safetensors.get().0.tensors())
}
tensors.into_iter().flatten().collect()
}
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
let index = match &self.routing {
None => 0,
Some(routing) => {
let index = routing.get(name).ok_or_else(|| {
Error::CannotFindTensor {
path: name.to_string(),
}
.bt()
})?;
*index
}
};
Ok(self.safetensors[index].get().0.tensor(name)?)
}
}
pub struct BufferedSafetensors {
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
}
impl BufferedSafetensors {
/// Creates a wrapper around a binary buffer and deserialize the safetensors header.
pub fn new(buffer: Vec<u8>) -> Result<Self> {
let safetensors = yoke::Yoke::<SafeTensors_<'static>, Vec<u8>>::try_attach_to_cart(
buffer,
|data: &[u8]| {
let st = safetensors::SafeTensors::deserialize(data)?;
Ok::<_, Error>(SafeTensors_(st))
},
)?;
Ok(Self { safetensors })
}
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
self.get(name)?.load(dev)
}
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
self.safetensors.get().0.tensors()
}
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
Ok(self.safetensors.get().0.tensor(name)?)
}
}
pub struct MmapedFile {
path: std::path::PathBuf,
inner: memmap2::Mmap,
@ -391,7 +272,7 @@ impl MmapedFile {
/// # Safety
///
/// The unsafe is inherited from [`memmap2::MmapOptions`].
pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
let p = p.as_ref();
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
let inner = memmap2::MmapOptions::new()

View File

@ -1,23 +0,0 @@
use crate::{Result, Tensor, WithDType};
pub enum TensorScalar {
Tensor(Tensor),
Scalar(Tensor),
}
pub trait TensorOrScalar {
fn to_tensor_scalar(self) -> Result<TensorScalar>;
}
impl TensorOrScalar for &Tensor {
fn to_tensor_scalar(self) -> Result<TensorScalar> {
Ok(TensorScalar::Tensor(self.clone()))
}
}
impl<T: WithDType> TensorOrScalar for T {
fn to_tensor_scalar(self) -> Result<TensorScalar> {
let scalar = Tensor::new(self, &crate::Device::Cpu)?;
Ok(TensorScalar::Scalar(scalar))
}
}

View File

@ -1,5 +1,3 @@
//! The shape of a tensor is a tuple with the size of each of its dimensions.
#![allow(clippy::redundant_closure_call)]
use crate::{Error, Result};
#[derive(Clone, PartialEq, Eq)]
@ -73,14 +71,6 @@ impl From<(usize, usize, usize, usize, usize)> for Shape {
}
}
impl From<(usize, usize, usize, usize, usize, usize)> for Shape {
fn from(d123456: (usize, usize, usize, usize, usize, usize)) -> Self {
Self(vec![
d123456.0, d123456.1, d123456.2, d123456.3, d123456.4, d123456.5,
])
}
}
impl From<Vec<usize>> for Shape {
fn from(dims: Vec<usize>) -> Self {
Self(dims)
@ -128,7 +118,6 @@ impl Shape {
Self(dims.to_vec())
}
/// The rank is the number of dimensions, 0 for a scalar value, 1 for a vector, etc.
pub fn rank(&self) -> usize {
self.0.len()
}
@ -137,12 +126,10 @@ impl Shape {
self.0
}
/// The dimensions as a slice of `usize`.
pub fn dims(&self) -> &[usize] {
&self.0
}
/// The total number of elements, this is the product of all dimension sizes.
pub fn elem_count(&self) -> usize {
self.0.iter().product()
}
@ -194,75 +181,10 @@ impl Shape {
true
}
/// Modifies the shape by adding a list of additional dimensions at the end of the existing
/// dimensions.
pub fn extend(mut self, additional_dims: &[usize]) -> Self {
self.0.extend(additional_dims);
self
}
/// Check whether the two shapes are compatible for broadcast, and if it is the case return the
/// broadcasted shape. This is to be used for binary pointwise ops.
pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
let lhs = self;
let lhs_dims = lhs.dims();
let rhs_dims = rhs.dims();
let lhs_ndims = lhs_dims.len();
let rhs_ndims = rhs_dims.len();
let bcast_ndims = usize::max(lhs_ndims, rhs_ndims);
let mut bcast_dims = vec![0; bcast_ndims];
for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() {
let rev_idx = bcast_ndims - idx;
let l_value = if lhs_ndims < rev_idx {
1
} else {
lhs_dims[lhs_ndims - rev_idx]
};
let r_value = if rhs_ndims < rev_idx {
1
} else {
rhs_dims[rhs_ndims - rev_idx]
};
*bcast_value = if l_value == r_value {
l_value
} else if l_value == 1 {
r_value
} else if r_value == 1 {
l_value
} else {
Err(Error::ShapeMismatchBinaryOp {
lhs: lhs.clone(),
rhs: rhs.clone(),
op,
}
.bt())?
}
}
Ok(Shape::from(bcast_dims))
}
pub(crate) fn broadcast_shape_matmul(&self, rhs: &Self) -> Result<(Shape, Shape)> {
let lhs = self;
let lhs_dims = lhs.dims();
let rhs_dims = rhs.dims();
if lhs_dims.len() < 2 || rhs_dims.len() < 2 {
crate::bail!("only 2d matrixes are supported {lhs:?} {rhs:?}")
}
let (m, lhs_k) = (lhs_dims[lhs_dims.len() - 2], lhs_dims[lhs_dims.len() - 1]);
let (rhs_k, n) = (rhs_dims[rhs_dims.len() - 2], rhs_dims[rhs_dims.len() - 1]);
if lhs_k != rhs_k {
crate::bail!("different inner dimensions in broadcast matmul {lhs:?} {rhs:?}")
}
let lhs_b = Self::from(&lhs_dims[..lhs_dims.len() - 2]);
let rhs_b = Self::from(&rhs_dims[..rhs_dims.len() - 2]);
let bcast = lhs_b.broadcast_shape_binary_op(&rhs_b, "broadcast_matmul")?;
let bcast_dims = bcast.dims();
let bcast_lhs = [bcast_dims, &[m, lhs_k]].concat();
let bcast_rhs = [bcast_dims, &[rhs_k, n]].concat();
Ok((Shape::from(bcast_lhs), Shape::from(bcast_rhs)))
}
}
pub trait Dim {
@ -423,39 +345,6 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
}
}
impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim> Dims for (D1, D2, D3, D4) {
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
let d0 = self.0.to_index(shape, op)?;
let d1 = self.1.to_index(shape, op)?;
let d2 = self.2.to_index(shape, op)?;
let d3 = self.3.to_index(shape, op)?;
Ok(vec![d0, d1, d2, d3])
}
}
impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5) {
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
let d0 = self.0.to_index(shape, op)?;
let d1 = self.1.to_index(shape, op)?;
let d2 = self.2.to_index(shape, op)?;
let d3 = self.3.to_index(shape, op)?;
let d4 = self.4.to_index(shape, op)?;
Ok(vec![d0, d1, d2, d3, d4])
}
}
impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim, D6: Dim> Dims for (D1, D2, D3, D4, D5, D6) {
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
let d0 = self.0.to_index(shape, op)?;
let d1 = self.1.to_index(shape, op)?;
let d2 = self.2.to_index(shape, op)?;
let d3 = self.3.to_index(shape, op)?;
let d4 = self.4.to_index(shape, op)?;
let d5 = self.5.to_index(shape, op)?;
Ok(vec![d0, d1, d2, d3, d4, d5])
}
}
extract_dims!(dims0, 0, |_: &[usize]| (), ());
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
@ -478,139 +367,6 @@ extract_dims!(
(usize, usize, usize, usize, usize)
);
pub trait ShapeWithOneHole {
fn into_shape(self, el_count: usize) -> Result<Shape>;
}
impl<S: Into<Shape>> ShapeWithOneHole for S {
fn into_shape(self, _el_count: usize) -> Result<Shape> {
Ok(self.into())
}
}
impl ShapeWithOneHole for ((),) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
Ok(el_count.into())
}
}
fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
if prod_d == 0 {
crate::bail!("cannot reshape tensor of {el_count} elements to {s:?}")
}
if el_count % prod_d != 0 {
crate::bail!("cannot reshape tensor with {el_count} elements to {s:?}")
}
Ok(el_count / prod_d)
}
impl ShapeWithOneHole for ((), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1) = self;
Ok((hole_size(el_count, d1, &self)?, d1).into())
}
}
impl ShapeWithOneHole for (usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, ()) = self;
Ok((d1, hole_size(el_count, d1, &self)?).into())
}
}
impl ShapeWithOneHole for ((), usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1, d2) = self;
Ok((hole_size(el_count, d1 * d2, &self)?, d1, d2).into())
}
}
impl ShapeWithOneHole for (usize, (), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, (), d2) = self;
Ok((d1, hole_size(el_count, d1 * d2, &self)?, d2).into())
}
}
impl ShapeWithOneHole for (usize, usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, ()) = self;
Ok((d1, d2, hole_size(el_count, d1 * d2, &self)?).into())
}
}
impl ShapeWithOneHole for ((), usize, usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1, d2, d3) = self;
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
Ok((d, d1, d2, d3).into())
}
}
impl ShapeWithOneHole for (usize, (), usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, (), d2, d3) = self;
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
Ok((d1, d, d2, d3).into())
}
}
impl ShapeWithOneHole for (usize, usize, (), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, (), d3) = self;
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
Ok((d1, d2, d, d3).into())
}
}
impl ShapeWithOneHole for (usize, usize, usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, d3, ()) = self;
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
Ok((d1, d2, d3, d).into())
}
}
impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1, d2, d3, d4) = self;
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
Ok((d, d1, d2, d3, d4).into())
}
}
impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, (), d2, d3, d4) = self;
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
Ok((d1, d, d2, d3, d4).into())
}
}
impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, (), d3, d4) = self;
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
Ok((d1, d2, d, d3, d4).into())
}
}
impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, d3, (), d4) = self;
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
Ok((d1, d2, d3, d, d4).into())
}
}
impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, d3, d4, ()) = self;
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
Ok((d1, d2, d3, d4, d).into())
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -1,6 +1,6 @@
use crate::backend::BackendStorage;
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
// We do not want to implement Clone on Storage as cloning may fail because of
// out of memory. Instead try_clone should be used.
@ -8,7 +8,6 @@ use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage,
pub enum Storage {
Cpu(CpuStorage),
Cuda(CudaStorage),
Metal(MetalStorage),
}
impl Storage {
@ -19,10 +18,6 @@ impl Storage {
let storage = storage.try_clone(layout)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.try_clone(layout)?;
Ok(Self::Metal(storage))
}
}
}
@ -30,7 +25,6 @@ impl Storage {
match self {
Self::Cpu(_) => Device::Cpu,
Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
Self::Metal(storage) => Device::Metal(storage.device().clone()),
}
}
@ -38,7 +32,6 @@ impl Storage {
match self {
Self::Cpu(storage) => storage.dtype(),
Self::Cuda(storage) => storage.dtype(),
Self::Metal(storage) => storage.dtype(),
}
}
@ -72,27 +65,6 @@ impl Storage {
let storage = storage.affine(layout, mul, add)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.affine(layout, mul, add)?;
Ok(Self::Metal(storage))
}
}
}
pub(crate) fn powf(&self, layout: &Layout, alpha: f64) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.powf(layout, alpha)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.powf(layout, alpha)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.powf(layout, alpha)?;
Ok(Self::Metal(storage))
}
}
}
@ -106,10 +78,6 @@ impl Storage {
let storage = storage.elu(layout, alpha)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.elu(layout, alpha)?;
Ok(Self::Metal(storage))
}
}
}
@ -131,10 +99,6 @@ impl Storage {
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(lhs), Self::Metal(rhs)) => {
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
Ok(Self::Metal(storage))
}
(lhs, rhs) => {
// Should not happen because of the same device check above but we're defensive
// anyway.
@ -158,10 +122,6 @@ impl Storage {
let storage = storage.reduce_op(op, layout, s)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.reduce_op(op, layout, s)?;
Ok(Self::Metal(storage))
}
}
}
@ -175,14 +135,10 @@ impl Storage {
let storage = storage.to_dtype(layout, dtype)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.to_dtype(layout, dtype)?;
Ok(Self::Metal(storage))
}
}
}
pub(crate) fn apply_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
pub(crate) fn custom_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
match self {
Self::Cpu(storage) => {
let (storage, shape) = c.cpu_fwd(storage, l)?;
@ -192,14 +148,10 @@ impl Storage {
let (storage, shape) = c.cuda_fwd(storage, l)?;
Ok((Self::Cuda(storage), shape))
}
Self::Metal(storage) => {
let (storage, shape) = c.metal_fwd(storage, l)?;
Ok((Self::Metal(storage), shape))
}
}
}
pub(crate) fn apply_op2(
pub(crate) fn custom_op2(
&self,
l1: &Layout,
t2: &Self,
@ -216,15 +168,11 @@ impl Storage {
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
Ok((Self::Cuda(s), shape))
}
(Self::Metal(s1), Self::Metal(s2)) => {
let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?;
Ok((Self::Metal(s), shape))
}
_ => unreachable!(),
}
}
pub(crate) fn apply_op3(
pub(crate) fn custom_op3(
&self,
l1: &Layout,
t2: &Self,
@ -244,10 +192,6 @@ impl Storage {
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
Ok((Self::Cuda(s), shape))
}
(Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?;
Ok((Self::Metal(s), shape))
}
_ => unreachable!(),
}
}
@ -262,10 +206,6 @@ impl Storage {
let storage = storage.unary_impl::<B>(layout)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.unary_impl::<B>(layout)?;
Ok(Self::Metal(storage))
}
}
}
@ -286,10 +226,6 @@ impl Storage {
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(lhs), Self::Metal(rhs)) => {
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
Ok(Self::Metal(storage))
}
(lhs, rhs) => {
// Should not happen because of the same device check above but we're defensive
// anyway.
@ -321,10 +257,6 @@ impl Storage {
let s = inp.conv1d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s))
}
(Storage::Metal(inp), Storage::Metal(kernel)) => {
let s = inp.conv1d(l, kernel, kernel_l, params)?;
Ok(Self::Metal(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@ -334,33 +266,6 @@ impl Storage {
}
}
pub(crate) fn conv_transpose1d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
self.same_device(kernel, "conv-transpose1d")?;
self.same_dtype(kernel, "conv-transpose1d")?;
match (self, &kernel) {
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
Ok(Self::Cpu(s))
}
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "conv-transpose1d",
}
.bt()),
}
}
pub(crate) fn conv2d(
&self,
l: &Layout,
@ -379,10 +284,6 @@ impl Storage {
let s = inp.conv2d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s))
}
(Storage::Metal(inp), Storage::Metal(kernel)) => {
let s = inp.conv2d(l, kernel, kernel_l, params)?;
Ok(Self::Metal(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@ -392,37 +293,6 @@ impl Storage {
}
}
pub(crate) fn conv_transpose2d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose2D,
) -> Result<Self> {
self.same_device(kernel, "conv_transpose2d")?;
self.same_dtype(kernel, "conv_transpose2d")?;
match (self, &kernel) {
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
Ok(Self::Cpu(s))
}
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s))
}
(Storage::Metal(inp), Storage::Metal(kernel)) => {
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
Ok(Self::Metal(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "conv_transpose2d",
}
.bt()),
}
}
pub(crate) fn avg_pool2d(
&self,
layout: &Layout,
@ -438,10 +308,6 @@ impl Storage {
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
Ok(Self::Metal(storage))
}
}
}
@ -460,27 +326,6 @@ impl Storage {
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
Ok(Self::Metal(storage))
}
}
}
pub(crate) fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.upsample_nearest1d(layout, sz)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.upsample_nearest1d(layout, sz)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.upsample_nearest1d(layout, sz)?;
Ok(Self::Metal(storage))
}
}
}
@ -494,10 +339,6 @@ impl Storage {
let storage = storage.upsample_nearest2d(layout, h, w)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.upsample_nearest2d(layout, h, w)?;
Ok(Self::Metal(storage))
}
}
}
@ -521,10 +362,6 @@ impl Storage {
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => {
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
Ok(Self::Metal(storage))
}
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@ -551,10 +388,6 @@ impl Storage {
let storage = s.gather(l, indexes, indexes_l, d)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(s), Self::Metal(indexes)) => {
let storage = s.gather(l, indexes, indexes_l, d)?;
Ok(Self::Metal(storage))
}
_ => unreachable!(),
}
}
@ -579,10 +412,6 @@ impl Storage {
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Metal(storage))
}
_ => unreachable!(),
}
}
@ -607,10 +436,6 @@ impl Storage {
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Metal(storage))
}
_ => unreachable!(),
}
}
@ -632,10 +457,6 @@ impl Storage {
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(lhs), Self::Metal(rhs)) => {
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
Ok(Self::Metal(storage))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@ -663,10 +484,6 @@ impl Storage {
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(lhs), Self::Metal(rhs)) => {
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
Ok(Self::Metal(storage))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@ -686,9 +503,6 @@ impl Storage {
match (self, dst) {
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
(Self::Metal(src), Self::Metal(dst)) => {
Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),

File diff suppressed because it is too large Load Diff

View File

@ -22,23 +22,3 @@ pub fn has_mkl() -> bool {
pub fn cuda_is_available() -> bool {
cfg!(feature = "cuda")
}
pub fn metal_is_available() -> bool {
cfg!(feature = "metal")
}
pub fn with_avx() -> bool {
cfg!(target_feature = "avx")
}
pub fn with_neon() -> bool {
cfg!(target_feature = "neon")
}
pub fn with_simd128() -> bool {
cfg!(target_feature = "simd128")
}
pub fn with_f16c() -> bool {
cfg!(target_feature = "f16c")
}

View File

@ -107,10 +107,6 @@ impl Var {
Ok(Self(inner))
}
pub fn as_detached_tensor(&self) -> Tensor {
self.0.detach()
}
pub fn as_tensor(&self) -> &Tensor {
&self.0
}

View File

@ -1,5 +1,6 @@
mod test_utils;
use anyhow::Result;
use candle_core::{test_device, test_utils, Device, IndexOp, Tensor};
use candle_core::{Device, Tensor};
/* This test is based on the following script.
import torch
@ -13,11 +14,6 @@ res = torch.nn.functional.conv1d(t, w)
print(res.flatten())
res = torch.nn.functional.conv1d(t, w, padding=1)
print(res.flatten())
w_t = w.transpose(0, 1)
res = torch.nn.functional.conv_transpose1d(t, w_t)
print(res.shape)
print(res)
*/
fn conv1d(dev: &Device) -> Result<()> {
let t = Tensor::new(
@ -37,41 +33,32 @@ fn conv1d(dev: &Device) -> Result<()> {
dev,
)?
.reshape((2, 4, 3))?;
let res = t.conv1d(&w, 0, 1, 1, 1)?;
let res = t.conv1d(&w, 0, 1)?;
assert_eq!(res.dims(), [1, 2, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.6357, -1.3336, 4.1393, -1.1784, 3.5675, 0.5069]
);
let res = t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?;
let res = t.conv1d(&w, /*padding*/ 1, 1)?;
assert_eq!(res.dims(), [1, 2, 5]);
// Same as pytorch default padding: use zeros.
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 7]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
4.7076, -5.9745, -0.8276, 1.621
],
);
Ok(())
}
fn conv1d_small(dev: &Device) -> Result<()> {
let t = Tensor::new(&[0.4056f32, -0.8689, -0.0773, -1.5630], dev)?.reshape((1, 1, 4))?;
let w = Tensor::new(&[1f32, 0., 0.], dev)?.reshape((1, 1, 3))?;
let res = t.conv1d(&w, 0, 1, 1, 1)?;
let res = t.conv1d(&w, 0, 1)?;
assert_eq!(res.dims(), [1, 1, 2]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.4056, -0.8689]
);
let res = t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?;
let res = t.conv1d(&w, /*padding*/ 1, 1)?;
assert_eq!(res.dims(), [1, 1, 4]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@ -90,19 +77,6 @@ print(t.flatten())
print(w.flatten())
res = torch.nn.functional.conv2d(t, w)
print(res.flatten())
w_t = w.transpose(0, 1)
res = torch.nn.functional.conv_transpose2d(t, w_t)
print(res.shape)
print(res)
res = torch.nn.functional.conv2d(t, w, dilation=2)
print(res.shape)
print(res[0])
res = torch.nn.functional.conv_transpose2d(t, w_t, dilation=2)
print(res.shape)
print(res)
*/
fn conv2d(dev: &Device) -> Result<()> {
let t = Tensor::new(
@ -135,7 +109,7 @@ fn conv2d(dev: &Device) -> Result<()> {
)?;
let t = t.reshape((1, 4, 5, 5))?;
let w = w.reshape((2, 4, 3, 3))?;
let res = t.conv2d(&w, 0, 1, 1, 1)?;
let res = t.conv2d(&w, 0, 1)?;
assert_eq!(res.dims(), [1, 2, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@ -144,69 +118,6 @@ fn conv2d(dev: &Device) -> Result<()> {
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
]
);
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 7, 7]);
assert_eq!(
test_utils::to_vec3_round(&res.i(0)?, 4)?,
[
[
[-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277],
[1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375],
[0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889],
[0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632],
[-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985],
[2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114],
[5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579]
],
[
[1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211],
[-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131],
[1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621],
[-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142],
[7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059],
[-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516],
[-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171]
]
]
);
// Dilations.
let res = t.conv2d(&w, 0, 1, 2, 1)?;
assert_eq!(res.dims(), [1, 2, 1, 1]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.45, -2.3504],
);
// Transpose and dilations.
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?;
assert_eq!(res.dims(), [1, 2, 9, 9]);
assert_eq!(
test_utils::to_vec3_round(&res.i(0)?, 4)?,
[
[
[-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277],
[2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499],
[-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376],
[-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141],
[-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822],
[0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03],
[-2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51, -3.5024],
[4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787],
[5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579]
],
[
[1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211],
[-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278],
[1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861],
[1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185],
[1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642],
[3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957],
[5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856],
[-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908],
[-5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827, 1.0171]
]
]
);
Ok(())
}
@ -220,16 +131,6 @@ print(t.flatten())
print(w.flatten())
res = torch.nn.functional.conv2d(t, w)
print(res.flatten())
w_t = w.transpose(0, 1)
res = torch.nn.functional.conv_transpose2d(t, w_t)
print(res.shape)
print(res.flatten())
t_t = w.transpose(0, 1)
res = torch.nn.functional.conv_transpose2d(t_t, w)
print(res.shape)
print(res.flatten())
*/
fn conv2d_small(dev: &Device) -> Result<()> {
let t = Tensor::new(
@ -242,41 +143,12 @@ fn conv2d_small(dev: &Device) -> Result<()> {
let w = Tensor::new(&[-0.9259f32, 1.3017], dev)?;
let t = t.reshape((1, 2, 3, 3))?;
let w = w.reshape((1, 2, 1, 1))?;
let res = t.conv2d(&w, 0, 1, 1, 1)?;
let res = t.conv2d(&w, 0, 1)?;
assert_eq!(res.dims(), [1, 1, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539]
);
let res = t.conv2d(&w, 2, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 7, 7]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1640, -0.0111, -0.1742, 0.0000, 0.0000,
0.0000, 0.0000, 2.6437, -2.0268, 1.1823, 0.0000, 0.0000, 0.0000, 0.0000, 3.2855,
-1.0324, 0.2539, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
]
);
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539],
);
let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [2, 2, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[
-0.3755, 0.8045, -0.6336, -0.2218, -1.1369, 0.8599, 1.5768, -0.1268, -0.1728, 0.528,
-1.131, 0.8908, 0.3118, 1.5984, -1.2089, -2.2168, 0.1783, 0.2429, -0.3838, 0.5802,
-0.3268, -2.0382, 0.6329, -0.2293, -1.2154, 0.6441, -0.3035, 0.5396, -0.8156, 0.4594,
2.8654, -0.8898, 0.3224, 1.7087, -0.9056, 0.4267
]
);
Ok(())
}
@ -290,7 +162,7 @@ fn conv2d_smaller(dev: &Device) -> Result<()> {
let w = Tensor::new(&[1f32, 1., 1., 1., 1., 1., 1., 1., 1.], dev)?;
let t = t.reshape((1, 1, 3, 3))?;
let w = w.reshape((1, 1, 3, 3))?;
let res = t.conv2d(&w, 0, 1, 1, 1)?;
let res = t.conv2d(&w, 0, 1)?;
assert_eq!(res.dims(), [1, 1, 1, 1]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@ -299,297 +171,8 @@ fn conv2d_smaller(dev: &Device) -> Result<()> {
Ok(())
}
/* This test is based on the following script.
import torch
torch.manual_seed(4242)
t = torch.randn((1, 2, 4, 2))
w = torch.randn((1, 2, 1, 1))
print(t.flatten())
print(w.flatten())
res = torch.nn.functional.conv2d(t, w)
print(res.flatten())
*/
fn conv2d_non_square(dev: &Device) -> Result<()> {
let t = Tensor::new(
&[
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699,
],
dev,
)?;
let w = Tensor::new(&[-1.1351f32, 1.3841], dev)?;
let t = t.reshape((1, 2, 4, 2))?;
let w = w.reshape((1, 2, 1, 1))?;
let res = t.conv2d(&w, 0, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 4, 2]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.2312, 5.2238, 2.3772, 1.9076, 2.0256, -0.5776, -1.6028, -1.467]
);
Ok(())
}
/*
import torch
torch.manual_seed(4242)
t = torch.randn((1, 4, 5, 5), requires_grad=True)
w = torch.randn((2, 4, 3, 3), requires_grad=True)
print(t.flatten())
print(w.flatten())
res = torch.nn.functional.conv2d(t, w)
print(res.flatten())
loss = (res ** 2).sum()
print(loss)
loss.backward()
print(t.grad.shape)
print(t.grad.flatten())
print(w.grad.shape)
print(w.grad.flatten())
t.grad.zero_()
w.grad.zero_()
res = torch.nn.functional.conv2d(t, w, stride=2)
print(res.flatten())
loss = (res ** 2).sum()
print(loss)
loss.backward()
print(t.grad.shape)
print(t.grad[0])
print(w.grad.shape)
print(w.grad[0])
*/
fn conv2d_grad(dev: &Device) -> Result<()> {
use candle_core::Var;
let t = Var::from_slice(
&[
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395,
1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836,
0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130, 1.3123,
1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071, 1.1586,
0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090, 0.2049,
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
],
(1, 4, 5, 5),
dev,
)?;
let w = Var::from_slice(
&[
-0.9325f32, 0.6451, -0.8537, 0.2378, 0.8764, -0.1832, 0.2987, -0.6488, -0.2273,
-2.4184, -0.1192, -0.4821, -0.5079, -0.5766, -2.4729, 1.6734, 0.4558, 0.2851, 1.1514,
-0.9013, 1.0662, -0.1817, -0.0259, 0.1709, 0.5367, 0.7513, 0.8086, -2.2586, -0.5027,
0.9141, -1.3086, -1.3343, -1.5669, -0.1657, 0.7958, 0.1432, 0.3896, -0.4501, 0.1667,
0.0714, -0.0952, 1.2970, -0.1674, -0.3178, 1.0677, 0.3060, 0.7080, 0.1914, 1.1679,
-0.3602, 1.9265, -1.8626, -0.5112, -0.0982, 0.2621, 0.6565, 0.5908, 1.0089, -0.1646,
1.8032, -0.6286, 0.2016, -0.3370, 1.2555, 0.8009, -0.6488, -0.4652, -1.5685, 1.5860,
0.5583, 0.4623, 0.6026,
],
(2, 4, 3, 3),
dev,
)?;
let res = t.conv2d(&w, 0, 1, 1, 1)?;
let loss = res.sqr()?.sum_all()?;
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 741.12f32);
let grads = loss.backward()?;
let grad_t = grads.get(&t).unwrap();
let grad_w = grads.get(&w).unwrap();
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&grad_t.flatten_all()?, 2)?,
[
9.29, -2.84, -5.71, 3.38, -7.71, -19.15, 7.02, 29.1, 9.34, 34.73, -22.87, 24.35,
-39.88, -14.01, 21.08, 9.94, 13.63, -34.68, 11.21, -6.26, 7.72, -6.32, -16.64, -1.08,
-20.22, 21.73, -0.37, -4.06, 5.82, -3.65, -30.73, 14.55, 87.7, 31.6, 4.53, -89.78,
-75.37, -57.43, -7.56, 92.96, 18.79, -4.63, -159.75, -42.47, -47.26, 52.88, 37.32,
49.0, 12.82, 2.01, -8.98, 20.18, 16.62, 12.06, 15.38, 20.0, 2.57, -15.22, 72.62,
-10.75, 2.25, -31.2, 3.75, -0.2, 9.76, -0.68, 5.21, -40.44, -22.59, -61.61, 17.28,
20.41, 37.55, 5.23, 6.81, 23.54, 23.62, -9.99, -9.13, 4.87, -35.06, -26.1, 63.48,
25.81, -39.21, -70.68, -46.96, 2.33, 41.81, 82.42, -28.63, -11.78, -35.33, -10.28,
-28.57, -9.13, 7.21, -9.05, -9.62, -11.25
]
);
assert_eq!(
test_utils::to_vec1_round(&grad_w.flatten_all()?, 2)?,
[
-28.92, -22.88, -141.23, 73.35, 61.07, 47.81, -20.0, -73.71, -41.82, -13.59, 21.5,
28.72, 28.57, -46.85, -90.19, 143.61, 16.68, 7.43, 18.88, -90.81, -20.29, 54.79, 82.63,
22.94, 77.81, -16.39, -13.2, 9.34, -40.39, -26.62, 5.33, -60.91, 9.09, -59.37, 7.08,
58.64, 5.55, 20.52, 2.5, -17.25, -6.8, 22.21, 30.15, -7.52, -37.46, 5.67, 22.58, 9.03,
47.05, 17.61, 37.31, -98.13, -14.61, -4.8, -6.36, 44.69, 23.34, 8.37, -13.52, 80.05,
-34.24, -16.36, -12.31, 1.92, -33.62, -14.1, -49.23, -7.39, 11.5, -9.98, 9.66, 29.6
]
);
// Same as before but with stride.
let res = t.conv2d(&w, 0, 2, 1, 1)?;
let loss = res.sqr()?.sum_all()?;
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 277.16f32);
let grads = loss.backward()?;
let grad_t = grads.get(&t).unwrap();
let grad_w = grads.get(&w).unwrap();
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
assert_eq!(
test_utils::to_vec3_round(&grad_t.i(0)?, 2)?,
[
[
[9.29, -7.03, 0.94, 3.49, -7.71],
[-1.8, -7.82, 8.9, 8.46, 7.43],
[-25.84, 22.09, -19.27, -0.22, 1.69],
[4.02, 18.53, -18.37, 2.3, -24.51],
[7.72, -9.68, -12.34, 5.6, -20.22]
],
[
[21.73, 3.39, -18.27, 3.86, -3.65],
[8.25, 3.73, 30.73, -8.61, -11.93],
[-72.15, -15.36, -17.53, -12.32, -1.61],
[-22.32, -7.79, -91.82, 6.44, -37.69],
[52.88, 14.44, 42.75, 9.88, 2.01]
],
[
[-8.98, 9.91, 6.75, -4.68, 15.38],
[4.93, -0.33, 9.94, -1.46, 14.78],
[13.62, -30.63, 3.96, -3.58, -4.48],
[-14.13, 1.19, -34.43, 3.08, -33.83],
[17.28, 12.94, 31.83, -3.35, 6.81]
],
[
[23.54, 6.98, -24.52, 0.52, 4.87],
[9.65, 6.18, 1.71, -25.23, -4.93],
[-54.99, -23.66, 3.19, -3.73, 18.58],
[-21.35, -10.39, -39.88, 28.73, -30.76],
[-9.13, 11.12, -14.0, -8.23, -11.25]
]
]
);
assert_eq!(
test_utils::to_vec3_round(&grad_w.i(0)?, 2)?,
[
[
[28.34, -7.91, -45.75],
[21.03, 3.86, 29.86],
[0.72, -36.58, -35.28]
],
[
[-16.04, 11.53, -16.38],
[29.62, -16.32, -48.35],
[57.5, 28.29, 25.81]
],
[
[2.93, -19.6, 1.57],
[27.15, 53.88, -24.64],
[12.74, -22.6, -26.2]
],
[
[-0.18, -14.86, -6.82],
[-19.55, -2.72, 45.9],
[-2.54, 36.97, 27.11]
]
]
);
// Replicate the issue from https://github.com/huggingface/candle/issues/1212
let res = t.i((.., .., 0..4, 0..4))?.conv2d(&w, 0, 2, 1, 1)?;
let loss = res.sqr()?.sum_all()?;
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 21.12f32);
let grads = loss.backward()?;
let grad_t = grads.get(&t).unwrap();
let grad_w = grads.get(&w).unwrap();
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
assert_eq!(
test_utils::to_vec3_round(&grad_t.i(0)?, 2)?,
[
[
[9.29, -7.03, 7.87, 0.0, 0.0],
[-1.8, -7.82, 5.9, 0.0, 0.0],
[-3.12, 4.49, 5.52, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
],
[
[21.73, 3.39, 4.77, 0.0, 0.0],
[8.25, 3.73, 27.61, 0.0, 0.0],
[-20.55, -5.61, -2.77, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
],
[
[-8.98, 9.91, -7.15, 0.0, 0.0],
[4.93, -0.33, 4.56, 0.0, 0.0],
[-6.7, -5.76, -8.05, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
],
[
[23.54, 6.98, -10.0, 0.0, 0.0],
[9.65, 6.18, 18.72, 0.0, 0.0],
[3.29, -5.27, 0.79, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
]
]
);
assert_eq!(
test_utils::to_vec3_round(&grad_w.i(0)?, 2)?,
[
[
[-3.47, 7.44, 0.66],
[12.89, -3.4, -9.29],
[-14.16, -0.83, 7.14]
],
[
[-3.23, 5.37, -3.02],
[-2.12, -11.24, 1.94],
[6.97, 7.2, 2.99]
],
[
[-4.04, -3.31, 4.87],
[-6.68, -5.68, 1.73],
[-5.54, 4.32, 0.52]
],
[[-4.72, 1.5, 4.72], [3.79, 4.04, 6.76], [-4.6, 5.8, 6.93]]
]
);
Ok(())
}
test_device!(conv1d, conv1d_cpu, conv1d_gpu, conv1d_metal);
test_device!(
conv1d_small,
conv1d_small_cpu,
conv1d_small_gpu,
conv1d_small_metal
);
test_device!(conv2d, conv2d_cpu, conv2d_gpu, conv2d_metal);
test_device!(
conv2d_non_square,
conv2d_non_square_cpu,
conv2d_non_square_gpu,
conv2d_non_square_metal
);
test_device!(
conv2d_small,
conv2d_small_cpu,
conv2d_small_gpu,
conv2d_small_metal
);
test_device!(
conv2d_smaller,
conv2d_smaller_cpu,
conv2d_smaller_gpu,
conv2d_smaller_metal
);
test_device!(
conv2d_grad,
conv2d_grad_cpu,
conv2d_grad_gpu,
conv2_grad_metal
);
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);

View File

@ -1,8 +1,10 @@
use candle_core::backend::BackendStorage;
use candle_core::cpu_backend;
use candle_core::test_utils::to_vec1_round;
use candle_core::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor};
mod test_utils;
use test_utils::to_vec1_round;
fn fwd<T: num_traits::Float>(v: T, alpha: f64) -> T {
if v.is_sign_positive() {
v
@ -37,7 +39,7 @@ fn custom_op1_no_backward() -> Result<()> {
let cpu = &Device::Cpu;
let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?;
let t = (t - 5.)?;
let elu_t = t.apply_op1_no_bwd(&Elu { alpha: 1. })?;
let elu_t = t.custom_op1(Elu { alpha: 1. })?;
assert_eq!(
to_vec1_round(&elu_t, 4)?,
&[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
@ -94,7 +96,7 @@ impl CustomOp1 for EluWithBackward {
fn bwd(&self, arg: &Tensor, _res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
let alpha = self.0.alpha;
let bwd = arg.apply_op1(EluBackward { alpha })?;
let bwd = arg.custom_op1(EluBackward { alpha })?;
Ok(Some(grad_res.mul(&bwd)?))
}
}
@ -103,7 +105,7 @@ impl CustomOp1 for EluWithBackward {
fn custom_op1_with_backward() -> Result<()> {
let cpu = &Device::Cpu;
let t = candle_core::Var::new(&[-2f32, 0f32, 2f32], cpu)?;
let elu_t = t.apply_op1(EluWithBackward::new(2.))?;
let elu_t = t.custom_op1(EluWithBackward::new(2.))?;
assert_eq!(to_vec1_round(&elu_t, 4)?, &[-1.7293, 0.0, 2.0]);
let grads = elu_t.backward()?;

View File

@ -1,5 +1,6 @@
use anyhow::{Context, Result};
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
use candle_core::{Device, Shape, Tensor, Var};
mod test_utils;
fn simple_grad(device: &Device) -> Result<()> {
let x = Var::new(&[3f32, 1., 4.], device)?;
@ -173,331 +174,11 @@ fn unary_grad(device: &Device) -> Result<()> {
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [6., 2., 8., 0.3]);
assert_eq!(grad_x.to_vec1::<f32>()?, [2., 2., 2., 2.]);
let x = Var::new(&[3f32, 1., 4., 0.15], device)?;
let y = x.powf(2.5)?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(test_utils::to_vec1_round(&y, 2)?, [15.59, 1.0, 32.0, 0.01]);
assert_eq!(
test_utils::to_vec1_round(grad_x, 2)?,
[12.99, 2.5, 20.0, 0.15]
);
let y = x.tanh()?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(test_utils::to_vec1_round(&y, 2)?, [1.0, 0.76, 1.0, 0.15]);
assert_eq!(
test_utils::to_vec1_round(grad_x, 2)?,
[0.01, 0.42, 0.0, 0.98],
);
// testing compared to pytorch nn.GELU(approximate = 'tanh')
let y = x.gelu()?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[2.9964, 0.8412, 3.9999, 0.0839]
);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[1.0116, 1.0830, 1.0003, 0.6188],
);
// Testing compared to pytorch torch.erf
//
// import torch
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
// y = x.erf()
// print(y)
// loss = y.sum()
// loss.backward()
// print(x.grad)
let y = x.erf()?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(test_utils::to_vec1_round(&y, 4)?, [1.0, 0.8427, 1.0, 0.168]);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[0.0001, 0.4151, 0.0, 1.1033],
);
// Testing compared to pytorch nn.GELU(approximate = 'none')
//
// import torch
// import torch.nn.functional as F
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
// y = F.gelu(x, approximate='none')
// print(y)
// loss = y.sum()
// loss.backward()
// print(x.grad)
let y = x.gelu_erf()?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[2.9960, 0.8413, 3.9999, 0.0839]
);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[1.0119, 1.0833, 1.0005, 0.6188],
);
// Testing compared to pytorch elu
//
// import torch
// import torch.nn.functional as F
// x = torch.tensor([-1.0, 0.0, -2.0, 3.0], requires_grad=True)
// y = F.elu(x, alpha=2.0)
// print(y)
// loss = y.min
// loss = y.sum()
// loss.backward()
// print(x.grad)
let elu_x = Var::new(&[-1.0f32, 0., -2., 3.], device)?;
let y = elu_x.elu(2.)?;
let grads = y.backward()?;
let grad_x = grads.get(&elu_x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[-1.2642, 0.0000, -1.7293, 3.0000]
);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[0.7358, 2.0000, 0.2707, 1.0000]
);
// manually checked: see comments
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
let y = x.interpolate2d(6, 6)?.reshape(36)?;
#[rustfmt::skip]
let z = Tensor::new(
&[
1_f32, 02., 03., 04., 05., 06.,
07., 08., 09., 10., 11., 12.,
13., 14., 15., 16., 17., 18.,
19., 20., 21., 22., 23., 24.,
25., 26., 27., 28., 29., 30.,
31., 32., 33., 34., 35., 36.,
],
device,
)?;
// gradient should be
// row 1
// 1+2+7+8 = 18
// 3+4+9+10 = 26
// 5+6+11+12 = 34
// row 2
// 13+14+19+20 = 66
// 15+16+21+22 = 74
// 17+18+23+24 = 82
// row 3
// 25+26+31+32 = 114
// 27+28+33+34 = 122
// 29+30+35+36 = 130
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
let grads = loss.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?,
[[18_f32, 26., 34.], [66., 74., 82.], [114., 122., 130.]]
);
// manually checked: see comments
let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?;
let y = x.interpolate2d(6, 6)?.reshape(36)?;
#[rustfmt::skip]
let z = Tensor::new(
&[
1_f32, 02., 03., 04., 05., 06.,
07., 08., 09., 10., 11., 12.,
13., 14., 15., 16., 17., 18.,
19., 20., 21., 22., 23., 24.,
25., 26., 27., 28., 29., 30.,
31., 32., 33., 34., 35., 36.,
],
device,
)?;
// gradient should be
// row 1
// 1+2+3+7+8+9+13+14+15 = 72
// 4+5+6+10+11+12+16+17+18 = 99
// row 2
// 19+20+21+25+26+27+31+32+33 = 234
// 22+23+24+28+29+30+34+35+36 = 243
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
let grads = loss.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?,
[[72_f32, 99.], [234., 261.]]
);
// manually checked: see comments
let x = Var::new(&[[[[1f32, 2.], [4., 5.]], [[6f32, 7.], [8., 9.]]]], device)?;
let y = x.interpolate2d(4, 4)?.reshape(32)?;
#[rustfmt::skip]
let z = Tensor::new(
&[
1_f32, 02., 03., 04.,
05., 06., 07., 08.,
09., 10., 11., 12.,
13., 14., 15., 16.,
17., 18., 19., 20.,
21., 22., 23., 24.,
25., 26., 27., 28.,
29., 30., 31., 32.
],
device,
)?;
// gradient should be
// m1r1
// 1+2+5+6=14
// 3+4+7+8=22
// m1r2
// 9+10+13+14=46
// 11+12+15+16=54
// m2r1
// 17+18+21+22=78
// 19+20+23+24=86
// m2r2
// 25+26+29+30=110
// 27+28+31+32=118
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
let grads = loss.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?,
[[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]]
);
// manually checked: see comments
let x = Var::new(
&[[[[1f32, 2.], [4., 5.]]], [[[6f32, 7.], [8., 9.]]]],
device,
)?;
let y = x.interpolate2d(4, 4)?.reshape(32)?;
#[rustfmt::skip]
let z = Tensor::new(
&[
1_f32, 02., 03., 04.,
05., 06., 07., 08.,
09., 10., 11., 12.,
13., 14., 15., 16.,
17., 18., 19., 20.,
21., 22., 23., 24.,
25., 26., 27., 28.,
29., 30., 31., 32.
],
device,
)?;
// gradient should be
// m1r1
// 1+2+5+6=14
// 3+4+7+8=22
// m1r2
// 9+10+13+14=46
// 11+12+15+16=54
// m2r1
// 17+18+21+22=78
// 19+20+23+24=86
// m2r2
// 25+26+29+30=110
// 27+28+31+32=118
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
let grads = loss.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?,
[[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]]
);
Ok(())
}
fn binary_grad(device: &Device) -> Result<()> {
let x = Var::new(&[3f32, 1., -4., -1.], device)?;
let x = x.as_tensor();
// leaky relu
let y = x.maximum(&(x * 0.1)?)?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(x.to_vec1::<f32>()?, [3., 1., -4., -1.]);
assert_eq!(y.to_vec1::<f32>()?, [3., 1., -0.4, -0.1]);
assert_eq!(grad_x.to_vec1::<f32>()?, [1., 1., 0.1, 0.1]);
let y = x.minimum(&(x * 0.1)?)?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [0.3, 0.1, -4., -1.]);
assert_eq!(grad_x.to_vec1::<f32>()?, [0.1, 0.1, 1., 1.]);
// This one is easy to mess up, we want the gradient to be one as it is the identity function.
let y = x.minimum(x)?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [3., 1., -4., -1.]);
assert_eq!(grad_x.to_vec1::<f32>()?, [1., 1., 1., 1.]);
let x_var = Var::new(&[3f32, 1., -4., -1., 5., 9.], device)?;
let x = x_var.as_tensor();
let y_var = Var::new(&[2f32, 7., 1.], device)?;
let y = y_var.as_tensor();
let ss = x
.reshape((2, 3))?
.slice_scatter0(&y.reshape((1, 3))?, 1)?
.sqr()?;
let grads = ss.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
let grad_y = grads.get(y).context("no grad for y")?;
assert_eq!(ss.to_vec2::<f32>()?, [[9., 1., 16.], [4., 49., 1.]]);
assert_eq!(grad_x.to_vec1::<f32>()?, [6.0, 2.0, -8.0, 0.0, 0.0, 0.0]);
assert_eq!(grad_y.to_vec1::<f32>()?, [4.0, 14.0, 2.0]);
Ok(())
}
test_device!(
simple_grad,
simple_grad_cpu,
simple_grad_gpu,
simple_grad_metal
);
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu, sum_grad_metal);
test_device!(
matmul_grad,
matmul_grad_cpu,
matmul_grad_gpu,
matmul_grad_metal
);
test_device!(
grad_descent,
grad_descent_cpu,
grad_descent_gpu,
grad_descent_metal
);
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu, unary_grad_metal);
test_device!(
binary_grad,
binary_grad_cpu,
binary_grad_gpu,
binary_grad_metal
);
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu);

View File

@ -1,6 +1,8 @@
use anyhow::Result;
use candle_core::{Device, IndexOp, Tensor};
mod test_utils;
#[test]
fn integer_index() -> Result<()> {
let dev = Device::Cpu;
@ -91,32 +93,3 @@ fn index_3d() -> Result<()> {
assert_eq!(tensor.i((1, .., 3))?.to_vec1::<u32>()?, &[15, 19, 23]);
Ok(())
}
#[test]
fn slice_assign() -> Result<()> {
let dev = Device::Cpu;
let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?;
let src = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((3, 2))?;
let out = tensor.slice_assign(&[1..4, 3..5], &src)?;
assert_eq!(
out.to_vec2::<u32>()?,
&[
[0, 1, 2, 3, 4],
[5, 6, 7, 0, 1],
[10, 11, 12, 2, 3],
[15, 16, 17, 4, 5]
]
);
let out = tensor.slice_assign(&[0..3, 0..2], &src)?;
assert_eq!(
out.to_vec2::<u32>()?,
&[
[0, 1, 2, 3, 4],
[2, 3, 7, 8, 9],
[4, 5, 12, 13, 14],
[15, 16, 17, 18, 19]
]
);
Ok(())
}

View File

@ -1,4 +1,5 @@
use candle::{test_device, Device, IndexOp, Result, Tensor};
mod test_utils;
use candle::{Device, IndexOp, Result, Tensor};
use candle_core as candle;
fn contiguous(device: &Device) -> Result<()> {
@ -49,7 +50,7 @@ fn contiguous(device: &Device) -> Result<()> {
Ok(())
}
test_device!(contiguous, contiguous_cpu, contiguous_gpu, contiguous_metal);
test_device!(contiguous, contiguous_cpu, contiguous_gpu);
#[test]
fn strided_blocks() -> Result<()> {

View File

@ -1,9 +0,0 @@
import numpy as np
x = np.arange(10)
# Write a npy file.
np.save("test.npy", x)
# Write multiple values to a npz file.
values = { "x": x, "x_plus_one": x + 1 }
np.savez("test.npz", **values)

View File

@ -1,4 +1,5 @@
use candle_core::{test_device, test_utils, Device, IndexOp, Result, Tensor};
mod test_utils;
use candle_core::{Device, IndexOp, Result, Tensor};
// https://github.com/huggingface/candle/issues/364
fn avg_pool2d(dev: &Device) -> Result<()> {
@ -6,15 +7,8 @@ fn avg_pool2d(dev: &Device) -> Result<()> {
1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?;
let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?;
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
assert_eq!(pool.to_vec2::<f32>()?, [[0.5f32, 1.], [1., 1.]]);
let data: Vec<f32> = vec![
1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1.,
];
let t = Tensor::from_vec(data, (1, 1, 2, 8), dev)?;
let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?;
assert_eq!(pool.to_vec2::<f32>()?, [[5. / 4., 6. / 4., 6. / 4., 1.]]);
Ok(())
}
@ -24,12 +18,8 @@ fn max_pool2d(dev: &Device) -> Result<()> {
];
let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?;
let pool = t.max_pool2d(2)?.squeeze(0)?.squeeze(0)?;
let pool = t.max_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
assert_eq!(pool.to_vec2::<f32>()?, [[2f32, 3.], [5., 1.]]);
let t = t.reshape((1, 1, 2, 8))?;
let pool = t.max_pool2d(2)?.squeeze(0)?.squeeze(0)?;
assert_eq!(pool.to_vec2::<f32>()?, [[2.0, 3.0, 5.0, 1.0]]);
Ok(())
}
@ -53,29 +43,16 @@ fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
dev,
)?
.reshape((1, 2, 4, 4))?;
let pool = t.avg_pool2d(2)?.squeeze(0)?;
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?;
assert_eq!(
test_utils::to_vec3_round(&pool, 4)?,
test_utils::to_vec3_round(pool, 4)?,
[
[[-1.1926, -0.0395], [0.2688, 0.1871]],
[[0.1835, -0.1606], [0.6249, 0.3217]]
]
);
let pool = t.avg_pool2d(3)?.squeeze(0)?;
assert_eq!(
test_utils::to_vec3_round(&pool, 4)?,
[[[0.085]], [[0.0078]]]
);
let t = t.reshape((1, 1, 4, 8))?;
let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?;
assert_eq!(
test_utils::to_vec2_round(&pool, 4)?,
[
[0.7745, 0.0276, -1.6983, 0.12],
[0.3542, 0.1625, 0.4542, -0.0014]
]
);
let pool = t.avg_pool2d((3, 3), (3, 3))?.squeeze(0)?;
assert_eq!(test_utils::to_vec3_round(pool, 4)?, [[[0.085]], [[0.0078]]]);
Ok(())
}
@ -98,17 +75,15 @@ fn upsample_nearest2d(dev: &Device) -> Result<()> {
Ok(())
}
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu, avg_pool2d_metal);
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu);
test_device!(
avg_pool2d_pytorch,
avg_pool2d_pytorch_cpu,
avg_pool2d_pytorch_gpu,
avg_pool2d_pytorch_metal
avg_pool2d_pytorch_gpu
);
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu, max_pool2d_metal);
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);
test_device!(
upsample_nearest2d,
upsample_nearest2d_cpu,
upsample_nearest2d_gpu,
upsample_nearest2d_metal
upsample_nearest2d_gpu
);

View File

@ -1,37 +0,0 @@
import torch
from collections import OrderedDict
# Write a trivial tensor to a pt file
a= torch.tensor([[1,2,3,4], [5,6,7,8]])
o = OrderedDict()
o["test"] = a
# Write a trivial tensor to a pt file
torch.save(o, "test.pt")
############################################################################################################
# Write a trivial tensor to a pt file with a key
torch.save({"model_state_dict": o}, "test_with_key.pt")
############################################################################################################
# Create a tensor with fortran contiguous memory layout
import numpy as np
# Step 1: Create a 3D NumPy array with Fortran order using a range of numbers
# For example, creating a 2x3x4 array
array_fortran = np.asfortranarray(np.arange(1, 2*3*4 + 1).reshape(2, 3, 4))
# Verify the memory order
print("Is Fortran contiguous (F order):", array_fortran.flags['F_CONTIGUOUS']) # Should be True
print("Is C contiguous (C order):", array_fortran.flags['C_CONTIGUOUS']) # Should be False
# Step 2: Convert the NumPy array to a PyTorch tensor
tensor_fortran = torch.from_numpy(array_fortran)
# Verify the tensor layout
print("Tensor stride:", tensor_fortran.stride()) # Stride will reflect the Fortran memory layout
# Step 3: Save the PyTorch tensor to a .pth file
torch.save({"tensor_fortran": tensor_fortran}, 'fortran_tensor_3d.pth')
print("3D Tensor saved with Fortran layout.")

View File

@ -1,31 +0,0 @@
/// Regression test for pth files not loading on Windows.
#[test]
fn test_pth() {
let tensors = candle_core::pickle::PthTensors::new("tests/test.pt", None).unwrap();
tensors.get("test").unwrap().unwrap();
}
#[test]
fn test_pth_with_key() {
let tensors =
candle_core::pickle::PthTensors::new("tests/test_with_key.pt", Some("model_state_dict"))
.unwrap();
tensors.get("test").unwrap().unwrap();
}
#[test]
fn test_pth_fortran_congiguous() {
let tensors =
candle_core::pickle::PthTensors::new("tests/fortran_tensor_3d.pth", None).unwrap();
let tensor = tensors.get("tensor_fortran").unwrap().unwrap();
assert_eq!(tensor.dims3().unwrap(), (2, 3, 4));
assert_eq!(
tensor.to_vec3::<i64>().unwrap(),
[
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]
]
);
}

File diff suppressed because it is too large Load Diff

View File

@ -1,24 +0,0 @@
use candle_core::{DType, Result, Tensor};
#[test]
fn npy() -> Result<()> {
let npy = Tensor::read_npy("tests/test.npy")?;
assert_eq!(
npy.to_dtype(DType::U8)?.to_vec1::<u8>()?,
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
);
Ok(())
}
#[test]
fn npz() -> Result<()> {
let npz = Tensor::read_npz("tests/test.npz")?;
assert_eq!(npz.len(), 2);
assert_eq!(npz[0].0, "x");
assert_eq!(npz[1].0, "x_plus_one");
assert_eq!(
npz[1].1.to_dtype(DType::U8)?.to_vec1::<u8>()?,
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
);
Ok(())
}

View File

@ -1,4 +1,5 @@
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor, D};
mod test_utils;
use candle_core::{DType, Device, IndexOp, Result, Tensor};
fn zeros(device: &Device) -> Result<()> {
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
@ -8,58 +9,6 @@ fn zeros(device: &Device) -> Result<()> {
Ok(())
}
fn ones(device: &Device) -> Result<()> {
assert_eq!(
Tensor::ones((2, 3), DType::U8, device)?.to_vec2::<u8>()?,
[[1, 1, 1], [1, 1, 1]],
);
assert_eq!(
Tensor::ones((2, 3), DType::U32, device)?.to_vec2::<u32>()?,
[[1, 1, 1], [1, 1, 1]],
);
assert_eq!(
Tensor::ones((2, 3), DType::I64, device)?.to_vec2::<i64>()?,
[[1, 1, 1], [1, 1, 1]],
);
assert_eq!(
Tensor::ones((2, 3), DType::F32, device)?.to_vec2::<f32>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
assert_eq!(
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
Ok(())
}
fn full(device: &Device) -> Result<()> {
assert_eq!(
Tensor::full(42u32, (2, 3), device)?.to_vec2::<u32>()?,
[[42, 42, 42], [42, 42, 42]],
);
Ok(())
}
fn arange(device: &Device) -> Result<()> {
assert_eq!(
Tensor::arange(0u8, 5u8, device)?.to_vec1::<u8>()?,
[0, 1, 2, 3, 4],
);
assert_eq!(
Tensor::arange_step(0u8, 5u8, 2, device)?.to_vec1::<u8>()?,
[0, 2, 4],
);
assert_eq!(
Tensor::arange_step(0u8, 5u8, 3, device)?.to_vec1::<u8>()?,
[0, 3],
);
assert_eq!(
Tensor::arange_step(5i64, 0i64, -1, device)?.to_vec1::<i64>()?,
[5, 4, 3, 2, 1],
);
Ok(())
}
fn add_mul(device: &Device) -> Result<()> {
let tensor = Tensor::new(&[3f32, 1., 4.], device)?;
let dim1 = tensor.dims1()?;
@ -85,71 +34,12 @@ fn tensor_2d(device: &Device) -> Result<()> {
Ok(())
}
fn clamp(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor = Tensor::new(data, device)?;
let tensor = tensor.clamp(1.5, 6.2)?;
assert_eq!(
tensor.to_vec2::<f32>()?,
[[3.0, 1.5, 4.0, 1.5, 5.0], [2.0, 1.5, 6.2, 6.2, 2.0]],
);
Ok(())
}
fn unary_op(device: &Device) -> Result<()> {
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
let tensor = Tensor::new(data, device)?;
assert_eq!(
test_utils::to_vec2_round(&tensor.gelu()?, 4)?,
[
[-0.0036, 0.8412, 3.9999, -0.046, 0.3457],
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
]
);
assert_eq!(
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
[
[-0.004, 0.8413, 3.9999, -0.046, 0.3457],
[2.6906, -0.0647, -0.1091, 1.7353, 2.7928]
]
);
assert_eq!(
test_utils::to_vec2_round(&tensor.erf()?, 4)?,
[
[-1.0, 0.8427, 1.0, -0.1125, 0.5205],
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
]
);
assert_eq!(
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
);
assert_eq!(
test_utils::to_vec2_round(&tensor.floor()?, 4)?,
[[-3.0, 1.0, 4.0, -1.0, 0.0], [2.0, -2.0, -1.0, 1.0, 2.0]]
);
assert_eq!(
test_utils::to_vec2_round(&tensor.round()?, 4)?,
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -2.0, -0.0, 2.0, 3.0]]
);
let tensor = Tensor::new(&[2997.9246, 314.15926f32], device)?;
assert_eq!(
test_utils::to_vec1_round(&tensor.round_to(2)?, 4)?,
[2997.92, 314.16]
);
assert_eq!(
test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?,
[3000.0, 300.]
);
Ok(())
}
fn binary_op(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor1 = Tensor::new(data, device)?;
let tensor = Tensor::new(data, device)?;
let data2 = &[[5f32, 5., 5., 5., 5.], [2., 1., 7., 8., 2.]];
let tensor2 = Tensor::new(data2, device)?;
let tensor = (&tensor1 + (&tensor1 * &tensor1)? / (&tensor1 + &tensor2))?;
let tensor = (&tensor + (&tensor * &tensor)? / (&tensor + &tensor2))?;
let dims = tensor.dims2()?;
assert_eq!(dims, (2, 5));
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
@ -159,17 +49,6 @@ fn binary_op(device: &Device) -> Result<()> {
let tensor = (&tensor - &tensor)?;
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
assert_eq!(content[0], [0., 0., 0., 0., 0.]);
let min = tensor1.minimum(&(&tensor2 * 0.5)?)?;
let max = tensor1.maximum(&(&tensor2 * 0.5)?)?;
assert_eq!(
min.to_vec2::<f32>()?,
[[2.5, 1.0, 2.5, 1.0, 2.5], [1.0, 0.5, 3.5, 4.0, 1.0]],
);
assert_eq!(
max.to_vec2::<f32>()?,
[[3.0, 2.5, 4.0, 2.5, 5.0], [2.0, 1.0, 7.0, 8.0, 2.0]]
);
Ok(())
}
@ -188,22 +67,6 @@ fn transpose(device: &Device) -> Result<()> {
Ok(())
}
fn var(device: &Device) -> Result<()> {
// Values taken from https://pytorch.org/docs/stable/generated/torch.var.html
let data = &[
[0.2035f32, 1.2959, 1.8101, -0.4644],
[1.5027, -0.3270, 0.5905, 0.6538],
[-1.5745, 1.3330, -0.5596, -0.6548],
[0.1264, -0.5080, 1.6420, 0.1992],
];
let tensor = Tensor::new(data, device)?;
assert_eq!(
test_utils::to_vec2_round(&tensor.var_keepdim(1)?, 4)?,
&[[1.0631], [0.559], [1.4893], [0.8258]]
);
Ok(())
}
fn sum(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
@ -717,30 +580,6 @@ fn index_select(device: &Device) -> Result<()> {
hs.to_vec2::<f32>()?,
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
);
// Prior to https://github.com/huggingface/candle/pull/1022
// There would be a bug where the last values in the result tensor would be set to 0.
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
let hs = t.index_select(&ids, 0)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0],
[6.0, 7.0, 8.0],
[3.0, 4.0, 5.0],
[0.0, 1.0, 2.0],
[6.0, 7.0, 8.0],
[3.0, 4.0, 5.0],
]
);
// Test when selecting dim > 0 with ids size different from elem count of
// target dim in source/input.
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
let hs = t.index_select(&ids, 1)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
Ok(())
}
@ -787,48 +626,6 @@ fn index_add(device: &Device) -> Result<()> {
Ok(())
}
fn slice_scatter(device: &Device) -> Result<()> {
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
assert_eq!(
t.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0],
[3.0, 4.0, 5.0],
[6.0, 7.0, 8.0],
[9.0, 10.0, 11.0]
]
);
let src = Tensor::arange(100f32, 106f32, device)?.reshape((2, 3))?;
assert_eq!(
t.slice_scatter0(&src, 0)?.to_vec2::<f32>()?,
&[
[100.0, 101.0, 102.0],
[103.0, 104.0, 105.0],
[6.0, 7.0, 8.0],
[9.0, 10.0, 11.0]
]
);
assert_eq!(
t.slice_scatter0(&src, 1)?.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0],
[100.0, 101.0, 102.0],
[103.0, 104.0, 105.0],
[9.0, 10.0, 11.0]
]
);
assert_eq!(
t.slice_scatter0(&src, 2)?.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0],
[3.0, 4.0, 5.0],
[100.0, 101.0, 102.0],
[103.0, 104.0, 105.0],
]
);
Ok(())
}
fn scatter_add(device: &Device) -> Result<()> {
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
assert_eq!(
@ -950,25 +747,6 @@ fn matmul(device: &Device) -> Result<()> {
Ok(())
}
fn broadcast_matmul(device: &Device) -> Result<()> {
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
let out = lhs.broadcast_matmul(&rhs)?;
assert_eq!(out.dims(), &[3, 6, 4, 2]);
for idx1 in 0..3 {
for idx2 in 0..6 {
let out = out.i((idx1, idx2))?;
let lhs = lhs.i((idx1, 0))?;
let rhs = rhs.i(idx2)?;
let out2 = lhs.matmul(&rhs);
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
// With cuda, we see errors of up to ~1e-12.
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
}
}
Ok(())
}
fn broadcasting(device: &Device) -> Result<()> {
let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?;
let t2 = Tensor::new(&[100f32, 200f32], device)?;
@ -1070,69 +848,27 @@ fn broadcasting(device: &Device) -> Result<()> {
Ok(())
}
fn randn(device: &Device) -> Result<()> {
let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?;
assert_eq!(tensor.dims(), [5, 3]);
let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?;
assert_eq!(tensor.dims(), [5, 3]);
Ok(())
}
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
test_device!(full, full_cpu, full_gpu, full_metal);
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
test_device!(min, min_cpu, min_gpu, min_metal);
test_device!(max, max_cpu, max_gpu, max_metal);
test_device!(argmax, argmax_cpu, argmax_gpu, argmax_metal);
test_device!(argmin, argmin_cpu, argmin_gpu, argmin_metal);
test_device!(transpose, transpose_cpu, transpose_gpu, transpose_metal);
test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
test_device!(
broadcast_matmul,
broadcast_matmul_cpu,
broadcast_matmul_gpu,
broadcast_matmul_metal
);
test_device!(
broadcasting,
broadcasting_cpu,
broadcasting_gpu,
broadcasting_metal
);
test_device!(
index_select,
index_select_cpu,
index_select_gpu,
index_select_metal
);
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
test_device!(
scatter_add,
scatter_add_cpu,
scatter_add_gpu,
scatter_add_metal
);
test_device!(
slice_scatter,
slice_scatter_cpu,
slice_scatter_gpu,
slice_scatter_metal
);
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
test_device!(var, var_cpu, var_gpu, var_metal);
test_device!(zeros, zeros_cpu, zeros_gpu);
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
test_device!(narrow, narrow_cpu, narrow_gpu);
test_device!(broadcast, broadcast_cpu, broadcast_gpu);
test_device!(cat, cat_cpu, cat_gpu);
test_device!(sum, sum_cpu, sum_gpu);
test_device!(min, min_cpu, min_gpu);
test_device!(max, max_cpu, max_gpu);
test_device!(argmax, argmax_cpu, argmax_gpu);
test_device!(argmin, argmin_cpu, argmin_gpu);
test_device!(transpose, transpose_cpu, transpose_gpu);
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
test_device!(cmp, cmp_cpu, cmp_gpu);
test_device!(matmul, matmul_cpu, matmul_gpu);
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
test_device!(index_select, index_select_cpu, index_select_gpu);
test_device!(index_add, index_add_cpu, index_add_gpu);
test_device!(gather, gather_cpu, gather_gpu);
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
// There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381
@ -1144,124 +880,3 @@ fn randn_hasneg() -> Result<()> {
}
Ok(())
}
#[test]
fn pad_with_same() -> Result<()> {
let t = Tensor::arange(1f32, 5f32, &Device::Cpu)?.reshape((2, 2))?;
let t0 = t.pad_with_same(0, 1, 2)?;
assert_eq!(
t0.to_vec2::<f32>()?,
[[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0], [3.0, 4.0]]
);
let t1 = t.pad_with_same(1, 1, 2)?;
assert_eq!(
t1.to_vec2::<f32>()?,
[[1.0, 1.0, 2.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0, 4.0]]
);
Ok(())
}
#[test]
fn i64_abs() -> Result<()> {
let t = Tensor::new(&[-42i64, 1337], &Device::Cpu)?;
let t = t.abs()?;
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
Ok(())
}
#[test]
fn tril_triu_eye() -> Result<()> {
let t = Tensor::tril2(4, DType::F32, &Device::Cpu)?;
assert_eq!(
t.to_vec2::<f32>()?,
[
[1.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 0.0],
[1.0, 1.0, 1.0, 1.0]
],
);
let t = Tensor::triu2(4, DType::F32, &Device::Cpu)?;
assert_eq!(
t.to_vec2::<f32>()?,
[
[1.0, 1.0, 1.0, 1.0],
[0.0, 1.0, 1.0, 1.0],
[0.0, 0.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 1.0]
]
);
let t = Tensor::eye(4, DType::F32, &Device::Cpu)?;
assert_eq!(
t.to_vec2::<f32>()?,
[
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0]
]
);
Ok(())
}
#[test]
fn cumsum() -> Result<()> {
let t = &[3f32, 1., 4., 1., 5.];
let t = Tensor::new(t, &Device::Cpu)?;
assert_eq!(t.cumsum(0)?.to_vec1::<f32>()?, [3., 4., 8., 9., 14.]);
let t = t.unsqueeze(1)?;
assert_eq!(
t.cumsum(0)?.to_vec2::<f32>()?,
[[3.0], [4.0], [8.0], [9.0], [14.0]]
);
assert_eq!(
t.cumsum(1)?.to_vec2::<f32>()?,
[[3.0], [1.0], [4.0], [1.0], [5.0]]
);
let t = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let t = Tensor::new(t, &Device::Cpu)?;
assert_eq!(
t.cumsum(1)?.to_vec2::<f32>()?,
[[3.0, 4.0, 8.0, 9.0, 14.0], [2.0, 3.0, 10.0, 18.0, 20.0]],
);
assert_eq!(
t.cumsum(0)?.to_vec2::<f32>()?,
[[3.0, 1.0, 4.0, 1.0, 5.0], [5.0, 2.0, 11.0, 9.0, 7.0]]
);
Ok(())
}
/// A helper function for floating point comparison. Both a and b must be 1D Tensor and contains the same amount of data.
/// Assertion passes if the difference of all pairs of a and b is smaller than epsilon.
fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
let a_vec: Vec<f64> = a.to_vec1()?;
let b_vec: Vec<f64> = b.to_vec1()?;
assert_eq!(a_vec.len(), b_vec.len());
for (a, b) in a_vec.iter().zip(b_vec.iter()) {
assert!((a - b).abs() < epsilon);
}
Ok(())
}
#[test]
fn log_sum_exp() -> Result<()> {
let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
let output = input.log_sum_exp(D::Minus1)?;
// The expectations obtained from pytorch.
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
assert_close(&output, &expected, 0.00001)?;
Ok(())
}
#[test]
fn pow() -> Result<()> {
let lhs = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
let rhs = (&lhs - 2.)?;
let res = lhs.pow(&rhs)?;
assert_eq!(
test_utils::to_vec2_round(&res, 4)?,
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]]
);
Ok(())
}

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -1,10 +1,15 @@
use crate::{Result, Tensor};
#![allow(dead_code)]
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_core::{Result, Tensor};
#[macro_export]
macro_rules! test_device {
// TODO: Switch to generating the two last arguments automatically once concat_idents is
// stable. https://github.com/rust-lang/rust/issues/29599
($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => {
($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => {
#[test]
fn $test_cpu() -> Result<()> {
$fn_name(&Device::Cpu)
@ -15,21 +20,9 @@ macro_rules! test_device {
fn $test_cuda() -> Result<()> {
$fn_name(&Device::new_cuda(0)?)
}
#[cfg(feature = "metal")]
#[test]
fn $test_metal() -> Result<()> {
$fn_name(&Device::new_metal(0)?)
}
};
}
pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
let b = 10f32.powi(digits);
let t = t.to_vec0::<f32>()?;
Ok(f32::round(t * b) / b)
}
pub fn to_vec1_round(t: &Tensor, digits: i32) -> Result<Vec<f32>> {
let b = 10f32.powi(digits);
let t = t.to_vec1::<f32>()?;
@ -47,7 +40,7 @@ pub fn to_vec2_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<f32>>> {
Ok(t)
}
pub fn to_vec3_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
pub fn to_vec3_round(t: Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
let b = 10f32.powi(digits);
let t = t.to_vec3::<f32>()?;
let t = t

Binary file not shown.

View File

@ -11,13 +11,10 @@ readme = "README.md"
[dependencies]
byteorder = { workspace = true }
candle = { workspace = true }
candle-nn = { workspace = true }
candle = { path = "../candle-core", version = "0.1.1", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.1.1" }
hf-hub = { workspace = true}
intel-mkl-src = { workspace = true, optional = true }
memmap2 = { workspace = true }
tokenizers = { workspace = true, features = ["onig"] }
rand = { workspace = true }
thiserror = { workspace = true }
parquet = { workspace = true}
image = { workspace = true }

View File

@ -1,73 +0,0 @@
use hf_hub::{
api::sync::{Api, ApiRepo},
Repo, RepoType,
};
use parquet::file::reader::SerializedFileReader;
use std::fs::File;
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("ApiError : {0}")]
ApiError(#[from] hf_hub::api::sync::ApiError),
#[error("IoError : {0}")]
IoError(#[from] std::io::Error),
#[error("ParquetError : {0}")]
ParquetError(#[from] parquet::errors::ParquetError),
}
fn sibling_to_parquet(
rfilename: &str,
repo: &ApiRepo,
) -> Result<SerializedFileReader<File>, Error> {
let local = repo.get(rfilename)?;
let file = File::open(local)?;
let reader = SerializedFileReader::new(file)?;
Ok(reader)
}
pub fn from_hub(api: &Api, dataset_id: String) -> Result<Vec<SerializedFileReader<File>>, Error> {
let repo = Repo::with_revision(
dataset_id,
RepoType::Dataset,
"refs/convert/parquet".to_string(),
);
let repo = api.repo(repo);
let info = repo.info()?;
let files: Result<Vec<_>, _> = info
.siblings
.into_iter()
.filter_map(|s| -> Option<Result<_, _>> {
let filename = s.rfilename;
if filename.ends_with(".parquet") {
let reader_result = sibling_to_parquet(&filename, &repo);
Some(reader_result)
} else {
None
}
})
.collect();
let files = files?;
Ok(files)
}
#[cfg(test)]
mod tests {
use super::*;
use parquet::file::reader::FileReader;
#[test]
fn test_dataset() {
let api = Api::new().unwrap();
let files = from_hub(
&api,
"hf-internal-testing/dummy_image_text_data".to_string(),
)
.unwrap();
assert_eq!(files.len(), 1);
assert_eq!(files[0].metadata().file_metadata().num_rows(), 20);
}
}

View File

@ -1,6 +1,5 @@
//! Datasets & Dataloaders for Candle
pub mod batcher;
pub mod hub;
pub mod nlp;
pub mod vision;

Some files were not shown because too many files have changed in this diff Show More