mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
1 Commits
bf16_metal
...
initialize
Author | SHA1 | Date | |
---|---|---|---|
0bb344f798 |
@ -1,8 +1,8 @@
|
||||
[build]
|
||||
[target.x86_64-unknown-linux-gnu]
|
||||
rustflags = ["-C", "target-cpu=native"]
|
||||
|
||||
[target.aarch64-apple-darwin]
|
||||
rustflags = ["-C", "target-cpu=native"]
|
||||
|
||||
[target.wasm32-unknown-unknown]
|
||||
rustflags = ["-C", "target-feature=+simd128"]
|
||||
|
||||
[target.x86_64-apple-darwin]
|
||||
rustflags = ["-C", "target-feature=-avx,-avx2"]
|
7
.github/dependabot.yml
vendored
7
.github/dependabot.yml
vendored
@ -1,7 +0,0 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "cargo"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 5
|
72
.github/workflows/ci_cuda.yaml
vendored
72
.github/workflows/ci_cuda.yaml
vendored
@ -5,15 +5,47 @@ on:
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
start-runner:
|
||||
name: Start self-hosted EC2 runner
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AWS_REGION: us-east-1
|
||||
EC2_AMI_ID: ami-03cfed9ea28f4b002
|
||||
EC2_INSTANCE_TYPE: g5.xlarge
|
||||
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
|
||||
EC2_SECURITY_GROUP: sg-030175c435ac141d6
|
||||
outputs:
|
||||
label: ${{ steps.start-ec2-runner.outputs.label }}
|
||||
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
|
||||
steps:
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v1
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
- name: Start EC2 runner
|
||||
id: start-ec2-runner
|
||||
uses: philschmid/philschmid-ec2-github-runner@main
|
||||
with:
|
||||
mode: start
|
||||
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
||||
ec2-image-id: ${{ env.EC2_AMI_ID }}
|
||||
ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }}
|
||||
subnet-id: ${{ env.EC2_SUBNET_ID }}
|
||||
security-group-id: ${{ env.EC2_SECURITY_GROUP }}
|
||||
aws-resource-tags: > # optional, requires additional permissions
|
||||
[
|
||||
{"Key": "Name", "Value": "ec2-tgi-github-runner"},
|
||||
{"Key": "GitHubRepository", "Value": "${{ github.repository }}"}
|
||||
]
|
||||
|
||||
test-cuda:
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
runs-on: [single-gpu, nvidia-gpu, t4, ci]
|
||||
container:
|
||||
image: nvidia/cuda:12.3.1-devel-ubuntu22.04
|
||||
options: --gpus 0
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
|
||||
needs: start-runner # required to start the main job when the runner is ready
|
||||
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
||||
permissions:
|
||||
contents: write
|
||||
packages: write
|
||||
@ -24,10 +56,32 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
- name: Install dependencies
|
||||
run: apt-get update && apt install curl build-essential libssl-dev protobuf-compiler pkg-config -y
|
||||
- name: Install Rust Stable
|
||||
uses: actions-rust-lang/setup-rust-toolchain@v1
|
||||
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- run: apt update -y && apt install libssl-dev -y
|
||||
- name: Test (cuda)
|
||||
run: cargo test --features cuda
|
||||
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
|
||||
stop-runner:
|
||||
name: Stop self-hosted EC2 runner
|
||||
needs:
|
||||
- start-runner
|
||||
- test-cuda
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AWS_REGION: us-east-1
|
||||
if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs
|
||||
steps:
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v1
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
- name: Stop EC2 runner
|
||||
uses: philschmid/philschmid-ec2-github-runner@main
|
||||
with:
|
||||
mode: stop
|
||||
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
||||
label: ${{ needs.start-runner.outputs.label }}
|
||||
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
|
||||
|
BIN
.github/workflows/maturin.yml
vendored
BIN
.github/workflows/maturin.yml
vendored
Binary file not shown.
68
.github/workflows/python.yml
vendored
68
.github/workflows/python.yml
vendored
@ -1,68 +0,0 @@
|
||||
name: PyO3-CI
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- candle-pyo3/**
|
||||
pull_request:
|
||||
paths:
|
||||
- candle-pyo3/**
|
||||
|
||||
jobs:
|
||||
build_and_test:
|
||||
name: Check everything builds & tests
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest] # For now, only test on Linux
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Install Rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: stable
|
||||
|
||||
- name: Install Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.11
|
||||
architecture: "x64"
|
||||
|
||||
- name: Cache Cargo Registry
|
||||
uses: actions/cache@v1
|
||||
with:
|
||||
path: ~/.cargo/registry
|
||||
key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}
|
||||
|
||||
- name: Install Protoc
|
||||
uses: arduino/setup-protoc@v2
|
||||
with:
|
||||
version: "25.0"
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Install
|
||||
working-directory: ./candle-pyo3
|
||||
run: |
|
||||
python -m venv .env
|
||||
source .env/bin/activate
|
||||
pip install -U pip
|
||||
pip install pytest maturin black
|
||||
python -m maturin develop -r --features onnx
|
||||
|
||||
- name: Check style
|
||||
working-directory: ./candle-pyo3
|
||||
run: |
|
||||
source .env/bin/activate
|
||||
python stub.py --check
|
||||
black --check .
|
||||
|
||||
- name: Run tests
|
||||
working-directory: ./candle-pyo3
|
||||
run: |
|
||||
source .env/bin/activate
|
||||
python -m pytest -s -v tests
|
11
.gitignore
vendored
11
.gitignore
vendored
@ -23,16 +23,9 @@ flamegraph.svg
|
||||
*.dylib
|
||||
*.so
|
||||
*.swp
|
||||
*.swo
|
||||
trace-*.json
|
||||
|
||||
candle-wasm-examples/*/build
|
||||
candle-wasm-examples/*/*.bin
|
||||
candle-wasm-examples/*/*.jpeg
|
||||
candle-wasm-examples/*/audios/*.wav
|
||||
candle-wasm-examples/**/*.safetensors
|
||||
candle-wasm-examples/**/*.gguf
|
||||
candle-wasm-examples/*/*.wav
|
||||
candle-wasm-examples/*/*.safetensors
|
||||
candle-wasm-examples/*/package-lock.json
|
||||
candle-wasm-examples/**/config*.json
|
||||
.DS_Store
|
||||
.idea/*
|
||||
|
11
.vscode/settings.json
vendored
11
.vscode/settings.json
vendored
@ -1,11 +0,0 @@
|
||||
{
|
||||
"[python]": {
|
||||
"editor.defaultFormatter": "ms-python.black-formatter"
|
||||
},
|
||||
"python.formatting.provider": "none",
|
||||
"python.testing.pytestArgs": [
|
||||
"candle-pyo3"
|
||||
],
|
||||
"python.testing.unittestEnabled": false,
|
||||
"python.testing.pytestEnabled": true
|
||||
}
|
113
CHANGELOG.md
113
CHANGELOG.md
@ -1,113 +0,0 @@
|
||||
# Changelog
|
||||
This documents the main changes to the `candle` crate.
|
||||
|
||||
## v0.3.1 - Unreleased
|
||||
|
||||
### Added
|
||||
|
||||
### Modified
|
||||
|
||||
## v0.3.0 - 2023-10-01
|
||||
|
||||
### Added
|
||||
|
||||
- Added the Mistral 7b v0.1 model
|
||||
[983](https://github.com/huggingface/candle/pull/983).
|
||||
- Quantized version of the Mistral model
|
||||
[1009](https://github.com/huggingface/candle/pull/1009).
|
||||
- Add the gelu-erf op and activation function
|
||||
[969](https://github.com/huggingface/candle/pull/969).
|
||||
- Add the mixformer/phi-v1.5 model
|
||||
[930](https://github.com/huggingface/candle/pull/930).
|
||||
- Add the sclice-scatter op
|
||||
[927](https://github.com/huggingface/candle/pull/927).
|
||||
- Add the Wuerstchen diffusion model
|
||||
[911](https://github.com/huggingface/candle/pull/911).
|
||||
|
||||
### Modified
|
||||
|
||||
- Support for simd128 intrinsics in some quantized vecdots
|
||||
[982](https://github.com/huggingface/candle/pull/982).
|
||||
- Optimize the index-select cuda kernel
|
||||
[976](https://github.com/huggingface/candle/pull/976).
|
||||
- Self-contained safetensor wrappers
|
||||
[946](https://github.com/huggingface/candle/pull/946).
|
||||
|
||||
## v0.2.2 - 2023-09-18
|
||||
|
||||
### Added
|
||||
- Support for `top_p` sampling
|
||||
[819](https://github.com/huggingface/candle/pull/819).
|
||||
- T5 model including decoding
|
||||
[864](https://github.com/huggingface/candle/pull/864).
|
||||
- 1-d upsampling
|
||||
[839](https://github.com/huggingface/candle/pull/839).
|
||||
|
||||
### Modified
|
||||
- Bugfix for conv2d
|
||||
[820](https://github.com/huggingface/candle/pull/820).
|
||||
- Support tensor based indexing using `.i`
|
||||
[842](https://github.com/huggingface/candle/pull/842).
|
||||
|
||||
## v0.2.1 - 2023-09-11
|
||||
|
||||
### Added
|
||||
- Add some RNNs (GRU and LSTM) in `candle-nn`
|
||||
[674](https://github.com/huggingface/candle/pull/674),
|
||||
[688](https://github.com/huggingface/candle/pull/688).
|
||||
- gguf v2 support
|
||||
[725](https://github.com/huggingface/candle/pull/725).
|
||||
- Quantized llama example in Python using the pyo3 api
|
||||
[716](https://github.com/huggingface/candle/pull/716).
|
||||
- `candle-nn` layer for conv2d-transposed
|
||||
[760](https://github.com/huggingface/candle/pull/760).
|
||||
- Add the Segment-Anything Model (SAM) as an example
|
||||
[773](https://github.com/huggingface/candle/pull/773).
|
||||
- TinyViT backbone for the segment anything example
|
||||
[787](https://github.com/huggingface/candle/pull/787).
|
||||
- Shape with holes support
|
||||
[770](https://github.com/huggingface/candle/pull/770).
|
||||
|
||||
### Modified
|
||||
- Dilations are now supported in conv-transpose2d.
|
||||
[671](https://github.com/huggingface/candle/pull/671).
|
||||
- Interactive mode for the quantized model
|
||||
[690](https://github.com/huggingface/candle/pull/690).
|
||||
- Faster softmax operation
|
||||
[747](https://github.com/huggingface/candle/pull/747).
|
||||
- Faster convolution operations on CPU and CUDA via im2col
|
||||
[802](https://github.com/huggingface/candle/pull/802).
|
||||
- Moving some models to a more central location
|
||||
[796](https://github.com/huggingface/candle/pull/796).
|
||||
|
||||
## v0.2.0 - 2023-08-30
|
||||
|
||||
### Added
|
||||
- Add the powf op
|
||||
[664](https://github.com/huggingface/candle/pull/664).
|
||||
- Stable Diffusion XL support
|
||||
[647](https://github.com/huggingface/candle/pull/647).
|
||||
- Add the conv-transpose2d op
|
||||
[635](https://github.com/huggingface/candle/pull/635).
|
||||
- Refactor the VarBuilder api
|
||||
[627](https://github.com/huggingface/candle/pull/627).
|
||||
- Add some quantization command
|
||||
[625](https://github.com/huggingface/candle/pull/625).
|
||||
- Support more quantized types, e.g. Q2K, Q4K, Q5K...
|
||||
[586](https://github.com/huggingface/candle/pull/586).
|
||||
- Add pose estimation to the yolo example
|
||||
[589](https://github.com/huggingface/candle/pull/589).
|
||||
- Api to write GGUF files
|
||||
[585](https://github.com/huggingface/candle/pull/585).
|
||||
- Support more quantization types
|
||||
[580](https://github.com/huggingface/candle/pull/580).
|
||||
- Add EfficientNet as an example Computer Vision model
|
||||
[572](https://github.com/huggingface/candle/pull/572).
|
||||
- Add a group parameter to convolutions
|
||||
[566](https://github.com/huggingface/candle/pull/566).
|
||||
- New dtype: int64
|
||||
[563](https://github.com/huggingface/candle/pull/563).
|
||||
- Handling of the GGUF file format.
|
||||
[559](https://github.com/huggingface/candle/pull/559).
|
||||
|
||||
## v0.1.2 - 2023-08-21
|
42
Cargo.toml
42
Cargo.toml
@ -3,23 +3,19 @@ members = [
|
||||
"candle-core",
|
||||
"candle-datasets",
|
||||
"candle-examples",
|
||||
"candle-book",
|
||||
"candle-nn",
|
||||
"candle-pyo3",
|
||||
"candle-transformers",
|
||||
"candle-wasm-examples/*",
|
||||
"candle-wasm-tests",
|
||||
"candle-wasm-examples/llama2-c",
|
||||
"candle-wasm-examples/whisper",
|
||||
]
|
||||
exclude = [
|
||||
"candle-flash-attn",
|
||||
"candle-kernels",
|
||||
"candle-metal-kernels",
|
||||
"candle-onnx",
|
||||
"candle-flash-attn",
|
||||
"candle-kernels",
|
||||
]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.4.0"
|
||||
version = "0.1.1"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -31,46 +27,32 @@ license = "MIT OR Apache-2.0"
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.4.0" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.4.0" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.4.0" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.4.0" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.4.0" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.4.0" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.4.0" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.4.0" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.10.0", features = ["f16"] }
|
||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||
hf-hub = "0.3.0"
|
||||
cudarc = { version = "0.9.14", features = ["f16"] }
|
||||
# TODO: Switch back to the official gemm implementation once it has caught up.
|
||||
gemm = { version = "0.15.6", package = "candle-gemm" }
|
||||
hf-hub = "0.2.0"
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
|
||||
imageproc = { version = "0.23.0", default-features = false }
|
||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||
libc = { version = "0.2.147" }
|
||||
log = "0.4"
|
||||
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
|
||||
memmap2 = "0.7.1"
|
||||
num_cpus = "1.15.0"
|
||||
num-traits = "0.2.15"
|
||||
parquet = { version = "50.0.0" }
|
||||
rand = "0.8.5"
|
||||
rand_distr = "0.4.3"
|
||||
rayon = "1.7.0"
|
||||
rusttype = { version = "0.9", default-features = false }
|
||||
safetensors = "0.4.1"
|
||||
safetensors = "0.3.1"
|
||||
serde = { version = "1.0.171", features = ["derive"] }
|
||||
serde_plain = "1.0.2"
|
||||
serde_json = "1.0.99"
|
||||
thiserror = "1"
|
||||
tokenizers = { version = "0.15.0", default-features = false }
|
||||
tokenizers = { version = "0.13.4", default-features = false }
|
||||
tracing = "0.1.37"
|
||||
tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
wav = "1.0.0"
|
||||
yoke = { version = "0.7.2", features = ["derive"] }
|
||||
zip = { version = "0.6.6", default-features = false }
|
||||
metal = { version = "0.27.0", features = ["mps"]}
|
||||
|
||||
[profile.release-with-debug]
|
||||
inherits = "release"
|
||||
|
6
Makefile
6
Makefile
@ -1,5 +1,3 @@
|
||||
.PHONY: clean-ptx clean test
|
||||
|
||||
clean-ptx:
|
||||
find target -name "*.ptx" -type f -delete
|
||||
echo "" > candle-kernels/src/lib.rs
|
||||
@ -13,4 +11,8 @@ clean:
|
||||
test:
|
||||
cargo test
|
||||
|
||||
pyo3-test:
|
||||
cargo build --profile=release-with-debug --package candle-pyo3
|
||||
python3 candle-pyo3/test.py
|
||||
|
||||
all: test
|
||||
|
281
README.md
281
README.md
@ -1,5 +1,5 @@
|
||||
# candle
|
||||
[](https://discord.gg/hugging-face-879548962464493619)
|
||||
[](https://discord.com/channels/879548962464493619/1136218819447238726)
|
||||
[](https://crates.io/crates/candle-core)
|
||||
[](https://docs.rs/candle-core)
|
||||

|
||||
@ -7,167 +7,57 @@
|
||||
Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support)
|
||||
and ease of use. Try our online demos:
|
||||
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
||||
[LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2),
|
||||
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
|
||||
[yolo](https://huggingface.co/spaces/lmz/candle-yolo),
|
||||
[Segment
|
||||
Anything](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
|
||||
[llama2](https://huggingface.co/spaces/lmz/candle-llama2).
|
||||
|
||||
## Get started
|
||||
|
||||
Make sure that you have [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) correctly installed as described in [**Installation**](https://huggingface.github.io/candle/guide/installation.html).
|
||||
|
||||
Let's see how to run a simple matrix multiplication.
|
||||
Write the following to your `myapp/src/main.rs` file:
|
||||
```rust
|
||||
use candle_core::{Device, Tensor};
|
||||
let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
|
||||
let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let device = Device::Cpu;
|
||||
|
||||
let a = Tensor::randn(0f32, 1., (2, 3), &device)?;
|
||||
let b = Tensor::randn(0f32, 1., (3, 4), &device)?;
|
||||
|
||||
let c = a.matmul(&b)?;
|
||||
println!("{c}");
|
||||
Ok(())
|
||||
}
|
||||
let c = a.matmul(&b)?;
|
||||
println!("{c}");
|
||||
```
|
||||
|
||||
`cargo run` should display a tensor of shape `Tensor[[2, 4], f32]`.
|
||||
|
||||
|
||||
Having installed `candle` with Cuda support, simply define the `device` to be on GPU:
|
||||
|
||||
```diff
|
||||
- let device = Device::Cpu;
|
||||
+ let device = Device::new_cuda(0)?;
|
||||
```
|
||||
|
||||
For more advanced examples, please have a look at the following section.
|
||||
|
||||
## Check out our examples
|
||||
|
||||
These online demos run entirely in your browser:
|
||||
- [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and
|
||||
object recognition.
|
||||
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): speech recognition.
|
||||
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
|
||||
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
|
||||
- [Phi-1.5, and Phi-2](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
|
||||
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
|
||||
- [BLIP](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning): image captioning.
|
||||
|
||||
We also provide a some command line based examples using state of the art models:
|
||||
|
||||
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes
|
||||
the SOLAR-10.7B variant.
|
||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
|
||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||
pre-trained on 1T tokens of English and code datasets. Also supports
|
||||
StableLM-2, a 1.6b LLM trained on 2T tokens, as well as the code variants.
|
||||
- [Mamba](./candle-examples/examples/mamba/): an inference only
|
||||
implementation of the Mamba state space model.
|
||||
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
||||
better performance than all publicly available 13b models as of 2023-09-28.
|
||||
- [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of
|
||||
experts 8x7b general LLM with better performance than a Llama 2 70B model with
|
||||
much faster inference.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
||||
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
||||
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
|
||||
(English/Chinese) general LLMs with 6b and 34b parameters.
|
||||
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
|
||||
the LLaMA model using the same quantization techniques as
|
||||
[llama.cpp](https://github.com/ggerganov/llama.cpp).
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/quantized/assets/aoc.gif" width="600">
|
||||
|
||||
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
|
||||
image generative model, support for the 1.5, 2.1, SDXL 1.0 and Turbo versions.
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200">
|
||||
|
||||
- [Wuerstchen](./candle-examples/examples/wuerstchen/): another text to
|
||||
image generative model.
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/wuerstchen/assets/cat.jpg" width="200">
|
||||
|
||||
- [yolo-v3](./candle-examples/examples/yolo-v3/) and
|
||||
[yolo-v8](./candle-examples/examples/yolo-v8/): object detection and pose
|
||||
estimation models.
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.od.jpg" width="200"><img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.pose.jpg" width="200">
|
||||
- [segment-anything](./candle-examples/examples/segment-anything/): image
|
||||
segmentation model with prompt.
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
|
||||
Check out our [examples](./candle-examples/examples/):
|
||||
|
||||
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
||||
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
|
||||
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
|
||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||
using self-supervision (can be used for imagenet classification, depth
|
||||
evaluation, segmentation).
|
||||
- [VGG](./candle-examples/examples/vgg/),
|
||||
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
|
||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||
generate captions for an image.
|
||||
- [TrOCR](./candle-examples/examples/trocr/): a transformer OCR model, with
|
||||
dedicated submodels for hand-writing and printed recognition.
|
||||
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
||||
model, generates the translated text from the input text.
|
||||
- [Llama and Llama-v2](./candle-examples/examples/llama/): general LLM.
|
||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||
- [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code
|
||||
generation.
|
||||
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
|
||||
image generative model, yet to be optimized.
|
||||
|
||||
Run them using commands like:
|
||||
Run them using the following commands:
|
||||
```
|
||||
cargo run --example quantized --release
|
||||
cargo run --example whisper --release
|
||||
cargo run --example llama --release
|
||||
cargo run --example falcon --release
|
||||
cargo run --example bert --release
|
||||
cargo run --example bigcode --release
|
||||
cargo run --example stable-diffusion --release --features image -- --prompt "a rusty robot holding a fire torch"
|
||||
```
|
||||
|
||||
In order to use **CUDA** add `--features cuda` to the example command line. If
|
||||
you have cuDNN installed, use `--features cudnn` for even more speedups.
|
||||
In order to use **CUDA** add `--features cuda` to the example command line.
|
||||
|
||||
There are also some wasm examples for whisper and
|
||||
[llama2.c](https://github.com/karpathy/llama2.c). You can either build them with
|
||||
`trunk` or try them online:
|
||||
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
||||
[llama2](https://huggingface.co/spaces/lmz/candle-llama2),
|
||||
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
|
||||
[Phi-1.5, and Phi-2](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm),
|
||||
[Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
|
||||
[llama2](https://huggingface.co/spaces/lmz/candle-llama2).
|
||||
|
||||
For LLaMA2, run the following command to retrieve the weight files and start a
|
||||
For llama2, run the following command to retrieve the weight files and start a
|
||||
test server:
|
||||
```bash
|
||||
cd candle-wasm-examples/llama2-c
|
||||
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/model.bin
|
||||
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/tokenizer.json
|
||||
trunk serve --release --port 8081
|
||||
trunk serve --release --public-url /candle-llama2/ --port 8081
|
||||
```
|
||||
And then head over to
|
||||
[http://localhost:8081/](http://localhost:8081/).
|
||||
|
||||
<!--- ANCHOR: useful_libraries --->
|
||||
|
||||
## Useful External Resources
|
||||
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): A
|
||||
very detailed tutorial showing how to convert a PyTorch model to Candle.
|
||||
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): Efficient and
|
||||
ergonomic LoRA implementation for Candle. `candle-lora` has
|
||||
out-of-the-box LoRA support for many models from Candle, which can be found
|
||||
[here](https://github.com/EricLBuehler/candle-lora/tree/master/candle-lora-transformers/examples).
|
||||
- [`optimisers`](https://github.com/KGrewal1/optimisers): A collection of optimisers
|
||||
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
|
||||
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
|
||||
serving local LLMs including an OpenAI compatible API server.
|
||||
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
|
||||
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
||||
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
||||
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
|
||||
|
||||
If you have an addition to this list, please submit a pull request.
|
||||
|
||||
<!--- ANCHOR_END: useful_libraries --->
|
||||
[http://localhost:8081/candle-llama2](http://localhost:8081/candle-llama2).
|
||||
|
||||
<!--- ANCHOR: features --->
|
||||
|
||||
@ -180,42 +70,11 @@ If you have an addition to this list, please submit a pull request.
|
||||
- Optimized CPU backend with optional MKL support for x86 and Accelerate for macs.
|
||||
- CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.
|
||||
- WASM support, run your models in a browser.
|
||||
- Included models.
|
||||
- Language Models.
|
||||
- LLaMA v1 and v2 with variants such as SOLAR-10.7B.
|
||||
- Falcon.
|
||||
- StarCoder.
|
||||
- Phi 1, 1.5, and 2.
|
||||
- Mamba, Minimal Mamba
|
||||
- Mistral 7b v0.1.
|
||||
- Mixtral 8x7b v0.1.
|
||||
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
|
||||
- Replit-code-v1.5-3B.
|
||||
- Bert.
|
||||
- Yi-6B and Yi-34B.
|
||||
- Quantized LLMs.
|
||||
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
||||
- Mistral 7b, and 7b instruct.
|
||||
- Mixtral 8x7b.
|
||||
- Zephyr 7b a and b (Mistral-7b based).
|
||||
- OpenChat 3.5 (Mistral-7b based).
|
||||
- Text to text.
|
||||
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
|
||||
- Marian MT (Machine Translation).
|
||||
- Whisper (multi-lingual support).
|
||||
- Text to image.
|
||||
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
||||
- Wurstchen v2.
|
||||
- Image to text.
|
||||
- BLIP.
|
||||
- TrOCR.
|
||||
- Computer Vision Models.
|
||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT.
|
||||
- yolo-v3, yolo-v8.
|
||||
- Segment-Anything Model (SAM).
|
||||
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
||||
- Model support out of the box.
|
||||
- LLMs: Llama v1 and v2, Falcon, StarCoder.
|
||||
- Whisper.
|
||||
- Stable Diffusion.
|
||||
- Serverless (on CPU), small and fast deployments.
|
||||
- Quantization support using the llama.cpp quantized types.
|
||||
|
||||
<!--- ANCHOR_END: features --->
|
||||
|
||||
@ -232,7 +91,7 @@ Cheatsheet:
|
||||
| Operations | `tensor.view((2, 2))` | `tensor.reshape((2, 2))?` |
|
||||
| Operations | `a.matmul(b)` | `a.matmul(&b)?` |
|
||||
| Arithmetic | `a + b` | `&a + &b` |
|
||||
| Device | `tensor.to(device="cuda")` | `tensor.to_device(&Device::new_cuda(0)?)?` |
|
||||
| Device | `tensor.to(device="cuda")` | `tensor.to_device(&Device::Cuda(0))?` |
|
||||
| Dtype | `tensor.to(dtype=torch.float16)` | `tensor.to_dtype(&DType::F16)?` |
|
||||
| Saving | `torch.save({"A": A}, "model.bin")` | `candle::safetensors::save(&HashMap::from([("A", A)]), "model.safetensors")?` |
|
||||
| Loading | `weights = torch.load("model.bin")` | `candle::safetensors::load("model.safetensors", &device)` |
|
||||
@ -249,7 +108,6 @@ Cheatsheet:
|
||||
- [candle-datasets](./candle-datasets/): Datasets and data loaders.
|
||||
- [candle-transformers](./candle-transformers): transformers-related utilities.
|
||||
- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer.
|
||||
- [candle-onnx](./candle-onnx/): ONNX model evaluation.
|
||||
|
||||
## FAQ
|
||||
|
||||
@ -286,97 +144,34 @@ Finally, Rust is cool! A lot of the HF ecosystem already has Rust crates, like [
|
||||
#### Missing symbols when compiling with the mkl feature.
|
||||
|
||||
If you get some missing symbols when compiling binaries/tests using the mkl
|
||||
or accelerate features, e.g. for mkl you get:
|
||||
features, e.g.:
|
||||
```
|
||||
= note: /usr/bin/ld: (....o): in function `blas::sgemm':
|
||||
.../blas-0.22.0/src/lib.rs:1944: undefined reference to `sgemm_' collect2: error: ld returned 1 exit status
|
||||
|
||||
= note: some `extern` functions couldn't be found; some native libraries may need to be installed or have their path specified
|
||||
= note: use the `-l` flag to specify native libraries to link
|
||||
= note: use the `cargo:rustc-link-lib` directive to specify the native libraries to link with Cargo
|
||||
```
|
||||
or for accelerate:
|
||||
```
|
||||
Undefined symbols for architecture arm64:
|
||||
"_dgemm_", referenced from:
|
||||
candle_core::accelerate::dgemm::h1b71a038552bcabe in libcandle_core...
|
||||
"_sgemm_", referenced from:
|
||||
candle_core::accelerate::sgemm::h2cf21c592cba3c47 in libcandle_core...
|
||||
ld: symbol(s) not found for architecture arm64
|
||||
= note: use the `cargo:rustc-link-lib` directive to specify the native libraries to link with Cargo (see https://doc.rust-lang.org/cargo/reference/build-scripts.html#cargorustc-link-libkindname)
|
||||
```
|
||||
|
||||
This is likely due to a missing linker flag that was needed to enable the mkl library. You
|
||||
can try adding the following for mkl at the top of your binary:
|
||||
```rust
|
||||
can try adding the following at the top of your binary:
|
||||
```
|
||||
extern crate intel_mkl_src;
|
||||
```
|
||||
or for accelerate:
|
||||
```rust
|
||||
extern crate accelerate_src;
|
||||
```
|
||||
|
||||
#### Cannot run the LLaMA examples: access to source requires login credentials
|
||||
#### Cannot run llama example : access to source requires login credentials
|
||||
|
||||
```
|
||||
Error: request error: https://huggingface.co/meta-llama/Llama-2-7b-hf/resolve/main/tokenizer.json: status code 401
|
||||
```
|
||||
|
||||
This is likely because you're not permissioned for the LLaMA-v2 model. To fix
|
||||
this, you have to register on the huggingface-hub, accept the [LLaMA-v2 model
|
||||
This is likely because you're not permissioned for the llama-v2 model. To fix
|
||||
this, you have to register on the huggingface-hub, accept the [llama-v2 model
|
||||
conditions](https://huggingface.co/meta-llama/Llama-2-7b-hf), and set up your
|
||||
authentication token. See issue
|
||||
[#350](https://github.com/huggingface/candle/issues/350) for more details.
|
||||
|
||||
#### Missing cute/cutlass headers when compiling flash-attn
|
||||
|
||||
```
|
||||
In file included from kernels/flash_fwd_launch_template.h:11:0,
|
||||
from kernels/flash_fwd_hdim224_fp16_sm80.cu:5:
|
||||
kernels/flash_fwd_kernel.h:8:10: fatal error: cute/algorithm/copy.hpp: No such file or directory
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
^~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
compilation terminated.
|
||||
Error: nvcc error while compiling:
|
||||
```
|
||||
[cutlass](https://github.com/NVIDIA/cutlass) is provided as a git submodule so you may want to run the following command to check it in properly.
|
||||
```bash
|
||||
git submodule update --init
|
||||
```
|
||||
|
||||
#### Compiling with flash-attention fails
|
||||
|
||||
```
|
||||
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ‘...’:
|
||||
```
|
||||
|
||||
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the CANDLE_NVCC_CCBIN environment variable.
|
||||
```
|
||||
env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
|
||||
```
|
||||
|
||||
#### Linking error on windows when running rustdoc or mdbook tests
|
||||
|
||||
```
|
||||
Couldn't compile the test.
|
||||
---- .\candle-book\src\inference\hub.md - Using_the_hub::Using_in_a_real_model_ (line 50) stdout ----
|
||||
error: linking with `link.exe` failed: exit code: 1181
|
||||
//very long chain of linking
|
||||
= note: LINK : fatal error LNK1181: cannot open input file 'windows.0.48.5.lib'
|
||||
```
|
||||
|
||||
Make sure you link all native libraries that might be located outside a project target, e.g., to run mdbook tests, you should run:
|
||||
|
||||
```
|
||||
mdbook test candle-book -L .\target\debug\deps\ `
|
||||
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.42.2\lib `
|
||||
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.48.5\lib
|
||||
```
|
||||
|
||||
#### Extremely slow model load time with WSL
|
||||
|
||||
This may be caused by the models being loaded from `/mnt/c`, more details on
|
||||
[stackoverflow](https://stackoverflow.com/questions/68972448/why-is-wsl-extremely-slow-when-compared-with-native-windows-npm-yarn-processing).
|
||||
|
||||
#### Tracking down errors
|
||||
|
||||
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
|
||||
|
@ -1,49 +0,0 @@
|
||||
[package]
|
||||
name = "candle-book"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
license.workspace = true
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { workspace = true }
|
||||
candle-datasets = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
candle-flash-attn = { workspace = true, optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
image = { workspace = true, optional = true }
|
||||
anyhow = { workspace = true }
|
||||
tokio = "1.29.1"
|
||||
|
||||
[dev-dependencies]
|
||||
byteorder = { workspace = true }
|
||||
hf-hub = { workspace = true, features=["tokio"]}
|
||||
clap = { workspace = true }
|
||||
memmap2 = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-chrome = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||
parquet = { workspace = true }
|
||||
image = { workspace = true }
|
||||
|
||||
[build-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = []
|
@ -10,14 +10,9 @@
|
||||
|
||||
# Reference Guide
|
||||
|
||||
- [Running a model](inference/inference.md)
|
||||
- [Running a model](inference/README.md)
|
||||
- [Using the hub](inference/hub.md)
|
||||
- [Error management](error_manage.md)
|
||||
- [Training](training/training.md)
|
||||
- [Simplified](training/simplified.md)
|
||||
- [MNIST](training/mnist.md)
|
||||
- [Fine-tuning]()
|
||||
- [Serialization]()
|
||||
- [Error management]()
|
||||
- [Advanced Cuda usage]()
|
||||
- [Writing a custom kernel]()
|
||||
- [Porting a custom kernel]()
|
||||
@ -26,3 +21,7 @@
|
||||
- [Creating a WASM app]()
|
||||
- [Creating a REST api webserver]()
|
||||
- [Creating a desktop Tauri app]()
|
||||
- [Training]()
|
||||
- [MNIST]()
|
||||
- [Fine-tuning]()
|
||||
- [Serialization]()
|
||||
|
@ -29,7 +29,7 @@ After adding `RUST_BACKTRACE=1`:
|
||||
Error: WithBacktrace { inner: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }, backtrace: Backtrace [{ fn: "candle::error::Error::bt", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/error.rs", line: 200 }, { fn: "candle::tensor::Tensor::matmul", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/tensor.rs", line: 816 }, { fn: "myapp::main", file: "./src/main.rs", line: 29 }, { fn: "core::ops::function::FnOnce::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 250 }, { fn: "std::sys_common::backtrace::__rust_begin_short_backtrace", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/sys_common/backtrace.rs", line: 135 }, { fn: "std::rt::lang_start::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 166 }, { fn: "core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 284 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::rt::lang_start", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 165 }, { fn: "main" }, { fn: "__libc_start_main" }, { fn: "_start" }] }
|
||||
```
|
||||
|
||||
Not super pretty at the moment, but we can see error occurred on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }`
|
||||
Not super pretty at the moment, but we can see error occured on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }`
|
||||
|
||||
|
||||
Another thing to note, is that since Rust is compiled it is not necessarily as easy to recover proper stacktraces
|
||||
|
@ -6,7 +6,7 @@ Open `src/main.rs` and fill in this content:
|
||||
|
||||
```rust
|
||||
# extern crate candle_core;
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
use candle_core::{DType, Device, Result, Tensor};
|
||||
|
||||
struct Model {
|
||||
first: Tensor,
|
||||
@ -25,11 +25,11 @@ fn main() -> Result<()> {
|
||||
// Use Device::new_cuda(0)?; to use the GPU.
|
||||
let device = Device::Cpu;
|
||||
|
||||
let first = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
|
||||
let second = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
|
||||
let first = Tensor::zeros((784, 100), DType::F32, &device)?;
|
||||
let second = Tensor::zeros((100, 10), DType::F32, &device)?;
|
||||
let model = Model { first, second };
|
||||
|
||||
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
|
||||
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
|
||||
|
||||
let digit = model.forward(&dummy_image)?;
|
||||
println!("Digit {digit:?} digit");
|
||||
@ -50,7 +50,7 @@ the classical `Linear` layer. We can do as such
|
||||
|
||||
```rust
|
||||
# extern crate candle_core;
|
||||
# use candle_core::{Device, Result, Tensor};
|
||||
# use candle_core::{DType, Device, Result, Tensor};
|
||||
struct Linear{
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
@ -80,7 +80,7 @@ This will change the model running code into a new function
|
||||
|
||||
```rust
|
||||
# extern crate candle_core;
|
||||
# use candle_core::{Device, Result, Tensor};
|
||||
# use candle_core::{DType, Device, Result, Tensor};
|
||||
# struct Linear{
|
||||
# weight: Tensor,
|
||||
# bias: Tensor,
|
||||
@ -110,15 +110,15 @@ fn main() -> Result<()> {
|
||||
let device = Device::cuda_if_available(0)?;
|
||||
|
||||
// Creating a dummy model
|
||||
let weight = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
|
||||
let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
|
||||
let weight = Tensor::zeros((784, 100), DType::F32, &device)?;
|
||||
let bias = Tensor::zeros((100, ), DType::F32, &device)?;
|
||||
let first = Linear{weight, bias};
|
||||
let weight = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
|
||||
let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
|
||||
let weight = Tensor::zeros((100, 10), DType::F32, &device)?;
|
||||
let bias = Tensor::zeros((10, ), DType::F32, &device)?;
|
||||
let second = Linear{weight, bias};
|
||||
let model = Model { first, second };
|
||||
|
||||
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
|
||||
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
|
||||
|
||||
// Inference on the model
|
||||
let digit = model.forward(&dummy_image)?;
|
||||
@ -146,8 +146,8 @@ And rewrite our examples using it
|
||||
```rust
|
||||
# extern crate candle_core;
|
||||
# extern crate candle_nn;
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
use candle_nn::{Linear, Module};
|
||||
use candle_core::{DType, Device, Result, Tensor};
|
||||
use candle_nn::Linear;
|
||||
|
||||
struct Model {
|
||||
first: Linear,
|
||||
@ -167,15 +167,15 @@ fn main() -> Result<()> {
|
||||
let device = Device::Cpu;
|
||||
|
||||
// This has changed (784, 100) -> (100, 784) !
|
||||
let weight = Tensor::randn(0f32, 1.0, (100, 784), &device)?;
|
||||
let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
|
||||
let weight = Tensor::zeros((100, 784), DType::F32, &device)?;
|
||||
let bias = Tensor::zeros((100, ), DType::F32, &device)?;
|
||||
let first = Linear::new(weight, Some(bias));
|
||||
let weight = Tensor::randn(0f32, 1.0, (10, 100), &device)?;
|
||||
let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
|
||||
let weight = Tensor::zeros((10, 100), DType::F32, &device)?;
|
||||
let bias = Tensor::zeros((10, ), DType::F32, &device)?;
|
||||
let second = Linear::new(weight, Some(bias));
|
||||
let model = Model { first, second };
|
||||
|
||||
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
|
||||
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
|
||||
|
||||
let digit = model.forward(&dummy_image)?;
|
||||
println!("Digit {digit:?} digit");
|
||||
@ -188,8 +188,8 @@ Feel free to modify this example to use `Conv2d` to create a classical convnet i
|
||||
|
||||
Now that we have the running dummy code we can get to more advanced topics:
|
||||
|
||||
- [For PyTorch users](../guide/cheatsheet.md)
|
||||
- [Running existing models](../inference/inference.md)
|
||||
- [Training models](../training/training.md)
|
||||
- [For PyTorch users](./guide/cheatsheet.md)
|
||||
- [Running existing models](./inference/README.md)
|
||||
- [Training models](./training/README.md)
|
||||
|
||||
|
||||
|
@ -1,46 +1,6 @@
|
||||
# Installation
|
||||
|
||||
**With Cuda support**:
|
||||
|
||||
1. First, make sure that Cuda is correctly installed.
|
||||
- `nvcc --version` should print information about your Cuda compiler driver.
|
||||
- `nvidia-smi --query-gpu=compute_cap --format=csv` should print your GPUs compute capability, e.g. something
|
||||
like:
|
||||
|
||||
```bash
|
||||
compute_cap
|
||||
8.9
|
||||
```
|
||||
|
||||
You can also compile the Cuda kernels for a specific compute cap using the
|
||||
`CUDA_COMPUTE_CAP=<compute cap>` environment variable.
|
||||
|
||||
If any of the above commands errors out, please make sure to update your Cuda version.
|
||||
|
||||
2. Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) with Cuda support.
|
||||
|
||||
Start by creating a new cargo:
|
||||
|
||||
```bash
|
||||
cargo new myapp
|
||||
cd myapp
|
||||
```
|
||||
|
||||
Make sure to add the `candle-core` crate with the cuda feature:
|
||||
|
||||
```bash
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core --features "cuda"
|
||||
```
|
||||
|
||||
Run `cargo build` to make sure everything can be correctly built.
|
||||
|
||||
```bash
|
||||
cargo build
|
||||
```
|
||||
|
||||
**Without Cuda support**:
|
||||
|
||||
Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) as follows:
|
||||
Start by creating a new app:
|
||||
|
||||
```bash
|
||||
cargo new myapp
|
||||
@ -48,12 +8,17 @@ cd myapp
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core
|
||||
```
|
||||
|
||||
Finally, run `cargo build` to make sure everything can be correctly built.
|
||||
At this point, candle will be built **without** CUDA support.
|
||||
To get CUDA support use the `cuda` feature
|
||||
```bash
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core --features cuda
|
||||
```
|
||||
|
||||
You can check everything works properly:
|
||||
|
||||
```bash
|
||||
cargo build
|
||||
```
|
||||
|
||||
**With mkl support**
|
||||
|
||||
You can also see the `mkl` feature which could be interesting to get faster inference on CPU. [Using mkl](./advanced/mkl.md)
|
||||
|
@ -39,7 +39,7 @@ cargo add hf-hub --features tokio
|
||||
```rust,ignore
|
||||
# This is tested directly in examples crate because it needs external dependencies unfortunately:
|
||||
# See [this](https://github.com/rust-lang/mdBook/issues/706)
|
||||
{{#include ../lib.rs:book_hub_1}}
|
||||
{{#include ../../../candle-examples/src/lib.rs:book_hub_1}}
|
||||
```
|
||||
|
||||
|
||||
@ -58,7 +58,7 @@ Now that we have our weights, we can use them in our bert architecture:
|
||||
#
|
||||
# let weights = repo.get("model.safetensors").unwrap();
|
||||
use candle_core::{Device, Tensor, DType};
|
||||
use candle_nn::{Linear, Module};
|
||||
use candle_nn::Linear;
|
||||
|
||||
let weights = candle_core::safetensors::load(weights, &Device::Cpu).unwrap();
|
||||
|
||||
@ -81,7 +81,7 @@ For more efficient loading, instead of reading the file, you could use [`memmap2
|
||||
and will definitely be slower on network mounted disk, because it will issue more read calls.
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../lib.rs:book_hub_2}}
|
||||
{{#include ../../../candle-examples/src/lib.rs:book_hub_2}}
|
||||
```
|
||||
|
||||
**Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety).
|
||||
@ -100,5 +100,5 @@ cargo add safetensors
|
||||
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../lib.rs:book_hub_3}}
|
||||
{{#include ../../../candle-examples/src/lib.rs:book_hub_3}}
|
||||
```
|
||||
|
@ -1,199 +0,0 @@
|
||||
#[cfg(test)]
|
||||
pub mod simplified;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Device, Tensor};
|
||||
use parquet::file::reader::SerializedFileReader;
|
||||
|
||||
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
|
||||
#[rustfmt::skip]
|
||||
#[tokio::test]
|
||||
async fn book_hub_1() {
|
||||
// ANCHOR: book_hub_1
|
||||
use candle::Device;
|
||||
use hf_hub::api::tokio::Api;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
let repo = api.model("bert-base-uncased".to_string());
|
||||
|
||||
let weights_filename = repo.get("model.safetensors").await.unwrap();
|
||||
|
||||
let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
|
||||
// ANCHOR_END: book_hub_1
|
||||
assert_eq!(weights.len(), 206);
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_hub_2() {
|
||||
{
|
||||
// ANCHOR: book_hub_2
|
||||
use candle::Device;
|
||||
use hf_hub::api::sync::Api;
|
||||
use memmap2::Mmap;
|
||||
use std::fs;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
let repo = api.model("bert-base-uncased".to_string());
|
||||
let weights_filename = repo.get("model.safetensors").unwrap();
|
||||
|
||||
let file = fs::File::open(weights_filename).unwrap();
|
||||
let mmap = unsafe { Mmap::map(&file).unwrap() };
|
||||
let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
|
||||
// ANCHOR_END: book_hub_2
|
||||
assert_eq!(weights.len(), 206);
|
||||
}
|
||||
|
||||
// #[rustfmt::skip]
|
||||
// #[test]
|
||||
// fn book_hub_3() {
|
||||
{
|
||||
// ANCHOR: book_hub_3
|
||||
use candle::{DType, Device, Tensor};
|
||||
use hf_hub::api::sync::Api;
|
||||
use memmap2::Mmap;
|
||||
use safetensors::slice::IndexOp;
|
||||
use safetensors::SafeTensors;
|
||||
use std::fs;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
let repo = api.model("bert-base-uncased".to_string());
|
||||
let weights_filename = repo.get("model.safetensors").unwrap();
|
||||
|
||||
let file = fs::File::open(weights_filename).unwrap();
|
||||
let mmap = unsafe { Mmap::map(&file).unwrap() };
|
||||
|
||||
// Use safetensors directly
|
||||
let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
|
||||
let view = tensors
|
||||
.tensor("bert.encoder.layer.0.attention.self.query.weight")
|
||||
.unwrap();
|
||||
|
||||
// We're going to load shard with rank 1, within a world_size of 4
|
||||
// We're going to split along dimension 0 doing VIEW[start..stop, :]
|
||||
let rank = 1;
|
||||
let world_size = 4;
|
||||
let dim = 0;
|
||||
let dtype = view.dtype();
|
||||
let mut tp_shape = view.shape().to_vec();
|
||||
let size = tp_shape[0];
|
||||
|
||||
if size % world_size != 0 {
|
||||
panic!("The dimension is not divisble by `world_size`");
|
||||
}
|
||||
let block_size = size / world_size;
|
||||
let start = rank * block_size;
|
||||
let stop = (rank + 1) * block_size;
|
||||
|
||||
// Everything is expressed in tensor dimension
|
||||
// bytes offsets is handled automatically for safetensors.
|
||||
|
||||
let iterator = view.slice(start..stop).unwrap();
|
||||
|
||||
tp_shape[dim] = block_size;
|
||||
|
||||
// Convert safetensors Dtype to candle DType
|
||||
let dtype: DType = dtype.try_into().unwrap();
|
||||
|
||||
// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
|
||||
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
|
||||
let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
|
||||
// ANCHOR_END: book_hub_3
|
||||
assert_eq!(view.shape(), &[768, 768]);
|
||||
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
||||
}
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_training_1() -> Result<()>{
|
||||
// ANCHOR: book_training_1
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
|
||||
let dataset_id = "mnist".to_string();
|
||||
|
||||
let api = Api::new()?;
|
||||
let repo = Repo::with_revision(
|
||||
dataset_id,
|
||||
RepoType::Dataset,
|
||||
"refs/convert/parquet".to_string(),
|
||||
);
|
||||
let repo = api.repo(repo);
|
||||
let test_parquet_filename = repo.get("mnist/test/0000.parquet")?;
|
||||
let train_parquet_filename = repo.get("mnist/train/0000.parquet")?;
|
||||
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
|
||||
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
|
||||
// ANCHOR_END: book_training_1
|
||||
// Ignore unused
|
||||
let _train = train_parquet;
|
||||
// ANCHOR: book_training_2
|
||||
for row in test_parquet {
|
||||
for (idx, (name, field)) in row?.get_column_iter().enumerate() {
|
||||
println!("Column id {idx}, name {name}, value {field}");
|
||||
}
|
||||
}
|
||||
// ANCHOR_END: book_training_2
|
||||
let test_parquet_filename = repo.get("mnist/test/0000.parquet")?;
|
||||
let train_parquet_filename = repo.get("mnist/train/0000.parquet")?;
|
||||
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
|
||||
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
|
||||
// ANCHOR: book_training_3
|
||||
|
||||
let test_samples = 10_000;
|
||||
let mut test_buffer_images: Vec<u8> = Vec::with_capacity(test_samples * 784);
|
||||
let mut test_buffer_labels: Vec<u8> = Vec::with_capacity(test_samples);
|
||||
for row in test_parquet{
|
||||
for (_name, field) in row?.get_column_iter() {
|
||||
if let parquet::record::Field::Group(subrow) = field {
|
||||
for (_name, field) in subrow.get_column_iter() {
|
||||
if let parquet::record::Field::Bytes(value) = field {
|
||||
let image = image::load_from_memory(value.data()).unwrap();
|
||||
test_buffer_images.extend(image.to_luma8().as_raw());
|
||||
}
|
||||
}
|
||||
}else if let parquet::record::Field::Long(label) = field {
|
||||
test_buffer_labels.push(*label as u8);
|
||||
}
|
||||
}
|
||||
}
|
||||
let test_images = (Tensor::from_vec(test_buffer_images, (test_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
|
||||
let test_labels = Tensor::from_vec(test_buffer_labels, (test_samples, ), &Device::Cpu)?;
|
||||
|
||||
let train_samples = 60_000;
|
||||
let mut train_buffer_images: Vec<u8> = Vec::with_capacity(train_samples * 784);
|
||||
let mut train_buffer_labels: Vec<u8> = Vec::with_capacity(train_samples);
|
||||
for row in train_parquet{
|
||||
for (_name, field) in row?.get_column_iter() {
|
||||
if let parquet::record::Field::Group(subrow) = field {
|
||||
for (_name, field) in subrow.get_column_iter() {
|
||||
if let parquet::record::Field::Bytes(value) = field {
|
||||
let image = image::load_from_memory(value.data()).unwrap();
|
||||
train_buffer_images.extend(image.to_luma8().as_raw());
|
||||
}
|
||||
}
|
||||
}else if let parquet::record::Field::Long(label) = field {
|
||||
train_buffer_labels.push(*label as u8);
|
||||
}
|
||||
}
|
||||
}
|
||||
let train_images = (Tensor::from_vec(train_buffer_images, (train_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
|
||||
let train_labels = Tensor::from_vec(train_buffer_labels, (train_samples, ), &Device::Cpu)?;
|
||||
|
||||
let mnist = candle_datasets::vision::Dataset {
|
||||
train_images,
|
||||
train_labels,
|
||||
test_images,
|
||||
test_labels,
|
||||
labels: 10,
|
||||
};
|
||||
|
||||
// ANCHOR_END: book_training_3
|
||||
assert_eq!(mnist.test_images.dims(), &[10_000, 784]);
|
||||
assert_eq!(mnist.test_labels.dims(), &[10_000]);
|
||||
assert_eq!(mnist.train_images.dims(), &[60_000, 784]);
|
||||
assert_eq!(mnist.train_labels.dims(), &[60_000]);
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -1,196 +0,0 @@
|
||||
//! #A simplified example in Rust of training a neural network and then using it based on the Candle Framework by Hugging Face.
|
||||
//! Author: Evgeny Igumnov 2023 igumnovnsk@gmail.com
|
||||
//! This program implements a neural network to predict the winner of the second round of elections based on the results of the first round.
|
||||
//!
|
||||
//! ##Basic moments:
|
||||
//!
|
||||
//! A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons.
|
||||
//! The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage.
|
||||
//! The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose.
|
||||
//! For training, samples with real data on the results of the first and second stages of different elections are used.
|
||||
//! The model is trained by backpropagation using gradient descent and the cross-entropy loss function.
|
||||
//! Model parameters (weights of neurons) are initialized randomly, then optimized during training.
|
||||
//! After training, the model is tested on a deferred sample to evaluate the accuracy.
|
||||
//! If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated.
|
||||
//! Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data.
|
||||
|
||||
#[rustfmt::skip]
|
||||
mod tests {
|
||||
|
||||
use candle::{DType, Result, Tensor, D, Device};
|
||||
use candle_nn::{loss, ops, Linear, Module, VarBuilder, VarMap, Optimizer};
|
||||
|
||||
// ANCHOR: book_training_simplified1
|
||||
const VOTE_DIM: usize = 2;
|
||||
const RESULTS: usize = 1;
|
||||
const EPOCHS: usize = 10;
|
||||
const LAYER1_OUT_SIZE: usize = 4;
|
||||
const LAYER2_OUT_SIZE: usize = 2;
|
||||
const LEARNING_RATE: f64 = 0.05;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Dataset {
|
||||
pub train_votes: Tensor,
|
||||
pub train_results: Tensor,
|
||||
pub test_votes: Tensor,
|
||||
pub test_results: Tensor,
|
||||
}
|
||||
|
||||
struct MultiLevelPerceptron {
|
||||
ln1: Linear,
|
||||
ln2: Linear,
|
||||
ln3: Linear,
|
||||
}
|
||||
|
||||
impl MultiLevelPerceptron {
|
||||
fn new(vs: VarBuilder) -> Result<Self> {
|
||||
let ln1 = candle_nn::linear(VOTE_DIM, LAYER1_OUT_SIZE, vs.pp("ln1"))?;
|
||||
let ln2 = candle_nn::linear(LAYER1_OUT_SIZE, LAYER2_OUT_SIZE, vs.pp("ln2"))?;
|
||||
let ln3 = candle_nn::linear(LAYER2_OUT_SIZE, RESULTS + 1, vs.pp("ln3"))?;
|
||||
Ok(Self { ln1, ln2, ln3 })
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.ln1.forward(xs)?;
|
||||
let xs = xs.relu()?;
|
||||
let xs = self.ln2.forward(&xs)?;
|
||||
let xs = xs.relu()?;
|
||||
self.ln3.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
// ANCHOR_END: book_training_simplified1
|
||||
|
||||
|
||||
|
||||
// ANCHOR: book_training_simplified3
|
||||
#[tokio::test]
|
||||
async fn simplified() -> anyhow::Result<()> {
|
||||
|
||||
let dev = Device::cuda_if_available(0)?;
|
||||
|
||||
let train_votes_vec: Vec<u32> = vec![
|
||||
15, 10,
|
||||
10, 15,
|
||||
5, 12,
|
||||
30, 20,
|
||||
16, 12,
|
||||
13, 25,
|
||||
6, 14,
|
||||
31, 21,
|
||||
];
|
||||
let train_votes_tensor = Tensor::from_vec(train_votes_vec.clone(), (train_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
|
||||
|
||||
let train_results_vec: Vec<u32> = vec![
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
];
|
||||
let train_results_tensor = Tensor::from_vec(train_results_vec, train_votes_vec.len() / VOTE_DIM, &dev)?;
|
||||
|
||||
let test_votes_vec: Vec<u32> = vec![
|
||||
13, 9,
|
||||
8, 14,
|
||||
3, 10,
|
||||
];
|
||||
let test_votes_tensor = Tensor::from_vec(test_votes_vec.clone(), (test_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
|
||||
|
||||
let test_results_vec: Vec<u32> = vec![
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
];
|
||||
let test_results_tensor = Tensor::from_vec(test_results_vec.clone(), test_results_vec.len(), &dev)?;
|
||||
|
||||
let m = Dataset {
|
||||
train_votes: train_votes_tensor,
|
||||
train_results: train_results_tensor,
|
||||
test_votes: test_votes_tensor,
|
||||
test_results: test_results_tensor,
|
||||
};
|
||||
|
||||
let trained_model: MultiLevelPerceptron;
|
||||
loop {
|
||||
println!("Trying to train neural network.");
|
||||
match train(m.clone(), &dev) {
|
||||
Ok(model) => {
|
||||
trained_model = model;
|
||||
break;
|
||||
},
|
||||
Err(e) => {
|
||||
println!("Error: {}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
let real_world_votes: Vec<u32> = vec![
|
||||
13, 22,
|
||||
];
|
||||
|
||||
let tensor_test_votes = Tensor::from_vec(real_world_votes.clone(), (1, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
|
||||
|
||||
let final_result = trained_model.forward(&tensor_test_votes)?;
|
||||
|
||||
let result = final_result
|
||||
.argmax(D::Minus1)?
|
||||
.to_dtype(DType::F32)?
|
||||
.get(0).map(|x| x.to_scalar::<f32>())??;
|
||||
println!("real_life_votes: {:?}", real_world_votes);
|
||||
println!("neural_network_prediction_result: {:?}", result);
|
||||
|
||||
Ok(())
|
||||
|
||||
}
|
||||
// ANCHOR_END: book_training_simplified3
|
||||
|
||||
// ANCHOR: book_training_simplified2
|
||||
fn train(m: Dataset, dev: &Device) -> anyhow::Result<MultiLevelPerceptron> {
|
||||
let train_results = m.train_results.to_device(dev)?;
|
||||
let train_votes = m.train_votes.to_device(dev)?;
|
||||
let varmap = VarMap::new();
|
||||
let vs = VarBuilder::from_varmap(&varmap, DType::F32, dev);
|
||||
let model = MultiLevelPerceptron::new(vs.clone())?;
|
||||
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), LEARNING_RATE)?;
|
||||
let test_votes = m.test_votes.to_device(dev)?;
|
||||
let test_results = m.test_results.to_device(dev)?;
|
||||
let mut final_accuracy: f32 = 0.0;
|
||||
for epoch in 1..EPOCHS + 1 {
|
||||
let logits = model.forward(&train_votes)?;
|
||||
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
||||
let loss = loss::nll(&log_sm, &train_results)?;
|
||||
sgd.backward_step(&loss)?;
|
||||
|
||||
let test_logits = model.forward(&test_votes)?;
|
||||
let sum_ok = test_logits
|
||||
.argmax(D::Minus1)?
|
||||
.eq(&test_results)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_scalar::<f32>()?;
|
||||
let test_accuracy = sum_ok / test_results.dims1()? as f32;
|
||||
final_accuracy = 100. * test_accuracy;
|
||||
println!("Epoch: {epoch:3} Train loss: {:8.5} Test accuracy: {:5.2}%",
|
||||
loss.to_scalar::<f32>()?,
|
||||
final_accuracy
|
||||
);
|
||||
if final_accuracy == 100.0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if final_accuracy < 100.0 {
|
||||
Err(anyhow::Error::msg("The model is not trained well enough."))
|
||||
} else {
|
||||
Ok(model)
|
||||
}
|
||||
}
|
||||
// ANCHOR_END: book_training_simplified2
|
||||
|
||||
|
||||
}
|
1
candle-book/src/training/README.md
Normal file
1
candle-book/src/training/README.md
Normal file
@ -0,0 +1 @@
|
||||
# Training
|
@ -1,10 +1 @@
|
||||
# MNIST
|
||||
|
||||
So we now have downloaded the MNIST parquet files, let's put them in a simple struct.
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../lib.rs:book_training_3}}
|
||||
```
|
||||
|
||||
The parsing of the file and putting it into single tensors requires the dataset to fit the entire memory.
|
||||
It is quite rudimentary, but simple enough for a small dataset like MNIST.
|
||||
|
@ -1,45 +0,0 @@
|
||||
# Simplified
|
||||
|
||||
## How its works
|
||||
|
||||
This program implements a neural network to predict the winner of the second round of elections based on the results of the first round.
|
||||
|
||||
Basic moments:
|
||||
|
||||
1. A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons.
|
||||
2. The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage.
|
||||
3. The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose.
|
||||
4. For training, samples with real data on the results of the first and second stages of different elections are used.
|
||||
5. The model is trained by backpropagation using gradient descent and the cross-entropy loss function.
|
||||
6. Model parameters (weights of neurons) are initialized randomly, then optimized during training.
|
||||
7. After training, the model is tested on a deferred sample to evaluate the accuracy.
|
||||
8. If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated.
|
||||
|
||||
Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data.
|
||||
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../simplified.rs:book_training_simplified1}}
|
||||
```
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../simplified.rs:book_training_simplified2}}
|
||||
```
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../simplified.rs:book_training_simplified3}}
|
||||
```
|
||||
|
||||
|
||||
## Example output
|
||||
|
||||
```bash
|
||||
Trying to train neural network.
|
||||
Epoch: 1 Train loss: 4.42555 Test accuracy: 0.00%
|
||||
Epoch: 2 Train loss: 0.84677 Test accuracy: 33.33%
|
||||
Epoch: 3 Train loss: 2.54335 Test accuracy: 33.33%
|
||||
Epoch: 4 Train loss: 0.37806 Test accuracy: 33.33%
|
||||
Epoch: 5 Train loss: 0.36647 Test accuracy: 100.00%
|
||||
real_life_votes: [13, 22]
|
||||
neural_network_prediction_result: 0.0
|
||||
```
|
@ -1,39 +0,0 @@
|
||||
# Training
|
||||
|
||||
|
||||
Training starts with data. We're going to use the huggingface hub and
|
||||
start with the Hello world dataset of machine learning, MNIST.
|
||||
|
||||
Let's start with downloading `MNIST` from [huggingface](https://huggingface.co/datasets/mnist).
|
||||
|
||||
This requires [`hf-hub`](https://github.com/huggingface/hf-hub).
|
||||
```bash
|
||||
cargo add hf-hub
|
||||
```
|
||||
|
||||
This is going to be very hands-on for now.
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../../../candle-examples/src/lib.rs:book_training_1}}
|
||||
```
|
||||
|
||||
This uses the standardized `parquet` files from the `refs/convert/parquet` branch on every dataset.
|
||||
Our handles are now [`parquet::file::serialized_reader::SerializedFileReader`].
|
||||
|
||||
We can inspect the content of the files with:
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../../../candle-examples/src/lib.rs:book_training_2}}
|
||||
```
|
||||
|
||||
You should see something like:
|
||||
|
||||
```bash
|
||||
Column id 1, name label, value 6
|
||||
Column id 0, name image, value {bytes: [137, ....]
|
||||
Column id 1, name label, value 8
|
||||
Column id 0, name image, value {bytes: [137, ....]
|
||||
```
|
||||
|
||||
So each row contains 2 columns (image, label) with image being saved as bytes.
|
||||
Let's put them into a useful struct.
|
@ -12,9 +12,7 @@ readme = "README.md"
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
candle-kernels = { workspace = true, optional = true }
|
||||
candle-metal-kernels = { workspace = true, optional = true }
|
||||
metal = { workspace = true, optional = true}
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.1.1", optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
gemm = { workspace = true }
|
||||
half = { workspace = true }
|
||||
@ -28,14 +26,11 @@ rand_distr = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
yoke = { workspace = true }
|
||||
zip = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
criterion = { workspace = true }
|
||||
|
||||
|
||||
[features]
|
||||
default = []
|
||||
@ -43,8 +38,3 @@ cuda = ["cudarc", "dep:candle-kernels"]
|
||||
cudnn = ["cuda", "cudarc/cudnn"]
|
||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||
metal = ["dep:metal", "dep:candle-metal-kernels"]
|
||||
|
||||
[[bench]]
|
||||
name = "bench_main"
|
||||
harness = false
|
||||
|
@ -1,9 +0,0 @@
|
||||
mod benchmarks;
|
||||
|
||||
use criterion::criterion_main;
|
||||
criterion_main!(
|
||||
benchmarks::affine::benches,
|
||||
benchmarks::matmul::benches,
|
||||
benchmarks::random::benches,
|
||||
benchmarks::where_cond::benches
|
||||
);
|
@ -1,43 +0,0 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(a: &Tensor) {
|
||||
a.affine(12.34, 56.78).unwrap();
|
||||
}
|
||||
|
||||
fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let tensor = Tensor::zeros((b, m, k), dtype, &device).unwrap();
|
||||
|
||||
let flops = b * m * k * dtype.size_in_bytes();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(black_box(&tensor));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
run_affine_benchmark(c, &device, DType::F32, "affine_f32");
|
||||
run_affine_benchmark(c, &device, DType::F16, "affine_f16");
|
||||
run_affine_benchmark(c, &device, DType::BF16, "affine_bf16");
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -1,44 +0,0 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(a: &Tensor, b: &Tensor) {
|
||||
a.matmul(&b.t().unwrap()).unwrap();
|
||||
}
|
||||
|
||||
fn run_bench(c: &mut Criterion, device: &Device) {
|
||||
let b = 1;
|
||||
let m = 1;
|
||||
let n = 2048;
|
||||
let k = 2048;
|
||||
|
||||
let dtype = DType::F32;
|
||||
let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap();
|
||||
let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap();
|
||||
|
||||
let flops = b * m * n * k;
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name("matmul"));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(black_box(&lhs), black_box(&rhs));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
run_bench(c, &device);
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -1,66 +0,0 @@
|
||||
pub(crate) mod affine;
|
||||
pub(crate) mod matmul;
|
||||
pub(crate) mod random;
|
||||
pub(crate) mod where_cond;
|
||||
|
||||
use candle_core::{Device, Result};
|
||||
|
||||
pub(crate) trait BenchDevice {
|
||||
fn sync(&self) -> Result<()>;
|
||||
|
||||
fn bench_name<S: Into<String>>(&self, name: S) -> String;
|
||||
}
|
||||
|
||||
impl BenchDevice for Device {
|
||||
fn sync(&self) -> Result<()> {
|
||||
match self {
|
||||
Device::Cpu => Ok(()),
|
||||
Device::Cuda(device) => {
|
||||
#[cfg(feature = "cuda")]
|
||||
return Ok(device.synchronize()?);
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
panic!("Cuda device without cuda feature enabled: {:?}", device)
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
#[cfg(feature = "metal")]
|
||||
return Ok(device.wait_until_completed()?);
|
||||
#[cfg(not(feature = "metal"))]
|
||||
panic!("Metal device without metal feature enabled: {:?}", device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn bench_name<S: Into<String>>(&self, name: S) -> String {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let cpu_type = if cfg!(feature = "accelerate") {
|
||||
"accelerate"
|
||||
} else if cfg!(feature = "mkl") {
|
||||
"mkl"
|
||||
} else {
|
||||
"cpu"
|
||||
};
|
||||
format!("{}_{}", cpu_type, name.into())
|
||||
}
|
||||
Device::Cuda(_) => format!("cuda_{}", name.into()),
|
||||
Device::Metal(_) => format!("metal_{}", name.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct BenchDeviceHandler {
|
||||
devices: Vec<Device>,
|
||||
}
|
||||
|
||||
impl BenchDeviceHandler {
|
||||
pub fn new() -> Result<Self> {
|
||||
let mut devices = Vec::new();
|
||||
if cfg!(feature = "metal") {
|
||||
devices.push(Device::new_metal(0)?);
|
||||
} else if cfg!(feature = "cuda") {
|
||||
devices.push(Device::new_cuda(0)?);
|
||||
}
|
||||
devices.push(Device::Cpu);
|
||||
Ok(Self { devices })
|
||||
}
|
||||
}
|
@ -1,63 +0,0 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn rand_uniform(a: &Tensor) {
|
||||
a.rand_like(-1.0, 123.0).unwrap();
|
||||
}
|
||||
|
||||
fn rand_normal(a: &Tensor) {
|
||||
a.randn_like(100.0, 15.0).unwrap();
|
||||
}
|
||||
|
||||
fn run_random_bench(c: &mut Criterion, device: &Device) {
|
||||
let b = 1;
|
||||
|
||||
let rows = 2048;
|
||||
let cols = 2048;
|
||||
|
||||
let dtype = DType::F32;
|
||||
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
|
||||
|
||||
let flops = b * rows * cols * dtype.size_in_bytes();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name("random_uniform"));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |benches| {
|
||||
benches.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
rand_uniform(black_box(&tensor));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
|
||||
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name("random_normal"));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |benches| {
|
||||
benches.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
rand_normal(black_box(&tensor));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
run_random_bench(c, &device);
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -1,64 +0,0 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(a: &Tensor, b: &Tensor, c: &Tensor) {
|
||||
a.where_cond(b, c).unwrap();
|
||||
}
|
||||
|
||||
const fn create_cond_arr<const N: usize>() -> [u8; N] {
|
||||
let mut arr = [0u8; N];
|
||||
let mut i = 0;
|
||||
while i < N {
|
||||
arr[i] = (i % 2) as u8;
|
||||
i += 1;
|
||||
}
|
||||
arr
|
||||
}
|
||||
|
||||
const B: usize = 1;
|
||||
const M: usize = 1024;
|
||||
const K: usize = 1024;
|
||||
const SIZE: usize = B * M * K;
|
||||
|
||||
const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
|
||||
|
||||
fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), &device).unwrap();
|
||||
let on_true = Tensor::ones((B, M, K), dtype, &device).unwrap();
|
||||
let on_false = Tensor::zeros((B, M, K), dtype, &device).unwrap();
|
||||
|
||||
let elements = B * M * K;
|
||||
// E.g. 2 f32 tensors + 1 u8 tensor
|
||||
let flops = (2 * elements * dtype.size_in_bytes()) + elements;
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(
|
||||
black_box(&tensor),
|
||||
black_box(&on_true),
|
||||
black_box(&on_false),
|
||||
);
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let device = BenchDeviceHandler::new().unwrap();
|
||||
for d in device.devices {
|
||||
run_where_cond_benchmark(c, &d, DType::F32, "where_cond_f32");
|
||||
run_where_cond_benchmark(c, &d, DType::BF16, "where_cond_bf16");
|
||||
run_where_cond_benchmark(c, &d, DType::F16, "where_cond_f16");
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -8,10 +8,11 @@ use anyhow::Result;
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?;
|
||||
let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?;
|
||||
let new_a = a.slice_scatter(&b, 1, 2)?;
|
||||
assert_eq!(a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
assert_eq!(new_a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
|
||||
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
|
||||
let start = std::time::Instant::now();
|
||||
let res = inp.conv2d(&w, 0, 1);
|
||||
println!("{:?}", start.elapsed());
|
||||
println!("{res:?}");
|
||||
Ok(())
|
||||
}
|
||||
|
142
candle-core/examples/cpu_benchmarks.rs
Normal file
142
candle-core/examples/cpu_benchmarks.rs
Normal file
@ -0,0 +1,142 @@
|
||||
/// This example contains some simple benchmarks so that it's easy to run them in perf etc.
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle_core::{Device, Result, Tensor, D};
|
||||
use clap::{Parser, Subcommand};
|
||||
|
||||
fn softmax<D: candle_core::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> {
|
||||
let dim = dim.to_index(xs.shape(), "softmax")?;
|
||||
let max = xs.max_keepdim(dim)?;
|
||||
let diff = xs.broadcast_sub(&max)?;
|
||||
let num = diff.exp()?;
|
||||
let den = num.sum_keepdim(dim)?;
|
||||
num.broadcast_div(&den)
|
||||
}
|
||||
|
||||
trait Benchmark {
|
||||
type PreProcessData;
|
||||
type RunResult;
|
||||
|
||||
fn preprocess() -> Result<Self::PreProcessData>;
|
||||
fn run_one(_: &Self::PreProcessData) -> Result<Self::RunResult>;
|
||||
|
||||
const ITERS: usize;
|
||||
}
|
||||
|
||||
// Conv1d example as used in whisper.
|
||||
struct Conv1d;
|
||||
impl Benchmark for Conv1d {
|
||||
type PreProcessData = (Tensor, Tensor);
|
||||
type RunResult = Tensor;
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
let inp = Tensor::randn(0f32, 1., (1, 384, 3000), &Device::Cpu)?;
|
||||
let w = Tensor::randn(0f32, 1., (384, 384, 3), &Device::Cpu)?;
|
||||
Ok((inp, w))
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
d.0.conv1d(&d.1, 0, 1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 5;
|
||||
}
|
||||
|
||||
// Conv2d example as used in stable-diffusion.
|
||||
struct Conv2d;
|
||||
impl Benchmark for Conv2d {
|
||||
type PreProcessData = (Tensor, Tensor);
|
||||
type RunResult = Tensor;
|
||||
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
|
||||
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
|
||||
Ok((inp, w))
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
d.0.conv2d(&d.1, 0, 1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 1;
|
||||
}
|
||||
|
||||
struct Matmul;
|
||||
impl Benchmark for Matmul {
|
||||
type PreProcessData = (Tensor, Tensor);
|
||||
type RunResult = Tensor;
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
let lhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
|
||||
let rhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
|
||||
Ok((lhs, rhs))
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
d.0.matmul(&d.1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 100;
|
||||
}
|
||||
|
||||
struct Softmax;
|
||||
impl Benchmark for Softmax {
|
||||
type PreProcessData = Tensor;
|
||||
type RunResult = Tensor;
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
// Typical whisper tiny size.
|
||||
let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?;
|
||||
Ok(x)
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
softmax(d, D::Minus1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 100;
|
||||
}
|
||||
|
||||
fn run<B: Benchmark>(iters: Option<usize>) -> Result<()> {
|
||||
use std::hint::black_box;
|
||||
|
||||
let iters = iters.unwrap_or(B::ITERS);
|
||||
let d = B::preprocess()?;
|
||||
let start = std::time::Instant::now();
|
||||
for _iter in 0..iters {
|
||||
let _res = black_box(B::run_one(black_box(&d))?);
|
||||
}
|
||||
println!("{:?}", start.elapsed() / iters as u32);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
enum Task {
|
||||
Conv1d,
|
||||
Conv2d,
|
||||
Matmul,
|
||||
Softmax,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
pub struct Args {
|
||||
/// The benchmark to be run.
|
||||
#[command(subcommand)]
|
||||
task: Task,
|
||||
|
||||
#[arg(long)]
|
||||
iters: Option<usize>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
match args.task {
|
||||
Task::Conv1d => run::<Conv1d>(args.iters)?,
|
||||
Task::Conv2d => run::<Conv2d>(args.iters)?,
|
||||
Task::Matmul => run::<Matmul>(args.iters)?,
|
||||
Task::Softmax => run::<Softmax>(args.iters)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -9,21 +9,9 @@ use candle_core::{Device, Tensor};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
let in_t = Tensor::rand(-1f32, 1f32, (1, 3, 12, 7), &device)?;
|
||||
let k_t = Tensor::rand(-1f32, 1f32, (6, 3, 1, 1), &device)?;
|
||||
let out_t = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
|
||||
println!("{out_t}");
|
||||
let in_t = in_t.to_device(&Device::Cpu)?;
|
||||
let k_t = k_t.to_device(&Device::Cpu)?;
|
||||
let out_t2 = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
|
||||
let diff = (out_t.to_device(&Device::Cpu)? - out_t2)?
|
||||
.sqr()?
|
||||
.sum_all()?;
|
||||
println!("{diff}");
|
||||
|
||||
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
|
||||
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
|
||||
let res = t.conv2d(&w, 1, 1, 1, 1)?;
|
||||
let res = t.conv2d(&w, 1, 1)?;
|
||||
println!("{res:?}");
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,389 +0,0 @@
|
||||
use candle_core::quantized::{gguf_file, GgmlDType, QTensor};
|
||||
use candle_core::{Device, Result};
|
||||
use clap::{Parser, Subcommand, ValueEnum};
|
||||
use rayon::prelude::*;
|
||||
|
||||
#[derive(ValueEnum, Debug, Clone)]
|
||||
enum QuantizationMode {
|
||||
/// The default quantization includes all 2d tensors, except the output tensor which always
|
||||
/// uses Q6_K.
|
||||
Llama,
|
||||
}
|
||||
|
||||
impl QuantizationMode {
|
||||
fn quantize(&self, name: &str, tensor: QTensor, dtype: GgmlDType) -> Result<QTensor> {
|
||||
match self {
|
||||
Self::Llama => {
|
||||
// Same behavior as the llama.cpp quantization.
|
||||
let should_quantize = name.ends_with(".weight") && tensor.rank() == 2;
|
||||
if should_quantize {
|
||||
let tensor = tensor.dequantize(&Device::Cpu)?;
|
||||
if name == "output.weight" {
|
||||
QTensor::quantize(&tensor, GgmlDType::Q6K)
|
||||
} else {
|
||||
QTensor::quantize(&tensor, dtype)
|
||||
}
|
||||
} else {
|
||||
Ok(tensor)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(ValueEnum, Debug, Clone)]
|
||||
enum Quantization {
|
||||
#[value(name = "q4_0")]
|
||||
Q4_0,
|
||||
#[value(name = "q4_1")]
|
||||
Q4_1,
|
||||
#[value(name = "q5_0")]
|
||||
Q5_0,
|
||||
#[value(name = "q5_1")]
|
||||
Q5_1,
|
||||
#[value(name = "q8_0")]
|
||||
Q8_0,
|
||||
#[value(name = "q8_1")]
|
||||
Q8_1,
|
||||
Q2k,
|
||||
Q3k,
|
||||
Q4k,
|
||||
Q5k,
|
||||
Q6k,
|
||||
Q8k,
|
||||
F16,
|
||||
F32,
|
||||
}
|
||||
|
||||
impl Quantization {
|
||||
fn dtype(&self) -> GgmlDType {
|
||||
match self {
|
||||
Quantization::Q4_0 => GgmlDType::Q4_0,
|
||||
Quantization::Q4_1 => GgmlDType::Q4_1,
|
||||
Quantization::Q5_0 => GgmlDType::Q5_0,
|
||||
Quantization::Q5_1 => GgmlDType::Q5_1,
|
||||
Quantization::Q8_0 => GgmlDType::Q8_0,
|
||||
Quantization::Q8_1 => GgmlDType::Q8_1,
|
||||
Quantization::Q2k => GgmlDType::Q2K,
|
||||
Quantization::Q3k => GgmlDType::Q3K,
|
||||
Quantization::Q4k => GgmlDType::Q4K,
|
||||
Quantization::Q5k => GgmlDType::Q5K,
|
||||
Quantization::Q6k => GgmlDType::Q6K,
|
||||
Quantization::Q8k => GgmlDType::Q8K,
|
||||
Quantization::F16 => GgmlDType::F16,
|
||||
Quantization::F32 => GgmlDType::F32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(ValueEnum, Debug, Clone)]
|
||||
enum Format {
|
||||
Safetensors,
|
||||
Npz,
|
||||
Ggml,
|
||||
Gguf,
|
||||
Pth,
|
||||
Pickle,
|
||||
}
|
||||
|
||||
impl Format {
|
||||
fn infer<P: AsRef<std::path::Path>>(p: P) -> Option<Self> {
|
||||
p.as_ref()
|
||||
.extension()
|
||||
.and_then(|e| e.to_str())
|
||||
.and_then(|e| match e {
|
||||
// We don't infer any format for .bin as it can be used for ggml/gguf or pytorch.
|
||||
"safetensors" | "safetensor" => Some(Self::Safetensors),
|
||||
"npz" => Some(Self::Npz),
|
||||
"pth" | "pt" => Some(Self::Pth),
|
||||
"ggml" => Some(Self::Ggml),
|
||||
"gguf" => Some(Self::Gguf),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
enum Command {
|
||||
Ls {
|
||||
files: Vec<std::path::PathBuf>,
|
||||
|
||||
/// The file format to use, if unspecified infer from the file extension.
|
||||
#[arg(long, value_enum)]
|
||||
format: Option<Format>,
|
||||
|
||||
/// Enable verbose mode.
|
||||
#[arg(short, long)]
|
||||
verbose: bool,
|
||||
},
|
||||
|
||||
Quantize {
|
||||
/// The input file(s), in safetensors format.
|
||||
in_file: Vec<std::path::PathBuf>,
|
||||
|
||||
/// The output file, in gguf format.
|
||||
#[arg(long)]
|
||||
out_file: std::path::PathBuf,
|
||||
|
||||
/// The quantization schema to apply.
|
||||
#[arg(long, value_enum)]
|
||||
quantization: Quantization,
|
||||
|
||||
/// Which tensor to quantize.
|
||||
#[arg(long, value_enum, default_value_t = QuantizationMode::Llama)]
|
||||
mode: QuantizationMode,
|
||||
},
|
||||
|
||||
Dequantize {
|
||||
/// The input file, in gguf format.
|
||||
in_file: std::path::PathBuf,
|
||||
|
||||
/// The output file, in safetensors format.
|
||||
#[arg(long)]
|
||||
out_file: std::path::PathBuf,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
struct Args {
|
||||
#[command(subcommand)]
|
||||
command: Command,
|
||||
}
|
||||
|
||||
fn run_ls(
|
||||
file: &std::path::PathBuf,
|
||||
format: Option<Format>,
|
||||
verbose: bool,
|
||||
device: &Device,
|
||||
) -> Result<()> {
|
||||
let format = match format {
|
||||
Some(format) => format,
|
||||
None => match Format::infer(file) {
|
||||
Some(format) => format,
|
||||
None => {
|
||||
println!(
|
||||
"{file:?}: cannot infer format from file extension, use the --format flag"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
},
|
||||
};
|
||||
match format {
|
||||
Format::Npz => {
|
||||
let tensors = candle_core::npy::NpzTensors::new(file)?;
|
||||
let mut names = tensors.names();
|
||||
names.sort();
|
||||
for name in names {
|
||||
let shape_dtype = match tensors.get_shape_and_dtype(name) {
|
||||
Ok((shape, dtype)) => format!("[{shape:?}; {dtype:?}]"),
|
||||
Err(err) => err.to_string(),
|
||||
};
|
||||
println!("{name}: {shape_dtype}")
|
||||
}
|
||||
}
|
||||
Format::Safetensors => {
|
||||
let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? };
|
||||
let mut tensors = tensors.tensors();
|
||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
for (name, view) in tensors.iter() {
|
||||
let dtype = view.dtype();
|
||||
let dtype = match candle_core::DType::try_from(dtype) {
|
||||
Ok(dtype) => format!("{dtype:?}"),
|
||||
Err(_) => format!("{dtype:?}"),
|
||||
};
|
||||
let shape = view.shape();
|
||||
println!("{name}: [{shape:?}; {dtype}]")
|
||||
}
|
||||
}
|
||||
Format::Pth => {
|
||||
let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose, None)?;
|
||||
tensors.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
for tensor_info in tensors.iter() {
|
||||
println!(
|
||||
"{}: [{:?}; {:?}]",
|
||||
tensor_info.name,
|
||||
tensor_info.layout.shape(),
|
||||
tensor_info.dtype,
|
||||
);
|
||||
if verbose {
|
||||
println!(" {:?}", tensor_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
Format::Pickle => {
|
||||
let file = std::fs::File::open(file)?;
|
||||
let mut reader = std::io::BufReader::new(file);
|
||||
let mut stack = candle_core::pickle::Stack::empty();
|
||||
stack.read_loop(&mut reader)?;
|
||||
for (i, obj) in stack.stack().iter().enumerate() {
|
||||
println!("{i} {obj:?}");
|
||||
}
|
||||
}
|
||||
Format::Ggml => {
|
||||
let mut file = std::fs::File::open(file)?;
|
||||
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?;
|
||||
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
|
||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
for (name, qtensor) in tensors.iter() {
|
||||
println!("{name}: [{:?}; {:?}]", qtensor.shape(), qtensor.dtype());
|
||||
}
|
||||
}
|
||||
Format::Gguf => {
|
||||
let mut file = std::fs::File::open(file)?;
|
||||
let content = gguf_file::Content::read(&mut file)?;
|
||||
if verbose {
|
||||
let mut metadata = content.metadata.into_iter().collect::<Vec<_>>();
|
||||
metadata.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
println!("metadata entries ({})", metadata.len());
|
||||
for (key, value) in metadata.iter() {
|
||||
println!(" {key}: {value:?}");
|
||||
}
|
||||
}
|
||||
let mut tensors = content.tensor_infos.into_iter().collect::<Vec<_>>();
|
||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
for (name, info) in tensors.iter() {
|
||||
println!("{name}: [{:?}; {:?}]", info.shape, info.ggml_dtype);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_quantize_safetensors(
|
||||
in_files: &[std::path::PathBuf],
|
||||
out_file: std::path::PathBuf,
|
||||
q: Quantization,
|
||||
) -> Result<()> {
|
||||
let mut out_file = std::fs::File::create(out_file)?;
|
||||
let mut tensors = std::collections::HashMap::new();
|
||||
for in_file in in_files.iter() {
|
||||
let in_tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?;
|
||||
tensors.extend(in_tensors)
|
||||
}
|
||||
println!("tensors: {}", tensors.len());
|
||||
|
||||
let dtype = q.dtype();
|
||||
let block_size = dtype.block_size();
|
||||
|
||||
let qtensors = tensors
|
||||
.into_par_iter()
|
||||
.map(|(name, tensor)| {
|
||||
let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0;
|
||||
println!(" quantizing {name} {tensor:?} {should_quantize}");
|
||||
let tensor = if should_quantize {
|
||||
QTensor::quantize(&tensor, dtype)?
|
||||
} else {
|
||||
QTensor::quantize(&tensor, GgmlDType::F32)?
|
||||
};
|
||||
Ok((name, tensor))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let qtensors = qtensors
|
||||
.iter()
|
||||
.map(|(k, v)| (k.as_str(), v))
|
||||
.collect::<Vec<_>>();
|
||||
gguf_file::write(&mut out_file, &[], &qtensors)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_dequantize(
|
||||
in_file: std::path::PathBuf,
|
||||
out_file: std::path::PathBuf,
|
||||
device: &Device,
|
||||
) -> Result<()> {
|
||||
let mut in_file = std::fs::File::open(in_file)?;
|
||||
let content = gguf_file::Content::read(&mut in_file)?;
|
||||
let mut tensors = std::collections::HashMap::new();
|
||||
for (tensor_name, _) in content.tensor_infos.iter() {
|
||||
let tensor = content.tensor(&mut in_file, tensor_name, device)?;
|
||||
let tensor = tensor.dequantize(device)?;
|
||||
tensors.insert(tensor_name.to_string(), tensor);
|
||||
}
|
||||
candle_core::safetensors::save(&tensors, out_file)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_quantize(
|
||||
in_files: &[std::path::PathBuf],
|
||||
out_file: std::path::PathBuf,
|
||||
q: Quantization,
|
||||
qmode: QuantizationMode,
|
||||
device: &Device,
|
||||
) -> Result<()> {
|
||||
if in_files.is_empty() {
|
||||
candle_core::bail!("no specified input files")
|
||||
}
|
||||
if let Some(extension) = out_file.extension() {
|
||||
if extension == "safetensors" {
|
||||
candle_core::bail!("the generated file cannot use the safetensors extension")
|
||||
}
|
||||
}
|
||||
if let Some(extension) = in_files[0].extension() {
|
||||
if extension == "safetensors" {
|
||||
return run_quantize_safetensors(in_files, out_file, q);
|
||||
}
|
||||
}
|
||||
|
||||
if in_files.len() != 1 {
|
||||
candle_core::bail!("only a single in-file can be used when quantizing gguf files")
|
||||
}
|
||||
|
||||
// Open the out file early so as to fail directly on missing directories etc.
|
||||
let mut out_file = std::fs::File::create(out_file)?;
|
||||
let mut in_ = std::fs::File::open(&in_files[0])?;
|
||||
let content = gguf_file::Content::read(&mut in_)?;
|
||||
println!("tensors: {}", content.tensor_infos.len());
|
||||
|
||||
let dtype = q.dtype();
|
||||
let qtensors = content
|
||||
.tensor_infos
|
||||
.par_iter()
|
||||
.map(|(name, _)| {
|
||||
println!(" quantizing {name}");
|
||||
let mut in_file = std::fs::File::open(&in_files[0])?;
|
||||
let tensor = content.tensor(&mut in_file, name, device)?;
|
||||
let tensor = qmode.quantize(name, tensor, dtype)?;
|
||||
Ok((name, tensor))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let qtensors = qtensors
|
||||
.iter()
|
||||
.map(|(k, v)| (k.as_str(), v))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let metadata = content
|
||||
.metadata
|
||||
.iter()
|
||||
.map(|(k, v)| (k.as_str(), v))
|
||||
.collect::<Vec<_>>();
|
||||
gguf_file::write(&mut out_file, metadata.as_slice(), &qtensors)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let device = Device::Cpu;
|
||||
match args.command {
|
||||
Command::Ls {
|
||||
files,
|
||||
format,
|
||||
verbose,
|
||||
} => {
|
||||
let multiple_files = files.len() > 1;
|
||||
for file in files.iter() {
|
||||
if multiple_files {
|
||||
println!("--- {file:?} ---");
|
||||
}
|
||||
run_ls(file, format.clone(), verbose, &device)?
|
||||
}
|
||||
}
|
||||
Command::Quantize {
|
||||
in_file,
|
||||
out_file,
|
||||
quantization,
|
||||
mode,
|
||||
} => run_quantize(&in_file, out_file, quantization, mode, &device)?,
|
||||
Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file, &device)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -50,8 +50,6 @@ mod ffi {
|
||||
pub fn vvcos(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||
pub fn vvlogf(dst: *mut c_float, src: *const c_float, len: *const c_int);
|
||||
pub fn vvlog(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||
pub fn vvtanhf(dst: *mut c_float, src: *const c_float, len: *const c_int);
|
||||
pub fn vvtanh(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||
|
||||
pub fn vDSP_vaddD(
|
||||
_: *const c_double,
|
||||
@ -125,42 +123,6 @@ mod ffi {
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
pub fn vDSP_vminD(
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *mut c_double,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
pub fn vDSP_vmin(
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *mut c_float,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
pub fn vDSP_vmaxD(
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *const c_double,
|
||||
_: c_long,
|
||||
_: *mut c_double,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
pub fn vDSP_vmax(
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *const c_float,
|
||||
_: c_long,
|
||||
_: *mut c_float,
|
||||
_: c_long,
|
||||
_: c_ulong,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -310,26 +272,6 @@ pub fn vd_cos(a: &[f64], y: &mut [f64]) {
|
||||
}
|
||||
unsafe { ffi::vvcos(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
#[inline]
|
||||
pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvtanhf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvtanh(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_ln(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
@ -370,38 +312,6 @@ pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
|
||||
y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_tanh_inplace(y: &mut [f32]) {
|
||||
unsafe { ffi::vvtanhf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_tanh_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
|
||||
}
|
||||
vs_tanh_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = 0.5 * v * (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
|
||||
}
|
||||
vd_tanh_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = 0.5 * v * (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! binary_op {
|
||||
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
|
||||
#[inline]
|
||||
@ -438,7 +348,3 @@ binary_op!(vs_mul, f32, vDSP_vmul);
|
||||
binary_op!(vd_mul, f64, vDSP_vmulD);
|
||||
binary_op!(vs_div, f32, vDSP_vdiv);
|
||||
binary_op!(vd_div, f64, vDSP_vdivD);
|
||||
binary_op!(vs_max, f32, vDSP_vmax);
|
||||
binary_op!(vd_max, f64, vDSP_vmaxD);
|
||||
binary_op!(vs_min, f32, vDSP_vmin);
|
||||
binary_op!(vd_min, f64, vDSP_vminD);
|
||||
|
@ -15,8 +15,6 @@ pub trait BackendStorage: Sized {
|
||||
|
||||
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self>;
|
||||
|
||||
fn powf(&self, _: &Layout, _: f64) -> Result<Self>;
|
||||
|
||||
fn elu(&self, _: &Layout, _: f64) -> Result<Self>;
|
||||
|
||||
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self>;
|
||||
@ -39,14 +37,6 @@ pub trait BackendStorage: Sized {
|
||||
_params: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self>;
|
||||
|
||||
fn conv_transpose1d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self>;
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
@ -55,17 +45,8 @@ pub trait BackendStorage: Sized {
|
||||
_params: &crate::conv::ParamsConv2D,
|
||||
) -> Result<Self>;
|
||||
|
||||
fn conv_transpose2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &crate::conv::ParamsConvTranspose2D,
|
||||
) -> Result<Self>;
|
||||
|
||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
|
||||
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
|
||||
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self>;
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
|
||||
|
||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
|
||||
@ -119,6 +100,4 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||
|
||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||
|
||||
fn set_seed(&self, _: u64) -> Result<()>;
|
||||
}
|
||||
|
@ -15,17 +15,6 @@ fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result
|
||||
}
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static CANDLE_GRAD_DO_NOT_DETACH: bool = {
|
||||
match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") {
|
||||
Ok(s) => {
|
||||
!s.is_empty() && s != "0"
|
||||
},
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
||||
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
||||
@ -47,8 +36,6 @@ impl Tensor {
|
||||
// Do not call recursively on the "leaf" nodes.
|
||||
track_grad = true;
|
||||
nodes
|
||||
} else if node.dtype().is_int() {
|
||||
nodes
|
||||
} else if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::IndexAdd(t1, t2, t3, _)
|
||||
@ -68,27 +55,16 @@ impl Tensor {
|
||||
kernel: rhs,
|
||||
..
|
||||
}
|
||||
| Op::ConvTranspose1D {
|
||||
arg: lhs,
|
||||
kernel: rhs,
|
||||
..
|
||||
}
|
||||
| Op::Conv2D {
|
||||
arg: lhs,
|
||||
kernel: rhs,
|
||||
..
|
||||
}
|
||||
| Op::ConvTranspose2D {
|
||||
arg: lhs,
|
||||
kernel: rhs,
|
||||
..
|
||||
}
|
||||
| Op::CustomOp2(lhs, rhs, _)
|
||||
| Op::Binary(lhs, rhs, _)
|
||||
| Op::Gather(lhs, rhs, _)
|
||||
| Op::IndexSelect(lhs, rhs, _)
|
||||
| Op::Matmul(lhs, rhs)
|
||||
| Op::SliceScatter0(lhs, rhs, _) => {
|
||||
| Op::Matmul(lhs, rhs) => {
|
||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
let (tg, nodes) = walk(rhs, nodes, already_seen);
|
||||
@ -109,40 +85,25 @@ impl Tensor {
|
||||
nodes
|
||||
}
|
||||
}
|
||||
Op::Unary(_node, UnaryOp::Ceil)
|
||||
| Op::Unary(_node, UnaryOp::Floor)
|
||||
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
||||
Op::Reshape(node)
|
||||
| Op::UpsampleNearest1D(node)
|
||||
| Op::UpsampleNearest2D { arg: node, .. }
|
||||
| Op::UpsampleNearest2D(node)
|
||||
| Op::AvgPool2D { arg: node, .. }
|
||||
| Op::MaxPool2D { arg: node, .. }
|
||||
| Op::Copy(node)
|
||||
| Op::Broadcast(node)
|
||||
| Op::Cmp(node, _)
|
||||
| Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
|
||||
| Op::Reduce(node, _, _)
|
||||
| Op::ToDType(node)
|
||||
| Op::ToDevice(node)
|
||||
| Op::Transpose(node, _, _)
|
||||
| Op::Permute(node, _)
|
||||
| Op::Narrow(node, _, _, _)
|
||||
| Op::Unary(node, _)
|
||||
| Op::Elu(node, _)
|
||||
| Op::Powf(node, _)
|
||||
| Op::CustomOp1(node, _) => {
|
||||
let (tg, nodes) = walk(node, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
}
|
||||
Op::ToDType(node) => {
|
||||
if node.dtype().is_float() {
|
||||
let (tg, nodes) = walk(node, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
} else {
|
||||
nodes
|
||||
}
|
||||
}
|
||||
Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
|
||||
}
|
||||
} else {
|
||||
nodes
|
||||
@ -166,16 +127,10 @@ impl Tensor {
|
||||
if node.is_variable() {
|
||||
continue;
|
||||
}
|
||||
let grad = grads
|
||||
.remove(node)
|
||||
.expect("candle internal error - grad not populated");
|
||||
// https://github.com/huggingface/candle/issues/1241
|
||||
// Ideally, we would make these operations in place where possible to ensure that we
|
||||
// do not have to allocate too often. Here we just call `.detach` to avoid computing
|
||||
// the backprop graph of the backprop itself. This would be an issue for second order
|
||||
// derivatives but these are out of scope at the moment.
|
||||
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
|
||||
let grad = if do_not_detach { grad } else { grad.detach() };
|
||||
let grad = grads.remove(node).unwrap();
|
||||
// TODO: We should perform all these operations in place (or at least not track the
|
||||
// whole graph). The only drawback would be if we wanted to support grad of grad but
|
||||
// this is out of scope.
|
||||
if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
||||
@ -206,21 +161,6 @@ impl Tensor {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?;
|
||||
}
|
||||
Op::Binary(lhs, rhs, BinaryOp::Minimum)
|
||||
| Op::Binary(lhs, rhs, BinaryOp::Maximum) => {
|
||||
let mask_lhs = node.eq(lhs)?.to_dtype(grad.dtype())?;
|
||||
let mask_rhs = node.eq(rhs)?.to_dtype(grad.dtype())?;
|
||||
|
||||
// If both masks are 1 one the same point, we want to scale the
|
||||
// gradient by 0.5 rather than 1.
|
||||
let lhs_grad = mask_lhs.mul(&grad)?.div(&(&mask_rhs + 1.)?)?;
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||
|
||||
let rhs_grad = mask_rhs.mul(&grad)?.div(&(&mask_lhs + 1.)?)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
}
|
||||
Op::WhereCond(pred, t, f) => {
|
||||
let zeros = grad.zeros_like()?;
|
||||
let t_sum_grad = grads.or_insert(t)?;
|
||||
@ -230,156 +170,13 @@ impl Tensor {
|
||||
let f_grad = pred.where_cond(&zeros, &grad)?;
|
||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||
}
|
||||
Op::Conv1D {
|
||||
arg,
|
||||
kernel,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
} => {
|
||||
// The output height for conv_transpose1d is:
|
||||
// (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1
|
||||
let grad_l_in = grad.dim(2)?;
|
||||
let k_size = kernel.dim(2)?;
|
||||
let out_size =
|
||||
(grad_l_in - 1) * stride + dilation * (k_size - 1) + 1 - 2 * padding;
|
||||
let out_padding = arg.dim(2)? - out_size;
|
||||
let grad_arg = grad.conv_transpose1d(
|
||||
kernel,
|
||||
*padding,
|
||||
out_padding,
|
||||
*stride,
|
||||
*dilation,
|
||||
)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
|
||||
let grad_kernel = arg
|
||||
.transpose(0, 1)?
|
||||
.conv1d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
||||
.transpose(0, 1)?;
|
||||
let sum_grad = grads.or_insert(kernel)?;
|
||||
let (_, _, k0) = kernel.dims3()?;
|
||||
let (_, _, g_k0) = grad_kernel.dims3()?;
|
||||
let grad_kernel = if g_k0 != k0 {
|
||||
grad_kernel.narrow(2, 0, k0)?
|
||||
} else {
|
||||
grad_kernel
|
||||
};
|
||||
*sum_grad = sum_grad.add(&grad_kernel)?;
|
||||
}
|
||||
Op::Conv2D {
|
||||
arg,
|
||||
kernel,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
} => {
|
||||
// The output height for conv_transpose2d is:
|
||||
// (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1
|
||||
let grad_h = grad.dim(2)?;
|
||||
let k_h = kernel.dim(2)?;
|
||||
let out_size =
|
||||
(grad_h - 1) * stride + dilation * (k_h - 1) + 1 - 2 * padding;
|
||||
let out_padding = arg.dim(2)? - out_size;
|
||||
let grad_arg = grad.conv_transpose2d(
|
||||
kernel,
|
||||
*padding,
|
||||
out_padding,
|
||||
*stride,
|
||||
*dilation,
|
||||
)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
|
||||
let grad_kernel = arg
|
||||
.transpose(0, 1)?
|
||||
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
||||
.transpose(0, 1)?;
|
||||
let sum_grad = grads.or_insert(kernel)?;
|
||||
let (_, _, k0, k1) = kernel.dims4()?;
|
||||
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
|
||||
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
|
||||
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
|
||||
} else {
|
||||
grad_kernel
|
||||
};
|
||||
*sum_grad = sum_grad.add(&grad_kernel)?;
|
||||
}
|
||||
Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "conv-transpose1d",
|
||||
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
||||
Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?,
|
||||
Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?,
|
||||
Op::MaxPool2D { .. } => Err(Error::BackwardNotSupported { op: "max-pool2d" })?,
|
||||
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "upsample-nearest2d",
|
||||
})?,
|
||||
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "conv-transpose2d",
|
||||
})?,
|
||||
Op::AvgPool2D {
|
||||
arg,
|
||||
kernel_size,
|
||||
stride,
|
||||
} => {
|
||||
if kernel_size != stride {
|
||||
crate::bail!("backward not supported for avgpool2d if ksize {kernel_size:?} != stride {stride:?}")
|
||||
}
|
||||
let (_n, _c, h, w) = arg.dims4()?;
|
||||
let grad_arg = grad.upsample_nearest2d(h, w)?;
|
||||
let grad_arg =
|
||||
(grad_arg * (1f64 / (kernel_size.0 * kernel_size.1) as f64))?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
}
|
||||
Op::MaxPool2D {
|
||||
arg,
|
||||
kernel_size,
|
||||
stride,
|
||||
} => {
|
||||
if kernel_size != stride {
|
||||
crate::bail!("backward not supported for maxpool2d if ksize {kernel_size:?} != stride {stride:?}")
|
||||
}
|
||||
let (_n, _c, h, w) = arg.dims4()?;
|
||||
// For computing the max-pool gradient, we compute a mask where a 1 means
|
||||
// that the element is the maximum, then we apply this mask to the
|
||||
// upsampled gradient (taking into account that multiple max may exist so
|
||||
// we scale the gradient for this case).
|
||||
let node_upsampled = node.upsample_nearest2d(h, w)?;
|
||||
let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?;
|
||||
let avg = mask.avg_pool2d_with_stride(*kernel_size, *stride)?;
|
||||
let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
}
|
||||
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "upsample-nearest1d",
|
||||
})?,
|
||||
Op::UpsampleNearest2D {
|
||||
arg,
|
||||
target_h,
|
||||
target_w,
|
||||
} => {
|
||||
let (_n, c, h, w) = arg.dims4()?;
|
||||
if target_h % h != 0 || target_w % w != 0 {
|
||||
crate::bail!("backward not supported for non integer upscaling factors")
|
||||
}
|
||||
let scale_h = target_h / h;
|
||||
let scale_w = target_w / w;
|
||||
|
||||
if scale_h != scale_w {
|
||||
crate::bail!("backward not supported for non uniform upscaling factors")
|
||||
};
|
||||
let kernel =
|
||||
Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?;
|
||||
let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = conv_sum;
|
||||
}
|
||||
Op::SliceScatter0(lhs, rhs, start_rhs) => {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?
|
||||
}
|
||||
Op::Gather(arg, indexes, dim) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
|
||||
@ -471,7 +268,7 @@ impl Tensor {
|
||||
}
|
||||
Op::ToDType(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad.to_dtype(arg.dtype())?)?
|
||||
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
|
||||
}
|
||||
Op::Copy(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
@ -494,11 +291,6 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Tanh) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let minus_dtanh = (node.sqr()? - 1.)?;
|
||||
*sum_grad = sum_grad.sub(&(&grad * &minus_dtanh)?)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Abs) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let ones = arg.ones_like()?;
|
||||
@ -551,59 +343,13 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
|
||||
Op::Unary(_, UnaryOp::Floor) => {
|
||||
Err(Error::BackwardNotSupported { op: "floor" })?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Round) => {
|
||||
Err(Error::BackwardNotSupported { op: "round" })?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Gelu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let cube = arg.powf(3.)?;
|
||||
let tanh = (0.0356774 * &cube + (0.797885 * arg)?)?.tanh()?;
|
||||
let gelu_grad = (((0.5 * &tanh)?
|
||||
+ (0.0535161 * cube + (0.398942 * arg)?)? * (1. - tanh.powf(2.)?))?
|
||||
+ 0.5)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Erf) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
// d/dx erf(x) = 2/sqrt(pi) * e^(-x^2)
|
||||
let erf_grad =
|
||||
(2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?;
|
||||
*sum_grad = sum_grad.add(&(&grad * erf_grad)?)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::GeluErf) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
// d/dx gelu_erf(x) = 0.5 + 0.398942 e^(-x^2/2) x + 0.5 erf(x/sqrt(2))
|
||||
let neg_half_square = (arg.sqr()?.neg()? / 2.)?;
|
||||
let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?;
|
||||
let arg_scaled_sqrt = (arg / 2f64.sqrt())?;
|
||||
let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?;
|
||||
let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?;
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
|
||||
Op::Unary(arg, UnaryOp::Relu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
||||
}
|
||||
Op::Elu(arg, alpha) => {
|
||||
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let zeros = arg.zeros_like()?;
|
||||
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
|
||||
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
|
||||
let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?;
|
||||
let combined_mask = (positive_mask + negative_exp_mask)?;
|
||||
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
|
||||
}
|
||||
Op::Powf(arg, e) => {
|
||||
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
|
||||
Op::CustomOp1(arg, c) => {
|
||||
if let Some(arg_grad) = c.bwd(arg, node, &grad)? {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
@ -657,15 +403,6 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Permute(arg, dims) => {
|
||||
let mut inv_dims = vec![0; dims.len()];
|
||||
for (i, &dim_idx) in dims.iter().enumerate() {
|
||||
inv_dims[dim_idx] = i
|
||||
}
|
||||
let arg_grad = grad.permute(inv_dims)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
@ -673,7 +410,6 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct GradStore(HashMap<TensorId, Tensor>);
|
||||
|
||||
impl GradStore {
|
||||
|
@ -1,5 +1,3 @@
|
||||
use crate::{op::BackpropOp, op::Op, Error, Result, Tensor};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ParamsConv1D {
|
||||
pub(crate) b_size: usize,
|
||||
@ -11,12 +9,12 @@ pub struct ParamsConv1D {
|
||||
pub(crate) k_size: usize,
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
pub(crate) dilation: usize,
|
||||
}
|
||||
|
||||
impl ParamsConv1D {
|
||||
pub(crate) fn l_out(&self) -> usize {
|
||||
(self.l_in + 2 * self.padding - self.dilation * (self.k_size - 1) - 1) / self.stride + 1
|
||||
let dilation = 1;
|
||||
(self.l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1
|
||||
}
|
||||
|
||||
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
||||
@ -25,46 +23,6 @@ impl ParamsConv1D {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ParamsConvTranspose1D {
|
||||
pub(crate) b_size: usize,
|
||||
pub(crate) l_in: usize,
|
||||
pub(crate) c_out: usize,
|
||||
pub(crate) c_in: usize,
|
||||
pub(crate) k_size: usize,
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) output_padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
pub(crate) dilation: usize,
|
||||
}
|
||||
|
||||
impl ParamsConvTranspose1D {
|
||||
pub(crate) fn l_out(&self) -> usize {
|
||||
(self.l_in - 1) * self.stride - 2 * self.padding
|
||||
+ self.dilation * (self.k_size - 1)
|
||||
+ self.output_padding
|
||||
+ 1
|
||||
}
|
||||
|
||||
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
||||
let l_out = self.l_out();
|
||||
vec![self.b_size, self.c_out, l_out]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum CudnnFwdAlgo {
|
||||
ImplicitGemm,
|
||||
ImplicitPrecompGemm,
|
||||
Gemm,
|
||||
Direct,
|
||||
Fft,
|
||||
FftTiling,
|
||||
Winograd,
|
||||
WinogradNonFused,
|
||||
Count,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ParamsConv2D {
|
||||
pub(crate) b_size: usize,
|
||||
@ -76,260 +34,20 @@ pub struct ParamsConv2D {
|
||||
pub(crate) c_in: usize,
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
pub(crate) dilation: usize,
|
||||
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||
}
|
||||
|
||||
impl ParamsConv2D {
|
||||
pub(crate) fn out_h(&self) -> usize {
|
||||
(self.i_h + 2 * self.padding - self.dilation * (self.k_h - 1) - 1) / self.stride + 1
|
||||
let dilation = 1;
|
||||
(self.i_h + 2 * self.padding - dilation * (self.k_h - 1) - 1) / self.stride + 1
|
||||
}
|
||||
|
||||
pub(crate) fn out_w(&self) -> usize {
|
||||
(self.i_w + 2 * self.padding - self.dilation * (self.k_w - 1) - 1) / self.stride + 1
|
||||
let dilation = 1;
|
||||
(self.i_w + 2 * self.padding - dilation * (self.k_w - 1) - 1) / self.stride + 1
|
||||
}
|
||||
|
||||
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
||||
vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ParamsConvTranspose2D {
|
||||
pub(crate) b_size: usize,
|
||||
pub(crate) i_h: usize,
|
||||
pub(crate) i_w: usize,
|
||||
pub(crate) k_h: usize,
|
||||
pub(crate) k_w: usize,
|
||||
pub(crate) c_out: usize,
|
||||
pub(crate) c_in: usize,
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) output_padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
pub(crate) dilation: usize,
|
||||
}
|
||||
|
||||
impl ParamsConvTranspose2D {
|
||||
pub(crate) fn out_h(&self) -> usize {
|
||||
(self.i_h - 1) * self.stride + self.dilation * (self.k_h - 1) + self.output_padding + 1
|
||||
- 2 * self.padding
|
||||
}
|
||||
|
||||
pub(crate) fn out_w(&self) -> usize {
|
||||
(self.i_w - 1) * self.stride + self.dilation * (self.k_w - 1) + self.output_padding + 1
|
||||
- 2 * self.padding
|
||||
}
|
||||
|
||||
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
||||
vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
|
||||
}
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
fn conv1d_single_group(&self, kernel: &Self, params: &ParamsConv1D) -> Result<Self> {
|
||||
let storage =
|
||||
self.storage()
|
||||
.conv1d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
|
||||
arg,
|
||||
kernel,
|
||||
padding: params.padding,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
/// Applies a 1D convolution over the input tensor.
|
||||
pub fn conv1d(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
let (c_out, c_in_k, k_size) = kernel.dims3()?;
|
||||
let (b_size, c_in, l_in) = self.dims3()?;
|
||||
if c_in != c_in_k * groups {
|
||||
Err(Error::Conv1dInvalidArgs {
|
||||
inp_shape: self.shape().clone(),
|
||||
k_shape: kernel.shape().clone(),
|
||||
padding,
|
||||
stride,
|
||||
msg: "the number of in-channels on the input doesn't match the kernel size",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
|
||||
let params = ParamsConv1D {
|
||||
b_size,
|
||||
l_in,
|
||||
c_out: c_out / groups,
|
||||
c_in: c_in / groups,
|
||||
k_size,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv1d_single_group(kernel, ¶ms)
|
||||
} else {
|
||||
let blocks = self.chunk(groups, 1)?;
|
||||
let kernel = kernel.chunk(groups, 0)?;
|
||||
let blocks = blocks
|
||||
.iter()
|
||||
.zip(&kernel)
|
||||
.map(|(block, kernel)| block.conv1d_single_group(kernel, ¶ms))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Tensor::cat(&blocks, 1)
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies a 1D transposed convolution over the input tensor.
|
||||
pub fn conv_transpose1d(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
) -> Result<Self> {
|
||||
let (b_size, c_in, l_in) = self.dims3()?;
|
||||
let (c_in_k, c_out, k_size) = kernel.dims3()?;
|
||||
if c_in != c_in_k {
|
||||
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||
}
|
||||
let params = ParamsConvTranspose1D {
|
||||
b_size,
|
||||
l_in,
|
||||
k_size,
|
||||
c_out,
|
||||
c_in,
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
dilation,
|
||||
};
|
||||
let storage = self.storage().conv_transpose1d(
|
||||
self.layout(),
|
||||
&kernel.storage(),
|
||||
kernel.layout(),
|
||||
¶ms,
|
||||
)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
|
||||
arg,
|
||||
kernel,
|
||||
padding: params.padding,
|
||||
output_padding: params.output_padding,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
|
||||
let storage =
|
||||
self.storage()
|
||||
.conv2d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
|
||||
arg,
|
||||
kernel,
|
||||
padding: params.padding,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
/// Applies a 2D convolution over the input tensor.
|
||||
pub fn conv2d(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
||||
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
|
||||
if c_in != c_in_k * groups {
|
||||
crate::bail!(
|
||||
"in_channel mismatch between input ({c_in}, groups {groups}) and kernel ({c_in_k})"
|
||||
)
|
||||
}
|
||||
let params = ParamsConv2D {
|
||||
b_size,
|
||||
i_h,
|
||||
i_w,
|
||||
k_h,
|
||||
k_w,
|
||||
c_out: c_out / groups,
|
||||
c_in: c_in / groups,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv2d_single_group(kernel, ¶ms)
|
||||
} else {
|
||||
let blocks = self.chunk(groups, 1)?;
|
||||
let kernel = kernel.chunk(groups, 0)?;
|
||||
let blocks = blocks
|
||||
.iter()
|
||||
.zip(&kernel)
|
||||
.map(|(block, kernel)| block.conv2d_single_group(kernel, ¶ms))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Tensor::cat(&blocks, 1)
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies a 2D transposed convolution over the input tensor.
|
||||
pub fn conv_transpose2d(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
) -> Result<Self> {
|
||||
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
||||
let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?;
|
||||
if c_in != c_in_k {
|
||||
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||
}
|
||||
let params = ParamsConvTranspose2D {
|
||||
b_size,
|
||||
i_h,
|
||||
i_w,
|
||||
k_h,
|
||||
k_w,
|
||||
c_out,
|
||||
c_in,
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
dilation,
|
||||
};
|
||||
let storage = self.storage().conv_transpose2d(
|
||||
self.layout(),
|
||||
&kernel.storage(),
|
||||
kernel.layout(),
|
||||
¶ms,
|
||||
)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose2D {
|
||||
arg,
|
||||
kernel,
|
||||
padding: params.padding,
|
||||
output_padding: params.output_padding,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
}
|
||||
|
@ -92,7 +92,6 @@ from_tensor!(f64);
|
||||
from_tensor!(f32);
|
||||
from_tensor!(f16);
|
||||
from_tensor!(bf16);
|
||||
from_tensor!(i64);
|
||||
from_tensor!(u32);
|
||||
from_tensor!(u8);
|
||||
|
||||
@ -130,11 +129,6 @@ impl Tensor {
|
||||
f.write_u32::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::I64 => {
|
||||
for v in vs.to_vec1::<i64>()? {
|
||||
f.write_i64::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::U8 => {
|
||||
let vs = vs.to_vec1::<u8>()?;
|
||||
f.write_all(&vs)?;
|
||||
|
@ -103,7 +103,7 @@ impl CpuF16<ARR> for CurrentCpuF16 {
|
||||
for i in 0..8 {
|
||||
tmp[i] = (*mem_addr.add(i)).to_f32();
|
||||
}
|
||||
_mm256_loadu_ps(tmp.as_ptr())
|
||||
_mm_loadu_ps(tmp.as_ptr())
|
||||
}
|
||||
|
||||
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
|
||||
|
@ -1,763 +0,0 @@
|
||||
#![allow(clippy::excessive_precision)]
|
||||
// Code taken from https://github.com/statrs-dev/statrs
|
||||
//! Provides the [error](https://en.wikipedia.org/wiki/Error_function) and
|
||||
//! related functions
|
||||
|
||||
mod evaluate {
|
||||
//! Provides functions that don't have a numerical solution and must
|
||||
//! be solved computationally (e.g. evaluation of a polynomial)
|
||||
|
||||
/// evaluates a polynomial at `z` where `coeff` are the coeffecients
|
||||
/// to a polynomial of order `k` where `k` is the length of `coeff` and the
|
||||
/// coeffecient
|
||||
/// to the `k`th power is the `k`th element in coeff. E.g. [3,-1,2] equates to
|
||||
/// `2z^2 - z + 3`
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// Returns 0 for a 0 length coefficient slice
|
||||
pub fn polynomial(z: f64, coeff: &[f64]) -> f64 {
|
||||
let n = coeff.len();
|
||||
if n == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mut sum = *coeff.last().unwrap();
|
||||
for c in coeff[0..n - 1].iter().rev() {
|
||||
sum = *c + z * sum;
|
||||
}
|
||||
sum
|
||||
}
|
||||
}
|
||||
use std::f64;
|
||||
|
||||
/// `erf` calculates the error function at `x`.
|
||||
pub fn erf(x: f64) -> f64 {
|
||||
if x.is_nan() {
|
||||
f64::NAN
|
||||
} else if x >= 0.0 && x.is_infinite() {
|
||||
1.0
|
||||
} else if x <= 0.0 && x.is_infinite() {
|
||||
-1.0
|
||||
} else if x == 0. {
|
||||
0.0
|
||||
} else {
|
||||
erf_impl(x, false)
|
||||
}
|
||||
}
|
||||
|
||||
/// `erf_inv` calculates the inverse error function
|
||||
/// at `x`.
|
||||
pub fn erf_inv(x: f64) -> f64 {
|
||||
if x == 0.0 {
|
||||
0.0
|
||||
} else if x >= 1.0 {
|
||||
f64::INFINITY
|
||||
} else if x <= -1.0 {
|
||||
f64::NEG_INFINITY
|
||||
} else if x < 0.0 {
|
||||
erf_inv_impl(-x, 1.0 + x, -1.0)
|
||||
} else {
|
||||
erf_inv_impl(x, 1.0 - x, 1.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// `erfc` calculates the complementary error function
|
||||
/// at `x`.
|
||||
pub fn erfc(x: f64) -> f64 {
|
||||
if x.is_nan() {
|
||||
f64::NAN
|
||||
} else if x == f64::INFINITY {
|
||||
0.0
|
||||
} else if x == f64::NEG_INFINITY {
|
||||
2.0
|
||||
} else {
|
||||
erf_impl(x, true)
|
||||
}
|
||||
}
|
||||
|
||||
/// `erfc_inv` calculates the complementary inverse
|
||||
/// error function at `x`.
|
||||
pub fn erfc_inv(x: f64) -> f64 {
|
||||
if x <= 0.0 {
|
||||
f64::INFINITY
|
||||
} else if x >= 2.0 {
|
||||
f64::NEG_INFINITY
|
||||
} else if x > 1.0 {
|
||||
erf_inv_impl(-1.0 + x, 2.0 - x, -1.0)
|
||||
} else {
|
||||
erf_inv_impl(1.0 - x, x, 1.0)
|
||||
}
|
||||
}
|
||||
|
||||
// **********************************************************
|
||||
// ********** Coefficients for erf_impl polynomial **********
|
||||
// **********************************************************
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_impl`
|
||||
/// in the interval [1e-10, 0.5].
|
||||
const ERF_IMPL_AN: &[f64] = &[
|
||||
0.00337916709551257388990745,
|
||||
-0.00073695653048167948530905,
|
||||
-0.374732337392919607868241,
|
||||
0.0817442448733587196071743,
|
||||
-0.0421089319936548595203468,
|
||||
0.0070165709512095756344528,
|
||||
-0.00495091255982435110337458,
|
||||
0.000871646599037922480317225,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_impl`
|
||||
/// in the interval [1e-10, 0.5]
|
||||
const ERF_IMPL_AD: &[f64] = &[
|
||||
1.0,
|
||||
-0.218088218087924645390535,
|
||||
0.412542972725442099083918,
|
||||
-0.0841891147873106755410271,
|
||||
0.0655338856400241519690695,
|
||||
-0.0120019604454941768171266,
|
||||
0.00408165558926174048329689,
|
||||
-0.000615900721557769691924509,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [0.5, 0.75].
|
||||
const ERF_IMPL_BN: &[f64] = &[
|
||||
-0.0361790390718262471360258,
|
||||
0.292251883444882683221149,
|
||||
0.281447041797604512774415,
|
||||
0.125610208862766947294894,
|
||||
0.0274135028268930549240776,
|
||||
0.00250839672168065762786937,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [0.5, 0.75].
|
||||
const ERF_IMPL_BD: &[f64] = &[
|
||||
1.0,
|
||||
1.8545005897903486499845,
|
||||
1.43575803037831418074962,
|
||||
0.582827658753036572454135,
|
||||
0.124810476932949746447682,
|
||||
0.0113724176546353285778481,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [0.75, 1.25].
|
||||
const ERF_IMPL_CN: &[f64] = &[
|
||||
-0.0397876892611136856954425,
|
||||
0.153165212467878293257683,
|
||||
0.191260295600936245503129,
|
||||
0.10276327061989304213645,
|
||||
0.029637090615738836726027,
|
||||
0.0046093486780275489468812,
|
||||
0.000307607820348680180548455,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [0.75, 1.25].
|
||||
const ERF_IMPL_CD: &[f64] = &[
|
||||
1.0,
|
||||
1.95520072987627704987886,
|
||||
1.64762317199384860109595,
|
||||
0.768238607022126250082483,
|
||||
0.209793185936509782784315,
|
||||
0.0319569316899913392596356,
|
||||
0.00213363160895785378615014,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [1.25, 2.25].
|
||||
const ERF_IMPL_DN: &[f64] = &[
|
||||
-0.0300838560557949717328341,
|
||||
0.0538578829844454508530552,
|
||||
0.0726211541651914182692959,
|
||||
0.0367628469888049348429018,
|
||||
0.00964629015572527529605267,
|
||||
0.00133453480075291076745275,
|
||||
0.778087599782504251917881e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [1.25, 2.25].
|
||||
const ERF_IMPL_DD: &[f64] = &[
|
||||
1.0,
|
||||
1.75967098147167528287343,
|
||||
1.32883571437961120556307,
|
||||
0.552528596508757581287907,
|
||||
0.133793056941332861912279,
|
||||
0.0179509645176280768640766,
|
||||
0.00104712440019937356634038,
|
||||
-0.106640381820357337177643e-7,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [2.25, 3.5].
|
||||
const ERF_IMPL_EN: &[f64] = &[
|
||||
-0.0117907570137227847827732,
|
||||
0.014262132090538809896674,
|
||||
0.0202234435902960820020765,
|
||||
0.00930668299990432009042239,
|
||||
0.00213357802422065994322516,
|
||||
0.00025022987386460102395382,
|
||||
0.120534912219588189822126e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [2.25, 3.5].
|
||||
const ERF_IMPL_ED: &[f64] = &[
|
||||
1.0,
|
||||
1.50376225203620482047419,
|
||||
0.965397786204462896346934,
|
||||
0.339265230476796681555511,
|
||||
0.0689740649541569716897427,
|
||||
0.00771060262491768307365526,
|
||||
0.000371421101531069302990367,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [3.5, 5.25].
|
||||
const ERF_IMPL_FN: &[f64] = &[
|
||||
-0.00546954795538729307482955,
|
||||
0.00404190278731707110245394,
|
||||
0.0054963369553161170521356,
|
||||
0.00212616472603945399437862,
|
||||
0.000394984014495083900689956,
|
||||
0.365565477064442377259271e-4,
|
||||
0.135485897109932323253786e-5,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [3.5, 5.25].
|
||||
const ERF_IMPL_FD: &[f64] = &[
|
||||
1.0,
|
||||
1.21019697773630784832251,
|
||||
0.620914668221143886601045,
|
||||
0.173038430661142762569515,
|
||||
0.0276550813773432047594539,
|
||||
0.00240625974424309709745382,
|
||||
0.891811817251336577241006e-4,
|
||||
-0.465528836283382684461025e-11,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [5.25, 8].
|
||||
const ERF_IMPL_GN: &[f64] = &[
|
||||
-0.00270722535905778347999196,
|
||||
0.0013187563425029400461378,
|
||||
0.00119925933261002333923989,
|
||||
0.00027849619811344664248235,
|
||||
0.267822988218331849989363e-4,
|
||||
0.923043672315028197865066e-6,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [5.25, 8].
|
||||
const ERF_IMPL_GD: &[f64] = &[
|
||||
1.0,
|
||||
0.814632808543141591118279,
|
||||
0.268901665856299542168425,
|
||||
0.0449877216103041118694989,
|
||||
0.00381759663320248459168994,
|
||||
0.000131571897888596914350697,
|
||||
0.404815359675764138445257e-11,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [8, 11.5].
|
||||
const ERF_IMPL_HN: &[f64] = &[
|
||||
-0.00109946720691742196814323,
|
||||
0.000406425442750422675169153,
|
||||
0.000274499489416900707787024,
|
||||
0.465293770646659383436343e-4,
|
||||
0.320955425395767463401993e-5,
|
||||
0.778286018145020892261936e-7,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [8, 11.5].
|
||||
const ERF_IMPL_HD: &[f64] = &[
|
||||
1.0,
|
||||
0.588173710611846046373373,
|
||||
0.139363331289409746077541,
|
||||
0.0166329340417083678763028,
|
||||
0.00100023921310234908642639,
|
||||
0.24254837521587225125068e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [11.5, 17].
|
||||
const ERF_IMPL_IN: &[f64] = &[
|
||||
-0.00056907993601094962855594,
|
||||
0.000169498540373762264416984,
|
||||
0.518472354581100890120501e-4,
|
||||
0.382819312231928859704678e-5,
|
||||
0.824989931281894431781794e-7,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [11.5, 17].
|
||||
const ERF_IMPL_ID: &[f64] = &[
|
||||
1.0,
|
||||
0.339637250051139347430323,
|
||||
0.043472647870310663055044,
|
||||
0.00248549335224637114641629,
|
||||
0.535633305337152900549536e-4,
|
||||
-0.117490944405459578783846e-12,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [17, 24].
|
||||
const ERF_IMPL_JN: &[f64] = &[
|
||||
-0.000241313599483991337479091,
|
||||
0.574224975202501512365975e-4,
|
||||
0.115998962927383778460557e-4,
|
||||
0.581762134402593739370875e-6,
|
||||
0.853971555085673614607418e-8,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [17, 24].
|
||||
const ERF_IMPL_JD: &[f64] = &[
|
||||
1.0,
|
||||
0.233044138299687841018015,
|
||||
0.0204186940546440312625597,
|
||||
0.000797185647564398289151125,
|
||||
0.117019281670172327758019e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [24, 38].
|
||||
const ERF_IMPL_KN: &[f64] = &[
|
||||
-0.000146674699277760365803642,
|
||||
0.162666552112280519955647e-4,
|
||||
0.269116248509165239294897e-5,
|
||||
0.979584479468091935086972e-7,
|
||||
0.101994647625723465722285e-8,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [24, 38].
|
||||
const ERF_IMPL_KD: &[f64] = &[
|
||||
1.0,
|
||||
0.165907812944847226546036,
|
||||
0.0103361716191505884359634,
|
||||
0.000286593026373868366935721,
|
||||
0.298401570840900340874568e-5,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [38, 60].
|
||||
const ERF_IMPL_LN: &[f64] = &[
|
||||
-0.583905797629771786720406e-4,
|
||||
0.412510325105496173512992e-5,
|
||||
0.431790922420250949096906e-6,
|
||||
0.993365155590013193345569e-8,
|
||||
0.653480510020104699270084e-10,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [38, 60].
|
||||
const ERF_IMPL_LD: &[f64] = &[
|
||||
1.0,
|
||||
0.105077086072039915406159,
|
||||
0.00414278428675475620830226,
|
||||
0.726338754644523769144108e-4,
|
||||
0.477818471047398785369849e-6,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [60, 85].
|
||||
const ERF_IMPL_MN: &[f64] = &[
|
||||
-0.196457797609229579459841e-4,
|
||||
0.157243887666800692441195e-5,
|
||||
0.543902511192700878690335e-7,
|
||||
0.317472492369117710852685e-9,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [60, 85].
|
||||
const ERF_IMPL_MD: &[f64] = &[
|
||||
1.0,
|
||||
0.052803989240957632204885,
|
||||
0.000926876069151753290378112,
|
||||
0.541011723226630257077328e-5,
|
||||
0.535093845803642394908747e-15,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [85, 110].
|
||||
const ERF_IMPL_NN: &[f64] = &[
|
||||
-0.789224703978722689089794e-5,
|
||||
0.622088451660986955124162e-6,
|
||||
0.145728445676882396797184e-7,
|
||||
0.603715505542715364529243e-10,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [85, 110].
|
||||
const ERF_IMPL_ND: &[f64] = &[
|
||||
1.0,
|
||||
0.0375328846356293715248719,
|
||||
0.000467919535974625308126054,
|
||||
0.193847039275845656900547e-5,
|
||||
];
|
||||
|
||||
// **********************************************************
|
||||
// ********** Coefficients for erf_inv_impl polynomial ******
|
||||
// **********************************************************
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0, 0.5].
|
||||
const ERF_INV_IMPL_AN: &[f64] = &[
|
||||
-0.000508781949658280665617,
|
||||
-0.00836874819741736770379,
|
||||
0.0334806625409744615033,
|
||||
-0.0126926147662974029034,
|
||||
-0.0365637971411762664006,
|
||||
0.0219878681111168899165,
|
||||
0.00822687874676915743155,
|
||||
-0.00538772965071242932965,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0, 0.5].
|
||||
const ERF_INV_IMPL_AD: &[f64] = &[
|
||||
1.0,
|
||||
-0.970005043303290640362,
|
||||
-1.56574558234175846809,
|
||||
1.56221558398423026363,
|
||||
0.662328840472002992063,
|
||||
-0.71228902341542847553,
|
||||
-0.0527396382340099713954,
|
||||
0.0795283687341571680018,
|
||||
-0.00233393759374190016776,
|
||||
0.000886216390456424707504,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.5, 0.75].
|
||||
const ERF_INV_IMPL_BN: &[f64] = &[
|
||||
-0.202433508355938759655,
|
||||
0.105264680699391713268,
|
||||
8.37050328343119927838,
|
||||
17.6447298408374015486,
|
||||
-18.8510648058714251895,
|
||||
-44.6382324441786960818,
|
||||
17.445385985570866523,
|
||||
21.1294655448340526258,
|
||||
-3.67192254707729348546,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.5, 0.75].
|
||||
const ERF_INV_IMPL_BD: &[f64] = &[
|
||||
1.0,
|
||||
6.24264124854247537712,
|
||||
3.9713437953343869095,
|
||||
-28.6608180499800029974,
|
||||
-20.1432634680485188801,
|
||||
48.5609213108739935468,
|
||||
10.8268667355460159008,
|
||||
-22.6436933413139721736,
|
||||
1.72114765761200282724,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x less than 3.
|
||||
const ERF_INV_IMPL_CN: &[f64] = &[
|
||||
-0.131102781679951906451,
|
||||
-0.163794047193317060787,
|
||||
0.117030156341995252019,
|
||||
0.387079738972604337464,
|
||||
0.337785538912035898924,
|
||||
0.142869534408157156766,
|
||||
0.0290157910005329060432,
|
||||
0.00214558995388805277169,
|
||||
-0.679465575181126350155e-6,
|
||||
0.285225331782217055858e-7,
|
||||
-0.681149956853776992068e-9,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x less than 3.
|
||||
const ERF_INV_IMPL_CD: &[f64] = &[
|
||||
1.0,
|
||||
3.46625407242567245975,
|
||||
5.38168345707006855425,
|
||||
4.77846592945843778382,
|
||||
2.59301921623620271374,
|
||||
0.848854343457902036425,
|
||||
0.152264338295331783612,
|
||||
0.01105924229346489121,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 3 and 6.
|
||||
const ERF_INV_IMPL_DN: &[f64] = &[
|
||||
-0.0350353787183177984712,
|
||||
-0.00222426529213447927281,
|
||||
0.0185573306514231072324,
|
||||
0.00950804701325919603619,
|
||||
0.00187123492819559223345,
|
||||
0.000157544617424960554631,
|
||||
0.460469890584317994083e-5,
|
||||
-0.230404776911882601748e-9,
|
||||
0.266339227425782031962e-11,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 3 and 6.
|
||||
const ERF_INV_IMPL_DD: &[f64] = &[
|
||||
1.0,
|
||||
1.3653349817554063097,
|
||||
0.762059164553623404043,
|
||||
0.220091105764131249824,
|
||||
0.0341589143670947727934,
|
||||
0.00263861676657015992959,
|
||||
0.764675292302794483503e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 6 and 18.
|
||||
const ERF_INV_IMPL_EN: &[f64] = &[
|
||||
-0.0167431005076633737133,
|
||||
-0.00112951438745580278863,
|
||||
0.00105628862152492910091,
|
||||
0.000209386317487588078668,
|
||||
0.149624783758342370182e-4,
|
||||
0.449696789927706453732e-6,
|
||||
0.462596163522878599135e-8,
|
||||
-0.281128735628831791805e-13,
|
||||
0.99055709973310326855e-16,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 6 and 18.
|
||||
const ERF_INV_IMPL_ED: &[f64] = &[
|
||||
1.0,
|
||||
0.591429344886417493481,
|
||||
0.138151865749083321638,
|
||||
0.0160746087093676504695,
|
||||
0.000964011807005165528527,
|
||||
0.275335474764726041141e-4,
|
||||
0.282243172016108031869e-6,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 18 and 44.
|
||||
const ERF_INV_IMPL_FN: &[f64] = &[
|
||||
-0.0024978212791898131227,
|
||||
-0.779190719229053954292e-5,
|
||||
0.254723037413027451751e-4,
|
||||
0.162397777342510920873e-5,
|
||||
0.396341011304801168516e-7,
|
||||
0.411632831190944208473e-9,
|
||||
0.145596286718675035587e-11,
|
||||
-0.116765012397184275695e-17,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 18 and 44.
|
||||
const ERF_INV_IMPL_FD: &[f64] = &[
|
||||
1.0,
|
||||
0.207123112214422517181,
|
||||
0.0169410838120975906478,
|
||||
0.000690538265622684595676,
|
||||
0.145007359818232637924e-4,
|
||||
0.144437756628144157666e-6,
|
||||
0.509761276599778486139e-9,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x greater than 44.
|
||||
const ERF_INV_IMPL_GN: &[f64] = &[
|
||||
-0.000539042911019078575891,
|
||||
-0.28398759004727721098e-6,
|
||||
0.899465114892291446442e-6,
|
||||
0.229345859265920864296e-7,
|
||||
0.225561444863500149219e-9,
|
||||
0.947846627503022684216e-12,
|
||||
0.135880130108924861008e-14,
|
||||
-0.348890393399948882918e-21,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x greater than 44.
|
||||
const ERF_INV_IMPL_GD: &[f64] = &[
|
||||
1.0,
|
||||
0.0845746234001899436914,
|
||||
0.00282092984726264681981,
|
||||
0.468292921940894236786e-4,
|
||||
0.399968812193862100054e-6,
|
||||
0.161809290887904476097e-8,
|
||||
0.231558608310259605225e-11,
|
||||
];
|
||||
|
||||
/// `erf_impl` computes the error function at `z`.
|
||||
/// If `inv` is true, `1 - erf` is calculated as opposed to `erf`
|
||||
fn erf_impl(z: f64, inv: bool) -> f64 {
|
||||
if z < 0.0 {
|
||||
if !inv {
|
||||
return -erf_impl(-z, false);
|
||||
}
|
||||
if z < -0.5 {
|
||||
return 2.0 - erf_impl(-z, true);
|
||||
}
|
||||
return 1.0 + erf_impl(-z, false);
|
||||
}
|
||||
|
||||
let result = if z < 0.5 {
|
||||
if z < 1e-10 {
|
||||
z * 1.125 + z * 0.003379167095512573896158903121545171688
|
||||
} else {
|
||||
z * 1.125
|
||||
+ z * evaluate::polynomial(z, ERF_IMPL_AN) / evaluate::polynomial(z, ERF_IMPL_AD)
|
||||
}
|
||||
} else if z < 110.0 {
|
||||
let (r, b) = if z < 0.75 {
|
||||
(
|
||||
evaluate::polynomial(z - 0.5, ERF_IMPL_BN)
|
||||
/ evaluate::polynomial(z - 0.5, ERF_IMPL_BD),
|
||||
0.3440242112,
|
||||
)
|
||||
} else if z < 1.25 {
|
||||
(
|
||||
evaluate::polynomial(z - 0.75, ERF_IMPL_CN)
|
||||
/ evaluate::polynomial(z - 0.75, ERF_IMPL_CD),
|
||||
0.419990927,
|
||||
)
|
||||
} else if z < 2.25 {
|
||||
(
|
||||
evaluate::polynomial(z - 1.25, ERF_IMPL_DN)
|
||||
/ evaluate::polynomial(z - 1.25, ERF_IMPL_DD),
|
||||
0.4898625016,
|
||||
)
|
||||
} else if z < 3.5 {
|
||||
(
|
||||
evaluate::polynomial(z - 2.25, ERF_IMPL_EN)
|
||||
/ evaluate::polynomial(z - 2.25, ERF_IMPL_ED),
|
||||
0.5317370892,
|
||||
)
|
||||
} else if z < 5.25 {
|
||||
(
|
||||
evaluate::polynomial(z - 3.5, ERF_IMPL_FN)
|
||||
/ evaluate::polynomial(z - 3.5, ERF_IMPL_FD),
|
||||
0.5489973426,
|
||||
)
|
||||
} else if z < 8.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 5.25, ERF_IMPL_GN)
|
||||
/ evaluate::polynomial(z - 5.25, ERF_IMPL_GD),
|
||||
0.5571740866,
|
||||
)
|
||||
} else if z < 11.5 {
|
||||
(
|
||||
evaluate::polynomial(z - 8.0, ERF_IMPL_HN)
|
||||
/ evaluate::polynomial(z - 8.0, ERF_IMPL_HD),
|
||||
0.5609807968,
|
||||
)
|
||||
} else if z < 17.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 11.5, ERF_IMPL_IN)
|
||||
/ evaluate::polynomial(z - 11.5, ERF_IMPL_ID),
|
||||
0.5626493692,
|
||||
)
|
||||
} else if z < 24.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 17.0, ERF_IMPL_JN)
|
||||
/ evaluate::polynomial(z - 17.0, ERF_IMPL_JD),
|
||||
0.5634598136,
|
||||
)
|
||||
} else if z < 38.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 24.0, ERF_IMPL_KN)
|
||||
/ evaluate::polynomial(z - 24.0, ERF_IMPL_KD),
|
||||
0.5638477802,
|
||||
)
|
||||
} else if z < 60.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 38.0, ERF_IMPL_LN)
|
||||
/ evaluate::polynomial(z - 38.0, ERF_IMPL_LD),
|
||||
0.5640528202,
|
||||
)
|
||||
} else if z < 85.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 60.0, ERF_IMPL_MN)
|
||||
/ evaluate::polynomial(z - 60.0, ERF_IMPL_MD),
|
||||
0.5641309023,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
evaluate::polynomial(z - 85.0, ERF_IMPL_NN)
|
||||
/ evaluate::polynomial(z - 85.0, ERF_IMPL_ND),
|
||||
0.5641584396,
|
||||
)
|
||||
};
|
||||
let g = (-z * z).exp() / z;
|
||||
g * b + g * r
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
if inv && z >= 0.5 {
|
||||
result
|
||||
} else if z >= 0.5 || inv {
|
||||
1.0 - result
|
||||
} else {
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
// `erf_inv_impl` computes the inverse error function where
|
||||
// `p`,`q`, and `s` are the first, second, and third intermediate
|
||||
// parameters respectively
|
||||
fn erf_inv_impl(p: f64, q: f64, s: f64) -> f64 {
|
||||
let result = if p <= 0.5 {
|
||||
let y = 0.0891314744949340820313;
|
||||
let g = p * (p + 10.0);
|
||||
let r = evaluate::polynomial(p, ERF_INV_IMPL_AN) / evaluate::polynomial(p, ERF_INV_IMPL_AD);
|
||||
g * y + g * r
|
||||
} else if q >= 0.25 {
|
||||
let y = 2.249481201171875;
|
||||
let g = (-2.0 * q.ln()).sqrt();
|
||||
let xs = q - 0.25;
|
||||
let r =
|
||||
evaluate::polynomial(xs, ERF_INV_IMPL_BN) / evaluate::polynomial(xs, ERF_INV_IMPL_BD);
|
||||
g / (y + r)
|
||||
} else {
|
||||
let x = (-q.ln()).sqrt();
|
||||
if x < 3.0 {
|
||||
let y = 0.807220458984375;
|
||||
let xs = x - 1.125;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_CN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_CD);
|
||||
y * x + r * x
|
||||
} else if x < 6.0 {
|
||||
let y = 0.93995571136474609375;
|
||||
let xs = x - 3.0;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_DN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_DD);
|
||||
y * x + r * x
|
||||
} else if x < 18.0 {
|
||||
let y = 0.98362827301025390625;
|
||||
let xs = x - 6.0;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_EN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_ED);
|
||||
y * x + r * x
|
||||
} else if x < 44.0 {
|
||||
let y = 0.99714565277099609375;
|
||||
let xs = x - 18.0;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_FN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_FD);
|
||||
y * x + r * x
|
||||
} else {
|
||||
let y = 0.99941349029541015625;
|
||||
let xs = x - 44.0;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_GN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_GD);
|
||||
y * x + r * x
|
||||
}
|
||||
};
|
||||
s * result
|
||||
}
|
@ -1,7 +1,4 @@
|
||||
pub trait VecOps: num_traits::NumAssign + Copy {
|
||||
fn min(self, rhs: Self) -> Self;
|
||||
fn max(self, rhs: Self) -> Self;
|
||||
|
||||
/// Dot-product of two vectors.
|
||||
///
|
||||
/// # Safety
|
||||
@ -29,47 +26,9 @@ pub trait VecOps: num_traits::NumAssign + Copy {
|
||||
*res += *xs.add(i)
|
||||
}
|
||||
}
|
||||
|
||||
/// Maximum element in a non-empty vector.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
|
||||
/// element.
|
||||
#[inline(always)]
|
||||
unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) {
|
||||
*res = *xs;
|
||||
for i in 1..len {
|
||||
*res = (*res).max(*xs.add(i))
|
||||
}
|
||||
}
|
||||
|
||||
/// Minimum element in a non-empty vector.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
|
||||
/// element.
|
||||
#[inline(always)]
|
||||
unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) {
|
||||
*res = *xs;
|
||||
for i in 1..len {
|
||||
*res = (*res).min(*xs.add(i))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VecOps for f32 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
Self::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
Self::max(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
|
||||
super::vec_dot_f32(lhs, rhs, res, len)
|
||||
@ -82,16 +41,6 @@ impl VecOps for f32 {
|
||||
}
|
||||
|
||||
impl VecOps for half::f16 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
Self::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
Self::max(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
|
||||
let mut res_f32 = 0f32;
|
||||
@ -100,61 +49,10 @@ impl VecOps for half::f16 {
|
||||
}
|
||||
}
|
||||
|
||||
impl VecOps for f64 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
Self::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
Self::max(self, other)
|
||||
}
|
||||
}
|
||||
impl VecOps for half::bf16 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
Self::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
Self::max(self, other)
|
||||
}
|
||||
}
|
||||
impl VecOps for u8 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
<Self as Ord>::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
<Self as Ord>::max(self, other)
|
||||
}
|
||||
}
|
||||
impl VecOps for u32 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
<Self as Ord>::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
<Self as Ord>::max(self, other)
|
||||
}
|
||||
}
|
||||
impl VecOps for i64 {
|
||||
#[inline(always)]
|
||||
fn min(self, other: Self) -> Self {
|
||||
<Self as Ord>::min(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn max(self, other: Self) -> Self {
|
||||
<Self as Ord>::max(self, other)
|
||||
}
|
||||
}
|
||||
impl VecOps for f64 {}
|
||||
impl VecOps for half::bf16 {}
|
||||
impl VecOps for u8 {}
|
||||
impl VecOps for u32 {}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
|
||||
|
@ -1,4 +1,3 @@
|
||||
pub mod erf;
|
||||
pub mod kernels;
|
||||
|
||||
trait Cpu<const ARR: usize> {
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,7 +1,7 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
|
||||
pub use candle_kernels as kernels;
|
||||
use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||
use cudarc::driver::{
|
||||
@ -139,14 +139,6 @@ impl CudaDevice {
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
|
||||
let params = (&data, v as i64, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
|
||||
@ -223,14 +215,6 @@ impl BackendDevice for CudaDevice {
|
||||
})
|
||||
}
|
||||
|
||||
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||
// We do not call set_seed but instead create a new curand object. This ensures that the
|
||||
// state will be identical and the same random numbers will be generated.
|
||||
let mut curand = self.curand.lock().unwrap();
|
||||
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
crate::DeviceLocation::Cuda {
|
||||
gpu_id: self.device.ordinal(),
|
||||
@ -252,10 +236,6 @@ impl BackendDevice for CudaDevice {
|
||||
let data = self.alloc_zeros::<u32>(elem_count).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
let data = self.alloc_zeros::<i64>(elem_count).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
@ -285,13 +265,11 @@ impl BackendDevice for CudaDevice {
|
||||
let slice = match dtype {
|
||||
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
||||
// cudarc changes.
|
||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||
Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_uniform",
|
||||
})
|
||||
.w()?
|
||||
}
|
||||
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_uniform",
|
||||
})
|
||||
.w()?,
|
||||
DType::F32 => {
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||
curand.0.fill_with_uniform(&mut data).w()?;
|
||||
@ -303,12 +281,10 @@ impl BackendDevice for CudaDevice {
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
let slice = if lo == 0. && up == 1.0 {
|
||||
slice
|
||||
} else {
|
||||
if lo != 0.0 || up != 1.0 {
|
||||
let layout = Layout::contiguous(shape);
|
||||
Affine(up - lo, lo).map(&slice, self, &layout)?
|
||||
};
|
||||
Affine(up - lo, lo).map(&slice, self, &layout)?;
|
||||
}
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
@ -320,23 +296,14 @@ impl BackendDevice for CudaDevice {
|
||||
// cudarc changes.
|
||||
let elem_count = shape.elem_count();
|
||||
let curand = self.curand.lock().unwrap();
|
||||
// curand can only generate an odd number of values.
|
||||
// https://github.com/huggingface/candle/issues/734
|
||||
let elem_count_round = if elem_count % 2 == 1 {
|
||||
elem_count + 1
|
||||
} else {
|
||||
elem_count
|
||||
};
|
||||
let slice = match dtype {
|
||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||
Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_normal",
|
||||
})
|
||||
.w()?
|
||||
}
|
||||
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_normal",
|
||||
})
|
||||
.w()?,
|
||||
DType::F32 => {
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||
curand
|
||||
.0
|
||||
.fill_with_normal(&mut data, mean as f32, std as f32)
|
||||
@ -344,7 +311,7 @@ impl BackendDevice for CudaDevice {
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
|
||||
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||
curand.0.fill_with_normal(&mut data, mean, std).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
@ -369,10 +336,6 @@ impl BackendDevice for CudaDevice {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorage::I64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorage::BF16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
@ -398,10 +361,9 @@ impl BackendDevice for CudaDevice {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum CudaStorageSlice {
|
||||
enum CudaStorageSlice {
|
||||
U8(CudaSlice<u8>),
|
||||
U32(CudaSlice<u32>),
|
||||
I64(CudaSlice<i64>),
|
||||
BF16(CudaSlice<bf16>),
|
||||
F16(CudaSlice<f16>),
|
||||
F32(CudaSlice<f32>),
|
||||
@ -409,7 +371,7 @@ pub enum CudaStorageSlice {
|
||||
}
|
||||
type S = CudaStorageSlice;
|
||||
|
||||
pub trait Map1 {
|
||||
trait Map1 {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
@ -421,7 +383,6 @@ pub trait Map1 {
|
||||
let out = match s {
|
||||
S::U8(s) => S::U8(self.f(s, d, l)?),
|
||||
S::U32(s) => S::U32(self.f(s, d, l)?),
|
||||
S::I64(s) => S::I64(self.f(s, d, l)?),
|
||||
S::BF16(s) => S::BF16(self.f(s, d, l)?),
|
||||
S::F16(s) => S::F16(self.f(s, d, l)?),
|
||||
S::F32(s) => S::F32(self.f(s, d, l)?),
|
||||
@ -431,7 +392,7 @@ pub trait Map1 {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2 {
|
||||
trait Map2 {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
@ -445,7 +406,6 @@ pub trait Map2 {
|
||||
let out = match (s1, s2) {
|
||||
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
|
||||
@ -456,7 +416,7 @@ pub trait Map2 {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2InPlace {
|
||||
trait Map2InPlace {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
@ -477,7 +437,6 @@ pub trait Map2InPlace {
|
||||
match (dst, src) {
|
||||
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
@ -487,7 +446,7 @@ pub trait Map2InPlace {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map1Any {
|
||||
trait Map1Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
@ -500,7 +459,6 @@ pub trait Map1Any {
|
||||
let out = match s {
|
||||
S::U8(s) => self.f(s, d, l, S::U8)?,
|
||||
S::U32(s) => self.f(s, d, l, S::U32)?,
|
||||
S::I64(s) => self.f(s, d, l, S::I64)?,
|
||||
S::BF16(s) => self.f(s, d, l, S::BF16)?,
|
||||
S::F16(s) => self.f(s, d, l, S::F16)?,
|
||||
S::F32(s) => self.f(s, d, l, S::F32)?,
|
||||
@ -510,7 +468,7 @@ pub trait Map1Any {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2Any {
|
||||
trait Map2Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
@ -524,7 +482,6 @@ pub trait Map2Any {
|
||||
let out = match (s1, s2) {
|
||||
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
@ -547,7 +504,7 @@ impl Map1 for Clone {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn kernel_name<T: WithDType>(root: &str) -> String {
|
||||
fn kernel_name<T: WithDType>(root: &str) -> String {
|
||||
let dtype = T::DTYPE.as_str();
|
||||
format!("{root}_{dtype}")
|
||||
}
|
||||
@ -608,129 +565,6 @@ impl Map1 for Elu {
|
||||
}
|
||||
}
|
||||
|
||||
struct Im2Col1D {
|
||||
l_k: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
padding: usize,
|
||||
}
|
||||
|
||||
impl Im2Col1D {
|
||||
fn l_out(&self, l: usize) -> usize {
|
||||
(l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
|
||||
}
|
||||
}
|
||||
|
||||
impl Map1 for Im2Col1D {
|
||||
fn f<T: DeviceRepr + WithDType>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let l_out = self.l_out(dims[2]);
|
||||
let dst_el = dims[0] * l_out * dims[1] * self.l_k;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("im2col1d"), kernels::CONV)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let params = (
|
||||
dst_el,
|
||||
l_out,
|
||||
self.l_k,
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
&ds,
|
||||
src,
|
||||
&dst,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct Im2Col {
|
||||
h_k: usize,
|
||||
w_k: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
padding: usize,
|
||||
}
|
||||
|
||||
impl Im2Col {
|
||||
fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
|
||||
let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
|
||||
let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
|
||||
(h_out, w_out)
|
||||
}
|
||||
}
|
||||
|
||||
impl Map1 for Im2Col {
|
||||
fn f<T: DeviceRepr + WithDType>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let (h_out, w_out) = self.hw_out(dims[2], dims[3]);
|
||||
let dst_el = dims[0] * h_out * w_out * dims[1] * self.h_k * self.w_k;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("im2col"), kernels::CONV)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let params = (
|
||||
dst_el,
|
||||
h_out,
|
||||
w_out,
|
||||
self.h_k,
|
||||
self.w_k,
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
&ds,
|
||||
src,
|
||||
&dst,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct Powf(f64);
|
||||
impl Map1 for Powf {
|
||||
fn f<T: DeviceRepr + WithDType>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("upowf"), kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (el, dims.len(), &ds, T::from_f64(self.0), src, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct Sum<'a>(&'a [usize]);
|
||||
impl<'a> Map1 for Sum<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
@ -880,9 +714,6 @@ impl<'a> Map1 for IndexSelect<'a> {
|
||||
CudaStorageSlice::U8(slice) => {
|
||||
("is_u8", *slice.slice(ids_l.start_offset()..).device_ptr())
|
||||
}
|
||||
CudaStorageSlice::I64(slice) => {
|
||||
("is_i64", *slice.slice(ids_l.start_offset()..).device_ptr())
|
||||
}
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "index_select ids should be u8 or u32",
|
||||
expected: DType::U32,
|
||||
@ -892,6 +723,8 @@ impl<'a> Map1 for IndexSelect<'a> {
|
||||
};
|
||||
let ids_shape = ids_l.shape();
|
||||
let ids_dims = ids_shape.dims();
|
||||
let ids_el = ids_shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(ids_el as u32);
|
||||
let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?;
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
@ -899,23 +732,19 @@ impl<'a> Map1 for IndexSelect<'a> {
|
||||
};
|
||||
let left_size: usize = src_l.dims()[..self.2].iter().product();
|
||||
let right_size: usize = src_l.dims()[self.2 + 1..].iter().product();
|
||||
let src_dim_size = src_l.dims()[self.2];
|
||||
let ids_dim_size = ids_shape.elem_count();
|
||||
let dst_el = ids_shape.elem_count() * left_size * right_size;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let dim_size = src_l.dims()[self.2];
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let out = unsafe { dev.alloc::<T>(ids_el * left_size * right_size) }.w()?;
|
||||
let params = (
|
||||
dst_el,
|
||||
ids_el,
|
||||
ids_dims.len(),
|
||||
&ds,
|
||||
ids,
|
||||
&src,
|
||||
&out,
|
||||
left_size,
|
||||
src_dim_size,
|
||||
ids_dim_size,
|
||||
dim_size,
|
||||
right_size,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
@ -944,11 +773,8 @@ impl<'a> Map1 for Gather<'a> {
|
||||
("gather_u32", *slice.slice(ids_o1..ids_o2).device_ptr())
|
||||
}
|
||||
CudaStorageSlice::U8(slice) => ("gather_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
CudaStorageSlice::I64(slice) => {
|
||||
("gather_i64", *slice.slice(ids_o1..ids_o2).device_ptr())
|
||||
}
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "gather ids should be u8/u32/i64",
|
||||
msg: "gather ids should be u8 or u32",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
@ -994,10 +820,9 @@ impl<'a> Map2InPlace for IndexAdd<'a> {
|
||||
};
|
||||
let (name, ids) = match &ids.slice {
|
||||
CudaStorageSlice::U32(slice) => ("ia_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
CudaStorageSlice::I64(slice) => ("ia_i64", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
CudaStorageSlice::U8(slice) => ("ia_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "index-add ids should be u8/u32/i64",
|
||||
msg: "index-add ids should be u8 or u32",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
@ -1042,10 +867,9 @@ impl<'a> Map2InPlace for ScatterAdd<'a> {
|
||||
};
|
||||
let (name, ids) = match &ids.slice {
|
||||
CudaStorageSlice::U32(slice) => ("sa_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
CudaStorageSlice::I64(slice) => ("sa_i64", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
CudaStorageSlice::U8(slice) => ("sa_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "scatter-add ids should be u8/u32/i64",
|
||||
msg: "scatter-add ids should be u8 or u32",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
@ -1097,12 +921,10 @@ impl<'a> Map2 for Conv1D<'a> {
|
||||
} else if dims.len() == 2 {
|
||||
[&[1], dims, &[1], inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
||||
} else {
|
||||
crate::bail!("unexpected input shape for conv1d {dims:?}")
|
||||
panic!("unexpected input shape for conv1d {dims:?}")
|
||||
};
|
||||
let ds = dev.htod_copy(ds).w()?;
|
||||
let params = (
|
||||
el, l_out, p.stride, p.padding, p.dilation, &ds, inp, k, &out,
|
||||
);
|
||||
let params = (el, l_out, p.stride, p.padding, &ds, inp, k, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
@ -1119,8 +941,8 @@ impl<'a> Map2 for Conv2D<'a> {
|
||||
k_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
// Kernel shape: (c_out, c_in_k, h_k, w_k)
|
||||
// Input shape: (b_size, c_in, h_in, w_in)
|
||||
// Kernel shape: (c_out, c_in_k, w_k, h_k)
|
||||
// Input shape: (b_size, c_in, w_in, c_in)
|
||||
let p = &self.0;
|
||||
let (out_w, out_h) = (p.out_w(), p.out_h());
|
||||
let dst_el = p.c_out * out_w * out_h * p.b_size;
|
||||
@ -1137,111 +959,10 @@ impl<'a> Map2 for Conv2D<'a> {
|
||||
let ds = if dims.len() == 4 {
|
||||
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
||||
} else {
|
||||
crate::bail!("unexpected input shape for conv2d {dims:?}")
|
||||
panic!("unexpected input shape for conv1d {dims:?}")
|
||||
};
|
||||
let ds = dev.htod_copy(ds).w()?;
|
||||
let params = (
|
||||
el, out_w, out_h, p.stride, p.padding, p.dilation, &ds, inp, k, &out,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
||||
impl<'a> Map2 for ConvTranspose1D<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
inp: &CudaSlice<T>,
|
||||
inp_l: &Layout,
|
||||
k: &CudaSlice<T>,
|
||||
k_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
// Kernel shape: (c_in_k, c_out, l_k)
|
||||
// Input shape: (b_size, c_in, l_in)
|
||||
let p = &self.0;
|
||||
let l_out = p.l_out();
|
||||
let dst_el = p.c_out * l_out * p.b_size;
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(k_l.start_offset()..);
|
||||
let shape = inp_l.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose1d"), kernels::CONV)?;
|
||||
let ds = if dims.len() == 3 {
|
||||
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
||||
} else {
|
||||
crate::bail!("unexpected input shape for conv_transpose1d {dims:?}")
|
||||
};
|
||||
let ds = dev.htod_copy(ds).w()?;
|
||||
let params = (
|
||||
el,
|
||||
l_out,
|
||||
p.stride,
|
||||
p.padding,
|
||||
p.output_padding,
|
||||
p.dilation,
|
||||
&ds,
|
||||
inp,
|
||||
k,
|
||||
&out,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
|
||||
impl<'a> Map2 for ConvTranspose2D<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
inp: &CudaSlice<T>,
|
||||
inp_l: &Layout,
|
||||
k: &CudaSlice<T>,
|
||||
k_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
// Kernel shape: (c_in_k, c_out, h_k, w_k)
|
||||
// Input shape: (b_size, c_in, h_in, w_in)
|
||||
let p = &self.0;
|
||||
let (out_w, out_h) = (p.out_w(), p.out_h());
|
||||
let dst_el = p.c_out * out_w * out_h * p.b_size;
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(k_l.start_offset()..);
|
||||
let shape = inp_l.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose2d"), kernels::CONV)?;
|
||||
let ds = if dims.len() == 4 {
|
||||
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
||||
} else {
|
||||
crate::bail!("unexpected input shape for conv_transpose2d {dims:?}")
|
||||
};
|
||||
let ds = dev.htod_copy(ds).w()?;
|
||||
let params = (
|
||||
el,
|
||||
out_w,
|
||||
out_h,
|
||||
p.stride,
|
||||
p.padding,
|
||||
p.output_padding,
|
||||
p.dilation,
|
||||
&ds,
|
||||
inp,
|
||||
k,
|
||||
&out,
|
||||
);
|
||||
let params = (el, out_w, out_h, p.stride, p.padding, &ds, inp, k, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
@ -1275,7 +996,7 @@ impl Map1 for Pool2D {
|
||||
let ds = if dims.len() == 4 {
|
||||
[dims, inp_l.stride()].concat()
|
||||
} else {
|
||||
crate::bail!("unexpected input shape for pool {dims:?}")
|
||||
panic!("unexpected input shape for conv1d {dims:?}")
|
||||
};
|
||||
let el = shape.elem_count();
|
||||
let out_w = (dims[2] - self.w_k) / self.w_stride + 1;
|
||||
@ -1321,7 +1042,7 @@ impl Map1 for UpsampleNearest2D {
|
||||
let ds = if dims.len() == 4 {
|
||||
[dims, inp_l.stride()].concat()
|
||||
} else {
|
||||
crate::bail!("unexpected input shape for upsample {dims:?}")
|
||||
panic!("unexpected input shape for conv1d {dims:?}")
|
||||
};
|
||||
let (out_w, out_h) = (self.0, self.1);
|
||||
let dst_el = out_w * out_h * dims[0] * dims[1];
|
||||
@ -1359,12 +1080,8 @@ impl<'a> Map2 for WhereCond<'a> {
|
||||
let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
|
||||
(ptr, "where_u32")
|
||||
}
|
||||
CudaStorageSlice::I64(slice) => {
|
||||
let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
|
||||
(ptr, "where_i64")
|
||||
}
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "where conditions should be u8/u32/i64",
|
||||
msg: "where conditions should be u8 or u32",
|
||||
expected: DType::U32,
|
||||
got: self.0.dtype(),
|
||||
})
|
||||
@ -1475,8 +1192,8 @@ fn slice_src_and_dst<'a, T>(
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CudaStorage {
|
||||
pub slice: CudaStorageSlice,
|
||||
pub device: CudaDevice,
|
||||
slice: CudaStorageSlice,
|
||||
device: CudaDevice,
|
||||
}
|
||||
|
||||
pub trait CudaDType: Sized {
|
||||
@ -1508,7 +1225,6 @@ macro_rules! cuda_dtype {
|
||||
}
|
||||
cuda_dtype!(u8, U8);
|
||||
cuda_dtype!(u32, U32);
|
||||
cuda_dtype!(i64, I64);
|
||||
cuda_dtype!(f16, F16);
|
||||
cuda_dtype!(bf16, BF16);
|
||||
cuda_dtype!(f32, F32);
|
||||
@ -1622,7 +1338,6 @@ impl BackendStorage for CudaStorage {
|
||||
match self.slice {
|
||||
CudaStorageSlice::U8(_) => DType::U8,
|
||||
CudaStorageSlice::U32(_) => DType::U32,
|
||||
CudaStorageSlice::I64(_) => DType::I64,
|
||||
CudaStorageSlice::BF16(_) => DType::BF16,
|
||||
CudaStorageSlice::F16(_) => DType::F16,
|
||||
CudaStorageSlice::F32(_) => DType::F32,
|
||||
@ -1648,7 +1363,6 @@ impl BackendStorage for CudaStorage {
|
||||
let inp = match &self.slice {
|
||||
CudaStorageSlice::U8(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::I64(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::F32(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
@ -1671,12 +1385,6 @@ impl BackendStorage for CudaStorage {
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::U32(out)
|
||||
}
|
||||
DType::I64 => {
|
||||
let out = unsafe { dev.alloc::<i64>(el) }.w()?;
|
||||
let params = (el, dims.len(), &ds, *inp, &out);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::I64(out)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let out = unsafe { dev.alloc::<bf16>(el) }.w()?;
|
||||
let params = (el, dims.len(), &ds, *inp, &out);
|
||||
@ -1714,12 +1422,6 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn powf(&self, layout: &Layout, e: f64) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = Powf(e).map(&self.slice, &device, layout)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = Elu(alpha).map(&self.slice, &device, layout)?;
|
||||
@ -1767,11 +1469,6 @@ impl BackendStorage for CudaStorage {
|
||||
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
|
||||
Ok(CpuStorage::U32(cpu_storage))
|
||||
}
|
||||
CudaStorageSlice::I64(slice) => {
|
||||
let dev = slice.device();
|
||||
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
|
||||
Ok(CpuStorage::I64(cpu_storage))
|
||||
}
|
||||
CudaStorageSlice::BF16(slice) => {
|
||||
let dev = slice.device();
|
||||
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
|
||||
@ -1815,58 +1512,8 @@ impl BackendStorage for CudaStorage {
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self> {
|
||||
const USE_IM2COL_CONV1D: bool = true;
|
||||
|
||||
let device = self.device().clone();
|
||||
if !USE_IM2COL_CONV1D {
|
||||
let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
return Ok(Self { slice, device });
|
||||
}
|
||||
|
||||
let col = Im2Col1D {
|
||||
l_k: params.k_size,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
padding: params.padding,
|
||||
}
|
||||
.map(&self.slice, &device, l)?;
|
||||
let col = Self { slice: col, device };
|
||||
let l_out = params.l_out();
|
||||
let b = params.b_size;
|
||||
let n = params.c_out;
|
||||
let k = params.k_size * params.c_in;
|
||||
let m = l_out;
|
||||
let col_l = Layout::contiguous((b, m, k));
|
||||
let res = if kernel_l.is_contiguous() {
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
};
|
||||
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
|
||||
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||
Ok(res_t)
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice =
|
||||
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
@ -1878,50 +1525,9 @@ impl BackendStorage for CudaStorage {
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConv2D,
|
||||
) -> Result<Self> {
|
||||
const USE_IM2COL_CONV2D: bool = true;
|
||||
|
||||
let device = self.device().clone();
|
||||
if !USE_IM2COL_CONV2D {
|
||||
let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
return Ok(Self { slice, device });
|
||||
}
|
||||
|
||||
let col = Im2Col {
|
||||
h_k: params.k_h,
|
||||
w_k: params.k_w,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
padding: params.padding,
|
||||
}
|
||||
.map(&self.slice, &device, l)?;
|
||||
let col = Self { slice: col, device };
|
||||
let h_out = params.out_h();
|
||||
let w_out = params.out_w();
|
||||
let b = params.b_size;
|
||||
let n = params.c_out;
|
||||
let k = params.k_h * params.k_w * params.c_in;
|
||||
let m = h_out * w_out;
|
||||
let col_l = Layout::contiguous((b, m, k));
|
||||
let res = if kernel_l.is_contiguous() {
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
};
|
||||
let res_l = Layout::contiguous((b, h_out, w_out, n))
|
||||
.transpose(1, 2)?
|
||||
.transpose(1, 3)?;
|
||||
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||
Ok(res_t)
|
||||
let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
#[cfg(feature = "cudnn")]
|
||||
@ -1964,6 +1570,7 @@ impl BackendStorage for CudaStorage {
|
||||
.map_err(crate::Error::wrap)?;
|
||||
S::F16(out)
|
||||
}
|
||||
|
||||
(S::F32(inp), S::F32(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
@ -1981,25 +1588,11 @@ impl BackendStorage for CudaStorage {
|
||||
S::F64(out)
|
||||
}
|
||||
(S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv2d does not support u32"))?,
|
||||
(S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv2d does not support i64"))?,
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in conv2d"))?,
|
||||
};
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConvTranspose2D,
|
||||
) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice =
|
||||
ConvTranspose2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn avg_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = Pool2D {
|
||||
@ -2026,10 +1619,6 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn upsample_nearest1d(&self, _: &Layout, _out_sz: usize) -> Result<Self> {
|
||||
crate::bail!("upsample-nearest1d is not supported on cuda")
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?;
|
||||
@ -2149,9 +1738,6 @@ impl BackendStorage for CudaStorage {
|
||||
let src_shape = src_l.shape();
|
||||
let dims = src_shape.dims();
|
||||
let el_count = src_shape.elem_count();
|
||||
if el_count == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let dev = &self.device;
|
||||
let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?;
|
||||
@ -2216,24 +1802,12 @@ impl BackendStorage for CudaStorage {
|
||||
unsafe { func.launch(cfg, params) }.w()?
|
||||
}
|
||||
}
|
||||
(CudaStorageSlice::I64(src), CudaStorageSlice::I64(dst)) => {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||
if src_l.is_contiguous() {
|
||||
dev.dtod_copy(&src, &mut dst).w()?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_i64", kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?
|
||||
}
|
||||
}
|
||||
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||
if src_l.is_contiguous() {
|
||||
dev.dtod_copy(&src, &mut dst).w()?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_f64", kernels::UNARY)?;
|
||||
let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
||||
// SAFETY: ffi.
|
||||
|
@ -34,9 +34,6 @@ pub(crate) fn launch_conv2d<
|
||||
params: &crate::conv::ParamsConv2D,
|
||||
dev: &crate::cuda_backend::CudaDevice,
|
||||
) -> crate::Result<()> {
|
||||
use crate::conv::CudnnFwdAlgo as CandleAlgo;
|
||||
use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
|
||||
|
||||
let device_id = dev.id();
|
||||
let cudnn = CUDNN.with(|cudnn| {
|
||||
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
||||
@ -51,14 +48,14 @@ pub(crate) fn launch_conv2d<
|
||||
let conv = cudnn.create_conv2d::<T>(
|
||||
/* pad */ [params.padding as i32, params.padding as i32],
|
||||
/* stride */ [params.stride as i32, params.stride as i32],
|
||||
/* dilation */ [params.dilation as i32, params.dilation as i32],
|
||||
/* dilation */ [1, 1],
|
||||
cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
|
||||
)?;
|
||||
let x_shape = [
|
||||
params.b_size as i32,
|
||||
params.c_in as i32,
|
||||
params.i_h as i32,
|
||||
params.i_w as i32,
|
||||
params.i_h as i32,
|
||||
];
|
||||
// Note that `src` already starts at the proper offset.
|
||||
let x = if src_l.is_contiguous() {
|
||||
@ -78,14 +75,14 @@ pub(crate) fn launch_conv2d<
|
||||
[
|
||||
params.c_out as i32,
|
||||
params.c_in as i32,
|
||||
params.k_h as i32,
|
||||
params.k_w as i32,
|
||||
params.k_h as i32,
|
||||
],
|
||||
)?;
|
||||
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
|
||||
let y = cudnn.create_4d_tensor(
|
||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||
[params.b_size as i32, params.c_out as i32, h_out, w_out],
|
||||
[params.b_size as i32, params.c_out as i32, w_out, h_out],
|
||||
)?;
|
||||
let conv2d = Conv2dForward {
|
||||
conv: &conv,
|
||||
@ -93,20 +90,7 @@ pub(crate) fn launch_conv2d<
|
||||
w: &w,
|
||||
y: &y,
|
||||
};
|
||||
let alg = match params.cudnn_fwd_algo {
|
||||
None => conv2d.pick_algorithm()?,
|
||||
Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
||||
Some(CandleAlgo::ImplicitPrecompGemm) => {
|
||||
A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
|
||||
}
|
||||
Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
||||
Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
||||
Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
||||
Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
||||
Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
||||
Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
|
||||
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
|
||||
};
|
||||
let alg = conv2d.pick_algorithm()?;
|
||||
let workspace_size = conv2d.get_workspace_size(alg)?;
|
||||
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
|
||||
unsafe {
|
||||
|
@ -8,16 +8,15 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
|
||||
pub enum DeviceLocation {
|
||||
Cpu,
|
||||
Cuda { gpu_id: usize },
|
||||
Metal { gpu_id: usize },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Device {
|
||||
Cpu,
|
||||
Cuda(crate::CudaDevice),
|
||||
Metal(crate::MetalDevice),
|
||||
}
|
||||
|
||||
// TODO: Should we back the cpu implementation using the NdArray crate or similar?
|
||||
pub trait NdArray {
|
||||
fn shape(&self) -> Result<Shape>;
|
||||
|
||||
@ -82,71 +81,15 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4: usize> NdArray
|
||||
for &[[[[S; N4]; N3]; N2]; N1]
|
||||
{
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
Ok(Shape::from((N1, N2, N3, N4)))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
let mut vec = Vec::with_capacity(N1 * N2 * N3 * N4);
|
||||
for i1 in 0..N1 {
|
||||
for i2 in 0..N2 {
|
||||
for i3 in 0..N3 {
|
||||
vec.extend(self[i1][i2][i3])
|
||||
}
|
||||
}
|
||||
}
|
||||
S::to_cpu_storage_owned(vec)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: NdArray> NdArray for Vec<S> {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
if self.is_empty() {
|
||||
crate::bail!("empty array")
|
||||
}
|
||||
let shape0 = self[0].shape()?;
|
||||
let n = self.len();
|
||||
for v in self.iter() {
|
||||
let shape = v.shape()?;
|
||||
if shape != shape0 {
|
||||
crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
|
||||
}
|
||||
}
|
||||
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
// This allocates intermediary memory and shouldn't be necessary.
|
||||
let storages = self.iter().map(|v| v.to_cpu_storage()).collect::<Vec<_>>();
|
||||
CpuStorage::concat(storages.as_slice()).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl Device {
|
||||
pub fn new_cuda(ordinal: usize) -> Result<Self> {
|
||||
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
||||
}
|
||||
|
||||
pub fn new_metal(ordinal: usize) -> Result<Self> {
|
||||
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
|
||||
}
|
||||
|
||||
pub fn set_seed(&self, seed: u64) -> Result<()> {
|
||||
match self {
|
||||
Self::Cpu => CpuDevice.set_seed(seed),
|
||||
Self::Cuda(c) => c.set_seed(seed),
|
||||
Self::Metal(m) => m.set_seed(seed),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn same_device(&self, rhs: &Self) -> bool {
|
||||
match (self, rhs) {
|
||||
(Self::Cpu, Self::Cpu) => true,
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
|
||||
(Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
@ -155,20 +98,21 @@ impl Device {
|
||||
match self {
|
||||
Self::Cpu => DeviceLocation::Cpu,
|
||||
Self::Cuda(device) => device.location(),
|
||||
Device::Metal(device) => device.location(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_cpu(&self) -> bool {
|
||||
matches!(self, Self::Cpu)
|
||||
match self {
|
||||
Self::Cpu => true,
|
||||
Self::Cuda(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_cuda(&self) -> bool {
|
||||
matches!(self, Self::Cuda(_))
|
||||
}
|
||||
|
||||
pub fn is_metal(&self) -> bool {
|
||||
matches!(self, Self::Metal(_))
|
||||
match self {
|
||||
Self::Cpu => false,
|
||||
Self::Cuda(_) => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
||||
@ -192,18 +136,8 @@ impl Device {
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
|
||||
if dtype == DType::F16 || dtype == DType::BF16 {
|
||||
let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
|
||||
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
|
||||
} else {
|
||||
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -230,18 +164,8 @@ impl Device {
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
|
||||
if dtype == DType::F16 || dtype == DType::BF16 {
|
||||
let storage = device.rand_normal(shape, DType::F32, mean, std)?;
|
||||
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
|
||||
} else {
|
||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -265,10 +189,6 @@ impl Device {
|
||||
let storage = device.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -282,10 +202,6 @@ impl Device {
|
||||
let storage = device.zeros_impl(shape, dtype)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.zeros_impl(shape, dtype)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -297,11 +213,6 @@ impl Device {
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = array.to_cpu_storage();
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -313,11 +224,6 @@ impl Device {
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = S::to_cpu_storage_owned(data);
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -9,17 +9,11 @@ impl Tensor {
|
||||
&self,
|
||||
f: &mut std::fmt::Formatter,
|
||||
) -> std::fmt::Result {
|
||||
let device_str = match self.device().location() {
|
||||
crate::DeviceLocation::Cpu => "".to_owned(),
|
||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||
format!(", cuda:{}", gpu_id)
|
||||
}
|
||||
crate::DeviceLocation::Metal { gpu_id } => {
|
||||
format!(", metal:{}", gpu_id)
|
||||
}
|
||||
let prefix = match self.device() {
|
||||
crate::Device::Cpu => "Cpu",
|
||||
crate::Device::Cuda(_) => "Cuda",
|
||||
};
|
||||
|
||||
write!(f, "Tensor[")?;
|
||||
write!(f, "{prefix}Tensor[")?;
|
||||
match self.dims() {
|
||||
[] => {
|
||||
if let Ok(v) = self.to_scalar::<T>() {
|
||||
@ -46,7 +40,7 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
}
|
||||
write!(f, "; {}{}]", self.dtype().as_str(), device_str)
|
||||
write!(f, "; {}]", self.dtype().as_str())
|
||||
}
|
||||
}
|
||||
|
||||
@ -55,7 +49,6 @@ impl std::fmt::Debug for Tensor {
|
||||
match self.dtype() {
|
||||
DType::U8 => self.fmt_dt::<u8>(f),
|
||||
DType::U32 => self.fmt_dt::<u32>(f),
|
||||
DType::I64 => self.fmt_dt::<i64>(f),
|
||||
DType::BF16 => self.fmt_dt::<bf16>(f),
|
||||
DType::F16 => self.fmt_dt::<f16>(f),
|
||||
DType::F32 => self.fmt_dt::<f32>(f),
|
||||
@ -438,12 +431,6 @@ impl std::fmt::Display for Tensor {
|
||||
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
|
||||
writeln!(f)?;
|
||||
}
|
||||
DType::I64 => {
|
||||
let tf: IntFormatter<i64> = IntFormatter::new();
|
||||
let max_w = tf.max_width(&to_display);
|
||||
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
|
||||
writeln!(f)?;
|
||||
}
|
||||
DType::BF16 => {
|
||||
if let Ok(tf) = FloatFormatter::<bf16>::new(&to_display, &po) {
|
||||
let max_w = tf.max_width(&to_display);
|
||||
@ -473,23 +460,6 @@ impl std::fmt::Display for Tensor {
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let device_str = match self.device().location() {
|
||||
crate::DeviceLocation::Cpu => "".to_owned(),
|
||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||
format!(", cuda:{}", gpu_id)
|
||||
}
|
||||
crate::DeviceLocation::Metal { gpu_id } => {
|
||||
format!(", metal:{}", gpu_id)
|
||||
}
|
||||
};
|
||||
|
||||
write!(
|
||||
f,
|
||||
"Tensor[{:?}, {}{}]",
|
||||
self.dims(),
|
||||
self.dtype().as_str(),
|
||||
device_str
|
||||
)
|
||||
write!(f, "Tensor[{:?}, {}]", self.dims(), self.dtype().as_str())
|
||||
}
|
||||
}
|
||||
|
@ -1,24 +1,13 @@
|
||||
//! Types for elements that can be stored and manipulated using tensors.
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::{CpuStorage, Error, Result};
|
||||
|
||||
/// The different types of elements allowed in tensors.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum DType {
|
||||
// Unsigned 8 bits integer.
|
||||
U8,
|
||||
// Unsigned 32 bits integer.
|
||||
U32,
|
||||
// Signed 64 bits integer.
|
||||
I64,
|
||||
// Brain floating-point using half precision (16 bits).
|
||||
BF16,
|
||||
// Floating-point using half precision (16 bits).
|
||||
F16,
|
||||
// Floating-point using single precision (32 bits).
|
||||
F32,
|
||||
// Floating-point using double precision (64 bits).
|
||||
F64,
|
||||
}
|
||||
|
||||
@ -31,7 +20,6 @@ impl std::str::FromStr for DType {
|
||||
match s {
|
||||
"u8" => Ok(Self::U8),
|
||||
"u32" => Ok(Self::U32),
|
||||
"i64" => Ok(Self::I64),
|
||||
"bf16" => Ok(Self::BF16),
|
||||
"f16" => Ok(Self::F16),
|
||||
"f32" => Ok(Self::F32),
|
||||
@ -42,12 +30,10 @@ impl std::str::FromStr for DType {
|
||||
}
|
||||
|
||||
impl DType {
|
||||
/// String representation for dtypes.
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::U8 => "u8",
|
||||
Self::U32 => "u32",
|
||||
Self::I64 => "i64",
|
||||
Self::BF16 => "bf16",
|
||||
Self::F16 => "f16",
|
||||
Self::F32 => "f32",
|
||||
@ -55,32 +41,16 @@ impl DType {
|
||||
}
|
||||
}
|
||||
|
||||
/// The size used by each element in bytes, i.e. 1 for `U8`, 4 for `F32`.
|
||||
pub fn size_in_bytes(&self) -> usize {
|
||||
match self {
|
||||
Self::U8 => 1,
|
||||
Self::U32 => 4,
|
||||
Self::I64 => 8,
|
||||
Self::BF16 => 2,
|
||||
Self::F16 => 2,
|
||||
Self::F32 => 4,
|
||||
Self::F64 => 8,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_int(&self) -> bool {
|
||||
match self {
|
||||
Self::U8 | Self::U32 | Self::I64 => true,
|
||||
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_float(&self) -> bool {
|
||||
match self {
|
||||
Self::U8 | Self::U32 | Self::I64 => false,
|
||||
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait WithDType:
|
||||
@ -155,7 +125,6 @@ use half::{bf16, f16};
|
||||
|
||||
with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64);
|
||||
with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64);
|
||||
with_dtype!(i64, I64, |v: f64| v as i64, |v: i64| v as f64);
|
||||
with_dtype!(f16, F16, f16::from_f64, f16::to_f64);
|
||||
with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
|
||||
with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
|
||||
@ -166,15 +135,6 @@ pub trait IntDType: WithDType {
|
||||
fn as_usize(&self) -> usize;
|
||||
}
|
||||
|
||||
impl IntDType for i64 {
|
||||
fn is_true(&self) -> bool {
|
||||
*self != 0
|
||||
}
|
||||
fn as_usize(&self) -> usize {
|
||||
*self as usize
|
||||
}
|
||||
}
|
||||
|
||||
impl IntDType for u32 {
|
||||
fn is_true(&self) -> bool {
|
||||
*self != 0
|
||||
|
@ -37,10 +37,6 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
@ -79,16 +75,6 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
_: &Layout,
|
||||
@ -99,16 +85,6 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &crate::conv::ParamsConvTranspose2D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
@ -162,10 +138,6 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
@ -177,10 +149,6 @@ impl crate::backend::BackendDevice for CudaDevice {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn set_seed(&self, _: u64) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
fail!()
|
||||
}
|
||||
|
@ -1,223 +0,0 @@
|
||||
#![allow(dead_code)]
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetalDevice;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MetalStorage;
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum MetalError {
|
||||
#[error("{0}")]
|
||||
Message(String),
|
||||
}
|
||||
|
||||
impl From<String> for MetalError {
|
||||
fn from(e: String) -> Self {
|
||||
MetalError::Message(e)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! fail {
|
||||
() => {
|
||||
unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
|
||||
};
|
||||
}
|
||||
|
||||
impl crate::backend::BackendStorage for MetalStorage {
|
||||
type Device = MetalDevice;
|
||||
|
||||
fn try_clone(&self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn dtype(&self) -> DType {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn device(&self) -> &Self::Device {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &crate::conv::ParamsConv2D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &crate::conv::ParamsConvTranspose2D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn index_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn matmul(
|
||||
&self,
|
||||
_: &Self,
|
||||
_: (usize, usize, usize, usize),
|
||||
_: &Layout,
|
||||
_: &Layout,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::backend::BackendDevice for MetalDevice {
|
||||
type Storage = MetalStorage;
|
||||
fn new(_: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn set_seed(&self, _: u64) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn same_device(&self, _: &Self) -> bool {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
use crate::{DType, DeviceLocation, Layout, MetalError, Shape};
|
||||
use crate::{DType, DeviceLocation, Layout, Shape};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MatMulUnexpectedStriding {
|
||||
@ -30,7 +30,7 @@ pub enum Error {
|
||||
UnsupportedDTypeForOp(DType, &'static str),
|
||||
|
||||
// === Dimension Index Errors ===
|
||||
#[error("{op}: dimension index {dim} out of range for shape {shape:?}")]
|
||||
#[error("{op}: dimension index {dim} out of range for {shape:?}")]
|
||||
DimOutOfRange {
|
||||
shape: Shape,
|
||||
dim: i32,
|
||||
@ -142,9 +142,6 @@ pub enum Error {
|
||||
#[error("{op} expects at least one tensor")]
|
||||
OpRequiresAtLeastOneTensor { op: &'static str },
|
||||
|
||||
#[error("{op} expects at least two tensors")]
|
||||
OpRequiresAtLeastTwoTensors { op: &'static str },
|
||||
|
||||
#[error("backward is not supported for {op}")]
|
||||
BackwardNotSupported { op: &'static str },
|
||||
|
||||
@ -152,9 +149,6 @@ pub enum Error {
|
||||
#[error("the candle crate has not been built with cuda support")]
|
||||
NotCompiledWithCudaSupport,
|
||||
|
||||
#[error("the candle crate has not been built with metal support")]
|
||||
NotCompiledWithMetalSupport,
|
||||
|
||||
#[error("cannot find tensor {path}")]
|
||||
CannotFindTensor { path: String },
|
||||
|
||||
@ -162,9 +156,6 @@ pub enum Error {
|
||||
#[error(transparent)]
|
||||
Cuda(Box<dyn std::error::Error + Send + Sync>),
|
||||
|
||||
#[error("Metal error {0}")]
|
||||
Metal(#[from] MetalError),
|
||||
|
||||
#[error(transparent)]
|
||||
TryFromIntError(#[from] core::num::TryFromIntError),
|
||||
|
||||
@ -216,11 +207,7 @@ pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
impl Error {
|
||||
pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self {
|
||||
Self::Wrapped(Box::new(err)).bt()
|
||||
}
|
||||
|
||||
pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self {
|
||||
Self::Msg(err.to_string()).bt()
|
||||
Self::Wrapped(Box::new(err))
|
||||
}
|
||||
|
||||
pub fn bt(self) -> Self {
|
||||
|
@ -46,31 +46,19 @@ impl Tensor {
|
||||
current_dim += 1;
|
||||
out
|
||||
}
|
||||
TensorIndexer::IndexSelect(indexes) => {
|
||||
if indexes.rank() != 1 {
|
||||
crate::bail!("multi-dimensional tensor indexing is not supported")
|
||||
}
|
||||
let out = x.index_select(&indexes.to_device(x.device())?, current_dim)?;
|
||||
current_dim += 1;
|
||||
out
|
||||
}
|
||||
TensorIndexer::Err(e) => crate::bail!("indexing error {e:?}"),
|
||||
};
|
||||
}
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
/// Generic structure used to index a slice of the tensor
|
||||
pub enum TensorIndexer {
|
||||
/// This selects the elements for which an index has some specific value.
|
||||
/// This selects the elemnts for which an index has some specific value.
|
||||
Select(usize),
|
||||
/// This is a regular slice, purely indexing a chunk of the tensor
|
||||
Narrow(Bound<usize>, Bound<usize>),
|
||||
/// Indexing via a 1d tensor
|
||||
IndexSelect(Tensor),
|
||||
Err(Error),
|
||||
}
|
||||
|
||||
impl From<usize> for TensorIndexer {
|
||||
@ -79,55 +67,36 @@ impl From<usize> for TensorIndexer {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[u32]> for TensorIndexer {
|
||||
fn from(index: &[u32]) -> Self {
|
||||
match Tensor::new(index, &crate::Device::Cpu) {
|
||||
Ok(tensor) => TensorIndexer::IndexSelect(tensor),
|
||||
Err(e) => TensorIndexer::Err(e),
|
||||
macro_rules! impl_from_range {
|
||||
($range_type:ty) => {
|
||||
impl From<$range_type> for TensorIndexer {
|
||||
fn from(range: $range_type) -> Self {
|
||||
use std::ops::Bound::*;
|
||||
|
||||
let start = match range.start_bound() {
|
||||
Included(idx) => Included(*idx),
|
||||
Excluded(idx) => Excluded(*idx),
|
||||
Unbounded => Unbounded,
|
||||
};
|
||||
|
||||
let end = match range.end_bound() {
|
||||
Included(idx) => Included(*idx),
|
||||
Excluded(idx) => Excluded(*idx),
|
||||
Unbounded => Unbounded,
|
||||
};
|
||||
|
||||
TensorIndexer::Narrow(start, end)
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl From<Vec<u32>> for TensorIndexer {
|
||||
fn from(index: Vec<u32>) -> Self {
|
||||
let len = index.len();
|
||||
match Tensor::from_vec(index, len, &crate::Device::Cpu) {
|
||||
Ok(tensor) => TensorIndexer::IndexSelect(tensor),
|
||||
Err(e) => TensorIndexer::Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Tensor> for TensorIndexer {
|
||||
fn from(tensor: &Tensor) -> Self {
|
||||
TensorIndexer::IndexSelect(tensor.clone())
|
||||
}
|
||||
}
|
||||
|
||||
trait RB: RangeBounds<usize> {}
|
||||
impl RB for Range<usize> {}
|
||||
impl RB for RangeFrom<usize> {}
|
||||
impl RB for RangeFull {}
|
||||
impl RB for RangeInclusive<usize> {}
|
||||
impl RB for RangeTo<usize> {}
|
||||
impl RB for RangeToInclusive<usize> {}
|
||||
|
||||
impl<T: RB> From<T> for TensorIndexer {
|
||||
fn from(range: T) -> Self {
|
||||
use std::ops::Bound::*;
|
||||
let start = match range.start_bound() {
|
||||
Included(idx) => Included(*idx),
|
||||
Excluded(idx) => Excluded(*idx),
|
||||
Unbounded => Unbounded,
|
||||
};
|
||||
let end = match range.end_bound() {
|
||||
Included(idx) => Included(*idx),
|
||||
Excluded(idx) => Excluded(*idx),
|
||||
Unbounded => Unbounded,
|
||||
};
|
||||
TensorIndexer::Narrow(start, end)
|
||||
}
|
||||
}
|
||||
impl_from_range!(Range<usize>);
|
||||
impl_from_range!(RangeFrom<usize>);
|
||||
impl_from_range!(RangeFull);
|
||||
impl_from_range!(RangeInclusive<usize>);
|
||||
impl_from_range!(RangeTo<usize>);
|
||||
impl_from_range!(RangeToInclusive<usize>);
|
||||
|
||||
/// Trait used to implement multiple signatures for ease of use of the slicing
|
||||
/// of a tensor
|
||||
|
@ -9,14 +9,6 @@ pub struct Layout {
|
||||
}
|
||||
|
||||
impl Layout {
|
||||
pub fn new(shape: Shape, stride: Vec<usize>, start_offset: usize) -> Self {
|
||||
Self {
|
||||
shape,
|
||||
stride,
|
||||
start_offset,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn contiguous_with_offset<S: Into<Shape>>(shape: S, start_offset: usize) -> Self {
|
||||
let shape = shape.into();
|
||||
let stride = shape.stride_contiguous();
|
||||
@ -120,31 +112,6 @@ impl Layout {
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn permute(&self, idxs: &[usize]) -> Result<Self> {
|
||||
let is_permutation =
|
||||
idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
|
||||
if !is_permutation {
|
||||
crate::bail!(
|
||||
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
|
||||
self.dims(),
|
||||
idxs
|
||||
)
|
||||
}
|
||||
let stride = self.stride();
|
||||
let dims = self.shape().dims();
|
||||
let mut perm_stride = stride.to_vec();
|
||||
let mut perm_dims = dims.to_vec();
|
||||
for (i, &idx) in idxs.iter().enumerate() {
|
||||
perm_stride[i] = stride[idx];
|
||||
perm_dims[i] = dims[idx];
|
||||
}
|
||||
Ok(Self {
|
||||
shape: Shape::from(perm_dims),
|
||||
stride: perm_stride,
|
||||
start_offset: self.start_offset,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
if shape.rank() < self.shape().rank() {
|
||||
|
@ -49,30 +49,24 @@ mod device;
|
||||
pub mod display;
|
||||
mod dtype;
|
||||
mod dummy_cuda_backend;
|
||||
mod dummy_metal_backend;
|
||||
pub mod error;
|
||||
mod indexer;
|
||||
pub mod layout;
|
||||
#[cfg(feature = "metal")]
|
||||
pub mod metal_backend;
|
||||
#[cfg(feature = "mkl")]
|
||||
mod mkl;
|
||||
pub mod npy;
|
||||
mod op;
|
||||
pub mod pickle;
|
||||
pub mod quantized;
|
||||
pub mod safetensors;
|
||||
pub mod scalar;
|
||||
pub mod shape;
|
||||
mod storage;
|
||||
mod strided_index;
|
||||
mod tensor;
|
||||
pub mod test_utils;
|
||||
pub mod utils;
|
||||
mod variable;
|
||||
|
||||
pub use cpu_backend::CpuStorage;
|
||||
pub use device::{Device, DeviceLocation, NdArray};
|
||||
pub use device::{Device, DeviceLocation};
|
||||
pub use dtype::{DType, FloatDType, IntDType, WithDType};
|
||||
pub use error::{Error, Result};
|
||||
pub use indexer::IndexOp;
|
||||
@ -90,53 +84,5 @@ pub use cuda_backend::{CudaDevice, CudaStorage};
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
|
||||
|
||||
#[cfg(not(feature = "metal"))]
|
||||
pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage};
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
pub trait ToUsize2 {
|
||||
fn to_usize2(self) -> (usize, usize);
|
||||
}
|
||||
|
||||
impl ToUsize2 for usize {
|
||||
fn to_usize2(self) -> (usize, usize) {
|
||||
(self, self)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToUsize2 for (usize, usize) {
|
||||
fn to_usize2(self) -> (usize, usize) {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
// A simple trait defining a module with forward method using a single argument.
|
||||
pub trait Module {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
||||
}
|
||||
|
||||
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self(xs)
|
||||
}
|
||||
}
|
||||
|
||||
// A trait defining a module with forward method using a single tensor argument and a flag to
|
||||
// separate the training and evaluation behaviors.
|
||||
pub trait ModuleT {
|
||||
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
|
||||
}
|
||||
|
||||
impl<M: Module> ModuleT for M {
|
||||
fn forward_t(&self, xs: &Tensor, _train: bool) -> Result<Tensor> {
|
||||
self.forward(xs)
|
||||
}
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -25,10 +25,6 @@ mod ffi {
|
||||
pub fn vdMul(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
|
||||
pub fn vsDiv(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
|
||||
pub fn vdDiv(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
|
||||
pub fn vsFmax(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
|
||||
pub fn vdFmax(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
|
||||
pub fn vsFmin(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
|
||||
pub fn vdFmin(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
|
||||
|
||||
pub fn sgemm_(
|
||||
transa: *const c_char,
|
||||
@ -301,7 +297,7 @@ pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
|
||||
fn vs_tanh(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
@ -311,7 +307,7 @@ pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
|
||||
fn vd_tanh(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
@ -380,7 +376,3 @@ binary_op!(vs_mul, f32, vsMul);
|
||||
binary_op!(vd_mul, f64, vdMul);
|
||||
binary_op!(vs_div, f32, vsDiv);
|
||||
binary_op!(vd_div, f64, vdDiv);
|
||||
binary_op!(vs_max, f32, vsFmax);
|
||||
binary_op!(vd_max, f64, vdFmax);
|
||||
binary_op!(vs_min, f32, vsFmin);
|
||||
binary_op!(vd_min, f64, vdFmin);
|
||||
|
@ -85,7 +85,6 @@ impl Header {
|
||||
DType::F16 => "f2",
|
||||
DType::F32 => "f4",
|
||||
DType::F64 => "f8",
|
||||
DType::I64 => "i8",
|
||||
DType::U32 => "u4",
|
||||
DType::U8 => "u1",
|
||||
};
|
||||
@ -161,7 +160,7 @@ impl Header {
|
||||
"f" | "f4" => DType::F32,
|
||||
"d" | "f8" => DType::F64,
|
||||
// "i" | "i4" => DType::S32,
|
||||
"q" | "i8" => DType::I64,
|
||||
// "q" | "i8" => DType::S64,
|
||||
// "h" | "i2" => DType::S16,
|
||||
// "b" | "i1" => DType::S8,
|
||||
"B" | "u1" => DType::U8,
|
||||
@ -197,11 +196,7 @@ impl Header {
|
||||
|
||||
impl Tensor {
|
||||
// TODO: Add the possibility to read directly to a device?
|
||||
pub(crate) fn from_reader<R: std::io::Read>(
|
||||
shape: Shape,
|
||||
dtype: DType,
|
||||
reader: &mut R,
|
||||
) -> Result<Self> {
|
||||
fn from_reader<R: std::io::Read>(shape: Shape, dtype: DType, reader: &mut R) -> Result<Self> {
|
||||
let elem_count = shape.elem_count();
|
||||
match dtype {
|
||||
DType::BF16 => {
|
||||
@ -234,11 +229,6 @@ impl Tensor {
|
||||
reader.read_u32_into::<LittleEndian>(&mut data_t)?;
|
||||
Tensor::from_vec(data_t, shape, &Device::Cpu)
|
||||
}
|
||||
DType::I64 => {
|
||||
let mut data_t = vec![0i64; elem_count];
|
||||
reader.read_i64_into::<LittleEndian>(&mut data_t)?;
|
||||
Tensor::from_vec(data_t, shape, &Device::Cpu)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -250,6 +240,8 @@ impl Tensor {
|
||||
if header.fortran_order {
|
||||
return Err(Error::Npy("fortran order not supported".to_string()));
|
||||
}
|
||||
let mut data: Vec<u8> = vec![];
|
||||
reader.read_to_end(&mut data)?;
|
||||
Self::from_reader(header.shape(), header.descr, &mut reader)
|
||||
}
|
||||
|
||||
@ -369,25 +361,6 @@ impl NpzTensors {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn names(&self) -> Vec<&String> {
|
||||
self.index_per_name.keys().collect()
|
||||
}
|
||||
|
||||
/// This only returns the shape and dtype for a named tensor. Compared to `get`, this avoids
|
||||
/// reading the whole tensor data.
|
||||
pub fn get_shape_and_dtype(&self, name: &str) -> Result<(Shape, DType)> {
|
||||
let index = match self.index_per_name.get(name) {
|
||||
None => crate::bail!("cannot find tensor {name}"),
|
||||
Some(index) => *index,
|
||||
};
|
||||
let zip_reader = BufReader::new(File::open(&self.path)?);
|
||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||
let mut reader = zip.by_index(index)?;
|
||||
let header = read_header(&mut reader)?;
|
||||
let header = Header::parse(&header)?;
|
||||
Ok((header.shape(), header.descr))
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
|
||||
let index = match self.index_per_name.get(name) {
|
||||
None => return Ok(None),
|
||||
|
@ -1,5 +1,4 @@
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
|
||||
use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor};
|
||||
use half::{bf16, f16};
|
||||
use num_traits::float::Float;
|
||||
|
||||
@ -41,8 +40,6 @@ pub enum BinaryOp {
|
||||
Mul,
|
||||
Sub,
|
||||
Div,
|
||||
Maximum,
|
||||
Minimum,
|
||||
}
|
||||
|
||||
// Unary ops with no argument
|
||||
@ -58,13 +55,7 @@ pub enum UnaryOp {
|
||||
Sqr,
|
||||
Sqrt,
|
||||
Gelu,
|
||||
GeluErf,
|
||||
Erf,
|
||||
Relu,
|
||||
Tanh,
|
||||
Floor,
|
||||
Ceil,
|
||||
Round,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -87,17 +78,6 @@ pub enum Op {
|
||||
kernel: Tensor,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
},
|
||||
|
||||
#[allow(dead_code)]
|
||||
ConvTranspose1D {
|
||||
arg: Tensor,
|
||||
kernel: Tensor,
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
},
|
||||
|
||||
#[allow(dead_code)]
|
||||
@ -106,17 +86,6 @@ pub enum Op {
|
||||
kernel: Tensor,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
},
|
||||
|
||||
#[allow(dead_code)]
|
||||
ConvTranspose2D {
|
||||
arg: Tensor,
|
||||
kernel: Tensor,
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
},
|
||||
|
||||
AvgPool2D {
|
||||
@ -131,12 +100,7 @@ pub enum Op {
|
||||
stride: (usize, usize),
|
||||
},
|
||||
|
||||
UpsampleNearest1D(Tensor),
|
||||
UpsampleNearest2D {
|
||||
arg: Tensor,
|
||||
target_h: usize,
|
||||
target_w: usize,
|
||||
},
|
||||
UpsampleNearest2D(Tensor),
|
||||
|
||||
Cat(Vec<Tensor>, usize),
|
||||
|
||||
@ -150,29 +114,17 @@ pub enum Op {
|
||||
Copy(Tensor),
|
||||
Broadcast(Tensor),
|
||||
Narrow(Tensor, usize, usize, usize),
|
||||
SliceScatter0(Tensor, Tensor, usize),
|
||||
Reshape(Tensor),
|
||||
ToDevice(Tensor),
|
||||
Transpose(Tensor, usize, usize),
|
||||
Permute(Tensor, Vec<usize>),
|
||||
Elu(Tensor, f64),
|
||||
Powf(Tensor, f64),
|
||||
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>),
|
||||
CustomOp2(
|
||||
Tensor,
|
||||
Tensor,
|
||||
std::sync::Arc<Box<dyn CustomOp2 + Send + Sync>>,
|
||||
),
|
||||
CustomOp3(
|
||||
Tensor,
|
||||
Tensor,
|
||||
Tensor,
|
||||
std::sync::Arc<Box<dyn CustomOp3 + Send + Sync>>,
|
||||
),
|
||||
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1>>),
|
||||
CustomOp2(Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp2>>),
|
||||
CustomOp3(Tensor, Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp3>>),
|
||||
}
|
||||
|
||||
/// Unary ops that can be defined in user-land.
|
||||
pub trait CustomOp1 {
|
||||
pub trait CustomOp1: Send + Sync {
|
||||
// Box<dyn> does not support const yet, so use a function to get the name.
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
@ -188,18 +140,6 @@ pub trait CustomOp1 {
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_storage: &MetalStorage,
|
||||
_layout: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
||||
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
||||
/// The function should return the gradient of the argument.
|
||||
@ -208,7 +148,7 @@ pub trait CustomOp1 {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CustomOp2 {
|
||||
pub trait CustomOp2: Send + Sync {
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
@ -235,20 +175,6 @@ pub trait CustomOp2 {
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
fn bwd(
|
||||
&self,
|
||||
_arg1: &Tensor,
|
||||
@ -260,7 +186,7 @@ pub trait CustomOp2 {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CustomOp3 {
|
||||
pub trait CustomOp3: Send + Sync {
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
@ -291,22 +217,6 @@ pub trait CustomOp3 {
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
fn bwd(
|
||||
&self,
|
||||
_arg1: &Tensor,
|
||||
@ -329,7 +239,6 @@ pub trait UnaryOpT {
|
||||
fn f64(v1: f64) -> f64;
|
||||
fn u8(v1: u8) -> u8;
|
||||
fn u32(v1: u32) -> u32;
|
||||
fn i64(v1: i64) -> i64;
|
||||
|
||||
// There is no very good way to represent optional function in traits so we go for an explicit
|
||||
// boolean flag to mark the function as existing.
|
||||
@ -353,7 +262,6 @@ pub trait BinaryOpT {
|
||||
fn f64(v1: f64, v2: f64) -> f64;
|
||||
fn u8(v1: u8, v2: u8) -> u8;
|
||||
fn u32(v1: u32, v2: u32) -> u32;
|
||||
fn i64(v1: i64, v2: i64) -> i64;
|
||||
|
||||
const BF16_VEC: bool = false;
|
||||
fn bf16_vec(_xs1: &[bf16], _xs2: &[bf16], _ys: &mut [bf16]) {}
|
||||
@ -367,16 +275,12 @@ pub trait BinaryOpT {
|
||||
fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {}
|
||||
const U32_VEC: bool = false;
|
||||
fn u32_vec(_xs1: &[u32], _xs2: &[u32], _ys: &mut [u32]) {}
|
||||
const I64_VEC: bool = false;
|
||||
fn i64_vec(_xs1: &[i64], _xs2: &[i64], _ys: &mut [i64]) {}
|
||||
}
|
||||
|
||||
pub(crate) struct Add;
|
||||
pub(crate) struct Div;
|
||||
pub(crate) struct Mul;
|
||||
pub(crate) struct Sub;
|
||||
pub(crate) struct Maximum;
|
||||
pub(crate) struct Minimum;
|
||||
pub(crate) struct Exp;
|
||||
pub(crate) struct Log;
|
||||
pub(crate) struct Sin;
|
||||
@ -387,13 +291,7 @@ pub(crate) struct Recip;
|
||||
pub(crate) struct Sqr;
|
||||
pub(crate) struct Sqrt;
|
||||
pub(crate) struct Gelu;
|
||||
pub(crate) struct GeluErf;
|
||||
pub(crate) struct Erf;
|
||||
pub(crate) struct Relu;
|
||||
pub(crate) struct Tanh;
|
||||
pub(crate) struct Floor;
|
||||
pub(crate) struct Ceil;
|
||||
pub(crate) struct Round;
|
||||
|
||||
macro_rules! bin_op {
|
||||
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
||||
@ -425,10 +323,6 @@ macro_rules! bin_op {
|
||||
fn u32(v1: u32, v2: u32) -> u32 {
|
||||
$e(v1, v2)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v1: i64, v2: i64) -> i64 {
|
||||
$e(v1, v2)
|
||||
}
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const F32_VEC: bool = true;
|
||||
@ -467,22 +361,7 @@ bin_op!(Add, "add", |v1, v2| v1 + v2, vs_add, vd_add);
|
||||
bin_op!(Sub, "sub", |v1, v2| v1 - v2, vs_sub, vd_sub);
|
||||
bin_op!(Mul, "mul", |v1, v2| v1 * v2, vs_mul, vd_mul);
|
||||
bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div);
|
||||
bin_op!(
|
||||
Minimum,
|
||||
"minimum",
|
||||
|v1, v2| if v1 > v2 { v2 } else { v1 },
|
||||
vs_min,
|
||||
vd_min
|
||||
);
|
||||
bin_op!(
|
||||
Maximum,
|
||||
"maximum",
|
||||
|v1, v2| if v1 < v2 { v2 } else { v1 },
|
||||
vs_max,
|
||||
vd_max
|
||||
);
|
||||
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
macro_rules! unary_op {
|
||||
($op: ident, $name: literal, $a: ident, $e: expr) => {
|
||||
impl UnaryOpT for $op {
|
||||
@ -513,10 +392,6 @@ macro_rules! unary_op {
|
||||
fn u32(_: u32) -> u32 {
|
||||
todo!("no unary function for u32")
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
todo!("no unary function for i64")
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -549,10 +424,6 @@ macro_rules! unary_op {
|
||||
fn u32(_: u32) -> u32 {
|
||||
todo!("no unary function for u32")
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
todo!("no unary function for i64")
|
||||
}
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const F32_VEC: bool = true;
|
||||
@ -591,14 +462,13 @@ unary_op!(Exp, "exp", v, v.exp(), vs_exp, vd_exp);
|
||||
unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
|
||||
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
|
||||
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
|
||||
unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
|
||||
unary_op!(Abs, "abs", v, v.abs());
|
||||
unary_op!(Neg, "neg", v, -v);
|
||||
unary_op!(Recip, "recip", v, v.recip());
|
||||
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
||||
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
||||
|
||||
/// Tanh based approximation of the `gelu` operation
|
||||
/// GeluErf is the more precise one.
|
||||
/// `gelu` operation
|
||||
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||
impl UnaryOpT for Gelu {
|
||||
const NAME: &'static str = "gelu";
|
||||
@ -645,10 +515,6 @@ impl UnaryOpT for Gelu {
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
0
|
||||
}
|
||||
const KERNEL: &'static str = "ugelu";
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
@ -668,230 +534,6 @@ impl UnaryOpT for Gelu {
|
||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||
crate::mkl::vd_gelu(xs, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
const F32_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
#[inline(always)]
|
||||
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
||||
crate::accelerate::vs_gelu(xs, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
const F64_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
#[inline(always)]
|
||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||
crate::accelerate::vd_gelu(xs, ys)
|
||||
}
|
||||
}
|
||||
|
||||
/// `erf` operation
|
||||
/// <https://en.wikipedia.org/wiki/Error_function>
|
||||
impl UnaryOpT for Erf {
|
||||
const NAME: &'static str = "erf";
|
||||
const KERNEL: &'static str = "uerf";
|
||||
const V: Self = Erf;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
bf16::from_f64(Self::f64(v.to_f64()))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
f16::from_f64(Self::f64(v.to_f64()))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
Self::f64(v as f64) as f32
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
crate::cpu::erf::erf(v)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(_: u8) -> u8 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Abs {
|
||||
const NAME: &'static str = "abs";
|
||||
const KERNEL: &'static str = "uabs";
|
||||
const V: Self = Abs;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v.abs()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v.abs()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v.abs()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v.abs()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
v.abs()
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Ceil {
|
||||
const NAME: &'static str = "ceil";
|
||||
const KERNEL: &'static str = "uceil";
|
||||
const V: Self = Ceil;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v.ceil()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v.ceil()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v.ceil()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v.ceil()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Floor {
|
||||
const NAME: &'static str = "floor";
|
||||
const KERNEL: &'static str = "ufloor";
|
||||
const V: Self = Floor;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v.floor()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v.floor()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v.floor()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v.floor()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Round {
|
||||
const NAME: &'static str = "round";
|
||||
const KERNEL: &'static str = "uround";
|
||||
const V: Self = Round;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v.round()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v.round()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v.round()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v.round()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for GeluErf {
|
||||
const NAME: &'static str = "gelu_erf";
|
||||
const KERNEL: &'static str = "ugelu_erf";
|
||||
const V: Self = GeluErf;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
bf16::from_f64(Self::f64(v.to_f64()))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
f16::from_f64(Self::f64(v.to_f64()))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
Self::f64(v as f64) as f32
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
(crate::cpu::erf::erf(v / 2f64.sqrt()) + 1.) * 0.5 * v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(_: u8) -> u8 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Relu {
|
||||
@ -922,10 +564,6 @@ impl UnaryOpT for Relu {
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
/// `BackpropOp` is a wrapper around `Option<Op>`. The main goal is to ensure that dependencies are
|
||||
@ -979,10 +617,6 @@ impl BackpropOp {
|
||||
};
|
||||
Self(op)
|
||||
}
|
||||
|
||||
pub(crate) fn is_none(&self) -> bool {
|
||||
self.0.is_none()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for BackpropOp {
|
||||
|
@ -1,814 +0,0 @@
|
||||
// Just enough pickle support to be able to read PyTorch checkpoints.
|
||||
// This hardcodes objects that are required for tensor reading, we may want to make this a bit more
|
||||
// composable/tensor agnostic at some point.
|
||||
use crate::{DType, Error as E, Layout, Result, Tensor};
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use std::collections::HashMap;
|
||||
use std::io::BufRead;
|
||||
|
||||
const VERBOSE: bool = false;
|
||||
|
||||
// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/
|
||||
#[repr(u8)]
|
||||
#[derive(Debug, Eq, PartialEq, Clone)]
|
||||
pub enum OpCode {
|
||||
// https://github.com/python/cpython/blob/ed25f097160b5cbb0c9a1f9a746d2f1bbc96515a/Lib/pickletools.py#L2123
|
||||
Proto = 0x80,
|
||||
Global = b'c',
|
||||
BinPut = b'q',
|
||||
LongBinPut = b'r',
|
||||
EmptyTuple = b')',
|
||||
Reduce = b'R',
|
||||
Mark = b'(',
|
||||
BinUnicode = b'X',
|
||||
BinInt = b'J',
|
||||
Tuple = b't',
|
||||
BinPersId = b'Q',
|
||||
BinInt1 = b'K',
|
||||
BinInt2 = b'M',
|
||||
Tuple1 = 0x85,
|
||||
Tuple2 = 0x86,
|
||||
Tuple3 = 0x87,
|
||||
NewTrue = 0x88,
|
||||
NewFalse = 0x89,
|
||||
None = b'N',
|
||||
BinGet = b'h',
|
||||
LongBinGet = b'j',
|
||||
SetItem = b's',
|
||||
SetItems = b'u',
|
||||
EmptyDict = b'}',
|
||||
Dict = b'd',
|
||||
Build = b'b',
|
||||
Stop = b'.',
|
||||
NewObj = 0x81,
|
||||
EmptyList = b']',
|
||||
BinFloat = b'g',
|
||||
Append = b'a',
|
||||
Appends = b'e',
|
||||
}
|
||||
|
||||
// Avoid using FromPrimitive so as not to drag another dependency.
|
||||
impl TryFrom<u8> for OpCode {
|
||||
type Error = u8;
|
||||
fn try_from(value: u8) -> std::result::Result<Self, Self::Error> {
|
||||
match value {
|
||||
0x80 => Ok(Self::Proto),
|
||||
b'c' => Ok(Self::Global),
|
||||
b'q' => Ok(Self::BinPut),
|
||||
b'r' => Ok(Self::LongBinPut),
|
||||
b')' => Ok(Self::EmptyTuple),
|
||||
b'R' => Ok(Self::Reduce),
|
||||
b'(' => Ok(Self::Mark),
|
||||
b'X' => Ok(Self::BinUnicode),
|
||||
b'J' => Ok(Self::BinInt),
|
||||
b't' => Ok(Self::Tuple),
|
||||
b'Q' => Ok(Self::BinPersId),
|
||||
b'K' => Ok(Self::BinInt1),
|
||||
b'M' => Ok(Self::BinInt2),
|
||||
b'N' => Ok(Self::None),
|
||||
0x85 => Ok(Self::Tuple1),
|
||||
0x86 => Ok(Self::Tuple2),
|
||||
0x87 => Ok(Self::Tuple3),
|
||||
0x88 => Ok(Self::NewTrue),
|
||||
0x89 => Ok(Self::NewFalse),
|
||||
b'h' => Ok(Self::BinGet),
|
||||
b'j' => Ok(Self::LongBinGet),
|
||||
b's' => Ok(Self::SetItem),
|
||||
b'u' => Ok(Self::SetItems),
|
||||
b'}' => Ok(Self::EmptyDict),
|
||||
b'd' => Ok(Self::EmptyDict),
|
||||
b'b' => Ok(Self::Build),
|
||||
b'.' => Ok(Self::Stop),
|
||||
0x81 => Ok(Self::NewObj),
|
||||
b']' => Ok(Self::EmptyList),
|
||||
b'G' => Ok(Self::BinFloat),
|
||||
b'a' => Ok(Self::Append),
|
||||
b'e' => Ok(Self::Appends),
|
||||
value => Err(value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn read_to_newline<R: BufRead>(r: &mut R) -> Result<Vec<u8>> {
|
||||
let mut data: Vec<u8> = Vec::with_capacity(32);
|
||||
r.read_until(b'\n', &mut data)?;
|
||||
data.pop();
|
||||
if data.last() == Some(&b'\r') {
|
||||
data.pop();
|
||||
}
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum Object {
|
||||
Class {
|
||||
module_name: String,
|
||||
class_name: String,
|
||||
},
|
||||
Int(i32),
|
||||
Float(f64),
|
||||
Unicode(String),
|
||||
Bool(bool),
|
||||
None,
|
||||
Tuple(Vec<Object>),
|
||||
List(Vec<Object>),
|
||||
Mark,
|
||||
Dict(Vec<(Object, Object)>),
|
||||
Reduce {
|
||||
callable: Box<Object>,
|
||||
args: Box<Object>,
|
||||
},
|
||||
Build {
|
||||
callable: Box<Object>,
|
||||
args: Box<Object>,
|
||||
},
|
||||
PersistentLoad(Box<Object>),
|
||||
}
|
||||
|
||||
type OResult<T> = std::result::Result<T, Object>;
|
||||
|
||||
impl Object {
|
||||
pub fn unicode(self) -> OResult<String> {
|
||||
match self {
|
||||
Self::Unicode(t) => Ok(t),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reduce(self) -> OResult<(Self, Self)> {
|
||||
match self {
|
||||
Self::Reduce { callable, args } => Ok((*callable, *args)),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn none(self) -> OResult<()> {
|
||||
match self {
|
||||
Self::None => Ok(()),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn persistent_load(self) -> OResult<Self> {
|
||||
match self {
|
||||
Self::PersistentLoad(t) => Ok(*t),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bool(self) -> OResult<bool> {
|
||||
match self {
|
||||
Self::Bool(t) => Ok(t),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn int(self) -> OResult<i32> {
|
||||
match self {
|
||||
Self::Int(t) => Ok(t),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tuple(self) -> OResult<Vec<Self>> {
|
||||
match self {
|
||||
Self::Tuple(t) => Ok(t),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dict(self) -> OResult<Vec<(Self, Self)>> {
|
||||
match self {
|
||||
Self::Dict(t) => Ok(t),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn class(self) -> OResult<(String, String)> {
|
||||
match self {
|
||||
Self::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} => Ok((module_name, class_name)),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_tensor_info(
|
||||
self,
|
||||
name: Self,
|
||||
dir_name: &std::path::Path,
|
||||
) -> Result<Option<TensorInfo>> {
|
||||
let name = match name.unicode() {
|
||||
Ok(name) => name,
|
||||
Err(_) => return Ok(None),
|
||||
};
|
||||
let (callable, args) = match self.reduce() {
|
||||
Ok(callable_args) => callable_args,
|
||||
_ => return Ok(None),
|
||||
};
|
||||
let (callable, args) = match callable {
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "torch._tensor" && class_name == "_rebuild_from_type_v2" => {
|
||||
let mut args = args.tuple()?;
|
||||
let callable = args.remove(0);
|
||||
let args = args.remove(1);
|
||||
(callable, args)
|
||||
}
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "torch._utils" && class_name == "_rebuild_parameter" => {
|
||||
let mut args = args.tuple()?;
|
||||
args.remove(0).reduce()?
|
||||
}
|
||||
_ => (callable, args),
|
||||
};
|
||||
match callable {
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
|
||||
_ => return Ok(None),
|
||||
};
|
||||
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
|
||||
Ok(Some(TensorInfo {
|
||||
name,
|
||||
dtype,
|
||||
layout,
|
||||
path: format!("{}/{}", dir_name.to_string_lossy(), file_path),
|
||||
storage_size,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Object> for String {
|
||||
type Error = Object;
|
||||
fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
|
||||
match value {
|
||||
Object::Unicode(s) => Ok(s),
|
||||
other => Err(other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Object> for usize {
|
||||
type Error = Object;
|
||||
fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
|
||||
match value {
|
||||
Object::Int(s) if s >= 0 => Ok(s as usize),
|
||||
other => Err(other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: TryFrom<Object, Error = Object>> TryFrom<Object> for Vec<T> {
|
||||
type Error = Object;
|
||||
fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
|
||||
match value {
|
||||
Object::Tuple(values) => {
|
||||
// This does not return the appropriate value in the error case but instead return
|
||||
// the object related to the first error.
|
||||
values
|
||||
.into_iter()
|
||||
.map(|v| T::try_from(v))
|
||||
.collect::<std::result::Result<Vec<T>, Self::Error>>()
|
||||
}
|
||||
other => Err(other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Stack {
|
||||
stack: Vec<Object>,
|
||||
memo: HashMap<u32, Object>,
|
||||
}
|
||||
|
||||
impl Stack {
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
stack: Vec::with_capacity(512),
|
||||
memo: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stack(&self) -> &[Object] {
|
||||
self.stack.as_slice()
|
||||
}
|
||||
|
||||
pub fn read_loop<R: BufRead>(&mut self, r: &mut R) -> Result<()> {
|
||||
loop {
|
||||
if self.read(r)? {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn finalize(mut self) -> Result<Object> {
|
||||
self.pop()
|
||||
}
|
||||
|
||||
fn push(&mut self, obj: Object) {
|
||||
self.stack.push(obj)
|
||||
}
|
||||
|
||||
fn pop(&mut self) -> Result<Object> {
|
||||
match self.stack.pop() {
|
||||
None => crate::bail!("unexpected empty stack"),
|
||||
Some(obj) => Ok(obj),
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/#Pickle.OpCodes.BUILD
|
||||
fn build(&mut self) -> Result<()> {
|
||||
let args = self.pop()?;
|
||||
let obj = self.pop()?;
|
||||
let obj = match (obj, args) {
|
||||
(Object::Dict(mut obj), Object::Dict(mut args)) => {
|
||||
obj.append(&mut args);
|
||||
Object::Dict(obj)
|
||||
}
|
||||
(obj, args) => Object::Build {
|
||||
callable: Box::new(obj),
|
||||
args: Box::new(args),
|
||||
},
|
||||
};
|
||||
self.push(obj);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn reduce(&mut self) -> Result<()> {
|
||||
let args = self.pop()?;
|
||||
let callable = self.pop()?;
|
||||
#[allow(clippy::single_match)]
|
||||
let reduced = match &callable {
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} => {
|
||||
if module_name == "collections"
|
||||
&& (class_name == "OrderedDict" || class_name == "defaultdict")
|
||||
{
|
||||
// TODO: have a separate ordered dict and a separate default dict.
|
||||
Some(Object::Dict(vec![]))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
let reduced = reduced.unwrap_or_else(|| Object::Reduce {
|
||||
callable: Box::new(callable),
|
||||
args: Box::new(args),
|
||||
});
|
||||
self.push(reduced);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn last(&mut self) -> Result<&mut Object> {
|
||||
match self.stack.last_mut() {
|
||||
None => crate::bail!("unexpected empty stack"),
|
||||
Some(obj) => Ok(obj),
|
||||
}
|
||||
}
|
||||
|
||||
fn memo_get(&self, id: u32) -> Result<Object> {
|
||||
match self.memo.get(&id) {
|
||||
None => crate::bail!("missing object in memo {id}"),
|
||||
Some(obj) => {
|
||||
// Maybe we should use refcounting rather than doing potential large clones here.
|
||||
Ok(obj.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn memo_put(&mut self, id: u32) -> Result<()> {
|
||||
let obj = self.last()?.clone();
|
||||
self.memo.insert(id, obj);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn persistent_load(&self, id: Object) -> Result<Object> {
|
||||
Ok(Object::PersistentLoad(Box::new(id)))
|
||||
}
|
||||
|
||||
fn new_obj(&self, class: Object, args: Object) -> Result<Object> {
|
||||
Ok(Object::Reduce {
|
||||
callable: Box::new(class),
|
||||
args: Box::new(args),
|
||||
})
|
||||
}
|
||||
|
||||
fn pop_to_marker(&mut self) -> Result<Vec<Object>> {
|
||||
let mut mark_idx = None;
|
||||
for (idx, obj) in self.stack.iter().enumerate().rev() {
|
||||
if obj == &Object::Mark {
|
||||
mark_idx = Some(idx);
|
||||
break;
|
||||
}
|
||||
}
|
||||
match mark_idx {
|
||||
Some(mark_idx) => {
|
||||
let objs = self.stack.split_off(mark_idx + 1);
|
||||
self.stack.pop();
|
||||
Ok(objs)
|
||||
}
|
||||
None => {
|
||||
crate::bail!("marker object not found")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read<R: BufRead>(&mut self, r: &mut R) -> Result<bool> {
|
||||
let op_code = match OpCode::try_from(r.read_u8()?) {
|
||||
Ok(op_code) => op_code,
|
||||
Err(op_code) => {
|
||||
crate::bail!("unknown op-code {op_code}")
|
||||
}
|
||||
};
|
||||
// println!("op: {op_code:?}");
|
||||
// println!("{:?}", self.stack);
|
||||
match op_code {
|
||||
OpCode::Proto => {
|
||||
let version = r.read_u8()?;
|
||||
if VERBOSE {
|
||||
println!("proto {version}");
|
||||
}
|
||||
}
|
||||
OpCode::Global => {
|
||||
let module_name = read_to_newline(r)?;
|
||||
let class_name = read_to_newline(r)?;
|
||||
let module_name = String::from_utf8_lossy(&module_name).to_string();
|
||||
let class_name = String::from_utf8_lossy(&class_name).to_string();
|
||||
self.push(Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
})
|
||||
}
|
||||
OpCode::BinInt1 => {
|
||||
let arg = r.read_u8()?;
|
||||
self.push(Object::Int(arg as i32))
|
||||
}
|
||||
OpCode::BinInt2 => {
|
||||
let arg = r.read_u16::<LittleEndian>()?;
|
||||
self.push(Object::Int(arg as i32))
|
||||
}
|
||||
OpCode::BinInt => {
|
||||
let arg = r.read_i32::<LittleEndian>()?;
|
||||
self.push(Object::Int(arg))
|
||||
}
|
||||
OpCode::BinFloat => {
|
||||
let arg = r.read_f64::<LittleEndian>()?;
|
||||
self.push(Object::Float(arg))
|
||||
}
|
||||
OpCode::BinUnicode => {
|
||||
let len = r.read_u32::<LittleEndian>()?;
|
||||
let mut data = vec![0u8; len as usize];
|
||||
r.read_exact(&mut data)?;
|
||||
let data = String::from_utf8(data).map_err(E::wrap)?;
|
||||
self.push(Object::Unicode(data))
|
||||
}
|
||||
OpCode::BinPersId => {
|
||||
let id = self.pop()?;
|
||||
let obj = self.persistent_load(id)?;
|
||||
self.push(obj)
|
||||
}
|
||||
OpCode::Tuple => {
|
||||
let objs = self.pop_to_marker()?;
|
||||
self.push(Object::Tuple(objs))
|
||||
}
|
||||
OpCode::Tuple1 => {
|
||||
let obj = self.pop()?;
|
||||
self.push(Object::Tuple(vec![obj]))
|
||||
}
|
||||
OpCode::Tuple2 => {
|
||||
let obj2 = self.pop()?;
|
||||
let obj1 = self.pop()?;
|
||||
self.push(Object::Tuple(vec![obj1, obj2]))
|
||||
}
|
||||
OpCode::Tuple3 => {
|
||||
let obj3 = self.pop()?;
|
||||
let obj2 = self.pop()?;
|
||||
let obj1 = self.pop()?;
|
||||
self.push(Object::Tuple(vec![obj1, obj2, obj3]))
|
||||
}
|
||||
OpCode::NewTrue => self.push(Object::Bool(true)),
|
||||
OpCode::NewFalse => self.push(Object::Bool(false)),
|
||||
OpCode::Append => {
|
||||
let value = self.pop()?;
|
||||
let pylist = self.last()?;
|
||||
if let Object::List(d) = pylist {
|
||||
d.push(value)
|
||||
} else {
|
||||
crate::bail!("expected a list, got {pylist:?}")
|
||||
}
|
||||
}
|
||||
OpCode::Appends => {
|
||||
let objs = self.pop_to_marker()?;
|
||||
let pylist = self.last()?;
|
||||
if let Object::List(d) = pylist {
|
||||
d.extend(objs)
|
||||
} else {
|
||||
crate::bail!("expected a list, got {pylist:?}")
|
||||
}
|
||||
}
|
||||
OpCode::SetItem => {
|
||||
let value = self.pop()?;
|
||||
let key = self.pop()?;
|
||||
let pydict = self.last()?;
|
||||
if let Object::Dict(d) = pydict {
|
||||
d.push((key, value))
|
||||
} else {
|
||||
crate::bail!("expected a dict, got {pydict:?}")
|
||||
}
|
||||
}
|
||||
OpCode::SetItems => {
|
||||
let mut objs = self.pop_to_marker()?;
|
||||
let pydict = self.last()?;
|
||||
if let Object::Dict(d) = pydict {
|
||||
if objs.len() % 2 != 0 {
|
||||
crate::bail!("setitems: not an even number of objects")
|
||||
}
|
||||
while let Some(value) = objs.pop() {
|
||||
let key = objs.pop().unwrap();
|
||||
d.push((key, value))
|
||||
}
|
||||
} else {
|
||||
crate::bail!("expected a dict, got {pydict:?}")
|
||||
}
|
||||
}
|
||||
OpCode::None => self.push(Object::None),
|
||||
OpCode::Stop => {
|
||||
return Ok(true);
|
||||
}
|
||||
OpCode::Build => self.build()?,
|
||||
OpCode::EmptyDict => self.push(Object::Dict(vec![])),
|
||||
OpCode::Dict => {
|
||||
let mut objs = self.pop_to_marker()?;
|
||||
let mut pydict = vec![];
|
||||
if objs.len() % 2 != 0 {
|
||||
crate::bail!("setitems: not an even number of objects")
|
||||
}
|
||||
while let Some(value) = objs.pop() {
|
||||
let key = objs.pop().unwrap();
|
||||
pydict.push((key, value))
|
||||
}
|
||||
self.push(Object::Dict(pydict))
|
||||
}
|
||||
OpCode::Mark => self.push(Object::Mark),
|
||||
OpCode::Reduce => self.reduce()?,
|
||||
OpCode::EmptyTuple => self.push(Object::Tuple(vec![])),
|
||||
OpCode::EmptyList => self.push(Object::List(vec![])),
|
||||
OpCode::BinGet => {
|
||||
let arg = r.read_u8()?;
|
||||
let obj = self.memo_get(arg as u32)?;
|
||||
self.push(obj)
|
||||
}
|
||||
OpCode::LongBinGet => {
|
||||
let arg = r.read_u32::<LittleEndian>()?;
|
||||
let obj = self.memo_get(arg)?;
|
||||
self.push(obj)
|
||||
}
|
||||
OpCode::BinPut => {
|
||||
let arg = r.read_u8()?;
|
||||
self.memo_put(arg as u32)?
|
||||
}
|
||||
OpCode::LongBinPut => {
|
||||
let arg = r.read_u32::<LittleEndian>()?;
|
||||
self.memo_put(arg)?
|
||||
}
|
||||
OpCode::NewObj => {
|
||||
let args = self.pop()?;
|
||||
let class = self.pop()?;
|
||||
let obj = self.new_obj(class, args)?;
|
||||
self.push(obj)
|
||||
}
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Object> for E {
|
||||
fn from(value: Object) -> Self {
|
||||
E::Msg(format!("conversion error on {value:?}"))
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/pytorch/pytorch/blob/4eac43d046ded0f0a5a5fa8db03eb40f45bf656e/torch/_utils.py#L198
|
||||
// Arguments: storage, storage_offset, size, stride, requires_grad, backward_hooks
|
||||
fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
|
||||
let mut args = args.tuple()?;
|
||||
let stride = Vec::<usize>::try_from(args.remove(3))?;
|
||||
let size = Vec::<usize>::try_from(args.remove(2))?;
|
||||
let offset = args.remove(1).int()? as usize;
|
||||
let storage = args.remove(0).persistent_load()?;
|
||||
let mut storage = storage.tuple()?;
|
||||
let storage_size = storage.remove(4).int()? as usize;
|
||||
let path = storage.remove(2).unicode()?;
|
||||
let (_module_name, class_name) = storage.remove(1).class()?;
|
||||
let dtype = match class_name.as_str() {
|
||||
"FloatStorage" => DType::F32,
|
||||
"DoubleStorage" => DType::F64,
|
||||
"HalfStorage" => DType::F16,
|
||||
"BFloat16Storage" => DType::BF16,
|
||||
"ByteStorage" => DType::U8,
|
||||
"LongStorage" => DType::I64,
|
||||
other => {
|
||||
crate::bail!("unsupported storage type {other}")
|
||||
}
|
||||
};
|
||||
let layout = Layout::new(crate::Shape::from(size), stride, offset);
|
||||
Ok((layout, dtype, path, storage_size))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TensorInfo {
|
||||
pub name: String,
|
||||
pub dtype: DType,
|
||||
pub layout: Layout,
|
||||
pub path: String,
|
||||
pub storage_size: usize,
|
||||
}
|
||||
|
||||
/// Read the tensor info from a .pth file.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `file` - The path to the .pth file.
|
||||
/// * `verbose` - Whether to print debug information.
|
||||
/// * `key` - Optional key to retrieve `state_dict` from the pth file.
|
||||
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
||||
file: P,
|
||||
verbose: bool,
|
||||
key: Option<&str>,
|
||||
) -> Result<Vec<TensorInfo>> {
|
||||
let file = std::fs::File::open(file)?;
|
||||
let zip_reader = std::io::BufReader::new(file);
|
||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||
let zip_file_names = zip
|
||||
.file_names()
|
||||
.map(|f| f.to_string())
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
let mut tensor_infos = vec![];
|
||||
for file_name in zip_file_names.iter() {
|
||||
if !file_name.ends_with("data.pkl") {
|
||||
continue;
|
||||
}
|
||||
let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").unwrap());
|
||||
let reader = zip.by_name(file_name)?;
|
||||
let mut reader = std::io::BufReader::new(reader);
|
||||
let mut stack = Stack::empty();
|
||||
stack.read_loop(&mut reader)?;
|
||||
let obj = stack.finalize()?;
|
||||
if VERBOSE || verbose {
|
||||
println!("{obj:#?}");
|
||||
}
|
||||
|
||||
let obj = match obj {
|
||||
Object::Build { callable, args } => match *callable {
|
||||
Object::Reduce { callable, args: _ } => match *callable {
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "__torch__" && class_name == "Module" => *args,
|
||||
_ => continue,
|
||||
},
|
||||
_ => continue,
|
||||
},
|
||||
obj => obj,
|
||||
};
|
||||
|
||||
// If key is provided, then we need to extract the state_dict from the object.
|
||||
let obj = if let Some(key) = key {
|
||||
if let Object::Dict(key_values) = obj {
|
||||
key_values
|
||||
.into_iter()
|
||||
.find(|(k, _)| *k == Object::Unicode(key.to_owned()))
|
||||
.map(|(_, v)| v)
|
||||
.ok_or_else(|| E::Msg(format!("key {key} not found")))?
|
||||
} else {
|
||||
obj
|
||||
}
|
||||
} else {
|
||||
obj
|
||||
};
|
||||
|
||||
// If the object is a dict, then we can extract the tensor info from it.
|
||||
// NOTE: We are assuming that the `obj` is state_dict by this stage.
|
||||
if let Object::Dict(key_values) = obj {
|
||||
for (name, value) in key_values.into_iter() {
|
||||
match value.into_tensor_info(name, &dir_name) {
|
||||
Ok(Some(tensor_info)) => tensor_infos.push(tensor_info),
|
||||
Ok(None) => {}
|
||||
Err(err) => eprintln!("skipping: {err:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(tensor_infos)
|
||||
}
|
||||
|
||||
/// Lazy tensor loader.
|
||||
pub struct PthTensors {
|
||||
tensor_infos: HashMap<String, TensorInfo>,
|
||||
path: std::path::PathBuf,
|
||||
// We do not store a zip reader as it needs mutable access to extract data. Instead we
|
||||
// re-create a zip reader for each tensor.
|
||||
}
|
||||
|
||||
impl PthTensors {
|
||||
pub fn new<P: AsRef<std::path::Path>>(path: P, key: Option<&str>) -> Result<Self> {
|
||||
let tensor_infos = read_pth_tensor_info(path.as_ref(), false, key)?;
|
||||
let tensor_infos = tensor_infos
|
||||
.into_iter()
|
||||
.map(|ti| (ti.name.to_string(), ti))
|
||||
.collect();
|
||||
let path = path.as_ref().to_owned();
|
||||
Ok(Self { tensor_infos, path })
|
||||
}
|
||||
|
||||
pub fn tensor_infos(&self) -> &HashMap<String, TensorInfo> {
|
||||
&self.tensor_infos
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
|
||||
use std::io::Read;
|
||||
let tensor_info = match self.tensor_infos.get(name) {
|
||||
None => return Ok(None),
|
||||
Some(tensor_info) => tensor_info,
|
||||
};
|
||||
// We hope that the file has not changed since first reading it.
|
||||
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
|
||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||
let mut reader = zip.by_name(&tensor_info.path)?;
|
||||
let is_fortran_contiguous = tensor_info.layout.is_fortran_contiguous();
|
||||
let rank = tensor_info.layout.shape().rank();
|
||||
|
||||
// Reading the data is a bit tricky as it can be strided, for now only support the basic
|
||||
// case and when the tensor is fortran contiguous.
|
||||
if !tensor_info.layout.is_contiguous() && !is_fortran_contiguous {
|
||||
crate::bail!(
|
||||
"cannot retrieve non-contiguous tensors {:?}",
|
||||
tensor_info.layout
|
||||
)
|
||||
}
|
||||
let start_offset = tensor_info.layout.start_offset();
|
||||
if start_offset > 0 {
|
||||
std::io::copy(
|
||||
&mut reader.by_ref().take(start_offset as u64),
|
||||
&mut std::io::sink(),
|
||||
)?;
|
||||
}
|
||||
let tensor = Tensor::from_reader(
|
||||
tensor_info.layout.shape().clone(),
|
||||
tensor_info.dtype,
|
||||
&mut reader,
|
||||
)?;
|
||||
|
||||
if rank > 1 && is_fortran_contiguous {
|
||||
// Reverse the shape, e.g. Shape(2, 3, 4) -> Shape(4, 3, 2)
|
||||
let shape_reversed: Vec<_> = tensor_info.layout.dims().iter().rev().cloned().collect();
|
||||
let tensor = tensor.reshape(shape_reversed)?;
|
||||
|
||||
// Permute (transpose) the dimensions, e.g. Shape(4, 3, 2) -> Shape(2, 3, 4)
|
||||
let dim_indeces_reversed: Vec<_> = (0..rank).rev().collect();
|
||||
let tensor = tensor.permute(dim_indeces_reversed)?;
|
||||
Ok(Some(tensor))
|
||||
} else {
|
||||
Ok(Some(tensor))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read all the tensors from a PyTorch pth file with a given key.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - Path to the pth file.
|
||||
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
|
||||
/// contains multiple objects and the state_dict is the one we are interested in.
|
||||
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
|
||||
path: P,
|
||||
key: Option<&str>,
|
||||
) -> Result<Vec<(String, Tensor)>> {
|
||||
let pth = PthTensors::new(path, key)?;
|
||||
let tensor_names = pth.tensor_infos.keys();
|
||||
let mut tensors = Vec::with_capacity(tensor_names.len());
|
||||
for name in tensor_names {
|
||||
if let Some(tensor) = pth.get(name)? {
|
||||
tensors.push((name.to_string(), tensor))
|
||||
}
|
||||
}
|
||||
Ok(tensors)
|
||||
}
|
||||
|
||||
/// Read all the tensors from a PyTorch pth file.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - Path to the pth file.
|
||||
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
|
||||
read_all_with_key(path, None)
|
||||
}
|
@ -1,667 +0,0 @@
|
||||
use super::k_quants::{
|
||||
BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K,
|
||||
};
|
||||
use crate::Result;
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
use half::f16;
|
||||
|
||||
#[cfg(target_arch = "x86")]
|
||||
use core::arch::x86::*;
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
use core::arch::x86_64::*;
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn sum_i16_pairs_float(x: __m256i) -> __m256 {
|
||||
let ones = _mm256_set1_epi16(1);
|
||||
let summed_pairs = _mm256_madd_epi16(ones, x);
|
||||
_mm256_cvtepi32_ps(summed_pairs)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn mul_sum_us8_pairs_float(ax: __m256i, sy: __m256i) -> __m256 {
|
||||
let dot = _mm256_maddubs_epi16(ax, sy);
|
||||
sum_i16_pairs_float(dot)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn hsum_float_8(x: __m256) -> f32 {
|
||||
let res = _mm256_extractf128_ps(x, 1);
|
||||
let res = _mm_add_ps(res, _mm256_castps256_ps128(x));
|
||||
let res = _mm_add_ps(res, _mm_movehl_ps(res, res));
|
||||
let res = _mm_add_ss(res, _mm_movehdup_ps(res));
|
||||
_mm_cvtss_f32(res)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn bytes_from_nibbles_32(rsi: *const u8) -> __m256i {
|
||||
let tmp = _mm_loadu_si128(rsi as *const __m128i);
|
||||
let bytes = _mm256_insertf128_si256::<1>(_mm256_castsi128_si256(tmp), _mm_srli_epi16(tmp, 4));
|
||||
let low_mask = _mm256_set1_epi8(0xF);
|
||||
_mm256_and_si256(low_mask, bytes)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 {
|
||||
let ax = _mm256_sign_epi8(x, x);
|
||||
let sy = _mm256_sign_epi8(y, x);
|
||||
mul_sum_us8_pairs_float(ax, sy)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
unsafe {
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d));
|
||||
let bx = bytes_from_nibbles_32(x.qs.as_ptr());
|
||||
let off = _mm256_set1_epi8(8);
|
||||
let bx = _mm256_sub_epi8(bx, off);
|
||||
let by = _mm256_loadu_si256(y.qs.as_ptr() as *const __m256i);
|
||||
let q = mul_sum_i8_pairs_float(bx, by);
|
||||
acc = _mm256_fmadd_ps(d, q, acc);
|
||||
}
|
||||
Ok(hsum_float_8(acc))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
unsafe {
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d));
|
||||
let bx = _mm256_loadu_si256(x.qs.as_ptr() as *const __m256i);
|
||||
let by = _mm256_loadu_si256(y.qs.as_ptr() as *const __m256i);
|
||||
let q = mul_sum_i8_pairs_float(bx, by);
|
||||
acc = _mm256_fmadd_ps(d, q, acc);
|
||||
}
|
||||
Ok(hsum_float_8(acc))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn get_scale_shuffle(i: usize) -> __m128i {
|
||||
const K_SHUFFLE: [u8; 128] = [
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3,
|
||||
3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7,
|
||||
7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10,
|
||||
11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13,
|
||||
13, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15,
|
||||
];
|
||||
_mm_loadu_si128((K_SHUFFLE.as_ptr() as *const __m128i).add(i))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn get_scale_shuffle_k4(i: usize) -> __m256i {
|
||||
const K_SHUFFLE: [u8; 256] = [
|
||||
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
|
||||
0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
|
||||
2, 3, 2, 3, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
|
||||
4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
|
||||
6, 7, 6, 7, 6, 7, 6, 7, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
|
||||
8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10,
|
||||
11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 12, 13, 12, 13, 12, 13,
|
||||
12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12,
|
||||
13, 12, 13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
|
||||
14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
|
||||
];
|
||||
_mm256_loadu_si256((K_SHUFFLE.as_ptr() as *const __m256i).add(i))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn get_scale_shuffle_q3k(i: usize) -> __m256i {
|
||||
const K_SHUFFLE: [u8; 128] = [
|
||||
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
|
||||
2, 3, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
|
||||
6, 7, 6, 7, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10, 11, 10, 11, 10, 11, 10, 11,
|
||||
10, 11, 10, 11, 10, 11, 10, 11, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12,
|
||||
13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
|
||||
];
|
||||
_mm256_loadu_si256((K_SHUFFLE.as_ptr() as *const __m256i).add(i))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
let qk = QK_K;
|
||||
if n % qk != 0 {
|
||||
crate::bail!("vec_dot_q6k_8k: {n} is not divisible by {qk}")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let m4 = _mm256_set1_epi8(0xF);
|
||||
let m2 = _mm256_set1_epi8(3);
|
||||
let m32s = _mm256_set1_epi8(32);
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
let mut q4 = x.ql.as_ptr();
|
||||
let mut qh = x.qh.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
let scales = _mm_loadu_si128(x.scales.as_ptr() as *const __m128i);
|
||||
let mut sumi = _mm256_setzero_si256();
|
||||
|
||||
for j in 0..QK_K / 128 {
|
||||
let is = j * 4;
|
||||
let scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is));
|
||||
let scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
|
||||
let scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
|
||||
let scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
|
||||
|
||||
let q4bits1 = _mm256_loadu_si256(q4 as *const __m256i);
|
||||
q4 = q4.add(32);
|
||||
let q4bits2 = _mm256_loadu_si256(q4 as *const __m256i);
|
||||
q4 = q4.add(32);
|
||||
let q4bits_h = _mm256_loadu_si256(qh as *const __m256i);
|
||||
qh = qh.add(32);
|
||||
|
||||
let q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bits_h, m2), 4);
|
||||
let q4h_1 =
|
||||
_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bits_h, 2), m2), 4);
|
||||
let q4h_2 =
|
||||
_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bits_h, 4), m2), 4);
|
||||
let q4h_3 =
|
||||
_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bits_h, 6), m2), 4);
|
||||
|
||||
let q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
|
||||
let q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
|
||||
let q4_2 =
|
||||
_mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
|
||||
let q4_3 =
|
||||
_mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
|
||||
|
||||
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_2 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_3 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
|
||||
let q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
|
||||
let q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
|
||||
let q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
|
||||
let q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
|
||||
|
||||
let p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
|
||||
let p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
|
||||
let p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
|
||||
let p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
|
||||
|
||||
let p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
|
||||
let p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
|
||||
let p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
|
||||
let p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
|
||||
|
||||
let p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
|
||||
let p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
|
||||
let p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2);
|
||||
let p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3);
|
||||
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3));
|
||||
}
|
||||
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
|
||||
}
|
||||
Ok(hsum_float_8(acc))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn mm256_set_m128i(a: __m128i, b: __m128i) -> __m256i {
|
||||
_mm256_insertf128_si256(_mm256_castsi128_si256(b), a, 1)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let m3 = _mm256_set1_epi8(3);
|
||||
let m4 = _mm_set1_epi8(0xF);
|
||||
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
let dmin = -y.d * x.dmin.to_f32();
|
||||
|
||||
let mut q2 = x.qs.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
let mins_and_scales = _mm_loadu_si128(x.scales.as_ptr() as *const __m128i);
|
||||
let scales8 = _mm_and_si128(mins_and_scales, m4);
|
||||
let mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
|
||||
let mins = _mm256_cvtepi8_epi16(mins8);
|
||||
let prod =
|
||||
_mm256_madd_epi16(mins, _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i));
|
||||
|
||||
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc);
|
||||
|
||||
let all_scales = _mm256_cvtepi8_epi16(scales8);
|
||||
let l_scales = _mm256_extracti128_si256(all_scales, 0);
|
||||
let h_scales = _mm256_extracti128_si256(all_scales, 1);
|
||||
let scales = [
|
||||
mm256_set_m128i(l_scales, l_scales),
|
||||
mm256_set_m128i(h_scales, h_scales),
|
||||
];
|
||||
|
||||
let mut sumi = _mm256_setzero_si256();
|
||||
|
||||
for scale in scales {
|
||||
let q2bits = _mm256_loadu_si256(q2 as *const __m256i);
|
||||
q2 = q2.add(32);
|
||||
|
||||
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_2 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_3 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
|
||||
let q2_0 = _mm256_and_si256(q2bits, m3);
|
||||
let q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3);
|
||||
let q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3);
|
||||
let q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3);
|
||||
|
||||
let p0 = _mm256_maddubs_epi16(q2_0, q8_0);
|
||||
let p1 = _mm256_maddubs_epi16(q2_1, q8_1);
|
||||
let p2 = _mm256_maddubs_epi16(q2_2, q8_2);
|
||||
let p3 = _mm256_maddubs_epi16(q2_3, q8_3);
|
||||
|
||||
let p0 =
|
||||
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(0)), p0);
|
||||
let p1 =
|
||||
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(1)), p1);
|
||||
let p2 =
|
||||
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(2)), p2);
|
||||
let p3 =
|
||||
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(3)), p3);
|
||||
|
||||
let p0 = _mm256_add_epi32(p0, p1);
|
||||
let p2 = _mm256_add_epi32(p2, p3);
|
||||
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2));
|
||||
}
|
||||
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
|
||||
}
|
||||
|
||||
Ok(hsum_float_8(acc))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
|
||||
const KMASK1: u32 = 0x03030303;
|
||||
const KMASK2: u32 = 0x0f0f0f0f;
|
||||
|
||||
let mut aux = [0u32; 3];
|
||||
|
||||
unsafe {
|
||||
let m3 = _mm256_set1_epi8(3);
|
||||
let mone = _mm256_set1_epi8(1);
|
||||
let m32 = _mm_set1_epi8(32);
|
||||
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
|
||||
let mut q3 = x.qs.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
LittleEndian::read_u32_into(&x.scales, &mut aux);
|
||||
let scales128 = _mm_set_epi32(
|
||||
(((aux[1] >> 4) & KMASK2) | (((aux[2] >> 6) & KMASK1) << 4)) as i32,
|
||||
(((aux[0] >> 4) & KMASK2) | (((aux[2] >> 4) & KMASK1) << 4)) as i32,
|
||||
((aux[1] & KMASK2) | (((aux[2] >> 2) & KMASK1) << 4)) as i32,
|
||||
((aux[0] & KMASK2) | (((aux[2]) & KMASK1) << 4)) as i32,
|
||||
);
|
||||
let scales128 = _mm_sub_epi8(scales128, m32);
|
||||
let all_scales = _mm256_cvtepi8_epi16(scales128);
|
||||
let l_scales = _mm256_extracti128_si256(all_scales, 0);
|
||||
let h_scales = _mm256_extracti128_si256(all_scales, 1);
|
||||
let scales = [
|
||||
mm256_set_m128i(l_scales, l_scales),
|
||||
mm256_set_m128i(h_scales, h_scales),
|
||||
];
|
||||
|
||||
// high bit
|
||||
let hbits = _mm256_loadu_si256(x.hmask.as_ptr() as *const __m256i);
|
||||
|
||||
let mut sumi = _mm256_setzero_si256();
|
||||
|
||||
for (j, scale) in scales.iter().enumerate() {
|
||||
// load low 2 bits
|
||||
let q3bits = _mm256_loadu_si256(q3 as *const __m256i);
|
||||
q3 = q3.add(32);
|
||||
|
||||
// Prepare low and high bits
|
||||
// We hardcode the shifts here to avoid loading them into a separate register
|
||||
let q3l_0 = _mm256_and_si256(q3bits, m3);
|
||||
let q3h_0 = if j == 0 {
|
||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), 0)
|
||||
} else {
|
||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 4)), 4)
|
||||
};
|
||||
let q3h_0 = _mm256_slli_epi16(q3h_0, 2);
|
||||
|
||||
let q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3);
|
||||
let q3h_1 = if j == 0 {
|
||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 1)), 1)
|
||||
} else {
|
||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 5)), 5)
|
||||
};
|
||||
let q3h_1 = _mm256_slli_epi16(q3h_1, 2);
|
||||
|
||||
let q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3);
|
||||
let q3h_2 = if j == 0 {
|
||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 2)), 2)
|
||||
} else {
|
||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 6)), 6)
|
||||
};
|
||||
let q3h_2 = _mm256_slli_epi16(q3h_2, 2);
|
||||
|
||||
let q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3);
|
||||
let q3h_3 = if j == 0 {
|
||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 3)), 3)
|
||||
} else {
|
||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 7)), 7)
|
||||
};
|
||||
let q3h_3 = _mm256_slli_epi16(q3h_3, 2);
|
||||
|
||||
// load Q8 quants
|
||||
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_2 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_3 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
|
||||
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we
|
||||
// can use _mm256_maddubs_epi16, and then subtract. The high bit part has the 2
|
||||
// already subtracted (and so, it is zero if the high bit was not set, and 2 if the
|
||||
// high bit was set)
|
||||
let q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
|
||||
let q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
|
||||
let q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2);
|
||||
let q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3);
|
||||
|
||||
let p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
|
||||
let p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
|
||||
let p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2);
|
||||
let p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3);
|
||||
|
||||
let p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
|
||||
let p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
|
||||
let p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
|
||||
let p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
|
||||
|
||||
// multiply with scales
|
||||
let p16_0 =
|
||||
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(0)), p16_0);
|
||||
let p16_1 =
|
||||
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(1)), p16_1);
|
||||
let p16_2 =
|
||||
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(2)), p16_2);
|
||||
let p16_3 =
|
||||
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(3)), p16_3);
|
||||
|
||||
// accumulate
|
||||
let p16_0 = _mm256_add_epi32(p16_0, p16_1);
|
||||
let p16_2 = _mm256_add_epi32(p16_2, p16_3);
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2));
|
||||
}
|
||||
|
||||
// multiply with block scale and accumulate
|
||||
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
|
||||
}
|
||||
Ok(hsum_float_8(acc))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut utmp = [0u32; 4];
|
||||
const KMASK1: u32 = 0x3f3f3f3f;
|
||||
const KMASK2: u32 = 0x0f0f0f0f;
|
||||
const KMASK3: u32 = 0x03030303;
|
||||
|
||||
unsafe {
|
||||
let m4 = _mm256_set1_epi8(0xF);
|
||||
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
let mut acc_m = _mm_setzero_ps();
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
let dmin = -y.d * x.dmin.to_f32();
|
||||
|
||||
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
|
||||
|
||||
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
|
||||
let uaux = utmp[1] & KMASK1;
|
||||
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
|
||||
utmp[2] = uaux;
|
||||
utmp[0] &= KMASK1;
|
||||
|
||||
let mut q4 = x.qs.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
let mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(
|
||||
utmp[3] as i32,
|
||||
utmp[2] as i32,
|
||||
utmp[1] as i32,
|
||||
utmp[0] as i32,
|
||||
));
|
||||
|
||||
let q8sums = _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i);
|
||||
let q8s = _mm_hadd_epi16(
|
||||
_mm256_extracti128_si256(q8sums, 0),
|
||||
_mm256_extracti128_si256(q8sums, 1),
|
||||
);
|
||||
let prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
|
||||
acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
|
||||
|
||||
let sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
|
||||
let scales = mm256_set_m128i(sc128, sc128);
|
||||
|
||||
let mut sumi = _mm256_setzero_si256();
|
||||
|
||||
for j in 0..QK_K / 64 {
|
||||
let scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j));
|
||||
let scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j + 1));
|
||||
|
||||
let q4bits = _mm256_loadu_si256(q4 as *const __m256i);
|
||||
q4 = q4.add(32);
|
||||
let q4l = _mm256_and_si256(q4bits, m4);
|
||||
let q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
|
||||
|
||||
let q8l = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let p16l = _mm256_maddubs_epi16(q4l, q8l);
|
||||
let p16l = _mm256_madd_epi16(scale_l, p16l);
|
||||
sumi = _mm256_add_epi32(sumi, p16l);
|
||||
|
||||
let q8h = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let p16h = _mm256_maddubs_epi16(q4h, q8h);
|
||||
let p16h = _mm256_madd_epi16(scale_h, p16h);
|
||||
sumi = _mm256_add_epi32(sumi, p16h);
|
||||
}
|
||||
|
||||
let vd = _mm256_set1_ps(d);
|
||||
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
|
||||
}
|
||||
|
||||
let acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
|
||||
let acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
|
||||
|
||||
Ok(hsum_float_8(acc) + _mm_cvtss_f32(acc_m))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut utmp = [0u32; 4];
|
||||
const KMASK1: u32 = 0x3f3f3f3f;
|
||||
const KMASK2: u32 = 0x0f0f0f0f;
|
||||
const KMASK3: u32 = 0x03030303;
|
||||
|
||||
unsafe {
|
||||
let m4 = _mm256_set1_epi8(0xF);
|
||||
let mzero = _mm_setzero_si128();
|
||||
let mone = _mm256_set1_epi8(1);
|
||||
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
let mut summs = 0.0;
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
let dmin = -y.d * x.dmin.to_f32();
|
||||
|
||||
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
|
||||
|
||||
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
|
||||
let uaux = utmp[1] & KMASK1;
|
||||
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
|
||||
utmp[2] = uaux;
|
||||
utmp[0] &= KMASK1;
|
||||
|
||||
let mut q5 = x.qs.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
let mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(
|
||||
utmp[3] as i32,
|
||||
utmp[2] as i32,
|
||||
utmp[1] as i32,
|
||||
utmp[0] as i32,
|
||||
));
|
||||
|
||||
let q8sums = _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i);
|
||||
let q8s = _mm_hadd_epi16(
|
||||
_mm256_extracti128_si256(q8sums, 0),
|
||||
_mm256_extracti128_si256(q8sums, 1),
|
||||
);
|
||||
let prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
|
||||
let hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
|
||||
summs += dmin * _mm_extract_epi32(hsum, 0) as f32;
|
||||
|
||||
let sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
|
||||
let scales = mm256_set_m128i(sc128, sc128);
|
||||
|
||||
let hbits = _mm256_loadu_si256(x.qh.as_ptr() as *const __m256i);
|
||||
let mut hmask = mone;
|
||||
|
||||
let mut sumi = _mm256_setzero_si256();
|
||||
|
||||
for j in 0..QK_K / 64 {
|
||||
let scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j));
|
||||
let scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j + 1));
|
||||
|
||||
let q5bits = _mm256_loadu_si256(q5 as *const __m256i);
|
||||
q5 = q5.add(32);
|
||||
|
||||
//Similar to q3k we hardcode the shifts here to avoid loading them into a separate register
|
||||
let q5l_0 = _mm256_and_si256(q5bits, m4);
|
||||
let q5l_0_shift_input = _mm256_and_si256(hbits, hmask);
|
||||
let q5l_0_right_shift = match j {
|
||||
0 => _mm256_srli_epi16(q5l_0_shift_input, 0),
|
||||
1 => _mm256_srli_epi16(q5l_0_shift_input, 2),
|
||||
2 => _mm256_srli_epi16(q5l_0_shift_input, 4),
|
||||
3 => _mm256_srli_epi16(q5l_0_shift_input, 6),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let q5h_0 = _mm256_slli_epi16(q5l_0_right_shift, 4);
|
||||
let q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
|
||||
hmask = _mm256_slli_epi16(hmask, 1);
|
||||
|
||||
let q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
|
||||
let q5l_1_shift_input = _mm256_and_si256(hbits, hmask);
|
||||
let q5l_1_right_shift = match j {
|
||||
0 => _mm256_srli_epi16(q5l_1_shift_input, 1),
|
||||
1 => _mm256_srli_epi16(q5l_1_shift_input, 3),
|
||||
2 => _mm256_srli_epi16(q5l_1_shift_input, 5),
|
||||
3 => _mm256_srli_epi16(q5l_1_shift_input, 7),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let q5h_1 = _mm256_slli_epi16(q5l_1_right_shift, 4);
|
||||
let q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
|
||||
hmask = _mm256_slli_epi16(hmask, 1);
|
||||
|
||||
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
|
||||
q8 = q8.add(32);
|
||||
|
||||
let p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
|
||||
let p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
|
||||
|
||||
let p16_0 = _mm256_madd_epi16(scale_0, p16_0);
|
||||
let p16_1 = _mm256_madd_epi16(scale_1, p16_1);
|
||||
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
|
||||
}
|
||||
let vd = _mm256_set1_ps(d);
|
||||
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
|
||||
}
|
||||
Ok(hsum_float_8(acc) + summs)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
let qk = QK_K;
|
||||
if n % qk != 0 {
|
||||
crate::bail!("vec_dot_q8k_8k: {n} is not divisible by {qk}")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||
let mut sumi = _mm256_setzero_si256();
|
||||
let x_qs = xs.qs.as_ptr();
|
||||
let y_qs = ys.qs.as_ptr();
|
||||
for j in (0..QK_K).step_by(32) {
|
||||
let xs = _mm256_loadu_si256(x_qs.add(j) as *const __m256i);
|
||||
let ys = _mm256_loadu_si256(y_qs.add(j) as *const __m256i);
|
||||
|
||||
let xs0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xs, 0));
|
||||
let ys0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(ys, 0));
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(xs0, ys0));
|
||||
|
||||
let xs1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xs, 1));
|
||||
let ys1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(ys, 1));
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(xs1, ys1));
|
||||
}
|
||||
let d = _mm256_set1_ps(xs.d * ys.d);
|
||||
acc = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi), acc);
|
||||
}
|
||||
Ok(hsum_float_8(acc))
|
||||
}
|
||||
}
|
@ -1,43 +0,0 @@
|
||||
#![allow(unused)]
|
||||
use super::GgmlDType;
|
||||
use crate::{Error, MetalDevice, MetalStorage, Result};
|
||||
|
||||
pub struct QMetalStorage {
|
||||
dtype: GgmlDType,
|
||||
device: MetalDevice,
|
||||
}
|
||||
|
||||
impl QMetalStorage {
|
||||
pub fn zeros(_: &MetalDevice, _: usize, _: GgmlDType) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &MetalDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, _elem_count: usize) -> Result<MetalStorage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
pub fn quantize(&mut self, _src: &MetalStorage) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
pub fn fwd(
|
||||
&self,
|
||||
_self_shape: &crate::Shape,
|
||||
_storage: &MetalStorage,
|
||||
_layout: &crate::Layout,
|
||||
) -> Result<(MetalStorage, crate::Shape)> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
}
|
@ -1,11 +1,8 @@
|
||||
//! Support for the GGML file format.
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use super::metal::load_quantized_metal;
|
||||
use super::{k_quants, GgmlDType, QStorage};
|
||||
use crate::{Device, Result};
|
||||
use super::{k_quants, GgmlDType};
|
||||
use crate::Result;
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use std::collections::HashMap;
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@ -123,22 +120,11 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
|
||||
raw_data: &[u8],
|
||||
size_in_bytes: usize,
|
||||
dims: Vec<usize>,
|
||||
device: &Device,
|
||||
) -> Result<super::QTensor> {
|
||||
let raw_data_ptr = raw_data.as_ptr();
|
||||
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
|
||||
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
||||
let data: QStorage = match device {
|
||||
Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
|
||||
#[cfg(feature = "metal")]
|
||||
Device::Metal(metal) => load_quantized_metal(metal, data)?,
|
||||
#[cfg(not(feature = "metal"))]
|
||||
Device::Metal(_metal) => {
|
||||
crate::bail!("Metal backend requires `metal` feature")
|
||||
}
|
||||
device => unimplemented!("Implement quantized tensor for device {device:?}"),
|
||||
};
|
||||
super::QTensor::new(data, dims)
|
||||
Ok(super::QTensor::new(data.to_vec(), dims))
|
||||
}
|
||||
|
||||
/// Creates a [Tensor] from a raw GGML tensor.
|
||||
@ -146,50 +132,23 @@ pub fn qtensor_from_ggml(
|
||||
ggml_dtype: GgmlDType,
|
||||
raw_data: &[u8],
|
||||
dims: Vec<usize>,
|
||||
device: &Device,
|
||||
) -> Result<super::QTensor> {
|
||||
let tensor_elems = dims.iter().product::<usize>();
|
||||
let block_size = ggml_dtype.block_size();
|
||||
if tensor_elems % block_size != 0 {
|
||||
crate::bail!(
|
||||
"the number of elements {tensor_elems} is not divisible by the block size {block_size}"
|
||||
)
|
||||
}
|
||||
let size_in_bytes = tensor_elems / block_size * ggml_dtype.type_size();
|
||||
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
|
||||
|
||||
match ggml_dtype {
|
||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
|
||||
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
|
||||
GgmlDType::Q4_0 => {
|
||||
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q4_1 => {
|
||||
from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q5_0 => {
|
||||
from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q8_0 => {
|
||||
from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q2K => {
|
||||
from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q3K => {
|
||||
from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q4K => {
|
||||
from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q5K => {
|
||||
from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q6K => {
|
||||
from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
|
||||
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
|
||||
}
|
||||
}
|
||||
@ -197,7 +156,6 @@ pub fn qtensor_from_ggml(
|
||||
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||
reader: &mut R,
|
||||
magic: VersionedMagic,
|
||||
device: &Device,
|
||||
) -> Result<(String, super::QTensor)> {
|
||||
let n_dims = reader.read_u32::<LittleEndian>()?;
|
||||
let name_len = reader.read_u32::<LittleEndian>()?;
|
||||
@ -205,9 +163,6 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||
let ggml_dtype = GgmlDType::from_u32(ggml_dtype)?;
|
||||
let mut dims = vec![0u32; n_dims as usize];
|
||||
reader.read_u32_into::<LittleEndian>(&mut dims)?;
|
||||
// The dimensions are stored in reverse order, see for example:
|
||||
// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/convert.py#L969
|
||||
dims.reverse();
|
||||
let mut name = vec![0u8; name_len as usize];
|
||||
reader.read_exact(&mut name)?;
|
||||
let name = String::from_utf8_lossy(&name).into_owned();
|
||||
@ -218,11 +173,12 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||
}
|
||||
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
|
||||
let tensor_elems = dims.iter().product::<usize>();
|
||||
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.block_size();
|
||||
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
|
||||
println!("{name} {ggml_dtype:?} {dims:?}");
|
||||
// TODO: Mmap version to avoid copying the data around?
|
||||
let mut raw_data = vec![0u8; size_in_bytes];
|
||||
reader.read_exact(&mut raw_data)?;
|
||||
match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) {
|
||||
match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
|
||||
Ok(tensor) => Ok((name, tensor)),
|
||||
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
|
||||
}
|
||||
@ -232,41 +188,28 @@ pub struct Content {
|
||||
pub magic: VersionedMagic,
|
||||
pub hparams: HParams,
|
||||
pub vocab: Vocab,
|
||||
pub tensors: HashMap<String, super::QTensor>,
|
||||
pub device: Device,
|
||||
pub tensors: Vec<(String, super::QTensor)>,
|
||||
}
|
||||
|
||||
impl Content {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
) -> Result<Content> {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
|
||||
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
|
||||
reader.seek(std::io::SeekFrom::Start(0))?;
|
||||
let magic = VersionedMagic::read(reader)?;
|
||||
let hparams = HParams::read(reader)?;
|
||||
let vocab = Vocab::read(reader, hparams.n_vocab as usize)?;
|
||||
let mut tensors = HashMap::new();
|
||||
let mut tensors = vec![];
|
||||
|
||||
while reader.stream_position()? != last_position {
|
||||
let (name, tensor) = read_one_tensor(reader, magic, device)?;
|
||||
tensors.insert(name, tensor);
|
||||
let (name, tensor) = read_one_tensor(reader, magic)?;
|
||||
tensors.push((name, tensor))
|
||||
}
|
||||
let device = device.clone();
|
||||
Ok(Self {
|
||||
magic,
|
||||
hparams,
|
||||
vocab,
|
||||
tensors,
|
||||
device,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, name: &str) -> Result<super::QTensor> {
|
||||
match self.tensors.remove(name) {
|
||||
None => crate::bail!("cannot find tensor with name '{name}'"),
|
||||
Some(tensor) => Ok(tensor),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,534 +0,0 @@
|
||||
//! Support for the GGUF file format.
|
||||
//!
|
||||
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
|
||||
|
||||
use super::{GgmlDType, QTensor};
|
||||
use crate::{Device, Result};
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub const DEFAULT_ALIGNMENT: u64 = 32;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum Magic {
|
||||
Gguf,
|
||||
}
|
||||
|
||||
impl TryFrom<u32> for Magic {
|
||||
type Error = crate::Error;
|
||||
fn try_from(value: u32) -> Result<Self> {
|
||||
let magic = match value {
|
||||
0x46554747 | 0x47475546 => Self::Gguf,
|
||||
_ => crate::bail!("unknown magic 0x{value:08x}"),
|
||||
};
|
||||
Ok(magic)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum VersionedMagic {
|
||||
GgufV1,
|
||||
GgufV2,
|
||||
GgufV3,
|
||||
}
|
||||
|
||||
impl VersionedMagic {
|
||||
fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
|
||||
let magic = reader.read_u32::<LittleEndian>()?;
|
||||
let magic = Magic::try_from(magic)?;
|
||||
let version = reader.read_u32::<LittleEndian>()?;
|
||||
let versioned_magic = match (magic, version) {
|
||||
(Magic::Gguf, 1) => Self::GgufV1,
|
||||
(Magic::Gguf, 2) => Self::GgufV2,
|
||||
(Magic::Gguf, 3) => Self::GgufV3,
|
||||
_ => crate::bail!("gguf: unsupported magic/version {magic:?}/{version}"),
|
||||
};
|
||||
Ok(versioned_magic)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TensorInfo {
|
||||
pub ggml_dtype: GgmlDType,
|
||||
pub shape: crate::Shape,
|
||||
pub offset: u64,
|
||||
}
|
||||
|
||||
impl TensorInfo {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(
|
||||
&self,
|
||||
reader: &mut R,
|
||||
tensor_data_offset: u64,
|
||||
device: &Device,
|
||||
) -> Result<QTensor> {
|
||||
let tensor_elems = self.shape.elem_count();
|
||||
let block_size = self.ggml_dtype.block_size();
|
||||
if tensor_elems % block_size != 0 {
|
||||
crate::bail!(
|
||||
"the number of elements {tensor_elems} is not divisible by the block size {block_size}"
|
||||
)
|
||||
}
|
||||
let size_in_bytes = tensor_elems / block_size * self.ggml_dtype.type_size();
|
||||
let mut raw_data = vec![0u8; size_in_bytes];
|
||||
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
|
||||
reader.read_exact(&mut raw_data)?;
|
||||
super::ggml_file::qtensor_from_ggml(
|
||||
self.ggml_dtype,
|
||||
&raw_data,
|
||||
self.shape.dims().to_vec(),
|
||||
device,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Content {
|
||||
pub magic: VersionedMagic,
|
||||
pub metadata: HashMap<String, Value>,
|
||||
pub tensor_infos: HashMap<String, TensorInfo>,
|
||||
pub tensor_data_offset: u64,
|
||||
}
|
||||
|
||||
fn read_string<R: std::io::Read>(reader: &mut R, magic: &VersionedMagic) -> Result<String> {
|
||||
let len = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||
reader.read_u64::<LittleEndian>()? as usize
|
||||
}
|
||||
};
|
||||
let mut v = vec![0u8; len];
|
||||
reader.read_exact(&mut v)?;
|
||||
// GGUF strings are supposed to be non-null terminated but in practice this happens.
|
||||
while let Some(0) = v.last() {
|
||||
v.pop();
|
||||
}
|
||||
// GGUF strings are utf8 encoded but there are cases that don't seem to be valid.
|
||||
Ok(String::from_utf8_lossy(&v).into_owned())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum ValueType {
|
||||
// The value is a 8-bit unsigned integer.
|
||||
U8,
|
||||
// The value is a 8-bit signed integer.
|
||||
I8,
|
||||
// The value is a 16-bit unsigned little-endian integer.
|
||||
U16,
|
||||
// The value is a 16-bit signed little-endian integer.
|
||||
I16,
|
||||
// The value is a 32-bit unsigned little-endian integer.
|
||||
U32,
|
||||
// The value is a 32-bit signed little-endian integer.
|
||||
I32,
|
||||
// The value is a 64-bit unsigned little-endian integer.
|
||||
U64,
|
||||
// The value is a 64-bit signed little-endian integer.
|
||||
I64,
|
||||
// The value is a 32-bit IEEE754 floating point number.
|
||||
F32,
|
||||
// The value is a 64-bit IEEE754 floating point number.
|
||||
F64,
|
||||
// The value is a boolean.
|
||||
// 1-byte value where 0 is false and 1 is true.
|
||||
// Anything else is invalid, and should be treated as either the model being invalid or the reader being buggy.
|
||||
Bool,
|
||||
// The value is a UTF-8 non-null-terminated string, with length prepended.
|
||||
String,
|
||||
// The value is an array of other values, with the length and type prepended.
|
||||
///
|
||||
// Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes.
|
||||
Array,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Value {
|
||||
U8(u8),
|
||||
I8(i8),
|
||||
U16(u16),
|
||||
I16(i16),
|
||||
U32(u32),
|
||||
I32(i32),
|
||||
U64(u64),
|
||||
I64(i64),
|
||||
F32(f32),
|
||||
F64(f64),
|
||||
Bool(bool),
|
||||
String(String),
|
||||
Array(Vec<Value>),
|
||||
}
|
||||
|
||||
impl Value {
|
||||
pub fn value_type(&self) -> ValueType {
|
||||
match self {
|
||||
Self::U8(_) => ValueType::U8,
|
||||
Self::I8(_) => ValueType::I8,
|
||||
Self::U16(_) => ValueType::U16,
|
||||
Self::I16(_) => ValueType::I16,
|
||||
Self::U32(_) => ValueType::U32,
|
||||
Self::I32(_) => ValueType::I32,
|
||||
Self::U64(_) => ValueType::U64,
|
||||
Self::I64(_) => ValueType::I64,
|
||||
Self::F32(_) => ValueType::F32,
|
||||
Self::F64(_) => ValueType::F64,
|
||||
Self::Bool(_) => ValueType::Bool,
|
||||
Self::String(_) => ValueType::String,
|
||||
Self::Array(_) => ValueType::Array,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_u8(&self) -> Result<u8> {
|
||||
match self {
|
||||
Self::U8(v) => Ok(*v),
|
||||
v => crate::bail!("not a u8 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_i8(&self) -> Result<i8> {
|
||||
match self {
|
||||
Self::I8(v) => Ok(*v),
|
||||
v => crate::bail!("not a i8 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_u16(&self) -> Result<u16> {
|
||||
match self {
|
||||
Self::U16(v) => Ok(*v),
|
||||
v => crate::bail!("not a u16 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_i16(&self) -> Result<i16> {
|
||||
match self {
|
||||
Self::I16(v) => Ok(*v),
|
||||
v => crate::bail!("not a i16 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_u32(&self) -> Result<u32> {
|
||||
match self {
|
||||
Self::U32(v) => Ok(*v),
|
||||
v => crate::bail!("not a u32 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_i32(&self) -> Result<i32> {
|
||||
match self {
|
||||
Self::I32(v) => Ok(*v),
|
||||
v => crate::bail!("not a i32 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_u64(&self) -> Result<u64> {
|
||||
match self {
|
||||
Self::U64(v) => Ok(*v),
|
||||
v => crate::bail!("not a u64 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_i64(&self) -> Result<i64> {
|
||||
match self {
|
||||
Self::I64(v) => Ok(*v),
|
||||
v => crate::bail!("not a i64 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_f32(&self) -> Result<f32> {
|
||||
match self {
|
||||
Self::F32(v) => Ok(*v),
|
||||
v => crate::bail!("not a f32 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_f64(&self) -> Result<f64> {
|
||||
match self {
|
||||
Self::F64(v) => Ok(*v),
|
||||
v => crate::bail!("not a f64 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_bool(&self) -> Result<bool> {
|
||||
match self {
|
||||
Self::Bool(v) => Ok(*v),
|
||||
v => crate::bail!("not a bool {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_vec(&self) -> Result<&Vec<Value>> {
|
||||
match self {
|
||||
Self::Array(v) => Ok(v),
|
||||
v => crate::bail!("not a vec {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_string(&self) -> Result<&String> {
|
||||
match self {
|
||||
Self::String(v) => Ok(v),
|
||||
v => crate::bail!("not a string {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn read<R: std::io::Read>(
|
||||
reader: &mut R,
|
||||
value_type: ValueType,
|
||||
magic: &VersionedMagic,
|
||||
) -> Result<Self> {
|
||||
let v = match value_type {
|
||||
ValueType::U8 => Self::U8(reader.read_u8()?),
|
||||
ValueType::I8 => Self::I8(reader.read_i8()?),
|
||||
ValueType::U16 => Self::U16(reader.read_u16::<LittleEndian>()?),
|
||||
ValueType::I16 => Self::I16(reader.read_i16::<LittleEndian>()?),
|
||||
ValueType::U32 => Self::U32(reader.read_u32::<LittleEndian>()?),
|
||||
ValueType::I32 => Self::I32(reader.read_i32::<LittleEndian>()?),
|
||||
ValueType::U64 => Self::U64(reader.read_u64::<LittleEndian>()?),
|
||||
ValueType::I64 => Self::I64(reader.read_i64::<LittleEndian>()?),
|
||||
ValueType::F32 => Self::F32(reader.read_f32::<LittleEndian>()?),
|
||||
ValueType::F64 => Self::F64(reader.read_f64::<LittleEndian>()?),
|
||||
ValueType::Bool => match reader.read_u8()? {
|
||||
0 => Self::Bool(false),
|
||||
1 => Self::Bool(true),
|
||||
b => crate::bail!("unexpected bool value {b}"),
|
||||
},
|
||||
ValueType::String => Self::String(read_string(reader, magic)?),
|
||||
ValueType::Array => {
|
||||
let value_type = reader.read_u32::<LittleEndian>()?;
|
||||
let value_type = ValueType::from_u32(value_type)?;
|
||||
let len = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||
reader.read_u64::<LittleEndian>()? as usize
|
||||
}
|
||||
};
|
||||
let mut vs = Vec::with_capacity(len);
|
||||
for _ in 0..len {
|
||||
vs.push(Value::read(reader, value_type, magic)?)
|
||||
}
|
||||
Self::Array(vs)
|
||||
}
|
||||
};
|
||||
Ok(v)
|
||||
}
|
||||
|
||||
fn write<W: std::io::Write>(&self, w: &mut W) -> Result<()> {
|
||||
match self {
|
||||
&Self::U8(v) => w.write_u8(v)?,
|
||||
&Self::I8(v) => w.write_i8(v)?,
|
||||
&Self::U16(v) => w.write_u16::<LittleEndian>(v)?,
|
||||
&Self::I16(v) => w.write_i16::<LittleEndian>(v)?,
|
||||
&Self::U32(v) => w.write_u32::<LittleEndian>(v)?,
|
||||
&Self::I32(v) => w.write_i32::<LittleEndian>(v)?,
|
||||
&Self::U64(v) => w.write_u64::<LittleEndian>(v)?,
|
||||
&Self::I64(v) => w.write_i64::<LittleEndian>(v)?,
|
||||
&Self::F32(v) => w.write_f32::<LittleEndian>(v)?,
|
||||
&Self::F64(v) => w.write_f64::<LittleEndian>(v)?,
|
||||
&Self::Bool(v) => w.write_u8(u8::from(v))?,
|
||||
Self::String(v) => write_string(w, v.as_str())?,
|
||||
Self::Array(v) => {
|
||||
// The `Value` type does not enforce that all the values in an Array have the same
|
||||
// type.
|
||||
let value_type = if v.is_empty() {
|
||||
// Doesn't matter, the array is empty.
|
||||
ValueType::U32
|
||||
} else {
|
||||
let value_type: std::collections::HashSet<_> =
|
||||
v.iter().map(|elem| elem.value_type()).collect();
|
||||
if value_type.len() != 1 {
|
||||
crate::bail!("multiple value-types in the same array {value_type:?}")
|
||||
}
|
||||
value_type.into_iter().next().unwrap()
|
||||
};
|
||||
w.write_u32::<LittleEndian>(value_type.to_u32())?;
|
||||
w.write_u64::<LittleEndian>(v.len() as u64)?;
|
||||
for elem in v.iter() {
|
||||
elem.write(w)?
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl ValueType {
|
||||
fn from_u32(v: u32) -> Result<Self> {
|
||||
let v = match v {
|
||||
0 => Self::U8,
|
||||
1 => Self::I8,
|
||||
2 => Self::U16,
|
||||
3 => Self::I16,
|
||||
4 => Self::U32,
|
||||
5 => Self::I32,
|
||||
6 => Self::F32,
|
||||
7 => Self::Bool,
|
||||
8 => Self::String,
|
||||
9 => Self::Array,
|
||||
10 => Self::U64,
|
||||
11 => Self::I64,
|
||||
12 => Self::F64,
|
||||
v => crate::bail!("unrecognized value-type {v:#08x}"),
|
||||
};
|
||||
Ok(v)
|
||||
}
|
||||
|
||||
fn to_u32(self) -> u32 {
|
||||
match self {
|
||||
Self::U8 => 0,
|
||||
Self::I8 => 1,
|
||||
Self::U16 => 2,
|
||||
Self::I16 => 3,
|
||||
Self::U32 => 4,
|
||||
Self::I32 => 5,
|
||||
Self::F32 => 6,
|
||||
Self::Bool => 7,
|
||||
Self::String => 8,
|
||||
Self::Array => 9,
|
||||
Self::U64 => 10,
|
||||
Self::I64 => 11,
|
||||
Self::F64 => 12,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Content {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Self> {
|
||||
let magic = VersionedMagic::read(reader)?;
|
||||
|
||||
let tensor_count = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||
reader.read_u64::<LittleEndian>()? as usize
|
||||
}
|
||||
};
|
||||
let metadata_kv_count = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||
reader.read_u64::<LittleEndian>()? as usize
|
||||
}
|
||||
};
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
for _idx in 0..metadata_kv_count {
|
||||
let key = read_string(reader, &magic)?;
|
||||
let value_type = reader.read_u32::<LittleEndian>()?;
|
||||
let value_type = ValueType::from_u32(value_type)?;
|
||||
let value = Value::read(reader, value_type, &magic)?;
|
||||
metadata.insert(key, value);
|
||||
}
|
||||
let mut tensor_infos = HashMap::new();
|
||||
for _idx in 0..tensor_count {
|
||||
let tensor_name = read_string(reader, &magic)?;
|
||||
let n_dimensions = reader.read_u32::<LittleEndian>()?;
|
||||
|
||||
let mut dimensions: Vec<usize> = match magic {
|
||||
VersionedMagic::GgufV1 => {
|
||||
let mut dimensions = vec![0; n_dimensions as usize];
|
||||
reader.read_u32_into::<LittleEndian>(&mut dimensions)?;
|
||||
dimensions.into_iter().map(|c| c as usize).collect()
|
||||
}
|
||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||
let mut dimensions = vec![0; n_dimensions as usize];
|
||||
reader.read_u64_into::<LittleEndian>(&mut dimensions)?;
|
||||
dimensions.into_iter().map(|c| c as usize).collect()
|
||||
}
|
||||
};
|
||||
|
||||
dimensions.reverse();
|
||||
let ggml_dtype = reader.read_u32::<LittleEndian>()?;
|
||||
let ggml_dtype = GgmlDType::from_u32(ggml_dtype)?;
|
||||
let offset = reader.read_u64::<LittleEndian>()?;
|
||||
tensor_infos.insert(
|
||||
tensor_name,
|
||||
TensorInfo {
|
||||
shape: crate::Shape::from(dimensions),
|
||||
offset,
|
||||
ggml_dtype,
|
||||
},
|
||||
);
|
||||
}
|
||||
let position = reader.stream_position()?;
|
||||
let alignment = match metadata.get("general.alignment") {
|
||||
Some(Value::U8(v)) => *v as u64,
|
||||
Some(Value::U16(v)) => *v as u64,
|
||||
Some(Value::U32(v)) => *v as u64,
|
||||
Some(Value::I8(v)) if *v >= 0 => *v as u64,
|
||||
Some(Value::I16(v)) if *v >= 0 => *v as u64,
|
||||
Some(Value::I32(v)) if *v >= 0 => *v as u64,
|
||||
_ => DEFAULT_ALIGNMENT,
|
||||
};
|
||||
let tensor_data_offset = (position + alignment - 1) / alignment * alignment;
|
||||
Ok(Self {
|
||||
magic,
|
||||
metadata,
|
||||
tensor_infos,
|
||||
tensor_data_offset,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn tensor<R: std::io::Seek + std::io::Read>(
|
||||
&self,
|
||||
reader: &mut R,
|
||||
name: &str,
|
||||
device: &Device,
|
||||
) -> Result<QTensor> {
|
||||
let tensor_info = match self.tensor_infos.get(name) {
|
||||
Some(tensor_info) => tensor_info,
|
||||
None => crate::bail!("cannot find tensor info for {name}"),
|
||||
};
|
||||
tensor_info.read(reader, self.tensor_data_offset, device)
|
||||
}
|
||||
}
|
||||
|
||||
fn write_string<W: std::io::Write>(w: &mut W, str: &str) -> Result<()> {
|
||||
let bytes = str.as_bytes();
|
||||
w.write_u64::<LittleEndian>(bytes.len() as u64)?;
|
||||
w.write_all(bytes)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write<W: std::io::Seek + std::io::Write>(
|
||||
w: &mut W,
|
||||
metadata: &[(&str, &Value)],
|
||||
tensors: &[(&str, &QTensor)],
|
||||
) -> Result<()> {
|
||||
w.write_u32::<LittleEndian>(0x46554747)?;
|
||||
w.write_u32::<LittleEndian>(2)?; // version 2.
|
||||
w.write_u64::<LittleEndian>(tensors.len() as u64)?;
|
||||
w.write_u64::<LittleEndian>(metadata.len() as u64)?;
|
||||
for (name, value) in metadata.iter() {
|
||||
write_string(w, name)?;
|
||||
w.write_u32::<LittleEndian>(value.value_type().to_u32())?;
|
||||
value.write(w)?;
|
||||
}
|
||||
let mut offset = 0usize;
|
||||
let mut offsets = Vec::with_capacity(tensors.len());
|
||||
for (name, tensor) in tensors.iter() {
|
||||
write_string(w, name)?;
|
||||
let dims = tensor.shape().dims();
|
||||
w.write_u32::<LittleEndian>(dims.len() as u32)?;
|
||||
for &dim in dims.iter().rev() {
|
||||
w.write_u64::<LittleEndian>(dim as u64)?;
|
||||
}
|
||||
w.write_u32::<LittleEndian>(tensor.dtype().to_u32())?;
|
||||
w.write_u64::<LittleEndian>(offset as u64)?;
|
||||
offsets.push(offset);
|
||||
let size_in_bytes = tensor.storage_size_in_bytes();
|
||||
let padding = 31 - (31 + size_in_bytes) % 32;
|
||||
offset += size_in_bytes + padding;
|
||||
}
|
||||
let pos = w.stream_position()? as usize;
|
||||
let padding = 31 - (31 + pos) % 32;
|
||||
w.write_all(&vec![0u8; padding])?;
|
||||
let tensor_start_pos = w.stream_position()? as usize;
|
||||
for (offset, (_name, tensor)) in offsets.iter().zip(tensors.iter()) {
|
||||
let pos = w.stream_position()? as usize;
|
||||
if tensor_start_pos + offset != pos {
|
||||
crate::bail!(
|
||||
"internal error, unexpected current position {tensor_start_pos} {offset} {pos}"
|
||||
)
|
||||
}
|
||||
let data = tensor.data()?;
|
||||
let size_in_bytes = data.len();
|
||||
w.write_all(&data)?;
|
||||
let padding = 31 - (31 + size_in_bytes) % 32;
|
||||
w.write_all(&vec![0u8; padding])?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,234 +0,0 @@
|
||||
use super::{GgmlDType, QStorage};
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::{DType, MetalDevice, MetalStorage, Result, Shape};
|
||||
use metal::Buffer;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct QMetalStorage {
|
||||
dtype: GgmlDType,
|
||||
device: MetalDevice,
|
||||
buffer: Arc<Buffer>,
|
||||
}
|
||||
|
||||
impl QMetalStorage {
|
||||
pub fn zeros(device: &MetalDevice, elem_count: usize, dtype: GgmlDType) -> Result<Self> {
|
||||
let size = elem_count * dtype.type_size() / dtype.block_size();
|
||||
let buffer = device.allocate_zeros(size)?;
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &MetalDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn buffer(&self) -> &Buffer {
|
||||
&self.buffer
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
|
||||
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
command_buffer.set_label("to_cpu");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.set_label("blit_to_cpu");
|
||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||
blit.end_encoding();
|
||||
self.device.wait_until_completed()?;
|
||||
let mut out = vec![0.0; elem_count];
|
||||
match self.dtype {
|
||||
GgmlDType::F32 => {
|
||||
let vec: Vec<f32> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
f32::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::F16 => {
|
||||
let vec: Vec<half::f16> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
half::f16::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q2K => {
|
||||
let vec: Vec<crate::quantized::BlockQ2K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q3K => {
|
||||
let vec: Vec<crate::quantized::BlockQ3K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4K => {
|
||||
let vec: Vec<crate::quantized::BlockQ4K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5K => {
|
||||
let vec: Vec<crate::quantized::BlockQ5K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q6K => {
|
||||
let vec: Vec<crate::quantized::BlockQ6K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8K => {
|
||||
let vec: Vec<crate::quantized::BlockQ8K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
}
|
||||
|
||||
let buffer = self.device.new_buffer_with_data(&out)?;
|
||||
Ok(MetalStorage::new(buffer, self.device.clone(), DType::F32))
|
||||
}
|
||||
|
||||
pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> {
|
||||
// Quantization only happens on CPU for now.
|
||||
let src = src.to_cpu::<f32>()?;
|
||||
let elem_count = src.len();
|
||||
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
|
||||
let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?;
|
||||
qcpu_storage.quantize(&src)?;
|
||||
let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?;
|
||||
self.buffer = buffer;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
self.buffer.length() as usize
|
||||
}
|
||||
|
||||
pub fn fwd(
|
||||
&self,
|
||||
self_shape: &Shape,
|
||||
storage: &MetalStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
use crate::MetalError;
|
||||
|
||||
if !layout.is_contiguous() {
|
||||
crate::bail!("input tensor is not contiguous {layout:?}")
|
||||
}
|
||||
let src_shape = layout.shape();
|
||||
// self is transposed so n is first then k.
|
||||
if src_shape.rank() < 2 {
|
||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||
}
|
||||
let (n, k) = self_shape.dims2()?;
|
||||
let mut dst_shape = src_shape.dims().to_vec();
|
||||
|
||||
let (b, m) = match dst_shape.len() {
|
||||
3 => (dst_shape[0], dst_shape[1]),
|
||||
2 => (1, dst_shape[0]),
|
||||
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
||||
};
|
||||
let last_k = dst_shape.pop().unwrap();
|
||||
if last_k != k {
|
||||
crate::bail!("input tensor {layout:?} incompatible with {:?}", self_shape)
|
||||
}
|
||||
dst_shape.push(n);
|
||||
let dst_shape = Shape::from(dst_shape);
|
||||
let device = storage.device().clone();
|
||||
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
|
||||
let command_buffer = device.command_buffer()?;
|
||||
candle_metal_kernels::call_quantized_matmul_t(
|
||||
device.device(),
|
||||
&command_buffer,
|
||||
device.kernels(),
|
||||
self.dtype.into(),
|
||||
(b, m, n, k),
|
||||
storage.buffer(),
|
||||
layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||
&self.buffer,
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
|
||||
Ok((dst_storage, dst_shape))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>(
|
||||
device: &MetalDevice,
|
||||
data: &[T],
|
||||
) -> Result<QStorage> {
|
||||
let buffer = device.new_buffer_with_data(data)?;
|
||||
let device = device.clone();
|
||||
Ok(QStorage::Metal(QMetalStorage {
|
||||
dtype: T::DTYPE,
|
||||
device,
|
||||
buffer,
|
||||
}))
|
||||
}
|
||||
|
||||
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||
let ptr = buffer.contents() as *const T;
|
||||
assert!(!ptr.is_null());
|
||||
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
|
||||
slice.to_vec()
|
||||
}
|
||||
|
||||
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
|
||||
fn from(value: GgmlDType) -> Self {
|
||||
match value {
|
||||
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
|
||||
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
|
||||
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
|
||||
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
|
||||
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
|
||||
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
|
||||
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
|
||||
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
|
||||
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
|
||||
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
|
||||
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
|
||||
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
|
||||
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
|
||||
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
|
||||
}
|
||||
}
|
||||
}
|
@ -1,119 +1,16 @@
|
||||
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
|
||||
use k_quants::*;
|
||||
use std::borrow::Cow;
|
||||
use crate::{Device, Result, Shape, Tensor};
|
||||
|
||||
#[cfg(target_feature = "avx")]
|
||||
pub mod avx;
|
||||
mod dummy_metal;
|
||||
pub mod ggml_file;
|
||||
pub mod gguf_file;
|
||||
pub mod k_quants;
|
||||
#[cfg(feature = "metal")]
|
||||
pub mod metal;
|
||||
#[cfg(not(feature = "metal"))]
|
||||
mod metal {
|
||||
pub use super::dummy_metal::*;
|
||||
}
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub mod neon;
|
||||
#[cfg(target_feature = "simd128")]
|
||||
pub mod simd128;
|
||||
pub mod utils;
|
||||
use half::f16;
|
||||
|
||||
pub use k_quants::GgmlType;
|
||||
|
||||
pub struct QTensor {
|
||||
storage: QStorage,
|
||||
data: Box<dyn QuantizedType>,
|
||||
shape: Shape,
|
||||
}
|
||||
|
||||
impl Device {
|
||||
fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = dtype.cpu_zeros(elem_count);
|
||||
Ok(QStorage::Cpu(storage))
|
||||
}
|
||||
Device::Metal(metal) => {
|
||||
let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
|
||||
Ok(QStorage::Metal(storage))
|
||||
}
|
||||
Device::Cuda(_cuda) => {
|
||||
crate::bail!("Cuda ggml quantization not supported");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum QStorage {
|
||||
Cpu(Box<dyn QuantizedType>),
|
||||
Metal(metal::QMetalStorage),
|
||||
}
|
||||
|
||||
impl QStorage {
|
||||
fn block_size(&self) -> usize {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.block_size(),
|
||||
QStorage::Metal(storage) => storage.dtype().block_size(),
|
||||
}
|
||||
}
|
||||
|
||||
fn dtype(&self) -> GgmlDType {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.dtype(),
|
||||
QStorage::Metal(storage) => storage.dtype(),
|
||||
}
|
||||
}
|
||||
|
||||
fn device(&self) -> Device {
|
||||
match self {
|
||||
QStorage::Cpu(_storage) => Device::Cpu,
|
||||
QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
|
||||
}
|
||||
}
|
||||
|
||||
fn size_in_bytes(&self) -> usize {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
|
||||
QStorage::Metal(storage) => storage.storage_size_in_bytes(),
|
||||
}
|
||||
}
|
||||
|
||||
fn quantize(&mut self, src: &Storage) -> Result<()> {
|
||||
match (self, src) {
|
||||
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
|
||||
storage.from_float(src.as_slice::<f32>()?)?;
|
||||
}
|
||||
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
|
||||
_ => crate::bail!("Invalid dequantize storage locations do not match"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn dequantize(&self, elem_count: usize) -> Result<Storage> {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
|
||||
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
|
||||
}
|
||||
}
|
||||
|
||||
fn data(&self) -> Result<Cow<[u8]>> {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => {
|
||||
let data_ptr = storage.as_ptr();
|
||||
let size_in_bytes = storage.storage_size_in_bytes();
|
||||
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
||||
Ok(Cow::from(data))
|
||||
}
|
||||
QStorage::Metal(_storage) => {
|
||||
crate::bail!("not implemented");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum GgmlDType {
|
||||
F32,
|
||||
F16,
|
||||
@ -153,46 +50,7 @@ impl GgmlDType {
|
||||
Ok(dtype)
|
||||
}
|
||||
|
||||
pub(crate) fn to_u32(self) -> u32 {
|
||||
match self {
|
||||
Self::F32 => 0,
|
||||
Self::F16 => 1,
|
||||
Self::Q4_0 => 2,
|
||||
Self::Q4_1 => 3,
|
||||
Self::Q5_0 => 6,
|
||||
Self::Q5_1 => 7,
|
||||
Self::Q8_0 => 8,
|
||||
Self::Q8_1 => 9,
|
||||
Self::Q2K => 10,
|
||||
Self::Q3K => 11,
|
||||
Self::Q4K => 12,
|
||||
Self::Q5K => 13,
|
||||
Self::Q6K => 14,
|
||||
Self::Q8K => 15,
|
||||
}
|
||||
}
|
||||
|
||||
/// The block dtype
|
||||
pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> {
|
||||
match self {
|
||||
Self::F32 => Box::new(vec![f32::zeros(); elem_count]),
|
||||
Self::F16 => Box::new(vec![f16::zeros(); elem_count]),
|
||||
Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]),
|
||||
Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]),
|
||||
Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]),
|
||||
Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]),
|
||||
Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]),
|
||||
Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]),
|
||||
Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]),
|
||||
Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]),
|
||||
Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]),
|
||||
Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
|
||||
Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
|
||||
Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
|
||||
}
|
||||
}
|
||||
/// The type size for blocks in bytes.
|
||||
pub fn type_size(&self) -> usize {
|
||||
fn type_size(&self) -> usize {
|
||||
use k_quants::*;
|
||||
match self {
|
||||
Self::F32 => 4,
|
||||
@ -213,8 +71,7 @@ impl GgmlDType {
|
||||
}
|
||||
}
|
||||
|
||||
/// The block size, i.e. the number of elements stored in each block.
|
||||
pub fn block_size(&self) -> usize {
|
||||
fn blck_size(&self) -> usize {
|
||||
match self {
|
||||
Self::F32 => 1,
|
||||
Self::F16 => 1,
|
||||
@ -233,13 +90,7 @@ impl GgmlDType {
|
||||
pub trait QuantizedType: Send + Sync {
|
||||
fn dtype(&self) -> GgmlDType;
|
||||
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
|
||||
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>;
|
||||
fn storage_size_in_bytes(&self) -> usize;
|
||||
fn as_ptr(&self) -> *const u8;
|
||||
fn block_size(&self) -> usize;
|
||||
#[allow(clippy::wrong_self_convention)]
|
||||
fn from_float(&mut self, xs: &[f32]) -> Result<()>;
|
||||
fn size(&self) -> usize;
|
||||
fn to_float(&self, ys: &mut [f32]) -> Result<()>;
|
||||
}
|
||||
|
||||
impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
|
||||
@ -247,34 +98,12 @@ impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
|
||||
k_quants::matmul(mkn, lhs, self.as_slice(), dst)
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.len() * core::mem::size_of::<T>()
|
||||
}
|
||||
|
||||
fn from_float(&mut self, xs: &[f32]) -> Result<()> {
|
||||
T::from_float(xs, self)
|
||||
}
|
||||
|
||||
fn dtype(&self) -> GgmlDType {
|
||||
T::DTYPE
|
||||
}
|
||||
|
||||
fn block_size(&self) -> usize {
|
||||
T::BLCK_SIZE
|
||||
}
|
||||
|
||||
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> {
|
||||
let mut ys = vec![0.0f32; elem_count];
|
||||
T::to_float(self.as_slice(), &mut ys)?;
|
||||
Ok(CpuStorage::F32(ys))
|
||||
}
|
||||
|
||||
fn storage_size_in_bytes(&self) -> usize {
|
||||
self.len() * std::mem::size_of::<T>()
|
||||
}
|
||||
|
||||
fn as_ptr(&self) -> *const u8 {
|
||||
self.as_ptr() as *const u8
|
||||
fn to_float(&self, ys: &mut [f32]) -> Result<()> {
|
||||
T::to_float(self.as_slice(), ys)
|
||||
}
|
||||
}
|
||||
|
||||
@ -284,57 +113,19 @@ impl std::fmt::Debug for QTensor {
|
||||
}
|
||||
}
|
||||
|
||||
fn check_shape(shape: &Shape, block_size: usize) -> Result<()> {
|
||||
let dims = shape.dims();
|
||||
if dims.is_empty() {
|
||||
crate::bail!("scalar tensor cannot be quantized {shape:?}")
|
||||
}
|
||||
if dims[dims.len() - 1] % block_size != 0 {
|
||||
crate::bail!(
|
||||
"quantized tensor must have their last dim divisible by block size {shape:?} {}",
|
||||
block_size
|
||||
)
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl QTensor {
|
||||
pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
check_shape(&shape, storage.block_size())?;
|
||||
Ok(Self { storage, shape })
|
||||
}
|
||||
|
||||
pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> {
|
||||
let shape = src.shape();
|
||||
let block_size = dtype.block_size();
|
||||
check_shape(shape, block_size)?;
|
||||
let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
|
||||
let elem_count = shape.elem_count();
|
||||
if elem_count % block_size != 0 {
|
||||
crate::bail!(
|
||||
"tensor size ({shape:?}) is not divisible by block size {}",
|
||||
block_size
|
||||
)
|
||||
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
|
||||
data: Vec<T>,
|
||||
shape: S,
|
||||
) -> Self {
|
||||
Self {
|
||||
data: Box::new(data),
|
||||
shape: shape.into(),
|
||||
}
|
||||
let mut storage = src.device().qzeros(elem_count, dtype)?;
|
||||
storage.quantize(&src.storage())?;
|
||||
Ok(Self {
|
||||
storage,
|
||||
shape: shape.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.storage.dtype()
|
||||
}
|
||||
|
||||
pub fn device(&self) -> Device {
|
||||
self.storage.device()
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
self.shape.rank()
|
||||
self.data.dtype()
|
||||
}
|
||||
|
||||
pub fn shape(&self) -> &Shape {
|
||||
@ -342,60 +133,26 @@ impl QTensor {
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
||||
let storage = self.storage.dequantize(self.shape.elem_count())?;
|
||||
let none = crate::op::BackpropOp::none();
|
||||
let is_variable = false;
|
||||
crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable)
|
||||
.to_device(device)
|
||||
let mut f32_data = vec![0f32; self.shape.elem_count()];
|
||||
self.data.to_float(&mut f32_data)?;
|
||||
Tensor::from_vec(f32_data, &self.shape, device)
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
self.storage.size_in_bytes()
|
||||
}
|
||||
|
||||
pub fn data(&self) -> Result<Cow<'_, [u8]>> {
|
||||
self.storage.data()
|
||||
pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
|
||||
self.data.matmul_t(mkn, lhs, dst)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum QMatMul {
|
||||
QTensor(std::sync::Arc<QTensor>),
|
||||
Tensor(Tensor),
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static DEQUANTIZE_ALL: bool = {
|
||||
match std::env::var("CANDLE_DEQUANTIZE_ALL") {
|
||||
Ok(s) => {
|
||||
!s.is_empty() && s != "0"
|
||||
},
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QMatMul(std::sync::Arc<QTensor>);
|
||||
|
||||
impl QMatMul {
|
||||
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
|
||||
let dequantize = match qtensor.dtype() {
|
||||
GgmlDType::F32 | GgmlDType::F16 => true,
|
||||
_ => DEQUANTIZE_ALL.with(|b| *b),
|
||||
};
|
||||
let t = if dequantize {
|
||||
let tensor = qtensor.dequantize(&Device::Cpu)?;
|
||||
Self::Tensor(tensor)
|
||||
} else {
|
||||
Self::QTensor(qtensor)
|
||||
};
|
||||
Ok(t)
|
||||
}
|
||||
|
||||
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
|
||||
Self::from_arc(std::sync::Arc::new(qtensor))
|
||||
pub fn new(qtensor: std::sync::Arc<QTensor>) -> Self {
|
||||
Self(qtensor)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::CustomOp1 for QTensor {
|
||||
impl crate::CustomOp1 for QMatMul {
|
||||
fn name(&self) -> &'static str {
|
||||
"qmatmul"
|
||||
}
|
||||
@ -409,55 +166,29 @@ impl crate::CustomOp1 for QTensor {
|
||||
crate::bail!("input tensor is not contiguous {layout:?}")
|
||||
}
|
||||
let src_shape = layout.shape();
|
||||
// self is transposed so n is first then k.
|
||||
let (n, k) = self.shape.dims2()?;
|
||||
let (k, n) = self.0.shape.dims2()?;
|
||||
if src_shape.rank() < 2 {
|
||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||
}
|
||||
let mut dst_shape = src_shape.dims().to_vec();
|
||||
let last_k = dst_shape.pop().unwrap();
|
||||
if last_k != k {
|
||||
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
|
||||
crate::bail!(
|
||||
"input tensor {layout:?} incompatible with {:?}",
|
||||
self.0.shape
|
||||
)
|
||||
}
|
||||
dst_shape.push(n);
|
||||
let dst_shape = Shape::from(dst_shape);
|
||||
#[allow(clippy::infallible_destructuring_match)]
|
||||
let self_storage = match &self.storage {
|
||||
QStorage::Cpu(storage) => storage,
|
||||
QStorage::Metal(_) => crate::bail!("Invalid storage"),
|
||||
};
|
||||
let slice = storage.as_slice::<f32>()?;
|
||||
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||
let storage = storage.as_slice::<f32>()?;
|
||||
let storage =
|
||||
&storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||
let mut dst_storage = vec![0f32; dst_shape.elem_count()];
|
||||
self_storage.matmul_t((dst_shape.elem_count() / n, k, n), slice, &mut dst_storage)?;
|
||||
self.0.matmul_t(
|
||||
(dst_shape.elem_count() / n, k, n),
|
||||
storage,
|
||||
&mut dst_storage,
|
||||
)?;
|
||||
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
||||
}
|
||||
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &crate::MetalStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::MetalStorage, Shape)> {
|
||||
let self_storage = match &self.storage {
|
||||
QStorage::Metal(metal) => metal,
|
||||
_ => unreachable!("Cannot call metal matmul on non metal QTensor"),
|
||||
};
|
||||
self_storage.fwd(&self.shape, storage, layout)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for QMatMul {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
||||
Self::Tensor(w) => {
|
||||
let w = match *xs.dims() {
|
||||
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
|
||||
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
|
||||
_ => w.t()?,
|
||||
};
|
||||
xs.matmul(&w)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,613 +0,0 @@
|
||||
use super::k_quants::{
|
||||
BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K,
|
||||
};
|
||||
use crate::Result;
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
|
||||
#[allow(unused_imports)]
|
||||
#[cfg(target_arch = "arm")]
|
||||
use core::arch::arm::*;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
use core::arch::aarch64::*;
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t {
|
||||
// TODO: dotprod
|
||||
let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b));
|
||||
let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
|
||||
vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
let nb = n / qk;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut sumv0 = vdupq_n_f32(0.0f32);
|
||||
for i in 0..nb {
|
||||
let x0 = &xs[i];
|
||||
let y0 = &ys[i];
|
||||
|
||||
let m4b = vdupq_n_u8(0x0F);
|
||||
let s8b = vdupq_n_s8(0x8);
|
||||
|
||||
let v0_0 = vld1q_u8(x0.qs.as_ptr());
|
||||
|
||||
// 4-bit -> 8-bit
|
||||
let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
|
||||
let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
||||
|
||||
// sub 8
|
||||
let v0_0ls = vsubq_s8(v0_0l, s8b);
|
||||
let v0_0hs = vsubq_s8(v0_0h, s8b);
|
||||
|
||||
// load y
|
||||
let v1_0l = vld1q_s8(y0.qs.as_ptr());
|
||||
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||
|
||||
let pl0 = vdotq_s32(v0_0ls, v1_0l);
|
||||
let ph0 = vdotq_s32(v0_0hs, v1_0h);
|
||||
sumv0 = vmlaq_n_f32(
|
||||
sumv0,
|
||||
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
|
||||
x0.d.to_f32() * y0.d.to_f32(),
|
||||
);
|
||||
}
|
||||
Ok(vaddvq_f32(sumv0))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
let nb = n / QK8_0;
|
||||
unsafe {
|
||||
let mut sumv0 = vdupq_n_f32(0.0f32);
|
||||
for i in 0..nb {
|
||||
let x0 = &xs[i];
|
||||
let y0 = &ys[i];
|
||||
|
||||
let x0_0 = vld1q_s8(x0.qs.as_ptr());
|
||||
let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16));
|
||||
|
||||
// load y
|
||||
let y0_0 = vld1q_s8(y0.qs.as_ptr());
|
||||
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||
|
||||
let p0 = vdotq_s32(x0_0, y0_0);
|
||||
let p1 = vdotq_s32(x0_1, y0_1);
|
||||
|
||||
sumv0 = vmlaq_n_f32(
|
||||
sumv0,
|
||||
vcvtq_f32_s32(vaddq_s32(p0, p1)),
|
||||
x0.d.to_f32() * y0.d.to_f32(),
|
||||
);
|
||||
}
|
||||
Ok(vaddvq_f32(sumv0))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
let qk = QK_K;
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
|
||||
}
|
||||
|
||||
let mut sumf = 0f32;
|
||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||
unsafe {
|
||||
let mut sum_i = vdupq_n_s32(0);
|
||||
let scale = xs.d * ys.d;
|
||||
let xs = xs.qs.as_ptr();
|
||||
let ys = ys.qs.as_ptr();
|
||||
for i in (0..QK_K).step_by(16) {
|
||||
let xs = vld1q_s8(xs.add(i));
|
||||
let ys = vld1q_s8(ys.add(i));
|
||||
let xy = vdotq_s32(xs, ys);
|
||||
sum_i = vaddq_s32(sum_i, xy)
|
||||
}
|
||||
sumf += vaddvq_s32(sum_i) as f32 * scale
|
||||
}
|
||||
}
|
||||
Ok(sumf)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut sum = 0f32;
|
||||
unsafe {
|
||||
let m4b = vdupq_n_u8(0xF);
|
||||
|
||||
let mone = vdupq_n_u8(3);
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d_all = x.d.to_f32();
|
||||
|
||||
let mut q6 = x.ql.as_ptr();
|
||||
let mut qh = x.qh.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
let mut scale = x.scales.as_ptr();
|
||||
|
||||
let q8sums = vld1q_s16_x2(y.bsums.as_ptr());
|
||||
let scales = vld1q_s8(scale);
|
||||
let q6scales = int16x8x2_t(
|
||||
vmovl_s8(vget_low_s8(scales)),
|
||||
vmovl_s8(vget_high_s8(scales)),
|
||||
);
|
||||
|
||||
let prod = vaddq_s32(
|
||||
vaddq_s32(
|
||||
vmull_s16(vget_low_s16(q8sums.0), vget_low_s16(q6scales.0)),
|
||||
vmull_s16(vget_high_s16(q8sums.0), vget_high_s16(q6scales.0)),
|
||||
),
|
||||
vaddq_s32(
|
||||
vmull_s16(vget_low_s16(q8sums.1), vget_low_s16(q6scales.1)),
|
||||
vmull_s16(vget_high_s16(q8sums.1), vget_high_s16(q6scales.1)),
|
||||
),
|
||||
);
|
||||
let isum_mins = vaddvq_s32(prod);
|
||||
|
||||
let mut isum = 0i32;
|
||||
|
||||
for _j in 0..QK_K / 128 {
|
||||
let qhbits = vld1q_u8_x2(qh);
|
||||
qh = qh.add(32);
|
||||
let q6bits = vld1q_u8_x4(q6);
|
||||
q6 = q6.add(64);
|
||||
let q8bytes = vld1q_s8_x4(q8);
|
||||
q8 = q8.add(64);
|
||||
|
||||
let q6h_0 = vshlq_n_u8(vandq_u8(mone, qhbits.0), 4);
|
||||
let q6h_1 = vshlq_n_u8(vandq_u8(mone, qhbits.1), 4);
|
||||
let shifted = vshrq_n_u8(qhbits.0, 2);
|
||||
let q6h_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
||||
let shifted = vshrq_n_u8(qhbits.1, 2);
|
||||
let q6h_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
||||
|
||||
let q6bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.0, m4b), q6h_0));
|
||||
let q6bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.1, m4b), q6h_1));
|
||||
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2));
|
||||
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3));
|
||||
|
||||
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
|
||||
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
|
||||
scale = scale.add(2);
|
||||
|
||||
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
|
||||
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
|
||||
scale = scale.add(2);
|
||||
|
||||
let q8bytes = vld1q_s8_x4(q8);
|
||||
q8 = q8.add(64);
|
||||
|
||||
let shifted = vshrq_n_u8(qhbits.0, 4);
|
||||
let q6h_0 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
||||
let shifted = vshrq_n_u8(qhbits.1, 4);
|
||||
let q6h_1 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
||||
let shifted = vshrq_n_u8(qhbits.0, 6);
|
||||
let q6h_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
||||
let shifted = vshrq_n_u8(qhbits.1, 6);
|
||||
let q6h_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
||||
|
||||
let q6bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.0, 4), q6h_0));
|
||||
let q6bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.1, 4), q6h_1));
|
||||
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2));
|
||||
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3));
|
||||
|
||||
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
|
||||
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
|
||||
scale = scale.add(2);
|
||||
|
||||
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
|
||||
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
|
||||
scale = scale.add(2);
|
||||
}
|
||||
sum += d_all * y.d * ((isum - 32 * isum_mins) as f32);
|
||||
}
|
||||
}
|
||||
Ok(sum)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut sumf = 0f32;
|
||||
let mut utmp = [0u32; 4];
|
||||
const KMASK1: u32 = 0x3f3f3f3f;
|
||||
const KMASK2: u32 = 0x0f0f0f0f;
|
||||
const KMASK3: u32 = 0x03030303;
|
||||
|
||||
unsafe {
|
||||
let m4b = vdupq_n_u8(0xF);
|
||||
let mone = vdupq_n_u8(1);
|
||||
let mtwo = vdupq_n_u8(2);
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
let dmin = y.d * x.dmin.to_f32();
|
||||
|
||||
let q8sums = vpaddq_s16(
|
||||
vld1q_s16(y.bsums.as_ptr()),
|
||||
vld1q_s16(y.bsums.as_ptr().add(8)),
|
||||
);
|
||||
|
||||
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
|
||||
|
||||
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
|
||||
let uaux = utmp[1] & KMASK1;
|
||||
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
|
||||
utmp[2] = uaux;
|
||||
utmp[0] &= KMASK1;
|
||||
|
||||
let mins8 = vld1_u8((utmp.as_ptr() as *const u8).add(8));
|
||||
let mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
|
||||
let prod = vaddq_s32(
|
||||
vmull_s16(vget_low_s16(q8sums), vget_low_s16(mins)),
|
||||
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)),
|
||||
);
|
||||
let sumi_mins = vaddvq_s32(prod);
|
||||
|
||||
let mut scales = utmp.as_ptr() as *const u8;
|
||||
|
||||
let mut q5 = x.qs.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
let mut qhbits = vld1q_u8_x2(x.qh.as_ptr());
|
||||
|
||||
let mut sumi = 0i32;
|
||||
|
||||
for _j in 0..QK_K / 64 {
|
||||
let q5bits = vld1q_u8_x2(q5);
|
||||
q5 = q5.add(32);
|
||||
let q8bytes = vld1q_s8_x4(q8);
|
||||
q8 = q8.add(64);
|
||||
|
||||
let q5h_0 = vshlq_n_u8(vandq_u8(mone, qhbits.0), 4);
|
||||
let q5h_1 = vshlq_n_u8(vandq_u8(mone, qhbits.1), 4);
|
||||
let q5h_2 = vshlq_n_u8(vandq_u8(mtwo, qhbits.0), 3);
|
||||
let q5h_3 = vshlq_n_u8(vandq_u8(mtwo, qhbits.1), 3);
|
||||
qhbits.0 = vshrq_n_u8(qhbits.0, 2);
|
||||
qhbits.1 = vshrq_n_u8(qhbits.1, 2);
|
||||
|
||||
let q5bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.0, m4b), q5h_0));
|
||||
let q5bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.1, m4b), q5h_1));
|
||||
let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
|
||||
let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
|
||||
|
||||
let p0 = vdotq_s32(q5bytes_0, q8bytes.0);
|
||||
let p1 = vdotq_s32(q5bytes_1, q8bytes.1);
|
||||
sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32;
|
||||
scales = scales.add(1);
|
||||
|
||||
let p2 = vdotq_s32(q5bytes_2, q8bytes.2);
|
||||
let p3 = vdotq_s32(q5bytes_3, q8bytes.3);
|
||||
sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32;
|
||||
scales = scales.add(1);
|
||||
}
|
||||
sumf += d * sumi as f32 - dmin * sumi_mins as f32;
|
||||
}
|
||||
}
|
||||
Ok(sumf)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut sumf = 0f32;
|
||||
let mut utmp = [0u32; 4];
|
||||
let mut scales = [0u8; 16];
|
||||
const KMASK1: u32 = 0x3f3f3f3f;
|
||||
const KMASK2: u32 = 0x0f0f0f0f;
|
||||
const KMASK3: u32 = 0x03030303;
|
||||
|
||||
unsafe {
|
||||
let m4b = vdupq_n_u8(0xF);
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
let dmin = y.d * x.dmin.to_f32();
|
||||
|
||||
let q8sums = vpaddq_s16(
|
||||
vld1q_s16(y.bsums.as_ptr()),
|
||||
vld1q_s16(y.bsums.as_ptr().add(8)),
|
||||
);
|
||||
|
||||
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
|
||||
|
||||
let mins8 = vld1_u32(
|
||||
[
|
||||
utmp[1] & KMASK1,
|
||||
((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4),
|
||||
]
|
||||
.as_ptr(),
|
||||
);
|
||||
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
|
||||
utmp[0] &= KMASK1;
|
||||
|
||||
let mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
|
||||
let prod = vaddq_s32(
|
||||
vmull_s16(vget_low_s16(q8sums), vget_low_s16(mins)),
|
||||
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)),
|
||||
);
|
||||
sumf -= dmin * vaddvq_s32(prod) as f32;
|
||||
|
||||
LittleEndian::write_u32_into(&utmp, &mut scales);
|
||||
|
||||
let mut q4 = x.qs.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
let mut sumi1 = 0i32;
|
||||
let mut sumi2 = 0i32;
|
||||
|
||||
for j in 0..QK_K / 64 {
|
||||
let q4bits = vld1q_u8_x2(q4);
|
||||
q4 = q4.add(32);
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
let q4bytes = int8x16x2_t(
|
||||
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
|
||||
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
|
||||
);
|
||||
let p0 = vdotq_s32(q4bytes.0, q8bytes.0);
|
||||
let p1 = vdotq_s32(q4bytes.1, q8bytes.1);
|
||||
sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32;
|
||||
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
let q4bytes = int8x16x2_t(
|
||||
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
|
||||
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
|
||||
);
|
||||
let p2 = vdotq_s32(q4bytes.0, q8bytes.0);
|
||||
let p3 = vdotq_s32(q4bytes.1, q8bytes.1);
|
||||
sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32;
|
||||
}
|
||||
sumf += d * (sumi1 + sumi2) as f32;
|
||||
}
|
||||
}
|
||||
Ok(sumf)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut sumf = 0f32;
|
||||
let mut utmp = [0u32; 4];
|
||||
let mut aux = [0u32; 3];
|
||||
const KMASK1: u32 = 0x03030303;
|
||||
const KMASK2: u32 = 0x0f0f0f0f;
|
||||
|
||||
unsafe {
|
||||
let m3b = vdupq_n_u8(0x3);
|
||||
let m0 = vdupq_n_u8(1);
|
||||
let m1 = vshlq_n_u8(m0, 1);
|
||||
let m2 = vshlq_n_u8(m0, 2);
|
||||
let m3 = vshlq_n_u8(m0, 3);
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
let mut q3 = x.qs.as_ptr();
|
||||
let qh = x.hmask.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
|
||||
let mut qhbits = vld1q_u8_x2(qh);
|
||||
|
||||
let mut isum = 0i32;
|
||||
|
||||
// Set up scales
|
||||
LittleEndian::read_u32_into(&x.scales, &mut aux);
|
||||
|
||||
utmp[3] = ((aux[1] >> 4) & KMASK2) | (((aux[2] >> 6) & KMASK1) << 4);
|
||||
utmp[2] = ((aux[0] >> 4) & KMASK2) | (((aux[2] >> 4) & KMASK1) << 4);
|
||||
utmp[1] = (aux[1] & KMASK2) | (((aux[2] >> 2) & KMASK1) << 4);
|
||||
utmp[0] = (aux[0] & KMASK2) | ((aux[2] & KMASK1) << 4);
|
||||
|
||||
let mut scale = utmp.as_mut_ptr() as *mut i8;
|
||||
for j in 0..16 {
|
||||
*scale.add(j) -= 32i8
|
||||
}
|
||||
|
||||
for j in 0..QK_K / 128 {
|
||||
let q3bits = vld1q_u8_x2(q3);
|
||||
q3 = q3.add(32);
|
||||
let q8bytes_1 = vld1q_s8_x4(q8);
|
||||
q8 = q8.add(64);
|
||||
let q8bytes_2 = vld1q_s8_x4(q8);
|
||||
q8 = q8.add(64);
|
||||
|
||||
let q3h_0 = vshlq_n_u8(vbicq_u8(m0, qhbits.0), 2);
|
||||
let q3h_1 = vshlq_n_u8(vbicq_u8(m0, qhbits.1), 2);
|
||||
let q3h_2 = vshlq_n_u8(vbicq_u8(m1, qhbits.0), 1);
|
||||
let q3h_3 = vshlq_n_u8(vbicq_u8(m1, qhbits.1), 1);
|
||||
|
||||
let q3bytes_0 = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vandq_u8(q3bits.0, m3b)),
|
||||
vreinterpretq_s8_u8(q3h_0),
|
||||
);
|
||||
let q3bytes_1 = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vandq_u8(q3bits.1, m3b)),
|
||||
vreinterpretq_s8_u8(q3h_1),
|
||||
);
|
||||
let q3bytes_2 = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 2), m3b)),
|
||||
vreinterpretq_s8_u8(q3h_2),
|
||||
);
|
||||
let q3bytes_3 = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 2), m3b)),
|
||||
vreinterpretq_s8_u8(q3h_3),
|
||||
);
|
||||
|
||||
let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0);
|
||||
let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1);
|
||||
let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2);
|
||||
let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3);
|
||||
isum += vaddvq_s32(p0) * *scale as i32
|
||||
+ vaddvq_s32(p1) * *scale.add(1) as i32
|
||||
+ vaddvq_s32(p2) * *scale.add(2) as i32
|
||||
+ vaddvq_s32(p3) * *scale.add(3) as i32;
|
||||
scale = scale.add(4);
|
||||
|
||||
let q3h_0 = vbicq_u8(m2, qhbits.0);
|
||||
let q3h_1 = vbicq_u8(m2, qhbits.1);
|
||||
let q3h_2 = vshrq_n_u8(vbicq_u8(m3, qhbits.0), 1);
|
||||
let q3h_3 = vshrq_n_u8(vbicq_u8(m3, qhbits.1), 1);
|
||||
|
||||
let q3bytes_0 = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 4), m3b)),
|
||||
vreinterpretq_s8_u8(q3h_0),
|
||||
);
|
||||
let q3bytes_1 = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 4), m3b)),
|
||||
vreinterpretq_s8_u8(q3h_1),
|
||||
);
|
||||
let q3bytes_2 = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 6), m3b)),
|
||||
vreinterpretq_s8_u8(q3h_2),
|
||||
);
|
||||
let q3bytes_3 = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 6), m3b)),
|
||||
vreinterpretq_s8_u8(q3h_3),
|
||||
);
|
||||
|
||||
let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0);
|
||||
let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1);
|
||||
let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2);
|
||||
let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3);
|
||||
isum += vaddvq_s32(p0) * *scale as i32
|
||||
+ vaddvq_s32(p1) * *scale.add(1) as i32
|
||||
+ vaddvq_s32(p2) * *scale.add(2) as i32
|
||||
+ vaddvq_s32(p3) * *scale.add(3) as i32;
|
||||
scale = scale.add(4);
|
||||
|
||||
if j == 0 {
|
||||
qhbits.0 = vshrq_n_u8(qhbits.0, 4);
|
||||
qhbits.1 = vshrq_n_u8(qhbits.1, 4);
|
||||
}
|
||||
}
|
||||
sumf += d * isum as f32;
|
||||
}
|
||||
}
|
||||
Ok(sumf)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut sumf = 0f32;
|
||||
let mut aux = [0u8; 16];
|
||||
|
||||
unsafe {
|
||||
let m3 = vdupq_n_u8(0x3);
|
||||
let m4 = vdupq_n_u8(0xF);
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let d = y.d * x.d.to_f32();
|
||||
let dmin = -y.d * x.dmin.to_f32();
|
||||
|
||||
let mut q2 = x.qs.as_ptr();
|
||||
let mut q8 = y.qs.as_ptr();
|
||||
let sc = x.scales.as_ptr();
|
||||
|
||||
let mins_and_scales = vld1q_u8(sc);
|
||||
let scales = vandq_u8(mins_and_scales, m4);
|
||||
vst1q_u8(aux.as_mut_ptr(), scales);
|
||||
|
||||
let mins = vshrq_n_u8(mins_and_scales, 4);
|
||||
let q8sums = vld1q_s16_x2(y.bsums.as_ptr());
|
||||
let mins16 = int16x8x2_t(
|
||||
vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))),
|
||||
vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins))),
|
||||
);
|
||||
let s0 = vaddq_s32(
|
||||
vmull_s16(vget_low_s16(mins16.0), vget_low_s16(q8sums.0)),
|
||||
vmull_s16(vget_high_s16(mins16.0), vget_high_s16(q8sums.0)),
|
||||
);
|
||||
let s1 = vaddq_s32(
|
||||
vmull_s16(vget_low_s16(mins16.1), vget_low_s16(q8sums.1)),
|
||||
vmull_s16(vget_high_s16(mins16.1), vget_high_s16(q8sums.1)),
|
||||
);
|
||||
sumf += dmin * vaddvq_s32(vaddq_s32(s0, s1)) as f32;
|
||||
|
||||
let mut isum = 0i32;
|
||||
let mut is = 0usize;
|
||||
|
||||
// TODO: dotprod
|
||||
for _j in 0..QK_K / 128 {
|
||||
let q2bits = vld1q_u8_x2(q2);
|
||||
q2 = q2.add(32);
|
||||
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
let mut q2bytes = int8x16x2_t(
|
||||
vreinterpretq_s8_u8(vandq_u8(q2bits.0, m3)),
|
||||
vreinterpretq_s8_u8(vandq_u8(q2bits.1, m3)),
|
||||
);
|
||||
isum += multiply_accum_with_scale(&aux, is, 0, q2bytes, q8bytes);
|
||||
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 2), m3));
|
||||
q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 2), m3));
|
||||
isum += multiply_accum_with_scale(&aux, is, 2, q2bytes, q8bytes);
|
||||
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 4), m3));
|
||||
q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 4), m3));
|
||||
isum += multiply_accum_with_scale(&aux, is, 4, q2bytes, q8bytes);
|
||||
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 6), m3));
|
||||
q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 6), m3));
|
||||
isum += multiply_accum_with_scale(&aux, is, 6, q2bytes, q8bytes);
|
||||
|
||||
is += 8;
|
||||
}
|
||||
sumf += d * isum as f32;
|
||||
}
|
||||
}
|
||||
Ok(sumf)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn multiply_accum_with_scale(
|
||||
aux: &[u8; 16],
|
||||
is: usize,
|
||||
index: usize,
|
||||
q2bytes: int8x16x2_t,
|
||||
q8bytes: int8x16x2_t,
|
||||
) -> i32 {
|
||||
let p1 = vdotq_s32(q2bytes.0, q8bytes.0);
|
||||
let p2 = vdotq_s32(q2bytes.1, q8bytes.1);
|
||||
vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32
|
||||
}
|
@ -1,419 +0,0 @@
|
||||
use super::k_quants::{BlockQ2K, BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K};
|
||||
use crate::Result;
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
use half::f16;
|
||||
|
||||
use core::arch::wasm32::*;
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
unsafe {
|
||||
let mut acc = f32x4_splat(0.0f32);
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let x1234 = v128_load(x.qs.as_ptr() as *const v128);
|
||||
let x12 = v128_and(x1234, u8x16_splat(0x0F));
|
||||
let x12 = i8x16_sub(x12, i8x16_splat(8));
|
||||
let x34 = u8x16_shr(x1234, 4);
|
||||
let x34 = i8x16_sub(x34, i8x16_splat(8));
|
||||
|
||||
let x1 = i16x8_extend_low_i8x16(x12);
|
||||
let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr());
|
||||
let sum_xy = i32x4_dot_i16x8(x1, y1);
|
||||
|
||||
let x2 = i16x8_extend_high_i8x16(x12);
|
||||
let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2));
|
||||
|
||||
let x3 = i16x8_extend_low_i8x16(x34);
|
||||
let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3));
|
||||
|
||||
let x4 = i16x8_extend_high_i8x16(x34);
|
||||
let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4));
|
||||
|
||||
let sum_xy = f32x4_convert_i32x4(sum_xy);
|
||||
|
||||
// f32x4_relaxed_madd is nightly only.
|
||||
let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d));
|
||||
let scaled = f32x4_mul(sum_xy, d);
|
||||
acc = f32x4_add(acc, scaled)
|
||||
}
|
||||
let res = f32x4_extract_lane::<0>(acc)
|
||||
+ f32x4_extract_lane::<1>(acc)
|
||||
+ f32x4_extract_lane::<2>(acc)
|
||||
+ f32x4_extract_lane::<3>(acc);
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
unsafe {
|
||||
let mut acc = f32x4_splat(0.0f32);
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let x1 = i16x8_load_extend_i8x8(x.qs.as_ptr());
|
||||
let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr());
|
||||
let sum_xy = i32x4_dot_i16x8(x1, y1);
|
||||
|
||||
let x2 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(8));
|
||||
let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2));
|
||||
|
||||
let x3 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(16));
|
||||
let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3));
|
||||
|
||||
let x4 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(24));
|
||||
let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4));
|
||||
|
||||
let sum_xy = f32x4_convert_i32x4(sum_xy);
|
||||
|
||||
// f32x4_relaxed_madd is nightly only.
|
||||
let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d));
|
||||
let scaled = f32x4_mul(sum_xy, d);
|
||||
acc = f32x4_add(acc, scaled)
|
||||
}
|
||||
let res = f32x4_extract_lane::<0>(acc)
|
||||
+ f32x4_extract_lane::<1>(acc)
|
||||
+ f32x4_extract_lane::<2>(acc)
|
||||
+ f32x4_extract_lane::<3>(acc);
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
unsafe {
|
||||
let mut sumf = f32x4_splat(0f32);
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let mut q2: &[_] = &x.qs;
|
||||
let mut q8: &[_] = &y.qs;
|
||||
let sc = &x.scales;
|
||||
|
||||
let mut summs = i32x4_splat(0);
|
||||
for i in (0..(QK_K / 16)).step_by(4) {
|
||||
let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(i));
|
||||
let scales = i32x4_shr(
|
||||
i32x4(
|
||||
sc[i] as i32,
|
||||
sc[i + 1] as i32,
|
||||
sc[i + 2] as i32,
|
||||
sc[i + 3] as i32,
|
||||
),
|
||||
4,
|
||||
);
|
||||
summs = i32x4_add(summs, i32x4_mul(bsums, scales))
|
||||
}
|
||||
let summs = f32x4_convert_i32x4(summs);
|
||||
|
||||
let dall = y.d * x.d.to_f32();
|
||||
let dmin = y.d * x.dmin.to_f32();
|
||||
|
||||
let mut isum = i32x4_splat(0);
|
||||
let mut is = 0;
|
||||
for _ in 0..(QK_K / 128) {
|
||||
let mut shift = 0;
|
||||
for _ in 0..4 {
|
||||
let d = (sc[is] & 0xF) as i32;
|
||||
is += 1;
|
||||
let mut isuml = i16x8_splat(0);
|
||||
for l in (0..16).step_by(8) {
|
||||
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
|
||||
let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
|
||||
let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
|
||||
isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
|
||||
}
|
||||
let dd = i32x4_splat(d);
|
||||
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
|
||||
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
|
||||
let d = (sc[is] & 0xF) as i32;
|
||||
is += 1;
|
||||
let mut isuml = i16x8_splat(0);
|
||||
for l in (16..32).step_by(8) {
|
||||
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
|
||||
let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
|
||||
let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
|
||||
isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
|
||||
}
|
||||
let dd = i32x4_splat(d);
|
||||
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
|
||||
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
|
||||
shift += 2;
|
||||
// adjust the indexing
|
||||
q8 = &q8[32..];
|
||||
}
|
||||
// adjust the indexing
|
||||
q2 = &q2[32..];
|
||||
}
|
||||
let isum = f32x4_convert_i32x4(isum);
|
||||
sumf = f32x4_add(
|
||||
sumf,
|
||||
f32x4_sub(
|
||||
f32x4_mul(isum, f32x4_splat(dall)),
|
||||
f32x4_mul(summs, f32x4_splat(dmin)),
|
||||
),
|
||||
);
|
||||
}
|
||||
let sumf = f32x4_extract_lane::<0>(sumf)
|
||||
+ f32x4_extract_lane::<1>(sumf)
|
||||
+ f32x4_extract_lane::<2>(sumf)
|
||||
+ f32x4_extract_lane::<3>(sumf);
|
||||
Ok(sumf)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
|
||||
const KMASK1: u32 = 0x3f3f3f3f;
|
||||
const KMASK2: u32 = 0x0f0f0f0f;
|
||||
const KMASK3: u32 = 0x03030303;
|
||||
|
||||
let mut utmp: [u32; 4] = [0; 4];
|
||||
let mut scales: [u8; 8] = [0; 8];
|
||||
let mut mins: [u8; 8] = [0; 8];
|
||||
|
||||
let mut aux8: [u8; QK_K] = [0; QK_K];
|
||||
let mut sums = f32x4_splat(0f32);
|
||||
unsafe {
|
||||
for (y, x) in ys.iter().zip(xs.iter()) {
|
||||
let q4 = &x.qs;
|
||||
let q8 = &y.qs;
|
||||
|
||||
for j in 0..QK_K / 64 {
|
||||
let q4_1 = v128_load(q4.as_ptr().add(32 * j) as *const v128);
|
||||
let q4_2 = v128_load(q4.as_ptr().add(32 * j + 16) as *const v128);
|
||||
v128_store(
|
||||
aux8.as_mut_ptr().add(64 * j) as *mut v128,
|
||||
v128_and(q4_1, u8x16_splat(0x0F)),
|
||||
);
|
||||
v128_store(
|
||||
aux8.as_mut_ptr().add(64 * j + 16) as *mut v128,
|
||||
v128_and(q4_2, u8x16_splat(0x0F)),
|
||||
);
|
||||
v128_store(
|
||||
aux8.as_mut_ptr().add(64 * j + 32) as *mut v128,
|
||||
u8x16_shr(q4_1, 4),
|
||||
);
|
||||
v128_store(
|
||||
aux8.as_mut_ptr().add(64 * j + 48) as *mut v128,
|
||||
u8x16_shr(q4_2, 4),
|
||||
);
|
||||
}
|
||||
|
||||
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
|
||||
|
||||
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
|
||||
let uaux = utmp[1] & KMASK1;
|
||||
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
|
||||
utmp[2] = uaux;
|
||||
utmp[0] &= KMASK1;
|
||||
|
||||
//extract scales and mins
|
||||
LittleEndian::write_u32_into(&utmp[0..2], &mut scales);
|
||||
LittleEndian::write_u32_into(&utmp[2..4], &mut mins);
|
||||
|
||||
let mut sumi = i32x4_splat(0);
|
||||
for j in (0..QK_K / 16).step_by(4) {
|
||||
let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(j));
|
||||
let (m1, m2) = (mins[j / 2] as i32, mins[j / 2 + 1] as i32);
|
||||
let mins = i32x4(m1, m1, m2, m2);
|
||||
sumi = i32x4_add(sumi, i32x4_mul(bsums, mins));
|
||||
}
|
||||
|
||||
let mut aux32 = i32x4_splat(0i32);
|
||||
for (scale_i, scale) in scales.iter().enumerate() {
|
||||
let scale = i32x4_splat(*scale as i32);
|
||||
for j in 0..4 {
|
||||
let i = 32 * scale_i + 8 * j;
|
||||
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(i));
|
||||
let aux8 = i16x8_load_extend_u8x8(aux8.as_ptr().add(i));
|
||||
let aux16 = i16x8_mul(q8, aux8);
|
||||
aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_low_i16x8(aux16)));
|
||||
aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_high_i16x8(aux16)));
|
||||
}
|
||||
}
|
||||
let aux32 = f32x4_convert_i32x4(aux32);
|
||||
let d = f32x4_splat(x.d.to_f32() * y.d);
|
||||
sums = f32x4_add(sums, f32x4_mul(aux32, d));
|
||||
let dmin = x.dmin.to_f32() * y.d;
|
||||
let dmin = f32x4_splat(dmin);
|
||||
let sumi = f32x4_convert_i32x4(sumi);
|
||||
sums = f32x4_sub(sums, f32x4_mul(sumi, dmin));
|
||||
}
|
||||
let sums = f32x4_extract_lane::<0>(sums)
|
||||
+ f32x4_extract_lane::<1>(sums)
|
||||
+ f32x4_extract_lane::<2>(sums)
|
||||
+ f32x4_extract_lane::<3>(sums);
|
||||
Ok(sums)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
|
||||
let mut aux8 = [0i8; QK_K];
|
||||
unsafe {
|
||||
let mut sums = f32x4_splat(0f32);
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let q4 = &x.ql;
|
||||
let qh = &x.qh;
|
||||
let q8 = &y.qs;
|
||||
let mut aux32 = f32x4_splat(0f32);
|
||||
|
||||
for j in (0..QK_K).step_by(128) {
|
||||
let aux8 = aux8.as_mut_ptr().add(j);
|
||||
let q4 = &q4.as_ptr().add(j / 2);
|
||||
let qh = &qh.as_ptr().add(j / 4);
|
||||
for l in (0..32).step_by(16) {
|
||||
// aux8[l] = (((q4[l] & 0xF) | ((qh[l] & 3) << 4)) as i32 - 32) as i8;
|
||||
let a8 = v128_or(
|
||||
v128_and(v128_load(q4.add(l) as *const v128), u8x16_splat(0xF)),
|
||||
u8x16_shl(
|
||||
v128_and(v128_load(qh.add(l) as *const v128), u8x16_splat(3)),
|
||||
4,
|
||||
),
|
||||
);
|
||||
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||
v128_store(
|
||||
aux8.add(l) as *mut v128,
|
||||
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||
);
|
||||
|
||||
// aux8[l + 32] =
|
||||
// (((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i32 - 32) as i8;
|
||||
let a8 = v128_or(
|
||||
v128_and(v128_load(q4.add(l + 32) as *const v128), u8x16_splat(0xF)),
|
||||
u8x16_shl(
|
||||
v128_and(
|
||||
u8x16_shr(v128_load(qh.add(l) as *const v128), 2),
|
||||
u8x16_splat(3),
|
||||
),
|
||||
4,
|
||||
),
|
||||
);
|
||||
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||
v128_store(
|
||||
aux8.add(l + 32) as *mut v128,
|
||||
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||
);
|
||||
|
||||
// aux8[l + 64] = (((q4[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i32 - 32) as i8;
|
||||
let a8 = v128_or(
|
||||
u8x16_shr(v128_load(q4.add(l) as *const v128), 4),
|
||||
u8x16_shl(
|
||||
v128_and(
|
||||
u8x16_shr(v128_load(qh.add(l) as *const v128), 4),
|
||||
u8x16_splat(3),
|
||||
),
|
||||
4,
|
||||
),
|
||||
);
|
||||
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||
v128_store(
|
||||
aux8.add(l + 64) as *mut v128,
|
||||
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||
);
|
||||
|
||||
// aux8[l + 96] =
|
||||
// (((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i32 - 32) as i8;
|
||||
let a8 = v128_or(
|
||||
u8x16_shr(v128_load(q4.add(l + 32) as *const v128), 4),
|
||||
u8x16_shl(
|
||||
v128_and(
|
||||
u8x16_shr(v128_load(qh.add(l) as *const v128), 6),
|
||||
u8x16_splat(3),
|
||||
),
|
||||
4,
|
||||
),
|
||||
);
|
||||
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||
v128_store(
|
||||
aux8.add(l + 96) as *mut v128,
|
||||
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
for (j, &scale) in x.scales.iter().enumerate() {
|
||||
let scale = f32x4_splat(scale as f32);
|
||||
for offset in [0, 8] {
|
||||
let aux16 = i16x8_mul(
|
||||
i16x8_load_extend_i8x8(q8.as_ptr().add(16 * j + offset)),
|
||||
i16x8_load_extend_i8x8(aux8.as_ptr().add(16 * j + offset)),
|
||||
);
|
||||
aux32 = f32x4_add(
|
||||
aux32,
|
||||
f32x4_mul(f32x4_convert_i32x4(i32x4_extend_low_i16x8(aux16)), scale),
|
||||
);
|
||||
aux32 = f32x4_add(
|
||||
aux32,
|
||||
f32x4_mul(f32x4_convert_i32x4(i32x4_extend_high_i16x8(aux16)), scale),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let d = f32x4_splat(x.d.to_f32() * y.d);
|
||||
sums = f32x4_add(sums, f32x4_mul(aux32, d));
|
||||
}
|
||||
let sums = f32x4_extract_lane::<0>(sums)
|
||||
+ f32x4_extract_lane::<1>(sums)
|
||||
+ f32x4_extract_lane::<2>(sums)
|
||||
+ f32x4_extract_lane::<3>(sums);
|
||||
Ok(sums)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
let qk = QK_K;
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut acc = f32x4_splat(0.0f32);
|
||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||
let x_qs = xs.qs.as_ptr();
|
||||
let y_qs = ys.qs.as_ptr();
|
||||
let mut sumi = i32x4_splat(0);
|
||||
for j in (0..QK_K).step_by(8) {
|
||||
let xs = i16x8_load_extend_i8x8(x_qs.add(j));
|
||||
let ys = i16x8_load_extend_i8x8(y_qs.add(j));
|
||||
let sum_xy = i32x4_dot_i16x8(xs, ys);
|
||||
sumi = i32x4_add(sumi, sum_xy)
|
||||
}
|
||||
let d = f32x4_splat(xs.d * ys.d);
|
||||
acc = f32x4_add(acc, f32x4_mul(f32x4_convert_i32x4(sumi), d))
|
||||
}
|
||||
let res = f32x4_extract_lane::<0>(acc)
|
||||
+ f32x4_extract_lane::<1>(acc)
|
||||
+ f32x4_extract_lane::<2>(acc)
|
||||
+ f32x4_extract_lane::<3>(acc);
|
||||
Ok(res)
|
||||
}
|
||||
}
|
@ -1,326 +0,0 @@
|
||||
use crate::Result;
|
||||
|
||||
pub(super) fn nearest_int(v: f32) -> i32 {
|
||||
v.round() as i32
|
||||
}
|
||||
|
||||
/// Validates that the input and output are the right size and returns an iterator which maps each
|
||||
/// input region `xs` to its corresponding output block in `ys`. Each output region is guaranteed
|
||||
/// to be `T::BLCK_SIZE` long.
|
||||
pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
|
||||
xs: &'b [f32],
|
||||
ys: &'a mut [T],
|
||||
) -> Result<Vec<(&'a mut T, &'b [f32])>> {
|
||||
let block_size = T::BLCK_SIZE;
|
||||
let dtype = T::DTYPE;
|
||||
|
||||
let expected_blocks = xs.len() / block_size;
|
||||
let actual_blocks = ys.len();
|
||||
|
||||
// Validate that the input is the right size
|
||||
if expected_blocks != actual_blocks {
|
||||
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
|
||||
}
|
||||
|
||||
Ok(ys.iter_mut().zip(xs.chunks_exact(block_size)).collect())
|
||||
}
|
||||
|
||||
/// Validates that the input and output are the right size and returns an iterator which maps each
|
||||
/// input block `xs` to its corresponding output region in `ys`. Each output region is guaranteed
|
||||
/// to be `T::BLCK_SIZE` long.
|
||||
pub(super) fn group_for_dequantization<'a, 'b, T: super::k_quants::GgmlType>(
|
||||
xs: &'a [T],
|
||||
ys: &'b mut [f32],
|
||||
) -> Result<Vec<(&'a T, &'b mut [f32])>> {
|
||||
let block_size = T::BLCK_SIZE;
|
||||
let dtype = T::DTYPE;
|
||||
|
||||
let actual_output_len = ys.len();
|
||||
let expected_output_len = xs.len() * block_size;
|
||||
// Validate that the output is the right size
|
||||
if expected_output_len != actual_output_len {
|
||||
crate::bail!("dequantize {dtype:?}: ys (len = {actual_output_len}) does not match the expected length of {expected_output_len}!")
|
||||
}
|
||||
|
||||
// Zip the blocks and outputs together
|
||||
Ok(xs.iter().zip(ys.chunks_exact_mut(block_size)).collect())
|
||||
}
|
||||
|
||||
pub(super) fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
|
||||
if j < 4 {
|
||||
let d = q[j] & 63;
|
||||
let m = q[j + 4] & 63;
|
||||
(d, m)
|
||||
} else {
|
||||
let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
|
||||
let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
|
||||
(d, m)
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) unsafe fn make_qx_quants(
|
||||
n: usize,
|
||||
nmax: i32,
|
||||
x: *const f32,
|
||||
ls: *mut i8,
|
||||
rmse_type: i32,
|
||||
) -> f32 {
|
||||
let mut max = 0f32;
|
||||
let mut amax = 0f32;
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let ax = x.abs();
|
||||
if ax > amax {
|
||||
amax = ax;
|
||||
max = x;
|
||||
}
|
||||
}
|
||||
if amax == 0. {
|
||||
// all zero
|
||||
for i in 0..n {
|
||||
*ls.add(i) = 0;
|
||||
}
|
||||
return 0.;
|
||||
}
|
||||
let mut iscale = -(nmax as f32) / max;
|
||||
if rmse_type == 0 {
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let l = nearest_int(iscale * x);
|
||||
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
|
||||
}
|
||||
return 1.0 / iscale;
|
||||
}
|
||||
let weight_type = rmse_type % 2;
|
||||
let mut sumlx = 0f32;
|
||||
let mut suml2 = 0f32;
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let l = nearest_int(iscale * x);
|
||||
let l = l.clamp(-nmax, nmax - 1);
|
||||
*ls.add(i) = (l + nmax) as i8;
|
||||
let w = if weight_type == 1 { x * x } else { 1.0 };
|
||||
let l = l as f32;
|
||||
sumlx += w * x * l;
|
||||
suml2 += w * l * l;
|
||||
}
|
||||
let mut scale = sumlx / suml2;
|
||||
let mut best = scale * sumlx;
|
||||
for _itry in 0..3 {
|
||||
let iscale = 1.0 / scale;
|
||||
let mut slx = 0f32;
|
||||
let mut sl2 = 0f32;
|
||||
let mut changed = false;
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let l = nearest_int(iscale * x);
|
||||
let l = l.clamp(-nmax, nmax - 1);
|
||||
if l + nmax != *ls.add(i) as i32 {
|
||||
changed = true;
|
||||
}
|
||||
let w = if weight_type == 1 { x * x } else { 1f32 };
|
||||
let l = l as f32;
|
||||
slx += w * x * l;
|
||||
sl2 += w * l * l;
|
||||
}
|
||||
if !changed || sl2 == 0.0 || slx * slx <= best * sl2 {
|
||||
break;
|
||||
}
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let l = nearest_int(iscale * x);
|
||||
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
|
||||
}
|
||||
sumlx = slx;
|
||||
suml2 = sl2;
|
||||
scale = sumlx / suml2;
|
||||
best = scale * sumlx;
|
||||
}
|
||||
for _itry in 0..5 {
|
||||
let mut n_changed = 0;
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let w = if weight_type == 1 { x * x } else { 1. };
|
||||
let l = *ls.add(i) as i32 - nmax;
|
||||
let mut slx = sumlx - w * x * l as f32;
|
||||
if slx > 0. {
|
||||
let mut sl2 = suml2 - w * l as f32 * l as f32;
|
||||
let new_l = nearest_int(x * sl2 / slx);
|
||||
let new_l = new_l.clamp(-nmax, nmax - 1);
|
||||
if new_l != l {
|
||||
slx += w * x * new_l as f32;
|
||||
sl2 += w * new_l as f32 * new_l as f32;
|
||||
if sl2 > 0. && slx * slx * suml2 > sumlx * sumlx * sl2 {
|
||||
*ls.add(i) = (nmax + new_l) as i8;
|
||||
sumlx = slx;
|
||||
suml2 = sl2;
|
||||
scale = sumlx / suml2;
|
||||
best = scale * sumlx;
|
||||
n_changed += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if n_changed == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if rmse_type < 3 {
|
||||
return scale;
|
||||
}
|
||||
for is in -4..4 {
|
||||
if is == 0 {
|
||||
continue;
|
||||
}
|
||||
iscale = -(nmax as f32 + 0.1f32 * is as f32) / max;
|
||||
let mut sumlx = 0.;
|
||||
let mut suml2 = 0.;
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let l = nearest_int(iscale * x);
|
||||
let l = l.clamp(-nmax, nmax - 1);
|
||||
let w = if weight_type == 1 { x * x } else { 1. };
|
||||
let l = l as f32;
|
||||
sumlx += w * x * l;
|
||||
suml2 += w * l * l;
|
||||
}
|
||||
if suml2 > 0. && sumlx * sumlx > best * suml2 {
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let l = nearest_int(iscale * x);
|
||||
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
|
||||
}
|
||||
scale = sumlx / suml2;
|
||||
best = scale * sumlx;
|
||||
}
|
||||
}
|
||||
scale
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L224
|
||||
pub(super) fn make_qkx1_quants(nmax: i32, ntry: usize, x: &[f32]) -> (f32, f32) {
|
||||
let n = x.len();
|
||||
let mut l = vec![0; n];
|
||||
// Get min/max
|
||||
let min = *x
|
||||
.iter()
|
||||
.take(n)
|
||||
.min_by(|a, b| a.total_cmp(b))
|
||||
.unwrap_or(&x[0]);
|
||||
let max = *x.iter().max_by(|a, b| a.total_cmp(b)).unwrap_or(&x[0]);
|
||||
|
||||
// If min == max, all values are the same => nothing to do here
|
||||
if max == min {
|
||||
return (0.0, 0.0);
|
||||
}
|
||||
|
||||
// Ensure min <= 0.0
|
||||
let mut min = min.min(0.);
|
||||
|
||||
// Compute scale and inverse scale
|
||||
let mut iscale = nmax as f32 / (max - min);
|
||||
let mut scale = 1.0 / iscale;
|
||||
|
||||
for _ in 0..ntry {
|
||||
let mut sumlx = 0.0;
|
||||
let mut suml2 = 0;
|
||||
let mut did_change = false;
|
||||
|
||||
for (i, value) in x.iter().enumerate().take(n) {
|
||||
let li = nearest_int(iscale * (value - min)).clamp(0, nmax);
|
||||
let clamped_li = li as u8;
|
||||
if clamped_li != l[i] {
|
||||
l[i] = clamped_li;
|
||||
did_change = true;
|
||||
}
|
||||
sumlx += (value - min) * li as f32;
|
||||
suml2 += li * li;
|
||||
}
|
||||
scale = sumlx / suml2 as f32;
|
||||
|
||||
let sum: f32 = x
|
||||
.iter()
|
||||
.take(n)
|
||||
.zip(l.iter().take(n))
|
||||
.map(|(xi, &li)| xi - scale * li as f32)
|
||||
.sum();
|
||||
|
||||
min = sum / n as f32;
|
||||
if min > 0.0 {
|
||||
min = 0.0;
|
||||
}
|
||||
iscale = 1.0 / scale;
|
||||
if !did_change {
|
||||
break;
|
||||
}
|
||||
}
|
||||
(scale, -min)
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L165
|
||||
pub(super) fn make_q3_quants(x: &[f32], nmax: i32, do_rmse: bool) -> f32 {
|
||||
let n = x.len();
|
||||
let mut l = vec![0i8; n];
|
||||
|
||||
let mut max = 0.0;
|
||||
let mut amax = 0.0;
|
||||
for &xi in x.iter().take(n) {
|
||||
let ax = xi.abs();
|
||||
if ax > amax {
|
||||
amax = ax;
|
||||
max = xi;
|
||||
}
|
||||
}
|
||||
|
||||
if amax == 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let iscale = -(nmax as f32) / max;
|
||||
if do_rmse {
|
||||
let mut sumlx = 0.0;
|
||||
let mut suml2 = 0.0;
|
||||
for i in 0..n {
|
||||
let li = (iscale * x[i]).round() as i32;
|
||||
let li = li.clamp(-nmax, nmax - 1);
|
||||
l[i] = li as i8;
|
||||
let w = x[i] * x[i];
|
||||
sumlx += w * x[i] * li as f32;
|
||||
suml2 += w * (li * li) as f32;
|
||||
}
|
||||
for _ in 0..5 {
|
||||
let mut n_changed = 0;
|
||||
for i in 0..n {
|
||||
let w = x[i] * x[i];
|
||||
let mut slx = sumlx - w * x[i] * l[i] as f32;
|
||||
if slx > 0.0 {
|
||||
let mut sl2 = suml2 - w * (l[i] as i32 * l[i] as i32) as f32;
|
||||
let mut new_l = (x[i] * sl2 / slx).round() as i32;
|
||||
new_l = new_l.clamp(-nmax, nmax - 1);
|
||||
if new_l != l[i] as i32 {
|
||||
slx += w * x[i] * new_l as f32;
|
||||
sl2 += w * (new_l * new_l) as f32;
|
||||
if sl2 > 0.0 && slx * slx * suml2 > sumlx * sumlx * sl2 {
|
||||
l[i] = new_l as i8;
|
||||
sumlx = slx;
|
||||
suml2 = sl2;
|
||||
n_changed += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if n_changed == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
for li in l.iter_mut() {
|
||||
*li += nmax as i8;
|
||||
}
|
||||
return sumlx / suml2;
|
||||
}
|
||||
for i in 0..n {
|
||||
let li = (iscale * x[i]).round() as i32;
|
||||
l[i] = (li.clamp(-nmax, nmax - 1) + nmax) as i8;
|
||||
}
|
||||
1.0 / iscale
|
||||
}
|
@ -10,7 +10,6 @@ impl From<DType> for st::Dtype {
|
||||
match value {
|
||||
DType::U8 => st::Dtype::U8,
|
||||
DType::U32 => st::Dtype::U32,
|
||||
DType::I64 => st::Dtype::I64,
|
||||
DType::BF16 => st::Dtype::BF16,
|
||||
DType::F16 => st::Dtype::F16,
|
||||
DType::F32 => st::Dtype::F32,
|
||||
@ -25,7 +24,6 @@ impl TryFrom<st::Dtype> for DType {
|
||||
match value {
|
||||
st::Dtype::U8 => Ok(DType::U8),
|
||||
st::Dtype::U32 => Ok(DType::U32),
|
||||
st::Dtype::I64 => Ok(DType::I64),
|
||||
st::Dtype::BF16 => Ok(DType::BF16),
|
||||
st::Dtype::F16 => Ok(DType::F16),
|
||||
st::Dtype::F32 => Ok(DType::F32),
|
||||
@ -78,7 +76,11 @@ impl st::View for &Tensor {
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
pub fn save_safetensors<P: AsRef<Path>>(&self, name: &str, filename: P) -> Result<()> {
|
||||
pub fn save_safetensors<P: AsRef<std::path::Path>>(
|
||||
&self,
|
||||
name: &str,
|
||||
filename: P,
|
||||
) -> Result<()> {
|
||||
let data = [(name, self.clone())];
|
||||
Ok(st::serialize_to_file(data, &None, filename.as_ref())?)
|
||||
}
|
||||
@ -187,7 +189,6 @@ impl Tensor {
|
||||
match dtype {
|
||||
DType::U8 => convert_slice::<u8>(data, shape, device),
|
||||
DType::U32 => convert_slice::<u32>(data, shape, device),
|
||||
DType::I64 => convert_slice::<i64>(data, shape, device),
|
||||
DType::BF16 => convert_slice::<half::bf16>(data, shape, device),
|
||||
DType::F16 => convert_slice::<half::f16>(data, shape, device),
|
||||
DType::F32 => convert_slice::<f32>(data, shape, device),
|
||||
@ -204,15 +205,24 @@ fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
convert_with_cast_::<u16, u32, _>(view, device, conv)
|
||||
}
|
||||
st::Dtype::U32 => convert_::<u32>(view, device),
|
||||
st::Dtype::I32 => {
|
||||
let conv = |x| Ok(i64::from(x));
|
||||
convert_with_cast_::<i32, i64, _>(view, device, conv)
|
||||
}
|
||||
st::Dtype::I64 => convert_::<i64>(view, device),
|
||||
st::Dtype::BF16 => convert_::<half::bf16>(view, device),
|
||||
st::Dtype::F16 => convert_::<half::f16>(view, device),
|
||||
st::Dtype::F32 => convert_::<f32>(view, device),
|
||||
st::Dtype::F64 => convert_::<f64>(view, device),
|
||||
st::Dtype::I32 => {
|
||||
let conv = |x| {
|
||||
u32::try_from(x)
|
||||
.map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}")))
|
||||
};
|
||||
convert_with_cast_::<i32, u32, _>(view, device, conv)
|
||||
}
|
||||
st::Dtype::I64 => {
|
||||
let conv = |x| {
|
||||
u32::try_from(x)
|
||||
.map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}")))
|
||||
};
|
||||
convert_with_cast_::<i64, u32, _>(view, device, conv)
|
||||
}
|
||||
dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
|
||||
}
|
||||
}
|
||||
@ -223,7 +233,6 @@ fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
|
||||
match tensor.dtype() {
|
||||
DType::U8 => Ok(convert_back_::<u8>(tensor.to_vec1()?)),
|
||||
DType::U32 => Ok(convert_back_::<u32>(tensor.to_vec1()?)),
|
||||
DType::I64 => Ok(convert_back_::<i64>(tensor.to_vec1()?)),
|
||||
DType::F16 => Ok(convert_back_::<half::f16>(tensor.to_vec1()?)),
|
||||
DType::BF16 => Ok(convert_back_::<half::bf16>(tensor.to_vec1()?)),
|
||||
DType::F32 => Ok(convert_back_::<f32>(tensor.to_vec1()?)),
|
||||
@ -251,134 +260,6 @@ pub fn save<K: AsRef<str> + Ord + std::fmt::Display, P: AsRef<Path>>(
|
||||
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
|
||||
}
|
||||
|
||||
#[derive(yoke::Yokeable)]
|
||||
struct SafeTensors_<'a>(SafeTensors<'a>);
|
||||
|
||||
pub struct MmapedSafetensors {
|
||||
safetensors: Vec<yoke::Yoke<SafeTensors_<'static>, memmap2::Mmap>>,
|
||||
routing: Option<HashMap<String, usize>>,
|
||||
}
|
||||
|
||||
impl MmapedSafetensors {
|
||||
/// Creates a wrapper around a memory mapped file and deserialize the safetensors header.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The unsafe is inherited from [`memmap2::MmapOptions`].
|
||||
pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
|
||||
let p = p.as_ref();
|
||||
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
|
||||
let file = memmap2::MmapOptions::new()
|
||||
.map(&file)
|
||||
.map_err(|e| Error::from(e).with_path(p))?;
|
||||
let safetensors = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
|
||||
file,
|
||||
|data: &[u8]| {
|
||||
let st = safetensors::SafeTensors::deserialize(data)
|
||||
.map_err(|e| Error::from(e).with_path(p))?;
|
||||
Ok::<_, Error>(SafeTensors_(st))
|
||||
},
|
||||
)?;
|
||||
Ok(Self {
|
||||
safetensors: vec![safetensors],
|
||||
routing: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates a wrapper around multiple memory mapped file and deserialize the safetensors headers.
|
||||
///
|
||||
/// If a tensor name appears in multiple files, the last entry is returned.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The unsafe is inherited from [`memmap2::MmapOptions`].
|
||||
pub unsafe fn multi<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
|
||||
let mut routing = HashMap::new();
|
||||
let mut safetensors = vec![];
|
||||
for (index, p) in paths.iter().enumerate() {
|
||||
let p = p.as_ref();
|
||||
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
|
||||
let file = memmap2::MmapOptions::new()
|
||||
.map(&file)
|
||||
.map_err(|e| Error::from(e).with_path(p))?;
|
||||
let data = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
|
||||
file,
|
||||
|data: &[u8]| {
|
||||
let st = safetensors::SafeTensors::deserialize(data)
|
||||
.map_err(|e| Error::from(e).with_path(p))?;
|
||||
Ok::<_, Error>(SafeTensors_(st))
|
||||
},
|
||||
)?;
|
||||
for k in data.get().0.names() {
|
||||
routing.insert(k.to_string(), index);
|
||||
}
|
||||
safetensors.push(data)
|
||||
}
|
||||
Ok(Self {
|
||||
safetensors,
|
||||
routing: Some(routing),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
|
||||
self.get(name)?.load(dev)
|
||||
}
|
||||
|
||||
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
|
||||
let mut tensors = vec![];
|
||||
for safetensors in self.safetensors.iter() {
|
||||
tensors.push(safetensors.get().0.tensors())
|
||||
}
|
||||
tensors.into_iter().flatten().collect()
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
|
||||
let index = match &self.routing {
|
||||
None => 0,
|
||||
Some(routing) => {
|
||||
let index = routing.get(name).ok_or_else(|| {
|
||||
Error::CannotFindTensor {
|
||||
path: name.to_string(),
|
||||
}
|
||||
.bt()
|
||||
})?;
|
||||
*index
|
||||
}
|
||||
};
|
||||
Ok(self.safetensors[index].get().0.tensor(name)?)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BufferedSafetensors {
|
||||
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
|
||||
}
|
||||
|
||||
impl BufferedSafetensors {
|
||||
/// Creates a wrapper around a binary buffer and deserialize the safetensors header.
|
||||
pub fn new(buffer: Vec<u8>) -> Result<Self> {
|
||||
let safetensors = yoke::Yoke::<SafeTensors_<'static>, Vec<u8>>::try_attach_to_cart(
|
||||
buffer,
|
||||
|data: &[u8]| {
|
||||
let st = safetensors::SafeTensors::deserialize(data)?;
|
||||
Ok::<_, Error>(SafeTensors_(st))
|
||||
},
|
||||
)?;
|
||||
Ok(Self { safetensors })
|
||||
}
|
||||
|
||||
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
|
||||
self.get(name)?.load(dev)
|
||||
}
|
||||
|
||||
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
|
||||
self.safetensors.get().0.tensors()
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
|
||||
Ok(self.safetensors.get().0.tensor(name)?)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MmapedFile {
|
||||
path: std::path::PathBuf,
|
||||
inner: memmap2::Mmap,
|
||||
@ -391,7 +272,7 @@ impl MmapedFile {
|
||||
/// # Safety
|
||||
///
|
||||
/// The unsafe is inherited from [`memmap2::MmapOptions`].
|
||||
pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
|
||||
pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
|
||||
let p = p.as_ref();
|
||||
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
|
||||
let inner = memmap2::MmapOptions::new()
|
||||
|
@ -1,23 +0,0 @@
|
||||
use crate::{Result, Tensor, WithDType};
|
||||
|
||||
pub enum TensorScalar {
|
||||
Tensor(Tensor),
|
||||
Scalar(Tensor),
|
||||
}
|
||||
|
||||
pub trait TensorOrScalar {
|
||||
fn to_tensor_scalar(self) -> Result<TensorScalar>;
|
||||
}
|
||||
|
||||
impl TensorOrScalar for &Tensor {
|
||||
fn to_tensor_scalar(self) -> Result<TensorScalar> {
|
||||
Ok(TensorScalar::Tensor(self.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: WithDType> TensorOrScalar for T {
|
||||
fn to_tensor_scalar(self) -> Result<TensorScalar> {
|
||||
let scalar = Tensor::new(self, &crate::Device::Cpu)?;
|
||||
Ok(TensorScalar::Scalar(scalar))
|
||||
}
|
||||
}
|
@ -1,5 +1,3 @@
|
||||
//! The shape of a tensor is a tuple with the size of each of its dimensions.
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::{Error, Result};
|
||||
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
@ -73,14 +71,6 @@ impl From<(usize, usize, usize, usize, usize)> for Shape {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(usize, usize, usize, usize, usize, usize)> for Shape {
|
||||
fn from(d123456: (usize, usize, usize, usize, usize, usize)) -> Self {
|
||||
Self(vec![
|
||||
d123456.0, d123456.1, d123456.2, d123456.3, d123456.4, d123456.5,
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<usize>> for Shape {
|
||||
fn from(dims: Vec<usize>) -> Self {
|
||||
Self(dims)
|
||||
@ -128,7 +118,6 @@ impl Shape {
|
||||
Self(dims.to_vec())
|
||||
}
|
||||
|
||||
/// The rank is the number of dimensions, 0 for a scalar value, 1 for a vector, etc.
|
||||
pub fn rank(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
@ -137,12 +126,10 @@ impl Shape {
|
||||
self.0
|
||||
}
|
||||
|
||||
/// The dimensions as a slice of `usize`.
|
||||
pub fn dims(&self) -> &[usize] {
|
||||
&self.0
|
||||
}
|
||||
|
||||
/// The total number of elements, this is the product of all dimension sizes.
|
||||
pub fn elem_count(&self) -> usize {
|
||||
self.0.iter().product()
|
||||
}
|
||||
@ -194,75 +181,10 @@ impl Shape {
|
||||
true
|
||||
}
|
||||
|
||||
/// Modifies the shape by adding a list of additional dimensions at the end of the existing
|
||||
/// dimensions.
|
||||
pub fn extend(mut self, additional_dims: &[usize]) -> Self {
|
||||
self.0.extend(additional_dims);
|
||||
self
|
||||
}
|
||||
|
||||
/// Check whether the two shapes are compatible for broadcast, and if it is the case return the
|
||||
/// broadcasted shape. This is to be used for binary pointwise ops.
|
||||
pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
|
||||
let lhs = self;
|
||||
let lhs_dims = lhs.dims();
|
||||
let rhs_dims = rhs.dims();
|
||||
let lhs_ndims = lhs_dims.len();
|
||||
let rhs_ndims = rhs_dims.len();
|
||||
let bcast_ndims = usize::max(lhs_ndims, rhs_ndims);
|
||||
let mut bcast_dims = vec![0; bcast_ndims];
|
||||
for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() {
|
||||
let rev_idx = bcast_ndims - idx;
|
||||
let l_value = if lhs_ndims < rev_idx {
|
||||
1
|
||||
} else {
|
||||
lhs_dims[lhs_ndims - rev_idx]
|
||||
};
|
||||
let r_value = if rhs_ndims < rev_idx {
|
||||
1
|
||||
} else {
|
||||
rhs_dims[rhs_ndims - rev_idx]
|
||||
};
|
||||
*bcast_value = if l_value == r_value {
|
||||
l_value
|
||||
} else if l_value == 1 {
|
||||
r_value
|
||||
} else if r_value == 1 {
|
||||
l_value
|
||||
} else {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: lhs.clone(),
|
||||
rhs: rhs.clone(),
|
||||
op,
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
Ok(Shape::from(bcast_dims))
|
||||
}
|
||||
|
||||
pub(crate) fn broadcast_shape_matmul(&self, rhs: &Self) -> Result<(Shape, Shape)> {
|
||||
let lhs = self;
|
||||
let lhs_dims = lhs.dims();
|
||||
let rhs_dims = rhs.dims();
|
||||
if lhs_dims.len() < 2 || rhs_dims.len() < 2 {
|
||||
crate::bail!("only 2d matrixes are supported {lhs:?} {rhs:?}")
|
||||
}
|
||||
let (m, lhs_k) = (lhs_dims[lhs_dims.len() - 2], lhs_dims[lhs_dims.len() - 1]);
|
||||
let (rhs_k, n) = (rhs_dims[rhs_dims.len() - 2], rhs_dims[rhs_dims.len() - 1]);
|
||||
if lhs_k != rhs_k {
|
||||
crate::bail!("different inner dimensions in broadcast matmul {lhs:?} {rhs:?}")
|
||||
}
|
||||
|
||||
let lhs_b = Self::from(&lhs_dims[..lhs_dims.len() - 2]);
|
||||
let rhs_b = Self::from(&rhs_dims[..rhs_dims.len() - 2]);
|
||||
let bcast = lhs_b.broadcast_shape_binary_op(&rhs_b, "broadcast_matmul")?;
|
||||
let bcast_dims = bcast.dims();
|
||||
|
||||
let bcast_lhs = [bcast_dims, &[m, lhs_k]].concat();
|
||||
let bcast_rhs = [bcast_dims, &[rhs_k, n]].concat();
|
||||
Ok((Shape::from(bcast_lhs), Shape::from(bcast_rhs)))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Dim {
|
||||
@ -423,39 +345,6 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
|
||||
}
|
||||
}
|
||||
|
||||
impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim> Dims for (D1, D2, D3, D4) {
|
||||
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||
let d0 = self.0.to_index(shape, op)?;
|
||||
let d1 = self.1.to_index(shape, op)?;
|
||||
let d2 = self.2.to_index(shape, op)?;
|
||||
let d3 = self.3.to_index(shape, op)?;
|
||||
Ok(vec![d0, d1, d2, d3])
|
||||
}
|
||||
}
|
||||
|
||||
impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5) {
|
||||
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||
let d0 = self.0.to_index(shape, op)?;
|
||||
let d1 = self.1.to_index(shape, op)?;
|
||||
let d2 = self.2.to_index(shape, op)?;
|
||||
let d3 = self.3.to_index(shape, op)?;
|
||||
let d4 = self.4.to_index(shape, op)?;
|
||||
Ok(vec![d0, d1, d2, d3, d4])
|
||||
}
|
||||
}
|
||||
|
||||
impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim, D6: Dim> Dims for (D1, D2, D3, D4, D5, D6) {
|
||||
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||
let d0 = self.0.to_index(shape, op)?;
|
||||
let d1 = self.1.to_index(shape, op)?;
|
||||
let d2 = self.2.to_index(shape, op)?;
|
||||
let d3 = self.3.to_index(shape, op)?;
|
||||
let d4 = self.4.to_index(shape, op)?;
|
||||
let d5 = self.5.to_index(shape, op)?;
|
||||
Ok(vec![d0, d1, d2, d3, d4, d5])
|
||||
}
|
||||
}
|
||||
|
||||
extract_dims!(dims0, 0, |_: &[usize]| (), ());
|
||||
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
|
||||
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
|
||||
@ -478,139 +367,6 @@ extract_dims!(
|
||||
(usize, usize, usize, usize, usize)
|
||||
);
|
||||
|
||||
pub trait ShapeWithOneHole {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape>;
|
||||
}
|
||||
|
||||
impl<S: Into<Shape>> ShapeWithOneHole for S {
|
||||
fn into_shape(self, _el_count: usize) -> Result<Shape> {
|
||||
Ok(self.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((),) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
Ok(el_count.into())
|
||||
}
|
||||
}
|
||||
|
||||
fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
|
||||
if prod_d == 0 {
|
||||
crate::bail!("cannot reshape tensor of {el_count} elements to {s:?}")
|
||||
}
|
||||
if el_count % prod_d != 0 {
|
||||
crate::bail!("cannot reshape tensor with {el_count} elements to {s:?}")
|
||||
}
|
||||
Ok(el_count / prod_d)
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1) = self;
|
||||
Ok((hole_size(el_count, d1, &self)?, d1).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, ()) = self;
|
||||
Ok((d1, hole_size(el_count, d1, &self)?).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1, d2) = self;
|
||||
Ok((hole_size(el_count, d1 * d2, &self)?, d1, d2).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, (), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, (), d2) = self;
|
||||
Ok((d1, hole_size(el_count, d1 * d2, &self)?, d2).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, ()) = self;
|
||||
Ok((d1, d2, hole_size(el_count, d1 * d2, &self)?).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize, usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1, d2, d3) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
|
||||
Ok((d, d1, d2, d3).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, (), usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, (), d2, d3) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
|
||||
Ok((d1, d, d2, d3).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, (), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, (), d3) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
|
||||
Ok((d1, d2, d, d3).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, d3, ()) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
|
||||
Ok((d1, d2, d3, d).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1, d2, d3, d4) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
||||
Ok((d, d1, d2, d3, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, (), d2, d3, d4) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
||||
Ok((d1, d, d2, d3, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, (), d3, d4) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
||||
Ok((d1, d2, d, d3, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, d3, (), d4) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
||||
Ok((d1, d2, d3, d, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, d3, d4, ()) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
||||
Ok((d1, d2, d3, d4, d).into())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
|
||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
|
||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
|
||||
|
||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||
// out of memory. Instead try_clone should be used.
|
||||
@ -8,7 +8,6 @@ use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage,
|
||||
pub enum Storage {
|
||||
Cpu(CpuStorage),
|
||||
Cuda(CudaStorage),
|
||||
Metal(MetalStorage),
|
||||
}
|
||||
|
||||
impl Storage {
|
||||
@ -19,10 +18,6 @@ impl Storage {
|
||||
let storage = storage.try_clone(layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.try_clone(layout)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -30,7 +25,6 @@ impl Storage {
|
||||
match self {
|
||||
Self::Cpu(_) => Device::Cpu,
|
||||
Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
|
||||
Self::Metal(storage) => Device::Metal(storage.device().clone()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -38,7 +32,6 @@ impl Storage {
|
||||
match self {
|
||||
Self::Cpu(storage) => storage.dtype(),
|
||||
Self::Cuda(storage) => storage.dtype(),
|
||||
Self::Metal(storage) => storage.dtype(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -72,27 +65,6 @@ impl Storage {
|
||||
let storage = storage.affine(layout, mul, add)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.affine(layout, mul, add)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn powf(&self, layout: &Layout, alpha: f64) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.powf(layout, alpha)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.powf(layout, alpha)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.powf(layout, alpha)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -106,10 +78,6 @@ impl Storage {
|
||||
let storage = storage.elu(layout, alpha)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.elu(layout, alpha)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -131,10 +99,6 @@ impl Storage {
|
||||
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
(lhs, rhs) => {
|
||||
// Should not happen because of the same device check above but we're defensive
|
||||
// anyway.
|
||||
@ -158,10 +122,6 @@ impl Storage {
|
||||
let storage = storage.reduce_op(op, layout, s)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.reduce_op(op, layout, s)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -175,14 +135,10 @@ impl Storage {
|
||||
let storage = storage.to_dtype(layout, dtype)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.to_dtype(layout, dtype)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn apply_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
|
||||
pub(crate) fn custom_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
|
||||
match self {
|
||||
Self::Cpu(storage) => {
|
||||
let (storage, shape) = c.cpu_fwd(storage, l)?;
|
||||
@ -192,14 +148,10 @@ impl Storage {
|
||||
let (storage, shape) = c.cuda_fwd(storage, l)?;
|
||||
Ok((Self::Cuda(storage), shape))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let (storage, shape) = c.metal_fwd(storage, l)?;
|
||||
Ok((Self::Metal(storage), shape))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn apply_op2(
|
||||
pub(crate) fn custom_op2(
|
||||
&self,
|
||||
l1: &Layout,
|
||||
t2: &Self,
|
||||
@ -216,15 +168,11 @@ impl Storage {
|
||||
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
|
||||
Ok((Self::Cuda(s), shape))
|
||||
}
|
||||
(Self::Metal(s1), Self::Metal(s2)) => {
|
||||
let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?;
|
||||
Ok((Self::Metal(s), shape))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn apply_op3(
|
||||
pub(crate) fn custom_op3(
|
||||
&self,
|
||||
l1: &Layout,
|
||||
t2: &Self,
|
||||
@ -244,10 +192,6 @@ impl Storage {
|
||||
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
|
||||
Ok((Self::Cuda(s), shape))
|
||||
}
|
||||
(Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
|
||||
let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?;
|
||||
Ok((Self::Metal(s), shape))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -262,10 +206,6 @@ impl Storage {
|
||||
let storage = storage.unary_impl::<B>(layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.unary_impl::<B>(layout)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -286,10 +226,6 @@ impl Storage {
|
||||
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
(lhs, rhs) => {
|
||||
// Should not happen because of the same device check above but we're defensive
|
||||
// anyway.
|
||||
@ -321,10 +257,6 @@ impl Storage {
|
||||
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Metal(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -334,33 +266,6 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn conv_transpose1d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
self.same_device(kernel, "conv-transpose1d")?;
|
||||
self.same_dtype(kernel, "conv-transpose1d")?;
|
||||
match (self, &kernel) {
|
||||
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
|
||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cpu(s))
|
||||
}
|
||||
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
|
||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "conv-transpose1d",
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn conv2d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
@ -379,10 +284,6 @@ impl Storage {
|
||||
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Metal(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -392,37 +293,6 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn conv_transpose2d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConvTranspose2D,
|
||||
) -> Result<Self> {
|
||||
self.same_device(kernel, "conv_transpose2d")?;
|
||||
self.same_dtype(kernel, "conv_transpose2d")?;
|
||||
match (self, &kernel) {
|
||||
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
|
||||
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cpu(s))
|
||||
}
|
||||
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
|
||||
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Metal(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "conv_transpose2d",
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn avg_pool2d(
|
||||
&self,
|
||||
layout: &Layout,
|
||||
@ -438,10 +308,6 @@ impl Storage {
|
||||
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -460,27 +326,6 @@ impl Storage {
|
||||
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.upsample_nearest1d(layout, sz)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.upsample_nearest1d(layout, sz)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.upsample_nearest1d(layout, sz)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -494,10 +339,6 @@ impl Storage {
|
||||
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -521,10 +362,6 @@ impl Storage {
|
||||
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => {
|
||||
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -551,10 +388,6 @@ impl Storage {
|
||||
let storage = s.gather(l, indexes, indexes_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(s), Self::Metal(indexes)) => {
|
||||
let storage = s.gather(l, indexes, indexes_l, d)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -579,10 +412,6 @@ impl Storage {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -607,10 +436,6 @@ impl Storage {
|
||||
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -632,10 +457,6 @@ impl Storage {
|
||||
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -663,10 +484,6 @@ impl Storage {
|
||||
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -686,9 +503,6 @@ impl Storage {
|
||||
match (self, dst) {
|
||||
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
|
||||
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
|
||||
(Self::Metal(src), Self::Metal(dst)) => {
|
||||
Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -22,23 +22,3 @@ pub fn has_mkl() -> bool {
|
||||
pub fn cuda_is_available() -> bool {
|
||||
cfg!(feature = "cuda")
|
||||
}
|
||||
|
||||
pub fn metal_is_available() -> bool {
|
||||
cfg!(feature = "metal")
|
||||
}
|
||||
|
||||
pub fn with_avx() -> bool {
|
||||
cfg!(target_feature = "avx")
|
||||
}
|
||||
|
||||
pub fn with_neon() -> bool {
|
||||
cfg!(target_feature = "neon")
|
||||
}
|
||||
|
||||
pub fn with_simd128() -> bool {
|
||||
cfg!(target_feature = "simd128")
|
||||
}
|
||||
|
||||
pub fn with_f16c() -> bool {
|
||||
cfg!(target_feature = "f16c")
|
||||
}
|
||||
|
@ -107,10 +107,6 @@ impl Var {
|
||||
Ok(Self(inner))
|
||||
}
|
||||
|
||||
pub fn as_detached_tensor(&self) -> Tensor {
|
||||
self.0.detach()
|
||||
}
|
||||
|
||||
pub fn as_tensor(&self) -> &Tensor {
|
||||
&self.0
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
mod test_utils;
|
||||
use anyhow::Result;
|
||||
use candle_core::{test_device, test_utils, Device, IndexOp, Tensor};
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
/* This test is based on the following script.
|
||||
import torch
|
||||
@ -13,11 +14,6 @@ res = torch.nn.functional.conv1d(t, w)
|
||||
print(res.flatten())
|
||||
res = torch.nn.functional.conv1d(t, w, padding=1)
|
||||
print(res.flatten())
|
||||
|
||||
w_t = w.transpose(0, 1)
|
||||
res = torch.nn.functional.conv_transpose1d(t, w_t)
|
||||
print(res.shape)
|
||||
print(res)
|
||||
*/
|
||||
fn conv1d(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
@ -37,41 +33,32 @@ fn conv1d(dev: &Device) -> Result<()> {
|
||||
dev,
|
||||
)?
|
||||
.reshape((2, 4, 3))?;
|
||||
let res = t.conv1d(&w, 0, 1, 1, 1)?;
|
||||
let res = t.conv1d(&w, 0, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.6357, -1.3336, 4.1393, -1.1784, 3.5675, 0.5069]
|
||||
);
|
||||
let res = t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?;
|
||||
let res = t.conv1d(&w, /*padding*/ 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 5]);
|
||||
// Same as pytorch default padding: use zeros.
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||
);
|
||||
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[
|
||||
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
|
||||
4.7076, -5.9745, -0.8276, 1.621
|
||||
],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn conv1d_small(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(&[0.4056f32, -0.8689, -0.0773, -1.5630], dev)?.reshape((1, 1, 4))?;
|
||||
let w = Tensor::new(&[1f32, 0., 0.], dev)?.reshape((1, 1, 3))?;
|
||||
let res = t.conv1d(&w, 0, 1, 1, 1)?;
|
||||
let res = t.conv1d(&w, 0, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 2]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[0.4056, -0.8689]
|
||||
);
|
||||
let res = t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?;
|
||||
let res = t.conv1d(&w, /*padding*/ 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 4]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
@ -90,19 +77,6 @@ print(t.flatten())
|
||||
print(w.flatten())
|
||||
res = torch.nn.functional.conv2d(t, w)
|
||||
print(res.flatten())
|
||||
|
||||
w_t = w.transpose(0, 1)
|
||||
res = torch.nn.functional.conv_transpose2d(t, w_t)
|
||||
print(res.shape)
|
||||
print(res)
|
||||
|
||||
res = torch.nn.functional.conv2d(t, w, dilation=2)
|
||||
print(res.shape)
|
||||
print(res[0])
|
||||
|
||||
res = torch.nn.functional.conv_transpose2d(t, w_t, dilation=2)
|
||||
print(res.shape)
|
||||
print(res)
|
||||
*/
|
||||
fn conv2d(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
@ -135,7 +109,7 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
)?;
|
||||
let t = t.reshape((1, 4, 5, 5))?;
|
||||
let w = w.reshape((2, 4, 3, 3))?;
|
||||
let res = t.conv2d(&w, 0, 1, 1, 1)?;
|
||||
let res = t.conv2d(&w, 0, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
@ -144,69 +118,6 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
||||
]
|
||||
);
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
||||
[
|
||||
[
|
||||
[-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277],
|
||||
[1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375],
|
||||
[0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889],
|
||||
[0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632],
|
||||
[-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985],
|
||||
[2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114],
|
||||
[5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579]
|
||||
],
|
||||
[
|
||||
[1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211],
|
||||
[-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131],
|
||||
[1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621],
|
||||
[-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142],
|
||||
[7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059],
|
||||
[-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516],
|
||||
[-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171]
|
||||
]
|
||||
]
|
||||
);
|
||||
// Dilations.
|
||||
let res = t.conv2d(&w, 0, 1, 2, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 1, 1]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.45, -2.3504],
|
||||
);
|
||||
|
||||
// Transpose and dilations.
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?;
|
||||
assert_eq!(res.dims(), [1, 2, 9, 9]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
||||
[
|
||||
[
|
||||
[-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277],
|
||||
[2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499],
|
||||
[-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376],
|
||||
[-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141],
|
||||
[-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822],
|
||||
[0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03],
|
||||
[-2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51, -3.5024],
|
||||
[4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787],
|
||||
[5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579]
|
||||
],
|
||||
[
|
||||
[1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211],
|
||||
[-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278],
|
||||
[1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861],
|
||||
[1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185],
|
||||
[1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642],
|
||||
[3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957],
|
||||
[5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856],
|
||||
[-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908],
|
||||
[-5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827, 1.0171]
|
||||
]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -220,16 +131,6 @@ print(t.flatten())
|
||||
print(w.flatten())
|
||||
res = torch.nn.functional.conv2d(t, w)
|
||||
print(res.flatten())
|
||||
|
||||
w_t = w.transpose(0, 1)
|
||||
res = torch.nn.functional.conv_transpose2d(t, w_t)
|
||||
print(res.shape)
|
||||
print(res.flatten())
|
||||
|
||||
t_t = w.transpose(0, 1)
|
||||
res = torch.nn.functional.conv_transpose2d(t_t, w)
|
||||
print(res.shape)
|
||||
print(res.flatten())
|
||||
*/
|
||||
fn conv2d_small(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
@ -242,41 +143,12 @@ fn conv2d_small(dev: &Device) -> Result<()> {
|
||||
let w = Tensor::new(&[-0.9259f32, 1.3017], dev)?;
|
||||
let t = t.reshape((1, 2, 3, 3))?;
|
||||
let w = w.reshape((1, 2, 1, 1))?;
|
||||
let res = t.conv2d(&w, 0, 1, 1, 1)?;
|
||||
let res = t.conv2d(&w, 0, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539]
|
||||
);
|
||||
let res = t.conv2d(&w, 2, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 7, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[
|
||||
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
||||
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1640, -0.0111, -0.1742, 0.0000, 0.0000,
|
||||
0.0000, 0.0000, 2.6437, -2.0268, 1.1823, 0.0000, 0.0000, 0.0000, 0.0000, 3.2855,
|
||||
-1.0324, 0.2539, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
||||
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
|
||||
]
|
||||
);
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539],
|
||||
);
|
||||
let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [2, 2, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[
|
||||
-0.3755, 0.8045, -0.6336, -0.2218, -1.1369, 0.8599, 1.5768, -0.1268, -0.1728, 0.528,
|
||||
-1.131, 0.8908, 0.3118, 1.5984, -1.2089, -2.2168, 0.1783, 0.2429, -0.3838, 0.5802,
|
||||
-0.3268, -2.0382, 0.6329, -0.2293, -1.2154, 0.6441, -0.3035, 0.5396, -0.8156, 0.4594,
|
||||
2.8654, -0.8898, 0.3224, 1.7087, -0.9056, 0.4267
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -290,7 +162,7 @@ fn conv2d_smaller(dev: &Device) -> Result<()> {
|
||||
let w = Tensor::new(&[1f32, 1., 1., 1., 1., 1., 1., 1., 1.], dev)?;
|
||||
let t = t.reshape((1, 1, 3, 3))?;
|
||||
let w = w.reshape((1, 1, 3, 3))?;
|
||||
let res = t.conv2d(&w, 0, 1, 1, 1)?;
|
||||
let res = t.conv2d(&w, 0, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 1, 1]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
@ -299,297 +171,8 @@ fn conv2d_smaller(dev: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/* This test is based on the following script.
|
||||
import torch
|
||||
torch.manual_seed(4242)
|
||||
|
||||
t = torch.randn((1, 2, 4, 2))
|
||||
w = torch.randn((1, 2, 1, 1))
|
||||
print(t.flatten())
|
||||
print(w.flatten())
|
||||
res = torch.nn.functional.conv2d(t, w)
|
||||
print(res.flatten())
|
||||
*/
|
||||
fn conv2d_non_square(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
&[
|
||||
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
||||
1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699,
|
||||
],
|
||||
dev,
|
||||
)?;
|
||||
let w = Tensor::new(&[-1.1351f32, 1.3841], dev)?;
|
||||
let t = t.reshape((1, 2, 4, 2))?;
|
||||
let w = w.reshape((1, 2, 1, 1))?;
|
||||
let res = t.conv2d(&w, 0, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 4, 2]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[0.2312, 5.2238, 2.3772, 1.9076, 2.0256, -0.5776, -1.6028, -1.467]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/*
|
||||
import torch
|
||||
torch.manual_seed(4242)
|
||||
|
||||
t = torch.randn((1, 4, 5, 5), requires_grad=True)
|
||||
w = torch.randn((2, 4, 3, 3), requires_grad=True)
|
||||
print(t.flatten())
|
||||
print(w.flatten())
|
||||
res = torch.nn.functional.conv2d(t, w)
|
||||
print(res.flatten())
|
||||
loss = (res ** 2).sum()
|
||||
print(loss)
|
||||
loss.backward()
|
||||
print(t.grad.shape)
|
||||
print(t.grad.flatten())
|
||||
print(w.grad.shape)
|
||||
print(w.grad.flatten())
|
||||
|
||||
t.grad.zero_()
|
||||
w.grad.zero_()
|
||||
res = torch.nn.functional.conv2d(t, w, stride=2)
|
||||
print(res.flatten())
|
||||
loss = (res ** 2).sum()
|
||||
print(loss)
|
||||
loss.backward()
|
||||
print(t.grad.shape)
|
||||
print(t.grad[0])
|
||||
print(w.grad.shape)
|
||||
print(w.grad[0])
|
||||
*/
|
||||
fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||
use candle_core::Var;
|
||||
let t = Var::from_slice(
|
||||
&[
|
||||
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
||||
1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395,
|
||||
1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836,
|
||||
0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130, 1.3123,
|
||||
1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071, 1.1586,
|
||||
0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090, 0.2049,
|
||||
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
|
||||
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
|
||||
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
|
||||
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
||||
],
|
||||
(1, 4, 5, 5),
|
||||
dev,
|
||||
)?;
|
||||
let w = Var::from_slice(
|
||||
&[
|
||||
-0.9325f32, 0.6451, -0.8537, 0.2378, 0.8764, -0.1832, 0.2987, -0.6488, -0.2273,
|
||||
-2.4184, -0.1192, -0.4821, -0.5079, -0.5766, -2.4729, 1.6734, 0.4558, 0.2851, 1.1514,
|
||||
-0.9013, 1.0662, -0.1817, -0.0259, 0.1709, 0.5367, 0.7513, 0.8086, -2.2586, -0.5027,
|
||||
0.9141, -1.3086, -1.3343, -1.5669, -0.1657, 0.7958, 0.1432, 0.3896, -0.4501, 0.1667,
|
||||
0.0714, -0.0952, 1.2970, -0.1674, -0.3178, 1.0677, 0.3060, 0.7080, 0.1914, 1.1679,
|
||||
-0.3602, 1.9265, -1.8626, -0.5112, -0.0982, 0.2621, 0.6565, 0.5908, 1.0089, -0.1646,
|
||||
1.8032, -0.6286, 0.2016, -0.3370, 1.2555, 0.8009, -0.6488, -0.4652, -1.5685, 1.5860,
|
||||
0.5583, 0.4623, 0.6026,
|
||||
],
|
||||
(2, 4, 3, 3),
|
||||
dev,
|
||||
)?;
|
||||
let res = t.conv2d(&w, 0, 1, 1, 1)?;
|
||||
let loss = res.sqr()?.sum_all()?;
|
||||
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 741.12f32);
|
||||
let grads = loss.backward()?;
|
||||
let grad_t = grads.get(&t).unwrap();
|
||||
let grad_w = grads.get(&w).unwrap();
|
||||
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
|
||||
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&grad_t.flatten_all()?, 2)?,
|
||||
[
|
||||
9.29, -2.84, -5.71, 3.38, -7.71, -19.15, 7.02, 29.1, 9.34, 34.73, -22.87, 24.35,
|
||||
-39.88, -14.01, 21.08, 9.94, 13.63, -34.68, 11.21, -6.26, 7.72, -6.32, -16.64, -1.08,
|
||||
-20.22, 21.73, -0.37, -4.06, 5.82, -3.65, -30.73, 14.55, 87.7, 31.6, 4.53, -89.78,
|
||||
-75.37, -57.43, -7.56, 92.96, 18.79, -4.63, -159.75, -42.47, -47.26, 52.88, 37.32,
|
||||
49.0, 12.82, 2.01, -8.98, 20.18, 16.62, 12.06, 15.38, 20.0, 2.57, -15.22, 72.62,
|
||||
-10.75, 2.25, -31.2, 3.75, -0.2, 9.76, -0.68, 5.21, -40.44, -22.59, -61.61, 17.28,
|
||||
20.41, 37.55, 5.23, 6.81, 23.54, 23.62, -9.99, -9.13, 4.87, -35.06, -26.1, 63.48,
|
||||
25.81, -39.21, -70.68, -46.96, 2.33, 41.81, 82.42, -28.63, -11.78, -35.33, -10.28,
|
||||
-28.57, -9.13, 7.21, -9.05, -9.62, -11.25
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&grad_w.flatten_all()?, 2)?,
|
||||
[
|
||||
-28.92, -22.88, -141.23, 73.35, 61.07, 47.81, -20.0, -73.71, -41.82, -13.59, 21.5,
|
||||
28.72, 28.57, -46.85, -90.19, 143.61, 16.68, 7.43, 18.88, -90.81, -20.29, 54.79, 82.63,
|
||||
22.94, 77.81, -16.39, -13.2, 9.34, -40.39, -26.62, 5.33, -60.91, 9.09, -59.37, 7.08,
|
||||
58.64, 5.55, 20.52, 2.5, -17.25, -6.8, 22.21, 30.15, -7.52, -37.46, 5.67, 22.58, 9.03,
|
||||
47.05, 17.61, 37.31, -98.13, -14.61, -4.8, -6.36, 44.69, 23.34, 8.37, -13.52, 80.05,
|
||||
-34.24, -16.36, -12.31, 1.92, -33.62, -14.1, -49.23, -7.39, 11.5, -9.98, 9.66, 29.6
|
||||
]
|
||||
);
|
||||
|
||||
// Same as before but with stride.
|
||||
let res = t.conv2d(&w, 0, 2, 1, 1)?;
|
||||
let loss = res.sqr()?.sum_all()?;
|
||||
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 277.16f32);
|
||||
let grads = loss.backward()?;
|
||||
let grad_t = grads.get(&t).unwrap();
|
||||
let grad_w = grads.get(&w).unwrap();
|
||||
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
|
||||
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_t.i(0)?, 2)?,
|
||||
[
|
||||
[
|
||||
[9.29, -7.03, 0.94, 3.49, -7.71],
|
||||
[-1.8, -7.82, 8.9, 8.46, 7.43],
|
||||
[-25.84, 22.09, -19.27, -0.22, 1.69],
|
||||
[4.02, 18.53, -18.37, 2.3, -24.51],
|
||||
[7.72, -9.68, -12.34, 5.6, -20.22]
|
||||
],
|
||||
[
|
||||
[21.73, 3.39, -18.27, 3.86, -3.65],
|
||||
[8.25, 3.73, 30.73, -8.61, -11.93],
|
||||
[-72.15, -15.36, -17.53, -12.32, -1.61],
|
||||
[-22.32, -7.79, -91.82, 6.44, -37.69],
|
||||
[52.88, 14.44, 42.75, 9.88, 2.01]
|
||||
],
|
||||
[
|
||||
[-8.98, 9.91, 6.75, -4.68, 15.38],
|
||||
[4.93, -0.33, 9.94, -1.46, 14.78],
|
||||
[13.62, -30.63, 3.96, -3.58, -4.48],
|
||||
[-14.13, 1.19, -34.43, 3.08, -33.83],
|
||||
[17.28, 12.94, 31.83, -3.35, 6.81]
|
||||
],
|
||||
[
|
||||
[23.54, 6.98, -24.52, 0.52, 4.87],
|
||||
[9.65, 6.18, 1.71, -25.23, -4.93],
|
||||
[-54.99, -23.66, 3.19, -3.73, 18.58],
|
||||
[-21.35, -10.39, -39.88, 28.73, -30.76],
|
||||
[-9.13, 11.12, -14.0, -8.23, -11.25]
|
||||
]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_w.i(0)?, 2)?,
|
||||
[
|
||||
[
|
||||
[28.34, -7.91, -45.75],
|
||||
[21.03, 3.86, 29.86],
|
||||
[0.72, -36.58, -35.28]
|
||||
],
|
||||
[
|
||||
[-16.04, 11.53, -16.38],
|
||||
[29.62, -16.32, -48.35],
|
||||
[57.5, 28.29, 25.81]
|
||||
],
|
||||
[
|
||||
[2.93, -19.6, 1.57],
|
||||
[27.15, 53.88, -24.64],
|
||||
[12.74, -22.6, -26.2]
|
||||
],
|
||||
[
|
||||
[-0.18, -14.86, -6.82],
|
||||
[-19.55, -2.72, 45.9],
|
||||
[-2.54, 36.97, 27.11]
|
||||
]
|
||||
]
|
||||
);
|
||||
|
||||
// Replicate the issue from https://github.com/huggingface/candle/issues/1212
|
||||
let res = t.i((.., .., 0..4, 0..4))?.conv2d(&w, 0, 2, 1, 1)?;
|
||||
let loss = res.sqr()?.sum_all()?;
|
||||
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 21.12f32);
|
||||
let grads = loss.backward()?;
|
||||
let grad_t = grads.get(&t).unwrap();
|
||||
let grad_w = grads.get(&w).unwrap();
|
||||
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
|
||||
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_t.i(0)?, 2)?,
|
||||
[
|
||||
[
|
||||
[9.29, -7.03, 7.87, 0.0, 0.0],
|
||||
[-1.8, -7.82, 5.9, 0.0, 0.0],
|
||||
[-3.12, 4.49, 5.52, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
],
|
||||
[
|
||||
[21.73, 3.39, 4.77, 0.0, 0.0],
|
||||
[8.25, 3.73, 27.61, 0.0, 0.0],
|
||||
[-20.55, -5.61, -2.77, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
],
|
||||
[
|
||||
[-8.98, 9.91, -7.15, 0.0, 0.0],
|
||||
[4.93, -0.33, 4.56, 0.0, 0.0],
|
||||
[-6.7, -5.76, -8.05, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
],
|
||||
[
|
||||
[23.54, 6.98, -10.0, 0.0, 0.0],
|
||||
[9.65, 6.18, 18.72, 0.0, 0.0],
|
||||
[3.29, -5.27, 0.79, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_w.i(0)?, 2)?,
|
||||
[
|
||||
[
|
||||
[-3.47, 7.44, 0.66],
|
||||
[12.89, -3.4, -9.29],
|
||||
[-14.16, -0.83, 7.14]
|
||||
],
|
||||
[
|
||||
[-3.23, 5.37, -3.02],
|
||||
[-2.12, -11.24, 1.94],
|
||||
[6.97, 7.2, 2.99]
|
||||
],
|
||||
[
|
||||
[-4.04, -3.31, 4.87],
|
||||
[-6.68, -5.68, 1.73],
|
||||
[-5.54, 4.32, 0.52]
|
||||
],
|
||||
[[-4.72, 1.5, 4.72], [3.79, 4.04, 6.76], [-4.6, 5.8, 6.93]]
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(conv1d, conv1d_cpu, conv1d_gpu, conv1d_metal);
|
||||
test_device!(
|
||||
conv1d_small,
|
||||
conv1d_small_cpu,
|
||||
conv1d_small_gpu,
|
||||
conv1d_small_metal
|
||||
);
|
||||
test_device!(conv2d, conv2d_cpu, conv2d_gpu, conv2d_metal);
|
||||
test_device!(
|
||||
conv2d_non_square,
|
||||
conv2d_non_square_cpu,
|
||||
conv2d_non_square_gpu,
|
||||
conv2d_non_square_metal
|
||||
);
|
||||
test_device!(
|
||||
conv2d_small,
|
||||
conv2d_small_cpu,
|
||||
conv2d_small_gpu,
|
||||
conv2d_small_metal
|
||||
);
|
||||
test_device!(
|
||||
conv2d_smaller,
|
||||
conv2d_smaller_cpu,
|
||||
conv2d_smaller_gpu,
|
||||
conv2d_smaller_metal
|
||||
);
|
||||
test_device!(
|
||||
conv2d_grad,
|
||||
conv2d_grad_cpu,
|
||||
conv2d_grad_gpu,
|
||||
conv2_grad_metal
|
||||
);
|
||||
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
|
||||
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
|
||||
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
|
||||
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
|
||||
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
|
||||
|
@ -1,8 +1,10 @@
|
||||
use candle_core::backend::BackendStorage;
|
||||
use candle_core::cpu_backend;
|
||||
use candle_core::test_utils::to_vec1_round;
|
||||
use candle_core::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor};
|
||||
|
||||
mod test_utils;
|
||||
use test_utils::to_vec1_round;
|
||||
|
||||
fn fwd<T: num_traits::Float>(v: T, alpha: f64) -> T {
|
||||
if v.is_sign_positive() {
|
||||
v
|
||||
@ -37,7 +39,7 @@ fn custom_op1_no_backward() -> Result<()> {
|
||||
let cpu = &Device::Cpu;
|
||||
let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?;
|
||||
let t = (t - 5.)?;
|
||||
let elu_t = t.apply_op1_no_bwd(&Elu { alpha: 1. })?;
|
||||
let elu_t = t.custom_op1(Elu { alpha: 1. })?;
|
||||
assert_eq!(
|
||||
to_vec1_round(&elu_t, 4)?,
|
||||
&[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
@ -94,7 +96,7 @@ impl CustomOp1 for EluWithBackward {
|
||||
|
||||
fn bwd(&self, arg: &Tensor, _res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
|
||||
let alpha = self.0.alpha;
|
||||
let bwd = arg.apply_op1(EluBackward { alpha })?;
|
||||
let bwd = arg.custom_op1(EluBackward { alpha })?;
|
||||
Ok(Some(grad_res.mul(&bwd)?))
|
||||
}
|
||||
}
|
||||
@ -103,7 +105,7 @@ impl CustomOp1 for EluWithBackward {
|
||||
fn custom_op1_with_backward() -> Result<()> {
|
||||
let cpu = &Device::Cpu;
|
||||
let t = candle_core::Var::new(&[-2f32, 0f32, 2f32], cpu)?;
|
||||
let elu_t = t.apply_op1(EluWithBackward::new(2.))?;
|
||||
let elu_t = t.custom_op1(EluWithBackward::new(2.))?;
|
||||
assert_eq!(to_vec1_round(&elu_t, 4)?, &[-1.7293, 0.0, 2.0]);
|
||||
|
||||
let grads = elu_t.backward()?;
|
||||
|
Binary file not shown.
@ -1,5 +1,6 @@
|
||||
use anyhow::{Context, Result};
|
||||
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
|
||||
use candle_core::{Device, Shape, Tensor, Var};
|
||||
mod test_utils;
|
||||
|
||||
fn simple_grad(device: &Device) -> Result<()> {
|
||||
let x = Var::new(&[3f32, 1., 4.], device)?;
|
||||
@ -173,331 +174,11 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [6., 2., 8., 0.3]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [2., 2., 2., 2.]);
|
||||
|
||||
let x = Var::new(&[3f32, 1., 4., 0.15], device)?;
|
||||
let y = x.powf(2.5)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(test_utils::to_vec1_round(&y, 2)?, [15.59, 1.0, 32.0, 0.01]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 2)?,
|
||||
[12.99, 2.5, 20.0, 0.15]
|
||||
);
|
||||
|
||||
let y = x.tanh()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(test_utils::to_vec1_round(&y, 2)?, [1.0, 0.76, 1.0, 0.15]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 2)?,
|
||||
[0.01, 0.42, 0.0, 0.98],
|
||||
);
|
||||
|
||||
// testing compared to pytorch nn.GELU(approximate = 'tanh')
|
||||
let y = x.gelu()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[2.9964, 0.8412, 3.9999, 0.0839]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[1.0116, 1.0830, 1.0003, 0.6188],
|
||||
);
|
||||
|
||||
// Testing compared to pytorch torch.erf
|
||||
//
|
||||
// import torch
|
||||
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
|
||||
// y = x.erf()
|
||||
// print(y)
|
||||
// loss = y.sum()
|
||||
// loss.backward()
|
||||
// print(x.grad)
|
||||
let y = x.erf()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(test_utils::to_vec1_round(&y, 4)?, [1.0, 0.8427, 1.0, 0.168]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[0.0001, 0.4151, 0.0, 1.1033],
|
||||
);
|
||||
|
||||
// Testing compared to pytorch nn.GELU(approximate = 'none')
|
||||
//
|
||||
// import torch
|
||||
// import torch.nn.functional as F
|
||||
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
|
||||
// y = F.gelu(x, approximate='none')
|
||||
// print(y)
|
||||
// loss = y.sum()
|
||||
// loss.backward()
|
||||
// print(x.grad)
|
||||
let y = x.gelu_erf()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[2.9960, 0.8413, 3.9999, 0.0839]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[1.0119, 1.0833, 1.0005, 0.6188],
|
||||
);
|
||||
|
||||
// Testing compared to pytorch elu
|
||||
//
|
||||
// import torch
|
||||
// import torch.nn.functional as F
|
||||
// x = torch.tensor([-1.0, 0.0, -2.0, 3.0], requires_grad=True)
|
||||
// y = F.elu(x, alpha=2.0)
|
||||
// print(y)
|
||||
// loss = y.min
|
||||
// loss = y.sum()
|
||||
// loss.backward()
|
||||
// print(x.grad)
|
||||
let elu_x = Var::new(&[-1.0f32, 0., -2., 3.], device)?;
|
||||
let y = elu_x.elu(2.)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&elu_x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[-1.2642, 0.0000, -1.7293, 3.0000]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[0.7358, 2.0000, 0.2707, 1.0000]
|
||||
);
|
||||
|
||||
// manually checked: see comments
|
||||
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
|
||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||
|
||||
#[rustfmt::skip]
|
||||
let z = Tensor::new(
|
||||
&[
|
||||
1_f32, 02., 03., 04., 05., 06.,
|
||||
07., 08., 09., 10., 11., 12.,
|
||||
13., 14., 15., 16., 17., 18.,
|
||||
19., 20., 21., 22., 23., 24.,
|
||||
25., 26., 27., 28., 29., 30.,
|
||||
31., 32., 33., 34., 35., 36.,
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
// gradient should be
|
||||
// row 1
|
||||
// 1+2+7+8 = 18
|
||||
// 3+4+9+10 = 26
|
||||
// 5+6+11+12 = 34
|
||||
// row 2
|
||||
// 13+14+19+20 = 66
|
||||
// 15+16+21+22 = 74
|
||||
// 17+18+23+24 = 82
|
||||
// row 3
|
||||
// 25+26+31+32 = 114
|
||||
// 27+28+33+34 = 122
|
||||
// 29+30+35+36 = 130
|
||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||
|
||||
let grads = loss.backward()?;
|
||||
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?,
|
||||
[[18_f32, 26., 34.], [66., 74., 82.], [114., 122., 130.]]
|
||||
);
|
||||
|
||||
// manually checked: see comments
|
||||
let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?;
|
||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||
|
||||
#[rustfmt::skip]
|
||||
let z = Tensor::new(
|
||||
&[
|
||||
1_f32, 02., 03., 04., 05., 06.,
|
||||
07., 08., 09., 10., 11., 12.,
|
||||
13., 14., 15., 16., 17., 18.,
|
||||
19., 20., 21., 22., 23., 24.,
|
||||
25., 26., 27., 28., 29., 30.,
|
||||
31., 32., 33., 34., 35., 36.,
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
// gradient should be
|
||||
// row 1
|
||||
// 1+2+3+7+8+9+13+14+15 = 72
|
||||
// 4+5+6+10+11+12+16+17+18 = 99
|
||||
// row 2
|
||||
// 19+20+21+25+26+27+31+32+33 = 234
|
||||
// 22+23+24+28+29+30+34+35+36 = 243
|
||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||
|
||||
let grads = loss.backward()?;
|
||||
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?,
|
||||
[[72_f32, 99.], [234., 261.]]
|
||||
);
|
||||
|
||||
// manually checked: see comments
|
||||
let x = Var::new(&[[[[1f32, 2.], [4., 5.]], [[6f32, 7.], [8., 9.]]]], device)?;
|
||||
|
||||
let y = x.interpolate2d(4, 4)?.reshape(32)?;
|
||||
|
||||
#[rustfmt::skip]
|
||||
let z = Tensor::new(
|
||||
&[
|
||||
1_f32, 02., 03., 04.,
|
||||
05., 06., 07., 08.,
|
||||
09., 10., 11., 12.,
|
||||
13., 14., 15., 16.,
|
||||
17., 18., 19., 20.,
|
||||
21., 22., 23., 24.,
|
||||
25., 26., 27., 28.,
|
||||
29., 30., 31., 32.
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
// gradient should be
|
||||
// m1r1
|
||||
// 1+2+5+6=14
|
||||
// 3+4+7+8=22
|
||||
// m1r2
|
||||
// 9+10+13+14=46
|
||||
// 11+12+15+16=54
|
||||
// m2r1
|
||||
// 17+18+21+22=78
|
||||
// 19+20+23+24=86
|
||||
// m2r2
|
||||
// 25+26+29+30=110
|
||||
// 27+28+31+32=118
|
||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||
|
||||
let grads = loss.backward()?;
|
||||
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?,
|
||||
[[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]]
|
||||
);
|
||||
|
||||
// manually checked: see comments
|
||||
let x = Var::new(
|
||||
&[[[[1f32, 2.], [4., 5.]]], [[[6f32, 7.], [8., 9.]]]],
|
||||
device,
|
||||
)?;
|
||||
|
||||
let y = x.interpolate2d(4, 4)?.reshape(32)?;
|
||||
|
||||
#[rustfmt::skip]
|
||||
let z = Tensor::new(
|
||||
&[
|
||||
1_f32, 02., 03., 04.,
|
||||
05., 06., 07., 08.,
|
||||
09., 10., 11., 12.,
|
||||
13., 14., 15., 16.,
|
||||
17., 18., 19., 20.,
|
||||
21., 22., 23., 24.,
|
||||
25., 26., 27., 28.,
|
||||
29., 30., 31., 32.
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
// gradient should be
|
||||
// m1r1
|
||||
// 1+2+5+6=14
|
||||
// 3+4+7+8=22
|
||||
// m1r2
|
||||
// 9+10+13+14=46
|
||||
// 11+12+15+16=54
|
||||
// m2r1
|
||||
// 17+18+21+22=78
|
||||
// 19+20+23+24=86
|
||||
// m2r2
|
||||
// 25+26+29+30=110
|
||||
// 27+28+31+32=118
|
||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||
|
||||
let grads = loss.backward()?;
|
||||
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?,
|
||||
[[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn binary_grad(device: &Device) -> Result<()> {
|
||||
let x = Var::new(&[3f32, 1., -4., -1.], device)?;
|
||||
let x = x.as_tensor();
|
||||
// leaky relu
|
||||
let y = x.maximum(&(x * 0.1)?)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(x.to_vec1::<f32>()?, [3., 1., -4., -1.]);
|
||||
assert_eq!(y.to_vec1::<f32>()?, [3., 1., -0.4, -0.1]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [1., 1., 0.1, 0.1]);
|
||||
|
||||
let y = x.minimum(&(x * 0.1)?)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [0.3, 0.1, -4., -1.]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [0.1, 0.1, 1., 1.]);
|
||||
|
||||
// This one is easy to mess up, we want the gradient to be one as it is the identity function.
|
||||
let y = x.minimum(x)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [3., 1., -4., -1.]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [1., 1., 1., 1.]);
|
||||
|
||||
let x_var = Var::new(&[3f32, 1., -4., -1., 5., 9.], device)?;
|
||||
let x = x_var.as_tensor();
|
||||
let y_var = Var::new(&[2f32, 7., 1.], device)?;
|
||||
let y = y_var.as_tensor();
|
||||
|
||||
let ss = x
|
||||
.reshape((2, 3))?
|
||||
.slice_scatter0(&y.reshape((1, 3))?, 1)?
|
||||
.sqr()?;
|
||||
let grads = ss.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
let grad_y = grads.get(y).context("no grad for y")?;
|
||||
assert_eq!(ss.to_vec2::<f32>()?, [[9., 1., 16.], [4., 49., 1.]]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [6.0, 2.0, -8.0, 0.0, 0.0, 0.0]);
|
||||
assert_eq!(grad_y.to_vec1::<f32>()?, [4.0, 14.0, 2.0]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(
|
||||
simple_grad,
|
||||
simple_grad_cpu,
|
||||
simple_grad_gpu,
|
||||
simple_grad_metal
|
||||
);
|
||||
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu, sum_grad_metal);
|
||||
test_device!(
|
||||
matmul_grad,
|
||||
matmul_grad_cpu,
|
||||
matmul_grad_gpu,
|
||||
matmul_grad_metal
|
||||
);
|
||||
test_device!(
|
||||
grad_descent,
|
||||
grad_descent_cpu,
|
||||
grad_descent_gpu,
|
||||
grad_descent_metal
|
||||
);
|
||||
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu, unary_grad_metal);
|
||||
test_device!(
|
||||
binary_grad,
|
||||
binary_grad_cpu,
|
||||
binary_grad_gpu,
|
||||
binary_grad_metal
|
||||
);
|
||||
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
|
||||
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
|
||||
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
|
||||
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
|
||||
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu);
|
||||
|
@ -1,6 +1,8 @@
|
||||
use anyhow::Result;
|
||||
use candle_core::{Device, IndexOp, Tensor};
|
||||
|
||||
mod test_utils;
|
||||
|
||||
#[test]
|
||||
fn integer_index() -> Result<()> {
|
||||
let dev = Device::Cpu;
|
||||
@ -91,32 +93,3 @@ fn index_3d() -> Result<()> {
|
||||
assert_eq!(tensor.i((1, .., 3))?.to_vec1::<u32>()?, &[15, 19, 23]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slice_assign() -> Result<()> {
|
||||
let dev = Device::Cpu;
|
||||
|
||||
let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?;
|
||||
let src = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((3, 2))?;
|
||||
let out = tensor.slice_assign(&[1..4, 3..5], &src)?;
|
||||
assert_eq!(
|
||||
out.to_vec2::<u32>()?,
|
||||
&[
|
||||
[0, 1, 2, 3, 4],
|
||||
[5, 6, 7, 0, 1],
|
||||
[10, 11, 12, 2, 3],
|
||||
[15, 16, 17, 4, 5]
|
||||
]
|
||||
);
|
||||
let out = tensor.slice_assign(&[0..3, 0..2], &src)?;
|
||||
assert_eq!(
|
||||
out.to_vec2::<u32>()?,
|
||||
&[
|
||||
[0, 1, 2, 3, 4],
|
||||
[2, 3, 7, 8, 9],
|
||||
[4, 5, 12, 13, 14],
|
||||
[15, 16, 17, 18, 19]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,4 +1,5 @@
|
||||
use candle::{test_device, Device, IndexOp, Result, Tensor};
|
||||
mod test_utils;
|
||||
use candle::{Device, IndexOp, Result, Tensor};
|
||||
use candle_core as candle;
|
||||
|
||||
fn contiguous(device: &Device) -> Result<()> {
|
||||
@ -49,7 +50,7 @@ fn contiguous(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(contiguous, contiguous_cpu, contiguous_gpu, contiguous_metal);
|
||||
test_device!(contiguous, contiguous_cpu, contiguous_gpu);
|
||||
|
||||
#[test]
|
||||
fn strided_blocks() -> Result<()> {
|
||||
|
@ -1,9 +0,0 @@
|
||||
import numpy as np
|
||||
x = np.arange(10)
|
||||
|
||||
# Write a npy file.
|
||||
np.save("test.npy", x)
|
||||
|
||||
# Write multiple values to a npz file.
|
||||
values = { "x": x, "x_plus_one": x + 1 }
|
||||
np.savez("test.npz", **values)
|
@ -1,4 +1,5 @@
|
||||
use candle_core::{test_device, test_utils, Device, IndexOp, Result, Tensor};
|
||||
mod test_utils;
|
||||
use candle_core::{Device, IndexOp, Result, Tensor};
|
||||
|
||||
// https://github.com/huggingface/candle/issues/364
|
||||
fn avg_pool2d(dev: &Device) -> Result<()> {
|
||||
@ -6,15 +7,8 @@ fn avg_pool2d(dev: &Device) -> Result<()> {
|
||||
1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?;
|
||||
let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?;
|
||||
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(pool.to_vec2::<f32>()?, [[0.5f32, 1.], [1., 1.]]);
|
||||
|
||||
let data: Vec<f32> = vec![
|
||||
1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1.,
|
||||
];
|
||||
let t = Tensor::from_vec(data, (1, 1, 2, 8), dev)?;
|
||||
let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(pool.to_vec2::<f32>()?, [[5. / 4., 6. / 4., 6. / 4., 1.]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -24,12 +18,8 @@ fn max_pool2d(dev: &Device) -> Result<()> {
|
||||
];
|
||||
let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?;
|
||||
|
||||
let pool = t.max_pool2d(2)?.squeeze(0)?.squeeze(0)?;
|
||||
let pool = t.max_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(pool.to_vec2::<f32>()?, [[2f32, 3.], [5., 1.]]);
|
||||
|
||||
let t = t.reshape((1, 1, 2, 8))?;
|
||||
let pool = t.max_pool2d(2)?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(pool.to_vec2::<f32>()?, [[2.0, 3.0, 5.0, 1.0]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -53,29 +43,16 @@ fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
|
||||
dev,
|
||||
)?
|
||||
.reshape((1, 2, 4, 4))?;
|
||||
let pool = t.avg_pool2d(2)?.squeeze(0)?;
|
||||
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&pool, 4)?,
|
||||
test_utils::to_vec3_round(pool, 4)?,
|
||||
[
|
||||
[[-1.1926, -0.0395], [0.2688, 0.1871]],
|
||||
[[0.1835, -0.1606], [0.6249, 0.3217]]
|
||||
]
|
||||
);
|
||||
let pool = t.avg_pool2d(3)?.squeeze(0)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&pool, 4)?,
|
||||
[[[0.085]], [[0.0078]]]
|
||||
);
|
||||
|
||||
let t = t.reshape((1, 1, 4, 8))?;
|
||||
let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&pool, 4)?,
|
||||
[
|
||||
[0.7745, 0.0276, -1.6983, 0.12],
|
||||
[0.3542, 0.1625, 0.4542, -0.0014]
|
||||
]
|
||||
);
|
||||
let pool = t.avg_pool2d((3, 3), (3, 3))?.squeeze(0)?;
|
||||
assert_eq!(test_utils::to_vec3_round(pool, 4)?, [[[0.085]], [[0.0078]]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -98,17 +75,15 @@ fn upsample_nearest2d(dev: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu, avg_pool2d_metal);
|
||||
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu);
|
||||
test_device!(
|
||||
avg_pool2d_pytorch,
|
||||
avg_pool2d_pytorch_cpu,
|
||||
avg_pool2d_pytorch_gpu,
|
||||
avg_pool2d_pytorch_metal
|
||||
avg_pool2d_pytorch_gpu
|
||||
);
|
||||
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu, max_pool2d_metal);
|
||||
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);
|
||||
test_device!(
|
||||
upsample_nearest2d,
|
||||
upsample_nearest2d_cpu,
|
||||
upsample_nearest2d_gpu,
|
||||
upsample_nearest2d_metal
|
||||
upsample_nearest2d_gpu
|
||||
);
|
||||
|
@ -1,37 +0,0 @@
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
|
||||
# Write a trivial tensor to a pt file
|
||||
a= torch.tensor([[1,2,3,4], [5,6,7,8]])
|
||||
o = OrderedDict()
|
||||
o["test"] = a
|
||||
|
||||
# Write a trivial tensor to a pt file
|
||||
torch.save(o, "test.pt")
|
||||
|
||||
############################################################################################################
|
||||
# Write a trivial tensor to a pt file with a key
|
||||
torch.save({"model_state_dict": o}, "test_with_key.pt")
|
||||
|
||||
############################################################################################################
|
||||
# Create a tensor with fortran contiguous memory layout
|
||||
import numpy as np
|
||||
|
||||
# Step 1: Create a 3D NumPy array with Fortran order using a range of numbers
|
||||
# For example, creating a 2x3x4 array
|
||||
array_fortran = np.asfortranarray(np.arange(1, 2*3*4 + 1).reshape(2, 3, 4))
|
||||
|
||||
# Verify the memory order
|
||||
print("Is Fortran contiguous (F order):", array_fortran.flags['F_CONTIGUOUS']) # Should be True
|
||||
print("Is C contiguous (C order):", array_fortran.flags['C_CONTIGUOUS']) # Should be False
|
||||
|
||||
# Step 2: Convert the NumPy array to a PyTorch tensor
|
||||
tensor_fortran = torch.from_numpy(array_fortran)
|
||||
|
||||
# Verify the tensor layout
|
||||
print("Tensor stride:", tensor_fortran.stride()) # Stride will reflect the Fortran memory layout
|
||||
|
||||
# Step 3: Save the PyTorch tensor to a .pth file
|
||||
torch.save({"tensor_fortran": tensor_fortran}, 'fortran_tensor_3d.pth')
|
||||
|
||||
print("3D Tensor saved with Fortran layout.")
|
@ -1,31 +0,0 @@
|
||||
/// Regression test for pth files not loading on Windows.
|
||||
#[test]
|
||||
fn test_pth() {
|
||||
let tensors = candle_core::pickle::PthTensors::new("tests/test.pt", None).unwrap();
|
||||
tensors.get("test").unwrap().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pth_with_key() {
|
||||
let tensors =
|
||||
candle_core::pickle::PthTensors::new("tests/test_with_key.pt", Some("model_state_dict"))
|
||||
.unwrap();
|
||||
tensors.get("test").unwrap().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pth_fortran_congiguous() {
|
||||
let tensors =
|
||||
candle_core::pickle::PthTensors::new("tests/fortran_tensor_3d.pth", None).unwrap();
|
||||
let tensor = tensors.get("tensor_fortran").unwrap().unwrap();
|
||||
|
||||
assert_eq!(tensor.dims3().unwrap(), (2, 3, 4));
|
||||
|
||||
assert_eq!(
|
||||
tensor.to_vec3::<i64>().unwrap(),
|
||||
[
|
||||
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
|
||||
[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]
|
||||
]
|
||||
);
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,24 +0,0 @@
|
||||
use candle_core::{DType, Result, Tensor};
|
||||
|
||||
#[test]
|
||||
fn npy() -> Result<()> {
|
||||
let npy = Tensor::read_npy("tests/test.npy")?;
|
||||
assert_eq!(
|
||||
npy.to_dtype(DType::U8)?.to_vec1::<u8>()?,
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn npz() -> Result<()> {
|
||||
let npz = Tensor::read_npz("tests/test.npz")?;
|
||||
assert_eq!(npz.len(), 2);
|
||||
assert_eq!(npz[0].0, "x");
|
||||
assert_eq!(npz[1].0, "x_plus_one");
|
||||
assert_eq!(
|
||||
npz[1].1.to_dtype(DType::U8)?.to_vec1::<u8>()?,
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
);
|
||||
Ok(())
|
||||
}
|
@ -1,4 +1,5 @@
|
||||
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor, D};
|
||||
mod test_utils;
|
||||
use candle_core::{DType, Device, IndexOp, Result, Tensor};
|
||||
|
||||
fn zeros(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
|
||||
@ -8,58 +9,6 @@ fn zeros(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ones(device: &Device) -> Result<()> {
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::U8, device)?.to_vec2::<u8>()?,
|
||||
[[1, 1, 1], [1, 1, 1]],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::U32, device)?.to_vec2::<u32>()?,
|
||||
[[1, 1, 1], [1, 1, 1]],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::I64, device)?.to_vec2::<i64>()?,
|
||||
[[1, 1, 1], [1, 1, 1]],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::F32, device)?.to_vec2::<f32>()?,
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn full(device: &Device) -> Result<()> {
|
||||
assert_eq!(
|
||||
Tensor::full(42u32, (2, 3), device)?.to_vec2::<u32>()?,
|
||||
[[42, 42, 42], [42, 42, 42]],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn arange(device: &Device) -> Result<()> {
|
||||
assert_eq!(
|
||||
Tensor::arange(0u8, 5u8, device)?.to_vec1::<u8>()?,
|
||||
[0, 1, 2, 3, 4],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::arange_step(0u8, 5u8, 2, device)?.to_vec1::<u8>()?,
|
||||
[0, 2, 4],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::arange_step(0u8, 5u8, 3, device)?.to_vec1::<u8>()?,
|
||||
[0, 3],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::arange_step(5i64, 0i64, -1, device)?.to_vec1::<i64>()?,
|
||||
[5, 4, 3, 2, 1],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn add_mul(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::new(&[3f32, 1., 4.], device)?;
|
||||
let dim1 = tensor.dims1()?;
|
||||
@ -85,71 +34,12 @@ fn tensor_2d(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn clamp(device: &Device) -> Result<()> {
|
||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
let tensor = tensor.clamp(1.5, 6.2)?;
|
||||
assert_eq!(
|
||||
tensor.to_vec2::<f32>()?,
|
||||
[[3.0, 1.5, 4.0, 1.5, 5.0], [2.0, 1.5, 6.2, 6.2, 2.0]],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn unary_op(device: &Device) -> Result<()> {
|
||||
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.gelu()?, 4)?,
|
||||
[
|
||||
[-0.0036, 0.8412, 3.9999, -0.046, 0.3457],
|
||||
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
|
||||
[
|
||||
[-0.004, 0.8413, 3.9999, -0.046, 0.3457],
|
||||
[2.6906, -0.0647, -0.1091, 1.7353, 2.7928]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.erf()?, 4)?,
|
||||
[
|
||||
[-1.0, 0.8427, 1.0, -0.1125, 0.5205],
|
||||
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
|
||||
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.floor()?, 4)?,
|
||||
[[-3.0, 1.0, 4.0, -1.0, 0.0], [2.0, -2.0, -1.0, 1.0, 2.0]]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.round()?, 4)?,
|
||||
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -2.0, -0.0, 2.0, 3.0]]
|
||||
);
|
||||
let tensor = Tensor::new(&[2997.9246, 314.15926f32], device)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&tensor.round_to(2)?, 4)?,
|
||||
[2997.92, 314.16]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?,
|
||||
[3000.0, 300.]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn binary_op(device: &Device) -> Result<()> {
|
||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||
let tensor1 = Tensor::new(data, device)?;
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
let data2 = &[[5f32, 5., 5., 5., 5.], [2., 1., 7., 8., 2.]];
|
||||
let tensor2 = Tensor::new(data2, device)?;
|
||||
let tensor = (&tensor1 + (&tensor1 * &tensor1)? / (&tensor1 + &tensor2))?;
|
||||
let tensor = (&tensor + (&tensor * &tensor)? / (&tensor + &tensor2))?;
|
||||
let dims = tensor.dims2()?;
|
||||
assert_eq!(dims, (2, 5));
|
||||
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
|
||||
@ -159,17 +49,6 @@ fn binary_op(device: &Device) -> Result<()> {
|
||||
let tensor = (&tensor - &tensor)?;
|
||||
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
|
||||
assert_eq!(content[0], [0., 0., 0., 0., 0.]);
|
||||
|
||||
let min = tensor1.minimum(&(&tensor2 * 0.5)?)?;
|
||||
let max = tensor1.maximum(&(&tensor2 * 0.5)?)?;
|
||||
assert_eq!(
|
||||
min.to_vec2::<f32>()?,
|
||||
[[2.5, 1.0, 2.5, 1.0, 2.5], [1.0, 0.5, 3.5, 4.0, 1.0]],
|
||||
);
|
||||
assert_eq!(
|
||||
max.to_vec2::<f32>()?,
|
||||
[[3.0, 2.5, 4.0, 2.5, 5.0], [2.0, 1.0, 7.0, 8.0, 2.0]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -188,22 +67,6 @@ fn transpose(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn var(device: &Device) -> Result<()> {
|
||||
// Values taken from https://pytorch.org/docs/stable/generated/torch.var.html
|
||||
let data = &[
|
||||
[0.2035f32, 1.2959, 1.8101, -0.4644],
|
||||
[1.5027, -0.3270, 0.5905, 0.6538],
|
||||
[-1.5745, 1.3330, -0.5596, -0.6548],
|
||||
[0.1264, -0.5080, 1.6420, 0.1992],
|
||||
];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.var_keepdim(1)?, 4)?,
|
||||
&[[1.0631], [0.559], [1.4893], [0.8258]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn sum(device: &Device) -> Result<()> {
|
||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
@ -717,30 +580,6 @@ fn index_select(device: &Device) -> Result<()> {
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
||||
);
|
||||
// Prior to https://github.com/huggingface/candle/pull/1022
|
||||
// There would be a bug where the last values in the result tensor would be set to 0.
|
||||
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
]
|
||||
);
|
||||
|
||||
// Test when selecting dim > 0 with ids size different from elem count of
|
||||
// target dim in source/input.
|
||||
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
|
||||
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
|
||||
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
|
||||
let hs = t.index_select(&ids, 1)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -787,48 +626,6 @@ fn index_add(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn slice_scatter(device: &Device) -> Result<()> {
|
||||
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
let src = Tensor::arange(100f32, 106f32, device)?.reshape((2, 3))?;
|
||||
assert_eq!(
|
||||
t.slice_scatter0(&src, 0)?.to_vec2::<f32>()?,
|
||||
&[
|
||||
[100.0, 101.0, 102.0],
|
||||
[103.0, 104.0, 105.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
t.slice_scatter0(&src, 1)?.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[100.0, 101.0, 102.0],
|
||||
[103.0, 104.0, 105.0],
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
t.slice_scatter0(&src, 2)?.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[100.0, 101.0, 102.0],
|
||||
[103.0, 104.0, 105.0],
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn scatter_add(device: &Device) -> Result<()> {
|
||||
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||
assert_eq!(
|
||||
@ -950,25 +747,6 @@ fn matmul(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn broadcast_matmul(device: &Device) -> Result<()> {
|
||||
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
|
||||
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
|
||||
let out = lhs.broadcast_matmul(&rhs)?;
|
||||
assert_eq!(out.dims(), &[3, 6, 4, 2]);
|
||||
for idx1 in 0..3 {
|
||||
for idx2 in 0..6 {
|
||||
let out = out.i((idx1, idx2))?;
|
||||
let lhs = lhs.i((idx1, 0))?;
|
||||
let rhs = rhs.i(idx2)?;
|
||||
let out2 = lhs.matmul(&rhs);
|
||||
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
|
||||
// With cuda, we see errors of up to ~1e-12.
|
||||
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn broadcasting(device: &Device) -> Result<()> {
|
||||
let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?;
|
||||
let t2 = Tensor::new(&[100f32, 200f32], device)?;
|
||||
@ -1070,69 +848,27 @@ fn broadcasting(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn randn(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?;
|
||||
assert_eq!(tensor.dims(), [5, 3]);
|
||||
let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?;
|
||||
assert_eq!(tensor.dims(), [5, 3]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
||||
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
||||
test_device!(full, full_cpu, full_gpu, full_metal);
|
||||
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
|
||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
||||
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
|
||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
|
||||
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
|
||||
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
|
||||
test_device!(min, min_cpu, min_gpu, min_metal);
|
||||
test_device!(max, max_cpu, max_gpu, max_metal);
|
||||
test_device!(argmax, argmax_cpu, argmax_gpu, argmax_metal);
|
||||
test_device!(argmin, argmin_cpu, argmin_gpu, argmin_metal);
|
||||
test_device!(transpose, transpose_cpu, transpose_gpu, transpose_metal);
|
||||
test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
|
||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
|
||||
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
||||
test_device!(
|
||||
broadcast_matmul,
|
||||
broadcast_matmul_cpu,
|
||||
broadcast_matmul_gpu,
|
||||
broadcast_matmul_metal
|
||||
);
|
||||
test_device!(
|
||||
broadcasting,
|
||||
broadcasting_cpu,
|
||||
broadcasting_gpu,
|
||||
broadcasting_metal
|
||||
);
|
||||
test_device!(
|
||||
index_select,
|
||||
index_select_cpu,
|
||||
index_select_gpu,
|
||||
index_select_metal
|
||||
);
|
||||
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
|
||||
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
|
||||
test_device!(
|
||||
scatter_add,
|
||||
scatter_add_cpu,
|
||||
scatter_add_gpu,
|
||||
scatter_add_metal
|
||||
);
|
||||
test_device!(
|
||||
slice_scatter,
|
||||
slice_scatter_cpu,
|
||||
slice_scatter_gpu,
|
||||
slice_scatter_metal
|
||||
);
|
||||
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
||||
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
||||
test_device!(var, var_cpu, var_gpu, var_metal);
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu);
|
||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
||||
test_device!(narrow, narrow_cpu, narrow_gpu);
|
||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu);
|
||||
test_device!(cat, cat_cpu, cat_gpu);
|
||||
test_device!(sum, sum_cpu, sum_gpu);
|
||||
test_device!(min, min_cpu, min_gpu);
|
||||
test_device!(max, max_cpu, max_gpu);
|
||||
test_device!(argmax, argmax_cpu, argmax_gpu);
|
||||
test_device!(argmin, argmin_cpu, argmin_gpu);
|
||||
test_device!(transpose, transpose_cpu, transpose_gpu);
|
||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
||||
test_device!(cmp, cmp_cpu, cmp_gpu);
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu);
|
||||
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
|
||||
test_device!(index_select, index_select_cpu, index_select_gpu);
|
||||
test_device!(index_add, index_add_cpu, index_add_gpu);
|
||||
test_device!(gather, gather_cpu, gather_gpu);
|
||||
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
||||
|
||||
// There was originally a bug on the CPU implementation for randn
|
||||
// https://github.com/huggingface/candle/issues/381
|
||||
@ -1144,124 +880,3 @@ fn randn_hasneg() -> Result<()> {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pad_with_same() -> Result<()> {
|
||||
let t = Tensor::arange(1f32, 5f32, &Device::Cpu)?.reshape((2, 2))?;
|
||||
let t0 = t.pad_with_same(0, 1, 2)?;
|
||||
assert_eq!(
|
||||
t0.to_vec2::<f32>()?,
|
||||
[[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0], [3.0, 4.0]]
|
||||
);
|
||||
let t1 = t.pad_with_same(1, 1, 2)?;
|
||||
assert_eq!(
|
||||
t1.to_vec2::<f32>()?,
|
||||
[[1.0, 1.0, 2.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0, 4.0]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn i64_abs() -> Result<()> {
|
||||
let t = Tensor::new(&[-42i64, 1337], &Device::Cpu)?;
|
||||
let t = t.abs()?;
|
||||
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tril_triu_eye() -> Result<()> {
|
||||
let t = Tensor::tril2(4, DType::F32, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
[
|
||||
[1.0, 0.0, 0.0, 0.0],
|
||||
[1.0, 1.0, 0.0, 0.0],
|
||||
[1.0, 1.0, 1.0, 0.0],
|
||||
[1.0, 1.0, 1.0, 1.0]
|
||||
],
|
||||
);
|
||||
let t = Tensor::triu2(4, DType::F32, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
[
|
||||
[1.0, 1.0, 1.0, 1.0],
|
||||
[0.0, 1.0, 1.0, 1.0],
|
||||
[0.0, 0.0, 1.0, 1.0],
|
||||
[0.0, 0.0, 0.0, 1.0]
|
||||
]
|
||||
);
|
||||
let t = Tensor::eye(4, DType::F32, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
[
|
||||
[1.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 1.0]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cumsum() -> Result<()> {
|
||||
let t = &[3f32, 1., 4., 1., 5.];
|
||||
let t = Tensor::new(t, &Device::Cpu)?;
|
||||
assert_eq!(t.cumsum(0)?.to_vec1::<f32>()?, [3., 4., 8., 9., 14.]);
|
||||
let t = t.unsqueeze(1)?;
|
||||
assert_eq!(
|
||||
t.cumsum(0)?.to_vec2::<f32>()?,
|
||||
[[3.0], [4.0], [8.0], [9.0], [14.0]]
|
||||
);
|
||||
assert_eq!(
|
||||
t.cumsum(1)?.to_vec2::<f32>()?,
|
||||
[[3.0], [1.0], [4.0], [1.0], [5.0]]
|
||||
);
|
||||
let t = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||
let t = Tensor::new(t, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
t.cumsum(1)?.to_vec2::<f32>()?,
|
||||
[[3.0, 4.0, 8.0, 9.0, 14.0], [2.0, 3.0, 10.0, 18.0, 20.0]],
|
||||
);
|
||||
assert_eq!(
|
||||
t.cumsum(0)?.to_vec2::<f32>()?,
|
||||
[[3.0, 1.0, 4.0, 1.0, 5.0], [5.0, 2.0, 11.0, 9.0, 7.0]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// A helper function for floating point comparison. Both a and b must be 1D Tensor and contains the same amount of data.
|
||||
/// Assertion passes if the difference of all pairs of a and b is smaller than epsilon.
|
||||
fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
|
||||
let a_vec: Vec<f64> = a.to_vec1()?;
|
||||
let b_vec: Vec<f64> = b.to_vec1()?;
|
||||
|
||||
assert_eq!(a_vec.len(), b_vec.len());
|
||||
for (a, b) in a_vec.iter().zip(b_vec.iter()) {
|
||||
assert!((a - b).abs() < epsilon);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn log_sum_exp() -> Result<()> {
|
||||
let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||
let output = input.log_sum_exp(D::Minus1)?;
|
||||
// The expectations obtained from pytorch.
|
||||
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
|
||||
assert_close(&output, &expected, 0.00001)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pow() -> Result<()> {
|
||||
let lhs = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||
let rhs = (&lhs - 2.)?;
|
||||
let res = lhs.pow(&rhs)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&res, 4)?,
|
||||
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,10 +1,15 @@
|
||||
use crate::{Result, Tensor};
|
||||
#![allow(dead_code)]
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle_core::{Result, Tensor};
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! test_device {
|
||||
// TODO: Switch to generating the two last arguments automatically once concat_idents is
|
||||
// stable. https://github.com/rust-lang/rust/issues/29599
|
||||
($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => {
|
||||
($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => {
|
||||
#[test]
|
||||
fn $test_cpu() -> Result<()> {
|
||||
$fn_name(&Device::Cpu)
|
||||
@ -15,21 +20,9 @@ macro_rules! test_device {
|
||||
fn $test_cuda() -> Result<()> {
|
||||
$fn_name(&Device::new_cuda(0)?)
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
#[test]
|
||||
fn $test_metal() -> Result<()> {
|
||||
$fn_name(&Device::new_metal(0)?)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
|
||||
let b = 10f32.powi(digits);
|
||||
let t = t.to_vec0::<f32>()?;
|
||||
Ok(f32::round(t * b) / b)
|
||||
}
|
||||
|
||||
pub fn to_vec1_round(t: &Tensor, digits: i32) -> Result<Vec<f32>> {
|
||||
let b = 10f32.powi(digits);
|
||||
let t = t.to_vec1::<f32>()?;
|
||||
@ -47,7 +40,7 @@ pub fn to_vec2_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<f32>>> {
|
||||
Ok(t)
|
||||
}
|
||||
|
||||
pub fn to_vec3_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
|
||||
pub fn to_vec3_round(t: Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
|
||||
let b = 10f32.powi(digits);
|
||||
let t = t.to_vec3::<f32>()?;
|
||||
let t = t
|
Binary file not shown.
@ -11,13 +11,10 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
byteorder = { workspace = true }
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle = { path = "../candle-core", version = "0.1.1", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.1.1" }
|
||||
hf-hub = { workspace = true}
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
memmap2 = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
rand = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
parquet = { workspace = true}
|
||||
image = { workspace = true }
|
||||
|
@ -1,73 +0,0 @@
|
||||
use hf_hub::{
|
||||
api::sync::{Api, ApiRepo},
|
||||
Repo, RepoType,
|
||||
};
|
||||
use parquet::file::reader::SerializedFileReader;
|
||||
use std::fs::File;
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("ApiError : {0}")]
|
||||
ApiError(#[from] hf_hub::api::sync::ApiError),
|
||||
|
||||
#[error("IoError : {0}")]
|
||||
IoError(#[from] std::io::Error),
|
||||
|
||||
#[error("ParquetError : {0}")]
|
||||
ParquetError(#[from] parquet::errors::ParquetError),
|
||||
}
|
||||
|
||||
fn sibling_to_parquet(
|
||||
rfilename: &str,
|
||||
repo: &ApiRepo,
|
||||
) -> Result<SerializedFileReader<File>, Error> {
|
||||
let local = repo.get(rfilename)?;
|
||||
let file = File::open(local)?;
|
||||
let reader = SerializedFileReader::new(file)?;
|
||||
Ok(reader)
|
||||
}
|
||||
|
||||
pub fn from_hub(api: &Api, dataset_id: String) -> Result<Vec<SerializedFileReader<File>>, Error> {
|
||||
let repo = Repo::with_revision(
|
||||
dataset_id,
|
||||
RepoType::Dataset,
|
||||
"refs/convert/parquet".to_string(),
|
||||
);
|
||||
let repo = api.repo(repo);
|
||||
let info = repo.info()?;
|
||||
|
||||
let files: Result<Vec<_>, _> = info
|
||||
.siblings
|
||||
.into_iter()
|
||||
.filter_map(|s| -> Option<Result<_, _>> {
|
||||
let filename = s.rfilename;
|
||||
if filename.ends_with(".parquet") {
|
||||
let reader_result = sibling_to_parquet(&filename, &repo);
|
||||
Some(reader_result)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let files = files?;
|
||||
|
||||
Ok(files)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use parquet::file::reader::FileReader;
|
||||
|
||||
#[test]
|
||||
fn test_dataset() {
|
||||
let api = Api::new().unwrap();
|
||||
let files = from_hub(
|
||||
&api,
|
||||
"hf-internal-testing/dummy_image_text_data".to_string(),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(files.len(), 1);
|
||||
assert_eq!(files[0].metadata().file_metadata().num_rows(), 20);
|
||||
}
|
||||
}
|
@ -1,6 +1,5 @@
|
||||
//! Datasets & Dataloaders for Candle
|
||||
pub mod batcher;
|
||||
pub mod hub;
|
||||
pub mod nlp;
|
||||
pub mod vision;
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user