Compare commits

..

1 Commits

Author SHA1 Message Date
a910ec5993 CustomOp for einsum. 2023-09-08 20:46:30 +01:00
395 changed files with 3534 additions and 45753 deletions

View File

@ -59,7 +59,7 @@ jobs:
- name: Install Rust Stable - name: Install Rust Stable
run: curl https://sh.rustup.rs -sSf | sh -s -- -y run: curl https://sh.rustup.rs -sSf | sh -s -- -y
- uses: Swatinem/rust-cache@v2 - uses: Swatinem/rust-cache@v2
- run: apt-get update -y && apt-get install libssl-dev protobuf-compiler -y - run: apt-get update -y && apt-get install libssl-dev -y
- name: Test (cuda) - name: Test (cuda)
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
stop-runner: stop-runner:

Binary file not shown.

View File

@ -1,68 +0,0 @@
name: PyO3-CI
on:
workflow_dispatch:
push:
branches:
- main
paths:
- candle-pyo3/**
pull_request:
paths:
- candle-pyo3/**
jobs:
build_and_test:
name: Check everything builds & tests
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest] # For now, only test on Linux
steps:
- name: Checkout repository
uses: actions/checkout@v2
- name: Install Rust
uses: actions-rs/toolchain@v1
with:
toolchain: stable
- name: Install Python
uses: actions/setup-python@v4
with:
python-version: 3.11
architecture: "x64"
- name: Cache Cargo Registry
uses: actions/cache@v1
with:
path: ~/.cargo/registry
key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}
- name: Install Protoc
uses: arduino/setup-protoc@v2
with:
version: "25.0"
repo-token: ${{ secrets.GITHUB_TOKEN }}
- name: Install
working-directory: ./candle-pyo3
run: |
python -m venv .env
source .env/bin/activate
pip install -U pip
pip install pytest maturin black
python -m maturin develop -r --features onnx
- name: Check style
working-directory: ./candle-pyo3
run: |
source .env/bin/activate
python stub.py --check
black --check .
- name: Run tests
working-directory: ./candle-pyo3
run: |
source .env/bin/activate
python -m pytest -s -v tests

8
.gitignore vendored
View File

@ -23,16 +23,14 @@ flamegraph.svg
*.dylib *.dylib
*.so *.so
*.swp *.swp
*.swo
trace-*.json trace-*.json
candle-wasm-examples/*/build candle-wasm-examples/*/build
candle-wasm-examples/*/*.bin candle-wasm-examples/*/*.bin
candle-wasm-examples/*/*.jpeg candle-wasm-examples/*/*.jpeg
candle-wasm-examples/*/audios/*.wav candle-wasm-examples/*/*.wav
candle-wasm-examples/**/*.safetensors candle-wasm-examples/*/*.safetensors
candle-wasm-examples/**/*.gguf
candle-wasm-examples/*/package-lock.json candle-wasm-examples/*/package-lock.json
candle-wasm-examples/**/config*.json
.DS_Store .DS_Store
.idea/* .idea/*

11
.vscode/settings.json vendored
View File

@ -1,11 +0,0 @@
{
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter"
},
"python.formatting.provider": "none",
"python.testing.pytestArgs": [
"candle-pyo3"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}

View File

@ -1,84 +1,13 @@
# Changelog # Changelog
This documents the main changes to the `candle` crate. This documents the main changes to the `candle` crate.
## v0.3.1 - Unreleased ## v0.2.1 - Unreleased
### Added ### Added
### Modified
## v0.3.0 - 2023-10-01
### Added
- Added the Mistral 7b v0.1 model
[983](https://github.com/huggingface/candle/pull/983).
- Quantized version of the Mistral model
[1009](https://github.com/huggingface/candle/pull/1009).
- Add the gelu-erf op and activation function
[969](https://github.com/huggingface/candle/pull/969).
- Add the mixformer/phi-v1.5 model
[930](https://github.com/huggingface/candle/pull/930).
- Add the sclice-scatter op
[927](https://github.com/huggingface/candle/pull/927).
- Add the Wuerstchen diffusion model
[911](https://github.com/huggingface/candle/pull/911).
### Modified
- Support for simd128 intrinsics in some quantized vecdots
[982](https://github.com/huggingface/candle/pull/982).
- Optimize the index-select cuda kernel
[976](https://github.com/huggingface/candle/pull/976).
- Self-contained safetensor wrappers
[946](https://github.com/huggingface/candle/pull/946).
## v0.2.2 - 2023-09-18
### Added
- Support for `top_p` sampling
[819](https://github.com/huggingface/candle/pull/819).
- T5 model including decoding
[864](https://github.com/huggingface/candle/pull/864).
- 1-d upsampling
[839](https://github.com/huggingface/candle/pull/839).
### Modified
- Bugfix for conv2d
[820](https://github.com/huggingface/candle/pull/820).
- Support tensor based indexing using `.i`
[842](https://github.com/huggingface/candle/pull/842).
## v0.2.1 - 2023-09-11
### Added
- Add some RNNs (GRU and LSTM) in `candle-nn`
[674](https://github.com/huggingface/candle/pull/674),
[688](https://github.com/huggingface/candle/pull/688).
- gguf v2 support
[725](https://github.com/huggingface/candle/pull/725).
- Quantized llama example in Python using the pyo3 api
[716](https://github.com/huggingface/candle/pull/716).
- `candle-nn` layer for conv2d-transposed
[760](https://github.com/huggingface/candle/pull/760).
- Add the Segment-Anything Model (SAM) as an example
[773](https://github.com/huggingface/candle/pull/773).
- TinyViT backbone for the segemnt anything example
[787](https://github.com/huggingface/candle/pull/787).
- Shape with holes support
[770](https://github.com/huggingface/candle/pull/770).
### Modified ### Modified
- Dilations are now supported in conv-transpose2d. - Dilations are now supported in conv-transpose2d.
[671](https://github.com/huggingface/candle/pull/671). [671](https://github.com/huggingface/candle/pull/671).
- Interactive mode for the quantized model
[690](https://github.com/huggingface/candle/pull/690).
- Faster softmax operation
[747](https://github.com/huggingface/candle/pull/747).
- Faster convolution operations on CPU and CUDA via im2col
[802](https://github.com/huggingface/candle/pull/802).
- Moving some models to a more central location
[796](https://github.com/huggingface/candle/pull/796).
## v0.2.0 - 2023-08-30 ## v0.2.0 - 2023-08-30

View File

@ -7,19 +7,18 @@ members = [
"candle-nn", "candle-nn",
"candle-pyo3", "candle-pyo3",
"candle-transformers", "candle-transformers",
"candle-wasm-examples/*", "candle-wasm-examples/llama2-c",
"candle-wasm-tests", "candle-wasm-examples/whisper",
"candle-wasm-examples/yolo",
] ]
exclude = [ exclude = [
"candle-flash-attn", "candle-flash-attn",
"candle-kernels", "candle-kernels",
"candle-metal-kernels",
"candle-onnx",
] ]
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]
version = "0.3.1" version = "0.2.1"
edition = "2021" edition = "2021"
description = "Minimalist ML framework." description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle" repository = "https://github.com/huggingface/candle"
@ -33,7 +32,8 @@ anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3" byteorder = "1.4.3"
clap = { version = "4.2.4", features = ["derive"] } clap = { version = "4.2.4", features = ["derive"] }
cudarc = { version = "0.9.14", features = ["f16"] } cudarc = { version = "0.9.14", features = ["f16"] }
gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] } # TODO: Switch back to the official gemm implementation once it has caught up.
gemm = { version = "0.15.6", package = "candle-gemm" }
hf-hub = "0.3.0" hf-hub = "0.3.0"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] } image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
@ -41,17 +41,15 @@ imageproc = { version = "0.23.0", default-features = false }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] } intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
libc = { version = "0.2.147" } libc = { version = "0.2.147" }
log = "0.4" log = "0.4"
memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] } memmap2 = "0.7.1"
num_cpus = "1.15.0" num_cpus = "1.15.0"
num-traits = "0.2.15" num-traits = "0.2.15"
parquet = { version = "45.0.0" }
rand = "0.8.5" rand = "0.8.5"
rand_distr = "0.4.3" rand_distr = "0.4.3"
rayon = "1.7.0" rayon = "1.7.0"
rusttype = { version = "0.9", default-features = false } rusttype = { version = "0.9", default-features = false }
safetensors = "0.3.1" safetensors = "0.3.1"
serde = { version = "1.0.171", features = ["derive"] } serde = { version = "1.0.171", features = ["derive"] }
serde_plain = "1.0.2"
serde_json = "1.0.99" serde_json = "1.0.99"
thiserror = "1" thiserror = "1"
tokenizers = { version = "0.13.4", default-features = false } tokenizers = { version = "0.13.4", default-features = false }
@ -59,9 +57,8 @@ tracing = "0.1.37"
tracing-chrome = "0.7.1" tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7" tracing-subscriber = "0.3.7"
wav = "1.0.0" wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false } zip = { version = "0.6.6", default-features = false }
metal = { version = "0.27.0", features = ["mps"]} parquet = { version = "45.0.0" }
[profile.release-with-debug] [profile.release-with-debug]
inherits = "release" inherits = "release"

166
README.md
View File

@ -8,10 +8,7 @@ Candle is a minimalist ML framework for Rust with a focus on performance (includ
and ease of use. Try our online demos: and ease of use. Try our online demos:
[whisper](https://huggingface.co/spaces/lmz/candle-whisper), [whisper](https://huggingface.co/spaces/lmz/candle-whisper),
[LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2), [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2),
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm), [yolo](https://huggingface.co/spaces/lmz/candle-yolo).
[yolo](https://huggingface.co/spaces/lmz/candle-yolo),
[Segment
Anything](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
## Get started ## Get started
@ -48,69 +45,40 @@ For more advanced examples, please have a look at the following section.
## Check out our examples ## Check out our examples
These online demos run entirely in your browser: Check out our [examples](./candle-examples/examples/):
- [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and
object recognition.
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): speech recognition.
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
- [BLIP](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning): image captioning.
We also provide a some command line based examples using state of the art models:
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
- [Falcon](./candle-examples/examples/falcon/): general LLM.
- [Phi-v1 and Phi-v1.5](./candle-examples/examples/phi/): a 1.3b general LLM with performance on par with LLaMA-v2 7b.
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
pre-trained on 1T tokens of English and code datasets.
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
performance larger than all publicly available 13b models as of 2023-09-28.
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
(English/Chinese) general LLMs with 6b and 34b parameters.
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
the LLaMA model using the same quantization techniques as
[llama.cpp](https://github.com/ggerganov/llama.cpp).
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/quantized/assets/aoc.gif" width="600">
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions.
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200">
- [Wuerstchen](./candle-examples/examples/wuerstchen/): another text to
image generative model.
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/wuerstchen/assets/cat.jpg" width="200">
- [yolo-v3](./candle-examples/examples/yolo-v3/) and
[yolo-v8](./candle-examples/examples/yolo-v8/): object detection and pose
estimation models.
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.od.jpg" width="200"><img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.pose.jpg" width="200">
- [segment-anything](./candle-examples/examples/segment-anything/): image
segmentation model with prompt.
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
- [Whisper](./candle-examples/examples/whisper/): speech recognition model. - [Whisper](./candle-examples/examples/whisper/): speech recognition model.
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/), - [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings. - [Falcon](./candle-examples/examples/falcon/): general LLM.
- [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code
generation.
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions.
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained - [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
using self-supervision (can be used for imagenet classification, depth using self-supervision (can be used for imagenet classification, depth
evaluation, segmentation). evaluation, segmentation).
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to - [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
generate captions for an image. the LLaMA model using the same quantization techniques as
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation [llama.cpp](https://github.com/ggerganov/llama.cpp).
model, generates the translated text from the input text. - [yolo-v3](./candle-examples/examples/yolo-v3/) and
[yolo-v8](./candle-examples/examples/yolo-v8/): object detection and pose
Run them using commands like: estimation models.
[segment-anything](./candle-examples/examples/segment-anything/): image
segmentation model with prompt.
Run them using the following commands:
``` ```
cargo run --example whisper --release
cargo run --example llama --release
cargo run --example falcon --release
cargo run --example bert --release
cargo run --example bigcode --release
cargo run --example stable-diffusion --release -- --prompt "a rusty robot holding a fire torch"
cargo run --example dinov2 --release -- --image path/to/myinput.jpg
cargo run --example quantized --release cargo run --example quantized --release
cargo run --example yolo-v3 --release -- myimage.jpg
cargo run --example yolo-v8 --release -- myimage.jpg # for pose estimation, add --task pose
cargo run --example segment-anything --release -- --image myimage.jpg
``` ```
In order to use **CUDA** add `--features cuda` to the example command line. If In order to use **CUDA** add `--features cuda` to the example command line. If
@ -120,10 +88,7 @@ There are also some wasm examples for whisper and
[llama2.c](https://github.com/karpathy/llama2.c). You can either build them with [llama2.c](https://github.com/karpathy/llama2.c). You can either build them with
`trunk` or try them online: `trunk` or try them online:
[whisper](https://huggingface.co/spaces/lmz/candle-whisper), [whisper](https://huggingface.co/spaces/lmz/candle-whisper),
[llama2](https://huggingface.co/spaces/lmz/candle-llama2), [llama2](https://huggingface.co/spaces/lmz/candle-llama2).
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
[Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm),
[Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
For LLaMA2, run the following command to retrieve the weight files and start a For LLaMA2, run the following command to retrieve the weight files and start a
test server: test server:
@ -136,25 +101,6 @@ trunk serve --release --port 8081
And then head over to And then head over to
[http://localhost:8081/](http://localhost:8081/). [http://localhost:8081/](http://localhost:8081/).
<!--- ANCHOR: useful_libraries --->
## Useful External Resources
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): A
very detailed tutorial showing how to convert a PyTorch model to Candle.
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): Efficient and ergonomic LoRA implemenation for Candle. `candle-lora` has
out-of-the-box LoRA support for many models from Candle, which can be found [here](https://github.com/EricLBuehler/candle-lora/tree/master/candle-lora-transformers/examples).
- [`optimisers`](https://github.com/KGrewal1/optimisers): A collection of optimisers
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
serving local LLMs including an OpenAI compatible API server.
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
If you have an addition to this list, please submit a pull request.
<!--- ANCHOR_END: useful_libraries --->
<!--- ANCHOR: features ---> <!--- ANCHOR: features --->
## Features ## Features
@ -167,34 +113,10 @@ If you have an addition to this list, please submit a pull request.
- CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL. - CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.
- WASM support, run your models in a browser. - WASM support, run your models in a browser.
- Included models. - Included models.
- Language Models. - LLMs: LLaMA v1 and v2, Falcon, StarCoder.
- LLaMA v1 and v2.
- Falcon.
- StarCoder.
- Phi v1.5.
- Mistral 7b v0.1.
- StableLM-3B-4E1T.
- Replit-code-v1.5-3B.
- Bert.
- Yi-6B and Yi-34B.
- Quantized LLMs.
- Llama 7b, 13b, 70b, as well as the chat and code variants.
- Mistral 7b, and 7b instruct.
- Zephyr 7b a and b (Mistral based).
- OpenChat 3.5 (Mistral based).
- Text to text.
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
- Marian MT (Machine Translation).
- Whisper (multi-lingual support). - Whisper (multi-lingual support).
- Text to image. - Stable Diffusion.
- Stable Diffusion v1.5, v2.1, XL v1.0. - Computer Vision: DINOv2, EfficientNet, yolo-v3, yolo-v8.
- Wurstchen v2.
- Image to text.
- BLIP.
- Computer Vision Models.
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
- yolo-v3, yolo-v8.
- Segment-Anything Model (SAM).
- File formats: load models from safetensors, npz, ggml, or PyTorch files. - File formats: load models from safetensors, npz, ggml, or PyTorch files.
- Serverless (on CPU), small and fast deployments. - Serverless (on CPU), small and fast deployments.
- Quantization support using the llama.cpp quantized types. - Quantization support using the llama.cpp quantized types.
@ -231,7 +153,6 @@ Cheatsheet:
- [candle-datasets](./candle-datasets/): Datasets and data loaders. - [candle-datasets](./candle-datasets/): Datasets and data loaders.
- [candle-transformers](./candle-transformers): transformers-related utilities. - [candle-transformers](./candle-transformers): transformers-related utilities.
- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer. - [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer.
- [candle-onnx](./candle-onnx/): ONNX model evaluation.
## FAQ ## FAQ
@ -336,29 +257,6 @@ This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a d
env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ... env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
``` ```
#### Linking error on windows when running rustdoc or mdbook tests
```
Couldn't compile the test.
---- .\candle-book\src\inference\hub.md - Using_the_hub::Using_in_a_real_model_ (line 50) stdout ----
error: linking with `link.exe` failed: exit code: 1181
//very long chain of linking
= note: LINK : fatal error LNK1181: cannot open input file 'windows.0.48.5.lib'
```
Make sure you link all native libraries that might be located outside a project target, e.g., to run mdbook tests, you should run:
```
mdbook test candle-book -L .\target\debug\deps\ `
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.42.2\lib `
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.48.5\lib
```
#### Extremely slow model load time with WSL
This may be caused by the models being loaded from `/mnt/c`, more details on
[stackoverflow](https://stackoverflow.com/questions/68972448/why-is-wsl-extremely-slow-when-compared-with-native-windows-npm-yarn-processing).
#### Tracking down errors #### Tracking down errors
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle

View File

@ -11,11 +11,11 @@ readme = "README.md"
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" } candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.3.1" } candle-datasets = { path = "../candle-datasets", version = "0.2.1" }
candle-nn = { path = "../candle-nn", version = "0.3.1" } candle-nn = { path = "../candle-nn", version = "0.2.1" }
candle-transformers = { path = "../candle-transformers", version = "0.3.1" } candle-transformers = { path = "../candle-transformers", version = "0.2.1" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true } candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
safetensors = { workspace = true } safetensors = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }
@ -24,10 +24,9 @@ intel-mkl-src = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true } cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true } half = { workspace = true, optional = true }
image = { workspace = true, optional = true } image = { workspace = true, optional = true }
anyhow = { workspace = true }
tokio = "1.29.1"
[dev-dependencies] [dev-dependencies]
anyhow = { workspace = true }
byteorder = { workspace = true } byteorder = { workspace = true }
hf-hub = { workspace = true, features=["tokio"]} hf-hub = { workspace = true, features=["tokio"]}
clap = { workspace = true } clap = { workspace = true }
@ -39,6 +38,7 @@ tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true } tracing-subscriber = { workspace = true }
wav = { workspace = true } wav = { workspace = true }
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1 # Necessary to disambiguate with tokio in wasm examples which are 1.28.1
tokio = "1.29.1"
parquet = { workspace = true } parquet = { workspace = true }
image = { workspace = true } image = { workspace = true }

View File

@ -10,11 +10,10 @@
# Reference Guide # Reference Guide
- [Running a model](inference/inference.md) - [Running a model](inference/README.md)
- [Using the hub](inference/hub.md) - [Using the hub](inference/hub.md)
- [Error management](error_manage.md) - [Error management](error_manage.md)
- [Training](training/training.md) - [Training](training/README.md)
- [Simplified](training/simplified.md)
- [MNIST](training/mnist.md) - [MNIST](training/mnist.md)
- [Fine-tuning]() - [Fine-tuning]()
- [Serialization]() - [Serialization]()

View File

@ -29,7 +29,7 @@ After adding `RUST_BACKTRACE=1`:
Error: WithBacktrace { inner: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }, backtrace: Backtrace [{ fn: "candle::error::Error::bt", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/error.rs", line: 200 }, { fn: "candle::tensor::Tensor::matmul", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/tensor.rs", line: 816 }, { fn: "myapp::main", file: "./src/main.rs", line: 29 }, { fn: "core::ops::function::FnOnce::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 250 }, { fn: "std::sys_common::backtrace::__rust_begin_short_backtrace", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/sys_common/backtrace.rs", line: 135 }, { fn: "std::rt::lang_start::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 166 }, { fn: "core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 284 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::rt::lang_start", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 165 }, { fn: "main" }, { fn: "__libc_start_main" }, { fn: "_start" }] } Error: WithBacktrace { inner: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }, backtrace: Backtrace [{ fn: "candle::error::Error::bt", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/error.rs", line: 200 }, { fn: "candle::tensor::Tensor::matmul", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/tensor.rs", line: 816 }, { fn: "myapp::main", file: "./src/main.rs", line: 29 }, { fn: "core::ops::function::FnOnce::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 250 }, { fn: "std::sys_common::backtrace::__rust_begin_short_backtrace", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/sys_common/backtrace.rs", line: 135 }, { fn: "std::rt::lang_start::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 166 }, { fn: "core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 284 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::rt::lang_start", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 165 }, { fn: "main" }, { fn: "__libc_start_main" }, { fn: "_start" }] }
``` ```
Not super pretty at the moment, but we can see error occurred on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }` Not super pretty at the moment, but we can see error occured on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }`
Another thing to note, is that since Rust is compiled it is not necessarily as easy to recover proper stacktraces Another thing to note, is that since Rust is compiled it is not necessarily as easy to recover proper stacktraces

View File

@ -6,7 +6,7 @@ Open `src/main.rs` and fill in this content:
```rust ```rust
# extern crate candle_core; # extern crate candle_core;
use candle_core::{Device, Result, Tensor}; use candle_core::{DType, Device, Result, Tensor};
struct Model { struct Model {
first: Tensor, first: Tensor,
@ -25,11 +25,11 @@ fn main() -> Result<()> {
// Use Device::new_cuda(0)?; to use the GPU. // Use Device::new_cuda(0)?; to use the GPU.
let device = Device::Cpu; let device = Device::Cpu;
let first = Tensor::randn(0f32, 1.0, (784, 100), &device)?; let first = Tensor::zeros((784, 100), DType::F32, &device)?;
let second = Tensor::randn(0f32, 1.0, (100, 10), &device)?; let second = Tensor::zeros((100, 10), DType::F32, &device)?;
let model = Model { first, second }; let model = Model { first, second };
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?; let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
let digit = model.forward(&dummy_image)?; let digit = model.forward(&dummy_image)?;
println!("Digit {digit:?} digit"); println!("Digit {digit:?} digit");
@ -50,7 +50,7 @@ the classical `Linear` layer. We can do as such
```rust ```rust
# extern crate candle_core; # extern crate candle_core;
# use candle_core::{Device, Result, Tensor}; # use candle_core::{DType, Device, Result, Tensor};
struct Linear{ struct Linear{
weight: Tensor, weight: Tensor,
bias: Tensor, bias: Tensor,
@ -80,7 +80,7 @@ This will change the model running code into a new function
```rust ```rust
# extern crate candle_core; # extern crate candle_core;
# use candle_core::{Device, Result, Tensor}; # use candle_core::{DType, Device, Result, Tensor};
# struct Linear{ # struct Linear{
# weight: Tensor, # weight: Tensor,
# bias: Tensor, # bias: Tensor,
@ -110,15 +110,15 @@ fn main() -> Result<()> {
let device = Device::cuda_if_available(0)?; let device = Device::cuda_if_available(0)?;
// Creating a dummy model // Creating a dummy model
let weight = Tensor::randn(0f32, 1.0, (784, 100), &device)?; let weight = Tensor::zeros((784, 100), DType::F32, &device)?;
let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?; let bias = Tensor::zeros((100, ), DType::F32, &device)?;
let first = Linear{weight, bias}; let first = Linear{weight, bias};
let weight = Tensor::randn(0f32, 1.0, (100, 10), &device)?; let weight = Tensor::zeros((100, 10), DType::F32, &device)?;
let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?; let bias = Tensor::zeros((10, ), DType::F32, &device)?;
let second = Linear{weight, bias}; let second = Linear{weight, bias};
let model = Model { first, second }; let model = Model { first, second };
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?; let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
// Inference on the model // Inference on the model
let digit = model.forward(&dummy_image)?; let digit = model.forward(&dummy_image)?;
@ -146,7 +146,7 @@ And rewrite our examples using it
```rust ```rust
# extern crate candle_core; # extern crate candle_core;
# extern crate candle_nn; # extern crate candle_nn;
use candle_core::{Device, Result, Tensor}; use candle_core::{DType, Device, Result, Tensor};
use candle_nn::{Linear, Module}; use candle_nn::{Linear, Module};
struct Model { struct Model {
@ -167,15 +167,15 @@ fn main() -> Result<()> {
let device = Device::Cpu; let device = Device::Cpu;
// This has changed (784, 100) -> (100, 784) ! // This has changed (784, 100) -> (100, 784) !
let weight = Tensor::randn(0f32, 1.0, (100, 784), &device)?; let weight = Tensor::zeros((100, 784), DType::F32, &device)?;
let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?; let bias = Tensor::zeros((100, ), DType::F32, &device)?;
let first = Linear::new(weight, Some(bias)); let first = Linear::new(weight, Some(bias));
let weight = Tensor::randn(0f32, 1.0, (10, 100), &device)?; let weight = Tensor::zeros((10, 100), DType::F32, &device)?;
let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?; let bias = Tensor::zeros((10, ), DType::F32, &device)?;
let second = Linear::new(weight, Some(bias)); let second = Linear::new(weight, Some(bias));
let model = Model { first, second }; let model = Model { first, second };
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?; let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
let digit = model.forward(&dummy_image)?; let digit = model.forward(&dummy_image)?;
println!("Digit {digit:?} digit"); println!("Digit {digit:?} digit");
@ -188,8 +188,8 @@ Feel free to modify this example to use `Conv2d` to create a classical convnet i
Now that we have the running dummy code we can get to more advanced topics: Now that we have the running dummy code we can get to more advanced topics:
- [For PyTorch users](../guide/cheatsheet.md) - [For PyTorch users](./guide/cheatsheet.md)
- [Running existing models](../inference/inference.md) - [Running existing models](./inference/README.md)
- [Training models](../training/training.md) - [Training models](./training/README.md)

View File

@ -12,9 +12,6 @@ compute_cap
8.9 8.9
``` ```
You can also compile the Cuda kernels for a specific compute cap using the
`CUDA_COMPUTE_CAP=<compute cap>` environment variable.
If any of the above commands errors out, please make sure to update your Cuda version. If any of the above commands errors out, please make sure to update your Cuda version.
2. Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) with Cuda support. 2. Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) with Cuda support.

View File

@ -1,6 +1,3 @@
#[cfg(test)]
pub mod simplified;
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use anyhow::Result; use anyhow::Result;

View File

@ -1,196 +0,0 @@
//! #A simplified example in Rust of training a neural network and then using it based on the Candle Framework by Hugging Face.
//! Author: Evgeny Igumnov 2023 igumnovnsk@gmail.com
//! This program implements a neural network to predict the winner of the second round of elections based on the results of the first round.
//!
//! ##Basic moments:
//!
//! A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons.
//! The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage.
//! The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose.
//! For training, samples with real data on the results of the first and second stages of different elections are used.
//! The model is trained by backpropagation using gradient descent and the cross-entropy loss function.
//! Model parameters (weights of neurons) are initialized randomly, then optimized during training.
//! After training, the model is tested on a deferred sample to evaluate the accuracy.
//! If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated.
//! Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data.
#[rustfmt::skip]
mod tests {
use candle::{DType, Result, Tensor, D, Device};
use candle_nn::{loss, ops, Linear, Module, VarBuilder, VarMap, Optimizer};
// ANCHOR: book_training_simplified1
const VOTE_DIM: usize = 2;
const RESULTS: usize = 1;
const EPOCHS: usize = 10;
const LAYER1_OUT_SIZE: usize = 4;
const LAYER2_OUT_SIZE: usize = 2;
const LEARNING_RATE: f64 = 0.05;
#[derive(Clone)]
pub struct Dataset {
pub train_votes: Tensor,
pub train_results: Tensor,
pub test_votes: Tensor,
pub test_results: Tensor,
}
struct MultiLevelPerceptron {
ln1: Linear,
ln2: Linear,
ln3: Linear,
}
impl MultiLevelPerceptron {
fn new(vs: VarBuilder) -> Result<Self> {
let ln1 = candle_nn::linear(VOTE_DIM, LAYER1_OUT_SIZE, vs.pp("ln1"))?;
let ln2 = candle_nn::linear(LAYER1_OUT_SIZE, LAYER2_OUT_SIZE, vs.pp("ln2"))?;
let ln3 = candle_nn::linear(LAYER2_OUT_SIZE, RESULTS + 1, vs.pp("ln3"))?;
Ok(Self { ln1, ln2, ln3 })
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.ln1.forward(xs)?;
let xs = xs.relu()?;
let xs = self.ln2.forward(&xs)?;
let xs = xs.relu()?;
self.ln3.forward(&xs)
}
}
// ANCHOR_END: book_training_simplified1
// ANCHOR: book_training_simplified3
#[tokio::test]
async fn simplified() -> anyhow::Result<()> {
let dev = Device::cuda_if_available(0)?;
let train_votes_vec: Vec<u32> = vec![
15, 10,
10, 15,
5, 12,
30, 20,
16, 12,
13, 25,
6, 14,
31, 21,
];
let train_votes_tensor = Tensor::from_vec(train_votes_vec.clone(), (train_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
let train_results_vec: Vec<u32> = vec![
1,
0,
0,
1,
1,
0,
0,
1,
];
let train_results_tensor = Tensor::from_vec(train_results_vec, train_votes_vec.len() / VOTE_DIM, &dev)?;
let test_votes_vec: Vec<u32> = vec![
13, 9,
8, 14,
3, 10,
];
let test_votes_tensor = Tensor::from_vec(test_votes_vec.clone(), (test_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
let test_results_vec: Vec<u32> = vec![
1,
0,
0,
];
let test_results_tensor = Tensor::from_vec(test_results_vec.clone(), test_results_vec.len(), &dev)?;
let m = Dataset {
train_votes: train_votes_tensor,
train_results: train_results_tensor,
test_votes: test_votes_tensor,
test_results: test_results_tensor,
};
let trained_model: MultiLevelPerceptron;
loop {
println!("Trying to train neural network.");
match train(m.clone(), &dev) {
Ok(model) => {
trained_model = model;
break;
},
Err(e) => {
println!("Error: {}", e);
continue;
}
}
}
let real_world_votes: Vec<u32> = vec![
13, 22,
];
let tensor_test_votes = Tensor::from_vec(real_world_votes.clone(), (1, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
let final_result = trained_model.forward(&tensor_test_votes)?;
let result = final_result
.argmax(D::Minus1)?
.to_dtype(DType::F32)?
.get(0).map(|x| x.to_scalar::<f32>())??;
println!("real_life_votes: {:?}", real_world_votes);
println!("neural_network_prediction_result: {:?}", result);
Ok(())
}
// ANCHOR_END: book_training_simplified3
// ANCHOR: book_training_simplified2
fn train(m: Dataset, dev: &Device) -> anyhow::Result<MultiLevelPerceptron> {
let train_results = m.train_results.to_device(dev)?;
let train_votes = m.train_votes.to_device(dev)?;
let varmap = VarMap::new();
let vs = VarBuilder::from_varmap(&varmap, DType::F32, dev);
let model = MultiLevelPerceptron::new(vs.clone())?;
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), LEARNING_RATE)?;
let test_votes = m.test_votes.to_device(dev)?;
let test_results = m.test_results.to_device(dev)?;
let mut final_accuracy: f32 = 0.0;
for epoch in 1..EPOCHS + 1 {
let logits = model.forward(&train_votes)?;
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
let loss = loss::nll(&log_sm, &train_results)?;
sgd.backward_step(&loss)?;
let test_logits = model.forward(&test_votes)?;
let sum_ok = test_logits
.argmax(D::Minus1)?
.eq(&test_results)?
.to_dtype(DType::F32)?
.sum_all()?
.to_scalar::<f32>()?;
let test_accuracy = sum_ok / test_results.dims1()? as f32;
final_accuracy = 100. * test_accuracy;
println!("Epoch: {epoch:3} Train loss: {:8.5} Test accuracy: {:5.2}%",
loss.to_scalar::<f32>()?,
final_accuracy
);
if final_accuracy == 100.0 {
break;
}
}
if final_accuracy < 100.0 {
Err(anyhow::Error::msg("The model is not trained well enough."))
} else {
Ok(model)
}
}
// ANCHOR_END: book_training_simplified2
}

View File

@ -1,45 +0,0 @@
# Simplified
## How its works
This program implements a neural network to predict the winner of the second round of elections based on the results of the first round.
Basic moments:
1. A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons.
2. The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage.
3. The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose.
4. For training, samples with real data on the results of the first and second stages of different elections are used.
5. The model is trained by backpropagation using gradient descent and the cross-entropy loss function.
6. Model parameters (weights of neurons) are initialized randomly, then optimized during training.
7. After training, the model is tested on a deferred sample to evaluate the accuracy.
8. If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated.
Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data.
```rust,ignore
{{#include ../simplified.rs:book_training_simplified1}}
```
```rust,ignore
{{#include ../simplified.rs:book_training_simplified2}}
```
```rust,ignore
{{#include ../simplified.rs:book_training_simplified3}}
```
## Example output
```bash
Trying to train neural network.
Epoch: 1 Train loss: 4.42555 Test accuracy: 0.00%
Epoch: 2 Train loss: 0.84677 Test accuracy: 33.33%
Epoch: 3 Train loss: 2.54335 Test accuracy: 33.33%
Epoch: 4 Train loss: 0.37806 Test accuracy: 33.33%
Epoch: 5 Train loss: 0.36647 Test accuracy: 100.00%
real_life_votes: [13, 22]
neural_network_prediction_result: 0.0
```

View File

@ -12,9 +12,7 @@ readme = "README.md"
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true } byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.3.1", optional = true } candle-kernels = { path = "../candle-kernels", version = "0.2.1", optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.1", optional = true }
metal = { workspace = true, optional = true}
cudarc = { workspace = true, optional = true } cudarc = { workspace = true, optional = true }
gemm = { workspace = true } gemm = { workspace = true }
half = { workspace = true } half = { workspace = true }
@ -28,7 +26,6 @@ rand_distr = { workspace = true }
rayon = { workspace = true } rayon = { workspace = true }
safetensors = { workspace = true } safetensors = { workspace = true }
thiserror = { workspace = true } thiserror = { workspace = true }
yoke = { workspace = true }
zip = { workspace = true } zip = { workspace = true }
[dev-dependencies] [dev-dependencies]
@ -41,4 +38,3 @@ cuda = ["cudarc", "dep:candle-kernels"]
cudnn = ["cuda", "cudarc/cudnn"] cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"] mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"] accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels"]

View File

@ -8,10 +8,11 @@ use anyhow::Result;
use candle_core::{Device, Tensor}; use candle_core::{Device, Tensor};
fn main() -> Result<()> { fn main() -> Result<()> {
let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?; let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?; let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
let new_a = a.slice_scatter(&b, 1, 2)?; let start = std::time::Instant::now();
assert_eq!(a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let res = inp.conv2d(&w, 0, 1, 1, 1)?;
assert_eq!(new_a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); println!("{:?}", start.elapsed());
println!("{res:?}");
Ok(()) Ok(())
} }

View File

@ -103,10 +103,8 @@ enum Command {
Quantize { Quantize {
/// The input file, in gguf format. /// The input file, in gguf format.
in_file: Vec<std::path::PathBuf>, in_file: std::path::PathBuf,
/// The output file, in gguf format. /// The output file, in gguf format.
#[arg(long)]
out_file: std::path::PathBuf, out_file: std::path::PathBuf,
/// The quantization schema to apply. /// The quantization schema to apply.
@ -152,7 +150,8 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
} }
} }
Format::Safetensors => { Format::Safetensors => {
let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? }; let tensors = unsafe { candle_core::safetensors::MmapedFile::new(file)? };
let tensors = tensors.deserialize()?;
let mut tensors = tensors.tensors(); let mut tensors = tensors.tensors();
tensors.sort_by(|a, b| a.0.cmp(&b.0)); tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, view) in tensors.iter() { for (name, view) in tensors.iter() {
@ -219,99 +218,15 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
Ok(()) Ok(())
} }
fn run_quantize_safetensors(
in_files: &[std::path::PathBuf],
out_file: std::path::PathBuf,
q: Quantization,
) -> Result<()> {
let mut out_file = std::fs::File::create(out_file)?;
let mut tensors = std::collections::HashMap::new();
for in_file in in_files.iter() {
let in_tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?;
tensors.extend(in_tensors)
}
println!("tensors: {}", tensors.len());
let quantize_fn = match q {
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
Quantization::F16 => QTensor::quantize::<half::f16>,
Quantization::F32 => QTensor::quantize::<f32>,
};
let block_size = match q {
Quantization::Q4_0 => k_quants::QK4_0,
Quantization::Q4_1 => k_quants::QK4_1,
Quantization::Q5_0 => k_quants::QK5_0,
Quantization::Q5_1 => k_quants::QK5_1,
Quantization::Q8_0 => k_quants::QK8_0,
Quantization::Q8_1 => k_quants::QK8_1,
Quantization::Q2k
| Quantization::Q3k
| Quantization::Q4k
| Quantization::Q5k
| Quantization::Q6k
| Quantization::Q8k => k_quants::QK_K,
Quantization::F16 | Quantization::F32 => 1,
};
let qtensors = tensors
.into_par_iter()
.map(|(name, tensor)| {
let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0;
println!(" quantizing {name} {tensor:?} {should_quantize}");
let tensor = if should_quantize {
quantize_fn(&tensor)?
} else {
QTensor::quantize::<f32>(&tensor)?
};
Ok((name, tensor))
})
.collect::<Result<Vec<_>>>()?;
let qtensors = qtensors
.iter()
.map(|(k, v)| (k.as_str(), v))
.collect::<Vec<_>>();
gguf_file::write(&mut out_file, &[], &qtensors)?;
Ok(())
}
fn run_quantize( fn run_quantize(
in_files: &[std::path::PathBuf], in_file: std::path::PathBuf,
out_file: std::path::PathBuf, out_file: std::path::PathBuf,
q: Quantization, q: Quantization,
qmode: QuantizationMode, qmode: QuantizationMode,
) -> Result<()> { ) -> Result<()> {
if in_files.is_empty() {
candle_core::bail!("no specified input files")
}
if let Some(extension) = out_file.extension() {
if extension == "safetensors" {
candle_core::bail!("the generated file cannot use the safetensors extension")
}
}
if let Some(extension) = in_files[0].extension() {
if extension == "safetensors" {
return run_quantize_safetensors(in_files, out_file, q);
}
}
if in_files.len() != 1 {
candle_core::bail!("only a single in-file can be used when quantizing gguf files")
}
// Open the out file early so as to fail directly on missing directories etc. // Open the out file early so as to fail directly on missing directories etc.
let mut out_file = std::fs::File::create(out_file)?; let mut out_file = std::fs::File::create(out_file)?;
let mut in_ = std::fs::File::open(&in_files[0])?; let mut in_ = std::fs::File::open(&in_file)?;
let content = gguf_file::Content::read(&mut in_)?; let content = gguf_file::Content::read(&mut in_)?;
println!("tensors: {}", content.tensor_infos.len()); println!("tensors: {}", content.tensor_infos.len());
@ -337,7 +252,7 @@ fn run_quantize(
.par_iter() .par_iter()
.map(|(name, _)| { .map(|(name, _)| {
println!(" quantizing {name}"); println!(" quantizing {name}");
let mut in_file = std::fs::File::open(&in_files[0])?; let mut in_file = std::fs::File::open(&in_file)?;
let tensor = content.tensor(&mut in_file, name)?; let tensor = content.tensor(&mut in_file, name)?;
let tensor = qmode.quantize(name, tensor, quantize_fn)?; let tensor = qmode.quantize(name, tensor, quantize_fn)?;
Ok((name, tensor)) Ok((name, tensor))
@ -378,7 +293,7 @@ fn main() -> anyhow::Result<()> {
out_file, out_file,
quantization, quantization,
mode, mode,
} => run_quantize(&in_file, out_file, quantization, mode)?, } => run_quantize(in_file, out_file, quantization, mode)?,
} }
Ok(()) Ok(())
} }

View File

@ -370,38 +370,6 @@ pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a) y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a)
} }
#[inline]
pub fn vs_tanh_inplace(y: &mut [f32]) {
unsafe { ffi::vvtanhf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
}
#[inline]
pub fn vd_tanh_inplace(y: &mut [f64]) {
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
}
#[inline]
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
}
vs_tanh_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = 0.5 * v * (1.0 + *y)
}
}
#[inline]
pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
}
vd_tanh_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = 0.5 * v * (1.0 + *y)
}
}
macro_rules! binary_op { macro_rules! binary_op {
($fn_name:ident, $ty:ty, $accelerate_name:ident) => { ($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
#[inline] #[inline]

View File

@ -39,14 +39,6 @@ pub trait BackendStorage: Sized {
_params: &crate::conv::ParamsConv1D, _params: &crate::conv::ParamsConv1D,
) -> Result<Self>; ) -> Result<Self>;
fn conv_transpose1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self>;
fn conv2d( fn conv2d(
&self, &self,
_l: &Layout, _l: &Layout,
@ -65,7 +57,6 @@ pub trait BackendStorage: Sized {
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>; fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>; fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self>;
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>; fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>; fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
@ -119,6 +110,4 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>; fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>; fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
fn set_seed(&self, _: u64) -> Result<()>;
} }

View File

@ -15,17 +15,6 @@ fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result
} }
} }
thread_local! {
static CANDLE_GRAD_DO_NOT_DETACH: bool = {
match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") {
Ok(s) => {
!s.is_empty() && s != "0"
},
Err(_) => false,
}
}
}
impl Tensor { impl Tensor {
/// Return all the nodes that lead to this value in a topologically sorted vec, the first /// Return all the nodes that lead to this value in a topologically sorted vec, the first
/// elements having dependencies on the latter ones, e.g. the first element if any is the /// elements having dependencies on the latter ones, e.g. the first element if any is the
@ -47,8 +36,6 @@ impl Tensor {
// Do not call recursively on the "leaf" nodes. // Do not call recursively on the "leaf" nodes.
track_grad = true; track_grad = true;
nodes nodes
} else if node.dtype().is_int() {
nodes
} else if let Some(op) = node.op() { } else if let Some(op) = node.op() {
match op { match op {
Op::IndexAdd(t1, t2, t3, _) Op::IndexAdd(t1, t2, t3, _)
@ -68,11 +55,6 @@ impl Tensor {
kernel: rhs, kernel: rhs,
.. ..
} }
| Op::ConvTranspose1D {
arg: lhs,
kernel: rhs,
..
}
| Op::Conv2D { | Op::Conv2D {
arg: lhs, arg: lhs,
kernel: rhs, kernel: rhs,
@ -87,8 +69,7 @@ impl Tensor {
| Op::Binary(lhs, rhs, _) | Op::Binary(lhs, rhs, _)
| Op::Gather(lhs, rhs, _) | Op::Gather(lhs, rhs, _)
| Op::IndexSelect(lhs, rhs, _) | Op::IndexSelect(lhs, rhs, _)
| Op::Matmul(lhs, rhs) | Op::Matmul(lhs, rhs) => {
| Op::SliceScatter0(lhs, rhs, _) => {
let (tg, nodes) = walk(lhs, nodes, already_seen); let (tg, nodes) = walk(lhs, nodes, already_seen);
track_grad |= tg; track_grad |= tg;
let (tg, nodes) = walk(rhs, nodes, already_seen); let (tg, nodes) = walk(rhs, nodes, already_seen);
@ -109,18 +90,15 @@ impl Tensor {
nodes nodes
} }
} }
Op::Unary(_node, UnaryOp::Ceil)
| Op::Unary(_node, UnaryOp::Floor)
| Op::Unary(_node, UnaryOp::Round) => nodes,
Op::Reshape(node) Op::Reshape(node)
| Op::UpsampleNearest1D(node)
| Op::UpsampleNearest2D(node) | Op::UpsampleNearest2D(node)
| Op::AvgPool2D { arg: node, .. } | Op::AvgPool2D { arg: node, .. }
| Op::MaxPool2D { arg: node, .. } | Op::MaxPool2D { arg: node, .. }
| Op::Copy(node) | Op::Copy(node)
| Op::Broadcast(node) | Op::Broadcast(node)
| Op::Cmp(node, _) | Op::Cmp(node, _)
| Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _) | Op::Reduce(node, _, _)
| Op::ToDType(node)
| Op::ToDevice(node) | Op::ToDevice(node)
| Op::Transpose(node, _, _) | Op::Transpose(node, _, _)
| Op::Permute(node, _) | Op::Permute(node, _)
@ -133,16 +111,6 @@ impl Tensor {
track_grad |= tg; track_grad |= tg;
nodes nodes
} }
Op::ToDType(node) => {
if node.dtype().is_float() {
let (tg, nodes) = walk(node, nodes, already_seen);
track_grad |= tg;
nodes
} else {
nodes
}
}
Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
} }
} else { } else {
nodes nodes
@ -166,16 +134,10 @@ impl Tensor {
if node.is_variable() { if node.is_variable() {
continue; continue;
} }
let grad = grads let grad = grads.remove(node).unwrap();
.remove(node) // TODO: We should perform all these operations in place (or at least not track the
.expect("candle internal error - grad not populated"); // whole graph). The only drawback would be if we wanted to support grad of grad but
// https://github.com/huggingface/candle/issues/1241 // this is out of scope.
// Ideally, we would make these operations in place where possible to ensure that we
// do not have to allocate too often. Here we just call `.detach` to avoid computing
// the backprop graph of the backprop itself. This would be an issue for second order
// derivatives but these are out of scope at the moment.
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
let grad = if do_not_detach { grad } else { grad.detach()? };
if let Some(op) = node.op() { if let Some(op) = node.op() {
match op { match op {
Op::Binary(lhs, rhs, BinaryOp::Add) => { Op::Binary(lhs, rhs, BinaryOp::Add) => {
@ -230,44 +192,7 @@ impl Tensor {
let f_grad = pred.where_cond(&zeros, &grad)?; let f_grad = pred.where_cond(&zeros, &grad)?;
*f_sum_grad = f_sum_grad.add(&f_grad)?; *f_sum_grad = f_sum_grad.add(&f_grad)?;
} }
Op::Conv1D { Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
arg,
kernel,
padding,
stride,
dilation,
} => {
// The output height for conv_transpose1d is:
// (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1
let grad_l_in = grad.dim(2)?;
let k_size = kernel.dim(2)?;
let out_size =
(grad_l_in - 1) * stride + dilation * (k_size - 1) + 1 - 2 * padding;
let out_padding = arg.dim(2)? - out_size;
let grad_arg = grad.conv_transpose1d(
kernel,
*padding,
out_padding,
*stride,
*dilation,
)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
let grad_kernel = arg
.transpose(0, 1)?
.conv1d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
let (_, _, k0) = kernel.dims3()?;
let (_, _, g_k0) = grad_kernel.dims3()?;
let grad_kernel = if g_k0 != k0 {
grad_kernel.narrow(2, 0, k0)?
} else {
grad_kernel
};
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::Conv2D { Op::Conv2D {
arg, arg,
kernel, kernel,
@ -297,18 +222,8 @@ impl Tensor {
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)? .conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.transpose(0, 1)?; .transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?; let sum_grad = grads.or_insert(kernel)?;
let (_, _, k0, k1) = kernel.dims4()?;
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
} else {
grad_kernel
};
*sum_grad = sum_grad.add(&grad_kernel)?; *sum_grad = sum_grad.add(&grad_kernel)?;
} }
Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
op: "conv-transpose1d",
})?,
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported { Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
op: "conv-transpose2d", op: "conv-transpose2d",
})?, })?,
@ -347,21 +262,9 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?; *sum_grad = sum_grad.add(&grad_arg)?;
} }
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest1d",
})?,
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported { Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest2d", op: "upsample-nearest2d",
})?, })?,
Op::SliceScatter0(lhs, rhs, start_rhs) => {
let rhs_sum_grad = grads.or_insert(rhs)?;
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
let lhs_sum_grad = grads.or_insert(lhs)?;
let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?
}
Op::Gather(arg, indexes, dim) => { Op::Gather(arg, indexes, dim) => {
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?; *sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
@ -453,7 +356,7 @@ impl Tensor {
} }
Op::ToDType(arg) => { Op::ToDType(arg) => {
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad.to_dtype(arg.dtype())?)? *sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
} }
Op::Copy(arg) => { Op::Copy(arg) => {
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
@ -533,54 +436,13 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)? *sum_grad = sum_grad.add(&arg_grad)?
} }
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?, Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
Op::Unary(_, UnaryOp::Floor) => {
Err(Error::BackwardNotSupported { op: "floor" })?
}
Op::Unary(_, UnaryOp::Round) => {
Err(Error::BackwardNotSupported { op: "round" })?
}
Op::Unary(arg, UnaryOp::Gelu) => {
let sum_grad = grads.or_insert(arg)?;
let cube = arg.powf(3.)?;
let tanh = (0.0356774 * &cube + (0.797885 * arg)?)?.tanh()?;
let gelu_grad = (((0.5 * &tanh)?
+ (0.0535161 * cube + (0.398942 * arg)?)? * (1. - tanh.powf(2.)?))?
+ 0.5)?;
*sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
}
Op::Unary(arg, UnaryOp::Erf) => {
let sum_grad = grads.or_insert(arg)?;
// d/dx erf(x) = 2/sqrt(pi) * e^(-x^2)
let erf_grad =
(2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?;
*sum_grad = sum_grad.add(&(&grad * erf_grad)?)?
}
Op::Unary(arg, UnaryOp::GeluErf) => {
let sum_grad = grads.or_insert(arg)?;
// d/dx gelu_erf(x) = 0.5 + 0.398942 e^(-x^2/2) x + 0.5 erf(x/sqrt(2))
let neg_half_square = (arg.sqr()?.neg()? / 2.)?;
let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?;
let arg_scaled_sqrt = (arg / 2f64.sqrt())?;
let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?;
let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?;
*sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?;
}
Op::Unary(arg, UnaryOp::Relu) => { Op::Unary(arg, UnaryOp::Relu) => {
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?; let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)? *sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
} }
Op::Elu(arg, alpha) => { Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
let sum_grad = grads.or_insert(arg)?;
let zeros = arg.zeros_like()?;
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?;
let combined_mask = (positive_mask + negative_exp_mask)?;
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
}
Op::Powf(arg, e) => { Op::Powf(arg, e) => {
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?; let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
@ -655,7 +517,6 @@ impl Tensor {
} }
} }
#[derive(Debug)]
pub struct GradStore(HashMap<TensorId, Tensor>); pub struct GradStore(HashMap<TensorId, Tensor>);
impl GradStore { impl GradStore {

View File

@ -25,46 +25,6 @@ impl ParamsConv1D {
} }
} }
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConvTranspose1D {
pub(crate) b_size: usize,
pub(crate) l_in: usize,
pub(crate) c_out: usize,
pub(crate) c_in: usize,
pub(crate) k_size: usize,
pub(crate) padding: usize,
pub(crate) output_padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
}
impl ParamsConvTranspose1D {
pub(crate) fn l_out(&self) -> usize {
(self.l_in - 1) * self.stride - 2 * self.padding
+ self.dilation * (self.k_size - 1)
+ self.output_padding
+ 1
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
let l_out = self.l_out();
vec![self.b_size, self.c_out, l_out]
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum CudnnFwdAlgo {
ImplicitGemm,
ImplicitPrecompGemm,
Gemm,
Direct,
Fft,
FftTiling,
Winograd,
WinogradNonFused,
Count,
}
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConv2D { pub struct ParamsConv2D {
pub(crate) b_size: usize, pub(crate) b_size: usize,
@ -77,7 +37,6 @@ pub struct ParamsConv2D {
pub(crate) padding: usize, pub(crate) padding: usize,
pub(crate) stride: usize, pub(crate) stride: usize,
pub(crate) dilation: usize, pub(crate) dilation: usize,
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
} }
impl ParamsConv2D { impl ParamsConv2D {
@ -187,49 +146,6 @@ impl Tensor {
} }
} }
/// Applies a 1D transposed convolution over the input tensor.
pub fn conv_transpose1d(
&self,
kernel: &Self,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
) -> Result<Self> {
let (b_size, c_in, l_in) = self.dims3()?;
let (c_in_k, c_out, k_size) = kernel.dims3()?;
if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
}
let params = ParamsConvTranspose1D {
b_size,
l_in,
k_size,
c_out,
c_in,
padding,
output_padding,
stride,
dilation,
};
let storage = self.storage().conv_transpose1d(
self.layout(),
&kernel.storage(),
kernel.layout(),
&params,
)?;
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
arg,
kernel,
padding: params.padding,
output_padding: params.output_padding,
stride: params.stride,
dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
}
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> { fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
let storage = let storage =
self.storage() self.storage()
@ -272,7 +188,6 @@ impl Tensor {
padding, padding,
stride, stride,
dilation, dilation,
cudnn_fwd_algo: None,
}; };
if groups == 1 { if groups == 1 {
self.conv2d_single_group(kernel, &params) self.conv2d_single_group(kernel, &params)

View File

@ -1,763 +0,0 @@
#![allow(clippy::excessive_precision)]
// Code taken from https://github.com/statrs-dev/statrs
//! Provides the [error](https://en.wikipedia.org/wiki/Error_function) and
//! related functions
mod evaluate {
//! Provides functions that don't have a numerical solution and must
//! be solved computationally (e.g. evaluation of a polynomial)
/// evaluates a polynomial at `z` where `coeff` are the coeffecients
/// to a polynomial of order `k` where `k` is the length of `coeff` and the
/// coeffecient
/// to the `k`th power is the `k`th element in coeff. E.g. [3,-1,2] equates to
/// `2z^2 - z + 3`
///
/// # Remarks
///
/// Returns 0 for a 0 length coefficient slice
pub fn polynomial(z: f64, coeff: &[f64]) -> f64 {
let n = coeff.len();
if n == 0 {
return 0.0;
}
let mut sum = *coeff.last().unwrap();
for c in coeff[0..n - 1].iter().rev() {
sum = *c + z * sum;
}
sum
}
}
use std::f64;
/// `erf` calculates the error function at `x`.
pub fn erf(x: f64) -> f64 {
if x.is_nan() {
f64::NAN
} else if x >= 0.0 && x.is_infinite() {
1.0
} else if x <= 0.0 && x.is_infinite() {
-1.0
} else if x == 0. {
0.0
} else {
erf_impl(x, false)
}
}
/// `erf_inv` calculates the inverse error function
/// at `x`.
pub fn erf_inv(x: f64) -> f64 {
if x == 0.0 {
0.0
} else if x >= 1.0 {
f64::INFINITY
} else if x <= -1.0 {
f64::NEG_INFINITY
} else if x < 0.0 {
erf_inv_impl(-x, 1.0 + x, -1.0)
} else {
erf_inv_impl(x, 1.0 - x, 1.0)
}
}
/// `erfc` calculates the complementary error function
/// at `x`.
pub fn erfc(x: f64) -> f64 {
if x.is_nan() {
f64::NAN
} else if x == f64::INFINITY {
0.0
} else if x == f64::NEG_INFINITY {
2.0
} else {
erf_impl(x, true)
}
}
/// `erfc_inv` calculates the complementary inverse
/// error function at `x`.
pub fn erfc_inv(x: f64) -> f64 {
if x <= 0.0 {
f64::INFINITY
} else if x >= 2.0 {
f64::NEG_INFINITY
} else if x > 1.0 {
erf_inv_impl(-1.0 + x, 2.0 - x, -1.0)
} else {
erf_inv_impl(1.0 - x, x, 1.0)
}
}
// **********************************************************
// ********** Coefficients for erf_impl polynomial **********
// **********************************************************
/// Polynomial coefficients for a numerator of `erf_impl`
/// in the interval [1e-10, 0.5].
const ERF_IMPL_AN: &[f64] = &[
0.00337916709551257388990745,
-0.00073695653048167948530905,
-0.374732337392919607868241,
0.0817442448733587196071743,
-0.0421089319936548595203468,
0.0070165709512095756344528,
-0.00495091255982435110337458,
0.000871646599037922480317225,
];
/// Polynomial coefficients for a denominator of `erf_impl`
/// in the interval [1e-10, 0.5]
const ERF_IMPL_AD: &[f64] = &[
1.0,
-0.218088218087924645390535,
0.412542972725442099083918,
-0.0841891147873106755410271,
0.0655338856400241519690695,
-0.0120019604454941768171266,
0.00408165558926174048329689,
-0.000615900721557769691924509,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [0.5, 0.75].
const ERF_IMPL_BN: &[f64] = &[
-0.0361790390718262471360258,
0.292251883444882683221149,
0.281447041797604512774415,
0.125610208862766947294894,
0.0274135028268930549240776,
0.00250839672168065762786937,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [0.5, 0.75].
const ERF_IMPL_BD: &[f64] = &[
1.0,
1.8545005897903486499845,
1.43575803037831418074962,
0.582827658753036572454135,
0.124810476932949746447682,
0.0113724176546353285778481,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [0.75, 1.25].
const ERF_IMPL_CN: &[f64] = &[
-0.0397876892611136856954425,
0.153165212467878293257683,
0.191260295600936245503129,
0.10276327061989304213645,
0.029637090615738836726027,
0.0046093486780275489468812,
0.000307607820348680180548455,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [0.75, 1.25].
const ERF_IMPL_CD: &[f64] = &[
1.0,
1.95520072987627704987886,
1.64762317199384860109595,
0.768238607022126250082483,
0.209793185936509782784315,
0.0319569316899913392596356,
0.00213363160895785378615014,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [1.25, 2.25].
const ERF_IMPL_DN: &[f64] = &[
-0.0300838560557949717328341,
0.0538578829844454508530552,
0.0726211541651914182692959,
0.0367628469888049348429018,
0.00964629015572527529605267,
0.00133453480075291076745275,
0.778087599782504251917881e-4,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [1.25, 2.25].
const ERF_IMPL_DD: &[f64] = &[
1.0,
1.75967098147167528287343,
1.32883571437961120556307,
0.552528596508757581287907,
0.133793056941332861912279,
0.0179509645176280768640766,
0.00104712440019937356634038,
-0.106640381820357337177643e-7,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [2.25, 3.5].
const ERF_IMPL_EN: &[f64] = &[
-0.0117907570137227847827732,
0.014262132090538809896674,
0.0202234435902960820020765,
0.00930668299990432009042239,
0.00213357802422065994322516,
0.00025022987386460102395382,
0.120534912219588189822126e-4,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [2.25, 3.5].
const ERF_IMPL_ED: &[f64] = &[
1.0,
1.50376225203620482047419,
0.965397786204462896346934,
0.339265230476796681555511,
0.0689740649541569716897427,
0.00771060262491768307365526,
0.000371421101531069302990367,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [3.5, 5.25].
const ERF_IMPL_FN: &[f64] = &[
-0.00546954795538729307482955,
0.00404190278731707110245394,
0.0054963369553161170521356,
0.00212616472603945399437862,
0.000394984014495083900689956,
0.365565477064442377259271e-4,
0.135485897109932323253786e-5,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [3.5, 5.25].
const ERF_IMPL_FD: &[f64] = &[
1.0,
1.21019697773630784832251,
0.620914668221143886601045,
0.173038430661142762569515,
0.0276550813773432047594539,
0.00240625974424309709745382,
0.891811817251336577241006e-4,
-0.465528836283382684461025e-11,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [5.25, 8].
const ERF_IMPL_GN: &[f64] = &[
-0.00270722535905778347999196,
0.0013187563425029400461378,
0.00119925933261002333923989,
0.00027849619811344664248235,
0.267822988218331849989363e-4,
0.923043672315028197865066e-6,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [5.25, 8].
const ERF_IMPL_GD: &[f64] = &[
1.0,
0.814632808543141591118279,
0.268901665856299542168425,
0.0449877216103041118694989,
0.00381759663320248459168994,
0.000131571897888596914350697,
0.404815359675764138445257e-11,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [8, 11.5].
const ERF_IMPL_HN: &[f64] = &[
-0.00109946720691742196814323,
0.000406425442750422675169153,
0.000274499489416900707787024,
0.465293770646659383436343e-4,
0.320955425395767463401993e-5,
0.778286018145020892261936e-7,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [8, 11.5].
const ERF_IMPL_HD: &[f64] = &[
1.0,
0.588173710611846046373373,
0.139363331289409746077541,
0.0166329340417083678763028,
0.00100023921310234908642639,
0.24254837521587225125068e-4,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [11.5, 17].
const ERF_IMPL_IN: &[f64] = &[
-0.00056907993601094962855594,
0.000169498540373762264416984,
0.518472354581100890120501e-4,
0.382819312231928859704678e-5,
0.824989931281894431781794e-7,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [11.5, 17].
const ERF_IMPL_ID: &[f64] = &[
1.0,
0.339637250051139347430323,
0.043472647870310663055044,
0.00248549335224637114641629,
0.535633305337152900549536e-4,
-0.117490944405459578783846e-12,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [17, 24].
const ERF_IMPL_JN: &[f64] = &[
-0.000241313599483991337479091,
0.574224975202501512365975e-4,
0.115998962927383778460557e-4,
0.581762134402593739370875e-6,
0.853971555085673614607418e-8,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [17, 24].
const ERF_IMPL_JD: &[f64] = &[
1.0,
0.233044138299687841018015,
0.0204186940546440312625597,
0.000797185647564398289151125,
0.117019281670172327758019e-4,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [24, 38].
const ERF_IMPL_KN: &[f64] = &[
-0.000146674699277760365803642,
0.162666552112280519955647e-4,
0.269116248509165239294897e-5,
0.979584479468091935086972e-7,
0.101994647625723465722285e-8,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [24, 38].
const ERF_IMPL_KD: &[f64] = &[
1.0,
0.165907812944847226546036,
0.0103361716191505884359634,
0.000286593026373868366935721,
0.298401570840900340874568e-5,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [38, 60].
const ERF_IMPL_LN: &[f64] = &[
-0.583905797629771786720406e-4,
0.412510325105496173512992e-5,
0.431790922420250949096906e-6,
0.993365155590013193345569e-8,
0.653480510020104699270084e-10,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [38, 60].
const ERF_IMPL_LD: &[f64] = &[
1.0,
0.105077086072039915406159,
0.00414278428675475620830226,
0.726338754644523769144108e-4,
0.477818471047398785369849e-6,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [60, 85].
const ERF_IMPL_MN: &[f64] = &[
-0.196457797609229579459841e-4,
0.157243887666800692441195e-5,
0.543902511192700878690335e-7,
0.317472492369117710852685e-9,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [60, 85].
const ERF_IMPL_MD: &[f64] = &[
1.0,
0.052803989240957632204885,
0.000926876069151753290378112,
0.541011723226630257077328e-5,
0.535093845803642394908747e-15,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [85, 110].
const ERF_IMPL_NN: &[f64] = &[
-0.789224703978722689089794e-5,
0.622088451660986955124162e-6,
0.145728445676882396797184e-7,
0.603715505542715364529243e-10,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [85, 110].
const ERF_IMPL_ND: &[f64] = &[
1.0,
0.0375328846356293715248719,
0.000467919535974625308126054,
0.193847039275845656900547e-5,
];
// **********************************************************
// ********** Coefficients for erf_inv_impl polynomial ******
// **********************************************************
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0, 0.5].
const ERF_INV_IMPL_AN: &[f64] = &[
-0.000508781949658280665617,
-0.00836874819741736770379,
0.0334806625409744615033,
-0.0126926147662974029034,
-0.0365637971411762664006,
0.0219878681111168899165,
0.00822687874676915743155,
-0.00538772965071242932965,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0, 0.5].
const ERF_INV_IMPL_AD: &[f64] = &[
1.0,
-0.970005043303290640362,
-1.56574558234175846809,
1.56221558398423026363,
0.662328840472002992063,
-0.71228902341542847553,
-0.0527396382340099713954,
0.0795283687341571680018,
-0.00233393759374190016776,
0.000886216390456424707504,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.5, 0.75].
const ERF_INV_IMPL_BN: &[f64] = &[
-0.202433508355938759655,
0.105264680699391713268,
8.37050328343119927838,
17.6447298408374015486,
-18.8510648058714251895,
-44.6382324441786960818,
17.445385985570866523,
21.1294655448340526258,
-3.67192254707729348546,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.5, 0.75].
const ERF_INV_IMPL_BD: &[f64] = &[
1.0,
6.24264124854247537712,
3.9713437953343869095,
-28.6608180499800029974,
-20.1432634680485188801,
48.5609213108739935468,
10.8268667355460159008,
-22.6436933413139721736,
1.72114765761200282724,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x less than 3.
const ERF_INV_IMPL_CN: &[f64] = &[
-0.131102781679951906451,
-0.163794047193317060787,
0.117030156341995252019,
0.387079738972604337464,
0.337785538912035898924,
0.142869534408157156766,
0.0290157910005329060432,
0.00214558995388805277169,
-0.679465575181126350155e-6,
0.285225331782217055858e-7,
-0.681149956853776992068e-9,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x less than 3.
const ERF_INV_IMPL_CD: &[f64] = &[
1.0,
3.46625407242567245975,
5.38168345707006855425,
4.77846592945843778382,
2.59301921623620271374,
0.848854343457902036425,
0.152264338295331783612,
0.01105924229346489121,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 3 and 6.
const ERF_INV_IMPL_DN: &[f64] = &[
-0.0350353787183177984712,
-0.00222426529213447927281,
0.0185573306514231072324,
0.00950804701325919603619,
0.00187123492819559223345,
0.000157544617424960554631,
0.460469890584317994083e-5,
-0.230404776911882601748e-9,
0.266339227425782031962e-11,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 3 and 6.
const ERF_INV_IMPL_DD: &[f64] = &[
1.0,
1.3653349817554063097,
0.762059164553623404043,
0.220091105764131249824,
0.0341589143670947727934,
0.00263861676657015992959,
0.764675292302794483503e-4,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 6 and 18.
const ERF_INV_IMPL_EN: &[f64] = &[
-0.0167431005076633737133,
-0.00112951438745580278863,
0.00105628862152492910091,
0.000209386317487588078668,
0.149624783758342370182e-4,
0.449696789927706453732e-6,
0.462596163522878599135e-8,
-0.281128735628831791805e-13,
0.99055709973310326855e-16,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 6 and 18.
const ERF_INV_IMPL_ED: &[f64] = &[
1.0,
0.591429344886417493481,
0.138151865749083321638,
0.0160746087093676504695,
0.000964011807005165528527,
0.275335474764726041141e-4,
0.282243172016108031869e-6,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 18 and 44.
const ERF_INV_IMPL_FN: &[f64] = &[
-0.0024978212791898131227,
-0.779190719229053954292e-5,
0.254723037413027451751e-4,
0.162397777342510920873e-5,
0.396341011304801168516e-7,
0.411632831190944208473e-9,
0.145596286718675035587e-11,
-0.116765012397184275695e-17,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 18 and 44.
const ERF_INV_IMPL_FD: &[f64] = &[
1.0,
0.207123112214422517181,
0.0169410838120975906478,
0.000690538265622684595676,
0.145007359818232637924e-4,
0.144437756628144157666e-6,
0.509761276599778486139e-9,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x greater than 44.
const ERF_INV_IMPL_GN: &[f64] = &[
-0.000539042911019078575891,
-0.28398759004727721098e-6,
0.899465114892291446442e-6,
0.229345859265920864296e-7,
0.225561444863500149219e-9,
0.947846627503022684216e-12,
0.135880130108924861008e-14,
-0.348890393399948882918e-21,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x greater than 44.
const ERF_INV_IMPL_GD: &[f64] = &[
1.0,
0.0845746234001899436914,
0.00282092984726264681981,
0.468292921940894236786e-4,
0.399968812193862100054e-6,
0.161809290887904476097e-8,
0.231558608310259605225e-11,
];
/// `erf_impl` computes the error function at `z`.
/// If `inv` is true, `1 - erf` is calculated as opposed to `erf`
fn erf_impl(z: f64, inv: bool) -> f64 {
if z < 0.0 {
if !inv {
return -erf_impl(-z, false);
}
if z < -0.5 {
return 2.0 - erf_impl(-z, true);
}
return 1.0 + erf_impl(-z, false);
}
let result = if z < 0.5 {
if z < 1e-10 {
z * 1.125 + z * 0.003379167095512573896158903121545171688
} else {
z * 1.125
+ z * evaluate::polynomial(z, ERF_IMPL_AN) / evaluate::polynomial(z, ERF_IMPL_AD)
}
} else if z < 110.0 {
let (r, b) = if z < 0.75 {
(
evaluate::polynomial(z - 0.5, ERF_IMPL_BN)
/ evaluate::polynomial(z - 0.5, ERF_IMPL_BD),
0.3440242112,
)
} else if z < 1.25 {
(
evaluate::polynomial(z - 0.75, ERF_IMPL_CN)
/ evaluate::polynomial(z - 0.75, ERF_IMPL_CD),
0.419990927,
)
} else if z < 2.25 {
(
evaluate::polynomial(z - 1.25, ERF_IMPL_DN)
/ evaluate::polynomial(z - 1.25, ERF_IMPL_DD),
0.4898625016,
)
} else if z < 3.5 {
(
evaluate::polynomial(z - 2.25, ERF_IMPL_EN)
/ evaluate::polynomial(z - 2.25, ERF_IMPL_ED),
0.5317370892,
)
} else if z < 5.25 {
(
evaluate::polynomial(z - 3.5, ERF_IMPL_FN)
/ evaluate::polynomial(z - 3.5, ERF_IMPL_FD),
0.5489973426,
)
} else if z < 8.0 {
(
evaluate::polynomial(z - 5.25, ERF_IMPL_GN)
/ evaluate::polynomial(z - 5.25, ERF_IMPL_GD),
0.5571740866,
)
} else if z < 11.5 {
(
evaluate::polynomial(z - 8.0, ERF_IMPL_HN)
/ evaluate::polynomial(z - 8.0, ERF_IMPL_HD),
0.5609807968,
)
} else if z < 17.0 {
(
evaluate::polynomial(z - 11.5, ERF_IMPL_IN)
/ evaluate::polynomial(z - 11.5, ERF_IMPL_ID),
0.5626493692,
)
} else if z < 24.0 {
(
evaluate::polynomial(z - 17.0, ERF_IMPL_JN)
/ evaluate::polynomial(z - 17.0, ERF_IMPL_JD),
0.5634598136,
)
} else if z < 38.0 {
(
evaluate::polynomial(z - 24.0, ERF_IMPL_KN)
/ evaluate::polynomial(z - 24.0, ERF_IMPL_KD),
0.5638477802,
)
} else if z < 60.0 {
(
evaluate::polynomial(z - 38.0, ERF_IMPL_LN)
/ evaluate::polynomial(z - 38.0, ERF_IMPL_LD),
0.5640528202,
)
} else if z < 85.0 {
(
evaluate::polynomial(z - 60.0, ERF_IMPL_MN)
/ evaluate::polynomial(z - 60.0, ERF_IMPL_MD),
0.5641309023,
)
} else {
(
evaluate::polynomial(z - 85.0, ERF_IMPL_NN)
/ evaluate::polynomial(z - 85.0, ERF_IMPL_ND),
0.5641584396,
)
};
let g = (-z * z).exp() / z;
g * b + g * r
} else {
0.0
};
if inv && z >= 0.5 {
result
} else if z >= 0.5 || inv {
1.0 - result
} else {
result
}
}
// `erf_inv_impl` computes the inverse error function where
// `p`,`q`, and `s` are the first, second, and third intermediate
// parameters respectively
fn erf_inv_impl(p: f64, q: f64, s: f64) -> f64 {
let result = if p <= 0.5 {
let y = 0.0891314744949340820313;
let g = p * (p + 10.0);
let r = evaluate::polynomial(p, ERF_INV_IMPL_AN) / evaluate::polynomial(p, ERF_INV_IMPL_AD);
g * y + g * r
} else if q >= 0.25 {
let y = 2.249481201171875;
let g = (-2.0 * q.ln()).sqrt();
let xs = q - 0.25;
let r =
evaluate::polynomial(xs, ERF_INV_IMPL_BN) / evaluate::polynomial(xs, ERF_INV_IMPL_BD);
g / (y + r)
} else {
let x = (-q.ln()).sqrt();
if x < 3.0 {
let y = 0.807220458984375;
let xs = x - 1.125;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_CN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_CD);
y * x + r * x
} else if x < 6.0 {
let y = 0.93995571136474609375;
let xs = x - 3.0;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_DN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_DD);
y * x + r * x
} else if x < 18.0 {
let y = 0.98362827301025390625;
let xs = x - 6.0;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_EN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_ED);
y * x + r * x
} else if x < 44.0 {
let y = 0.99714565277099609375;
let xs = x - 18.0;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_FN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_FD);
y * x + r * x
} else {
let y = 0.99941349029541015625;
let xs = x - 44.0;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_GN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_GD);
y * x + r * x
}
};
s * result
}

View File

@ -1,4 +1,3 @@
pub mod erf;
pub mod kernels; pub mod kernels;
trait Cpu<const ARR: usize> { trait Cpu<const ARR: usize> {

View File

@ -4,9 +4,6 @@ use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
use half::{bf16, f16}; use half::{bf16, f16};
use rayon::prelude::*; use rayon::prelude::*;
const USE_IM2COL_CONV1D: bool = true;
const USE_IM2COL_CONV2D: bool = true;
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator + // TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
// intercept the oom errors to avoid panicking and provide a proper error. // intercept the oom errors to avoid panicking and provide a proper error.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -727,36 +724,6 @@ impl Map1 for MaxPool2D {
} }
} }
struct UpsampleNearest1D(usize);
impl Map1 for UpsampleNearest1D {
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
// TODO: Specialized implementation for the case 2*sz?
let dst_sz = self.0;
let (b_sz, c, src_sz) = layout.shape().dims3()?;
let stride = layout.stride();
let stride_sz = stride[2];
let src_index = layout.start_offset();
let scale_sz = src_sz as f64 / dst_sz as f64;
let mut dst = vec![T::zero(); b_sz * c * dst_sz];
let src_idxs = (0..dst_sz)
.map(|idx| usize::min(src_sz - 1, (idx as f64 * scale_sz) as usize))
.collect::<Vec<_>>();
for b_idx in 0..b_sz {
let dst = &mut dst[b_idx * c * dst_sz..];
let src_index = src_index + b_idx * stride[0];
for c_idx in 0..c {
let dst = &mut dst[c_idx * dst_sz..];
let src_index = src_index + c_idx * stride[1];
for (idx, src_idx) in src_idxs.iter().enumerate() {
dst[idx] = src[src_index + src_idx * stride_sz]
}
}
}
Ok(dst)
}
}
struct UpsampleNearest2D(usize, usize); struct UpsampleNearest2D(usize, usize);
impl Map1 for UpsampleNearest2D { impl Map1 for UpsampleNearest2D {
@ -804,11 +771,11 @@ impl<'a, I: IntDType> Map1 for Gather<'a, I> {
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> { fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
let ids = match self.ids_l.contiguous_offsets() { let ids = match self.ids_l.contiguous_offsets() {
Some((a, b)) => &self.ids[a..b], Some((a, b)) => &self.ids[a..b],
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?, None => Err(Error::RequiresContiguous { op: "gather" })?,
}; };
let src = match src_l.contiguous_offsets() { let src = match src_l.contiguous_offsets() {
Some((a, b)) => &src[a..b], Some((a, b)) => &src[a..b],
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?, None => Err(Error::RequiresContiguous { op: "gather" })?,
}; };
let dim = self.dim; let dim = self.dim;
let ids_dims = self.ids_l.dims(); let ids_dims = self.ids_l.dims();
@ -857,7 +824,7 @@ impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> { fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
let src = match layout.contiguous_offsets() { let src = match layout.contiguous_offsets() {
Some((a, b)) => &src[a..b], Some((a, b)) => &src[a..b],
None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?, None => Err(Error::RequiresContiguous { op: "index-select" })?,
}; };
let dim = self.dim; let dim = self.dim;
let n_ids = match self.ids_l.dims() { let n_ids = match self.ids_l.dims() {
@ -913,7 +880,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
let mut dst = vec![T::zero(); dst_len]; let mut dst = vec![T::zero(); dst_len];
copy_strided_src_(v1, &mut dst, 0, l1); copy_strided_src_(v1, &mut dst, 0, l1);
let src = match src_l.contiguous_offsets() { let src = match src_l.contiguous_offsets() {
None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?, None => Err(Error::RequiresContiguous { op: "scatter-add" })?,
Some((o1, o2)) => &src[o1..o2], Some((o1, o2)) => &src[o1..o2],
}; };
@ -929,7 +896,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
let ids = match self.ids_l.contiguous_offsets() { let ids = match self.ids_l.contiguous_offsets() {
Some((a, b)) => &self.ids[a..b], Some((a, b)) => &self.ids[a..b],
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?, None => Err(Error::RequiresContiguous { op: "gather" })?,
}; };
for left_i in 0..ids_left_len { for left_i in 0..ids_left_len {
let start_ids_idx = left_i * ids_right_len * ids_dim_len; let start_ids_idx = left_i * ids_right_len * ids_dim_len;
@ -971,7 +938,7 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
let mut dst = vec![T::zero(); dst_len]; let mut dst = vec![T::zero(); dst_len];
copy_strided_src_(v1, &mut dst, 0, l1); copy_strided_src_(v1, &mut dst, 0, l1);
let src = match src_l.contiguous_offsets() { let src = match src_l.contiguous_offsets() {
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, None => Err(Error::RequiresContiguous { op: "index-add" })?,
Some((o1, o2)) => &src[o1..o2], Some((o1, o2)) => &src[o1..o2],
}; };
let dim = self.dim; let dim = self.dim;
@ -1122,208 +1089,6 @@ impl<'a> Map2 for Conv1D<'a> {
} }
} }
struct Im2Col1D {
l_k: usize,
stride: usize,
dilation: usize,
padding: usize,
}
impl Im2Col1D {
fn l_out(&self, l: usize) -> usize {
(l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
}
}
impl Map1 for Im2Col1D {
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
let &Self {
l_k,
stride,
dilation,
padding,
} = self;
let (b, c, l) = layout.shape().dims3()?;
let l_out = self.l_out(l);
let src = &vs[layout.start_offset()..];
let mut dst = vec![T::zero(); b * l_out * c * l_k];
let (src_s0, src_s1, src_s2) = {
let s = layout.stride();
(s[0], s[1], s[2])
};
// TODO: provide specialized kernels for the common use cases.
// - l_k = 1
// - padding = 0
// - stride = 1
// - dilation = 1
for b_idx in 0..b {
let src_idx = b_idx * src_s0;
let dst_idx = b_idx * l_out * c * l_k;
for l_idx in 0..l_out {
let dst_idx = dst_idx + l_idx * c * l_k;
for c_idx in 0..c {
let dst_idx = dst_idx + c_idx * l_k;
let src_idx = c_idx * src_s1 + src_idx;
for l_k_idx in 0..l_k {
let src_l = l_idx * stride + l_k_idx * dilation;
if padding != 0 && (src_l < padding || src_l >= l + padding) {
continue;
}
let src_l = src_l - padding;
let src_idx = src_idx + src_l * src_s2;
let dst_idx = dst_idx + l_k_idx;
dst[dst_idx] = src[src_idx]
}
}
}
}
Ok(dst)
}
}
struct Im2Col {
h_k: usize,
w_k: usize,
stride: usize,
dilation: usize,
padding: usize,
}
impl Im2Col {
fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
(h_out, w_out)
}
}
impl Map1 for Im2Col {
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
let &Self {
h_k,
w_k,
stride,
dilation,
padding,
} = self;
let (b, c, h, w) = layout.shape().dims4()?;
let (h_out, w_out) = self.hw_out(h, w);
let src = &vs[layout.start_offset()..];
let mut dst = vec![T::zero(); b * h_out * w_out * c * h_k * w_k];
let (src_s0, src_s1, src_s2, src_s3) = {
let s = layout.stride();
(s[0], s[1], s[2], s[3])
};
// TODO: provide specialized kernels for the common use cases.
// - h_k = w_k = 1
// - padding = 0
// - stride = 1
// - dilation = 1
for b_idx in 0..b {
let src_idx = b_idx * src_s0;
let dst_idx = b_idx * h_out * w_out * c * h_k * w_k;
for h_idx in 0..h_out {
let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k;
for w_idx in 0..w_out {
let dst_idx = dst_idx + w_idx * c * h_k * w_k;
for c_idx in 0..c {
let dst_idx = dst_idx + c_idx * h_k * w_k;
let src_idx = c_idx * src_s1 + src_idx;
for h_k_idx in 0..h_k {
let src_h = h_idx * stride + h_k_idx * dilation;
if padding != 0 && (src_h < padding || src_h >= h + padding) {
continue;
}
let src_h = src_h - padding;
let src_idx = src_idx + src_h * src_s2;
let dst_idx = dst_idx + h_k_idx * w_k;
for w_k_idx in 0..w_k {
let src_w = w_idx * stride + w_k_idx * dilation;
if padding != 0 && (src_w < padding || src_w >= w + padding) {
continue;
}
let src_w = src_w - padding;
let src_idx = src_idx + src_w * src_s3;
let dst_idx = dst_idx + w_k_idx;
dst[dst_idx] = src[src_idx]
}
}
}
}
}
}
Ok(dst)
}
}
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
impl<'a> Map2 for ConvTranspose1D<'a> {
const OP: &'static str = "conv_transpose1d";
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
let inp = &inp[inp_l.start_offset()..];
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
let l_out = p.l_out();
// Output shape: [b_size, c_out, l_out].
let dst_elems = p.c_out * l_out * p.b_size;
let dst = vec![T::zero(); dst_elems];
let dst_s0 = p.c_out * l_out;
let dst_s1 = l_out;
let dst_s2 = 1;
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
let cont_s0 = p.l_in * p.c_in;
let cont_s1 = p.c_in;
for b_idx in 0..p.b_size {
for l_idx in 0..p.l_in {
for c_idx in 0..p.c_in {
let src_idx = b_idx * inp_s0 + c_idx * inp_s1 + l_idx * inp_s2;
let dst_idx = b_idx * cont_s0 + l_idx * cont_s1 + c_idx;
inp_cont[dst_idx] = inp[src_idx]
}
}
}
for k_idx in 0..p.k_size {
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
let k_cont = (0..p.c_in)
.map(|c_in_idx| k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_idx * k_s2])
.collect::<Vec<_>>();
for b_idx in 0..p.b_size {
for l_idx in 0..p.l_in {
let out_idx = l_idx * p.stride + k_idx * p.dilation;
if out_idx < p.padding {
continue;
}
let out_idx = out_idx - p.padding;
if out_idx < l_out {
let inp_cont = &inp_cont[b_idx * cont_s0 + l_idx * cont_s1..];
let dst_idx = b_idx * dst_s0 + out_idx * dst_s2 + dst_c_idx * dst_s1;
let mut d = T::zero();
unsafe {
T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
}
let dst_p = dst.as_ptr();
// Safety: dst_idx are uniques per dst_c_idx which is used to
// parallelise the different tasks so no two threads can try to
// write at the same location.
unsafe {
let ptr = dst_p.add(dst_idx) as *mut T;
*ptr += d
}
}
}
}
})
}
Ok(dst)
}
}
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D); struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
impl<'a> Map2 for Conv2D<'a> { impl<'a> Map2 for Conv2D<'a> {
@ -1529,9 +1294,8 @@ impl Map2 for MatMul {
) -> Result<Vec<T>> { ) -> Result<Vec<T>> {
use gemm::{gemm, Parallelism}; use gemm::{gemm, Parallelism};
match T::DTYPE { if T::DTYPE == DType::BF16 {
DType::F16 | DType::F32 | DType::F64 => {} return Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?;
_ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?,
} }
let (b, m, n, k) = self.0; let (b, m, n, k) = self.0;
@ -2235,10 +1999,6 @@ impl BackendStorage for CpuStorage {
MaxPool2D(kernel_size, stride).map(self, layout) MaxPool2D(kernel_size, stride).map(self, layout)
} }
fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
UpsampleNearest1D(sz).map(self, layout)
}
fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> { fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
UpsampleNearest2D(h, w).map(self, layout) UpsampleNearest2D(h, w).map(self, layout)
} }
@ -2467,50 +2227,7 @@ impl BackendStorage for CpuStorage {
kernel_l: &Layout, kernel_l: &Layout,
params: &crate::conv::ParamsConv1D, params: &crate::conv::ParamsConv1D,
) -> Result<Self> { ) -> Result<Self> {
if !USE_IM2COL_CONV1D { Conv1D(params).map(self, l, kernel, kernel_l)
return Conv1D(params).map(self, l, kernel, kernel_l);
}
let op = Im2Col1D {
l_k: params.k_size,
padding: params.padding,
stride: params.stride,
dilation: params.dilation,
};
let col = op.map(self, l)?;
let b = params.b_size;
let n = params.c_out;
let l_out = params.l_out();
let k = op.l_k * params.c_in;
let m = l_out;
let col_l = Layout::contiguous((b, m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
}
fn conv_transpose1d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
} }
fn conv2d( fn conv2d(
@ -2520,43 +2237,7 @@ impl BackendStorage for CpuStorage {
kernel_l: &Layout, kernel_l: &Layout,
params: &crate::conv::ParamsConv2D, params: &crate::conv::ParamsConv2D,
) -> Result<Self> { ) -> Result<Self> {
if !USE_IM2COL_CONV2D { Conv2D(params).map(self, l, kernel, kernel_l)
return Conv2D(params).map(self, l, kernel, kernel_l);
}
let op = Im2Col {
h_k: params.k_h,
w_k: params.k_w,
padding: params.padding,
stride: params.stride,
dilation: params.dilation,
};
let col = op.map(self, l)?;
let b = params.b_size;
let n = params.c_out;
let (h_out, w_out) = (params.out_h(), params.out_w());
let k = op.h_k * op.w_k * params.c_in;
let m = h_out * w_out;
let col_l = Layout::contiguous((b, m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
.transpose(1, 2)?
.transpose(1, 3)?;
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
} }
fn conv_transpose2d( fn conv_transpose2d(
@ -2617,25 +2298,25 @@ impl BackendStorage for CpuStorage {
Self::U8(ids) => { Self::U8(ids) => {
let ids = match ids_l.contiguous_offsets() { let ids = match ids_l.contiguous_offsets() {
Some((a, b)) => &ids[a..b], Some((a, b)) => &ids[a..b],
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, None => Err(Error::RequiresContiguous { op: "index-add" })?,
}; };
IndexAdd { ids, dim }.map(self, l, src, src_l) IndexAdd { ids, dim }.map(self, l, src, src_l)
} }
Self::U32(ids) => { Self::U32(ids) => {
let ids = match ids_l.contiguous_offsets() { let ids = match ids_l.contiguous_offsets() {
Some((a, b)) => &ids[a..b], Some((a, b)) => &ids[a..b],
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, None => Err(Error::RequiresContiguous { op: "index-add" })?,
}; };
IndexAdd { ids, dim }.map(self, l, src, src_l) IndexAdd { ids, dim }.map(self, l, src, src_l)
} }
Self::I64(ids) => { Self::I64(ids) => {
let ids = match ids_l.contiguous_offsets() { let ids = match ids_l.contiguous_offsets() {
Some((a, b)) => &ids[a..b], Some((a, b)) => &ids[a..b],
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, None => Err(Error::RequiresContiguous { op: "index-add" })?,
}; };
IndexAdd { ids, dim }.map(self, l, src, src_l) IndexAdd { ids, dim }.map(self, l, src, src_l)
} }
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")),
} }
} }
@ -2681,10 +2362,6 @@ impl BackendDevice for CpuDevice {
Ok(Self) Ok(Self)
} }
fn set_seed(&self, _seed: u64) -> Result<()> {
crate::bail!("cannot seed the CPU rng with set_seed")
}
fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> { fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> {
use rand::prelude::*; use rand::prelude::*;

View File

@ -223,14 +223,6 @@ impl BackendDevice for CudaDevice {
}) })
} }
fn set_seed(&self, seed: u64) -> Result<()> {
// We do not call set_seed but instead create a new curand object. This ensures that the
// state will be identical and the same random numbers will be generated.
let mut curand = self.curand.lock().unwrap();
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
Ok(())
}
fn location(&self) -> crate::DeviceLocation { fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Cuda { crate::DeviceLocation::Cuda {
gpu_id: self.device.ordinal(), gpu_id: self.device.ordinal(),
@ -320,13 +312,6 @@ impl BackendDevice for CudaDevice {
// cudarc changes. // cudarc changes.
let elem_count = shape.elem_count(); let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap(); let curand = self.curand.lock().unwrap();
// curand can only generate an odd number of values.
// https://github.com/huggingface/candle/issues/734
let elem_count_round = if elem_count % 2 == 1 {
elem_count + 1
} else {
elem_count
};
let slice = match dtype { let slice = match dtype {
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype { Err(CudaError::UnsupportedDtype {
@ -336,7 +321,7 @@ impl BackendDevice for CudaDevice {
.w()? .w()?
} }
DType::F32 => { DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?; let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
curand curand
.0 .0
.fill_with_normal(&mut data, mean as f32, std as f32) .fill_with_normal(&mut data, mean as f32, std as f32)
@ -344,7 +329,7 @@ impl BackendDevice for CudaDevice {
CudaStorageSlice::F32(data) CudaStorageSlice::F32(data)
} }
DType::F64 => { DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?; let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
curand.0.fill_with_normal(&mut data, mean, std).w()?; curand.0.fill_with_normal(&mut data, mean, std).w()?;
CudaStorageSlice::F64(data) CudaStorageSlice::F64(data)
} }
@ -608,105 +593,6 @@ impl Map1 for Elu {
} }
} }
struct Im2Col1D {
l_k: usize,
stride: usize,
dilation: usize,
padding: usize,
}
impl Im2Col1D {
fn l_out(&self, l: usize) -> usize {
(l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
}
}
impl Map1 for Im2Col1D {
fn f<T: DeviceRepr + WithDType>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
let shape = layout.shape();
let dims = shape.dims();
let l_out = self.l_out(dims[2]);
let dst_el = dims[0] * l_out * dims[1] * self.l_k;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("im2col1d"), kernels::CONV)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let params = (
dst_el,
l_out,
self.l_k,
self.stride,
self.padding,
self.dilation,
&ds,
src,
&dst,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
}
}
struct Im2Col {
h_k: usize,
w_k: usize,
stride: usize,
dilation: usize,
padding: usize,
}
impl Im2Col {
fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
(h_out, w_out)
}
}
impl Map1 for Im2Col {
fn f<T: DeviceRepr + WithDType>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
let shape = layout.shape();
let dims = shape.dims();
let (h_out, w_out) = self.hw_out(dims[2], dims[3]);
let dst_el = dims[0] * h_out * w_out * dims[1] * self.h_k * self.w_k;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("im2col"), kernels::CONV)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let params = (
dst_el,
h_out,
w_out,
self.h_k,
self.w_k,
self.stride,
self.padding,
self.dilation,
&ds,
src,
&dst,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
}
}
struct Powf(f64); struct Powf(f64);
impl Map1 for Powf { impl Map1 for Powf {
fn f<T: DeviceRepr + WithDType>( fn f<T: DeviceRepr + WithDType>(
@ -892,6 +778,8 @@ impl<'a> Map1 for IndexSelect<'a> {
}; };
let ids_shape = ids_l.shape(); let ids_shape = ids_l.shape();
let ids_dims = ids_shape.dims(); let ids_dims = ids_shape.dims();
let ids_el = ids_shape.elem_count();
let cfg = LaunchConfig::for_num_elems(ids_el as u32);
let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?; let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?;
let src = match src_l.contiguous_offsets() { let src = match src_l.contiguous_offsets() {
Some((o1, o2)) => src.slice(o1..o2), Some((o1, o2)) => src.slice(o1..o2),
@ -899,23 +787,19 @@ impl<'a> Map1 for IndexSelect<'a> {
}; };
let left_size: usize = src_l.dims()[..self.2].iter().product(); let left_size: usize = src_l.dims()[..self.2].iter().product();
let right_size: usize = src_l.dims()[self.2 + 1..].iter().product(); let right_size: usize = src_l.dims()[self.2 + 1..].iter().product();
let src_dim_size = src_l.dims()[self.2]; let dim_size = src_l.dims()[self.2];
let ids_dim_size = ids_shape.elem_count();
let dst_el = ids_shape.elem_count() * left_size * right_size;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?; let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
// SAFETY: Set later by running the kernel. // SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?; let out = unsafe { dev.alloc::<T>(ids_el * left_size * right_size) }.w()?;
let params = ( let params = (
dst_el, ids_el,
ids_dims.len(), ids_dims.len(),
&ds, &ds,
ids, ids,
&src, &src,
&out, &out,
left_size, left_size,
src_dim_size, dim_size,
ids_dim_size,
right_size, right_size,
); );
// SAFETY: ffi. // SAFETY: ffi.
@ -1766,56 +1650,9 @@ impl BackendStorage for CudaStorage {
kernel_l: &Layout, kernel_l: &Layout,
params: &crate::conv::ParamsConv1D, params: &crate::conv::ParamsConv1D,
) -> Result<Self> { ) -> Result<Self> {
const USE_IM2COL_CONV1D: bool = true;
let device = self.device().clone(); let device = self.device().clone();
if !USE_IM2COL_CONV1D {
let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
return Ok(Self { slice, device }); Ok(Self { slice, device })
}
let col = Im2Col1D {
l_k: params.k_size,
stride: params.stride,
dilation: params.dilation,
padding: params.padding,
}
.map(&self.slice, &device, l)?;
let col = Self { slice: col, device };
let l_out = params.l_out();
let b = params.b_size;
let n = params.c_out;
let k = params.k_size * params.c_in;
let m = l_out;
let col_l = Layout::contiguous((b, m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
}
fn conv_transpose1d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
todo!()
} }
#[cfg(not(feature = "cudnn"))] #[cfg(not(feature = "cudnn"))]
@ -1826,50 +1663,9 @@ impl BackendStorage for CudaStorage {
kernel_l: &Layout, kernel_l: &Layout,
params: &crate::conv::ParamsConv2D, params: &crate::conv::ParamsConv2D,
) -> Result<Self> { ) -> Result<Self> {
const USE_IM2COL_CONV2D: bool = true;
let device = self.device().clone(); let device = self.device().clone();
if !USE_IM2COL_CONV2D {
let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
return Ok(Self { slice, device }); Ok(Self { slice, device })
}
let col = Im2Col {
h_k: params.k_h,
w_k: params.k_w,
stride: params.stride,
dilation: params.dilation,
padding: params.padding,
}
.map(&self.slice, &device, l)?;
let col = Self { slice: col, device };
let h_out = params.out_h();
let w_out = params.out_w();
let b = params.b_size;
let n = params.c_out;
let k = params.k_h * params.k_w * params.c_in;
let m = h_out * w_out;
let col_l = Layout::contiguous((b, m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, h_out, w_out, n))
.transpose(1, 2)?
.transpose(1, 3)?;
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
} }
#[cfg(feature = "cudnn")] #[cfg(feature = "cudnn")]
@ -1974,10 +1770,6 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device }) Ok(Self { slice, device })
} }
fn upsample_nearest1d(&self, _: &Layout, _out_sz: usize) -> Result<Self> {
crate::bail!("upsample-nearest1d is not supported on cuda")
}
fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result<Self> { fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result<Self> {
let device = self.device().clone(); let device = self.device().clone();
let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?; let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?;
@ -2181,7 +1973,7 @@ impl BackendStorage for CudaStorage {
if src_l.is_contiguous() { if src_l.is_contiguous() {
dev.dtod_copy(&src, &mut dst).w()? dev.dtod_copy(&src, &mut dst).w()?
} else { } else {
let func = dev.get_or_load_func("ucopy_f64", kernels::UNARY)?; let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;
// SAFETY: Set later by running the kernel. // SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, &src, &mut dst); let params = (el_count, dims.len(), &ds, &src, &mut dst);
// SAFETY: ffi. // SAFETY: ffi.

View File

@ -34,9 +34,6 @@ pub(crate) fn launch_conv2d<
params: &crate::conv::ParamsConv2D, params: &crate::conv::ParamsConv2D,
dev: &crate::cuda_backend::CudaDevice, dev: &crate::cuda_backend::CudaDevice,
) -> crate::Result<()> { ) -> crate::Result<()> {
use crate::conv::CudnnFwdAlgo as CandleAlgo;
use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
let device_id = dev.id(); let device_id = dev.id();
let cudnn = CUDNN.with(|cudnn| { let cudnn = CUDNN.with(|cudnn| {
if let Some(cudnn) = cudnn.borrow().get(&device_id) { if let Some(cudnn) = cudnn.borrow().get(&device_id) {
@ -93,20 +90,7 @@ pub(crate) fn launch_conv2d<
w: &w, w: &w,
y: &y, y: &y,
}; };
let alg = match params.cudnn_fwd_algo { let alg = conv2d.pick_algorithm()?;
None => conv2d.pick_algorithm()?,
Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
Some(CandleAlgo::ImplicitPrecompGemm) => {
A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
}
Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
};
let workspace_size = conv2d.get_workspace_size(alg)?; let workspace_size = conv2d.get_workspace_size(alg)?;
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?; let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
unsafe { unsafe {

View File

@ -8,14 +8,12 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
pub enum DeviceLocation { pub enum DeviceLocation {
Cpu, Cpu,
Cuda { gpu_id: usize }, Cuda { gpu_id: usize },
Metal { gpu_id: usize },
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum Device { pub enum Device {
Cpu, Cpu,
Cuda(crate::CudaDevice), Cuda(crate::CudaDevice),
Metal(crate::MetalDevice),
} }
pub trait NdArray { pub trait NdArray {
@ -130,23 +128,10 @@ impl Device {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?)) Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
} }
pub fn new_metal(ordinal: usize) -> Result<Self> {
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
}
pub fn set_seed(&self, seed: u64) -> Result<()> {
match self {
Self::Cpu => CpuDevice.set_seed(seed),
Self::Cuda(c) => c.set_seed(seed),
Self::Metal(m) => m.set_seed(seed),
}
}
pub fn same_device(&self, rhs: &Self) -> bool { pub fn same_device(&self, rhs: &Self) -> bool {
match (self, rhs) { match (self, rhs) {
(Self::Cpu, Self::Cpu) => true, (Self::Cpu, Self::Cpu) => true,
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs), (Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
(Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
_ => false, _ => false,
} }
} }
@ -155,20 +140,21 @@ impl Device {
match self { match self {
Self::Cpu => DeviceLocation::Cpu, Self::Cpu => DeviceLocation::Cpu,
Self::Cuda(device) => device.location(), Self::Cuda(device) => device.location(),
Device::Metal(device) => device.location(),
} }
} }
pub fn is_cpu(&self) -> bool { pub fn is_cpu(&self) -> bool {
matches!(self, Self::Cpu) match self {
Self::Cpu => true,
Self::Cuda(_) => false,
}
} }
pub fn is_cuda(&self) -> bool { pub fn is_cuda(&self) -> bool {
matches!(self, Self::Cuda(_)) match self {
Self::Cpu => false,
Self::Cuda(_) => true,
} }
pub fn is_metal(&self) -> bool {
matches!(self, Self::Metal(_))
} }
pub fn cuda_if_available(ordinal: usize) -> Result<Self> { pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
@ -192,20 +178,10 @@ impl Device {
Ok(Storage::Cpu(storage)) Ok(Storage::Cpu(storage))
} }
Device::Cuda(device) => { Device::Cuda(device) => {
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
if dtype == DType::F16 || dtype == DType::BF16 {
let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
} else {
let storage = device.rand_uniform(shape, dtype, lo, up)?; let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cuda(storage)) Ok(Storage::Cuda(storage))
} }
} }
Device::Metal(device) => {
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Metal(storage))
}
}
} }
pub(crate) fn rand_uniform<T: crate::FloatDType>( pub(crate) fn rand_uniform<T: crate::FloatDType>(
@ -230,20 +206,10 @@ impl Device {
Ok(Storage::Cpu(storage)) Ok(Storage::Cpu(storage))
} }
Device::Cuda(device) => { Device::Cuda(device) => {
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
if dtype == DType::F16 || dtype == DType::BF16 {
let storage = device.rand_normal(shape, DType::F32, mean, std)?;
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
} else {
let storage = device.rand_normal(shape, dtype, mean, std)?; let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage)) Ok(Storage::Cuda(storage))
} }
} }
Device::Metal(device) => {
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Metal(storage))
}
}
} }
pub(crate) fn rand_normal<T: crate::FloatDType>( pub(crate) fn rand_normal<T: crate::FloatDType>(
@ -265,10 +231,6 @@ impl Device {
let storage = device.ones_impl(shape, dtype)?; let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Cuda(storage)) Ok(Storage::Cuda(storage))
} }
Device::Metal(device) => {
let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Metal(storage))
}
} }
} }
@ -282,10 +244,6 @@ impl Device {
let storage = device.zeros_impl(shape, dtype)?; let storage = device.zeros_impl(shape, dtype)?;
Ok(Storage::Cuda(storage)) Ok(Storage::Cuda(storage))
} }
Device::Metal(device) => {
let storage = device.zeros_impl(shape, dtype)?;
Ok(Storage::Metal(storage))
}
} }
} }
@ -297,11 +255,6 @@ impl Device {
let storage = device.storage_from_cpu_storage(&storage)?; let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage)) Ok(Storage::Cuda(storage))
} }
Device::Metal(device) => {
let storage = array.to_cpu_storage();
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Metal(storage))
}
} }
} }
@ -313,11 +266,6 @@ impl Device {
let storage = device.storage_from_cpu_storage(&storage)?; let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage)) Ok(Storage::Cuda(storage))
} }
Device::Metal(device) => {
let storage = S::to_cpu_storage_owned(data);
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Metal(storage))
}
} }
} }
} }

View File

@ -14,9 +14,6 @@ impl Tensor {
crate::DeviceLocation::Cuda { gpu_id } => { crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id) format!(", cuda:{}", gpu_id)
} }
crate::DeviceLocation::Metal { gpu_id } => {
format!(", metal:{}", gpu_id)
}
}; };
write!(f, "Tensor[")?; write!(f, "Tensor[")?;
@ -479,9 +476,6 @@ impl std::fmt::Display for Tensor {
crate::DeviceLocation::Cuda { gpu_id } => { crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id) format!(", cuda:{}", gpu_id)
} }
crate::DeviceLocation::Metal { gpu_id } => {
format!(", metal:{}", gpu_id)
}
}; };
write!( write!(

View File

@ -67,20 +67,6 @@ impl DType {
Self::F64 => 8, Self::F64 => 8,
} }
} }
pub fn is_int(&self) -> bool {
match self {
Self::U8 | Self::U32 | Self::I64 => true,
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => false,
}
}
pub fn is_float(&self) -> bool {
match self {
Self::U8 | Self::U32 | Self::I64 => false,
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true,
}
}
} }
pub trait WithDType: pub trait WithDType:

View File

@ -79,16 +79,6 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }
fn conv_transpose1d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn conv2d( fn conv2d(
&self, &self,
_: &Layout, _: &Layout,
@ -162,10 +152,6 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> { fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }
@ -177,10 +163,6 @@ impl crate::backend::BackendDevice for CudaDevice {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }
fn set_seed(&self, _: u64) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
fn location(&self) -> crate::DeviceLocation { fn location(&self) -> crate::DeviceLocation {
fail!() fail!()
} }

View File

@ -1,223 +0,0 @@
#![allow(dead_code)]
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
#[derive(Debug, Clone)]
pub struct MetalDevice;
#[derive(Debug)]
pub struct MetalStorage;
#[derive(thiserror::Error, Debug)]
pub enum MetalError {
#[error("{0}")]
Message(String),
}
impl From<String> for MetalError {
fn from(e: String) -> Self {
MetalError::Message(e)
}
}
macro_rules! fail {
() => {
unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
};
}
impl crate::backend::BackendStorage for MetalStorage {
type Device = MetalDevice;
fn try_clone(&self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn dtype(&self) -> DType {
fail!()
}
fn device(&self) -> &Self::Device {
fail!()
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv1d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConv1D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv_transpose1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv2d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConv2D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv_transpose2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose2D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn scatter_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn index_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn matmul(
&self,
_: &Self,
_: (usize, usize, usize, usize),
_: &Layout,
_: &Layout,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
}
impl crate::backend::BackendDevice for MetalDevice {
type Storage = MetalStorage;
fn new(_: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn set_seed(&self, _: u64) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
fn location(&self) -> crate::DeviceLocation {
fail!()
}
fn same_device(&self, _: &Self) -> bool {
fail!()
}
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
}

View File

@ -1,4 +1,4 @@
use crate::{DType, DeviceLocation, Layout, MetalError, Shape}; use crate::{DType, DeviceLocation, Layout, Shape};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct MatMulUnexpectedStriding { pub struct MatMulUnexpectedStriding {
@ -142,9 +142,6 @@ pub enum Error {
#[error("{op} expects at least one tensor")] #[error("{op} expects at least one tensor")]
OpRequiresAtLeastOneTensor { op: &'static str }, OpRequiresAtLeastOneTensor { op: &'static str },
#[error("{op} expects at least two tensors")]
OpRequiresAtLeastTwoTensors { op: &'static str },
#[error("backward is not supported for {op}")] #[error("backward is not supported for {op}")]
BackwardNotSupported { op: &'static str }, BackwardNotSupported { op: &'static str },
@ -152,9 +149,6 @@ pub enum Error {
#[error("the candle crate has not been built with cuda support")] #[error("the candle crate has not been built with cuda support")]
NotCompiledWithCudaSupport, NotCompiledWithCudaSupport,
#[error("the candle crate has not been built with metal support")]
NotCompiledWithMetalSupport,
#[error("cannot find tensor {path}")] #[error("cannot find tensor {path}")]
CannotFindTensor { path: String }, CannotFindTensor { path: String },
@ -162,9 +156,6 @@ pub enum Error {
#[error(transparent)] #[error(transparent)]
Cuda(Box<dyn std::error::Error + Send + Sync>), Cuda(Box<dyn std::error::Error + Send + Sync>),
#[error("Metal error {0}")]
Metal(#[from] MetalError),
#[error(transparent)] #[error(transparent)]
TryFromIntError(#[from] core::num::TryFromIntError), TryFromIntError(#[from] core::num::TryFromIntError),

View File

@ -46,31 +46,19 @@ impl Tensor {
current_dim += 1; current_dim += 1;
out out
} }
TensorIndexer::IndexSelect(indexes) => {
if indexes.rank() != 1 {
crate::bail!("multi-dimensional tensor indexing is not supported")
}
let out = x.index_select(&indexes.to_device(x.device())?, current_dim)?;
current_dim += 1;
out
}
TensorIndexer::Err(e) => crate::bail!("indexing error {e:?}"),
}; };
} }
Ok(x) Ok(x)
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
/// Generic structure used to index a slice of the tensor /// Generic structure used to index a slice of the tensor
pub enum TensorIndexer { pub enum TensorIndexer {
/// This selects the elemnts for which an index has some specific value. /// This selects the elemnts for which an index has some specific value.
Select(usize), Select(usize),
/// This is a regular slice, purely indexing a chunk of the tensor /// This is a regular slice, purely indexing a chunk of the tensor
Narrow(Bound<usize>, Bound<usize>), Narrow(Bound<usize>, Bound<usize>),
/// Indexing via a 1d tensor
IndexSelect(Tensor),
Err(Error),
} }
impl From<usize> for TensorIndexer { impl From<usize> for TensorIndexer {
@ -79,55 +67,36 @@ impl From<usize> for TensorIndexer {
} }
} }
impl From<&[u32]> for TensorIndexer { macro_rules! impl_from_range {
fn from(index: &[u32]) -> Self { ($range_type:ty) => {
match Tensor::new(index, &crate::Device::Cpu) { impl From<$range_type> for TensorIndexer {
Ok(tensor) => TensorIndexer::IndexSelect(tensor), fn from(range: $range_type) -> Self {
Err(e) => TensorIndexer::Err(e),
}
}
}
impl From<Vec<u32>> for TensorIndexer {
fn from(index: Vec<u32>) -> Self {
let len = index.len();
match Tensor::from_vec(index, len, &crate::Device::Cpu) {
Ok(tensor) => TensorIndexer::IndexSelect(tensor),
Err(e) => TensorIndexer::Err(e),
}
}
}
impl From<&Tensor> for TensorIndexer {
fn from(tensor: &Tensor) -> Self {
TensorIndexer::IndexSelect(tensor.clone())
}
}
trait RB: RangeBounds<usize> {}
impl RB for Range<usize> {}
impl RB for RangeFrom<usize> {}
impl RB for RangeFull {}
impl RB for RangeInclusive<usize> {}
impl RB for RangeTo<usize> {}
impl RB for RangeToInclusive<usize> {}
impl<T: RB> From<T> for TensorIndexer {
fn from(range: T) -> Self {
use std::ops::Bound::*; use std::ops::Bound::*;
let start = match range.start_bound() { let start = match range.start_bound() {
Included(idx) => Included(*idx), Included(idx) => Included(*idx),
Excluded(idx) => Excluded(*idx), Excluded(idx) => Excluded(*idx),
Unbounded => Unbounded, Unbounded => Unbounded,
}; };
let end = match range.end_bound() { let end = match range.end_bound() {
Included(idx) => Included(*idx), Included(idx) => Included(*idx),
Excluded(idx) => Excluded(*idx), Excluded(idx) => Excluded(*idx),
Unbounded => Unbounded, Unbounded => Unbounded,
}; };
TensorIndexer::Narrow(start, end) TensorIndexer::Narrow(start, end)
} }
} }
};
}
impl_from_range!(Range<usize>);
impl_from_range!(RangeFrom<usize>);
impl_from_range!(RangeFull);
impl_from_range!(RangeInclusive<usize>);
impl_from_range!(RangeTo<usize>);
impl_from_range!(RangeToInclusive<usize>);
/// Trait used to implement multiple signatures for ease of use of the slicing /// Trait used to implement multiple signatures for ease of use of the slicing
/// of a tensor /// of a tensor

View File

@ -49,12 +49,9 @@ mod device;
pub mod display; pub mod display;
mod dtype; mod dtype;
mod dummy_cuda_backend; mod dummy_cuda_backend;
mod dummy_metal_backend;
pub mod error; pub mod error;
mod indexer; mod indexer;
pub mod layout; pub mod layout;
#[cfg(feature = "metal")]
pub mod metal_backend;
#[cfg(feature = "mkl")] #[cfg(feature = "mkl")]
mod mkl; mod mkl;
pub mod npy; pub mod npy;
@ -90,12 +87,6 @@ pub use cuda_backend::{CudaDevice, CudaStorage};
#[cfg(not(feature = "cuda"))] #[cfg(not(feature = "cuda"))]
pub use dummy_cuda_backend::{CudaDevice, CudaStorage}; pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
#[cfg(feature = "metal")]
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
#[cfg(not(feature = "metal"))]
pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage};
#[cfg(feature = "mkl")] #[cfg(feature = "mkl")]
extern crate intel_mkl_src; extern crate intel_mkl_src;
@ -119,24 +110,18 @@ impl ToUsize2 for (usize, usize) {
} }
// A simple trait defining a module with forward method using a single argument. // A simple trait defining a module with forward method using a single argument.
pub trait Module { pub trait Module: std::fmt::Debug {
fn forward(&self, xs: &Tensor) -> Result<Tensor>; fn forward(&self, xs: &Tensor) -> Result<Tensor>;
/// Change the module to use training mode vs eval mode.
///
/// The default implementation does nothing as this is only used for a couple modules such as
/// dropout or batch-normalization.
fn set_training(&mut self, _training: bool) {}
} }
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T { impl Module for quantized::QMatMul {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self(xs)
}
}
// A trait defining a module with forward method using a single tensor argument and a flag to
// separate the training and evaluation behaviors.
pub trait ModuleT {
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
}
impl<M: Module> ModuleT for M {
fn forward_t(&self, xs: &Tensor, _train: bool) -> Result<Tensor> {
self.forward(xs) self.forward(xs)
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -250,6 +250,8 @@ impl Tensor {
if header.fortran_order { if header.fortran_order {
return Err(Error::Npy("fortran order not supported".to_string())); return Err(Error::Npy("fortran order not supported".to_string()));
} }
let mut data: Vec<u8> = vec![];
reader.read_to_end(&mut data)?;
Self::from_reader(header.shape(), header.descr, &mut reader) Self::from_reader(header.shape(), header.descr, &mut reader)
} }

View File

@ -1,5 +1,5 @@
#![allow(clippy::redundant_closure_call)] #![allow(clippy::redundant_closure_call)]
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor}; use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor};
use half::{bf16, f16}; use half::{bf16, f16};
use num_traits::float::Float; use num_traits::float::Float;
@ -58,13 +58,8 @@ pub enum UnaryOp {
Sqr, Sqr,
Sqrt, Sqrt,
Gelu, Gelu,
GeluErf,
Erf,
Relu, Relu,
Tanh, Tanh,
Floor,
Ceil,
Round,
} }
#[derive(Clone)] #[derive(Clone)]
@ -90,16 +85,6 @@ pub enum Op {
dilation: usize, dilation: usize,
}, },
#[allow(dead_code)]
ConvTranspose1D {
arg: Tensor,
kernel: Tensor,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
},
#[allow(dead_code)] #[allow(dead_code)]
Conv2D { Conv2D {
arg: Tensor, arg: Tensor,
@ -131,7 +116,6 @@ pub enum Op {
stride: (usize, usize), stride: (usize, usize),
}, },
UpsampleNearest1D(Tensor),
UpsampleNearest2D(Tensor), UpsampleNearest2D(Tensor),
Cat(Vec<Tensor>, usize), Cat(Vec<Tensor>, usize),
@ -146,7 +130,6 @@ pub enum Op {
Copy(Tensor), Copy(Tensor),
Broadcast(Tensor), Broadcast(Tensor),
Narrow(Tensor, usize, usize, usize), Narrow(Tensor, usize, usize, usize),
SliceScatter0(Tensor, Tensor, usize),
Reshape(Tensor), Reshape(Tensor),
ToDevice(Tensor), ToDevice(Tensor),
Transpose(Tensor, usize, usize), Transpose(Tensor, usize, usize),
@ -184,18 +167,6 @@ pub trait CustomOp1 {
)) ))
} }
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_storage: &MetalStorage,
_layout: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
/// This function takes as argument the argument `arg` used in the forward pass, the result /// This function takes as argument the argument `arg` used in the forward pass, the result
/// produced by the forward operation `res` and the gradient of the result `grad_res`. /// produced by the forward operation `res` and the gradient of the result `grad_res`.
/// The function should return the gradient of the argument. /// The function should return the gradient of the argument.
@ -231,20 +202,6 @@ pub trait CustomOp2 {
)) ))
} }
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
fn bwd( fn bwd(
&self, &self,
_arg1: &Tensor, _arg1: &Tensor,
@ -287,22 +244,6 @@ pub trait CustomOp3 {
)) ))
} }
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
fn bwd( fn bwd(
&self, &self,
_arg1: &Tensor, _arg1: &Tensor,
@ -383,13 +324,8 @@ pub(crate) struct Recip;
pub(crate) struct Sqr; pub(crate) struct Sqr;
pub(crate) struct Sqrt; pub(crate) struct Sqrt;
pub(crate) struct Gelu; pub(crate) struct Gelu;
pub(crate) struct GeluErf;
pub(crate) struct Erf;
pub(crate) struct Relu; pub(crate) struct Relu;
pub(crate) struct Tanh; pub(crate) struct Tanh;
pub(crate) struct Floor;
pub(crate) struct Ceil;
pub(crate) struct Round;
macro_rules! bin_op { macro_rules! bin_op {
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => { ($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
@ -588,13 +524,13 @@ unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin); unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos); unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh); unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
unary_op!(Abs, "abs", v, v.abs());
unary_op!(Neg, "neg", v, -v); unary_op!(Neg, "neg", v, -v);
unary_op!(Recip, "recip", v, v.recip()); unary_op!(Recip, "recip", v, v.recip());
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr); unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt); unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
/// Tanh based approximation of the `gelu` operation /// `gelu` operation
/// GeluErf is the more precise one.
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions> /// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
impl UnaryOpT for Gelu { impl UnaryOpT for Gelu {
const NAME: &'static str = "gelu"; const NAME: &'static str = "gelu";
@ -664,230 +600,6 @@ impl UnaryOpT for Gelu {
fn f64_vec(xs: &[f64], ys: &mut [f64]) { fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::vd_gelu(xs, ys) crate::mkl::vd_gelu(xs, ys)
} }
#[cfg(feature = "accelerate")]
const F32_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::accelerate::vs_gelu(xs, ys)
}
#[cfg(feature = "accelerate")]
const F64_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::accelerate::vd_gelu(xs, ys)
}
}
/// `erf` operation
/// <https://en.wikipedia.org/wiki/Error_function>
impl UnaryOpT for Erf {
const NAME: &'static str = "erf";
const KERNEL: &'static str = "uerf";
const V: Self = Erf;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
bf16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f16(v: f16) -> f16 {
f16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f32(v: f32) -> f32 {
Self::f64(v as f64) as f32
}
#[inline(always)]
fn f64(v: f64) -> f64 {
crate::cpu::erf::erf(v)
}
#[inline(always)]
fn u8(_: u8) -> u8 {
0
}
#[inline(always)]
fn u32(_: u32) -> u32 {
0
}
#[inline(always)]
fn i64(_: i64) -> i64 {
0
}
}
impl UnaryOpT for Abs {
const NAME: &'static str = "abs";
const KERNEL: &'static str = "uabs";
const V: Self = Abs;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v.abs()
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v.abs()
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v.abs()
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v.abs()
}
#[inline(always)]
fn u8(v: u8) -> u8 {
v
}
#[inline(always)]
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v.abs()
}
}
impl UnaryOpT for Ceil {
const NAME: &'static str = "ceil";
const KERNEL: &'static str = "uceil";
const V: Self = Ceil;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v.ceil()
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v.ceil()
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v.ceil()
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v.ceil()
}
#[inline(always)]
fn u8(v: u8) -> u8 {
v
}
#[inline(always)]
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v
}
}
impl UnaryOpT for Floor {
const NAME: &'static str = "floor";
const KERNEL: &'static str = "ufloor";
const V: Self = Floor;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v.floor()
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v.floor()
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v.floor()
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v.floor()
}
#[inline(always)]
fn u8(v: u8) -> u8 {
v
}
#[inline(always)]
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v
}
}
impl UnaryOpT for Round {
const NAME: &'static str = "round";
const KERNEL: &'static str = "uround";
const V: Self = Round;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v.round()
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v.round()
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v.round()
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v.round()
}
#[inline(always)]
fn u8(v: u8) -> u8 {
v
}
#[inline(always)]
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v
}
}
impl UnaryOpT for GeluErf {
const NAME: &'static str = "gelu_erf";
const KERNEL: &'static str = "ugelu_erf";
const V: Self = GeluErf;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
bf16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f16(v: f16) -> f16 {
f16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f32(v: f32) -> f32 {
Self::f64(v as f64) as f32
}
#[inline(always)]
fn f64(v: f64) -> f64 {
(crate::cpu::erf::erf(v / 2f64.sqrt()) + 1.) * 0.5 * v
}
#[inline(always)]
fn u8(_: u8) -> u8 {
0
}
#[inline(always)]
fn u32(_: u32) -> u32 {
0
}
#[inline(always)]
fn i64(_: i64) -> i64 {
0
}
} }
impl UnaryOpT for Relu { impl UnaryOpT for Relu {
@ -975,10 +687,6 @@ impl BackpropOp {
}; };
Self(op) Self(op)
} }
pub(crate) fn is_none(&self) -> bool {
self.0.is_none()
}
} }
impl std::ops::Deref for BackpropOp { impl std::ops::Deref for BackpropOp {

View File

@ -193,50 +193,6 @@ impl Object {
_ => Err(self), _ => Err(self),
} }
} }
pub fn into_tensor_info(
self,
name: Self,
dir_name: &std::path::Path,
) -> Result<Option<TensorInfo>> {
let name = match name.unicode() {
Ok(name) => name,
Err(_) => return Ok(None),
};
let (callable, args) = match self.reduce() {
Ok(callable_args) => callable_args,
_ => return Ok(None),
};
let (callable, args) = match callable {
Object::Class {
module_name,
class_name,
} if module_name == "torch._tensor" && class_name == "_rebuild_from_type_v2" => {
let mut args = args.tuple()?;
let callable = args.remove(0);
let args = args.remove(1);
(callable, args)
}
_ => (callable, args),
};
match callable {
Object::Class {
module_name,
class_name,
} if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
_ => return Ok(None),
};
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
let mut path = dir_name.to_path_buf();
path.push(file_path);
Ok(Some(TensorInfo {
name,
dtype,
layout,
path: path.to_string_lossy().into_owned(),
storage_size,
}))
}
} }
impl TryFrom<Object> for String { impl TryFrom<Object> for String {
@ -609,7 +565,6 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
"HalfStorage" => DType::F16, "HalfStorage" => DType::F16,
"BFloat16Storage" => DType::BF16, "BFloat16Storage" => DType::BF16,
"ByteStorage" => DType::U8, "ByteStorage" => DType::U8,
"LongStorage" => DType::I64,
other => { other => {
crate::bail!("unsupported storage type {other}") crate::bail!("unsupported storage type {other}")
} }
@ -668,10 +623,50 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
}; };
if let Object::Dict(key_values) = obj { if let Object::Dict(key_values) = obj {
for (name, value) in key_values.into_iter() { for (name, value) in key_values.into_iter() {
match value.into_tensor_info(name, &dir_name) { let name = match name.unicode() {
Ok(Some(tensor_info)) => tensor_infos.push(tensor_info), Ok(name) => name,
Ok(None) => {} Err(_) => continue,
Err(err) => eprintln!("skipping: {err:?}"), };
let (callable, args) = match value.reduce() {
Ok(callable_args) => callable_args,
_ => continue,
};
let (callable, args) = match callable {
Object::Class {
module_name,
class_name,
} if module_name == "torch._tensor"
&& class_name == "_rebuild_from_type_v2" =>
{
let mut args = args.tuple()?;
let callable = args.remove(0);
let args = args.remove(1);
(callable, args)
}
_ => (callable, args),
};
match callable {
Object::Class {
module_name,
class_name,
} if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
_ => continue,
};
match rebuild_args(args) {
Ok((layout, dtype, file_path, storage_size)) => {
let mut path = dir_name.clone();
path.push(file_path);
tensor_infos.push(TensorInfo {
name,
dtype,
layout,
path: path.to_string_lossy().into_owned(),
storage_size,
})
}
Err(err) => {
eprintln!("skipping {name}: {err:?}")
}
} }
} }
} }
@ -728,16 +723,3 @@ impl PthTensors {
Ok(Some(tensor)) Ok(Some(tensor))
} }
} }
/// Read all the tensors from a PyTorch pth file.
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
let pth = PthTensors::new(path)?;
let tensor_names = pth.tensor_infos.keys();
let mut tensors = Vec::with_capacity(tensor_names.len());
for name in tensor_names {
if let Some(tensor) = pth.get(name)? {
tensors.push((name.to_string(), tensor))
}
}
Ok(tensors)
}

View File

@ -50,9 +50,14 @@ pub(crate) unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 {
#[inline(always)] #[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> { pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0; let qk = QK8_0;
let nb = n / qk;
if n % QK8_0 != 0 { if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}") crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
} }
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}
unsafe { unsafe {
let mut acc = _mm256_setzero_ps(); let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) { for (x, y) in xs.iter().zip(ys.iter()) {
@ -633,35 +638,3 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
Ok(hsum_float_8(acc) + summs) Ok(hsum_float_8(acc) + summs)
} }
} }
#[inline(always)]
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
let qk = QK_K;
if n % qk != 0 {
crate::bail!("vec_dot_q8k_8k: {n} is not divisible by {qk}")
}
unsafe {
let mut acc = _mm256_setzero_ps();
for (xs, ys) in xs.iter().zip(ys.iter()) {
let mut sumi = _mm256_setzero_si256();
let x_qs = xs.qs.as_ptr();
let y_qs = ys.qs.as_ptr();
for j in (0..QK_K).step_by(32) {
let xs = _mm256_loadu_si256(x_qs.add(j) as *const __m256i);
let ys = _mm256_loadu_si256(y_qs.add(j) as *const __m256i);
let xs0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xs, 0));
let ys0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(ys, 0));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(xs0, ys0));
let xs1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xs, 1));
let ys1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(ys, 1));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(xs1, ys1));
}
let d = _mm256_set1_ps(xs.d * ys.d);
acc = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi), acc);
}
Ok(hsum_float_8(acc))
}
}

View File

@ -135,13 +135,7 @@ pub fn qtensor_from_ggml(
dims: Vec<usize>, dims: Vec<usize>,
) -> Result<super::QTensor> { ) -> Result<super::QTensor> {
let tensor_elems = dims.iter().product::<usize>(); let tensor_elems = dims.iter().product::<usize>();
let blck_size = ggml_dtype.blck_size(); let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
if tensor_elems % blck_size != 0 {
crate::bail!(
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
)
}
let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
match ggml_dtype { match ggml_dtype {
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims), GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),

View File

@ -29,7 +29,6 @@ impl TryFrom<u32> for Magic {
pub enum VersionedMagic { pub enum VersionedMagic {
GgufV1, GgufV1,
GgufV2, GgufV2,
GgufV3,
} }
impl VersionedMagic { impl VersionedMagic {
@ -40,7 +39,6 @@ impl VersionedMagic {
let versioned_magic = match (magic, version) { let versioned_magic = match (magic, version) {
(Magic::Gguf, 1) => Self::GgufV1, (Magic::Gguf, 1) => Self::GgufV1,
(Magic::Gguf, 2) => Self::GgufV2, (Magic::Gguf, 2) => Self::GgufV2,
(Magic::Gguf, 3) => Self::GgufV3,
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"), _ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
}; };
Ok(versioned_magic) Ok(versioned_magic)
@ -61,13 +59,8 @@ impl TensorInfo {
tensor_data_offset: u64, tensor_data_offset: u64,
) -> Result<QTensor> { ) -> Result<QTensor> {
let tensor_elems = self.shape.elem_count(); let tensor_elems = self.shape.elem_count();
let blck_size = self.ggml_dtype.blck_size(); let size_in_bytes =
if tensor_elems % blck_size != 0 { tensor_elems * self.ggml_dtype.type_size() / self.ggml_dtype.blck_size();
crate::bail!(
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
)
}
let size_in_bytes = tensor_elems / blck_size * self.ggml_dtype.type_size();
let mut raw_data = vec![0u8; size_in_bytes]; let mut raw_data = vec![0u8; size_in_bytes];
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?; reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
reader.read_exact(&mut raw_data)?; reader.read_exact(&mut raw_data)?;
@ -86,9 +79,7 @@ pub struct Content {
fn read_string<R: std::io::Read>(reader: &mut R, magic: &VersionedMagic) -> Result<String> { fn read_string<R: std::io::Read>(reader: &mut R, magic: &VersionedMagic) -> Result<String> {
let len = match magic { let len = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize, VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => { VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
reader.read_u64::<LittleEndian>()? as usize
}
}; };
let mut v = vec![0u8; len]; let mut v = vec![0u8; len];
reader.read_exact(&mut v)?; reader.read_exact(&mut v)?;
@ -288,9 +279,7 @@ impl Value {
let value_type = ValueType::from_u32(value_type)?; let value_type = ValueType::from_u32(value_type)?;
let len = match magic { let len = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize, VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => { VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
reader.read_u64::<LittleEndian>()? as usize
}
}; };
let mut vs = Vec::with_capacity(len); let mut vs = Vec::with_capacity(len);
for _ in 0..len { for _ in 0..len {
@ -387,15 +376,11 @@ impl Content {
let tensor_count = match magic { let tensor_count = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize, VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => { VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
reader.read_u64::<LittleEndian>()? as usize
}
}; };
let metadata_kv_count = match magic { let metadata_kv_count = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize, VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => { VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
reader.read_u64::<LittleEndian>()? as usize
}
}; };
let mut metadata = HashMap::new(); let mut metadata = HashMap::new();
@ -417,7 +402,7 @@ impl Content {
reader.read_u32_into::<LittleEndian>(&mut dimensions)?; reader.read_u32_into::<LittleEndian>(&mut dimensions)?;
dimensions.into_iter().map(|c| c as usize).collect() dimensions.into_iter().map(|c| c as usize).collect()
} }
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => { VersionedMagic::GgufV2 => {
let mut dimensions = vec![0; n_dimensions as usize]; let mut dimensions = vec![0; n_dimensions as usize];
reader.read_u64_into::<LittleEndian>(&mut dimensions)?; reader.read_u64_into::<LittleEndian>(&mut dimensions)?;
dimensions.into_iter().map(|c| c as usize).collect() dimensions.into_iter().map(|c| c as usize).collect()

View File

@ -34,9 +34,6 @@ pub trait GgmlType: Sized + Clone + Send + Sync {
/// Dot product used as a building block for quantized mat-mul. /// Dot product used as a building block for quantized mat-mul.
/// n is the number of elements to be considered. /// n is the number of elements to be considered.
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>; fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>;
/// Generic implementation of the dot product without simd optimizations.
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>;
} }
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
@ -228,17 +225,15 @@ impl GgmlType for BlockQ4_0 {
#[cfg(target_feature = "neon")] #[cfg(target_feature = "neon")]
return super::neon::vec_dot_q4_0_q8_0(n, xs, ys); return super::neon::vec_dot_q4_0_q8_0(n, xs, ys);
#[cfg(target_feature = "simd128")]
return super::simd128::vec_dot_q4_0_q8_0(n, xs, ys);
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
let qk = QK8_0; let qk = QK8_0;
let nb = n / qk;
if n % QK8_0 != 0 { if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}") crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
} }
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}
// Generic implementation. // Generic implementation.
let mut sumf = 0f32; let mut sumf = 0f32;
for (xs, ys) in xs.iter().zip(ys.iter()) { for (xs, ys) in xs.iter().zip(ys.iter()) {
@ -260,10 +255,6 @@ impl GgmlType for BlockQ4_1 {
type VecDotType = BlockQ8_1; type VecDotType = BlockQ8_1;
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> { fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
// ggml_vec_dot_q4_1_q8_1 // ggml_vec_dot_q4_1_q8_1
let qk = QK8_1; let qk = QK8_1;
if n % qk != 0 { if n % qk != 0 {
@ -363,10 +354,7 @@ impl GgmlType for BlockQ5_0 {
if nb % 2 != 0 { if nb % 2 != 0 {
crate::bail!("vec_dot_q5_0_q8_0: {n}, nb is not divisible by 2") crate::bail!("vec_dot_q5_0_q8_0: {n}, nb is not divisible by 2")
} }
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(_n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
// Generic implementation. // Generic implementation.
let mut sumf = 0f32; let mut sumf = 0f32;
@ -457,10 +445,6 @@ impl GgmlType for BlockQ5_1 {
type VecDotType = BlockQ8_1; type VecDotType = BlockQ8_1;
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> { fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
let qk = Self::BLCK_SIZE; let qk = Self::BLCK_SIZE;
if n % Self::BLCK_SIZE != 0 { if n % Self::BLCK_SIZE != 0 {
crate::bail!("vec_dot_q5_1_q8_1: {n} is not divisible by {qk}") crate::bail!("vec_dot_q5_1_q8_1: {n} is not divisible by {qk}")
@ -622,13 +606,6 @@ impl GgmlType for BlockQ8_0 {
#[cfg(target_feature = "neon")] #[cfg(target_feature = "neon")]
return super::neon::vec_dot_q8_0_q8_0(n, xs, ys); return super::neon::vec_dot_q8_0_q8_0(n, xs, ys);
#[cfg(target_feature = "simd128")]
return super::simd128::vec_dot_q8_0_q8_0(n, xs, ys);
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
let qk = QK8_0; let qk = QK8_0;
if n % QK8_0 != 0 { if n % QK8_0 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}") crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
@ -654,11 +631,7 @@ impl GgmlType for BlockQ8_1 {
const BLCK_SIZE: usize = QK8_1; const BLCK_SIZE: usize = QK8_1;
type VecDotType = BlockQ8_1; type VecDotType = BlockQ8_1;
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> { fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
unimplemented!("no support for vec-dot on Q8_1") unimplemented!("no support for vec-dot on Q8_1")
} }
@ -708,13 +681,6 @@ impl GgmlType for BlockQ2K {
#[cfg(target_feature = "neon")] #[cfg(target_feature = "neon")]
return super::neon::vec_dot_q2k_q8k(n, xs, ys); return super::neon::vec_dot_q2k_q8k(n, xs, ys);
#[cfg(target_feature = "simd128")]
return super::simd128::vec_dot_q2k_q8k(n, xs, ys);
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
if n % QK_K != 0 { if n % QK_K != 0 {
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}") crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
} }
@ -735,17 +701,18 @@ impl GgmlType for BlockQ2K {
let mut isum = 0; let mut isum = 0;
let mut is = 0; let mut is = 0;
let mut d;
for _ in 0..(QK_K / 128) { for _ in 0..(QK_K / 128) {
let mut shift = 0; let mut shift = 0;
for _ in 0..4 { for _ in 0..4 {
let d = (sc[is] & 0xF) as i32; d = (sc[is] & 0xF) as i32;
is += 1; is += 1;
let mut isuml = 0; let mut isuml = 0;
for l in 0..16 { for l in 0..16 {
isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32); isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
} }
isum += d * isuml; isum += d * isuml;
let d = (sc[is] & 0xF) as i32; d = (sc[is] & 0xF) as i32;
is += 1; is += 1;
isuml = 0; isuml = 0;
for l in 16..32 { for l in 16..32 {
@ -884,10 +851,6 @@ impl GgmlType for BlockQ3K {
#[cfg(target_feature = "neon")] #[cfg(target_feature = "neon")]
return super::neon::vec_dot_q3k_q8k(n, xs, ys); return super::neon::vec_dot_q3k_q8k(n, xs, ys);
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
if n % QK_K != 0 { if n % QK_K != 0 {
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}") crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
} }
@ -1114,6 +1077,7 @@ impl GgmlType for BlockQ3K {
let d_all = block.d.to_f32(); let d_all = block.d.to_f32();
let mut m = 1; let mut m = 1;
let mut is = 0; let mut is = 0;
let mut dl;
// Dequantize both 128 long blocks // Dequantize both 128 long blocks
// 32 qs values per 128 long block // 32 qs values per 128 long block
@ -1124,7 +1088,7 @@ impl GgmlType for BlockQ3K {
for (scale_index, scale_scoped_y) in for (scale_index, scale_scoped_y) in
shift_scoped_y.chunks_exact_mut(16).enumerate() shift_scoped_y.chunks_exact_mut(16).enumerate()
{ {
let dl = d_all * (scales[is] as f32 - 32.0); dl = d_all * (scales[is] as f32 - 32.0);
for (i, inner_y) in scale_scoped_y.iter_mut().enumerate() { for (i, inner_y) in scale_scoped_y.iter_mut().enumerate() {
let new_y = dl let new_y = dl
* (((qs[i + 16 * scale_index] >> shift) & 3) as i8 * (((qs[i + 16 * scale_index] >> shift) & 3) as i8
@ -1162,13 +1126,6 @@ impl GgmlType for BlockQ4K {
#[cfg(target_feature = "neon")] #[cfg(target_feature = "neon")]
return super::neon::vec_dot_q4k_q8k(n, xs, ys); return super::neon::vec_dot_q4k_q8k(n, xs, ys);
#[cfg(target_feature = "simd128")]
return super::simd128::vec_dot_q4k_q8k(n, xs, ys);
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
if n % QK_K != 0 { if n % QK_K != 0 {
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}") crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
} }
@ -1355,10 +1312,6 @@ impl GgmlType for BlockQ5K {
#[cfg(target_feature = "neon")] #[cfg(target_feature = "neon")]
return super::neon::vec_dot_q5k_q8k(n, xs, ys); return super::neon::vec_dot_q5k_q8k(n, xs, ys);
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
if n % QK_K != 0 { if n % QK_K != 0 {
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}") crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
} }
@ -1576,13 +1529,6 @@ impl GgmlType for BlockQ6K {
#[cfg(target_feature = "neon")] #[cfg(target_feature = "neon")]
return super::neon::vec_dot_q6k_q8k(n, xs, ys); return super::neon::vec_dot_q6k_q8k(n, xs, ys);
#[cfg(target_feature = "simd128")]
return super::simd128::vec_dot_q6k_q8k(n, xs, ys);
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
if n % QK_K != 0 { if n % QK_K != 0 {
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}") crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
} }
@ -1751,38 +1697,8 @@ impl GgmlType for BlockQ8K {
const BLCK_SIZE: usize = QK_K; const BLCK_SIZE: usize = QK_K;
type VecDotType = BlockQ8K; type VecDotType = BlockQ8K;
#[allow(unreachable_code)] fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> { unreachable!()
#[cfg(target_feature = "avx")]
return super::avx::vec_dot_q8k_q8k(n, xs, ys);
#[cfg(target_feature = "neon")]
return super::neon::vec_dot_q8k_q8k(n, xs, ys);
#[cfg(target_feature = "simd128")]
return super::simd128::vec_dot_q8k_q8k(n, xs, ys);
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
let qk = QK_K;
if n % QK_K != 0 {
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
}
// Generic implementation.
let mut sumf = 0f32;
for (xs, ys) in xs.iter().zip(ys.iter()) {
let sum_i = xs
.qs
.iter()
.zip(ys.qs.iter())
.map(|(&x, &y)| x as i32 * y as i32)
.sum::<i32>();
sumf += sum_i as f32 * xs.d * ys.d
}
Ok(sumf)
} }
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
@ -1888,10 +1804,6 @@ impl GgmlType for f32 {
type VecDotType = f32; type VecDotType = f32;
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> { fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
if xs.len() < n { if xs.len() < n {
crate::bail!("size mismatch {} < {n}", xs.len()) crate::bail!("size mismatch {} < {n}", xs.len())
} }
@ -1926,10 +1838,6 @@ impl GgmlType for f16 {
type VecDotType = f16; type VecDotType = f16;
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> { fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
if xs.len() < n { if xs.len() < n {
crate::bail!("size mismatch {} < {n}", xs.len()) crate::bail!("size mismatch {} < {n}", xs.len())
} }

View File

@ -7,8 +7,6 @@ pub mod gguf_file;
pub mod k_quants; pub mod k_quants;
#[cfg(target_feature = "neon")] #[cfg(target_feature = "neon")]
pub mod neon; pub mod neon;
#[cfg(target_feature = "simd128")]
pub mod simd128;
pub mod utils; pub mod utils;
pub use k_quants::GgmlType; pub use k_quants::GgmlType;
@ -231,40 +229,20 @@ impl QTensor {
} }
} }
#[derive(Clone, Debug)] #[derive(Debug)]
pub enum QMatMul { pub struct QMatMul(std::sync::Arc<QTensor>);
QTensor(std::sync::Arc<QTensor>),
Tensor(Tensor),
}
thread_local! {
static DEQUANTIZE_ALL: bool = {
match std::env::var("CANDLE_DEQUANTIZE_ALL") {
Ok(s) => {
!s.is_empty() && s != "0"
},
Err(_) => false,
}
}
}
impl QMatMul { impl QMatMul {
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> { pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Self {
let dequantize = match qtensor.dtype() { Self(qtensor)
GgmlDType::F32 | GgmlDType::F16 => true,
_ => DEQUANTIZE_ALL.with(|b| *b),
};
let t = if dequantize {
let tensor = qtensor.dequantize(&Device::Cpu)?;
Self::Tensor(tensor)
} else {
Self::QTensor(qtensor)
};
Ok(t)
} }
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> { pub fn from_qtensor(qtensor: QTensor) -> Self {
Self::from_arc(std::sync::Arc::new(qtensor)) Self(std::sync::Arc::new(qtensor))
}
pub fn inner(&self) -> &std::sync::Arc<QTensor> {
&self.0
} }
} }
@ -307,18 +285,8 @@ impl crate::CustomOp1 for QTensor {
} }
} }
impl crate::Module for QMatMul { impl QMatMul {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self { xs.apply_op1_no_bwd(self.0.as_ref())
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
Self::Tensor(w) => {
let w = match *xs.dims() {
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
_ => w.t()?,
};
xs.matmul(&w)
}
}
} }
} }

View File

@ -19,29 +19,42 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
if n % QK8_0 != 0 { if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}") crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
} }
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}
unsafe { unsafe {
let mut sumv0 = vdupq_n_f32(0.0f32); let mut sumv0 = vdupq_n_f32(0.0f32);
for i in 0..nb { let mut sumv1 = vdupq_n_f32(0.0f32);
for i in (0..nb).step_by(2) {
let x0 = &xs[i]; let x0 = &xs[i];
let x1 = &xs[i + 1];
let y0 = &ys[i]; let y0 = &ys[i];
let y1 = &ys[i + 1];
let m4b = vdupq_n_u8(0x0F); let m4b = vdupq_n_u8(0x0F);
let s8b = vdupq_n_s8(0x8); let s8b = vdupq_n_s8(0x8);
let v0_0 = vld1q_u8(x0.qs.as_ptr()); let v0_0 = vld1q_u8(x0.qs.as_ptr());
let v0_1 = vld1q_u8(x1.qs.as_ptr());
// 4-bit -> 8-bit // 4-bit -> 8-bit
let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b)); let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
let v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
let v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
// sub 8 // sub 8
let v0_0ls = vsubq_s8(v0_0l, s8b); let v0_0ls = vsubq_s8(v0_0l, s8b);
let v0_0hs = vsubq_s8(v0_0h, s8b); let v0_0hs = vsubq_s8(v0_0h, s8b);
let v0_1ls = vsubq_s8(v0_1l, s8b);
let v0_1hs = vsubq_s8(v0_1h, s8b);
// load y // load y
let v1_0l = vld1q_s8(y0.qs.as_ptr()); let v1_0l = vld1q_s8(y0.qs.as_ptr());
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16)); let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
let v1_1l = vld1q_s8(y1.qs.as_ptr());
let v1_1h = vld1q_s8(y1.qs.as_ptr().add(16));
// TODO: Support dotprod when it's available outside of nightly. // TODO: Support dotprod when it's available outside of nightly.
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l)); let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
@ -49,16 +62,28 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h)); let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h)); let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
let pl1l = vmull_s8(vget_low_s8(v0_1ls), vget_low_s8(v1_1l));
let pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
let ph1l = vmull_s8(vget_low_s8(v0_1hs), vget_low_s8(v1_1h));
let ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
let pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
let ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
sumv0 = vmlaq_n_f32( sumv0 = vmlaq_n_f32(
sumv0, sumv0,
vcvtq_f32_s32(vaddq_s32(pl0, ph0)), vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
x0.d.to_f32() * y0.d.to_f32(), x0.d.to_f32() * y0.d.to_f32(),
); );
sumv1 = vmlaq_n_f32(
sumv1,
vcvtq_f32_s32(vaddq_s32(pl1, ph1)),
x1.d.to_f32() * y1.d.to_f32(),
);
} }
Ok(vaddvq_f32(sumv0)) Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
} }
} }
@ -69,18 +94,28 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}") crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
} }
let nb = n / QK8_0; let nb = n / QK8_0;
if nb % 2 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
}
unsafe { unsafe {
let mut sumv0 = vdupq_n_f32(0.0f32); let mut sumv0 = vdupq_n_f32(0.0f32);
for i in 0..nb { let mut sumv1 = vdupq_n_f32(0.0f32);
for i in (0..nb).step_by(2) {
let x0 = &xs[i]; let x0 = &xs[i];
let x1 = &xs[i + 1];
let y0 = &ys[i]; let y0 = &ys[i];
let y1 = &ys[i + 1];
let x0_0 = vld1q_s8(x0.qs.as_ptr()); let x0_0 = vld1q_s8(x0.qs.as_ptr());
let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16)); let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16));
let x1_0 = vld1q_s8(x1.qs.as_ptr());
let x1_1 = vld1q_s8(x1.qs.as_ptr().add(16));
// load y // load y
let y0_0 = vld1q_s8(y0.qs.as_ptr()); let y0_0 = vld1q_s8(y0.qs.as_ptr());
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16)); let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
let y1_0 = vld1q_s8(y1.qs.as_ptr());
let y1_1 = vld1q_s8(y1.qs.as_ptr().add(16));
// TODO dotprod once this is the intrinsics are. // TODO dotprod once this is the intrinsics are.
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0)); let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
@ -88,48 +123,31 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1)); let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1)); let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
let p1_0 = vmull_s8(vget_low_s8(x1_0), vget_low_s8(y1_0));
let p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
let p1_2 = vmull_s8(vget_low_s8(x1_1), vget_low_s8(y1_1));
let p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1)); let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3)); let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
let p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
let p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
sumv0 = vmlaq_n_f32( sumv0 = vmlaq_n_f32(
sumv0, sumv0,
vcvtq_f32_s32(vaddq_s32(p0, p1)), vcvtq_f32_s32(vaddq_s32(p0, p1)),
x0.d.to_f32() * y0.d.to_f32(), x0.d.to_f32() * y0.d.to_f32(),
); );
sumv1 = vmlaq_n_f32(
sumv1,
vcvtq_f32_s32(vaddq_s32(p2, p3)),
x1.d.to_f32() * y1.d.to_f32(),
);
} }
Ok(vaddvq_f32(sumv0)) Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
} }
} }
#[inline(always)]
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
let qk = QK_K;
if n % QK_K != 0 {
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
}
let mut sumf = 0f32;
for (xs, ys) in xs.iter().zip(ys.iter()) {
unsafe {
let mut sum_i = vdupq_n_s32(0);
let scale = xs.d * ys.d;
let xs = xs.qs.as_ptr();
let ys = ys.qs.as_ptr();
for i in (0..QK_K).step_by(16) {
let xs = vld1q_s8(xs.add(i));
let ys = vld1q_s8(ys.add(i));
let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys));
let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys));
let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up));
sum_i = vaddq_s32(sum_i, xy)
}
sumf += vaddvq_s32(sum_i) as f32 * scale
}
}
Ok(sumf)
}
#[inline(always)] #[inline(always)]
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> { pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 { if n % QK_K != 0 {

View File

@ -1,419 +0,0 @@
use super::k_quants::{BlockQ2K, BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K};
use crate::Result;
use byteorder::{ByteOrder, LittleEndian};
use half::f16;
use core::arch::wasm32::*;
#[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
unsafe {
let mut acc = f32x4_splat(0.0f32);
for (x, y) in xs.iter().zip(ys.iter()) {
let x1234 = v128_load(x.qs.as_ptr() as *const v128);
let x12 = v128_and(x1234, u8x16_splat(0x0F));
let x12 = i8x16_sub(x12, i8x16_splat(8));
let x34 = u8x16_shr(x1234, 4);
let x34 = i8x16_sub(x34, i8x16_splat(8));
let x1 = i16x8_extend_low_i8x16(x12);
let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr());
let sum_xy = i32x4_dot_i16x8(x1, y1);
let x2 = i16x8_extend_high_i8x16(x12);
let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2));
let x3 = i16x8_extend_low_i8x16(x34);
let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3));
let x4 = i16x8_extend_high_i8x16(x34);
let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4));
let sum_xy = f32x4_convert_i32x4(sum_xy);
// f32x4_relaxed_madd is nightly only.
let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d));
let scaled = f32x4_mul(sum_xy, d);
acc = f32x4_add(acc, scaled)
}
let res = f32x4_extract_lane::<0>(acc)
+ f32x4_extract_lane::<1>(acc)
+ f32x4_extract_lane::<2>(acc)
+ f32x4_extract_lane::<3>(acc);
Ok(res)
}
}
#[inline(always)]
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
}
unsafe {
let mut acc = f32x4_splat(0.0f32);
for (x, y) in xs.iter().zip(ys.iter()) {
let x1 = i16x8_load_extend_i8x8(x.qs.as_ptr());
let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr());
let sum_xy = i32x4_dot_i16x8(x1, y1);
let x2 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(8));
let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2));
let x3 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(16));
let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3));
let x4 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(24));
let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4));
let sum_xy = f32x4_convert_i32x4(sum_xy);
// f32x4_relaxed_madd is nightly only.
let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d));
let scaled = f32x4_mul(sum_xy, d);
acc = f32x4_add(acc, scaled)
}
let res = f32x4_extract_lane::<0>(acc)
+ f32x4_extract_lane::<1>(acc)
+ f32x4_extract_lane::<2>(acc)
+ f32x4_extract_lane::<3>(acc);
Ok(res)
}
}
#[inline(always)]
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
}
unsafe {
let mut sumf = f32x4_splat(0f32);
for (x, y) in xs.iter().zip(ys.iter()) {
let mut q2: &[_] = &x.qs;
let mut q8: &[_] = &y.qs;
let sc = &x.scales;
let mut summs = i32x4_splat(0);
for i in (0..(QK_K / 16)).step_by(4) {
let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(i));
let scales = i32x4_shr(
i32x4(
sc[i] as i32,
sc[i + 1] as i32,
sc[i + 2] as i32,
sc[i + 3] as i32,
),
4,
);
summs = i32x4_add(summs, i32x4_mul(bsums, scales))
}
let summs = f32x4_convert_i32x4(summs);
let dall = y.d * x.d.to_f32();
let dmin = y.d * x.dmin.to_f32();
let mut isum = i32x4_splat(0);
let mut is = 0;
for _ in 0..(QK_K / 128) {
let mut shift = 0;
for _ in 0..4 {
let d = (sc[is] & 0xF) as i32;
is += 1;
let mut isuml = i16x8_splat(0);
for l in (0..16).step_by(8) {
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
}
let dd = i32x4_splat(d);
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
let d = (sc[is] & 0xF) as i32;
is += 1;
let mut isuml = i16x8_splat(0);
for l in (16..32).step_by(8) {
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
}
let dd = i32x4_splat(d);
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
shift += 2;
// adjust the indexing
q8 = &q8[32..];
}
// adjust the indexing
q2 = &q2[32..];
}
let isum = f32x4_convert_i32x4(isum);
sumf = f32x4_add(
sumf,
f32x4_sub(
f32x4_mul(isum, f32x4_splat(dall)),
f32x4_mul(summs, f32x4_splat(dmin)),
),
);
}
let sumf = f32x4_extract_lane::<0>(sumf)
+ f32x4_extract_lane::<1>(sumf)
+ f32x4_extract_lane::<2>(sumf)
+ f32x4_extract_lane::<3>(sumf);
Ok(sumf)
}
}
#[inline(always)]
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
}
const KMASK1: u32 = 0x3f3f3f3f;
const KMASK2: u32 = 0x0f0f0f0f;
const KMASK3: u32 = 0x03030303;
let mut utmp: [u32; 4] = [0; 4];
let mut scales: [u8; 8] = [0; 8];
let mut mins: [u8; 8] = [0; 8];
let mut aux8: [u8; QK_K] = [0; QK_K];
let mut sums = f32x4_splat(0f32);
unsafe {
for (y, x) in ys.iter().zip(xs.iter()) {
let q4 = &x.qs;
let q8 = &y.qs;
for j in 0..QK_K / 64 {
let q4_1 = v128_load(q4.as_ptr().add(32 * j) as *const v128);
let q4_2 = v128_load(q4.as_ptr().add(32 * j + 16) as *const v128);
v128_store(
aux8.as_mut_ptr().add(64 * j) as *mut v128,
v128_and(q4_1, u8x16_splat(0x0F)),
);
v128_store(
aux8.as_mut_ptr().add(64 * j + 16) as *mut v128,
v128_and(q4_2, u8x16_splat(0x0F)),
);
v128_store(
aux8.as_mut_ptr().add(64 * j + 32) as *mut v128,
u8x16_shr(q4_1, 4),
);
v128_store(
aux8.as_mut_ptr().add(64 * j + 48) as *mut v128,
u8x16_shr(q4_2, 4),
);
}
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
let uaux = utmp[1] & KMASK1;
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
utmp[2] = uaux;
utmp[0] &= KMASK1;
//extract scales and mins
LittleEndian::write_u32_into(&utmp[0..2], &mut scales);
LittleEndian::write_u32_into(&utmp[2..4], &mut mins);
let mut sumi = i32x4_splat(0);
for j in (0..QK_K / 16).step_by(4) {
let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(j));
let (m1, m2) = (mins[j / 2] as i32, mins[j / 2 + 1] as i32);
let mins = i32x4(m1, m1, m2, m2);
sumi = i32x4_add(sumi, i32x4_mul(bsums, mins));
}
let mut aux32 = i32x4_splat(0i32);
for (scale_i, scale) in scales.iter().enumerate() {
let scale = i32x4_splat(*scale as i32);
for j in 0..4 {
let i = 32 * scale_i + 8 * j;
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(i));
let aux8 = i16x8_load_extend_u8x8(aux8.as_ptr().add(i));
let aux16 = i16x8_mul(q8, aux8);
aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_low_i16x8(aux16)));
aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_high_i16x8(aux16)));
}
}
let aux32 = f32x4_convert_i32x4(aux32);
let d = f32x4_splat(x.d.to_f32() * y.d);
sums = f32x4_add(sums, f32x4_mul(aux32, d));
let dmin = x.dmin.to_f32() * y.d;
let dmin = f32x4_splat(dmin);
let sumi = f32x4_convert_i32x4(sumi);
sums = f32x4_sub(sums, f32x4_mul(sumi, dmin));
}
let sums = f32x4_extract_lane::<0>(sums)
+ f32x4_extract_lane::<1>(sums)
+ f32x4_extract_lane::<2>(sums)
+ f32x4_extract_lane::<3>(sums);
Ok(sums)
}
}
#[inline(always)]
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
}
let mut aux8 = [0i8; QK_K];
unsafe {
let mut sums = f32x4_splat(0f32);
for (x, y) in xs.iter().zip(ys.iter()) {
let q4 = &x.ql;
let qh = &x.qh;
let q8 = &y.qs;
let mut aux32 = f32x4_splat(0f32);
for j in (0..QK_K).step_by(128) {
let aux8 = aux8.as_mut_ptr().add(j);
let q4 = &q4.as_ptr().add(j / 2);
let qh = &qh.as_ptr().add(j / 4);
for l in (0..32).step_by(16) {
// aux8[l] = (((q4[l] & 0xF) | ((qh[l] & 3) << 4)) as i32 - 32) as i8;
let a8 = v128_or(
v128_and(v128_load(q4.add(l) as *const v128), u8x16_splat(0xF)),
u8x16_shl(
v128_and(v128_load(qh.add(l) as *const v128), u8x16_splat(3)),
4,
),
);
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
v128_store(
aux8.add(l) as *mut v128,
i8x16_narrow_i16x8(a8_low, a8_high),
);
// aux8[l + 32] =
// (((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i32 - 32) as i8;
let a8 = v128_or(
v128_and(v128_load(q4.add(l + 32) as *const v128), u8x16_splat(0xF)),
u8x16_shl(
v128_and(
u8x16_shr(v128_load(qh.add(l) as *const v128), 2),
u8x16_splat(3),
),
4,
),
);
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
v128_store(
aux8.add(l + 32) as *mut v128,
i8x16_narrow_i16x8(a8_low, a8_high),
);
// aux8[l + 64] = (((q4[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i32 - 32) as i8;
let a8 = v128_or(
u8x16_shr(v128_load(q4.add(l) as *const v128), 4),
u8x16_shl(
v128_and(
u8x16_shr(v128_load(qh.add(l) as *const v128), 4),
u8x16_splat(3),
),
4,
),
);
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
v128_store(
aux8.add(l + 64) as *mut v128,
i8x16_narrow_i16x8(a8_low, a8_high),
);
// aux8[l + 96] =
// (((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i32 - 32) as i8;
let a8 = v128_or(
u8x16_shr(v128_load(q4.add(l + 32) as *const v128), 4),
u8x16_shl(
v128_and(
u8x16_shr(v128_load(qh.add(l) as *const v128), 6),
u8x16_splat(3),
),
4,
),
);
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
v128_store(
aux8.add(l + 96) as *mut v128,
i8x16_narrow_i16x8(a8_low, a8_high),
);
}
}
for (j, &scale) in x.scales.iter().enumerate() {
let scale = f32x4_splat(scale as f32);
for offset in [0, 8] {
let aux16 = i16x8_mul(
i16x8_load_extend_i8x8(q8.as_ptr().add(16 * j + offset)),
i16x8_load_extend_i8x8(aux8.as_ptr().add(16 * j + offset)),
);
aux32 = f32x4_add(
aux32,
f32x4_mul(f32x4_convert_i32x4(i32x4_extend_low_i16x8(aux16)), scale),
);
aux32 = f32x4_add(
aux32,
f32x4_mul(f32x4_convert_i32x4(i32x4_extend_high_i16x8(aux16)), scale),
);
}
}
let d = f32x4_splat(x.d.to_f32() * y.d);
sums = f32x4_add(sums, f32x4_mul(aux32, d));
}
let sums = f32x4_extract_lane::<0>(sums)
+ f32x4_extract_lane::<1>(sums)
+ f32x4_extract_lane::<2>(sums)
+ f32x4_extract_lane::<3>(sums);
Ok(sums)
}
}
#[inline(always)]
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
let qk = QK_K;
if n % QK_K != 0 {
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
}
unsafe {
let mut acc = f32x4_splat(0.0f32);
for (xs, ys) in xs.iter().zip(ys.iter()) {
let x_qs = xs.qs.as_ptr();
let y_qs = ys.qs.as_ptr();
let mut sumi = i32x4_splat(0);
for j in (0..QK_K).step_by(8) {
let xs = i16x8_load_extend_i8x8(x_qs.add(j));
let ys = i16x8_load_extend_i8x8(y_qs.add(j));
let sum_xy = i32x4_dot_i16x8(xs, ys);
sumi = i32x4_add(sumi, sum_xy)
}
let d = f32x4_splat(xs.d * ys.d);
acc = f32x4_add(acc, f32x4_mul(f32x4_convert_i32x4(sumi), d))
}
let res = f32x4_extract_lane::<0>(acc)
+ f32x4_extract_lane::<1>(acc)
+ f32x4_extract_lane::<2>(acc)
+ f32x4_extract_lane::<3>(acc);
Ok(res)
}
}

View File

@ -17,7 +17,7 @@ pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
let expected_blocks = xs.len() / block_size; let expected_blocks = xs.len() / block_size;
let actual_blocks = ys.len(); let actual_blocks = ys.len();
// Validate that the input is the right size //validate that the input is the right size
if expected_blocks != actual_blocks { if expected_blocks != actual_blocks {
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!") crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
} }
@ -37,12 +37,12 @@ pub(super) fn group_for_dequantization<'a, 'b, T: super::k_quants::GgmlType>(
let actual_output_len = ys.len(); let actual_output_len = ys.len();
let expected_output_len = xs.len() * block_size; let expected_output_len = xs.len() * block_size;
// Validate that the output is the right size //validate that the output is the right size
if expected_output_len != actual_output_len { if expected_output_len != actual_output_len {
crate::bail!("dequantize {dtype:?}: ys (len = {actual_output_len}) does not match the expected length of {expected_output_len}!") crate::bail!("dequantize {dtype:?}: ys (len = {actual_output_len}) does not match the expected length of {expected_output_len}!")
} }
// Zip the blocks and outputs together //zip the blocks and outputs together
Ok(xs.iter().zip(ys.chunks_exact_mut(block_size)).collect()) Ok(xs.iter().zip(ys.chunks_exact_mut(block_size)).collect())
} }

View File

@ -78,7 +78,11 @@ impl st::View for &Tensor {
} }
impl Tensor { impl Tensor {
pub fn save_safetensors<P: AsRef<Path>>(&self, name: &str, filename: P) -> Result<()> { pub fn save_safetensors<P: AsRef<std::path::Path>>(
&self,
name: &str,
filename: P,
) -> Result<()> {
let data = [(name, self.clone())]; let data = [(name, self.clone())];
Ok(st::serialize_to_file(data, &None, filename.as_ref())?) Ok(st::serialize_to_file(data, &None, filename.as_ref())?)
} }
@ -251,134 +255,6 @@ pub fn save<K: AsRef<str> + Ord + std::fmt::Display, P: AsRef<Path>>(
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?) Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
} }
#[derive(yoke::Yokeable)]
struct SafeTensors_<'a>(SafeTensors<'a>);
pub struct MmapedSafetensors {
safetensors: Vec<yoke::Yoke<SafeTensors_<'static>, memmap2::Mmap>>,
routing: Option<HashMap<String, usize>>,
}
impl MmapedSafetensors {
/// Creates a wrapper around a memory mapped file and deserialize the safetensors header.
///
/// # Safety
///
/// The unsafe is inherited from [`memmap2::MmapOptions`].
pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
let p = p.as_ref();
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
let file = memmap2::MmapOptions::new()
.map(&file)
.map_err(|e| Error::from(e).with_path(p))?;
let safetensors = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
file,
|data: &[u8]| {
let st = safetensors::SafeTensors::deserialize(data)
.map_err(|e| Error::from(e).with_path(p))?;
Ok::<_, Error>(SafeTensors_(st))
},
)?;
Ok(Self {
safetensors: vec![safetensors],
routing: None,
})
}
/// Creates a wrapper around multiple memory mapped file and deserialize the safetensors headers.
///
/// If a tensor name appears in multiple files, the last entry is returned.
///
/// # Safety
///
/// The unsafe is inherited from [`memmap2::MmapOptions`].
pub unsafe fn multi<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
let mut routing = HashMap::new();
let mut safetensors = vec![];
for (index, p) in paths.iter().enumerate() {
let p = p.as_ref();
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
let file = memmap2::MmapOptions::new()
.map(&file)
.map_err(|e| Error::from(e).with_path(p))?;
let data = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
file,
|data: &[u8]| {
let st = safetensors::SafeTensors::deserialize(data)
.map_err(|e| Error::from(e).with_path(p))?;
Ok::<_, Error>(SafeTensors_(st))
},
)?;
for k in data.get().0.names() {
routing.insert(k.to_string(), index);
}
safetensors.push(data)
}
Ok(Self {
safetensors,
routing: Some(routing),
})
}
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
self.get(name)?.load(dev)
}
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
let mut tensors = vec![];
for safetensors in self.safetensors.iter() {
tensors.push(safetensors.get().0.tensors())
}
tensors.into_iter().flatten().collect()
}
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
let index = match &self.routing {
None => 0,
Some(routing) => {
let index = routing.get(name).ok_or_else(|| {
Error::CannotFindTensor {
path: name.to_string(),
}
.bt()
})?;
*index
}
};
Ok(self.safetensors[index].get().0.tensor(name)?)
}
}
pub struct BufferedSafetensors {
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
}
impl BufferedSafetensors {
/// Creates a wrapper around a binary buffer and deserialize the safetensors header.
pub fn new(buffer: Vec<u8>) -> Result<Self> {
let safetensors = yoke::Yoke::<SafeTensors_<'static>, Vec<u8>>::try_attach_to_cart(
buffer,
|data: &[u8]| {
let st = safetensors::SafeTensors::deserialize(data)?;
Ok::<_, Error>(SafeTensors_(st))
},
)?;
Ok(Self { safetensors })
}
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
self.get(name)?.load(dev)
}
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
self.safetensors.get().0.tensors()
}
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
Ok(self.safetensors.get().0.tensor(name)?)
}
}
pub struct MmapedFile { pub struct MmapedFile {
path: std::path::PathBuf, path: std::path::PathBuf,
inner: memmap2::Mmap, inner: memmap2::Mmap,
@ -391,7 +267,7 @@ impl MmapedFile {
/// # Safety /// # Safety
/// ///
/// The unsafe is inherited from [`memmap2::MmapOptions`]. /// The unsafe is inherited from [`memmap2::MmapOptions`].
pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> { pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
let p = p.as_ref(); let p = p.as_ref();
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?; let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
let inner = memmap2::MmapOptions::new() let inner = memmap2::MmapOptions::new()

View File

@ -203,7 +203,7 @@ impl Shape {
/// Check whether the two shapes are compatible for broadcast, and if it is the case return the /// Check whether the two shapes are compatible for broadcast, and if it is the case return the
/// broadcasted shape. This is to be used for binary pointwise ops. /// broadcasted shape. This is to be used for binary pointwise ops.
pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> { pub(crate) fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
let lhs = self; let lhs = self;
let lhs_dims = lhs.dims(); let lhs_dims = lhs.dims();
let rhs_dims = rhs.dims(); let rhs_dims = rhs.dims();
@ -444,18 +444,6 @@ impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5)
} }
} }
impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim, D6: Dim> Dims for (D1, D2, D3, D4, D5, D6) {
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
let d0 = self.0.to_index(shape, op)?;
let d1 = self.1.to_index(shape, op)?;
let d2 = self.2.to_index(shape, op)?;
let d3 = self.3.to_index(shape, op)?;
let d4 = self.4.to_index(shape, op)?;
let d5 = self.5.to_index(shape, op)?;
Ok(vec![d0, d1, d2, d3, d4, d5])
}
}
extract_dims!(dims0, 0, |_: &[usize]| (), ()); extract_dims!(dims0, 0, |_: &[usize]| (), ());
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize); extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize)); extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
@ -511,119 +499,154 @@ impl ShapeWithOneHole for ((),) {
} }
} }
fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
if prod_d == 0 {
crate::bail!("cannot reshape tensor of {el_count} elements to {s:?}")
}
if el_count % prod_d != 0 {
crate::bail!("cannot reshape tensor with {el_count} elements to {s:?}")
}
Ok(el_count / prod_d)
}
impl ShapeWithOneHole for ((), usize) { impl ShapeWithOneHole for ((), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1) = self; let ((), d1) = self;
Ok((hole_size(el_count, d1, &self)?, d1).into()) if el_count % d1 != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
}
Ok((el_count / d1, d1).into())
} }
} }
impl ShapeWithOneHole for (usize, ()) { impl ShapeWithOneHole for (usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, ()) = self; let (d1, ()) = self;
Ok((d1, hole_size(el_count, d1, &self)?).into()) if el_count % d1 != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
}
Ok((d1, el_count / d1).into())
} }
} }
impl ShapeWithOneHole for ((), usize, usize) { impl ShapeWithOneHole for ((), usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1, d2) = self; let ((), d1, d2) = self;
Ok((hole_size(el_count, d1 * d2, &self)?, d1, d2).into()) let d = d1 * d2;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((el_count / d, d1, d2).into())
} }
} }
impl ShapeWithOneHole for (usize, (), usize) { impl ShapeWithOneHole for (usize, (), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, (), d2) = self; let (d1, (), d2) = self;
Ok((d1, hole_size(el_count, d1 * d2, &self)?, d2).into()) let d = d1 * d2;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, el_count / d, d2).into())
} }
} }
impl ShapeWithOneHole for (usize, usize, ()) { impl ShapeWithOneHole for (usize, usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, ()) = self; let (d1, d2, ()) = self;
Ok((d1, d2, hole_size(el_count, d1 * d2, &self)?).into()) let d = d1 * d2;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, el_count / d).into())
} }
} }
impl ShapeWithOneHole for ((), usize, usize, usize) { impl ShapeWithOneHole for ((), usize, usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1, d2, d3) = self; let ((), d1, d2, d3) = self;
let d = hole_size(el_count, d1 * d2 * d3, &self)?; let d = d1 * d2 * d3;
Ok((d, d1, d2, d3).into()) if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((el_count / d, d1, d2, d3).into())
} }
} }
impl ShapeWithOneHole for (usize, (), usize, usize) { impl ShapeWithOneHole for (usize, (), usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, (), d2, d3) = self; let (d1, (), d2, d3) = self;
let d = hole_size(el_count, d1 * d2 * d3, &self)?; let d = d1 * d2 * d3;
Ok((d1, d, d2, d3).into()) if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, el_count / d, d2, d3).into())
} }
} }
impl ShapeWithOneHole for (usize, usize, (), usize) { impl ShapeWithOneHole for (usize, usize, (), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, (), d3) = self; let (d1, d2, (), d3) = self;
let d = hole_size(el_count, d1 * d2 * d3, &self)?; let d = d1 * d2 * d3;
Ok((d1, d2, d, d3).into()) if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, el_count / d, d3).into())
} }
} }
impl ShapeWithOneHole for (usize, usize, usize, ()) { impl ShapeWithOneHole for (usize, usize, usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, d3, ()) = self; let (d1, d2, d3, ()) = self;
let d = hole_size(el_count, d1 * d2 * d3, &self)?; let d = d1 * d2 * d3;
Ok((d1, d2, d3, d).into()) if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, d3, el_count / d).into())
} }
} }
impl ShapeWithOneHole for ((), usize, usize, usize, usize) { impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1, d2, d3, d4) = self; let ((), d1, d2, d3, d4) = self;
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?; let d = d1 * d2 * d3 * d4;
Ok((d, d1, d2, d3, d4).into()) if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((el_count / d, d1, d2, d3, d4).into())
} }
} }
impl ShapeWithOneHole for (usize, (), usize, usize, usize) { impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, (), d2, d3, d4) = self; let (d1, (), d2, d3, d4) = self;
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?; let d = d1 * d2 * d3 * d4;
Ok((d1, d, d2, d3, d4).into()) if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, el_count / d, d2, d3, d4).into())
} }
} }
impl ShapeWithOneHole for (usize, usize, (), usize, usize) { impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, (), d3, d4) = self; let (d1, d2, (), d3, d4) = self;
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?; let d = d1 * d2 * d3 * d4;
Ok((d1, d2, d, d3, d4).into()) if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, el_count / d, d3, d4).into())
} }
} }
impl ShapeWithOneHole for (usize, usize, usize, (), usize) { impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, d3, (), d4) = self; let (d1, d2, d3, (), d4) = self;
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?; let d = d1 * d2 * d3 * d4;
Ok((d1, d2, d3, d, d4).into()) if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, d3, el_count / d, d4).into())
} }
} }
impl ShapeWithOneHole for (usize, usize, usize, usize, ()) { impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, d3, d4, ()) = self; let (d1, d2, d3, d4, ()) = self;
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?; let d = d1 * d2 * d3 * d4;
Ok((d1, d2, d3, d4, d).into()) if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, d3, d4, el_count / d).into())
} }
} }

View File

@ -1,6 +1,6 @@
use crate::backend::BackendStorage; use crate::backend::BackendStorage;
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp}; use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape}; use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
// We do not want to implement Clone on Storage as cloning may fail because of // We do not want to implement Clone on Storage as cloning may fail because of
// out of memory. Instead try_clone should be used. // out of memory. Instead try_clone should be used.
@ -8,7 +8,6 @@ use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage,
pub enum Storage { pub enum Storage {
Cpu(CpuStorage), Cpu(CpuStorage),
Cuda(CudaStorage), Cuda(CudaStorage),
Metal(MetalStorage),
} }
impl Storage { impl Storage {
@ -19,10 +18,6 @@ impl Storage {
let storage = storage.try_clone(layout)?; let storage = storage.try_clone(layout)?;
Ok(Self::Cuda(storage)) Ok(Self::Cuda(storage))
} }
Self::Metal(storage) => {
let storage = storage.try_clone(layout)?;
Ok(Self::Metal(storage))
}
} }
} }
@ -30,7 +25,6 @@ impl Storage {
match self { match self {
Self::Cpu(_) => Device::Cpu, Self::Cpu(_) => Device::Cpu,
Self::Cuda(storage) => Device::Cuda(storage.device().clone()), Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
Self::Metal(storage) => Device::Metal(storage.device().clone()),
} }
} }
@ -38,7 +32,6 @@ impl Storage {
match self { match self {
Self::Cpu(storage) => storage.dtype(), Self::Cpu(storage) => storage.dtype(),
Self::Cuda(storage) => storage.dtype(), Self::Cuda(storage) => storage.dtype(),
Self::Metal(storage) => storage.dtype(),
} }
} }
@ -72,10 +65,6 @@ impl Storage {
let storage = storage.affine(layout, mul, add)?; let storage = storage.affine(layout, mul, add)?;
Ok(Self::Cuda(storage)) Ok(Self::Cuda(storage))
} }
Self::Metal(storage) => {
let storage = storage.affine(layout, mul, add)?;
Ok(Self::Metal(storage))
}
} }
} }
@ -89,10 +78,6 @@ impl Storage {
let storage = storage.powf(layout, alpha)?; let storage = storage.powf(layout, alpha)?;
Ok(Self::Cuda(storage)) Ok(Self::Cuda(storage))
} }
Self::Metal(storage) => {
let storage = storage.powf(layout, alpha)?;
Ok(Self::Metal(storage))
}
} }
} }
@ -106,10 +91,6 @@ impl Storage {
let storage = storage.elu(layout, alpha)?; let storage = storage.elu(layout, alpha)?;
Ok(Self::Cuda(storage)) Ok(Self::Cuda(storage))
} }
Self::Metal(storage) => {
let storage = storage.elu(layout, alpha)?;
Ok(Self::Metal(storage))
}
} }
} }
@ -131,10 +112,6 @@ impl Storage {
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?; let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
Ok(Self::Cuda(storage)) Ok(Self::Cuda(storage))
} }
(Self::Metal(lhs), Self::Metal(rhs)) => {
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
Ok(Self::Metal(storage))
}
(lhs, rhs) => { (lhs, rhs) => {
// Should not happen because of the same device check above but we're defensive // Should not happen because of the same device check above but we're defensive
// anyway. // anyway.
@ -158,10 +135,6 @@ impl Storage {
let storage = storage.reduce_op(op, layout, s)?; let storage = storage.reduce_op(op, layout, s)?;
Ok(Self::Cuda(storage)) Ok(Self::Cuda(storage))
} }
Self::Metal(storage) => {
let storage = storage.reduce_op(op, layout, s)?;
Ok(Self::Metal(storage))
}
} }
} }
@ -175,10 +148,6 @@ impl Storage {
let storage = storage.to_dtype(layout, dtype)?; let storage = storage.to_dtype(layout, dtype)?;
Ok(Self::Cuda(storage)) Ok(Self::Cuda(storage))
} }
Self::Metal(storage) => {
let storage = storage.to_dtype(layout, dtype)?;
Ok(Self::Metal(storage))
}
} }
} }
@ -192,10 +161,6 @@ impl Storage {
let (storage, shape) = c.cuda_fwd(storage, l)?; let (storage, shape) = c.cuda_fwd(storage, l)?;
Ok((Self::Cuda(storage), shape)) Ok((Self::Cuda(storage), shape))
} }
Self::Metal(storage) => {
let (storage, shape) = c.metal_fwd(storage, l)?;
Ok((Self::Metal(storage), shape))
}
} }
} }
@ -216,10 +181,6 @@ impl Storage {
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?; let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
Ok((Self::Cuda(s), shape)) Ok((Self::Cuda(s), shape))
} }
(Self::Metal(s1), Self::Metal(s2)) => {
let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?;
Ok((Self::Metal(s), shape))
}
_ => unreachable!(), _ => unreachable!(),
} }
} }
@ -244,10 +205,6 @@ impl Storage {
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?; let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
Ok((Self::Cuda(s), shape)) Ok((Self::Cuda(s), shape))
} }
(Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?;
Ok((Self::Metal(s), shape))
}
_ => unreachable!(), _ => unreachable!(),
} }
} }
@ -262,10 +219,6 @@ impl Storage {
let storage = storage.unary_impl::<B>(layout)?; let storage = storage.unary_impl::<B>(layout)?;
Ok(Self::Cuda(storage)) Ok(Self::Cuda(storage))
} }
Self::Metal(storage) => {
let storage = storage.unary_impl::<B>(layout)?;
Ok(Self::Metal(storage))
}
} }
} }
@ -286,10 +239,6 @@ impl Storage {
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?; let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
Ok(Self::Cuda(storage)) Ok(Self::Cuda(storage))
} }
(Self::Metal(lhs), Self::Metal(rhs)) => {
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
Ok(Self::Metal(storage))
}
(lhs, rhs) => { (lhs, rhs) => {
// Should not happen because of the same device check above but we're defensive // Should not happen because of the same device check above but we're defensive
// anyway. // anyway.
@ -321,10 +270,6 @@ impl Storage {
let s = inp.conv1d(l, kernel, kernel_l, params)?; let s = inp.conv1d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s)) Ok(Self::Cuda(s))
} }
(Storage::Metal(inp), Storage::Metal(kernel)) => {
let s = inp.conv1d(l, kernel, kernel_l, params)?;
Ok(Self::Metal(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(), lhs: lhs.device().location(),
rhs: rhs.device().location(), rhs: rhs.device().location(),
@ -334,33 +279,6 @@ impl Storage {
} }
} }
pub(crate) fn conv_transpose1d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
self.same_device(kernel, "conv-transpose1d")?;
self.same_dtype(kernel, "conv-transpose1d")?;
match (self, &kernel) {
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
Ok(Self::Cpu(s))
}
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "conv-transpose1d",
}
.bt()),
}
}
pub(crate) fn conv2d( pub(crate) fn conv2d(
&self, &self,
l: &Layout, l: &Layout,
@ -379,10 +297,6 @@ impl Storage {
let s = inp.conv2d(l, kernel, kernel_l, params)?; let s = inp.conv2d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s)) Ok(Self::Cuda(s))
} }
(Storage::Metal(inp), Storage::Metal(kernel)) => {
let s = inp.conv2d(l, kernel, kernel_l, params)?;
Ok(Self::Metal(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(), lhs: lhs.device().location(),
rhs: rhs.device().location(), rhs: rhs.device().location(),
@ -410,10 +324,6 @@ impl Storage {
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?; let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s)) Ok(Self::Cuda(s))
} }
(Storage::Metal(inp), Storage::Metal(kernel)) => {
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
Ok(Self::Metal(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(), lhs: lhs.device().location(),
rhs: rhs.device().location(), rhs: rhs.device().location(),
@ -438,10 +348,6 @@ impl Storage {
let storage = storage.avg_pool2d(layout, kernel_size, stride)?; let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
Ok(Self::Cuda(storage)) Ok(Self::Cuda(storage))
} }
Self::Metal(storage) => {
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
Ok(Self::Metal(storage))
}
} }
} }
@ -460,27 +366,6 @@ impl Storage {
let storage = storage.max_pool2d(layout, kernel_size, stride)?; let storage = storage.max_pool2d(layout, kernel_size, stride)?;
Ok(Self::Cuda(storage)) Ok(Self::Cuda(storage))
} }
Self::Metal(storage) => {
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
Ok(Self::Metal(storage))
}
}
}
pub(crate) fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.upsample_nearest1d(layout, sz)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.upsample_nearest1d(layout, sz)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.upsample_nearest1d(layout, sz)?;
Ok(Self::Metal(storage))
}
} }
} }
@ -494,10 +379,6 @@ impl Storage {
let storage = storage.upsample_nearest2d(layout, h, w)?; let storage = storage.upsample_nearest2d(layout, h, w)?;
Ok(Self::Cuda(storage)) Ok(Self::Cuda(storage))
} }
Self::Metal(storage) => {
let storage = storage.upsample_nearest2d(layout, h, w)?;
Ok(Self::Metal(storage))
}
} }
} }
@ -521,10 +402,6 @@ impl Storage {
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?; let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
Ok(Self::Cuda(storage)) Ok(Self::Cuda(storage))
} }
(Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => {
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
Ok(Self::Metal(storage))
}
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { (_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(), lhs: lhs.device().location(),
rhs: rhs.device().location(), rhs: rhs.device().location(),
@ -551,10 +428,6 @@ impl Storage {
let storage = s.gather(l, indexes, indexes_l, d)?; let storage = s.gather(l, indexes, indexes_l, d)?;
Ok(Self::Cuda(storage)) Ok(Self::Cuda(storage))
} }
(Self::Metal(s), Self::Metal(indexes)) => {
let storage = s.gather(l, indexes, indexes_l, d)?;
Ok(Self::Metal(storage))
}
_ => unreachable!(), _ => unreachable!(),
} }
} }
@ -579,10 +452,6 @@ impl Storage {
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?; let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cuda(storage)) Ok(Self::Cuda(storage))
} }
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Metal(storage))
}
_ => unreachable!(), _ => unreachable!(),
} }
} }
@ -607,10 +476,6 @@ impl Storage {
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?; let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cuda(storage)) Ok(Self::Cuda(storage))
} }
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Metal(storage))
}
_ => unreachable!(), _ => unreachable!(),
} }
} }
@ -632,10 +497,6 @@ impl Storage {
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?; let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
Ok(Self::Cuda(storage)) Ok(Self::Cuda(storage))
} }
(Self::Metal(lhs), Self::Metal(rhs)) => {
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
Ok(Self::Metal(storage))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(), lhs: lhs.device().location(),
rhs: rhs.device().location(), rhs: rhs.device().location(),
@ -663,10 +524,6 @@ impl Storage {
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?; let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
Ok(Self::Cuda(storage)) Ok(Self::Cuda(storage))
} }
(Self::Metal(lhs), Self::Metal(rhs)) => {
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
Ok(Self::Metal(storage))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(), lhs: lhs.device().location(),
rhs: rhs.device().location(), rhs: rhs.device().location(),
@ -686,9 +543,6 @@ impl Storage {
match (self, dst) { match (self, dst) {
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l), (Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?), (Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
(Self::Metal(src), Self::Metal(dst)) => {
Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(), lhs: lhs.device().location(),
rhs: rhs.device().location(), rhs: rhs.device().location(),

View File

@ -6,7 +6,7 @@ use crate::op::{
}; };
use crate::scalar::TensorOrScalar; use crate::scalar::TensorOrScalar;
use crate::shape::{Dim, Dims}; use crate::shape::{Dim, Dims};
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape}; use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
/// Unique identifier for tensors. /// Unique identifier for tensors.
@ -105,28 +105,6 @@ macro_rules! binary_op {
}; };
} }
macro_rules! binary_op_scalar {
($fn_name:ident, $op_name:ident) => {
pub fn $fn_name<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
let rhs = match rhs.to_tensor_scalar()? {
crate::scalar::TensorScalar::Tensor(rhs) => rhs,
crate::scalar::TensorScalar::Scalar(rhs) => rhs
.to_dtype(self.dtype())?
.to_device(self.device())?
.broadcast_as(self.shape())?,
};
let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
let storage = self.storage().binary_impl::<crate::op::$op_name>(
&*rhs.storage(),
self.layout(),
rhs.layout(),
)?;
let op = BackpropOp::new2(self, &rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
Ok(from_storage(storage, shape.clone(), op, false))
}
};
}
macro_rules! broadcast_binary_op { macro_rules! broadcast_binary_op {
($fn_name:ident, $inner_fn_name:ident) => { ($fn_name:ident, $inner_fn_name:ident) => {
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> { pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
@ -177,9 +155,14 @@ impl Tensor {
is_variable: bool, is_variable: bool,
) -> Result<Self> { ) -> Result<Self> {
let none = BackpropOp::none(); let none = BackpropOp::none();
if is_variable {
let shape = shape.into(); let shape = shape.into();
let storage = device.ones(&shape, dtype)?; let storage = device.ones(&shape, dtype)?;
Ok(from_storage(storage, shape, none, is_variable)) Ok(from_storage(storage, shape, none, is_variable))
} else {
let storage = device.ones(&crate::shape::SCALAR, dtype)?;
from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
}
} }
/// Creates a new tensor filled with ones. /// Creates a new tensor filled with ones.
@ -217,9 +200,14 @@ impl Tensor {
is_variable: bool, is_variable: bool,
) -> Result<Self> { ) -> Result<Self> {
let none = BackpropOp::none(); let none = BackpropOp::none();
if is_variable {
let shape = shape.into(); let shape = shape.into();
let storage = device.zeros(&shape, dtype)?; let storage = device.zeros(&shape, dtype)?;
Ok(from_storage(storage, shape, none, is_variable)) Ok(from_storage(storage, shape, none, is_variable))
} else {
let storage = device.zeros(&crate::shape::SCALAR, dtype)?;
from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
}
} }
/// Creates a new tensor filled with zeros. /// Creates a new tensor filled with zeros.
@ -385,22 +373,12 @@ impl Tensor {
step: D, step: D,
device: &Device, device: &Device,
) -> Result<Self> { ) -> Result<Self> {
if D::is_zero(&step) {
crate::bail!("step cannot be zero")
}
let mut data = vec![]; let mut data = vec![];
let mut current = start; let mut current = start;
if step >= D::zero() {
while current < end { while current < end {
data.push(current); data.push(current);
current += step; current += step;
} }
} else {
while current > end {
data.push(current);
current += step;
}
}
let len = data.len(); let len = data.len();
Self::from_vec_impl(data, len, device, false) Self::from_vec_impl(data, len, device, false)
} }
@ -459,7 +437,7 @@ impl Tensor {
/// Returns true if the computation graph should track this op, that is if it is /// Returns true if the computation graph should track this op, that is if it is
/// a variable or if it has some variable as dependencies. /// a variable or if it has some variable as dependencies.
pub fn track_op(&self) -> bool { pub(crate) fn track_op(&self) -> bool {
self.is_variable || self.op.is_some() self.is_variable || self.op.is_some()
} }
@ -469,20 +447,14 @@ impl Tensor {
binary_op!(mul, Mul); binary_op!(mul, Mul);
binary_op!(sub, Sub); binary_op!(sub, Sub);
binary_op!(div, Div); binary_op!(div, Div);
binary_op_scalar!(maximum, Maximum); binary_op!(maximum, Maximum);
binary_op_scalar!(minimum, Minimum); binary_op!(minimum, Minimum);
broadcast_binary_op!(broadcast_add, add); broadcast_binary_op!(broadcast_add, add);
broadcast_binary_op!(broadcast_mul, mul); broadcast_binary_op!(broadcast_mul, mul);
broadcast_binary_op!(broadcast_sub, sub); broadcast_binary_op!(broadcast_sub, sub);
broadcast_binary_op!(broadcast_div, div); broadcast_binary_op!(broadcast_div, div);
broadcast_binary_op!(broadcast_maximum, maximum); broadcast_binary_op!(broadcast_maximum, maximum);
broadcast_binary_op!(broadcast_minimum, minimum); broadcast_binary_op!(broadcast_minimum, minimum);
broadcast_binary_op!(broadcast_eq, eq);
broadcast_binary_op!(broadcast_ne, ne);
broadcast_binary_op!(broadcast_lt, lt);
broadcast_binary_op!(broadcast_le, le);
broadcast_binary_op!(broadcast_gt, gt);
broadcast_binary_op!(broadcast_ge, ge);
unary_op!(recip, Recip); unary_op!(recip, Recip);
unary_op!(neg, Neg); unary_op!(neg, Neg);
@ -495,21 +467,7 @@ impl Tensor {
unary_op!(sqr, Sqr); unary_op!(sqr, Sqr);
unary_op!(sqrt, Sqrt); unary_op!(sqrt, Sqrt);
unary_op!(gelu, Gelu); unary_op!(gelu, Gelu);
unary_op!(gelu_erf, GeluErf);
unary_op!(erf, Erf);
unary_op!(relu, Relu); unary_op!(relu, Relu);
unary_op!(ceil, Ceil);
unary_op!(floor, Floor);
unary_op!(round, Round);
/// Round element of the input tensor to the nearest integer.
///
/// If the number of decimals is negative, it specifies the number of positions to the left of
/// the decimal point.
pub fn round_to(&self, decimals: i32) -> Result<Self> {
let mult = 10f64.powi(decimals);
(self * mult)?.round()? * (1f64 / mult)
}
/// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple /// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple
/// dimensions, an error is returned instead. /// dimensions, an error is returned instead.
@ -529,7 +487,6 @@ impl Tensor {
match &*self.storage() { match &*self.storage() {
Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage), Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
} }
} }
@ -557,73 +514,6 @@ impl Tensor {
Ok(inp) Ok(inp)
} }
/// Creates grids of coordinates specified by the 1D inputs.
///
/// # Arguments
///
/// * `args` - A slice of 1D tensors.
/// * `xy_indexing` - Whether to use xy indexing or ij indexing. If xy is selected, the
/// first dimension corresponds to the cardinality of the second input and the second
/// dimension corresponds to the cardinality of the first input. If ij is selected, the
/// dimensions are in the same order as the cardinality of the inputs.
///
/// # Examples
///
/// ```rust
/// use candle_core::{Tensor, Device, Shape};
/// let x = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?;
/// let y = Tensor::new(&[4f32, 5., 6.], &Device::Cpu)?;
///
/// let grids_xy = Tensor::meshgrid(&[&x, &y], true)?;
///
/// assert_eq!(grids_xy.len(), 2);
/// assert_eq!(grids_xy[0].dims(), &[3, 3]);
///
/// assert_eq!(grids_xy[0].to_vec2::<f32>()?, &[[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]]);
/// assert_eq!(grids_xy[1].to_vec2::<f32>()?, &[[4., 4., 4.], [5., 5., 5.], [6., 6., 6.]]);
///
/// let grids_ij = Tensor::meshgrid(&[&x, &y], false)?;
///
/// assert_eq!(grids_ij[0].to_vec2::<f32>()?, &[[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]]);
/// assert_eq!(grids_ij[1].to_vec2::<f32>()?, &[[4., 5., 6.], [4., 5., 6.], [4., 5., 6.]]);
/// # Ok::<(), candle_core::Error>(())
/// ```
///
/// # Errors
///
/// * Will return `Err` if `args` contains less than 2 tensors.
///
pub fn meshgrid<A: AsRef<Tensor>>(args: &[A], xy_indexing: bool) -> Result<Vec<Self>> {
if args.len() <= 1 {
Err(Error::OpRequiresAtLeastTwoTensors { op: "meshgrid" }.bt())?
}
let args: Vec<_> = if xy_indexing {
args.iter().rev().collect()
} else {
args.iter().collect()
};
let mut shape = Vec::with_capacity(args.len());
for arg in args.iter() {
shape.push(arg.as_ref().dims1()?)
}
let mut grids = Vec::with_capacity(args.len());
for idx in 0..args.len() {
let mut ones = vec![1usize; args.len()];
ones[idx] = shape[idx];
let arg = args[idx].as_ref().reshape(ones)?;
let mut repeats = shape.clone();
repeats[idx] = 1;
let repeated_tensor = arg.repeat(repeats)?;
grids.push(repeated_tensor);
}
if xy_indexing {
grids.reverse();
}
Ok(grids)
}
/// This operation multiplies the input tensor by `mul` then adds `add` and return the result. /// This operation multiplies the input tensor by `mul` then adds `add` and return the result.
/// The input values `mul` and `add` are casted to the appropriate type so some rounding might /// The input values `mul` and `add` are casted to the appropriate type so some rounding might
/// be performed. /// be performed.
@ -699,23 +589,15 @@ impl Tensor {
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> { pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
let dims = self.dims(); let dims = self.dims();
let dim = dim.to_index(self.shape(), "narrow")?; let dim = dim.to_index(self.shape(), "narrow")?;
let err = |msg| { if start + len > dims[dim] {
Err::<(), _>( Err(Error::NarrowInvalidArgs {
Error::NarrowInvalidArgs {
shape: self.shape().clone(), shape: self.shape().clone(),
dim, dim,
start, start,
len, len,
msg, msg: "start + len > dim_len",
} }
.bt(), .bt())?
)
};
if start > dims[dim] {
err("start > dim_len")?
}
if start.saturating_add(len) > dims[dim] {
err("start + len > dim_len")?
} }
if start == 0 && dims[dim] == len { if start == 0 && dims[dim] == len {
Ok(self.clone()) Ok(self.clone())
@ -762,12 +644,7 @@ impl Tensor {
let storage = self.storage().reduce_op(op, self.layout(), &[dim])?; let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
let mut dims = self.dims().to_vec(); let mut dims = self.dims().to_vec();
dims[dim] = 1; dims[dim] = 1;
let op = match op { let op = BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()));
ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => {
BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()))
}
ReduceOp::ArgMin | ReduceOp::ArgMax => BackpropOp::none(),
};
let res = from_storage(storage, dims, op, false); let res = from_storage(storage, dims, op, false);
if keepdim { if keepdim {
Ok(res) Ok(res)
@ -856,20 +733,6 @@ impl Tensor {
self.sum_impl(mean_dims, false)? * scale self.sum_impl(mean_dims, false)? * scale
} }
/// Returns the unbiased variance over the selected dimension.
pub fn var_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "var")?;
let mean = self.mean_keepdim(dim)?;
let squares = self.broadcast_sub(&mean)?.sqr()?;
squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64
}
/// Returns the unbiased variance over the selected dimension.
pub fn var<D: Dim>(&self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "var")?;
self.var_keepdim(dim)?.squeeze(dim)
}
/// Gathers the maximum value across the selected dimension. The resulting shape has the same /// Gathers the maximum value across the selected dimension. The resulting shape has the same
/// number of dimensions as the original tensor and the select dimension has a single element. /// number of dimensions as the original tensor and the select dimension has a single element.
pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> { pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
@ -964,35 +827,12 @@ impl Tensor {
self.cmp(rhs, CmpOp::Le) self.cmp(rhs, CmpOp::Le)
} }
/// Clamp the tensor values to be between `min` and `max`. /// Upsample the input tensor to the `(target_h, target_w)` size, taking the value of the
pub fn clamp<T1: TensorOrScalar, T2: TensorOrScalar>(&self, min: T1, max: T2) -> Result<Self> {
self.maximum(min)?.minimum(max)
}
/// Interpolate the input tensor to the `target_size` size, taking the value of the nearest element.
///
/// The input tensor should have three dimensions, `(batch, channels, l)`, the returned
/// tensor also has three dimensions, `(batch, channels, target_size)`.
pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
let (n, c, _l) = self.dims3()?;
let op = BackpropOp::new1(self, Op::UpsampleNearest1D);
let storage = self
.storage()
.upsample_nearest1d(self.layout(), target_size)?;
Ok(from_storage(storage, (n, c, target_size), op, false))
}
/// Alias for `interpolate1d`.
pub fn upsample_nearest1d(&self, target_size: usize) -> Result<Self> {
self.interpolate1d(target_size)
}
/// Interpolate the input tensor to the `(target_h, target_w)` size, taking the value of the
/// nearest element. /// nearest element.
/// ///
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
/// tensor also has four dimensions, `(batch, channels, target_h, target_w)`. /// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> { pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
let (n, c, _h, _w) = self.dims4()?; let (n, c, _h, _w) = self.dims4()?;
let op = BackpropOp::new1(self, Op::UpsampleNearest2D); let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
let storage = self let storage = self
@ -1001,11 +841,6 @@ impl Tensor {
Ok(from_storage(storage, (n, c, target_h, target_w), op, false)) Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
} }
/// Alias for `interpolate2d`.
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
self.interpolate2d(target_h, target_w)
}
/// 2D average pooling over an input tensor with multiple channels. /// 2D average pooling over an input tensor with multiple channels.
/// ///
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
@ -1217,16 +1052,14 @@ impl Tensor {
op: "scatter-add (self, src)", op: "scatter-add (self, src)",
lhs: self.shape().clone(), lhs: self.shape().clone(),
rhs: source.shape().clone(), rhs: source.shape().clone(),
} })?
.bt())?
} }
if indexes.dims() != source.dims() { if indexes.dims() != source.dims() {
Err(Error::ShapeMismatchBinaryOp { Err(Error::ShapeMismatchBinaryOp {
op: "scatter-add (indexes, src)", op: "scatter-add (indexes, src)",
lhs: indexes.shape().clone(), lhs: indexes.shape().clone(),
rhs: source.shape().clone(), rhs: source.shape().clone(),
} })?
.bt())?
} }
let storage = self.storage().scatter_add( let storage = self.storage().scatter_add(
self.layout(), self.layout(),
@ -1242,75 +1075,6 @@ impl Tensor {
Ok(from_storage(storage, self.shape(), op, false)) Ok(from_storage(storage, self.shape(), op, false))
} }
/// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
let dim = dim.to_index(self.shape(), "slice-scatter")?;
if dim == 0 {
self.slice_scatter0(src, start)
} else {
// TODO: Maybe we want to add a more efficient implementation at some point.
self.transpose(0, dim)?
.slice_scatter0(&src.transpose(0, dim)?, start)?
.transpose(0, dim)
}
}
/// Embeds the values of the `src` tensor into the `self` tensor on the first dimension.
pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result<Self> {
if self.dtype() != src.dtype() {
Err(Error::DTypeMismatchBinaryOp {
lhs: self.dtype(),
rhs: src.dtype(),
op: "slice-scatter",
}
.bt())?
}
if self.device().location() != src.device.location() {
Err(Error::DeviceMismatchBinaryOp {
lhs: self.device().location(),
rhs: src.device().location(),
op: "slice-scatter",
}
.bt())?
}
if self.rank() != src.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: self.rank(),
got: src.rank(),
shape: src.shape().clone(),
}
.bt())?
}
let shape_ok =
self.dims()
.iter()
.zip(src.dims().iter())
.enumerate()
.all(|(dim_idx, (&d1, &d2))| {
if 0 == dim_idx {
d2 + start <= d1
} else {
d1 == d2
}
});
if !shape_ok {
Err(Error::ShapeMismatchBinaryOp {
op: "slice-scatter (self, src)",
lhs: self.shape().clone(),
rhs: src.shape().clone(),
}
.bt())?
}
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
let offset = start * src.dims()[1..].iter().product::<usize>();
src.storage()
.copy_strided_src(&mut storage, offset, src.layout())?;
let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start));
Ok(from_storage(storage, self.shape(), op, false))
}
/// Accumulate element from `source` at indexes `indexes` and add them to `self`. /// Accumulate element from `source` at indexes `indexes` and add them to `self`.
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> { pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "index-add")?; let dim = dim.to_index(self.shape(), "index-add")?;
@ -1333,8 +1097,7 @@ impl Tensor {
op: "index-add (self, source)", op: "index-add (self, source)",
lhs: self.shape().clone(), lhs: self.shape().clone(),
rhs: source.shape().clone(), rhs: source.shape().clone(),
} })?
.bt())?
} }
// The number of element in indexes must match the dimension on which the add is // The number of element in indexes must match the dimension on which the add is
// performed on the source tensor (and the index values from `indexes` are taken from // performed on the source tensor (and the index values from `indexes` are taken from
@ -1345,8 +1108,7 @@ impl Tensor {
op: "index-add (ids, source))", op: "index-add (ids, source))",
lhs: indexes.shape().clone(), lhs: indexes.shape().clone(),
rhs: source.shape().clone(), rhs: source.shape().clone(),
} })?
.bt())?
} }
let storage = self.storage().index_add( let storage = self.storage().index_add(
self.layout(), self.layout(),
@ -1394,8 +1156,7 @@ impl Tensor {
op: "gather", op: "gather",
lhs: self.shape().clone(), lhs: self.shape().clone(),
rhs: indexes.shape().clone(), rhs: indexes.shape().clone(),
} })?
.bt())?
} }
let storage = let storage =
self.storage() self.storage()
@ -1469,7 +1230,6 @@ impl Tensor {
match &*self.storage() { match &*self.storage() {
Storage::Cpu(storage) => from_cpu_storage(storage), Storage::Cpu(storage) => from_cpu_storage(storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
} }
} }
@ -1500,7 +1260,6 @@ impl Tensor {
match &*self.storage() { match &*self.storage() {
Storage::Cpu(storage) => from_cpu_storage(storage), Storage::Cpu(storage) => from_cpu_storage(storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
} }
} }
@ -1541,7 +1300,6 @@ impl Tensor {
match &*self.storage() { match &*self.storage() {
Storage::Cpu(storage) => from_cpu_storage(storage), Storage::Cpu(storage) => from_cpu_storage(storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
} }
} }
@ -1705,24 +1463,6 @@ impl Tensor {
} }
} }
/// Returns the sub-tensor fixing the index at `index` on the dimension `dim`.
///
/// ```rust
/// use candle_core::{Tensor, Device};
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let t = tensor.get_on_dim(1, 0)?;
/// assert_eq!(t.to_vec1::<f32>()?, &[0., 2., 4.]);
/// let t = tensor.get_on_dim(1, 1)?;
/// assert_eq!(t.to_vec1::<f32>()?, &[1., 3., 5.]);
/// let t = tensor.get_on_dim(0, 1)?;
/// assert_eq!(t.to_vec1::<f32>()?, &[2., 3.]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn get_on_dim<D: Dim>(&self, dim: D, index: usize) -> Result<Tensor> {
let dim = dim.to_index(self.shape(), "get_on_dim")?;
self.narrow(dim, index, 1)?.squeeze(dim)
}
/// Returns a tensor that is a transposed version of the input, the two last dimensions of the /// Returns a tensor that is a transposed version of the input, the two last dimensions of the
/// input are swapped. /// input are swapped.
/// ///
@ -1751,9 +1491,6 @@ impl Tensor {
pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> { pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
let dim1 = dim1.to_index(self.shape(), "transpose")?; let dim1 = dim1.to_index(self.shape(), "transpose")?;
let dim2 = dim2.to_index(self.shape(), "transpose")?; let dim2 = dim2.to_index(self.shape(), "transpose")?;
if dim1 == dim2 {
return Ok(self.clone());
}
let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2)); let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
id: TensorId::new(), id: TensorId::new(),
@ -1831,12 +1568,7 @@ impl Tensor {
/// Returns a new tensor detached from the current graph, gradient are not propagated through /// Returns a new tensor detached from the current graph, gradient are not propagated through
/// this new node. The storage of this tensor is shared with the initial tensor. /// this new node. The storage of this tensor is shared with the initial tensor.
///
/// If the tensor is already detached from the computation graph, the same tensor is returned.
pub fn detach(&self) -> Result<Tensor> { pub fn detach(&self) -> Result<Tensor> {
if self.op.is_none() && !self.is_variable {
Ok(self.clone())
} else {
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
id: TensorId::new(), id: TensorId::new(),
storage: self.storage.clone(), storage: self.storage.clone(),
@ -1848,7 +1580,6 @@ impl Tensor {
}; };
Ok(Tensor(Arc::new(tensor_))) Ok(Tensor(Arc::new(tensor_)))
} }
}
/// If the target device is the same as the tensor device, only a shallow copy is performed. /// If the target device is the same as the tensor device, only a shallow copy is performed.
pub fn to_device(&self, device: &Device) -> Result<Tensor> { pub fn to_device(&self, device: &Device) -> Result<Tensor> {
@ -1859,11 +1590,7 @@ impl Tensor {
(Storage::Cpu(storage), Device::Cuda(cuda)) => { (Storage::Cpu(storage), Device::Cuda(cuda)) => {
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?) Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
} }
(Storage::Cpu(storage), Device::Metal(metal)) => {
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
}
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
(Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
(Storage::Cuda(storage), Device::Cuda(cuda)) => { (Storage::Cuda(storage), Device::Cuda(cuda)) => {
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids // TODO: Avoid passing through the cpu storage here, especially if the gpu ids
// are the same. // are the same.
@ -1871,9 +1598,6 @@ impl Tensor {
Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?) Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
} }
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()), (Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
_ => {
bail!("not implemented yet")
}
}; };
let op = BackpropOp::new1(self, Op::ToDevice); let op = BackpropOp::new1(self, Op::ToDevice);
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
@ -2128,34 +1852,6 @@ impl Tensor {
for arg in args { for arg in args {
arg.as_ref().check_dim(dim, "cat")?; arg.as_ref().check_dim(dim, "cat")?;
} }
for (arg_idx, arg) in args.iter().enumerate() {
let arg = arg.as_ref();
if arg0.rank() != arg.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: arg0.rank(),
got: arg.rank(),
shape: arg.shape().clone(),
}
.bt())?
}
for (dim_idx, (v1, v2)) in arg0
.shape()
.dims()
.iter()
.zip(arg.shape().dims().iter())
.enumerate()
{
if dim_idx != dim && v1 != v2 {
Err(Error::ShapeMismatchCat {
dim: dim_idx,
first_shape: arg0.shape().clone(),
n: arg_idx + 1,
nth_shape: arg.shape().clone(),
}
.bt())?
}
}
}
if dim == 0 { if dim == 0 {
Self::cat0(args) Self::cat0(args)
} else { } else {
@ -2273,56 +1969,11 @@ impl Tensor {
} }
} }
/// Pad the input tensor using same values along dimension `dim`. This adds `left` elements before the
/// input tensor values and `right` elements after.
pub fn pad_with_same<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
if left == 0 && right == 0 {
Ok(self.clone())
} else if self.elem_count() == 0 {
crate::bail!("cannot use pad_with_same on an empty tensor")
} else if left == 0 {
let dim = dim.to_index(self.shape(), "pad_with_same")?;
let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
let mut v = vec![self];
for _ in 0..right {
v.push(&r)
}
Tensor::cat(&v, dim)
} else if right == 0 {
let dim = dim.to_index(self.shape(), "pad_with_same")?;
let l = self.narrow(dim, 0, 1)?;
let mut v = vec![];
for _ in 0..left {
v.push(&l)
}
v.push(self);
Tensor::cat(&v, dim)
} else {
let dim = dim.to_index(self.shape(), "pad_with_same")?;
let l = self.narrow(dim, 0, 1)?;
let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
let mut v = vec![];
for _ in 0..left {
v.push(&l)
}
v.push(self);
for _ in 0..right {
v.push(&r)
}
Tensor::cat(&v, dim)
}
}
/// Run the `forward` method of `m` on `self`. /// Run the `forward` method of `m` on `self`.
pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> { pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
m.forward(self) m.forward(self)
} }
/// Run the `forward` method of `m` on `self`.
pub fn apply_t<M: crate::ModuleT>(&self, m: &M, train: bool) -> Result<Self> {
m.forward_t(self, train)
}
pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> { pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
self.storage.read().unwrap() self.storage.read().unwrap()
} }
@ -2437,127 +2088,6 @@ impl Tensor {
) -> Result<Self> { ) -> Result<Self> {
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c))) self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
} }
/// Normalize a 'relative' axis value: positive values are kept, negative
/// values means counting the dimensions from the back.
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
let rank = self.rank() as i64;
if rank <= axis {
crate::bail!("axis {axis} is too large, tensor rank {rank}")
} else if 0 <= axis {
Ok(axis as usize)
} else {
let naxis = rank + axis;
if naxis < 0 {
crate::bail!("axis {axis} is too small, tensor rank {rank}")
}
Ok(naxis as usize)
}
}
/// Returns a lower triangular matrix of ones of size n by n.
pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
let t = Tensor::arange(0u32, n as u32, device)?;
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
t1.le(&t2)?.to_dtype(dtype)
}
/// Returns an upper triangular matrix of ones of size n by n.
pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
let t = Tensor::arange(0u32, n as u32, device)?;
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
t1.ge(&t2)?.to_dtype(dtype)
}
/// Returns a matrix with a diagonal of ones of size n by n.
pub fn eye(n: usize, dtype: DType, device: &Device) -> Result<Self> {
let t = Tensor::arange(0u32, n as u32, device)?;
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
t1.eq(&t2)?.to_dtype(dtype)
}
/// Returns the cumulative sum of elements of the input tensor summed over the specified
/// dimension.
///
/// This operation is most efficient when dim is the last dimension of the tensor.
pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "cumsum")?;
let rank = self.rank();
if rank == 0 {
return Ok(self.clone());
}
let n_axis = self.dim(dim)?;
let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?;
if rank == 1 {
self.unsqueeze(0)?.matmul(&triu)?.squeeze(0)
} else {
let last = rank - 1;
let t = self.transpose(dim, last)?;
let t = t.broadcast_matmul(&triu)?;
t.transpose(dim, last)
}
}
/// Returns a copy of `self` where the values within `ranges` have been replaced with the
/// content of `src`.
pub fn slice_assign<D: std::ops::RangeBounds<usize>>(
&self,
ranges: &[D],
src: &Tensor,
) -> Result<Self> {
let src_dims = src.dims();
let self_dims = self.dims();
if self_dims.len() != src_dims.len() {
crate::bail!(
"slice-assign requires input with the same rank {} <> {}",
self_dims.len(),
src_dims.len()
)
}
if self_dims.len() != ranges.len() {
crate::bail!(
"slice-assign requires input with the same rank as there are ranges {} <> {}",
self_dims.len(),
ranges.len()
)
}
let mut src = src.clone();
let mut mask = Self::ones(src.shape(), DType::U8, src.device())?;
for (i, range) in ranges.iter().enumerate() {
let start_included = match range.start_bound() {
std::ops::Bound::Unbounded => 0,
std::ops::Bound::Included(v) => *v,
std::ops::Bound::Excluded(v) => *v + 1,
};
let end_excluded = match range.end_bound() {
std::ops::Bound::Unbounded => self_dims[i],
std::ops::Bound::Included(v) => *v + 1,
std::ops::Bound::Excluded(v) => *v,
};
if end_excluded <= start_included {
crate::bail!(
"slice-assign: empty range for dim {i}, {start_included} {end_excluded}"
)
}
if self_dims[i] < end_excluded {
crate::bail!(
"slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
self_dims[i]
)
}
if end_excluded - start_included != src_dims[i] {
crate::bail!(
"slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
)
}
src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?;
mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?
}
mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
}
} }
macro_rules! bin_trait { macro_rules! bin_trait {

View File

@ -4,7 +4,7 @@ use crate::{Result, Tensor};
macro_rules! test_device { macro_rules! test_device {
// TODO: Switch to generating the two last arguments automatically once concat_idents is // TODO: Switch to generating the two last arguments automatically once concat_idents is
// stable. https://github.com/rust-lang/rust/issues/29599 // stable. https://github.com/rust-lang/rust/issues/29599
($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => { ($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => {
#[test] #[test]
fn $test_cpu() -> Result<()> { fn $test_cpu() -> Result<()> {
$fn_name(&Device::Cpu) $fn_name(&Device::Cpu)
@ -15,12 +15,6 @@ macro_rules! test_device {
fn $test_cuda() -> Result<()> { fn $test_cuda() -> Result<()> {
$fn_name(&Device::new_cuda(0)?) $fn_name(&Device::new_cuda(0)?)
} }
#[cfg(feature = "metal")]
#[test]
fn $test_metal() -> Result<()> {
$fn_name(&Device::new_metal(0)?)
}
}; };
} }

View File

@ -23,10 +23,6 @@ pub fn cuda_is_available() -> bool {
cfg!(feature = "cuda") cfg!(feature = "cuda")
} }
pub fn metal_is_available() -> bool {
cfg!(feature = "metal")
}
pub fn with_avx() -> bool { pub fn with_avx() -> bool {
cfg!(target_feature = "avx") cfg!(target_feature = "avx")
} }

View File

@ -13,11 +13,6 @@ res = torch.nn.functional.conv1d(t, w)
print(res.flatten()) print(res.flatten())
res = torch.nn.functional.conv1d(t, w, padding=1) res = torch.nn.functional.conv1d(t, w, padding=1)
print(res.flatten()) print(res.flatten())
w_t = w.transpose(0, 1)
res = torch.nn.functional.conv_transpose1d(t, w_t)
print(res.shape)
print(res)
*/ */
fn conv1d(dev: &Device) -> Result<()> { fn conv1d(dev: &Device) -> Result<()> {
let t = Tensor::new( let t = Tensor::new(
@ -50,17 +45,6 @@ fn conv1d(dev: &Device) -> Result<()> {
test_utils::to_vec1_round(&res.flatten_all()?, 4)?, test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352] [2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
); );
if dev.is_cpu() {
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 7]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
4.7076, -5.9745, -0.8276, 1.621
],
);
}
Ok(()) Ok(())
} }
@ -495,103 +479,17 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
] ]
] ]
); );
// Replicate the issue from https://github.com/huggingface/candle/issues/1212
let res = t.i((.., .., 0..4, 0..4))?.conv2d(&w, 0, 2, 1, 1)?;
let loss = res.sqr()?.sum_all()?;
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 21.12f32);
let grads = loss.backward()?;
let grad_t = grads.get(&t).unwrap();
let grad_w = grads.get(&w).unwrap();
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
assert_eq!(
test_utils::to_vec3_round(&grad_t.i(0)?, 2)?,
[
[
[9.29, -7.03, 7.87, 0.0, 0.0],
[-1.8, -7.82, 5.9, 0.0, 0.0],
[-3.12, 4.49, 5.52, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
],
[
[21.73, 3.39, 4.77, 0.0, 0.0],
[8.25, 3.73, 27.61, 0.0, 0.0],
[-20.55, -5.61, -2.77, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
],
[
[-8.98, 9.91, -7.15, 0.0, 0.0],
[4.93, -0.33, 4.56, 0.0, 0.0],
[-6.7, -5.76, -8.05, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
],
[
[23.54, 6.98, -10.0, 0.0, 0.0],
[9.65, 6.18, 18.72, 0.0, 0.0],
[3.29, -5.27, 0.79, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
]
]
);
assert_eq!(
test_utils::to_vec3_round(&grad_w.i(0)?, 2)?,
[
[
[-3.47, 7.44, 0.66],
[12.89, -3.4, -9.29],
[-14.16, -0.83, 7.14]
],
[
[-3.23, 5.37, -3.02],
[-2.12, -11.24, 1.94],
[6.97, 7.2, 2.99]
],
[
[-4.04, -3.31, 4.87],
[-6.68, -5.68, 1.73],
[-5.54, 4.32, 0.52]
],
[[-4.72, 1.5, 4.72], [3.79, 4.04, 6.76], [-4.6, 5.8, 6.93]]
]
);
Ok(()) Ok(())
} }
test_device!(conv1d, conv1d_cpu, conv1d_gpu, conv1d_metal); test_device!(conv1d, conv1d_cpu, conv1d_gpu);
test_device!( test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
conv1d_small, test_device!(conv2d, conv2d_cpu, conv2d_gpu);
conv1d_small_cpu,
conv1d_small_gpu,
conv1d_small_metal
);
test_device!(conv2d, conv2d_cpu, conv2d_gpu, conv2d_metal);
test_device!( test_device!(
conv2d_non_square, conv2d_non_square,
conv2d_non_square_cpu, conv2d_non_square_cpu,
conv2d_non_square_gpu, conv2d_non_square_gpu
conv2d_non_square_metal
);
test_device!(
conv2d_small,
conv2d_small_cpu,
conv2d_small_gpu,
conv2d_small_metal
);
test_device!(
conv2d_smaller,
conv2d_smaller_cpu,
conv2d_smaller_gpu,
conv2d_smaller_metal
);
test_device!(
conv2d_grad,
conv2d_grad_cpu,
conv2d_grad_gpu,
conv2_grad_metal
); );
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
test_device!(conv2d_grad, conv2d_grad_cpu, conv2d_grad_gpu);

View File

@ -192,84 +192,6 @@ fn unary_grad(device: &Device) -> Result<()> {
test_utils::to_vec1_round(grad_x, 2)?, test_utils::to_vec1_round(grad_x, 2)?,
[0.01, 0.42, 0.0, 0.98], [0.01, 0.42, 0.0, 0.98],
); );
// testing compared to pytorch nn.GELU(approximate = 'tanh')
let y = x.gelu()?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[2.9964, 0.8412, 3.9999, 0.0839]
);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[1.0116, 1.0830, 1.0003, 0.6188],
);
// Testing compared to pytorch torch.erf
//
// import torch
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
// y = x.erf()
// print(y)
// loss = y.sum()
// loss.backward()
// print(x.grad)
let y = x.erf()?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(test_utils::to_vec1_round(&y, 4)?, [1.0, 0.8427, 1.0, 0.168]);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[0.0001, 0.4151, 0.0, 1.1033],
);
// Testing compared to pytorch nn.GELU(approximate = 'none')
//
// import torch
// import torch.nn.functional as F
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
// y = F.gelu(x, approximate='none')
// print(y)
// loss = y.sum()
// loss.backward()
// print(x.grad)
let y = x.gelu_erf()?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[2.9960, 0.8413, 3.9999, 0.0839]
);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[1.0119, 1.0833, 1.0005, 0.6188],
);
// Testing compared to pytorch elu
//
// import torch
// import torch.nn.functional as F
// x = torch.tensor([-1.0, 0.0, -2.0, 3.0], requires_grad=True)
// y = F.elu(x, alpha=2.0)
// print(y)
// loss = y.min
// loss = y.sum()
// loss.backward()
// print(x.grad)
let elu_x = Var::new(&[-1.0f32, 0., -2., 3.], device)?;
let y = elu_x.elu(2.)?;
let grads = y.backward()?;
let grad_x = grads.get(&elu_x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[-1.2642, 0.0000, -1.7293, 3.0000]
);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[0.7358, 2.0000, 0.2707, 1.0000]
);
Ok(()) Ok(())
} }
@ -296,48 +218,12 @@ fn binary_grad(device: &Device) -> Result<()> {
let grad_x = grads.get(x).context("no grad for x")?; let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [3., 1., -4., -1.]); assert_eq!(y.to_vec1::<f32>()?, [3., 1., -4., -1.]);
assert_eq!(grad_x.to_vec1::<f32>()?, [1., 1., 1., 1.]); assert_eq!(grad_x.to_vec1::<f32>()?, [1., 1., 1., 1.]);
let x_var = Var::new(&[3f32, 1., -4., -1., 5., 9.], device)?;
let x = x_var.as_tensor();
let y_var = Var::new(&[2f32, 7., 1.], device)?;
let y = y_var.as_tensor();
let ss = x
.reshape((2, 3))?
.slice_scatter0(&y.reshape((1, 3))?, 1)?
.sqr()?;
let grads = ss.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
let grad_y = grads.get(y).context("no grad for y")?;
assert_eq!(ss.to_vec2::<f32>()?, [[9., 1., 16.], [4., 49., 1.]]);
assert_eq!(grad_x.to_vec1::<f32>()?, [6.0, 2.0, -8.0, 0.0, 0.0, 0.0]);
assert_eq!(grad_y.to_vec1::<f32>()?, [4.0, 14.0, 2.0]);
Ok(()) Ok(())
} }
test_device!( test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
simple_grad, test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
simple_grad_cpu, test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
simple_grad_gpu, test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
simple_grad_metal test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu);
); test_device!(binary_grad, binary_grad_cpu, binary_grad_gpu);
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu, sum_grad_metal);
test_device!(
matmul_grad,
matmul_grad_cpu,
matmul_grad_gpu,
matmul_grad_metal
);
test_device!(
grad_descent,
grad_descent_cpu,
grad_descent_gpu,
grad_descent_metal
);
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu, unary_grad_metal);
test_device!(
binary_grad,
binary_grad_cpu,
binary_grad_gpu,
binary_grad_metal
);

View File

@ -91,32 +91,3 @@ fn index_3d() -> Result<()> {
assert_eq!(tensor.i((1, .., 3))?.to_vec1::<u32>()?, &[15, 19, 23]); assert_eq!(tensor.i((1, .., 3))?.to_vec1::<u32>()?, &[15, 19, 23]);
Ok(()) Ok(())
} }
#[test]
fn slice_assign() -> Result<()> {
let dev = Device::Cpu;
let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?;
let src = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((3, 2))?;
let out = tensor.slice_assign(&[1..4, 3..5], &src)?;
assert_eq!(
out.to_vec2::<u32>()?,
&[
[0, 1, 2, 3, 4],
[5, 6, 7, 0, 1],
[10, 11, 12, 2, 3],
[15, 16, 17, 4, 5]
]
);
let out = tensor.slice_assign(&[0..3, 0..2], &src)?;
assert_eq!(
out.to_vec2::<u32>()?,
&[
[0, 1, 2, 3, 4],
[2, 3, 7, 8, 9],
[4, 5, 12, 13, 14],
[15, 16, 17, 18, 19]
]
);
Ok(())
}

View File

@ -49,7 +49,7 @@ fn contiguous(device: &Device) -> Result<()> {
Ok(()) Ok(())
} }
test_device!(contiguous, contiguous_cpu, contiguous_gpu, contiguous_metal); test_device!(contiguous, contiguous_cpu, contiguous_gpu);
#[test] #[test]
fn strided_blocks() -> Result<()> { fn strided_blocks() -> Result<()> {

View File

@ -1,9 +0,0 @@
import numpy as np
x = np.arange(10)
# Write a npy file.
np.save("test.npy", x)
# Write multiple values to a npz file.
values = { "x": x, "x_plus_one": x + 1 }
np.savez("test.npz", **values)

View File

@ -98,17 +98,15 @@ fn upsample_nearest2d(dev: &Device) -> Result<()> {
Ok(()) Ok(())
} }
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu, avg_pool2d_metal); test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu);
test_device!( test_device!(
avg_pool2d_pytorch, avg_pool2d_pytorch,
avg_pool2d_pytorch_cpu, avg_pool2d_pytorch_cpu,
avg_pool2d_pytorch_gpu, avg_pool2d_pytorch_gpu
avg_pool2d_pytorch_metal
); );
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu, max_pool2d_metal); test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);
test_device!( test_device!(
upsample_nearest2d, upsample_nearest2d,
upsample_nearest2d_cpu, upsample_nearest2d_cpu,
upsample_nearest2d_gpu, upsample_nearest2d_gpu
upsample_nearest2d_metal
); );

View File

@ -1,7 +1,7 @@
use candle_core::{ use candle_core::{
quantized::{self, GgmlDType}, quantized::{self, GgmlDType},
test_utils::to_vec2_round, test_utils::to_vec2_round,
Device, Module, Result, Tensor, Device, Result, Tensor,
}; };
use quantized::{k_quants, GgmlType}; use quantized::{k_quants, GgmlType};
use rand::prelude::*; use rand::prelude::*;
@ -43,7 +43,7 @@ fn quantized_matmul() -> Result<()> {
); );
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?; let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?; let matmul = quantized::QMatMul::from_qtensor(qtensor);
let res = matmul.forward(&tensor_lhs)?; let res = matmul.forward(&tensor_lhs)?;
assert_eq!( assert_eq!(
to_vec2_round(&res, 0)?, to_vec2_round(&res, 0)?,
@ -91,7 +91,7 @@ fn quantized_matmul_neg() -> Result<()> {
); );
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?; let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?; let matmul = quantized::QMatMul::from_qtensor(qtensor);
let res = matmul.forward(&tensor_lhs)?; let res = matmul.forward(&tensor_lhs)?;
assert_eq!( assert_eq!(
to_vec2_round(&res, 0)?, to_vec2_round(&res, 0)?,
@ -491,9 +491,6 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
GgmlDType::Q5_0 => 0.001353, GgmlDType::Q5_0 => 0.001353,
GgmlDType::Q5_1 => 0.001363, GgmlDType::Q5_1 => 0.001363,
GgmlDType::Q8_0 => 0.000092, GgmlDType::Q8_0 => 0.000092,
// Not from the ggml repo.
GgmlDType::Q8K => 0.00065,
_ => candle_core::bail!("No GGML results for quantization type {dtype:?}",), _ => candle_core::bail!("No GGML results for quantization type {dtype:?}",),
}; };
Ok(err) Ok(err)
@ -511,22 +508,17 @@ fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
T::VecDotType::from_float(&b, &mut b_quant)?; T::VecDotType::from_float(&b, &mut b_quant)?;
let result = T::vec_dot(length, &a_quant, &b_quant)?; let result = T::vec_dot(length, &a_quant, &b_quant)?;
let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?;
let reference_result = vec_dot_reference(&a, &b); let reference_result = vec_dot_reference(&a, &b);
if (result - result_unopt).abs() / length as f32 > 1e-6 {
candle_core::bail!(
"the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}"
)
}
let error = (result - reference_result).abs() / length as f32; let error = (result - reference_result).abs() / length as f32;
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?; let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;
if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR { if error > GGML_MAX_DOT_PRODUCT_ERROR {
candle_core::bail!( candle_core::bail!(
"Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}", "Dot product error {} exceeds max error {}",
error,
GGML_MAX_DOT_PRODUCT_ERROR
); );
} }
@ -579,7 +571,7 @@ fn quantized_matmul_q2k() -> Result<()> {
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?; let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?; let rhs = quantized::QMatMul::from_qtensor(rhs);
let mm = rhs.forward(&lhs)?; let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]); assert_eq!(mm.dims(), [m, n]);
@ -605,7 +597,7 @@ fn quantized_matmul_q3k() -> Result<()> {
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?; let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?; let rhs = quantized::QMatMul::from_qtensor(rhs);
let mm = rhs.forward(&lhs)?; let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]); assert_eq!(mm.dims(), [m, n]);
@ -631,7 +623,7 @@ fn quantized_matmul_q4k() -> Result<()> {
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?; let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?; let rhs = quantized::QMatMul::from_qtensor(rhs);
let mm = rhs.forward(&lhs)?; let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]); assert_eq!(mm.dims(), [m, n]);
@ -657,7 +649,7 @@ fn quantized_matmul_q5k() -> Result<()> {
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?; let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?; let rhs = quantized::QMatMul::from_qtensor(rhs);
let mm = rhs.forward(&lhs)?; let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]); assert_eq!(mm.dims(), [m, n]);
@ -684,7 +676,7 @@ fn quantized_matmul_q6k() -> Result<()> {
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?; let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?; let rhs = quantized::QMatMul::from_qtensor(rhs);
let mm = rhs.forward(&lhs)?; let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]); assert_eq!(mm.dims(), [m, n]);
@ -695,28 +687,3 @@ fn quantized_matmul_q6k() -> Result<()> {
ggml_matmul_error_test::<BlockQ6K>()?; ggml_matmul_error_test::<BlockQ6K>()?;
Ok(()) Ok(())
} }
#[test]
fn quantized_matmul_q8k() -> Result<()> {
use k_quants::BlockQ8K;
let cpu = &Device::Cpu;
let (m, k, n) = (11, 512, 21);
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ8K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.266, 1.504, -0.204, 1.7]);
ggml_matmul_error_test::<BlockQ8K>()?;
Ok(())
}

View File

@ -1,24 +0,0 @@
use candle_core::{DType, Result, Tensor};
#[test]
fn npy() -> Result<()> {
let npy = Tensor::read_npy("tests/test.npy")?;
assert_eq!(
npy.to_dtype(DType::U8)?.to_vec1::<u8>()?,
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
);
Ok(())
}
#[test]
fn npz() -> Result<()> {
let npz = Tensor::read_npz("tests/test.npz")?;
assert_eq!(npz.len(), 2);
assert_eq!(npz[0].0, "x");
assert_eq!(npz[1].0, "x_plus_one");
assert_eq!(
npz[1].1.to_dtype(DType::U8)?.to_vec1::<u8>()?,
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
);
Ok(())
}

View File

@ -1,4 +1,4 @@
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor}; use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor};
fn zeros(device: &Device) -> Result<()> { fn zeros(device: &Device) -> Result<()> {
let tensor = Tensor::zeros((5, 2), DType::F32, device)?; let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
@ -8,50 +8,6 @@ fn zeros(device: &Device) -> Result<()> {
Ok(()) Ok(())
} }
fn ones(device: &Device) -> Result<()> {
assert_eq!(
Tensor::ones((2, 3), DType::U8, device)?.to_vec2::<u8>()?,
[[1, 1, 1], [1, 1, 1]],
);
assert_eq!(
Tensor::ones((2, 3), DType::U32, device)?.to_vec2::<u32>()?,
[[1, 1, 1], [1, 1, 1]],
);
assert_eq!(
Tensor::ones((2, 3), DType::I64, device)?.to_vec2::<i64>()?,
[[1, 1, 1], [1, 1, 1]],
);
assert_eq!(
Tensor::ones((2, 3), DType::F32, device)?.to_vec2::<f32>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
assert_eq!(
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
Ok(())
}
fn arange(device: &Device) -> Result<()> {
assert_eq!(
Tensor::arange(0u8, 5u8, device)?.to_vec1::<u8>()?,
[0, 1, 2, 3, 4],
);
assert_eq!(
Tensor::arange_step(0u8, 5u8, 2, device)?.to_vec1::<u8>()?,
[0, 2, 4],
);
assert_eq!(
Tensor::arange_step(0u8, 5u8, 3, device)?.to_vec1::<u8>()?,
[0, 3],
);
assert_eq!(
Tensor::arange_step(5i64, 0i64, -1, device)?.to_vec1::<i64>()?,
[5, 4, 3, 2, 1],
);
Ok(())
}
fn add_mul(device: &Device) -> Result<()> { fn add_mul(device: &Device) -> Result<()> {
let tensor = Tensor::new(&[3f32, 1., 4.], device)?; let tensor = Tensor::new(&[3f32, 1., 4.], device)?;
let dim1 = tensor.dims1()?; let dim1 = tensor.dims1()?;
@ -77,65 +33,6 @@ fn tensor_2d(device: &Device) -> Result<()> {
Ok(()) Ok(())
} }
fn clamp(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor = Tensor::new(data, device)?;
let tensor = tensor.clamp(1.5, 6.2)?;
assert_eq!(
tensor.to_vec2::<f32>()?,
[[3.0, 1.5, 4.0, 1.5, 5.0], [2.0, 1.5, 6.2, 6.2, 2.0]],
);
Ok(())
}
fn unary_op(device: &Device) -> Result<()> {
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
let tensor = Tensor::new(data, device)?;
assert_eq!(
test_utils::to_vec2_round(&tensor.gelu()?, 4)?,
[
[-0.0036, 0.8412, 3.9999, -0.046, 0.3457],
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
]
);
assert_eq!(
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
[
[-0.004, 0.8413, 3.9999, -0.046, 0.3457],
[2.6906, -0.0647, -0.1091, 1.7353, 2.7928]
]
);
assert_eq!(
test_utils::to_vec2_round(&tensor.erf()?, 4)?,
[
[-1.0, 0.8427, 1.0, -0.1125, 0.5205],
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
]
);
assert_eq!(
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
);
assert_eq!(
test_utils::to_vec2_round(&tensor.floor()?, 4)?,
[[-3.0, 1.0, 4.0, -1.0, 0.0], [2.0, -2.0, -1.0, 1.0, 2.0]]
);
assert_eq!(
test_utils::to_vec2_round(&tensor.round()?, 4)?,
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -2.0, -0.0, 2.0, 3.0]]
);
let tensor = Tensor::new(&[2997.9246, 314.15926f32], device)?;
assert_eq!(
test_utils::to_vec1_round(&tensor.round_to(2)?, 4)?,
[2997.92, 314.16]
);
assert_eq!(
test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?,
[3000.0, 300.]
);
Ok(())
}
fn binary_op(device: &Device) -> Result<()> { fn binary_op(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]]; let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor1 = Tensor::new(data, device)?; let tensor1 = Tensor::new(data, device)?;
@ -180,22 +77,6 @@ fn transpose(device: &Device) -> Result<()> {
Ok(()) Ok(())
} }
fn var(device: &Device) -> Result<()> {
// Values taken from https://pytorch.org/docs/stable/generated/torch.var.html
let data = &[
[0.2035f32, 1.2959, 1.8101, -0.4644],
[1.5027, -0.3270, 0.5905, 0.6538],
[-1.5745, 1.3330, -0.5596, -0.6548],
[0.1264, -0.5080, 1.6420, 0.1992],
];
let tensor = Tensor::new(data, device)?;
assert_eq!(
test_utils::to_vec2_round(&tensor.var_keepdim(1)?, 4)?,
&[[1.0631], [0.559], [1.4893], [0.8258]]
);
Ok(())
}
fn sum(device: &Device) -> Result<()> { fn sum(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]]; let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?; let tensor = Tensor::new(data, device)?;
@ -709,30 +590,6 @@ fn index_select(device: &Device) -> Result<()> {
hs.to_vec2::<f32>()?, hs.to_vec2::<f32>()?,
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]] &[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
); );
// Prior to https://github.com/huggingface/candle/pull/1022
// There would be a bug where the last values in the result tensor would be set to 0.
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
let hs = t.index_select(&ids, 0)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0],
[6.0, 7.0, 8.0],
[3.0, 4.0, 5.0],
[0.0, 1.0, 2.0],
[6.0, 7.0, 8.0],
[3.0, 4.0, 5.0],
]
);
// Test when selecting dim > 0 with ids size different from elem count of
// target dim in source/input.
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
let hs = t.index_select(&ids, 1)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
Ok(()) Ok(())
} }
@ -779,48 +636,6 @@ fn index_add(device: &Device) -> Result<()> {
Ok(()) Ok(())
} }
fn slice_scatter(device: &Device) -> Result<()> {
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
assert_eq!(
t.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0],
[3.0, 4.0, 5.0],
[6.0, 7.0, 8.0],
[9.0, 10.0, 11.0]
]
);
let src = Tensor::arange(100f32, 106f32, device)?.reshape((2, 3))?;
assert_eq!(
t.slice_scatter0(&src, 0)?.to_vec2::<f32>()?,
&[
[100.0, 101.0, 102.0],
[103.0, 104.0, 105.0],
[6.0, 7.0, 8.0],
[9.0, 10.0, 11.0]
]
);
assert_eq!(
t.slice_scatter0(&src, 1)?.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0],
[100.0, 101.0, 102.0],
[103.0, 104.0, 105.0],
[9.0, 10.0, 11.0]
]
);
assert_eq!(
t.slice_scatter0(&src, 2)?.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0],
[3.0, 4.0, 5.0],
[100.0, 101.0, 102.0],
[103.0, 104.0, 105.0],
]
);
Ok(())
}
fn scatter_add(device: &Device) -> Result<()> { fn scatter_add(device: &Device) -> Result<()> {
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?; let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
assert_eq!( assert_eq!(
@ -1062,68 +877,28 @@ fn broadcasting(device: &Device) -> Result<()> {
Ok(()) Ok(())
} }
fn randn(device: &Device) -> Result<()> { test_device!(zeros, zeros_cpu, zeros_gpu);
let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?; test_device!(add_mul, add_mul_cpu, add_mul_gpu);
assert_eq!(tensor.dims(), [5, 3]); test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?; test_device!(narrow, narrow_cpu, narrow_gpu);
assert_eq!(tensor.dims(), [5, 3]); test_device!(broadcast, broadcast_cpu, broadcast_gpu);
Ok(()) test_device!(cat, cat_cpu, cat_gpu);
} test_device!(sum, sum_cpu, sum_gpu);
test_device!(min, min_cpu, min_gpu);
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal); test_device!(max, max_cpu, max_gpu);
test_device!(ones, ones_cpu, ones_gpu, ones_metal); test_device!(argmax, argmax_cpu, argmax_gpu);
test_device!(arange, arange_cpu, arange_gpu, arange_metal); test_device!(argmin, argmin_cpu, argmin_gpu);
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal); test_device!(transpose, transpose_cpu, transpose_gpu);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal); test_device!(binary_op, binary_op_cpu, binary_op_gpu);
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal); test_device!(embeddings, embeddings_cpu, embeddings_gpu);
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal); test_device!(cmp, cmp_cpu, cmp_gpu);
test_device!(cat, cat_cpu, cat_gpu, cat_metal); test_device!(matmul, matmul_cpu, matmul_gpu);
test_device!(sum, sum_cpu, sum_gpu, sum_metal); test_device!(broadcast_matmul, broadcast_matmul_cpu, broadcast_matmul_gpu);
test_device!(min, min_cpu, min_gpu, min_metal); test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
test_device!(max, max_cpu, max_gpu, max_metal); test_device!(index_select, index_select_cpu, index_select_gpu);
test_device!(argmax, argmax_cpu, argmax_gpu, argmax_metal); test_device!(index_add, index_add_cpu, index_add_gpu);
test_device!(argmin, argmin_cpu, argmin_gpu, argmin_metal); test_device!(gather, gather_cpu, gather_gpu);
test_device!(transpose, transpose_cpu, transpose_gpu, transpose_metal); test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
test_device!(
broadcast_matmul,
broadcast_matmul_cpu,
broadcast_matmul_gpu,
broadcast_matmul_metal
);
test_device!(
broadcasting,
broadcasting_cpu,
broadcasting_gpu,
broadcasting_metal
);
test_device!(
index_select,
index_select_cpu,
index_select_gpu,
index_select_metal
);
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
test_device!(
scatter_add,
scatter_add_cpu,
scatter_add_gpu,
scatter_add_metal
);
test_device!(
slice_scatter,
slice_scatter_cpu,
slice_scatter_gpu,
slice_scatter_metal
);
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
test_device!(var, var_cpu, var_gpu, var_metal);
// There was originally a bug on the CPU implementation for randn // There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381 // https://github.com/huggingface/candle/issues/381
@ -1135,89 +910,3 @@ fn randn_hasneg() -> Result<()> {
} }
Ok(()) Ok(())
} }
#[test]
fn pad_with_same() -> Result<()> {
let t = Tensor::arange(1f32, 5f32, &Device::Cpu)?.reshape((2, 2))?;
let t0 = t.pad_with_same(0, 1, 2)?;
assert_eq!(
t0.to_vec2::<f32>()?,
[[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0], [3.0, 4.0]]
);
let t1 = t.pad_with_same(1, 1, 2)?;
assert_eq!(
t1.to_vec2::<f32>()?,
[[1.0, 1.0, 2.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0, 4.0]]
);
Ok(())
}
#[test]
fn i64_abs() -> Result<()> {
let t = Tensor::new(&[-42i64, 1337], &Device::Cpu)?;
let t = t.abs()?;
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
Ok(())
}
#[test]
fn tril_triu_eye() -> Result<()> {
let t = Tensor::tril2(4, DType::F32, &Device::Cpu)?;
assert_eq!(
t.to_vec2::<f32>()?,
[
[1.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 0.0],
[1.0, 1.0, 1.0, 1.0]
],
);
let t = Tensor::triu2(4, DType::F32, &Device::Cpu)?;
assert_eq!(
t.to_vec2::<f32>()?,
[
[1.0, 1.0, 1.0, 1.0],
[0.0, 1.0, 1.0, 1.0],
[0.0, 0.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 1.0]
]
);
let t = Tensor::eye(4, DType::F32, &Device::Cpu)?;
assert_eq!(
t.to_vec2::<f32>()?,
[
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0]
]
);
Ok(())
}
#[test]
fn cumsum() -> Result<()> {
let t = &[3f32, 1., 4., 1., 5.];
let t = Tensor::new(t, &Device::Cpu)?;
assert_eq!(t.cumsum(0)?.to_vec1::<f32>()?, [3., 4., 8., 9., 14.]);
let t = t.unsqueeze(1)?;
assert_eq!(
t.cumsum(0)?.to_vec2::<f32>()?,
[[3.0], [4.0], [8.0], [9.0], [14.0]]
);
assert_eq!(
t.cumsum(1)?.to_vec2::<f32>()?,
[[3.0], [1.0], [4.0], [1.0], [5.0]]
);
let t = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let t = Tensor::new(t, &Device::Cpu)?;
assert_eq!(
t.cumsum(1)?.to_vec2::<f32>()?,
[[3.0, 4.0, 8.0, 9.0, 14.0], [2.0, 3.0, 10.0, 18.0, 20.0]],
);
assert_eq!(
t.cumsum(0)?.to_vec2::<f32>()?,
[[3.0, 1.0, 4.0, 1.0, 5.0], [5.0, 2.0, 11.0, 9.0, 7.0]]
);
Ok(())
}

Binary file not shown.

Binary file not shown.

View File

@ -11,8 +11,8 @@ readme = "README.md"
[dependencies] [dependencies]
byteorder = { workspace = true } byteorder = { workspace = true }
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" } candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.1" } candle-nn = { path = "../candle-nn", version = "0.2.1" }
hf-hub = { workspace = true} hf-hub = { workspace = true}
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }
memmap2 = { workspace = true } memmap2 = { workspace = true }

View File

@ -4,9 +4,7 @@
//! <https://www.cs.toronto.edu/~kriz/cifar.html> //! <https://www.cs.toronto.edu/~kriz/cifar.html>
//! The binary version of the dataset is used. //! The binary version of the dataset is used.
use crate::vision::Dataset; use crate::vision::Dataset;
use candle::{DType, Device, Error, Result, Tensor}; use candle::{DType, Device, Result, Tensor};
use hf_hub::{api::sync::Api, Repo, RepoType};
use parquet::file::reader::{FileReader, SerializedFileReader};
use std::fs::File; use std::fs::File;
use std::io::{BufReader, Read}; use std::io::{BufReader, Read};
@ -62,58 +60,3 @@ pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<Dataset> {
labels: 10, labels: 10,
}) })
} }
fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor, Tensor)> {
let samples = parquet.metadata().file_metadata().num_rows() as usize;
let mut buffer_images: Vec<u8> = Vec::with_capacity(samples * 1_024);
let mut buffer_labels: Vec<u8> = Vec::with_capacity(samples);
for row in parquet.into_iter().flatten() {
for (_name, field) in row.get_column_iter() {
if let parquet::record::Field::Group(subrow) = field {
for (_name, field) in subrow.get_column_iter() {
if let parquet::record::Field::Bytes(value) = field {
let image = image::load_from_memory(value.data()).unwrap();
buffer_images.extend(image.to_rgb8().as_raw());
}
}
} else if let parquet::record::Field::Long(label) = field {
buffer_labels.push(*label as u8);
}
}
}
let images = (Tensor::from_vec(buffer_images, (samples, 3, 32, 32), &Device::Cpu)?
.to_dtype(DType::U8)?
/ 255.)?;
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
Ok((images, labels))
}
pub fn load() -> Result<Dataset> {
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let dataset_id = "cifar10".to_string();
let repo = Repo::with_revision(
dataset_id,
RepoType::Dataset,
"refs/convert/parquet".to_string(),
);
let repo = api.repo(repo);
let test_parquet_filename = repo
.get("plain_text/test/0000.parquet")
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let train_parquet_filename = repo
.get("plain_text/train/0000.parquet")
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
let (test_images, test_labels) = load_parquet(test_parquet)?;
let (train_images, train_labels) = load_parquet(train_parquet)?;
Ok(crate::vision::Dataset {
train_images,
train_labels,
test_images,
test_labels,
labels: 10,
})
}

View File

@ -8,9 +8,13 @@ use parquet::file::reader::{FileReader, SerializedFileReader};
use std::fs::File; use std::fs::File;
use std::io::{self, BufReader, Read}; use std::io::{self, BufReader, Read};
fn read_u32<T: Read>(reader: &mut T) -> std::io::Result<u32> { fn read_u32<T: Read>(reader: &mut T) -> Result<u32> {
use byteorder::ReadBytesExt; let mut b = vec![0u8; 4];
reader.read_u32::<byteorder::BigEndian>() reader.read_exact(&mut b)?;
let (result, _) = b.iter().rev().fold((0u64, 1u64), |(s, basis), &x| {
(s + basis * u64::from(x), basis * 256)
});
Ok(result as u32)
} }
fn check_magic_number<T: Read>(reader: &mut T, expected: u32) -> Result<()> { fn check_magic_number<T: Read>(reader: &mut T, expected: u32) -> Result<()> {

View File

@ -11,23 +11,19 @@ readme = "README.md"
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" } candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.3.1" } candle-datasets = { path = "../candle-datasets", version = "0.2.1" }
candle-nn = { path = "../candle-nn", version = "0.3.1" } candle-nn = { path = "../candle-nn", version = "0.2.1" }
candle-transformers = { path = "../candle-transformers", version = "0.3.1" } candle-transformers = { path = "../candle-transformers", version = "0.2.1" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true } candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
candle-onnx = { path = "../candle-onnx", version = "0.3.1", optional = true }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
image = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
rayon = { workspace = true }
safetensors = { workspace = true } safetensors = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }
tokenizers = { workspace = true, features = ["onig"] } num-traits = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
image = { workspace = true }
[dev-dependencies] [dev-dependencies]
anyhow = { workspace = true } anyhow = { workspace = true }
@ -38,6 +34,7 @@ imageproc = { workspace = true }
memmap2 = { workspace = true } memmap2 = { workspace = true }
rand = { workspace = true } rand = { workspace = true }
rusttype = { workspace = true } rusttype = { workspace = true }
tokenizers = { workspace = true, features = ["onig"] }
tracing = { workspace = true } tracing = { workspace = true }
tracing-chrome = { workspace = true } tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true } tracing-subscriber = { workspace = true }
@ -53,24 +50,10 @@ default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"] accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
cudnn = ["candle/cudnn"] cudnn = ["candle/cudnn"]
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"] flash-attn = ["cuda", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["cuda", "cudarc/nccl", "dep:half"] nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
[[example]] [[example]]
name = "llama_multiprocess" name = "llama_multiprocess"
required-features = ["cuda", "nccl", "flash-attn"] required-features = ["cuda", "nccl", "flash-attn"]
[[example]]
name = "reinforcement-learning"
required-features = ["pyo3"]
[[example]]
name = "onnx"
required-features = ["onnx"]
[[example]]
name = "onnx_basics"
required-features = ["onnx"]

View File

@ -1,44 +0,0 @@
# candle-bert
Bert is a general large language model. In this example it can be used for two
different tasks:
- Compute sentence embeddings for a prompt.
- Compute similarities between a set of sentences.
## Sentence embeddings
Bert is used to compute the sentence embeddings for a prompt. The model weights
are downloaded from the hub on the first run.
```bash
cargo run --example bert --release -- --prompt "Here is a test sentence"
> [[[ 0.0798, -0.0665, -0.0247, ..., -0.1082, -0.1000, -0.2751],
> [ 0.4218, 0.2690, 0.2740, ..., 0.3889, 1.3503, 0.9908],
> [ 0.0466, 0.3041, -0.1143, ..., 0.4427, 0.6926, -0.1515],
> ...
> [ 0.3396, 0.4320, -0.4408, ..., 0.9212, 0.2331, -0.6777],
> [ 0.2789, 0.7539, 0.4306, ..., -0.0095, 0.3375, -1.7529],
> [ 0.6737, 0.7882, 0.0548, ..., 0.1836, 0.7299, -0.6617]]]
> Tensor[[1, 7, 384], f32]
```
## Similarities
In this example, Bert is used to compute the sentence embeddings for a set of
sentences (hardcoded in the examples). Then cosine similarities are computed for
each sentence pair and they are reported by decreasing values, hence the first
reported pair contains the two sentences that have the highest similarity score.
The sentence embeddings are computed using average pooling through all the
sentence tokens, including some potential padding.
```bash
cargo run --example bert --release
> score: 0.85 'The new movie is awesome' 'The new movie is so great'
> score: 0.61 'The cat sits outside' 'The cat plays in the garden'
> score: 0.52 'I love pasta' 'Do you like pizza?'
> score: 0.23 'The new movie is awesome' 'Do you like pizza?'
> score: 0.22 'I love pasta' 'The new movie is awesome'
```

View File

@ -3,13 +3,14 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")] #[cfg(feature = "accelerate")]
extern crate accelerate_src; extern crate accelerate_src;
use candle_transformers::models::bert::{BertModel, Config, DTYPE}; mod model;
use anyhow::{Error as E, Result}; use anyhow::{anyhow, Error as E, Result};
use candle::Tensor; use candle::Tensor;
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
use clap::Parser; use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType}; use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
use model::{BertModel, Config, DTYPE};
use tokenizers::{PaddingParams, Tokenizer}; use tokenizers::{PaddingParams, Tokenizer};
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -19,6 +20,10 @@ struct Args {
#[arg(long)] #[arg(long)]
cpu: bool, cpu: bool,
/// Run offline (you must have the files already cached)
#[arg(long)]
offline: bool,
/// Enable tracing (generates a trace-timestamp.json file). /// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)] #[arg(long)]
tracing: bool, tracing: bool,
@ -34,10 +39,6 @@ struct Args {
#[arg(long)] #[arg(long)]
prompt: Option<String>, prompt: Option<String>,
/// Use the pytorch weights rather than the safetensors ones
#[arg(long)]
use_pth: bool,
/// The number of times to run the prompt. /// The number of times to run the prompt.
#[arg(long, default_value = "1")] #[arg(long, default_value = "1")]
n: usize, n: usize,
@ -60,27 +61,35 @@ impl Args {
}; };
let repo = Repo::with_revision(model_id, RepoType::Model, revision); let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let (config_filename, tokenizer_filename, weights_filename) = { let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
let cache = Cache::default().repo(repo);
(
cache
.get("config.json")
.ok_or(anyhow!("Missing config file in cache"))?,
cache
.get("tokenizer.json")
.ok_or(anyhow!("Missing tokenizer file in cache"))?,
cache
.get("model.safetensors")
.ok_or(anyhow!("Missing weights file in cache"))?,
)
} else {
let api = Api::new()?; let api = Api::new()?;
let api = api.repo(repo); let api = api.repo(repo);
let config = api.get("config.json")?; (
let tokenizer = api.get("tokenizer.json")?; api.get("config.json")?,
let weights = if self.use_pth { api.get("tokenizer.json")?,
api.get("pytorch_model.bin")? api.get("model.safetensors")?,
} else { )
api.get("model.safetensors")?
};
(config, tokenizer, weights)
}; };
let config = std::fs::read_to_string(config_filename)?; let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?; let config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let vb = if self.use_pth { let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
VarBuilder::from_pth(&weights_filename, DTYPE, &device)? let weights = weights.deserialize()?;
} else { let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
};
let model = BertModel::load(vb, &config)?; let model = BertModel::load(vb, &config)?;
Ok((model, tokenizer)) Ok((model, tokenizer))
} }

View File

@ -1,4 +1,3 @@
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
use candle::{DType, Device, Result, Tensor}; use candle::{DType, Device, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder}; use candle_nn::{Embedding, Module, VarBuilder};
use serde::Deserialize; use serde::Deserialize;
@ -26,13 +25,81 @@ impl HiddenActLayer {
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> { fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
let _enter = self.span.enter(); let _enter = self.span.enter();
match self.act { match self.act {
// TODO: The all-MiniLM-L6-v2 model uses "gelu" whereas this is "gelu_new", this explains some
// small numerical difference.
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213 // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
HiddenAct::Gelu => xs.gelu_erf(), HiddenAct::Gelu => xs.gelu(),
HiddenAct::Relu => xs.relu(), HiddenAct::Relu => xs.relu(),
} }
} }
} }
#[derive(Debug)]
pub struct Linear {
weight: Tensor,
bias: Option<Tensor>,
span: tracing::Span,
}
impl Linear {
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
let span = tracing::span!(tracing::Level::TRACE, "linear");
Self { weight, bias, span }
}
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let _enter = self.span.enter();
let w = match x.dims() {
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
_ => self.weight.t()?,
};
let x = x.matmul(&w)?;
match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),
}
}
}
#[derive(Debug)]
pub struct LayerNorm {
weight: Tensor,
bias: Tensor,
eps: f64,
span: tracing::Span,
}
impl LayerNorm {
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
Self {
weight,
bias,
eps,
span,
}
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let (_bsize, _seq_len, hidden_size) = x.dims3()?;
let x = x.to_dtype(internal_dtype)?;
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
let x = x.broadcast_sub(&mean_x)?;
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
let x = x_normed
.to_dtype(x_dtype)?
.broadcast_mul(&self.weight)?
.broadcast_add(&self.bias)?;
Ok(x)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
enum PositionEmbeddingType { enum PositionEmbeddingType {
@ -115,6 +182,12 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
Ok(Embedding::new(embeddings, hidden_size)) Ok(Embedding::new(embeddings, hidden_size))
} }
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
let bias = vb.get(size2, "bias")?;
Ok(Linear::new(weight, Some(bias)))
}
struct Dropout { struct Dropout {
#[allow(dead_code)] #[allow(dead_code)]
pr: f64, pr: f64,
@ -124,15 +197,27 @@ impl Dropout {
fn new(pr: f64) -> Self { fn new(pr: f64) -> Self {
Self { pr } Self { pr }
} }
}
impl Module for Dropout {
fn forward(&self, x: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor) -> Result<Tensor> {
// TODO // TODO
Ok(x.clone()) Ok(x.clone())
} }
} }
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
(Ok(weight), Ok(bias)) => (weight, bias),
(Err(err), _) | (_, Err(err)) => {
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
(weight, bias)
} else {
return Err(err);
}
}
};
Ok(LayerNorm::new(weight, bias, eps))
}
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180 // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
struct BertEmbeddings { struct BertEmbeddings {
word_embeddings: Embedding, word_embeddings: Embedding,
@ -233,9 +318,7 @@ impl BertSelfAttention {
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?; let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
xs.contiguous() xs.contiguous()
} }
}
impl Module for BertSelfAttention {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter(); let _enter = self.span.enter();
let query_layer = self.query.forward(hidden_states)?; let query_layer = self.query.forward(hidden_states)?;
@ -310,9 +393,7 @@ impl BertAttention {
span: tracing::span!(tracing::Level::TRACE, "attn"), span: tracing::span!(tracing::Level::TRACE, "attn"),
}) })
} }
}
impl Module for BertAttention {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter(); let _enter = self.span.enter();
let self_outputs = self.self_attention.forward(hidden_states)?; let self_outputs = self.self_attention.forward(hidden_states)?;
@ -337,9 +418,7 @@ impl BertIntermediate {
span: tracing::span!(tracing::Level::TRACE, "inter"), span: tracing::span!(tracing::Level::TRACE, "inter"),
}) })
} }
}
impl Module for BertIntermediate {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter(); let _enter = self.span.enter();
let hidden_states = self.dense.forward(hidden_states)?; let hidden_states = self.dense.forward(hidden_states)?;
@ -401,9 +480,7 @@ impl BertLayer {
span: tracing::span!(tracing::Level::TRACE, "layer"), span: tracing::span!(tracing::Level::TRACE, "layer"),
}) })
} }
}
impl Module for BertLayer {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter(); let _enter = self.span.enter();
let attention_output = self.attention.forward(hidden_states)?; let attention_output = self.attention.forward(hidden_states)?;
@ -432,9 +509,7 @@ impl BertEncoder {
let span = tracing::span!(tracing::Level::TRACE, "encoder"); let span = tracing::span!(tracing::Level::TRACE, "encoder");
Ok(BertEncoder { layers, span }) Ok(BertEncoder { layers, span })
} }
}
impl Module for BertEncoder {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter(); let _enter = self.span.enter();
let mut hidden_states = hidden_states.clone(); let mut hidden_states = hidden_states.clone();

View File

@ -1,19 +0,0 @@
# candle-starcoder: code generation model
[StarCoder/BigCode](https://huggingface.co/bigcode/starcoderbase-1b) is a LLM
model specialized to code generation. The initial model was trained on 80
programming languages.
## Running some example
```bash
cargo run --example bigcode --release -- --prompt "fn fact(n: u64) -> u64 "
> fn fact(n: u64) -> u64 {
> if n == 0 {
> 1
> } else {
> n * fact(n - 1)
> }
> }
```

View File

@ -7,7 +7,8 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};
use clap::Parser; use clap::Parser;
use candle_transformers::models::bigcode::{Config, GPTBigCode}; mod model;
use model::{Config, GPTBigCode};
use candle::{DType, Device, Tensor}; use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
@ -28,10 +29,9 @@ impl TextGeneration {
tokenizer: Tokenizer, tokenizer: Tokenizer,
seed: u64, seed: u64,
temp: Option<f64>, temp: Option<f64>,
top_p: Option<f64>,
device: &Device, device: &Device,
) -> Self { ) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p); let logits_processor = LogitsProcessor::new(seed, temp);
Self { Self {
model, model,
tokenizer, tokenizer,
@ -95,10 +95,6 @@ struct Args {
#[arg(long)] #[arg(long)]
temperature: Option<f64>, temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples. /// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)] #[arg(long, default_value_t = 299792458)]
seed: u64, seed: u64,
@ -138,21 +134,23 @@ fn main() -> Result<()> {
println!("retrieved the files in {:?}", start.elapsed()); println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let weights = filenames
.iter()
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
.collect::<Result<Vec<_>>>()?;
let weights = weights
.iter()
.map(|f| Ok(f.deserialize()?))
.collect::<Result<Vec<_>>>()?;
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; let vb = VarBuilder::from_safetensors(weights, DType::F32, &device);
let config = Config::starcoder_1b(); let config = Config::starcoder_1b();
let model = GPTBigCode::load(vb, config)?; let model = GPTBigCode::load(vb, config)?;
println!("loaded the model in {:?}", start.elapsed()); println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new( let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device);
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?; pipeline.run(&args.prompt, args.sample_len)?;
Ok(()) Ok(())
} }

View File

@ -182,7 +182,7 @@ impl Attention {
let mask_value = let mask_value =
Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?; Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?;
let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?; let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?;
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
let value = value.contiguous()?; let value = value.contiguous()?;
let attn_output = if self.multi_query { let attn_output = if self.multi_query {
attn_weights attn_weights

View File

@ -1,19 +0,0 @@
# candle-blip
The
[blip-image-captioning](https://huggingface.co/Salesforce/blip-image-captioning-base)
model can generate captions for an input image.
## Running on an example
```bash
cargo run --example blip --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
```
```
Running on CPU, to run on GPU, build this example with `--features cuda`
loaded image Tensor[dims 3, 384, 384; f32]
model built
several cyclists are riding down a road with cars behind them%
```
![Leading group, Giro d'Italia 2021](../yolo-v8/assets/bike.jpg)

View File

@ -1,154 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::Parser;
use candle::{DType, Device, Result, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::models::blip;
use candle_transformers::models::quantized_blip;
use tokenizers::Tokenizer;
enum Model {
M(blip::BlipForConditionalGeneration),
Q(quantized_blip::BlipForConditionalGeneration),
}
impl Model {
fn text_decoder_forward(&mut self, xs: &Tensor, img_xs: &Tensor) -> Result<Tensor> {
match self {
Self::M(m) => m.text_decoder().forward(xs, img_xs),
Self::Q(m) => m.text_decoder().forward(xs, img_xs),
}
}
}
// TODO: Maybe add support for the conditional prompt.
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Use the quantized version of the model.
#[arg(long)]
quantized: bool,
}
const SEP_TOKEN_ID: u32 = 102;
/// Loads an image from disk using the image crate, this returns a tensor with shape
/// (3, 384, 384). OpenAI normalization is applied.
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
let img = image::io::Reader::open(p)?
.decode()
.map_err(candle::Error::wrap)?
.resize_to_fill(384, 384, image::imageops::FilterType::Triangle);
let img = img.to_rgb8();
let data = img.into_raw();
let data = Tensor::from_vec(data, (384, 384, 3), &Device::Cpu)?.permute((2, 0, 1))?;
let mean =
Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], &Device::Cpu)?.reshape((3, 1, 1))?;
let std = Tensor::new(&[0.26862954f32, 0.261_302_6, 0.275_777_1], &Device::Cpu)?
.reshape((3, 1, 1))?;
(data.to_dtype(candle::DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
if args.quantized {
let api = api.model("lmz/candle-blip".to_string());
api.get("blip-image-captioning-large-q4k.gguf")?
} else {
let api = api.repo(hf_hub::Repo::with_revision(
"Salesforce/blip-image-captioning-large".to_string(),
hf_hub::RepoType::Model,
"refs/pr/18".to_string(),
));
api.get("model.safetensors")?
}
}
Some(model) => model.into(),
};
let tokenizer = match args.tokenizer {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("Salesforce/blip-image-captioning-large".to_string());
api.get("tokenizer.json")?
}
Some(file) => file.into(),
};
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let mut tokenizer = TokenOutputStream::new(tokenizer);
let mut logits_processor =
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
let config = blip::Config::image_captioning_large();
let (image_embeds, device, mut model) = if args.quantized {
let device = Device::Cpu;
let image = load_image(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");
let vb = quantized_blip::VarBuilder::from_gguf(model_file)?;
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
(image_embeds, device, Model::Q(model))
} else {
let device = candle_examples::device(args.cpu)?;
let image = load_image(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = blip::BlipForConditionalGeneration::new(&config, vb)?;
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
(image_embeds, device, Model::M(model))
};
let mut token_ids = vec![30522u32];
for index in 0..1000 {
let context_size = if index > 0 { 1 } else { token_ids.len() };
let start_pos = token_ids.len().saturating_sub(context_size);
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
let logits = model.text_decoder_forward(&input_ids, &image_embeds)?;
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
if token == SEP_TOKEN_ID {
break;
}
token_ids.push(token);
if let Some(t) = tokenizer.next_token(token)? {
use std::io::Write;
print!("{t}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
println!();
Ok(())
}

View File

@ -1,59 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::Parser;
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::convmixer;
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("lmz/candle-convmixer".into());
api.get("convmixer_1024_20_ks9_p14.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = convmixer::c1024_20(1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -1,19 +0,0 @@
# candle-dinov2
[DINOv2](https://github.com/facebookresearch/dinov2) is a computer vision model.
In this example, it is used as an ImageNet classifier: the model returns the
probability for the image to belong to each of the 1000 ImageNet categories.
## Running some example
```bash
cargo run --example dinov2 --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
> mountain bike, all-terrain bike, off-roader: 43.67%
> bicycle-built-for-two, tandem bicycle, tandem: 33.20%
> crash helmet : 13.23%
> unicycle, monocycle : 2.44%
> maillot : 2.42%
```
![Leading group, Giro d'Italia 2021](../yolo-v8/assets/bike.jpg)

View File

@ -9,10 +9,285 @@ extern crate accelerate_src;
use clap::Parser; use clap::Parser;
use candle::{DType, IndexOp, D}; use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{Module, VarBuilder}; use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
use candle_transformers::models::dinov2;
const IMG_SIZE: usize = 518;
const PATCH_SIZE: usize = 14;
const NUM_CLASSES: usize = 1000;
fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
if bias {
candle_nn::linear(in_dim, out_dim, vb)
} else {
candle_nn::linear_no_bias(in_dim, out_dim, vb)
}
}
#[derive(Debug)]
struct Attention {
qkv: Linear,
proj: Linear,
num_heads: usize,
scale: f64,
}
impl Attention {
fn new(
vb: VarBuilder,
dim: usize,
num_heads: usize,
qkv_bias: bool,
proj_bias: bool,
) -> Result<Self> {
let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
let scale = 1. / ((dim / num_heads) as f64).sqrt();
Ok(Self {
qkv,
proj,
num_heads,
scale,
})
}
}
impl Module for Attention {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (b, n, c) = xs.dims3()?;
let qkv = self
.qkv
.forward(xs)?
.reshape((b, n, 3, self.num_heads, c / self.num_heads))?
.transpose(1, 2)? // 02134
.transpose(0, 1)? // 20134
.transpose(2, 3)?; // 20314
let q = (qkv.i(0)? * self.scale)?;
let k = qkv.i(1)?;
let v = qkv.i(2)?;
let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?;
let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
self.proj.forward(&attn)
}
}
#[derive(Debug)]
struct LayerScale {
gamma: Tensor,
}
impl LayerScale {
fn new(vb: VarBuilder, dim: usize) -> Result<Self> {
let gamma = vb.get(dim, "gamma")?;
Ok(Self { gamma })
}
}
impl Module for LayerScale {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.broadcast_mul(&self.gamma)
}
}
#[derive(Debug)]
struct Mlp {
fc1: Linear,
fc2: Linear,
}
impl Mlp {
fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> {
let out_features = in_features;
let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?;
let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?;
Ok(Self { fc1, fc2 })
}
}
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.fc1.forward(xs)?.gelu()?;
self.fc2.forward(&xs)
}
}
#[derive(Debug)]
struct Block {
norm1: LayerNorm,
attn: Attention,
ls1: LayerScale,
norm2: LayerNorm,
mlp: Mlp,
ls2: LayerScale,
}
impl Block {
fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?;
let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
let ls2 = LayerScale::new(vb.pp("ls2"), dim)?;
Ok(Self {
norm1,
attn,
ls1,
norm2,
mlp,
ls2,
})
}
}
impl Module for Block {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
let xs = self
.ls1
.forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = self
.ls2
.forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?;
xs + residual
}
}
#[derive(Debug)]
struct PatchEmbed {
proj: candle_nn::Conv2d,
patch_size: (usize, usize),
num_patches: usize,
}
impl PatchEmbed {
fn new(
vb: VarBuilder,
img_size: usize,
patch_size: usize,
in_chans: usize,
embed_dim: usize,
) -> Result<Self> {
let config = candle_nn::Conv2dConfig {
stride: patch_size,
..Default::default()
};
let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?;
let num_patches = (img_size / patch_size) * (img_size / patch_size);
Ok(Self {
proj,
patch_size: (patch_size, patch_size),
num_patches,
})
}
}
impl Module for PatchEmbed {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (_b, _c, h, w) = xs.dims4()?;
let (patch_h, patch_w) = self.patch_size;
if (h % patch_h) != 0 {
candle::bail!("image height {h} is not a multiple of patch height {patch_h}")
}
if (w % patch_w) != 0 {
candle::bail!("image width {w} is not a multiple of patch width {patch_w}")
}
let xs = self.proj.forward(xs)?;
let (b, c, h, w) = xs.dims4()?;
// flatten embeddings.
xs.reshape((b, c, h * w))?.transpose(1, 2)
}
}
#[derive(Debug)]
pub struct DinoVisionTransformer {
patch_embed: PatchEmbed,
cls_token: Tensor,
pos_embed: Tensor,
blocks: Vec<Block>,
norm: LayerNorm,
head: Linear,
}
impl DinoVisionTransformer {
pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> {
let patch_embed =
PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?;
let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
let num_tokens = 1;
let pos_embed = vb.get(
(1, patch_embed.num_patches + num_tokens, embed_dim),
"pos_embed",
)?;
let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?;
let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?;
let vb_b = vb.pp("blocks");
let blocks = (0..depth)
.map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
patch_embed,
cls_token,
pos_embed,
blocks,
norm,
head,
})
}
fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
let npatch = xs.dim(1)? - 1;
let n = self.pos_embed.dim(1)? - 1;
let sqrt_n = (n as f64).sqrt();
if npatch == n && w == h {
return Ok(xs.clone());
}
let class_pos_embed = self.pos_embed.i((.., ..1))?;
let patch_pos_embed = self.pos_embed.i((.., 1..))?;
let dim = xs.dim(D::Minus1)?;
let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
let patch_pos_embed = patch_pos_embed
.reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
.transpose(2, 3)?
.transpose(1, 2)?;
// This uses bicubic interpolation in the original implementation.
let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
let el_count = patch_pos_embed.shape().elem_count();
let patch_pos_embed =
patch_pos_embed
.transpose(1, 2)?
.transpose(2, 3)?
.reshape((1, el_count / dim, dim))?;
Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1)
}
fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
let (_b, _nc, w, h) = xs.dims4()?;
let xs = self.patch_embed.forward(xs)?;
let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
&xs + &self.interpolate_pos_encoding(&xs, w, h)?
}
}
impl Module for DinoVisionTransformer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.prepare_tokens_with_mask(xs)?;
for blk in self.blocks.iter() {
xs = blk.forward(&xs)?
}
let xs = self.norm.forward(&xs)?;
let xs_norm_clstoken = xs.i((.., 0))?;
let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?;
let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?;
self.head.forward(&xs)
}
}
pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> {
DinoVisionTransformer::new(vb, 12, 384, 6)
}
#[derive(Parser)] #[derive(Parser)]
struct Args { struct Args {
#[arg(long)] #[arg(long)]
@ -42,8 +317,10 @@ pub fn main() -> anyhow::Result<()> {
} }
Some(model) => model.into(), Some(model) => model.into(),
}; };
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
let model = dinov2::vit_small(vb)?; let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let model = vit_small(vb)?;
println!("model built"); println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?; let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)? let prs = candle_nn::ops::softmax(&logits, D::Minus1)?

View File

@ -1,22 +0,0 @@
# candle-distilbert
DistilBert is a distiled version of the Bert model.
## Sentence embeddings
DistilBert is used to compute the sentence embeddings for a prompt. The model weights
are downloaded from the hub on the first run.
```bash
cargo run --example distilbert --release -- --prompt "Here is a test sentence"
> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441],
> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244],
> [ 0.0702, -0.1311, -0.4914, ..., 0.3483, -0.6194, 0.1829],
> ...
> [ 0.2993, -0.0106, -0.4640, ..., 0.2844, -0.6732, 0.0042],
> [ 0.1066, -0.0081, -0.4299, ..., 0.3435, -0.7729, 0.0190],
> [ 0.8903, 0.2055, -0.2541, ..., 0.3208, -0.6585, 0.0586]]]
> Tensor[[1, 7, 768], f32]
```

View File

@ -1,135 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE};
use anyhow::{Error as E, Result};
use candle::{Device, Tensor};
use candle_nn::VarBuilder;
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
/// When set, compute embeddings for this prompt.
#[arg(long)]
prompt: String,
/// Use the pytorch weights rather than the safetensors ones
#[arg(long)]
use_pth: bool,
/// The number of times to run the prompt.
#[arg(long, default_value = "1")]
n: usize,
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
}
impl Args {
fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> {
let device = candle_examples::device(self.cpu)?;
let default_model = "distilbert-base-uncased".to_string();
let default_revision = "main".to_string();
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(None, Some(revision)) => (default_model, revision),
(None, None) => (default_model, default_revision),
};
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let (config_filename, tokenizer_filename, weights_filename) = {
let api = Api::new()?;
let api = api.repo(repo);
let config = api.get("config.json")?;
let tokenizer = api.get("tokenizer.json")?;
let weights = if self.use_pth {
api.get("pytorch_model.bin")?
} else {
api.get("model.safetensors")?
};
(config, tokenizer, weights)
};
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let vb = if self.use_pth {
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
} else {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
};
let model = DistilBertModel::load(vb, &config)?;
Ok((model, tokenizer))
}
}
fn get_mask(size: usize, device: &Device) -> Tensor {
let mask: Vec<_> = (0..size)
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
.collect();
Tensor::from_slice(&mask, (size, size), device).unwrap()
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
let device = &model.device;
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let tokens = tokenizer
.encode(args.prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let mask = get_mask(tokens.len(), device);
println!("token_ids: {:?}", token_ids.to_vec2::<u32>());
println!("mask: {:?}", mask.to_vec2::<u8>());
let ys = model.forward(&token_ids, &mask)?;
println!("{ys}");
Ok(())
}
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
}

View File

@ -8,11 +8,340 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")] #[cfg(feature = "accelerate")]
extern crate accelerate_src; extern crate accelerate_src;
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::efficientnet::{EfficientNet, MBConvConfig};
use clap::{Parser, ValueEnum}; use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn as nn;
use nn::{Module, VarBuilder};
// Based on the Python version from torchvision.
// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47
#[derive(Debug, Clone, Copy)]
pub struct MBConvConfig {
expand_ratio: f64,
kernel: usize,
stride: usize,
input_channels: usize,
out_channels: usize,
num_layers: usize,
}
fn make_divisible(v: f64, divisor: usize) -> usize {
let min_value = divisor;
let new_v = usize::max(
min_value,
(v + divisor as f64 * 0.5) as usize / divisor * divisor,
);
if (new_v as f64) < 0.9 * v {
new_v + divisor
} else {
new_v
}
}
fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> {
let bneck_conf = |e, k, s, i, o, n| {
let input_channels = make_divisible(i as f64 * width_mult, 8);
let out_channels = make_divisible(o as f64 * width_mult, 8);
let num_layers = (n as f64 * depth_mult).ceil() as usize;
MBConvConfig {
expand_ratio: e,
kernel: k,
stride: s,
input_channels,
out_channels,
num_layers,
}
};
vec![
bneck_conf(1., 3, 1, 32, 16, 1),
bneck_conf(6., 3, 2, 16, 24, 2),
bneck_conf(6., 5, 2, 24, 40, 2),
bneck_conf(6., 3, 2, 40, 80, 3),
bneck_conf(6., 5, 1, 80, 112, 3),
bneck_conf(6., 5, 2, 112, 192, 4),
bneck_conf(6., 3, 1, 192, 320, 1),
]
}
impl MBConvConfig {
fn b0() -> Vec<Self> {
bneck_confs(1.0, 1.0)
}
fn b1() -> Vec<Self> {
bneck_confs(1.0, 1.1)
}
fn b2() -> Vec<Self> {
bneck_confs(1.1, 1.2)
}
fn b3() -> Vec<Self> {
bneck_confs(1.2, 1.4)
}
fn b4() -> Vec<Self> {
bneck_confs(1.4, 1.8)
}
fn b5() -> Vec<Self> {
bneck_confs(1.6, 2.2)
}
fn b6() -> Vec<Self> {
bneck_confs(1.8, 2.6)
}
fn b7() -> Vec<Self> {
bneck_confs(2.0, 3.1)
}
}
/// Conv2D with same padding.
#[derive(Debug)]
struct Conv2DSame {
conv2d: nn::Conv2d,
s: usize,
k: usize,
}
impl Conv2DSame {
fn new(
vb: VarBuilder,
i: usize,
o: usize,
k: usize,
stride: usize,
groups: usize,
bias: bool,
) -> Result<Self> {
let conv_config = nn::Conv2dConfig {
stride,
groups,
..Default::default()
};
let conv2d = if bias {
nn::conv2d(i, o, k, conv_config, vb)?
} else {
nn::conv2d_no_bias(i, o, k, conv_config, vb)?
};
Ok(Self {
conv2d,
s: stride,
k,
})
}
}
impl Module for Conv2DSame {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let s = self.s;
let k = self.k;
let (_, _, ih, iw) = xs.dims4()?;
let oh = (ih + s - 1) / s;
let ow = (iw + s - 1) / s;
let pad_h = usize::max((oh - 1) * s + k - ih, 0);
let pad_w = usize::max((ow - 1) * s + k - iw, 0);
if pad_h > 0 || pad_w > 0 {
let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?;
let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?;
self.conv2d.forward(&xs)
} else {
self.conv2d.forward(xs)
}
}
}
#[derive(Debug)]
struct ConvNormActivation {
conv2d: Conv2DSame,
bn2d: nn::BatchNorm,
activation: bool,
}
impl ConvNormActivation {
fn new(
vb: VarBuilder,
i: usize,
o: usize,
k: usize,
stride: usize,
groups: usize,
) -> Result<Self> {
let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?;
let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?;
Ok(Self {
conv2d,
bn2d,
activation: true,
})
}
fn no_activation(self) -> Self {
Self {
activation: false,
..self
}
}
}
impl Module for ConvNormActivation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.conv2d.forward(xs)?;
let xs = self.bn2d.forward(&xs)?;
if self.activation {
swish(&xs)
} else {
Ok(xs)
}
}
}
#[derive(Debug)]
struct SqueezeExcitation {
fc1: Conv2DSame,
fc2: Conv2DSame,
}
impl SqueezeExcitation {
fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> {
let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?;
let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?;
Ok(Self { fc1, fc2 })
}
}
impl Module for SqueezeExcitation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
// equivalent to adaptive_avg_pool2d([1, 1])
let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
let xs = self.fc1.forward(&xs)?;
let xs = swish(&xs)?;
let xs = self.fc2.forward(&xs)?;
let xs = nn::ops::sigmoid(&xs)?;
residual.broadcast_mul(&xs)
}
}
#[derive(Debug)]
struct MBConv {
expand_cna: Option<ConvNormActivation>,
depthwise_cna: ConvNormActivation,
squeeze_excitation: SqueezeExcitation,
project_cna: ConvNormActivation,
config: MBConvConfig,
}
impl MBConv {
fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> {
let vb = vb.pp("block");
let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8);
let expand_cna = if exp != c.input_channels {
Some(ConvNormActivation::new(
vb.pp("0"),
c.input_channels,
exp,
1,
1,
1,
)?)
} else {
None
};
let start_index = if expand_cna.is_some() { 1 } else { 0 };
let depthwise_cna =
ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?;
let squeeze_channels = usize::max(1, c.input_channels / 4);
let squeeze_excitation =
SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?;
let project_cna =
ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)?
.no_activation();
Ok(Self {
expand_cna,
depthwise_cna,
squeeze_excitation,
project_cna,
config: c,
})
}
}
impl Module for MBConv {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let use_res_connect =
self.config.stride == 1 && self.config.input_channels == self.config.out_channels;
let ys = match &self.expand_cna {
Some(expand_cna) => expand_cna.forward(xs)?,
None => xs.clone(),
};
let ys = self.depthwise_cna.forward(&ys)?;
let ys = self.squeeze_excitation.forward(&ys)?;
let ys = self.project_cna.forward(&ys)?;
if use_res_connect {
ys + xs
} else {
Ok(ys)
}
}
}
fn swish(s: &Tensor) -> Result<Tensor> {
s * nn::ops::sigmoid(s)?
}
#[derive(Debug)]
struct EfficientNet {
init_cna: ConvNormActivation,
blocks: Vec<MBConv>,
final_cna: ConvNormActivation,
classifier: nn::Linear,
}
impl EfficientNet {
fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> {
let f_p = p.pp("features");
let first_in_c = configs[0].input_channels;
let last_out_c = configs.last().unwrap().out_channels;
let final_out_c = 4 * last_out_c;
let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
let nconfigs = configs.len();
let mut blocks = vec![];
for (index, cnf) in configs.into_iter().enumerate() {
let f_p = f_p.pp(index + 1);
for r_index in 0..cnf.num_layers {
let cnf = if r_index == 0 {
cnf
} else {
MBConvConfig {
input_channels: cnf.out_channels,
stride: 1,
..cnf
}
};
blocks.push(MBConv::new(f_p.pp(r_index), cnf)?)
}
}
let final_cna =
ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?;
let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?;
Ok(Self {
init_cna,
blocks,
final_cna,
classifier,
})
}
}
impl Module for EfficientNet {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.init_cna.forward(xs)?;
for block in self.blocks.iter() {
xs = block.forward(&xs)?
}
let xs = self.final_cna.forward(&xs)?;
// Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1)
let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?;
self.classifier.forward(&xs)
}
}
#[derive(Clone, Copy, Debug, ValueEnum)] #[derive(Clone, Copy, Debug, ValueEnum)]
enum Which { enum Which {
B0, B0,
@ -68,7 +397,9 @@ pub fn main() -> anyhow::Result<()> {
} }
Some(model) => model.into(), Some(model) => model.into(),
}; };
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let cfg = match args.which { let cfg = match args.which {
Which::B0 => MBConvConfig::b0(), Which::B0 => MBConvConfig::b0(),
Which::B1 => MBConvConfig::b1(), Which::B1 => MBConvConfig::b1(),

View File

@ -1,3 +0,0 @@
# candle-falcon
Falcon is a general large language model.

View File

@ -14,7 +14,8 @@ use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType}; use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use candle_transformers::models::falcon::{Config, Falcon}; mod model;
use model::{Config, Falcon};
struct TextGeneration { struct TextGeneration {
model: Falcon, model: Falcon,
@ -25,25 +26,17 @@ struct TextGeneration {
repeat_last_n: usize, repeat_last_n: usize,
} }
struct GenerationOptions {
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration { impl TextGeneration {
fn new( fn new(
model: Falcon, model: Falcon,
tokenizer: Tokenizer, tokenizer: Tokenizer,
generation_options: GenerationOptions,
seed: u64, seed: u64,
temp: Option<f64>,
device: &Device, device: &Device,
repeat_penalty: f32,
repeat_last_n: usize,
) -> Self { ) -> Self {
let logits_processor = let logits_processor = LogitsProcessor::new(seed, temp);
LogitsProcessor::new(seed, generation_options.temp, generation_options.top_p);
let repeat_penalty = generation_options.repeat_penalty;
let repeat_last_n = generation_options.repeat_last_n;
Self { Self {
model, model,
tokenizer, tokenizer,
@ -126,10 +119,6 @@ struct Args {
#[arg(long)] #[arg(long)]
temperature: Option<f64>, temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples. /// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)] #[arg(long, default_value_t = 299792458)]
seed: u64, seed: u64,
@ -177,25 +166,35 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let weights = filenames
.iter()
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
.collect::<Result<Vec<_>>>()?;
let weights = weights
.iter()
.map(|f| Ok(f.deserialize()?))
.collect::<Result<Vec<_>>>()?;
let dtype = if args.use_f32 { let dtype = if args.use_f32 {
DType::F32 DType::F32
} else { } else {
DType::BF16 DType::BF16
}; };
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; let vb = VarBuilder::from_safetensors(weights, dtype, &device);
let config = Config::falcon7b(); let config = Config::falcon7b();
config.validate()?; config.validate()?;
let model = Falcon::load(vb, config)?; let model = Falcon::load(vb, config)?;
println!("loaded the model in {:?}", start.elapsed()); println!("loaded the model in {:?}", start.elapsed());
let generation_options = GenerationOptions { let mut pipeline = TextGeneration::new(
temp: args.temperature, model,
top_p: args.top_p, tokenizer,
repeat_penalty: args.repeat_penalty, args.seed,
repeat_last_n: args.repeat_last_n, args.temperature,
}; &device,
let mut pipeline = args.repeat_penalty,
TextGeneration::new(model, tokenizer, generation_options, args.seed, &device); args.repeat_last_n,
);
pipeline.run(&args.prompt, args.sample_len)?; pipeline.run(&args.prompt, args.sample_len)?;
Ok(()) Ok(())
} }

View File

@ -1,4 +1,5 @@
use candle::{DType, Device, Result, Tensor, D}; use anyhow::Result;
use candle::{DType, Device, Tensor, D};
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder}; use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
const MAX_SEQ_LEN: usize = 5000; const MAX_SEQ_LEN: usize = 5000;
@ -20,7 +21,7 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) { if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
(weight, bias) (weight, bias)
} else { } else {
return Err(err); return Err(err.into());
} }
} }
}; };
@ -81,13 +82,13 @@ impl Default for Config {
impl Config { impl Config {
pub fn validate(&self) -> Result<()> { pub fn validate(&self) -> Result<()> {
if self.alibi { if self.alibi {
candle::bail!("alibi is not supported"); anyhow::bail!("alibi is not supported");
} }
if self.new_decoder_architecture { if self.new_decoder_architecture {
candle::bail!("new_decoder_architecture is not supported"); anyhow::bail!("new_decoder_architecture is not supported");
} }
if self.n_head_kv.is_some() { if self.n_head_kv.is_some() {
candle::bail!("n_head_kv is not supported"); anyhow::bail!("n_head_kv is not supported");
} }
Ok(()) Ok(())
} }

View File

@ -1,45 +0,0 @@
# candle-jina-bert
Jina-Bert is a general large language model with a context size of 8192, [model
card](https://huggingface.co/jinaai/jina-embeddings-v2-base-en). In this example
it can be used for two different tasks:
- Compute sentence embeddings for a prompt.
- Compute similarities between a set of sentences.
## Sentence embeddings
Jina-Bert is used to compute the sentence embeddings for a prompt. The model weights
are downloaded from the hub on the first run.
```bash
cargo run --example jina-bert --release -- --prompt "Here is a test sentence"
> [[[ 0.1595, -0.9885, 0.6494, ..., 0.3003, -0.6901, -1.2355],
> [ 0.0374, -0.1798, 1.3359, ..., 0.6731, 0.2133, -1.6807],
> [ 0.1700, -0.8534, 0.8924, ..., -0.1785, -0.0727, -1.5087],
> ...
> [-0.3113, -1.3665, 0.2027, ..., -0.2519, 0.1711, -1.5811],
> [ 0.0907, -1.0492, 0.5382, ..., 0.0242, -0.7077, -1.0830],
> [ 0.0369, -0.6343, 0.6105, ..., 0.0671, 0.3778, -1.1505]]]
> Tensor[[1, 7, 768], f32]
```
## Similarities
In this example, Jina-Bert is used to compute the sentence embeddings for a set of
sentences (hardcoded in the examples). Then cosine similarities are computed for
each sentence pair and they are reported by decreasing values, hence the first
reported pair contains the two sentences that have the highest similarity score.
The sentence embeddings are computed using average pooling through all the
sentence tokens, including some potential padding.
```bash
cargo run --example jina-bert --release
> score: 0.94 'The new movie is awesome' 'The new movie is so great'
> score: 0.81 'The cat sits outside' 'The cat plays in the garden'
> score: 0.78 'I love pasta' 'Do you like pizza?'
> score: 0.68 'I love pasta' 'The new movie is awesome'
> score: 0.67 'A man is playing guitar' 'A woman watches TV'
```

View File

@ -1,180 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_transformers::models::jina_bert::{BertModel, Config};
use anyhow::Error as E;
use candle::{DType, Module, Tensor};
use candle_nn::VarBuilder;
use clap::Parser;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// When set, compute embeddings for this prompt.
#[arg(long)]
prompt: Option<String>,
/// The number of times to run the prompt.
#[arg(long, default_value = "1")]
n: usize,
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long)]
model: Option<String>,
}
impl Args {
fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
use hf_hub::{api::sync::Api, Repo, RepoType};
let model = match &self.model {
Some(model_file) => std::path::PathBuf::from(model_file),
None => Api::new()?
.repo(Repo::new(
"jinaai/jina-embeddings-v2-base-en".to_string(),
RepoType::Model,
))
.get("model.safetensors")?,
};
let tokenizer = match &self.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => Api::new()?
.repo(Repo::new(
"sentence-transformers/all-MiniLM-L6-v2".to_string(),
RepoType::Model,
))
.get("tokenizer.json")?,
};
let device = candle_examples::device(self.cpu)?;
let config = Config::v2_base();
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let model = BertModel::new(vb, &config)?;
Ok((model, tokenizer))
}
}
fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let start = std::time::Instant::now();
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
let device = &model.device;
if let Some(prompt) = args.prompt {
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
println!("Loaded and encoded {:?}", start.elapsed());
for idx in 0..args.n {
let start = std::time::Instant::now();
let ys = model.forward(&token_ids)?;
if idx == 0 {
println!("{ys}");
}
println!("Took {:?}", start.elapsed());
}
} else {
let sentences = [
"The cat sits outside",
"A man is playing guitar",
"I love pasta",
"The new movie is awesome",
"The cat plays in the garden",
"A woman watches TV",
"The new movie is so great",
"Do you like pizza?",
];
let n_sentences = sentences.len();
if let Some(pp) = tokenizer.get_padding_mut() {
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
} else {
let pp = tokenizers::PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
};
tokenizer.with_padding(Some(pp));
}
let tokens = tokenizer
.encode_batch(sentences.to_vec(), true)
.map_err(E::msg)?;
let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Tensor::new(tokens.as_slice(), device)
})
.collect::<candle::Result<Vec<_>>>()?;
let token_ids = Tensor::stack(&token_ids, 0)?;
println!("running inference on batch {:?}", token_ids.shape());
let embeddings = model.forward(&token_ids)?;
println!("generated embeddings {:?}", embeddings.shape());
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
let embeddings = if args.normalize_embeddings {
normalize_l2(&embeddings)?
} else {
embeddings
};
println!("pooled embeddings {:?}", embeddings.shape());
let mut similarities = vec![];
for i in 0..n_sentences {
let e_i = embeddings.get(i)?;
for j in (i + 1)..n_sentences {
let e_j = embeddings.get(j)?;
let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
similarities.push((cosine_similarity, i, j))
}
}
similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
for &(score, i, j) in similarities[..5].iter() {
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
}
}
Ok(())
}
pub fn normalize_l2(v: &Tensor) -> candle::Result<Tensor> {
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
}

View File

@ -21,10 +21,11 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType}; use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write; use std::io::Write;
use candle_transformers::models::llama as model; mod model;
use model::{Config, Llama, LlamaConfig}; use model::{Config, Llama, LlamaConfig};
const EOS_TOKEN: &str = "</s>"; const EOS_TOKEN: &str = "</s>";
const MAX_SEQ_LEN: usize = 4096;
const DEFAULT_PROMPT: &str = "My favorite theorem is "; const DEFAULT_PROMPT: &str = "My favorite theorem is ";
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -42,10 +43,6 @@ struct Args {
#[arg(long)] #[arg(long)]
temperature: Option<f64>, temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples. /// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)] #[arg(long, default_value_t = 299792458)]
seed: u64, seed: u64,
@ -172,9 +169,17 @@ fn main() -> Result<()> {
} }
println!("building the model"); println!("building the model");
let handles = filenames
.iter()
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? }))
.collect::<Result<Vec<_>>>()?;
let tensors: Vec<_> = handles
.iter()
.map(|h| Ok(h.deserialize()?))
.collect::<Result<Vec<_>>>()?;
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache) (Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
} }
}; };
@ -189,7 +194,7 @@ fn main() -> Result<()> {
println!("starting the inference loop"); println!("starting the inference loop");
print!("{prompt}"); print!("{prompt}");
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p); let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
let start_gen = std::time::Instant::now(); let start_gen = std::time::Instant::now();
let mut index_pos = 0; let mut index_pos = 0;
let mut token_generated = 0; let mut token_generated = 0;

View File

@ -1,11 +1,10 @@
use super::with_tracing::{linear_no_bias as linear, Linear};
use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Module, VarBuilder}; use candle_nn::{Embedding, Module, VarBuilder};
use serde::Deserialize; use serde::Deserialize;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
pub const MAX_SEQ_LEN: usize = 4096; use super::MAX_SEQ_LEN;
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct LlamaConfig { pub struct LlamaConfig {
@ -82,6 +81,21 @@ impl Config {
} }
} }
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
// model.
#[derive(Debug)]
pub struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
}
impl Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
#[derive(Clone)] #[derive(Clone)]
pub struct Cache { pub struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>, masks: Arc<Mutex<HashMap<usize, Tensor>>>,
@ -136,6 +150,12 @@ impl Cache {
} }
} }
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let span = tracing::span!(tracing::Level::TRACE, "linear");
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
Ok(Linear { inner, span })
}
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> { fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?; let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
Ok(Embedding::new(embeddings, cfg.hidden_size)) Ok(Embedding::new(embeddings, cfg.hidden_size))

View File

@ -6,10 +6,9 @@ extern crate accelerate_src;
#[cfg(feature = "mkl")] #[cfg(feature = "mkl")]
extern crate intel_mkl_src; extern crate intel_mkl_src;
use candle_transformers::models::llama2_c as model; mod model;
use candle_transformers::models::llama2_c_weights as weights;
use candle_transformers::models::quantized_llama2_c as qmodel;
mod training; mod training;
mod weights;
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};
@ -20,7 +19,6 @@ use std::io::Write;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use model::{Config, Llama}; use model::{Config, Llama};
use qmodel::QLlama;
use weights::TransformerWeights; use weights::TransformerWeights;
#[derive(Parser, Debug, Clone)] #[derive(Parser, Debug, Clone)]
@ -29,10 +27,6 @@ struct InferenceCmd {
#[arg(long)] #[arg(long)]
temperature: Option<f64>, temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
#[arg(long, default_value = "")] #[arg(long, default_value = "")]
prompt: String, prompt: String,
@ -139,7 +133,6 @@ fn main() -> anyhow::Result<()> {
None => { None => {
let cmd = InferenceCmd { let cmd = InferenceCmd {
temperature: None, temperature: None,
top_p: None,
prompt: "".to_string(), prompt: "".to_string(),
config: None, config: None,
model_id: "karpathy/tinyllamas".to_string(), model_id: "karpathy/tinyllamas".to_string(),
@ -154,20 +147,6 @@ fn main() -> anyhow::Result<()> {
Ok(()) Ok(())
} }
enum Model {
Llama(Llama),
QLlama(QLlama),
}
impl Model {
fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> {
match self {
Self::Llama(l) => Ok(l.forward(xs, pos)?),
Self::QLlama(l) => Ok(l.forward(xs, pos)?),
}
}
}
fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> { fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
use std::io::BufRead; use std::io::BufRead;
@ -257,69 +236,27 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let device = candle_examples::device(common_args.cpu)?; let device = candle_examples::device(common_args.cpu)?;
let is_gguf = config_path.extension().map_or(false, |v| v == "gguf");
let is_safetensors = config_path let is_safetensors = config_path
.extension() .extension()
.map_or(false, |v| v == "safetensors"); .map_or(false, |v| v == "safetensors");
let (model, config) = if is_gguf { let (vb, config) = if is_safetensors {
let vb = qmodel::VarBuilder::from_gguf(config_path)?; let config = Config::tiny();
let (_vocab_size, dim) = vb
.get_no_shape("model.embed_tokens.weight")?
.shape()
.dims2()?;
let config = match dim {
64 => Config::tiny_260k(),
288 => Config::tiny_15m(),
512 => Config::tiny_42m(),
768 => Config::tiny_110m(),
_ => anyhow::bail!("no config for dim {dim}"),
};
let freq_cis_real = vb
.get(
(config.seq_len, config.head_size() / 2),
"rot.freq_cis_real",
)?
.dequantize(&candle::Device::Cpu)?;
let freq_cis_imag = vb
.get(
(config.seq_len, config.head_size() / 2),
"rot.freq_cis_imag",
)?
.dequantize(&candle::Device::Cpu)?;
let fake_vb = candle_nn::VarBuilder::from_tensors(
[
("freq_cis_real".to_string(), freq_cis_real),
("freq_cis_imag".to_string(), freq_cis_imag),
]
.into_iter()
.collect(),
candle::DType::F32,
&candle::Device::Cpu,
);
let cache = model::Cache::new(true, &config, fake_vb)?;
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
(model, config)
} else if is_safetensors {
let config = Config::tiny_15m();
let tensors = candle::safetensors::load(config_path, &device)?; let tensors = candle::safetensors::load(config_path, &device)?;
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device); let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
let cache = model::Cache::new(true, &config, vb.pp("rot"))?; (vb, config)
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
(model, config)
} else { } else {
let mut file = std::fs::File::open(config_path)?; let mut file = std::fs::File::open(config_path)?;
let config = Config::from_reader(&mut file)?; let config = Config::from_reader(&mut file)?;
println!("{config:?}"); println!("{config:?}");
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?; let vb = weights.var_builder(&config, &device)?;
let cache = model::Cache::new(true, &config, vb.pp("rot"))?; (vb, config)
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
(model, config)
}; };
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;
println!("starting the inference loop"); println!("starting the inference loop");
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p); let mut logits_processor = LogitsProcessor::new(299792458, args.temperature);
let mut index_pos = 0; let mut index_pos = 0;
print!("{}", args.prompt); print!("{}", args.prompt);
@ -331,7 +268,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let start_gen = std::time::Instant::now(); let start_gen = std::time::Instant::now();
for index in 0.. { for index in 0.. {
if tokens.len() >= config.seq_len { if tokens.len() >= model.config.seq_len {
break; break;
} }
let context_size = if index > 0 { 1 } else { tokens.len() }; let context_size = if index > 0 { 1 } else { tokens.len() };

View File

@ -17,20 +17,7 @@ pub struct Config {
} }
impl Config { impl Config {
pub fn tiny_260k() -> Self { pub fn tiny() -> Self {
Self {
dim: 64,
hidden_dim: 768,
n_layers: 5,
n_heads: 8,
n_kv_heads: 4,
vocab_size: 32000,
seq_len: 512,
norm_eps: 1e-5,
}
}
pub fn tiny_15m() -> Self {
Self { Self {
dim: 288, dim: 288,
hidden_dim: 768, hidden_dim: 768,
@ -42,32 +29,6 @@ impl Config {
norm_eps: 1e-5, norm_eps: 1e-5,
} }
} }
pub fn tiny_42m() -> Self {
Self {
dim: 512,
hidden_dim: 768,
n_layers: 8,
n_heads: 8,
n_kv_heads: 8,
vocab_size: 32000,
seq_len: 1024,
norm_eps: 1e-5,
}
}
pub fn tiny_110m() -> Self {
Self {
dim: 768,
hidden_dim: 768,
n_layers: 12,
n_heads: 12,
n_kv_heads: 12,
vocab_size: 32000,
seq_len: 1024,
norm_eps: 1e-5,
}
}
} }
#[derive(Clone)] #[derive(Clone)]
@ -75,9 +36,9 @@ pub struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>, masks: Arc<Mutex<HashMap<usize, Tensor>>>,
pub use_kv_cache: bool, pub use_kv_cache: bool,
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
pub kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>, kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
pub cos: Tensor, cos: Tensor,
pub sin: Tensor, sin: Tensor,
device: Device, device: Device,
} }
@ -114,7 +75,7 @@ impl Cache {
}) })
} }
pub fn mask(&self, t: usize) -> Result<Tensor> { fn mask(&self, t: usize) -> Result<Tensor> {
let mut masks = self.masks.lock().unwrap(); let mut masks = self.masks.lock().unwrap();
if let Some(mask) = masks.get(&t) { if let Some(mask) = masks.get(&t) {
Ok(mask.clone()) Ok(mask.clone())

View File

@ -33,7 +33,7 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
); );
let varmap = candle_nn::VarMap::new(); let varmap = candle_nn::VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device); let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
let config = Config::tiny_15m(); let config = Config::tiny();
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone()); let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);

View File

@ -1,8 +1,9 @@
use anyhow::Result;
use byteorder::{LittleEndian, ReadBytesExt}; use byteorder::{LittleEndian, ReadBytesExt};
use candle::{DType, Device, IndexOp, Result, Shape, Tensor}; use candle::{DType, Device, IndexOp, Shape, Tensor};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
use super::llama2_c::Config; use crate::model::Config;
pub struct TransformerWeights { pub struct TransformerWeights {
// token embedding table // token embedding table

View File

@ -89,10 +89,6 @@ struct Args {
#[arg(long)] #[arg(long)]
temperature: Option<f64>, temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples. /// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)] #[arg(long, default_value_t = 299792458)]
seed: u64, seed: u64,
@ -205,9 +201,16 @@ fn main() -> Result<()> {
let cache = model::Cache::new(dtype, &config, &device)?; let cache = model::Cache::new(dtype, &config, &device)?;
println!("building the model"); println!("building the model");
let vb = unsafe { let handles = filenames
candle_nn::var_builder::ShardedSafeTensors::var_builder(&filenames, dtype, &device)? .iter()
}; .map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? }))
.collect::<Result<Vec<_>>>()?;
let tensors: Vec<_> = handles
.iter()
.map(|h| Ok(h.deserialize()?))
.collect::<Result<Vec<_>>>()?;
let vb = candle_nn::var_builder::ShardedSafeTensors::var_builder(tensors, dtype, &device);
let llama = Llama::load(vb, &cache, &config, comm)?; let llama = Llama::load(vb, &cache, &config, comm)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
@ -219,7 +222,7 @@ fn main() -> Result<()> {
.to_vec(); .to_vec();
println!("starting the inference loop"); println!("starting the inference loop");
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p); let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
let mut new_tokens = vec![]; let mut new_tokens = vec![];
let start_gen = std::time::Instant::now(); let start_gen = std::time::Instant::now();
let mut index_pos = 0; let mut index_pos = 0;

View File

@ -1,38 +0,0 @@
# candle-marian-mt
`marian-mt` is a neural machine translation model. In this example it is used to
translate text from French to English. See the associated [model
card](https://huggingface.co/Helsinki-NLP/opus-mt-tc-big-fr-en) for details on
the model itself.
## Running an example
```bash
cargo run --example marian-mt --release -- \
--text "Demain, dès l'aube, à l'heure où blanchit la campagne, Je partirai. Vois-tu, je sais que tu m'attends. J'irai par la forêt, j'irai par la montagne. Je ne puis demeurer loin de toi plus longtemps."
```
```
<NIL> Tomorrow, at dawn, at the time when the country is whitening, I will go. See,
I know you are waiting for me. I will go through the forest, I will go through the
mountain. I cannot stay far from you any longer.</s>
```
## Generating the tokenizer.json files
You can use the following script to generate the `tokenizer.json` config files
from the hf-hub repos. This requires the `tokenizers` and `sentencepiece`
packages to be install and use the `convert_slow_tokenizer.py` script from this
directory.
```python
from convert_slow_tokenizer import MarianConverter
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
fast_tokenizer.save(f"tokenizer-marian-base-fr.json")
fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
fast_tokenizer.save(f"tokenizer-marian-base-en.json")
```

File diff suppressed because it is too large Load Diff

View File

@ -1,152 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::{Parser, ValueEnum};
use candle::{DType, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::models::marian;
use tokenizers::Tokenizer;
#[derive(Clone, Debug, Copy, ValueEnum)]
enum Which {
Base,
Big,
}
// TODO: Maybe add support for the conditional prompt.
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long)]
tokenizer_dec: Option<String>,
/// Choose the variant of the model to run.
#[arg(long, default_value = "big")]
which: Which,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Use the quantized version of the model.
#[arg(long)]
quantized: bool,
/// Text to be translated
#[arg(long)]
text: String,
}
pub fn main() -> anyhow::Result<()> {
use hf_hub::api::sync::Api;
let args = Args::parse();
let config = match args.which {
Which::Base => marian::Config::opus_mt_fr_en(),
Which::Big => marian::Config::opus_mt_tc_big_fr_en(),
};
let tokenizer = {
let tokenizer = match args.tokenizer {
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
None => {
let name = match args.which {
Which::Base => "tokenizer-marian-base-fr.json",
Which::Big => "tokenizer-marian-fr.json",
};
Api::new()?
.model("lmz/candle-marian".to_string())
.get(name)?
}
};
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
};
let tokenizer_dec = {
let tokenizer = match args.tokenizer_dec {
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
None => {
let name = match args.which {
Which::Base => "tokenizer-marian-base-en.json",
Which::Big => "tokenizer-marian-en.json",
};
Api::new()?
.model("lmz/candle-marian".to_string())
.get(name)?
}
};
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
};
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
let device = candle_examples::device(args.cpu)?;
let vb = {
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => match args.which {
Which::Base => Api::new()?
.repo(hf_hub::Repo::with_revision(
"Helsinki-NLP/opus-mt-fr-en".to_string(),
hf_hub::RepoType::Model,
"refs/pr/4".to_string(),
))
.get("model.safetensors")?,
Which::Big => Api::new()?
.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
.get("model.safetensors")?,
},
};
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
};
let mut model = marian::MTModel::new(&config, vb)?;
let mut logits_processor =
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
let encoder_xs = {
let mut tokens = tokenizer
.encode(args.text, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
tokens.push(config.eos_token_id);
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
model.encoder().forward(&tokens, 0)?
};
let mut token_ids = vec![config.decoder_start_token_id];
for index in 0..1000 {
let context_size = if index >= 1 { 1 } else { token_ids.len() };
let start_pos = token_ids.len().saturating_sub(context_size);
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
let logits = model.decode(&input_ids, &encoder_xs, start_pos)?;
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
token_ids.push(token);
if let Some(t) = tokenizer_dec.next_token(token)? {
use std::io::Write;
print!("{t}");
std::io::stdout().flush()?;
}
if token == config.eos_token_id || token == config.forced_eos_token_id {
break;
}
}
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
println!();
Ok(())
}

Some files were not shown because too many files have changed in this diff Show More