mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Compare commits
3 Commits
0.5.1
...
matmul-slo
Author | SHA1 | Date | |
---|---|---|---|
69c1fb1ee8 | |||
c55ebaf477 | |||
4c91dd2ff4 |
7
.github/dependabot.yml
vendored
7
.github/dependabot.yml
vendored
@ -1,7 +0,0 @@
|
|||||||
version: 2
|
|
||||||
updates:
|
|
||||||
- package-ecosystem: "cargo"
|
|
||||||
directory: "/"
|
|
||||||
schedule:
|
|
||||||
interval: "weekly"
|
|
||||||
open-pull-requests-limit: 5
|
|
72
.github/workflows/ci_cuda.yaml
vendored
72
.github/workflows/ci_cuda.yaml
vendored
@ -5,15 +5,47 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
start-runner:
|
||||||
|
name: Start self-hosted EC2 runner
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
AWS_REGION: us-east-1
|
||||||
|
EC2_AMI_ID: ami-03cfed9ea28f4b002
|
||||||
|
EC2_INSTANCE_TYPE: g5.xlarge
|
||||||
|
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
|
||||||
|
EC2_SECURITY_GROUP: sg-030175c435ac141d6
|
||||||
|
outputs:
|
||||||
|
label: ${{ steps.start-ec2-runner.outputs.label }}
|
||||||
|
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
|
||||||
|
steps:
|
||||||
|
- name: Configure AWS credentials
|
||||||
|
uses: aws-actions/configure-aws-credentials@v1
|
||||||
|
with:
|
||||||
|
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||||
|
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||||
|
aws-region: ${{ env.AWS_REGION }}
|
||||||
|
- name: Start EC2 runner
|
||||||
|
id: start-ec2-runner
|
||||||
|
uses: philschmid/philschmid-ec2-github-runner@main
|
||||||
|
with:
|
||||||
|
mode: start
|
||||||
|
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
||||||
|
ec2-image-id: ${{ env.EC2_AMI_ID }}
|
||||||
|
ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }}
|
||||||
|
subnet-id: ${{ env.EC2_SUBNET_ID }}
|
||||||
|
security-group-id: ${{ env.EC2_SECURITY_GROUP }}
|
||||||
|
aws-resource-tags: > # optional, requires additional permissions
|
||||||
|
[
|
||||||
|
{"Key": "Name", "Value": "ec2-tgi-github-runner"},
|
||||||
|
{"Key": "GitHubRepository", "Value": "${{ github.repository }}"}
|
||||||
|
]
|
||||||
|
|
||||||
test-cuda:
|
test-cuda:
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
runs-on: [single-gpu, nvidia-gpu, t4, ci]
|
needs: start-runner # required to start the main job when the runner is ready
|
||||||
container:
|
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
||||||
image: nvidia/cuda:12.3.1-devel-ubuntu22.04
|
|
||||||
options: --gpus 0
|
|
||||||
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
|
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
packages: write
|
packages: write
|
||||||
@ -24,10 +56,32 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
- name: Install dependencies
|
|
||||||
run: apt-get update && apt install curl build-essential libssl-dev protobuf-compiler pkg-config -y
|
|
||||||
- name: Install Rust Stable
|
- name: Install Rust Stable
|
||||||
uses: actions-rust-lang/setup-rust-toolchain@v1
|
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
|
||||||
- uses: Swatinem/rust-cache@v2
|
- uses: Swatinem/rust-cache@v2
|
||||||
|
- run: apt-get update -y && apt-get install libssl-dev -y
|
||||||
- name: Test (cuda)
|
- name: Test (cuda)
|
||||||
run: cargo test --features cuda
|
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
|
||||||
|
stop-runner:
|
||||||
|
name: Stop self-hosted EC2 runner
|
||||||
|
needs:
|
||||||
|
- start-runner
|
||||||
|
- test-cuda
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
AWS_REGION: us-east-1
|
||||||
|
if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs
|
||||||
|
steps:
|
||||||
|
- name: Configure AWS credentials
|
||||||
|
uses: aws-actions/configure-aws-credentials@v1
|
||||||
|
with:
|
||||||
|
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||||
|
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||||
|
aws-region: ${{ env.AWS_REGION }}
|
||||||
|
- name: Stop EC2 runner
|
||||||
|
uses: philschmid/philschmid-ec2-github-runner@main
|
||||||
|
with:
|
||||||
|
mode: stop
|
||||||
|
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
||||||
|
label: ${{ needs.start-runner.outputs.label }}
|
||||||
|
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
|
||||||
|
BIN
.github/workflows/maturin.yml
vendored
BIN
.github/workflows/maturin.yml
vendored
Binary file not shown.
68
.github/workflows/python.yml
vendored
68
.github/workflows/python.yml
vendored
@ -1,68 +0,0 @@
|
|||||||
name: PyO3-CI
|
|
||||||
|
|
||||||
on:
|
|
||||||
workflow_dispatch:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
paths:
|
|
||||||
- candle-pyo3/**
|
|
||||||
pull_request:
|
|
||||||
paths:
|
|
||||||
- candle-pyo3/**
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build_and_test:
|
|
||||||
name: Check everything builds & tests
|
|
||||||
runs-on: ${{ matrix.os }}
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest] # For now, only test on Linux
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v2
|
|
||||||
|
|
||||||
- name: Install Rust
|
|
||||||
uses: actions-rs/toolchain@v1
|
|
||||||
with:
|
|
||||||
toolchain: stable
|
|
||||||
|
|
||||||
- name: Install Python
|
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
|
||||||
python-version: 3.11
|
|
||||||
architecture: "x64"
|
|
||||||
|
|
||||||
- name: Cache Cargo Registry
|
|
||||||
uses: actions/cache@v1
|
|
||||||
with:
|
|
||||||
path: ~/.cargo/registry
|
|
||||||
key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}
|
|
||||||
|
|
||||||
- name: Install Protoc
|
|
||||||
uses: arduino/setup-protoc@v2
|
|
||||||
with:
|
|
||||||
version: "25.0"
|
|
||||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
|
|
||||||
- name: Install
|
|
||||||
working-directory: ./candle-pyo3
|
|
||||||
run: |
|
|
||||||
python -m venv .env
|
|
||||||
source .env/bin/activate
|
|
||||||
pip install -U pip
|
|
||||||
pip install pytest maturin black
|
|
||||||
python -m maturin develop -r --features onnx
|
|
||||||
|
|
||||||
- name: Check style
|
|
||||||
working-directory: ./candle-pyo3
|
|
||||||
run: |
|
|
||||||
source .env/bin/activate
|
|
||||||
python stub.py --check
|
|
||||||
black --check .
|
|
||||||
|
|
||||||
- name: Run tests
|
|
||||||
working-directory: ./candle-pyo3
|
|
||||||
run: |
|
|
||||||
source .env/bin/activate
|
|
||||||
python -m pytest -s -v tests
|
|
@ -63,7 +63,7 @@ This documents the main changes to the `candle` crate.
|
|||||||
[760](https://github.com/huggingface/candle/pull/760).
|
[760](https://github.com/huggingface/candle/pull/760).
|
||||||
- Add the Segment-Anything Model (SAM) as an example
|
- Add the Segment-Anything Model (SAM) as an example
|
||||||
[773](https://github.com/huggingface/candle/pull/773).
|
[773](https://github.com/huggingface/candle/pull/773).
|
||||||
- TinyViT backbone for the segment anything example
|
- TinyViT backbone for the segemnt anything example
|
||||||
[787](https://github.com/huggingface/candle/pull/787).
|
[787](https://github.com/huggingface/candle/pull/787).
|
||||||
- Shape with holes support
|
- Shape with holes support
|
||||||
[770](https://github.com/huggingface/candle/pull/770).
|
[770](https://github.com/huggingface/candle/pull/770).
|
||||||
|
53
Cargo.toml
53
Cargo.toml
@ -7,20 +7,20 @@ members = [
|
|||||||
"candle-nn",
|
"candle-nn",
|
||||||
"candle-pyo3",
|
"candle-pyo3",
|
||||||
"candle-transformers",
|
"candle-transformers",
|
||||||
"candle-wasm-examples/*",
|
"candle-wasm-examples/llama2-c",
|
||||||
|
"candle-wasm-examples/segment-anything",
|
||||||
|
"candle-wasm-examples/whisper",
|
||||||
|
"candle-wasm-examples/yolo",
|
||||||
|
"candle-wasm-examples/bert",
|
||||||
|
"candle-wasm-examples/phi",
|
||||||
|
"candle-wasm-examples/t5",
|
||||||
"candle-wasm-tests",
|
"candle-wasm-tests",
|
||||||
"tensor-tools",
|
|
||||||
]
|
|
||||||
exclude = [
|
|
||||||
"candle-flash-attn",
|
|
||||||
"candle-kernels",
|
|
||||||
"candle-metal-kernels",
|
|
||||||
"candle-onnx",
|
|
||||||
]
|
]
|
||||||
|
exclude = ["candle-flash-attn", "candle-kernels"]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.5.1"
|
version = "0.3.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "Minimalist ML framework."
|
description = "Minimalist ML framework."
|
||||||
repository = "https://github.com/huggingface/candle"
|
repository = "https://github.com/huggingface/candle"
|
||||||
@ -29,50 +29,39 @@ categories = ["science"]
|
|||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
ab_glyph = "0.2.23"
|
|
||||||
accelerate-src = { version = "0.3.2" }
|
accelerate-src = { version = "0.3.2" }
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
byteorder = "1.4.3"
|
byteorder = "1.4.3"
|
||||||
candle = { path = "./candle-core", package = "candle-core", version = "0.5.1" }
|
|
||||||
candle-datasets = { path = "./candle-datasets", version = "0.5.1" }
|
|
||||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.5.1" }
|
|
||||||
candle-kernels = { path = "./candle-kernels", version = "0.5.1" }
|
|
||||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.5.1" }
|
|
||||||
candle-nn = { path = "./candle-nn", version = "0.5.1" }
|
|
||||||
candle-onnx = { path = "./candle-onnx", version = "0.5.1" }
|
|
||||||
candle-transformers = { path = "./candle-transformers", version = "0.5.1" }
|
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
criterion = { version = "0.5.1", default-features=false }
|
cudarc = { version = "0.9.14", features = ["f16"] }
|
||||||
cudarc = { version = "0.11.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
# TODO: Switch back to the official gemm implementation once it has caught up.
|
||||||
fancy-regex = "0.13.0"
|
gemm = { version = "0.16.0", package = "candle-gemm" }
|
||||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
|
||||||
hf-hub = "0.3.0"
|
hf-hub = "0.3.0"
|
||||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||||
hound = "3.5.1"
|
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
|
||||||
image = { version = "0.25.0", default-features = false, features = ["jpeg", "png"] }
|
imageproc = { version = "0.23.0", default-features = false }
|
||||||
imageproc = { version = "0.24.0", default-features = false }
|
|
||||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||||
libc = { version = "0.2.147" }
|
libc = { version = "0.2.147" }
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
|
memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] }
|
||||||
num_cpus = "1.15.0"
|
num_cpus = "1.15.0"
|
||||||
num-traits = "0.2.15"
|
num-traits = "0.2.15"
|
||||||
parquet = { version = "51.0.0" }
|
parquet = { version = "45.0.0" }
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
rand_distr = "0.4.3"
|
rand_distr = "0.4.3"
|
||||||
rayon = "1.7.0"
|
rayon = "1.7.0"
|
||||||
safetensors = "0.4.1"
|
rusttype = { version = "0.9", default-features = false }
|
||||||
|
safetensors = "0.3.1"
|
||||||
serde = { version = "1.0.171", features = ["derive"] }
|
serde = { version = "1.0.171", features = ["derive"] }
|
||||||
serde_plain = "1.0.2"
|
|
||||||
serde_json = "1.0.99"
|
serde_json = "1.0.99"
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
tokenizers = { version = "0.19.1", default-features = false }
|
tokenizers = { version = "0.13.4", default-features = false }
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-chrome = "0.7.1"
|
tracing-chrome = "0.7.1"
|
||||||
tracing-subscriber = "0.3.7"
|
tracing-subscriber = "0.3.7"
|
||||||
|
wav = "1.0.0"
|
||||||
yoke = { version = "0.7.2", features = ["derive"] }
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "1.1.1", default-features = false }
|
zip = { version = "0.6.6", default-features = false }
|
||||||
metal = { version = "0.27.0", features = ["mps"]}
|
|
||||||
|
|
||||||
[profile.release-with-debug]
|
[profile.release-with-debug]
|
||||||
inherits = "release"
|
inherits = "release"
|
||||||
|
134
README.md
134
README.md
@ -51,41 +51,22 @@ For more advanced examples, please have a look at the following section.
|
|||||||
These online demos run entirely in your browser:
|
These online demos run entirely in your browser:
|
||||||
- [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and
|
- [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and
|
||||||
object recognition.
|
object recognition.
|
||||||
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): speech recognition.
|
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): text to speech.
|
||||||
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
|
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
|
||||||
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
|
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
|
||||||
- [Phi-1.5, and Phi-2](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
|
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
|
||||||
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
|
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
|
||||||
- [BLIP](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning): image captioning.
|
|
||||||
|
|
||||||
We also provide a some command line based examples using state of the art models:
|
We also provide a some command line based examples using state of the art models:
|
||||||
|
|
||||||
- [LLaMA v1, v2, and v3](./candle-examples/examples/llama/): general LLM, includes
|
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
|
||||||
the SOLAR-10.7B variant.
|
|
||||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||||
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google Deepmind.
|
- [Phi-v1.5](./candle-examples/examples/phi/): a 1.3b general LLM with performance on par with LLaMA-v2 7b.
|
||||||
- [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b
|
|
||||||
Griffin based models from Google that mix attention with a RNN like state.
|
|
||||||
- [Phi-1, Phi-1.5, Phi-2, and Phi-3](./candle-examples/examples/phi/): 1.3b,
|
|
||||||
2.7b, and 3.8b general LLMs with performance on par with 7b models.
|
|
||||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||||
pre-trained on 1T tokens of English and code datasets. Also supports
|
pre-trained on 1T tokens of English and code datasets.
|
||||||
StableLM-2, a 1.6b LLM trained on 2T tokens, as well as the code variants.
|
|
||||||
- [Mamba](./candle-examples/examples/mamba/): an inference only
|
|
||||||
implementation of the Mamba state space model.
|
|
||||||
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
||||||
better performance than all publicly available 13b models as of 2023-09-28.
|
performance larger than all publicly available 13b models as of 2023-09-28.
|
||||||
- [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of
|
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
||||||
experts 8x7b general LLM with better performance than a Llama 2 70B model with
|
|
||||||
much faster inference.
|
|
||||||
- [StarCoder](./candle-examples/examples/bigcode/) and
|
|
||||||
[StarCoder2](./candle-examples/examples/starcoder2/): LLM specialized to code generation.
|
|
||||||
- [Qwen1.5](./candle-examples/examples/qwen/): Bilingual (English/Chinese) LLMs.
|
|
||||||
- [RWKV v5 and v6](./candle-examples/examples/rwkv/): An RNN with transformer level LLM
|
|
||||||
performance.
|
|
||||||
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
|
||||||
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
|
|
||||||
(English/Chinese) general LLMs with 6b and 34b parameters.
|
|
||||||
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
|
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
|
||||||
the LLaMA model using the same quantization techniques as
|
the LLaMA model using the same quantization techniques as
|
||||||
[llama.cpp](https://github.com/ggerganov/llama.cpp).
|
[llama.cpp](https://github.com/ggerganov/llama.cpp).
|
||||||
@ -93,7 +74,7 @@ We also provide a some command line based examples using state of the art models
|
|||||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/quantized/assets/aoc.gif" width="600">
|
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/quantized/assets/aoc.gif" width="600">
|
||||||
|
|
||||||
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
|
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
|
||||||
image generative model, support for the 1.5, 2.1, SDXL 1.0 and Turbo versions.
|
image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions.
|
||||||
|
|
||||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200">
|
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200">
|
||||||
|
|
||||||
@ -112,29 +93,11 @@ We also provide a some command line based examples using state of the art models
|
|||||||
|
|
||||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
|
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
|
||||||
|
|
||||||
- [SegFormer](./candle-examples/examples/segformer/): transformer based semantic segmentation model.
|
|
||||||
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
||||||
- [EnCodec](./candle-examples/examples/encodec/): high-quality audio compression
|
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
|
||||||
model using residual vector quantization.
|
|
||||||
- [MetaVoice](./candle-examples/examples/metavoice/): foundational model for
|
|
||||||
text-to-speech.
|
|
||||||
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
|
|
||||||
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
|
|
||||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||||
using self-supervision (can be used for imagenet classification, depth
|
using self-supervision (can be used for imagenet classification, depth
|
||||||
evaluation, segmentation).
|
evaluation, segmentation).
|
||||||
- [VGG](./candle-examples/examples/vgg/),
|
|
||||||
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
|
|
||||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
|
||||||
generate captions for an image.
|
|
||||||
- [CLIP](./candle-examples/examples/clip/): multi-model vision and language
|
|
||||||
model.
|
|
||||||
- [TrOCR](./candle-examples/examples/trocr/): a transformer OCR model, with
|
|
||||||
dedicated submodels for hand-writing and printed recognition.
|
|
||||||
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
|
||||||
model, generates the translated text from the input text.
|
|
||||||
- [Moondream](./candle-examples/examples/moondream/): tiny computer-vision model
|
|
||||||
that can answer real-world questions about images.
|
|
||||||
|
|
||||||
Run them using commands like:
|
Run them using commands like:
|
||||||
```
|
```
|
||||||
@ -150,7 +113,7 @@ There are also some wasm examples for whisper and
|
|||||||
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
||||||
[llama2](https://huggingface.co/spaces/lmz/candle-llama2),
|
[llama2](https://huggingface.co/spaces/lmz/candle-llama2),
|
||||||
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
|
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
|
||||||
[Phi-1.5, and Phi-2](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm),
|
[Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm),
|
||||||
[Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
|
[Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
|
||||||
|
|
||||||
For LLaMA2, run the following command to retrieve the weight files and start a
|
For LLaMA2, run the following command to retrieve the weight files and start a
|
||||||
@ -166,23 +129,8 @@ And then head over to
|
|||||||
|
|
||||||
<!--- ANCHOR: useful_libraries --->
|
<!--- ANCHOR: useful_libraries --->
|
||||||
|
|
||||||
## Useful External Resources
|
## Useful Libraries
|
||||||
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): A
|
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora) provides a LoRA implementation that conforms to the official `peft` implementation.
|
||||||
very detailed tutorial showing how to convert a PyTorch model to Candle.
|
|
||||||
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): Efficient and
|
|
||||||
ergonomic LoRA implementation for Candle. `candle-lora` has
|
|
||||||
out-of-the-box LoRA support for many models from Candle, which can be found
|
|
||||||
[here](https://github.com/EricLBuehler/candle-lora/tree/master/candle-lora-transformers/examples).
|
|
||||||
- [`optimisers`](https://github.com/KGrewal1/optimisers): A collection of optimisers
|
|
||||||
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
|
|
||||||
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
|
|
||||||
serving local LLMs including an OpenAI compatible API server.
|
|
||||||
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
|
|
||||||
- [`candle-coursera-ml`](https://github.com/vishpat/candle-coursera-ml): Implementation of ML algorithms from Coursera's [Machine Learning Specialization](https://www.coursera.org/specializations/machine-learning-introduction) course.
|
|
||||||
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
|
||||||
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
|
||||||
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
|
|
||||||
- [`candle-einops`](https://github.com/tomsanbear/candle-einops): A pure rust implementation of the python [einops](https://github.com/arogozhnikov/einops) library.
|
|
||||||
|
|
||||||
If you have an addition to this list, please submit a pull request.
|
If you have an addition to this list, please submit a pull request.
|
||||||
|
|
||||||
@ -201,45 +149,23 @@ If you have an addition to this list, please submit a pull request.
|
|||||||
- WASM support, run your models in a browser.
|
- WASM support, run your models in a browser.
|
||||||
- Included models.
|
- Included models.
|
||||||
- Language Models.
|
- Language Models.
|
||||||
- LLaMA v1, v2, and v3 with variants such as SOLAR-10.7B.
|
- LLaMA v1 and v2.
|
||||||
- Falcon.
|
- Falcon.
|
||||||
- StarCoder, StarCoder2.
|
- StarCoder.
|
||||||
- Phi 1, 1.5, 2, and 3.
|
- Phi v1.5.
|
||||||
- Mamba, Minimal Mamba
|
|
||||||
- Gemma 2b and 7b.
|
|
||||||
- Mistral 7b v0.1.
|
- Mistral 7b v0.1.
|
||||||
- Mixtral 8x7b v0.1.
|
- StableLM-3B-4E1T.
|
||||||
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
|
- T5.
|
||||||
- Replit-code-v1.5-3B.
|
|
||||||
- Bert.
|
- Bert.
|
||||||
- Yi-6B and Yi-34B.
|
- Whisper (multi-lingual support).
|
||||||
- Qwen1.5, Qwen1.5 MoE.
|
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
||||||
- RWKV v5 and v6.
|
- Wurstchen v2.
|
||||||
- Quantized LLMs.
|
|
||||||
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
|
||||||
- Mistral 7b, and 7b instruct.
|
|
||||||
- Mixtral 8x7b.
|
|
||||||
- Zephyr 7b a and b (Mistral-7b based).
|
|
||||||
- OpenChat 3.5 (Mistral-7b based).
|
|
||||||
- Text to text.
|
|
||||||
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
|
|
||||||
- Marian MT (Machine Translation).
|
|
||||||
- Text to image.
|
|
||||||
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
|
||||||
- Wurstchen v2.
|
|
||||||
- Image to text.
|
|
||||||
- BLIP.
|
|
||||||
- TrOCR.
|
|
||||||
- Audio.
|
|
||||||
- Whisper, multi-lingual speech-to-text.
|
|
||||||
- EnCodec, audio compression model.
|
|
||||||
- MetaVoice-1B, text-to-speech model.
|
|
||||||
- Computer Vision Models.
|
- Computer Vision Models.
|
||||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT,
|
- DINOv2.
|
||||||
ConvNeXTv2, MobileOne, EfficientVit (MSRA).
|
- EfficientNet.
|
||||||
- yolo-v3, yolo-v8.
|
- yolo-v3.
|
||||||
|
- yolo-v8.
|
||||||
- Segment-Anything Model (SAM).
|
- Segment-Anything Model (SAM).
|
||||||
- SegFormer.
|
|
||||||
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
||||||
- Serverless (on CPU), small and fast deployments.
|
- Serverless (on CPU), small and fast deployments.
|
||||||
- Quantization support using the llama.cpp quantized types.
|
- Quantization support using the llama.cpp quantized types.
|
||||||
@ -276,7 +202,6 @@ Cheatsheet:
|
|||||||
- [candle-datasets](./candle-datasets/): Datasets and data loaders.
|
- [candle-datasets](./candle-datasets/): Datasets and data loaders.
|
||||||
- [candle-transformers](./candle-transformers): transformers-related utilities.
|
- [candle-transformers](./candle-transformers): transformers-related utilities.
|
||||||
- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer.
|
- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer.
|
||||||
- [candle-onnx](./candle-onnx/): ONNX model evaluation.
|
|
||||||
|
|
||||||
## FAQ
|
## FAQ
|
||||||
|
|
||||||
@ -376,9 +301,9 @@ git submodule update --init
|
|||||||
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ‘...’:
|
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ‘...’:
|
||||||
```
|
```
|
||||||
|
|
||||||
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the NVCC_CCBIN environment variable.
|
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the CANDLE_NVCC_CCBIN environment variable.
|
||||||
```
|
```
|
||||||
env NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
|
env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Linking error on windows when running rustdoc or mdbook tests
|
#### Linking error on windows when running rustdoc or mdbook tests
|
||||||
@ -408,10 +333,3 @@ This may be caused by the models being loaded from `/mnt/c`, more details on
|
|||||||
|
|
||||||
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
|
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
|
||||||
error is generated.
|
error is generated.
|
||||||
|
|
||||||
#### CudaRC error
|
|
||||||
|
|
||||||
If you encounter an error like this one `called `Result::unwrap()` on an `Err` value: LoadLibraryExW { source: Os { code: 126, kind: Uncategorized, message: "The specified module could not be found." } }` on windows. To fix copy and rename these 3 files (make sure they are in path). The paths depend on your cuda version.
|
|
||||||
`c:\Windows\System32\nvcuda.dll` -> `cuda.dll`
|
|
||||||
`c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\cublas64_12.dll` -> `cublas.dll`
|
|
||||||
`c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\curand64_10.dll` -> `curand.dll`
|
|
||||||
|
@ -11,11 +11,11 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { workspace = true }
|
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-datasets = { workspace = true }
|
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
||||||
candle-nn = { workspace = true }
|
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||||
candle-transformers = { workspace = true }
|
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
||||||
candle-flash-attn = { workspace = true, optional = true }
|
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||||
safetensors = { workspace = true }
|
safetensors = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
@ -37,6 +37,7 @@ tokenizers = { workspace = true, features = ["onig"] }
|
|||||||
tracing = { workspace = true }
|
tracing = { workspace = true }
|
||||||
tracing-chrome = { workspace = true }
|
tracing-chrome = { workspace = true }
|
||||||
tracing-subscriber = { workspace = true }
|
tracing-subscriber = { workspace = true }
|
||||||
|
wav = { workspace = true }
|
||||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||||
parquet = { workspace = true }
|
parquet = { workspace = true }
|
||||||
image = { workspace = true }
|
image = { workspace = true }
|
||||||
|
@ -12,9 +12,6 @@ compute_cap
|
|||||||
8.9
|
8.9
|
||||||
```
|
```
|
||||||
|
|
||||||
You can also compile the Cuda kernels for a specific compute cap using the
|
|
||||||
`CUDA_COMPUTE_CAP=<compute cap>` environment variable.
|
|
||||||
|
|
||||||
If any of the above commands errors out, please make sure to update your Cuda version.
|
If any of the above commands errors out, please make sure to update your Cuda version.
|
||||||
|
|
||||||
2. Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) with Cuda support.
|
2. Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) with Cuda support.
|
||||||
|
@ -28,7 +28,6 @@ let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap()
|
|||||||
#[rustfmt::skip]
|
#[rustfmt::skip]
|
||||||
#[test]
|
#[test]
|
||||||
fn book_hub_2() {
|
fn book_hub_2() {
|
||||||
{
|
|
||||||
// ANCHOR: book_hub_2
|
// ANCHOR: book_hub_2
|
||||||
use candle::Device;
|
use candle::Device;
|
||||||
use hf_hub::api::sync::Api;
|
use hf_hub::api::sync::Api;
|
||||||
@ -46,10 +45,9 @@ let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap()
|
|||||||
assert_eq!(weights.len(), 206);
|
assert_eq!(weights.len(), 206);
|
||||||
}
|
}
|
||||||
|
|
||||||
// #[rustfmt::skip]
|
#[rustfmt::skip]
|
||||||
// #[test]
|
#[test]
|
||||||
// fn book_hub_3() {
|
fn book_hub_3() {
|
||||||
{
|
|
||||||
// ANCHOR: book_hub_3
|
// ANCHOR: book_hub_3
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use hf_hub::api::sync::Api;
|
use hf_hub::api::sync::Api;
|
||||||
@ -81,7 +79,7 @@ let mut tp_shape = view.shape().to_vec();
|
|||||||
let size = tp_shape[0];
|
let size = tp_shape[0];
|
||||||
|
|
||||||
if size % world_size != 0 {
|
if size % world_size != 0 {
|
||||||
panic!("The dimension is not divisible by `world_size`");
|
panic!("The dimension is not divisble by `world_size`");
|
||||||
}
|
}
|
||||||
let block_size = size / world_size;
|
let block_size = size / world_size;
|
||||||
let start = rank * block_size;
|
let start = rank * block_size;
|
||||||
@ -104,7 +102,6 @@ let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).un
|
|||||||
assert_eq!(view.shape(), &[768, 768]);
|
assert_eq!(view.shape(), &[768, 768]);
|
||||||
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
#[rustfmt::skip]
|
#[rustfmt::skip]
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -12,9 +12,7 @@ readme = "README.md"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
candle-kernels = { workspace = true, optional = true }
|
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
|
||||||
candle-metal-kernels = { workspace = true, optional = true }
|
|
||||||
metal = { workspace = true, optional = true}
|
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
gemm = { workspace = true }
|
gemm = { workspace = true }
|
||||||
half = { workspace = true }
|
half = { workspace = true }
|
||||||
@ -34,8 +32,6 @@ zip = { workspace = true }
|
|||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
clap = { workspace = true }
|
clap = { workspace = true }
|
||||||
criterion = { workspace = true }
|
|
||||||
|
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
@ -43,8 +39,3 @@ cuda = ["cudarc", "dep:candle-kernels"]
|
|||||||
cudnn = ["cuda", "cudarc/cudnn"]
|
cudnn = ["cuda", "cudarc/cudnn"]
|
||||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||||
metal = ["dep:metal", "dep:candle-metal-kernels"]
|
|
||||||
|
|
||||||
[[bench]]
|
|
||||||
name = "bench_main"
|
|
||||||
harness = false
|
|
||||||
|
@ -1,12 +0,0 @@
|
|||||||
mod benchmarks;
|
|
||||||
|
|
||||||
use criterion::criterion_main;
|
|
||||||
criterion_main!(
|
|
||||||
benchmarks::affine::benches,
|
|
||||||
benchmarks::matmul::benches,
|
|
||||||
benchmarks::random::benches,
|
|
||||||
benchmarks::where_cond::benches,
|
|
||||||
benchmarks::conv_transpose2d::benches,
|
|
||||||
benchmarks::qmatmul::benches,
|
|
||||||
benchmarks::unary::benches
|
|
||||||
);
|
|
@ -1,43 +0,0 @@
|
|||||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
|
||||||
use candle_core::{DType, Device, Tensor};
|
|
||||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
fn run(a: &Tensor) {
|
|
||||||
a.affine(12.34, 56.78).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
|
||||||
let b = 1;
|
|
||||||
let m = 1024;
|
|
||||||
let k = 1024;
|
|
||||||
|
|
||||||
let tensor = Tensor::zeros((b, m, k), dtype, &device).unwrap();
|
|
||||||
|
|
||||||
let flops = b * m * k * dtype.size_in_bytes();
|
|
||||||
|
|
||||||
let mut group = c.benchmark_group(device.bench_name(name));
|
|
||||||
group.throughput(Throughput::Bytes(flops as u64));
|
|
||||||
group.bench_function("iter", move |b| {
|
|
||||||
b.iter_custom(|iters| {
|
|
||||||
let start = Instant::now();
|
|
||||||
for _i in 0..iters {
|
|
||||||
run(black_box(&tensor));
|
|
||||||
}
|
|
||||||
device.sync().unwrap();
|
|
||||||
start.elapsed()
|
|
||||||
})
|
|
||||||
});
|
|
||||||
group.finish();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn criterion_benchmark(c: &mut Criterion) {
|
|
||||||
let handler = BenchDeviceHandler::new().unwrap();
|
|
||||||
for device in handler.devices {
|
|
||||||
run_affine_benchmark(c, &device, DType::F32, "affine_f32");
|
|
||||||
run_affine_benchmark(c, &device, DType::F16, "affine_f16");
|
|
||||||
run_affine_benchmark(c, &device, DType::BF16, "affine_bf16");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
criterion_group!(benches, criterion_benchmark);
|
|
@ -1,59 +0,0 @@
|
|||||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
|
||||||
use candle_core::{DType, Device, Tensor};
|
|
||||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
fn run(
|
|
||||||
x: &Tensor,
|
|
||||||
k: &Tensor,
|
|
||||||
padding: usize,
|
|
||||||
output_padding: usize,
|
|
||||||
stride: usize,
|
|
||||||
dilation: usize,
|
|
||||||
) {
|
|
||||||
x.conv_transpose2d(k, padding, output_padding, stride, dilation)
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
|
||||||
let t = Tensor::arange(0.0f32, 10000.0, device)
|
|
||||||
.unwrap()
|
|
||||||
.reshape((1, 4, 50, 50))
|
|
||||||
.unwrap()
|
|
||||||
.to_dtype(dtype)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let kernel = Tensor::arange(0.0f32, 100.0, device)
|
|
||||||
.unwrap()
|
|
||||||
.reshape((4, 1, 5, 5))
|
|
||||||
.unwrap()
|
|
||||||
.to_dtype(dtype)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let flops = t.dims().iter().product::<usize>() * dtype.size_in_bytes();
|
|
||||||
|
|
||||||
let mut group = c.benchmark_group(device.bench_name(name));
|
|
||||||
group.throughput(Throughput::Bytes(flops as u64));
|
|
||||||
group.bench_function("iter", move |b| {
|
|
||||||
b.iter_custom(|iters| {
|
|
||||||
let start = Instant::now();
|
|
||||||
for _i in 0..iters {
|
|
||||||
run(black_box(&t), black_box(&kernel), 1, 0, 1, 2);
|
|
||||||
}
|
|
||||||
device.sync().unwrap();
|
|
||||||
start.elapsed()
|
|
||||||
})
|
|
||||||
});
|
|
||||||
group.finish();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn criterion_benchmark(c: &mut Criterion) {
|
|
||||||
let handler = BenchDeviceHandler::new().unwrap();
|
|
||||||
for device in handler.devices {
|
|
||||||
run_benchmark(c, &device, DType::F32, "conv_transpose2d_f32");
|
|
||||||
run_benchmark(c, &device, DType::F16, "conv_transpose2d_f16");
|
|
||||||
run_benchmark(c, &device, DType::BF16, "conv_transpose2d_bf16");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
criterion_group!(benches, criterion_benchmark);
|
|
@ -1,44 +0,0 @@
|
|||||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
|
||||||
use candle_core::{DType, Device, Tensor};
|
|
||||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
fn run(a: &Tensor, b: &Tensor) {
|
|
||||||
a.matmul(&b.t().unwrap()).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_bench(c: &mut Criterion, device: &Device) {
|
|
||||||
let b = 1;
|
|
||||||
let m = 1;
|
|
||||||
let n = 2048;
|
|
||||||
let k = 2048;
|
|
||||||
|
|
||||||
let dtype = DType::F32;
|
|
||||||
let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap();
|
|
||||||
let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap();
|
|
||||||
|
|
||||||
let flops = b * m * n * k;
|
|
||||||
|
|
||||||
let mut group = c.benchmark_group(device.bench_name("matmul"));
|
|
||||||
group.throughput(Throughput::Bytes(flops as u64));
|
|
||||||
group.bench_function("iter", move |b| {
|
|
||||||
b.iter_custom(|iters| {
|
|
||||||
let start = Instant::now();
|
|
||||||
for _i in 0..iters {
|
|
||||||
run(black_box(&lhs), black_box(&rhs));
|
|
||||||
}
|
|
||||||
device.sync().unwrap();
|
|
||||||
start.elapsed()
|
|
||||||
})
|
|
||||||
});
|
|
||||||
group.finish();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn criterion_benchmark(c: &mut Criterion) {
|
|
||||||
let handler = BenchDeviceHandler::new().unwrap();
|
|
||||||
for device in handler.devices {
|
|
||||||
run_bench(c, &device);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
criterion_group!(benches, criterion_benchmark);
|
|
@ -1,69 +0,0 @@
|
|||||||
pub(crate) mod affine;
|
|
||||||
pub(crate) mod conv_transpose2d;
|
|
||||||
pub(crate) mod matmul;
|
|
||||||
pub(crate) mod qmatmul;
|
|
||||||
pub(crate) mod random;
|
|
||||||
pub(crate) mod unary;
|
|
||||||
pub(crate) mod where_cond;
|
|
||||||
|
|
||||||
use candle_core::{Device, Result};
|
|
||||||
|
|
||||||
pub(crate) trait BenchDevice {
|
|
||||||
fn sync(&self) -> Result<()>;
|
|
||||||
|
|
||||||
fn bench_name<S: Into<String>>(&self, name: S) -> String;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl BenchDevice for Device {
|
|
||||||
fn sync(&self) -> Result<()> {
|
|
||||||
match self {
|
|
||||||
Device::Cpu => Ok(()),
|
|
||||||
Device::Cuda(device) => {
|
|
||||||
#[cfg(feature = "cuda")]
|
|
||||||
return Ok(device.synchronize()?);
|
|
||||||
#[cfg(not(feature = "cuda"))]
|
|
||||||
panic!("Cuda device without cuda feature enabled: {:?}", device)
|
|
||||||
}
|
|
||||||
Device::Metal(device) => {
|
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
return Ok(device.wait_until_completed()?);
|
|
||||||
#[cfg(not(feature = "metal"))]
|
|
||||||
panic!("Metal device without metal feature enabled: {:?}", device)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn bench_name<S: Into<String>>(&self, name: S) -> String {
|
|
||||||
match self {
|
|
||||||
Device::Cpu => {
|
|
||||||
let cpu_type = if cfg!(feature = "accelerate") {
|
|
||||||
"accelerate"
|
|
||||||
} else if cfg!(feature = "mkl") {
|
|
||||||
"mkl"
|
|
||||||
} else {
|
|
||||||
"cpu"
|
|
||||||
};
|
|
||||||
format!("{}_{}", cpu_type, name.into())
|
|
||||||
}
|
|
||||||
Device::Cuda(_) => format!("cuda_{}", name.into()),
|
|
||||||
Device::Metal(_) => format!("metal_{}", name.into()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct BenchDeviceHandler {
|
|
||||||
devices: Vec<Device>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl BenchDeviceHandler {
|
|
||||||
pub fn new() -> Result<Self> {
|
|
||||||
let mut devices = Vec::new();
|
|
||||||
if cfg!(feature = "metal") {
|
|
||||||
devices.push(Device::new_metal(0)?);
|
|
||||||
} else if cfg!(feature = "cuda") {
|
|
||||||
devices.push(Device::new_cuda(0)?);
|
|
||||||
}
|
|
||||||
devices.push(Device::Cpu);
|
|
||||||
Ok(Self { devices })
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,72 +0,0 @@
|
|||||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
|
||||||
use candle_core::{
|
|
||||||
quantized::{self, GgmlDType, QMatMul},
|
|
||||||
Device, Module, Tensor,
|
|
||||||
};
|
|
||||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
fn run(matmul: &QMatMul, x: &Tensor) {
|
|
||||||
matmul.forward(&x).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
|
|
||||||
let b = 1;
|
|
||||||
let m = 1;
|
|
||||||
let n = 1024;
|
|
||||||
let k = 1024;
|
|
||||||
|
|
||||||
let lhs = (0..(m * k))
|
|
||||||
.map(|v| v as f32 / (m * k) as f32)
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
let rhs = (0..(k * n))
|
|
||||||
.map(|v| v as f32 / (n * k) as f32)
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
let lhs = Tensor::from_slice(&lhs, (m, k), device).unwrap();
|
|
||||||
let rhs = Tensor::from_slice(&rhs, (k, n), device).unwrap();
|
|
||||||
|
|
||||||
let qtensor = quantized::QTensor::quantize(&rhs.t().unwrap(), dtype).unwrap();
|
|
||||||
let matmul = quantized::QMatMul::from_qtensor(qtensor).unwrap();
|
|
||||||
|
|
||||||
let flops = b * m * n * k;
|
|
||||||
|
|
||||||
let mut group = c.benchmark_group(device.bench_name(format!("qmatmul_{:?}", dtype)));
|
|
||||||
group.sample_size(200);
|
|
||||||
group.throughput(Throughput::Bytes(flops as u64));
|
|
||||||
group.bench_function("iter", move |b| {
|
|
||||||
b.iter_custom(|iters| {
|
|
||||||
let start = Instant::now();
|
|
||||||
for _i in 0..iters {
|
|
||||||
run(black_box(&matmul), black_box(&lhs));
|
|
||||||
}
|
|
||||||
device.sync().unwrap();
|
|
||||||
start.elapsed()
|
|
||||||
})
|
|
||||||
});
|
|
||||||
group.finish();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn criterion_benchmark(c: &mut Criterion) {
|
|
||||||
let handler = BenchDeviceHandler::new().unwrap();
|
|
||||||
for device in handler.devices {
|
|
||||||
for dtype in vec![
|
|
||||||
GgmlDType::F32,
|
|
||||||
GgmlDType::F16,
|
|
||||||
GgmlDType::Q4_0,
|
|
||||||
GgmlDType::Q4_1,
|
|
||||||
GgmlDType::Q5_0,
|
|
||||||
GgmlDType::Q5_1,
|
|
||||||
GgmlDType::Q8_0,
|
|
||||||
GgmlDType::Q2K,
|
|
||||||
GgmlDType::Q3K,
|
|
||||||
GgmlDType::Q4K,
|
|
||||||
GgmlDType::Q5K,
|
|
||||||
GgmlDType::Q6K,
|
|
||||||
] {
|
|
||||||
run_bench(c, &device, dtype);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
criterion_group!(benches, criterion_benchmark);
|
|
@ -1,63 +0,0 @@
|
|||||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
|
||||||
use candle_core::{DType, Device, Tensor};
|
|
||||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
fn rand_uniform(a: &Tensor) {
|
|
||||||
a.rand_like(-1.0, 123.0).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rand_normal(a: &Tensor) {
|
|
||||||
a.randn_like(100.0, 15.0).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_random_bench(c: &mut Criterion, device: &Device) {
|
|
||||||
let b = 1;
|
|
||||||
|
|
||||||
let rows = 2048;
|
|
||||||
let cols = 2048;
|
|
||||||
|
|
||||||
let dtype = DType::F32;
|
|
||||||
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
|
|
||||||
|
|
||||||
let flops = b * rows * cols * dtype.size_in_bytes();
|
|
||||||
|
|
||||||
let mut group = c.benchmark_group(device.bench_name("random_uniform"));
|
|
||||||
group.throughput(Throughput::Bytes(flops as u64));
|
|
||||||
group.bench_function("iter", move |benches| {
|
|
||||||
benches.iter_custom(|iters| {
|
|
||||||
let start = Instant::now();
|
|
||||||
for _i in 0..iters {
|
|
||||||
rand_uniform(black_box(&tensor));
|
|
||||||
}
|
|
||||||
device.sync().unwrap();
|
|
||||||
start.elapsed()
|
|
||||||
})
|
|
||||||
});
|
|
||||||
group.finish();
|
|
||||||
|
|
||||||
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
|
|
||||||
|
|
||||||
let mut group = c.benchmark_group(device.bench_name("random_normal"));
|
|
||||||
group.throughput(Throughput::Bytes(flops as u64));
|
|
||||||
group.bench_function("iter", move |benches| {
|
|
||||||
benches.iter_custom(|iters| {
|
|
||||||
let start = Instant::now();
|
|
||||||
for _i in 0..iters {
|
|
||||||
rand_normal(black_box(&tensor));
|
|
||||||
}
|
|
||||||
device.sync().unwrap();
|
|
||||||
start.elapsed()
|
|
||||||
})
|
|
||||||
});
|
|
||||||
group.finish();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn criterion_benchmark(c: &mut Criterion) {
|
|
||||||
let handler = BenchDeviceHandler::new().unwrap();
|
|
||||||
for device in handler.devices {
|
|
||||||
run_random_bench(c, &device);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
criterion_group!(benches, criterion_benchmark);
|
|
@ -1,49 +0,0 @@
|
|||||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
|
||||||
use candle_core::{DType, Device, Tensor};
|
|
||||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
fn run(a: &Tensor) {
|
|
||||||
a.sqrt().unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_unary_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
|
||||||
let b = 1;
|
|
||||||
let m = 1024;
|
|
||||||
let k = 1024;
|
|
||||||
|
|
||||||
let tensor = Tensor::arange(0.0f32, (b * m * k) as f32, &device)
|
|
||||||
.unwrap()
|
|
||||||
.to_dtype(dtype)
|
|
||||||
.unwrap()
|
|
||||||
.reshape((b, m, k))
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let flops = b * m * k * dtype.size_in_bytes();
|
|
||||||
|
|
||||||
let mut group = c.benchmark_group(device.bench_name(name));
|
|
||||||
group.throughput(Throughput::Bytes(flops as u64));
|
|
||||||
group.bench_function("iter", move |b| {
|
|
||||||
b.iter_custom(|iters| {
|
|
||||||
let start = Instant::now();
|
|
||||||
for _i in 0..iters {
|
|
||||||
run(black_box(&tensor));
|
|
||||||
}
|
|
||||||
device.sync().unwrap();
|
|
||||||
start.elapsed()
|
|
||||||
})
|
|
||||||
});
|
|
||||||
group.finish();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn criterion_benchmark(c: &mut Criterion) {
|
|
||||||
let handler = BenchDeviceHandler::new().unwrap();
|
|
||||||
for device in handler.devices {
|
|
||||||
for dtype in [DType::F32, DType::BF16, DType::F16] {
|
|
||||||
let name = format!("sqrt_{:?}", dtype);
|
|
||||||
run_unary_benchmark(c, &device, dtype, &name);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
criterion_group!(benches, criterion_benchmark);
|
|
@ -1,64 +0,0 @@
|
|||||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
|
||||||
use candle_core::{DType, Device, Tensor};
|
|
||||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
fn run(a: &Tensor, b: &Tensor, c: &Tensor) {
|
|
||||||
a.where_cond(b, c).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
const fn create_cond_arr<const N: usize>() -> [u8; N] {
|
|
||||||
let mut arr = [0u8; N];
|
|
||||||
let mut i = 0;
|
|
||||||
while i < N {
|
|
||||||
arr[i] = (i % 2) as u8;
|
|
||||||
i += 1;
|
|
||||||
}
|
|
||||||
arr
|
|
||||||
}
|
|
||||||
|
|
||||||
const B: usize = 1;
|
|
||||||
const M: usize = 1024;
|
|
||||||
const K: usize = 1024;
|
|
||||||
const SIZE: usize = B * M * K;
|
|
||||||
|
|
||||||
const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
|
|
||||||
|
|
||||||
fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
|
||||||
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), &device).unwrap();
|
|
||||||
let on_true = Tensor::ones((B, M, K), dtype, &device).unwrap();
|
|
||||||
let on_false = Tensor::zeros((B, M, K), dtype, &device).unwrap();
|
|
||||||
|
|
||||||
let elements = B * M * K;
|
|
||||||
// E.g. 2 f32 tensors + 1 u8 tensor
|
|
||||||
let flops = (2 * elements * dtype.size_in_bytes()) + elements;
|
|
||||||
|
|
||||||
let mut group = c.benchmark_group(device.bench_name(name));
|
|
||||||
group.throughput(Throughput::Bytes(flops as u64));
|
|
||||||
group.bench_function("iter", move |b| {
|
|
||||||
b.iter_custom(|iters| {
|
|
||||||
let start = Instant::now();
|
|
||||||
for _i in 0..iters {
|
|
||||||
run(
|
|
||||||
black_box(&tensor),
|
|
||||||
black_box(&on_true),
|
|
||||||
black_box(&on_false),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
device.sync().unwrap();
|
|
||||||
start.elapsed()
|
|
||||||
})
|
|
||||||
});
|
|
||||||
group.finish();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn criterion_benchmark(c: &mut Criterion) {
|
|
||||||
let device = BenchDeviceHandler::new().unwrap();
|
|
||||||
for d in device.devices {
|
|
||||||
run_where_cond_benchmark(c, &d, DType::F32, "where_cond_f32");
|
|
||||||
run_where_cond_benchmark(c, &d, DType::BF16, "where_cond_bf16");
|
|
||||||
run_where_cond_benchmark(c, &d, DType::F16, "where_cond_f16");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
criterion_group!(benches, criterion_benchmark);
|
|
@ -8,10 +8,11 @@ use anyhow::Result;
|
|||||||
use candle_core::{Device, Tensor};
|
use candle_core::{Device, Tensor};
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?;
|
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
|
||||||
let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?;
|
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
|
||||||
let new_a = a.slice_scatter(&b, 1, 2)?;
|
let start = std::time::Instant::now();
|
||||||
assert_eq!(a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
let res = inp.conv2d(&w, 0, 1, 1, 1)?;
|
||||||
assert_eq!(new_a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
println!("{:?}", start.elapsed());
|
||||||
|
println!("{res:?}");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -9,22 +9,21 @@ use candle_core::{Device, Tensor};
|
|||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let device = Device::new_cuda(0)?;
|
let device = Device::new_cuda(0)?;
|
||||||
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?;
|
let in_t = Tensor::rand(-1f32, 1f32, (1, 3, 12, 7), &device)?;
|
||||||
candle_core::cuda::set_gemm_reduced_precision_f32(false);
|
let k_t = Tensor::rand(-1f32, 1f32, (6, 3, 1, 1), &device)?;
|
||||||
let _x1 = x.matmul(&x)?;
|
let out_t = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
|
||||||
drop(_x1);
|
println!("{out_t}");
|
||||||
let start_time = std::time::Instant::now();
|
let in_t = in_t.to_device(&Device::Cpu)?;
|
||||||
let _x1 = x.matmul(&x)?;
|
let k_t = k_t.to_device(&Device::Cpu)?;
|
||||||
device.synchronize()?;
|
let out_t2 = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
|
||||||
println!("fp32: {:?}", start_time.elapsed());
|
let diff = (out_t.to_device(&Device::Cpu)? - out_t2)?
|
||||||
drop(_x1);
|
.sqr()?
|
||||||
candle_core::cuda::set_gemm_reduced_precision_f32(true);
|
.sum_all()?;
|
||||||
let _x1 = x.matmul(&x)?;
|
println!("{diff}");
|
||||||
drop(_x1);
|
|
||||||
let start_time = std::time::Instant::now();
|
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
|
||||||
let _x1 = x.matmul(&x)?;
|
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
|
||||||
device.synchronize()?;
|
let res = t.conv2d(&w, 1, 1, 1, 1)?;
|
||||||
println!("tf32: {:?}", start_time.elapsed());
|
println!("{res:?}");
|
||||||
drop(_x1);
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use candle::quantized::{gguf_file, GgmlDType, QTensor};
|
use candle_core::quantized::{gguf_file, k_quants, QTensor};
|
||||||
use candle::{Device, Result};
|
use candle_core::{Device, Result, Tensor};
|
||||||
use clap::{Parser, Subcommand, ValueEnum};
|
use clap::{Parser, Subcommand, ValueEnum};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
|
|
||||||
@ -11,7 +11,12 @@ enum QuantizationMode {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl QuantizationMode {
|
impl QuantizationMode {
|
||||||
fn quantize(&self, name: &str, tensor: QTensor, dtype: GgmlDType) -> Result<QTensor> {
|
fn quantize(
|
||||||
|
&self,
|
||||||
|
name: &str,
|
||||||
|
tensor: QTensor,
|
||||||
|
default: fn(&Tensor) -> Result<QTensor>,
|
||||||
|
) -> Result<QTensor> {
|
||||||
match self {
|
match self {
|
||||||
Self::Llama => {
|
Self::Llama => {
|
||||||
// Same behavior as the llama.cpp quantization.
|
// Same behavior as the llama.cpp quantization.
|
||||||
@ -19,9 +24,9 @@ impl QuantizationMode {
|
|||||||
if should_quantize {
|
if should_quantize {
|
||||||
let tensor = tensor.dequantize(&Device::Cpu)?;
|
let tensor = tensor.dequantize(&Device::Cpu)?;
|
||||||
if name == "output.weight" {
|
if name == "output.weight" {
|
||||||
QTensor::quantize(&tensor, GgmlDType::Q6K)
|
QTensor::quantize::<k_quants::BlockQ6K>(&tensor)
|
||||||
} else {
|
} else {
|
||||||
QTensor::quantize(&tensor, dtype)
|
default(&tensor)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Ok(tensor)
|
Ok(tensor)
|
||||||
@ -55,27 +60,6 @@ enum Quantization {
|
|||||||
F32,
|
F32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Quantization {
|
|
||||||
fn dtype(&self) -> GgmlDType {
|
|
||||||
match self {
|
|
||||||
Quantization::Q4_0 => GgmlDType::Q4_0,
|
|
||||||
Quantization::Q4_1 => GgmlDType::Q4_1,
|
|
||||||
Quantization::Q5_0 => GgmlDType::Q5_0,
|
|
||||||
Quantization::Q5_1 => GgmlDType::Q5_1,
|
|
||||||
Quantization::Q8_0 => GgmlDType::Q8_0,
|
|
||||||
Quantization::Q8_1 => GgmlDType::Q8_1,
|
|
||||||
Quantization::Q2k => GgmlDType::Q2K,
|
|
||||||
Quantization::Q3k => GgmlDType::Q3K,
|
|
||||||
Quantization::Q4k => GgmlDType::Q4K,
|
|
||||||
Quantization::Q5k => GgmlDType::Q5K,
|
|
||||||
Quantization::Q6k => GgmlDType::Q6K,
|
|
||||||
Quantization::Q8k => GgmlDType::Q8K,
|
|
||||||
Quantization::F16 => GgmlDType::F16,
|
|
||||||
Quantization::F32 => GgmlDType::F32,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(ValueEnum, Debug, Clone)]
|
#[derive(ValueEnum, Debug, Clone)]
|
||||||
enum Format {
|
enum Format {
|
||||||
Safetensors,
|
Safetensors,
|
||||||
@ -117,26 +101,8 @@ enum Command {
|
|||||||
verbose: bool,
|
verbose: bool,
|
||||||
},
|
},
|
||||||
|
|
||||||
Print {
|
|
||||||
file: std::path::PathBuf,
|
|
||||||
|
|
||||||
names: Vec<String>,
|
|
||||||
|
|
||||||
/// The file format to use, if unspecified infer from the file extension.
|
|
||||||
#[arg(long, value_enum)]
|
|
||||||
format: Option<Format>,
|
|
||||||
|
|
||||||
/// Print the whole content of each tensor.
|
|
||||||
#[arg(long)]
|
|
||||||
full: bool,
|
|
||||||
|
|
||||||
/// Line width for printing the tensors.
|
|
||||||
#[arg(long)]
|
|
||||||
line_width: Option<usize>,
|
|
||||||
},
|
|
||||||
|
|
||||||
Quantize {
|
Quantize {
|
||||||
/// The input file(s), in safetensors format.
|
/// The input file, in gguf format.
|
||||||
in_file: Vec<std::path::PathBuf>,
|
in_file: Vec<std::path::PathBuf>,
|
||||||
|
|
||||||
/// The output file, in gguf format.
|
/// The output file, in gguf format.
|
||||||
@ -151,15 +117,6 @@ enum Command {
|
|||||||
#[arg(long, value_enum, default_value_t = QuantizationMode::Llama)]
|
#[arg(long, value_enum, default_value_t = QuantizationMode::Llama)]
|
||||||
mode: QuantizationMode,
|
mode: QuantizationMode,
|
||||||
},
|
},
|
||||||
|
|
||||||
Dequantize {
|
|
||||||
/// The input file, in gguf format.
|
|
||||||
in_file: std::path::PathBuf,
|
|
||||||
|
|
||||||
/// The output file, in safetensors format.
|
|
||||||
#[arg(long)]
|
|
||||||
out_file: std::path::PathBuf,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug, Clone)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
@ -168,20 +125,7 @@ struct Args {
|
|||||||
command: Command,
|
command: Command,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_print(
|
fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> Result<()> {
|
||||||
file: &std::path::PathBuf,
|
|
||||||
names: Vec<String>,
|
|
||||||
format: Option<Format>,
|
|
||||||
full: bool,
|
|
||||||
line_width: Option<usize>,
|
|
||||||
device: &Device,
|
|
||||||
) -> Result<()> {
|
|
||||||
if full {
|
|
||||||
candle::display::set_print_options_full();
|
|
||||||
}
|
|
||||||
if let Some(line_width) = line_width {
|
|
||||||
candle::display::set_line_width(line_width)
|
|
||||||
}
|
|
||||||
let format = match format {
|
let format = match format {
|
||||||
Some(format) => format,
|
Some(format) => format,
|
||||||
None => match Format::infer(file) {
|
None => match Format::infer(file) {
|
||||||
@ -196,98 +140,7 @@ fn run_print(
|
|||||||
};
|
};
|
||||||
match format {
|
match format {
|
||||||
Format::Npz => {
|
Format::Npz => {
|
||||||
let tensors = candle::npy::NpzTensors::new(file)?;
|
let tensors = candle_core::npy::NpzTensors::new(file)?;
|
||||||
for name in names.iter() {
|
|
||||||
println!("==== {name} ====");
|
|
||||||
match tensors.get(name)? {
|
|
||||||
Some(tensor) => println!("{tensor}"),
|
|
||||||
None => println!("not found"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Format::Safetensors => {
|
|
||||||
use candle::safetensors::Load;
|
|
||||||
let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? };
|
|
||||||
let tensors: std::collections::HashMap<_, _> = tensors.tensors().into_iter().collect();
|
|
||||||
for name in names.iter() {
|
|
||||||
println!("==== {name} ====");
|
|
||||||
match tensors.get(name) {
|
|
||||||
Some(tensor_view) => {
|
|
||||||
let tensor = tensor_view.load(device)?;
|
|
||||||
println!("{tensor}")
|
|
||||||
}
|
|
||||||
None => println!("not found"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Format::Pth => {
|
|
||||||
let pth_file = candle::pickle::PthTensors::new(file, None)?;
|
|
||||||
for name in names.iter() {
|
|
||||||
println!("==== {name} ====");
|
|
||||||
match pth_file.get(name)? {
|
|
||||||
Some(tensor) => {
|
|
||||||
println!("{tensor}")
|
|
||||||
}
|
|
||||||
None => println!("not found"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Format::Pickle => {
|
|
||||||
candle::bail!("pickle format is not supported for print")
|
|
||||||
}
|
|
||||||
Format::Ggml => {
|
|
||||||
let mut file = std::fs::File::open(file)?;
|
|
||||||
let content = candle::quantized::ggml_file::Content::read(&mut file, device)?;
|
|
||||||
for name in names.iter() {
|
|
||||||
println!("==== {name} ====");
|
|
||||||
match content.tensors.get(name) {
|
|
||||||
Some(tensor) => {
|
|
||||||
let tensor = tensor.dequantize(device)?;
|
|
||||||
println!("{tensor}")
|
|
||||||
}
|
|
||||||
None => println!("not found"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Format::Gguf => {
|
|
||||||
let mut file = std::fs::File::open(file)?;
|
|
||||||
let content = gguf_file::Content::read(&mut file)?;
|
|
||||||
for name in names.iter() {
|
|
||||||
println!("==== {name} ====");
|
|
||||||
match content.tensor(&mut file, name, device) {
|
|
||||||
Ok(tensor) => {
|
|
||||||
let tensor = tensor.dequantize(device)?;
|
|
||||||
println!("{tensor}")
|
|
||||||
}
|
|
||||||
Err(_) => println!("not found"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_ls(
|
|
||||||
file: &std::path::PathBuf,
|
|
||||||
format: Option<Format>,
|
|
||||||
verbose: bool,
|
|
||||||
device: &Device,
|
|
||||||
) -> Result<()> {
|
|
||||||
let format = match format {
|
|
||||||
Some(format) => format,
|
|
||||||
None => match Format::infer(file) {
|
|
||||||
Some(format) => format,
|
|
||||||
None => {
|
|
||||||
println!(
|
|
||||||
"{file:?}: cannot infer format from file extension, use the --format flag"
|
|
||||||
);
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
},
|
|
||||||
};
|
|
||||||
match format {
|
|
||||||
Format::Npz => {
|
|
||||||
let tensors = candle::npy::NpzTensors::new(file)?;
|
|
||||||
let mut names = tensors.names();
|
let mut names = tensors.names();
|
||||||
names.sort();
|
names.sort();
|
||||||
for name in names {
|
for name in names {
|
||||||
@ -299,12 +152,12 @@ fn run_ls(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Format::Safetensors => {
|
Format::Safetensors => {
|
||||||
let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? };
|
let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? };
|
||||||
let mut tensors = tensors.tensors();
|
let mut tensors = tensors.tensors();
|
||||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||||
for (name, view) in tensors.iter() {
|
for (name, view) in tensors.iter() {
|
||||||
let dtype = view.dtype();
|
let dtype = view.dtype();
|
||||||
let dtype = match candle::DType::try_from(dtype) {
|
let dtype = match candle_core::DType::try_from(dtype) {
|
||||||
Ok(dtype) => format!("{dtype:?}"),
|
Ok(dtype) => format!("{dtype:?}"),
|
||||||
Err(_) => format!("{dtype:?}"),
|
Err(_) => format!("{dtype:?}"),
|
||||||
};
|
};
|
||||||
@ -313,7 +166,7 @@ fn run_ls(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Format::Pth => {
|
Format::Pth => {
|
||||||
let mut tensors = candle::pickle::read_pth_tensor_info(file, verbose, None)?;
|
let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose)?;
|
||||||
tensors.sort_by(|a, b| a.name.cmp(&b.name));
|
tensors.sort_by(|a, b| a.name.cmp(&b.name));
|
||||||
for tensor_info in tensors.iter() {
|
for tensor_info in tensors.iter() {
|
||||||
println!(
|
println!(
|
||||||
@ -330,7 +183,7 @@ fn run_ls(
|
|||||||
Format::Pickle => {
|
Format::Pickle => {
|
||||||
let file = std::fs::File::open(file)?;
|
let file = std::fs::File::open(file)?;
|
||||||
let mut reader = std::io::BufReader::new(file);
|
let mut reader = std::io::BufReader::new(file);
|
||||||
let mut stack = candle::pickle::Stack::empty();
|
let mut stack = candle_core::pickle::Stack::empty();
|
||||||
stack.read_loop(&mut reader)?;
|
stack.read_loop(&mut reader)?;
|
||||||
for (i, obj) in stack.stack().iter().enumerate() {
|
for (i, obj) in stack.stack().iter().enumerate() {
|
||||||
println!("{i} {obj:?}");
|
println!("{i} {obj:?}");
|
||||||
@ -338,7 +191,7 @@ fn run_ls(
|
|||||||
}
|
}
|
||||||
Format::Ggml => {
|
Format::Ggml => {
|
||||||
let mut file = std::fs::File::open(file)?;
|
let mut file = std::fs::File::open(file)?;
|
||||||
let content = candle::quantized::ggml_file::Content::read(&mut file, device)?;
|
let content = candle_core::quantized::ggml_file::Content::read(&mut file)?;
|
||||||
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
|
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
|
||||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||||
for (name, qtensor) in tensors.iter() {
|
for (name, qtensor) in tensors.iter() {
|
||||||
@ -374,13 +227,42 @@ fn run_quantize_safetensors(
|
|||||||
let mut out_file = std::fs::File::create(out_file)?;
|
let mut out_file = std::fs::File::create(out_file)?;
|
||||||
let mut tensors = std::collections::HashMap::new();
|
let mut tensors = std::collections::HashMap::new();
|
||||||
for in_file in in_files.iter() {
|
for in_file in in_files.iter() {
|
||||||
let in_tensors = candle::safetensors::load(in_file, &Device::Cpu)?;
|
let in_tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?;
|
||||||
tensors.extend(in_tensors)
|
tensors.extend(in_tensors)
|
||||||
}
|
}
|
||||||
println!("tensors: {}", tensors.len());
|
println!("tensors: {}", tensors.len());
|
||||||
|
|
||||||
let dtype = q.dtype();
|
let quantize_fn = match q {
|
||||||
let block_size = dtype.block_size();
|
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
|
||||||
|
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
|
||||||
|
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
|
||||||
|
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
|
||||||
|
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
|
||||||
|
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
|
||||||
|
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
|
||||||
|
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
|
||||||
|
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
|
||||||
|
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
|
||||||
|
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
|
||||||
|
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
|
||||||
|
Quantization::F16 => QTensor::quantize::<half::f16>,
|
||||||
|
Quantization::F32 => QTensor::quantize::<f32>,
|
||||||
|
};
|
||||||
|
let block_size = match q {
|
||||||
|
Quantization::Q4_0 => k_quants::QK4_0,
|
||||||
|
Quantization::Q4_1 => k_quants::QK4_1,
|
||||||
|
Quantization::Q5_0 => k_quants::QK5_0,
|
||||||
|
Quantization::Q5_1 => k_quants::QK5_1,
|
||||||
|
Quantization::Q8_0 => k_quants::QK8_0,
|
||||||
|
Quantization::Q8_1 => k_quants::QK8_1,
|
||||||
|
Quantization::Q2k
|
||||||
|
| Quantization::Q3k
|
||||||
|
| Quantization::Q4k
|
||||||
|
| Quantization::Q5k
|
||||||
|
| Quantization::Q6k
|
||||||
|
| Quantization::Q8k => k_quants::QK_K,
|
||||||
|
Quantization::F16 | Quantization::F32 => 1,
|
||||||
|
};
|
||||||
|
|
||||||
let qtensors = tensors
|
let qtensors = tensors
|
||||||
.into_par_iter()
|
.into_par_iter()
|
||||||
@ -388,9 +270,9 @@ fn run_quantize_safetensors(
|
|||||||
let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0;
|
let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0;
|
||||||
println!(" quantizing {name} {tensor:?} {should_quantize}");
|
println!(" quantizing {name} {tensor:?} {should_quantize}");
|
||||||
let tensor = if should_quantize {
|
let tensor = if should_quantize {
|
||||||
QTensor::quantize(&tensor, dtype)?
|
quantize_fn(&tensor)?
|
||||||
} else {
|
} else {
|
||||||
QTensor::quantize(&tensor, GgmlDType::F32)?
|
QTensor::quantize::<f32>(&tensor)?
|
||||||
};
|
};
|
||||||
Ok((name, tensor))
|
Ok((name, tensor))
|
||||||
})
|
})
|
||||||
@ -403,36 +285,18 @@ fn run_quantize_safetensors(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_dequantize(
|
|
||||||
in_file: std::path::PathBuf,
|
|
||||||
out_file: std::path::PathBuf,
|
|
||||||
device: &Device,
|
|
||||||
) -> Result<()> {
|
|
||||||
let mut in_file = std::fs::File::open(in_file)?;
|
|
||||||
let content = gguf_file::Content::read(&mut in_file)?;
|
|
||||||
let mut tensors = std::collections::HashMap::new();
|
|
||||||
for (tensor_name, _) in content.tensor_infos.iter() {
|
|
||||||
let tensor = content.tensor(&mut in_file, tensor_name, device)?;
|
|
||||||
let tensor = tensor.dequantize(device)?;
|
|
||||||
tensors.insert(tensor_name.to_string(), tensor);
|
|
||||||
}
|
|
||||||
candle::safetensors::save(&tensors, out_file)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_quantize(
|
fn run_quantize(
|
||||||
in_files: &[std::path::PathBuf],
|
in_files: &[std::path::PathBuf],
|
||||||
out_file: std::path::PathBuf,
|
out_file: std::path::PathBuf,
|
||||||
q: Quantization,
|
q: Quantization,
|
||||||
qmode: QuantizationMode,
|
qmode: QuantizationMode,
|
||||||
device: &Device,
|
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
if in_files.is_empty() {
|
if in_files.is_empty() {
|
||||||
candle::bail!("no specified input files")
|
candle_core::bail!("no specified input files")
|
||||||
}
|
}
|
||||||
if let Some(extension) = out_file.extension() {
|
if let Some(extension) = out_file.extension() {
|
||||||
if extension == "safetensors" {
|
if extension == "safetensors" {
|
||||||
candle::bail!("the generated file cannot use the safetensors extension")
|
candle_core::bail!("the generated file cannot use the safetensors extension")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if let Some(extension) = in_files[0].extension() {
|
if let Some(extension) = in_files[0].extension() {
|
||||||
@ -442,7 +306,7 @@ fn run_quantize(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if in_files.len() != 1 {
|
if in_files.len() != 1 {
|
||||||
candle::bail!("only a single in-file can be used when quantizing gguf files")
|
candle_core::bail!("only a single in-file can be used when quantizing gguf files")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open the out file early so as to fail directly on missing directories etc.
|
// Open the out file early so as to fail directly on missing directories etc.
|
||||||
@ -451,15 +315,31 @@ fn run_quantize(
|
|||||||
let content = gguf_file::Content::read(&mut in_)?;
|
let content = gguf_file::Content::read(&mut in_)?;
|
||||||
println!("tensors: {}", content.tensor_infos.len());
|
println!("tensors: {}", content.tensor_infos.len());
|
||||||
|
|
||||||
let dtype = q.dtype();
|
let quantize_fn = match q {
|
||||||
|
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
|
||||||
|
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
|
||||||
|
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
|
||||||
|
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
|
||||||
|
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
|
||||||
|
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
|
||||||
|
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
|
||||||
|
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
|
||||||
|
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
|
||||||
|
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
|
||||||
|
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
|
||||||
|
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
|
||||||
|
Quantization::F16 => QTensor::quantize::<half::f16>,
|
||||||
|
Quantization::F32 => QTensor::quantize::<f32>,
|
||||||
|
};
|
||||||
|
|
||||||
let qtensors = content
|
let qtensors = content
|
||||||
.tensor_infos
|
.tensor_infos
|
||||||
.par_iter()
|
.par_iter()
|
||||||
.map(|(name, _)| {
|
.map(|(name, _)| {
|
||||||
println!(" quantizing {name}");
|
println!(" quantizing {name}");
|
||||||
let mut in_file = std::fs::File::open(&in_files[0])?;
|
let mut in_file = std::fs::File::open(&in_files[0])?;
|
||||||
let tensor = content.tensor(&mut in_file, name, device)?;
|
let tensor = content.tensor(&mut in_file, name)?;
|
||||||
let tensor = qmode.quantize(name, tensor, dtype)?;
|
let tensor = qmode.quantize(name, tensor, quantize_fn)?;
|
||||||
Ok((name, tensor))
|
Ok((name, tensor))
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
@ -479,7 +359,6 @@ fn run_quantize(
|
|||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
let device = Device::Cpu;
|
|
||||||
match args.command {
|
match args.command {
|
||||||
Command::Ls {
|
Command::Ls {
|
||||||
files,
|
files,
|
||||||
@ -491,23 +370,15 @@ fn main() -> anyhow::Result<()> {
|
|||||||
if multiple_files {
|
if multiple_files {
|
||||||
println!("--- {file:?} ---");
|
println!("--- {file:?} ---");
|
||||||
}
|
}
|
||||||
run_ls(file, format.clone(), verbose, &device)?
|
run_ls(file, format.clone(), verbose)?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Command::Print {
|
|
||||||
file,
|
|
||||||
names,
|
|
||||||
format,
|
|
||||||
full,
|
|
||||||
line_width,
|
|
||||||
} => run_print(&file, names, format, full, line_width, &device)?,
|
|
||||||
Command::Quantize {
|
Command::Quantize {
|
||||||
in_file,
|
in_file,
|
||||||
out_file,
|
out_file,
|
||||||
quantization,
|
quantization,
|
||||||
mode,
|
mode,
|
||||||
} => run_quantize(&in_file, out_file, quantization, mode, &device)?,
|
} => run_quantize(&in_file, out_file, quantization, mode)?,
|
||||||
Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file, &device)?,
|
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
@ -380,16 +380,6 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
|
|||||||
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub fn vs_exp_inplace(y: &mut [f32]) {
|
|
||||||
unsafe { ffi::vvexpf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub fn vd_exp_inplace(y: &mut [f64]) {
|
|
||||||
unsafe { ffi::vvexp(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
@ -412,28 +402,6 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
|
|
||||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
|
||||||
*y = -v
|
|
||||||
}
|
|
||||||
vs_exp_inplace(ys);
|
|
||||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
|
||||||
*y = v / (1.0 + *y)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
|
|
||||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
|
||||||
*y = -v
|
|
||||||
}
|
|
||||||
vd_exp_inplace(ys);
|
|
||||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
|
||||||
*y = v / (1.0 + *y)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
macro_rules! binary_op {
|
macro_rules! binary_op {
|
||||||
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
|
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
|
||||||
#[inline]
|
#[inline]
|
||||||
|
@ -39,14 +39,6 @@ pub trait BackendStorage: Sized {
|
|||||||
_params: &crate::conv::ParamsConv1D,
|
_params: &crate::conv::ParamsConv1D,
|
||||||
) -> Result<Self>;
|
) -> Result<Self>;
|
||||||
|
|
||||||
fn conv_transpose1d(
|
|
||||||
&self,
|
|
||||||
_l: &Layout,
|
|
||||||
_kernel: &Self,
|
|
||||||
_kernel_l: &Layout,
|
|
||||||
_params: &crate::conv::ParamsConvTranspose1D,
|
|
||||||
) -> Result<Self>;
|
|
||||||
|
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
&self,
|
&self,
|
||||||
_l: &Layout,
|
_l: &Layout,
|
||||||
@ -98,19 +90,6 @@ pub trait BackendStorage: Sized {
|
|||||||
) -> Result<Self>;
|
) -> Result<Self>;
|
||||||
|
|
||||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>;
|
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>;
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
// Similar to cudaMemcpy2D, though values are in elements and not in bytes.
|
|
||||||
fn copy2d(
|
|
||||||
&self,
|
|
||||||
_: &mut Self,
|
|
||||||
_d1: usize,
|
|
||||||
_d2: usize,
|
|
||||||
_src_stride1: usize,
|
|
||||||
_dst_stride1: usize,
|
|
||||||
_src_offset: usize,
|
|
||||||
_dst_offset: usize,
|
|
||||||
) -> Result<()>;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||||
@ -127,24 +106,11 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
|||||||
|
|
||||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
||||||
|
|
||||||
/// # Safety
|
|
||||||
/// This function is unsafe as it doesn't initialize the underlying data store.
|
|
||||||
/// The caller should ensure that the data is properly initialized as early as possible
|
|
||||||
/// after this call.
|
|
||||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
|
||||||
|
|
||||||
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage>;
|
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
|
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
|
||||||
|
|
||||||
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>;
|
|
||||||
|
|
||||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||||
|
|
||||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||||
|
|
||||||
fn set_seed(&self, _: u64) -> Result<()>;
|
fn set_seed(&self, _: u64) -> Result<()>;
|
||||||
|
|
||||||
/// Synchronize should block until all the operations on the device are completed.
|
|
||||||
fn synchronize(&self) -> Result<()>;
|
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
/// Methods for backpropagation of gradients.
|
|
||||||
use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
|
use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
|
||||||
use crate::{Error, Result, Tensor, TensorId};
|
use crate::{Error, Result, Tensor, TensorId};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@ -16,17 +15,6 @@ fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
thread_local! {
|
|
||||||
static CANDLE_GRAD_DO_NOT_DETACH: bool = {
|
|
||||||
match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") {
|
|
||||||
Ok(s) => {
|
|
||||||
!s.is_empty() && s != "0"
|
|
||||||
},
|
|
||||||
Err(_) => false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
||||||
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
||||||
@ -48,8 +36,6 @@ impl Tensor {
|
|||||||
// Do not call recursively on the "leaf" nodes.
|
// Do not call recursively on the "leaf" nodes.
|
||||||
track_grad = true;
|
track_grad = true;
|
||||||
nodes
|
nodes
|
||||||
} else if node.dtype().is_int() {
|
|
||||||
nodes
|
|
||||||
} else if let Some(op) = node.op() {
|
} else if let Some(op) = node.op() {
|
||||||
match op {
|
match op {
|
||||||
Op::IndexAdd(t1, t2, t3, _)
|
Op::IndexAdd(t1, t2, t3, _)
|
||||||
@ -69,11 +55,6 @@ impl Tensor {
|
|||||||
kernel: rhs,
|
kernel: rhs,
|
||||||
..
|
..
|
||||||
}
|
}
|
||||||
| Op::ConvTranspose1D {
|
|
||||||
arg: lhs,
|
|
||||||
kernel: rhs,
|
|
||||||
..
|
|
||||||
}
|
|
||||||
| Op::Conv2D {
|
| Op::Conv2D {
|
||||||
arg: lhs,
|
arg: lhs,
|
||||||
kernel: rhs,
|
kernel: rhs,
|
||||||
@ -112,17 +93,17 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
Op::Unary(_node, UnaryOp::Ceil)
|
Op::Unary(_node, UnaryOp::Ceil)
|
||||||
| Op::Unary(_node, UnaryOp::Floor)
|
| Op::Unary(_node, UnaryOp::Floor)
|
||||||
| Op::Unary(_node, UnaryOp::Round)
|
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
||||||
| Op::Unary(_node, UnaryOp::Sign) => nodes,
|
|
||||||
Op::Reshape(node)
|
Op::Reshape(node)
|
||||||
| Op::UpsampleNearest1D { arg: node, .. }
|
| Op::UpsampleNearest1D(node)
|
||||||
| Op::UpsampleNearest2D { arg: node, .. }
|
| Op::UpsampleNearest2D(node)
|
||||||
| Op::AvgPool2D { arg: node, .. }
|
| Op::AvgPool2D { arg: node, .. }
|
||||||
| Op::MaxPool2D { arg: node, .. }
|
| Op::MaxPool2D { arg: node, .. }
|
||||||
| Op::Copy(node)
|
| Op::Copy(node)
|
||||||
| Op::Broadcast(node)
|
| Op::Broadcast(node)
|
||||||
| Op::Cmp(node, _)
|
| Op::Cmp(node, _)
|
||||||
| Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
|
| Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
|
||||||
|
| Op::ToDType(node)
|
||||||
| Op::ToDevice(node)
|
| Op::ToDevice(node)
|
||||||
| Op::Transpose(node, _, _)
|
| Op::Transpose(node, _, _)
|
||||||
| Op::Permute(node, _)
|
| Op::Permute(node, _)
|
||||||
@ -135,15 +116,6 @@ impl Tensor {
|
|||||||
track_grad |= tg;
|
track_grad |= tg;
|
||||||
nodes
|
nodes
|
||||||
}
|
}
|
||||||
Op::ToDType(node) => {
|
|
||||||
if node.dtype().is_float() {
|
|
||||||
let (tg, nodes) = walk(node, nodes, already_seen);
|
|
||||||
track_grad |= tg;
|
|
||||||
nodes
|
|
||||||
} else {
|
|
||||||
nodes
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
|
Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -168,16 +140,10 @@ impl Tensor {
|
|||||||
if node.is_variable() {
|
if node.is_variable() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
let grad = grads
|
let grad = grads.remove(node).unwrap();
|
||||||
.remove(node)
|
// TODO: We should perform all these operations in place (or at least not track the
|
||||||
.expect("candle internal error - grad not populated");
|
// whole graph). The only drawback would be if we wanted to support grad of grad but
|
||||||
// https://github.com/huggingface/candle/issues/1241
|
// this is out of scope.
|
||||||
// Ideally, we would make these operations in place where possible to ensure that we
|
|
||||||
// do not have to allocate too often. Here we just call `.detach` to avoid computing
|
|
||||||
// the backprop graph of the backprop itself. This would be an issue for second order
|
|
||||||
// derivatives but these are out of scope at the moment.
|
|
||||||
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
|
|
||||||
let grad = if do_not_detach { grad } else { grad.detach() };
|
|
||||||
if let Some(op) = node.op() {
|
if let Some(op) = node.op() {
|
||||||
match op {
|
match op {
|
||||||
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
||||||
@ -232,45 +198,7 @@ impl Tensor {
|
|||||||
let f_grad = pred.where_cond(&zeros, &grad)?;
|
let f_grad = pred.where_cond(&zeros, &grad)?;
|
||||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||||
}
|
}
|
||||||
Op::Conv1D {
|
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
||||||
arg,
|
|
||||||
kernel,
|
|
||||||
padding,
|
|
||||||
stride,
|
|
||||||
dilation,
|
|
||||||
} => {
|
|
||||||
// The output height for conv_transpose1d is:
|
|
||||||
// (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1
|
|
||||||
let grad_l_in = grad.dim(2)?;
|
|
||||||
let k_size = kernel.dim(2)?;
|
|
||||||
let out_size =
|
|
||||||
(grad_l_in - 1) * stride + dilation * (k_size - 1) + 1 - 2 * padding;
|
|
||||||
let out_padding = arg.dim(2)? - out_size;
|
|
||||||
let grad_arg = grad.conv_transpose1d(
|
|
||||||
kernel,
|
|
||||||
*padding,
|
|
||||||
out_padding,
|
|
||||||
*stride,
|
|
||||||
*dilation,
|
|
||||||
/* groups */ 1,
|
|
||||||
)?;
|
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
|
||||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
|
||||||
|
|
||||||
let grad_kernel = arg
|
|
||||||
.transpose(0, 1)?
|
|
||||||
.conv1d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
|
||||||
.transpose(0, 1)?;
|
|
||||||
let sum_grad = grads.or_insert(kernel)?;
|
|
||||||
let (_, _, k0) = kernel.dims3()?;
|
|
||||||
let (_, _, g_k0) = grad_kernel.dims3()?;
|
|
||||||
let grad_kernel = if g_k0 != k0 {
|
|
||||||
grad_kernel.narrow(2, 0, k0)?
|
|
||||||
} else {
|
|
||||||
grad_kernel
|
|
||||||
};
|
|
||||||
*sum_grad = sum_grad.add(&grad_kernel)?;
|
|
||||||
}
|
|
||||||
Op::Conv2D {
|
Op::Conv2D {
|
||||||
arg,
|
arg,
|
||||||
kernel,
|
kernel,
|
||||||
@ -300,44 +228,11 @@ impl Tensor {
|
|||||||
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
||||||
.transpose(0, 1)?;
|
.transpose(0, 1)?;
|
||||||
let sum_grad = grads.or_insert(kernel)?;
|
let sum_grad = grads.or_insert(kernel)?;
|
||||||
let (_, _, k0, k1) = kernel.dims4()?;
|
|
||||||
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
|
|
||||||
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
|
|
||||||
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
|
|
||||||
} else {
|
|
||||||
grad_kernel
|
|
||||||
};
|
|
||||||
*sum_grad = sum_grad.add(&grad_kernel)?;
|
*sum_grad = sum_grad.add(&grad_kernel)?;
|
||||||
}
|
}
|
||||||
Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
|
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
||||||
op: "conv-transpose1d",
|
op: "conv-transpose2d",
|
||||||
})?,
|
})?,
|
||||||
Op::ConvTranspose2D {
|
|
||||||
arg,
|
|
||||||
kernel,
|
|
||||||
padding,
|
|
||||||
stride,
|
|
||||||
dilation,
|
|
||||||
output_padding: _output_padding,
|
|
||||||
} => {
|
|
||||||
let grad_arg = grad.conv2d(kernel, *padding, *dilation, *stride, 1)?;
|
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
|
||||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
|
||||||
|
|
||||||
let grad_kernel = grad
|
|
||||||
.transpose(0, 1)?
|
|
||||||
.conv2d(&arg.transpose(0, 1)?, *padding, *stride, *dilation, 1)?
|
|
||||||
.transpose(0, 1)?;
|
|
||||||
let sum_grad = grads.or_insert(kernel)?;
|
|
||||||
let (_, _, k0, k1) = kernel.dims4()?;
|
|
||||||
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
|
|
||||||
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
|
|
||||||
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
|
|
||||||
} else {
|
|
||||||
grad_kernel
|
|
||||||
};
|
|
||||||
*sum_grad = sum_grad.add(&grad_kernel)?;
|
|
||||||
}
|
|
||||||
Op::AvgPool2D {
|
Op::AvgPool2D {
|
||||||
arg,
|
arg,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
@ -373,39 +268,12 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||||
}
|
}
|
||||||
Op::UpsampleNearest1D { arg, target_size } => {
|
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
|
||||||
let (_n, c, size) = arg.dims3()?;
|
op: "upsample-nearest1d",
|
||||||
if target_size % size != 0 {
|
})?,
|
||||||
crate::bail!("backward not supported for non integer upscaling factors")
|
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
|
||||||
}
|
op: "upsample-nearest2d",
|
||||||
let scale = target_size / size;
|
})?,
|
||||||
|
|
||||||
let kernel = Tensor::ones((c, 1, scale), arg.dtype(), arg.device())?;
|
|
||||||
let conv_sum = grad.conv1d(&kernel, 0, scale, 1, c)?;
|
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
|
||||||
*sum_grad = conv_sum;
|
|
||||||
}
|
|
||||||
Op::UpsampleNearest2D {
|
|
||||||
arg,
|
|
||||||
target_h,
|
|
||||||
target_w,
|
|
||||||
} => {
|
|
||||||
let (_n, c, h, w) = arg.dims4()?;
|
|
||||||
if target_h % h != 0 || target_w % w != 0 {
|
|
||||||
crate::bail!("backward not supported for non integer upscaling factors")
|
|
||||||
}
|
|
||||||
let scale_h = target_h / h;
|
|
||||||
let scale_w = target_w / w;
|
|
||||||
|
|
||||||
if scale_h != scale_w {
|
|
||||||
crate::bail!("backward not supported for non uniform upscaling factors")
|
|
||||||
};
|
|
||||||
let kernel =
|
|
||||||
Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?;
|
|
||||||
let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?;
|
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
|
||||||
*sum_grad = conv_sum;
|
|
||||||
}
|
|
||||||
Op::SliceScatter0(lhs, rhs, start_rhs) => {
|
Op::SliceScatter0(lhs, rhs, start_rhs) => {
|
||||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
|
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
|
||||||
@ -489,6 +357,7 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&grad)?;
|
*sum_grad = sum_grad.add(&grad)?;
|
||||||
}
|
}
|
||||||
|
Op::Cmp(_args, _) => {}
|
||||||
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
|
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
|
||||||
let node = broadcast_back(arg, node, reduced_dims)?;
|
let node = broadcast_back(arg, node, reduced_dims)?;
|
||||||
let grad = broadcast_back(arg, &grad, reduced_dims)?;
|
let grad = broadcast_back(arg, &grad, reduced_dims)?;
|
||||||
@ -505,7 +374,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
Op::ToDType(arg) => {
|
Op::ToDType(arg) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&grad.to_dtype(arg.dtype())?)?
|
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
|
||||||
}
|
}
|
||||||
Op::Copy(arg) => {
|
Op::Copy(arg) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
@ -578,66 +447,31 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
Op::Unary(_, UnaryOp::Floor)
|
Op::Reduce(_, ReduceOp::ArgMin, _) => {}
|
||||||
| Op::Unary(_, UnaryOp::Round)
|
Op::Reduce(_, ReduceOp::ArgMax, _) => {}
|
||||||
| Op::Reduce(_, ReduceOp::ArgMin, _)
|
|
||||||
| Op::Reduce(_, ReduceOp::ArgMax, _)
|
|
||||||
| Op::Unary(_, UnaryOp::Sign)
|
|
||||||
| Op::Cmp(_, _) => {}
|
|
||||||
Op::Reshape(arg) => {
|
Op::Reshape(arg) => {
|
||||||
let arg_grad = grad.reshape(arg.dims())?;
|
let arg_grad = grad.reshape(arg.dims())?;
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
|
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
|
||||||
Op::Unary(arg, UnaryOp::Gelu) => {
|
Op::Unary(_, UnaryOp::Floor) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
Err(Error::BackwardNotSupported { op: "floor" })?
|
||||||
let cube = arg.powf(3.)?;
|
|
||||||
let tanh = (0.0356774 * &cube + (0.797885 * arg)?)?.tanh()?;
|
|
||||||
let gelu_grad = (((0.5 * &tanh)?
|
|
||||||
+ (0.0535161 * cube + (0.398942 * arg)?)? * (1. - tanh.powf(2.)?))?
|
|
||||||
+ 0.5)?;
|
|
||||||
*sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
|
|
||||||
}
|
}
|
||||||
Op::Unary(arg, UnaryOp::Erf) => {
|
Op::Unary(_, UnaryOp::Round) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
Err(Error::BackwardNotSupported { op: "round" })?
|
||||||
// d/dx erf(x) = 2/sqrt(pi) * e^(-x^2)
|
|
||||||
let erf_grad =
|
|
||||||
(2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?;
|
|
||||||
*sum_grad = sum_grad.add(&(&grad * erf_grad)?)?
|
|
||||||
}
|
}
|
||||||
Op::Unary(arg, UnaryOp::GeluErf) => {
|
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
|
||||||
// d/dx gelu_erf(x) = 0.5 + 0.398942 e^(-x^2/2) x + 0.5 erf(x/sqrt(2))
|
Op::Unary(_, UnaryOp::GeluErf) => {
|
||||||
let neg_half_square = (arg.sqr()?.neg()? / 2.)?;
|
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
|
||||||
let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?;
|
|
||||||
let arg_scaled_sqrt = (arg / 2f64.sqrt())?;
|
|
||||||
let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?;
|
|
||||||
let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?;
|
|
||||||
*sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?;
|
|
||||||
}
|
}
|
||||||
Op::Unary(arg, UnaryOp::Relu) => {
|
Op::Unary(arg, UnaryOp::Relu) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||||
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
||||||
}
|
}
|
||||||
Op::Unary(arg, UnaryOp::Silu) => {
|
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
|
||||||
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
|
|
||||||
let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
|
|
||||||
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
|
|
||||||
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
|
|
||||||
}
|
|
||||||
Op::Elu(arg, alpha) => {
|
|
||||||
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
|
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
|
||||||
let zeros = arg.zeros_like()?;
|
|
||||||
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
|
|
||||||
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
|
|
||||||
let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?;
|
|
||||||
let combined_mask = (positive_mask + negative_exp_mask)?;
|
|
||||||
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
|
|
||||||
}
|
|
||||||
Op::Powf(arg, e) => {
|
Op::Powf(arg, e) => {
|
||||||
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
|
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
@ -712,38 +546,30 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A store for gradients, associating a tensor id to the corresponding gradient tensor, used for back propagation.
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct GradStore(HashMap<TensorId, Tensor>);
|
pub struct GradStore(HashMap<TensorId, Tensor>);
|
||||||
|
|
||||||
impl GradStore {
|
impl GradStore {
|
||||||
/// Create a new gradient store
|
|
||||||
fn new() -> Self {
|
fn new() -> Self {
|
||||||
GradStore(HashMap::new())
|
GradStore(HashMap::new())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the gradient tensor corresponding to the given tensor id
|
|
||||||
pub fn get_id(&self, id: TensorId) -> Option<&Tensor> {
|
pub fn get_id(&self, id: TensorId) -> Option<&Tensor> {
|
||||||
self.0.get(&id)
|
self.0.get(&id)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the gradient tensor associated with the given tensor
|
|
||||||
pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> {
|
pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> {
|
||||||
self.0.get(&tensor.id())
|
self.0.get(&tensor.id())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Remove the gradient tensor associated with the given tensor, returning it if it exists
|
|
||||||
pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> {
|
pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> {
|
||||||
self.0.remove(&tensor.id())
|
self.0.remove(&tensor.id())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Insert a gradient tensor associated with the given tensor, returning the previous gradient tensor if it existed
|
|
||||||
pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> {
|
pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> {
|
||||||
self.0.insert(tensor.id(), grad)
|
self.0.insert(tensor.id(), grad)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the gradient tensor associated with the given tensor, or, if it does not exist,
|
|
||||||
/// insert a tensor of zeroes, with the same shape and type as the given tensors and return it
|
|
||||||
fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
|
fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
|
||||||
use std::collections::hash_map::Entry;
|
use std::collections::hash_map::Entry;
|
||||||
let grad = match self.0.entry(tensor.id()) {
|
let grad = match self.0.entry(tensor.id()) {
|
||||||
|
@ -25,33 +25,6 @@ impl ParamsConv1D {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
pub struct ParamsConvTranspose1D {
|
|
||||||
pub(crate) b_size: usize,
|
|
||||||
pub(crate) l_in: usize,
|
|
||||||
pub(crate) c_out: usize,
|
|
||||||
pub(crate) c_in: usize,
|
|
||||||
pub(crate) k_size: usize,
|
|
||||||
pub(crate) padding: usize,
|
|
||||||
pub(crate) output_padding: usize,
|
|
||||||
pub(crate) stride: usize,
|
|
||||||
pub(crate) dilation: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ParamsConvTranspose1D {
|
|
||||||
pub(crate) fn l_out(&self) -> usize {
|
|
||||||
(self.l_in - 1) * self.stride - 2 * self.padding
|
|
||||||
+ self.dilation * (self.k_size - 1)
|
|
||||||
+ self.output_padding
|
|
||||||
+ 1
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
|
||||||
let l_out = self.l_out();
|
|
||||||
vec![self.b_size, self.c_out, l_out]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
pub enum CudnnFwdAlgo {
|
pub enum CudnnFwdAlgo {
|
||||||
ImplicitGemm,
|
ImplicitGemm,
|
||||||
@ -187,72 +160,6 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv_transpose1d_single_group(
|
|
||||||
&self,
|
|
||||||
kernel: &Self,
|
|
||||||
params: &ParamsConvTranspose1D,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let storage = self.storage().conv_transpose1d(
|
|
||||||
self.layout(),
|
|
||||||
&kernel.storage(),
|
|
||||||
kernel.layout(),
|
|
||||||
params,
|
|
||||||
)?;
|
|
||||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
|
|
||||||
arg,
|
|
||||||
kernel,
|
|
||||||
padding: params.padding,
|
|
||||||
output_padding: params.output_padding,
|
|
||||||
stride: params.stride,
|
|
||||||
dilation: params.dilation,
|
|
||||||
});
|
|
||||||
let out_dims = params.out_dims();
|
|
||||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a 1D transposed convolution over the input tensor.
|
|
||||||
pub fn conv_transpose1d(
|
|
||||||
&self,
|
|
||||||
kernel: &Self,
|
|
||||||
padding: usize,
|
|
||||||
output_padding: usize,
|
|
||||||
stride: usize,
|
|
||||||
dilation: usize,
|
|
||||||
groups: usize,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let (c_in_k, c_out, k_size) = kernel.dims3()?;
|
|
||||||
let (b_size, c_in, l_in) = self.dims3()?;
|
|
||||||
if c_in != c_in_k {
|
|
||||||
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
|
||||||
}
|
|
||||||
if c_in % groups != 0 {
|
|
||||||
crate::bail!("in_channel {c_in} is not divisible by the number of groups")
|
|
||||||
}
|
|
||||||
let params = ParamsConvTranspose1D {
|
|
||||||
b_size,
|
|
||||||
l_in,
|
|
||||||
k_size,
|
|
||||||
c_out,
|
|
||||||
c_in: c_in / groups,
|
|
||||||
padding,
|
|
||||||
output_padding,
|
|
||||||
stride,
|
|
||||||
dilation,
|
|
||||||
};
|
|
||||||
if groups == 1 {
|
|
||||||
self.conv_transpose1d_single_group(kernel, ¶ms)
|
|
||||||
} else {
|
|
||||||
let blocks = self.chunk(groups, 1)?;
|
|
||||||
let kernel = kernel.chunk(groups, 0)?;
|
|
||||||
let blocks = blocks
|
|
||||||
.iter()
|
|
||||||
.zip(&kernel)
|
|
||||||
.map(|(block, kernel)| block.conv_transpose1d_single_group(kernel, ¶ms))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
Tensor::cat(&blocks, 1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
|
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
|
||||||
let storage =
|
let storage =
|
||||||
self.storage()
|
self.storage()
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
pub mod erf;
|
pub mod erf;
|
||||||
pub mod kernels;
|
pub mod kernels;
|
||||||
|
|
||||||
#[allow(unused)]
|
|
||||||
trait Cpu<const ARR: usize> {
|
trait Cpu<const ARR: usize> {
|
||||||
type Unit;
|
type Unit;
|
||||||
type Array;
|
type Array;
|
||||||
@ -19,7 +18,6 @@ trait Cpu<const ARR: usize> {
|
|||||||
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit);
|
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(unused)]
|
|
||||||
trait CpuF16<const ARR: usize> {
|
trait CpuF16<const ARR: usize> {
|
||||||
type Unit;
|
type Unit;
|
||||||
type Array;
|
type Array;
|
||||||
|
@ -4,13 +4,7 @@ use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
|
|||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
|
|
||||||
mod utils;
|
|
||||||
pub use utils::{
|
|
||||||
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8,
|
|
||||||
};
|
|
||||||
|
|
||||||
const USE_IM2COL_CONV1D: bool = true;
|
const USE_IM2COL_CONV1D: bool = true;
|
||||||
const USE_IM2COL_CONV1D_TR: bool = true;
|
|
||||||
const USE_IM2COL_CONV2D: bool = true;
|
const USE_IM2COL_CONV2D: bool = true;
|
||||||
|
|
||||||
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
||||||
@ -27,18 +21,103 @@ pub enum CpuStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum CpuStorageRef<'a> {
|
pub struct CpuDevice;
|
||||||
U8(&'a [u8]),
|
|
||||||
U32(&'a [u32]),
|
pub trait Map1 {
|
||||||
I64(&'a [i64]),
|
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
|
||||||
BF16(&'a [bf16]),
|
|
||||||
F16(&'a [f16]),
|
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
|
||||||
F32(&'a [f32]),
|
match vs {
|
||||||
F64(&'a [f64]),
|
CpuStorage::U8(vs) => Ok(CpuStorage::U8(self.f(vs, layout)?)),
|
||||||
|
CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)),
|
||||||
|
CpuStorage::I64(vs) => Ok(CpuStorage::I64(self.f(vs, layout)?)),
|
||||||
|
CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)),
|
||||||
|
CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)),
|
||||||
|
CpuStorage::F32(vs) => Ok(CpuStorage::F32(self.f(vs, layout)?)),
|
||||||
|
CpuStorage::F64(vs) => Ok(CpuStorage::F64(self.f(vs, layout)?)),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
pub trait Map1Any {
|
||||||
pub struct CpuDevice;
|
fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
|
||||||
|
&self,
|
||||||
|
vs: &[T],
|
||||||
|
layout: &Layout,
|
||||||
|
wrap: W,
|
||||||
|
) -> Result<CpuStorage>;
|
||||||
|
|
||||||
|
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
|
||||||
|
match vs {
|
||||||
|
CpuStorage::U8(vs) => Ok(self.f(vs, layout, CpuStorage::U8)?),
|
||||||
|
CpuStorage::U32(vs) => Ok(self.f(vs, layout, CpuStorage::U32)?),
|
||||||
|
CpuStorage::I64(vs) => Ok(self.f(vs, layout, CpuStorage::I64)?),
|
||||||
|
CpuStorage::BF16(vs) => Ok(self.f(vs, layout, CpuStorage::BF16)?),
|
||||||
|
CpuStorage::F16(vs) => Ok(self.f(vs, layout, CpuStorage::F16)?),
|
||||||
|
CpuStorage::F32(vs) => Ok(self.f(vs, layout, CpuStorage::F32)?),
|
||||||
|
CpuStorage::F64(vs) => Ok(self.f(vs, layout, CpuStorage::F64)?),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type C = CpuStorage;
|
||||||
|
pub trait Map2 {
|
||||||
|
const OP: &'static str;
|
||||||
|
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
|
||||||
|
|
||||||
|
fn map(
|
||||||
|
&self,
|
||||||
|
v1: &CpuStorage,
|
||||||
|
l1: &Layout,
|
||||||
|
v2: &CpuStorage,
|
||||||
|
l2: &Layout,
|
||||||
|
) -> Result<CpuStorage> {
|
||||||
|
match (v1, v2) {
|
||||||
|
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
|
||||||
|
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: v1.dtype(),
|
||||||
|
rhs: v2.dtype(),
|
||||||
|
op: Self::OP,
|
||||||
|
}
|
||||||
|
.bt()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map2U8 {
|
||||||
|
const OP: &'static str;
|
||||||
|
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
|
||||||
|
|
||||||
|
fn map(
|
||||||
|
&self,
|
||||||
|
v1: &CpuStorage,
|
||||||
|
l1: &Layout,
|
||||||
|
v2: &CpuStorage,
|
||||||
|
l2: &Layout,
|
||||||
|
) -> Result<CpuStorage> {
|
||||||
|
match (v1, v2) {
|
||||||
|
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: v1.dtype(),
|
||||||
|
rhs: v2.dtype(),
|
||||||
|
op: Self::OP,
|
||||||
|
}
|
||||||
|
.bt()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct Cmp(CmpOp);
|
struct Cmp(CmpOp);
|
||||||
impl Map2U8 for Cmp {
|
impl Map2U8 for Cmp {
|
||||||
@ -286,6 +365,275 @@ impl<'a> Map1 for ReduceSum<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
||||||
|
vs: &[T],
|
||||||
|
layout: &Layout,
|
||||||
|
mut f: F,
|
||||||
|
) -> Vec<U> {
|
||||||
|
match layout.strided_blocks() {
|
||||||
|
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
|
||||||
|
[start_offset..start_offset + len]
|
||||||
|
.iter()
|
||||||
|
.map(|&v| f(v))
|
||||||
|
.collect(),
|
||||||
|
crate::StridedBlocks::MultipleBlocks {
|
||||||
|
block_start_index,
|
||||||
|
block_len,
|
||||||
|
} => {
|
||||||
|
let mut result = Vec::with_capacity(layout.shape().elem_count());
|
||||||
|
// Specialize the case where block_len is one to avoid the second loop.
|
||||||
|
if block_len == 1 {
|
||||||
|
for index in block_start_index {
|
||||||
|
let v = unsafe { vs.get_unchecked(index) };
|
||||||
|
result.push(f(*v))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for index in block_start_index {
|
||||||
|
for offset in 0..block_len {
|
||||||
|
let v = unsafe { vs.get_unchecked(index + offset) };
|
||||||
|
result.push(f(*v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
|
||||||
|
vs: &[T],
|
||||||
|
layout: &Layout,
|
||||||
|
mut f: F,
|
||||||
|
mut f_vec: FV,
|
||||||
|
) -> Vec<U> {
|
||||||
|
match layout.strided_blocks() {
|
||||||
|
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||||
|
let mut ys: Vec<U> = Vec::with_capacity(len);
|
||||||
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
|
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
||||||
|
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
|
||||||
|
// SAFETY: values are all set by f_vec.
|
||||||
|
unsafe { ys.set_len(len) };
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
crate::StridedBlocks::MultipleBlocks {
|
||||||
|
block_start_index,
|
||||||
|
block_len,
|
||||||
|
} => {
|
||||||
|
let el_count = layout.shape().elem_count();
|
||||||
|
// Specialize the case where block_len is one to avoid the second loop.
|
||||||
|
if block_len == 1 {
|
||||||
|
let mut result = Vec::with_capacity(el_count);
|
||||||
|
for index in block_start_index {
|
||||||
|
let v = unsafe { vs.get_unchecked(index) };
|
||||||
|
result.push(f(*v))
|
||||||
|
}
|
||||||
|
result
|
||||||
|
} else {
|
||||||
|
let mut ys: Vec<U> = Vec::with_capacity(el_count);
|
||||||
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
|
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
||||||
|
let mut dst_index = 0;
|
||||||
|
for src_index in block_start_index {
|
||||||
|
let vs = &vs[src_index..src_index + block_len];
|
||||||
|
let ys = &mut ys_to_set[dst_index..dst_index + block_len];
|
||||||
|
f_vec(vs, ys);
|
||||||
|
dst_index += block_len;
|
||||||
|
}
|
||||||
|
// SAFETY: values are all set by f_vec.
|
||||||
|
unsafe { ys.set_len(el_count) };
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This function maps over two strided index sequences.
|
||||||
|
pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
|
||||||
|
lhs_l: &Layout,
|
||||||
|
rhs_l: &Layout,
|
||||||
|
lhs: &[T],
|
||||||
|
rhs: &[T],
|
||||||
|
mut f: F,
|
||||||
|
) -> Vec<U> {
|
||||||
|
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
||||||
|
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
|
||||||
|
.iter()
|
||||||
|
.zip(rhs[o_r1..o_r2].iter())
|
||||||
|
.map(|(&l, &r)| f(l, r))
|
||||||
|
.collect(),
|
||||||
|
(Some((o_l1, o_l2)), None) => {
|
||||||
|
// TODO: Maybe we want to avoid going through the layout twice.
|
||||||
|
match rhs_l.offsets_b() {
|
||||||
|
Some(ob) => {
|
||||||
|
let mut i_in_block = 0;
|
||||||
|
let mut i_right_broadcast = 0;
|
||||||
|
lhs[o_l1..o_l2]
|
||||||
|
.iter()
|
||||||
|
.map(|&l| {
|
||||||
|
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
|
||||||
|
i_right_broadcast += 1;
|
||||||
|
if i_right_broadcast >= ob.right_broadcast {
|
||||||
|
i_in_block += 1;
|
||||||
|
i_right_broadcast = 0;
|
||||||
|
}
|
||||||
|
if i_in_block >= ob.len {
|
||||||
|
i_in_block = 0
|
||||||
|
}
|
||||||
|
f(l, *r)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
None => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(None, Some((o_r1, o_r2))) => {
|
||||||
|
// TODO: Maybe we want to avoid going through the layout twice.
|
||||||
|
match lhs_l.offsets_b() {
|
||||||
|
Some(ob) => {
|
||||||
|
let mut i_in_block = 0;
|
||||||
|
let mut i_right_broadcast = 0;
|
||||||
|
rhs[o_r1..o_r2]
|
||||||
|
.iter()
|
||||||
|
.map(|&r| {
|
||||||
|
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
|
||||||
|
i_right_broadcast += 1;
|
||||||
|
if i_right_broadcast >= ob.right_broadcast {
|
||||||
|
i_in_block += 1;
|
||||||
|
i_right_broadcast = 0;
|
||||||
|
}
|
||||||
|
if i_in_block >= ob.len {
|
||||||
|
i_in_block = 0
|
||||||
|
}
|
||||||
|
f(*l, r)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
None => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Similar to binary_map but with vectorized variants.
|
||||||
|
pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
|
||||||
|
lhs_l: &Layout,
|
||||||
|
rhs_l: &Layout,
|
||||||
|
lhs: &[T],
|
||||||
|
rhs: &[T],
|
||||||
|
mut f: F,
|
||||||
|
mut f_vec: FV,
|
||||||
|
) -> Vec<T> {
|
||||||
|
let el_count = lhs_l.shape().elem_count();
|
||||||
|
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
||||||
|
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
|
||||||
|
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||||
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
|
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||||
|
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
|
||||||
|
// SAFETY: values are all set by f_vec.
|
||||||
|
unsafe { ys.set_len(el_count) };
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
|
||||||
|
Some(ob) if ob.right_broadcast == 1 => {
|
||||||
|
let rhs = &rhs[ob.start..ob.start + ob.len];
|
||||||
|
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||||
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
|
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||||
|
let mut dst_i = 0;
|
||||||
|
for src_i in (o_l1..o_l2).step_by(ob.len) {
|
||||||
|
f_vec(
|
||||||
|
&lhs[src_i..src_i + ob.len],
|
||||||
|
rhs,
|
||||||
|
&mut ys_to_set[dst_i..dst_i + ob.len],
|
||||||
|
);
|
||||||
|
dst_i += ob.len;
|
||||||
|
}
|
||||||
|
// SAFETY: values are all set by f_vec.
|
||||||
|
unsafe { ys.set_len(el_count) };
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
Some(ob) => {
|
||||||
|
let rhs = &rhs[ob.start..ob.start + ob.len];
|
||||||
|
let mut ys = lhs[o_l1..o_l2].to_vec();
|
||||||
|
for idx_l in 0..ob.left_broadcast {
|
||||||
|
let start = idx_l * ob.len * ob.right_broadcast;
|
||||||
|
for (i, &r) in rhs.iter().enumerate() {
|
||||||
|
let start = start + i * ob.right_broadcast;
|
||||||
|
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
||||||
|
*v = f(*v, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
None => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
},
|
||||||
|
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
|
||||||
|
Some(ob) if ob.right_broadcast == 1 => {
|
||||||
|
let lhs = &lhs[ob.start..ob.start + ob.len];
|
||||||
|
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||||
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
|
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||||
|
let mut dst_i = 0;
|
||||||
|
for src_i in (o_r1..o_r2).step_by(ob.len) {
|
||||||
|
f_vec(
|
||||||
|
lhs,
|
||||||
|
&rhs[src_i..src_i + ob.len],
|
||||||
|
&mut ys_to_set[dst_i..dst_i + ob.len],
|
||||||
|
);
|
||||||
|
dst_i += ob.len;
|
||||||
|
}
|
||||||
|
// SAFETY: values are all set by f_vec.
|
||||||
|
unsafe { ys.set_len(el_count) };
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
Some(ob) => {
|
||||||
|
let lhs = &lhs[ob.start..ob.start + ob.len];
|
||||||
|
let mut ys = rhs[o_r1..o_r2].to_vec();
|
||||||
|
for idx_l in 0..ob.left_broadcast {
|
||||||
|
let start = idx_l * ob.len * ob.right_broadcast;
|
||||||
|
for (i, &l) in lhs.iter().enumerate() {
|
||||||
|
let start = start + i * ob.right_broadcast;
|
||||||
|
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
||||||
|
*v = f(l, *v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
None => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
},
|
||||||
|
_ => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct Affine(f64, f64);
|
struct Affine(f64, f64);
|
||||||
|
|
||||||
impl Map1 for Affine {
|
impl Map1 for Affine {
|
||||||
@ -456,11 +804,11 @@ impl<'a, I: IntDType> Map1 for Gather<'a, I> {
|
|||||||
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||||
let ids = match self.ids_l.contiguous_offsets() {
|
let ids = match self.ids_l.contiguous_offsets() {
|
||||||
Some((a, b)) => &self.ids[a..b],
|
Some((a, b)) => &self.ids[a..b],
|
||||||
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
||||||
};
|
};
|
||||||
let src = match src_l.contiguous_offsets() {
|
let src = match src_l.contiguous_offsets() {
|
||||||
Some((a, b)) => &src[a..b],
|
Some((a, b)) => &src[a..b],
|
||||||
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
||||||
};
|
};
|
||||||
let dim = self.dim;
|
let dim = self.dim;
|
||||||
let ids_dims = self.ids_l.dims();
|
let ids_dims = self.ids_l.dims();
|
||||||
@ -509,7 +857,7 @@ impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
|
|||||||
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||||
let src = match layout.contiguous_offsets() {
|
let src = match layout.contiguous_offsets() {
|
||||||
Some((a, b)) => &src[a..b],
|
Some((a, b)) => &src[a..b],
|
||||||
None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?,
|
None => Err(Error::RequiresContiguous { op: "index-select" })?,
|
||||||
};
|
};
|
||||||
let dim = self.dim;
|
let dim = self.dim;
|
||||||
let n_ids = match self.ids_l.dims() {
|
let n_ids = match self.ids_l.dims() {
|
||||||
@ -565,7 +913,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
|
|||||||
let mut dst = vec![T::zero(); dst_len];
|
let mut dst = vec![T::zero(); dst_len];
|
||||||
copy_strided_src_(v1, &mut dst, 0, l1);
|
copy_strided_src_(v1, &mut dst, 0, l1);
|
||||||
let src = match src_l.contiguous_offsets() {
|
let src = match src_l.contiguous_offsets() {
|
||||||
None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
None => Err(Error::RequiresContiguous { op: "scatter-add" })?,
|
||||||
Some((o1, o2)) => &src[o1..o2],
|
Some((o1, o2)) => &src[o1..o2],
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -581,7 +929,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
|
|||||||
|
|
||||||
let ids = match self.ids_l.contiguous_offsets() {
|
let ids = match self.ids_l.contiguous_offsets() {
|
||||||
Some((a, b)) => &self.ids[a..b],
|
Some((a, b)) => &self.ids[a..b],
|
||||||
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
||||||
};
|
};
|
||||||
for left_i in 0..ids_left_len {
|
for left_i in 0..ids_left_len {
|
||||||
let start_ids_idx = left_i * ids_right_len * ids_dim_len;
|
let start_ids_idx = left_i * ids_right_len * ids_dim_len;
|
||||||
@ -623,7 +971,7 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
|
|||||||
let mut dst = vec![T::zero(); dst_len];
|
let mut dst = vec![T::zero(); dst_len];
|
||||||
copy_strided_src_(v1, &mut dst, 0, l1);
|
copy_strided_src_(v1, &mut dst, 0, l1);
|
||||||
let src = match src_l.contiguous_offsets() {
|
let src = match src_l.contiguous_offsets() {
|
||||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||||
Some((o1, o2)) => &src[o1..o2],
|
Some((o1, o2)) => &src[o1..o2],
|
||||||
};
|
};
|
||||||
let dim = self.dim;
|
let dim = self.dim;
|
||||||
@ -674,26 +1022,6 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
fn copy2d_<T: Copy>(
|
|
||||||
src: &[T],
|
|
||||||
dst: &mut [T],
|
|
||||||
d1: usize,
|
|
||||||
d2: usize,
|
|
||||||
src_stride1: usize,
|
|
||||||
dst_stride1: usize,
|
|
||||||
src_offset: usize,
|
|
||||||
dst_offset: usize,
|
|
||||||
) {
|
|
||||||
for i1 in 0..d1 {
|
|
||||||
let dst_idx = i1 * dst_stride1 + dst_offset;
|
|
||||||
let src_idx = i1 * src_stride1 + src_offset;
|
|
||||||
let dst = &mut dst[dst_idx..dst_idx + d2];
|
|
||||||
let src = &src[src_idx..src_idx + d2];
|
|
||||||
dst.copy_from_slice(src)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
|
fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
|
||||||
match src_l.strided_blocks() {
|
match src_l.strided_blocks() {
|
||||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||||
@ -928,103 +1256,6 @@ impl Map1 for Im2Col {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Col2Im1D {
|
|
||||||
stride: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Map1 for Col2Im1D {
|
|
||||||
fn f<T: WithDType>(&self, col: &[T], l: &Layout) -> Result<Vec<T>> {
|
|
||||||
let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
|
|
||||||
let stride = self.stride;
|
|
||||||
let l_out = (l_in - 1) * stride + k_size;
|
|
||||||
let mut im = vec![T::zero(); b_size * c_out * l_out];
|
|
||||||
let (dst_s0, dst_s1) = (c_out * l_out, l_out);
|
|
||||||
let (src_s0, src_s1, src_s2) = (c_out * k_size * l_in, c_out * k_size, k_size);
|
|
||||||
for l_in_i in 0..l_in {
|
|
||||||
for k_i in 0..k_size {
|
|
||||||
let l_out_i = l_in_i * stride + k_i;
|
|
||||||
for b_i in 0..b_size {
|
|
||||||
for c_i in 0..c_out {
|
|
||||||
let dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_out_i;
|
|
||||||
let src_idx = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i;
|
|
||||||
im[dst_idx] += col[src_idx]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(im)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
|
||||||
|
|
||||||
impl<'a> Map2 for ConvTranspose1D<'a> {
|
|
||||||
const OP: &'static str = "conv_transpose1d";
|
|
||||||
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
|
||||||
let p = self.0;
|
|
||||||
let inp = &inp[inp_l.start_offset()..];
|
|
||||||
let k = &k[k_l.start_offset()..];
|
|
||||||
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
|
|
||||||
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
|
|
||||||
let l_out = p.l_out();
|
|
||||||
|
|
||||||
// Output shape: [b_size, c_out, l_out].
|
|
||||||
let dst_elems = p.c_out * l_out * p.b_size;
|
|
||||||
let dst = vec![T::zero(); dst_elems];
|
|
||||||
let dst_s0 = p.c_out * l_out;
|
|
||||||
let dst_s1 = l_out;
|
|
||||||
let dst_s2 = 1;
|
|
||||||
|
|
||||||
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
|
|
||||||
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
|
|
||||||
let cont_s0 = p.l_in * p.c_in;
|
|
||||||
let cont_s1 = p.c_in;
|
|
||||||
for b_idx in 0..p.b_size {
|
|
||||||
for l_idx in 0..p.l_in {
|
|
||||||
for c_idx in 0..p.c_in {
|
|
||||||
let src_idx = b_idx * inp_s0 + c_idx * inp_s1 + l_idx * inp_s2;
|
|
||||||
let dst_idx = b_idx * cont_s0 + l_idx * cont_s1 + c_idx;
|
|
||||||
inp_cont[dst_idx] = inp[src_idx]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for k_idx in 0..p.k_size {
|
|
||||||
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
|
|
||||||
let k_cont = (0..p.c_in)
|
|
||||||
.map(|c_in_idx| k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_idx * k_s2])
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
for b_idx in 0..p.b_size {
|
|
||||||
for l_idx in 0..p.l_in {
|
|
||||||
let out_idx = l_idx * p.stride + k_idx * p.dilation;
|
|
||||||
if out_idx < p.padding {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
let out_idx = out_idx - p.padding;
|
|
||||||
if out_idx < l_out {
|
|
||||||
let inp_cont = &inp_cont[b_idx * cont_s0 + l_idx * cont_s1..];
|
|
||||||
let dst_idx = b_idx * dst_s0 + out_idx * dst_s2 + dst_c_idx * dst_s1;
|
|
||||||
let mut d = T::zero();
|
|
||||||
unsafe {
|
|
||||||
T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
|
|
||||||
}
|
|
||||||
let dst_p = dst.as_ptr();
|
|
||||||
// Safety: dst_idx are uniques per dst_c_idx which is used to
|
|
||||||
// parallelise the different tasks so no two threads can try to
|
|
||||||
// write at the same location.
|
|
||||||
unsafe {
|
|
||||||
let ptr = dst_p.add(dst_idx) as *mut T;
|
|
||||||
*ptr += d
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
Ok(dst)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
||||||
|
|
||||||
impl<'a> Map2 for Conv2D<'a> {
|
impl<'a> Map2 for Conv2D<'a> {
|
||||||
@ -1215,30 +1446,6 @@ impl MatMul {
|
|||||||
}))
|
}))
|
||||||
.bt()
|
.bt()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ab_skip(&self, lhs_l: &Layout, rhs_l: &Layout) -> Result<(usize, usize)> {
|
|
||||||
let lhs_stride = lhs_l.stride();
|
|
||||||
let rhs_stride = rhs_l.stride();
|
|
||||||
let rank = lhs_stride.len();
|
|
||||||
let (_b, m, n, k) = self.0;
|
|
||||||
let a_skip: usize = match lhs_stride[..rank - 2] {
|
|
||||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
|
||||||
[_, stride] if lhs_l.dims()[0] == 1 => stride,
|
|
||||||
[stride, _] if lhs_l.dims()[1] == 1 => stride,
|
|
||||||
[stride] => stride,
|
|
||||||
[] => m * k,
|
|
||||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
|
||||||
};
|
|
||||||
let b_skip: usize = match rhs_stride[..rank - 2] {
|
|
||||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
|
||||||
[_, stride] if rhs_l.dims()[0] == 1 => stride,
|
|
||||||
[stride, _] if rhs_l.dims()[1] == 1 => stride,
|
|
||||||
[stride] => stride,
|
|
||||||
[] => n * k,
|
|
||||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
|
||||||
};
|
|
||||||
Ok((a_skip, b_skip))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Map2 for MatMul {
|
impl Map2 for MatMul {
|
||||||
@ -1272,7 +1479,18 @@ impl Map2 for MatMul {
|
|||||||
let rhs_cs = rhs_stride[rank - 1];
|
let rhs_cs = rhs_stride[rank - 1];
|
||||||
let rhs_rs = rhs_stride[rank - 2];
|
let rhs_rs = rhs_stride[rank - 2];
|
||||||
|
|
||||||
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
|
let a_skip: usize = match lhs_stride[..rank - 2] {
|
||||||
|
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||||
|
[stride] => stride,
|
||||||
|
[] => m * k,
|
||||||
|
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
||||||
|
};
|
||||||
|
let b_skip: usize = match rhs_stride[..rank - 2] {
|
||||||
|
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||||
|
[stride] => stride,
|
||||||
|
[] => n * k,
|
||||||
|
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
||||||
|
};
|
||||||
let c_skip: usize = m * n;
|
let c_skip: usize = m * n;
|
||||||
|
|
||||||
let dst_shape: Shape = (m, n).into();
|
let dst_shape: Shape = (m, n).into();
|
||||||
@ -1332,8 +1550,20 @@ impl Map2 for MatMul {
|
|||||||
|
|
||||||
let lhs_stride = lhs_l.stride();
|
let lhs_stride = lhs_l.stride();
|
||||||
let rhs_stride = rhs_l.stride();
|
let rhs_stride = rhs_l.stride();
|
||||||
|
let rank = lhs_stride.len();
|
||||||
|
|
||||||
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
|
let a_skip: usize = match lhs_stride[..rank - 2] {
|
||||||
|
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||||
|
[stride] => stride,
|
||||||
|
[] => m * k,
|
||||||
|
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
||||||
|
};
|
||||||
|
let b_skip: usize = match rhs_stride[..rank - 2] {
|
||||||
|
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||||
|
[stride] => stride,
|
||||||
|
[] => n * k,
|
||||||
|
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
||||||
|
};
|
||||||
let c_skip: usize = m * n;
|
let c_skip: usize = m * n;
|
||||||
|
|
||||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||||
@ -1341,7 +1571,7 @@ impl Map2 for MatMul {
|
|||||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||||
|
|
||||||
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
||||||
(n as i32, b'N')
|
(n as i32, b'N')
|
||||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||||
(k as i32, b'T')
|
(k as i32, b'T')
|
||||||
@ -1349,7 +1579,7 @@ impl Map2 for MatMul {
|
|||||||
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
||||||
};
|
};
|
||||||
// The b tensor has dims batching, m, k (lhs)
|
// The b tensor has dims batching, m, k (lhs)
|
||||||
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
||||||
(k as i32, b'N')
|
(k as i32, b'N')
|
||||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||||
(m as i32, b'T')
|
(m as i32, b'T')
|
||||||
@ -1423,8 +1653,20 @@ impl Map2 for MatMul {
|
|||||||
|
|
||||||
let lhs_stride = lhs_l.stride();
|
let lhs_stride = lhs_l.stride();
|
||||||
let rhs_stride = rhs_l.stride();
|
let rhs_stride = rhs_l.stride();
|
||||||
|
let rank = lhs_stride.len();
|
||||||
|
|
||||||
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
|
let a_skip: usize = match lhs_stride[..rank - 2] {
|
||||||
|
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||||
|
[stride] => stride,
|
||||||
|
[] => m * k,
|
||||||
|
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
||||||
|
};
|
||||||
|
let b_skip: usize = match rhs_stride[..rank - 2] {
|
||||||
|
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||||
|
[stride] => stride,
|
||||||
|
[] => n * k,
|
||||||
|
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
||||||
|
};
|
||||||
let c_skip: usize = m * n;
|
let c_skip: usize = m * n;
|
||||||
|
|
||||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||||
@ -1432,7 +1674,7 @@ impl Map2 for MatMul {
|
|||||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||||
|
|
||||||
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
||||||
(n as i32, b'N')
|
(n as i32, b'N')
|
||||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||||
(k as i32, b'T')
|
(k as i32, b'T')
|
||||||
@ -1440,7 +1682,7 @@ impl Map2 for MatMul {
|
|||||||
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
||||||
};
|
};
|
||||||
// The b tensor has dims batching, m, k (lhs)
|
// The b tensor has dims batching, m, k (lhs)
|
||||||
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
||||||
(k as i32, b'N')
|
(k as i32, b'N')
|
||||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||||
(m as i32, b'T')
|
(m as i32, b'T')
|
||||||
@ -2112,48 +2354,6 @@ impl BackendStorage for CpuStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn copy2d(
|
|
||||||
&self,
|
|
||||||
dst: &mut Self,
|
|
||||||
d1: usize,
|
|
||||||
d2: usize,
|
|
||||||
src_s: usize,
|
|
||||||
dst_s: usize,
|
|
||||||
src_o: usize,
|
|
||||||
dst_o: usize,
|
|
||||||
) -> Result<()> {
|
|
||||||
match (self, dst) {
|
|
||||||
(Self::U8(src), Self::U8(dst)) => copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o),
|
|
||||||
(Self::U32(src), Self::U32(dst)) => {
|
|
||||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
|
||||||
}
|
|
||||||
(Self::I64(src), Self::I64(dst)) => {
|
|
||||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
|
||||||
}
|
|
||||||
(Self::BF16(src), Self::BF16(dst)) => {
|
|
||||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
|
||||||
}
|
|
||||||
(Self::F16(src), Self::F16(dst)) => {
|
|
||||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
|
||||||
}
|
|
||||||
(Self::F32(src), Self::F32(dst)) => {
|
|
||||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
|
||||||
}
|
|
||||||
(Self::F64(src), Self::F64(dst)) => {
|
|
||||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
|
||||||
}
|
|
||||||
(_, dst) => {
|
|
||||||
return Err(Error::DTypeMismatchBinaryOp {
|
|
||||||
lhs: self.dtype(),
|
|
||||||
rhs: dst.dtype(),
|
|
||||||
op: "copy2d",
|
|
||||||
}
|
|
||||||
.bt());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||||
match (self, dst) {
|
match (self, dst) {
|
||||||
(Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
(Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||||
@ -2222,10 +2422,7 @@ impl BackendStorage for CpuStorage {
|
|||||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
} else {
|
} else {
|
||||||
// Make the kernel contiguous if not already the case.
|
// Make the kernel contiguous if not already the case.
|
||||||
let mut kernel_c = unsafe {
|
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||||
self.device()
|
|
||||||
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
|
||||||
};
|
|
||||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
@ -2233,66 +2430,11 @@ impl BackendStorage for CpuStorage {
|
|||||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
};
|
};
|
||||||
let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
|
let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
|
||||||
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
|
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||||
Ok(res_t)
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv_transpose1d(
|
|
||||||
&self,
|
|
||||||
l: &Layout,
|
|
||||||
kernel: &Self,
|
|
||||||
kernel_l: &Layout,
|
|
||||||
params: &crate::conv::ParamsConvTranspose1D,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let can_use_col2im = kernel_l.is_contiguous()
|
|
||||||
&& params.dilation == 1
|
|
||||||
&& params.padding == 0
|
|
||||||
&& params.output_padding == 0;
|
|
||||||
if USE_IM2COL_CONV1D_TR && can_use_col2im {
|
|
||||||
let (b_size, c_in, l_in) = l.shape().dims3()?;
|
|
||||||
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
|
|
||||||
if !kernel_l.is_contiguous() {
|
|
||||||
crate::bail!(
|
|
||||||
"convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if c_in != c_in2 {
|
|
||||||
crate::bail!(
|
|
||||||
"convtr1d: shape mismatch on c_in {:?} {:?}",
|
|
||||||
l.shape(),
|
|
||||||
kernel_l.shape()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
let col = {
|
|
||||||
// This merges the last two dimensions of the kernel together.
|
|
||||||
let kernel_l_mm = Layout::new(
|
|
||||||
(b_size, c_in, k_size * c_out).into(),
|
|
||||||
vec![0, k_size * c_out, 1],
|
|
||||||
kernel_l.start_offset(),
|
|
||||||
);
|
|
||||||
self.matmul(
|
|
||||||
kernel,
|
|
||||||
(
|
|
||||||
b_size,
|
|
||||||
/* m */ l_in,
|
|
||||||
/* n */ c_out * k_size,
|
|
||||||
/* k */ c_in,
|
|
||||||
),
|
|
||||||
&l.transpose(1, 2)?,
|
|
||||||
&kernel_l_mm,
|
|
||||||
)?
|
|
||||||
};
|
|
||||||
let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
|
|
||||||
Col2Im1D {
|
|
||||||
stride: params.stride,
|
|
||||||
}
|
|
||||||
.map(&col, &col_l)
|
|
||||||
} else {
|
|
||||||
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
&self,
|
&self,
|
||||||
l: &Layout,
|
l: &Layout,
|
||||||
@ -2324,10 +2466,7 @@ impl BackendStorage for CpuStorage {
|
|||||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
} else {
|
} else {
|
||||||
// Make the kernel contiguous if not already the case.
|
// Make the kernel contiguous if not already the case.
|
||||||
let mut kernel_c = unsafe {
|
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||||
self.device()
|
|
||||||
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
|
||||||
};
|
|
||||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
@ -2337,7 +2476,7 @@ impl BackendStorage for CpuStorage {
|
|||||||
let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
|
let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
.transpose(1, 3)?;
|
.transpose(1, 3)?;
|
||||||
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
|
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||||
Ok(res_t)
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
@ -2357,7 +2496,7 @@ impl BackendStorage for CpuStorage {
|
|||||||
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||||
Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||||
Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select").bt()),
|
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select")),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2366,7 +2505,7 @@ impl BackendStorage for CpuStorage {
|
|||||||
Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||||
Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||||
Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather").bt()),
|
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather")),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2383,7 +2522,7 @@ impl BackendStorage for CpuStorage {
|
|||||||
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||||
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||||
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()),
|
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add")),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2400,25 +2539,25 @@ impl BackendStorage for CpuStorage {
|
|||||||
Self::U8(ids) => {
|
Self::U8(ids) => {
|
||||||
let ids = match ids_l.contiguous_offsets() {
|
let ids = match ids_l.contiguous_offsets() {
|
||||||
Some((a, b)) => &ids[a..b],
|
Some((a, b)) => &ids[a..b],
|
||||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||||
};
|
};
|
||||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||||
}
|
}
|
||||||
Self::U32(ids) => {
|
Self::U32(ids) => {
|
||||||
let ids = match ids_l.contiguous_offsets() {
|
let ids = match ids_l.contiguous_offsets() {
|
||||||
Some((a, b)) => &ids[a..b],
|
Some((a, b)) => &ids[a..b],
|
||||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||||
};
|
};
|
||||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||||
}
|
}
|
||||||
Self::I64(ids) => {
|
Self::I64(ids) => {
|
||||||
let ids = match ids_l.contiguous_offsets() {
|
let ids = match ids_l.contiguous_offsets() {
|
||||||
Some((a, b)) => &ids[a..b],
|
Some((a, b)) => &ids[a..b],
|
||||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||||
};
|
};
|
||||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||||
}
|
}
|
||||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()),
|
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2456,18 +2595,10 @@ impl BackendDevice for CpuDevice {
|
|||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
|
||||||
Ok(T::to_cpu_storage(s))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result<Self::Storage> {
|
fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result<Self::Storage> {
|
||||||
Ok(s.clone())
|
Ok(s.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn storage_from_cpu_storage_owned(&self, s: CpuStorage) -> Result<Self::Storage> {
|
|
||||||
Ok(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn new(_: usize) -> Result<Self> {
|
fn new(_: usize) -> Result<Self> {
|
||||||
Ok(Self)
|
Ok(Self)
|
||||||
}
|
}
|
||||||
@ -2569,53 +2700,6 @@ impl BackendDevice for CpuDevice {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::uninit_vec)]
|
|
||||||
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
|
||||||
let elem_count = shape.elem_count();
|
|
||||||
// The code below is highly unsafe but hopefully not directly unsound as we only consider
|
|
||||||
// types that are Copy, not Drop, and for which all bit patterns are proper values.
|
|
||||||
// It's still pretty risky, see the following for more details:
|
|
||||||
// https://github.com/rust-lang/rust-clippy/issues/4483
|
|
||||||
let storage = match dtype {
|
|
||||||
DType::U8 => {
|
|
||||||
let mut v = Vec::with_capacity(elem_count);
|
|
||||||
v.set_len(elem_count);
|
|
||||||
CpuStorage::U8(v)
|
|
||||||
}
|
|
||||||
DType::U32 => {
|
|
||||||
let mut v = Vec::with_capacity(elem_count);
|
|
||||||
v.set_len(elem_count);
|
|
||||||
CpuStorage::U32(v)
|
|
||||||
}
|
|
||||||
DType::I64 => {
|
|
||||||
let mut v = Vec::with_capacity(elem_count);
|
|
||||||
v.set_len(elem_count);
|
|
||||||
CpuStorage::I64(v)
|
|
||||||
}
|
|
||||||
DType::BF16 => {
|
|
||||||
let mut v = Vec::with_capacity(elem_count);
|
|
||||||
v.set_len(elem_count);
|
|
||||||
CpuStorage::BF16(v)
|
|
||||||
}
|
|
||||||
DType::F16 => {
|
|
||||||
let mut v = Vec::with_capacity(elem_count);
|
|
||||||
v.set_len(elem_count);
|
|
||||||
CpuStorage::F16(v)
|
|
||||||
}
|
|
||||||
DType::F32 => {
|
|
||||||
let mut v = Vec::with_capacity(elem_count);
|
|
||||||
v.set_len(elem_count);
|
|
||||||
CpuStorage::F32(v)
|
|
||||||
}
|
|
||||||
DType::F64 => {
|
|
||||||
let mut v = Vec::with_capacity(elem_count);
|
|
||||||
v.set_len(elem_count);
|
|
||||||
CpuStorage::F64(v)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(storage)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
let storage = match dtype {
|
let storage = match dtype {
|
||||||
@ -2643,10 +2727,6 @@ impl BackendDevice for CpuDevice {
|
|||||||
};
|
};
|
||||||
Ok(storage)
|
Ok(storage)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn synchronize(&self) -> Result<()> {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
@ -1,350 +0,0 @@
|
|||||||
/// Helper functions to write CPU kernels.
|
|
||||||
use crate::backend::BackendStorage;
|
|
||||||
use crate::{Error, Layout, Result, WithDType};
|
|
||||||
|
|
||||||
type C = super::CpuStorage;
|
|
||||||
pub trait Map1 {
|
|
||||||
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
|
|
||||||
|
|
||||||
fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
|
|
||||||
match vs {
|
|
||||||
C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)),
|
|
||||||
C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)),
|
|
||||||
C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)),
|
|
||||||
C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)),
|
|
||||||
C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)),
|
|
||||||
C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)),
|
|
||||||
C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Map1Any {
|
|
||||||
fn f<T: WithDType, W: Fn(Vec<T>) -> C>(&self, vs: &[T], layout: &Layout, wrap: W) -> Result<C>;
|
|
||||||
|
|
||||||
fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
|
|
||||||
match vs {
|
|
||||||
C::U8(vs) => Ok(self.f(vs, layout, C::U8)?),
|
|
||||||
C::U32(vs) => Ok(self.f(vs, layout, C::U32)?),
|
|
||||||
C::I64(vs) => Ok(self.f(vs, layout, C::I64)?),
|
|
||||||
C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?),
|
|
||||||
C::F16(vs) => Ok(self.f(vs, layout, C::F16)?),
|
|
||||||
C::F32(vs) => Ok(self.f(vs, layout, C::F32)?),
|
|
||||||
C::F64(vs) => Ok(self.f(vs, layout, C::F64)?),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Map2 {
|
|
||||||
const OP: &'static str;
|
|
||||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
|
|
||||||
|
|
||||||
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
|
|
||||||
match (v1, v2) {
|
|
||||||
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
|
|
||||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
|
||||||
lhs: v1.dtype(),
|
|
||||||
rhs: v2.dtype(),
|
|
||||||
op: Self::OP,
|
|
||||||
}
|
|
||||||
.bt()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Map2U8 {
|
|
||||||
const OP: &'static str;
|
|
||||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
|
|
||||||
|
|
||||||
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
|
|
||||||
match (v1, v2) {
|
|
||||||
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
|
||||||
lhs: v1.dtype(),
|
|
||||||
rhs: v2.dtype(),
|
|
||||||
op: Self::OP,
|
|
||||||
}
|
|
||||||
.bt()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
|
|
||||||
lhs_l: &Layout,
|
|
||||||
rhs_l: &Layout,
|
|
||||||
lhs: &[T],
|
|
||||||
rhs: &[T],
|
|
||||||
mut f: F,
|
|
||||||
) -> Vec<U> {
|
|
||||||
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
|
||||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
|
|
||||||
.iter()
|
|
||||||
.zip(rhs[o_r1..o_r2].iter())
|
|
||||||
.map(|(&l, &r)| f(l, r))
|
|
||||||
.collect(),
|
|
||||||
(Some((o_l1, o_l2)), None) => {
|
|
||||||
// TODO: Maybe we want to avoid going through the layout twice.
|
|
||||||
match rhs_l.offsets_b() {
|
|
||||||
Some(ob) => {
|
|
||||||
let mut i_in_block = 0;
|
|
||||||
let mut i_right_broadcast = 0;
|
|
||||||
lhs[o_l1..o_l2]
|
|
||||||
.iter()
|
|
||||||
.map(|&l| {
|
|
||||||
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
|
|
||||||
i_right_broadcast += 1;
|
|
||||||
if i_right_broadcast >= ob.right_broadcast {
|
|
||||||
i_in_block += 1;
|
|
||||||
i_right_broadcast = 0;
|
|
||||||
}
|
|
||||||
if i_in_block >= ob.len {
|
|
||||||
i_in_block = 0
|
|
||||||
}
|
|
||||||
f(l, *r)
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
None => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(None, Some((o_r1, o_r2))) => {
|
|
||||||
// TODO: Maybe we want to avoid going through the layout twice.
|
|
||||||
match lhs_l.offsets_b() {
|
|
||||||
Some(ob) => {
|
|
||||||
let mut i_in_block = 0;
|
|
||||||
let mut i_right_broadcast = 0;
|
|
||||||
rhs[o_r1..o_r2]
|
|
||||||
.iter()
|
|
||||||
.map(|&r| {
|
|
||||||
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
|
|
||||||
i_right_broadcast += 1;
|
|
||||||
if i_right_broadcast >= ob.right_broadcast {
|
|
||||||
i_in_block += 1;
|
|
||||||
i_right_broadcast = 0;
|
|
||||||
}
|
|
||||||
if i_in_block >= ob.len {
|
|
||||||
i_in_block = 0
|
|
||||||
}
|
|
||||||
f(*l, r)
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
None => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Similar to binary_map but with vectorized variants.
|
|
||||||
pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
|
|
||||||
lhs_l: &Layout,
|
|
||||||
rhs_l: &Layout,
|
|
||||||
lhs: &[T],
|
|
||||||
rhs: &[T],
|
|
||||||
mut f: F,
|
|
||||||
mut f_vec: FV,
|
|
||||||
) -> Vec<T> {
|
|
||||||
let el_count = lhs_l.shape().elem_count();
|
|
||||||
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
|
||||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
|
|
||||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
|
||||||
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
|
|
||||||
// SAFETY: values are all set by f_vec.
|
|
||||||
unsafe { ys.set_len(el_count) };
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
|
|
||||||
Some(ob) if ob.right_broadcast == 1 => {
|
|
||||||
let rhs = &rhs[ob.start..ob.start + ob.len];
|
|
||||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
|
||||||
let mut dst_i = 0;
|
|
||||||
for src_i in (o_l1..o_l2).step_by(ob.len) {
|
|
||||||
f_vec(
|
|
||||||
&lhs[src_i..src_i + ob.len],
|
|
||||||
rhs,
|
|
||||||
&mut ys_to_set[dst_i..dst_i + ob.len],
|
|
||||||
);
|
|
||||||
dst_i += ob.len;
|
|
||||||
}
|
|
||||||
// SAFETY: values are all set by f_vec.
|
|
||||||
unsafe { ys.set_len(el_count) };
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
Some(ob) => {
|
|
||||||
let rhs = &rhs[ob.start..ob.start + ob.len];
|
|
||||||
let mut ys = lhs[o_l1..o_l2].to_vec();
|
|
||||||
for idx_l in 0..ob.left_broadcast {
|
|
||||||
let start = idx_l * ob.len * ob.right_broadcast;
|
|
||||||
for (i, &r) in rhs.iter().enumerate() {
|
|
||||||
let start = start + i * ob.right_broadcast;
|
|
||||||
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
|
||||||
*v = f(*v, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
None => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
},
|
|
||||||
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
|
|
||||||
Some(ob) if ob.right_broadcast == 1 => {
|
|
||||||
let lhs = &lhs[ob.start..ob.start + ob.len];
|
|
||||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
|
||||||
let mut dst_i = 0;
|
|
||||||
for src_i in (o_r1..o_r2).step_by(ob.len) {
|
|
||||||
f_vec(
|
|
||||||
lhs,
|
|
||||||
&rhs[src_i..src_i + ob.len],
|
|
||||||
&mut ys_to_set[dst_i..dst_i + ob.len],
|
|
||||||
);
|
|
||||||
dst_i += ob.len;
|
|
||||||
}
|
|
||||||
// SAFETY: values are all set by f_vec.
|
|
||||||
unsafe { ys.set_len(el_count) };
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
Some(ob) => {
|
|
||||||
let lhs = &lhs[ob.start..ob.start + ob.len];
|
|
||||||
let mut ys = rhs[o_r1..o_r2].to_vec();
|
|
||||||
for idx_l in 0..ob.left_broadcast {
|
|
||||||
let start = idx_l * ob.len * ob.right_broadcast;
|
|
||||||
for (i, &l) in lhs.iter().enumerate() {
|
|
||||||
let start = start + i * ob.right_broadcast;
|
|
||||||
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
|
||||||
*v = f(l, *v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
None => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
},
|
|
||||||
_ => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
|
||||||
vs: &[T],
|
|
||||||
layout: &Layout,
|
|
||||||
mut f: F,
|
|
||||||
) -> Vec<U> {
|
|
||||||
match layout.strided_blocks() {
|
|
||||||
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
|
|
||||||
[start_offset..start_offset + len]
|
|
||||||
.iter()
|
|
||||||
.map(|&v| f(v))
|
|
||||||
.collect(),
|
|
||||||
crate::StridedBlocks::MultipleBlocks {
|
|
||||||
block_start_index,
|
|
||||||
block_len,
|
|
||||||
} => {
|
|
||||||
let mut result = Vec::with_capacity(layout.shape().elem_count());
|
|
||||||
// Specialize the case where block_len is one to avoid the second loop.
|
|
||||||
if block_len == 1 {
|
|
||||||
for index in block_start_index {
|
|
||||||
let v = unsafe { vs.get_unchecked(index) };
|
|
||||||
result.push(f(*v))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for index in block_start_index {
|
|
||||||
for offset in 0..block_len {
|
|
||||||
let v = unsafe { vs.get_unchecked(index + offset) };
|
|
||||||
result.push(f(*v))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
|
|
||||||
vs: &[T],
|
|
||||||
layout: &Layout,
|
|
||||||
mut f: F,
|
|
||||||
mut f_vec: FV,
|
|
||||||
) -> Vec<U> {
|
|
||||||
match layout.strided_blocks() {
|
|
||||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
|
||||||
let mut ys: Vec<U> = Vec::with_capacity(len);
|
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
|
||||||
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
|
|
||||||
// SAFETY: values are all set by f_vec.
|
|
||||||
unsafe { ys.set_len(len) };
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
crate::StridedBlocks::MultipleBlocks {
|
|
||||||
block_start_index,
|
|
||||||
block_len,
|
|
||||||
} => {
|
|
||||||
let el_count = layout.shape().elem_count();
|
|
||||||
// Specialize the case where block_len is one to avoid the second loop.
|
|
||||||
if block_len == 1 {
|
|
||||||
let mut result = Vec::with_capacity(el_count);
|
|
||||||
for index in block_start_index {
|
|
||||||
let v = unsafe { vs.get_unchecked(index) };
|
|
||||||
result.push(f(*v))
|
|
||||||
}
|
|
||||||
result
|
|
||||||
} else {
|
|
||||||
let mut ys: Vec<U> = Vec::with_capacity(el_count);
|
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
|
||||||
let mut dst_index = 0;
|
|
||||||
for src_index in block_start_index {
|
|
||||||
let vs = &vs[src_index..src_index + block_len];
|
|
||||||
let ys = &mut ys_to_set[dst_index..dst_index + block_len];
|
|
||||||
f_vec(vs, ys);
|
|
||||||
dst_index += block_len;
|
|
||||||
}
|
|
||||||
// SAFETY: values are all set by f_vec.
|
|
||||||
unsafe { ys.set_len(el_count) };
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
File diff suppressed because it is too large
Load Diff
@ -1,452 +0,0 @@
|
|||||||
use crate::backend::BackendDevice;
|
|
||||||
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
|
||||||
pub use candle_kernels as kernels;
|
|
||||||
pub use cudarc;
|
|
||||||
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
|
|
||||||
use half::{bf16, f16};
|
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
|
|
||||||
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
|
|
||||||
|
|
||||||
/// Unique identifier for cuda devices.
|
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
|
||||||
pub struct DeviceId(usize);
|
|
||||||
|
|
||||||
impl DeviceId {
|
|
||||||
fn new() -> Self {
|
|
||||||
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
|
||||||
use std::sync::atomic;
|
|
||||||
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
|
||||||
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct CudaRng(cudarc::curand::CudaRng);
|
|
||||||
unsafe impl Send for CudaRng {}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct CudaDevice {
|
|
||||||
id: DeviceId,
|
|
||||||
device: Arc<cudarc::driver::CudaDevice>,
|
|
||||||
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
|
|
||||||
curand: Arc<Mutex<CudaRng>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Debug for CudaDevice {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(f, "CudaDevice({:?})", self.id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::ops::Deref for CudaDevice {
|
|
||||||
type Target = Arc<cudarc::driver::CudaDevice>;
|
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CudaDevice {
|
|
||||||
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
|
|
||||||
self.device.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn id(&self) -> DeviceId {
|
|
||||||
self.id
|
|
||||||
}
|
|
||||||
|
|
||||||
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
|
||||||
let elem_count = shape.elem_count();
|
|
||||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
|
||||||
let slice = match dtype {
|
|
||||||
DType::U8 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
|
|
||||||
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
|
|
||||||
let params = (&data, v as u8, elem_count);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
CudaStorageSlice::U8(data)
|
|
||||||
}
|
|
||||||
DType::U32 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
|
|
||||||
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
|
|
||||||
let params = (&data, v as u32, elem_count);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
CudaStorageSlice::U32(data)
|
|
||||||
}
|
|
||||||
DType::I64 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
|
|
||||||
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
|
|
||||||
let params = (&data, v as i64, elem_count);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
CudaStorageSlice::I64(data)
|
|
||||||
}
|
|
||||||
DType::BF16 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
|
|
||||||
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
|
|
||||||
let params = (&data, bf16::from_f64(v), elem_count);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
CudaStorageSlice::BF16(data)
|
|
||||||
}
|
|
||||||
DType::F16 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
|
|
||||||
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
|
|
||||||
let params = (&data, f16::from_f64(v), elem_count);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
CudaStorageSlice::F16(data)
|
|
||||||
}
|
|
||||||
DType::F32 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
|
||||||
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
|
|
||||||
let params = (&data, v as f32, elem_count);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
CudaStorageSlice::F32(data)
|
|
||||||
}
|
|
||||||
DType::F64 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
|
||||||
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
|
|
||||||
let params = (&data, v, elem_count);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
CudaStorageSlice::F64(data)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(CudaStorage {
|
|
||||||
slice,
|
|
||||||
device: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
|
|
||||||
if !self.has_func(module_name, module_name) {
|
|
||||||
// Leaking the string here is a bit sad but we need a &'static str and this is only
|
|
||||||
// done once per kernel name.
|
|
||||||
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
|
|
||||||
self.load_ptx(ptx.into(), module_name, &[static_module_name])
|
|
||||||
.map_err(|cuda| CudaError::Load {
|
|
||||||
cuda,
|
|
||||||
module_name: module_name.to_string(),
|
|
||||||
})
|
|
||||||
.w()?;
|
|
||||||
}
|
|
||||||
self.get_func(module_name, module_name)
|
|
||||||
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
|
|
||||||
// able to only build the error value if needed.
|
|
||||||
.ok_or(CudaError::MissingKernel {
|
|
||||||
module_name: module_name.to_string(),
|
|
||||||
})
|
|
||||||
.w()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl BackendDevice for CudaDevice {
|
|
||||||
type Storage = CudaStorage;
|
|
||||||
|
|
||||||
fn new(ordinal: usize) -> Result<Self> {
|
|
||||||
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
|
|
||||||
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
|
|
||||||
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
|
|
||||||
Ok(Self {
|
|
||||||
id: DeviceId::new(),
|
|
||||||
device,
|
|
||||||
blas: Arc::new(blas),
|
|
||||||
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_seed(&self, seed: u64) -> Result<()> {
|
|
||||||
// We do not call set_seed but instead create a new curand object. This ensures that the
|
|
||||||
// state will be identical and the same random numbers will be generated.
|
|
||||||
let mut curand = self.curand.lock().unwrap();
|
|
||||||
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn location(&self) -> crate::DeviceLocation {
|
|
||||||
crate::DeviceLocation::Cuda {
|
|
||||||
gpu_id: self.device.ordinal(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn same_device(&self, rhs: &Self) -> bool {
|
|
||||||
self.id == rhs.id
|
|
||||||
}
|
|
||||||
|
|
||||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
|
||||||
let elem_count = shape.elem_count();
|
|
||||||
let slice = match dtype {
|
|
||||||
DType::U8 => {
|
|
||||||
let data = self.alloc_zeros::<u8>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::U8(data)
|
|
||||||
}
|
|
||||||
DType::U32 => {
|
|
||||||
let data = self.alloc_zeros::<u32>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::U32(data)
|
|
||||||
}
|
|
||||||
DType::I64 => {
|
|
||||||
let data = self.alloc_zeros::<i64>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::I64(data)
|
|
||||||
}
|
|
||||||
DType::BF16 => {
|
|
||||||
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::BF16(data)
|
|
||||||
}
|
|
||||||
DType::F16 => {
|
|
||||||
let data = self.alloc_zeros::<f16>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::F16(data)
|
|
||||||
}
|
|
||||||
DType::F32 => {
|
|
||||||
let data = self.alloc_zeros::<f32>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::F32(data)
|
|
||||||
}
|
|
||||||
DType::F64 => {
|
|
||||||
let data = self.alloc_zeros::<f64>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::F64(data)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(CudaStorage {
|
|
||||||
slice,
|
|
||||||
device: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result<CudaStorage> {
|
|
||||||
let elem_count = shape.elem_count();
|
|
||||||
let curand = self.curand.lock().unwrap();
|
|
||||||
let slice = match dtype {
|
|
||||||
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
|
||||||
// cudarc changes.
|
|
||||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
|
||||||
Err(CudaError::UnsupportedDtype {
|
|
||||||
dtype,
|
|
||||||
op: "rand_uniform",
|
|
||||||
})
|
|
||||||
.w()?
|
|
||||||
}
|
|
||||||
DType::F32 => {
|
|
||||||
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
|
||||||
curand.0.fill_with_uniform(&mut data).w()?;
|
|
||||||
CudaStorageSlice::F32(data)
|
|
||||||
}
|
|
||||||
DType::F64 => {
|
|
||||||
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
|
||||||
curand.0.fill_with_uniform(&mut data).w()?;
|
|
||||||
CudaStorageSlice::F64(data)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let slice = if lo == 0. && up == 1.0 {
|
|
||||||
slice
|
|
||||||
} else {
|
|
||||||
use super::utils::Map1;
|
|
||||||
let layout = Layout::contiguous(shape);
|
|
||||||
super::Affine(up - lo, lo).map(&slice, self, &layout)?
|
|
||||||
};
|
|
||||||
Ok(CudaStorage {
|
|
||||||
slice,
|
|
||||||
device: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
|
|
||||||
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
|
||||||
// cudarc changes.
|
|
||||||
let elem_count = shape.elem_count();
|
|
||||||
let curand = self.curand.lock().unwrap();
|
|
||||||
// curand can only generate an odd number of values.
|
|
||||||
// https://github.com/huggingface/candle/issues/734
|
|
||||||
let elem_count_round = if elem_count % 2 == 1 {
|
|
||||||
elem_count + 1
|
|
||||||
} else {
|
|
||||||
elem_count
|
|
||||||
};
|
|
||||||
let slice = match dtype {
|
|
||||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
|
||||||
Err(CudaError::UnsupportedDtype {
|
|
||||||
dtype,
|
|
||||||
op: "rand_normal",
|
|
||||||
})
|
|
||||||
.w()?
|
|
||||||
}
|
|
||||||
DType::F32 => {
|
|
||||||
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
|
|
||||||
curand
|
|
||||||
.0
|
|
||||||
.fill_with_normal(&mut data, mean as f32, std as f32)
|
|
||||||
.w()?;
|
|
||||||
CudaStorageSlice::F32(data)
|
|
||||||
}
|
|
||||||
DType::F64 => {
|
|
||||||
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
|
|
||||||
curand.0.fill_with_normal(&mut data, mean, std).w()?;
|
|
||||||
CudaStorageSlice::F64(data)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(CudaStorage {
|
|
||||||
slice,
|
|
||||||
device: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
|
||||||
self.const_impl(1., shape, dtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
|
||||||
let elem_count = shape.elem_count();
|
|
||||||
let slice = match dtype {
|
|
||||||
DType::U8 => {
|
|
||||||
let data = self.alloc::<u8>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::U8(data)
|
|
||||||
}
|
|
||||||
DType::U32 => {
|
|
||||||
let data = self.alloc::<u32>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::U32(data)
|
|
||||||
}
|
|
||||||
DType::I64 => {
|
|
||||||
let data = self.alloc::<i64>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::I64(data)
|
|
||||||
}
|
|
||||||
DType::BF16 => {
|
|
||||||
let data = self.alloc::<bf16>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::BF16(data)
|
|
||||||
}
|
|
||||||
DType::F16 => {
|
|
||||||
let data = self.alloc::<f16>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::F16(data)
|
|
||||||
}
|
|
||||||
DType::F32 => {
|
|
||||||
let data = self.alloc::<f32>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::F32(data)
|
|
||||||
}
|
|
||||||
DType::F64 => {
|
|
||||||
let data = self.alloc::<f64>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::F64(data)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(CudaStorage {
|
|
||||||
slice,
|
|
||||||
device: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
|
||||||
let slice = match T::cpu_storage_ref(s) {
|
|
||||||
CpuStorageRef::U8(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::U8(data)
|
|
||||||
}
|
|
||||||
CpuStorageRef::U32(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::U32(data)
|
|
||||||
}
|
|
||||||
CpuStorageRef::I64(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::I64(data)
|
|
||||||
}
|
|
||||||
CpuStorageRef::BF16(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::BF16(data)
|
|
||||||
}
|
|
||||||
CpuStorageRef::F16(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::F16(data)
|
|
||||||
}
|
|
||||||
CpuStorageRef::F32(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::F32(data)
|
|
||||||
}
|
|
||||||
CpuStorageRef::F64(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::F64(data)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(CudaStorage {
|
|
||||||
slice,
|
|
||||||
device: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
|
||||||
let slice = match storage {
|
|
||||||
CpuStorage::U8(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::U8(data)
|
|
||||||
}
|
|
||||||
CpuStorage::U32(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::U32(data)
|
|
||||||
}
|
|
||||||
CpuStorage::I64(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::I64(data)
|
|
||||||
}
|
|
||||||
CpuStorage::BF16(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::BF16(data)
|
|
||||||
}
|
|
||||||
CpuStorage::F16(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::F16(data)
|
|
||||||
}
|
|
||||||
CpuStorage::F32(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::F32(data)
|
|
||||||
}
|
|
||||||
CpuStorage::F64(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::F64(data)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(CudaStorage {
|
|
||||||
slice,
|
|
||||||
device: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
|
|
||||||
let slice = match storage {
|
|
||||||
CpuStorage::U8(storage) => {
|
|
||||||
let data = self.htod_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::U8(data)
|
|
||||||
}
|
|
||||||
CpuStorage::U32(storage) => {
|
|
||||||
let data = self.htod_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::U32(data)
|
|
||||||
}
|
|
||||||
CpuStorage::I64(storage) => {
|
|
||||||
let data = self.htod_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::I64(data)
|
|
||||||
}
|
|
||||||
CpuStorage::BF16(storage) => {
|
|
||||||
let data = self.htod_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::BF16(data)
|
|
||||||
}
|
|
||||||
CpuStorage::F16(storage) => {
|
|
||||||
let data = self.htod_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::F16(data)
|
|
||||||
}
|
|
||||||
CpuStorage::F32(storage) => {
|
|
||||||
let data = self.htod_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::F32(data)
|
|
||||||
}
|
|
||||||
CpuStorage::F64(storage) => {
|
|
||||||
let data = self.htod_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::F64(data)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(CudaStorage {
|
|
||||||
slice,
|
|
||||||
device: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn synchronize(&self) -> Result<()> {
|
|
||||||
self.device.synchronize().map_err(crate::Error::wrap)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,62 +0,0 @@
|
|||||||
use crate::{DType, Layout};
|
|
||||||
|
|
||||||
/// cudarc related errors
|
|
||||||
#[derive(thiserror::Error, Debug)]
|
|
||||||
pub enum CudaError {
|
|
||||||
#[error(transparent)]
|
|
||||||
Cuda(#[from] cudarc::driver::DriverError),
|
|
||||||
|
|
||||||
#[error(transparent)]
|
|
||||||
Compiler(#[from] cudarc::nvrtc::CompileError),
|
|
||||||
|
|
||||||
#[error(transparent)]
|
|
||||||
Cublas(#[from] cudarc::cublas::result::CublasError),
|
|
||||||
|
|
||||||
#[error(transparent)]
|
|
||||||
Curand(#[from] cudarc::curand::result::CurandError),
|
|
||||||
|
|
||||||
#[error("missing kernel '{module_name}'")]
|
|
||||||
MissingKernel { module_name: String },
|
|
||||||
|
|
||||||
#[error("unsupported dtype {dtype:?} for {op}")]
|
|
||||||
UnsupportedDtype { dtype: DType, op: &'static str },
|
|
||||||
|
|
||||||
#[error("internal error '{0}'")]
|
|
||||||
InternalError(&'static str),
|
|
||||||
|
|
||||||
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
|
||||||
MatMulNonContiguous {
|
|
||||||
lhs_stride: Layout,
|
|
||||||
rhs_stride: Layout,
|
|
||||||
mnk: (usize, usize, usize),
|
|
||||||
},
|
|
||||||
|
|
||||||
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
|
||||||
UnexpectedDType {
|
|
||||||
msg: &'static str,
|
|
||||||
expected: DType,
|
|
||||||
got: DType,
|
|
||||||
},
|
|
||||||
|
|
||||||
#[error("{cuda} when loading {module_name}")]
|
|
||||||
Load {
|
|
||||||
cuda: cudarc::driver::DriverError,
|
|
||||||
module_name: String,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<CudaError> for crate::Error {
|
|
||||||
fn from(val: CudaError) -> Self {
|
|
||||||
crate::Error::Cuda(Box::new(val)).bt()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait WrapErr<O> {
|
|
||||||
fn w(self) -> std::result::Result<O, crate::Error>;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<O, E: Into<CudaError>> WrapErr<O> for std::result::Result<O, E> {
|
|
||||||
fn w(self) -> std::result::Result<O, crate::Error> {
|
|
||||||
self.map_err(|e| crate::Error::Cuda(Box::new(e.into())).bt())
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,134 +0,0 @@
|
|||||||
/// Helper functions to plug cuda kernels in candle.
|
|
||||||
use crate::{Layout, Result, Shape, WithDType};
|
|
||||||
pub use cudarc;
|
|
||||||
use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits};
|
|
||||||
|
|
||||||
use super::{CudaDevice, CudaError, WrapErr};
|
|
||||||
|
|
||||||
pub type S = super::CudaStorageSlice;
|
|
||||||
|
|
||||||
pub trait Map1 {
|
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
||||||
&self,
|
|
||||||
src: &CudaSlice<T>,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
layout: &Layout,
|
|
||||||
) -> Result<CudaSlice<T>>;
|
|
||||||
|
|
||||||
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
|
||||||
let out = match s {
|
|
||||||
S::U8(s) => S::U8(self.f(s, d, l)?),
|
|
||||||
S::U32(s) => S::U32(self.f(s, d, l)?),
|
|
||||||
S::I64(s) => S::I64(self.f(s, d, l)?),
|
|
||||||
S::BF16(s) => S::BF16(self.f(s, d, l)?),
|
|
||||||
S::F16(s) => S::F16(self.f(s, d, l)?),
|
|
||||||
S::F32(s) => S::F32(self.f(s, d, l)?),
|
|
||||||
S::F64(s) => S::F64(self.f(s, d, l)?),
|
|
||||||
};
|
|
||||||
Ok(out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Map2 {
|
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
||||||
&self,
|
|
||||||
src1: &CudaSlice<T>,
|
|
||||||
layout1: &Layout,
|
|
||||||
src2: &CudaSlice<T>,
|
|
||||||
layout2: &Layout,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
) -> Result<CudaSlice<T>>;
|
|
||||||
|
|
||||||
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
|
||||||
let out = match (s1, s2) {
|
|
||||||
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
|
|
||||||
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
|
|
||||||
(S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?),
|
|
||||||
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
|
|
||||||
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
|
|
||||||
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
|
|
||||||
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
|
|
||||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
|
||||||
};
|
|
||||||
Ok(out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Map2InPlace {
|
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
||||||
&self,
|
|
||||||
dst: &mut CudaSlice<T>,
|
|
||||||
dst_shape: &Shape,
|
|
||||||
src: &CudaSlice<T>,
|
|
||||||
src_l: &Layout,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
) -> Result<()>;
|
|
||||||
|
|
||||||
fn map(
|
|
||||||
&self,
|
|
||||||
dst: &mut S,
|
|
||||||
dst_s: &Shape,
|
|
||||||
src: &S,
|
|
||||||
src_l: &Layout,
|
|
||||||
d: &CudaDevice,
|
|
||||||
) -> Result<()> {
|
|
||||||
match (dst, src) {
|
|
||||||
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
|
|
||||||
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
|
|
||||||
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
|
|
||||||
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
|
|
||||||
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
|
|
||||||
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
|
|
||||||
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
|
|
||||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Map1Any {
|
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
|
||||||
&self,
|
|
||||||
src: &CudaSlice<T>,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
layout: &Layout,
|
|
||||||
wrap: W,
|
|
||||||
) -> Result<S>;
|
|
||||||
|
|
||||||
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
|
||||||
let out = match s {
|
|
||||||
S::U8(s) => self.f(s, d, l, S::U8)?,
|
|
||||||
S::U32(s) => self.f(s, d, l, S::U32)?,
|
|
||||||
S::I64(s) => self.f(s, d, l, S::I64)?,
|
|
||||||
S::BF16(s) => self.f(s, d, l, S::BF16)?,
|
|
||||||
S::F16(s) => self.f(s, d, l, S::F16)?,
|
|
||||||
S::F32(s) => self.f(s, d, l, S::F32)?,
|
|
||||||
S::F64(s) => self.f(s, d, l, S::F64)?,
|
|
||||||
};
|
|
||||||
Ok(out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Map2Any {
|
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
||||||
&self,
|
|
||||||
src1: &CudaSlice<T>,
|
|
||||||
layout1: &Layout,
|
|
||||||
src2: &CudaSlice<T>,
|
|
||||||
layout2: &Layout,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
) -> Result<S>;
|
|
||||||
|
|
||||||
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
|
||||||
let out = match (s1, s2) {
|
|
||||||
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
|
|
||||||
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
|
||||||
(S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
|
||||||
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
|
||||||
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
|
||||||
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
|
||||||
(S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
|
||||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?,
|
|
||||||
};
|
|
||||||
Ok(out)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,377 +0,0 @@
|
|||||||
use crate::op::{BackpropOp, Op};
|
|
||||||
use crate::tensor::from_storage;
|
|
||||||
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
/// Unary ops that can be defined in user-land.
|
|
||||||
pub trait CustomOp1 {
|
|
||||||
// Box<dyn> does not support const yet, so use a function to get the name.
|
|
||||||
fn name(&self) -> &'static str;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
|
|
||||||
Err(crate::Error::Cuda(
|
|
||||||
format!("no cuda implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn metal_fwd(
|
|
||||||
&self,
|
|
||||||
_storage: &MetalStorage,
|
|
||||||
_layout: &Layout,
|
|
||||||
) -> Result<(MetalStorage, Shape)> {
|
|
||||||
Err(crate::Error::Metal(
|
|
||||||
format!("no metal implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
|
||||||
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
|
||||||
/// The function should return the gradient of the argument.
|
|
||||||
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
|
|
||||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait CustomOp2 {
|
|
||||||
fn name(&self) -> &'static str;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cpu_fwd(
|
|
||||||
&self,
|
|
||||||
s1: &CpuStorage,
|
|
||||||
l1: &Layout,
|
|
||||||
s2: &CpuStorage,
|
|
||||||
l2: &Layout,
|
|
||||||
) -> Result<(CpuStorage, Shape)>;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cuda_fwd(
|
|
||||||
&self,
|
|
||||||
_: &CudaStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &CudaStorage,
|
|
||||||
_: &Layout,
|
|
||||||
) -> Result<(CudaStorage, Shape)> {
|
|
||||||
Err(crate::Error::Cuda(
|
|
||||||
format!("no cuda implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn metal_fwd(
|
|
||||||
&self,
|
|
||||||
_: &MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
) -> Result<(MetalStorage, Shape)> {
|
|
||||||
Err(crate::Error::Metal(
|
|
||||||
format!("no metal implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn bwd(
|
|
||||||
&self,
|
|
||||||
_arg1: &Tensor,
|
|
||||||
_arg2: &Tensor,
|
|
||||||
_res: &Tensor,
|
|
||||||
_grad_res: &Tensor,
|
|
||||||
) -> Result<(Option<Tensor>, Option<Tensor>)> {
|
|
||||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait CustomOp3 {
|
|
||||||
fn name(&self) -> &'static str;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cpu_fwd(
|
|
||||||
&self,
|
|
||||||
s1: &CpuStorage,
|
|
||||||
l1: &Layout,
|
|
||||||
s2: &CpuStorage,
|
|
||||||
l2: &Layout,
|
|
||||||
s3: &CpuStorage,
|
|
||||||
l3: &Layout,
|
|
||||||
) -> Result<(CpuStorage, Shape)>;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cuda_fwd(
|
|
||||||
&self,
|
|
||||||
_: &CudaStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &CudaStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &CudaStorage,
|
|
||||||
_: &Layout,
|
|
||||||
) -> Result<(CudaStorage, Shape)> {
|
|
||||||
Err(crate::Error::Cuda(
|
|
||||||
format!("no cuda implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn metal_fwd(
|
|
||||||
&self,
|
|
||||||
_: &MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
) -> Result<(MetalStorage, Shape)> {
|
|
||||||
Err(crate::Error::Metal(
|
|
||||||
format!("no metal implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn bwd(
|
|
||||||
&self,
|
|
||||||
_arg1: &Tensor,
|
|
||||||
_arg2: &Tensor,
|
|
||||||
_arg3: &Tensor,
|
|
||||||
_res: &Tensor,
|
|
||||||
_grad_res: &Tensor,
|
|
||||||
) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
|
|
||||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Tensor {
|
|
||||||
/// Applies a unary custom op without backward support
|
|
||||||
pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
|
|
||||||
let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
|
|
||||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a binary custom op without backward support
|
|
||||||
pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
|
|
||||||
let (storage, shape) =
|
|
||||||
self.storage()
|
|
||||||
.apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
|
|
||||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a ternary custom op without backward support
|
|
||||||
pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
|
|
||||||
let (storage, shape) = self.storage().apply_op3(
|
|
||||||
self.layout(),
|
|
||||||
&t2.storage(),
|
|
||||||
t2.layout(),
|
|
||||||
&t3.storage(),
|
|
||||||
t3.layout(),
|
|
||||||
c,
|
|
||||||
)?;
|
|
||||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a unary custom op.
|
|
||||||
pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
|
|
||||||
let (storage, shape) = self
|
|
||||||
.storage()
|
|
||||||
.apply_op1(self.layout(), c.as_ref().as_ref())?;
|
|
||||||
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
|
|
||||||
Ok(from_storage(storage, shape, op, false))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
|
|
||||||
self.apply_op1_arc(Arc::new(Box::new(c)))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a binary custom op.
|
|
||||||
pub fn apply_op2_arc(
|
|
||||||
&self,
|
|
||||||
rhs: &Self,
|
|
||||||
c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let (storage, shape) = self.storage().apply_op2(
|
|
||||||
self.layout(),
|
|
||||||
&rhs.storage(),
|
|
||||||
rhs.layout(),
|
|
||||||
c.as_ref().as_ref(),
|
|
||||||
)?;
|
|
||||||
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
|
|
||||||
Ok(from_storage(storage, shape, op, false))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
|
|
||||||
self.apply_op2_arc(r, Arc::new(Box::new(c)))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a ternary custom op.
|
|
||||||
pub fn apply_op3_arc(
|
|
||||||
&self,
|
|
||||||
t2: &Self,
|
|
||||||
t3: &Self,
|
|
||||||
c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let (storage, shape) = self.storage().apply_op3(
|
|
||||||
self.layout(),
|
|
||||||
&t2.storage(),
|
|
||||||
t2.layout(),
|
|
||||||
&t3.storage(),
|
|
||||||
t3.layout(),
|
|
||||||
c.as_ref().as_ref(),
|
|
||||||
)?;
|
|
||||||
let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
|
|
||||||
Op::CustomOp3(t1, t2, t3, c.clone())
|
|
||||||
});
|
|
||||||
Ok(from_storage(storage, shape, op, false))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
|
|
||||||
&self,
|
|
||||||
t2: &Self,
|
|
||||||
t3: &Self,
|
|
||||||
c: C,
|
|
||||||
) -> Result<Self> {
|
|
||||||
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// In place ops.
|
|
||||||
|
|
||||||
/// Unary ops that can be defined in user-land.
|
|
||||||
/// These ops work in place and as such back-prop is unsupported.
|
|
||||||
pub trait InplaceOp1 {
|
|
||||||
// Box<dyn> does not support const yet, so use a function to get the name.
|
|
||||||
fn name(&self) -> &'static str;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cpu_fwd(&self, storage: &mut CpuStorage, layout: &Layout) -> Result<()>;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cuda_fwd(&self, _storage: &mut CudaStorage, _layout: &Layout) -> Result<()> {
|
|
||||||
Err(crate::Error::Cuda(
|
|
||||||
format!("no cuda implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn metal_fwd(&self, _storage: &mut MetalStorage, _layout: &Layout) -> Result<()> {
|
|
||||||
Err(crate::Error::Metal(
|
|
||||||
format!("no metal implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait InplaceOp2 {
|
|
||||||
fn name(&self) -> &'static str;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cpu_fwd(&self, s1: &mut CpuStorage, l1: &Layout, s2: &CpuStorage, l2: &Layout)
|
|
||||||
-> Result<()>;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cuda_fwd(&self, _: &mut CudaStorage, _: &Layout, _: &CudaStorage, _: &Layout) -> Result<()> {
|
|
||||||
Err(crate::Error::Cuda(
|
|
||||||
format!("no cuda implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn metal_fwd(
|
|
||||||
&self,
|
|
||||||
_: &mut MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
) -> Result<()> {
|
|
||||||
Err(crate::Error::Metal(
|
|
||||||
format!("no metal implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait InplaceOp3 {
|
|
||||||
fn name(&self) -> &'static str;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cpu_fwd(
|
|
||||||
&self,
|
|
||||||
s1: &mut CpuStorage,
|
|
||||||
l1: &Layout,
|
|
||||||
s2: &CpuStorage,
|
|
||||||
l2: &Layout,
|
|
||||||
s3: &CpuStorage,
|
|
||||||
l3: &Layout,
|
|
||||||
) -> Result<()>;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cuda_fwd(
|
|
||||||
&self,
|
|
||||||
_: &mut CudaStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &CudaStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &CudaStorage,
|
|
||||||
_: &Layout,
|
|
||||||
) -> Result<()> {
|
|
||||||
Err(crate::Error::Cuda(
|
|
||||||
format!("no cuda implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn metal_fwd(
|
|
||||||
&self,
|
|
||||||
_: &mut MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
) -> Result<()> {
|
|
||||||
Err(crate::Error::Metal(
|
|
||||||
format!("no metal implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Tensor {
|
|
||||||
/// Applies a unary custom op in place.
|
|
||||||
pub fn inplace_op1<C: InplaceOp1>(&self, c: &C) -> Result<()> {
|
|
||||||
self.storage_mut().inplace_op1(self.layout(), c)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a unary custom op in place (for the first tensor).
|
|
||||||
pub fn inplace_op2<C: InplaceOp2>(&self, rhs: &Self, c: &C) -> Result<()> {
|
|
||||||
self.storage_mut()
|
|
||||||
.inplace_op2(self.layout(), &rhs.storage(), rhs.layout(), c)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a ternary custom op in place (for the first tensor).
|
|
||||||
pub fn inplace_op3<C: InplaceOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<()> {
|
|
||||||
self.storage_mut().inplace_op3(
|
|
||||||
self.layout(),
|
|
||||||
&t2.storage(),
|
|
||||||
t2.layout(),
|
|
||||||
&t3.storage(),
|
|
||||||
t3.layout(),
|
|
||||||
c,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
@ -8,14 +8,12 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
|
|||||||
pub enum DeviceLocation {
|
pub enum DeviceLocation {
|
||||||
Cpu,
|
Cpu,
|
||||||
Cuda { gpu_id: usize },
|
Cuda { gpu_id: usize },
|
||||||
Metal { gpu_id: usize },
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum Device {
|
pub enum Device {
|
||||||
Cpu,
|
Cpu,
|
||||||
Cuda(crate::CudaDevice),
|
Cuda(crate::CudaDevice),
|
||||||
Metal(crate::MetalDevice),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait NdArray {
|
pub trait NdArray {
|
||||||
@ -130,23 +128,10 @@ impl Device {
|
|||||||
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_metal(ordinal: usize) -> Result<Self> {
|
|
||||||
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn set_seed(&self, seed: u64) -> Result<()> {
|
|
||||||
match self {
|
|
||||||
Self::Cpu => CpuDevice.set_seed(seed),
|
|
||||||
Self::Cuda(c) => c.set_seed(seed),
|
|
||||||
Self::Metal(m) => m.set_seed(seed),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn same_device(&self, rhs: &Self) -> bool {
|
pub fn same_device(&self, rhs: &Self) -> bool {
|
||||||
match (self, rhs) {
|
match (self, rhs) {
|
||||||
(Self::Cpu, Self::Cpu) => true,
|
(Self::Cpu, Self::Cpu) => true,
|
||||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
|
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
|
||||||
(Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
|
|
||||||
_ => false,
|
_ => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -155,20 +140,21 @@ impl Device {
|
|||||||
match self {
|
match self {
|
||||||
Self::Cpu => DeviceLocation::Cpu,
|
Self::Cpu => DeviceLocation::Cpu,
|
||||||
Self::Cuda(device) => device.location(),
|
Self::Cuda(device) => device.location(),
|
||||||
Device::Metal(device) => device.location(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_cpu(&self) -> bool {
|
pub fn is_cpu(&self) -> bool {
|
||||||
matches!(self, Self::Cpu)
|
match self {
|
||||||
|
Self::Cpu => true,
|
||||||
|
Self::Cuda(_) => false,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_cuda(&self) -> bool {
|
pub fn is_cuda(&self) -> bool {
|
||||||
matches!(self, Self::Cuda(_))
|
match self {
|
||||||
}
|
Self::Cpu => false,
|
||||||
|
Self::Cuda(_) => true,
|
||||||
pub fn is_metal(&self) -> bool {
|
}
|
||||||
matches!(self, Self::Metal(_))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
||||||
@ -192,18 +178,8 @@ impl Device {
|
|||||||
Ok(Storage::Cpu(storage))
|
Ok(Storage::Cpu(storage))
|
||||||
}
|
}
|
||||||
Device::Cuda(device) => {
|
Device::Cuda(device) => {
|
||||||
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
|
|
||||||
if dtype == DType::F16 || dtype == DType::BF16 {
|
|
||||||
let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
|
|
||||||
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
|
|
||||||
} else {
|
|
||||||
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
|
||||||
Ok(Storage::Cuda(storage))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Device::Metal(device) => {
|
|
||||||
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||||
Ok(Storage::Metal(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -230,18 +206,8 @@ impl Device {
|
|||||||
Ok(Storage::Cpu(storage))
|
Ok(Storage::Cpu(storage))
|
||||||
}
|
}
|
||||||
Device::Cuda(device) => {
|
Device::Cuda(device) => {
|
||||||
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
|
|
||||||
if dtype == DType::F16 || dtype == DType::BF16 {
|
|
||||||
let storage = device.rand_normal(shape, DType::F32, mean, std)?;
|
|
||||||
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
|
|
||||||
} else {
|
|
||||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
|
||||||
Ok(Storage::Cuda(storage))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Device::Metal(device) => {
|
|
||||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||||
Ok(Storage::Metal(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -265,10 +231,6 @@ impl Device {
|
|||||||
let storage = device.ones_impl(shape, dtype)?;
|
let storage = device.ones_impl(shape, dtype)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
Device::Metal(device) => {
|
|
||||||
let storage = device.ones_impl(shape, dtype)?;
|
|
||||||
Ok(Storage::Metal(storage))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -282,41 +244,6 @@ impl Device {
|
|||||||
let storage = device.zeros_impl(shape, dtype)?;
|
let storage = device.zeros_impl(shape, dtype)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
Device::Metal(device) => {
|
|
||||||
let storage = device.zeros_impl(shape, dtype)?;
|
|
||||||
Ok(Storage::Metal(storage))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
|
||||||
match self {
|
|
||||||
Device::Cpu => {
|
|
||||||
let storage = CpuDevice.alloc_uninit(shape, dtype)?;
|
|
||||||
Ok(Storage::Cpu(storage))
|
|
||||||
}
|
|
||||||
Device::Cuda(device) => {
|
|
||||||
let storage = device.alloc_uninit(shape, dtype)?;
|
|
||||||
Ok(Storage::Cuda(storage))
|
|
||||||
}
|
|
||||||
Device::Metal(device) => {
|
|
||||||
let storage = device.alloc_uninit(shape, dtype)?;
|
|
||||||
Ok(Storage::Metal(storage))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn storage_from_slice<D: WithDType>(&self, data: &[D]) -> Result<Storage> {
|
|
||||||
match self {
|
|
||||||
Device::Cpu => Ok(Storage::Cpu(data.to_cpu_storage())),
|
|
||||||
Device::Cuda(device) => {
|
|
||||||
let storage = device.storage_from_slice(data)?;
|
|
||||||
Ok(Storage::Cuda(storage))
|
|
||||||
}
|
|
||||||
Device::Metal(device) => {
|
|
||||||
let storage = device.storage_from_slice(data)?;
|
|
||||||
Ok(Storage::Metal(storage))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -325,14 +252,9 @@ impl Device {
|
|||||||
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
|
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
|
||||||
Device::Cuda(device) => {
|
Device::Cuda(device) => {
|
||||||
let storage = array.to_cpu_storage();
|
let storage = array.to_cpu_storage();
|
||||||
let storage = device.storage_from_cpu_storage_owned(storage)?;
|
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
Device::Metal(device) => {
|
|
||||||
let storage = array.to_cpu_storage();
|
|
||||||
let storage = device.storage_from_cpu_storage_owned(storage)?;
|
|
||||||
Ok(Storage::Metal(storage))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -341,22 +263,9 @@ impl Device {
|
|||||||
Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
|
Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
|
||||||
Device::Cuda(device) => {
|
Device::Cuda(device) => {
|
||||||
let storage = S::to_cpu_storage_owned(data);
|
let storage = S::to_cpu_storage_owned(data);
|
||||||
let storage = device.storage_from_cpu_storage_owned(storage)?;
|
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
Device::Metal(device) => {
|
|
||||||
let storage = S::to_cpu_storage_owned(data);
|
|
||||||
let storage = device.storage_from_cpu_storage_owned(storage)?;
|
|
||||||
Ok(Storage::Metal(storage))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn synchronize(&self) -> Result<()> {
|
|
||||||
match self {
|
|
||||||
Self::Cpu => Ok(()),
|
|
||||||
Self::Cuda(d) => d.synchronize(),
|
|
||||||
Self::Metal(d) => d.synchronize(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -14,9 +14,6 @@ impl Tensor {
|
|||||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||||
format!(", cuda:{}", gpu_id)
|
format!(", cuda:{}", gpu_id)
|
||||||
}
|
}
|
||||||
crate::DeviceLocation::Metal { gpu_id } => {
|
|
||||||
format!(", metal:{}", gpu_id)
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
write!(f, "Tensor[")?;
|
write!(f, "Tensor[")?;
|
||||||
@ -65,13 +62,12 @@ impl std::fmt::Debug for Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Options for Tensor pretty printing
|
/// Options for Tensor pretty printing
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct PrinterOptions {
|
pub struct PrinterOptions {
|
||||||
pub precision: usize,
|
precision: usize,
|
||||||
pub threshold: usize,
|
threshold: usize,
|
||||||
pub edge_items: usize,
|
edge_items: usize,
|
||||||
pub line_width: usize,
|
line_width: usize,
|
||||||
pub sci_mode: Option<bool>,
|
sci_mode: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
static PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
|
static PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
|
||||||
@ -90,10 +86,6 @@ impl PrinterOptions {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn print_options() -> &'static std::sync::Mutex<PrinterOptions> {
|
|
||||||
&PRINT_OPTS
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn set_print_options(options: PrinterOptions) {
|
pub fn set_print_options(options: PrinterOptions) {
|
||||||
*PRINT_OPTS.lock().unwrap() = options
|
*PRINT_OPTS.lock().unwrap() = options
|
||||||
}
|
}
|
||||||
@ -122,26 +114,6 @@ pub fn set_print_options_full() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_line_width(line_width: usize) {
|
|
||||||
PRINT_OPTS.lock().unwrap().line_width = line_width
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn set_precision(precision: usize) {
|
|
||||||
PRINT_OPTS.lock().unwrap().precision = precision
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn set_edge_items(edge_items: usize) {
|
|
||||||
PRINT_OPTS.lock().unwrap().edge_items = edge_items
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn set_threshold(threshold: usize) {
|
|
||||||
PRINT_OPTS.lock().unwrap().threshold = threshold
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn set_sci_mode(sci_mode: Option<bool>) {
|
|
||||||
PRINT_OPTS.lock().unwrap().sci_mode = sci_mode
|
|
||||||
}
|
|
||||||
|
|
||||||
struct FmtSize {
|
struct FmtSize {
|
||||||
current_size: usize,
|
current_size: usize,
|
||||||
}
|
}
|
||||||
@ -504,9 +476,6 @@ impl std::fmt::Display for Tensor {
|
|||||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||||
format!(", cuda:{}", gpu_id)
|
format!(", cuda:{}", gpu_id)
|
||||||
}
|
}
|
||||||
crate::DeviceLocation::Metal { gpu_id } => {
|
|
||||||
format!(", metal:{}", gpu_id)
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
write!(
|
write!(
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
//! Types for elements that can be stored and manipulated using tensors.
|
//! Types for elements that can be stored and manipulated using tensors.
|
||||||
#![allow(clippy::redundant_closure_call)]
|
#![allow(clippy::redundant_closure_call)]
|
||||||
use crate::backend::BackendStorage;
|
use crate::backend::BackendStorage;
|
||||||
use crate::{CpuStorage, CpuStorageRef, Error, Result};
|
use crate::{CpuStorage, Error, Result};
|
||||||
|
|
||||||
/// The different types of elements allowed in tensors.
|
/// The different types of elements allowed in tensors.
|
||||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||||
@ -23,15 +23,7 @@ pub enum DType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq)]
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
pub struct DTypeParseError(String);
|
pub struct DTypeParseError;
|
||||||
|
|
||||||
impl std::fmt::Display for DTypeParseError {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(f, "cannot parse '{}' as a dtype", self.0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::error::Error for DTypeParseError {}
|
|
||||||
|
|
||||||
impl std::str::FromStr for DType {
|
impl std::str::FromStr for DType {
|
||||||
type Err = DTypeParseError;
|
type Err = DTypeParseError;
|
||||||
@ -44,7 +36,7 @@ impl std::str::FromStr for DType {
|
|||||||
"f16" => Ok(Self::F16),
|
"f16" => Ok(Self::F16),
|
||||||
"f32" => Ok(Self::F32),
|
"f32" => Ok(Self::F32),
|
||||||
"f64" => Ok(Self::F64),
|
"f64" => Ok(Self::F64),
|
||||||
_ => Err(DTypeParseError(s.to_string())),
|
_ => Err(DTypeParseError),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -100,14 +92,12 @@ pub trait WithDType:
|
|||||||
+ 'static
|
+ 'static
|
||||||
+ Send
|
+ Send
|
||||||
+ Sync
|
+ Sync
|
||||||
+ std::any::Any
|
|
||||||
+ crate::cpu::kernels::VecOps
|
+ crate::cpu::kernels::VecOps
|
||||||
{
|
{
|
||||||
const DTYPE: DType;
|
const DTYPE: DType;
|
||||||
|
|
||||||
fn from_f64(v: f64) -> Self;
|
fn from_f64(v: f64) -> Self;
|
||||||
fn to_f64(self) -> f64;
|
fn to_f64(self) -> f64;
|
||||||
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>;
|
|
||||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
|
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
|
||||||
|
|
||||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
||||||
@ -131,10 +121,6 @@ macro_rules! with_dtype {
|
|||||||
$to_f64(self)
|
$to_f64(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> {
|
|
||||||
CpuStorageRef::$dtype(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
|
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
|
||||||
CpuStorage::$dtype(data)
|
CpuStorage::$dtype(data)
|
||||||
}
|
}
|
||||||
|
@ -79,16 +79,6 @@ impl crate::backend::BackendStorage for CudaStorage {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv_transpose1d(
|
|
||||||
&self,
|
|
||||||
_: &Layout,
|
|
||||||
_: &Self,
|
|
||||||
_: &Layout,
|
|
||||||
_: &crate::conv::ParamsConvTranspose1D,
|
|
||||||
) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
&self,
|
&self,
|
||||||
_: &Layout,
|
_: &Layout,
|
||||||
@ -154,19 +144,6 @@ impl crate::backend::BackendStorage for CudaStorage {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn copy2d(
|
|
||||||
&self,
|
|
||||||
_: &mut Self,
|
|
||||||
_: usize,
|
|
||||||
_: usize,
|
|
||||||
_: usize,
|
|
||||||
_: usize,
|
|
||||||
_: usize,
|
|
||||||
_: usize,
|
|
||||||
) -> Result<()> {
|
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
@ -210,22 +187,10 @@ impl crate::backend::BackendDevice for CudaDevice {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage> {
|
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage> {
|
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
@ -233,38 +198,4 @@ impl crate::backend::BackendDevice for CudaDevice {
|
|||||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn synchronize(&self) -> Result<()> {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
|
||||||
/// allowed with f16 GEMMs.
|
|
||||||
pub fn gemm_reduced_precision_f16() -> bool {
|
|
||||||
true
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
|
||||||
/// allowed with f16 GEMMs.
|
|
||||||
pub fn set_gemm_reduced_precision_f16(_: bool) {}
|
|
||||||
|
|
||||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
|
||||||
/// allowed with bf16 GEMMs.
|
|
||||||
pub fn gemm_reduced_precision_bf16() -> bool {
|
|
||||||
true
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
|
||||||
/// allowed with bf16 GEMMs.
|
|
||||||
pub fn set_gemm_reduced_precision_bf16(_: bool) {}
|
|
||||||
|
|
||||||
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
|
|
||||||
/// allowed with f32 GEMMs.
|
|
||||||
pub fn gemm_reduced_precision_f32() -> bool {
|
|
||||||
true
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
|
|
||||||
/// allowed with f32 GEMMs.
|
|
||||||
pub fn set_gemm_reduced_precision_f32(_b: bool) {}
|
|
||||||
|
@ -1,252 +0,0 @@
|
|||||||
#![allow(dead_code)]
|
|
||||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
|
||||||
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct MetalDevice;
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct MetalStorage;
|
|
||||||
|
|
||||||
#[derive(thiserror::Error, Debug)]
|
|
||||||
pub enum MetalError {
|
|
||||||
#[error("{0}")]
|
|
||||||
Message(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<String> for MetalError {
|
|
||||||
fn from(e: String) -> Self {
|
|
||||||
MetalError::Message(e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
macro_rules! fail {
|
|
||||||
() => {
|
|
||||||
unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
impl crate::backend::BackendStorage for MetalStorage {
|
|
||||||
type Device = MetalDevice;
|
|
||||||
|
|
||||||
fn try_clone(&self, _: &Layout) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn dtype(&self) -> DType {
|
|
||||||
fail!()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn device(&self) -> &Self::Device {
|
|
||||||
fail!()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn conv1d(
|
|
||||||
&self,
|
|
||||||
_: &Layout,
|
|
||||||
_: &Self,
|
|
||||||
_: &Layout,
|
|
||||||
_: &crate::conv::ParamsConv1D,
|
|
||||||
) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn conv_transpose1d(
|
|
||||||
&self,
|
|
||||||
_l: &Layout,
|
|
||||||
_kernel: &Self,
|
|
||||||
_kernel_l: &Layout,
|
|
||||||
_params: &crate::conv::ParamsConvTranspose1D,
|
|
||||||
) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn conv2d(
|
|
||||||
&self,
|
|
||||||
_: &Layout,
|
|
||||||
_: &Self,
|
|
||||||
_: &Layout,
|
|
||||||
_: &crate::conv::ParamsConv2D,
|
|
||||||
) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn conv_transpose2d(
|
|
||||||
&self,
|
|
||||||
_l: &Layout,
|
|
||||||
_kernel: &Self,
|
|
||||||
_kernel_l: &Layout,
|
|
||||||
_params: &crate::conv::ParamsConvTranspose2D,
|
|
||||||
) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn scatter_add(
|
|
||||||
&self,
|
|
||||||
_: &Layout,
|
|
||||||
_: &Self,
|
|
||||||
_: &Layout,
|
|
||||||
_: &Self,
|
|
||||||
_: &Layout,
|
|
||||||
_: usize,
|
|
||||||
) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn index_add(
|
|
||||||
&self,
|
|
||||||
_: &Layout,
|
|
||||||
_: &Self,
|
|
||||||
_: &Layout,
|
|
||||||
_: &Self,
|
|
||||||
_: &Layout,
|
|
||||||
_: usize,
|
|
||||||
) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn matmul(
|
|
||||||
&self,
|
|
||||||
_: &Self,
|
|
||||||
_: (usize, usize, usize, usize),
|
|
||||||
_: &Layout,
|
|
||||||
_: &Layout,
|
|
||||||
) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn copy2d(
|
|
||||||
&self,
|
|
||||||
_: &mut Self,
|
|
||||||
_: usize,
|
|
||||||
_: usize,
|
|
||||||
_: usize,
|
|
||||||
_: usize,
|
|
||||||
_: usize,
|
|
||||||
_: usize,
|
|
||||||
) -> Result<()> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl crate::backend::BackendDevice for MetalDevice {
|
|
||||||
type Storage = MetalStorage;
|
|
||||||
fn new(_: usize) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_seed(&self, _: u64) -> Result<()> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn location(&self) -> crate::DeviceLocation {
|
|
||||||
fail!()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn same_device(&self, _: &Self) -> bool {
|
|
||||||
fail!()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn synchronize(&self) -> Result<()> {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,4 +1,4 @@
|
|||||||
use crate::{DType, DeviceLocation, Layout, MetalError, Shape};
|
use crate::{DType, DeviceLocation, Layout, Shape};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct MatMulUnexpectedStriding {
|
pub struct MatMulUnexpectedStriding {
|
||||||
@ -142,9 +142,6 @@ pub enum Error {
|
|||||||
#[error("{op} expects at least one tensor")]
|
#[error("{op} expects at least one tensor")]
|
||||||
OpRequiresAtLeastOneTensor { op: &'static str },
|
OpRequiresAtLeastOneTensor { op: &'static str },
|
||||||
|
|
||||||
#[error("{op} expects at least two tensors")]
|
|
||||||
OpRequiresAtLeastTwoTensors { op: &'static str },
|
|
||||||
|
|
||||||
#[error("backward is not supported for {op}")]
|
#[error("backward is not supported for {op}")]
|
||||||
BackwardNotSupported { op: &'static str },
|
BackwardNotSupported { op: &'static str },
|
||||||
|
|
||||||
@ -152,9 +149,6 @@ pub enum Error {
|
|||||||
#[error("the candle crate has not been built with cuda support")]
|
#[error("the candle crate has not been built with cuda support")]
|
||||||
NotCompiledWithCudaSupport,
|
NotCompiledWithCudaSupport,
|
||||||
|
|
||||||
#[error("the candle crate has not been built with metal support")]
|
|
||||||
NotCompiledWithMetalSupport,
|
|
||||||
|
|
||||||
#[error("cannot find tensor {path}")]
|
#[error("cannot find tensor {path}")]
|
||||||
CannotFindTensor { path: String },
|
CannotFindTensor { path: String },
|
||||||
|
|
||||||
@ -162,9 +156,6 @@ pub enum Error {
|
|||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Cuda(Box<dyn std::error::Error + Send + Sync>),
|
Cuda(Box<dyn std::error::Error + Send + Sync>),
|
||||||
|
|
||||||
#[error("Metal error {0}")]
|
|
||||||
Metal(#[from] MetalError),
|
|
||||||
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
TryFromIntError(#[from] core::num::TryFromIntError),
|
TryFromIntError(#[from] core::num::TryFromIntError),
|
||||||
|
|
||||||
@ -219,14 +210,10 @@ impl Error {
|
|||||||
Self::Wrapped(Box::new(err)).bt()
|
Self::Wrapped(Box::new(err)).bt()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn msg(err: impl std::error::Error) -> Self {
|
pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self {
|
||||||
Self::Msg(err.to_string()).bt()
|
Self::Msg(err.to_string()).bt()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn debug(err: impl std::fmt::Debug) -> Self {
|
|
||||||
Self::Msg(format!("{err:?}")).bt()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn bt(self) -> Self {
|
pub fn bt(self) -> Self {
|
||||||
let backtrace = std::backtrace::Backtrace::capture();
|
let backtrace = std::backtrace::Backtrace::capture();
|
||||||
match backtrace.status() {
|
match backtrace.status() {
|
||||||
|
@ -64,7 +64,7 @@ impl Tensor {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
/// Generic structure used to index a slice of the tensor
|
/// Generic structure used to index a slice of the tensor
|
||||||
pub enum TensorIndexer {
|
pub enum TensorIndexer {
|
||||||
/// This selects the elements for which an index has some specific value.
|
/// This selects the elemnts for which an index has some specific value.
|
||||||
Select(usize),
|
Select(usize),
|
||||||
/// This is a regular slice, purely indexing a chunk of the tensor
|
/// This is a regular slice, purely indexing a chunk of the tensor
|
||||||
Narrow(Bound<usize>, Bound<usize>),
|
Narrow(Bound<usize>, Bound<usize>),
|
||||||
@ -104,31 +104,37 @@ impl From<&Tensor> for TensorIndexer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
trait RB: RangeBounds<usize> {}
|
macro_rules! impl_from_range {
|
||||||
impl RB for Range<usize> {}
|
($range_type:ty) => {
|
||||||
impl RB for RangeFrom<usize> {}
|
impl From<$range_type> for TensorIndexer {
|
||||||
impl RB for RangeFull {}
|
fn from(range: $range_type) -> Self {
|
||||||
impl RB for RangeInclusive<usize> {}
|
use std::ops::Bound::*;
|
||||||
impl RB for RangeTo<usize> {}
|
|
||||||
impl RB for RangeToInclusive<usize> {}
|
|
||||||
|
|
||||||
impl<T: RB> From<T> for TensorIndexer {
|
let start = match range.start_bound() {
|
||||||
fn from(range: T) -> Self {
|
Included(idx) => Included(*idx),
|
||||||
use std::ops::Bound::*;
|
Excluded(idx) => Excluded(*idx),
|
||||||
let start = match range.start_bound() {
|
Unbounded => Unbounded,
|
||||||
Included(idx) => Included(*idx),
|
};
|
||||||
Excluded(idx) => Excluded(*idx),
|
|
||||||
Unbounded => Unbounded,
|
let end = match range.end_bound() {
|
||||||
};
|
Included(idx) => Included(*idx),
|
||||||
let end = match range.end_bound() {
|
Excluded(idx) => Excluded(*idx),
|
||||||
Included(idx) => Included(*idx),
|
Unbounded => Unbounded,
|
||||||
Excluded(idx) => Excluded(*idx),
|
};
|
||||||
Unbounded => Unbounded,
|
|
||||||
};
|
TensorIndexer::Narrow(start, end)
|
||||||
TensorIndexer::Narrow(start, end)
|
}
|
||||||
}
|
}
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl_from_range!(Range<usize>);
|
||||||
|
impl_from_range!(RangeFrom<usize>);
|
||||||
|
impl_from_range!(RangeFull);
|
||||||
|
impl_from_range!(RangeInclusive<usize>);
|
||||||
|
impl_from_range!(RangeTo<usize>);
|
||||||
|
impl_from_range!(RangeToInclusive<usize>);
|
||||||
|
|
||||||
/// Trait used to implement multiple signatures for ease of use of the slicing
|
/// Trait used to implement multiple signatures for ease of use of the slicing
|
||||||
/// of a tensor
|
/// of a tensor
|
||||||
pub trait IndexOp<T> {
|
pub trait IndexOp<T> {
|
||||||
|
@ -70,7 +70,7 @@ impl Layout {
|
|||||||
self.shape.is_fortran_contiguous(&self.stride)
|
self.shape.is_fortran_contiguous(&self.stride)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
|
pub(crate) fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
|
||||||
let dims = self.shape().dims();
|
let dims = self.shape().dims();
|
||||||
if dim >= dims.len() {
|
if dim >= dims.len() {
|
||||||
Err(Error::DimOutOfRange {
|
Err(Error::DimOutOfRange {
|
||||||
@ -99,7 +99,7 @@ impl Layout {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
|
pub(crate) fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
|
||||||
let rank = self.shape.rank();
|
let rank = self.shape.rank();
|
||||||
if rank <= dim1 || rank <= dim2 {
|
if rank <= dim1 || rank <= dim2 {
|
||||||
Err(Error::UnexpectedNumberOfDims {
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
@ -120,7 +120,7 @@ impl Layout {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn permute(&self, idxs: &[usize]) -> Result<Self> {
|
pub(crate) fn permute(&self, idxs: &[usize]) -> Result<Self> {
|
||||||
let is_permutation =
|
let is_permutation =
|
||||||
idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
|
idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
|
||||||
if !is_permutation {
|
if !is_permutation {
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
//!
|
//!
|
||||||
//! ## Features
|
//! ## Features
|
||||||
//!
|
//!
|
||||||
//! - Simple syntax (looks and feels like PyTorch)
|
//! - Simple syntax (looks and like PyTorch)
|
||||||
//! - CPU and Cuda backends (and M1 support)
|
//! - CPU and Cuda backends (and M1 support)
|
||||||
//! - Enable serverless (CPU) small and fast deployments
|
//! - Enable serverless (CPU) small and fast deployments
|
||||||
//! - Model training
|
//! - Model training
|
||||||
@ -37,51 +37,44 @@
|
|||||||
mod accelerate;
|
mod accelerate;
|
||||||
pub mod backend;
|
pub mod backend;
|
||||||
pub mod backprop;
|
pub mod backprop;
|
||||||
pub mod conv;
|
mod conv;
|
||||||
mod convert;
|
mod convert;
|
||||||
pub mod cpu;
|
pub mod cpu;
|
||||||
pub mod cpu_backend;
|
pub mod cpu_backend;
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
pub mod cuda_backend;
|
pub mod cuda_backend;
|
||||||
mod custom_op;
|
#[cfg(feature = "cudnn")]
|
||||||
|
pub mod cudnn;
|
||||||
mod device;
|
mod device;
|
||||||
pub mod display;
|
pub mod display;
|
||||||
mod dtype;
|
mod dtype;
|
||||||
pub mod dummy_cuda_backend;
|
mod dummy_cuda_backend;
|
||||||
mod dummy_metal_backend;
|
|
||||||
pub mod error;
|
pub mod error;
|
||||||
mod indexer;
|
mod indexer;
|
||||||
pub mod layout;
|
pub mod layout;
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
pub mod metal_backend;
|
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
mod mkl;
|
mod mkl;
|
||||||
pub mod npy;
|
pub mod npy;
|
||||||
pub mod op;
|
mod op;
|
||||||
pub mod pickle;
|
pub mod pickle;
|
||||||
pub mod quantized;
|
pub mod quantized;
|
||||||
pub mod safetensors;
|
pub mod safetensors;
|
||||||
pub mod scalar;
|
pub mod scalar;
|
||||||
pub mod shape;
|
pub mod shape;
|
||||||
mod sort;
|
|
||||||
mod storage;
|
mod storage;
|
||||||
mod strided_index;
|
mod strided_index;
|
||||||
mod tensor;
|
mod tensor;
|
||||||
mod tensor_cat;
|
|
||||||
pub mod test_utils;
|
pub mod test_utils;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
mod variable;
|
mod variable;
|
||||||
|
|
||||||
#[cfg(feature = "cudnn")]
|
pub use cpu_backend::CpuStorage;
|
||||||
pub use cuda_backend::cudnn;
|
pub use device::{Device, DeviceLocation};
|
||||||
|
pub use dtype::{DType, FloatDType, IntDType, WithDType};
|
||||||
pub use cpu_backend::{CpuStorage, CpuStorageRef};
|
|
||||||
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
|
|
||||||
pub use device::{Device, DeviceLocation, NdArray};
|
|
||||||
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
|
|
||||||
pub use error::{Error, Result};
|
pub use error::{Error, Result};
|
||||||
pub use indexer::IndexOp;
|
pub use indexer::IndexOp;
|
||||||
pub use layout::Layout;
|
pub use layout::Layout;
|
||||||
|
pub use op::{CustomOp1, CustomOp2, CustomOp3};
|
||||||
pub use shape::{Shape, D};
|
pub use shape::{Shape, D};
|
||||||
pub use storage::Storage;
|
pub use storage::Storage;
|
||||||
pub use strided_index::{StridedBlocks, StridedIndex};
|
pub use strided_index::{StridedBlocks, StridedIndex};
|
||||||
@ -89,18 +82,10 @@ pub use tensor::{Tensor, TensorId};
|
|||||||
pub use variable::Var;
|
pub use variable::Var;
|
||||||
|
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
pub use cuda_backend as cuda;
|
pub use cuda_backend::{CudaDevice, CudaStorage};
|
||||||
|
|
||||||
#[cfg(not(feature = "cuda"))]
|
#[cfg(not(feature = "cuda"))]
|
||||||
pub use dummy_cuda_backend as cuda;
|
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
|
||||||
|
|
||||||
pub use cuda::{CudaDevice, CudaStorage};
|
|
||||||
|
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
|
|
||||||
|
|
||||||
#[cfg(not(feature = "metal"))]
|
|
||||||
pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage};
|
|
||||||
|
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
@ -129,29 +114,14 @@ pub trait Module {
|
|||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Module for quantized::QMatMul {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
self.forward(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
self(xs)
|
self(xs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<M: Module> Module for Option<&M> {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
match self {
|
|
||||||
None => Ok(xs.clone()),
|
|
||||||
Some(m) => m.forward(xs),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// A trait defining a module with forward method using a single tensor argument and a flag to
|
|
||||||
// separate the training and evaluation behaviors.
|
|
||||||
pub trait ModuleT {
|
|
||||||
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<M: Module> ModuleT for M {
|
|
||||||
fn forward_t(&self, xs: &Tensor, _train: bool) -> Result<Tensor> {
|
|
||||||
self.forward(xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -1,287 +0,0 @@
|
|||||||
use crate::{DType, Result};
|
|
||||||
use candle_metal_kernels::Kernels;
|
|
||||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::ffi::c_void;
|
|
||||||
use std::path::Path;
|
|
||||||
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard};
|
|
||||||
|
|
||||||
use super::MetalError;
|
|
||||||
|
|
||||||
/// Unique identifier for cuda devices.
|
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
|
||||||
pub struct DeviceId(usize);
|
|
||||||
|
|
||||||
impl DeviceId {
|
|
||||||
pub(crate) fn new() -> Self {
|
|
||||||
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
|
||||||
use std::sync::atomic;
|
|
||||||
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
|
||||||
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
|
|
||||||
type AllocatedBuffers = Arc<RwLock<BufferMap>>;
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct MetalDevice {
|
|
||||||
/// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than
|
|
||||||
/// the device itself.
|
|
||||||
pub(crate) id: DeviceId,
|
|
||||||
|
|
||||||
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
|
|
||||||
pub(crate) device: metal::Device,
|
|
||||||
|
|
||||||
/// Single command queue for the entire device.
|
|
||||||
pub(crate) command_queue: CommandQueue,
|
|
||||||
/// One command buffer at a time.
|
|
||||||
/// The scheduler works by allowing multiple
|
|
||||||
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
|
|
||||||
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
|
|
||||||
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
|
|
||||||
/// to start to work).
|
|
||||||
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
|
|
||||||
/// for their START time, but there's no guarantee that command buffer1 will finish before
|
|
||||||
/// command buffer2 starts (or there are metal bugs there)
|
|
||||||
pub(crate) command_buffer: Arc<RwLock<CommandBuffer>>,
|
|
||||||
/// Keeps track of the current amount of compute command encoders on the current
|
|
||||||
/// command buffer
|
|
||||||
/// Arc, RwLock because of the interior mutability.
|
|
||||||
pub(crate) command_buffer_index: Arc<RwLock<usize>>,
|
|
||||||
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
|
|
||||||
pub(crate) compute_per_buffer: usize,
|
|
||||||
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
|
||||||
/// Heavily used by [`candle_metal_kernels`]
|
|
||||||
pub(crate) kernels: Arc<Kernels>,
|
|
||||||
/// Simple allocator struct.
|
|
||||||
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
|
|
||||||
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
|
|
||||||
/// (could be linked to FFI communication overhead).
|
|
||||||
///
|
|
||||||
/// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the
|
|
||||||
/// graph calculation, and only we the allocator kept a reference to it, therefore it's free
|
|
||||||
/// to be reused. However, in order for this to work, we need to guarantee the order of
|
|
||||||
/// operation, so that this buffer is not being used by another kernel at the same time.
|
|
||||||
/// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things.
|
|
||||||
///
|
|
||||||
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
|
|
||||||
/// (strong_count = 1).
|
|
||||||
pub(crate) buffers: AllocatedBuffers,
|
|
||||||
/// Seed for random number generation.
|
|
||||||
pub(crate) seed: Arc<Mutex<Buffer>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Debug for MetalDevice {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(f, "MetalDevice({:?})", self.id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::ops::Deref for MetalDevice {
|
|
||||||
type Target = metal::DeviceRef;
|
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MetalDevice {
|
|
||||||
pub fn id(&self) -> DeviceId {
|
|
||||||
self.id
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn metal_device(&self) -> &metal::Device {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn command_queue(&self) -> &CommandQueue {
|
|
||||||
&self.command_queue
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn command_buffer(&self) -> Result<CommandBuffer> {
|
|
||||||
let mut command_buffer_lock = self.command_buffer.write().map_err(MetalError::from)?;
|
|
||||||
let mut command_buffer = command_buffer_lock.to_owned();
|
|
||||||
let mut index = self
|
|
||||||
.command_buffer_index
|
|
||||||
.write()
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
if *index > self.compute_per_buffer {
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
|
||||||
*command_buffer_lock = command_buffer.clone();
|
|
||||||
*index = 0;
|
|
||||||
|
|
||||||
self.drop_unused_buffers()?;
|
|
||||||
}
|
|
||||||
*index += 1;
|
|
||||||
Ok(command_buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn wait_until_completed(&self) -> Result<()> {
|
|
||||||
let mut command_buffer = self.command_buffer.write().map_err(MetalError::from)?;
|
|
||||||
match command_buffer.status() {
|
|
||||||
metal::MTLCommandBufferStatus::Committed
|
|
||||||
| metal::MTLCommandBufferStatus::Scheduled
|
|
||||||
| metal::MTLCommandBufferStatus::Completed => {
|
|
||||||
panic!("Already committed");
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
*command_buffer = self.command_queue.new_command_buffer().to_owned();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn kernels(&self) -> &Kernels {
|
|
||||||
&self.kernels
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn device(&self) -> &metal::Device {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates a new buffer (not necessarily zeroed).
|
|
||||||
/// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
|
||||||
/// This means the buffer data cannot be read on the CPU directly.
|
|
||||||
///
|
|
||||||
/// [`name`] is only used to keep track of the resource origin in case of bugs
|
|
||||||
pub fn new_buffer(
|
|
||||||
&self,
|
|
||||||
element_count: usize,
|
|
||||||
dtype: DType,
|
|
||||||
name: &str,
|
|
||||||
) -> Result<Arc<Buffer>> {
|
|
||||||
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
|
||||||
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates a new buffer (not necessarily zeroed).
|
|
||||||
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
|
||||||
/// This means the buffer can be read on the CPU but will require manual
|
|
||||||
/// synchronization when the CPU memory is modified
|
|
||||||
/// Used as a bridge to gather data back from the GPU
|
|
||||||
pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> {
|
|
||||||
self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates a new buffer from data.
|
|
||||||
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
|
||||||
///
|
|
||||||
/// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes)
|
|
||||||
/// allocates the buffer and copies over the existing data before returning the MTLBuffer.
|
|
||||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
|
|
||||||
let size = core::mem::size_of_val(data) as NSUInteger;
|
|
||||||
let new_buffer = self.device.new_buffer_with_data(
|
|
||||||
data.as_ptr() as *const c_void,
|
|
||||||
size,
|
|
||||||
MTLResourceOptions::StorageModeManaged,
|
|
||||||
);
|
|
||||||
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
|
||||||
let subbuffers = buffers
|
|
||||||
.entry((size, MTLResourceOptions::StorageModeManaged))
|
|
||||||
.or_insert(vec![]);
|
|
||||||
|
|
||||||
let new_buffer = Arc::new(new_buffer);
|
|
||||||
subbuffers.push(new_buffer.clone());
|
|
||||||
Ok(new_buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
|
|
||||||
let buffer = self.allocate_buffer(
|
|
||||||
size_in_bytes as NSUInteger,
|
|
||||||
MTLResourceOptions::StorageModePrivate,
|
|
||||||
"allocate_zeros",
|
|
||||||
)?;
|
|
||||||
let command_buffer = self.command_buffer()?;
|
|
||||||
command_buffer.set_label("zeros");
|
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
|
||||||
blit.fill_buffer(
|
|
||||||
&buffer,
|
|
||||||
metal::NSRange {
|
|
||||||
location: 0,
|
|
||||||
length: buffer.length(),
|
|
||||||
},
|
|
||||||
0,
|
|
||||||
);
|
|
||||||
blit.end_encoding();
|
|
||||||
Ok(buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn find_available_buffer(
|
|
||||||
&self,
|
|
||||||
size: NSUInteger,
|
|
||||||
option: MTLResourceOptions,
|
|
||||||
buffers: &RwLockWriteGuard<BufferMap>,
|
|
||||||
) -> Option<Arc<Buffer>> {
|
|
||||||
let mut best_buffer: Option<&Arc<Buffer>> = None;
|
|
||||||
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
|
|
||||||
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
|
|
||||||
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
|
|
||||||
for sub in subbuffers {
|
|
||||||
if Arc::strong_count(sub) == 1 {
|
|
||||||
best_buffer = Some(sub);
|
|
||||||
best_buffer_size = *buffer_size;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
best_buffer.cloned()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn drop_unused_buffers(&self) -> Result<()> {
|
|
||||||
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
|
||||||
for subbuffers in buffers.values_mut() {
|
|
||||||
let newbuffers = subbuffers
|
|
||||||
.iter()
|
|
||||||
.filter(|s| Arc::strong_count(*s) > 1)
|
|
||||||
.map(Arc::clone)
|
|
||||||
.collect();
|
|
||||||
*subbuffers = newbuffers;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The critical allocator algorithm
|
|
||||||
fn allocate_buffer(
|
|
||||||
&self,
|
|
||||||
size: NSUInteger,
|
|
||||||
option: MTLResourceOptions,
|
|
||||||
_name: &str,
|
|
||||||
) -> Result<Arc<Buffer>> {
|
|
||||||
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
|
||||||
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
|
|
||||||
// Cloning also ensures we increment the strong count
|
|
||||||
return Ok(b.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
let size = buf_size(size);
|
|
||||||
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
|
|
||||||
|
|
||||||
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
|
|
||||||
let new_buffer = Arc::new(new_buffer);
|
|
||||||
subbuffers.push(new_buffer.clone());
|
|
||||||
|
|
||||||
Ok(new_buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a metal GPU capture trace on [`path`].
|
|
||||||
pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
|
|
||||||
let capture = metal::CaptureManager::shared();
|
|
||||||
let descriptor = metal::CaptureDescriptor::new();
|
|
||||||
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
|
||||||
descriptor.set_capture_device(self);
|
|
||||||
descriptor.set_output_url(path);
|
|
||||||
|
|
||||||
capture
|
|
||||||
.start_capture(&descriptor)
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn buf_size(size: NSUInteger) -> NSUInteger {
|
|
||||||
size.saturating_sub(1).next_power_of_two() as NSUInteger
|
|
||||||
}
|
|
File diff suppressed because it is too large
Load Diff
@ -333,16 +333,6 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
|
|||||||
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub fn vs_exp_inplace(y: &mut [f32]) {
|
|
||||||
unsafe { ffi::vsExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub fn vd_exp_inplace(y: &mut [f64]) {
|
|
||||||
unsafe { ffi::vdExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
@ -365,28 +355,6 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
|
|
||||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
|
||||||
*y = -v
|
|
||||||
}
|
|
||||||
vs_exp_inplace(ys);
|
|
||||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
|
||||||
*y = v / (1.0 + *y)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
|
|
||||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
|
||||||
*y = -v
|
|
||||||
}
|
|
||||||
vd_exp_inplace(ys);
|
|
||||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
|
||||||
*y = v / (1.0 + *y)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
macro_rules! binary_op {
|
macro_rules! binary_op {
|
||||||
($fn_name:ident, $ty:ty, $mkl_name:ident) => {
|
($fn_name:ident, $ty:ty, $mkl_name:ident) => {
|
||||||
#[inline]
|
#[inline]
|
||||||
|
@ -250,6 +250,8 @@ impl Tensor {
|
|||||||
if header.fortran_order {
|
if header.fortran_order {
|
||||||
return Err(Error::Npy("fortran order not supported".to_string()));
|
return Err(Error::Npy("fortran order not supported".to_string()));
|
||||||
}
|
}
|
||||||
|
let mut data: Vec<u8> = vec![];
|
||||||
|
reader.read_to_end(&mut data)?;
|
||||||
Self::from_reader(header.shape(), header.descr, &mut reader)
|
Self::from_reader(header.shape(), header.descr, &mut reader)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -330,7 +332,7 @@ impl Tensor {
|
|||||||
path: P,
|
path: P,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let mut zip = zip::ZipWriter::new(File::create(path.as_ref())?);
|
let mut zip = zip::ZipWriter::new(File::create(path.as_ref())?);
|
||||||
let options: zip::write::FileOptions<()> =
|
let options =
|
||||||
zip::write::FileOptions::default().compression_method(zip::CompressionMethod::Stored);
|
zip::write::FileOptions::default().compression_method(zip::CompressionMethod::Stored);
|
||||||
|
|
||||||
for (name, tensor) in ts.iter() {
|
for (name, tensor) in ts.iter() {
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#![allow(clippy::redundant_closure_call)]
|
#![allow(clippy::redundant_closure_call)]
|
||||||
use crate::Tensor;
|
use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use num_traits::float::Float;
|
use num_traits::float::Float;
|
||||||
|
|
||||||
@ -61,12 +61,10 @@ pub enum UnaryOp {
|
|||||||
GeluErf,
|
GeluErf,
|
||||||
Erf,
|
Erf,
|
||||||
Relu,
|
Relu,
|
||||||
Silu,
|
|
||||||
Tanh,
|
Tanh,
|
||||||
Floor,
|
Floor,
|
||||||
Ceil,
|
Ceil,
|
||||||
Round,
|
Round,
|
||||||
Sign,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -92,16 +90,6 @@ pub enum Op {
|
|||||||
dilation: usize,
|
dilation: usize,
|
||||||
},
|
},
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
ConvTranspose1D {
|
|
||||||
arg: Tensor,
|
|
||||||
kernel: Tensor,
|
|
||||||
padding: usize,
|
|
||||||
output_padding: usize,
|
|
||||||
stride: usize,
|
|
||||||
dilation: usize,
|
|
||||||
},
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
Conv2D {
|
Conv2D {
|
||||||
arg: Tensor,
|
arg: Tensor,
|
||||||
@ -133,15 +121,8 @@ pub enum Op {
|
|||||||
stride: (usize, usize),
|
stride: (usize, usize),
|
||||||
},
|
},
|
||||||
|
|
||||||
UpsampleNearest1D {
|
UpsampleNearest1D(Tensor),
|
||||||
arg: Tensor,
|
UpsampleNearest2D(Tensor),
|
||||||
target_size: usize,
|
|
||||||
},
|
|
||||||
UpsampleNearest2D {
|
|
||||||
arg: Tensor,
|
|
||||||
target_h: usize,
|
|
||||||
target_w: usize,
|
|
||||||
},
|
|
||||||
|
|
||||||
Cat(Vec<Tensor>, usize),
|
Cat(Vec<Tensor>, usize),
|
||||||
|
|
||||||
@ -162,23 +143,126 @@ pub enum Op {
|
|||||||
Permute(Tensor, Vec<usize>),
|
Permute(Tensor, Vec<usize>),
|
||||||
Elu(Tensor, f64),
|
Elu(Tensor, f64),
|
||||||
Powf(Tensor, f64),
|
Powf(Tensor, f64),
|
||||||
CustomOp1(
|
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>),
|
||||||
Tensor,
|
|
||||||
std::sync::Arc<Box<dyn crate::CustomOp1 + Send + Sync>>,
|
|
||||||
),
|
|
||||||
CustomOp2(
|
CustomOp2(
|
||||||
Tensor,
|
Tensor,
|
||||||
Tensor,
|
Tensor,
|
||||||
std::sync::Arc<Box<dyn crate::CustomOp2 + Send + Sync>>,
|
std::sync::Arc<Box<dyn CustomOp2 + Send + Sync>>,
|
||||||
),
|
),
|
||||||
CustomOp3(
|
CustomOp3(
|
||||||
Tensor,
|
Tensor,
|
||||||
Tensor,
|
Tensor,
|
||||||
Tensor,
|
Tensor,
|
||||||
std::sync::Arc<Box<dyn crate::CustomOp3 + Send + Sync>>,
|
std::sync::Arc<Box<dyn CustomOp3 + Send + Sync>>,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Unary ops that can be defined in user-land.
|
||||||
|
pub trait CustomOp1 {
|
||||||
|
// Box<dyn> does not support const yet, so use a function to get the name.
|
||||||
|
fn name(&self) -> &'static str;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
|
||||||
|
Err(crate::Error::Cuda(
|
||||||
|
format!("no cuda implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
||||||
|
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
||||||
|
/// The function should return the gradient of the argument.
|
||||||
|
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
|
||||||
|
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait CustomOp2 {
|
||||||
|
fn name(&self) -> &'static str;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cpu_fwd(
|
||||||
|
&self,
|
||||||
|
s1: &CpuStorage,
|
||||||
|
l1: &Layout,
|
||||||
|
s2: &CpuStorage,
|
||||||
|
l2: &Layout,
|
||||||
|
) -> Result<(CpuStorage, Shape)>;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cuda_fwd(
|
||||||
|
&self,
|
||||||
|
_: &CudaStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &CudaStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<(CudaStorage, Shape)> {
|
||||||
|
Err(crate::Error::Cuda(
|
||||||
|
format!("no cuda implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bwd(
|
||||||
|
&self,
|
||||||
|
_arg1: &Tensor,
|
||||||
|
_arg2: &Tensor,
|
||||||
|
_res: &Tensor,
|
||||||
|
_grad_res: &Tensor,
|
||||||
|
) -> Result<(Option<Tensor>, Option<Tensor>)> {
|
||||||
|
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait CustomOp3 {
|
||||||
|
fn name(&self) -> &'static str;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cpu_fwd(
|
||||||
|
&self,
|
||||||
|
s1: &CpuStorage,
|
||||||
|
l1: &Layout,
|
||||||
|
s2: &CpuStorage,
|
||||||
|
l2: &Layout,
|
||||||
|
s3: &CpuStorage,
|
||||||
|
l3: &Layout,
|
||||||
|
) -> Result<(CpuStorage, Shape)>;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cuda_fwd(
|
||||||
|
&self,
|
||||||
|
_: &CudaStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &CudaStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &CudaStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<(CudaStorage, Shape)> {
|
||||||
|
Err(crate::Error::Cuda(
|
||||||
|
format!("no cuda implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bwd(
|
||||||
|
&self,
|
||||||
|
_arg1: &Tensor,
|
||||||
|
_arg2: &Tensor,
|
||||||
|
_arg3: &Tensor,
|
||||||
|
_res: &Tensor,
|
||||||
|
_grad_res: &Tensor,
|
||||||
|
) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
|
||||||
|
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub trait UnaryOpT {
|
pub trait UnaryOpT {
|
||||||
const NAME: &'static str;
|
const NAME: &'static str;
|
||||||
const KERNEL: &'static str;
|
const KERNEL: &'static str;
|
||||||
@ -250,12 +334,10 @@ pub(crate) struct Gelu;
|
|||||||
pub(crate) struct GeluErf;
|
pub(crate) struct GeluErf;
|
||||||
pub(crate) struct Erf;
|
pub(crate) struct Erf;
|
||||||
pub(crate) struct Relu;
|
pub(crate) struct Relu;
|
||||||
pub(crate) struct Silu;
|
|
||||||
pub(crate) struct Tanh;
|
pub(crate) struct Tanh;
|
||||||
pub(crate) struct Floor;
|
pub(crate) struct Floor;
|
||||||
pub(crate) struct Ceil;
|
pub(crate) struct Ceil;
|
||||||
pub(crate) struct Round;
|
pub(crate) struct Round;
|
||||||
pub(crate) struct Sign;
|
|
||||||
|
|
||||||
macro_rules! bin_op {
|
macro_rules! bin_op {
|
||||||
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
||||||
@ -454,20 +536,13 @@ unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
|
|||||||
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
|
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
|
||||||
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
|
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
|
||||||
unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
|
unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
|
||||||
|
unary_op!(Abs, "abs", v, v.abs());
|
||||||
unary_op!(Neg, "neg", v, -v);
|
unary_op!(Neg, "neg", v, -v);
|
||||||
unary_op!(Recip, "recip", v, v.recip());
|
unary_op!(Recip, "recip", v, v.recip());
|
||||||
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
||||||
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
||||||
|
|
||||||
// Hardcode the value for sqrt(2/pi)
|
/// `gelu` operation
|
||||||
// https://github.com/huggingface/candle/issues/1982
|
|
||||||
#[allow(clippy::excessive_precision)]
|
|
||||||
const SQRT_TWO_OVER_PI_F32: f32 = 0.79788456080286535587989211986876373;
|
|
||||||
#[allow(clippy::excessive_precision)]
|
|
||||||
const SQRT_TWO_OVER_PI_F64: f64 = 0.79788456080286535587989211986876373;
|
|
||||||
|
|
||||||
/// Tanh based approximation of the `gelu` operation
|
|
||||||
/// GeluErf is the more precise one.
|
|
||||||
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||||
impl UnaryOpT for Gelu {
|
impl UnaryOpT for Gelu {
|
||||||
const NAME: &'static str = "gelu";
|
const NAME: &'static str = "gelu";
|
||||||
@ -478,7 +553,7 @@ impl UnaryOpT for Gelu {
|
|||||||
* v
|
* v
|
||||||
* (bf16::ONE
|
* (bf16::ONE
|
||||||
+ bf16::tanh(
|
+ bf16::tanh(
|
||||||
bf16::from_f32_const(SQRT_TWO_OVER_PI_F32)
|
(bf16::from_f32_const(2.0) / bf16::PI).sqrt()
|
||||||
* v
|
* v
|
||||||
* (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
|
* (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
|
||||||
))
|
))
|
||||||
@ -489,18 +564,22 @@ impl UnaryOpT for Gelu {
|
|||||||
* v
|
* v
|
||||||
* (f16::ONE
|
* (f16::ONE
|
||||||
+ f16::tanh(
|
+ f16::tanh(
|
||||||
f16::from_f32_const(SQRT_TWO_OVER_PI_F32)
|
(f16::from_f32_const(2.0) / f16::PI).sqrt()
|
||||||
* v
|
* v
|
||||||
* (f16::ONE + f16::from_f32_const(0.044715) * v * v),
|
* (f16::ONE + f16::from_f32_const(0.044715) * v * v),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn f32(v: f32) -> f32 {
|
fn f32(v: f32) -> f32 {
|
||||||
0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v)))
|
0.5 * v
|
||||||
|
* (1.0
|
||||||
|
+ f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
||||||
}
|
}
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn f64(v: f64) -> f64 {
|
fn f64(v: f64) -> f64 {
|
||||||
0.5 * v * (1.0 + f64::tanh(SQRT_TWO_OVER_PI_F64 * v * (1.0 + 0.044715 * v * v)))
|
0.5 * v
|
||||||
|
* (1.0
|
||||||
|
+ f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
||||||
}
|
}
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn u8(_: u8) -> u8 {
|
fn u8(_: u8) -> u8 {
|
||||||
@ -553,8 +632,6 @@ impl UnaryOpT for Gelu {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// `erf` operation
|
|
||||||
/// <https://en.wikipedia.org/wiki/Error_function>
|
|
||||||
impl UnaryOpT for Erf {
|
impl UnaryOpT for Erf {
|
||||||
const NAME: &'static str = "erf";
|
const NAME: &'static str = "erf";
|
||||||
const KERNEL: &'static str = "uerf";
|
const KERNEL: &'static str = "uerf";
|
||||||
@ -589,111 +666,6 @@ impl UnaryOpT for Erf {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Silu operation
|
|
||||||
impl UnaryOpT for Silu {
|
|
||||||
const NAME: &'static str = "silu";
|
|
||||||
const V: Self = Silu;
|
|
||||||
#[inline(always)]
|
|
||||||
fn bf16(v: bf16) -> bf16 {
|
|
||||||
v / (bf16::ONE + (-v).exp())
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn f16(v: f16) -> f16 {
|
|
||||||
v / (f16::ONE + (-v).exp())
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn f32(v: f32) -> f32 {
|
|
||||||
v / (1.0 + (-v).exp())
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn f64(v: f64) -> f64 {
|
|
||||||
v / (1.0 + (-v).exp())
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn u8(_: u8) -> u8 {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn u32(_: u32) -> u32 {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn i64(_: i64) -> i64 {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
const KERNEL: &'static str = "usilu";
|
|
||||||
|
|
||||||
#[cfg(feature = "mkl")]
|
|
||||||
const F32_VEC: bool = true;
|
|
||||||
|
|
||||||
#[cfg(feature = "mkl")]
|
|
||||||
#[inline(always)]
|
|
||||||
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
|
||||||
crate::mkl::vs_silu(xs, ys)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "mkl")]
|
|
||||||
const F64_VEC: bool = true;
|
|
||||||
|
|
||||||
#[cfg(feature = "mkl")]
|
|
||||||
#[inline(always)]
|
|
||||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
|
||||||
crate::mkl::vd_silu(xs, ys)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
const F32_VEC: bool = true;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
#[inline(always)]
|
|
||||||
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
|
||||||
crate::accelerate::vs_silu(xs, ys)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
const F64_VEC: bool = true;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
#[inline(always)]
|
|
||||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
|
||||||
crate::accelerate::vd_silu(xs, ys)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl UnaryOpT for Abs {
|
|
||||||
const NAME: &'static str = "abs";
|
|
||||||
const KERNEL: &'static str = "uabs";
|
|
||||||
const V: Self = Abs;
|
|
||||||
#[inline(always)]
|
|
||||||
fn bf16(v: bf16) -> bf16 {
|
|
||||||
v.abs()
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn f16(v: f16) -> f16 {
|
|
||||||
v.abs()
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn f32(v: f32) -> f32 {
|
|
||||||
v.abs()
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn f64(v: f64) -> f64 {
|
|
||||||
v.abs()
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn u8(v: u8) -> u8 {
|
|
||||||
v
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn u32(v: u32) -> u32 {
|
|
||||||
v
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn i64(v: i64) -> i64 {
|
|
||||||
v.abs()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl UnaryOpT for Ceil {
|
impl UnaryOpT for Ceil {
|
||||||
const NAME: &'static str = "ceil";
|
const NAME: &'static str = "ceil";
|
||||||
const KERNEL: &'static str = "uceil";
|
const KERNEL: &'static str = "uceil";
|
||||||
@ -915,10 +887,6 @@ impl BackpropOp {
|
|||||||
};
|
};
|
||||||
Self(op)
|
Self(op)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn is_none(&self) -> bool {
|
|
||||||
self.0.is_none()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::ops::Deref for BackpropOp {
|
impl std::ops::Deref for BackpropOp {
|
||||||
@ -927,37 +895,3 @@ impl std::ops::Deref for BackpropOp {
|
|||||||
&self.0
|
&self.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UnaryOpT for Sign {
|
|
||||||
const NAME: &'static str = "sign";
|
|
||||||
const KERNEL: &'static str = "usign";
|
|
||||||
const V: Self = Sign;
|
|
||||||
#[inline(always)]
|
|
||||||
fn bf16(v: bf16) -> bf16 {
|
|
||||||
bf16::from((v > bf16::ZERO) as i8) - bf16::from((v < bf16::ZERO) as i8)
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn f16(v: f16) -> f16 {
|
|
||||||
f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8)
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn f32(v: f32) -> f32 {
|
|
||||||
f32::from(v > 0.) - f32::from(v < 0.)
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn f64(v: f64) -> f64 {
|
|
||||||
f64::from(v > 0.) - f64::from(v < 0.)
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn u8(v: u8) -> u8 {
|
|
||||||
u8::min(1, v)
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn u32(v: u32) -> u32 {
|
|
||||||
u32::min(1, v)
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn i64(v: i64) -> i64 {
|
|
||||||
(v > 0) as i64 - (v < 0) as i64
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -42,7 +42,7 @@ pub enum OpCode {
|
|||||||
Stop = b'.',
|
Stop = b'.',
|
||||||
NewObj = 0x81,
|
NewObj = 0x81,
|
||||||
EmptyList = b']',
|
EmptyList = b']',
|
||||||
BinFloat = b'G',
|
BinFloat = b'g',
|
||||||
Append = b'a',
|
Append = b'a',
|
||||||
Appends = b'e',
|
Appends = b'e',
|
||||||
}
|
}
|
||||||
@ -193,55 +193,6 @@ impl Object {
|
|||||||
_ => Err(self),
|
_ => Err(self),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn into_tensor_info(
|
|
||||||
self,
|
|
||||||
name: Self,
|
|
||||||
dir_name: &std::path::Path,
|
|
||||||
) -> Result<Option<TensorInfo>> {
|
|
||||||
let name = match name.unicode() {
|
|
||||||
Ok(name) => name,
|
|
||||||
Err(_) => return Ok(None),
|
|
||||||
};
|
|
||||||
let (callable, args) = match self.reduce() {
|
|
||||||
Ok(callable_args) => callable_args,
|
|
||||||
_ => return Ok(None),
|
|
||||||
};
|
|
||||||
let (callable, args) = match callable {
|
|
||||||
Object::Class {
|
|
||||||
module_name,
|
|
||||||
class_name,
|
|
||||||
} if module_name == "torch._tensor" && class_name == "_rebuild_from_type_v2" => {
|
|
||||||
let mut args = args.tuple()?;
|
|
||||||
let callable = args.remove(0);
|
|
||||||
let args = args.remove(1);
|
|
||||||
(callable, args)
|
|
||||||
}
|
|
||||||
Object::Class {
|
|
||||||
module_name,
|
|
||||||
class_name,
|
|
||||||
} if module_name == "torch._utils" && class_name == "_rebuild_parameter" => {
|
|
||||||
let mut args = args.tuple()?;
|
|
||||||
args.remove(0).reduce()?
|
|
||||||
}
|
|
||||||
_ => (callable, args),
|
|
||||||
};
|
|
||||||
match callable {
|
|
||||||
Object::Class {
|
|
||||||
module_name,
|
|
||||||
class_name,
|
|
||||||
} if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
|
|
||||||
_ => return Ok(None),
|
|
||||||
};
|
|
||||||
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
|
|
||||||
Ok(Some(TensorInfo {
|
|
||||||
name,
|
|
||||||
dtype,
|
|
||||||
layout,
|
|
||||||
path: format!("{}/{}", dir_name.to_string_lossy(), file_path),
|
|
||||||
storage_size,
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TryFrom<Object> for String {
|
impl TryFrom<Object> for String {
|
||||||
@ -350,10 +301,8 @@ impl Stack {
|
|||||||
module_name,
|
module_name,
|
||||||
class_name,
|
class_name,
|
||||||
} => {
|
} => {
|
||||||
if module_name == "collections"
|
if module_name == "collections" && class_name == "OrderedDict" {
|
||||||
&& (class_name == "OrderedDict" || class_name == "defaultdict")
|
// TODO: have a separate ordered dict.
|
||||||
{
|
|
||||||
// TODO: have a separate ordered dict and a separate default dict.
|
|
||||||
Some(Object::Dict(vec![]))
|
Some(Object::Dict(vec![]))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
@ -462,10 +411,7 @@ impl Stack {
|
|||||||
self.push(Object::Int(arg))
|
self.push(Object::Int(arg))
|
||||||
}
|
}
|
||||||
OpCode::BinFloat => {
|
OpCode::BinFloat => {
|
||||||
// Somehow floats are encoded using BigEndian whereas int types use LittleEndian.
|
let arg = r.read_f64::<LittleEndian>()?;
|
||||||
// https://github.com/python/cpython/blob/0c80da4c14d904a367968955544dd6ae58c8101c/Lib/pickletools.py#L855
|
|
||||||
// https://github.com/pytorch/pytorch/blob/372d078f361e726bb4ac0884ac334b04c58179ef/torch/_weights_only_unpickler.py#L243
|
|
||||||
let arg = r.read_f64::<byteorder::BigEndian>()?;
|
|
||||||
self.push(Object::Float(arg))
|
self.push(Object::Float(arg))
|
||||||
}
|
}
|
||||||
OpCode::BinUnicode => {
|
OpCode::BinUnicode => {
|
||||||
@ -619,7 +565,6 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
|
|||||||
"HalfStorage" => DType::F16,
|
"HalfStorage" => DType::F16,
|
||||||
"BFloat16Storage" => DType::BF16,
|
"BFloat16Storage" => DType::BF16,
|
||||||
"ByteStorage" => DType::U8,
|
"ByteStorage" => DType::U8,
|
||||||
"LongStorage" => DType::I64,
|
|
||||||
other => {
|
other => {
|
||||||
crate::bail!("unsupported storage type {other}")
|
crate::bail!("unsupported storage type {other}")
|
||||||
}
|
}
|
||||||
@ -637,16 +582,9 @@ pub struct TensorInfo {
|
|||||||
pub storage_size: usize,
|
pub storage_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Read the tensor info from a .pth file.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * `file` - The path to the .pth file.
|
|
||||||
/// * `verbose` - Whether to print debug information.
|
|
||||||
/// * `key` - Optional key to retrieve `state_dict` from the pth file.
|
|
||||||
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
||||||
file: P,
|
file: P,
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
key: Option<&str>,
|
|
||||||
) -> Result<Vec<TensorInfo>> {
|
) -> Result<Vec<TensorInfo>> {
|
||||||
let file = std::fs::File::open(file)?;
|
let file = std::fs::File::open(file)?;
|
||||||
let zip_reader = std::io::BufReader::new(file);
|
let zip_reader = std::io::BufReader::new(file);
|
||||||
@ -668,9 +606,8 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
|||||||
stack.read_loop(&mut reader)?;
|
stack.read_loop(&mut reader)?;
|
||||||
let obj = stack.finalize()?;
|
let obj = stack.finalize()?;
|
||||||
if VERBOSE || verbose {
|
if VERBOSE || verbose {
|
||||||
println!("{obj:#?}");
|
println!("{obj:?}");
|
||||||
}
|
}
|
||||||
|
|
||||||
let obj = match obj {
|
let obj = match obj {
|
||||||
Object::Build { callable, args } => match *callable {
|
Object::Build { callable, args } => match *callable {
|
||||||
Object::Reduce { callable, args: _ } => match *callable {
|
Object::Reduce { callable, args: _ } => match *callable {
|
||||||
@ -684,30 +621,52 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
|||||||
},
|
},
|
||||||
obj => obj,
|
obj => obj,
|
||||||
};
|
};
|
||||||
|
|
||||||
// If key is provided, then we need to extract the state_dict from the object.
|
|
||||||
let obj = if let Some(key) = key {
|
|
||||||
if let Object::Dict(key_values) = obj {
|
|
||||||
key_values
|
|
||||||
.into_iter()
|
|
||||||
.find(|(k, _)| *k == Object::Unicode(key.to_owned()))
|
|
||||||
.map(|(_, v)| v)
|
|
||||||
.ok_or_else(|| E::Msg(format!("key {key} not found")))?
|
|
||||||
} else {
|
|
||||||
obj
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
obj
|
|
||||||
};
|
|
||||||
|
|
||||||
// If the object is a dict, then we can extract the tensor info from it.
|
|
||||||
// NOTE: We are assuming that the `obj` is state_dict by this stage.
|
|
||||||
if let Object::Dict(key_values) = obj {
|
if let Object::Dict(key_values) = obj {
|
||||||
for (name, value) in key_values.into_iter() {
|
for (name, value) in key_values.into_iter() {
|
||||||
match value.into_tensor_info(name, &dir_name) {
|
let name = match name.unicode() {
|
||||||
Ok(Some(tensor_info)) => tensor_infos.push(tensor_info),
|
Ok(name) => name,
|
||||||
Ok(None) => {}
|
Err(_) => continue,
|
||||||
Err(err) => eprintln!("skipping: {err:?}"),
|
};
|
||||||
|
let (callable, args) = match value.reduce() {
|
||||||
|
Ok(callable_args) => callable_args,
|
||||||
|
_ => continue,
|
||||||
|
};
|
||||||
|
let (callable, args) = match callable {
|
||||||
|
Object::Class {
|
||||||
|
module_name,
|
||||||
|
class_name,
|
||||||
|
} if module_name == "torch._tensor"
|
||||||
|
&& class_name == "_rebuild_from_type_v2" =>
|
||||||
|
{
|
||||||
|
let mut args = args.tuple()?;
|
||||||
|
let callable = args.remove(0);
|
||||||
|
let args = args.remove(1);
|
||||||
|
(callable, args)
|
||||||
|
}
|
||||||
|
_ => (callable, args),
|
||||||
|
};
|
||||||
|
match callable {
|
||||||
|
Object::Class {
|
||||||
|
module_name,
|
||||||
|
class_name,
|
||||||
|
} if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
|
||||||
|
_ => continue,
|
||||||
|
};
|
||||||
|
match rebuild_args(args) {
|
||||||
|
Ok((layout, dtype, file_path, storage_size)) => {
|
||||||
|
let mut path = dir_name.clone();
|
||||||
|
path.push(file_path);
|
||||||
|
tensor_infos.push(TensorInfo {
|
||||||
|
name,
|
||||||
|
dtype,
|
||||||
|
layout,
|
||||||
|
path: path.to_string_lossy().into_owned(),
|
||||||
|
storage_size,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
eprintln!("skipping {name}: {err:?}")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -724,8 +683,8 @@ pub struct PthTensors {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl PthTensors {
|
impl PthTensors {
|
||||||
pub fn new<P: AsRef<std::path::Path>>(path: P, key: Option<&str>) -> Result<Self> {
|
pub fn new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
|
||||||
let tensor_infos = read_pth_tensor_info(path.as_ref(), false, key)?;
|
let tensor_infos = read_pth_tensor_info(path.as_ref(), false)?;
|
||||||
let tensor_infos = tensor_infos
|
let tensor_infos = tensor_infos
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|ti| (ti.name.to_string(), ti))
|
.map(|ti| (ti.name.to_string(), ti))
|
||||||
@ -739,7 +698,6 @@ impl PthTensors {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
|
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
|
||||||
use std::io::Read;
|
|
||||||
let tensor_info = match self.tensor_infos.get(name) {
|
let tensor_info = match self.tensor_infos.get(name) {
|
||||||
None => return Ok(None),
|
None => return Ok(None),
|
||||||
Some(tensor_info) => tensor_info,
|
Some(tensor_info) => tensor_info,
|
||||||
@ -748,70 +706,20 @@ impl PthTensors {
|
|||||||
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
|
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
|
||||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||||
let mut reader = zip.by_name(&tensor_info.path)?;
|
let mut reader = zip.by_name(&tensor_info.path)?;
|
||||||
let is_fortran_contiguous = tensor_info.layout.is_fortran_contiguous();
|
|
||||||
let rank = tensor_info.layout.shape().rank();
|
|
||||||
|
|
||||||
// Reading the data is a bit tricky as it can be strided, for now only support the basic
|
// Reading the data is a bit tricky as it can be strided, use an offset, etc.
|
||||||
// case and when the tensor is fortran contiguous.
|
// For now only support the basic case.
|
||||||
if !tensor_info.layout.is_contiguous() && !is_fortran_contiguous {
|
if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() {
|
||||||
crate::bail!(
|
crate::bail!(
|
||||||
"cannot retrieve non-contiguous tensors {:?}",
|
"cannot retrieve non-contiguous tensors {:?}",
|
||||||
tensor_info.layout
|
tensor_info.layout
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
let start_offset = tensor_info.layout.start_offset();
|
|
||||||
if start_offset > 0 {
|
|
||||||
std::io::copy(
|
|
||||||
&mut reader.by_ref().take(start_offset as u64),
|
|
||||||
&mut std::io::sink(),
|
|
||||||
)?;
|
|
||||||
}
|
|
||||||
let tensor = Tensor::from_reader(
|
let tensor = Tensor::from_reader(
|
||||||
tensor_info.layout.shape().clone(),
|
tensor_info.layout.shape().clone(),
|
||||||
tensor_info.dtype,
|
tensor_info.dtype,
|
||||||
&mut reader,
|
&mut reader,
|
||||||
)?;
|
)?;
|
||||||
|
Ok(Some(tensor))
|
||||||
if rank > 1 && is_fortran_contiguous {
|
|
||||||
// Reverse the shape, e.g. Shape(2, 3, 4) -> Shape(4, 3, 2)
|
|
||||||
let shape_reversed: Vec<_> = tensor_info.layout.dims().iter().rev().cloned().collect();
|
|
||||||
let tensor = tensor.reshape(shape_reversed)?;
|
|
||||||
|
|
||||||
// Permute (transpose) the dimensions, e.g. Shape(4, 3, 2) -> Shape(2, 3, 4)
|
|
||||||
let dim_indeces_reversed: Vec<_> = (0..rank).rev().collect();
|
|
||||||
let tensor = tensor.permute(dim_indeces_reversed)?;
|
|
||||||
Ok(Some(tensor))
|
|
||||||
} else {
|
|
||||||
Ok(Some(tensor))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Read all the tensors from a PyTorch pth file with a given key.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * `path` - Path to the pth file.
|
|
||||||
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
|
|
||||||
/// contains multiple objects and the state_dict is the one we are interested in.
|
|
||||||
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
|
|
||||||
path: P,
|
|
||||||
key: Option<&str>,
|
|
||||||
) -> Result<Vec<(String, Tensor)>> {
|
|
||||||
let pth = PthTensors::new(path, key)?;
|
|
||||||
let tensor_names = pth.tensor_infos.keys();
|
|
||||||
let mut tensors = Vec::with_capacity(tensor_names.len());
|
|
||||||
for name in tensor_names {
|
|
||||||
if let Some(tensor) = pth.get(name)? {
|
|
||||||
tensors.push((name.to_string(), tensor))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(tensors)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Read all the tensors from a PyTorch pth file.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * `path` - Path to the pth file.
|
|
||||||
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
|
|
||||||
read_all_with_key(path, None)
|
|
||||||
}
|
|
||||||
|
@ -50,9 +50,14 @@ pub(crate) unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 {
|
|||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||||
let qk = QK8_0;
|
let qk = QK8_0;
|
||||||
|
let nb = n / qk;
|
||||||
if n % QK8_0 != 0 {
|
if n % QK8_0 != 0 {
|
||||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||||
}
|
}
|
||||||
|
if nb % 2 != 0 {
|
||||||
|
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||||
|
}
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let mut acc = _mm256_setzero_ps();
|
let mut acc = _mm256_setzero_ps();
|
||||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||||
@ -353,7 +358,7 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
|
|||||||
q3 = q3.add(32);
|
q3 = q3.add(32);
|
||||||
|
|
||||||
// Prepare low and high bits
|
// Prepare low and high bits
|
||||||
// We hardcode the shifts here to avoid loading them into a separate register
|
// We hardcode the shifts here to avoid loading them into a seperate register
|
||||||
let q3l_0 = _mm256_and_si256(q3bits, m3);
|
let q3l_0 = _mm256_and_si256(q3bits, m3);
|
||||||
let q3h_0 = if j == 0 {
|
let q3h_0 = if j == 0 {
|
||||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), 0)
|
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), 0)
|
||||||
@ -586,7 +591,7 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
|
|||||||
let q5bits = _mm256_loadu_si256(q5 as *const __m256i);
|
let q5bits = _mm256_loadu_si256(q5 as *const __m256i);
|
||||||
q5 = q5.add(32);
|
q5 = q5.add(32);
|
||||||
|
|
||||||
//Similar to q3k we hardcode the shifts here to avoid loading them into a separate register
|
//Similar to q3k we hardcode the shifts here to avoid loading them into a seperate register
|
||||||
let q5l_0 = _mm256_and_si256(q5bits, m4);
|
let q5l_0 = _mm256_and_si256(q5bits, m4);
|
||||||
let q5l_0_shift_input = _mm256_and_si256(hbits, hmask);
|
let q5l_0_shift_input = _mm256_and_si256(hbits, hmask);
|
||||||
let q5l_0_right_shift = match j {
|
let q5l_0_right_shift = match j {
|
||||||
|
@ -1,680 +0,0 @@
|
|||||||
use super::{GgmlDType, QStorage};
|
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
|
|
||||||
use crate::{CudaDevice, CudaStorage, Result};
|
|
||||||
use half::f16;
|
|
||||||
|
|
||||||
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub struct QCudaStorage {
|
|
||||||
data: CudaSlice<u8>,
|
|
||||||
dtype: GgmlDType,
|
|
||||||
device: CudaDevice,
|
|
||||||
}
|
|
||||||
|
|
||||||
static FORCE_DMMV: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
|
|
||||||
|
|
||||||
pub fn set_force_dmmv(f: bool) {
|
|
||||||
FORCE_DMMV.store(f, std::sync::atomic::Ordering::Relaxed)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub const WARP_SIZE: usize = 32;
|
|
||||||
pub const MMQ_X_Q4_0_AMPERE: usize = 4;
|
|
||||||
pub const MMQ_Y_Q4_0_AMPERE: usize = 32;
|
|
||||||
pub const NWARPS_Q4_0_AMPERE: usize = 4;
|
|
||||||
pub const GGML_CUDA_MMV_X: usize = 32;
|
|
||||||
pub const GGML_CUDA_MMV_Y: usize = 1;
|
|
||||||
pub const CUDA_QUANTIZE_BLOCK_SIZE: usize = 256;
|
|
||||||
pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
|
|
||||||
pub const MATRIX_ROW_PADDING: usize = 512;
|
|
||||||
|
|
||||||
fn ceil_div(p: usize, q: usize) -> usize {
|
|
||||||
(p + q - 1) / q
|
|
||||||
}
|
|
||||||
|
|
||||||
fn pad(p: usize, q: usize) -> usize {
|
|
||||||
ceil_div(p, q) * q
|
|
||||||
}
|
|
||||||
|
|
||||||
fn quantize_q8_1(
|
|
||||||
src: &CudaView<f32>,
|
|
||||||
dst: &mut CudaSlice<u8>,
|
|
||||||
elem_count: usize,
|
|
||||||
ky: usize,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
) -> Result<()> {
|
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let kx = elem_count;
|
|
||||||
let kx_padded = pad(kx, MATRIX_ROW_PADDING);
|
|
||||||
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
|
|
||||||
let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
|
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
|
||||||
grid_dim: (num_blocks as u32, ky as u32, 1),
|
|
||||||
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
|
|
||||||
shared_mem_bytes: 0,
|
|
||||||
};
|
|
||||||
let params = (src, dst, kx as i32, kx_padded as i32);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn dequantize_f32(
|
|
||||||
data: &CudaSlice<u8>,
|
|
||||||
dtype: GgmlDType,
|
|
||||||
elem_count: usize,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
) -> Result<CudaStorage> {
|
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let nb = (elem_count + 255) / 256;
|
|
||||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
|
||||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
|
|
||||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb),
|
|
||||||
GgmlDType::Q5_0 => (
|
|
||||||
"dequantize_block_q5_0_f32",
|
|
||||||
false,
|
|
||||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
|
||||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
|
||||||
),
|
|
||||||
GgmlDType::Q5_1 => (
|
|
||||||
"dequantize_block_q5_1_f32",
|
|
||||||
false,
|
|
||||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
|
||||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
|
||||||
),
|
|
||||||
GgmlDType::Q8_0 => ("dequantize_block_q8_0_f32", false, 32, nb),
|
|
||||||
GgmlDType::Q2K => ("dequantize_block_q2_K_f32", true, 64, nb),
|
|
||||||
GgmlDType::Q3K => ("dequantize_block_q3_K_f32", true, 64, nb),
|
|
||||||
GgmlDType::Q4K => ("dequantize_block_q4_K_f32", true, 32, nb),
|
|
||||||
GgmlDType::Q5K => ("dequantize_block_q5_K_f32", true, 64, nb),
|
|
||||||
GgmlDType::Q6K => ("dequantize_block_q6_K_f32", true, 64, nb),
|
|
||||||
GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb),
|
|
||||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
|
||||||
};
|
|
||||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
|
||||||
let dst = unsafe { dev.alloc::<f32>(elem_count).w()? };
|
|
||||||
// See e.g.
|
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
|
||||||
grid_dim: (num_blocks as u32, 1, 1),
|
|
||||||
block_dim: (block_dim as u32, 1, 1),
|
|
||||||
shared_mem_bytes: 0,
|
|
||||||
};
|
|
||||||
|
|
||||||
if is_k {
|
|
||||||
let params = (data, &dst);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
} else {
|
|
||||||
let nb32 = match dtype {
|
|
||||||
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
|
||||||
_ => elem_count / 32,
|
|
||||||
};
|
|
||||||
let params = (data, &dst, nb32 as i32);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
}
|
|
||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn dequantize_f16(
|
|
||||||
data: &CudaSlice<u8>,
|
|
||||||
dtype: GgmlDType,
|
|
||||||
elem_count: usize,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
) -> Result<CudaStorage> {
|
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let nb = (elem_count + 255) / 256;
|
|
||||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
|
||||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
|
|
||||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb),
|
|
||||||
GgmlDType::Q5_0 => (
|
|
||||||
"dequantize_block_q5_0_f16",
|
|
||||||
false,
|
|
||||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
|
||||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
|
||||||
),
|
|
||||||
GgmlDType::Q5_1 => (
|
|
||||||
"dequantize_block_q5_1_f16",
|
|
||||||
false,
|
|
||||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
|
||||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
|
||||||
),
|
|
||||||
GgmlDType::Q8_0 => ("dequantize_block_q8_0_f16", false, 32, nb),
|
|
||||||
GgmlDType::Q2K => ("dequantize_block_q2_K_f16", true, 64, nb),
|
|
||||||
GgmlDType::Q3K => ("dequantize_block_q3_K_f16", true, 64, nb),
|
|
||||||
GgmlDType::Q4K => ("dequantize_block_q4_K_f16", true, 32, nb),
|
|
||||||
GgmlDType::Q5K => ("dequantize_block_q5_K_f16", true, 64, nb),
|
|
||||||
GgmlDType::Q6K => ("dequantize_block_q6_K_f16", true, 64, nb),
|
|
||||||
GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb),
|
|
||||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
|
||||||
};
|
|
||||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
|
||||||
let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
|
|
||||||
// See e.g.
|
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
|
||||||
grid_dim: (num_blocks as u32, 1, 1),
|
|
||||||
block_dim: (block_dim as u32, 1, 1),
|
|
||||||
shared_mem_bytes: 0,
|
|
||||||
};
|
|
||||||
|
|
||||||
if is_k {
|
|
||||||
let params = (data, &dst);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
} else {
|
|
||||||
let nb32 = match dtype {
|
|
||||||
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
|
||||||
_ => elem_count / 32,
|
|
||||||
};
|
|
||||||
let params = (data, &dst, nb32 as i32);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
}
|
|
||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn dequantize_mul_mat_vec(
|
|
||||||
data: &CudaSlice<u8>,
|
|
||||||
y: &CudaView<f32>,
|
|
||||||
dtype: GgmlDType,
|
|
||||||
ncols: usize,
|
|
||||||
nrows: usize,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
) -> Result<CudaStorage> {
|
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
|
|
||||||
if data_elems < ncols * nrows {
|
|
||||||
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
|
||||||
}
|
|
||||||
if y.len() != ncols {
|
|
||||||
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
|
|
||||||
}
|
|
||||||
let kernel_name = match dtype {
|
|
||||||
GgmlDType::Q4_0 => "dequantize_mul_mat_vec_q4_0_cuda",
|
|
||||||
GgmlDType::Q4_1 => "dequantize_mul_mat_vec_q4_1_cuda",
|
|
||||||
GgmlDType::Q5_0 => "dequantize_mul_mat_vec_q5_0_cuda",
|
|
||||||
GgmlDType::Q5_1 => "dequantize_mul_mat_vec_q5_1_cuda",
|
|
||||||
GgmlDType::Q8_0 => "dequantize_mul_mat_vec_q8_0_cuda",
|
|
||||||
GgmlDType::Q2K => "dequantize_mul_mat_vec_q2_k",
|
|
||||||
GgmlDType::Q3K => "dequantize_mul_mat_vec_q3_k",
|
|
||||||
GgmlDType::Q4K => "dequantize_mul_mat_vec_q4_k",
|
|
||||||
GgmlDType::Q5K => "dequantize_mul_mat_vec_q5_k",
|
|
||||||
GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k",
|
|
||||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
|
||||||
};
|
|
||||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
|
||||||
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
|
|
||||||
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
|
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
|
||||||
grid_dim: (block_num_y as u32, 1, 1),
|
|
||||||
block_dim: (WARP_SIZE as u32, GGML_CUDA_MMV_Y as u32, 1),
|
|
||||||
shared_mem_bytes: 0,
|
|
||||||
};
|
|
||||||
|
|
||||||
let params = (data, y, &dst, ncols as i32, nrows as i32);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn mul_mat_vec_via_q8_1(
|
|
||||||
data: &CudaSlice<u8>,
|
|
||||||
y: &CudaView<f32>,
|
|
||||||
dtype: GgmlDType,
|
|
||||||
ncols: usize,
|
|
||||||
nrows: usize,
|
|
||||||
b_size: usize,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
) -> Result<CudaStorage> {
|
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
|
|
||||||
if data_elems < ncols * nrows {
|
|
||||||
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
|
||||||
}
|
|
||||||
if y.len() != ncols * b_size {
|
|
||||||
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
|
|
||||||
}
|
|
||||||
if b_size == 0 || b_size > 8 {
|
|
||||||
crate::bail!("only bsize between 1 and 8 are supported, got {b_size}")
|
|
||||||
}
|
|
||||||
// Start by quantizing y
|
|
||||||
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
|
|
||||||
let y_size_in_bytes =
|
|
||||||
b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
|
||||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
|
||||||
quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?;
|
|
||||||
|
|
||||||
let kernel_name = match dtype {
|
|
||||||
GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda",
|
|
||||||
GgmlDType::Q4_1 => "mul_mat_vec_q4_1_q8_1_cuda",
|
|
||||||
GgmlDType::Q5_0 => "mul_mat_vec_q5_0_q8_1_cuda",
|
|
||||||
GgmlDType::Q5_1 => "mul_mat_vec_q5_1_q8_1_cuda",
|
|
||||||
GgmlDType::Q8_0 => "mul_mat_vec_q8_0_q8_1_cuda",
|
|
||||||
GgmlDType::Q2K => "mul_mat_vec_q2_K_q8_1_cuda",
|
|
||||||
GgmlDType::Q3K => "mul_mat_vec_q3_K_q8_1_cuda",
|
|
||||||
GgmlDType::Q4K => "mul_mat_vec_q4_K_q8_1_cuda",
|
|
||||||
GgmlDType::Q5K => "mul_mat_vec_q5_K_q8_1_cuda",
|
|
||||||
GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda",
|
|
||||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
|
||||||
};
|
|
||||||
let kernel_name = format!("{kernel_name}{b_size}");
|
|
||||||
let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
|
|
||||||
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
|
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
|
|
||||||
let (nblocks, nwarps) = match b_size {
|
|
||||||
1 => (nrows as u32, 4),
|
|
||||||
2..=4 => ((nrows as u32 + 1) / 2, 4),
|
|
||||||
5..=8 => ((nrows as u32 + 1) / 2, 2),
|
|
||||||
_ => crate::bail!("unexpected bsize {b_size}"),
|
|
||||||
};
|
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
|
||||||
grid_dim: (nblocks, 1, 1),
|
|
||||||
block_dim: (WARP_SIZE as u32, nwarps, 1),
|
|
||||||
shared_mem_bytes: 0,
|
|
||||||
};
|
|
||||||
|
|
||||||
let params = (
|
|
||||||
data,
|
|
||||||
&y_q8_1,
|
|
||||||
&dst,
|
|
||||||
/* ncols_x */ ncols as i32,
|
|
||||||
/* nrows_x */ nrows as i32,
|
|
||||||
/* nrows_y */ ncols_padded as i32,
|
|
||||||
/* nrows_dst */ nrows as i32,
|
|
||||||
);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
fn mul_mat_via_q8_1(
|
|
||||||
data: &CudaSlice<u8>,
|
|
||||||
y: &CudaView<f32>,
|
|
||||||
dtype: GgmlDType,
|
|
||||||
x_rows: usize,
|
|
||||||
x_cols: usize,
|
|
||||||
y_rows: usize,
|
|
||||||
y_cols: usize,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
) -> Result<CudaStorage> {
|
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
|
|
||||||
if data_elems < x_rows * x_cols {
|
|
||||||
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
|
|
||||||
}
|
|
||||||
if y.len() != y_rows * y_cols {
|
|
||||||
crate::bail!("unexpected y size {}, {y_rows} {y_cols}", y.len())
|
|
||||||
}
|
|
||||||
if x_cols != y_rows {
|
|
||||||
crate::bail!("unexpected x/y size {x_rows} {x_cols} {y_rows} {y_cols}")
|
|
||||||
}
|
|
||||||
let k = x_cols;
|
|
||||||
// Start by quantizing y
|
|
||||||
let k_padded = pad(k, MATRIX_ROW_PADDING);
|
|
||||||
let y_size_in_bytes =
|
|
||||||
k_padded * y_rows * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
|
||||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
|
||||||
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
|
|
||||||
|
|
||||||
let (kernel_name, mmq_x, mmq_y) = match dtype {
|
|
||||||
GgmlDType::Q4_0 => ("mul_mat_q4_0", 64, 128),
|
|
||||||
GgmlDType::Q4_1 => ("mul_mat_q4_1", 64, 128),
|
|
||||||
GgmlDType::Q5_0 => ("mul_mat_q5_0", 128, 64),
|
|
||||||
GgmlDType::Q5_1 => ("mul_mat_q5_1", 128, 64),
|
|
||||||
GgmlDType::Q8_0 => ("mul_mat_q8_0", 128, 64),
|
|
||||||
GgmlDType::Q2K => ("mul_mat_q2_K", 64, 128),
|
|
||||||
GgmlDType::Q3K => ("mul_mat_q3_K", 128, 128),
|
|
||||||
GgmlDType::Q4K => ("mul_mat_q4_K", 64, 128),
|
|
||||||
GgmlDType::Q5K => ("mul_mat_q5_K", 64, 128),
|
|
||||||
GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64),
|
|
||||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
|
||||||
};
|
|
||||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
|
||||||
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
|
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
|
||||||
grid_dim: (
|
|
||||||
ceil_div(x_rows, mmq_y) as u32,
|
|
||||||
ceil_div(y_cols, mmq_x) as u32,
|
|
||||||
1,
|
|
||||||
),
|
|
||||||
block_dim: (WARP_SIZE as u32, 4, 1),
|
|
||||||
shared_mem_bytes: 0,
|
|
||||||
};
|
|
||||||
|
|
||||||
let params = (
|
|
||||||
/* vx */ data,
|
|
||||||
/* vy */ &y_q8_1,
|
|
||||||
/* dst */ &dst,
|
|
||||||
/* ncols_x */ x_cols as i32,
|
|
||||||
/* nrows_x */ x_rows as i32,
|
|
||||||
/* ncols_y */ y_cols as i32,
|
|
||||||
/* nrows_y */ k_padded as i32,
|
|
||||||
/* nrows_dst */ x_rows as i32,
|
|
||||||
);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
|
||||||
}
|
|
||||||
|
|
||||||
impl QCudaStorage {
|
|
||||||
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
|
|
||||||
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
|
|
||||||
let data = device.alloc_zeros::<u8>(size_in_bytes).w()?;
|
|
||||||
Ok(QCudaStorage {
|
|
||||||
data,
|
|
||||||
device: device.clone(),
|
|
||||||
dtype,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dtype(&self) -> GgmlDType {
|
|
||||||
self.dtype
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn device(&self) -> &CudaDevice {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dequantize(&self, elem_count: usize) -> Result<CudaStorage> {
|
|
||||||
fn deq<T: GgmlType>(buffer: &[u8], n: usize, dst: &mut [f32]) -> Result<()> {
|
|
||||||
let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) };
|
|
||||||
let vec = slice.to_vec();
|
|
||||||
T::to_float(&vec, dst)
|
|
||||||
}
|
|
||||||
|
|
||||||
let fast_kernel = matches!(
|
|
||||||
self.dtype,
|
|
||||||
GgmlDType::Q4_0
|
|
||||||
| GgmlDType::Q4_1
|
|
||||||
| GgmlDType::Q5_0
|
|
||||||
| GgmlDType::Q5_1
|
|
||||||
| GgmlDType::Q8_0
|
|
||||||
| GgmlDType::Q2K
|
|
||||||
| GgmlDType::Q3K
|
|
||||||
| GgmlDType::Q4K
|
|
||||||
| GgmlDType::Q5K
|
|
||||||
| GgmlDType::Q6K
|
|
||||||
| GgmlDType::Q8K
|
|
||||||
);
|
|
||||||
if fast_kernel {
|
|
||||||
return dequantize_f32(&self.data, self.dtype, elem_count, self.device());
|
|
||||||
}
|
|
||||||
// Run the dequantization on cpu.
|
|
||||||
|
|
||||||
let buffer = self.device.dtoh_sync_copy(&self.data).w()?;
|
|
||||||
let mut out = vec![0.0; elem_count];
|
|
||||||
let block_len = elem_count / self.dtype.block_size();
|
|
||||||
match self.dtype {
|
|
||||||
GgmlDType::F32 => deq::<f32>(&buffer, block_len, &mut out)?,
|
|
||||||
GgmlDType::F16 => deq::<half::f16>(&buffer, block_len, &mut out)?,
|
|
||||||
GgmlDType::Q4_0 => deq::<crate::quantized::BlockQ4_0>(&buffer, block_len, &mut out)?,
|
|
||||||
GgmlDType::Q4_1 => deq::<crate::quantized::BlockQ4_1>(&buffer, block_len, &mut out)?,
|
|
||||||
GgmlDType::Q5_0 => deq::<crate::quantized::BlockQ5_0>(&buffer, block_len, &mut out)?,
|
|
||||||
GgmlDType::Q5_1 => deq::<crate::quantized::BlockQ5_1>(&buffer, block_len, &mut out)?,
|
|
||||||
GgmlDType::Q8_0 => deq::<crate::quantized::BlockQ8_0>(&buffer, block_len, &mut out)?,
|
|
||||||
GgmlDType::Q8_1 => deq::<crate::quantized::BlockQ8_1>(&buffer, block_len, &mut out)?,
|
|
||||||
GgmlDType::Q2K => deq::<crate::quantized::BlockQ2K>(&buffer, block_len, &mut out)?,
|
|
||||||
GgmlDType::Q3K => deq::<crate::quantized::BlockQ3K>(&buffer, block_len, &mut out)?,
|
|
||||||
GgmlDType::Q4K => deq::<crate::quantized::BlockQ4K>(&buffer, block_len, &mut out)?,
|
|
||||||
GgmlDType::Q5K => deq::<crate::quantized::BlockQ5K>(&buffer, block_len, &mut out)?,
|
|
||||||
GgmlDType::Q6K => deq::<crate::quantized::BlockQ6K>(&buffer, block_len, &mut out)?,
|
|
||||||
GgmlDType::Q8K => deq::<crate::quantized::BlockQ8K>(&buffer, block_len, &mut out)?,
|
|
||||||
}
|
|
||||||
|
|
||||||
self.device
|
|
||||||
.storage_from_cpu_storage(&crate::CpuStorage::F32(out))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dequantize_f16(&self, elem_count: usize) -> Result<CudaStorage> {
|
|
||||||
dequantize_f16(&self.data, self.dtype, elem_count, self.device())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
|
|
||||||
// Run the quantization on cpu.
|
|
||||||
let src = match &src.slice {
|
|
||||||
crate::cuda_backend::CudaStorageSlice::F32(data) => {
|
|
||||||
self.device.dtoh_sync_copy(data).w()?
|
|
||||||
}
|
|
||||||
_ => crate::bail!("only f32 can be quantized"),
|
|
||||||
};
|
|
||||||
let src_len = src.len();
|
|
||||||
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
|
|
||||||
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
|
|
||||||
qcpu_storage.quantize(&src)?;
|
|
||||||
let data = qcpu_storage.data()?;
|
|
||||||
let data = self.device.htod_sync_copy(data.as_ref()).w()?;
|
|
||||||
self.data = data;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn storage_size_in_bytes(&self) -> usize {
|
|
||||||
self.data.len()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn fwd(
|
|
||||||
&self,
|
|
||||||
self_shape: &crate::Shape,
|
|
||||||
storage: &CudaStorage,
|
|
||||||
layout: &crate::Layout,
|
|
||||||
) -> Result<(CudaStorage, crate::Shape)> {
|
|
||||||
let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
|
||||||
1
|
|
||||||
} else {
|
|
||||||
8
|
|
||||||
};
|
|
||||||
let use_vec_kernel = match layout.shape().dims() {
|
|
||||||
[b, m, _k] => b * m <= max_bm,
|
|
||||||
[b, _k] => *b <= max_bm,
|
|
||||||
_ => false,
|
|
||||||
};
|
|
||||||
if use_vec_kernel {
|
|
||||||
self.dequantize_matmul_vec(self_shape, storage, layout)
|
|
||||||
} else {
|
|
||||||
self.dequantize_matmul(self_shape, storage, layout)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl QCudaStorage {
|
|
||||||
fn dequantize_matmul_vec(
|
|
||||||
&self,
|
|
||||||
self_shape: &crate::Shape,
|
|
||||||
rhs: &CudaStorage,
|
|
||||||
rhs_l: &crate::Layout,
|
|
||||||
) -> Result<(CudaStorage, crate::Shape)> {
|
|
||||||
let (nrows, ncols) = self_shape.dims2()?;
|
|
||||||
let rhs = rhs.as_cuda_slice::<f32>()?;
|
|
||||||
let rhs = match rhs_l.contiguous_offsets() {
|
|
||||||
Some((o1, o2)) => rhs.slice(o1..o2),
|
|
||||||
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
|
|
||||||
};
|
|
||||||
let (b_size, k) = match rhs_l.shape().dims() {
|
|
||||||
[b, m, k] => (b * m, *k),
|
|
||||||
[b, k] => (*b, *k),
|
|
||||||
_ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()),
|
|
||||||
};
|
|
||||||
if ncols != k {
|
|
||||||
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
|
|
||||||
}
|
|
||||||
|
|
||||||
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
|
||||||
dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
|
|
||||||
} else {
|
|
||||||
mul_mat_vec_via_q8_1(
|
|
||||||
&self.data,
|
|
||||||
&rhs,
|
|
||||||
self.dtype,
|
|
||||||
ncols,
|
|
||||||
nrows,
|
|
||||||
b_size,
|
|
||||||
self.device(),
|
|
||||||
)?
|
|
||||||
};
|
|
||||||
let mut out_shape = rhs_l.shape().dims().to_vec();
|
|
||||||
out_shape.pop();
|
|
||||||
out_shape.push(nrows);
|
|
||||||
Ok((out, out_shape.into()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn dequantize_matmul(
|
|
||||||
&self,
|
|
||||||
self_shape: &crate::Shape,
|
|
||||||
storage: &CudaStorage,
|
|
||||||
layout: &crate::Layout,
|
|
||||||
) -> Result<(CudaStorage, crate::Shape)> {
|
|
||||||
use crate::backend::BackendStorage;
|
|
||||||
let (n, k) = self_shape.dims2()?;
|
|
||||||
let (b, m, k2) = match layout.shape().dims() {
|
|
||||||
&[b, m, k2] => (b, m, k2),
|
|
||||||
&[m, k2] => (1, m, k2),
|
|
||||||
s => crate::bail!("unexpected shape for input {s:?}"),
|
|
||||||
};
|
|
||||||
if k2 != k {
|
|
||||||
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", layout.shape())
|
|
||||||
}
|
|
||||||
|
|
||||||
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
|
||||||
let data_f32 = self.dequantize(n * k)?;
|
|
||||||
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
|
|
||||||
storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?
|
|
||||||
} else {
|
|
||||||
let storage = storage.as_cuda_slice::<f32>()?;
|
|
||||||
let storage = match layout.contiguous_offsets() {
|
|
||||||
Some((o1, o2)) => storage.slice(o1..o2),
|
|
||||||
None => Err(crate::Error::RequiresContiguous {
|
|
||||||
op: "quantized-matmul",
|
|
||||||
}
|
|
||||||
.bt())?,
|
|
||||||
};
|
|
||||||
mul_mat_via_q8_1(
|
|
||||||
&self.data,
|
|
||||||
&storage,
|
|
||||||
self.dtype,
|
|
||||||
/* x_rows */ n,
|
|
||||||
/* x_cols */ k,
|
|
||||||
/* y_rows */ k,
|
|
||||||
/* y_cols */ b * m,
|
|
||||||
self.device(),
|
|
||||||
)?
|
|
||||||
};
|
|
||||||
let mut out_shape = layout.shape().dims().to_vec();
|
|
||||||
out_shape.pop();
|
|
||||||
out_shape.push(n);
|
|
||||||
Ok((out, out_shape.into()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
|
||||||
device: &CudaDevice,
|
|
||||||
data: &[T],
|
|
||||||
) -> Result<super::QStorage> {
|
|
||||||
let data = unsafe {
|
|
||||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, core::mem::size_of_val(data))
|
|
||||||
};
|
|
||||||
let data = device.htod_sync_copy(data).w()?;
|
|
||||||
Ok(QStorage::Cuda(QCudaStorage {
|
|
||||||
data,
|
|
||||||
device: device.clone(),
|
|
||||||
dtype: T::DTYPE,
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod test {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn cuda_quantize_q8_1() -> Result<()> {
|
|
||||||
let dev = CudaDevice::new(0)?;
|
|
||||||
let el = 256;
|
|
||||||
let el_padded = pad(el, MATRIX_ROW_PADDING);
|
|
||||||
let y_size_in_bytes =
|
|
||||||
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
|
||||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
|
||||||
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
|
|
||||||
let y = dev.htod_sync_copy(&vs).w()?;
|
|
||||||
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn cuda_mmv_q8_1() -> Result<()> {
|
|
||||||
let dev = CudaDevice::new(0)?;
|
|
||||||
let ncols = 256;
|
|
||||||
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
|
|
||||||
let y = dev.htod_sync_copy(&vs).w()?;
|
|
||||||
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
|
|
||||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
|
||||||
let cuda_storage = mul_mat_vec_via_q8_1(
|
|
||||||
&xs.data,
|
|
||||||
&y.slice(..),
|
|
||||||
/* dtype */ GgmlDType::Q4_0,
|
|
||||||
/* ncols */ ncols,
|
|
||||||
/* nrows */ 1,
|
|
||||||
/* b_size */ 1,
|
|
||||||
&dev,
|
|
||||||
)?;
|
|
||||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
|
||||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
|
||||||
assert_eq!(vs.len(), 1);
|
|
||||||
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
|
|
||||||
// Q8 means 1/256 precision.
|
|
||||||
assert_eq!(vs[0], 5561664.5);
|
|
||||||
|
|
||||||
let cuda_storage = dequantize_mul_mat_vec(
|
|
||||||
&xs.data,
|
|
||||||
&y.slice(..),
|
|
||||||
/* dtype */ GgmlDType::Q4_0,
|
|
||||||
/* ncols */ ncols,
|
|
||||||
/* nrows */ 1,
|
|
||||||
&dev,
|
|
||||||
)?;
|
|
||||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
|
||||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
|
||||||
assert_eq!(vs.len(), 1);
|
|
||||||
assert_eq!(vs[0], 5561851.0);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn cuda_mm_q8_1() -> Result<()> {
|
|
||||||
let dev = CudaDevice::new(0)?;
|
|
||||||
let ncols = 256;
|
|
||||||
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
|
|
||||||
let y = dev.htod_sync_copy(&vs).w()?;
|
|
||||||
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
|
|
||||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
|
||||||
let cuda_storage = mul_mat_via_q8_1(
|
|
||||||
&xs.data,
|
|
||||||
&y.slice(..),
|
|
||||||
/* dtype */ GgmlDType::Q4_0,
|
|
||||||
/* x_rows */ 4,
|
|
||||||
/* x_cols */ ncols,
|
|
||||||
/* y_rows */ ncols,
|
|
||||||
/* y_cols */ 4,
|
|
||||||
&dev,
|
|
||||||
)?;
|
|
||||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
|
||||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
|
||||||
|
|
||||||
/*
|
|
||||||
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
|
|
||||||
x @ x.t() / 16
|
|
||||||
tensor([[ 347480.0000, 869720.0000, 1391960.0000, 1914200.0000],
|
|
||||||
[ 869720.0000, 2440536.0000, 4011352.0000, 5582166.5000],
|
|
||||||
[ 1391960.0000, 4011352.0000, 6630742.0000, 9250132.0000],
|
|
||||||
[ 1914200.0000, 5582166.5000, 9250132.0000, 12918099.0000]])
|
|
||||||
*/
|
|
||||||
assert_eq!(vs.len(), 16);
|
|
||||||
assert_eq!(vs[0], 347604.0);
|
|
||||||
assert_eq!(vs[1], 888153.06);
|
|
||||||
assert_eq!(vs[4], 869780.7);
|
|
||||||
assert_eq!(vs[5], 2483145.0);
|
|
||||||
assert_eq!(vs[11], 9407368.0);
|
|
||||||
assert_eq!(vs[14], 9470856.0);
|
|
||||||
assert_eq!(vs[15], 13138824.0);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,54 +0,0 @@
|
|||||||
#![allow(unused)]
|
|
||||||
use super::GgmlDType;
|
|
||||||
use crate::{CudaDevice, CudaStorage, Error, Result};
|
|
||||||
|
|
||||||
pub struct QCudaStorage {
|
|
||||||
dtype: GgmlDType,
|
|
||||||
device: CudaDevice,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl QCudaStorage {
|
|
||||||
pub fn zeros(_: &CudaDevice, _: usize, _: GgmlDType) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dtype(&self) -> GgmlDType {
|
|
||||||
self.dtype
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn device(&self) -> &CudaDevice {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dequantize(&self, _elem_count: usize) -> Result<CudaStorage> {
|
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dequantize_f16(&self, _elem_count: usize) -> Result<CudaStorage> {
|
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> {
|
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn storage_size_in_bytes(&self) -> usize {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn fwd(
|
|
||||||
&self,
|
|
||||||
_self_shape: &crate::Shape,
|
|
||||||
_storage: &CudaStorage,
|
|
||||||
_layout: &crate::Layout,
|
|
||||||
) -> Result<(CudaStorage, crate::Shape)> {
|
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
|
||||||
_device: &CudaDevice,
|
|
||||||
_data: &[T],
|
|
||||||
) -> Result<super::QStorage> {
|
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
|
||||||
}
|
|
@ -1,50 +0,0 @@
|
|||||||
#![allow(unused)]
|
|
||||||
use super::GgmlDType;
|
|
||||||
use crate::{Error, MetalDevice, MetalStorage, Result};
|
|
||||||
|
|
||||||
pub struct QMetalStorage {
|
|
||||||
dtype: GgmlDType,
|
|
||||||
device: MetalDevice,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl QMetalStorage {
|
|
||||||
pub fn zeros(_: &MetalDevice, _: usize, _: GgmlDType) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dtype(&self) -> GgmlDType {
|
|
||||||
self.dtype
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn device(&self) -> &MetalDevice {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dequantize(&self, _elem_count: usize) -> Result<MetalStorage> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn quantize(&mut self, _src: &MetalStorage) -> Result<()> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn storage_size_in_bytes(&self) -> usize {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn fwd(
|
|
||||||
&self,
|
|
||||||
_self_shape: &crate::Shape,
|
|
||||||
_storage: &MetalStorage,
|
|
||||||
_layout: &crate::Layout,
|
|
||||||
) -> Result<(MetalStorage, crate::Shape)> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
|
||||||
_device: &MetalDevice,
|
|
||||||
_data: &[T],
|
|
||||||
) -> Result<super::QStorage> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
@ -1,7 +1,7 @@
|
|||||||
//! Support for the GGML file format.
|
//! Support for the GGML file format.
|
||||||
|
|
||||||
use super::{k_quants, GgmlDType, QStorage};
|
use super::{k_quants, GgmlDType};
|
||||||
use crate::{Device, Result};
|
use crate::Result;
|
||||||
use byteorder::{LittleEndian, ReadBytesExt};
|
use byteorder::{LittleEndian, ReadBytesExt};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
@ -121,17 +121,11 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
|
|||||||
raw_data: &[u8],
|
raw_data: &[u8],
|
||||||
size_in_bytes: usize,
|
size_in_bytes: usize,
|
||||||
dims: Vec<usize>,
|
dims: Vec<usize>,
|
||||||
device: &Device,
|
|
||||||
) -> Result<super::QTensor> {
|
) -> Result<super::QTensor> {
|
||||||
let raw_data_ptr = raw_data.as_ptr();
|
let raw_data_ptr = raw_data.as_ptr();
|
||||||
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
|
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
|
||||||
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
||||||
let data: QStorage = match device {
|
super::QTensor::new(data.to_vec(), dims)
|
||||||
Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
|
|
||||||
Device::Metal(metal) => super::metal::load_quantized(metal, data)?,
|
|
||||||
Device::Cuda(cuda) => super::cuda::load_quantized(cuda, data)?,
|
|
||||||
};
|
|
||||||
super::QTensor::new(data, dims)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a [Tensor] from a raw GGML tensor.
|
/// Creates a [Tensor] from a raw GGML tensor.
|
||||||
@ -139,50 +133,29 @@ pub fn qtensor_from_ggml(
|
|||||||
ggml_dtype: GgmlDType,
|
ggml_dtype: GgmlDType,
|
||||||
raw_data: &[u8],
|
raw_data: &[u8],
|
||||||
dims: Vec<usize>,
|
dims: Vec<usize>,
|
||||||
device: &Device,
|
|
||||||
) -> Result<super::QTensor> {
|
) -> Result<super::QTensor> {
|
||||||
let tensor_elems = dims.iter().product::<usize>();
|
let tensor_elems = dims.iter().product::<usize>();
|
||||||
let block_size = ggml_dtype.block_size();
|
let blck_size = ggml_dtype.blck_size();
|
||||||
if tensor_elems % block_size != 0 {
|
if tensor_elems % blck_size != 0 {
|
||||||
crate::bail!(
|
crate::bail!(
|
||||||
"the number of elements {tensor_elems} is not divisible by the block size {block_size}"
|
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
let size_in_bytes = tensor_elems / block_size * ggml_dtype.type_size();
|
let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
|
||||||
|
|
||||||
match ggml_dtype {
|
match ggml_dtype {
|
||||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
|
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
|
||||||
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
|
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
|
||||||
GgmlDType::Q4_0 => {
|
GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
|
||||||
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
|
GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
|
||||||
}
|
GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
|
||||||
GgmlDType::Q4_1 => {
|
GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
|
||||||
from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device)
|
GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
|
||||||
}
|
GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
|
||||||
GgmlDType::Q5_0 => {
|
GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
|
||||||
from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device)
|
GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
|
||||||
}
|
GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
|
||||||
GgmlDType::Q5_1 => {
|
GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
|
||||||
from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
|
|
||||||
}
|
|
||||||
GgmlDType::Q8_0 => {
|
|
||||||
from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
|
|
||||||
}
|
|
||||||
GgmlDType::Q2K => {
|
|
||||||
from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
|
|
||||||
}
|
|
||||||
GgmlDType::Q3K => {
|
|
||||||
from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
|
|
||||||
}
|
|
||||||
GgmlDType::Q4K => {
|
|
||||||
from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
|
|
||||||
}
|
|
||||||
GgmlDType::Q5K => {
|
|
||||||
from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
|
|
||||||
}
|
|
||||||
GgmlDType::Q6K => {
|
|
||||||
from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
|
|
||||||
}
|
|
||||||
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
|
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -190,7 +163,6 @@ pub fn qtensor_from_ggml(
|
|||||||
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||||
reader: &mut R,
|
reader: &mut R,
|
||||||
magic: VersionedMagic,
|
magic: VersionedMagic,
|
||||||
device: &Device,
|
|
||||||
) -> Result<(String, super::QTensor)> {
|
) -> Result<(String, super::QTensor)> {
|
||||||
let n_dims = reader.read_u32::<LittleEndian>()?;
|
let n_dims = reader.read_u32::<LittleEndian>()?;
|
||||||
let name_len = reader.read_u32::<LittleEndian>()?;
|
let name_len = reader.read_u32::<LittleEndian>()?;
|
||||||
@ -211,11 +183,11 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
|||||||
}
|
}
|
||||||
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
|
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
|
||||||
let tensor_elems = dims.iter().product::<usize>();
|
let tensor_elems = dims.iter().product::<usize>();
|
||||||
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.block_size();
|
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
|
||||||
// TODO: Mmap version to avoid copying the data around?
|
// TODO: Mmap version to avoid copying the data around?
|
||||||
let mut raw_data = vec![0u8; size_in_bytes];
|
let mut raw_data = vec![0u8; size_in_bytes];
|
||||||
reader.read_exact(&mut raw_data)?;
|
reader.read_exact(&mut raw_data)?;
|
||||||
match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) {
|
match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
|
||||||
Ok(tensor) => Ok((name, tensor)),
|
Ok(tensor) => Ok((name, tensor)),
|
||||||
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
|
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
|
||||||
}
|
}
|
||||||
@ -226,14 +198,10 @@ pub struct Content {
|
|||||||
pub hparams: HParams,
|
pub hparams: HParams,
|
||||||
pub vocab: Vocab,
|
pub vocab: Vocab,
|
||||||
pub tensors: HashMap<String, super::QTensor>,
|
pub tensors: HashMap<String, super::QTensor>,
|
||||||
pub device: Device,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Content {
|
impl Content {
|
||||||
pub fn read<R: std::io::Seek + std::io::Read>(
|
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
|
||||||
reader: &mut R,
|
|
||||||
device: &Device,
|
|
||||||
) -> Result<Content> {
|
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
|
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
|
||||||
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
|
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
|
||||||
reader.seek(std::io::SeekFrom::Start(0))?;
|
reader.seek(std::io::SeekFrom::Start(0))?;
|
||||||
@ -243,16 +211,14 @@ impl Content {
|
|||||||
let mut tensors = HashMap::new();
|
let mut tensors = HashMap::new();
|
||||||
|
|
||||||
while reader.stream_position()? != last_position {
|
while reader.stream_position()? != last_position {
|
||||||
let (name, tensor) = read_one_tensor(reader, magic, device)?;
|
let (name, tensor) = read_one_tensor(reader, magic)?;
|
||||||
tensors.insert(name, tensor);
|
tensors.insert(name, tensor);
|
||||||
}
|
}
|
||||||
let device = device.clone();
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
magic,
|
magic,
|
||||||
hparams,
|
hparams,
|
||||||
vocab,
|
vocab,
|
||||||
tensors,
|
tensors,
|
||||||
device,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
|
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
|
||||||
|
|
||||||
use super::{GgmlDType, QTensor};
|
use super::{GgmlDType, QTensor};
|
||||||
use crate::{Device, Result};
|
use crate::Result;
|
||||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
@ -29,7 +29,6 @@ impl TryFrom<u32> for Magic {
|
|||||||
pub enum VersionedMagic {
|
pub enum VersionedMagic {
|
||||||
GgufV1,
|
GgufV1,
|
||||||
GgufV2,
|
GgufV2,
|
||||||
GgufV3,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VersionedMagic {
|
impl VersionedMagic {
|
||||||
@ -40,8 +39,7 @@ impl VersionedMagic {
|
|||||||
let versioned_magic = match (magic, version) {
|
let versioned_magic = match (magic, version) {
|
||||||
(Magic::Gguf, 1) => Self::GgufV1,
|
(Magic::Gguf, 1) => Self::GgufV1,
|
||||||
(Magic::Gguf, 2) => Self::GgufV2,
|
(Magic::Gguf, 2) => Self::GgufV2,
|
||||||
(Magic::Gguf, 3) => Self::GgufV3,
|
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
|
||||||
_ => crate::bail!("gguf: unsupported magic/version {magic:?}/{version}"),
|
|
||||||
};
|
};
|
||||||
Ok(versioned_magic)
|
Ok(versioned_magic)
|
||||||
}
|
}
|
||||||
@ -59,25 +57,19 @@ impl TensorInfo {
|
|||||||
&self,
|
&self,
|
||||||
reader: &mut R,
|
reader: &mut R,
|
||||||
tensor_data_offset: u64,
|
tensor_data_offset: u64,
|
||||||
device: &Device,
|
|
||||||
) -> Result<QTensor> {
|
) -> Result<QTensor> {
|
||||||
let tensor_elems = self.shape.elem_count();
|
let tensor_elems = self.shape.elem_count();
|
||||||
let block_size = self.ggml_dtype.block_size();
|
let blck_size = self.ggml_dtype.blck_size();
|
||||||
if tensor_elems % block_size != 0 {
|
if tensor_elems % blck_size != 0 {
|
||||||
crate::bail!(
|
crate::bail!(
|
||||||
"the number of elements {tensor_elems} is not divisible by the block size {block_size}"
|
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
let size_in_bytes = tensor_elems / block_size * self.ggml_dtype.type_size();
|
let size_in_bytes = tensor_elems / blck_size * self.ggml_dtype.type_size();
|
||||||
let mut raw_data = vec![0u8; size_in_bytes];
|
let mut raw_data = vec![0u8; size_in_bytes];
|
||||||
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
|
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
|
||||||
reader.read_exact(&mut raw_data)?;
|
reader.read_exact(&mut raw_data)?;
|
||||||
super::ggml_file::qtensor_from_ggml(
|
super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec())
|
||||||
self.ggml_dtype,
|
|
||||||
&raw_data,
|
|
||||||
self.shape.dims().to_vec(),
|
|
||||||
device,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -92,9 +84,7 @@ pub struct Content {
|
|||||||
fn read_string<R: std::io::Read>(reader: &mut R, magic: &VersionedMagic) -> Result<String> {
|
fn read_string<R: std::io::Read>(reader: &mut R, magic: &VersionedMagic) -> Result<String> {
|
||||||
let len = match magic {
|
let len = match magic {
|
||||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||||
reader.read_u64::<LittleEndian>()? as usize
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
let mut v = vec![0u8; len];
|
let mut v = vec![0u8; len];
|
||||||
reader.read_exact(&mut v)?;
|
reader.read_exact(&mut v)?;
|
||||||
@ -135,6 +125,7 @@ pub enum ValueType {
|
|||||||
// The value is a UTF-8 non-null-terminated string, with length prepended.
|
// The value is a UTF-8 non-null-terminated string, with length prepended.
|
||||||
String,
|
String,
|
||||||
// The value is an array of other values, with the length and type prepended.
|
// The value is an array of other values, with the length and type prepended.
|
||||||
|
///
|
||||||
// Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes.
|
// Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes.
|
||||||
Array,
|
Array,
|
||||||
}
|
}
|
||||||
@ -293,9 +284,7 @@ impl Value {
|
|||||||
let value_type = ValueType::from_u32(value_type)?;
|
let value_type = ValueType::from_u32(value_type)?;
|
||||||
let len = match magic {
|
let len = match magic {
|
||||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||||
reader.read_u64::<LittleEndian>()? as usize
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
let mut vs = Vec::with_capacity(len);
|
let mut vs = Vec::with_capacity(len);
|
||||||
for _ in 0..len {
|
for _ in 0..len {
|
||||||
@ -392,15 +381,11 @@ impl Content {
|
|||||||
|
|
||||||
let tensor_count = match magic {
|
let tensor_count = match magic {
|
||||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||||
reader.read_u64::<LittleEndian>()? as usize
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
let metadata_kv_count = match magic {
|
let metadata_kv_count = match magic {
|
||||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||||
reader.read_u64::<LittleEndian>()? as usize
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut metadata = HashMap::new();
|
let mut metadata = HashMap::new();
|
||||||
@ -422,7 +407,7 @@ impl Content {
|
|||||||
reader.read_u32_into::<LittleEndian>(&mut dimensions)?;
|
reader.read_u32_into::<LittleEndian>(&mut dimensions)?;
|
||||||
dimensions.into_iter().map(|c| c as usize).collect()
|
dimensions.into_iter().map(|c| c as usize).collect()
|
||||||
}
|
}
|
||||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
VersionedMagic::GgufV2 => {
|
||||||
let mut dimensions = vec![0; n_dimensions as usize];
|
let mut dimensions = vec![0; n_dimensions as usize];
|
||||||
reader.read_u64_into::<LittleEndian>(&mut dimensions)?;
|
reader.read_u64_into::<LittleEndian>(&mut dimensions)?;
|
||||||
dimensions.into_iter().map(|c| c as usize).collect()
|
dimensions.into_iter().map(|c| c as usize).collect()
|
||||||
@ -465,13 +450,12 @@ impl Content {
|
|||||||
&self,
|
&self,
|
||||||
reader: &mut R,
|
reader: &mut R,
|
||||||
name: &str,
|
name: &str,
|
||||||
device: &Device,
|
|
||||||
) -> Result<QTensor> {
|
) -> Result<QTensor> {
|
||||||
let tensor_info = match self.tensor_infos.get(name) {
|
let tensor_info = match self.tensor_infos.get(name) {
|
||||||
Some(tensor_info) => tensor_info,
|
Some(tensor_info) => tensor_info,
|
||||||
None => crate::bail!("cannot find tensor info for {name}"),
|
None => crate::bail!("cannot find tensor-infor for {name}"),
|
||||||
};
|
};
|
||||||
tensor_info.read(reader, self.tensor_data_offset, device)
|
tensor_info.read(reader, self.tensor_data_offset)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -523,9 +507,10 @@ pub fn write<W: std::io::Seek + std::io::Write>(
|
|||||||
"internal error, unexpected current position {tensor_start_pos} {offset} {pos}"
|
"internal error, unexpected current position {tensor_start_pos} {offset} {pos}"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
let data = tensor.data()?;
|
let data_ptr = tensor.as_ptr();
|
||||||
let size_in_bytes = data.len();
|
let size_in_bytes = tensor.storage_size_in_bytes();
|
||||||
w.write_all(&data)?;
|
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
||||||
|
w.write_all(data)?;
|
||||||
let padding = 31 - (31 + size_in_bytes) % 32;
|
let padding = 31 - (31 + size_in_bytes) % 32;
|
||||||
w.write_all(&vec![0u8; padding])?;
|
w.write_all(&vec![0u8; padding])?;
|
||||||
}
|
}
|
||||||
|
@ -236,9 +236,14 @@ impl GgmlType for BlockQ4_0 {
|
|||||||
|
|
||||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||||
let qk = QK8_0;
|
let qk = QK8_0;
|
||||||
|
let nb = n / qk;
|
||||||
if n % QK8_0 != 0 {
|
if n % QK8_0 != 0 {
|
||||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||||
}
|
}
|
||||||
|
if nb % 2 != 0 {
|
||||||
|
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||||
|
}
|
||||||
|
|
||||||
// Generic implementation.
|
// Generic implementation.
|
||||||
let mut sumf = 0f32;
|
let mut sumf = 0f32;
|
||||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||||
@ -1545,13 +1550,13 @@ impl GgmlType for BlockQ5K {
|
|||||||
let d2 = d * sc as f32;
|
let d2 = d * sc as f32;
|
||||||
let m2 = min * m as f32;
|
let m2 = min * m as f32;
|
||||||
for (ql, qh) in ql.iter().zip(qh) {
|
for (ql, qh) in ql.iter().zip(qh) {
|
||||||
let to_add = if qh & u1 != 0 { 16f32 } else { 0f32 };
|
let to_add = if qh & u1 != 0 { 16 } else { 1 };
|
||||||
y[ys_index] = d1 * ((ql & 0xF) as f32 + to_add) - m1;
|
y[ys_index] = d1 * ((ql & 0xF) + to_add) as f32 - m1;
|
||||||
ys_index += 1;
|
ys_index += 1;
|
||||||
}
|
}
|
||||||
for (ql, qh) in ql.iter().zip(qh) {
|
for (ql, qh) in ql.iter().zip(qh) {
|
||||||
let to_add = if qh & u2 != 0 { 16f32 } else { 0f32 };
|
let to_add = if qh & u2 != 0 { 16 } else { 1 };
|
||||||
y[ys_index] = d2 * ((ql >> 4) as f32 + to_add) - m2;
|
y[ys_index] = d2 * ((ql >> 4) + to_add) as f32 - m2;
|
||||||
ys_index += 1;
|
ys_index += 1;
|
||||||
}
|
}
|
||||||
is += 2;
|
is += 2;
|
||||||
|
@ -1,230 +0,0 @@
|
|||||||
use super::{GgmlDType, QStorage};
|
|
||||||
use crate::backend::BackendStorage;
|
|
||||||
use crate::{DType, MetalDevice, MetalStorage, Result, Shape};
|
|
||||||
use metal::Buffer;
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
pub struct QMetalStorage {
|
|
||||||
dtype: GgmlDType,
|
|
||||||
device: MetalDevice,
|
|
||||||
buffer: Arc<Buffer>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl QMetalStorage {
|
|
||||||
pub fn zeros(device: &MetalDevice, elem_count: usize, dtype: GgmlDType) -> Result<Self> {
|
|
||||||
let size = elem_count * dtype.type_size() / dtype.block_size();
|
|
||||||
let buffer = device.allocate_zeros(size)?;
|
|
||||||
Ok(Self {
|
|
||||||
buffer,
|
|
||||||
device: device.clone(),
|
|
||||||
dtype,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dtype(&self) -> GgmlDType {
|
|
||||||
self.dtype
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn device(&self) -> &MetalDevice {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn buffer(&self) -> &Buffer {
|
|
||||||
&self.buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
|
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
|
|
||||||
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
|
||||||
let command_buffer = self.device.command_buffer()?;
|
|
||||||
command_buffer.set_label("to_cpu");
|
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
|
||||||
blit.set_label("blit_to_cpu");
|
|
||||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
|
||||||
blit.end_encoding();
|
|
||||||
self.device.wait_until_completed()?;
|
|
||||||
let mut out = vec![0.0; elem_count];
|
|
||||||
let block_len = elem_count / self.dtype.block_size();
|
|
||||||
match self.dtype {
|
|
||||||
GgmlDType::F32 => {
|
|
||||||
let vec: Vec<f32> = read_to_vec(&buffer, block_len);
|
|
||||||
f32::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::F16 => {
|
|
||||||
let vec: Vec<half::f16> = read_to_vec(&buffer, block_len);
|
|
||||||
half::f16::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q4_0 => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q4_1 => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q5_0 => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q5_1 => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q8_0 => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q8_1 => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q2K => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ2K> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q3K => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ3K> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q4K => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ4K> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q5K => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ5K> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q6K => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ6K> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q8K => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ8K> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let buffer = self.device.new_buffer_with_data(&out)?;
|
|
||||||
Ok(MetalStorage::new(
|
|
||||||
buffer,
|
|
||||||
self.device.clone(),
|
|
||||||
elem_count,
|
|
||||||
DType::F32,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> {
|
|
||||||
// Quantization only happens on CPU for now.
|
|
||||||
let src = src.to_cpu::<f32>()?;
|
|
||||||
let elem_count = src.len();
|
|
||||||
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
|
|
||||||
let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?;
|
|
||||||
qcpu_storage.quantize(&src)?;
|
|
||||||
let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?;
|
|
||||||
self.buffer = buffer;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn storage_size_in_bytes(&self) -> usize {
|
|
||||||
self.buffer.length() as usize
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn fwd(
|
|
||||||
&self,
|
|
||||||
self_shape: &Shape,
|
|
||||||
storage: &MetalStorage,
|
|
||||||
layout: &crate::Layout,
|
|
||||||
) -> Result<(MetalStorage, Shape)> {
|
|
||||||
use crate::MetalError;
|
|
||||||
|
|
||||||
if !layout.is_contiguous() {
|
|
||||||
crate::bail!("input tensor is not contiguous {layout:?}")
|
|
||||||
}
|
|
||||||
let src_shape = layout.shape();
|
|
||||||
// self is transposed so n is first then k.
|
|
||||||
if src_shape.rank() < 2 {
|
|
||||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
|
||||||
}
|
|
||||||
let (n, k) = self_shape.dims2()?;
|
|
||||||
let mut dst_shape = src_shape.dims().to_vec();
|
|
||||||
|
|
||||||
// We always use a single batch dimension and stack all the tensors in the batch on the
|
|
||||||
// second dimension as the implementation in candle-metal-kernels doesn't handle batch
|
|
||||||
// properly.
|
|
||||||
let m = match dst_shape.len() {
|
|
||||||
3 => dst_shape[0] * dst_shape[1],
|
|
||||||
2 => dst_shape[0],
|
|
||||||
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
|
||||||
};
|
|
||||||
let last_k = dst_shape.pop().unwrap();
|
|
||||||
if last_k != k {
|
|
||||||
crate::bail!("input tensor {layout:?} incompatible with {:?}", self_shape)
|
|
||||||
}
|
|
||||||
dst_shape.push(n);
|
|
||||||
let dst_shape = Shape::from(dst_shape);
|
|
||||||
let device = storage.device().clone();
|
|
||||||
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
|
|
||||||
let command_buffer = device.command_buffer()?;
|
|
||||||
// In some cases it would be better to use the mm variant, though it has its drawbacks
|
|
||||||
// around memory alignemnt.
|
|
||||||
for batch_id in 0..m {
|
|
||||||
candle_metal_kernels::call_quantized_matmul_mv_t(
|
|
||||||
device.device(),
|
|
||||||
&command_buffer,
|
|
||||||
device.kernels(),
|
|
||||||
self.dtype.into(),
|
|
||||||
(1, 1, n, k),
|
|
||||||
storage.buffer(),
|
|
||||||
(layout.start_offset() + batch_id * k) * storage.dtype().size_in_bytes(),
|
|
||||||
&self.buffer,
|
|
||||||
batch_id * n * DType::F32.size_in_bytes(),
|
|
||||||
&dst,
|
|
||||||
)
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
}
|
|
||||||
let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32);
|
|
||||||
Ok((dst_storage, dst_shape))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
|
||||||
device: &MetalDevice,
|
|
||||||
data: &[T],
|
|
||||||
) -> Result<QStorage> {
|
|
||||||
let buffer = device.new_buffer_with_data(data)?;
|
|
||||||
let device = device.clone();
|
|
||||||
Ok(QStorage::Metal(QMetalStorage {
|
|
||||||
dtype: T::DTYPE,
|
|
||||||
device,
|
|
||||||
buffer,
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
|
||||||
let ptr = buffer.contents() as *const T;
|
|
||||||
assert!(!ptr.is_null());
|
|
||||||
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
|
|
||||||
slice.to_vec()
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
|
|
||||||
fn from(value: GgmlDType) -> Self {
|
|
||||||
match value {
|
|
||||||
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
|
|
||||||
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
|
|
||||||
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
|
|
||||||
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
|
|
||||||
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
|
|
||||||
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
|
|
||||||
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
|
|
||||||
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
|
|
||||||
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
|
|
||||||
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
|
|
||||||
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
|
|
||||||
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
|
|
||||||
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
|
|
||||||
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,134 +1,23 @@
|
|||||||
use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
|
use crate::{Device, Result, Shape, Tensor};
|
||||||
use k_quants::*;
|
|
||||||
use std::borrow::Cow;
|
|
||||||
|
|
||||||
#[cfg(target_feature = "avx")]
|
#[cfg(target_feature = "avx")]
|
||||||
pub mod avx;
|
pub mod avx;
|
||||||
mod dummy_cuda;
|
|
||||||
mod dummy_metal;
|
|
||||||
pub mod ggml_file;
|
pub mod ggml_file;
|
||||||
pub mod gguf_file;
|
pub mod gguf_file;
|
||||||
pub mod k_quants;
|
pub mod k_quants;
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
pub mod metal;
|
|
||||||
#[cfg(not(feature = "metal"))]
|
|
||||||
mod metal {
|
|
||||||
pub use super::dummy_metal::*;
|
|
||||||
}
|
|
||||||
#[cfg(feature = "cuda")]
|
|
||||||
pub mod cuda;
|
|
||||||
#[cfg(not(feature = "cuda"))]
|
|
||||||
mod cuda {
|
|
||||||
pub use super::dummy_cuda::*;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(target_feature = "neon")]
|
#[cfg(target_feature = "neon")]
|
||||||
pub mod neon;
|
pub mod neon;
|
||||||
#[cfg(target_feature = "simd128")]
|
#[cfg(target_feature = "simd128")]
|
||||||
pub mod simd128;
|
pub mod simd128;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
use half::f16;
|
|
||||||
|
|
||||||
pub use k_quants::GgmlType;
|
pub use k_quants::GgmlType;
|
||||||
|
|
||||||
pub struct QTensor {
|
pub struct QTensor {
|
||||||
storage: QStorage,
|
data: Box<dyn QuantizedType>,
|
||||||
shape: Shape,
|
shape: Shape,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Device {
|
|
||||||
fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> {
|
|
||||||
match self {
|
|
||||||
Device::Cpu => {
|
|
||||||
let storage = dtype.cpu_zeros(elem_count);
|
|
||||||
Ok(QStorage::Cpu(storage))
|
|
||||||
}
|
|
||||||
Device::Metal(metal) => {
|
|
||||||
let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
|
|
||||||
Ok(QStorage::Metal(storage))
|
|
||||||
}
|
|
||||||
Device::Cuda(cuda) => {
|
|
||||||
let storage = cuda::QCudaStorage::zeros(cuda, elem_count, dtype)?;
|
|
||||||
Ok(QStorage::Cuda(storage))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub enum QStorage {
|
|
||||||
Cpu(Box<dyn QuantizedType>),
|
|
||||||
Metal(metal::QMetalStorage),
|
|
||||||
Cuda(cuda::QCudaStorage),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl QStorage {
|
|
||||||
fn block_size(&self) -> usize {
|
|
||||||
match self {
|
|
||||||
QStorage::Cpu(storage) => storage.block_size(),
|
|
||||||
QStorage::Metal(storage) => storage.dtype().block_size(),
|
|
||||||
QStorage::Cuda(storage) => storage.dtype().block_size(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn dtype(&self) -> GgmlDType {
|
|
||||||
match self {
|
|
||||||
QStorage::Cpu(storage) => storage.dtype(),
|
|
||||||
QStorage::Metal(storage) => storage.dtype(),
|
|
||||||
QStorage::Cuda(storage) => storage.dtype(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn device(&self) -> Device {
|
|
||||||
match self {
|
|
||||||
QStorage::Cpu(_storage) => Device::Cpu,
|
|
||||||
QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
|
|
||||||
QStorage::Cuda(storage) => Device::Cuda(storage.device().clone()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn size_in_bytes(&self) -> usize {
|
|
||||||
match self {
|
|
||||||
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
|
|
||||||
QStorage::Metal(storage) => storage.storage_size_in_bytes(),
|
|
||||||
QStorage::Cuda(storage) => storage.storage_size_in_bytes(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn quantize(&mut self, src: &Storage) -> Result<()> {
|
|
||||||
match (self, src) {
|
|
||||||
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
|
|
||||||
storage.from_float(src.as_slice::<f32>()?)?;
|
|
||||||
}
|
|
||||||
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
|
|
||||||
(QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?,
|
|
||||||
_ => crate::bail!("Invalid dequantize storage locations do not match"),
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn dequantize(&self, elem_count: usize) -> Result<Storage> {
|
|
||||||
match self {
|
|
||||||
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
|
|
||||||
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
|
|
||||||
QStorage::Cuda(storage) => Ok(Storage::Cuda(storage.dequantize(elem_count)?)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn data(&self) -> Result<Cow<[u8]>> {
|
|
||||||
match self {
|
|
||||||
QStorage::Cpu(storage) => {
|
|
||||||
let data_ptr = storage.as_ptr();
|
|
||||||
let size_in_bytes = storage.storage_size_in_bytes();
|
|
||||||
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
|
||||||
Ok(Cow::from(data))
|
|
||||||
}
|
|
||||||
QStorage::Metal(_) | QStorage::Cuda(_) => {
|
|
||||||
crate::bail!("not implemented");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
pub enum GgmlDType {
|
pub enum GgmlDType {
|
||||||
F32,
|
F32,
|
||||||
@ -188,25 +77,6 @@ impl GgmlDType {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The block dtype
|
|
||||||
pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> {
|
|
||||||
match self {
|
|
||||||
Self::F32 => Box::new(vec![f32::zeros(); elem_count]),
|
|
||||||
Self::F16 => Box::new(vec![f16::zeros(); elem_count]),
|
|
||||||
Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]),
|
|
||||||
Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]),
|
|
||||||
Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]),
|
|
||||||
Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]),
|
|
||||||
Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]),
|
|
||||||
Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]),
|
|
||||||
Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]),
|
|
||||||
Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]),
|
|
||||||
Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]),
|
|
||||||
Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
|
|
||||||
Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
|
|
||||||
Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/// The type size for blocks in bytes.
|
/// The type size for blocks in bytes.
|
||||||
pub fn type_size(&self) -> usize {
|
pub fn type_size(&self) -> usize {
|
||||||
use k_quants::*;
|
use k_quants::*;
|
||||||
@ -230,7 +100,7 @@ impl GgmlDType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// The block size, i.e. the number of elements stored in each block.
|
/// The block size, i.e. the number of elements stored in each block.
|
||||||
pub fn block_size(&self) -> usize {
|
pub fn blck_size(&self) -> usize {
|
||||||
match self {
|
match self {
|
||||||
Self::F32 => 1,
|
Self::F32 => 1,
|
||||||
Self::F16 => 1,
|
Self::F16 => 1,
|
||||||
@ -249,13 +119,9 @@ impl GgmlDType {
|
|||||||
pub trait QuantizedType: Send + Sync {
|
pub trait QuantizedType: Send + Sync {
|
||||||
fn dtype(&self) -> GgmlDType;
|
fn dtype(&self) -> GgmlDType;
|
||||||
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
|
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
|
||||||
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>;
|
fn to_float(&self, ys: &mut [f32]) -> Result<()>;
|
||||||
fn storage_size_in_bytes(&self) -> usize;
|
fn storage_size_in_bytes(&self) -> usize;
|
||||||
fn as_ptr(&self) -> *const u8;
|
fn as_ptr(&self) -> *const u8;
|
||||||
fn block_size(&self) -> usize;
|
|
||||||
#[allow(clippy::wrong_self_convention)]
|
|
||||||
fn from_float(&mut self, xs: &[f32]) -> Result<()>;
|
|
||||||
fn size(&self) -> usize;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
|
impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
|
||||||
@ -263,26 +129,12 @@ impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
|
|||||||
k_quants::matmul(mkn, lhs, self.as_slice(), dst)
|
k_quants::matmul(mkn, lhs, self.as_slice(), dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn size(&self) -> usize {
|
|
||||||
self.len() * core::mem::size_of::<T>()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn from_float(&mut self, xs: &[f32]) -> Result<()> {
|
|
||||||
T::from_float(xs, self)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn dtype(&self) -> GgmlDType {
|
fn dtype(&self) -> GgmlDType {
|
||||||
T::DTYPE
|
T::DTYPE
|
||||||
}
|
}
|
||||||
|
|
||||||
fn block_size(&self) -> usize {
|
fn to_float(&self, ys: &mut [f32]) -> Result<()> {
|
||||||
T::BLCK_SIZE
|
T::to_float(self.as_slice(), ys)
|
||||||
}
|
|
||||||
|
|
||||||
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> {
|
|
||||||
let mut ys = vec![0.0f32; elem_count];
|
|
||||||
T::to_float(self.as_slice(), &mut ys)?;
|
|
||||||
Ok(CpuStorage::F32(ys))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn storage_size_in_bytes(&self) -> usize {
|
fn storage_size_in_bytes(&self) -> usize {
|
||||||
@ -300,53 +152,56 @@ impl std::fmt::Debug for QTensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_shape(shape: &Shape, block_size: usize) -> Result<()> {
|
fn check_shape<T: k_quants::GgmlType>(shape: &Shape) -> Result<()> {
|
||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
if dims.is_empty() {
|
if dims.is_empty() {
|
||||||
crate::bail!("scalar tensor cannot be quantized {shape:?}")
|
crate::bail!("scalar tensor cannot be quantized {shape:?}")
|
||||||
}
|
}
|
||||||
if dims[dims.len() - 1] % block_size != 0 {
|
if dims[dims.len() - 1] % T::BLCK_SIZE != 0 {
|
||||||
crate::bail!(
|
crate::bail!(
|
||||||
"quantized tensor must have their last dim divisible by block size {shape:?} {}",
|
"quantized tensor must have their last dim divisible by block size {shape:?} {}",
|
||||||
block_size
|
T::BLCK_SIZE
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
impl QTensor {
|
impl QTensor {
|
||||||
pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> {
|
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
|
||||||
|
data: Vec<T>,
|
||||||
|
shape: S,
|
||||||
|
) -> Result<Self> {
|
||||||
let shape = shape.into();
|
let shape = shape.into();
|
||||||
check_shape(&shape, storage.block_size())?;
|
check_shape::<T>(&shape)?;
|
||||||
Ok(Self { storage, shape })
|
Ok(Self {
|
||||||
|
data: Box::new(data),
|
||||||
|
shape,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> {
|
pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
|
||||||
let shape = src.shape();
|
let shape = src.shape();
|
||||||
let block_size = dtype.block_size();
|
check_shape::<T>(shape)?;
|
||||||
check_shape(shape, block_size)?;
|
let src = src
|
||||||
let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
|
.to_dtype(crate::DType::F32)?
|
||||||
let elem_count = shape.elem_count();
|
.flatten_all()?
|
||||||
if elem_count % block_size != 0 {
|
.to_vec1::<f32>()?;
|
||||||
|
if src.len() % T::BLCK_SIZE != 0 {
|
||||||
crate::bail!(
|
crate::bail!(
|
||||||
"tensor size ({shape:?}) is not divisible by block size {}",
|
"tensor size ({shape:?}) is not divisible by block size {}",
|
||||||
block_size
|
T::BLCK_SIZE
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
let mut storage = src.device().qzeros(elem_count, dtype)?;
|
let mut data = vec![T::zeros(); src.len() / T::BLCK_SIZE];
|
||||||
storage.quantize(&src.storage())?;
|
T::from_float(&src, &mut data)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
storage,
|
data: Box::new(data),
|
||||||
shape: shape.clone(),
|
shape: shape.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dtype(&self) -> GgmlDType {
|
pub fn dtype(&self) -> GgmlDType {
|
||||||
self.storage.dtype()
|
self.data.dtype()
|
||||||
}
|
|
||||||
|
|
||||||
pub fn device(&self) -> Device {
|
|
||||||
self.storage.device()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn rank(&self) -> usize {
|
pub fn rank(&self) -> usize {
|
||||||
@ -358,34 +213,21 @@ impl QTensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
||||||
let storage = self.storage.dequantize(self.shape.elem_count())?;
|
let mut f32_data = vec![0f32; self.shape.elem_count()];
|
||||||
let none = crate::op::BackpropOp::none();
|
self.data.to_float(&mut f32_data)?;
|
||||||
crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device)
|
Tensor::from_vec(f32_data, &self.shape, device)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> {
|
pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
|
||||||
// In the CUDA case, we have a specialized kernel as this can be useful for volta
|
self.data.matmul_t(mkn, lhs, dst)
|
||||||
// architectures. https://github.com/huggingface/candle/issues/2136
|
|
||||||
match &self.storage {
|
|
||||||
QStorage::Cuda(s) => {
|
|
||||||
let s = s.dequantize_f16(self.shape.elem_count())?;
|
|
||||||
let none = crate::op::BackpropOp::none();
|
|
||||||
crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false)
|
|
||||||
.to_device(device)
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?;
|
|
||||||
Ok(s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn storage_size_in_bytes(&self) -> usize {
|
pub fn storage_size_in_bytes(&self) -> usize {
|
||||||
self.storage.size_in_bytes()
|
self.data.storage_size_in_bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn data(&self) -> Result<Cow<'_, [u8]>> {
|
pub fn as_ptr(&self) -> *const u8 {
|
||||||
self.storage.data()
|
self.data.as_ptr()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -393,7 +235,6 @@ impl QTensor {
|
|||||||
pub enum QMatMul {
|
pub enum QMatMul {
|
||||||
QTensor(std::sync::Arc<QTensor>),
|
QTensor(std::sync::Arc<QTensor>),
|
||||||
Tensor(Tensor),
|
Tensor(Tensor),
|
||||||
TensorF16(Tensor),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
thread_local! {
|
thread_local! {
|
||||||
@ -407,17 +248,6 @@ thread_local! {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
thread_local! {
|
|
||||||
static DEQUANTIZE_ALL_F16: bool = {
|
|
||||||
match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") {
|
|
||||||
Ok(s) => {
|
|
||||||
!s.is_empty() && s != "0"
|
|
||||||
},
|
|
||||||
Err(_) => false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl QMatMul {
|
impl QMatMul {
|
||||||
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
|
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
|
||||||
let dequantize = match qtensor.dtype() {
|
let dequantize = match qtensor.dtype() {
|
||||||
@ -425,11 +255,8 @@ impl QMatMul {
|
|||||||
_ => DEQUANTIZE_ALL.with(|b| *b),
|
_ => DEQUANTIZE_ALL.with(|b| *b),
|
||||||
};
|
};
|
||||||
let t = if dequantize {
|
let t = if dequantize {
|
||||||
let tensor = qtensor.dequantize(&qtensor.device())?;
|
let tensor = qtensor.dequantize(&Device::Cpu)?;
|
||||||
Self::Tensor(tensor)
|
Self::Tensor(tensor)
|
||||||
} else if DEQUANTIZE_ALL_F16.with(|b| *b) {
|
|
||||||
let tensor = qtensor.dequantize_f16(&qtensor.device())?;
|
|
||||||
Self::TensorF16(tensor)
|
|
||||||
} else {
|
} else {
|
||||||
Self::QTensor(qtensor)
|
Self::QTensor(qtensor)
|
||||||
};
|
};
|
||||||
@ -439,25 +266,6 @@ impl QMatMul {
|
|||||||
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
|
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
|
||||||
Self::from_arc(std::sync::Arc::new(qtensor))
|
Self::from_arc(std::sync::Arc::new(qtensor))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dequantize_f16(&self) -> Result<Tensor> {
|
|
||||||
match self {
|
|
||||||
Self::QTensor(t) => t.dequantize_f16(&t.device()),
|
|
||||||
Self::Tensor(t) => t.to_dtype(DType::F16),
|
|
||||||
Self::TensorF16(t) => Ok(t.clone()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn forward_via_f16(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let w = self.dequantize_f16()?;
|
|
||||||
let in_dtype = xs.dtype();
|
|
||||||
let w = match *xs.dims() {
|
|
||||||
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
|
|
||||||
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
|
|
||||||
_ => w.t()?,
|
|
||||||
};
|
|
||||||
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl crate::CustomOp1 for QTensor {
|
impl crate::CustomOp1 for QTensor {
|
||||||
@ -486,45 +294,21 @@ impl crate::CustomOp1 for QTensor {
|
|||||||
}
|
}
|
||||||
dst_shape.push(n);
|
dst_shape.push(n);
|
||||||
let dst_shape = Shape::from(dst_shape);
|
let dst_shape = Shape::from(dst_shape);
|
||||||
#[allow(clippy::infallible_destructuring_match)]
|
let storage = storage.as_slice::<f32>()?;
|
||||||
let self_storage = match &self.storage {
|
let storage =
|
||||||
QStorage::Cpu(storage) => storage,
|
&storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||||
QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"),
|
|
||||||
};
|
|
||||||
let slice = storage.as_slice::<f32>()?;
|
|
||||||
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
|
||||||
let mut dst_storage = vec![0f32; dst_shape.elem_count()];
|
let mut dst_storage = vec![0f32; dst_shape.elem_count()];
|
||||||
self_storage.matmul_t((dst_shape.elem_count() / n, k, n), slice, &mut dst_storage)?;
|
self.matmul_t(
|
||||||
|
(dst_shape.elem_count() / n, k, n),
|
||||||
|
storage,
|
||||||
|
&mut dst_storage,
|
||||||
|
)?;
|
||||||
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn metal_fwd(
|
|
||||||
&self,
|
|
||||||
storage: &crate::MetalStorage,
|
|
||||||
layout: &crate::Layout,
|
|
||||||
) -> Result<(crate::MetalStorage, Shape)> {
|
|
||||||
let self_storage = match &self.storage {
|
|
||||||
QStorage::Metal(metal) => metal,
|
|
||||||
_ => unreachable!("Cannot call metal matmul on non metal QTensor"),
|
|
||||||
};
|
|
||||||
self_storage.fwd(&self.shape, storage, layout)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cuda_fwd(
|
|
||||||
&self,
|
|
||||||
storage: &crate::CudaStorage,
|
|
||||||
layout: &crate::Layout,
|
|
||||||
) -> Result<(crate::CudaStorage, Shape)> {
|
|
||||||
let self_storage = match &self.storage {
|
|
||||||
QStorage::Cuda(cuda) => cuda,
|
|
||||||
_ => unreachable!("Cannot call cuda matmul on non cuda QTensor"),
|
|
||||||
};
|
|
||||||
self_storage.fwd(&self.shape, storage, layout)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl crate::Module for QMatMul {
|
impl QMatMul {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
match self {
|
match self {
|
||||||
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
||||||
Self::Tensor(w) => {
|
Self::Tensor(w) => {
|
||||||
@ -535,15 +319,6 @@ impl crate::Module for QMatMul {
|
|||||||
};
|
};
|
||||||
xs.matmul(&w)
|
xs.matmul(&w)
|
||||||
}
|
}
|
||||||
Self::TensorF16(w) => {
|
|
||||||
let in_dtype = xs.dtype();
|
|
||||||
let w = match *xs.dims() {
|
|
||||||
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
|
|
||||||
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
|
|
||||||
_ => w.t()?,
|
|
||||||
};
|
|
||||||
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -12,14 +12,6 @@ use core::arch::arm::*;
|
|||||||
#[cfg(target_arch = "aarch64")]
|
#[cfg(target_arch = "aarch64")]
|
||||||
use core::arch::aarch64::*;
|
use core::arch::aarch64::*;
|
||||||
|
|
||||||
#[inline(always)]
|
|
||||||
unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t {
|
|
||||||
// TODO: dotprod
|
|
||||||
let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b));
|
|
||||||
let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
|
|
||||||
vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||||
let qk = QK8_0;
|
let qk = QK8_0;
|
||||||
@ -27,39 +19,71 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
|
|||||||
if n % QK8_0 != 0 {
|
if n % QK8_0 != 0 {
|
||||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||||
}
|
}
|
||||||
|
if nb % 2 != 0 {
|
||||||
|
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||||
|
}
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let mut sumv0 = vdupq_n_f32(0.0f32);
|
let mut sumv0 = vdupq_n_f32(0.0f32);
|
||||||
for i in 0..nb {
|
let mut sumv1 = vdupq_n_f32(0.0f32);
|
||||||
|
for i in (0..nb).step_by(2) {
|
||||||
let x0 = &xs[i];
|
let x0 = &xs[i];
|
||||||
|
let x1 = &xs[i + 1];
|
||||||
let y0 = &ys[i];
|
let y0 = &ys[i];
|
||||||
|
let y1 = &ys[i + 1];
|
||||||
|
|
||||||
let m4b = vdupq_n_u8(0x0F);
|
let m4b = vdupq_n_u8(0x0F);
|
||||||
let s8b = vdupq_n_s8(0x8);
|
let s8b = vdupq_n_s8(0x8);
|
||||||
|
|
||||||
let v0_0 = vld1q_u8(x0.qs.as_ptr());
|
let v0_0 = vld1q_u8(x0.qs.as_ptr());
|
||||||
|
let v0_1 = vld1q_u8(x1.qs.as_ptr());
|
||||||
|
|
||||||
// 4-bit -> 8-bit
|
// 4-bit -> 8-bit
|
||||||
let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
|
let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
|
||||||
let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
||||||
|
let v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
|
||||||
|
let v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
||||||
|
|
||||||
// sub 8
|
// sub 8
|
||||||
let v0_0ls = vsubq_s8(v0_0l, s8b);
|
let v0_0ls = vsubq_s8(v0_0l, s8b);
|
||||||
let v0_0hs = vsubq_s8(v0_0h, s8b);
|
let v0_0hs = vsubq_s8(v0_0h, s8b);
|
||||||
|
let v0_1ls = vsubq_s8(v0_1l, s8b);
|
||||||
|
let v0_1hs = vsubq_s8(v0_1h, s8b);
|
||||||
|
|
||||||
// load y
|
// load y
|
||||||
let v1_0l = vld1q_s8(y0.qs.as_ptr());
|
let v1_0l = vld1q_s8(y0.qs.as_ptr());
|
||||||
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
|
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||||
|
let v1_1l = vld1q_s8(y1.qs.as_ptr());
|
||||||
|
let v1_1h = vld1q_s8(y1.qs.as_ptr().add(16));
|
||||||
|
|
||||||
|
// TODO: Support dotprod when it's available outside of nightly.
|
||||||
|
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
|
||||||
|
let pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
|
||||||
|
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
|
||||||
|
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
|
||||||
|
|
||||||
|
let pl1l = vmull_s8(vget_low_s8(v0_1ls), vget_low_s8(v1_1l));
|
||||||
|
let pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
|
||||||
|
let ph1l = vmull_s8(vget_low_s8(v0_1hs), vget_low_s8(v1_1h));
|
||||||
|
let ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
|
||||||
|
|
||||||
|
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
||||||
|
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
||||||
|
let pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
||||||
|
let ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
||||||
|
|
||||||
let pl0 = vdotq_s32(v0_0ls, v1_0l);
|
|
||||||
let ph0 = vdotq_s32(v0_0hs, v1_0h);
|
|
||||||
sumv0 = vmlaq_n_f32(
|
sumv0 = vmlaq_n_f32(
|
||||||
sumv0,
|
sumv0,
|
||||||
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
|
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
|
||||||
x0.d.to_f32() * y0.d.to_f32(),
|
x0.d.to_f32() * y0.d.to_f32(),
|
||||||
);
|
);
|
||||||
|
sumv1 = vmlaq_n_f32(
|
||||||
|
sumv1,
|
||||||
|
vcvtq_f32_s32(vaddq_s32(pl1, ph1)),
|
||||||
|
x1.d.to_f32() * y1.d.to_f32(),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
Ok(vaddvq_f32(sumv0))
|
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -70,29 +94,57 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
|||||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||||
}
|
}
|
||||||
let nb = n / QK8_0;
|
let nb = n / QK8_0;
|
||||||
|
if nb % 2 != 0 {
|
||||||
|
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
|
||||||
|
}
|
||||||
unsafe {
|
unsafe {
|
||||||
let mut sumv0 = vdupq_n_f32(0.0f32);
|
let mut sumv0 = vdupq_n_f32(0.0f32);
|
||||||
for i in 0..nb {
|
let mut sumv1 = vdupq_n_f32(0.0f32);
|
||||||
|
for i in (0..nb).step_by(2) {
|
||||||
let x0 = &xs[i];
|
let x0 = &xs[i];
|
||||||
|
let x1 = &xs[i + 1];
|
||||||
let y0 = &ys[i];
|
let y0 = &ys[i];
|
||||||
|
let y1 = &ys[i + 1];
|
||||||
|
|
||||||
let x0_0 = vld1q_s8(x0.qs.as_ptr());
|
let x0_0 = vld1q_s8(x0.qs.as_ptr());
|
||||||
let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16));
|
let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16));
|
||||||
|
let x1_0 = vld1q_s8(x1.qs.as_ptr());
|
||||||
|
let x1_1 = vld1q_s8(x1.qs.as_ptr().add(16));
|
||||||
|
|
||||||
// load y
|
// load y
|
||||||
let y0_0 = vld1q_s8(y0.qs.as_ptr());
|
let y0_0 = vld1q_s8(y0.qs.as_ptr());
|
||||||
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
|
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||||
|
let y1_0 = vld1q_s8(y1.qs.as_ptr());
|
||||||
|
let y1_1 = vld1q_s8(y1.qs.as_ptr().add(16));
|
||||||
|
|
||||||
let p0 = vdotq_s32(x0_0, y0_0);
|
// TODO dotprod once this is the intrinsics are.
|
||||||
let p1 = vdotq_s32(x0_1, y0_1);
|
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
|
||||||
|
let p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
|
||||||
|
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
|
||||||
|
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
|
||||||
|
|
||||||
|
let p1_0 = vmull_s8(vget_low_s8(x1_0), vget_low_s8(y1_0));
|
||||||
|
let p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
|
||||||
|
let p1_2 = vmull_s8(vget_low_s8(x1_1), vget_low_s8(y1_1));
|
||||||
|
let p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
|
||||||
|
|
||||||
|
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
|
||||||
|
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
|
||||||
|
let p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
|
||||||
|
let p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
|
||||||
|
|
||||||
sumv0 = vmlaq_n_f32(
|
sumv0 = vmlaq_n_f32(
|
||||||
sumv0,
|
sumv0,
|
||||||
vcvtq_f32_s32(vaddq_s32(p0, p1)),
|
vcvtq_f32_s32(vaddq_s32(p0, p1)),
|
||||||
x0.d.to_f32() * y0.d.to_f32(),
|
x0.d.to_f32() * y0.d.to_f32(),
|
||||||
);
|
);
|
||||||
|
sumv1 = vmlaq_n_f32(
|
||||||
|
sumv1,
|
||||||
|
vcvtq_f32_s32(vaddq_s32(p2, p3)),
|
||||||
|
x1.d.to_f32() * y1.d.to_f32(),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
Ok(vaddvq_f32(sumv0))
|
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -113,7 +165,10 @@ pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Res
|
|||||||
for i in (0..QK_K).step_by(16) {
|
for i in (0..QK_K).step_by(16) {
|
||||||
let xs = vld1q_s8(xs.add(i));
|
let xs = vld1q_s8(xs.add(i));
|
||||||
let ys = vld1q_s8(ys.add(i));
|
let ys = vld1q_s8(ys.add(i));
|
||||||
let xy = vdotq_s32(xs, ys);
|
let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys));
|
||||||
|
let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys));
|
||||||
|
|
||||||
|
let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up));
|
||||||
sum_i = vaddq_s32(sum_i, xy)
|
sum_i = vaddq_s32(sum_i, xy)
|
||||||
}
|
}
|
||||||
sumf += vaddvq_s32(sum_i) as f32 * scale
|
sumf += vaddvq_s32(sum_i) as f32 * scale
|
||||||
@ -183,16 +238,30 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
|
|||||||
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2));
|
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2));
|
||||||
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3));
|
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3));
|
||||||
|
|
||||||
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
|
// TODO: dotprod
|
||||||
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
|
|
||||||
|
let p0 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
|
||||||
|
);
|
||||||
|
let p1 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
|
||||||
|
);
|
||||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||||
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
|
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
|
||||||
scale = scale.add(2);
|
scale = scale.add(2);
|
||||||
|
|
||||||
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
|
let p2 = vaddq_s16(
|
||||||
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
|
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
|
||||||
|
);
|
||||||
|
let p3 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
|
||||||
|
);
|
||||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||||
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
|
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
|
||||||
scale = scale.add(2);
|
scale = scale.add(2);
|
||||||
|
|
||||||
let q8bytes = vld1q_s8_x4(q8);
|
let q8bytes = vld1q_s8_x4(q8);
|
||||||
@ -212,16 +281,29 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
|
|||||||
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2));
|
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2));
|
||||||
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3));
|
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3));
|
||||||
|
|
||||||
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
|
// TODO: dotprod case.
|
||||||
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
|
let p0 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
|
||||||
|
);
|
||||||
|
let p1 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
|
||||||
|
);
|
||||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||||
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
|
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
|
||||||
scale = scale.add(2);
|
scale = scale.add(2);
|
||||||
|
|
||||||
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
|
let p2 = vaddq_s16(
|
||||||
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
|
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
|
||||||
|
);
|
||||||
|
let p3 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
|
||||||
|
);
|
||||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||||
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
|
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
|
||||||
scale = scale.add(2);
|
scale = scale.add(2);
|
||||||
}
|
}
|
||||||
sum += d_all * y.d * ((isum - 32 * isum_mins) as f32);
|
sum += d_all * y.d * ((isum - 32 * isum_mins) as f32);
|
||||||
@ -298,14 +380,28 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
|
|||||||
let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
|
let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
|
||||||
let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
|
let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
|
||||||
|
|
||||||
let p0 = vdotq_s32(q5bytes_0, q8bytes.0);
|
// TODO: dotprod
|
||||||
let p1 = vdotq_s32(q5bytes_1, q8bytes.1);
|
|
||||||
sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32;
|
let p0 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)),
|
||||||
|
vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)),
|
||||||
|
);
|
||||||
|
let p1 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)),
|
||||||
|
vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)),
|
||||||
|
);
|
||||||
|
sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32;
|
||||||
scales = scales.add(1);
|
scales = scales.add(1);
|
||||||
|
|
||||||
let p2 = vdotq_s32(q5bytes_2, q8bytes.2);
|
let p2 = vaddq_s16(
|
||||||
let p3 = vdotq_s32(q5bytes_3, q8bytes.3);
|
vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)),
|
||||||
sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32;
|
vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)),
|
||||||
|
);
|
||||||
|
let p3 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)),
|
||||||
|
vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)),
|
||||||
|
);
|
||||||
|
sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32;
|
||||||
scales = scales.add(1);
|
scales = scales.add(1);
|
||||||
}
|
}
|
||||||
sumf += d * sumi as f32 - dmin * sumi_mins as f32;
|
sumf += d * sumi as f32 - dmin * sumi_mins as f32;
|
||||||
@ -368,15 +464,22 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
|
|||||||
for j in 0..QK_K / 64 {
|
for j in 0..QK_K / 64 {
|
||||||
let q4bits = vld1q_u8_x2(q4);
|
let q4bits = vld1q_u8_x2(q4);
|
||||||
q4 = q4.add(32);
|
q4 = q4.add(32);
|
||||||
|
// TODO: dotprod
|
||||||
let q8bytes = vld1q_s8_x2(q8);
|
let q8bytes = vld1q_s8_x2(q8);
|
||||||
q8 = q8.add(32);
|
q8 = q8.add(32);
|
||||||
let q4bytes = int8x16x2_t(
|
let q4bytes = int8x16x2_t(
|
||||||
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
|
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
|
||||||
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
|
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
|
||||||
);
|
);
|
||||||
let p0 = vdotq_s32(q4bytes.0, q8bytes.0);
|
let p0 = vaddq_s16(
|
||||||
let p1 = vdotq_s32(q4bytes.1, q8bytes.1);
|
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
|
||||||
sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32;
|
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
|
||||||
|
);
|
||||||
|
let p1 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
|
||||||
|
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
|
||||||
|
);
|
||||||
|
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32;
|
||||||
|
|
||||||
let q8bytes = vld1q_s8_x2(q8);
|
let q8bytes = vld1q_s8_x2(q8);
|
||||||
q8 = q8.add(32);
|
q8 = q8.add(32);
|
||||||
@ -384,9 +487,15 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
|
|||||||
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
|
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
|
||||||
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
|
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
|
||||||
);
|
);
|
||||||
let p2 = vdotq_s32(q4bytes.0, q8bytes.0);
|
let p2 = vaddq_s16(
|
||||||
let p3 = vdotq_s32(q4bytes.1, q8bytes.1);
|
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
|
||||||
sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32;
|
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
|
||||||
|
);
|
||||||
|
let p3 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
|
||||||
|
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
|
||||||
|
);
|
||||||
|
sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32;
|
||||||
}
|
}
|
||||||
sumf += d * (sumi1 + sumi2) as f32;
|
sumf += d * (sumi1 + sumi2) as f32;
|
||||||
}
|
}
|
||||||
@ -464,14 +573,27 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
|
|||||||
vreinterpretq_s8_u8(q3h_3),
|
vreinterpretq_s8_u8(q3h_3),
|
||||||
);
|
);
|
||||||
|
|
||||||
let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0);
|
// TODO: dotprod
|
||||||
let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1);
|
let p0 = vaddq_s16(
|
||||||
let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2);
|
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)),
|
||||||
let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3);
|
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)),
|
||||||
isum += vaddvq_s32(p0) * *scale as i32
|
);
|
||||||
+ vaddvq_s32(p1) * *scale.add(1) as i32
|
let p1 = vaddq_s16(
|
||||||
+ vaddvq_s32(p2) * *scale.add(2) as i32
|
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)),
|
||||||
+ vaddvq_s32(p3) * *scale.add(3) as i32;
|
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)),
|
||||||
|
);
|
||||||
|
let p2 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)),
|
||||||
|
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)),
|
||||||
|
);
|
||||||
|
let p3 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)),
|
||||||
|
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)),
|
||||||
|
);
|
||||||
|
isum += vaddvq_s16(p0) as i32 * *scale as i32
|
||||||
|
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
|
||||||
|
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
|
||||||
|
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
|
||||||
scale = scale.add(4);
|
scale = scale.add(4);
|
||||||
|
|
||||||
let q3h_0 = vbicq_u8(m2, qhbits.0);
|
let q3h_0 = vbicq_u8(m2, qhbits.0);
|
||||||
@ -496,14 +618,27 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
|
|||||||
vreinterpretq_s8_u8(q3h_3),
|
vreinterpretq_s8_u8(q3h_3),
|
||||||
);
|
);
|
||||||
|
|
||||||
let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0);
|
// TODO: dotprod
|
||||||
let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1);
|
let p0 = vaddq_s16(
|
||||||
let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2);
|
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)),
|
||||||
let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3);
|
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)),
|
||||||
isum += vaddvq_s32(p0) * *scale as i32
|
);
|
||||||
+ vaddvq_s32(p1) * *scale.add(1) as i32
|
let p1 = vaddq_s16(
|
||||||
+ vaddvq_s32(p2) * *scale.add(2) as i32
|
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)),
|
||||||
+ vaddvq_s32(p3) * *scale.add(3) as i32;
|
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)),
|
||||||
|
);
|
||||||
|
let p2 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)),
|
||||||
|
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)),
|
||||||
|
);
|
||||||
|
let p3 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)),
|
||||||
|
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)),
|
||||||
|
);
|
||||||
|
isum += vaddvq_s16(p0) as i32 * *scale as i32
|
||||||
|
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
|
||||||
|
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
|
||||||
|
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
|
||||||
scale = scale.add(4);
|
scale = scale.add(4);
|
||||||
|
|
||||||
if j == 0 {
|
if j == 0 {
|
||||||
@ -561,6 +696,7 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res
|
|||||||
let mut is = 0usize;
|
let mut is = 0usize;
|
||||||
|
|
||||||
// TODO: dotprod
|
// TODO: dotprod
|
||||||
|
|
||||||
for _j in 0..QK_K / 128 {
|
for _j in 0..QK_K / 128 {
|
||||||
let q2bits = vld1q_u8_x2(q2);
|
let q2bits = vld1q_u8_x2(q2);
|
||||||
q2 = q2.add(32);
|
q2 = q2.add(32);
|
||||||
@ -607,7 +743,14 @@ unsafe fn multiply_accum_with_scale(
|
|||||||
q2bytes: int8x16x2_t,
|
q2bytes: int8x16x2_t,
|
||||||
q8bytes: int8x16x2_t,
|
q8bytes: int8x16x2_t,
|
||||||
) -> i32 {
|
) -> i32 {
|
||||||
let p1 = vdotq_s32(q2bytes.0, q8bytes.0);
|
let p1 = vaddq_s16(
|
||||||
let p2 = vdotq_s32(q2bytes.1, q8bytes.1);
|
vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)),
|
||||||
vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32
|
vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)),
|
||||||
|
);
|
||||||
|
let p2 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)),
|
||||||
|
vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)),
|
||||||
|
);
|
||||||
|
vaddvq_s16(p1) as i32 * aux[is + index] as i32
|
||||||
|
+ vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,10 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
|
|||||||
if n % QK8_0 != 0 {
|
if n % QK8_0 != 0 {
|
||||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||||
}
|
}
|
||||||
|
let nb = n / QK8_0;
|
||||||
|
if nb % 2 != 0 {
|
||||||
|
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||||
|
}
|
||||||
unsafe {
|
unsafe {
|
||||||
let mut acc = f32x4_splat(0.0f32);
|
let mut acc = f32x4_splat(0.0f32);
|
||||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||||
@ -57,6 +61,10 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
|||||||
if n % QK8_0 != 0 {
|
if n % QK8_0 != 0 {
|
||||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||||
}
|
}
|
||||||
|
let nb = n / QK8_0;
|
||||||
|
if nb % 2 != 0 {
|
||||||
|
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
|
||||||
|
}
|
||||||
unsafe {
|
unsafe {
|
||||||
let mut acc = f32x4_splat(0.0f32);
|
let mut acc = f32x4_splat(0.0f32);
|
||||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||||
|
@ -349,30 +349,6 @@ impl MmapedSafetensors {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct SliceSafetensors<'a> {
|
|
||||||
safetensors: SafeTensors<'a>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> SliceSafetensors<'a> {
|
|
||||||
/// Creates a wrapper around a binary buffer and deserialize the safetensors header.
|
|
||||||
pub fn new(buffer: &'a [u8]) -> Result<Self> {
|
|
||||||
let safetensors = safetensors::SafeTensors::deserialize(buffer)?;
|
|
||||||
Ok(Self { safetensors })
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
|
|
||||||
self.safetensors.tensor(name)?.load(dev)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
|
|
||||||
self.safetensors.tensors()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
|
|
||||||
Ok(self.safetensors.tensor(name)?)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct BufferedSafetensors {
|
pub struct BufferedSafetensors {
|
||||||
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
|
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
|
||||||
}
|
}
|
||||||
|
@ -171,7 +171,7 @@ impl Shape {
|
|||||||
}
|
}
|
||||||
let mut acc = 1;
|
let mut acc = 1;
|
||||||
for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
|
for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
|
||||||
if dim > 1 && stride != acc {
|
if stride != acc {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
acc *= dim;
|
acc *= dim;
|
||||||
@ -186,7 +186,7 @@ impl Shape {
|
|||||||
}
|
}
|
||||||
let mut acc = 1;
|
let mut acc = 1;
|
||||||
for (&stride, &dim) in stride.iter().zip(self.0.iter()) {
|
for (&stride, &dim) in stride.iter().zip(self.0.iter()) {
|
||||||
if dim > 1 && stride != acc {
|
if stride != acc {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
acc *= dim;
|
acc *= dim;
|
||||||
@ -203,7 +203,7 @@ impl Shape {
|
|||||||
|
|
||||||
/// Check whether the two shapes are compatible for broadcast, and if it is the case return the
|
/// Check whether the two shapes are compatible for broadcast, and if it is the case return the
|
||||||
/// broadcasted shape. This is to be used for binary pointwise ops.
|
/// broadcasted shape. This is to be used for binary pointwise ops.
|
||||||
pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
|
pub(crate) fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
|
||||||
let lhs = self;
|
let lhs = self;
|
||||||
let lhs_dims = lhs.dims();
|
let lhs_dims = lhs.dims();
|
||||||
let rhs_dims = rhs.dims();
|
let rhs_dims = rhs.dims();
|
||||||
@ -478,139 +478,6 @@ extract_dims!(
|
|||||||
(usize, usize, usize, usize, usize)
|
(usize, usize, usize, usize, usize)
|
||||||
);
|
);
|
||||||
|
|
||||||
pub trait ShapeWithOneHole {
|
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape>;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S: Into<Shape>> ShapeWithOneHole for S {
|
|
||||||
fn into_shape(self, _el_count: usize) -> Result<Shape> {
|
|
||||||
Ok(self.into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ShapeWithOneHole for ((),) {
|
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
|
||||||
Ok(el_count.into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
|
|
||||||
if prod_d == 0 {
|
|
||||||
crate::bail!("cannot reshape tensor of {el_count} elements to {s:?}")
|
|
||||||
}
|
|
||||||
if el_count % prod_d != 0 {
|
|
||||||
crate::bail!("cannot reshape tensor with {el_count} elements to {s:?}")
|
|
||||||
}
|
|
||||||
Ok(el_count / prod_d)
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ShapeWithOneHole for ((), usize) {
|
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
|
||||||
let ((), d1) = self;
|
|
||||||
Ok((hole_size(el_count, d1, &self)?, d1).into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ShapeWithOneHole for (usize, ()) {
|
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
|
||||||
let (d1, ()) = self;
|
|
||||||
Ok((d1, hole_size(el_count, d1, &self)?).into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ShapeWithOneHole for ((), usize, usize) {
|
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
|
||||||
let ((), d1, d2) = self;
|
|
||||||
Ok((hole_size(el_count, d1 * d2, &self)?, d1, d2).into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ShapeWithOneHole for (usize, (), usize) {
|
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
|
||||||
let (d1, (), d2) = self;
|
|
||||||
Ok((d1, hole_size(el_count, d1 * d2, &self)?, d2).into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ShapeWithOneHole for (usize, usize, ()) {
|
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
|
||||||
let (d1, d2, ()) = self;
|
|
||||||
Ok((d1, d2, hole_size(el_count, d1 * d2, &self)?).into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ShapeWithOneHole for ((), usize, usize, usize) {
|
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
|
||||||
let ((), d1, d2, d3) = self;
|
|
||||||
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
|
|
||||||
Ok((d, d1, d2, d3).into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ShapeWithOneHole for (usize, (), usize, usize) {
|
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
|
||||||
let (d1, (), d2, d3) = self;
|
|
||||||
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
|
|
||||||
Ok((d1, d, d2, d3).into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ShapeWithOneHole for (usize, usize, (), usize) {
|
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
|
||||||
let (d1, d2, (), d3) = self;
|
|
||||||
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
|
|
||||||
Ok((d1, d2, d, d3).into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ShapeWithOneHole for (usize, usize, usize, ()) {
|
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
|
||||||
let (d1, d2, d3, ()) = self;
|
|
||||||
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
|
|
||||||
Ok((d1, d2, d3, d).into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
|
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
|
||||||
let ((), d1, d2, d3, d4) = self;
|
|
||||||
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
|
||||||
Ok((d, d1, d2, d3, d4).into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
|
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
|
||||||
let (d1, (), d2, d3, d4) = self;
|
|
||||||
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
|
||||||
Ok((d1, d, d2, d3, d4).into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
|
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
|
||||||
let (d1, d2, (), d3, d4) = self;
|
|
||||||
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
|
||||||
Ok((d1, d2, d, d3, d4).into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
|
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
|
||||||
let (d1, d2, d3, (), d4) = self;
|
|
||||||
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
|
||||||
Ok((d1, d2, d3, d, d4).into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
|
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
|
||||||
let (d1, d2, d3, d4, ()) = self;
|
|
||||||
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
|
||||||
Ok((d1, d2, d3, d4, d).into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@ -627,3 +494,171 @@ mod tests {
|
|||||||
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
|
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait ShapeWithOneHole {
|
||||||
|
fn into_shape(self, el_count: usize) -> Result<Shape>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: Into<Shape>> ShapeWithOneHole for S {
|
||||||
|
fn into_shape(self, _el_count: usize) -> Result<Shape> {
|
||||||
|
Ok(self.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShapeWithOneHole for ((),) {
|
||||||
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
|
Ok(el_count.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShapeWithOneHole for ((), usize) {
|
||||||
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
|
let ((), d1) = self;
|
||||||
|
if el_count % d1 != 0 {
|
||||||
|
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
|
||||||
|
}
|
||||||
|
Ok((el_count / d1, d1).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShapeWithOneHole for (usize, ()) {
|
||||||
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
|
let (d1, ()) = self;
|
||||||
|
if el_count % d1 != 0 {
|
||||||
|
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
|
||||||
|
}
|
||||||
|
Ok((d1, el_count / d1).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShapeWithOneHole for ((), usize, usize) {
|
||||||
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
|
let ((), d1, d2) = self;
|
||||||
|
let d = d1 * d2;
|
||||||
|
if el_count % d != 0 {
|
||||||
|
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||||
|
}
|
||||||
|
Ok((el_count / d, d1, d2).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShapeWithOneHole for (usize, (), usize) {
|
||||||
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
|
let (d1, (), d2) = self;
|
||||||
|
let d = d1 * d2;
|
||||||
|
if el_count % d != 0 {
|
||||||
|
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||||
|
}
|
||||||
|
Ok((d1, el_count / d, d2).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShapeWithOneHole for (usize, usize, ()) {
|
||||||
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
|
let (d1, d2, ()) = self;
|
||||||
|
let d = d1 * d2;
|
||||||
|
if el_count % d != 0 {
|
||||||
|
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||||
|
}
|
||||||
|
Ok((d1, d2, el_count / d).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShapeWithOneHole for ((), usize, usize, usize) {
|
||||||
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
|
let ((), d1, d2, d3) = self;
|
||||||
|
let d = d1 * d2 * d3;
|
||||||
|
if el_count % d != 0 {
|
||||||
|
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||||
|
}
|
||||||
|
Ok((el_count / d, d1, d2, d3).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShapeWithOneHole for (usize, (), usize, usize) {
|
||||||
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
|
let (d1, (), d2, d3) = self;
|
||||||
|
let d = d1 * d2 * d3;
|
||||||
|
if el_count % d != 0 {
|
||||||
|
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||||
|
}
|
||||||
|
Ok((d1, el_count / d, d2, d3).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShapeWithOneHole for (usize, usize, (), usize) {
|
||||||
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
|
let (d1, d2, (), d3) = self;
|
||||||
|
let d = d1 * d2 * d3;
|
||||||
|
if el_count % d != 0 {
|
||||||
|
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||||
|
}
|
||||||
|
Ok((d1, d2, el_count / d, d3).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShapeWithOneHole for (usize, usize, usize, ()) {
|
||||||
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
|
let (d1, d2, d3, ()) = self;
|
||||||
|
let d = d1 * d2 * d3;
|
||||||
|
if el_count % d != 0 {
|
||||||
|
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||||
|
}
|
||||||
|
Ok((d1, d2, d3, el_count / d).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
|
||||||
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
|
let ((), d1, d2, d3, d4) = self;
|
||||||
|
let d = d1 * d2 * d3 * d4;
|
||||||
|
if el_count % d != 0 {
|
||||||
|
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||||
|
}
|
||||||
|
Ok((el_count / d, d1, d2, d3, d4).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
|
||||||
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
|
let (d1, (), d2, d3, d4) = self;
|
||||||
|
let d = d1 * d2 * d3 * d4;
|
||||||
|
if el_count % d != 0 {
|
||||||
|
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||||
|
}
|
||||||
|
Ok((d1, el_count / d, d2, d3, d4).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
|
||||||
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
|
let (d1, d2, (), d3, d4) = self;
|
||||||
|
let d = d1 * d2 * d3 * d4;
|
||||||
|
if el_count % d != 0 {
|
||||||
|
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||||
|
}
|
||||||
|
Ok((d1, d2, el_count / d, d3, d4).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
|
||||||
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
|
let (d1, d2, d3, (), d4) = self;
|
||||||
|
let d = d1 * d2 * d3 * d4;
|
||||||
|
if el_count % d != 0 {
|
||||||
|
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||||
|
}
|
||||||
|
Ok((d1, d2, d3, el_count / d, d4).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
|
||||||
|
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||||
|
let (d1, d2, d3, d4, ()) = self;
|
||||||
|
let d = d1 * d2 * d3 * d4;
|
||||||
|
if el_count % d != 0 {
|
||||||
|
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||||
|
}
|
||||||
|
Ok((d1, d2, d3, d4, el_count / d).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,239 +0,0 @@
|
|||||||
use crate::{Result, Tensor};
|
|
||||||
use rayon::prelude::*;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
struct ArgSort {
|
|
||||||
asc: bool,
|
|
||||||
last_dim: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ArgSort {
|
|
||||||
fn asort<T: crate::WithDType>(&self, vs: &[T], layout: &crate::Layout) -> Vec<u32> {
|
|
||||||
#[allow(clippy::uninit_vec)]
|
|
||||||
// Safety: indexes are set later in the parallelized section.
|
|
||||||
let mut sort_indexes = unsafe {
|
|
||||||
let el_count = layout.shape().elem_count();
|
|
||||||
let mut v = Vec::with_capacity(el_count);
|
|
||||||
v.set_len(el_count);
|
|
||||||
v
|
|
||||||
};
|
|
||||||
if self.asc {
|
|
||||||
sort_indexes
|
|
||||||
.par_chunks_exact_mut(self.last_dim)
|
|
||||||
.zip(vs.par_chunks_exact(self.last_dim))
|
|
||||||
.for_each(|(indexes, vs)| {
|
|
||||||
indexes
|
|
||||||
.iter_mut()
|
|
||||||
.enumerate()
|
|
||||||
.for_each(|(i, v)| *v = i as u32);
|
|
||||||
indexes.sort_by(|&i, &j| {
|
|
||||||
vs[i as usize]
|
|
||||||
.partial_cmp(&vs[j as usize])
|
|
||||||
.unwrap_or(std::cmp::Ordering::Greater)
|
|
||||||
})
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
sort_indexes
|
|
||||||
.par_chunks_exact_mut(self.last_dim)
|
|
||||||
.zip(vs.par_chunks_exact(self.last_dim))
|
|
||||||
.for_each(|(indexes, vs)| {
|
|
||||||
indexes
|
|
||||||
.iter_mut()
|
|
||||||
.enumerate()
|
|
||||||
.for_each(|(i, v)| *v = i as u32);
|
|
||||||
indexes.sort_by(|&j, &i| {
|
|
||||||
vs[i as usize]
|
|
||||||
.partial_cmp(&vs[j as usize])
|
|
||||||
.unwrap_or(std::cmp::Ordering::Greater)
|
|
||||||
})
|
|
||||||
});
|
|
||||||
}
|
|
||||||
sort_indexes
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl crate::CustomOp1 for ArgSort {
|
|
||||||
fn name(&self) -> &'static str {
|
|
||||||
"argsort"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cpu_fwd(
|
|
||||||
&self,
|
|
||||||
storage: &crate::CpuStorage,
|
|
||||||
layout: &crate::Layout,
|
|
||||||
) -> Result<(crate::CpuStorage, crate::Shape)> {
|
|
||||||
let sort_indexes = match storage {
|
|
||||||
crate::CpuStorage::U8(vs) => self.asort(vs, layout),
|
|
||||||
crate::CpuStorage::U32(vs) => self.asort(vs, layout),
|
|
||||||
crate::CpuStorage::I64(vs) => self.asort(vs, layout),
|
|
||||||
crate::CpuStorage::BF16(vs) => self.asort(vs, layout),
|
|
||||||
crate::CpuStorage::F16(vs) => self.asort(vs, layout),
|
|
||||||
crate::CpuStorage::F32(vs) => self.asort(vs, layout),
|
|
||||||
crate::CpuStorage::F64(vs) => self.asort(vs, layout),
|
|
||||||
};
|
|
||||||
let sort_indexes = crate::CpuStorage::U32(sort_indexes);
|
|
||||||
Ok((sort_indexes, layout.shape().into()))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "cuda")]
|
|
||||||
fn cuda_fwd(
|
|
||||||
&self,
|
|
||||||
storage: &crate::CudaStorage,
|
|
||||||
layout: &crate::Layout,
|
|
||||||
) -> Result<(crate::CudaStorage, crate::Shape)> {
|
|
||||||
use crate::cuda_backend::cudarc::driver::{
|
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
|
||||||
};
|
|
||||||
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, Map1Any, WrapErr};
|
|
||||||
use crate::{CudaDevice, WithDType};
|
|
||||||
|
|
||||||
impl Map1Any for ArgSort {
|
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
|
||||||
&self,
|
|
||||||
src: &CudaSlice<T>,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
layout: &crate::Layout,
|
|
||||||
_wrap: W,
|
|
||||||
) -> Result<S> {
|
|
||||||
let slice = match layout.contiguous_offsets() {
|
|
||||||
None => crate::bail!("input has to be contiguous"),
|
|
||||||
Some((o1, o2)) => src.slice(o1..o2),
|
|
||||||
};
|
|
||||||
let elem_count = layout.shape().elem_count();
|
|
||||||
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
|
|
||||||
let func = if self.asc {
|
|
||||||
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
|
|
||||||
} else {
|
|
||||||
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
|
|
||||||
};
|
|
||||||
let ncols = self.last_dim;
|
|
||||||
let nrows = elem_count / ncols;
|
|
||||||
let ncols_pad = next_power_of_2(ncols);
|
|
||||||
let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
|
|
||||||
let cfg = LaunchConfig {
|
|
||||||
grid_dim: (1, nrows as u32, 1),
|
|
||||||
block_dim: (ncols_pad as u32, 1, 1),
|
|
||||||
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
|
|
||||||
};
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
Ok(S::U32(dst))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
use crate::backend::BackendStorage;
|
|
||||||
let dev = storage.device();
|
|
||||||
let slice = self.map(&storage.slice, dev, layout)?;
|
|
||||||
let dst = crate::cuda_backend::CudaStorage {
|
|
||||||
slice,
|
|
||||||
device: dev.clone(),
|
|
||||||
};
|
|
||||||
Ok((dst, layout.shape().clone()))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
fn metal_fwd(
|
|
||||||
&self,
|
|
||||||
storage: &crate::MetalStorage,
|
|
||||||
layout: &crate::Layout,
|
|
||||||
) -> Result<(crate::MetalStorage, crate::Shape)> {
|
|
||||||
use crate::backend::BackendStorage;
|
|
||||||
use crate::DType;
|
|
||||||
|
|
||||||
let name = {
|
|
||||||
if self.asc {
|
|
||||||
match storage.dtype() {
|
|
||||||
DType::BF16 => "asort_asc_bf16",
|
|
||||||
DType::F16 => "asort_asc_f16",
|
|
||||||
DType::F32 => "asort_asc_f32",
|
|
||||||
DType::F64 => "asort_asc_f64",
|
|
||||||
DType::U8 => "asort_asc_u8",
|
|
||||||
DType::U32 => "asort_asc_u32",
|
|
||||||
DType::I64 => "asort_asc_i64",
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
match storage.dtype() {
|
|
||||||
DType::BF16 => "asort_desc_bf16",
|
|
||||||
DType::F16 => "asort_desc_f16",
|
|
||||||
DType::F32 => "asort_desc_f32",
|
|
||||||
DType::F64 => "asort_desc_f64",
|
|
||||||
DType::U8 => "asort_desc_u8",
|
|
||||||
DType::U32 => "asort_desc_u32",
|
|
||||||
DType::I64 => "asort_desc_i64",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let device = storage.device();
|
|
||||||
let kernels = device.kernels();
|
|
||||||
let command_buffer = device.command_buffer()?;
|
|
||||||
let el = layout.shape().elem_count();
|
|
||||||
let ncols = self.last_dim;
|
|
||||||
let nrows = el / ncols;
|
|
||||||
let src = crate::metal_backend::buffer_o(storage.buffer(), layout, storage.dtype());
|
|
||||||
let dst = device.new_buffer(el, DType::U32, "asort")?;
|
|
||||||
let mut ncols_pad = 1;
|
|
||||||
while ncols_pad < ncols {
|
|
||||||
ncols_pad *= 2;
|
|
||||||
}
|
|
||||||
candle_metal_kernels::call_arg_sort(
|
|
||||||
device.metal_device(),
|
|
||||||
&command_buffer,
|
|
||||||
kernels,
|
|
||||||
name,
|
|
||||||
nrows,
|
|
||||||
ncols,
|
|
||||||
ncols_pad,
|
|
||||||
src,
|
|
||||||
&dst,
|
|
||||||
)
|
|
||||||
.map_err(crate::Error::wrap)?;
|
|
||||||
let dst = crate::MetalStorage::new(dst, device.clone(), el, DType::U32);
|
|
||||||
Ok((dst, layout.shape().clone()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(unused)]
|
|
||||||
fn next_power_of_2(x: usize) -> usize {
|
|
||||||
let mut n = 1;
|
|
||||||
while n < x {
|
|
||||||
n *= 2
|
|
||||||
}
|
|
||||||
n
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Tensor {
|
|
||||||
/// Returns the indices that sort the tensor along the last dimension.
|
|
||||||
///
|
|
||||||
/// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
|
|
||||||
/// descending order. The sort is unstable so there is no guarantees on the final order when it
|
|
||||||
/// comes to ties.
|
|
||||||
pub fn arg_sort_last_dim(&self, asc: bool) -> Result<Tensor> {
|
|
||||||
if !self.is_contiguous() {
|
|
||||||
return Err(crate::Error::RequiresContiguous {
|
|
||||||
op: "arg_sort_last_dim",
|
|
||||||
});
|
|
||||||
}
|
|
||||||
let last_dim = match self.dims().last() {
|
|
||||||
None => crate::bail!("empty last-dim in arg-sort"),
|
|
||||||
Some(last_dim) => *last_dim,
|
|
||||||
};
|
|
||||||
// No need for a backward pass for arg sort.
|
|
||||||
self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sorts the tensor along the last dimension, returns the sorted tensor together with the
|
|
||||||
/// sorted indexes.
|
|
||||||
///
|
|
||||||
/// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
|
|
||||||
/// descending order. The sort is unstable so there is no guarantees on the final order when it
|
|
||||||
/// comes to ties.
|
|
||||||
pub fn sort_last_dim(&self, asc: bool) -> Result<(Tensor, Tensor)> {
|
|
||||||
if !self.is_contiguous() {
|
|
||||||
return Err(crate::Error::RequiresContiguous {
|
|
||||||
op: "sort_last_dim",
|
|
||||||
});
|
|
||||||
}
|
|
||||||
let asort = self.arg_sort_last_dim(asc)?;
|
|
||||||
let sorted = self.gather(&asort, crate::D::Minus1)?;
|
|
||||||
Ok((sorted, asort))
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,7 +1,6 @@
|
|||||||
use crate::backend::BackendStorage;
|
use crate::backend::BackendStorage;
|
||||||
use crate::op::{self, CmpOp, ReduceOp};
|
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
|
||||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
|
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
|
||||||
use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
|
|
||||||
|
|
||||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||||
// out of memory. Instead try_clone should be used.
|
// out of memory. Instead try_clone should be used.
|
||||||
@ -9,7 +8,6 @@ use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3}
|
|||||||
pub enum Storage {
|
pub enum Storage {
|
||||||
Cpu(CpuStorage),
|
Cpu(CpuStorage),
|
||||||
Cuda(CudaStorage),
|
Cuda(CudaStorage),
|
||||||
Metal(MetalStorage),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Storage {
|
impl Storage {
|
||||||
@ -20,10 +18,6 @@ impl Storage {
|
|||||||
let storage = storage.try_clone(layout)?;
|
let storage = storage.try_clone(layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
Self::Metal(storage) => {
|
|
||||||
let storage = storage.try_clone(layout)?;
|
|
||||||
Ok(Self::Metal(storage))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -31,7 +25,6 @@ impl Storage {
|
|||||||
match self {
|
match self {
|
||||||
Self::Cpu(_) => Device::Cpu,
|
Self::Cpu(_) => Device::Cpu,
|
||||||
Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
|
Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
|
||||||
Self::Metal(storage) => Device::Metal(storage.device().clone()),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -39,24 +32,13 @@ impl Storage {
|
|||||||
match self {
|
match self {
|
||||||
Self::Cpu(storage) => storage.dtype(),
|
Self::Cpu(storage) => storage.dtype(),
|
||||||
Self::Cuda(storage) => storage.dtype(),
|
Self::Cuda(storage) => storage.dtype(),
|
||||||
Self::Metal(storage) => storage.dtype(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
|
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
|
||||||
let lhs_device = self.device();
|
let lhs = self.device().location();
|
||||||
let rhs_device = rhs.device();
|
let rhs = rhs.device().location();
|
||||||
let lhs = lhs_device.location();
|
if lhs != rhs {
|
||||||
let rhs = rhs_device.location();
|
|
||||||
let same_device = if self.device().is_metal() {
|
|
||||||
// On metal, we require the device to be exactly the same rather than
|
|
||||||
// having the same location. In cuda this is not necessary as all CudaDevice on the
|
|
||||||
// same GPU will use the same cuda stream.
|
|
||||||
lhs_device.same_device(&rhs_device)
|
|
||||||
} else {
|
|
||||||
lhs == rhs
|
|
||||||
};
|
|
||||||
if !same_device {
|
|
||||||
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
|
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
|
||||||
} else {
|
} else {
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -83,10 +65,6 @@ impl Storage {
|
|||||||
let storage = storage.affine(layout, mul, add)?;
|
let storage = storage.affine(layout, mul, add)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
Self::Metal(storage) => {
|
|
||||||
let storage = storage.affine(layout, mul, add)?;
|
|
||||||
Ok(Self::Metal(storage))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -100,10 +78,6 @@ impl Storage {
|
|||||||
let storage = storage.powf(layout, alpha)?;
|
let storage = storage.powf(layout, alpha)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
Self::Metal(storage) => {
|
|
||||||
let storage = storage.powf(layout, alpha)?;
|
|
||||||
Ok(Self::Metal(storage))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -117,10 +91,6 @@ impl Storage {
|
|||||||
let storage = storage.elu(layout, alpha)?;
|
let storage = storage.elu(layout, alpha)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
Self::Metal(storage) => {
|
|
||||||
let storage = storage.elu(layout, alpha)?;
|
|
||||||
Ok(Self::Metal(storage))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -142,10 +112,6 @@ impl Storage {
|
|||||||
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
|
||||||
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
|
||||||
Ok(Self::Metal(storage))
|
|
||||||
}
|
|
||||||
(lhs, rhs) => {
|
(lhs, rhs) => {
|
||||||
// Should not happen because of the same device check above but we're defensive
|
// Should not happen because of the same device check above but we're defensive
|
||||||
// anyway.
|
// anyway.
|
||||||
@ -169,10 +135,6 @@ impl Storage {
|
|||||||
let storage = storage.reduce_op(op, layout, s)?;
|
let storage = storage.reduce_op(op, layout, s)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
Self::Metal(storage) => {
|
|
||||||
let storage = storage.reduce_op(op, layout, s)?;
|
|
||||||
Ok(Self::Metal(storage))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -186,10 +148,6 @@ impl Storage {
|
|||||||
let storage = storage.to_dtype(layout, dtype)?;
|
let storage = storage.to_dtype(layout, dtype)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
Self::Metal(storage) => {
|
|
||||||
let storage = storage.to_dtype(layout, dtype)?;
|
|
||||||
Ok(Self::Metal(storage))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -203,10 +161,6 @@ impl Storage {
|
|||||||
let (storage, shape) = c.cuda_fwd(storage, l)?;
|
let (storage, shape) = c.cuda_fwd(storage, l)?;
|
||||||
Ok((Self::Cuda(storage), shape))
|
Ok((Self::Cuda(storage), shape))
|
||||||
}
|
}
|
||||||
Self::Metal(storage) => {
|
|
||||||
let (storage, shape) = c.metal_fwd(storage, l)?;
|
|
||||||
Ok((Self::Metal(storage), shape))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -227,10 +181,6 @@ impl Storage {
|
|||||||
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
|
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
|
||||||
Ok((Self::Cuda(s), shape))
|
Ok((Self::Cuda(s), shape))
|
||||||
}
|
}
|
||||||
(Self::Metal(s1), Self::Metal(s2)) => {
|
|
||||||
let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?;
|
|
||||||
Ok((Self::Metal(s), shape))
|
|
||||||
}
|
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -255,55 +205,6 @@ impl Storage {
|
|||||||
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
|
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
|
||||||
Ok((Self::Cuda(s), shape))
|
Ok((Self::Cuda(s), shape))
|
||||||
}
|
}
|
||||||
(Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
|
|
||||||
let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?;
|
|
||||||
Ok((Self::Metal(s), shape))
|
|
||||||
}
|
|
||||||
_ => unreachable!(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn inplace_op1(&mut self, l: &Layout, c: &dyn InplaceOp1) -> Result<()> {
|
|
||||||
match self {
|
|
||||||
Self::Cpu(storage) => c.cpu_fwd(storage, l),
|
|
||||||
Self::Cuda(storage) => c.cuda_fwd(storage, l),
|
|
||||||
Self::Metal(storage) => c.metal_fwd(storage, l),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn inplace_op2(
|
|
||||||
&mut self,
|
|
||||||
l1: &Layout,
|
|
||||||
t2: &Self,
|
|
||||||
l2: &Layout,
|
|
||||||
c: &dyn InplaceOp2,
|
|
||||||
) -> Result<()> {
|
|
||||||
self.same_device(t2, c.name())?;
|
|
||||||
match (self, t2) {
|
|
||||||
(Self::Cpu(s1), Self::Cpu(s2)) => c.cpu_fwd(s1, l1, s2, l2),
|
|
||||||
(Self::Cuda(s1), Self::Cuda(s2)) => c.cuda_fwd(s1, l1, s2, l2),
|
|
||||||
(Self::Metal(s1), Self::Metal(s2)) => c.metal_fwd(s1, l1, s2, l2),
|
|
||||||
_ => unreachable!(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn inplace_op3(
|
|
||||||
&mut self,
|
|
||||||
l1: &Layout,
|
|
||||||
t2: &Self,
|
|
||||||
l2: &Layout,
|
|
||||||
t3: &Self,
|
|
||||||
l3: &Layout,
|
|
||||||
c: &dyn InplaceOp3,
|
|
||||||
) -> Result<()> {
|
|
||||||
self.same_device(t2, c.name())?;
|
|
||||||
self.same_device(t3, c.name())?;
|
|
||||||
match (self, t2, t3) {
|
|
||||||
(Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => c.cpu_fwd(s1, l1, s2, l2, s3, l3),
|
|
||||||
(Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => c.cuda_fwd(s1, l1, s2, l2, s3, l3),
|
|
||||||
(Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
|
|
||||||
c.metal_fwd(s1, l1, s2, l2, s3, l3)
|
|
||||||
}
|
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -318,10 +219,6 @@ impl Storage {
|
|||||||
let storage = storage.unary_impl::<B>(layout)?;
|
let storage = storage.unary_impl::<B>(layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
Self::Metal(storage) => {
|
|
||||||
let storage = storage.unary_impl::<B>(layout)?;
|
|
||||||
Ok(Self::Metal(storage))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -342,10 +239,6 @@ impl Storage {
|
|||||||
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
|
||||||
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
|
||||||
Ok(Self::Metal(storage))
|
|
||||||
}
|
|
||||||
(lhs, rhs) => {
|
(lhs, rhs) => {
|
||||||
// Should not happen because of the same device check above but we're defensive
|
// Should not happen because of the same device check above but we're defensive
|
||||||
// anyway.
|
// anyway.
|
||||||
@ -377,10 +270,6 @@ impl Storage {
|
|||||||
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
||||||
Ok(Self::Cuda(s))
|
Ok(Self::Cuda(s))
|
||||||
}
|
}
|
||||||
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
|
||||||
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
|
||||||
Ok(Self::Metal(s))
|
|
||||||
}
|
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -390,37 +279,6 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn conv_transpose1d(
|
|
||||||
&self,
|
|
||||||
l: &Layout,
|
|
||||||
kernel: &Self,
|
|
||||||
kernel_l: &Layout,
|
|
||||||
params: &crate::conv::ParamsConvTranspose1D,
|
|
||||||
) -> Result<Self> {
|
|
||||||
self.same_device(kernel, "conv-transpose1d")?;
|
|
||||||
self.same_dtype(kernel, "conv-transpose1d")?;
|
|
||||||
match (self, &kernel) {
|
|
||||||
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
|
|
||||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
|
||||||
Ok(Self::Cpu(s))
|
|
||||||
}
|
|
||||||
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
|
|
||||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
|
||||||
Ok(Self::Cuda(s))
|
|
||||||
}
|
|
||||||
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
|
||||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
|
||||||
Ok(Self::Metal(s))
|
|
||||||
}
|
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
|
||||||
lhs: lhs.device().location(),
|
|
||||||
rhs: rhs.device().location(),
|
|
||||||
op: "conv-transpose1d",
|
|
||||||
}
|
|
||||||
.bt()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn conv2d(
|
pub(crate) fn conv2d(
|
||||||
&self,
|
&self,
|
||||||
l: &Layout,
|
l: &Layout,
|
||||||
@ -439,10 +297,6 @@ impl Storage {
|
|||||||
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
||||||
Ok(Self::Cuda(s))
|
Ok(Self::Cuda(s))
|
||||||
}
|
}
|
||||||
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
|
||||||
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
|
||||||
Ok(Self::Metal(s))
|
|
||||||
}
|
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -470,10 +324,6 @@ impl Storage {
|
|||||||
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
||||||
Ok(Self::Cuda(s))
|
Ok(Self::Cuda(s))
|
||||||
}
|
}
|
||||||
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
|
||||||
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
|
||||||
Ok(Self::Metal(s))
|
|
||||||
}
|
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -498,10 +348,6 @@ impl Storage {
|
|||||||
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
Self::Metal(storage) => {
|
|
||||||
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
|
||||||
Ok(Self::Metal(storage))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -520,10 +366,6 @@ impl Storage {
|
|||||||
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
Self::Metal(storage) => {
|
|
||||||
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
|
||||||
Ok(Self::Metal(storage))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -537,10 +379,6 @@ impl Storage {
|
|||||||
let storage = storage.upsample_nearest1d(layout, sz)?;
|
let storage = storage.upsample_nearest1d(layout, sz)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
Self::Metal(storage) => {
|
|
||||||
let storage = storage.upsample_nearest1d(layout, sz)?;
|
|
||||||
Ok(Self::Metal(storage))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -554,10 +392,6 @@ impl Storage {
|
|||||||
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
Self::Metal(storage) => {
|
|
||||||
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
|
||||||
Ok(Self::Metal(storage))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -581,10 +415,6 @@ impl Storage {
|
|||||||
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
(Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => {
|
|
||||||
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
|
||||||
Ok(Self::Metal(storage))
|
|
||||||
}
|
|
||||||
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -611,10 +441,6 @@ impl Storage {
|
|||||||
let storage = s.gather(l, indexes, indexes_l, d)?;
|
let storage = s.gather(l, indexes, indexes_l, d)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
(Self::Metal(s), Self::Metal(indexes)) => {
|
|
||||||
let storage = s.gather(l, indexes, indexes_l, d)?;
|
|
||||||
Ok(Self::Metal(storage))
|
|
||||||
}
|
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -639,10 +465,6 @@ impl Storage {
|
|||||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
|
||||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
|
||||||
Ok(Self::Metal(storage))
|
|
||||||
}
|
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -667,10 +489,6 @@ impl Storage {
|
|||||||
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
|
||||||
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
|
||||||
Ok(Self::Metal(storage))
|
|
||||||
}
|
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -692,10 +510,6 @@ impl Storage {
|
|||||||
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
|
||||||
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
|
||||||
Ok(Self::Metal(storage))
|
|
||||||
}
|
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -723,10 +537,6 @@ impl Storage {
|
|||||||
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
|
||||||
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
|
||||||
Ok(Self::Metal(storage))
|
|
||||||
}
|
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -746,9 +556,6 @@ impl Storage {
|
|||||||
match (self, dst) {
|
match (self, dst) {
|
||||||
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
|
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
|
||||||
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
|
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
|
||||||
(Self::Metal(src), Self::Metal(dst)) => {
|
|
||||||
Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
|
|
||||||
}
|
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -757,32 +564,4 @@ impl Storage {
|
|||||||
.bt()),
|
.bt()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub(crate) fn copy2d(
|
|
||||||
&self,
|
|
||||||
dst: &mut Self,
|
|
||||||
d1: usize,
|
|
||||||
d2: usize,
|
|
||||||
src_s: usize,
|
|
||||||
dst_s: usize,
|
|
||||||
src_o: usize,
|
|
||||||
dst_o: usize,
|
|
||||||
) -> Result<()> {
|
|
||||||
match (self, dst) {
|
|
||||||
(Self::Cpu(src), Self::Cpu(dst)) => src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o),
|
|
||||||
(Self::Cuda(src), Self::Cuda(dst)) => {
|
|
||||||
Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
|
|
||||||
}
|
|
||||||
(Self::Metal(src), Self::Metal(dst)) => {
|
|
||||||
Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
|
|
||||||
}
|
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
|
||||||
lhs: lhs.device().location(),
|
|
||||||
rhs: rhs.device().location(),
|
|
||||||
op: "copy2d",
|
|
||||||
}
|
|
||||||
.bt()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,300 +0,0 @@
|
|||||||
use crate::{shape::Dim, Error, Result, Shape, Tensor};
|
|
||||||
|
|
||||||
impl Tensor {
|
|
||||||
/// Concatenates two or more tensors along a particular dimension.
|
|
||||||
///
|
|
||||||
/// All tensors must of the same rank, and the output will have
|
|
||||||
/// the same rank
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// # use candle_core::{Tensor, DType, Device};
|
|
||||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
|
||||||
/// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
|
||||||
///
|
|
||||||
/// let c = Tensor::cat(&[&a, &b], 0)?;
|
|
||||||
/// assert_eq!(c.shape().dims(), &[4, 3]);
|
|
||||||
///
|
|
||||||
/// let c = Tensor::cat(&[&a, &b], 1)?;
|
|
||||||
/// assert_eq!(c.shape().dims(), &[2, 6]);
|
|
||||||
/// # Ok::<(), candle_core::Error>(())
|
|
||||||
/// ```
|
|
||||||
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
|
|
||||||
if args.is_empty() {
|
|
||||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
|
||||||
}
|
|
||||||
let arg0 = args[0].as_ref();
|
|
||||||
if args.len() == 1 {
|
|
||||||
return Ok(arg0.clone());
|
|
||||||
}
|
|
||||||
let dim = dim.to_index(arg0.shape(), "cat")?;
|
|
||||||
for arg in args {
|
|
||||||
arg.as_ref().check_dim(dim, "cat")?;
|
|
||||||
}
|
|
||||||
for (arg_idx, arg) in args.iter().enumerate() {
|
|
||||||
let arg = arg.as_ref();
|
|
||||||
if arg0.rank() != arg.rank() {
|
|
||||||
Err(Error::UnexpectedNumberOfDims {
|
|
||||||
expected: arg0.rank(),
|
|
||||||
got: arg.rank(),
|
|
||||||
shape: arg.shape().clone(),
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
for (dim_idx, (v1, v2)) in arg0
|
|
||||||
.shape()
|
|
||||||
.dims()
|
|
||||||
.iter()
|
|
||||||
.zip(arg.shape().dims().iter())
|
|
||||||
.enumerate()
|
|
||||||
{
|
|
||||||
if dim_idx != dim && v1 != v2 {
|
|
||||||
Err(Error::ShapeMismatchCat {
|
|
||||||
dim: dim_idx,
|
|
||||||
first_shape: arg0.shape().clone(),
|
|
||||||
n: arg_idx + 1,
|
|
||||||
nth_shape: arg.shape().clone(),
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
|
|
||||||
if all_contiguous {
|
|
||||||
Self::cat_contiguous(args, dim)
|
|
||||||
} else if dim == 0 {
|
|
||||||
Self::cat0(args)
|
|
||||||
} else {
|
|
||||||
let args: Vec<Tensor> = args
|
|
||||||
.iter()
|
|
||||||
.map(|a| a.as_ref().transpose(0, dim))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
let cat = Self::cat0(&args)?;
|
|
||||||
cat.transpose(0, dim)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
|
|
||||||
if args.is_empty() {
|
|
||||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
|
||||||
}
|
|
||||||
let arg0 = args[0].as_ref();
|
|
||||||
if args.len() == 1 {
|
|
||||||
return Ok(arg0.clone());
|
|
||||||
}
|
|
||||||
let rank = arg0.rank();
|
|
||||||
let device = arg0.device();
|
|
||||||
let dtype = arg0.dtype();
|
|
||||||
let first_dims = arg0.shape().dims();
|
|
||||||
let mut cat_dims = first_dims.to_vec();
|
|
||||||
cat_dims[0] = 0;
|
|
||||||
let mut offsets = vec![0usize];
|
|
||||||
for (arg_idx, arg) in args.iter().enumerate() {
|
|
||||||
let arg = arg.as_ref();
|
|
||||||
if arg.dtype() != dtype {
|
|
||||||
Err(Error::DTypeMismatchBinaryOp {
|
|
||||||
lhs: dtype,
|
|
||||||
rhs: arg.dtype(),
|
|
||||||
op: "cat",
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
if arg.device().location() != device.location() {
|
|
||||||
Err(Error::DeviceMismatchBinaryOp {
|
|
||||||
lhs: device.location(),
|
|
||||||
rhs: arg.device().location(),
|
|
||||||
op: "cat",
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
if rank != arg.rank() {
|
|
||||||
Err(Error::UnexpectedNumberOfDims {
|
|
||||||
expected: rank,
|
|
||||||
got: arg.rank(),
|
|
||||||
shape: arg.shape().clone(),
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
for (dim_idx, (v1, v2)) in arg0
|
|
||||||
.shape()
|
|
||||||
.dims()
|
|
||||||
.iter()
|
|
||||||
.zip(arg.shape().dims().iter())
|
|
||||||
.enumerate()
|
|
||||||
{
|
|
||||||
if dim_idx == 0 {
|
|
||||||
cat_dims[0] += v2;
|
|
||||||
}
|
|
||||||
if dim_idx != 0 && v1 != v2 {
|
|
||||||
Err(Error::ShapeMismatchCat {
|
|
||||||
dim: dim_idx,
|
|
||||||
first_shape: arg0.shape().clone(),
|
|
||||||
n: arg_idx + 1,
|
|
||||||
nth_shape: arg.shape().clone(),
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let next_offset = offsets.last().unwrap() + arg.elem_count();
|
|
||||||
offsets.push(next_offset);
|
|
||||||
}
|
|
||||||
let shape = Shape::from(cat_dims);
|
|
||||||
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, 0));
|
|
||||||
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
|
|
||||||
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
|
||||||
let arg = arg.as_ref();
|
|
||||||
arg.storage()
|
|
||||||
.copy_strided_src(&mut storage, offset, arg.layout())?;
|
|
||||||
}
|
|
||||||
Ok(crate::tensor::from_storage(storage, shape, op, false))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cat_contiguous<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
|
|
||||||
if args.is_empty() {
|
|
||||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
|
||||||
}
|
|
||||||
let arg0 = args[0].as_ref();
|
|
||||||
if args.len() == 1 {
|
|
||||||
return Ok(arg0.clone());
|
|
||||||
}
|
|
||||||
let rank = arg0.rank();
|
|
||||||
let device = arg0.device();
|
|
||||||
let dtype = arg0.dtype();
|
|
||||||
let first_dims = arg0.shape().dims();
|
|
||||||
let mut cat_dims = first_dims.to_vec();
|
|
||||||
cat_dims[dim] = 0;
|
|
||||||
for (arg_idx, arg) in args.iter().enumerate() {
|
|
||||||
let arg = arg.as_ref();
|
|
||||||
if arg.dtype() != dtype {
|
|
||||||
Err(Error::DTypeMismatchBinaryOp {
|
|
||||||
lhs: dtype,
|
|
||||||
rhs: arg.dtype(),
|
|
||||||
op: "cat",
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
if arg.device().location() != device.location() {
|
|
||||||
Err(Error::DeviceMismatchBinaryOp {
|
|
||||||
lhs: device.location(),
|
|
||||||
rhs: arg.device().location(),
|
|
||||||
op: "cat",
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
if rank != arg.rank() {
|
|
||||||
Err(Error::UnexpectedNumberOfDims {
|
|
||||||
expected: rank,
|
|
||||||
got: arg.rank(),
|
|
||||||
shape: arg.shape().clone(),
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
for (dim_idx, (v1, v2)) in arg0
|
|
||||||
.shape()
|
|
||||||
.dims()
|
|
||||||
.iter()
|
|
||||||
.zip(arg.shape().dims().iter())
|
|
||||||
.enumerate()
|
|
||||||
{
|
|
||||||
if dim_idx == dim {
|
|
||||||
cat_dims[dim] += v2;
|
|
||||||
}
|
|
||||||
if dim_idx != dim && v1 != v2 {
|
|
||||||
Err(Error::ShapeMismatchCat {
|
|
||||||
dim: dim_idx,
|
|
||||||
first_shape: arg0.shape().clone(),
|
|
||||||
n: arg_idx + 1,
|
|
||||||
nth_shape: arg.shape().clone(),
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let cat_target_dim_len = cat_dims[dim];
|
|
||||||
let block_size: usize = cat_dims.iter().skip(1 + dim).product();
|
|
||||||
let shape = Shape::from(cat_dims);
|
|
||||||
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, dim));
|
|
||||||
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
|
|
||||||
let mut dst_o = 0;
|
|
||||||
for arg in args.iter() {
|
|
||||||
let arg = arg.as_ref();
|
|
||||||
let arg_dims = arg.shape().dims();
|
|
||||||
let d1: usize = arg_dims.iter().take(dim).product();
|
|
||||||
let d2 = block_size * arg_dims[dim];
|
|
||||||
let dst_s = block_size * cat_target_dim_len;
|
|
||||||
let src_o = arg.layout().start_offset();
|
|
||||||
arg.storage().copy2d(
|
|
||||||
&mut storage,
|
|
||||||
d1,
|
|
||||||
d2,
|
|
||||||
/* src_s */ d2,
|
|
||||||
dst_s,
|
|
||||||
src_o,
|
|
||||||
dst_o,
|
|
||||||
)?;
|
|
||||||
dst_o += d2;
|
|
||||||
}
|
|
||||||
Ok(crate::tensor::from_storage(storage, shape, op, false))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set the values on `self` using values from `src`. The copy starts at the specified
|
|
||||||
/// `offset` for the target dimension `dim` on `self`.
|
|
||||||
/// `self` and `src` must have the same shape except on dimension `dim` where the `self` size
|
|
||||||
/// has to be greater than or equal to `offset` plus the `src` size.
|
|
||||||
///
|
|
||||||
/// Note that this modifies `self` in place and as such is not compatibel with
|
|
||||||
/// back-propagation.
|
|
||||||
pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
|
|
||||||
let dim = dim.to_index(self.shape(), "slice-set")?;
|
|
||||||
if !self.is_contiguous() || !src.is_contiguous() {
|
|
||||||
Err(Error::RequiresContiguous { op: "slice-set" }.bt())?
|
|
||||||
}
|
|
||||||
if self.dtype() != src.dtype() {
|
|
||||||
Err(Error::DTypeMismatchBinaryOp {
|
|
||||||
lhs: self.dtype(),
|
|
||||||
rhs: src.dtype(),
|
|
||||||
op: "slice-set",
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
if self.device().location() != src.device().location() {
|
|
||||||
Err(Error::DeviceMismatchBinaryOp {
|
|
||||||
lhs: self.device().location(),
|
|
||||||
rhs: src.device().location(),
|
|
||||||
op: "slice-set",
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
if self.rank() != src.rank() {
|
|
||||||
Err(Error::UnexpectedNumberOfDims {
|
|
||||||
expected: self.rank(),
|
|
||||||
got: src.rank(),
|
|
||||||
shape: self.shape().clone(),
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
for (dim_idx, (v1, v2)) in self.dims().iter().zip(src.dims().iter()).enumerate() {
|
|
||||||
if dim_idx == dim && *v2 + offset > *v1 {
|
|
||||||
crate::bail!("shape mismatch on target dim, dst: {v1}, src: {v2} + {offset}")
|
|
||||||
}
|
|
||||||
if dim_idx != dim && v1 != v2 {
|
|
||||||
crate::bail!("shape mismatch on dim {dim_idx}, {v1} <> {v2}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let block_size: usize = src.dims().iter().skip(1 + dim).product();
|
|
||||||
let d1: usize = src.dims().iter().take(dim).product();
|
|
||||||
let d2 = block_size * src.dims()[dim];
|
|
||||||
let dst_o = self.layout().start_offset() + offset * block_size;
|
|
||||||
let src_o = src.layout().start_offset();
|
|
||||||
src.storage().copy2d(
|
|
||||||
&mut self.storage_mut(),
|
|
||||||
d1,
|
|
||||||
d2,
|
|
||||||
/* src_s */ d2,
|
|
||||||
/* dst_s */ block_size * self.dims()[dim],
|
|
||||||
src_o,
|
|
||||||
dst_o,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
@ -4,7 +4,7 @@ use crate::{Result, Tensor};
|
|||||||
macro_rules! test_device {
|
macro_rules! test_device {
|
||||||
// TODO: Switch to generating the two last arguments automatically once concat_idents is
|
// TODO: Switch to generating the two last arguments automatically once concat_idents is
|
||||||
// stable. https://github.com/rust-lang/rust/issues/29599
|
// stable. https://github.com/rust-lang/rust/issues/29599
|
||||||
($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => {
|
($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => {
|
||||||
#[test]
|
#[test]
|
||||||
fn $test_cpu() -> Result<()> {
|
fn $test_cpu() -> Result<()> {
|
||||||
$fn_name(&Device::Cpu)
|
$fn_name(&Device::Cpu)
|
||||||
@ -15,12 +15,6 @@ macro_rules! test_device {
|
|||||||
fn $test_cuda() -> Result<()> {
|
fn $test_cuda() -> Result<()> {
|
||||||
$fn_name(&Device::new_cuda(0)?)
|
$fn_name(&Device::new_cuda(0)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
#[test]
|
|
||||||
fn $test_metal() -> Result<()> {
|
|
||||||
$fn_name(&Device::new_metal(0)?)
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -23,10 +23,6 @@ pub fn cuda_is_available() -> bool {
|
|||||||
cfg!(feature = "cuda")
|
cfg!(feature = "cuda")
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn metal_is_available() -> bool {
|
|
||||||
cfg!(feature = "metal")
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn with_avx() -> bool {
|
pub fn with_avx() -> bool {
|
||||||
cfg!(target_feature = "avx")
|
cfg!(target_feature = "avx")
|
||||||
}
|
}
|
||||||
|
@ -34,14 +34,9 @@ impl Var {
|
|||||||
Ok(Self(inner))
|
Ok(Self(inner))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert a tensor to a variable, if the tensor is already a variable then it is returned as is.
|
|
||||||
pub fn from_tensor(t: &Tensor) -> Result<Self> {
|
pub fn from_tensor(t: &Tensor) -> Result<Self> {
|
||||||
if t.is_variable() {
|
let inner = t.make_var()?;
|
||||||
Ok(Self(t.clone()))
|
Ok(Self(inner))
|
||||||
} else {
|
|
||||||
let inner = t.make_var()?;
|
|
||||||
Ok(Self(inner))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn rand_f64<S: Into<Shape>>(
|
pub fn rand_f64<S: Into<Shape>>(
|
||||||
@ -112,10 +107,6 @@ impl Var {
|
|||||||
Ok(Self(inner))
|
Ok(Self(inner))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn as_detached_tensor(&self) -> Tensor {
|
|
||||||
self.0.detach()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn as_tensor(&self) -> &Tensor {
|
pub fn as_tensor(&self) -> &Tensor {
|
||||||
&self.0
|
&self.0
|
||||||
}
|
}
|
||||||
|
@ -13,14 +13,6 @@ res = torch.nn.functional.conv1d(t, w)
|
|||||||
print(res.flatten())
|
print(res.flatten())
|
||||||
res = torch.nn.functional.conv1d(t, w, padding=1)
|
res = torch.nn.functional.conv1d(t, w, padding=1)
|
||||||
print(res.flatten())
|
print(res.flatten())
|
||||||
|
|
||||||
w_t = w.transpose(0, 1)
|
|
||||||
res = torch.nn.functional.conv_transpose1d(t, w_t)
|
|
||||||
print(res.shape)
|
|
||||||
print(res)
|
|
||||||
res = torch.nn.functional.conv_transpose1d(t, w_t, groups=2)
|
|
||||||
print(res.shape)
|
|
||||||
print(res)
|
|
||||||
*/
|
*/
|
||||||
fn conv1d(dev: &Device) -> Result<()> {
|
fn conv1d(dev: &Device) -> Result<()> {
|
||||||
let t = Tensor::new(
|
let t = Tensor::new(
|
||||||
@ -53,31 +45,6 @@ fn conv1d(dev: &Device) -> Result<()> {
|
|||||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||||
);
|
);
|
||||||
|
|
||||||
let w = w.transpose(0, 1)?;
|
|
||||||
// The CPU kernels applied in the contiguous and non contiguous cases are different.
|
|
||||||
for w in [w.clone(), w.contiguous()?] {
|
|
||||||
let res = t.conv_transpose1d(&w, 0, 0, 1, 1, 1)?;
|
|
||||||
assert_eq!(res.dims(), [1, 2, 7]);
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
|
||||||
[
|
|
||||||
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
|
|
||||||
4.7076, -5.9745, -0.8276, 1.621
|
|
||||||
],
|
|
||||||
);
|
|
||||||
let res = t.conv_transpose1d(&w, 0, 0, 1, 1, 2)?;
|
|
||||||
assert_eq!(res.dims(), [1, 4, 7]);
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec2_round(&res.squeeze(0)?, 4)?,
|
|
||||||
[
|
|
||||||
[-1.5596, -1.8099, 2.0407, 4.8764, -0.1743, -0.735, -0.7819],
|
|
||||||
[0.7816, 3.8152, -0.5926, 2.2515, -5.1844, -0.3157, 1.4721],
|
|
||||||
[1.6295, 0.52, 6.2611, 0.7109, 2.6315, -1.8793, 0.7113],
|
|
||||||
[1.0949, 1.0166, 1.7464, 2.4561, -0.79, -0.5119, 0.1488]
|
|
||||||
]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -135,7 +102,7 @@ fn conv2d(dev: &Device) -> Result<()> {
|
|||||||
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
|
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
|
||||||
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
|
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
|
||||||
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
|
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
|
||||||
-0.8, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
||||||
],
|
],
|
||||||
dev,
|
dev,
|
||||||
)?;
|
)?;
|
||||||
@ -163,9 +130,7 @@ fn conv2d(dev: &Device) -> Result<()> {
|
|||||||
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||||
|
|
||||||
assert_eq!(res.dims(), [1, 2, 7, 7]);
|
assert_eq!(res.dims(), [1, 2, 7, 7]);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
||||||
@ -190,7 +155,6 @@ fn conv2d(dev: &Device) -> Result<()> {
|
|||||||
]
|
]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
// Dilations.
|
// Dilations.
|
||||||
let res = t.conv2d(&w, 0, 1, 2, 1)?;
|
let res = t.conv2d(&w, 0, 1, 2, 1)?;
|
||||||
assert_eq!(res.dims(), [1, 2, 1, 1]);
|
assert_eq!(res.dims(), [1, 2, 1, 1]);
|
||||||
@ -229,7 +193,6 @@ fn conv2d(dev: &Device) -> Result<()> {
|
|||||||
]
|
]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -276,13 +239,13 @@ fn conv2d_small(dev: &Device) -> Result<()> {
|
|||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||||
[
|
[
|
||||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1640,
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
||||||
-0.0111, -0.1742, 0.0, 0.0, 0.0, 0.0, 2.6437, -2.0268, 1.1823, 0.0, 0.0, 0.0, 0.0,
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1640, -0.0111, -0.1742, 0.0000, 0.0000,
|
||||||
3.2855, -1.0324, 0.2539, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
0.0000, 0.0000, 2.6437, -2.0268, 1.1823, 0.0000, 0.0000, 0.0000, 0.0000, 3.2855,
|
||||||
0.0, 0.0, 0.0, 0.0
|
-1.0324, 0.2539, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
||||||
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||||
assert_eq!(res.dims(), [1, 1, 3, 3]);
|
assert_eq!(res.dims(), [1, 1, 3, 3]);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -384,7 +347,6 @@ print(w.grad.shape)
|
|||||||
print(w.grad[0])
|
print(w.grad[0])
|
||||||
*/
|
*/
|
||||||
fn conv2d_grad(dev: &Device) -> Result<()> {
|
fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||||
// conv-transposes are not implemented for metal
|
|
||||||
use candle_core::Var;
|
use candle_core::Var;
|
||||||
let t = Var::from_slice(
|
let t = Var::from_slice(
|
||||||
&[
|
&[
|
||||||
@ -397,7 +359,7 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
|||||||
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
|
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
|
||||||
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
|
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
|
||||||
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
|
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
|
||||||
-0.8, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
||||||
],
|
],
|
||||||
(1, 4, 5, 5),
|
(1, 4, 5, 5),
|
||||||
dev,
|
dev,
|
||||||
@ -517,251 +479,17 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
|||||||
]
|
]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
// Replicate the issue from https://github.com/huggingface/candle/issues/1212
|
|
||||||
let res = t.i((.., .., 0..4, 0..4))?.conv2d(&w, 0, 2, 1, 1)?;
|
|
||||||
let loss = res.sqr()?.sum_all()?;
|
|
||||||
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 21.12f32);
|
|
||||||
let grads = loss.backward()?;
|
|
||||||
let grad_t = grads.get(&t).unwrap();
|
|
||||||
let grad_w = grads.get(&w).unwrap();
|
|
||||||
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
|
|
||||||
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec3_round(&grad_t.i(0)?, 2)?,
|
|
||||||
[
|
|
||||||
[
|
|
||||||
[9.29, -7.03, 7.87, 0.0, 0.0],
|
|
||||||
[-1.8, -7.82, 5.9, 0.0, 0.0],
|
|
||||||
[-3.12, 4.49, 5.52, 0.0, 0.0],
|
|
||||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
|
||||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
|
||||||
],
|
|
||||||
[
|
|
||||||
[21.73, 3.39, 4.77, 0.0, 0.0],
|
|
||||||
[8.25, 3.73, 27.61, 0.0, 0.0],
|
|
||||||
[-20.55, -5.61, -2.77, 0.0, 0.0],
|
|
||||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
|
||||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
|
||||||
],
|
|
||||||
[
|
|
||||||
[-8.98, 9.91, -7.15, 0.0, 0.0],
|
|
||||||
[4.93, -0.33, 4.56, 0.0, 0.0],
|
|
||||||
[-6.7, -5.76, -8.05, 0.0, 0.0],
|
|
||||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
|
||||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
|
||||||
],
|
|
||||||
[
|
|
||||||
[23.54, 6.98, -10.0, 0.0, 0.0],
|
|
||||||
[9.65, 6.18, 18.72, 0.0, 0.0],
|
|
||||||
[3.29, -5.27, 0.79, 0.0, 0.0],
|
|
||||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
|
||||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
|
||||||
]
|
|
||||||
]
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec3_round(&grad_w.i(0)?, 2)?,
|
|
||||||
[
|
|
||||||
[
|
|
||||||
[-3.47, 7.44, 0.66],
|
|
||||||
[12.89, -3.4, -9.29],
|
|
||||||
[-14.16, -0.83, 7.14]
|
|
||||||
],
|
|
||||||
[
|
|
||||||
[-3.23, 5.37, -3.02],
|
|
||||||
[-2.12, -11.24, 1.94],
|
|
||||||
[6.97, 7.2, 2.99]
|
|
||||||
],
|
|
||||||
[
|
|
||||||
[-4.04, -3.31, 4.87],
|
|
||||||
[-6.68, -5.68, 1.73],
|
|
||||||
[-5.54, 4.32, 0.52]
|
|
||||||
],
|
|
||||||
[[-4.72, 1.5, 4.72], [3.79, 4.04, 6.76], [-4.6, 5.8, 6.93]]
|
|
||||||
]
|
|
||||||
);
|
|
||||||
|
|
||||||
// Conv Transpose 2d Test
|
|
||||||
//tested against following python
|
|
||||||
|
|
||||||
// import torch
|
|
||||||
// torch.manual_seed(4242)
|
|
||||||
// padding = 4
|
|
||||||
// outpadding = 2
|
|
||||||
// dilation = 3
|
|
||||||
// stride = 3
|
|
||||||
// input = torch.randn((1, 4, 7, 5), requires_grad=True)
|
|
||||||
// kernel = torch.randn((4, 2, 3, 5), requires_grad=True)
|
|
||||||
// print("input", input.flatten())
|
|
||||||
// print("kernel", kernel.flatten())
|
|
||||||
// res = torch.nn.functional.conv_transpose2d(
|
|
||||||
// input,
|
|
||||||
// kernel,
|
|
||||||
// stride=stride,
|
|
||||||
// padding=padding,
|
|
||||||
// dilation=dilation,
|
|
||||||
// output_padding=outpadding,
|
|
||||||
// )
|
|
||||||
// res.retain_grad()
|
|
||||||
// print(res.shape)
|
|
||||||
// loss = (res**2).sum()
|
|
||||||
// print(loss)
|
|
||||||
// loss.backward()
|
|
||||||
// print(input.grad.shape)
|
|
||||||
// print("input grad", torch.round(input.grad, decimals=1))
|
|
||||||
// print(kernel.grad.shape)
|
|
||||||
// print("kernel grad", torch.round(kernel.grad.flatten(), decimals=1))
|
|
||||||
|
|
||||||
let padding = 4;
|
|
||||||
let outpadding = 2;
|
|
||||||
let dilation = 3;
|
|
||||||
let stride = 3;
|
|
||||||
|
|
||||||
let t = Var::from_slice(
|
|
||||||
&[
|
|
||||||
0.4056_f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997,
|
|
||||||
3.0616, 1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843,
|
|
||||||
0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013,
|
|
||||||
-0.6836, 0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130,
|
|
||||||
1.3123, 1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071,
|
|
||||||
1.1586, 0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090,
|
|
||||||
0.2049, 0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323,
|
|
||||||
-1.3712, 0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742,
|
|
||||||
0.3790, -0.4431, -0.4720, -0.7890, 0.2620, 0.5411, -1.1715, -2.4997, 2.3249, -0.8912,
|
|
||||||
-0.4733, -0.5701, -2.8888, -1.4112, -0.5471, -0.9234, -1.1660, 0.4189, -0.7465,
|
|
||||||
-0.6473, 0.1402, 0.7875, 0.5377, -0.6779, -0.8088, -0.4864, -0.2312, 0.9279, 0.1264,
|
|
||||||
1.5480, 0.8265, -0.1025, 0.5138, -0.2512, 0.1576, 1.2705, 0.3641, -0.9325, 0.6451,
|
|
||||||
-0.8537, 0.2378, 0.1794, 0.2752, -0.3687, -1.1149, -0.1410, -0.5829, -0.0892, 1.4258,
|
|
||||||
-2.2789, 0.5270, 0.1825, 1.7007, -0.5263, -0.2954, 0.4440, 0.5537, 0.3492, 0.6186,
|
|
||||||
1.6475, 0.2219,
|
|
||||||
],
|
|
||||||
(1, 4, 7, 5),
|
|
||||||
dev,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
#[rustfmt::skip]
|
|
||||||
let w = Var::from_slice(
|
|
||||||
&[
|
|
||||||
-1.1744_f32, 0.3266, 2.5893, 1.0142, 0.1763, 0.7752, 0.6604, 0.2029, -0.2145, 0.7234,
|
|
||||||
-0.3441, -1.5400, -0.6333, 0.6613, 0.2083, 0.6230, -1.7002, 0.3393, 0.4049, 1.0762,
|
|
||||||
0.2723, 1.4181, 0.0029, -0.2122, 1.7668, 1.4168, 0.3320, -0.2719, 0.7932, -0.7204,
|
|
||||||
0.4447, 0.1211, 0.5908, 1.0089, -0.1646, 1.8033, -0.6286, 0.2016, -0.3370, 1.2555,
|
|
||||||
0.8009, -0.6488, -0.4652, -1.5685, 1.5860, 0.5583, 0.4623, 0.6026, 0.8828, 2.4990,
|
|
||||||
0.6811, -0.3369, 1.3320, 1.7669, -1.1067, 1.2958, -0.9415, -0.9655, -0.4462, 0.7181,
|
|
||||||
0.5181, -1.1658, -1.8467, -0.7763, 1.2769, 0.8651, 0.9890, 1.5092, 0.7207, -0.8481,
|
|
||||||
0.7417, 0.3375, -1.2685, 1.4572, 1.0915, 0.1093, -0.8550, -0.5831, -0.6309, -0.2509,
|
|
||||||
0.5220, -0.0914, 0.7900, 0.1096, 0.3258, 0.2723, -1.0942, -0.3393, -0.1653, 0.5732,
|
|
||||||
-0.8014, 1.8194, -1.9023, 0.2127, 1.8636, -0.8979, 0.1927, -0.2778, 0.3105, 0.0071,
|
|
||||||
-1.1823, 0.2476, -0.7178, -1.3821, 1.0769, -0.4376, -0.9967, -0.1227, 1.6197, -1.0604,
|
|
||||||
0.1372, 0.8141, -0.6163, 0.7304, -0.8285, 2.0636, -0.7176, 0.2495, -0.2581, -0.4478,
|
|
||||||
],
|
|
||||||
(4, 2, 3, 5),
|
|
||||||
dev,
|
|
||||||
)?;
|
|
||||||
let res = t.conv_transpose2d(&w, padding, outpadding, stride, dilation)?;
|
|
||||||
let loss = res.sqr()?.sum_all()?;
|
|
||||||
assert_eq!(test_utils::to_vec0_round(&loss, 0)?, 2904.0);
|
|
||||||
let grads = loss.backward()?;
|
|
||||||
|
|
||||||
let grad_t = grads.get(&t).unwrap();
|
|
||||||
let grad_w = grads.get(&w).unwrap();
|
|
||||||
assert_eq!(grad_t.dims(), [1, 4, 7, 5]);
|
|
||||||
assert_eq!(grad_w.dims(), [4, 2, 3, 5]);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec1_round(&grad_w.flatten_all()?, 1)?,
|
|
||||||
[
|
|
||||||
// torch gets 89.1
|
|
||||||
-89.0, -135.3, 136.7, 102.0, -53.4, 117.9, 118.6, -43.9, -218.0, -58.5, -114.3, -150.0,
|
|
||||||
-15.6, 172.1, 66.3, -64.3, -27.9, -19.8, 31.7, 62.1, 5.5, 92.6, 28.2, -29.6, 55.9,
|
|
||||||
52.7, -72.7, -119.8, 53.8, -25.5, 128.8, 19.3, 68.0, 190.9, -64.1, -86.2, -111.2,
|
|
||||||
106.6, -67.7, 37.8, 115.9, 50.4, -77.7, -54.9, 22.3, -4.6, 89.8, 61.7, 122.4, 192.6,
|
|
||||||
-27.8, -104.6, 57.0, 166.4, 27.1, 6.1, 18.7, -93.2, 31.5, 168.2, -3.7, -99.5, -55.5,
|
|
||||||
-10.8, 17.5, 20.8, 16.9, 43.8, 42.0, -89.2, 18.8, -9.6, -84.1, 212.6, 19.7, -50.0,
|
|
||||||
-52.0, -40.0, -166.6, -73.2, -10.8, -73.3, 31.5, -23.4, -79.3, -27.0, -84.4, -42.9,
|
|
||||||
-20.3, 51.8, -16.7, 76.3, -120.5, -65.8, 96.5, -10.7, -45.9, -88.1, 65.4, -7.0, -1.5,
|
|
||||||
92.8, -25.1, -114.2, -5.8, -14.8, -51.2, -20.7, 54.2, -79.8, 47.7, -29.2, -8.8, 53.5,
|
|
||||||
-28.4, 85.0, -18.3, 107.0, 28.3, -71.8
|
|
||||||
]
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec3_round(&grad_t.i(0)?, 1)?,
|
|
||||||
[
|
|
||||||
[
|
|
||||||
[32.3, -41.6, -24.0, 14.1, 17.6],
|
|
||||||
[-11.8, 72.5, 87.6, 46.4, 61.5],
|
|
||||||
[115.0, 108.5, -48.6, -63.4, -50.0],
|
|
||||||
[51.3, 5.4, 31.3, 91.1, -30.9],
|
|
||||||
[52.7, 92.8, -68.0, -47.0, 83.0],
|
|
||||||
// pytorch gets -107.1
|
|
||||||
[-10.2, -107.0, -5.4, 213.1, -31.4],
|
|
||||||
[-2.4, 65.1, 9.2, -146.2, -24.2]
|
|
||||||
],
|
|
||||||
[
|
|
||||||
[-72.6, -63.9, -61.9, 45.3, 33.0],
|
|
||||||
[79.3, -0.5, -26.2, 78.2, 42.7],
|
|
||||||
[90.9, 141.6, 40.1, -62.7, 37.0],
|
|
||||||
[32.8, 198.2, -0.8, -31.1, 27.3],
|
|
||||||
// torch gets 48.0
|
|
||||||
[34.5, 34.9, -47.9, 127.6, -12.3],
|
|
||||||
[-61.4, -3.2, -2.9, -10.9, -16.6],
|
|
||||||
[74.6, 60.1, -68.9, 34.5, -50.4]
|
|
||||||
],
|
|
||||||
[
|
|
||||||
[37.5, -56.9, -43.6, -13.5, -9.9],
|
|
||||||
[40.0, 97.3, 28.6, 14.2, -30.1],
|
|
||||||
[-22.3, -126.3, -68.8, -8.2, 26.1],
|
|
||||||
[-32.9, 37.3, 108.5, -54.8, 29.6],
|
|
||||||
[34.9, -176.9, -125.0, -28.3, -13.9],
|
|
||||||
[-54.9, 142.6, 62.1, -80.4, -65.6],
|
|
||||||
[7.4, -91.1, -67.6, 35.0, 39.7]
|
|
||||||
],
|
|
||||||
[
|
|
||||||
[-57.2, -40.9, -10.1, 32.6, 29.4],
|
|
||||||
[18.7, -18.0, 29.5, -1.2, 59.2],
|
|
||||||
[-14.0, -74.4, 19.8, -117.0, 58.2],
|
|
||||||
[-21.8, 163.5, -71.1, -99.0, 80.9],
|
|
||||||
[-58.9, -10.9, 93.8, -139.6, 98.0],
|
|
||||||
// torch gets 54.5
|
|
||||||
[-54.4, 135.3, 6.0, -79.1, 134.6],
|
|
||||||
[27.5, -76.0, 43.4, -2.8, -7.8]
|
|
||||||
]
|
|
||||||
]
|
|
||||||
);
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(conv1d, conv1d_cpu, conv1d_gpu, conv1d_metal);
|
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
|
||||||
test_device!(
|
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
|
||||||
conv1d_small,
|
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
|
||||||
conv1d_small_cpu,
|
|
||||||
conv1d_small_gpu,
|
|
||||||
conv1d_small_metal
|
|
||||||
);
|
|
||||||
test_device!(conv2d, conv2d_cpu, conv2d_gpu, conv2d_metal);
|
|
||||||
test_device!(
|
test_device!(
|
||||||
conv2d_non_square,
|
conv2d_non_square,
|
||||||
conv2d_non_square_cpu,
|
conv2d_non_square_cpu,
|
||||||
conv2d_non_square_gpu,
|
conv2d_non_square_gpu
|
||||||
conv2d_non_square_metal
|
|
||||||
);
|
|
||||||
test_device!(
|
|
||||||
conv2d_small,
|
|
||||||
conv2d_small_cpu,
|
|
||||||
conv2d_small_gpu,
|
|
||||||
conv2d_small_metal
|
|
||||||
);
|
|
||||||
test_device!(
|
|
||||||
conv2d_smaller,
|
|
||||||
conv2d_smaller_cpu,
|
|
||||||
conv2d_smaller_gpu,
|
|
||||||
conv2d_smaller_metal
|
|
||||||
);
|
|
||||||
test_device!(
|
|
||||||
conv2d_grad,
|
|
||||||
conv2d_grad_cpu,
|
|
||||||
conv2d_grad_gpu,
|
|
||||||
conv2_grad_metal
|
|
||||||
);
|
);
|
||||||
|
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
|
||||||
|
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
|
||||||
|
test_device!(conv2d_grad, conv2d_grad_cpu, conv2d_grad_gpu);
|
||||||
|
@ -112,34 +112,3 @@ fn custom_op1_with_backward() -> Result<()> {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
impl candle_core::InplaceOp1 for Elu {
|
|
||||||
fn name(&self) -> &'static str {
|
|
||||||
"elu"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cpu_fwd(&self, s: &mut CpuStorage, _l: &Layout) -> Result<()> {
|
|
||||||
let alpha = self.alpha;
|
|
||||||
match s {
|
|
||||||
CpuStorage::BF16(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
|
|
||||||
CpuStorage::F16(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
|
|
||||||
CpuStorage::F32(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
|
|
||||||
CpuStorage::F64(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
|
|
||||||
_ => candle_core::bail!("unsupported dtype for inplace elu"),
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn inplace_op1() -> Result<()> {
|
|
||||||
let cpu = &Device::Cpu;
|
|
||||||
let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?;
|
|
||||||
let t = (t - 5.)?;
|
|
||||||
t.inplace_op1(&Elu { alpha: 1. })?;
|
|
||||||
assert_eq!(
|
|
||||||
to_vec1_round(&t, 4)?,
|
|
||||||
&[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
Binary file not shown.
@ -1,4 +1,3 @@
|
|||||||
#![allow(clippy::approx_constant)]
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
|
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
|
||||||
|
|
||||||
@ -97,24 +96,24 @@ fn unary_grad(device: &Device) -> Result<()> {
|
|||||||
let grads = y.backward()?;
|
let grads = y.backward()?;
|
||||||
let grad_x = grads.get(x).context("no grad for x")?;
|
let grad_x = grads.get(x).context("no grad for x")?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec1_round(&y, 4)?,
|
y.to_vec1::<f32>()?,
|
||||||
[20.0855, 2.7183, 54.5982, 1.1618]
|
[20.085537, 2.7182817, 54.59815, 1.1618342]
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec1_round(grad_x, 4)?,
|
grad_x.to_vec1::<f32>()?,
|
||||||
[20.0855, 2.7183, 54.5982, 1.1618]
|
[20.085537, 2.7182817, 54.59815, 1.1618342]
|
||||||
);
|
);
|
||||||
let y = x.exp()?.sqr()?;
|
let y = x.exp()?.sqr()?;
|
||||||
let grads = y.backward()?;
|
let grads = y.backward()?;
|
||||||
let grad_x = grads.get(x).context("no grad for x")?;
|
let grad_x = grads.get(x).context("no grad for x")?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec1_round(&y, 3)?,
|
y.to_vec1::<f32>()?,
|
||||||
[403.429, 7.389, 2980.958, 1.35]
|
[403.4288, 7.3890557, 2980.9578, 1.3498588]
|
||||||
);
|
);
|
||||||
// exp(x)^2 = exp(2*x)
|
// exp(x)^2 = exp(2*x)
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec1_round(grad_x, 2)?,
|
grad_x.to_vec1::<f32>()?,
|
||||||
[806.86, 14.78, 5961.92, 2.7]
|
[806.8576, 14.778111, 5961.9155, 2.6997175]
|
||||||
);
|
);
|
||||||
let y = x.sin()?;
|
let y = x.sin()?;
|
||||||
let grads = y.backward()?;
|
let grads = y.backward()?;
|
||||||
@ -193,273 +192,6 @@ fn unary_grad(device: &Device) -> Result<()> {
|
|||||||
test_utils::to_vec1_round(grad_x, 2)?,
|
test_utils::to_vec1_round(grad_x, 2)?,
|
||||||
[0.01, 0.42, 0.0, 0.98],
|
[0.01, 0.42, 0.0, 0.98],
|
||||||
);
|
);
|
||||||
|
|
||||||
// testing compared to pytorch nn.GELU(approximate = 'tanh')
|
|
||||||
let y = x.gelu()?;
|
|
||||||
let grads = y.backward()?;
|
|
||||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec1_round(&y, 4)?,
|
|
||||||
[2.9964, 0.8412, 3.9999, 0.0839]
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec1_round(grad_x, 4)?,
|
|
||||||
[1.0116, 1.0830, 1.0003, 0.6188],
|
|
||||||
);
|
|
||||||
|
|
||||||
// Testing compared to pytorch torch.erf
|
|
||||||
//
|
|
||||||
// import torch
|
|
||||||
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
|
|
||||||
// y = x.erf()
|
|
||||||
// print(y)
|
|
||||||
// loss = y.sum()
|
|
||||||
// loss.backward()
|
|
||||||
// print(x.grad)
|
|
||||||
let y = x.erf()?;
|
|
||||||
let grads = y.backward()?;
|
|
||||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
|
||||||
assert_eq!(test_utils::to_vec1_round(&y, 4)?, [1.0, 0.8427, 1.0, 0.168]);
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec1_round(grad_x, 4)?,
|
|
||||||
[0.0001, 0.4151, 0.0, 1.1033],
|
|
||||||
);
|
|
||||||
|
|
||||||
// Testing compared to pytorch nn.GELU(approximate = 'none')
|
|
||||||
//
|
|
||||||
// import torch
|
|
||||||
// import torch.nn.functional as F
|
|
||||||
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
|
|
||||||
// y = F.gelu(x, approximate='none')
|
|
||||||
// print(y)
|
|
||||||
// loss = y.sum()
|
|
||||||
// loss.backward()
|
|
||||||
// print(x.grad)
|
|
||||||
let y = x.gelu_erf()?;
|
|
||||||
let grads = y.backward()?;
|
|
||||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec1_round(&y, 4)?,
|
|
||||||
[2.9960, 0.8413, 3.9999, 0.0839]
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec1_round(grad_x, 4)?,
|
|
||||||
[1.0119, 1.0833, 1.0005, 0.6188],
|
|
||||||
);
|
|
||||||
|
|
||||||
// Testing compared to pytorch elu
|
|
||||||
//
|
|
||||||
// import torch
|
|
||||||
// import torch.nn.functional as F
|
|
||||||
// x = torch.tensor([-1.0, 0.0, -2.0, 3.0], requires_grad=True)
|
|
||||||
// y = F.elu(x, alpha=2.0)
|
|
||||||
// print(y)
|
|
||||||
// loss = y.min
|
|
||||||
// loss = y.sum()
|
|
||||||
// loss.backward()
|
|
||||||
// print(x.grad)
|
|
||||||
let elu_x = Var::new(&[-1.0f32, 0., -2., 3.], device)?;
|
|
||||||
let y = elu_x.elu(2.)?;
|
|
||||||
let grads = y.backward()?;
|
|
||||||
let grad_x = grads.get(&elu_x).context("no grad for x")?;
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec1_round(&y, 4)?,
|
|
||||||
[-1.2642, 0.0000, -1.7293, 3.0000]
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec1_round(grad_x, 4)?,
|
|
||||||
[0.7358, 2.0000, 0.2707, 1.0000]
|
|
||||||
);
|
|
||||||
|
|
||||||
// testing compared to pytorch nn.Silu()
|
|
||||||
let y = x.silu()?;
|
|
||||||
let grads = y.backward()?;
|
|
||||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec1_round(&y, 4)?,
|
|
||||||
[2.8577, 0.7311, 3.9281, 0.0806]
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec1_round(grad_x, 4)?,
|
|
||||||
[1.0881, 0.9277, 1.0527, 0.5747],
|
|
||||||
);
|
|
||||||
|
|
||||||
if device.is_cpu() {
|
|
||||||
let x = Var::new(&[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]], device)?;
|
|
||||||
let y = x.interpolate1d(12)?.reshape(36)?;
|
|
||||||
|
|
||||||
let z = Tensor::new(
|
|
||||||
&[
|
|
||||||
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16.,
|
|
||||||
17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32.,
|
|
||||||
33., 34., 35., 36.,
|
|
||||||
],
|
|
||||||
device,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
|
||||||
let grads = loss.backward()?;
|
|
||||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec3_round(grad_x, 4)?,
|
|
||||||
[[[10_f32, 26., 42.], [58., 74., 90.], [106., 122., 138.]]]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// manually checked: see comments
|
|
||||||
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
|
|
||||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
|
||||||
|
|
||||||
let z = Tensor::new(
|
|
||||||
&[
|
|
||||||
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17.,
|
|
||||||
18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34.,
|
|
||||||
35., 36.,
|
|
||||||
],
|
|
||||||
device,
|
|
||||||
)?;
|
|
||||||
// gradient should be
|
|
||||||
// row 1
|
|
||||||
// 1+2+7+8 = 18
|
|
||||||
// 3+4+9+10 = 26
|
|
||||||
// 5+6+11+12 = 34
|
|
||||||
// row 2
|
|
||||||
// 13+14+19+20 = 66
|
|
||||||
// 15+16+21+22 = 74
|
|
||||||
// 17+18+23+24 = 82
|
|
||||||
// row 3
|
|
||||||
// 25+26+31+32 = 114
|
|
||||||
// 27+28+33+34 = 122
|
|
||||||
// 29+30+35+36 = 130
|
|
||||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
|
||||||
|
|
||||||
let grads = loss.backward()?;
|
|
||||||
|
|
||||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?,
|
|
||||||
[[18_f32, 26., 34.], [66., 74., 82.], [114., 122., 130.]]
|
|
||||||
);
|
|
||||||
|
|
||||||
// manually checked: see comments
|
|
||||||
let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?;
|
|
||||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
|
||||||
|
|
||||||
let z = Tensor::new(
|
|
||||||
&[
|
|
||||||
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17.,
|
|
||||||
18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34.,
|
|
||||||
35., 36.,
|
|
||||||
],
|
|
||||||
device,
|
|
||||||
)?;
|
|
||||||
// gradient should be
|
|
||||||
// row 1
|
|
||||||
// 1+2+3+7+8+9+13+14+15 = 72
|
|
||||||
// 4+5+6+10+11+12+16+17+18 = 99
|
|
||||||
// row 2
|
|
||||||
// 19+20+21+25+26+27+31+32+33 = 234
|
|
||||||
// 22+23+24+28+29+30+34+35+36 = 243
|
|
||||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
|
||||||
|
|
||||||
let grads = loss.backward()?;
|
|
||||||
|
|
||||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?,
|
|
||||||
[[72_f32, 99.], [234., 261.]]
|
|
||||||
);
|
|
||||||
|
|
||||||
// manually checked: see comments
|
|
||||||
let x = Var::new(&[[[[1f32, 2.], [4., 5.]], [[6f32, 7.], [8., 9.]]]], device)?;
|
|
||||||
|
|
||||||
let y = x.interpolate2d(4, 4)?.reshape(32)?;
|
|
||||||
|
|
||||||
#[rustfmt::skip]
|
|
||||||
let z = Tensor::new(
|
|
||||||
&[
|
|
||||||
1_f32, 02., 03., 04.,
|
|
||||||
05., 06., 07., 08.,
|
|
||||||
09., 10., 11., 12.,
|
|
||||||
13., 14., 15., 16.,
|
|
||||||
17., 18., 19., 20.,
|
|
||||||
21., 22., 23., 24.,
|
|
||||||
25., 26., 27., 28.,
|
|
||||||
29., 30., 31., 32.
|
|
||||||
],
|
|
||||||
device,
|
|
||||||
)?;
|
|
||||||
// gradient should be
|
|
||||||
// m1r1
|
|
||||||
// 1+2+5+6=14
|
|
||||||
// 3+4+7+8=22
|
|
||||||
// m1r2
|
|
||||||
// 9+10+13+14=46
|
|
||||||
// 11+12+15+16=54
|
|
||||||
// m2r1
|
|
||||||
// 17+18+21+22=78
|
|
||||||
// 19+20+23+24=86
|
|
||||||
// m2r2
|
|
||||||
// 25+26+29+30=110
|
|
||||||
// 27+28+31+32=118
|
|
||||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
|
||||||
|
|
||||||
let grads = loss.backward()?;
|
|
||||||
|
|
||||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?,
|
|
||||||
[[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]]
|
|
||||||
);
|
|
||||||
|
|
||||||
// manually checked: see comments
|
|
||||||
let x = Var::new(
|
|
||||||
&[[[[1f32, 2.], [4., 5.]]], [[[6f32, 7.], [8., 9.]]]],
|
|
||||||
device,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let y = x.interpolate2d(4, 4)?.reshape(32)?;
|
|
||||||
|
|
||||||
#[rustfmt::skip]
|
|
||||||
let z = Tensor::new(
|
|
||||||
&[
|
|
||||||
1_f32, 02., 03., 04.,
|
|
||||||
05., 06., 07., 08.,
|
|
||||||
09., 10., 11., 12.,
|
|
||||||
13., 14., 15., 16.,
|
|
||||||
17., 18., 19., 20.,
|
|
||||||
21., 22., 23., 24.,
|
|
||||||
25., 26., 27., 28.,
|
|
||||||
29., 30., 31., 32.
|
|
||||||
],
|
|
||||||
device,
|
|
||||||
)?;
|
|
||||||
// gradient should be
|
|
||||||
// m1r1
|
|
||||||
// 1+2+5+6=14
|
|
||||||
// 3+4+7+8=22
|
|
||||||
// m1r2
|
|
||||||
// 9+10+13+14=46
|
|
||||||
// 11+12+15+16=54
|
|
||||||
// m2r1
|
|
||||||
// 17+18+21+22=78
|
|
||||||
// 19+20+23+24=86
|
|
||||||
// m2r2
|
|
||||||
// 25+26+29+30=110
|
|
||||||
// 27+28+31+32=118
|
|
||||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
|
||||||
|
|
||||||
let grads = loss.backward()?;
|
|
||||||
|
|
||||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?,
|
|
||||||
[[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]]
|
|
||||||
);
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -505,29 +237,9 @@ fn binary_grad(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(
|
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
|
||||||
simple_grad,
|
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
|
||||||
simple_grad_cpu,
|
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
|
||||||
simple_grad_gpu,
|
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
|
||||||
simple_grad_metal
|
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu);
|
||||||
);
|
test_device!(binary_grad, binary_grad_cpu, binary_grad_gpu);
|
||||||
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu, sum_grad_metal);
|
|
||||||
test_device!(
|
|
||||||
matmul_grad,
|
|
||||||
matmul_grad_cpu,
|
|
||||||
matmul_grad_gpu,
|
|
||||||
matmul_grad_metal
|
|
||||||
);
|
|
||||||
test_device!(
|
|
||||||
grad_descent,
|
|
||||||
grad_descent_cpu,
|
|
||||||
grad_descent_gpu,
|
|
||||||
grad_descent_metal
|
|
||||||
);
|
|
||||||
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu, unary_grad_metal);
|
|
||||||
test_device!(
|
|
||||||
binary_grad,
|
|
||||||
binary_grad_cpu,
|
|
||||||
binary_grad_gpu,
|
|
||||||
binary_grad_metal
|
|
||||||
);
|
|
||||||
|
@ -91,32 +91,3 @@ fn index_3d() -> Result<()> {
|
|||||||
assert_eq!(tensor.i((1, .., 3))?.to_vec1::<u32>()?, &[15, 19, 23]);
|
assert_eq!(tensor.i((1, .., 3))?.to_vec1::<u32>()?, &[15, 19, 23]);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn slice_assign() -> Result<()> {
|
|
||||||
let dev = Device::Cpu;
|
|
||||||
|
|
||||||
let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?;
|
|
||||||
let src = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((3, 2))?;
|
|
||||||
let out = tensor.slice_assign(&[1..4, 3..5], &src)?;
|
|
||||||
assert_eq!(
|
|
||||||
out.to_vec2::<u32>()?,
|
|
||||||
&[
|
|
||||||
[0, 1, 2, 3, 4],
|
|
||||||
[5, 6, 7, 0, 1],
|
|
||||||
[10, 11, 12, 2, 3],
|
|
||||||
[15, 16, 17, 4, 5]
|
|
||||||
]
|
|
||||||
);
|
|
||||||
let out = tensor.slice_assign(&[0..3, 0..2], &src)?;
|
|
||||||
assert_eq!(
|
|
||||||
out.to_vec2::<u32>()?,
|
|
||||||
&[
|
|
||||||
[0, 1, 2, 3, 4],
|
|
||||||
[2, 3, 7, 8, 9],
|
|
||||||
[4, 5, 12, 13, 14],
|
|
||||||
[15, 16, 17, 18, 19]
|
|
||||||
]
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
@ -49,7 +49,7 @@ fn contiguous(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(contiguous, contiguous_cpu, contiguous_gpu, contiguous_metal);
|
test_device!(contiguous, contiguous_cpu, contiguous_gpu);
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn strided_blocks() -> Result<()> {
|
fn strided_blocks() -> Result<()> {
|
||||||
@ -88,7 +88,7 @@ fn strided_blocks() -> Result<()> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
||||||
let tensor = tensor.i((.., 1))?.contiguous()?;
|
let tensor = tensor.i((.., 1))?;
|
||||||
match tensor.strided_blocks() {
|
match tensor.strided_blocks() {
|
||||||
candle::StridedBlocks::SingleBlock { start_offset, len } => {
|
candle::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||||
assert_eq!(start_offset, 0);
|
assert_eq!(start_offset, 0);
|
||||||
@ -100,20 +100,6 @@ fn strided_blocks() -> Result<()> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
||||||
let tensor = tensor.i((.., 1))?;
|
|
||||||
match tensor.strided_blocks() {
|
|
||||||
candle::StridedBlocks::SingleBlock { .. } => {
|
|
||||||
panic!("unexpected block structure")
|
|
||||||
}
|
|
||||||
candle::StridedBlocks::MultipleBlocks {
|
|
||||||
block_len,
|
|
||||||
block_start_index,
|
|
||||||
} => {
|
|
||||||
assert_eq!(block_len, 4);
|
|
||||||
assert_eq!(block_start_index.collect::<Vec<_>>(), &[4, 16])
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
|
||||||
match tensor.t()?.strided_blocks() {
|
match tensor.t()?.strided_blocks() {
|
||||||
candle::StridedBlocks::SingleBlock { .. } => {
|
candle::StridedBlocks::SingleBlock { .. } => {
|
||||||
panic!("unexpected block structure")
|
panic!("unexpected block structure")
|
||||||
|
@ -1,106 +0,0 @@
|
|||||||
use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor};
|
|
||||||
|
|
||||||
fn matmul(device: &Device) -> Result<()> {
|
|
||||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
|
||||||
let a = Tensor::from_slice(&data, (2, 2), device)?;
|
|
||||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
|
||||||
let b = Tensor::from_slice(&data, (2, 2), device)?;
|
|
||||||
|
|
||||||
let c = a.matmul(&b)?;
|
|
||||||
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
|
|
||||||
|
|
||||||
let data = vec![1.0f32, 2.0];
|
|
||||||
let a = Tensor::from_slice(&data, (2, 1), device)?;
|
|
||||||
let data = vec![3.0f32, 4.0];
|
|
||||||
let b = Tensor::from_slice(&data, (1, 2), device)?;
|
|
||||||
let c = a.matmul(&b)?;
|
|
||||||
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
|
|
||||||
|
|
||||||
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
|
|
||||||
let a = Tensor::from_slice(&data, (2, 3), device)?;
|
|
||||||
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
|
|
||||||
let b = Tensor::from_slice(&data, (3, 2), device)?;
|
|
||||||
let c = a.matmul(&b)?;
|
|
||||||
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
|
|
||||||
|
|
||||||
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
|
||||||
let a = Tensor::from_slice(&data, (2, 2, 3), device)?;
|
|
||||||
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
|
|
||||||
let b = Tensor::from_slice(&data, (2, 3, 2), device)?;
|
|
||||||
let expected = [[[16., 19.], [52., 64.]], [[214., 235.], [304., 334.]]];
|
|
||||||
|
|
||||||
let c = a.matmul(&b)?;
|
|
||||||
assert_eq!(c.to_vec3::<f32>()?, &expected);
|
|
||||||
|
|
||||||
// Also perform the matmul on contiguous transposed versions.
|
|
||||||
let a_tt = a.t()?.contiguous()?.t()?;
|
|
||||||
assert!(!a_tt.is_contiguous());
|
|
||||||
assert_eq!(a.dims(), a_tt.dims());
|
|
||||||
assert_eq!(a_tt.stride(), &[6, 1, 2]);
|
|
||||||
|
|
||||||
let b_tt = b.t()?.contiguous()?.t()?;
|
|
||||||
assert!(!b_tt.is_contiguous());
|
|
||||||
assert_eq!(b.dims(), b_tt.dims());
|
|
||||||
assert_eq!(b_tt.stride(), &[6, 1, 3]);
|
|
||||||
|
|
||||||
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
|
|
||||||
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
|
||||||
assert_eq!(a_tt.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn broadcast_matmul(device: &Device) -> Result<()> {
|
|
||||||
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
|
|
||||||
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
|
|
||||||
let out = lhs.broadcast_matmul(&rhs)?;
|
|
||||||
assert_eq!(out.dims(), &[3, 6, 4, 2]);
|
|
||||||
for idx1 in 0..3 {
|
|
||||||
for idx2 in 0..6 {
|
|
||||||
let out = out.i((idx1, idx2))?;
|
|
||||||
let lhs = lhs.i((idx1, 0))?;
|
|
||||||
let rhs = rhs.i(idx2)?;
|
|
||||||
let out2 = lhs.matmul(&rhs);
|
|
||||||
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
|
|
||||||
// With cuda, we see errors of up to ~1e-12.
|
|
||||||
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/huggingface/candle/issues/1948
|
|
||||||
fn squeeze_mm(device: &Device) -> Result<()> {
|
|
||||||
let seq_len = 8_usize;
|
|
||||||
let a = Tensor::zeros((1, seq_len, 16), DType::F32, device)?;
|
|
||||||
let x = a.i((.., seq_len - 1, ..))?;
|
|
||||||
let w = Tensor::zeros((32, 16), DType::F32, device)?.t()?;
|
|
||||||
let x = x.matmul(&w)?;
|
|
||||||
assert_eq!(x.dims(), &[1, 32]);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/huggingface/candle/issues/1992
|
|
||||||
fn mm_layout(device: &Device) -> Result<()> {
|
|
||||||
let a = Tensor::arange(0f32, 16f32, device)?.reshape((1, 1, 4, 4))?;
|
|
||||||
let b = Tensor::arange(0f32, 8f32, device)?.reshape((1, 1, 4, 2))?;
|
|
||||||
let mm1 = a.matmul(&b)?;
|
|
||||||
// Forces the layout to be:
|
|
||||||
// shape: [1, 1, 4, 2], stride: [8, 2, 2, 1], start_offset: 0
|
|
||||||
// This is still a contiguous matrix but matmul checks are only the two last dimensions have
|
|
||||||
// non 1 sizes but matmul check may be reluctant to handle it.
|
|
||||||
let b = b.transpose(1, 2)?.force_contiguous()?.transpose(1, 2)?;
|
|
||||||
let mm2 = a.matmul(&b)?;
|
|
||||||
let diff = (mm1 - mm2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
|
||||||
assert_eq!(diff, 0.);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
|
||||||
test_device!(
|
|
||||||
broadcast_matmul,
|
|
||||||
broadcast_matmul_cpu,
|
|
||||||
broadcast_matmul_gpu,
|
|
||||||
broadcast_matmul_metal
|
|
||||||
);
|
|
||||||
test_device!(squeeze_mm, squeeze_mm_cpu, squeeze_mm_gpu, squeeze_mm_metal);
|
|
||||||
test_device!(mm_layout, mm_layout_cpu, mm_layout_gpu, mm_layout_metal);
|
|
@ -1,9 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
x = np.arange(10)
|
|
||||||
|
|
||||||
# Write a npy file.
|
|
||||||
np.save("test.npy", x)
|
|
||||||
|
|
||||||
# Write multiple values to a npz file.
|
|
||||||
values = { "x": x, "x_plus_one": x + 1 }
|
|
||||||
np.savez("test.npz", **values)
|
|
@ -43,9 +43,6 @@ res = torch.nn.functional.avg_pool2d(t, 2)
|
|||||||
print(res)
|
print(res)
|
||||||
*/
|
*/
|
||||||
fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
|
fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
|
||||||
if dev.is_metal() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let t = Tensor::new(
|
let t = Tensor::new(
|
||||||
&[
|
&[
|
||||||
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
||||||
@ -101,17 +98,15 @@ fn upsample_nearest2d(dev: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu, avg_pool2d_metal);
|
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu);
|
||||||
test_device!(
|
test_device!(
|
||||||
avg_pool2d_pytorch,
|
avg_pool2d_pytorch,
|
||||||
avg_pool2d_pytorch_cpu,
|
avg_pool2d_pytorch_cpu,
|
||||||
avg_pool2d_pytorch_gpu,
|
avg_pool2d_pytorch_gpu
|
||||||
avg_pool2d_pytorch_metal
|
|
||||||
);
|
);
|
||||||
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu, max_pool2d_metal);
|
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);
|
||||||
test_device!(
|
test_device!(
|
||||||
upsample_nearest2d,
|
upsample_nearest2d,
|
||||||
upsample_nearest2d_cpu,
|
upsample_nearest2d_cpu,
|
||||||
upsample_nearest2d_gpu,
|
upsample_nearest2d_gpu
|
||||||
upsample_nearest2d_metal
|
|
||||||
);
|
);
|
||||||
|
@ -1,37 +0,0 @@
|
|||||||
import torch
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
# Write a trivial tensor to a pt file
|
|
||||||
a= torch.tensor([[1,2,3,4], [5,6,7,8]])
|
|
||||||
o = OrderedDict()
|
|
||||||
o["test"] = a
|
|
||||||
|
|
||||||
# Write a trivial tensor to a pt file
|
|
||||||
torch.save(o, "test.pt")
|
|
||||||
|
|
||||||
############################################################################################################
|
|
||||||
# Write a trivial tensor to a pt file with a key
|
|
||||||
torch.save({"model_state_dict": o}, "test_with_key.pt")
|
|
||||||
|
|
||||||
############################################################################################################
|
|
||||||
# Create a tensor with fortran contiguous memory layout
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# Step 1: Create a 3D NumPy array with Fortran order using a range of numbers
|
|
||||||
# For example, creating a 2x3x4 array
|
|
||||||
array_fortran = np.asfortranarray(np.arange(1, 2*3*4 + 1).reshape(2, 3, 4))
|
|
||||||
|
|
||||||
# Verify the memory order
|
|
||||||
print("Is Fortran contiguous (F order):", array_fortran.flags['F_CONTIGUOUS']) # Should be True
|
|
||||||
print("Is C contiguous (C order):", array_fortran.flags['C_CONTIGUOUS']) # Should be False
|
|
||||||
|
|
||||||
# Step 2: Convert the NumPy array to a PyTorch tensor
|
|
||||||
tensor_fortran = torch.from_numpy(array_fortran)
|
|
||||||
|
|
||||||
# Verify the tensor layout
|
|
||||||
print("Tensor stride:", tensor_fortran.stride()) # Stride will reflect the Fortran memory layout
|
|
||||||
|
|
||||||
# Step 3: Save the PyTorch tensor to a .pth file
|
|
||||||
torch.save({"tensor_fortran": tensor_fortran}, 'fortran_tensor_3d.pth')
|
|
||||||
|
|
||||||
print("3D Tensor saved with Fortran layout.")
|
|
@ -1,31 +0,0 @@
|
|||||||
/// Regression test for pth files not loading on Windows.
|
|
||||||
#[test]
|
|
||||||
fn test_pth() {
|
|
||||||
let tensors = candle_core::pickle::PthTensors::new("tests/test.pt", None).unwrap();
|
|
||||||
tensors.get("test").unwrap().unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_pth_with_key() {
|
|
||||||
let tensors =
|
|
||||||
candle_core::pickle::PthTensors::new("tests/test_with_key.pt", Some("model_state_dict"))
|
|
||||||
.unwrap();
|
|
||||||
tensors.get("test").unwrap().unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_pth_fortran_congiguous() {
|
|
||||||
let tensors =
|
|
||||||
candle_core::pickle::PthTensors::new("tests/fortran_tensor_3d.pth", None).unwrap();
|
|
||||||
let tensor = tensors.get("tensor_fortran").unwrap().unwrap();
|
|
||||||
|
|
||||||
assert_eq!(tensor.dims3().unwrap(), (2, 3, 4));
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
tensor.to_vec3::<i64>().unwrap(),
|
|
||||||
[
|
|
||||||
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
|
|
||||||
[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]
|
|
||||||
]
|
|
||||||
);
|
|
||||||
}
|
|
File diff suppressed because it is too large
Load Diff
@ -1,71 +0,0 @@
|
|||||||
use candle_core::{DType, Result, Tensor};
|
|
||||||
|
|
||||||
struct TmpFile(std::path::PathBuf);
|
|
||||||
|
|
||||||
impl TmpFile {
|
|
||||||
fn create(base: &str) -> TmpFile {
|
|
||||||
let filename = std::env::temp_dir().join(format!(
|
|
||||||
"candle-{}-{}-{:?}",
|
|
||||||
base,
|
|
||||||
std::process::id(),
|
|
||||||
std::thread::current().id(),
|
|
||||||
));
|
|
||||||
TmpFile(filename)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::convert::AsRef<std::path::Path> for TmpFile {
|
|
||||||
fn as_ref(&self) -> &std::path::Path {
|
|
||||||
self.0.as_path()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Drop for TmpFile {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
std::fs::remove_file(&self.0).unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn npy() -> Result<()> {
|
|
||||||
let npy = Tensor::read_npy("tests/test.npy")?;
|
|
||||||
assert_eq!(
|
|
||||||
npy.to_dtype(DType::U8)?.to_vec1::<u8>()?,
|
|
||||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn npz() -> Result<()> {
|
|
||||||
let npz = Tensor::read_npz("tests/test.npz")?;
|
|
||||||
assert_eq!(npz.len(), 2);
|
|
||||||
assert_eq!(npz[0].0, "x");
|
|
||||||
assert_eq!(npz[1].0, "x_plus_one");
|
|
||||||
assert_eq!(
|
|
||||||
npz[1].1.to_dtype(DType::U8)?.to_vec1::<u8>()?,
|
|
||||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn safetensors() -> Result<()> {
|
|
||||||
use candle_core::safetensors::Load;
|
|
||||||
|
|
||||||
let tmp_file = TmpFile::create("st");
|
|
||||||
let t = Tensor::arange(0f32, 24f32, &candle_core::Device::Cpu)?;
|
|
||||||
t.save_safetensors("t", &tmp_file)?;
|
|
||||||
// Load from file.
|
|
||||||
let st = candle_core::safetensors::load(&tmp_file, &candle_core::Device::Cpu)?;
|
|
||||||
let t2 = st.get("t").unwrap();
|
|
||||||
let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
|
||||||
assert_eq!(diff, 0f32);
|
|
||||||
// Load from bytes.
|
|
||||||
let bytes = std::fs::read(tmp_file)?;
|
|
||||||
let st = candle_core::safetensors::SliceSafetensors::new(&bytes)?;
|
|
||||||
let t2 = st.get("t").unwrap().load(&candle_core::Device::Cpu);
|
|
||||||
let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
|
||||||
assert_eq!(diff, 0f32);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -1,4 +1,4 @@
|
|||||||
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor, D};
|
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor};
|
||||||
|
|
||||||
fn zeros(device: &Device) -> Result<()> {
|
fn zeros(device: &Device) -> Result<()> {
|
||||||
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
|
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
|
||||||
@ -29,34 +29,7 @@ fn ones(device: &Device) -> Result<()> {
|
|||||||
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
||||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||||
);
|
);
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn full(device: &Device) -> Result<()> {
|
|
||||||
assert_eq!(
|
|
||||||
Tensor::full(42u32, (2, 3), device)?.to_vec2::<u32>()?,
|
|
||||||
[[42, 42, 42], [42, 42, 42]],
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn arange(device: &Device) -> Result<()> {
|
|
||||||
assert_eq!(
|
|
||||||
Tensor::arange(0u8, 5u8, device)?.to_vec1::<u8>()?,
|
|
||||||
[0, 1, 2, 3, 4],
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
Tensor::arange_step(0u8, 5u8, 2, device)?.to_vec1::<u8>()?,
|
|
||||||
[0, 2, 4],
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
Tensor::arange_step(0u8, 5u8, 3, device)?.to_vec1::<u8>()?,
|
|
||||||
[0, 3],
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
Tensor::arange_step(5i64, 0i64, -1, device)?.to_vec1::<i64>()?,
|
|
||||||
[5, 4, 3, 2, 1],
|
|
||||||
);
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -96,40 +69,6 @@ fn clamp(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn asort(device: &Device) -> Result<()> {
|
|
||||||
let data = &[[3f32, 1., 4., 1.1, 5.], [2.1, 1., 7., 8., 2.]];
|
|
||||||
let tensor = Tensor::new(data, device)?;
|
|
||||||
let indexes = tensor.arg_sort_last_dim(true)?;
|
|
||||||
assert_eq!(
|
|
||||||
indexes.to_vec2::<u32>()?,
|
|
||||||
[[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]],
|
|
||||||
);
|
|
||||||
let indexes = tensor.arg_sort_last_dim(false)?;
|
|
||||||
assert_eq!(
|
|
||||||
indexes.to_vec2::<u32>()?,
|
|
||||||
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
|
|
||||||
);
|
|
||||||
let (sorted, indexes) = tensor.sort_last_dim(true)?;
|
|
||||||
assert_eq!(
|
|
||||||
indexes.to_vec2::<u32>()?,
|
|
||||||
[[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]],
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
sorted.to_vec2::<f32>()?,
|
|
||||||
[[1.0, 1.1, 3.0, 4.0, 5.0], [1.0, 2.0, 2.1, 7.0, 8.0]]
|
|
||||||
);
|
|
||||||
let (sorted, indexes) = tensor.sort_last_dim(false)?;
|
|
||||||
assert_eq!(
|
|
||||||
indexes.to_vec2::<u32>()?,
|
|
||||||
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
sorted.to_vec2::<f32>()?,
|
|
||||||
[[5.0, 4.0, 3.0, 1.1, 1.0], [8.0, 7.0, 2.1, 2.0, 1.0]]
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn unary_op(device: &Device) -> Result<()> {
|
fn unary_op(device: &Device) -> Result<()> {
|
||||||
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
|
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
|
||||||
let tensor = Tensor::new(data, device)?;
|
let tensor = Tensor::new(data, device)?;
|
||||||
@ -140,9 +79,6 @@ fn unary_op(device: &Device) -> Result<()> {
|
|||||||
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
|
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
let t_f16 = tensor.to_dtype(DType::F16)?.gelu()?.to_dtype(DType::F32)?;
|
|
||||||
let max_diff = (tensor.gelu()? - t_f16)?.flatten_all()?.max(0)?;
|
|
||||||
assert!(max_diff.to_vec0::<f32>()? < 5e-3);
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
|
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
|
||||||
[
|
[
|
||||||
@ -157,13 +93,6 @@ fn unary_op(device: &Device) -> Result<()> {
|
|||||||
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
|
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec2_round(&tensor.silu()?, 4)?,
|
|
||||||
[
|
|
||||||
[-0.1423, 0.7311, 3.9281, -0.0475, 0.3112],
|
|
||||||
[2.53, -0.2553, -0.1205, 1.5447, 2.6395]
|
|
||||||
]
|
|
||||||
);
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
|
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
|
||||||
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
|
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
|
||||||
@ -185,14 +114,6 @@ fn unary_op(device: &Device) -> Result<()> {
|
|||||||
test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?,
|
test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?,
|
||||||
[3000.0, 300.]
|
[3000.0, 300.]
|
||||||
);
|
);
|
||||||
let tensor = Tensor::new(
|
|
||||||
&[-1.01f32, -0.9, -0.1, 0.0, -0.0, 0.1, 0.9, 1.0, 1.1],
|
|
||||||
device,
|
|
||||||
)?;
|
|
||||||
assert_eq!(
|
|
||||||
tensor.sign()?.to_vec1::<f32>()?,
|
|
||||||
[-1., -1., -1., 0., 0., 1., 1., 1., 1.]
|
|
||||||
);
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -240,22 +161,6 @@ fn transpose(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn var(device: &Device) -> Result<()> {
|
|
||||||
// Values taken from https://pytorch.org/docs/stable/generated/torch.var.html
|
|
||||||
let data = &[
|
|
||||||
[0.2035f32, 1.2959, 1.8101, -0.4644],
|
|
||||||
[1.5027, -0.3270, 0.5905, 0.6538],
|
|
||||||
[-1.5745, 1.3330, -0.5596, -0.6548],
|
|
||||||
[0.1264, -0.5080, 1.6420, 0.1992],
|
|
||||||
];
|
|
||||||
let tensor = Tensor::new(data, device)?;
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec2_round(&tensor.var_keepdim(1)?, 4)?,
|
|
||||||
&[[1.0631], [0.559], [1.4893], [0.8258]]
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn sum(device: &Device) -> Result<()> {
|
fn sum(device: &Device) -> Result<()> {
|
||||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||||
let tensor = Tensor::new(data, device)?;
|
let tensor = Tensor::new(data, device)?;
|
||||||
@ -665,30 +570,6 @@ fn broadcast(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn slice_set(device: &Device) -> Result<()> {
|
|
||||||
let (b, h, max_t, d) = (2, 4, 7, 3);
|
|
||||||
let cache = Tensor::zeros((b, h, max_t, d), DType::F32, device)?;
|
|
||||||
let tensor = Tensor::randn(0f32, 1f32, (b, h, 4, d), device)?;
|
|
||||||
cache.slice_set(&tensor, 2, 0)?;
|
|
||||||
let cache_t = cache.narrow(2, 0, 4)?;
|
|
||||||
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
|
||||||
assert_eq!(diff, 0.);
|
|
||||||
cache.slice_set(&tensor, 2, 1)?;
|
|
||||||
let cache_t = cache.narrow(2, 1, 4)?;
|
|
||||||
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
|
||||||
assert_eq!(diff, 0.);
|
|
||||||
let ones = Tensor::ones((b, h, 1, d), DType::F32, device)?;
|
|
||||||
cache.slice_set(&ones, 2, 6)?;
|
|
||||||
let diff = cache.narrow(2, 5, 1)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
|
||||||
assert_eq!(diff, 0.);
|
|
||||||
let diff = (cache.narrow(2, 6, 1)? - 1.)?
|
|
||||||
.abs()?
|
|
||||||
.sum_all()?
|
|
||||||
.to_vec0::<f32>()?;
|
|
||||||
assert_eq!(diff, 0.);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cat(device: &Device) -> Result<()> {
|
fn cat(device: &Device) -> Result<()> {
|
||||||
// 1D
|
// 1D
|
||||||
let t1 = Tensor::new(&[3f32, 1., 4.], device)?;
|
let t1 = Tensor::new(&[3f32, 1., 4.], device)?;
|
||||||
@ -741,31 +622,6 @@ fn cat(device: &Device) -> Result<()> {
|
|||||||
[2.0, 7.0, 1.0, 8.0, 2.0, 2.0, 7.0, 1.0, 8.0, 2.0]
|
[2.0, 7.0, 1.0, 8.0, 2.0, 2.0, 7.0, 1.0, 8.0, 2.0]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
// 3D
|
|
||||||
let t1 = Tensor::arange(0, 48i64, device)?.reshape((2, 6, 4))?;
|
|
||||||
let t2 = Tensor::arange(100, 124i64, device)?.reshape((2, 3, 4))?;
|
|
||||||
let t3 = Tensor::arange(10000, 10032i64, device)?.reshape((2, 4, 4))?;
|
|
||||||
|
|
||||||
let t_cat = Tensor::cat(&[&t1, &t2, &t3], 1)?;
|
|
||||||
|
|
||||||
let t1 = t1.t()?.contiguous()?.t()?;
|
|
||||||
let t2 = t2.t()?.contiguous()?.t()?;
|
|
||||||
let t3 = t3.t()?.contiguous()?.t()?;
|
|
||||||
let t_cat2 = Tensor::cat(&[&t1, &t2, &t3], 1)?;
|
|
||||||
|
|
||||||
let diff = t_cat.eq(&t_cat2)?.to_dtype(DType::F32)?.sum_all()?;
|
|
||||||
assert_eq!(diff.to_vec0::<f32>()?, 104.0);
|
|
||||||
assert_eq!(t_cat.i((0, 0, 0))?.to_vec0::<i64>()?, 0);
|
|
||||||
assert_eq!(t_cat.i((0, 4, 0))?.to_vec0::<i64>()?, 16);
|
|
||||||
assert_eq!(t_cat.i((0, 5, 0))?.to_vec0::<i64>()?, 20);
|
|
||||||
assert_eq!(t_cat.i((1, 5, 0))?.to_vec0::<i64>()?, 44);
|
|
||||||
assert_eq!(t_cat.i((0, 6, 0))?.to_vec0::<i64>()?, 100);
|
|
||||||
assert_eq!(t_cat.i((1, 6, 0))?.to_vec0::<i64>()?, 112);
|
|
||||||
assert_eq!(t_cat.i((0, 6, 1))?.to_vec0::<i64>()?, 101);
|
|
||||||
assert_eq!(t_cat.i((0, 7, 1))?.to_vec0::<i64>()?, 105);
|
|
||||||
assert_eq!(t_cat.i((0, 12, 1))?.to_vec0::<i64>()?, 10013);
|
|
||||||
assert_eq!(t_cat.i((1, 12, 3))?.to_vec0::<i64>()?, 10031);
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -776,8 +632,6 @@ fn embeddings(device: &Device) -> Result<()> {
|
|||||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||||
let hs = t.index_select(&ids, 0)?;
|
let hs = t.index_select(&ids, 0)?;
|
||||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||||
let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?;
|
|
||||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -805,47 +659,44 @@ fn index_select(device: &Device) -> Result<()> {
|
|||||||
[9.0, 10.0, 11.0]
|
[9.0, 10.0, 11.0]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
for dtype in [DType::U8, DType::U32, DType::I64] {
|
let hs = t.index_select(&ids, 1)?;
|
||||||
let ids = ids.to_dtype(dtype)?;
|
assert_eq!(
|
||||||
let hs = t.index_select(&ids, 1)?;
|
hs.to_vec2::<f32>()?,
|
||||||
assert_eq!(
|
&[
|
||||||
hs.to_vec2::<f32>()?,
|
[0.0, 2.0, 1.0],
|
||||||
&[
|
[3.0, 5.0, 4.0],
|
||||||
[0.0, 2.0, 1.0],
|
[6.0, 8.0, 7.0],
|
||||||
[3.0, 5.0, 4.0],
|
[9.0, 11.0, 10.0]
|
||||||
[6.0, 8.0, 7.0],
|
]
|
||||||
[9.0, 11.0, 10.0]
|
);
|
||||||
]
|
let hs = t.index_select(&ids, 0)?;
|
||||||
);
|
assert_eq!(
|
||||||
let hs = t.index_select(&ids, 0)?;
|
hs.to_vec2::<f32>()?,
|
||||||
assert_eq!(
|
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
||||||
hs.to_vec2::<f32>()?,
|
);
|
||||||
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
// Prior to https://github.com/huggingface/candle/pull/1022
|
||||||
);
|
// There would be a bug where the last values in the result tensor would be set to 0.
|
||||||
// Prior to https://github.com/huggingface/candle/pull/1022
|
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
|
||||||
// There would be a bug where the last values in the result tensor would be set to 0.
|
let hs = t.index_select(&ids, 0)?;
|
||||||
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
|
assert_eq!(
|
||||||
let hs = t.index_select(&ids, 0)?;
|
hs.to_vec2::<f32>()?,
|
||||||
assert_eq!(
|
&[
|
||||||
hs.to_vec2::<f32>()?,
|
[0.0, 1.0, 2.0],
|
||||||
&[
|
[6.0, 7.0, 8.0],
|
||||||
[0.0, 1.0, 2.0],
|
[3.0, 4.0, 5.0],
|
||||||
[6.0, 7.0, 8.0],
|
[0.0, 1.0, 2.0],
|
||||||
[3.0, 4.0, 5.0],
|
[6.0, 7.0, 8.0],
|
||||||
[0.0, 1.0, 2.0],
|
[3.0, 4.0, 5.0],
|
||||||
[6.0, 7.0, 8.0],
|
]
|
||||||
[3.0, 4.0, 5.0],
|
);
|
||||||
]
|
|
||||||
);
|
|
||||||
|
|
||||||
// Test when selecting dim > 0 with ids size different from elem count of
|
// Test when selecting dim > 0 with ids size different from elem count of
|
||||||
// target dim in source/input.
|
// target dim in source/input.
|
||||||
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
|
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
|
||||||
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
|
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
|
||||||
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
|
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
|
||||||
let hs = t.index_select(&ids, 1)?;
|
let hs = t.index_select(&ids, 1)?;
|
||||||
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
|
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -1007,6 +858,74 @@ fn gather(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn matmul(device: &Device) -> Result<()> {
|
||||||
|
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||||
|
let a = Tensor::from_slice(&data, (2, 2), device)?;
|
||||||
|
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||||
|
let b = Tensor::from_slice(&data, (2, 2), device)?;
|
||||||
|
|
||||||
|
let c = a.matmul(&b)?;
|
||||||
|
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
|
||||||
|
|
||||||
|
let data = vec![1.0f32, 2.0];
|
||||||
|
let a = Tensor::from_slice(&data, (2, 1), device)?;
|
||||||
|
let data = vec![3.0f32, 4.0];
|
||||||
|
let b = Tensor::from_slice(&data, (1, 2), device)?;
|
||||||
|
let c = a.matmul(&b)?;
|
||||||
|
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
|
||||||
|
|
||||||
|
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
|
||||||
|
let a = Tensor::from_slice(&data, (2, 3), device)?;
|
||||||
|
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
|
||||||
|
let b = Tensor::from_slice(&data, (3, 2), device)?;
|
||||||
|
let c = a.matmul(&b)?;
|
||||||
|
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
|
||||||
|
|
||||||
|
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
||||||
|
let a = Tensor::from_slice(&data, (2, 2, 3), device)?;
|
||||||
|
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
|
||||||
|
let b = Tensor::from_slice(&data, (2, 3, 2), device)?;
|
||||||
|
let expected = [[[16., 19.], [52., 64.]], [[214., 235.], [304., 334.]]];
|
||||||
|
|
||||||
|
let c = a.matmul(&b)?;
|
||||||
|
assert_eq!(c.to_vec3::<f32>()?, &expected);
|
||||||
|
|
||||||
|
// Also perform the matmul on contiguous transposed versions.
|
||||||
|
let a_tt = a.t()?.contiguous()?.t()?;
|
||||||
|
assert!(!a_tt.is_contiguous());
|
||||||
|
assert_eq!(a.dims(), a_tt.dims());
|
||||||
|
assert_eq!(a_tt.stride(), &[6, 1, 2]);
|
||||||
|
|
||||||
|
let b_tt = b.t()?.contiguous()?.t()?;
|
||||||
|
assert!(!b_tt.is_contiguous());
|
||||||
|
assert_eq!(b.dims(), b_tt.dims());
|
||||||
|
assert_eq!(b_tt.stride(), &[6, 1, 3]);
|
||||||
|
|
||||||
|
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
|
||||||
|
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||||
|
assert_eq!(a_tt.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn broadcast_matmul(device: &Device) -> Result<()> {
|
||||||
|
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
|
||||||
|
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
|
||||||
|
let out = lhs.broadcast_matmul(&rhs)?;
|
||||||
|
assert_eq!(out.dims(), &[3, 6, 4, 2]);
|
||||||
|
for idx1 in 0..3 {
|
||||||
|
for idx2 in 0..6 {
|
||||||
|
let out = out.i((idx1, idx2))?;
|
||||||
|
let lhs = lhs.i((idx1, 0))?;
|
||||||
|
let rhs = rhs.i(idx2)?;
|
||||||
|
let out2 = lhs.matmul(&rhs);
|
||||||
|
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
|
||||||
|
// With cuda, we see errors of up to ~1e-12.
|
||||||
|
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn broadcasting(device: &Device) -> Result<()> {
|
fn broadcasting(device: &Device) -> Result<()> {
|
||||||
let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?;
|
let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?;
|
||||||
let t2 = Tensor::new(&[100f32, 200f32], device)?;
|
let t2 = Tensor::new(&[100f32, 200f32], device)?;
|
||||||
@ -1111,108 +1030,38 @@ fn broadcasting(device: &Device) -> Result<()> {
|
|||||||
fn randn(device: &Device) -> Result<()> {
|
fn randn(device: &Device) -> Result<()> {
|
||||||
let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?;
|
let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?;
|
||||||
assert_eq!(tensor.dims(), [5, 3]);
|
assert_eq!(tensor.dims(), [5, 3]);
|
||||||
// Check that the seed gets updated by checking that
|
|
||||||
// a new series of numbers is generated each time
|
|
||||||
let tensor2 = Tensor::randn(0f32, 1f32, (5, 3), device)?;
|
|
||||||
assert_ne!(tensor.to_vec2::<f32>()?, tensor2.to_vec2::<f32>()?);
|
|
||||||
let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?;
|
let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?;
|
||||||
assert_eq!(tensor.dims(), [5, 3]);
|
assert_eq!(tensor.dims(), [5, 3]);
|
||||||
// Check that the seed gets updated by checking that
|
|
||||||
// a new series of numbers is generated each time
|
|
||||||
let tensor2 = Tensor::rand(0f32, 1f32, (5, 3), device)?;
|
|
||||||
assert_ne!(tensor.to_vec2::<f32>()?, tensor2.to_vec2::<f32>()?);
|
|
||||||
// We do not expect deterministic elements at any index.
|
|
||||||
// There once was a bug that had a deterministic zero element in evenly sized tensors.
|
|
||||||
const N: usize = 2;
|
|
||||||
let v = (0..100)
|
|
||||||
.map(|_| Tensor::randn(0f32, 1f32, N, device).and_then(|t| t.to_vec1::<f32>()))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
assert!(
|
|
||||||
(0..N).all(|i| v.windows(2).any(|pair| pair[0][i] != pair[1][i])),
|
|
||||||
"There are deterministic values in the randn tensors"
|
|
||||||
);
|
|
||||||
let v = (0..100)
|
|
||||||
.map(|_| Tensor::rand(0f32, 1f32, N, device).and_then(|t| t.to_vec1::<f32>()))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
assert!(
|
|
||||||
(0..N).all(|i| v.windows(2).any(|pair| pair[0][i] != pair[1][i])),
|
|
||||||
"There are deterministic values in the rand tensors"
|
|
||||||
);
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn zero_dim(device: &Device) -> Result<()> {
|
test_device!(zeros, zeros_cpu, zeros_gpu);
|
||||||
let t = Tensor::zeros((4, 0, 1), DType::F32, device)?;
|
test_device!(ones, ones_cpu, ones_gpu);
|
||||||
assert_eq!(t.dims3()?, (4, 0, 1));
|
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
||||||
let t2 = Tensor::zeros((4, 3, 1), DType::F32, device)?;
|
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
||||||
let t_cat = Tensor::cat(&[&t, &t2], 1)?;
|
test_device!(narrow, narrow_cpu, narrow_gpu);
|
||||||
assert_eq!(t_cat.dims3()?, (4, 3, 1));
|
test_device!(broadcast, broadcast_cpu, broadcast_gpu);
|
||||||
let t_cat = Tensor::cat(&[&t, &t], 1)?;
|
test_device!(cat, cat_cpu, cat_gpu);
|
||||||
assert_eq!(t_cat.dims3()?, (4, 0, 1));
|
test_device!(sum, sum_cpu, sum_gpu);
|
||||||
let t_unary = t.sqrt()?;
|
test_device!(min, min_cpu, min_gpu);
|
||||||
assert_eq!(t_unary.dims3()?, (4, 0, 1));
|
test_device!(max, max_cpu, max_gpu);
|
||||||
let t_plus = (&t + 1.)?;
|
test_device!(argmax, argmax_cpu, argmax_gpu);
|
||||||
assert_eq!(t_plus.dims3()?, (4, 0, 1));
|
test_device!(argmin, argmin_cpu, argmin_gpu);
|
||||||
let t_mm = t2.matmul(&t.t()?)?;
|
test_device!(transpose, transpose_cpu, transpose_gpu);
|
||||||
assert_eq!(t_mm.dims3()?, (4, 3, 0));
|
test_device!(unary_op, unary_op_cpu, unary_op_gpu);
|
||||||
let t_mm = t.matmul(&t2.t()?)?;
|
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
||||||
assert_eq!(t_mm.dims3()?, (4, 0, 3));
|
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
||||||
let t_mm = t.t()?.matmul(&t)?;
|
test_device!(cmp, cmp_cpu, cmp_gpu);
|
||||||
assert_eq!(t_mm.dims3()?, (4, 1, 1));
|
test_device!(matmul, matmul_cpu, matmul_gpu);
|
||||||
Ok(())
|
test_device!(broadcast_matmul, broadcast_matmul_cpu, broadcast_matmul_gpu);
|
||||||
}
|
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
|
||||||
|
test_device!(index_select, index_select_cpu, index_select_gpu);
|
||||||
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
test_device!(index_add, index_add_cpu, index_add_gpu);
|
||||||
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
test_device!(gather, gather_cpu, gather_gpu);
|
||||||
test_device!(full, full_cpu, full_gpu, full_metal);
|
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
||||||
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
|
test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
|
||||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
test_device!(randn, randn_cpu, randn_gpu);
|
||||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
test_device!(clamp, clamp_cpu, clamp_gpu);
|
||||||
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
|
|
||||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
|
|
||||||
test_device!(slice_set, ss_cpu, ss_gpu, ss_metal);
|
|
||||||
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
|
|
||||||
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
|
|
||||||
test_device!(min, min_cpu, min_gpu, min_metal);
|
|
||||||
test_device!(max, max_cpu, max_gpu, max_metal);
|
|
||||||
test_device!(argmax, argmax_cpu, argmax_gpu, argmax_metal);
|
|
||||||
test_device!(argmin, argmin_cpu, argmin_gpu, argmin_metal);
|
|
||||||
test_device!(transpose, transpose_cpu, transpose_gpu, transpose_metal);
|
|
||||||
test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
|
|
||||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
|
|
||||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
|
|
||||||
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
|
|
||||||
test_device!(
|
|
||||||
broadcasting,
|
|
||||||
broadcasting_cpu,
|
|
||||||
broadcasting_gpu,
|
|
||||||
broadcasting_metal
|
|
||||||
);
|
|
||||||
test_device!(
|
|
||||||
index_select,
|
|
||||||
index_select_cpu,
|
|
||||||
index_select_gpu,
|
|
||||||
index_select_metal
|
|
||||||
);
|
|
||||||
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
|
|
||||||
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
|
|
||||||
test_device!(
|
|
||||||
scatter_add,
|
|
||||||
scatter_add_cpu,
|
|
||||||
scatter_add_gpu,
|
|
||||||
scatter_add_metal
|
|
||||||
);
|
|
||||||
test_device!(
|
|
||||||
slice_scatter,
|
|
||||||
slice_scatter_cpu,
|
|
||||||
slice_scatter_gpu,
|
|
||||||
slice_scatter_metal
|
|
||||||
);
|
|
||||||
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
|
||||||
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
|
||||||
test_device!(asort, asort_cpu, asort_gpu, asort_metal);
|
|
||||||
test_device!(var, var_cpu, var_gpu, var_metal);
|
|
||||||
test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal);
|
|
||||||
|
|
||||||
// There was originally a bug on the CPU implementation for randn
|
// There was originally a bug on the CPU implementation for randn
|
||||||
// https://github.com/huggingface/candle/issues/381
|
// https://github.com/huggingface/candle/issues/381
|
||||||
@ -1224,124 +1073,3 @@ fn randn_hasneg() -> Result<()> {
|
|||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn pad_with_same() -> Result<()> {
|
|
||||||
let t = Tensor::arange(1f32, 5f32, &Device::Cpu)?.reshape((2, 2))?;
|
|
||||||
let t0 = t.pad_with_same(0, 1, 2)?;
|
|
||||||
assert_eq!(
|
|
||||||
t0.to_vec2::<f32>()?,
|
|
||||||
[[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0], [3.0, 4.0]]
|
|
||||||
);
|
|
||||||
let t1 = t.pad_with_same(1, 1, 2)?;
|
|
||||||
assert_eq!(
|
|
||||||
t1.to_vec2::<f32>()?,
|
|
||||||
[[1.0, 1.0, 2.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0, 4.0]]
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn i64_abs() -> Result<()> {
|
|
||||||
let t = Tensor::new(&[-42i64, 1337], &Device::Cpu)?;
|
|
||||||
let t = t.abs()?;
|
|
||||||
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn tril_triu_eye() -> Result<()> {
|
|
||||||
let t = Tensor::tril2(4, DType::F32, &Device::Cpu)?;
|
|
||||||
assert_eq!(
|
|
||||||
t.to_vec2::<f32>()?,
|
|
||||||
[
|
|
||||||
[1.0, 0.0, 0.0, 0.0],
|
|
||||||
[1.0, 1.0, 0.0, 0.0],
|
|
||||||
[1.0, 1.0, 1.0, 0.0],
|
|
||||||
[1.0, 1.0, 1.0, 1.0]
|
|
||||||
],
|
|
||||||
);
|
|
||||||
let t = Tensor::triu2(4, DType::F32, &Device::Cpu)?;
|
|
||||||
assert_eq!(
|
|
||||||
t.to_vec2::<f32>()?,
|
|
||||||
[
|
|
||||||
[1.0, 1.0, 1.0, 1.0],
|
|
||||||
[0.0, 1.0, 1.0, 1.0],
|
|
||||||
[0.0, 0.0, 1.0, 1.0],
|
|
||||||
[0.0, 0.0, 0.0, 1.0]
|
|
||||||
]
|
|
||||||
);
|
|
||||||
let t = Tensor::eye(4, DType::F32, &Device::Cpu)?;
|
|
||||||
assert_eq!(
|
|
||||||
t.to_vec2::<f32>()?,
|
|
||||||
[
|
|
||||||
[1.0, 0.0, 0.0, 0.0],
|
|
||||||
[0.0, 1.0, 0.0, 0.0],
|
|
||||||
[0.0, 0.0, 1.0, 0.0],
|
|
||||||
[0.0, 0.0, 0.0, 1.0]
|
|
||||||
]
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn cumsum() -> Result<()> {
|
|
||||||
let t = &[3f32, 1., 4., 1., 5.];
|
|
||||||
let t = Tensor::new(t, &Device::Cpu)?;
|
|
||||||
assert_eq!(t.cumsum(0)?.to_vec1::<f32>()?, [3., 4., 8., 9., 14.]);
|
|
||||||
let t = t.unsqueeze(1)?;
|
|
||||||
assert_eq!(
|
|
||||||
t.cumsum(0)?.to_vec2::<f32>()?,
|
|
||||||
[[3.0], [4.0], [8.0], [9.0], [14.0]]
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
t.cumsum(1)?.to_vec2::<f32>()?,
|
|
||||||
[[3.0], [1.0], [4.0], [1.0], [5.0]]
|
|
||||||
);
|
|
||||||
let t = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
|
||||||
let t = Tensor::new(t, &Device::Cpu)?;
|
|
||||||
assert_eq!(
|
|
||||||
t.cumsum(1)?.to_vec2::<f32>()?,
|
|
||||||
[[3.0, 4.0, 8.0, 9.0, 14.0], [2.0, 3.0, 10.0, 18.0, 20.0]],
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
t.cumsum(0)?.to_vec2::<f32>()?,
|
|
||||||
[[3.0, 1.0, 4.0, 1.0, 5.0], [5.0, 2.0, 11.0, 9.0, 7.0]]
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A helper function for floating point comparison. Both a and b must be 1D Tensor and contains the same amount of data.
|
|
||||||
/// Assertion passes if the difference of all pairs of a and b is smaller than epsilon.
|
|
||||||
fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
|
|
||||||
let a_vec: Vec<f64> = a.to_vec1()?;
|
|
||||||
let b_vec: Vec<f64> = b.to_vec1()?;
|
|
||||||
|
|
||||||
assert_eq!(a_vec.len(), b_vec.len());
|
|
||||||
for (a, b) in a_vec.iter().zip(b_vec.iter()) {
|
|
||||||
assert!((a - b).abs() < epsilon);
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn log_sum_exp() -> Result<()> {
|
|
||||||
let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
|
||||||
let output = input.log_sum_exp(D::Minus1)?;
|
|
||||||
// The expectations obtained from pytorch.
|
|
||||||
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
|
|
||||||
assert_close(&output, &expected, 0.00001)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn pow() -> Result<()> {
|
|
||||||
let lhs = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
|
||||||
let rhs = (&lhs - 2.)?;
|
|
||||||
let res = lhs.pow(&rhs)?;
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec2_round(&res, 3)?,
|
|
||||||
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0]]
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -11,8 +11,8 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
candle = { workspace = true }
|
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-nn = { workspace = true }
|
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||||
hf-hub = { workspace = true}
|
hf-hub = { workspace = true}
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
memmap2 = { workspace = true }
|
memmap2 = { workspace = true }
|
||||||
|
@ -4,9 +4,7 @@
|
|||||||
//! <https://www.cs.toronto.edu/~kriz/cifar.html>
|
//! <https://www.cs.toronto.edu/~kriz/cifar.html>
|
||||||
//! The binary version of the dataset is used.
|
//! The binary version of the dataset is used.
|
||||||
use crate::vision::Dataset;
|
use crate::vision::Dataset;
|
||||||
use candle::{DType, Device, Error, Result, Tensor};
|
use candle::{DType, Device, Result, Tensor};
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
||||||
use parquet::file::reader::{FileReader, SerializedFileReader};
|
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::{BufReader, Read};
|
use std::io::{BufReader, Read};
|
||||||
|
|
||||||
@ -62,58 +60,3 @@ pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<Dataset> {
|
|||||||
labels: 10,
|
labels: 10,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor, Tensor)> {
|
|
||||||
let samples = parquet.metadata().file_metadata().num_rows() as usize;
|
|
||||||
let mut buffer_images: Vec<u8> = Vec::with_capacity(samples * 1_024);
|
|
||||||
let mut buffer_labels: Vec<u8> = Vec::with_capacity(samples);
|
|
||||||
for row in parquet.into_iter().flatten() {
|
|
||||||
for (_name, field) in row.get_column_iter() {
|
|
||||||
if let parquet::record::Field::Group(subrow) = field {
|
|
||||||
for (_name, field) in subrow.get_column_iter() {
|
|
||||||
if let parquet::record::Field::Bytes(value) = field {
|
|
||||||
let image = image::load_from_memory(value.data()).unwrap();
|
|
||||||
buffer_images.extend(image.to_rgb8().as_raw());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if let parquet::record::Field::Long(label) = field {
|
|
||||||
buffer_labels.push(*label as u8);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let images = (Tensor::from_vec(buffer_images, (samples, 3, 32, 32), &Device::Cpu)?
|
|
||||||
.to_dtype(DType::U8)?
|
|
||||||
/ 255.)?;
|
|
||||||
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
|
|
||||||
Ok((images, labels))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load() -> Result<Dataset> {
|
|
||||||
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
|
||||||
let dataset_id = "cifar10".to_string();
|
|
||||||
let repo = Repo::with_revision(
|
|
||||||
dataset_id,
|
|
||||||
RepoType::Dataset,
|
|
||||||
"refs/convert/parquet".to_string(),
|
|
||||||
);
|
|
||||||
let repo = api.repo(repo);
|
|
||||||
let test_parquet_filename = repo
|
|
||||||
.get("plain_text/test/0000.parquet")
|
|
||||||
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
|
||||||
let train_parquet_filename = repo
|
|
||||||
.get("plain_text/train/0000.parquet")
|
|
||||||
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
|
||||||
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)
|
|
||||||
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
|
|
||||||
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)
|
|
||||||
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
|
|
||||||
let (test_images, test_labels) = load_parquet(test_parquet)?;
|
|
||||||
let (train_images, train_labels) = load_parquet(train_parquet)?;
|
|
||||||
Ok(crate::vision::Dataset {
|
|
||||||
train_images,
|
|
||||||
train_labels,
|
|
||||||
test_images,
|
|
||||||
test_labels,
|
|
||||||
labels: 10,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
@ -11,93 +11,50 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { workspace = true }
|
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-datasets = { workspace = true, optional = true }
|
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
||||||
candle-nn = { workspace = true }
|
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||||
candle-transformers = { workspace = true }
|
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
||||||
candle-flash-attn = { workspace = true, optional = true }
|
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||||
candle-onnx = { workspace = true, optional = true }
|
|
||||||
|
|
||||||
csv = "1.3.0"
|
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
half = { workspace = true, optional = true }
|
half = { workspace = true, optional = true }
|
||||||
hf-hub = { workspace = true, features = ["tokio"] }
|
|
||||||
image = { workspace = true }
|
image = { workspace = true }
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true }
|
|
||||||
rayon = { workspace = true }
|
rayon = { workspace = true }
|
||||||
rubato = { version = "0.15.0", optional = true }
|
|
||||||
safetensors = { workspace = true }
|
safetensors = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
symphonia = { version = "0.5.3", features = ["all"], optional = true }
|
|
||||||
tokenizers = { workspace = true, features = ["onig"] }
|
tokenizers = { workspace = true, features = ["onig"] }
|
||||||
cpal= { version = "0.15.2", optional = true }
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
clap = { workspace = true }
|
clap = { workspace = true }
|
||||||
|
hf-hub = { workspace = true, features=["tokio"]}
|
||||||
imageproc = { workspace = true }
|
imageproc = { workspace = true }
|
||||||
memmap2 = { workspace = true }
|
memmap2 = { workspace = true }
|
||||||
rand = { workspace = true }
|
rand = { workspace = true }
|
||||||
ab_glyph = { workspace = true }
|
rusttype = { workspace = true }
|
||||||
tracing = { workspace = true }
|
tracing = { workspace = true }
|
||||||
tracing-chrome = { workspace = true }
|
tracing-chrome = { workspace = true }
|
||||||
tracing-subscriber = { workspace = true }
|
tracing-subscriber = { workspace = true }
|
||||||
|
wav = { workspace = true }
|
||||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||||
tokio = "1.29.1"
|
tokio = "1.29.1"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
bindgen_cuda = { version = "0.1.1", optional = true }
|
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
|
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||||
cudnn = ["candle/cudnn"]
|
cudnn = ["candle/cudnn"]
|
||||||
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||||
onnx = ["candle-onnx"]
|
|
||||||
metal = ["candle/metal", "candle-nn/metal"]
|
|
||||||
microphone = ["cpal"]
|
|
||||||
encodec = ["cpal", "symphonia", "rubato"]
|
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
name = "llama_multiprocess"
|
name = "llama_multiprocess"
|
||||||
required-features = ["cuda", "nccl", "flash-attn"]
|
required-features = ["cuda", "nccl", "flash-attn"]
|
||||||
|
|
||||||
[[example]]
|
|
||||||
name = "reinforcement-learning"
|
|
||||||
required-features = ["pyo3"]
|
|
||||||
|
|
||||||
[[example]]
|
|
||||||
name = "onnx"
|
|
||||||
required-features = ["onnx"]
|
|
||||||
|
|
||||||
[[example]]
|
|
||||||
name = "onnx_basics"
|
|
||||||
required-features = ["onnx"]
|
|
||||||
|
|
||||||
[[example]]
|
|
||||||
name = "whisper"
|
|
||||||
required-features = ["symphonia"]
|
|
||||||
|
|
||||||
[[example]]
|
|
||||||
name = "whisper-microphone"
|
|
||||||
required-features = ["microphone"]
|
|
||||||
|
|
||||||
[[example]]
|
|
||||||
name = "mnist-training"
|
|
||||||
required-features = ["candle-datasets"]
|
|
||||||
|
|
||||||
[[example]]
|
|
||||||
name = "llama2-c"
|
|
||||||
required-features = ["candle-datasets"]
|
|
||||||
|
|
||||||
[[example]]
|
|
||||||
name = "encodec"
|
|
||||||
required-features = ["encodec"]
|
|
||||||
|
@ -4,28 +4,235 @@ use std::io::Write;
|
|||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
struct KernelDirectories {
|
struct KernelDirectories {
|
||||||
kernel_glob: &'static str,
|
kernel_dir: &'static str,
|
||||||
rust_target: &'static str,
|
rust_target: &'static str,
|
||||||
include_dirs: &'static [&'static str],
|
include_dirs: &'static [&'static str],
|
||||||
}
|
}
|
||||||
|
|
||||||
const KERNEL_DIRS: [KernelDirectories; 1] = [KernelDirectories {
|
const DIRS: [KernelDirectories; 1] = [KernelDirectories {
|
||||||
kernel_glob: "examples/custom-ops/kernels/*.cu",
|
kernel_dir: "examples/custom-ops/kernels/",
|
||||||
rust_target: "examples/custom-ops/cuda_kernels.rs",
|
rust_target: "examples/custom-ops/cuda_kernels.rs",
|
||||||
include_dirs: &[],
|
include_dirs: &[],
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
impl KernelDirectories {
|
||||||
|
fn maybe_build_ptx(
|
||||||
|
&self,
|
||||||
|
cu_file: &std::path::Path,
|
||||||
|
ptx_file: &std::path::Path,
|
||||||
|
compute_cap: usize,
|
||||||
|
) -> Result<()> {
|
||||||
|
let should_compile = if ptx_file.exists() {
|
||||||
|
let ptx_modified = ptx_file.metadata()?.modified()?;
|
||||||
|
let cu_modified = cu_file.metadata()?.modified()?;
|
||||||
|
cu_modified.duration_since(ptx_modified).is_ok()
|
||||||
|
} else {
|
||||||
|
true
|
||||||
|
};
|
||||||
|
if should_compile {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
let mut command = std::process::Command::new("nvcc");
|
||||||
|
let out_dir = ptx_file.parent().context("no parent for ptx file")?;
|
||||||
|
let include_dirs: Vec<String> =
|
||||||
|
self.include_dirs.iter().map(|c| format!("-I{c}")).collect();
|
||||||
|
command
|
||||||
|
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
||||||
|
.arg("--ptx")
|
||||||
|
.args(["--default-stream", "per-thread"])
|
||||||
|
.args(["--output-directory", out_dir.to_str().unwrap()])
|
||||||
|
.arg(format!("-I/{}", self.kernel_dir))
|
||||||
|
.args(include_dirs)
|
||||||
|
.arg(cu_file);
|
||||||
|
let output = command
|
||||||
|
.spawn()
|
||||||
|
.context("failed spawning nvcc")?
|
||||||
|
.wait_with_output()?;
|
||||||
|
if !output.status.success() {
|
||||||
|
anyhow::bail!(
|
||||||
|
"nvcc error while compiling {cu_file:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||||
|
String::from_utf8_lossy(&output.stdout),
|
||||||
|
String::from_utf8_lossy(&output.stderr)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
std::fs::OpenOptions::new()
|
||||||
|
.create(true)
|
||||||
|
.write(true)
|
||||||
|
.open(ptx_file)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
fn process(&self, out_dir: &std::path::Path, compute_cap: usize) -> Result<()> {
|
||||||
|
println!("cargo:rerun-if-changed={}", self.kernel_dir);
|
||||||
|
let kernel_dir = PathBuf::from(self.kernel_dir);
|
||||||
|
let out_dir = out_dir.join(self.kernel_dir);
|
||||||
|
if !out_dir.exists() {
|
||||||
|
std::fs::create_dir_all(&out_dir)?;
|
||||||
|
}
|
||||||
|
let mut cu_files = vec![];
|
||||||
|
let mut cuh_files = vec![];
|
||||||
|
for file in std::fs::read_dir(kernel_dir)?.flatten() {
|
||||||
|
let file = file.path();
|
||||||
|
match file.extension().and_then(|v| v.to_str()) {
|
||||||
|
Some("cu") => cu_files.push(file),
|
||||||
|
Some("cuh") => cuh_files.push(file),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut ptx_paths = vec![];
|
||||||
|
for cu_file in cu_files.iter() {
|
||||||
|
let file_stem = cu_file
|
||||||
|
.file_stem()
|
||||||
|
.with_context(|| format!("no stem {cu_file:?}"))?;
|
||||||
|
let file_stem = file_stem.to_string_lossy().into_owned();
|
||||||
|
let ptx_file = out_dir.join(&format!("{file_stem}.ptx"));
|
||||||
|
self.maybe_build_ptx(cu_file, &ptx_file, compute_cap)?;
|
||||||
|
ptx_paths.push(ptx_file);
|
||||||
|
}
|
||||||
|
|
||||||
|
let regenerate_rs_file = true;
|
||||||
|
if regenerate_rs_file {
|
||||||
|
let mut file = std::fs::File::create(self.rust_target)?;
|
||||||
|
for ptx_path in ptx_paths {
|
||||||
|
let name = ptx_path
|
||||||
|
.file_stem()
|
||||||
|
.context("empty stem")?
|
||||||
|
.to_string_lossy();
|
||||||
|
file.write_all(b"#[rustfmt::skip]\n")?;
|
||||||
|
let const_definition = format!(
|
||||||
|
r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}/{name}.ptx"));"#,
|
||||||
|
name.to_uppercase().replace('.', "_"),
|
||||||
|
self.kernel_dir,
|
||||||
|
);
|
||||||
|
file.write_all(const_definition.as_bytes())?;
|
||||||
|
file.write_all(b"\n")?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
println!("cargo:rerun-if-changed=build.rs");
|
println!("cargo:rerun-if-changed=build.rs");
|
||||||
|
|
||||||
|
let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?;
|
||||||
|
let out_dir = PathBuf::from(out_dir);
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
{
|
set_cuda_include_dir()?;
|
||||||
for kdir in KERNEL_DIRS.iter() {
|
#[cfg(feature = "cuda")]
|
||||||
let builder = bindgen_cuda::Builder::default().kernel_paths_glob(kdir.kernel_glob);
|
let compute_cap = compute_cap()?;
|
||||||
println!("cargo:info={builder:?}");
|
#[cfg(not(feature = "cuda"))]
|
||||||
let bindings = builder.build_ptx().unwrap();
|
let compute_cap = 0;
|
||||||
bindings.write(kdir.rust_target).unwrap()
|
for d in DIRS {
|
||||||
}
|
d.process(&out_dir, compute_cap)?
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn set_cuda_include_dir() -> Result<()> {
|
||||||
|
// NOTE: copied from cudarc build.rs.
|
||||||
|
let env_vars = [
|
||||||
|
"CUDA_PATH",
|
||||||
|
"CUDA_ROOT",
|
||||||
|
"CUDA_TOOLKIT_ROOT_DIR",
|
||||||
|
"CUDNN_LIB",
|
||||||
|
];
|
||||||
|
let env_vars = env_vars
|
||||||
|
.into_iter()
|
||||||
|
.map(std::env::var)
|
||||||
|
.filter_map(Result::ok)
|
||||||
|
.map(Into::<PathBuf>::into);
|
||||||
|
|
||||||
|
let roots = [
|
||||||
|
"/usr",
|
||||||
|
"/usr/local/cuda",
|
||||||
|
"/opt/cuda",
|
||||||
|
"/usr/lib/cuda",
|
||||||
|
"C:/Program Files/NVIDIA GPU Computing Toolkit",
|
||||||
|
"C:/CUDA",
|
||||||
|
];
|
||||||
|
let roots = roots.into_iter().map(Into::<PathBuf>::into);
|
||||||
|
let root = env_vars
|
||||||
|
.chain(roots)
|
||||||
|
.find(|path| path.join("include").join("cuda.h").is_file())
|
||||||
|
.context("cannot find include/cuda.h")?;
|
||||||
|
println!(
|
||||||
|
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
|
||||||
|
root.join("include").display()
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
|
fn compute_cap() -> Result<usize> {
|
||||||
|
// Grab compute code from nvidia-smi
|
||||||
|
let mut compute_cap = {
|
||||||
|
let out = std::process::Command::new("nvidia-smi")
|
||||||
|
.arg("--query-gpu=compute_cap")
|
||||||
|
.arg("--format=csv")
|
||||||
|
.output()
|
||||||
|
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
|
||||||
|
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
|
||||||
|
let mut lines = out.lines();
|
||||||
|
assert_eq!(
|
||||||
|
lines.next().context("missing line in stdout")?,
|
||||||
|
"compute_cap"
|
||||||
|
);
|
||||||
|
let cap = lines
|
||||||
|
.next()
|
||||||
|
.context("missing line in stdout")?
|
||||||
|
.replace('.', "");
|
||||||
|
cap.parse::<usize>()
|
||||||
|
.with_context(|| format!("cannot parse as int {cap}"))?
|
||||||
|
};
|
||||||
|
|
||||||
|
// Grab available GPU codes from nvcc and select the highest one
|
||||||
|
let max_nvcc_code = {
|
||||||
|
let out = std::process::Command::new("nvcc")
|
||||||
|
.arg("--list-gpu-code")
|
||||||
|
.output()
|
||||||
|
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
||||||
|
let out = std::str::from_utf8(&out.stdout).unwrap();
|
||||||
|
|
||||||
|
let out = out.lines().collect::<Vec<&str>>();
|
||||||
|
let mut codes = Vec::with_capacity(out.len());
|
||||||
|
for code in out {
|
||||||
|
let code = code.split('_').collect::<Vec<&str>>();
|
||||||
|
if !code.is_empty() && code.contains(&"sm") {
|
||||||
|
if let Ok(num) = code[1].parse::<usize>() {
|
||||||
|
codes.push(num);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
codes.sort();
|
||||||
|
if !codes.contains(&compute_cap) {
|
||||||
|
anyhow::bail!(
|
||||||
|
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
*codes.last().unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
|
// If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
|
||||||
|
// then choose the highest gpu code in nvcc
|
||||||
|
if compute_cap > max_nvcc_code {
|
||||||
|
println!(
|
||||||
|
"cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
|
||||||
|
);
|
||||||
|
compute_cap = max_nvcc_code;
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
||||||
|
|
||||||
|
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
||||||
|
compute_cap = compute_cap_str
|
||||||
|
.parse::<usize>()
|
||||||
|
.with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?;
|
||||||
|
println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
|
||||||
|
}
|
||||||
|
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
|
||||||
|
Ok(compute_cap)
|
||||||
|
}
|
||||||
|
@ -2,10 +2,10 @@
|
|||||||
|
|
||||||
Bert is a general large language model. In this example it can be used for two
|
Bert is a general large language model. In this example it can be used for two
|
||||||
different tasks:
|
different tasks:
|
||||||
|
|
||||||
- Compute sentence embeddings for a prompt.
|
- Compute sentence embeddings for a prompt.
|
||||||
- Compute similarities between a set of sentences.
|
- Compute similarities between a set of sentences.
|
||||||
|
|
||||||
|
|
||||||
## Sentence embeddings
|
## Sentence embeddings
|
||||||
|
|
||||||
Bert is used to compute the sentence embeddings for a prompt. The model weights
|
Bert is used to compute the sentence embeddings for a prompt. The model weights
|
||||||
@ -24,48 +24,6 @@ cargo run --example bert --release -- --prompt "Here is a test sentence"
|
|||||||
> Tensor[[1, 7, 384], f32]
|
> Tensor[[1, 7, 384], f32]
|
||||||
```
|
```
|
||||||
|
|
||||||
### Custom models
|
|
||||||
|
|
||||||
You can specify different models, such as BGE, with the `--model-id` flag:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cargo run --example bert --release -- \
|
|
||||||
--model-id BAAI/bge-large-zh-v1.5 \
|
|
||||||
--prompt "Here is a test sentence"
|
|
||||||
Loaded and encoded 435.70775ms
|
|
||||||
[[[ 3.0944e-1, -7.8455e-5, -1.2768e0, ..., 1.3755e-2, -3.2371e-1, 2.3819e-1],
|
|
||||||
[-2.8506e-1, 1.9953e-1, -1.3076e0, ..., 6.9819e-2, 1.0833e-2, -1.1512e0],
|
|
||||||
[ 3.9892e-1, 2.0000e-1, -9.3178e-1, ..., -4.1393e-1, -4.9644e-2, -3.3786e-1],
|
|
||||||
...
|
|
||||||
[ 6.0345e-1, 3.5744e-1, -1.2672e0, ..., -6.9165e-1, -3.4973e-3, -8.4214e-1],
|
|
||||||
[ 3.9218e-1, -3.2735e-1, -1.3123e0, ..., -4.9318e-1, -5.1334e-1, -3.6391e-1],
|
|
||||||
[ 3.0978e-1, 2.5662e-4, -1.2773e0, ..., 1.3357e-2, -3.2390e-1, 2.3858e-1]]]
|
|
||||||
Tensor[[1, 9, 1024], f32]
|
|
||||||
Took 176.744667ms
|
|
||||||
```
|
|
||||||
|
|
||||||
### Gelu approximation
|
|
||||||
|
|
||||||
You can get a speedup by using an approximation of the gelu activation, with a
|
|
||||||
small loss of precision, by passing the `--approximate-gelu` flag:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
$ cargo run --example bert --release -- \
|
|
||||||
--model-id BAAI/bge-large-zh-v1.5 \
|
|
||||||
--prompt "Here is a test sentence" \
|
|
||||||
--approximate-gelu
|
|
||||||
Loaded and encoded 244.388042ms
|
|
||||||
[[[ 3.1048e-1, -6.0339e-4, -1.2758e0, ..., 1.3718e-2, -3.2362e-1, 2.3775e-1],
|
|
||||||
[-2.8354e-1, 1.9984e-1, -1.3077e0, ..., 6.9390e-2, 9.9681e-3, -1.1531e0],
|
|
||||||
[ 3.9947e-1, 1.9917e-1, -9.3178e-1, ..., -4.1301e-1, -5.0719e-2, -3.3955e-1],
|
|
||||||
...
|
|
||||||
[ 6.0499e-1, 3.5664e-1, -1.2642e0, ..., -6.9134e-1, -3.4581e-3, -8.4471e-1],
|
|
||||||
[ 3.9311e-1, -3.2812e-1, -1.3105e0, ..., -4.9291e-1, -5.1270e-1, -3.6543e-1],
|
|
||||||
[ 3.1082e-1, -2.6737e-4, -1.2762e0, ..., 1.3319e-2, -3.2381e-1, 2.3815e-1]]]
|
|
||||||
Tensor[[1, 9, 1024], f32]
|
|
||||||
Took 116.840791ms
|
|
||||||
```
|
|
||||||
|
|
||||||
## Similarities
|
## Similarities
|
||||||
|
|
||||||
In this example, Bert is used to compute the sentence embeddings for a set of
|
In this example, Bert is used to compute the sentence embeddings for a set of
|
||||||
|
@ -3,13 +3,13 @@ extern crate intel_mkl_src;
|
|||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE};
|
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{anyhow, Error as E, Result};
|
||||||
use candle::Tensor;
|
use candle::Tensor;
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
|
||||||
use tokenizers::{PaddingParams, Tokenizer};
|
use tokenizers::{PaddingParams, Tokenizer};
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -19,6 +19,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cpu: bool,
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Run offline (you must have the files already cached)
|
||||||
|
#[arg(long)]
|
||||||
|
offline: bool,
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tracing: bool,
|
tracing: bool,
|
||||||
@ -34,10 +38,6 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
prompt: Option<String>,
|
prompt: Option<String>,
|
||||||
|
|
||||||
/// Use the pytorch weights rather than the safetensors ones
|
|
||||||
#[arg(long)]
|
|
||||||
use_pth: bool,
|
|
||||||
|
|
||||||
/// The number of times to run the prompt.
|
/// The number of times to run the prompt.
|
||||||
#[arg(long, default_value = "1")]
|
#[arg(long, default_value = "1")]
|
||||||
n: usize,
|
n: usize,
|
||||||
@ -45,10 +45,6 @@ struct Args {
|
|||||||
/// L2 normalization for embeddings.
|
/// L2 normalization for embeddings.
|
||||||
#[arg(long, default_value = "true")]
|
#[arg(long, default_value = "true")]
|
||||||
normalize_embeddings: bool,
|
normalize_embeddings: bool,
|
||||||
|
|
||||||
/// Use tanh based approximation for Gelu instead of erf implementation.
|
|
||||||
#[arg(long, default_value = "false")]
|
|
||||||
approximate_gelu: bool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Args {
|
impl Args {
|
||||||
@ -64,30 +60,34 @@ impl Args {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||||
let (config_filename, tokenizer_filename, weights_filename) = {
|
let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
|
||||||
|
let cache = Cache::default().repo(repo);
|
||||||
|
(
|
||||||
|
cache
|
||||||
|
.get("config.json")
|
||||||
|
.ok_or(anyhow!("Missing config file in cache"))?,
|
||||||
|
cache
|
||||||
|
.get("tokenizer.json")
|
||||||
|
.ok_or(anyhow!("Missing tokenizer file in cache"))?,
|
||||||
|
cache
|
||||||
|
.get("model.safetensors")
|
||||||
|
.ok_or(anyhow!("Missing weights file in cache"))?,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let api = api.repo(repo);
|
let api = api.repo(repo);
|
||||||
let config = api.get("config.json")?;
|
(
|
||||||
let tokenizer = api.get("tokenizer.json")?;
|
api.get("config.json")?,
|
||||||
let weights = if self.use_pth {
|
api.get("tokenizer.json")?,
|
||||||
api.get("pytorch_model.bin")?
|
api.get("model.safetensors")?,
|
||||||
} else {
|
)
|
||||||
api.get("model.safetensors")?
|
|
||||||
};
|
|
||||||
(config, tokenizer, weights)
|
|
||||||
};
|
};
|
||||||
let config = std::fs::read_to_string(config_filename)?;
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
let mut config: Config = serde_json::from_str(&config)?;
|
let config: Config = serde_json::from_str(&config)?;
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let vb = if self.use_pth {
|
let vb =
|
||||||
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
|
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? };
|
||||||
} else {
|
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
|
||||||
};
|
|
||||||
if self.approximate_gelu {
|
|
||||||
config.hidden_act = HiddenAct::GeluApproximate;
|
|
||||||
}
|
|
||||||
let model = BertModel::load(vb, &config)?;
|
let model = BertModel::load(vb, &config)?;
|
||||||
Ok((model, tokenizer))
|
Ok((model, tokenizer))
|
||||||
}
|
}
|
||||||
|
@ -1,19 +0,0 @@
|
|||||||
# candle-blip
|
|
||||||
|
|
||||||
The
|
|
||||||
[blip-image-captioning](https://huggingface.co/Salesforce/blip-image-captioning-base)
|
|
||||||
model can generate captions for an input image.
|
|
||||||
|
|
||||||
## Running on an example
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cargo run --example blip --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
|
||||||
```
|
|
||||||
|
|
||||||
```
|
|
||||||
Running on CPU, to run on GPU, build this example with `--features cuda`
|
|
||||||
loaded image Tensor[dims 3, 384, 384; f32]
|
|
||||||
model built
|
|
||||||
several cyclists are riding down a road with cars behind them%
|
|
||||||
```
|
|
||||||

|
|
@ -1,154 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use anyhow::Error as E;
|
|
||||||
use clap::Parser;
|
|
||||||
|
|
||||||
use candle::{DType, Device, Result, Tensor};
|
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
|
||||||
use candle_nn::VarBuilder;
|
|
||||||
use candle_transformers::models::blip;
|
|
||||||
use candle_transformers::models::quantized_blip;
|
|
||||||
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
|
|
||||||
enum Model {
|
|
||||||
M(blip::BlipForConditionalGeneration),
|
|
||||||
Q(quantized_blip::BlipForConditionalGeneration),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Model {
|
|
||||||
fn text_decoder_forward(&mut self, xs: &Tensor, img_xs: &Tensor) -> Result<Tensor> {
|
|
||||||
match self {
|
|
||||||
Self::M(m) => m.text_decoder().forward(xs, img_xs),
|
|
||||||
Self::Q(m) => m.text_decoder().forward(xs, img_xs),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Maybe add support for the conditional prompt.
|
|
||||||
#[derive(Parser)]
|
|
||||||
struct Args {
|
|
||||||
#[arg(long)]
|
|
||||||
model: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
tokenizer: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
image: String,
|
|
||||||
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
/// Use the quantized version of the model.
|
|
||||||
#[arg(long)]
|
|
||||||
quantized: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
const SEP_TOKEN_ID: u32 = 102;
|
|
||||||
|
|
||||||
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
|
||||||
/// (3, 384, 384). OpenAI normalization is applied.
|
|
||||||
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
|
||||||
let img = image::io::Reader::open(p)?
|
|
||||||
.decode()
|
|
||||||
.map_err(candle::Error::wrap)?
|
|
||||||
.resize_to_fill(384, 384, image::imageops::FilterType::Triangle);
|
|
||||||
let img = img.to_rgb8();
|
|
||||||
let data = img.into_raw();
|
|
||||||
let data = Tensor::from_vec(data, (384, 384, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
|
||||||
let mean =
|
|
||||||
Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], &Device::Cpu)?.reshape((3, 1, 1))?;
|
|
||||||
let std = Tensor::new(&[0.26862954f32, 0.261_302_6, 0.275_777_1], &Device::Cpu)?
|
|
||||||
.reshape((3, 1, 1))?;
|
|
||||||
(data.to_dtype(candle::DType::F32)? / 255.)?
|
|
||||||
.broadcast_sub(&mean)?
|
|
||||||
.broadcast_div(&std)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
|
||||||
let args = Args::parse();
|
|
||||||
|
|
||||||
let model_file = match args.model {
|
|
||||||
None => {
|
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
|
||||||
if args.quantized {
|
|
||||||
let api = api.model("lmz/candle-blip".to_string());
|
|
||||||
api.get("blip-image-captioning-large-q4k.gguf")?
|
|
||||||
} else {
|
|
||||||
let api = api.repo(hf_hub::Repo::with_revision(
|
|
||||||
"Salesforce/blip-image-captioning-large".to_string(),
|
|
||||||
hf_hub::RepoType::Model,
|
|
||||||
"refs/pr/18".to_string(),
|
|
||||||
));
|
|
||||||
api.get("model.safetensors")?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Some(model) => model.into(),
|
|
||||||
};
|
|
||||||
let tokenizer = match args.tokenizer {
|
|
||||||
None => {
|
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
|
||||||
let api = api.model("Salesforce/blip-image-captioning-large".to_string());
|
|
||||||
api.get("tokenizer.json")?
|
|
||||||
}
|
|
||||||
Some(file) => file.into(),
|
|
||||||
};
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
|
||||||
let mut tokenizer = TokenOutputStream::new(tokenizer);
|
|
||||||
let mut logits_processor =
|
|
||||||
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
|
|
||||||
|
|
||||||
let config = blip::Config::image_captioning_large();
|
|
||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let (image_embeds, device, mut model) = if args.quantized {
|
|
||||||
let device = Device::Cpu;
|
|
||||||
let image = load_image(args.image)?.to_device(&device)?;
|
|
||||||
println!("loaded image {image:?}");
|
|
||||||
|
|
||||||
let vb = quantized_blip::VarBuilder::from_gguf(model_file, &device)?;
|
|
||||||
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
|
|
||||||
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
|
|
||||||
(image_embeds, device, Model::Q(model))
|
|
||||||
} else {
|
|
||||||
let image = load_image(args.image)?.to_device(&device)?;
|
|
||||||
println!("loaded image {image:?}");
|
|
||||||
|
|
||||||
let vb =
|
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
|
||||||
let model = blip::BlipForConditionalGeneration::new(&config, vb)?;
|
|
||||||
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
|
|
||||||
(image_embeds, device, Model::M(model))
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut token_ids = vec![30522u32];
|
|
||||||
for index in 0..1000 {
|
|
||||||
let context_size = if index > 0 { 1 } else { token_ids.len() };
|
|
||||||
let start_pos = token_ids.len().saturating_sub(context_size);
|
|
||||||
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
|
|
||||||
let logits = model.text_decoder_forward(&input_ids, &image_embeds)?;
|
|
||||||
let logits = logits.squeeze(0)?;
|
|
||||||
let logits = logits.get(logits.dim(0)? - 1)?;
|
|
||||||
let token = logits_processor.sample(&logits)?;
|
|
||||||
if token == SEP_TOKEN_ID {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
token_ids.push(token);
|
|
||||||
if let Some(t) = tokenizer.next_token(token)? {
|
|
||||||
use std::io::Write;
|
|
||||||
print!("{t}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
|
||||||
print!("{rest}");
|
|
||||||
}
|
|
||||||
println!();
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -1,237 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
|
||||||
use clap::Parser;
|
|
||||||
|
|
||||||
use candle_transformers::models::chatglm::{Config, Model};
|
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
|
||||||
use candle_nn::VarBuilder;
|
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
|
|
||||||
struct TextGeneration {
|
|
||||||
model: Model,
|
|
||||||
device: Device,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
logits_processor: LogitsProcessor,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
verbose_prompt: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TextGeneration {
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
fn new(
|
|
||||||
model: Model,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
seed: u64,
|
|
||||||
temp: Option<f64>,
|
|
||||||
top_p: Option<f64>,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
verbose_prompt: bool,
|
|
||||||
device: &Device,
|
|
||||||
) -> Self {
|
|
||||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
|
||||||
Self {
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
logits_processor,
|
|
||||||
repeat_penalty,
|
|
||||||
repeat_last_n,
|
|
||||||
verbose_prompt,
|
|
||||||
device: device.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
|
||||||
use std::io::Write;
|
|
||||||
println!("starting the inference loop");
|
|
||||||
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
|
||||||
if tokens.is_empty() {
|
|
||||||
anyhow::bail!("Empty prompts are not supported in the chatglm model.")
|
|
||||||
}
|
|
||||||
if self.verbose_prompt {
|
|
||||||
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
|
||||||
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
|
||||||
println!("{id:7} -> '{token}'");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let mut tokens = tokens.get_ids().to_vec();
|
|
||||||
let mut generated_tokens = 0usize;
|
|
||||||
let eos_token = match self.tokenizer.get_vocab(true).get("</s>") {
|
|
||||||
Some(token) => *token,
|
|
||||||
None => anyhow::bail!("cannot find the endoftext token"),
|
|
||||||
};
|
|
||||||
print!("{prompt}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
let start_gen = std::time::Instant::now();
|
|
||||||
for index in 0..sample_len {
|
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
|
||||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
|
||||||
let logits = self.model.forward(&input)?;
|
|
||||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
|
||||||
let logits = if self.repeat_penalty == 1. {
|
|
||||||
logits
|
|
||||||
} else {
|
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
||||||
candle_transformers::utils::apply_repeat_penalty(
|
|
||||||
&logits,
|
|
||||||
self.repeat_penalty,
|
|
||||||
&tokens[start_at..],
|
|
||||||
)?
|
|
||||||
};
|
|
||||||
|
|
||||||
let next_token = self.logits_processor.sample(&logits)?;
|
|
||||||
tokens.push(next_token);
|
|
||||||
generated_tokens += 1;
|
|
||||||
if next_token == eos_token {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
|
||||||
print!("{token}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
}
|
|
||||||
let dt = start_gen.elapsed();
|
|
||||||
println!(
|
|
||||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
|
||||||
generated_tokens as f64 / dt.as_secs_f64(),
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
|
||||||
#[arg(long)]
|
|
||||||
tracing: bool,
|
|
||||||
|
|
||||||
/// Display the token for the specified prompt.
|
|
||||||
#[arg(long)]
|
|
||||||
verbose_prompt: bool,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
prompt: String,
|
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
|
||||||
#[arg(long)]
|
|
||||||
temperature: Option<f64>,
|
|
||||||
|
|
||||||
/// Nucleus sampling probability cutoff.
|
|
||||||
#[arg(long)]
|
|
||||||
top_p: Option<f64>,
|
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
|
||||||
#[arg(long, default_value_t = 299792458)]
|
|
||||||
seed: u64,
|
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
|
||||||
#[arg(long, short = 'n', default_value_t = 5000)]
|
|
||||||
sample_len: usize,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
model_id: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
revision: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
weight_file: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
tokenizer: Option<String>,
|
|
||||||
|
|
||||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
|
||||||
#[arg(long, default_value_t = 1.1)]
|
|
||||||
repeat_penalty: f32,
|
|
||||||
|
|
||||||
/// The context size to consider for the repeat penalty.
|
|
||||||
#[arg(long, default_value_t = 64)]
|
|
||||||
repeat_last_n: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
|
||||||
use tracing_subscriber::prelude::*;
|
|
||||||
|
|
||||||
let args = Args::parse();
|
|
||||||
let _guard = if args.tracing {
|
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
|
||||||
Some(guard)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
println!(
|
|
||||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
|
||||||
candle::utils::with_avx(),
|
|
||||||
candle::utils::with_neon(),
|
|
||||||
candle::utils::with_simd128(),
|
|
||||||
candle::utils::with_f16c()
|
|
||||||
);
|
|
||||||
println!(
|
|
||||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
|
||||||
args.temperature.unwrap_or(0.),
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n
|
|
||||||
);
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let api = Api::new()?;
|
|
||||||
let model_id = match args.model_id {
|
|
||||||
Some(model_id) => model_id.to_string(),
|
|
||||||
None => "THUDM/chatglm3-6b".to_string(),
|
|
||||||
};
|
|
||||||
let revision = match args.revision {
|
|
||||||
Some(rev) => rev.to_string(),
|
|
||||||
None => "main".to_string(),
|
|
||||||
};
|
|
||||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
|
||||||
let tokenizer_filename = match args.tokenizer {
|
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
|
||||||
None => api
|
|
||||||
.model("lmz/candle-chatglm".to_string())
|
|
||||||
.get("chatglm-tokenizer.json")?,
|
|
||||||
};
|
|
||||||
let filenames = match args.weight_file {
|
|
||||||
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
|
||||||
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
|
||||||
};
|
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let config = Config::glm3_6b();
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
|
||||||
let model = Model::new(&config, vb)?;
|
|
||||||
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
|
||||||
|
|
||||||
let mut pipeline = TextGeneration::new(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
args.seed,
|
|
||||||
args.temperature,
|
|
||||||
args.top_p,
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n,
|
|
||||||
args.verbose_prompt,
|
|
||||||
&device,
|
|
||||||
);
|
|
||||||
pipeline.run(&args.prompt, args.sample_len)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -1,46 +0,0 @@
|
|||||||
Contrastive Language-Image Pre-Training
|
|
||||||
|
|
||||||
Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
|
|
||||||
pairs of images with related texts.
|
|
||||||
|
|
||||||
https://github.com/openai/CLIP
|
|
||||||
|
|
||||||
https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
|
|
||||||
|
|
||||||
## Running on an example on cpu
|
|
||||||
|
|
||||||
```
|
|
||||||
$ cargo run --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
|
|
||||||
|
|
||||||
|
|
||||||
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
|
|
||||||
|
|
||||||
INFO clip: Probability: 0.0000% Text: a cycling race
|
|
||||||
INFO clip: Probability: 0.0000% Text: a photo of two cats
|
|
||||||
INFO clip: Probability: 100.0000% Text: a robot holding a candle
|
|
||||||
|
|
||||||
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
|
|
||||||
|
|
||||||
INFO clip: Probability: 99.9999% Text: a cycling race
|
|
||||||
INFO clip: Probability: 0.0001% Text: a photo of two cats
|
|
||||||
INFO clip: Probability: 0.0000% Text: a robot holding a candle
|
|
||||||
```
|
|
||||||
|
|
||||||
## Running on an example with metal feature (mac)
|
|
||||||
|
|
||||||
```
|
|
||||||
$ cargo run --features metal --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
|
|
||||||
|
|
||||||
|
|
||||||
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
|
|
||||||
|
|
||||||
INFO clip: Probability: 0.0000% Text: a cycling race
|
|
||||||
INFO clip: Probability: 0.0000% Text: a photo of two cats
|
|
||||||
INFO clip: Probability: 100.0000% Text: a robot holding a candle
|
|
||||||
|
|
||||||
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
|
|
||||||
|
|
||||||
INFO clip: Probability: 99.9999% Text: a cycling race
|
|
||||||
INFO clip: Probability: 0.0001% Text: a photo of two cats
|
|
||||||
INFO clip: Probability: 0.0000% Text: a robot holding a candle
|
|
||||||
```
|
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user