mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
8 Commits
Author | SHA1 | Date | |
---|---|---|---|
4114872aae | |||
f2a648f313 | |||
ec895453cd | |||
3769d8bf71 | |||
5d8e214dfe | |||
576bf7c21f | |||
49a4fa44bb | |||
b936e32e11 |
2
.github/workflows/ci_cuda.yaml
vendored
2
.github/workflows/ci_cuda.yaml
vendored
@ -59,7 +59,7 @@ jobs:
|
||||
- name: Install Rust Stable
|
||||
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- run: apt-get update -y && apt-get install libssl-dev protobuf-compiler -y
|
||||
- run: apt-get update -y && apt-get install libssl-dev -y
|
||||
- name: Test (cuda)
|
||||
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
|
||||
stop-runner:
|
||||
|
BIN
.github/workflows/maturin.yml
vendored
BIN
.github/workflows/maturin.yml
vendored
Binary file not shown.
68
.github/workflows/python.yml
vendored
68
.github/workflows/python.yml
vendored
@ -1,68 +0,0 @@
|
||||
name: PyO3-CI
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- candle-pyo3/**
|
||||
pull_request:
|
||||
paths:
|
||||
- candle-pyo3/**
|
||||
|
||||
jobs:
|
||||
build_and_test:
|
||||
name: Check everything builds & tests
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest] # For now, only test on Linux
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Install Rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: stable
|
||||
|
||||
- name: Install Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.11
|
||||
architecture: "x64"
|
||||
|
||||
- name: Cache Cargo Registry
|
||||
uses: actions/cache@v1
|
||||
with:
|
||||
path: ~/.cargo/registry
|
||||
key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}
|
||||
|
||||
- name: Install Protoc
|
||||
uses: arduino/setup-protoc@v2
|
||||
with:
|
||||
version: "25.0"
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Install
|
||||
working-directory: ./candle-pyo3
|
||||
run: |
|
||||
python -m venv .env
|
||||
source .env/bin/activate
|
||||
pip install -U pip
|
||||
pip install pytest maturin black
|
||||
python -m maturin develop -r --features onnx
|
||||
|
||||
- name: Check style
|
||||
working-directory: ./candle-pyo3
|
||||
run: |
|
||||
source .env/bin/activate
|
||||
python stub.py --check
|
||||
black --check .
|
||||
|
||||
- name: Run tests
|
||||
working-directory: ./candle-pyo3
|
||||
run: |
|
||||
source .env/bin/activate
|
||||
python -m pytest -s -v tests
|
8
.gitignore
vendored
8
.gitignore
vendored
@ -23,16 +23,14 @@ flamegraph.svg
|
||||
*.dylib
|
||||
*.so
|
||||
*.swp
|
||||
*.swo
|
||||
trace-*.json
|
||||
|
||||
candle-wasm-examples/*/build
|
||||
candle-wasm-examples/*/*.bin
|
||||
candle-wasm-examples/*/*.jpeg
|
||||
candle-wasm-examples/*/audios/*.wav
|
||||
candle-wasm-examples/**/*.safetensors
|
||||
candle-wasm-examples/**/*.gguf
|
||||
candle-wasm-examples/*/*.wav
|
||||
candle-wasm-examples/*/*.safetensors
|
||||
candle-wasm-examples/*/package-lock.json
|
||||
candle-wasm-examples/**/config*.json
|
||||
|
||||
.DS_Store
|
||||
.idea/*
|
||||
|
11
.vscode/settings.json
vendored
11
.vscode/settings.json
vendored
@ -1,11 +0,0 @@
|
||||
{
|
||||
"[python]": {
|
||||
"editor.defaultFormatter": "ms-python.black-formatter"
|
||||
},
|
||||
"python.formatting.provider": "none",
|
||||
"python.testing.pytestArgs": [
|
||||
"candle-pyo3"
|
||||
],
|
||||
"python.testing.unittestEnabled": false,
|
||||
"python.testing.pytestEnabled": true
|
||||
}
|
28
CHANGELOG.md
28
CHANGELOG.md
@ -1,38 +1,12 @@
|
||||
# Changelog
|
||||
This documents the main changes to the `candle` crate.
|
||||
|
||||
## v0.3.1 - Unreleased
|
||||
## v0.2.3 - Unreleased
|
||||
|
||||
### Added
|
||||
|
||||
### Modified
|
||||
|
||||
## v0.3.0 - 2023-10-01
|
||||
|
||||
### Added
|
||||
|
||||
- Added the Mistral 7b v0.1 model
|
||||
[983](https://github.com/huggingface/candle/pull/983).
|
||||
- Quantized version of the Mistral model
|
||||
[1009](https://github.com/huggingface/candle/pull/1009).
|
||||
- Add the gelu-erf op and activation function
|
||||
[969](https://github.com/huggingface/candle/pull/969).
|
||||
- Add the mixformer/phi-v1.5 model
|
||||
[930](https://github.com/huggingface/candle/pull/930).
|
||||
- Add the sclice-scatter op
|
||||
[927](https://github.com/huggingface/candle/pull/927).
|
||||
- Add the Wuerstchen diffusion model
|
||||
[911](https://github.com/huggingface/candle/pull/911).
|
||||
|
||||
### Modified
|
||||
|
||||
- Support for simd128 intrinsics in some quantized vecdots
|
||||
[982](https://github.com/huggingface/candle/pull/982).
|
||||
- Optimize the index-select cuda kernel
|
||||
[976](https://github.com/huggingface/candle/pull/976).
|
||||
- Self-contained safetensor wrappers
|
||||
[946](https://github.com/huggingface/candle/pull/946).
|
||||
|
||||
## v0.2.2 - 2023-09-18
|
||||
|
||||
### Added
|
||||
|
24
Cargo.toml
24
Cargo.toml
@ -7,19 +7,19 @@ members = [
|
||||
"candle-nn",
|
||||
"candle-pyo3",
|
||||
"candle-transformers",
|
||||
"candle-wasm-examples/*",
|
||||
"candle-wasm-tests",
|
||||
"candle-wasm-examples/llama2-c",
|
||||
"candle-wasm-examples/segment-anything",
|
||||
"candle-wasm-examples/whisper",
|
||||
"candle-wasm-examples/yolo",
|
||||
]
|
||||
exclude = [
|
||||
"candle-flash-attn",
|
||||
"candle-kernels",
|
||||
"candle-metal-kernels",
|
||||
"candle-onnx",
|
||||
"candle-flash-attn",
|
||||
"candle-kernels",
|
||||
]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.3.0"
|
||||
version = "0.2.3"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -33,7 +33,8 @@ anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
cudarc = { version = "0.9.14", features = ["f16"] }
|
||||
gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] }
|
||||
# TODO: Switch back to the official gemm implementation once it has caught up.
|
||||
gemm = { version = "0.16.0", package = "candle-gemm" }
|
||||
hf-hub = "0.3.0"
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
|
||||
@ -41,17 +42,15 @@ imageproc = { version = "0.23.0", default-features = false }
|
||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||
libc = { version = "0.2.147" }
|
||||
log = "0.4"
|
||||
memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] }
|
||||
memmap2 = "0.7.1"
|
||||
num_cpus = "1.15.0"
|
||||
num-traits = "0.2.15"
|
||||
parquet = { version = "45.0.0" }
|
||||
rand = "0.8.5"
|
||||
rand_distr = "0.4.3"
|
||||
rayon = "1.7.0"
|
||||
rusttype = { version = "0.9", default-features = false }
|
||||
safetensors = "0.3.1"
|
||||
serde = { version = "1.0.171", features = ["derive"] }
|
||||
serde_plain = "1.0.2"
|
||||
serde_json = "1.0.99"
|
||||
thiserror = "1"
|
||||
tokenizers = { version = "0.13.4", default-features = false }
|
||||
@ -59,9 +58,8 @@ tracing = "0.1.37"
|
||||
tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
wav = "1.0.0"
|
||||
yoke = { version = "0.7.2", features = ["derive"] }
|
||||
zip = { version = "0.6.6", default-features = false }
|
||||
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||
parquet = { version = "45.0.0" }
|
||||
|
||||
[profile.release-with-debug]
|
||||
inherits = "release"
|
||||
|
72
README.md
72
README.md
@ -8,7 +8,6 @@ Candle is a minimalist ML framework for Rust with a focus on performance (includ
|
||||
and ease of use. Try our online demos:
|
||||
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
||||
[LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2),
|
||||
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
|
||||
[yolo](https://huggingface.co/spaces/lmz/candle-yolo),
|
||||
[Segment
|
||||
Anything](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
|
||||
@ -51,26 +50,16 @@ For more advanced examples, please have a look at the following section.
|
||||
These online demos run entirely in your browser:
|
||||
- [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and
|
||||
object recognition.
|
||||
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): speech recognition.
|
||||
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): text to speech.
|
||||
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
|
||||
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
|
||||
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
|
||||
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
|
||||
- [BLIP](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning): image captioning.
|
||||
|
||||
We also provide a some command line based examples using state of the art models:
|
||||
|
||||
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
|
||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||
- [Phi-v1 and Phi-v1.5](./candle-examples/examples/phi/): a 1.3b general LLM with performance on par with LLaMA-v2 7b.
|
||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||
pre-trained on 1T tokens of English and code datasets.
|
||||
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
||||
performance larger than all publicly available 13b models as of 2023-09-28.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
||||
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
||||
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
|
||||
(English/Chinese) general LLMs with 6b and 34b parameters.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code
|
||||
generation.
|
||||
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
|
||||
the LLaMA model using the same quantization techniques as
|
||||
[llama.cpp](https://github.com/ggerganov/llama.cpp).
|
||||
@ -82,11 +71,6 @@ We also provide a some command line based examples using state of the art models
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200">
|
||||
|
||||
- [Wuerstchen](./candle-examples/examples/wuerstchen/): another text to
|
||||
image generative model.
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/wuerstchen/assets/cat.jpg" width="200">
|
||||
|
||||
- [yolo-v3](./candle-examples/examples/yolo-v3/) and
|
||||
[yolo-v8](./candle-examples/examples/yolo-v8/): object detection and pose
|
||||
estimation models.
|
||||
@ -98,15 +82,10 @@ We also provide a some command line based examples using state of the art models
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
|
||||
|
||||
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
||||
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
|
||||
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
|
||||
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
|
||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||
using self-supervision (can be used for imagenet classification, depth
|
||||
evaluation, segmentation).
|
||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||
generate captions for an image.
|
||||
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
||||
model, generates the translated text from the input text.
|
||||
|
||||
Run them using commands like:
|
||||
```
|
||||
@ -121,8 +100,6 @@ There are also some wasm examples for whisper and
|
||||
`trunk` or try them online:
|
||||
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
||||
[llama2](https://huggingface.co/spaces/lmz/candle-llama2),
|
||||
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
|
||||
[Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm),
|
||||
[Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
|
||||
|
||||
For LLaMA2, run the following command to retrieve the weight files and start a
|
||||
@ -138,18 +115,8 @@ And then head over to
|
||||
|
||||
<!--- ANCHOR: useful_libraries --->
|
||||
|
||||
## Useful External Resources
|
||||
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): a
|
||||
very detailed tutorial showing how to convert a PyTorch model to Candle.
|
||||
- [`optimisers`](https://github.com/KGrewal1/optimisers): a collection of optimisers
|
||||
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
|
||||
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): a LoRA implementation
|
||||
that conforms to the official `peft` implementation.
|
||||
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
|
||||
serving local LLMs including an OpenAI compatible API server.
|
||||
- [`candle-ext`](https://github.com/mokeyish/candle-ext): an extension library to Candle that provides PyTorch functions not currently available in Candle.
|
||||
- [`kalosm`](https://github.com/floneum/floneum/tree/master/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
||||
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
||||
## Useful Libraries
|
||||
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora) provides a LoRA implementation that conforms to the official `peft` implementation.
|
||||
|
||||
If you have an addition to this list, please submit a pull request.
|
||||
|
||||
@ -171,24 +138,15 @@ If you have an addition to this list, please submit a pull request.
|
||||
- LLaMA v1 and v2.
|
||||
- Falcon.
|
||||
- StarCoder.
|
||||
- Phi v1.5.
|
||||
- Mistral 7b v0.1.
|
||||
- StableLM-3B-4E1T.
|
||||
- Replit-code-v1.5-3B.
|
||||
- T5.
|
||||
- Bert.
|
||||
- Yi-6B and Yi-34B.
|
||||
- Text to text.
|
||||
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
|
||||
- Marian MT (Machine Translation).
|
||||
- Whisper (multi-lingual support).
|
||||
- Text to image.
|
||||
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
||||
- Wurstchen v2.
|
||||
- Image to text.
|
||||
- BLIP.
|
||||
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
||||
- Computer Vision Models.
|
||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
|
||||
- yolo-v3, yolo-v8.
|
||||
- DINOv2.
|
||||
- EfficientNet.
|
||||
- yolo-v3.
|
||||
- yolo-v8.
|
||||
- Segment-Anything Model (SAM).
|
||||
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
||||
- Serverless (on CPU), small and fast deployments.
|
||||
@ -226,7 +184,6 @@ Cheatsheet:
|
||||
- [candle-datasets](./candle-datasets/): Datasets and data loaders.
|
||||
- [candle-transformers](./candle-transformers): transformers-related utilities.
|
||||
- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer.
|
||||
- [candle-onnx](./candle-onnx/): ONNX model evaluation.
|
||||
|
||||
## FAQ
|
||||
|
||||
@ -349,11 +306,6 @@ mdbook test candle-book -L .\target\debug\deps\ `
|
||||
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.48.5\lib
|
||||
```
|
||||
|
||||
#### Extremely slow model load time with WSL
|
||||
|
||||
This may be caused by the models being loaded from `/mnt/c`, more details on
|
||||
[stackoverflow](https://stackoverflow.com/questions/68972448/why-is-wsl-extremely-slow-when-compared-with-native-windows-npm-yarn-processing).
|
||||
|
||||
#### Tracking down errors
|
||||
|
||||
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
|
||||
|
@ -11,11 +11,11 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.2.3" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.3" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.2.3" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.3", optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
@ -24,10 +24,9 @@ intel-mkl-src = { workspace = true, optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
image = { workspace = true, optional = true }
|
||||
anyhow = { workspace = true }
|
||||
tokio = "1.29.1"
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
byteorder = { workspace = true }
|
||||
hf-hub = { workspace = true, features=["tokio"]}
|
||||
clap = { workspace = true }
|
||||
@ -39,6 +38,7 @@ tracing-chrome = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||
tokio = "1.29.1"
|
||||
parquet = { workspace = true }
|
||||
image = { workspace = true }
|
||||
|
||||
|
@ -14,7 +14,6 @@
|
||||
- [Using the hub](inference/hub.md)
|
||||
- [Error management](error_manage.md)
|
||||
- [Training](training/training.md)
|
||||
- [Simplified](training/simplified.md)
|
||||
- [MNIST](training/mnist.md)
|
||||
- [Fine-tuning]()
|
||||
- [Serialization]()
|
||||
|
@ -12,9 +12,6 @@ compute_cap
|
||||
8.9
|
||||
```
|
||||
|
||||
You can also compile the Cuda kernels for a specific compute cap using the
|
||||
`CUDA_COMPUTE_CAP=<compute cap>` environment variable.
|
||||
|
||||
If any of the above commands errors out, please make sure to update your Cuda version.
|
||||
|
||||
2. Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) with Cuda support.
|
||||
|
@ -1,6 +1,3 @@
|
||||
#[cfg(test)]
|
||||
pub mod simplified;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
|
@ -1,196 +0,0 @@
|
||||
//! #A simplified example in Rust of training a neural network and then using it based on the Candle Framework by Hugging Face.
|
||||
//! Author: Evgeny Igumnov 2023 igumnovnsk@gmail.com
|
||||
//! This program implements a neural network to predict the winner of the second round of elections based on the results of the first round.
|
||||
//!
|
||||
//! ##Basic moments:
|
||||
//!
|
||||
//! A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons.
|
||||
//! The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage.
|
||||
//! The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose.
|
||||
//! For training, samples with real data on the results of the first and second stages of different elections are used.
|
||||
//! The model is trained by backpropagation using gradient descent and the cross-entropy loss function.
|
||||
//! Model parameters (weights of neurons) are initialized randomly, then optimized during training.
|
||||
//! After training, the model is tested on a deferred sample to evaluate the accuracy.
|
||||
//! If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated.
|
||||
//! Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data.
|
||||
|
||||
#[rustfmt::skip]
|
||||
mod tests {
|
||||
|
||||
use candle::{DType, Result, Tensor, D, Device};
|
||||
use candle_nn::{loss, ops, Linear, Module, VarBuilder, VarMap, Optimizer};
|
||||
|
||||
// ANCHOR: book_training_simplified1
|
||||
const VOTE_DIM: usize = 2;
|
||||
const RESULTS: usize = 1;
|
||||
const EPOCHS: usize = 10;
|
||||
const LAYER1_OUT_SIZE: usize = 4;
|
||||
const LAYER2_OUT_SIZE: usize = 2;
|
||||
const LEARNING_RATE: f64 = 0.05;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Dataset {
|
||||
pub train_votes: Tensor,
|
||||
pub train_results: Tensor,
|
||||
pub test_votes: Tensor,
|
||||
pub test_results: Tensor,
|
||||
}
|
||||
|
||||
struct MultiLevelPerceptron {
|
||||
ln1: Linear,
|
||||
ln2: Linear,
|
||||
ln3: Linear,
|
||||
}
|
||||
|
||||
impl MultiLevelPerceptron {
|
||||
fn new(vs: VarBuilder) -> Result<Self> {
|
||||
let ln1 = candle_nn::linear(VOTE_DIM, LAYER1_OUT_SIZE, vs.pp("ln1"))?;
|
||||
let ln2 = candle_nn::linear(LAYER1_OUT_SIZE, LAYER2_OUT_SIZE, vs.pp("ln2"))?;
|
||||
let ln3 = candle_nn::linear(LAYER2_OUT_SIZE, RESULTS + 1, vs.pp("ln3"))?;
|
||||
Ok(Self { ln1, ln2, ln3 })
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.ln1.forward(xs)?;
|
||||
let xs = xs.relu()?;
|
||||
let xs = self.ln2.forward(&xs)?;
|
||||
let xs = xs.relu()?;
|
||||
self.ln3.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
// ANCHOR_END: book_training_simplified1
|
||||
|
||||
|
||||
|
||||
// ANCHOR: book_training_simplified3
|
||||
#[tokio::test]
|
||||
async fn simplified() -> anyhow::Result<()> {
|
||||
|
||||
let dev = Device::cuda_if_available(0)?;
|
||||
|
||||
let train_votes_vec: Vec<u32> = vec![
|
||||
15, 10,
|
||||
10, 15,
|
||||
5, 12,
|
||||
30, 20,
|
||||
16, 12,
|
||||
13, 25,
|
||||
6, 14,
|
||||
31, 21,
|
||||
];
|
||||
let train_votes_tensor = Tensor::from_vec(train_votes_vec.clone(), (train_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
|
||||
|
||||
let train_results_vec: Vec<u32> = vec![
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
];
|
||||
let train_results_tensor = Tensor::from_vec(train_results_vec, train_votes_vec.len() / VOTE_DIM, &dev)?;
|
||||
|
||||
let test_votes_vec: Vec<u32> = vec![
|
||||
13, 9,
|
||||
8, 14,
|
||||
3, 10,
|
||||
];
|
||||
let test_votes_tensor = Tensor::from_vec(test_votes_vec.clone(), (test_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
|
||||
|
||||
let test_results_vec: Vec<u32> = vec![
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
];
|
||||
let test_results_tensor = Tensor::from_vec(test_results_vec.clone(), test_results_vec.len(), &dev)?;
|
||||
|
||||
let m = Dataset {
|
||||
train_votes: train_votes_tensor,
|
||||
train_results: train_results_tensor,
|
||||
test_votes: test_votes_tensor,
|
||||
test_results: test_results_tensor,
|
||||
};
|
||||
|
||||
let trained_model: MultiLevelPerceptron;
|
||||
loop {
|
||||
println!("Trying to train neural network.");
|
||||
match train(m.clone(), &dev) {
|
||||
Ok(model) => {
|
||||
trained_model = model;
|
||||
break;
|
||||
},
|
||||
Err(e) => {
|
||||
println!("Error: {}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
let real_world_votes: Vec<u32> = vec![
|
||||
13, 22,
|
||||
];
|
||||
|
||||
let tensor_test_votes = Tensor::from_vec(real_world_votes.clone(), (1, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
|
||||
|
||||
let final_result = trained_model.forward(&tensor_test_votes)?;
|
||||
|
||||
let result = final_result
|
||||
.argmax(D::Minus1)?
|
||||
.to_dtype(DType::F32)?
|
||||
.get(0).map(|x| x.to_scalar::<f32>())??;
|
||||
println!("real_life_votes: {:?}", real_world_votes);
|
||||
println!("neural_network_prediction_result: {:?}", result);
|
||||
|
||||
Ok(())
|
||||
|
||||
}
|
||||
// ANCHOR_END: book_training_simplified3
|
||||
|
||||
// ANCHOR: book_training_simplified2
|
||||
fn train(m: Dataset, dev: &Device) -> anyhow::Result<MultiLevelPerceptron> {
|
||||
let train_results = m.train_results.to_device(dev)?;
|
||||
let train_votes = m.train_votes.to_device(dev)?;
|
||||
let varmap = VarMap::new();
|
||||
let vs = VarBuilder::from_varmap(&varmap, DType::F32, dev);
|
||||
let model = MultiLevelPerceptron::new(vs.clone())?;
|
||||
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), LEARNING_RATE)?;
|
||||
let test_votes = m.test_votes.to_device(dev)?;
|
||||
let test_results = m.test_results.to_device(dev)?;
|
||||
let mut final_accuracy: f32 = 0.0;
|
||||
for epoch in 1..EPOCHS + 1 {
|
||||
let logits = model.forward(&train_votes)?;
|
||||
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
||||
let loss = loss::nll(&log_sm, &train_results)?;
|
||||
sgd.backward_step(&loss)?;
|
||||
|
||||
let test_logits = model.forward(&test_votes)?;
|
||||
let sum_ok = test_logits
|
||||
.argmax(D::Minus1)?
|
||||
.eq(&test_results)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_scalar::<f32>()?;
|
||||
let test_accuracy = sum_ok / test_results.dims1()? as f32;
|
||||
final_accuracy = 100. * test_accuracy;
|
||||
println!("Epoch: {epoch:3} Train loss: {:8.5} Test accuracy: {:5.2}%",
|
||||
loss.to_scalar::<f32>()?,
|
||||
final_accuracy
|
||||
);
|
||||
if final_accuracy == 100.0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if final_accuracy < 100.0 {
|
||||
Err(anyhow::Error::msg("The model is not trained well enough."))
|
||||
} else {
|
||||
Ok(model)
|
||||
}
|
||||
}
|
||||
// ANCHOR_END: book_training_simplified2
|
||||
|
||||
|
||||
}
|
@ -1,45 +0,0 @@
|
||||
# Simplified
|
||||
|
||||
## How its works
|
||||
|
||||
This program implements a neural network to predict the winner of the second round of elections based on the results of the first round.
|
||||
|
||||
Basic moments:
|
||||
|
||||
1. A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons.
|
||||
2. The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage.
|
||||
3. The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose.
|
||||
4. For training, samples with real data on the results of the first and second stages of different elections are used.
|
||||
5. The model is trained by backpropagation using gradient descent and the cross-entropy loss function.
|
||||
6. Model parameters (weights of neurons) are initialized randomly, then optimized during training.
|
||||
7. After training, the model is tested on a deferred sample to evaluate the accuracy.
|
||||
8. If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated.
|
||||
|
||||
Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data.
|
||||
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../simplified.rs:book_training_simplified1}}
|
||||
```
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../simplified.rs:book_training_simplified2}}
|
||||
```
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../simplified.rs:book_training_simplified3}}
|
||||
```
|
||||
|
||||
|
||||
## Example output
|
||||
|
||||
```bash
|
||||
Trying to train neural network.
|
||||
Epoch: 1 Train loss: 4.42555 Test accuracy: 0.00%
|
||||
Epoch: 2 Train loss: 0.84677 Test accuracy: 33.33%
|
||||
Epoch: 3 Train loss: 2.54335 Test accuracy: 33.33%
|
||||
Epoch: 4 Train loss: 0.37806 Test accuracy: 33.33%
|
||||
Epoch: 5 Train loss: 0.36647 Test accuracy: 100.00%
|
||||
real_life_votes: [13, 22]
|
||||
neural_network_prediction_result: 0.0
|
||||
```
|
@ -12,9 +12,7 @@ readme = "README.md"
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
|
||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
||||
metal = { workspace = true, optional = true}
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.2.3", optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
gemm = { workspace = true }
|
||||
half = { workspace = true }
|
||||
@ -28,7 +26,6 @@ rand_distr = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
yoke = { workspace = true }
|
||||
zip = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
@ -41,4 +38,3 @@ cuda = ["cudarc", "dep:candle-kernels"]
|
||||
cudnn = ["cuda", "cudarc/cudnn"]
|
||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||
metal = ["dep:metal", "dep:candle-metal-kernels"]
|
||||
|
@ -103,10 +103,8 @@ enum Command {
|
||||
|
||||
Quantize {
|
||||
/// The input file, in gguf format.
|
||||
in_file: Vec<std::path::PathBuf>,
|
||||
|
||||
in_file: std::path::PathBuf,
|
||||
/// The output file, in gguf format.
|
||||
#[arg(long)]
|
||||
out_file: std::path::PathBuf,
|
||||
|
||||
/// The quantization schema to apply.
|
||||
@ -152,7 +150,8 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
|
||||
}
|
||||
}
|
||||
Format::Safetensors => {
|
||||
let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? };
|
||||
let tensors = unsafe { candle_core::safetensors::MmapedFile::new(file)? };
|
||||
let tensors = tensors.deserialize()?;
|
||||
let mut tensors = tensors.tensors();
|
||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
for (name, view) in tensors.iter() {
|
||||
@ -219,99 +218,15 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_quantize_safetensors(
|
||||
in_files: &[std::path::PathBuf],
|
||||
out_file: std::path::PathBuf,
|
||||
q: Quantization,
|
||||
) -> Result<()> {
|
||||
let mut out_file = std::fs::File::create(out_file)?;
|
||||
let mut tensors = std::collections::HashMap::new();
|
||||
for in_file in in_files.iter() {
|
||||
let in_tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?;
|
||||
tensors.extend(in_tensors)
|
||||
}
|
||||
println!("tensors: {}", tensors.len());
|
||||
|
||||
let quantize_fn = match q {
|
||||
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
|
||||
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
|
||||
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
|
||||
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
|
||||
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
|
||||
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
|
||||
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
|
||||
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
|
||||
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
|
||||
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
|
||||
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
|
||||
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
|
||||
Quantization::F16 => QTensor::quantize::<half::f16>,
|
||||
Quantization::F32 => QTensor::quantize::<f32>,
|
||||
};
|
||||
let block_size = match q {
|
||||
Quantization::Q4_0 => k_quants::QK4_0,
|
||||
Quantization::Q4_1 => k_quants::QK4_1,
|
||||
Quantization::Q5_0 => k_quants::QK5_0,
|
||||
Quantization::Q5_1 => k_quants::QK5_1,
|
||||
Quantization::Q8_0 => k_quants::QK8_0,
|
||||
Quantization::Q8_1 => k_quants::QK8_1,
|
||||
Quantization::Q2k
|
||||
| Quantization::Q3k
|
||||
| Quantization::Q4k
|
||||
| Quantization::Q5k
|
||||
| Quantization::Q6k
|
||||
| Quantization::Q8k => k_quants::QK_K,
|
||||
Quantization::F16 | Quantization::F32 => 1,
|
||||
};
|
||||
|
||||
let qtensors = tensors
|
||||
.into_par_iter()
|
||||
.map(|(name, tensor)| {
|
||||
let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0;
|
||||
println!(" quantizing {name} {tensor:?} {should_quantize}");
|
||||
let tensor = if should_quantize {
|
||||
quantize_fn(&tensor)?
|
||||
} else {
|
||||
QTensor::quantize::<f32>(&tensor)?
|
||||
};
|
||||
Ok((name, tensor))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let qtensors = qtensors
|
||||
.iter()
|
||||
.map(|(k, v)| (k.as_str(), v))
|
||||
.collect::<Vec<_>>();
|
||||
gguf_file::write(&mut out_file, &[], &qtensors)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_quantize(
|
||||
in_files: &[std::path::PathBuf],
|
||||
in_file: std::path::PathBuf,
|
||||
out_file: std::path::PathBuf,
|
||||
q: Quantization,
|
||||
qmode: QuantizationMode,
|
||||
) -> Result<()> {
|
||||
if in_files.is_empty() {
|
||||
candle_core::bail!("no specified input files")
|
||||
}
|
||||
if let Some(extension) = out_file.extension() {
|
||||
if extension == "safetensors" {
|
||||
candle_core::bail!("the generated file cannot use the safetensors extension")
|
||||
}
|
||||
}
|
||||
if let Some(extension) = in_files[0].extension() {
|
||||
if extension == "safetensors" {
|
||||
return run_quantize_safetensors(in_files, out_file, q);
|
||||
}
|
||||
}
|
||||
|
||||
if in_files.len() != 1 {
|
||||
candle_core::bail!("only a single in-file can be used when quantizing gguf files")
|
||||
}
|
||||
|
||||
// Open the out file early so as to fail directly on missing directories etc.
|
||||
let mut out_file = std::fs::File::create(out_file)?;
|
||||
let mut in_ = std::fs::File::open(&in_files[0])?;
|
||||
let mut in_ = std::fs::File::open(&in_file)?;
|
||||
let content = gguf_file::Content::read(&mut in_)?;
|
||||
println!("tensors: {}", content.tensor_infos.len());
|
||||
|
||||
@ -337,7 +252,7 @@ fn run_quantize(
|
||||
.par_iter()
|
||||
.map(|(name, _)| {
|
||||
println!(" quantizing {name}");
|
||||
let mut in_file = std::fs::File::open(&in_files[0])?;
|
||||
let mut in_file = std::fs::File::open(&in_file)?;
|
||||
let tensor = content.tensor(&mut in_file, name)?;
|
||||
let tensor = qmode.quantize(name, tensor, quantize_fn)?;
|
||||
Ok((name, tensor))
|
||||
@ -378,7 +293,7 @@ fn main() -> anyhow::Result<()> {
|
||||
out_file,
|
||||
quantization,
|
||||
mode,
|
||||
} => run_quantize(&in_file, out_file, quantization, mode)?,
|
||||
} => run_quantize(in_file, out_file, quantization, mode)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -39,14 +39,6 @@ pub trait BackendStorage: Sized {
|
||||
_params: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self>;
|
||||
|
||||
fn conv_transpose1d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self>;
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
@ -119,6 +111,4 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||
|
||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||
|
||||
fn set_seed(&self, _: u64) -> Result<()>;
|
||||
}
|
||||
|
@ -15,17 +15,6 @@ fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result
|
||||
}
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static CANDLE_GRAD_DO_NOT_DETACH: bool = {
|
||||
match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") {
|
||||
Ok(s) => {
|
||||
!s.is_empty() && s != "0"
|
||||
},
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
||||
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
||||
@ -47,8 +36,6 @@ impl Tensor {
|
||||
// Do not call recursively on the "leaf" nodes.
|
||||
track_grad = true;
|
||||
nodes
|
||||
} else if node.dtype().is_int() {
|
||||
nodes
|
||||
} else if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::IndexAdd(t1, t2, t3, _)
|
||||
@ -68,11 +55,6 @@ impl Tensor {
|
||||
kernel: rhs,
|
||||
..
|
||||
}
|
||||
| Op::ConvTranspose1D {
|
||||
arg: lhs,
|
||||
kernel: rhs,
|
||||
..
|
||||
}
|
||||
| Op::Conv2D {
|
||||
arg: lhs,
|
||||
kernel: rhs,
|
||||
@ -87,8 +69,7 @@ impl Tensor {
|
||||
| Op::Binary(lhs, rhs, _)
|
||||
| Op::Gather(lhs, rhs, _)
|
||||
| Op::IndexSelect(lhs, rhs, _)
|
||||
| Op::Matmul(lhs, rhs)
|
||||
| Op::SliceScatter0(lhs, rhs, _) => {
|
||||
| Op::Matmul(lhs, rhs) => {
|
||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
let (tg, nodes) = walk(rhs, nodes, already_seen);
|
||||
@ -109,9 +90,6 @@ impl Tensor {
|
||||
nodes
|
||||
}
|
||||
}
|
||||
Op::Unary(_node, UnaryOp::Ceil)
|
||||
| Op::Unary(_node, UnaryOp::Floor)
|
||||
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
||||
Op::Reshape(node)
|
||||
| Op::UpsampleNearest1D(node)
|
||||
| Op::UpsampleNearest2D(node)
|
||||
@ -121,6 +99,7 @@ impl Tensor {
|
||||
| Op::Broadcast(node)
|
||||
| Op::Cmp(node, _)
|
||||
| Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
|
||||
| Op::ToDType(node)
|
||||
| Op::ToDevice(node)
|
||||
| Op::Transpose(node, _, _)
|
||||
| Op::Permute(node, _)
|
||||
@ -133,15 +112,6 @@ impl Tensor {
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
}
|
||||
Op::ToDType(node) => {
|
||||
if node.dtype().is_float() {
|
||||
let (tg, nodes) = walk(node, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
} else {
|
||||
nodes
|
||||
}
|
||||
}
|
||||
Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
|
||||
}
|
||||
} else {
|
||||
@ -166,16 +136,10 @@ impl Tensor {
|
||||
if node.is_variable() {
|
||||
continue;
|
||||
}
|
||||
let grad = grads
|
||||
.remove(node)
|
||||
.expect("candle internal error - grad not populated");
|
||||
// https://github.com/huggingface/candle/issues/1241
|
||||
// Ideally, we would make these operations in place where possible to ensure that we
|
||||
// do not have to allocate too often. Here we just call `.detach` to avoid computing
|
||||
// the backprop graph of the backprop itself. This would be an issue for second order
|
||||
// derivatives but these are out of scope at the moment.
|
||||
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
|
||||
let grad = if do_not_detach { grad } else { grad.detach()? };
|
||||
let grad = grads.remove(node).unwrap();
|
||||
// TODO: We should perform all these operations in place (or at least not track the
|
||||
// whole graph). The only drawback would be if we wanted to support grad of grad but
|
||||
// this is out of scope.
|
||||
if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
||||
@ -230,44 +194,7 @@ impl Tensor {
|
||||
let f_grad = pred.where_cond(&zeros, &grad)?;
|
||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||
}
|
||||
Op::Conv1D {
|
||||
arg,
|
||||
kernel,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
} => {
|
||||
// The output height for conv_transpose1d is:
|
||||
// (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1
|
||||
let grad_l_in = grad.dim(2)?;
|
||||
let k_size = kernel.dim(2)?;
|
||||
let out_size =
|
||||
(grad_l_in - 1) * stride + dilation * (k_size - 1) + 1 - 2 * padding;
|
||||
let out_padding = arg.dim(2)? - out_size;
|
||||
let grad_arg = grad.conv_transpose1d(
|
||||
kernel,
|
||||
*padding,
|
||||
out_padding,
|
||||
*stride,
|
||||
*dilation,
|
||||
)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
|
||||
let grad_kernel = arg
|
||||
.transpose(0, 1)?
|
||||
.conv1d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
||||
.transpose(0, 1)?;
|
||||
let sum_grad = grads.or_insert(kernel)?;
|
||||
let (_, _, k0) = kernel.dims3()?;
|
||||
let (_, _, g_k0) = grad_kernel.dims3()?;
|
||||
let grad_kernel = if g_k0 != k0 {
|
||||
grad_kernel.narrow(2, 0, k0)?
|
||||
} else {
|
||||
grad_kernel
|
||||
};
|
||||
*sum_grad = sum_grad.add(&grad_kernel)?;
|
||||
}
|
||||
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
||||
Op::Conv2D {
|
||||
arg,
|
||||
kernel,
|
||||
@ -297,18 +224,8 @@ impl Tensor {
|
||||
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
||||
.transpose(0, 1)?;
|
||||
let sum_grad = grads.or_insert(kernel)?;
|
||||
let (_, _, k0, k1) = kernel.dims4()?;
|
||||
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
|
||||
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
|
||||
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
|
||||
} else {
|
||||
grad_kernel
|
||||
};
|
||||
*sum_grad = sum_grad.add(&grad_kernel)?;
|
||||
}
|
||||
Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "conv-transpose1d",
|
||||
})?,
|
||||
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "conv-transpose2d",
|
||||
})?,
|
||||
@ -353,15 +270,6 @@ impl Tensor {
|
||||
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "upsample-nearest2d",
|
||||
})?,
|
||||
Op::SliceScatter0(lhs, rhs, start_rhs) => {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?
|
||||
}
|
||||
Op::Gather(arg, indexes, dim) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
|
||||
@ -453,7 +361,7 @@ impl Tensor {
|
||||
}
|
||||
Op::ToDType(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad.to_dtype(arg.dtype())?)?
|
||||
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
|
||||
}
|
||||
Op::Copy(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
@ -533,54 +441,13 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
|
||||
Op::Unary(_, UnaryOp::Floor) => {
|
||||
Err(Error::BackwardNotSupported { op: "floor" })?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Round) => {
|
||||
Err(Error::BackwardNotSupported { op: "round" })?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Gelu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let cube = arg.powf(3.)?;
|
||||
let tanh = (0.0356774 * &cube + (0.797885 * arg)?)?.tanh()?;
|
||||
let gelu_grad = (((0.5 * &tanh)?
|
||||
+ (0.0535161 * cube + (0.398942 * arg)?)? * (1. - tanh.powf(2.)?))?
|
||||
+ 0.5)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Erf) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
// d/dx erf(x) = 2/sqrt(pi) * e^(-x^2)
|
||||
let erf_grad =
|
||||
(2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?;
|
||||
*sum_grad = sum_grad.add(&(&grad * erf_grad)?)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::GeluErf) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
// d/dx gelu_erf(x) = 0.5 + 0.398942 e^(-x^2/2) x + 0.5 erf(x/sqrt(2))
|
||||
let neg_half_square = (arg.sqr()?.neg()? / 2.)?;
|
||||
let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?;
|
||||
let arg_scaled_sqrt = (arg / 2f64.sqrt())?;
|
||||
let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?;
|
||||
let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?;
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
|
||||
Op::Unary(arg, UnaryOp::Relu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
||||
}
|
||||
Op::Elu(arg, alpha) => {
|
||||
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let zeros = arg.zeros_like()?;
|
||||
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
|
||||
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
|
||||
let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?;
|
||||
let combined_mask = (positive_mask + negative_exp_mask)?;
|
||||
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
|
||||
}
|
||||
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
|
||||
Op::Powf(arg, e) => {
|
||||
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
@ -25,46 +25,6 @@ impl ParamsConv1D {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ParamsConvTranspose1D {
|
||||
pub(crate) b_size: usize,
|
||||
pub(crate) l_in: usize,
|
||||
pub(crate) c_out: usize,
|
||||
pub(crate) c_in: usize,
|
||||
pub(crate) k_size: usize,
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) output_padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
pub(crate) dilation: usize,
|
||||
}
|
||||
|
||||
impl ParamsConvTranspose1D {
|
||||
pub(crate) fn l_out(&self) -> usize {
|
||||
(self.l_in - 1) * self.stride - 2 * self.padding
|
||||
+ self.dilation * (self.k_size - 1)
|
||||
+ self.output_padding
|
||||
+ 1
|
||||
}
|
||||
|
||||
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
||||
let l_out = self.l_out();
|
||||
vec![self.b_size, self.c_out, l_out]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum CudnnFwdAlgo {
|
||||
ImplicitGemm,
|
||||
ImplicitPrecompGemm,
|
||||
Gemm,
|
||||
Direct,
|
||||
Fft,
|
||||
FftTiling,
|
||||
Winograd,
|
||||
WinogradNonFused,
|
||||
Count,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ParamsConv2D {
|
||||
pub(crate) b_size: usize,
|
||||
@ -77,7 +37,6 @@ pub struct ParamsConv2D {
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
pub(crate) dilation: usize,
|
||||
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||
}
|
||||
|
||||
impl ParamsConv2D {
|
||||
@ -187,49 +146,6 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies a 1D transposed convolution over the input tensor.
|
||||
pub fn conv_transpose1d(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
) -> Result<Self> {
|
||||
let (b_size, c_in, l_in) = self.dims3()?;
|
||||
let (c_in_k, c_out, k_size) = kernel.dims3()?;
|
||||
if c_in != c_in_k {
|
||||
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||
}
|
||||
let params = ParamsConvTranspose1D {
|
||||
b_size,
|
||||
l_in,
|
||||
k_size,
|
||||
c_out,
|
||||
c_in,
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
dilation,
|
||||
};
|
||||
let storage = self.storage().conv_transpose1d(
|
||||
self.layout(),
|
||||
&kernel.storage(),
|
||||
kernel.layout(),
|
||||
¶ms,
|
||||
)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
|
||||
arg,
|
||||
kernel,
|
||||
padding: params.padding,
|
||||
output_padding: params.output_padding,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
|
||||
let storage =
|
||||
self.storage()
|
||||
@ -272,7 +188,6 @@ impl Tensor {
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv2d_single_group(kernel, ¶ms)
|
||||
|
@ -1,763 +0,0 @@
|
||||
#![allow(clippy::excessive_precision)]
|
||||
// Code taken from https://github.com/statrs-dev/statrs
|
||||
//! Provides the [error](https://en.wikipedia.org/wiki/Error_function) and
|
||||
//! related functions
|
||||
|
||||
mod evaluate {
|
||||
//! Provides functions that don't have a numerical solution and must
|
||||
//! be solved computationally (e.g. evaluation of a polynomial)
|
||||
|
||||
/// evaluates a polynomial at `z` where `coeff` are the coeffecients
|
||||
/// to a polynomial of order `k` where `k` is the length of `coeff` and the
|
||||
/// coeffecient
|
||||
/// to the `k`th power is the `k`th element in coeff. E.g. [3,-1,2] equates to
|
||||
/// `2z^2 - z + 3`
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// Returns 0 for a 0 length coefficient slice
|
||||
pub fn polynomial(z: f64, coeff: &[f64]) -> f64 {
|
||||
let n = coeff.len();
|
||||
if n == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mut sum = *coeff.last().unwrap();
|
||||
for c in coeff[0..n - 1].iter().rev() {
|
||||
sum = *c + z * sum;
|
||||
}
|
||||
sum
|
||||
}
|
||||
}
|
||||
use std::f64;
|
||||
|
||||
/// `erf` calculates the error function at `x`.
|
||||
pub fn erf(x: f64) -> f64 {
|
||||
if x.is_nan() {
|
||||
f64::NAN
|
||||
} else if x >= 0.0 && x.is_infinite() {
|
||||
1.0
|
||||
} else if x <= 0.0 && x.is_infinite() {
|
||||
-1.0
|
||||
} else if x == 0. {
|
||||
0.0
|
||||
} else {
|
||||
erf_impl(x, false)
|
||||
}
|
||||
}
|
||||
|
||||
/// `erf_inv` calculates the inverse error function
|
||||
/// at `x`.
|
||||
pub fn erf_inv(x: f64) -> f64 {
|
||||
if x == 0.0 {
|
||||
0.0
|
||||
} else if x >= 1.0 {
|
||||
f64::INFINITY
|
||||
} else if x <= -1.0 {
|
||||
f64::NEG_INFINITY
|
||||
} else if x < 0.0 {
|
||||
erf_inv_impl(-x, 1.0 + x, -1.0)
|
||||
} else {
|
||||
erf_inv_impl(x, 1.0 - x, 1.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// `erfc` calculates the complementary error function
|
||||
/// at `x`.
|
||||
pub fn erfc(x: f64) -> f64 {
|
||||
if x.is_nan() {
|
||||
f64::NAN
|
||||
} else if x == f64::INFINITY {
|
||||
0.0
|
||||
} else if x == f64::NEG_INFINITY {
|
||||
2.0
|
||||
} else {
|
||||
erf_impl(x, true)
|
||||
}
|
||||
}
|
||||
|
||||
/// `erfc_inv` calculates the complementary inverse
|
||||
/// error function at `x`.
|
||||
pub fn erfc_inv(x: f64) -> f64 {
|
||||
if x <= 0.0 {
|
||||
f64::INFINITY
|
||||
} else if x >= 2.0 {
|
||||
f64::NEG_INFINITY
|
||||
} else if x > 1.0 {
|
||||
erf_inv_impl(-1.0 + x, 2.0 - x, -1.0)
|
||||
} else {
|
||||
erf_inv_impl(1.0 - x, x, 1.0)
|
||||
}
|
||||
}
|
||||
|
||||
// **********************************************************
|
||||
// ********** Coefficients for erf_impl polynomial **********
|
||||
// **********************************************************
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_impl`
|
||||
/// in the interval [1e-10, 0.5].
|
||||
const ERF_IMPL_AN: &[f64] = &[
|
||||
0.00337916709551257388990745,
|
||||
-0.00073695653048167948530905,
|
||||
-0.374732337392919607868241,
|
||||
0.0817442448733587196071743,
|
||||
-0.0421089319936548595203468,
|
||||
0.0070165709512095756344528,
|
||||
-0.00495091255982435110337458,
|
||||
0.000871646599037922480317225,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_impl`
|
||||
/// in the interval [1e-10, 0.5]
|
||||
const ERF_IMPL_AD: &[f64] = &[
|
||||
1.0,
|
||||
-0.218088218087924645390535,
|
||||
0.412542972725442099083918,
|
||||
-0.0841891147873106755410271,
|
||||
0.0655338856400241519690695,
|
||||
-0.0120019604454941768171266,
|
||||
0.00408165558926174048329689,
|
||||
-0.000615900721557769691924509,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [0.5, 0.75].
|
||||
const ERF_IMPL_BN: &[f64] = &[
|
||||
-0.0361790390718262471360258,
|
||||
0.292251883444882683221149,
|
||||
0.281447041797604512774415,
|
||||
0.125610208862766947294894,
|
||||
0.0274135028268930549240776,
|
||||
0.00250839672168065762786937,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [0.5, 0.75].
|
||||
const ERF_IMPL_BD: &[f64] = &[
|
||||
1.0,
|
||||
1.8545005897903486499845,
|
||||
1.43575803037831418074962,
|
||||
0.582827658753036572454135,
|
||||
0.124810476932949746447682,
|
||||
0.0113724176546353285778481,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [0.75, 1.25].
|
||||
const ERF_IMPL_CN: &[f64] = &[
|
||||
-0.0397876892611136856954425,
|
||||
0.153165212467878293257683,
|
||||
0.191260295600936245503129,
|
||||
0.10276327061989304213645,
|
||||
0.029637090615738836726027,
|
||||
0.0046093486780275489468812,
|
||||
0.000307607820348680180548455,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [0.75, 1.25].
|
||||
const ERF_IMPL_CD: &[f64] = &[
|
||||
1.0,
|
||||
1.95520072987627704987886,
|
||||
1.64762317199384860109595,
|
||||
0.768238607022126250082483,
|
||||
0.209793185936509782784315,
|
||||
0.0319569316899913392596356,
|
||||
0.00213363160895785378615014,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [1.25, 2.25].
|
||||
const ERF_IMPL_DN: &[f64] = &[
|
||||
-0.0300838560557949717328341,
|
||||
0.0538578829844454508530552,
|
||||
0.0726211541651914182692959,
|
||||
0.0367628469888049348429018,
|
||||
0.00964629015572527529605267,
|
||||
0.00133453480075291076745275,
|
||||
0.778087599782504251917881e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [1.25, 2.25].
|
||||
const ERF_IMPL_DD: &[f64] = &[
|
||||
1.0,
|
||||
1.75967098147167528287343,
|
||||
1.32883571437961120556307,
|
||||
0.552528596508757581287907,
|
||||
0.133793056941332861912279,
|
||||
0.0179509645176280768640766,
|
||||
0.00104712440019937356634038,
|
||||
-0.106640381820357337177643e-7,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [2.25, 3.5].
|
||||
const ERF_IMPL_EN: &[f64] = &[
|
||||
-0.0117907570137227847827732,
|
||||
0.014262132090538809896674,
|
||||
0.0202234435902960820020765,
|
||||
0.00930668299990432009042239,
|
||||
0.00213357802422065994322516,
|
||||
0.00025022987386460102395382,
|
||||
0.120534912219588189822126e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [2.25, 3.5].
|
||||
const ERF_IMPL_ED: &[f64] = &[
|
||||
1.0,
|
||||
1.50376225203620482047419,
|
||||
0.965397786204462896346934,
|
||||
0.339265230476796681555511,
|
||||
0.0689740649541569716897427,
|
||||
0.00771060262491768307365526,
|
||||
0.000371421101531069302990367,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [3.5, 5.25].
|
||||
const ERF_IMPL_FN: &[f64] = &[
|
||||
-0.00546954795538729307482955,
|
||||
0.00404190278731707110245394,
|
||||
0.0054963369553161170521356,
|
||||
0.00212616472603945399437862,
|
||||
0.000394984014495083900689956,
|
||||
0.365565477064442377259271e-4,
|
||||
0.135485897109932323253786e-5,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [3.5, 5.25].
|
||||
const ERF_IMPL_FD: &[f64] = &[
|
||||
1.0,
|
||||
1.21019697773630784832251,
|
||||
0.620914668221143886601045,
|
||||
0.173038430661142762569515,
|
||||
0.0276550813773432047594539,
|
||||
0.00240625974424309709745382,
|
||||
0.891811817251336577241006e-4,
|
||||
-0.465528836283382684461025e-11,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [5.25, 8].
|
||||
const ERF_IMPL_GN: &[f64] = &[
|
||||
-0.00270722535905778347999196,
|
||||
0.0013187563425029400461378,
|
||||
0.00119925933261002333923989,
|
||||
0.00027849619811344664248235,
|
||||
0.267822988218331849989363e-4,
|
||||
0.923043672315028197865066e-6,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [5.25, 8].
|
||||
const ERF_IMPL_GD: &[f64] = &[
|
||||
1.0,
|
||||
0.814632808543141591118279,
|
||||
0.268901665856299542168425,
|
||||
0.0449877216103041118694989,
|
||||
0.00381759663320248459168994,
|
||||
0.000131571897888596914350697,
|
||||
0.404815359675764138445257e-11,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [8, 11.5].
|
||||
const ERF_IMPL_HN: &[f64] = &[
|
||||
-0.00109946720691742196814323,
|
||||
0.000406425442750422675169153,
|
||||
0.000274499489416900707787024,
|
||||
0.465293770646659383436343e-4,
|
||||
0.320955425395767463401993e-5,
|
||||
0.778286018145020892261936e-7,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [8, 11.5].
|
||||
const ERF_IMPL_HD: &[f64] = &[
|
||||
1.0,
|
||||
0.588173710611846046373373,
|
||||
0.139363331289409746077541,
|
||||
0.0166329340417083678763028,
|
||||
0.00100023921310234908642639,
|
||||
0.24254837521587225125068e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [11.5, 17].
|
||||
const ERF_IMPL_IN: &[f64] = &[
|
||||
-0.00056907993601094962855594,
|
||||
0.000169498540373762264416984,
|
||||
0.518472354581100890120501e-4,
|
||||
0.382819312231928859704678e-5,
|
||||
0.824989931281894431781794e-7,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [11.5, 17].
|
||||
const ERF_IMPL_ID: &[f64] = &[
|
||||
1.0,
|
||||
0.339637250051139347430323,
|
||||
0.043472647870310663055044,
|
||||
0.00248549335224637114641629,
|
||||
0.535633305337152900549536e-4,
|
||||
-0.117490944405459578783846e-12,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [17, 24].
|
||||
const ERF_IMPL_JN: &[f64] = &[
|
||||
-0.000241313599483991337479091,
|
||||
0.574224975202501512365975e-4,
|
||||
0.115998962927383778460557e-4,
|
||||
0.581762134402593739370875e-6,
|
||||
0.853971555085673614607418e-8,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [17, 24].
|
||||
const ERF_IMPL_JD: &[f64] = &[
|
||||
1.0,
|
||||
0.233044138299687841018015,
|
||||
0.0204186940546440312625597,
|
||||
0.000797185647564398289151125,
|
||||
0.117019281670172327758019e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [24, 38].
|
||||
const ERF_IMPL_KN: &[f64] = &[
|
||||
-0.000146674699277760365803642,
|
||||
0.162666552112280519955647e-4,
|
||||
0.269116248509165239294897e-5,
|
||||
0.979584479468091935086972e-7,
|
||||
0.101994647625723465722285e-8,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [24, 38].
|
||||
const ERF_IMPL_KD: &[f64] = &[
|
||||
1.0,
|
||||
0.165907812944847226546036,
|
||||
0.0103361716191505884359634,
|
||||
0.000286593026373868366935721,
|
||||
0.298401570840900340874568e-5,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [38, 60].
|
||||
const ERF_IMPL_LN: &[f64] = &[
|
||||
-0.583905797629771786720406e-4,
|
||||
0.412510325105496173512992e-5,
|
||||
0.431790922420250949096906e-6,
|
||||
0.993365155590013193345569e-8,
|
||||
0.653480510020104699270084e-10,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [38, 60].
|
||||
const ERF_IMPL_LD: &[f64] = &[
|
||||
1.0,
|
||||
0.105077086072039915406159,
|
||||
0.00414278428675475620830226,
|
||||
0.726338754644523769144108e-4,
|
||||
0.477818471047398785369849e-6,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [60, 85].
|
||||
const ERF_IMPL_MN: &[f64] = &[
|
||||
-0.196457797609229579459841e-4,
|
||||
0.157243887666800692441195e-5,
|
||||
0.543902511192700878690335e-7,
|
||||
0.317472492369117710852685e-9,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [60, 85].
|
||||
const ERF_IMPL_MD: &[f64] = &[
|
||||
1.0,
|
||||
0.052803989240957632204885,
|
||||
0.000926876069151753290378112,
|
||||
0.541011723226630257077328e-5,
|
||||
0.535093845803642394908747e-15,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [85, 110].
|
||||
const ERF_IMPL_NN: &[f64] = &[
|
||||
-0.789224703978722689089794e-5,
|
||||
0.622088451660986955124162e-6,
|
||||
0.145728445676882396797184e-7,
|
||||
0.603715505542715364529243e-10,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [85, 110].
|
||||
const ERF_IMPL_ND: &[f64] = &[
|
||||
1.0,
|
||||
0.0375328846356293715248719,
|
||||
0.000467919535974625308126054,
|
||||
0.193847039275845656900547e-5,
|
||||
];
|
||||
|
||||
// **********************************************************
|
||||
// ********** Coefficients for erf_inv_impl polynomial ******
|
||||
// **********************************************************
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0, 0.5].
|
||||
const ERF_INV_IMPL_AN: &[f64] = &[
|
||||
-0.000508781949658280665617,
|
||||
-0.00836874819741736770379,
|
||||
0.0334806625409744615033,
|
||||
-0.0126926147662974029034,
|
||||
-0.0365637971411762664006,
|
||||
0.0219878681111168899165,
|
||||
0.00822687874676915743155,
|
||||
-0.00538772965071242932965,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0, 0.5].
|
||||
const ERF_INV_IMPL_AD: &[f64] = &[
|
||||
1.0,
|
||||
-0.970005043303290640362,
|
||||
-1.56574558234175846809,
|
||||
1.56221558398423026363,
|
||||
0.662328840472002992063,
|
||||
-0.71228902341542847553,
|
||||
-0.0527396382340099713954,
|
||||
0.0795283687341571680018,
|
||||
-0.00233393759374190016776,
|
||||
0.000886216390456424707504,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.5, 0.75].
|
||||
const ERF_INV_IMPL_BN: &[f64] = &[
|
||||
-0.202433508355938759655,
|
||||
0.105264680699391713268,
|
||||
8.37050328343119927838,
|
||||
17.6447298408374015486,
|
||||
-18.8510648058714251895,
|
||||
-44.6382324441786960818,
|
||||
17.445385985570866523,
|
||||
21.1294655448340526258,
|
||||
-3.67192254707729348546,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.5, 0.75].
|
||||
const ERF_INV_IMPL_BD: &[f64] = &[
|
||||
1.0,
|
||||
6.24264124854247537712,
|
||||
3.9713437953343869095,
|
||||
-28.6608180499800029974,
|
||||
-20.1432634680485188801,
|
||||
48.5609213108739935468,
|
||||
10.8268667355460159008,
|
||||
-22.6436933413139721736,
|
||||
1.72114765761200282724,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x less than 3.
|
||||
const ERF_INV_IMPL_CN: &[f64] = &[
|
||||
-0.131102781679951906451,
|
||||
-0.163794047193317060787,
|
||||
0.117030156341995252019,
|
||||
0.387079738972604337464,
|
||||
0.337785538912035898924,
|
||||
0.142869534408157156766,
|
||||
0.0290157910005329060432,
|
||||
0.00214558995388805277169,
|
||||
-0.679465575181126350155e-6,
|
||||
0.285225331782217055858e-7,
|
||||
-0.681149956853776992068e-9,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x less than 3.
|
||||
const ERF_INV_IMPL_CD: &[f64] = &[
|
||||
1.0,
|
||||
3.46625407242567245975,
|
||||
5.38168345707006855425,
|
||||
4.77846592945843778382,
|
||||
2.59301921623620271374,
|
||||
0.848854343457902036425,
|
||||
0.152264338295331783612,
|
||||
0.01105924229346489121,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 3 and 6.
|
||||
const ERF_INV_IMPL_DN: &[f64] = &[
|
||||
-0.0350353787183177984712,
|
||||
-0.00222426529213447927281,
|
||||
0.0185573306514231072324,
|
||||
0.00950804701325919603619,
|
||||
0.00187123492819559223345,
|
||||
0.000157544617424960554631,
|
||||
0.460469890584317994083e-5,
|
||||
-0.230404776911882601748e-9,
|
||||
0.266339227425782031962e-11,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 3 and 6.
|
||||
const ERF_INV_IMPL_DD: &[f64] = &[
|
||||
1.0,
|
||||
1.3653349817554063097,
|
||||
0.762059164553623404043,
|
||||
0.220091105764131249824,
|
||||
0.0341589143670947727934,
|
||||
0.00263861676657015992959,
|
||||
0.764675292302794483503e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 6 and 18.
|
||||
const ERF_INV_IMPL_EN: &[f64] = &[
|
||||
-0.0167431005076633737133,
|
||||
-0.00112951438745580278863,
|
||||
0.00105628862152492910091,
|
||||
0.000209386317487588078668,
|
||||
0.149624783758342370182e-4,
|
||||
0.449696789927706453732e-6,
|
||||
0.462596163522878599135e-8,
|
||||
-0.281128735628831791805e-13,
|
||||
0.99055709973310326855e-16,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 6 and 18.
|
||||
const ERF_INV_IMPL_ED: &[f64] = &[
|
||||
1.0,
|
||||
0.591429344886417493481,
|
||||
0.138151865749083321638,
|
||||
0.0160746087093676504695,
|
||||
0.000964011807005165528527,
|
||||
0.275335474764726041141e-4,
|
||||
0.282243172016108031869e-6,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 18 and 44.
|
||||
const ERF_INV_IMPL_FN: &[f64] = &[
|
||||
-0.0024978212791898131227,
|
||||
-0.779190719229053954292e-5,
|
||||
0.254723037413027451751e-4,
|
||||
0.162397777342510920873e-5,
|
||||
0.396341011304801168516e-7,
|
||||
0.411632831190944208473e-9,
|
||||
0.145596286718675035587e-11,
|
||||
-0.116765012397184275695e-17,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 18 and 44.
|
||||
const ERF_INV_IMPL_FD: &[f64] = &[
|
||||
1.0,
|
||||
0.207123112214422517181,
|
||||
0.0169410838120975906478,
|
||||
0.000690538265622684595676,
|
||||
0.145007359818232637924e-4,
|
||||
0.144437756628144157666e-6,
|
||||
0.509761276599778486139e-9,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x greater than 44.
|
||||
const ERF_INV_IMPL_GN: &[f64] = &[
|
||||
-0.000539042911019078575891,
|
||||
-0.28398759004727721098e-6,
|
||||
0.899465114892291446442e-6,
|
||||
0.229345859265920864296e-7,
|
||||
0.225561444863500149219e-9,
|
||||
0.947846627503022684216e-12,
|
||||
0.135880130108924861008e-14,
|
||||
-0.348890393399948882918e-21,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x greater than 44.
|
||||
const ERF_INV_IMPL_GD: &[f64] = &[
|
||||
1.0,
|
||||
0.0845746234001899436914,
|
||||
0.00282092984726264681981,
|
||||
0.468292921940894236786e-4,
|
||||
0.399968812193862100054e-6,
|
||||
0.161809290887904476097e-8,
|
||||
0.231558608310259605225e-11,
|
||||
];
|
||||
|
||||
/// `erf_impl` computes the error function at `z`.
|
||||
/// If `inv` is true, `1 - erf` is calculated as opposed to `erf`
|
||||
fn erf_impl(z: f64, inv: bool) -> f64 {
|
||||
if z < 0.0 {
|
||||
if !inv {
|
||||
return -erf_impl(-z, false);
|
||||
}
|
||||
if z < -0.5 {
|
||||
return 2.0 - erf_impl(-z, true);
|
||||
}
|
||||
return 1.0 + erf_impl(-z, false);
|
||||
}
|
||||
|
||||
let result = if z < 0.5 {
|
||||
if z < 1e-10 {
|
||||
z * 1.125 + z * 0.003379167095512573896158903121545171688
|
||||
} else {
|
||||
z * 1.125
|
||||
+ z * evaluate::polynomial(z, ERF_IMPL_AN) / evaluate::polynomial(z, ERF_IMPL_AD)
|
||||
}
|
||||
} else if z < 110.0 {
|
||||
let (r, b) = if z < 0.75 {
|
||||
(
|
||||
evaluate::polynomial(z - 0.5, ERF_IMPL_BN)
|
||||
/ evaluate::polynomial(z - 0.5, ERF_IMPL_BD),
|
||||
0.3440242112,
|
||||
)
|
||||
} else if z < 1.25 {
|
||||
(
|
||||
evaluate::polynomial(z - 0.75, ERF_IMPL_CN)
|
||||
/ evaluate::polynomial(z - 0.75, ERF_IMPL_CD),
|
||||
0.419990927,
|
||||
)
|
||||
} else if z < 2.25 {
|
||||
(
|
||||
evaluate::polynomial(z - 1.25, ERF_IMPL_DN)
|
||||
/ evaluate::polynomial(z - 1.25, ERF_IMPL_DD),
|
||||
0.4898625016,
|
||||
)
|
||||
} else if z < 3.5 {
|
||||
(
|
||||
evaluate::polynomial(z - 2.25, ERF_IMPL_EN)
|
||||
/ evaluate::polynomial(z - 2.25, ERF_IMPL_ED),
|
||||
0.5317370892,
|
||||
)
|
||||
} else if z < 5.25 {
|
||||
(
|
||||
evaluate::polynomial(z - 3.5, ERF_IMPL_FN)
|
||||
/ evaluate::polynomial(z - 3.5, ERF_IMPL_FD),
|
||||
0.5489973426,
|
||||
)
|
||||
} else if z < 8.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 5.25, ERF_IMPL_GN)
|
||||
/ evaluate::polynomial(z - 5.25, ERF_IMPL_GD),
|
||||
0.5571740866,
|
||||
)
|
||||
} else if z < 11.5 {
|
||||
(
|
||||
evaluate::polynomial(z - 8.0, ERF_IMPL_HN)
|
||||
/ evaluate::polynomial(z - 8.0, ERF_IMPL_HD),
|
||||
0.5609807968,
|
||||
)
|
||||
} else if z < 17.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 11.5, ERF_IMPL_IN)
|
||||
/ evaluate::polynomial(z - 11.5, ERF_IMPL_ID),
|
||||
0.5626493692,
|
||||
)
|
||||
} else if z < 24.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 17.0, ERF_IMPL_JN)
|
||||
/ evaluate::polynomial(z - 17.0, ERF_IMPL_JD),
|
||||
0.5634598136,
|
||||
)
|
||||
} else if z < 38.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 24.0, ERF_IMPL_KN)
|
||||
/ evaluate::polynomial(z - 24.0, ERF_IMPL_KD),
|
||||
0.5638477802,
|
||||
)
|
||||
} else if z < 60.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 38.0, ERF_IMPL_LN)
|
||||
/ evaluate::polynomial(z - 38.0, ERF_IMPL_LD),
|
||||
0.5640528202,
|
||||
)
|
||||
} else if z < 85.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 60.0, ERF_IMPL_MN)
|
||||
/ evaluate::polynomial(z - 60.0, ERF_IMPL_MD),
|
||||
0.5641309023,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
evaluate::polynomial(z - 85.0, ERF_IMPL_NN)
|
||||
/ evaluate::polynomial(z - 85.0, ERF_IMPL_ND),
|
||||
0.5641584396,
|
||||
)
|
||||
};
|
||||
let g = (-z * z).exp() / z;
|
||||
g * b + g * r
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
if inv && z >= 0.5 {
|
||||
result
|
||||
} else if z >= 0.5 || inv {
|
||||
1.0 - result
|
||||
} else {
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
// `erf_inv_impl` computes the inverse error function where
|
||||
// `p`,`q`, and `s` are the first, second, and third intermediate
|
||||
// parameters respectively
|
||||
fn erf_inv_impl(p: f64, q: f64, s: f64) -> f64 {
|
||||
let result = if p <= 0.5 {
|
||||
let y = 0.0891314744949340820313;
|
||||
let g = p * (p + 10.0);
|
||||
let r = evaluate::polynomial(p, ERF_INV_IMPL_AN) / evaluate::polynomial(p, ERF_INV_IMPL_AD);
|
||||
g * y + g * r
|
||||
} else if q >= 0.25 {
|
||||
let y = 2.249481201171875;
|
||||
let g = (-2.0 * q.ln()).sqrt();
|
||||
let xs = q - 0.25;
|
||||
let r =
|
||||
evaluate::polynomial(xs, ERF_INV_IMPL_BN) / evaluate::polynomial(xs, ERF_INV_IMPL_BD);
|
||||
g / (y + r)
|
||||
} else {
|
||||
let x = (-q.ln()).sqrt();
|
||||
if x < 3.0 {
|
||||
let y = 0.807220458984375;
|
||||
let xs = x - 1.125;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_CN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_CD);
|
||||
y * x + r * x
|
||||
} else if x < 6.0 {
|
||||
let y = 0.93995571136474609375;
|
||||
let xs = x - 3.0;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_DN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_DD);
|
||||
y * x + r * x
|
||||
} else if x < 18.0 {
|
||||
let y = 0.98362827301025390625;
|
||||
let xs = x - 6.0;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_EN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_ED);
|
||||
y * x + r * x
|
||||
} else if x < 44.0 {
|
||||
let y = 0.99714565277099609375;
|
||||
let xs = x - 18.0;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_FN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_FD);
|
||||
y * x + r * x
|
||||
} else {
|
||||
let y = 0.99941349029541015625;
|
||||
let xs = x - 44.0;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_GN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_GD);
|
||||
y * x + r * x
|
||||
}
|
||||
};
|
||||
s * result
|
||||
}
|
@ -1,4 +1,3 @@
|
||||
pub mod erf;
|
||||
pub mod kernels;
|
||||
|
||||
trait Cpu<const ARR: usize> {
|
||||
|
@ -804,11 +804,11 @@ impl<'a, I: IntDType> Map1 for Gather<'a, I> {
|
||||
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
let ids = match self.ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &self.ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((a, b)) => &src[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
||||
};
|
||||
let dim = self.dim;
|
||||
let ids_dims = self.ids_l.dims();
|
||||
@ -857,7 +857,7 @@ impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
|
||||
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||
let src = match layout.contiguous_offsets() {
|
||||
Some((a, b)) => &src[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-select" })?,
|
||||
};
|
||||
let dim = self.dim;
|
||||
let n_ids = match self.ids_l.dims() {
|
||||
@ -913,7 +913,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
|
||||
let mut dst = vec![T::zero(); dst_len];
|
||||
copy_strided_src_(v1, &mut dst, 0, l1);
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "scatter-add" })?,
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
};
|
||||
|
||||
@ -929,7 +929,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
|
||||
|
||||
let ids = match self.ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &self.ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
||||
};
|
||||
for left_i in 0..ids_left_len {
|
||||
let start_ids_idx = left_i * ids_right_len * ids_dim_len;
|
||||
@ -971,7 +971,7 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
|
||||
let mut dst = vec![T::zero(); dst_len];
|
||||
copy_strided_src_(v1, &mut dst, 0, l1);
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
};
|
||||
let dim = self.dim;
|
||||
@ -1256,74 +1256,6 @@ impl Map1 for Im2Col {
|
||||
}
|
||||
}
|
||||
|
||||
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
||||
|
||||
impl<'a> Map2 for ConvTranspose1D<'a> {
|
||||
const OP: &'static str = "conv_transpose1d";
|
||||
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
||||
let p = self.0;
|
||||
let inp = &inp[inp_l.start_offset()..];
|
||||
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
|
||||
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
|
||||
let l_out = p.l_out();
|
||||
|
||||
// Output shape: [b_size, c_out, l_out].
|
||||
let dst_elems = p.c_out * l_out * p.b_size;
|
||||
let dst = vec![T::zero(); dst_elems];
|
||||
let dst_s0 = p.c_out * l_out;
|
||||
let dst_s1 = l_out;
|
||||
let dst_s2 = 1;
|
||||
|
||||
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
|
||||
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
|
||||
let cont_s0 = p.l_in * p.c_in;
|
||||
let cont_s1 = p.c_in;
|
||||
for b_idx in 0..p.b_size {
|
||||
for l_idx in 0..p.l_in {
|
||||
for c_idx in 0..p.c_in {
|
||||
let src_idx = b_idx * inp_s0 + c_idx * inp_s1 + l_idx * inp_s2;
|
||||
let dst_idx = b_idx * cont_s0 + l_idx * cont_s1 + c_idx;
|
||||
inp_cont[dst_idx] = inp[src_idx]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for k_idx in 0..p.k_size {
|
||||
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
|
||||
let k_cont = (0..p.c_in)
|
||||
.map(|c_in_idx| k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_idx * k_s2])
|
||||
.collect::<Vec<_>>();
|
||||
for b_idx in 0..p.b_size {
|
||||
for l_idx in 0..p.l_in {
|
||||
let out_idx = l_idx * p.stride + k_idx * p.dilation;
|
||||
if out_idx < p.padding {
|
||||
continue;
|
||||
}
|
||||
let out_idx = out_idx - p.padding;
|
||||
if out_idx < l_out {
|
||||
let inp_cont = &inp_cont[b_idx * cont_s0 + l_idx * cont_s1..];
|
||||
let dst_idx = b_idx * dst_s0 + out_idx * dst_s2 + dst_c_idx * dst_s1;
|
||||
let mut d = T::zero();
|
||||
unsafe {
|
||||
T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
|
||||
}
|
||||
let dst_p = dst.as_ptr();
|
||||
// Safety: dst_idx are uniques per dst_c_idx which is used to
|
||||
// parallelise the different tasks so no two threads can try to
|
||||
// write at the same location.
|
||||
unsafe {
|
||||
let ptr = dst_p.add(dst_idx) as *mut T;
|
||||
*ptr += d
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
||||
|
||||
impl<'a> Map2 for Conv2D<'a> {
|
||||
@ -2503,16 +2435,6 @@ impl BackendStorage for CpuStorage {
|
||||
Ok(res_t)
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
@ -2617,25 +2539,25 @@ impl BackendStorage for CpuStorage {
|
||||
Self::U8(ids) => {
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
Self::U32(ids) => {
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
Self::I64(ids) => {
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")),
|
||||
}
|
||||
}
|
||||
|
||||
@ -2681,10 +2603,6 @@ impl BackendDevice for CpuDevice {
|
||||
Ok(Self)
|
||||
}
|
||||
|
||||
fn set_seed(&self, _seed: u64) -> Result<()> {
|
||||
crate::bail!("cannot seed the CPU rng with set_seed")
|
||||
}
|
||||
|
||||
fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> {
|
||||
use rand::prelude::*;
|
||||
|
||||
|
@ -223,14 +223,6 @@ impl BackendDevice for CudaDevice {
|
||||
})
|
||||
}
|
||||
|
||||
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||
// We do not call set_seed but instead create a new curand object. This ensures that the
|
||||
// state will be identical and the same random numbers will be generated.
|
||||
let mut curand = self.curand.lock().unwrap();
|
||||
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
crate::DeviceLocation::Cuda {
|
||||
gpu_id: self.device.ordinal(),
|
||||
@ -892,6 +884,8 @@ impl<'a> Map1 for IndexSelect<'a> {
|
||||
};
|
||||
let ids_shape = ids_l.shape();
|
||||
let ids_dims = ids_shape.dims();
|
||||
let ids_el = ids_shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(ids_el as u32);
|
||||
let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?;
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
@ -899,23 +893,19 @@ impl<'a> Map1 for IndexSelect<'a> {
|
||||
};
|
||||
let left_size: usize = src_l.dims()[..self.2].iter().product();
|
||||
let right_size: usize = src_l.dims()[self.2 + 1..].iter().product();
|
||||
let src_dim_size = src_l.dims()[self.2];
|
||||
let ids_dim_size = ids_shape.elem_count();
|
||||
let dst_el = ids_shape.elem_count() * left_size * right_size;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let dim_size = src_l.dims()[self.2];
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let out = unsafe { dev.alloc::<T>(ids_el * left_size * right_size) }.w()?;
|
||||
let params = (
|
||||
dst_el,
|
||||
ids_el,
|
||||
ids_dims.len(),
|
||||
&ds,
|
||||
ids,
|
||||
&src,
|
||||
&out,
|
||||
left_size,
|
||||
src_dim_size,
|
||||
ids_dim_size,
|
||||
dim_size,
|
||||
right_size,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
@ -1808,16 +1798,6 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(res_t)
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cudnn"))]
|
||||
fn conv2d(
|
||||
&self,
|
||||
@ -2181,7 +2161,7 @@ impl BackendStorage for CudaStorage {
|
||||
if src_l.is_contiguous() {
|
||||
dev.dtod_copy(&src, &mut dst).w()?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_f64", kernels::UNARY)?;
|
||||
let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
||||
// SAFETY: ffi.
|
||||
|
@ -34,9 +34,6 @@ pub(crate) fn launch_conv2d<
|
||||
params: &crate::conv::ParamsConv2D,
|
||||
dev: &crate::cuda_backend::CudaDevice,
|
||||
) -> crate::Result<()> {
|
||||
use crate::conv::CudnnFwdAlgo as CandleAlgo;
|
||||
use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
|
||||
|
||||
let device_id = dev.id();
|
||||
let cudnn = CUDNN.with(|cudnn| {
|
||||
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
||||
@ -93,20 +90,7 @@ pub(crate) fn launch_conv2d<
|
||||
w: &w,
|
||||
y: &y,
|
||||
};
|
||||
let alg = match params.cudnn_fwd_algo {
|
||||
None => conv2d.pick_algorithm()?,
|
||||
Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
||||
Some(CandleAlgo::ImplicitPrecompGemm) => {
|
||||
A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
|
||||
}
|
||||
Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
||||
Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
||||
Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
||||
Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
||||
Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
||||
Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
|
||||
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
|
||||
};
|
||||
let alg = conv2d.pick_algorithm()?;
|
||||
let workspace_size = conv2d.get_workspace_size(alg)?;
|
||||
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
|
||||
unsafe {
|
||||
|
@ -8,14 +8,12 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
|
||||
pub enum DeviceLocation {
|
||||
Cpu,
|
||||
Cuda { gpu_id: usize },
|
||||
Metal { gpu_id: usize },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Device {
|
||||
Cpu,
|
||||
Cuda(crate::CudaDevice),
|
||||
Metal(crate::MetalDevice),
|
||||
}
|
||||
|
||||
pub trait NdArray {
|
||||
@ -130,23 +128,10 @@ impl Device {
|
||||
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
||||
}
|
||||
|
||||
pub fn new_metal(ordinal: usize) -> Result<Self> {
|
||||
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
|
||||
}
|
||||
|
||||
pub fn set_seed(&self, seed: u64) -> Result<()> {
|
||||
match self {
|
||||
Self::Cpu => CpuDevice.set_seed(seed),
|
||||
Self::Cuda(c) => c.set_seed(seed),
|
||||
Self::Metal(m) => m.set_seed(seed),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn same_device(&self, rhs: &Self) -> bool {
|
||||
match (self, rhs) {
|
||||
(Self::Cpu, Self::Cpu) => true,
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
|
||||
(Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
@ -155,20 +140,21 @@ impl Device {
|
||||
match self {
|
||||
Self::Cpu => DeviceLocation::Cpu,
|
||||
Self::Cuda(device) => device.location(),
|
||||
Device::Metal(device) => device.location(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_cpu(&self) -> bool {
|
||||
matches!(self, Self::Cpu)
|
||||
match self {
|
||||
Self::Cpu => true,
|
||||
Self::Cuda(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_cuda(&self) -> bool {
|
||||
matches!(self, Self::Cuda(_))
|
||||
}
|
||||
|
||||
pub fn is_metal(&self) -> bool {
|
||||
matches!(self, Self::Metal(_))
|
||||
match self {
|
||||
Self::Cpu => false,
|
||||
Self::Cuda(_) => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
||||
@ -192,19 +178,8 @@ impl Device {
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
|
||||
if dtype == DType::F16 || dtype == DType::BF16 {
|
||||
let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
|
||||
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
|
||||
} else {
|
||||
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
Device::Metal(_device) => {
|
||||
// let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||
// Ok(Storage::Metal(storage))
|
||||
crate::bail!("Metal rand_uniform not implemented")
|
||||
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -231,18 +206,8 @@ impl Device {
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
|
||||
if dtype == DType::F16 || dtype == DType::BF16 {
|
||||
let storage = device.rand_normal(shape, DType::F32, mean, std)?;
|
||||
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
|
||||
} else {
|
||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -266,10 +231,6 @@ impl Device {
|
||||
let storage = device.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -283,10 +244,6 @@ impl Device {
|
||||
let storage = device.zeros_impl(shape, dtype)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.zeros_impl(shape, dtype)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -298,11 +255,6 @@ impl Device {
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = array.to_cpu_storage();
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -314,11 +266,6 @@ impl Device {
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = S::to_cpu_storage_owned(data);
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -14,9 +14,6 @@ impl Tensor {
|
||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||
format!(", cuda:{}", gpu_id)
|
||||
}
|
||||
crate::DeviceLocation::Metal { gpu_id } => {
|
||||
format!(", metal:{}", gpu_id)
|
||||
}
|
||||
};
|
||||
|
||||
write!(f, "Tensor[")?;
|
||||
@ -479,9 +476,6 @@ impl std::fmt::Display for Tensor {
|
||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||
format!(", cuda:{}", gpu_id)
|
||||
}
|
||||
crate::DeviceLocation::Metal { gpu_id } => {
|
||||
format!(", metal:{}", gpu_id)
|
||||
}
|
||||
};
|
||||
|
||||
write!(
|
||||
|
@ -67,20 +67,6 @@ impl DType {
|
||||
Self::F64 => 8,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_int(&self) -> bool {
|
||||
match self {
|
||||
Self::U8 | Self::U32 | Self::I64 => true,
|
||||
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_float(&self) -> bool {
|
||||
match self {
|
||||
Self::U8 | Self::U32 | Self::I64 => false,
|
||||
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait WithDType:
|
||||
|
@ -79,16 +79,6 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
_: &Layout,
|
||||
@ -177,10 +167,6 @@ impl crate::backend::BackendDevice for CudaDevice {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn set_seed(&self, _: u64) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
fail!()
|
||||
}
|
||||
|
@ -1,223 +0,0 @@
|
||||
#![allow(dead_code)]
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetalDevice;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MetalStorage;
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum MetalError {
|
||||
#[error("{0}")]
|
||||
Message(String),
|
||||
}
|
||||
|
||||
impl From<String> for MetalError {
|
||||
fn from(e: String) -> Self {
|
||||
MetalError::Message(e)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! fail {
|
||||
() => {
|
||||
unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
|
||||
};
|
||||
}
|
||||
|
||||
impl crate::backend::BackendStorage for MetalStorage {
|
||||
type Device = MetalDevice;
|
||||
|
||||
fn try_clone(&self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn dtype(&self) -> DType {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn device(&self) -> &Self::Device {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &crate::conv::ParamsConv2D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &crate::conv::ParamsConvTranspose2D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn index_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn matmul(
|
||||
&self,
|
||||
_: &Self,
|
||||
_: (usize, usize, usize, usize),
|
||||
_: &Layout,
|
||||
_: &Layout,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::backend::BackendDevice for MetalDevice {
|
||||
type Storage = MetalStorage;
|
||||
fn new(_: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn set_seed(&self, _: u64) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn same_device(&self, _: &Self) -> bool {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
use crate::{DType, DeviceLocation, Layout, MetalError, Shape};
|
||||
use crate::{DType, DeviceLocation, Layout, Shape};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MatMulUnexpectedStriding {
|
||||
@ -142,9 +142,6 @@ pub enum Error {
|
||||
#[error("{op} expects at least one tensor")]
|
||||
OpRequiresAtLeastOneTensor { op: &'static str },
|
||||
|
||||
#[error("{op} expects at least two tensors")]
|
||||
OpRequiresAtLeastTwoTensors { op: &'static str },
|
||||
|
||||
#[error("backward is not supported for {op}")]
|
||||
BackwardNotSupported { op: &'static str },
|
||||
|
||||
@ -152,9 +149,6 @@ pub enum Error {
|
||||
#[error("the candle crate has not been built with cuda support")]
|
||||
NotCompiledWithCudaSupport,
|
||||
|
||||
#[error("the candle crate has not been built with metal support")]
|
||||
NotCompiledWithMetalSupport,
|
||||
|
||||
#[error("cannot find tensor {path}")]
|
||||
CannotFindTensor { path: String },
|
||||
|
||||
@ -162,9 +156,6 @@ pub enum Error {
|
||||
#[error(transparent)]
|
||||
Cuda(Box<dyn std::error::Error + Send + Sync>),
|
||||
|
||||
#[error("Metal error {0}")]
|
||||
Metal(#[from] MetalError),
|
||||
|
||||
#[error(transparent)]
|
||||
TryFromIntError(#[from] core::num::TryFromIntError),
|
||||
|
||||
|
@ -49,12 +49,9 @@ mod device;
|
||||
pub mod display;
|
||||
mod dtype;
|
||||
mod dummy_cuda_backend;
|
||||
mod dummy_metal_backend;
|
||||
pub mod error;
|
||||
mod indexer;
|
||||
pub mod layout;
|
||||
#[cfg(feature = "metal")]
|
||||
pub mod metal_backend;
|
||||
#[cfg(feature = "mkl")]
|
||||
mod mkl;
|
||||
pub mod npy;
|
||||
@ -90,12 +87,6 @@ pub use cuda_backend::{CudaDevice, CudaStorage};
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
|
||||
|
||||
#[cfg(not(feature = "metal"))]
|
||||
pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage};
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
@ -134,15 +125,3 @@ impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
||||
self(xs)
|
||||
}
|
||||
}
|
||||
|
||||
// A trait defining a module with forward method using a single tensor argument and a flag to
|
||||
// separate the training and evaluation behaviors.
|
||||
pub trait ModuleT {
|
||||
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
|
||||
}
|
||||
|
||||
impl<M: Module> ModuleT for M {
|
||||
fn forward_t(&self, xs: &Tensor, _train: bool) -> Result<Tensor> {
|
||||
self.forward(xs)
|
||||
}
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -250,6 +250,8 @@ impl Tensor {
|
||||
if header.fortran_order {
|
||||
return Err(Error::Npy("fortran order not supported".to_string()));
|
||||
}
|
||||
let mut data: Vec<u8> = vec![];
|
||||
reader.read_to_end(&mut data)?;
|
||||
Self::from_reader(header.shape(), header.descr, &mut reader)
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
|
||||
use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor};
|
||||
use half::{bf16, f16};
|
||||
use num_traits::float::Float;
|
||||
|
||||
@ -58,13 +58,8 @@ pub enum UnaryOp {
|
||||
Sqr,
|
||||
Sqrt,
|
||||
Gelu,
|
||||
GeluErf,
|
||||
Erf,
|
||||
Relu,
|
||||
Tanh,
|
||||
Floor,
|
||||
Ceil,
|
||||
Round,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -90,16 +85,6 @@ pub enum Op {
|
||||
dilation: usize,
|
||||
},
|
||||
|
||||
#[allow(dead_code)]
|
||||
ConvTranspose1D {
|
||||
arg: Tensor,
|
||||
kernel: Tensor,
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
},
|
||||
|
||||
#[allow(dead_code)]
|
||||
Conv2D {
|
||||
arg: Tensor,
|
||||
@ -146,7 +131,6 @@ pub enum Op {
|
||||
Copy(Tensor),
|
||||
Broadcast(Tensor),
|
||||
Narrow(Tensor, usize, usize, usize),
|
||||
SliceScatter0(Tensor, Tensor, usize),
|
||||
Reshape(Tensor),
|
||||
ToDevice(Tensor),
|
||||
Transpose(Tensor, usize, usize),
|
||||
@ -184,18 +168,6 @@ pub trait CustomOp1 {
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_storage: &MetalStorage,
|
||||
_layout: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
||||
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
||||
/// The function should return the gradient of the argument.
|
||||
@ -231,20 +203,6 @@ pub trait CustomOp2 {
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
fn bwd(
|
||||
&self,
|
||||
_arg1: &Tensor,
|
||||
@ -287,22 +245,6 @@ pub trait CustomOp3 {
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
fn bwd(
|
||||
&self,
|
||||
_arg1: &Tensor,
|
||||
@ -383,13 +325,8 @@ pub(crate) struct Recip;
|
||||
pub(crate) struct Sqr;
|
||||
pub(crate) struct Sqrt;
|
||||
pub(crate) struct Gelu;
|
||||
pub(crate) struct GeluErf;
|
||||
pub(crate) struct Erf;
|
||||
pub(crate) struct Relu;
|
||||
pub(crate) struct Tanh;
|
||||
pub(crate) struct Floor;
|
||||
pub(crate) struct Ceil;
|
||||
pub(crate) struct Round;
|
||||
|
||||
macro_rules! bin_op {
|
||||
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
||||
@ -588,13 +525,13 @@ unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
|
||||
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
|
||||
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
|
||||
unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
|
||||
unary_op!(Abs, "abs", v, v.abs());
|
||||
unary_op!(Neg, "neg", v, -v);
|
||||
unary_op!(Recip, "recip", v, v.recip());
|
||||
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
||||
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
||||
|
||||
/// Tanh based approximation of the `gelu` operation
|
||||
/// GeluErf is the more precise one.
|
||||
/// `gelu` operation
|
||||
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||
impl UnaryOpT for Gelu {
|
||||
const NAME: &'static str = "gelu";
|
||||
@ -684,212 +621,6 @@ impl UnaryOpT for Gelu {
|
||||
}
|
||||
}
|
||||
|
||||
/// `erf` operation
|
||||
/// <https://en.wikipedia.org/wiki/Error_function>
|
||||
impl UnaryOpT for Erf {
|
||||
const NAME: &'static str = "erf";
|
||||
const KERNEL: &'static str = "uerf";
|
||||
const V: Self = Erf;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
bf16::from_f64(Self::f64(v.to_f64()))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
f16::from_f64(Self::f64(v.to_f64()))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
Self::f64(v as f64) as f32
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
crate::cpu::erf::erf(v)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(_: u8) -> u8 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Abs {
|
||||
const NAME: &'static str = "abs";
|
||||
const KERNEL: &'static str = "uabs";
|
||||
const V: Self = Abs;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v.abs()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v.abs()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v.abs()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v.abs()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
v.abs()
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Ceil {
|
||||
const NAME: &'static str = "ceil";
|
||||
const KERNEL: &'static str = "uceil";
|
||||
const V: Self = Ceil;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v.ceil()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v.ceil()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v.ceil()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v.ceil()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Floor {
|
||||
const NAME: &'static str = "floor";
|
||||
const KERNEL: &'static str = "ufloor";
|
||||
const V: Self = Floor;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v.floor()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v.floor()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v.floor()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v.floor()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Round {
|
||||
const NAME: &'static str = "round";
|
||||
const KERNEL: &'static str = "uround";
|
||||
const V: Self = Round;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v.round()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v.round()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v.round()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v.round()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for GeluErf {
|
||||
const NAME: &'static str = "gelu_erf";
|
||||
const KERNEL: &'static str = "ugelu_erf";
|
||||
const V: Self = GeluErf;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
bf16::from_f64(Self::f64(v.to_f64()))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
f16::from_f64(Self::f64(v.to_f64()))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
Self::f64(v as f64) as f32
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
(crate::cpu::erf::erf(v / 2f64.sqrt()) + 1.) * 0.5 * v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(_: u8) -> u8 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Relu {
|
||||
const NAME: &'static str = "relu";
|
||||
const KERNEL: &'static str = "urelu";
|
||||
@ -975,10 +706,6 @@ impl BackpropOp {
|
||||
};
|
||||
Self(op)
|
||||
}
|
||||
|
||||
pub(crate) fn is_none(&self) -> bool {
|
||||
self.0.is_none()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for BackpropOp {
|
||||
|
@ -193,50 +193,6 @@ impl Object {
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_tensor_info(
|
||||
self,
|
||||
name: Self,
|
||||
dir_name: &std::path::Path,
|
||||
) -> Result<Option<TensorInfo>> {
|
||||
let name = match name.unicode() {
|
||||
Ok(name) => name,
|
||||
Err(_) => return Ok(None),
|
||||
};
|
||||
let (callable, args) = match self.reduce() {
|
||||
Ok(callable_args) => callable_args,
|
||||
_ => return Ok(None),
|
||||
};
|
||||
let (callable, args) = match callable {
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "torch._tensor" && class_name == "_rebuild_from_type_v2" => {
|
||||
let mut args = args.tuple()?;
|
||||
let callable = args.remove(0);
|
||||
let args = args.remove(1);
|
||||
(callable, args)
|
||||
}
|
||||
_ => (callable, args),
|
||||
};
|
||||
match callable {
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
|
||||
_ => return Ok(None),
|
||||
};
|
||||
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
|
||||
let mut path = dir_name.to_path_buf();
|
||||
path.push(file_path);
|
||||
Ok(Some(TensorInfo {
|
||||
name,
|
||||
dtype,
|
||||
layout,
|
||||
path: path.to_string_lossy().into_owned(),
|
||||
storage_size,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Object> for String {
|
||||
@ -609,7 +565,6 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
|
||||
"HalfStorage" => DType::F16,
|
||||
"BFloat16Storage" => DType::BF16,
|
||||
"ByteStorage" => DType::U8,
|
||||
"LongStorage" => DType::I64,
|
||||
other => {
|
||||
crate::bail!("unsupported storage type {other}")
|
||||
}
|
||||
@ -668,10 +623,50 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
||||
};
|
||||
if let Object::Dict(key_values) = obj {
|
||||
for (name, value) in key_values.into_iter() {
|
||||
match value.into_tensor_info(name, &dir_name) {
|
||||
Ok(Some(tensor_info)) => tensor_infos.push(tensor_info),
|
||||
Ok(None) => {}
|
||||
Err(err) => eprintln!("skipping: {err:?}"),
|
||||
let name = match name.unicode() {
|
||||
Ok(name) => name,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let (callable, args) = match value.reduce() {
|
||||
Ok(callable_args) => callable_args,
|
||||
_ => continue,
|
||||
};
|
||||
let (callable, args) = match callable {
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "torch._tensor"
|
||||
&& class_name == "_rebuild_from_type_v2" =>
|
||||
{
|
||||
let mut args = args.tuple()?;
|
||||
let callable = args.remove(0);
|
||||
let args = args.remove(1);
|
||||
(callable, args)
|
||||
}
|
||||
_ => (callable, args),
|
||||
};
|
||||
match callable {
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
|
||||
_ => continue,
|
||||
};
|
||||
match rebuild_args(args) {
|
||||
Ok((layout, dtype, file_path, storage_size)) => {
|
||||
let mut path = dir_name.clone();
|
||||
path.push(file_path);
|
||||
tensor_infos.push(TensorInfo {
|
||||
name,
|
||||
dtype,
|
||||
layout,
|
||||
path: path.to_string_lossy().into_owned(),
|
||||
storage_size,
|
||||
})
|
||||
}
|
||||
Err(err) => {
|
||||
eprintln!("skipping {name}: {err:?}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -728,16 +723,3 @@ impl PthTensors {
|
||||
Ok(Some(tensor))
|
||||
}
|
||||
}
|
||||
|
||||
/// Read all the tensors from a PyTorch pth file.
|
||||
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
|
||||
let pth = PthTensors::new(path)?;
|
||||
let tensor_names = pth.tensor_infos.keys();
|
||||
let mut tensors = Vec::with_capacity(tensor_names.len());
|
||||
for name in tensor_names {
|
||||
if let Some(tensor) = pth.get(name)? {
|
||||
tensors.push((name.to_string(), tensor))
|
||||
}
|
||||
}
|
||||
Ok(tensors)
|
||||
}
|
||||
|
@ -50,9 +50,14 @@ pub(crate) unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 {
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
let nb = n / qk;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
@ -633,35 +638,3 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
|
||||
Ok(hsum_float_8(acc) + summs)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
let qk = QK_K;
|
||||
if n % qk != 0 {
|
||||
crate::bail!("vec_dot_q8k_8k: {n} is not divisible by {qk}")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||
let mut sumi = _mm256_setzero_si256();
|
||||
let x_qs = xs.qs.as_ptr();
|
||||
let y_qs = ys.qs.as_ptr();
|
||||
for j in (0..QK_K).step_by(32) {
|
||||
let xs = _mm256_loadu_si256(x_qs.add(j) as *const __m256i);
|
||||
let ys = _mm256_loadu_si256(y_qs.add(j) as *const __m256i);
|
||||
|
||||
let xs0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xs, 0));
|
||||
let ys0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(ys, 0));
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(xs0, ys0));
|
||||
|
||||
let xs1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xs, 1));
|
||||
let ys1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(ys, 1));
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(xs1, ys1));
|
||||
}
|
||||
let d = _mm256_set1_ps(xs.d * ys.d);
|
||||
acc = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi), acc);
|
||||
}
|
||||
Ok(hsum_float_8(acc))
|
||||
}
|
||||
}
|
||||
|
@ -135,13 +135,7 @@ pub fn qtensor_from_ggml(
|
||||
dims: Vec<usize>,
|
||||
) -> Result<super::QTensor> {
|
||||
let tensor_elems = dims.iter().product::<usize>();
|
||||
let blck_size = ggml_dtype.blck_size();
|
||||
if tensor_elems % blck_size != 0 {
|
||||
crate::bail!(
|
||||
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
|
||||
)
|
||||
}
|
||||
let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
|
||||
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
|
||||
|
||||
match ggml_dtype {
|
||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
|
||||
|
@ -29,7 +29,6 @@ impl TryFrom<u32> for Magic {
|
||||
pub enum VersionedMagic {
|
||||
GgufV1,
|
||||
GgufV2,
|
||||
GgufV3,
|
||||
}
|
||||
|
||||
impl VersionedMagic {
|
||||
@ -40,7 +39,6 @@ impl VersionedMagic {
|
||||
let versioned_magic = match (magic, version) {
|
||||
(Magic::Gguf, 1) => Self::GgufV1,
|
||||
(Magic::Gguf, 2) => Self::GgufV2,
|
||||
(Magic::Gguf, 3) => Self::GgufV3,
|
||||
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
|
||||
};
|
||||
Ok(versioned_magic)
|
||||
@ -61,13 +59,8 @@ impl TensorInfo {
|
||||
tensor_data_offset: u64,
|
||||
) -> Result<QTensor> {
|
||||
let tensor_elems = self.shape.elem_count();
|
||||
let blck_size = self.ggml_dtype.blck_size();
|
||||
if tensor_elems % blck_size != 0 {
|
||||
crate::bail!(
|
||||
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
|
||||
)
|
||||
}
|
||||
let size_in_bytes = tensor_elems / blck_size * self.ggml_dtype.type_size();
|
||||
let size_in_bytes =
|
||||
tensor_elems * self.ggml_dtype.type_size() / self.ggml_dtype.blck_size();
|
||||
let mut raw_data = vec![0u8; size_in_bytes];
|
||||
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
|
||||
reader.read_exact(&mut raw_data)?;
|
||||
@ -86,9 +79,7 @@ pub struct Content {
|
||||
fn read_string<R: std::io::Read>(reader: &mut R, magic: &VersionedMagic) -> Result<String> {
|
||||
let len = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||
reader.read_u64::<LittleEndian>()? as usize
|
||||
}
|
||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||
};
|
||||
let mut v = vec![0u8; len];
|
||||
reader.read_exact(&mut v)?;
|
||||
@ -288,9 +279,7 @@ impl Value {
|
||||
let value_type = ValueType::from_u32(value_type)?;
|
||||
let len = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||
reader.read_u64::<LittleEndian>()? as usize
|
||||
}
|
||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||
};
|
||||
let mut vs = Vec::with_capacity(len);
|
||||
for _ in 0..len {
|
||||
@ -387,15 +376,11 @@ impl Content {
|
||||
|
||||
let tensor_count = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||
reader.read_u64::<LittleEndian>()? as usize
|
||||
}
|
||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||
};
|
||||
let metadata_kv_count = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||
reader.read_u64::<LittleEndian>()? as usize
|
||||
}
|
||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||
};
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
@ -417,7 +402,7 @@ impl Content {
|
||||
reader.read_u32_into::<LittleEndian>(&mut dimensions)?;
|
||||
dimensions.into_iter().map(|c| c as usize).collect()
|
||||
}
|
||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||
VersionedMagic::GgufV2 => {
|
||||
let mut dimensions = vec![0; n_dimensions as usize];
|
||||
reader.read_u64_into::<LittleEndian>(&mut dimensions)?;
|
||||
dimensions.into_iter().map(|c| c as usize).collect()
|
||||
|
@ -34,9 +34,6 @@ pub trait GgmlType: Sized + Clone + Send + Sync {
|
||||
/// Dot product used as a building block for quantized mat-mul.
|
||||
/// n is the number of elements to be considered.
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>;
|
||||
|
||||
/// Generic implementation of the dot product without simd optimizations.
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@ -228,17 +225,15 @@ impl GgmlType for BlockQ4_0 {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q4_0_q8_0(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "simd128")]
|
||||
return super::simd128::vec_dot_q4_0_q8_0(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
let nb = n / qk;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||
}
|
||||
|
||||
// Generic implementation.
|
||||
let mut sumf = 0f32;
|
||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||
@ -260,10 +255,6 @@ impl GgmlType for BlockQ4_1 {
|
||||
type VecDotType = BlockQ8_1;
|
||||
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
// ggml_vec_dot_q4_1_q8_1
|
||||
let qk = QK8_1;
|
||||
if n % qk != 0 {
|
||||
@ -363,10 +354,7 @@ impl GgmlType for BlockQ5_0 {
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q5_0_q8_0: {n}, nb is not divisible by 2")
|
||||
}
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(_n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
// Generic implementation.
|
||||
let mut sumf = 0f32;
|
||||
|
||||
@ -457,10 +445,6 @@ impl GgmlType for BlockQ5_1 {
|
||||
type VecDotType = BlockQ8_1;
|
||||
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
let qk = Self::BLCK_SIZE;
|
||||
if n % Self::BLCK_SIZE != 0 {
|
||||
crate::bail!("vec_dot_q5_1_q8_1: {n} is not divisible by {qk}")
|
||||
@ -622,13 +606,6 @@ impl GgmlType for BlockQ8_0 {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q8_0_q8_0(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "simd128")]
|
||||
return super::simd128::vec_dot_q8_0_q8_0(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||
@ -654,11 +631,7 @@ impl GgmlType for BlockQ8_1 {
|
||||
const BLCK_SIZE: usize = QK8_1;
|
||||
type VecDotType = BlockQ8_1;
|
||||
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
unimplemented!("no support for vec-dot on Q8_1")
|
||||
}
|
||||
|
||||
@ -708,13 +681,6 @@ impl GgmlType for BlockQ2K {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q2k_q8k(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "simd128")]
|
||||
return super::simd128::vec_dot_q2k_q8k(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
@ -735,17 +701,18 @@ impl GgmlType for BlockQ2K {
|
||||
|
||||
let mut isum = 0;
|
||||
let mut is = 0;
|
||||
let mut d;
|
||||
for _ in 0..(QK_K / 128) {
|
||||
let mut shift = 0;
|
||||
for _ in 0..4 {
|
||||
let d = (sc[is] & 0xF) as i32;
|
||||
d = (sc[is] & 0xF) as i32;
|
||||
is += 1;
|
||||
let mut isuml = 0;
|
||||
for l in 0..16 {
|
||||
isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
|
||||
}
|
||||
isum += d * isuml;
|
||||
let d = (sc[is] & 0xF) as i32;
|
||||
d = (sc[is] & 0xF) as i32;
|
||||
is += 1;
|
||||
isuml = 0;
|
||||
for l in 16..32 {
|
||||
@ -884,10 +851,6 @@ impl GgmlType for BlockQ3K {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q3k_q8k(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
@ -1114,6 +1077,7 @@ impl GgmlType for BlockQ3K {
|
||||
let d_all = block.d.to_f32();
|
||||
let mut m = 1;
|
||||
let mut is = 0;
|
||||
let mut dl;
|
||||
|
||||
// Dequantize both 128 long blocks
|
||||
// 32 qs values per 128 long block
|
||||
@ -1124,7 +1088,7 @@ impl GgmlType for BlockQ3K {
|
||||
for (scale_index, scale_scoped_y) in
|
||||
shift_scoped_y.chunks_exact_mut(16).enumerate()
|
||||
{
|
||||
let dl = d_all * (scales[is] as f32 - 32.0);
|
||||
dl = d_all * (scales[is] as f32 - 32.0);
|
||||
for (i, inner_y) in scale_scoped_y.iter_mut().enumerate() {
|
||||
let new_y = dl
|
||||
* (((qs[i + 16 * scale_index] >> shift) & 3) as i8
|
||||
@ -1162,13 +1126,6 @@ impl GgmlType for BlockQ4K {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q4k_q8k(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "simd128")]
|
||||
return super::simd128::vec_dot_q4k_q8k(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
@ -1355,10 +1312,6 @@ impl GgmlType for BlockQ5K {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q5k_q8k(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
@ -1576,13 +1529,6 @@ impl GgmlType for BlockQ6K {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q6k_q8k(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "simd128")]
|
||||
return super::simd128::vec_dot_q6k_q8k(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
@ -1751,38 +1697,8 @@ impl GgmlType for BlockQ8K {
|
||||
const BLCK_SIZE: usize = QK_K;
|
||||
type VecDotType = BlockQ8K;
|
||||
|
||||
#[allow(unreachable_code)]
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
#[cfg(target_feature = "avx")]
|
||||
return super::avx::vec_dot_q8k_q8k(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q8k_q8k(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "simd128")]
|
||||
return super::simd128::vec_dot_q8k_q8k(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
let qk = QK_K;
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
|
||||
}
|
||||
|
||||
// Generic implementation.
|
||||
let mut sumf = 0f32;
|
||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||
let sum_i = xs
|
||||
.qs
|
||||
.iter()
|
||||
.zip(ys.qs.iter())
|
||||
.map(|(&x, &y)| x as i32 * y as i32)
|
||||
.sum::<i32>();
|
||||
sumf += sum_i as f32 * xs.d * ys.d
|
||||
}
|
||||
Ok(sumf)
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
|
||||
@ -1888,10 +1804,6 @@ impl GgmlType for f32 {
|
||||
type VecDotType = f32;
|
||||
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if xs.len() < n {
|
||||
crate::bail!("size mismatch {} < {n}", xs.len())
|
||||
}
|
||||
@ -1926,10 +1838,6 @@ impl GgmlType for f16 {
|
||||
type VecDotType = f16;
|
||||
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if xs.len() < n {
|
||||
crate::bail!("size mismatch {} < {n}", xs.len())
|
||||
}
|
||||
|
@ -7,8 +7,6 @@ pub mod gguf_file;
|
||||
pub mod k_quants;
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub mod neon;
|
||||
#[cfg(target_feature = "simd128")]
|
||||
pub mod simd128;
|
||||
pub mod utils;
|
||||
|
||||
pub use k_quants::GgmlType;
|
||||
@ -231,40 +229,20 @@ impl QTensor {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum QMatMul {
|
||||
QTensor(std::sync::Arc<QTensor>),
|
||||
Tensor(Tensor),
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static DEQUANTIZE_ALL: bool = {
|
||||
match std::env::var("CANDLE_DEQUANTIZE_ALL") {
|
||||
Ok(s) => {
|
||||
!s.is_empty() && s != "0"
|
||||
},
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
#[derive(Debug)]
|
||||
pub struct QMatMul(std::sync::Arc<QTensor>);
|
||||
|
||||
impl QMatMul {
|
||||
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
|
||||
let dequantize = match qtensor.dtype() {
|
||||
GgmlDType::F32 | GgmlDType::F16 => true,
|
||||
_ => DEQUANTIZE_ALL.with(|b| *b),
|
||||
};
|
||||
let t = if dequantize {
|
||||
let tensor = qtensor.dequantize(&Device::Cpu)?;
|
||||
Self::Tensor(tensor)
|
||||
} else {
|
||||
Self::QTensor(qtensor)
|
||||
};
|
||||
Ok(t)
|
||||
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Self {
|
||||
Self(qtensor)
|
||||
}
|
||||
|
||||
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
|
||||
Self::from_arc(std::sync::Arc::new(qtensor))
|
||||
pub fn from_qtensor(qtensor: QTensor) -> Self {
|
||||
Self(std::sync::Arc::new(qtensor))
|
||||
}
|
||||
|
||||
pub fn inner(&self) -> &std::sync::Arc<QTensor> {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
@ -309,16 +287,6 @@ impl crate::CustomOp1 for QTensor {
|
||||
|
||||
impl QMatMul {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
||||
Self::Tensor(w) => {
|
||||
let w = match *xs.dims() {
|
||||
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
|
||||
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
|
||||
_ => w.t()?,
|
||||
};
|
||||
xs.matmul(&w)
|
||||
}
|
||||
}
|
||||
xs.apply_op1_no_bwd(self.0.as_ref())
|
||||
}
|
||||
}
|
||||
|
@ -19,29 +19,42 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut sumv0 = vdupq_n_f32(0.0f32);
|
||||
for i in 0..nb {
|
||||
let mut sumv1 = vdupq_n_f32(0.0f32);
|
||||
for i in (0..nb).step_by(2) {
|
||||
let x0 = &xs[i];
|
||||
let x1 = &xs[i + 1];
|
||||
let y0 = &ys[i];
|
||||
let y1 = &ys[i + 1];
|
||||
|
||||
let m4b = vdupq_n_u8(0x0F);
|
||||
let s8b = vdupq_n_s8(0x8);
|
||||
|
||||
let v0_0 = vld1q_u8(x0.qs.as_ptr());
|
||||
let v0_1 = vld1q_u8(x1.qs.as_ptr());
|
||||
|
||||
// 4-bit -> 8-bit
|
||||
let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
|
||||
let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
||||
let v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
|
||||
let v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
||||
|
||||
// sub 8
|
||||
let v0_0ls = vsubq_s8(v0_0l, s8b);
|
||||
let v0_0hs = vsubq_s8(v0_0h, s8b);
|
||||
let v0_1ls = vsubq_s8(v0_1l, s8b);
|
||||
let v0_1hs = vsubq_s8(v0_1h, s8b);
|
||||
|
||||
// load y
|
||||
let v1_0l = vld1q_s8(y0.qs.as_ptr());
|
||||
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||
let v1_1l = vld1q_s8(y1.qs.as_ptr());
|
||||
let v1_1h = vld1q_s8(y1.qs.as_ptr().add(16));
|
||||
|
||||
// TODO: Support dotprod when it's available outside of nightly.
|
||||
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
|
||||
@ -49,16 +62,28 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
|
||||
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
|
||||
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
|
||||
|
||||
let pl1l = vmull_s8(vget_low_s8(v0_1ls), vget_low_s8(v1_1l));
|
||||
let pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
|
||||
let ph1l = vmull_s8(vget_low_s8(v0_1hs), vget_low_s8(v1_1h));
|
||||
let ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
|
||||
|
||||
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
||||
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
||||
let pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
||||
let ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
||||
|
||||
sumv0 = vmlaq_n_f32(
|
||||
sumv0,
|
||||
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
|
||||
x0.d.to_f32() * y0.d.to_f32(),
|
||||
);
|
||||
sumv1 = vmlaq_n_f32(
|
||||
sumv1,
|
||||
vcvtq_f32_s32(vaddq_s32(pl1, ph1)),
|
||||
x1.d.to_f32() * y1.d.to_f32(),
|
||||
);
|
||||
}
|
||||
Ok(vaddvq_f32(sumv0))
|
||||
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
|
||||
}
|
||||
}
|
||||
|
||||
@ -69,18 +94,28 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
let nb = n / QK8_0;
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
|
||||
}
|
||||
unsafe {
|
||||
let mut sumv0 = vdupq_n_f32(0.0f32);
|
||||
for i in 0..nb {
|
||||
let mut sumv1 = vdupq_n_f32(0.0f32);
|
||||
for i in (0..nb).step_by(2) {
|
||||
let x0 = &xs[i];
|
||||
let x1 = &xs[i + 1];
|
||||
let y0 = &ys[i];
|
||||
let y1 = &ys[i + 1];
|
||||
|
||||
let x0_0 = vld1q_s8(x0.qs.as_ptr());
|
||||
let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16));
|
||||
let x1_0 = vld1q_s8(x1.qs.as_ptr());
|
||||
let x1_1 = vld1q_s8(x1.qs.as_ptr().add(16));
|
||||
|
||||
// load y
|
||||
let y0_0 = vld1q_s8(y0.qs.as_ptr());
|
||||
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||
let y1_0 = vld1q_s8(y1.qs.as_ptr());
|
||||
let y1_1 = vld1q_s8(y1.qs.as_ptr().add(16));
|
||||
|
||||
// TODO dotprod once this is the intrinsics are.
|
||||
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
|
||||
@ -88,48 +123,31 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
||||
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
|
||||
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
|
||||
|
||||
let p1_0 = vmull_s8(vget_low_s8(x1_0), vget_low_s8(y1_0));
|
||||
let p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
|
||||
let p1_2 = vmull_s8(vget_low_s8(x1_1), vget_low_s8(y1_1));
|
||||
let p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
|
||||
|
||||
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
|
||||
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
|
||||
let p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
|
||||
let p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
|
||||
|
||||
sumv0 = vmlaq_n_f32(
|
||||
sumv0,
|
||||
vcvtq_f32_s32(vaddq_s32(p0, p1)),
|
||||
x0.d.to_f32() * y0.d.to_f32(),
|
||||
);
|
||||
sumv1 = vmlaq_n_f32(
|
||||
sumv1,
|
||||
vcvtq_f32_s32(vaddq_s32(p2, p3)),
|
||||
x1.d.to_f32() * y1.d.to_f32(),
|
||||
);
|
||||
}
|
||||
Ok(vaddvq_f32(sumv0))
|
||||
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
let qk = QK_K;
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
|
||||
}
|
||||
|
||||
let mut sumf = 0f32;
|
||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||
unsafe {
|
||||
let mut sum_i = vdupq_n_s32(0);
|
||||
let scale = xs.d * ys.d;
|
||||
let xs = xs.qs.as_ptr();
|
||||
let ys = ys.qs.as_ptr();
|
||||
for i in (0..QK_K).step_by(16) {
|
||||
let xs = vld1q_s8(xs.add(i));
|
||||
let ys = vld1q_s8(ys.add(i));
|
||||
let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys));
|
||||
let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys));
|
||||
|
||||
let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up));
|
||||
sum_i = vaddq_s32(sum_i, xy)
|
||||
}
|
||||
sumf += vaddvq_s32(sum_i) as f32 * scale
|
||||
}
|
||||
}
|
||||
Ok(sumf)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
|
@ -1,419 +0,0 @@
|
||||
use super::k_quants::{BlockQ2K, BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K};
|
||||
use crate::Result;
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
use half::f16;
|
||||
|
||||
use core::arch::wasm32::*;
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
unsafe {
|
||||
let mut acc = f32x4_splat(0.0f32);
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let x1234 = v128_load(x.qs.as_ptr() as *const v128);
|
||||
let x12 = v128_and(x1234, u8x16_splat(0x0F));
|
||||
let x12 = i8x16_sub(x12, i8x16_splat(8));
|
||||
let x34 = u8x16_shr(x1234, 4);
|
||||
let x34 = i8x16_sub(x34, i8x16_splat(8));
|
||||
|
||||
let x1 = i16x8_extend_low_i8x16(x12);
|
||||
let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr());
|
||||
let sum_xy = i32x4_dot_i16x8(x1, y1);
|
||||
|
||||
let x2 = i16x8_extend_high_i8x16(x12);
|
||||
let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2));
|
||||
|
||||
let x3 = i16x8_extend_low_i8x16(x34);
|
||||
let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3));
|
||||
|
||||
let x4 = i16x8_extend_high_i8x16(x34);
|
||||
let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4));
|
||||
|
||||
let sum_xy = f32x4_convert_i32x4(sum_xy);
|
||||
|
||||
// f32x4_relaxed_madd is nightly only.
|
||||
let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d));
|
||||
let scaled = f32x4_mul(sum_xy, d);
|
||||
acc = f32x4_add(acc, scaled)
|
||||
}
|
||||
let res = f32x4_extract_lane::<0>(acc)
|
||||
+ f32x4_extract_lane::<1>(acc)
|
||||
+ f32x4_extract_lane::<2>(acc)
|
||||
+ f32x4_extract_lane::<3>(acc);
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
unsafe {
|
||||
let mut acc = f32x4_splat(0.0f32);
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let x1 = i16x8_load_extend_i8x8(x.qs.as_ptr());
|
||||
let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr());
|
||||
let sum_xy = i32x4_dot_i16x8(x1, y1);
|
||||
|
||||
let x2 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(8));
|
||||
let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2));
|
||||
|
||||
let x3 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(16));
|
||||
let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3));
|
||||
|
||||
let x4 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(24));
|
||||
let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4));
|
||||
|
||||
let sum_xy = f32x4_convert_i32x4(sum_xy);
|
||||
|
||||
// f32x4_relaxed_madd is nightly only.
|
||||
let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d));
|
||||
let scaled = f32x4_mul(sum_xy, d);
|
||||
acc = f32x4_add(acc, scaled)
|
||||
}
|
||||
let res = f32x4_extract_lane::<0>(acc)
|
||||
+ f32x4_extract_lane::<1>(acc)
|
||||
+ f32x4_extract_lane::<2>(acc)
|
||||
+ f32x4_extract_lane::<3>(acc);
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
unsafe {
|
||||
let mut sumf = f32x4_splat(0f32);
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let mut q2: &[_] = &x.qs;
|
||||
let mut q8: &[_] = &y.qs;
|
||||
let sc = &x.scales;
|
||||
|
||||
let mut summs = i32x4_splat(0);
|
||||
for i in (0..(QK_K / 16)).step_by(4) {
|
||||
let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(i));
|
||||
let scales = i32x4_shr(
|
||||
i32x4(
|
||||
sc[i] as i32,
|
||||
sc[i + 1] as i32,
|
||||
sc[i + 2] as i32,
|
||||
sc[i + 3] as i32,
|
||||
),
|
||||
4,
|
||||
);
|
||||
summs = i32x4_add(summs, i32x4_mul(bsums, scales))
|
||||
}
|
||||
let summs = f32x4_convert_i32x4(summs);
|
||||
|
||||
let dall = y.d * x.d.to_f32();
|
||||
let dmin = y.d * x.dmin.to_f32();
|
||||
|
||||
let mut isum = i32x4_splat(0);
|
||||
let mut is = 0;
|
||||
for _ in 0..(QK_K / 128) {
|
||||
let mut shift = 0;
|
||||
for _ in 0..4 {
|
||||
let d = (sc[is] & 0xF) as i32;
|
||||
is += 1;
|
||||
let mut isuml = i16x8_splat(0);
|
||||
for l in (0..16).step_by(8) {
|
||||
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
|
||||
let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
|
||||
let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
|
||||
isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
|
||||
}
|
||||
let dd = i32x4_splat(d);
|
||||
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
|
||||
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
|
||||
let d = (sc[is] & 0xF) as i32;
|
||||
is += 1;
|
||||
let mut isuml = i16x8_splat(0);
|
||||
for l in (16..32).step_by(8) {
|
||||
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
|
||||
let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
|
||||
let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
|
||||
isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
|
||||
}
|
||||
let dd = i32x4_splat(d);
|
||||
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
|
||||
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
|
||||
shift += 2;
|
||||
// adjust the indexing
|
||||
q8 = &q8[32..];
|
||||
}
|
||||
// adjust the indexing
|
||||
q2 = &q2[32..];
|
||||
}
|
||||
let isum = f32x4_convert_i32x4(isum);
|
||||
sumf = f32x4_add(
|
||||
sumf,
|
||||
f32x4_sub(
|
||||
f32x4_mul(isum, f32x4_splat(dall)),
|
||||
f32x4_mul(summs, f32x4_splat(dmin)),
|
||||
),
|
||||
);
|
||||
}
|
||||
let sumf = f32x4_extract_lane::<0>(sumf)
|
||||
+ f32x4_extract_lane::<1>(sumf)
|
||||
+ f32x4_extract_lane::<2>(sumf)
|
||||
+ f32x4_extract_lane::<3>(sumf);
|
||||
Ok(sumf)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
|
||||
const KMASK1: u32 = 0x3f3f3f3f;
|
||||
const KMASK2: u32 = 0x0f0f0f0f;
|
||||
const KMASK3: u32 = 0x03030303;
|
||||
|
||||
let mut utmp: [u32; 4] = [0; 4];
|
||||
let mut scales: [u8; 8] = [0; 8];
|
||||
let mut mins: [u8; 8] = [0; 8];
|
||||
|
||||
let mut aux8: [u8; QK_K] = [0; QK_K];
|
||||
let mut sums = f32x4_splat(0f32);
|
||||
unsafe {
|
||||
for (y, x) in ys.iter().zip(xs.iter()) {
|
||||
let q4 = &x.qs;
|
||||
let q8 = &y.qs;
|
||||
|
||||
for j in 0..QK_K / 64 {
|
||||
let q4_1 = v128_load(q4.as_ptr().add(32 * j) as *const v128);
|
||||
let q4_2 = v128_load(q4.as_ptr().add(32 * j + 16) as *const v128);
|
||||
v128_store(
|
||||
aux8.as_mut_ptr().add(64 * j) as *mut v128,
|
||||
v128_and(q4_1, u8x16_splat(0x0F)),
|
||||
);
|
||||
v128_store(
|
||||
aux8.as_mut_ptr().add(64 * j + 16) as *mut v128,
|
||||
v128_and(q4_2, u8x16_splat(0x0F)),
|
||||
);
|
||||
v128_store(
|
||||
aux8.as_mut_ptr().add(64 * j + 32) as *mut v128,
|
||||
u8x16_shr(q4_1, 4),
|
||||
);
|
||||
v128_store(
|
||||
aux8.as_mut_ptr().add(64 * j + 48) as *mut v128,
|
||||
u8x16_shr(q4_2, 4),
|
||||
);
|
||||
}
|
||||
|
||||
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
|
||||
|
||||
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
|
||||
let uaux = utmp[1] & KMASK1;
|
||||
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
|
||||
utmp[2] = uaux;
|
||||
utmp[0] &= KMASK1;
|
||||
|
||||
//extract scales and mins
|
||||
LittleEndian::write_u32_into(&utmp[0..2], &mut scales);
|
||||
LittleEndian::write_u32_into(&utmp[2..4], &mut mins);
|
||||
|
||||
let mut sumi = i32x4_splat(0);
|
||||
for j in (0..QK_K / 16).step_by(4) {
|
||||
let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(j));
|
||||
let (m1, m2) = (mins[j / 2] as i32, mins[j / 2 + 1] as i32);
|
||||
let mins = i32x4(m1, m1, m2, m2);
|
||||
sumi = i32x4_add(sumi, i32x4_mul(bsums, mins));
|
||||
}
|
||||
|
||||
let mut aux32 = i32x4_splat(0i32);
|
||||
for (scale_i, scale) in scales.iter().enumerate() {
|
||||
let scale = i32x4_splat(*scale as i32);
|
||||
for j in 0..4 {
|
||||
let i = 32 * scale_i + 8 * j;
|
||||
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(i));
|
||||
let aux8 = i16x8_load_extend_u8x8(aux8.as_ptr().add(i));
|
||||
let aux16 = i16x8_mul(q8, aux8);
|
||||
aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_low_i16x8(aux16)));
|
||||
aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_high_i16x8(aux16)));
|
||||
}
|
||||
}
|
||||
let aux32 = f32x4_convert_i32x4(aux32);
|
||||
let d = f32x4_splat(x.d.to_f32() * y.d);
|
||||
sums = f32x4_add(sums, f32x4_mul(aux32, d));
|
||||
let dmin = x.dmin.to_f32() * y.d;
|
||||
let dmin = f32x4_splat(dmin);
|
||||
let sumi = f32x4_convert_i32x4(sumi);
|
||||
sums = f32x4_sub(sums, f32x4_mul(sumi, dmin));
|
||||
}
|
||||
let sums = f32x4_extract_lane::<0>(sums)
|
||||
+ f32x4_extract_lane::<1>(sums)
|
||||
+ f32x4_extract_lane::<2>(sums)
|
||||
+ f32x4_extract_lane::<3>(sums);
|
||||
Ok(sums)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
|
||||
let mut aux8 = [0i8; QK_K];
|
||||
unsafe {
|
||||
let mut sums = f32x4_splat(0f32);
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let q4 = &x.ql;
|
||||
let qh = &x.qh;
|
||||
let q8 = &y.qs;
|
||||
let mut aux32 = f32x4_splat(0f32);
|
||||
|
||||
for j in (0..QK_K).step_by(128) {
|
||||
let aux8 = aux8.as_mut_ptr().add(j);
|
||||
let q4 = &q4.as_ptr().add(j / 2);
|
||||
let qh = &qh.as_ptr().add(j / 4);
|
||||
for l in (0..32).step_by(16) {
|
||||
// aux8[l] = (((q4[l] & 0xF) | ((qh[l] & 3) << 4)) as i32 - 32) as i8;
|
||||
let a8 = v128_or(
|
||||
v128_and(v128_load(q4.add(l) as *const v128), u8x16_splat(0xF)),
|
||||
u8x16_shl(
|
||||
v128_and(v128_load(qh.add(l) as *const v128), u8x16_splat(3)),
|
||||
4,
|
||||
),
|
||||
);
|
||||
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||
v128_store(
|
||||
aux8.add(l) as *mut v128,
|
||||
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||
);
|
||||
|
||||
// aux8[l + 32] =
|
||||
// (((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i32 - 32) as i8;
|
||||
let a8 = v128_or(
|
||||
v128_and(v128_load(q4.add(l + 32) as *const v128), u8x16_splat(0xF)),
|
||||
u8x16_shl(
|
||||
v128_and(
|
||||
u8x16_shr(v128_load(qh.add(l) as *const v128), 2),
|
||||
u8x16_splat(3),
|
||||
),
|
||||
4,
|
||||
),
|
||||
);
|
||||
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||
v128_store(
|
||||
aux8.add(l + 32) as *mut v128,
|
||||
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||
);
|
||||
|
||||
// aux8[l + 64] = (((q4[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i32 - 32) as i8;
|
||||
let a8 = v128_or(
|
||||
u8x16_shr(v128_load(q4.add(l) as *const v128), 4),
|
||||
u8x16_shl(
|
||||
v128_and(
|
||||
u8x16_shr(v128_load(qh.add(l) as *const v128), 4),
|
||||
u8x16_splat(3),
|
||||
),
|
||||
4,
|
||||
),
|
||||
);
|
||||
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||
v128_store(
|
||||
aux8.add(l + 64) as *mut v128,
|
||||
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||
);
|
||||
|
||||
// aux8[l + 96] =
|
||||
// (((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i32 - 32) as i8;
|
||||
let a8 = v128_or(
|
||||
u8x16_shr(v128_load(q4.add(l + 32) as *const v128), 4),
|
||||
u8x16_shl(
|
||||
v128_and(
|
||||
u8x16_shr(v128_load(qh.add(l) as *const v128), 6),
|
||||
u8x16_splat(3),
|
||||
),
|
||||
4,
|
||||
),
|
||||
);
|
||||
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||
v128_store(
|
||||
aux8.add(l + 96) as *mut v128,
|
||||
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
for (j, &scale) in x.scales.iter().enumerate() {
|
||||
let scale = f32x4_splat(scale as f32);
|
||||
for offset in [0, 8] {
|
||||
let aux16 = i16x8_mul(
|
||||
i16x8_load_extend_i8x8(q8.as_ptr().add(16 * j + offset)),
|
||||
i16x8_load_extend_i8x8(aux8.as_ptr().add(16 * j + offset)),
|
||||
);
|
||||
aux32 = f32x4_add(
|
||||
aux32,
|
||||
f32x4_mul(f32x4_convert_i32x4(i32x4_extend_low_i16x8(aux16)), scale),
|
||||
);
|
||||
aux32 = f32x4_add(
|
||||
aux32,
|
||||
f32x4_mul(f32x4_convert_i32x4(i32x4_extend_high_i16x8(aux16)), scale),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let d = f32x4_splat(x.d.to_f32() * y.d);
|
||||
sums = f32x4_add(sums, f32x4_mul(aux32, d));
|
||||
}
|
||||
let sums = f32x4_extract_lane::<0>(sums)
|
||||
+ f32x4_extract_lane::<1>(sums)
|
||||
+ f32x4_extract_lane::<2>(sums)
|
||||
+ f32x4_extract_lane::<3>(sums);
|
||||
Ok(sums)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
let qk = QK_K;
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut acc = f32x4_splat(0.0f32);
|
||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||
let x_qs = xs.qs.as_ptr();
|
||||
let y_qs = ys.qs.as_ptr();
|
||||
let mut sumi = i32x4_splat(0);
|
||||
for j in (0..QK_K).step_by(8) {
|
||||
let xs = i16x8_load_extend_i8x8(x_qs.add(j));
|
||||
let ys = i16x8_load_extend_i8x8(y_qs.add(j));
|
||||
let sum_xy = i32x4_dot_i16x8(xs, ys);
|
||||
sumi = i32x4_add(sumi, sum_xy)
|
||||
}
|
||||
let d = f32x4_splat(xs.d * ys.d);
|
||||
acc = f32x4_add(acc, f32x4_mul(f32x4_convert_i32x4(sumi), d))
|
||||
}
|
||||
let res = f32x4_extract_lane::<0>(acc)
|
||||
+ f32x4_extract_lane::<1>(acc)
|
||||
+ f32x4_extract_lane::<2>(acc)
|
||||
+ f32x4_extract_lane::<3>(acc);
|
||||
Ok(res)
|
||||
}
|
||||
}
|
@ -17,7 +17,7 @@ pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
|
||||
let expected_blocks = xs.len() / block_size;
|
||||
let actual_blocks = ys.len();
|
||||
|
||||
// Validate that the input is the right size
|
||||
//validate that the input is the right size
|
||||
if expected_blocks != actual_blocks {
|
||||
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
|
||||
}
|
||||
@ -37,12 +37,12 @@ pub(super) fn group_for_dequantization<'a, 'b, T: super::k_quants::GgmlType>(
|
||||
|
||||
let actual_output_len = ys.len();
|
||||
let expected_output_len = xs.len() * block_size;
|
||||
// Validate that the output is the right size
|
||||
//validate that the output is the right size
|
||||
if expected_output_len != actual_output_len {
|
||||
crate::bail!("dequantize {dtype:?}: ys (len = {actual_output_len}) does not match the expected length of {expected_output_len}!")
|
||||
}
|
||||
|
||||
// Zip the blocks and outputs together
|
||||
//zip the blocks and outputs together
|
||||
Ok(xs.iter().zip(ys.chunks_exact_mut(block_size)).collect())
|
||||
}
|
||||
|
||||
|
@ -251,134 +251,6 @@ pub fn save<K: AsRef<str> + Ord + std::fmt::Display, P: AsRef<Path>>(
|
||||
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
|
||||
}
|
||||
|
||||
#[derive(yoke::Yokeable)]
|
||||
struct SafeTensors_<'a>(SafeTensors<'a>);
|
||||
|
||||
pub struct MmapedSafetensors {
|
||||
safetensors: Vec<yoke::Yoke<SafeTensors_<'static>, memmap2::Mmap>>,
|
||||
routing: Option<HashMap<String, usize>>,
|
||||
}
|
||||
|
||||
impl MmapedSafetensors {
|
||||
/// Creates a wrapper around a memory mapped file and deserialize the safetensors header.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The unsafe is inherited from [`memmap2::MmapOptions`].
|
||||
pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
|
||||
let p = p.as_ref();
|
||||
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
|
||||
let file = memmap2::MmapOptions::new()
|
||||
.map(&file)
|
||||
.map_err(|e| Error::from(e).with_path(p))?;
|
||||
let safetensors = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
|
||||
file,
|
||||
|data: &[u8]| {
|
||||
let st = safetensors::SafeTensors::deserialize(data)
|
||||
.map_err(|e| Error::from(e).with_path(p))?;
|
||||
Ok::<_, Error>(SafeTensors_(st))
|
||||
},
|
||||
)?;
|
||||
Ok(Self {
|
||||
safetensors: vec![safetensors],
|
||||
routing: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates a wrapper around multiple memory mapped file and deserialize the safetensors headers.
|
||||
///
|
||||
/// If a tensor name appears in multiple files, the last entry is returned.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The unsafe is inherited from [`memmap2::MmapOptions`].
|
||||
pub unsafe fn multi<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
|
||||
let mut routing = HashMap::new();
|
||||
let mut safetensors = vec![];
|
||||
for (index, p) in paths.iter().enumerate() {
|
||||
let p = p.as_ref();
|
||||
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
|
||||
let file = memmap2::MmapOptions::new()
|
||||
.map(&file)
|
||||
.map_err(|e| Error::from(e).with_path(p))?;
|
||||
let data = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
|
||||
file,
|
||||
|data: &[u8]| {
|
||||
let st = safetensors::SafeTensors::deserialize(data)
|
||||
.map_err(|e| Error::from(e).with_path(p))?;
|
||||
Ok::<_, Error>(SafeTensors_(st))
|
||||
},
|
||||
)?;
|
||||
for k in data.get().0.names() {
|
||||
routing.insert(k.to_string(), index);
|
||||
}
|
||||
safetensors.push(data)
|
||||
}
|
||||
Ok(Self {
|
||||
safetensors,
|
||||
routing: Some(routing),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
|
||||
self.get(name)?.load(dev)
|
||||
}
|
||||
|
||||
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
|
||||
let mut tensors = vec![];
|
||||
for safetensors in self.safetensors.iter() {
|
||||
tensors.push(safetensors.get().0.tensors())
|
||||
}
|
||||
tensors.into_iter().flatten().collect()
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
|
||||
let index = match &self.routing {
|
||||
None => 0,
|
||||
Some(routing) => {
|
||||
let index = routing.get(name).ok_or_else(|| {
|
||||
Error::CannotFindTensor {
|
||||
path: name.to_string(),
|
||||
}
|
||||
.bt()
|
||||
})?;
|
||||
*index
|
||||
}
|
||||
};
|
||||
Ok(self.safetensors[index].get().0.tensor(name)?)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BufferedSafetensors {
|
||||
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
|
||||
}
|
||||
|
||||
impl BufferedSafetensors {
|
||||
/// Creates a wrapper around a binary buffer and deserialize the safetensors header.
|
||||
pub fn new(buffer: Vec<u8>) -> Result<Self> {
|
||||
let safetensors = yoke::Yoke::<SafeTensors_<'static>, Vec<u8>>::try_attach_to_cart(
|
||||
buffer,
|
||||
|data: &[u8]| {
|
||||
let st = safetensors::SafeTensors::deserialize(data)?;
|
||||
Ok::<_, Error>(SafeTensors_(st))
|
||||
},
|
||||
)?;
|
||||
Ok(Self { safetensors })
|
||||
}
|
||||
|
||||
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
|
||||
self.get(name)?.load(dev)
|
||||
}
|
||||
|
||||
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
|
||||
self.safetensors.get().0.tensors()
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
|
||||
Ok(self.safetensors.get().0.tensor(name)?)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MmapedFile {
|
||||
path: std::path::PathBuf,
|
||||
inner: memmap2::Mmap,
|
||||
|
@ -203,7 +203,7 @@ impl Shape {
|
||||
|
||||
/// Check whether the two shapes are compatible for broadcast, and if it is the case return the
|
||||
/// broadcasted shape. This is to be used for binary pointwise ops.
|
||||
pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
|
||||
pub(crate) fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
|
||||
let lhs = self;
|
||||
let lhs_dims = lhs.dims();
|
||||
let rhs_dims = rhs.dims();
|
||||
@ -511,119 +511,154 @@ impl ShapeWithOneHole for ((),) {
|
||||
}
|
||||
}
|
||||
|
||||
fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
|
||||
if prod_d == 0 {
|
||||
crate::bail!("cannot reshape tensor of {el_count} elements to {s:?}")
|
||||
}
|
||||
if el_count % prod_d != 0 {
|
||||
crate::bail!("cannot reshape tensor with {el_count} elements to {s:?}")
|
||||
}
|
||||
Ok(el_count / prod_d)
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1) = self;
|
||||
Ok((hole_size(el_count, d1, &self)?, d1).into())
|
||||
if el_count % d1 != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
|
||||
}
|
||||
Ok((el_count / d1, d1).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, ()) = self;
|
||||
Ok((d1, hole_size(el_count, d1, &self)?).into())
|
||||
if el_count % d1 != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
|
||||
}
|
||||
Ok((d1, el_count / d1).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1, d2) = self;
|
||||
Ok((hole_size(el_count, d1 * d2, &self)?, d1, d2).into())
|
||||
let d = d1 * d2;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((el_count / d, d1, d2).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, (), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, (), d2) = self;
|
||||
Ok((d1, hole_size(el_count, d1 * d2, &self)?, d2).into())
|
||||
let d = d1 * d2;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, el_count / d, d2).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, ()) = self;
|
||||
Ok((d1, d2, hole_size(el_count, d1 * d2, &self)?).into())
|
||||
let d = d1 * d2;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, el_count / d).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize, usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1, d2, d3) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
|
||||
Ok((d, d1, d2, d3).into())
|
||||
let d = d1 * d2 * d3;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((el_count / d, d1, d2, d3).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, (), usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, (), d2, d3) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
|
||||
Ok((d1, d, d2, d3).into())
|
||||
let d = d1 * d2 * d3;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, el_count / d, d2, d3).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, (), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, (), d3) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
|
||||
Ok((d1, d2, d, d3).into())
|
||||
let d = d1 * d2 * d3;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, el_count / d, d3).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, d3, ()) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
|
||||
Ok((d1, d2, d3, d).into())
|
||||
let d = d1 * d2 * d3;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, d3, el_count / d).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1, d2, d3, d4) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
||||
Ok((d, d1, d2, d3, d4).into())
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((el_count / d, d1, d2, d3, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, (), d2, d3, d4) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
||||
Ok((d1, d, d2, d3, d4).into())
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, el_count / d, d2, d3, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, (), d3, d4) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
||||
Ok((d1, d2, d, d3, d4).into())
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, el_count / d, d3, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, d3, (), d4) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
||||
Ok((d1, d2, d3, d, d4).into())
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, d3, el_count / d, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, d3, d4, ()) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
||||
Ok((d1, d2, d3, d4, d).into())
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, d3, d4, el_count / d).into())
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
|
||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
|
||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
|
||||
|
||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||
// out of memory. Instead try_clone should be used.
|
||||
@ -8,7 +8,6 @@ use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage,
|
||||
pub enum Storage {
|
||||
Cpu(CpuStorage),
|
||||
Cuda(CudaStorage),
|
||||
Metal(MetalStorage),
|
||||
}
|
||||
|
||||
impl Storage {
|
||||
@ -19,10 +18,6 @@ impl Storage {
|
||||
let storage = storage.try_clone(layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.try_clone(layout)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -30,7 +25,6 @@ impl Storage {
|
||||
match self {
|
||||
Self::Cpu(_) => Device::Cpu,
|
||||
Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
|
||||
Self::Metal(storage) => Device::Metal(storage.device().clone()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -38,7 +32,6 @@ impl Storage {
|
||||
match self {
|
||||
Self::Cpu(storage) => storage.dtype(),
|
||||
Self::Cuda(storage) => storage.dtype(),
|
||||
Self::Metal(storage) => storage.dtype(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -72,10 +65,6 @@ impl Storage {
|
||||
let storage = storage.affine(layout, mul, add)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.affine(layout, mul, add)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -89,10 +78,6 @@ impl Storage {
|
||||
let storage = storage.powf(layout, alpha)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.powf(layout, alpha)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -106,10 +91,6 @@ impl Storage {
|
||||
let storage = storage.elu(layout, alpha)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.elu(layout, alpha)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -131,10 +112,6 @@ impl Storage {
|
||||
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
(lhs, rhs) => {
|
||||
// Should not happen because of the same device check above but we're defensive
|
||||
// anyway.
|
||||
@ -158,10 +135,6 @@ impl Storage {
|
||||
let storage = storage.reduce_op(op, layout, s)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.reduce_op(op, layout, s)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -175,10 +148,6 @@ impl Storage {
|
||||
let storage = storage.to_dtype(layout, dtype)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.to_dtype(layout, dtype)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -192,10 +161,6 @@ impl Storage {
|
||||
let (storage, shape) = c.cuda_fwd(storage, l)?;
|
||||
Ok((Self::Cuda(storage), shape))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let (storage, shape) = c.metal_fwd(storage, l)?;
|
||||
Ok((Self::Metal(storage), shape))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -216,10 +181,6 @@ impl Storage {
|
||||
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
|
||||
Ok((Self::Cuda(s), shape))
|
||||
}
|
||||
(Self::Metal(s1), Self::Metal(s2)) => {
|
||||
let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?;
|
||||
Ok((Self::Metal(s), shape))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -244,10 +205,6 @@ impl Storage {
|
||||
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
|
||||
Ok((Self::Cuda(s), shape))
|
||||
}
|
||||
(Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
|
||||
let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?;
|
||||
Ok((Self::Metal(s), shape))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -262,10 +219,6 @@ impl Storage {
|
||||
let storage = storage.unary_impl::<B>(layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.unary_impl::<B>(layout)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -286,10 +239,6 @@ impl Storage {
|
||||
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
(lhs, rhs) => {
|
||||
// Should not happen because of the same device check above but we're defensive
|
||||
// anyway.
|
||||
@ -321,10 +270,6 @@ impl Storage {
|
||||
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Metal(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -334,33 +279,6 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn conv_transpose1d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
self.same_device(kernel, "conv-transpose1d")?;
|
||||
self.same_dtype(kernel, "conv-transpose1d")?;
|
||||
match (self, &kernel) {
|
||||
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
|
||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cpu(s))
|
||||
}
|
||||
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
|
||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "conv-transpose1d",
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn conv2d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
@ -379,10 +297,6 @@ impl Storage {
|
||||
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Metal(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -410,10 +324,6 @@ impl Storage {
|
||||
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Metal(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -438,10 +348,6 @@ impl Storage {
|
||||
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -460,10 +366,6 @@ impl Storage {
|
||||
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -477,10 +379,6 @@ impl Storage {
|
||||
let storage = storage.upsample_nearest1d(layout, sz)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.upsample_nearest1d(layout, sz)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -494,10 +392,6 @@ impl Storage {
|
||||
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -521,10 +415,6 @@ impl Storage {
|
||||
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => {
|
||||
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -551,10 +441,6 @@ impl Storage {
|
||||
let storage = s.gather(l, indexes, indexes_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(s), Self::Metal(indexes)) => {
|
||||
let storage = s.gather(l, indexes, indexes_l, d)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -579,10 +465,6 @@ impl Storage {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -607,10 +489,6 @@ impl Storage {
|
||||
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -632,10 +510,6 @@ impl Storage {
|
||||
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -663,10 +537,6 @@ impl Storage {
|
||||
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -686,9 +556,6 @@ impl Storage {
|
||||
match (self, dst) {
|
||||
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
|
||||
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
|
||||
(Self::Metal(src), Self::Metal(dst)) => {
|
||||
Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
|
@ -6,7 +6,7 @@ use crate::op::{
|
||||
};
|
||||
use crate::scalar::TensorOrScalar;
|
||||
use crate::shape::{Dim, Dims};
|
||||
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
/// Unique identifier for tensors.
|
||||
@ -177,9 +177,14 @@ impl Tensor {
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let none = BackpropOp::none();
|
||||
let shape = shape.into();
|
||||
let storage = device.ones(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
if is_variable {
|
||||
let shape = shape.into();
|
||||
let storage = device.ones(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
} else {
|
||||
let storage = device.ones(&crate::shape::SCALAR, dtype)?;
|
||||
from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new tensor filled with ones.
|
||||
@ -217,9 +222,14 @@ impl Tensor {
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let none = BackpropOp::none();
|
||||
let shape = shape.into();
|
||||
let storage = device.zeros(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
if is_variable {
|
||||
let shape = shape.into();
|
||||
let storage = device.zeros(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
} else {
|
||||
let storage = device.zeros(&crate::shape::SCALAR, dtype)?;
|
||||
from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new tensor filled with zeros.
|
||||
@ -385,21 +395,11 @@ impl Tensor {
|
||||
step: D,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
if D::is_zero(&step) {
|
||||
crate::bail!("step cannot be zero")
|
||||
}
|
||||
let mut data = vec![];
|
||||
let mut current = start;
|
||||
if step >= D::zero() {
|
||||
while current < end {
|
||||
data.push(current);
|
||||
current += step;
|
||||
}
|
||||
} else {
|
||||
while current > end {
|
||||
data.push(current);
|
||||
current += step;
|
||||
}
|
||||
while current < end {
|
||||
data.push(current);
|
||||
current += step;
|
||||
}
|
||||
let len = data.len();
|
||||
Self::from_vec_impl(data, len, device, false)
|
||||
@ -459,7 +459,7 @@ impl Tensor {
|
||||
|
||||
/// Returns true if the computation graph should track this op, that is if it is
|
||||
/// a variable or if it has some variable as dependencies.
|
||||
pub fn track_op(&self) -> bool {
|
||||
pub(crate) fn track_op(&self) -> bool {
|
||||
self.is_variable || self.op.is_some()
|
||||
}
|
||||
|
||||
@ -477,12 +477,6 @@ impl Tensor {
|
||||
broadcast_binary_op!(broadcast_div, div);
|
||||
broadcast_binary_op!(broadcast_maximum, maximum);
|
||||
broadcast_binary_op!(broadcast_minimum, minimum);
|
||||
broadcast_binary_op!(broadcast_eq, eq);
|
||||
broadcast_binary_op!(broadcast_ne, ne);
|
||||
broadcast_binary_op!(broadcast_lt, lt);
|
||||
broadcast_binary_op!(broadcast_le, le);
|
||||
broadcast_binary_op!(broadcast_gt, gt);
|
||||
broadcast_binary_op!(broadcast_ge, ge);
|
||||
|
||||
unary_op!(recip, Recip);
|
||||
unary_op!(neg, Neg);
|
||||
@ -495,21 +489,7 @@ impl Tensor {
|
||||
unary_op!(sqr, Sqr);
|
||||
unary_op!(sqrt, Sqrt);
|
||||
unary_op!(gelu, Gelu);
|
||||
unary_op!(gelu_erf, GeluErf);
|
||||
unary_op!(erf, Erf);
|
||||
unary_op!(relu, Relu);
|
||||
unary_op!(ceil, Ceil);
|
||||
unary_op!(floor, Floor);
|
||||
unary_op!(round, Round);
|
||||
|
||||
/// Round element of the input tensor to the nearest integer.
|
||||
///
|
||||
/// If the number of decimals is negative, it specifies the number of positions to the left of
|
||||
/// the decimal point.
|
||||
pub fn round_to(&self, decimals: i32) -> Result<Self> {
|
||||
let mult = 10f64.powi(decimals);
|
||||
(self * mult)?.round()? * (1f64 / mult)
|
||||
}
|
||||
|
||||
/// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple
|
||||
/// dimensions, an error is returned instead.
|
||||
@ -529,7 +509,6 @@ impl Tensor {
|
||||
match &*self.storage() {
|
||||
Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
|
||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
}
|
||||
}
|
||||
|
||||
@ -557,73 +536,6 @@ impl Tensor {
|
||||
Ok(inp)
|
||||
}
|
||||
|
||||
/// Creates grids of coordinates specified by the 1D inputs.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `args` - A slice of 1D tensors.
|
||||
/// * `xy_indexing` - Whether to use xy indexing or ij indexing. If xy is selected, the
|
||||
/// first dimension corresponds to the cardinality of the second input and the second
|
||||
/// dimension corresponds to the cardinality of the first input. If ij is selected, the
|
||||
/// dimensions are in the same order as the cardinality of the inputs.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device, Shape};
|
||||
/// let x = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?;
|
||||
/// let y = Tensor::new(&[4f32, 5., 6.], &Device::Cpu)?;
|
||||
///
|
||||
/// let grids_xy = Tensor::meshgrid(&[&x, &y], true)?;
|
||||
///
|
||||
/// assert_eq!(grids_xy.len(), 2);
|
||||
/// assert_eq!(grids_xy[0].dims(), &[3, 3]);
|
||||
///
|
||||
/// assert_eq!(grids_xy[0].to_vec2::<f32>()?, &[[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]]);
|
||||
/// assert_eq!(grids_xy[1].to_vec2::<f32>()?, &[[4., 4., 4.], [5., 5., 5.], [6., 6., 6.]]);
|
||||
///
|
||||
/// let grids_ij = Tensor::meshgrid(&[&x, &y], false)?;
|
||||
///
|
||||
/// assert_eq!(grids_ij[0].to_vec2::<f32>()?, &[[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]]);
|
||||
/// assert_eq!(grids_ij[1].to_vec2::<f32>()?, &[[4., 5., 6.], [4., 5., 6.], [4., 5., 6.]]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// * Will return `Err` if `args` contains less than 2 tensors.
|
||||
///
|
||||
pub fn meshgrid<A: AsRef<Tensor>>(args: &[A], xy_indexing: bool) -> Result<Vec<Self>> {
|
||||
if args.len() <= 1 {
|
||||
Err(Error::OpRequiresAtLeastTwoTensors { op: "meshgrid" }.bt())?
|
||||
}
|
||||
let args: Vec<_> = if xy_indexing {
|
||||
args.iter().rev().collect()
|
||||
} else {
|
||||
args.iter().collect()
|
||||
};
|
||||
|
||||
let mut shape = Vec::with_capacity(args.len());
|
||||
for arg in args.iter() {
|
||||
shape.push(arg.as_ref().dims1()?)
|
||||
}
|
||||
|
||||
let mut grids = Vec::with_capacity(args.len());
|
||||
for idx in 0..args.len() {
|
||||
let mut ones = vec![1usize; args.len()];
|
||||
ones[idx] = shape[idx];
|
||||
let arg = args[idx].as_ref().reshape(ones)?;
|
||||
let mut repeats = shape.clone();
|
||||
repeats[idx] = 1;
|
||||
let repeated_tensor = arg.repeat(repeats)?;
|
||||
grids.push(repeated_tensor);
|
||||
}
|
||||
if xy_indexing {
|
||||
grids.reverse();
|
||||
}
|
||||
Ok(grids)
|
||||
}
|
||||
|
||||
/// This operation multiplies the input tensor by `mul` then adds `add` and return the result.
|
||||
/// The input values `mul` and `add` are casted to the appropriate type so some rounding might
|
||||
/// be performed.
|
||||
@ -699,23 +611,15 @@ impl Tensor {
|
||||
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
|
||||
let dims = self.dims();
|
||||
let dim = dim.to_index(self.shape(), "narrow")?;
|
||||
let err = |msg| {
|
||||
Err::<(), _>(
|
||||
Error::NarrowInvalidArgs {
|
||||
shape: self.shape().clone(),
|
||||
dim,
|
||||
start,
|
||||
len,
|
||||
msg,
|
||||
}
|
||||
.bt(),
|
||||
)
|
||||
};
|
||||
if start > dims[dim] {
|
||||
err("start > dim_len")?
|
||||
}
|
||||
if start.saturating_add(len) > dims[dim] {
|
||||
err("start + len > dim_len")?
|
||||
if start + len > dims[dim] {
|
||||
Err(Error::NarrowInvalidArgs {
|
||||
shape: self.shape().clone(),
|
||||
dim,
|
||||
start,
|
||||
len,
|
||||
msg: "start + len > dim_len",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if start == 0 && dims[dim] == len {
|
||||
Ok(self.clone())
|
||||
@ -856,20 +760,6 @@ impl Tensor {
|
||||
self.sum_impl(mean_dims, false)? * scale
|
||||
}
|
||||
|
||||
/// Returns the unbiased variance over the selected dimension.
|
||||
pub fn var_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "var")?;
|
||||
let mean = self.mean_keepdim(dim)?;
|
||||
let squares = self.broadcast_sub(&mean)?.sqr()?;
|
||||
squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64
|
||||
}
|
||||
|
||||
/// Returns the unbiased variance over the selected dimension.
|
||||
pub fn var<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "var")?;
|
||||
self.var_keepdim(dim)?.squeeze(dim)
|
||||
}
|
||||
|
||||
/// Gathers the maximum value across the selected dimension. The resulting shape has the same
|
||||
/// number of dimensions as the original tensor and the select dimension has a single element.
|
||||
pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
@ -1217,16 +1107,14 @@ impl Tensor {
|
||||
op: "scatter-add (self, src)",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
})?
|
||||
}
|
||||
if indexes.dims() != source.dims() {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "scatter-add (indexes, src)",
|
||||
lhs: indexes.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
})?
|
||||
}
|
||||
let storage = self.storage().scatter_add(
|
||||
self.layout(),
|
||||
@ -1242,75 +1130,6 @@ impl Tensor {
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
/// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
|
||||
pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "slice-scatter")?;
|
||||
if dim == 0 {
|
||||
self.slice_scatter0(src, start)
|
||||
} else {
|
||||
// TODO: Maybe we want to add a more efficient implementation at some point.
|
||||
self.transpose(0, dim)?
|
||||
.slice_scatter0(&src.transpose(0, dim)?, start)?
|
||||
.transpose(0, dim)
|
||||
}
|
||||
}
|
||||
|
||||
/// Embeds the values of the `src` tensor into the `self` tensor on the first dimension.
|
||||
pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result<Self> {
|
||||
if self.dtype() != src.dtype() {
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: self.dtype(),
|
||||
rhs: src.dtype(),
|
||||
op: "slice-scatter",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if self.device().location() != src.device.location() {
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: self.device().location(),
|
||||
rhs: src.device().location(),
|
||||
op: "slice-scatter",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if self.rank() != src.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: self.rank(),
|
||||
got: src.rank(),
|
||||
shape: src.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let shape_ok =
|
||||
self.dims()
|
||||
.iter()
|
||||
.zip(src.dims().iter())
|
||||
.enumerate()
|
||||
.all(|(dim_idx, (&d1, &d2))| {
|
||||
if 0 == dim_idx {
|
||||
d2 + start <= d1
|
||||
} else {
|
||||
d1 == d2
|
||||
}
|
||||
});
|
||||
if !shape_ok {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "slice-scatter (self, src)",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: src.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
let offset = start * src.dims()[1..].iter().product::<usize>();
|
||||
src.storage()
|
||||
.copy_strided_src(&mut storage, offset, src.layout())?;
|
||||
let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start));
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
/// Accumulate element from `source` at indexes `indexes` and add them to `self`.
|
||||
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "index-add")?;
|
||||
@ -1333,8 +1152,7 @@ impl Tensor {
|
||||
op: "index-add (self, source)",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
})?
|
||||
}
|
||||
// The number of element in indexes must match the dimension on which the add is
|
||||
// performed on the source tensor (and the index values from `indexes` are taken from
|
||||
@ -1345,8 +1163,7 @@ impl Tensor {
|
||||
op: "index-add (ids, source))",
|
||||
lhs: indexes.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
})?
|
||||
}
|
||||
let storage = self.storage().index_add(
|
||||
self.layout(),
|
||||
@ -1394,8 +1211,7 @@ impl Tensor {
|
||||
op: "gather",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: indexes.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
})?
|
||||
}
|
||||
let storage =
|
||||
self.storage()
|
||||
@ -1469,7 +1285,6 @@ impl Tensor {
|
||||
match &*self.storage() {
|
||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1500,7 +1315,6 @@ impl Tensor {
|
||||
match &*self.storage() {
|
||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1541,7 +1355,6 @@ impl Tensor {
|
||||
match &*self.storage() {
|
||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1705,24 +1518,6 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the sub-tensor fixing the index at `index` on the dimension `dim`.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||
/// let t = tensor.get_on_dim(1, 0)?;
|
||||
/// assert_eq!(t.to_vec1::<f32>()?, &[0., 2., 4.]);
|
||||
/// let t = tensor.get_on_dim(1, 1)?;
|
||||
/// assert_eq!(t.to_vec1::<f32>()?, &[1., 3., 5.]);
|
||||
/// let t = tensor.get_on_dim(0, 1)?;
|
||||
/// assert_eq!(t.to_vec1::<f32>()?, &[2., 3.]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn get_on_dim<D: Dim>(&self, dim: D, index: usize) -> Result<Tensor> {
|
||||
let dim = dim.to_index(self.shape(), "get_on_dim")?;
|
||||
self.narrow(dim, index, 1)?.squeeze(dim)
|
||||
}
|
||||
|
||||
/// Returns a tensor that is a transposed version of the input, the two last dimensions of the
|
||||
/// input are swapped.
|
||||
///
|
||||
@ -1751,9 +1546,6 @@ impl Tensor {
|
||||
pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
|
||||
let dim1 = dim1.to_index(self.shape(), "transpose")?;
|
||||
let dim2 = dim2.to_index(self.shape(), "transpose")?;
|
||||
if dim1 == dim2 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
@ -1831,23 +1623,17 @@ impl Tensor {
|
||||
|
||||
/// Returns a new tensor detached from the current graph, gradient are not propagated through
|
||||
/// this new node. The storage of this tensor is shared with the initial tensor.
|
||||
///
|
||||
/// If the tensor is already detached from the computation graph, the same tensor is returned.
|
||||
pub fn detach(&self) -> Result<Tensor> {
|
||||
if self.op.is_none() && !self.is_variable {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
layout: self.layout.clone(),
|
||||
op: BackpropOp::none(),
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
layout: self.layout.clone(),
|
||||
op: BackpropOp::none(),
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
/// If the target device is the same as the tensor device, only a shallow copy is performed.
|
||||
@ -1859,14 +1645,7 @@ impl Tensor {
|
||||
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
|
||||
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
|
||||
}
|
||||
(Storage::Cpu(storage), Device::Metal(metal)) => {
|
||||
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
|
||||
}
|
||||
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||
(Storage::Metal(storage), Device::Cpu) => {
|
||||
println!("{storage:?} - {:?}", storage.to_cpu_storage()?);
|
||||
Storage::Cpu(storage.to_cpu_storage()?)
|
||||
}
|
||||
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
|
||||
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
|
||||
// are the same.
|
||||
@ -1874,9 +1653,6 @@ impl Tensor {
|
||||
Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
|
||||
}
|
||||
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
||||
_ => {
|
||||
bail!("not implemented yet")
|
||||
}
|
||||
};
|
||||
let op = BackpropOp::new1(self, Op::ToDevice);
|
||||
let tensor_ = Tensor_ {
|
||||
@ -2131,34 +1907,6 @@ impl Tensor {
|
||||
for arg in args {
|
||||
arg.as_ref().check_dim(dim, "cat")?;
|
||||
}
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg0.rank() != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: arg0.rank(),
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx != dim && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
}
|
||||
if dim == 0 {
|
||||
Self::cat0(args)
|
||||
} else {
|
||||
@ -2276,56 +2024,11 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Pad the input tensor using same values along dimension `dim`. This adds `left` elements before the
|
||||
/// input tensor values and `right` elements after.
|
||||
pub fn pad_with_same<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
|
||||
if left == 0 && right == 0 {
|
||||
Ok(self.clone())
|
||||
} else if self.elem_count() == 0 {
|
||||
crate::bail!("cannot use pad_with_same on an empty tensor")
|
||||
} else if left == 0 {
|
||||
let dim = dim.to_index(self.shape(), "pad_with_same")?;
|
||||
let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
|
||||
let mut v = vec![self];
|
||||
for _ in 0..right {
|
||||
v.push(&r)
|
||||
}
|
||||
Tensor::cat(&v, dim)
|
||||
} else if right == 0 {
|
||||
let dim = dim.to_index(self.shape(), "pad_with_same")?;
|
||||
let l = self.narrow(dim, 0, 1)?;
|
||||
let mut v = vec![];
|
||||
for _ in 0..left {
|
||||
v.push(&l)
|
||||
}
|
||||
v.push(self);
|
||||
Tensor::cat(&v, dim)
|
||||
} else {
|
||||
let dim = dim.to_index(self.shape(), "pad_with_same")?;
|
||||
let l = self.narrow(dim, 0, 1)?;
|
||||
let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
|
||||
let mut v = vec![];
|
||||
for _ in 0..left {
|
||||
v.push(&l)
|
||||
}
|
||||
v.push(self);
|
||||
for _ in 0..right {
|
||||
v.push(&r)
|
||||
}
|
||||
Tensor::cat(&v, dim)
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the `forward` method of `m` on `self`.
|
||||
pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
|
||||
m.forward(self)
|
||||
}
|
||||
|
||||
/// Run the `forward` method of `m` on `self`.
|
||||
pub fn apply_t<M: crate::ModuleT>(&self, m: &M, train: bool) -> Result<Self> {
|
||||
m.forward_t(self, train)
|
||||
}
|
||||
|
||||
pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
|
||||
self.storage.read().unwrap()
|
||||
}
|
||||
@ -2440,23 +2143,6 @@ impl Tensor {
|
||||
) -> Result<Self> {
|
||||
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
||||
}
|
||||
|
||||
/// Normalize a 'relative' axis value: positive values are kept, negative
|
||||
/// values means counting the dimensions from the back.
|
||||
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
|
||||
let rank = self.rank() as i64;
|
||||
if rank <= axis {
|
||||
crate::bail!("axis {axis} is too large, tensor rank {rank}")
|
||||
} else if 0 <= axis {
|
||||
Ok(axis as usize)
|
||||
} else {
|
||||
let naxis = rank + axis;
|
||||
if naxis < 0 {
|
||||
crate::bail!("axis {axis} is too small, tensor rank {rank}")
|
||||
}
|
||||
Ok(naxis as usize)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! bin_trait {
|
||||
|
@ -4,7 +4,7 @@ use crate::{Result, Tensor};
|
||||
macro_rules! test_device {
|
||||
// TODO: Switch to generating the two last arguments automatically once concat_idents is
|
||||
// stable. https://github.com/rust-lang/rust/issues/29599
|
||||
($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => {
|
||||
($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => {
|
||||
#[test]
|
||||
fn $test_cpu() -> Result<()> {
|
||||
$fn_name(&Device::Cpu)
|
||||
@ -15,12 +15,6 @@ macro_rules! test_device {
|
||||
fn $test_cuda() -> Result<()> {
|
||||
$fn_name(&Device::new_cuda(0)?)
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
#[test]
|
||||
fn $test_metal() -> Result<()> {
|
||||
$fn_name(&Device::new_metal(0)?)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -23,10 +23,6 @@ pub fn cuda_is_available() -> bool {
|
||||
cfg!(feature = "cuda")
|
||||
}
|
||||
|
||||
pub fn metal_is_available() -> bool {
|
||||
cfg!(feature = "metal")
|
||||
}
|
||||
|
||||
pub fn with_avx() -> bool {
|
||||
cfg!(target_feature = "avx")
|
||||
}
|
||||
|
@ -13,11 +13,6 @@ res = torch.nn.functional.conv1d(t, w)
|
||||
print(res.flatten())
|
||||
res = torch.nn.functional.conv1d(t, w, padding=1)
|
||||
print(res.flatten())
|
||||
|
||||
w_t = w.transpose(0, 1)
|
||||
res = torch.nn.functional.conv_transpose1d(t, w_t)
|
||||
print(res.shape)
|
||||
print(res)
|
||||
*/
|
||||
fn conv1d(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
@ -50,17 +45,6 @@ fn conv1d(dev: &Device) -> Result<()> {
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||
);
|
||||
if dev.is_cpu() {
|
||||
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[
|
||||
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
|
||||
4.7076, -5.9745, -0.8276, 1.621
|
||||
],
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -495,103 +479,17 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||
]
|
||||
]
|
||||
);
|
||||
|
||||
// Replicate the issue from https://github.com/huggingface/candle/issues/1212
|
||||
let res = t.i((.., .., 0..4, 0..4))?.conv2d(&w, 0, 2, 1, 1)?;
|
||||
let loss = res.sqr()?.sum_all()?;
|
||||
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 21.12f32);
|
||||
let grads = loss.backward()?;
|
||||
let grad_t = grads.get(&t).unwrap();
|
||||
let grad_w = grads.get(&w).unwrap();
|
||||
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
|
||||
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_t.i(0)?, 2)?,
|
||||
[
|
||||
[
|
||||
[9.29, -7.03, 7.87, 0.0, 0.0],
|
||||
[-1.8, -7.82, 5.9, 0.0, 0.0],
|
||||
[-3.12, 4.49, 5.52, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
],
|
||||
[
|
||||
[21.73, 3.39, 4.77, 0.0, 0.0],
|
||||
[8.25, 3.73, 27.61, 0.0, 0.0],
|
||||
[-20.55, -5.61, -2.77, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
],
|
||||
[
|
||||
[-8.98, 9.91, -7.15, 0.0, 0.0],
|
||||
[4.93, -0.33, 4.56, 0.0, 0.0],
|
||||
[-6.7, -5.76, -8.05, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
],
|
||||
[
|
||||
[23.54, 6.98, -10.0, 0.0, 0.0],
|
||||
[9.65, 6.18, 18.72, 0.0, 0.0],
|
||||
[3.29, -5.27, 0.79, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_w.i(0)?, 2)?,
|
||||
[
|
||||
[
|
||||
[-3.47, 7.44, 0.66],
|
||||
[12.89, -3.4, -9.29],
|
||||
[-14.16, -0.83, 7.14]
|
||||
],
|
||||
[
|
||||
[-3.23, 5.37, -3.02],
|
||||
[-2.12, -11.24, 1.94],
|
||||
[6.97, 7.2, 2.99]
|
||||
],
|
||||
[
|
||||
[-4.04, -3.31, 4.87],
|
||||
[-6.68, -5.68, 1.73],
|
||||
[-5.54, 4.32, 0.52]
|
||||
],
|
||||
[[-4.72, 1.5, 4.72], [3.79, 4.04, 6.76], [-4.6, 5.8, 6.93]]
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(conv1d, conv1d_cpu, conv1d_gpu, conv1d_metal);
|
||||
test_device!(
|
||||
conv1d_small,
|
||||
conv1d_small_cpu,
|
||||
conv1d_small_gpu,
|
||||
conv1d_small_metal
|
||||
);
|
||||
test_device!(conv2d, conv2d_cpu, conv2d_gpu, conv2d_metal);
|
||||
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
|
||||
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
|
||||
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
|
||||
test_device!(
|
||||
conv2d_non_square,
|
||||
conv2d_non_square_cpu,
|
||||
conv2d_non_square_gpu,
|
||||
conv2d_non_square_metal
|
||||
);
|
||||
test_device!(
|
||||
conv2d_small,
|
||||
conv2d_small_cpu,
|
||||
conv2d_small_gpu,
|
||||
conv2d_small_metal
|
||||
);
|
||||
test_device!(
|
||||
conv2d_smaller,
|
||||
conv2d_smaller_cpu,
|
||||
conv2d_smaller_gpu,
|
||||
conv2d_smaller_metal
|
||||
);
|
||||
test_device!(
|
||||
conv2d_grad,
|
||||
conv2d_grad_cpu,
|
||||
conv2d_grad_gpu,
|
||||
conv2_grad_metal
|
||||
conv2d_non_square_gpu
|
||||
);
|
||||
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
|
||||
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
|
||||
test_device!(conv2d_grad, conv2d_grad_cpu, conv2d_grad_gpu);
|
||||
|
@ -192,84 +192,6 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
test_utils::to_vec1_round(grad_x, 2)?,
|
||||
[0.01, 0.42, 0.0, 0.98],
|
||||
);
|
||||
|
||||
// testing compared to pytorch nn.GELU(approximate = 'tanh')
|
||||
let y = x.gelu()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[2.9964, 0.8412, 3.9999, 0.0839]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[1.0116, 1.0830, 1.0003, 0.6188],
|
||||
);
|
||||
|
||||
// Testing compared to pytorch torch.erf
|
||||
//
|
||||
// import torch
|
||||
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
|
||||
// y = x.erf()
|
||||
// print(y)
|
||||
// loss = y.sum()
|
||||
// loss.backward()
|
||||
// print(x.grad)
|
||||
let y = x.erf()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(test_utils::to_vec1_round(&y, 4)?, [1.0, 0.8427, 1.0, 0.168]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[0.0001, 0.4151, 0.0, 1.1033],
|
||||
);
|
||||
|
||||
// Testing compared to pytorch nn.GELU(approximate = 'none')
|
||||
//
|
||||
// import torch
|
||||
// import torch.nn.functional as F
|
||||
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
|
||||
// y = F.gelu(x, approximate='none')
|
||||
// print(y)
|
||||
// loss = y.sum()
|
||||
// loss.backward()
|
||||
// print(x.grad)
|
||||
let y = x.gelu_erf()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[2.9960, 0.8413, 3.9999, 0.0839]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[1.0119, 1.0833, 1.0005, 0.6188],
|
||||
);
|
||||
|
||||
// Testing compared to pytorch elu
|
||||
//
|
||||
// import torch
|
||||
// import torch.nn.functional as F
|
||||
// x = torch.tensor([-1.0, 0.0, -2.0, 3.0], requires_grad=True)
|
||||
// y = F.elu(x, alpha=2.0)
|
||||
// print(y)
|
||||
// loss = y.min
|
||||
// loss = y.sum()
|
||||
// loss.backward()
|
||||
// print(x.grad)
|
||||
let elu_x = Var::new(&[-1.0f32, 0., -2., 3.], device)?;
|
||||
let y = elu_x.elu(2.)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&elu_x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[-1.2642, 0.0000, -1.7293, 3.0000]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[0.7358, 2.0000, 0.2707, 1.0000]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -296,48 +218,12 @@ fn binary_grad(device: &Device) -> Result<()> {
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [3., 1., -4., -1.]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [1., 1., 1., 1.]);
|
||||
|
||||
let x_var = Var::new(&[3f32, 1., -4., -1., 5., 9.], device)?;
|
||||
let x = x_var.as_tensor();
|
||||
let y_var = Var::new(&[2f32, 7., 1.], device)?;
|
||||
let y = y_var.as_tensor();
|
||||
|
||||
let ss = x
|
||||
.reshape((2, 3))?
|
||||
.slice_scatter0(&y.reshape((1, 3))?, 1)?
|
||||
.sqr()?;
|
||||
let grads = ss.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
let grad_y = grads.get(y).context("no grad for y")?;
|
||||
assert_eq!(ss.to_vec2::<f32>()?, [[9., 1., 16.], [4., 49., 1.]]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [6.0, 2.0, -8.0, 0.0, 0.0, 0.0]);
|
||||
assert_eq!(grad_y.to_vec1::<f32>()?, [4.0, 14.0, 2.0]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(
|
||||
simple_grad,
|
||||
simple_grad_cpu,
|
||||
simple_grad_gpu,
|
||||
simple_grad_metal
|
||||
);
|
||||
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu, sum_grad_metal);
|
||||
test_device!(
|
||||
matmul_grad,
|
||||
matmul_grad_cpu,
|
||||
matmul_grad_gpu,
|
||||
matmul_grad_metal
|
||||
);
|
||||
test_device!(
|
||||
grad_descent,
|
||||
grad_descent_cpu,
|
||||
grad_descent_gpu,
|
||||
grad_descent_metal
|
||||
);
|
||||
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu, unary_grad_metal);
|
||||
test_device!(
|
||||
binary_grad,
|
||||
binary_grad_cpu,
|
||||
binary_grad_gpu,
|
||||
binary_grad_metal
|
||||
);
|
||||
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
|
||||
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
|
||||
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
|
||||
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
|
||||
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu);
|
||||
test_device!(binary_grad, binary_grad_cpu, binary_grad_gpu);
|
||||
|
@ -49,7 +49,7 @@ fn contiguous(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(contiguous, contiguous_cpu, contiguous_gpu, contiguous_metal);
|
||||
test_device!(contiguous, contiguous_cpu, contiguous_gpu);
|
||||
|
||||
#[test]
|
||||
fn strided_blocks() -> Result<()> {
|
||||
|
@ -1,9 +0,0 @@
|
||||
import numpy as np
|
||||
x = np.arange(10)
|
||||
|
||||
# Write a npy file.
|
||||
np.save("test.npy", x)
|
||||
|
||||
# Write multiple values to a npz file.
|
||||
values = { "x": x, "x_plus_one": x + 1 }
|
||||
np.savez("test.npz", **values)
|
@ -98,17 +98,15 @@ fn upsample_nearest2d(dev: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu, avg_pool2d_metal);
|
||||
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu);
|
||||
test_device!(
|
||||
avg_pool2d_pytorch,
|
||||
avg_pool2d_pytorch_cpu,
|
||||
avg_pool2d_pytorch_gpu,
|
||||
avg_pool2d_pytorch_metal
|
||||
avg_pool2d_pytorch_gpu
|
||||
);
|
||||
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu, max_pool2d_metal);
|
||||
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);
|
||||
test_device!(
|
||||
upsample_nearest2d,
|
||||
upsample_nearest2d_cpu,
|
||||
upsample_nearest2d_gpu,
|
||||
upsample_nearest2d_metal
|
||||
upsample_nearest2d_gpu
|
||||
);
|
||||
|
@ -43,7 +43,7 @@ fn quantized_matmul() -> Result<()> {
|
||||
);
|
||||
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
@ -91,7 +91,7 @@ fn quantized_matmul_neg() -> Result<()> {
|
||||
);
|
||||
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
@ -491,9 +491,6 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
|
||||
GgmlDType::Q5_0 => 0.001353,
|
||||
GgmlDType::Q5_1 => 0.001363,
|
||||
GgmlDType::Q8_0 => 0.000092,
|
||||
|
||||
// Not from the ggml repo.
|
||||
GgmlDType::Q8K => 0.00065,
|
||||
_ => candle_core::bail!("No GGML results for quantization type {dtype:?}",),
|
||||
};
|
||||
Ok(err)
|
||||
@ -511,22 +508,17 @@ fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
|
||||
T::VecDotType::from_float(&b, &mut b_quant)?;
|
||||
|
||||
let result = T::vec_dot(length, &a_quant, &b_quant)?;
|
||||
let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?;
|
||||
let reference_result = vec_dot_reference(&a, &b);
|
||||
|
||||
if (result - result_unopt).abs() / length as f32 > 1e-6 {
|
||||
candle_core::bail!(
|
||||
"the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}"
|
||||
)
|
||||
}
|
||||
|
||||
let error = (result - reference_result).abs() / length as f32;
|
||||
|
||||
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;
|
||||
|
||||
if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR {
|
||||
if error > GGML_MAX_DOT_PRODUCT_ERROR {
|
||||
candle_core::bail!(
|
||||
"Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",
|
||||
"Dot product error {} exceeds max error {}",
|
||||
error,
|
||||
GGML_MAX_DOT_PRODUCT_ERROR
|
||||
);
|
||||
}
|
||||
|
||||
@ -579,7 +571,7 @@ fn quantized_matmul_q2k() -> Result<()> {
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
@ -605,7 +597,7 @@ fn quantized_matmul_q3k() -> Result<()> {
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
@ -631,7 +623,7 @@ fn quantized_matmul_q4k() -> Result<()> {
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
@ -657,7 +649,7 @@ fn quantized_matmul_q5k() -> Result<()> {
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
@ -684,7 +676,7 @@ fn quantized_matmul_q6k() -> Result<()> {
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
@ -695,28 +687,3 @@ fn quantized_matmul_q6k() -> Result<()> {
|
||||
ggml_matmul_error_test::<BlockQ6K>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_q8k() -> Result<()> {
|
||||
use k_quants::BlockQ8K;
|
||||
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (11, 512, 21);
|
||||
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ8K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.266, 1.504, -0.204, 1.7]);
|
||||
|
||||
ggml_matmul_error_test::<BlockQ8K>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,24 +0,0 @@
|
||||
use candle_core::{DType, Result, Tensor};
|
||||
|
||||
#[test]
|
||||
fn npy() -> Result<()> {
|
||||
let npy = Tensor::read_npy("tests/test.npy")?;
|
||||
assert_eq!(
|
||||
npy.to_dtype(DType::U8)?.to_vec1::<u8>()?,
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn npz() -> Result<()> {
|
||||
let npz = Tensor::read_npz("tests/test.npz")?;
|
||||
assert_eq!(npz.len(), 2);
|
||||
assert_eq!(npz[0].0, "x");
|
||||
assert_eq!(npz[1].0, "x_plus_one");
|
||||
assert_eq!(
|
||||
npz[1].1.to_dtype(DType::U8)?.to_vec1::<u8>()?,
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
);
|
||||
Ok(())
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor};
|
||||
use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor};
|
||||
|
||||
fn zeros(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
|
||||
@ -8,50 +8,6 @@ fn zeros(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ones(device: &Device) -> Result<()> {
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::U8, device)?.to_vec2::<u8>()?,
|
||||
[[1, 1, 1], [1, 1, 1]],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::U32, device)?.to_vec2::<u32>()?,
|
||||
[[1, 1, 1], [1, 1, 1]],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::I64, device)?.to_vec2::<i64>()?,
|
||||
[[1, 1, 1], [1, 1, 1]],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::F32, device)?.to_vec2::<f32>()?,
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn arange(device: &Device) -> Result<()> {
|
||||
assert_eq!(
|
||||
Tensor::arange(0u8, 5u8, device)?.to_vec1::<u8>()?,
|
||||
[0, 1, 2, 3, 4],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::arange_step(0u8, 5u8, 2, device)?.to_vec1::<u8>()?,
|
||||
[0, 2, 4],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::arange_step(0u8, 5u8, 3, device)?.to_vec1::<u8>()?,
|
||||
[0, 3],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::arange_step(5i64, 0i64, -1, device)?.to_vec1::<i64>()?,
|
||||
[5, 4, 3, 2, 1],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn add_mul(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::new(&[3f32, 1., 4.], device)?;
|
||||
let dim1 = tensor.dims1()?;
|
||||
@ -88,54 +44,6 @@ fn clamp(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn unary_op(device: &Device) -> Result<()> {
|
||||
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.gelu()?, 4)?,
|
||||
[
|
||||
[-0.0036, 0.8412, 3.9999, -0.046, 0.3457],
|
||||
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
|
||||
[
|
||||
[-0.004, 0.8413, 3.9999, -0.046, 0.3457],
|
||||
[2.6906, -0.0647, -0.1091, 1.7353, 2.7928]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.erf()?, 4)?,
|
||||
[
|
||||
[-1.0, 0.8427, 1.0, -0.1125, 0.5205],
|
||||
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
|
||||
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.floor()?, 4)?,
|
||||
[[-3.0, 1.0, 4.0, -1.0, 0.0], [2.0, -2.0, -1.0, 1.0, 2.0]]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.round()?, 4)?,
|
||||
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -2.0, -0.0, 2.0, 3.0]]
|
||||
);
|
||||
let tensor = Tensor::new(&[2997.9246, 314.15926f32], device)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&tensor.round_to(2)?, 4)?,
|
||||
[2997.92, 314.16]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?,
|
||||
[3000.0, 300.]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn binary_op(device: &Device) -> Result<()> {
|
||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||
let tensor1 = Tensor::new(data, device)?;
|
||||
@ -180,22 +88,6 @@ fn transpose(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn var(device: &Device) -> Result<()> {
|
||||
// Values taken from https://pytorch.org/docs/stable/generated/torch.var.html
|
||||
let data = &[
|
||||
[0.2035f32, 1.2959, 1.8101, -0.4644],
|
||||
[1.5027, -0.3270, 0.5905, 0.6538],
|
||||
[-1.5745, 1.3330, -0.5596, -0.6548],
|
||||
[0.1264, -0.5080, 1.6420, 0.1992],
|
||||
];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.var_keepdim(1)?, 4)?,
|
||||
&[[1.0631], [0.559], [1.4893], [0.8258]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn sum(device: &Device) -> Result<()> {
|
||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
@ -709,30 +601,6 @@ fn index_select(device: &Device) -> Result<()> {
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
||||
);
|
||||
// Prior to https://github.com/huggingface/candle/pull/1022
|
||||
// There would be a bug where the last values in the result tensor would be set to 0.
|
||||
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
]
|
||||
);
|
||||
|
||||
// Test when selecting dim > 0 with ids size different from elem count of
|
||||
// target dim in source/input.
|
||||
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
|
||||
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
|
||||
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
|
||||
let hs = t.index_select(&ids, 1)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -779,48 +647,6 @@ fn index_add(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn slice_scatter(device: &Device) -> Result<()> {
|
||||
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
let src = Tensor::arange(100f32, 106f32, device)?.reshape((2, 3))?;
|
||||
assert_eq!(
|
||||
t.slice_scatter0(&src, 0)?.to_vec2::<f32>()?,
|
||||
&[
|
||||
[100.0, 101.0, 102.0],
|
||||
[103.0, 104.0, 105.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
t.slice_scatter0(&src, 1)?.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[100.0, 101.0, 102.0],
|
||||
[103.0, 104.0, 105.0],
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
t.slice_scatter0(&src, 2)?.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[100.0, 101.0, 102.0],
|
||||
[103.0, 104.0, 105.0],
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn scatter_add(device: &Device) -> Result<()> {
|
||||
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||
assert_eq!(
|
||||
@ -1070,60 +896,30 @@ fn randn(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
||||
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
||||
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
|
||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
||||
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
|
||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
|
||||
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
|
||||
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
|
||||
test_device!(min, min_cpu, min_gpu, min_metal);
|
||||
test_device!(max, max_cpu, max_gpu, max_metal);
|
||||
test_device!(argmax, argmax_cpu, argmax_gpu, argmax_metal);
|
||||
test_device!(argmin, argmin_cpu, argmin_gpu, argmin_metal);
|
||||
test_device!(transpose, transpose_cpu, transpose_gpu, transpose_metal);
|
||||
test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
|
||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
|
||||
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
||||
test_device!(
|
||||
broadcast_matmul,
|
||||
broadcast_matmul_cpu,
|
||||
broadcast_matmul_gpu,
|
||||
broadcast_matmul_metal
|
||||
);
|
||||
test_device!(
|
||||
broadcasting,
|
||||
broadcasting_cpu,
|
||||
broadcasting_gpu,
|
||||
broadcasting_metal
|
||||
);
|
||||
test_device!(
|
||||
index_select,
|
||||
index_select_cpu,
|
||||
index_select_gpu,
|
||||
index_select_metal
|
||||
);
|
||||
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
|
||||
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
|
||||
test_device!(
|
||||
scatter_add,
|
||||
scatter_add_cpu,
|
||||
scatter_add_gpu,
|
||||
scatter_add_metal
|
||||
);
|
||||
test_device!(
|
||||
slice_scatter,
|
||||
slice_scatter_cpu,
|
||||
slice_scatter_gpu,
|
||||
slice_scatter_metal
|
||||
);
|
||||
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
||||
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
||||
test_device!(var, var_cpu, var_gpu, var_metal);
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu);
|
||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
||||
test_device!(narrow, narrow_cpu, narrow_gpu);
|
||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu);
|
||||
test_device!(cat, cat_cpu, cat_gpu);
|
||||
test_device!(sum, sum_cpu, sum_gpu);
|
||||
test_device!(min, min_cpu, min_gpu);
|
||||
test_device!(max, max_cpu, max_gpu);
|
||||
test_device!(argmax, argmax_cpu, argmax_gpu);
|
||||
test_device!(argmin, argmin_cpu, argmin_gpu);
|
||||
test_device!(transpose, transpose_cpu, transpose_gpu);
|
||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
||||
test_device!(cmp, cmp_cpu, cmp_gpu);
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu);
|
||||
test_device!(broadcast_matmul, broadcast_matmul_cpu, broadcast_matmul_gpu);
|
||||
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
|
||||
test_device!(index_select, index_select_cpu, index_select_gpu);
|
||||
test_device!(index_add, index_add_cpu, index_add_gpu);
|
||||
test_device!(gather, gather_cpu, gather_gpu);
|
||||
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
||||
test_device!(randn, randn_cpu, randn_gpu);
|
||||
test_device!(clamp, clamp_cpu, clamp_gpu);
|
||||
|
||||
// There was originally a bug on the CPU implementation for randn
|
||||
// https://github.com/huggingface/candle/issues/381
|
||||
@ -1135,27 +931,3 @@ fn randn_hasneg() -> Result<()> {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pad_with_same() -> Result<()> {
|
||||
let t = Tensor::arange(1f32, 5f32, &Device::Cpu)?.reshape((2, 2))?;
|
||||
let t0 = t.pad_with_same(0, 1, 2)?;
|
||||
assert_eq!(
|
||||
t0.to_vec2::<f32>()?,
|
||||
[[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0], [3.0, 4.0]]
|
||||
);
|
||||
let t1 = t.pad_with_same(1, 1, 2)?;
|
||||
assert_eq!(
|
||||
t1.to_vec2::<f32>()?,
|
||||
[[1.0, 1.0, 2.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0, 4.0]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn i64_abs() -> Result<()> {
|
||||
let t = Tensor::new(&[-42i64, 1337], &Device::Cpu)?;
|
||||
let t = t.abs()?;
|
||||
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
|
||||
Ok(())
|
||||
}
|
||||
|
Binary file not shown.
Binary file not shown.
@ -11,8 +11,8 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
byteorder = { workspace = true }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.3" }
|
||||
hf-hub = { workspace = true}
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
memmap2 = { workspace = true }
|
||||
|
@ -4,9 +4,7 @@
|
||||
//! <https://www.cs.toronto.edu/~kriz/cifar.html>
|
||||
//! The binary version of the dataset is used.
|
||||
use crate::vision::Dataset;
|
||||
use candle::{DType, Device, Error, Result, Tensor};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use parquet::file::reader::{FileReader, SerializedFileReader};
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use std::fs::File;
|
||||
use std::io::{BufReader, Read};
|
||||
|
||||
@ -62,58 +60,3 @@ pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<Dataset> {
|
||||
labels: 10,
|
||||
})
|
||||
}
|
||||
|
||||
fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor, Tensor)> {
|
||||
let samples = parquet.metadata().file_metadata().num_rows() as usize;
|
||||
let mut buffer_images: Vec<u8> = Vec::with_capacity(samples * 1_024);
|
||||
let mut buffer_labels: Vec<u8> = Vec::with_capacity(samples);
|
||||
for row in parquet.into_iter().flatten() {
|
||||
for (_name, field) in row.get_column_iter() {
|
||||
if let parquet::record::Field::Group(subrow) = field {
|
||||
for (_name, field) in subrow.get_column_iter() {
|
||||
if let parquet::record::Field::Bytes(value) = field {
|
||||
let image = image::load_from_memory(value.data()).unwrap();
|
||||
buffer_images.extend(image.to_rgb8().as_raw());
|
||||
}
|
||||
}
|
||||
} else if let parquet::record::Field::Long(label) = field {
|
||||
buffer_labels.push(*label as u8);
|
||||
}
|
||||
}
|
||||
}
|
||||
let images = (Tensor::from_vec(buffer_images, (samples, 3, 32, 32), &Device::Cpu)?
|
||||
.to_dtype(DType::U8)?
|
||||
/ 255.)?;
|
||||
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
|
||||
Ok((images, labels))
|
||||
}
|
||||
|
||||
pub fn load() -> Result<Dataset> {
|
||||
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||
let dataset_id = "cifar10".to_string();
|
||||
let repo = Repo::with_revision(
|
||||
dataset_id,
|
||||
RepoType::Dataset,
|
||||
"refs/convert/parquet".to_string(),
|
||||
);
|
||||
let repo = api.repo(repo);
|
||||
let test_parquet_filename = repo
|
||||
.get("plain_text/test/0000.parquet")
|
||||
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||
let train_parquet_filename = repo
|
||||
.get("plain_text/train/0000.parquet")
|
||||
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)
|
||||
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
|
||||
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)
|
||||
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
|
||||
let (test_images, test_labels) = load_parquet(test_parquet)?;
|
||||
let (train_images, train_labels) = load_parquet(train_parquet)?;
|
||||
Ok(crate::vision::Dataset {
|
||||
train_images,
|
||||
train_labels,
|
||||
test_images,
|
||||
test_labels,
|
||||
labels: 10,
|
||||
})
|
||||
}
|
||||
|
@ -11,23 +11,20 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||
candle-onnx = { path = "../candle-onnx", version = "0.3.0", optional = true }
|
||||
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.2.3" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.3", optional = true }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.3" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.2.3" }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
image = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
@ -38,6 +35,7 @@ imageproc = { workspace = true }
|
||||
memmap2 = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
rusttype = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-chrome = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
@ -53,24 +51,10 @@ default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
cudnn = ["candle/cudnn"]
|
||||
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||
onnx = ["candle-onnx"]
|
||||
metal = ["candle/metal", "candle-nn/metal"]
|
||||
|
||||
[[example]]
|
||||
name = "llama_multiprocess"
|
||||
required-features = ["cuda", "nccl", "flash-attn"]
|
||||
|
||||
[[example]]
|
||||
name = "reinforcement-learning"
|
||||
required-features = ["pyo3"]
|
||||
|
||||
[[example]]
|
||||
name = "onnx"
|
||||
required-features = ["onnx"]
|
||||
|
||||
[[example]]
|
||||
name = "onnx_basics"
|
||||
required-features = ["onnx"]
|
||||
|
@ -5,11 +5,11 @@ extern crate intel_mkl_src;
|
||||
extern crate accelerate_src;
|
||||
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use anyhow::{anyhow, Error as E, Result};
|
||||
use candle::Tensor;
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
|
||||
use tokenizers::{PaddingParams, Tokenizer};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -19,6 +19,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Run offline (you must have the files already cached)
|
||||
#[arg(long)]
|
||||
offline: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
@ -34,10 +38,6 @@ struct Args {
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// Use the pytorch weights rather than the safetensors ones
|
||||
#[arg(long)]
|
||||
use_pth: bool,
|
||||
|
||||
/// The number of times to run the prompt.
|
||||
#[arg(long, default_value = "1")]
|
||||
n: usize,
|
||||
@ -60,27 +60,35 @@ impl Args {
|
||||
};
|
||||
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let (config_filename, tokenizer_filename, weights_filename) = {
|
||||
let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
|
||||
let cache = Cache::default().repo(repo);
|
||||
(
|
||||
cache
|
||||
.get("config.json")
|
||||
.ok_or(anyhow!("Missing config file in cache"))?,
|
||||
cache
|
||||
.get("tokenizer.json")
|
||||
.ok_or(anyhow!("Missing tokenizer file in cache"))?,
|
||||
cache
|
||||
.get("model.safetensors")
|
||||
.ok_or(anyhow!("Missing weights file in cache"))?,
|
||||
)
|
||||
} else {
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
let config = api.get("config.json")?;
|
||||
let tokenizer = api.get("tokenizer.json")?;
|
||||
let weights = if self.use_pth {
|
||||
api.get("pytorch_model.bin")?
|
||||
} else {
|
||||
api.get("model.safetensors")?
|
||||
};
|
||||
(config, tokenizer, weights)
|
||||
(
|
||||
api.get("config.json")?,
|
||||
api.get("tokenizer.json")?,
|
||||
api.get("model.safetensors")?,
|
||||
)
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let config: Config = serde_json::from_str(&config)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let vb = if self.use_pth {
|
||||
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
|
||||
} else {
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
||||
};
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
||||
let model = BertModel::load(vb, &config)?;
|
||||
Ok((model, tokenizer))
|
||||
}
|
||||
|
@ -138,9 +138,18 @@ fn main() -> Result<()> {
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let weights = filenames
|
||||
.iter()
|
||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let weights = weights
|
||||
.iter()
|
||||
.map(|f| Ok(f.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let vb = VarBuilder::from_safetensors(weights, DType::F32, &device);
|
||||
let config = Config::starcoder_1b();
|
||||
let model = GPTBigCode::load(vb, config)?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
@ -1,19 +0,0 @@
|
||||
# candle-blip
|
||||
|
||||
The
|
||||
[blip-image-captioning](https://huggingface.co/Salesforce/blip-image-captioning-base)
|
||||
model can generate captions for an input image.
|
||||
|
||||
## Running on an example
|
||||
|
||||
```bash
|
||||
cargo run --example blip --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
```
|
||||
|
||||
```
|
||||
Running on CPU, to run on GPU, build this example with `--features cuda`
|
||||
loaded image Tensor[dims 3, 384, 384; f32]
|
||||
model built
|
||||
several cyclists are riding down a road with cars behind them%
|
||||
```
|
||||

|
@ -1,154 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Error as E;
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::blip;
|
||||
use candle_transformers::models::quantized_blip;
|
||||
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
M(blip::BlipForConditionalGeneration),
|
||||
Q(quantized_blip::BlipForConditionalGeneration),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn text_decoder_forward(&mut self, xs: &Tensor, img_xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::M(m) => m.text_decoder().forward(xs, img_xs),
|
||||
Self::Q(m) => m.text_decoder().forward(xs, img_xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Maybe add support for the conditional prompt.
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Use the quantized version of the model.
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
}
|
||||
|
||||
const SEP_TOKEN_ID: u32 = 102;
|
||||
|
||||
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||
/// (3, 384, 384). OpenAI normalization is applied.
|
||||
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||
let img = image::io::Reader::open(p)?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?
|
||||
.resize_to_fill(384, 384, image::imageops::FilterType::Triangle);
|
||||
let img = img.to_rgb8();
|
||||
let data = img.into_raw();
|
||||
let data = Tensor::from_vec(data, (384, 384, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
||||
let mean =
|
||||
Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||
let std = Tensor::new(&[0.26862954f32, 0.261_302_6, 0.275_777_1], &Device::Cpu)?
|
||||
.reshape((3, 1, 1))?;
|
||||
(data.to_dtype(candle::DType::F32)? / 255.)?
|
||||
.broadcast_sub(&mean)?
|
||||
.broadcast_div(&std)
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
if args.quantized {
|
||||
let api = api.model("lmz/candle-blip".to_string());
|
||||
api.get("blip-image-captioning-large-q4k.gguf")?
|
||||
} else {
|
||||
let api = api.repo(hf_hub::Repo::with_revision(
|
||||
"Salesforce/blip-image-captioning-large".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/18".to_string(),
|
||||
));
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let tokenizer = match args.tokenizer {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("Salesforce/blip-image-captioning-large".to_string());
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
Some(file) => file.into(),
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
let mut tokenizer = TokenOutputStream::new(tokenizer);
|
||||
let mut logits_processor =
|
||||
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
|
||||
|
||||
let config = blip::Config::image_captioning_large();
|
||||
|
||||
let (image_embeds, device, mut model) = if args.quantized {
|
||||
let device = Device::Cpu;
|
||||
let image = load_image(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let vb = quantized_blip::VarBuilder::from_gguf(model_file)?;
|
||||
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
|
||||
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
|
||||
(image_embeds, device, Model::Q(model))
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let image = load_image(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = blip::BlipForConditionalGeneration::new(&config, vb)?;
|
||||
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
|
||||
(image_embeds, device, Model::M(model))
|
||||
};
|
||||
|
||||
let mut token_ids = vec![30522u32];
|
||||
for index in 0..1000 {
|
||||
let context_size = if index > 0 { 1 } else { token_ids.len() };
|
||||
let start_pos = token_ids.len().saturating_sub(context_size);
|
||||
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
|
||||
let logits = model.text_decoder_forward(&input_ids, &image_embeds)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||
let token = logits_processor.sample(&logits)?;
|
||||
if token == SEP_TOKEN_ID {
|
||||
break;
|
||||
}
|
||||
token_ids.push(token);
|
||||
if let Some(t) = tokenizer.next_token(token)? {
|
||||
use std::io::Write;
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
println!();
|
||||
Ok(())
|
||||
}
|
@ -1,59 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::convmixer;
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("lmz/candle-convmixer".into());
|
||||
api.get("convmixer_1024_20_ks9_p14.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = convmixer::c1024_20(1000, vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -42,7 +42,9 @@ pub fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let model = dinov2::vit_small(vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
|
@ -68,7 +68,9 @@ pub fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let cfg = match args.which {
|
||||
Which::B0 => MBConvConfig::b0(),
|
||||
Which::B1 => MBConvConfig::b1(),
|
||||
|
@ -177,12 +177,21 @@ fn main() -> Result<()> {
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let weights = filenames
|
||||
.iter()
|
||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let weights = weights
|
||||
.iter()
|
||||
.map(|f| Ok(f.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let dtype = if args.use_f32 {
|
||||
DType::F32
|
||||
} else {
|
||||
DType::BF16
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let vb = VarBuilder::from_safetensors(weights, dtype, &device);
|
||||
let config = Config::falcon7b();
|
||||
config.validate()?;
|
||||
let model = Falcon::load(vb, config)?;
|
||||
|
@ -1,45 +0,0 @@
|
||||
# candle-jina-bert
|
||||
|
||||
Jina-Bert is a general large language model with a context size of 8192, [model
|
||||
card](https://huggingface.co/jinaai/jina-embeddings-v2-base-en). In this example
|
||||
it can be used for two different tasks:
|
||||
- Compute sentence embeddings for a prompt.
|
||||
- Compute similarities between a set of sentences.
|
||||
|
||||
|
||||
## Sentence embeddings
|
||||
|
||||
Jina-Bert is used to compute the sentence embeddings for a prompt. The model weights
|
||||
are downloaded from the hub on the first run.
|
||||
|
||||
```bash
|
||||
cargo run --example jina-bert --release -- --prompt "Here is a test sentence"
|
||||
|
||||
> [[[ 0.1595, -0.9885, 0.6494, ..., 0.3003, -0.6901, -1.2355],
|
||||
> [ 0.0374, -0.1798, 1.3359, ..., 0.6731, 0.2133, -1.6807],
|
||||
> [ 0.1700, -0.8534, 0.8924, ..., -0.1785, -0.0727, -1.5087],
|
||||
> ...
|
||||
> [-0.3113, -1.3665, 0.2027, ..., -0.2519, 0.1711, -1.5811],
|
||||
> [ 0.0907, -1.0492, 0.5382, ..., 0.0242, -0.7077, -1.0830],
|
||||
> [ 0.0369, -0.6343, 0.6105, ..., 0.0671, 0.3778, -1.1505]]]
|
||||
> Tensor[[1, 7, 768], f32]
|
||||
```
|
||||
|
||||
## Similarities
|
||||
|
||||
In this example, Jina-Bert is used to compute the sentence embeddings for a set of
|
||||
sentences (hardcoded in the examples). Then cosine similarities are computed for
|
||||
each sentence pair and they are reported by decreasing values, hence the first
|
||||
reported pair contains the two sentences that have the highest similarity score.
|
||||
The sentence embeddings are computed using average pooling through all the
|
||||
sentence tokens, including some potential padding.
|
||||
|
||||
```bash
|
||||
cargo run --example jina-bert --release
|
||||
|
||||
> score: 0.94 'The new movie is awesome' 'The new movie is so great'
|
||||
> score: 0.81 'The cat sits outside' 'The cat plays in the garden'
|
||||
> score: 0.78 'I love pasta' 'Do you like pizza?'
|
||||
> score: 0.68 'I love pasta' 'The new movie is awesome'
|
||||
> score: 0.67 'A man is playing guitar' 'A woman watches TV'
|
||||
```
|
@ -1,180 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle_transformers::models::jina_bert::{BertModel, Config};
|
||||
|
||||
use anyhow::Error as E;
|
||||
use candle::{DType, Module, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// When set, compute embeddings for this prompt.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// The number of times to run the prompt.
|
||||
#[arg(long, default_value = "1")]
|
||||
n: usize,
|
||||
|
||||
/// L2 normalization for embeddings.
|
||||
#[arg(long, default_value = "true")]
|
||||
normalize_embeddings: bool,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
let model = match &self.model {
|
||||
Some(model_file) => std::path::PathBuf::from(model_file),
|
||||
None => Api::new()?
|
||||
.repo(Repo::new(
|
||||
"jinaai/jina-embeddings-v2-base-en".to_string(),
|
||||
RepoType::Model,
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
};
|
||||
let tokenizer = match &self.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => Api::new()?
|
||||
.repo(Repo::new(
|
||||
"sentence-transformers/all-MiniLM-L6-v2".to_string(),
|
||||
RepoType::Model,
|
||||
))
|
||||
.get("tokenizer.json")?,
|
||||
};
|
||||
let device = candle_examples::device(self.cpu)?;
|
||||
let config = Config::v2_base();
|
||||
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||
let model = BertModel::new(vb, &config)?;
|
||||
Ok((model, tokenizer))
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
println!("tracing...");
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
||||
let device = &model.device;
|
||||
|
||||
if let Some(prompt) = args.prompt {
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
let tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
println!("Loaded and encoded {:?}", start.elapsed());
|
||||
for idx in 0..args.n {
|
||||
let start = std::time::Instant::now();
|
||||
let ys = model.forward(&token_ids)?;
|
||||
if idx == 0 {
|
||||
println!("{ys}");
|
||||
}
|
||||
println!("Took {:?}", start.elapsed());
|
||||
}
|
||||
} else {
|
||||
let sentences = [
|
||||
"The cat sits outside",
|
||||
"A man is playing guitar",
|
||||
"I love pasta",
|
||||
"The new movie is awesome",
|
||||
"The cat plays in the garden",
|
||||
"A woman watches TV",
|
||||
"The new movie is so great",
|
||||
"Do you like pizza?",
|
||||
];
|
||||
let n_sentences = sentences.len();
|
||||
if let Some(pp) = tokenizer.get_padding_mut() {
|
||||
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
|
||||
} else {
|
||||
let pp = tokenizers::PaddingParams {
|
||||
strategy: tokenizers::PaddingStrategy::BatchLongest,
|
||||
..Default::default()
|
||||
};
|
||||
tokenizer.with_padding(Some(pp));
|
||||
}
|
||||
let tokens = tokenizer
|
||||
.encode_batch(sentences.to_vec(), true)
|
||||
.map_err(E::msg)?;
|
||||
let token_ids = tokens
|
||||
.iter()
|
||||
.map(|tokens| {
|
||||
let tokens = tokens.get_ids().to_vec();
|
||||
Tensor::new(tokens.as_slice(), device)
|
||||
})
|
||||
.collect::<candle::Result<Vec<_>>>()?;
|
||||
|
||||
let token_ids = Tensor::stack(&token_ids, 0)?;
|
||||
println!("running inference on batch {:?}", token_ids.shape());
|
||||
let embeddings = model.forward(&token_ids)?;
|
||||
println!("generated embeddings {:?}", embeddings.shape());
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||
let embeddings = if args.normalize_embeddings {
|
||||
normalize_l2(&embeddings)?
|
||||
} else {
|
||||
embeddings
|
||||
};
|
||||
println!("pooled embeddings {:?}", embeddings.shape());
|
||||
|
||||
let mut similarities = vec![];
|
||||
for i in 0..n_sentences {
|
||||
let e_i = embeddings.get(i)?;
|
||||
for j in (i + 1)..n_sentences {
|
||||
let e_j = embeddings.get(j)?;
|
||||
let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
|
||||
let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
|
||||
let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
|
||||
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
|
||||
similarities.push((cosine_similarity, i, j))
|
||||
}
|
||||
}
|
||||
similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
|
||||
for &(score, i, j) in similarities[..5].iter() {
|
||||
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn normalize_l2(v: &Tensor) -> candle::Result<Tensor> {
|
||||
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
|
||||
}
|
@ -172,9 +172,17 @@ fn main() -> Result<()> {
|
||||
}
|
||||
|
||||
println!("building the model");
|
||||
let handles = filenames
|
||||
.iter()
|
||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? }))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let tensors: Vec<_> = handles
|
||||
.iter()
|
||||
.map(|h| Ok(h.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
|
||||
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
|
||||
}
|
||||
};
|
||||
|
@ -6,10 +6,9 @@ extern crate accelerate_src;
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use candle_transformers::models::llama2_c as model;
|
||||
use candle_transformers::models::llama2_c_weights as weights;
|
||||
use candle_transformers::models::quantized_llama2_c as qmodel;
|
||||
mod model;
|
||||
mod training;
|
||||
mod weights;
|
||||
use clap::{Parser, Subcommand};
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
@ -20,7 +19,6 @@ use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use model::{Config, Llama};
|
||||
use qmodel::QLlama;
|
||||
use weights::TransformerWeights;
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
@ -154,20 +152,6 @@ fn main() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
enum Model {
|
||||
Llama(Llama),
|
||||
QLlama(QLlama),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> {
|
||||
match self {
|
||||
Self::Llama(l) => Ok(l.forward(xs, pos)?),
|
||||
Self::QLlama(l) => Ok(l.forward(xs, pos)?),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
||||
use std::io::BufRead;
|
||||
|
||||
@ -257,66 +241,24 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
|
||||
let device = candle_examples::device(common_args.cpu)?;
|
||||
|
||||
let is_gguf = config_path.extension().map_or(false, |v| v == "gguf");
|
||||
let is_safetensors = config_path
|
||||
.extension()
|
||||
.map_or(false, |v| v == "safetensors");
|
||||
let (model, config) = if is_gguf {
|
||||
let vb = qmodel::VarBuilder::from_gguf(config_path)?;
|
||||
let (_vocab_size, dim) = vb
|
||||
.get_no_shape("model.embed_tokens.weight")?
|
||||
.shape()
|
||||
.dims2()?;
|
||||
let config = match dim {
|
||||
64 => Config::tiny_260k(),
|
||||
288 => Config::tiny_15m(),
|
||||
512 => Config::tiny_42m(),
|
||||
768 => Config::tiny_110m(),
|
||||
_ => anyhow::bail!("no config for dim {dim}"),
|
||||
};
|
||||
let freq_cis_real = vb
|
||||
.get(
|
||||
(config.seq_len, config.head_size() / 2),
|
||||
"rot.freq_cis_real",
|
||||
)?
|
||||
.dequantize(&candle::Device::Cpu)?;
|
||||
let freq_cis_imag = vb
|
||||
.get(
|
||||
(config.seq_len, config.head_size() / 2),
|
||||
"rot.freq_cis_imag",
|
||||
)?
|
||||
.dequantize(&candle::Device::Cpu)?;
|
||||
|
||||
let fake_vb = candle_nn::VarBuilder::from_tensors(
|
||||
[
|
||||
("freq_cis_real".to_string(), freq_cis_real),
|
||||
("freq_cis_imag".to_string(), freq_cis_imag),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
candle::DType::F32,
|
||||
&candle::Device::Cpu,
|
||||
);
|
||||
let cache = model::Cache::new(true, &config, fake_vb)?;
|
||||
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
} else if is_safetensors {
|
||||
let config = Config::tiny_15m();
|
||||
let (vb, config) = if is_safetensors {
|
||||
let config = Config::tiny();
|
||||
let tensors = candle::safetensors::load(config_path, &device)?;
|
||||
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
(vb, config)
|
||||
} else {
|
||||
let mut file = std::fs::File::open(config_path)?;
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
println!("{config:?}");
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
let vb = weights.var_builder(&config, &device)?;
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
(vb, config)
|
||||
};
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
|
||||
println!("starting the inference loop");
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);
|
||||
@ -331,7 +273,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0.. {
|
||||
if tokens.len() >= config.seq_len {
|
||||
if tokens.len() >= model.config.seq_len {
|
||||
break;
|
||||
}
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
|
@ -17,20 +17,7 @@ pub struct Config {
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn tiny_260k() -> Self {
|
||||
Self {
|
||||
dim: 64,
|
||||
hidden_dim: 768,
|
||||
n_layers: 5,
|
||||
n_heads: 8,
|
||||
n_kv_heads: 4,
|
||||
vocab_size: 32000,
|
||||
seq_len: 512,
|
||||
norm_eps: 1e-5,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tiny_15m() -> Self {
|
||||
pub fn tiny() -> Self {
|
||||
Self {
|
||||
dim: 288,
|
||||
hidden_dim: 768,
|
||||
@ -42,32 +29,6 @@ impl Config {
|
||||
norm_eps: 1e-5,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tiny_42m() -> Self {
|
||||
Self {
|
||||
dim: 512,
|
||||
hidden_dim: 768,
|
||||
n_layers: 8,
|
||||
n_heads: 8,
|
||||
n_kv_heads: 8,
|
||||
vocab_size: 32000,
|
||||
seq_len: 1024,
|
||||
norm_eps: 1e-5,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tiny_110m() -> Self {
|
||||
Self {
|
||||
dim: 768,
|
||||
hidden_dim: 768,
|
||||
n_layers: 12,
|
||||
n_heads: 12,
|
||||
n_kv_heads: 12,
|
||||
vocab_size: 32000,
|
||||
seq_len: 1024,
|
||||
norm_eps: 1e-5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -75,9 +36,9 @@ pub struct Cache {
|
||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||
pub use_kv_cache: bool,
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
||||
pub cos: Tensor,
|
||||
pub sin: Tensor,
|
||||
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
||||
cos: Tensor,
|
||||
sin: Tensor,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
@ -114,7 +75,7 @@ impl Cache {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn mask(&self, t: usize) -> Result<Tensor> {
|
||||
fn mask(&self, t: usize) -> Result<Tensor> {
|
||||
let mut masks = self.masks.lock().unwrap();
|
||||
if let Some(mask) = masks.get(&t) {
|
||||
Ok(mask.clone())
|
@ -33,7 +33,7 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
||||
);
|
||||
let varmap = candle_nn::VarMap::new();
|
||||
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
|
||||
let config = Config::tiny_15m();
|
||||
let config = Config::tiny();
|
||||
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
|
||||
|
@ -1,8 +1,9 @@
|
||||
use anyhow::Result;
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use candle::{DType, Device, IndexOp, Result, Shape, Tensor};
|
||||
use candle::{DType, Device, IndexOp, Shape, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
use super::llama2_c::Config;
|
||||
use crate::model::Config;
|
||||
|
||||
pub struct TransformerWeights {
|
||||
// token embedding table
|
@ -89,10 +89,6 @@ struct Args {
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
@ -205,9 +201,16 @@ fn main() -> Result<()> {
|
||||
let cache = model::Cache::new(dtype, &config, &device)?;
|
||||
|
||||
println!("building the model");
|
||||
let vb = unsafe {
|
||||
candle_nn::var_builder::ShardedSafeTensors::var_builder(&filenames, dtype, &device)?
|
||||
};
|
||||
let handles = filenames
|
||||
.iter()
|
||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? }))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let tensors: Vec<_> = handles
|
||||
.iter()
|
||||
.map(|h| Ok(h.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let vb = candle_nn::var_builder::ShardedSafeTensors::var_builder(tensors, dtype, &device);
|
||||
let llama = Llama::load(vb, &cache, &config, comm)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
@ -219,7 +222,7 @@ fn main() -> Result<()> {
|
||||
.to_vec();
|
||||
|
||||
println!("starting the inference loop");
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
||||
let mut new_tokens = vec![];
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut index_pos = 0;
|
||||
|
@ -1,38 +0,0 @@
|
||||
# candle-marian-mt
|
||||
|
||||
`marian-mt` is a neural machine translation model. In this example it is used to
|
||||
translate text from French to English. See the associated [model
|
||||
card](https://huggingface.co/Helsinki-NLP/opus-mt-tc-big-fr-en) for details on
|
||||
the model itself.
|
||||
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
cargo run --example marian-mt --release -- \
|
||||
--text "Demain, dès l'aube, à l'heure où blanchit la campagne, Je partirai. Vois-tu, je sais que tu m'attends. J'irai par la forêt, j'irai par la montagne. Je ne puis demeurer loin de toi plus longtemps."
|
||||
```
|
||||
|
||||
```
|
||||
<NIL> Tomorrow, at dawn, at the time when the country is whitening, I will go. See,
|
||||
I know you are waiting for me. I will go through the forest, I will go through the
|
||||
mountain. I cannot stay far from you any longer.</s>
|
||||
```
|
||||
|
||||
## Generating the tokenizer.json files
|
||||
|
||||
You can use the following script to generate the `tokenizer.json` config files
|
||||
from the hf-hub repos. This requires the `tokenizers` and `sentencepiece`
|
||||
packages to be install and use the `convert_slow_tokenizer.py` script from this
|
||||
directory.
|
||||
|
||||
```python
|
||||
from convert_slow_tokenizer import MarianConverter
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
|
||||
fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
|
||||
fast_tokenizer.save(f"tokenizer-marian-base-fr.json")
|
||||
fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
|
||||
fast_tokenizer.save(f"tokenizer-marian-base-en.json")
|
||||
```
|
File diff suppressed because it is too large
Load Diff
@ -1,152 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Error as E;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::marian;
|
||||
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||
enum Which {
|
||||
Base,
|
||||
Big,
|
||||
}
|
||||
|
||||
// TODO: Maybe add support for the conditional prompt.
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_dec: Option<String>,
|
||||
|
||||
/// Choose the variant of the model to run.
|
||||
#[arg(long, default_value = "big")]
|
||||
which: Which,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Use the quantized version of the model.
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Text to be translated
|
||||
#[arg(long)]
|
||||
text: String,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
use hf_hub::api::sync::Api;
|
||||
let args = Args::parse();
|
||||
|
||||
let config = match args.which {
|
||||
Which::Base => marian::Config::opus_mt_fr_en(),
|
||||
Which::Big => marian::Config::opus_mt_tc_big_fr_en(),
|
||||
};
|
||||
let tokenizer = {
|
||||
let tokenizer = match args.tokenizer {
|
||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||
None => {
|
||||
let name = match args.which {
|
||||
Which::Base => "tokenizer-marian-base-fr.json",
|
||||
Which::Big => "tokenizer-marian-fr.json",
|
||||
};
|
||||
Api::new()?
|
||||
.model("lmz/candle-marian".to_string())
|
||||
.get(name)?
|
||||
}
|
||||
};
|
||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||
};
|
||||
|
||||
let tokenizer_dec = {
|
||||
let tokenizer = match args.tokenizer_dec {
|
||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||
None => {
|
||||
let name = match args.which {
|
||||
Which::Base => "tokenizer-marian-base-en.json",
|
||||
Which::Big => "tokenizer-marian-en.json",
|
||||
};
|
||||
Api::new()?
|
||||
.model("lmz/candle-marian".to_string())
|
||||
.get(name)?
|
||||
}
|
||||
};
|
||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||
};
|
||||
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = {
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => match args.which {
|
||||
Which::Base => Api::new()?
|
||||
.repo(hf_hub::Repo::with_revision(
|
||||
"Helsinki-NLP/opus-mt-fr-en".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/4".to_string(),
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
Which::Big => Api::new()?
|
||||
.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
|
||||
.get("model.safetensors")?,
|
||||
},
|
||||
};
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
|
||||
};
|
||||
let mut model = marian::MTModel::new(&config, vb)?;
|
||||
|
||||
let mut logits_processor =
|
||||
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
|
||||
|
||||
let encoder_xs = {
|
||||
let mut tokens = tokenizer
|
||||
.encode(args.text, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
tokens.push(config.eos_token_id);
|
||||
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
model.encoder().forward(&tokens, 0)?
|
||||
};
|
||||
|
||||
let mut token_ids = vec![config.decoder_start_token_id];
|
||||
for index in 0..1000 {
|
||||
let context_size = if index >= 1 { 1 } else { token_ids.len() };
|
||||
let start_pos = token_ids.len().saturating_sub(context_size);
|
||||
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
|
||||
let logits = model.decode(&input_ids, &encoder_xs, start_pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||
let token = logits_processor.sample(&logits)?;
|
||||
token_ids.push(token);
|
||||
if let Some(t) = tokenizer_dec.next_token(token)? {
|
||||
use std::io::Write;
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
if token == config.eos_token_id || token == config.forced_eos_token_id {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
println!();
|
||||
Ok(())
|
||||
}
|
@ -1,90 +0,0 @@
|
||||
# candle-mistral: 7b LLM with Apache 2.0 licensed weights
|
||||
|
||||
Mistral-7B-v0.1 is a pretrained generative LLM with 7 billion parameters. It outperforms all the publicly available 13b models
|
||||
as of 2023-09-28. Weights (and the original Python model code) are released under the permissive Apache 2.0 license.
|
||||
|
||||
- [Blog post](https://mistral.ai/news/announcing-mistral-7b/) from Mistral announcing the model release.
|
||||
- [Model card](https://huggingface.co/mistralai/Mistral-7B-v0.1) on the
|
||||
HuggingFace Hub.
|
||||
This example supports the initial model as well as a quantized variant.
|
||||
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
$ cargo run --example mistral --release --features cuda -- --prompt 'Write helloworld code in Rust' --sample-len 150
|
||||
|
||||
Generated text:
|
||||
Write helloworld code in Rust
|
||||
=============================
|
||||
|
||||
This is a simple example of how to write "Hello, world!" program in Rust.
|
||||
|
||||
## Compile and run
|
||||
|
||||
``bash
|
||||
$ cargo build --release
|
||||
Compiling hello-world v0.1.0 (/home/user/rust/hello-world)
|
||||
Finished release [optimized] target(s) in 0.26s
|
||||
$ ./target/release/hello-world
|
||||
Hello, world!
|
||||
``
|
||||
|
||||
## Source code
|
||||
|
||||
``rust
|
||||
fn main() {
|
||||
println!("Hello, world!");
|
||||
}
|
||||
``
|
||||
|
||||
## License
|
||||
|
||||
This example is released under the terms
|
||||
```
|
||||
|
||||
## Running the quantized version of the model
|
||||
|
||||
```bash
|
||||
$ cargo run --example mistral --features accelerate --release -- \
|
||||
$ --prompt "Here is a sample quick sort implementation in rust " --quantized -n 400
|
||||
avx: false, neon: true, simd128: false, f16c: false
|
||||
temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
|
||||
retrieved the files in 562.292µs
|
||||
loaded the model in 1.100323667s
|
||||
Here is a sample quick sort implementation in rust
|
||||
|
||||
``rust
|
||||
fn quick_sort(arr: &mut [i32]) {
|
||||
if arr.len() <= 1 {
|
||||
return;
|
||||
}
|
||||
|
||||
let pivot = arr[0];
|
||||
let mut left = vec![];
|
||||
let mut right = vec![];
|
||||
|
||||
for i in 1..arr.len() {
|
||||
if arr[i] < pivot {
|
||||
left.push(arr[i]);
|
||||
} else {
|
||||
right.push(arr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
quick_sort(&mut left);
|
||||
quick_sort(&mut right);
|
||||
|
||||
let mut i = 0;
|
||||
for _ in &left {
|
||||
arr[i] = left.pop().unwrap();
|
||||
i += 1;
|
||||
}
|
||||
|
||||
for _ in &right {
|
||||
arr[i] = right.pop().unwrap();
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
``
|
||||
226 tokens generated (10.91 token/s)
|
||||
```
|
@ -1,271 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::mistral::{Config, Model as Mistral};
|
||||
use candle_transformers::models::quantized_mistral::Model as QMistral;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
Mistral(Mistral),
|
||||
Quantized(QMistral),
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("</s>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the </s> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = match &mut self.model {
|
||||
Model::Mistral(m) => m.forward(&input, start_pos)?,
|
||||
Model::Quantized(m) => m.forward(&input, start_pos)?,
|
||||
};
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long, default_value = "lmz/candle-mistral")]
|
||||
model_id: String,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => {
|
||||
if args.quantized {
|
||||
vec![repo.get("model-q4k.gguf")?]
|
||||
} else {
|
||||
vec![
|
||||
repo.get("pytorch_model-00001-of-00002.safetensors")?,
|
||||
repo.get("pytorch_model-00002-of-00002.safetensors")?,
|
||||
]
|
||||
}
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::config_7b_v0_1(args.use_flash_attn);
|
||||
let (model, device) = if args.quantized {
|
||||
let filename = &filenames[0];
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
|
||||
let model = QMistral::new(&config, vb)?;
|
||||
(Model::Quantized(model), Device::Cpu)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Mistral::new(&config, vb)?;
|
||||
(Model::Mistral(model), device)
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -9,7 +9,7 @@ use clap::{Parser, ValueEnum};
|
||||
use rand::prelude::*;
|
||||
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
use candle_nn::{loss, ops, Conv2d, Linear, Module, ModuleT, Optimizer, VarBuilder, VarMap};
|
||||
use candle_nn::{loss, ops, Conv2d, Linear, Module, Optimizer, VarBuilder, VarMap};
|
||||
|
||||
const IMAGE_DIM: usize = 784;
|
||||
const LABELS: usize = 10;
|
||||
@ -95,7 +95,7 @@ impl ConvNet {
|
||||
.flatten_from(1)?
|
||||
.apply(&self.fc1)?
|
||||
.relu()?;
|
||||
self.dropout.forward_t(&xs, train)?.apply(&self.fc2)
|
||||
self.dropout.forward(&xs, train)?.apply(&self.fc2)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::nn::conv1d_weight_norm;
|
||||
use candle::{DType, IndexOp, Module, Result, Tensor};
|
||||
use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
|
||||
use candle::{DType, IndexOp, Result, Tensor};
|
||||
use candle_nn::{conv1d, Conv1d, Conv1dConfig, Module, VarBuilder};
|
||||
|
||||
// Encodec Model
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
|
||||
@ -8,7 +8,6 @@ use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum NormType {
|
||||
WeightNorm,
|
||||
TimeGroupNorm,
|
||||
None,
|
||||
}
|
||||
|
||||
@ -200,34 +199,25 @@ impl EncodecResidualVectorQuantizer {
|
||||
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
|
||||
#[derive(Debug)]
|
||||
struct EncodecLSTM {
|
||||
layers: Vec<candle_nn::LSTM>,
|
||||
layers: Vec<(Tensor, Tensor, Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl EncodecLSTM {
|
||||
fn load(dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let vb = &vb.pp("lstm");
|
||||
let mut layers = vec![];
|
||||
for layer_idx in 0..cfg.num_lstm_layers {
|
||||
let config = candle_nn::LSTMConfig {
|
||||
layer_idx,
|
||||
..Default::default()
|
||||
};
|
||||
let lstm = candle_nn::lstm(dim, dim, config, vb.clone())?;
|
||||
layers.push(lstm)
|
||||
for i in 0..cfg.num_lstm_layers {
|
||||
let w_hh = vb.get((4 * dim, dim), &format!("weight_hh_l{i}"))?;
|
||||
let w_ih = vb.get((4 * dim, dim), &format!("weight_ih_l{i}"))?;
|
||||
let b_hh = vb.get(4 * dim, &format!("bias_hh_l{i}"))?;
|
||||
let b_ih = vb.get(4 * dim, &format!("bias_ih_l{i}"))?;
|
||||
layers.push((w_hh, w_ih, b_hh, b_ih))
|
||||
}
|
||||
Ok(Self { layers })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecLSTM {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
use candle_nn::RNN;
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
let states = layer.seq(&xs)?;
|
||||
xs = layer.states_to_tensor(&states)?;
|
||||
}
|
||||
Ok(xs)
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
@ -257,9 +247,7 @@ impl EncodecConvTranspose1d {
|
||||
bias,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecConvTranspose1d {
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
@ -269,7 +257,6 @@ impl Module for EncodecConvTranspose1d {
|
||||
struct EncodecConv1d {
|
||||
causal: bool,
|
||||
conv: Conv1d,
|
||||
norm: Option<candle_nn::GroupNorm>,
|
||||
}
|
||||
|
||||
impl EncodecConv1d {
|
||||
@ -294,7 +281,7 @@ impl EncodecConv1d {
|
||||
},
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
NormType::None | NormType::TimeGroupNorm => conv1d(
|
||||
NormType::None => conv1d(
|
||||
in_c,
|
||||
out_c,
|
||||
kernel_size,
|
||||
@ -307,29 +294,17 @@ impl EncodecConv1d {
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
};
|
||||
let norm = match cfg.norm_type {
|
||||
NormType::None | NormType::WeightNorm => None,
|
||||
NormType::TimeGroupNorm => {
|
||||
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
|
||||
Some(gn)
|
||||
}
|
||||
};
|
||||
Ok(Self {
|
||||
causal: cfg.use_causal_conv,
|
||||
conv,
|
||||
norm,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecConv1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
// TODO: padding, depending on causal.
|
||||
let xs = self.conv.forward(xs)?;
|
||||
match &self.norm {
|
||||
None => Ok(xs),
|
||||
Some(norm) => xs.apply(norm),
|
||||
}
|
||||
// If we add support for NormType "time_group_norm", we should add some normalization here.
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
@ -365,9 +340,7 @@ impl EncodecResnetBlock {
|
||||
shortcut,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecResnetBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs.clone();
|
||||
let xs = xs.elu(1.)?;
|
||||
@ -466,17 +439,8 @@ impl EncodecEncoder {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.apply(&self.init_conv)?;
|
||||
for (resnets, conv) in self.sampling_layers.iter() {
|
||||
for resnet in resnets.iter() {
|
||||
xs = xs.apply(resnet)?;
|
||||
}
|
||||
xs = xs.elu(1.0)?.apply(conv)?;
|
||||
}
|
||||
xs.apply(&self.final_lstm)?
|
||||
.elu(1.0)?
|
||||
.apply(&self.final_conv)
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
@ -543,15 +507,8 @@ impl EncodecDecoder {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.apply(&self.init_conv)?.apply(&self.init_lstm)?;
|
||||
for (conv, resnets) in self.sampling_layers.iter() {
|
||||
xs = xs.elu(1.)?.apply(conv)?;
|
||||
for resnet in resnets.iter() {
|
||||
xs = xs.apply(resnet)?
|
||||
}
|
||||
}
|
||||
xs.elu(1.)?.apply(&self.final_conv)
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -73,7 +73,9 @@ fn main() -> Result<()> {
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DTYPE, &device)? };
|
||||
let model = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
||||
let model = model.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
|
||||
let config = GenConfig::small();
|
||||
let mut model = MusicgenForConditionalGeneration::load(vb, config)?;
|
||||
|
||||
|
@ -40,7 +40,7 @@ impl Default for Config {
|
||||
num_attention_heads: 16,
|
||||
layerdrop: 0.0,
|
||||
use_cache: true,
|
||||
activation_function: Activation::Gelu,
|
||||
activation_function: Activation::Gelu, // TODO: Handle old style gelu.
|
||||
hidden_size: 1024,
|
||||
dropout: 0.1,
|
||||
attention_dropout: 0.0,
|
||||
@ -66,7 +66,7 @@ impl Config {
|
||||
num_attention_heads: 16,
|
||||
layerdrop: 0.0,
|
||||
use_cache: true,
|
||||
activation_function: Activation::Gelu,
|
||||
activation_function: Activation::Gelu, // TODO: Handle old style gelu.
|
||||
hidden_size: 1024,
|
||||
dropout: 0.1,
|
||||
attention_dropout: 0.0,
|
||||
|
@ -1,10 +0,0 @@
|
||||
## Using ONNX models in Candle
|
||||
|
||||
This example demonstrates how to run ONNX based models in Candle, the model
|
||||
being used here is a small sequeezenet variant.
|
||||
|
||||
You can run the example with the following command:
|
||||
|
||||
```bash
|
||||
cargo run --example squeezenet-onnx --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
```
|
@ -1,78 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::{IndexOp, D};
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
SqueezeNet,
|
||||
EfficientNet,
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
/// The model to be used.
|
||||
#[arg(value_enum, long, default_value_t = Which::SqueezeNet)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
let image = match args.which {
|
||||
Which::SqueezeNet => image,
|
||||
Which::EfficientNet => image.permute((1, 2, 0))?,
|
||||
};
|
||||
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => match args.which {
|
||||
Which::SqueezeNet => hf_hub::api::sync::Api::new()?
|
||||
.model("lmz/candle-onnx".into())
|
||||
.get("squeezenet1.1-7.onnx")?,
|
||||
Which::EfficientNet => hf_hub::api::sync::Api::new()?
|
||||
.model("onnx/EfficientNet-Lite4".into())
|
||||
.get("efficientnet-lite4-11.onnx")?,
|
||||
},
|
||||
};
|
||||
|
||||
let model = candle_onnx::read_file(model)?;
|
||||
let graph = model.graph.as_ref().unwrap();
|
||||
let mut inputs = std::collections::HashMap::new();
|
||||
inputs.insert(graph.input[0].name.to_string(), image.unsqueeze(0)?);
|
||||
let mut outputs = candle_onnx::simple_eval(&model, inputs)?;
|
||||
let output = outputs.remove(&graph.output[0].name).unwrap();
|
||||
let prs = match args.which {
|
||||
Which::SqueezeNet => candle_nn::ops::softmax(&output, D::Minus1)?,
|
||||
Which::EfficientNet => output,
|
||||
};
|
||||
let prs = prs.i(0)?.to_vec1::<f32>()?;
|
||||
|
||||
// Sort the predictions and take the top 5
|
||||
let mut top: Vec<_> = prs.iter().enumerate().collect();
|
||||
top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
|
||||
let top = top.into_iter().take(5).collect::<Vec<_>>();
|
||||
|
||||
// Print the top predictions
|
||||
for &(i, p) in &top {
|
||||
println!(
|
||||
"{:50}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[i],
|
||||
p * 100.0
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
@ -1,87 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use candle::{Device, Tensor};
|
||||
|
||||
use clap::{Parser, Subcommand};
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
enum Command {
|
||||
Print {
|
||||
#[arg(long)]
|
||||
file: String,
|
||||
},
|
||||
SimpleEval {
|
||||
#[arg(long)]
|
||||
file: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
pub struct Args {
|
||||
#[command(subcommand)]
|
||||
command: Command,
|
||||
}
|
||||
|
||||
pub fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
match args.command {
|
||||
Command::Print { file } => {
|
||||
let model = candle_onnx::read_file(file)?;
|
||||
println!("{model:?}");
|
||||
let graph = model.graph.unwrap();
|
||||
for node in graph.node.iter() {
|
||||
println!("{node:?}");
|
||||
}
|
||||
}
|
||||
Command::SimpleEval { file } => {
|
||||
let model = candle_onnx::read_file(file)?;
|
||||
let graph = model.graph.as_ref().unwrap();
|
||||
let constants: std::collections::HashSet<_> =
|
||||
graph.initializer.iter().map(|i| i.name.as_str()).collect();
|
||||
let mut inputs = std::collections::HashMap::new();
|
||||
for input in graph.input.iter() {
|
||||
use candle_onnx::onnx::tensor_proto::DataType;
|
||||
if constants.contains(input.name.as_str()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let type_ = input.r#type.as_ref().expect("no type for input");
|
||||
let type_ = type_.value.as_ref().expect("no type.value for input");
|
||||
let value = match type_ {
|
||||
candle_onnx::onnx::type_proto::Value::TensorType(tt) => {
|
||||
let dt = match DataType::try_from(tt.elem_type) {
|
||||
Ok(dt) => match candle_onnx::dtype(dt) {
|
||||
Some(dt) => dt,
|
||||
None => {
|
||||
anyhow::bail!(
|
||||
"unsupported 'value' data-type {dt:?} for {}",
|
||||
input.name
|
||||
)
|
||||
}
|
||||
},
|
||||
type_ => anyhow::bail!("unsupported input type {type_:?}"),
|
||||
};
|
||||
let shape = tt.shape.as_ref().expect("no tensortype.shape for input");
|
||||
let dims = shape
|
||||
.dim
|
||||
.iter()
|
||||
.map(|dim| match dim.value.as_ref().expect("no dim value") {
|
||||
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
|
||||
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => Ok(42),
|
||||
})
|
||||
.collect::<Result<Vec<usize>>>()?;
|
||||
Tensor::zeros(dims, dt, &Device::Cpu)?
|
||||
}
|
||||
type_ => anyhow::bail!("unsupported input type {type_:?}"),
|
||||
};
|
||||
println!("input {}: {value:?}", input.name);
|
||||
inputs.insert(input.name.clone(), value);
|
||||
}
|
||||
let outputs = candle_onnx::simple_eval(&model, inputs)?;
|
||||
for (name, value) in outputs.iter() {
|
||||
println!("output {name}: {value:?}")
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -1,56 +0,0 @@
|
||||
# candle-phi: 1.3b LLM with state of the art performance for <10b models.
|
||||
|
||||
[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) is a language model using
|
||||
only 1.3 billion parameters but with state of the art performance compared to
|
||||
models with up to 10 billion parameters.
|
||||
|
||||
The candle implementation provides both the standard version as well as a
|
||||
quantized variant.
|
||||
|
||||
## Running some example
|
||||
|
||||
```bash
|
||||
$ cargo run --example phi --release -- --prompt "def print_prime(n): "
|
||||
|
||||
def print_prime(n):
|
||||
print("Printing prime numbers")
|
||||
for i in range(2, n+1):
|
||||
if is_prime(i):
|
||||
print(i)
|
||||
|
||||
def is_prime(n):
|
||||
if n <= 1:
|
||||
return False
|
||||
for i in range(2, int(math.sqrt(n))+1):
|
||||
if n % i == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
$ cargo run --example phi --release -- \
|
||||
--prompt "Explain how to find the median of an array and write the corresponding python function.\nAnswer:" \
|
||||
--quantized --sample-len 200
|
||||
|
||||
Explain how to find the median of an array and write the corresponding python function.
|
||||
Answer: The median is the middle value in an array. If the array has an even number of elements, the median is the average of the two middle values.
|
||||
|
||||
def median(arr):
|
||||
arr.sort()
|
||||
n = len(arr)
|
||||
if n % 2 == 0:
|
||||
return (arr[n//2 - 1] + arr[n//2]) / 2
|
||||
else:
|
||||
return arr[n//2]
|
||||
```
|
||||
|
||||
This also supports the [Puffin Phi v2
|
||||
model](https://huggingface.co/teknium/Puffin-Phi-v2) for human interaction.
|
||||
```
|
||||
$ cargo run --example phi --release -- \
|
||||
--prompt "USER: What would you do on a sunny day in Paris?\nASSISTANT:" \
|
||||
--sample-len 200 --model puffin-phi-v2 --quantized
|
||||
USER: What would you do on a sunny day in Paris?
|
||||
ASSISTANT: On a sunny day in Paris, you could visit the Musée du Louvre to admire the famous
|
||||
painting "Mona Lisa" by Leonardo da Vinci. You might also want to stroll along the Champs-Élysées
|
||||
and enjoy the beautiful architecture of the buildings around you. Don't forget to stop by a café
|
||||
for a cup of coffee and to soak up the sun!"
|
||||
```
|
@ -1,313 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
||||
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
MixFormer(MixFormer),
|
||||
Quantized(QMixFormer),
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
verbose_prompt,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
println!("starting the inference loop");
|
||||
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
||||
if tokens.is_empty() {
|
||||
anyhow::bail!("Empty prompts are not supported in the phi model.")
|
||||
}
|
||||
if self.verbose_prompt {
|
||||
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
||||
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
println!("{id:7} -> '{token}'");
|
||||
}
|
||||
}
|
||||
let mut tokens = tokens.get_ids().to_vec();
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
||||
Some(token) => *token,
|
||||
None => anyhow::bail!("cannot find the endoftext token"),
|
||||
};
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush()?;
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = match &mut self.model {
|
||||
Model::MixFormer(m) => m.forward(&input)?,
|
||||
Model::Quantized(m) => m.forward(&input)?,
|
||||
};
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||
print!("{token}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum WhichModel {
|
||||
#[value(name = "1")]
|
||||
V1,
|
||||
#[value(name = "1.5")]
|
||||
V1_5,
|
||||
PuffinPhiV2,
|
||||
PhiHermes,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Display the token for the specified prompt.
|
||||
#[arg(long)]
|
||||
verbose_prompt: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "1.5")]
|
||||
model: WhichModel,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id.to_string(),
|
||||
None => {
|
||||
if args.quantized {
|
||||
"lmz/candle-quantized-phi".to_string()
|
||||
} else {
|
||||
match args.model {
|
||||
WhichModel::V1 => "microsoft/phi-1".to_string(),
|
||||
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
"lmz/candle-quantized-phi".to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
let revision = match args.revision {
|
||||
Some(rev) => rev.to_string(),
|
||||
None => {
|
||||
if args.quantized {
|
||||
"main".to_string()
|
||||
} else {
|
||||
match args.model {
|
||||
WhichModel::V1 => "refs/pr/2".to_string(),
|
||||
WhichModel::V1_5 => "refs/pr/18".to_string(),
|
||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => "main".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
let tokenizer_filename = match args.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 => repo.get("tokenizer.json")?,
|
||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
repo.get("tokenizer-puffin-phi-v2.json")?
|
||||
}
|
||||
},
|
||||
};
|
||||
let filename = match args.weight_file {
|
||||
Some(weight_file) => std::path::PathBuf::from(weight_file),
|
||||
None => {
|
||||
if args.quantized {
|
||||
match args.model {
|
||||
WhichModel::V1 => repo.get("model-v1-q4k.gguf")?,
|
||||
WhichModel::V1_5 => repo.get("model-q4k.gguf")?,
|
||||
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2-q4k.gguf")?,
|
||||
WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B-q4k.gguf")?,
|
||||
}
|
||||
} else {
|
||||
match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 => repo.get("model.safetensors")?,
|
||||
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2.safetensors")?,
|
||||
WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B.safetensors")?,
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = match args.model {
|
||||
WhichModel::V1 => Config::v1(),
|
||||
WhichModel::V1_5 => Config::v1_5(),
|
||||
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
||||
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
||||
};
|
||||
let (model, device) = if args.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
|
||||
let model = QMixFormer::new(&config, vb)?;
|
||||
(Model::Quantized(model), Device::Cpu)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
|
||||
let model = MixFormer::new(&config, vb)?;
|
||||
(Model::MixFormer(model), device)
|
||||
};
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
args.verbose_prompt,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -1,63 +0,0 @@
|
||||
# candle-quantized-t5
|
||||
|
||||
## Seq2Seq example
|
||||
|
||||
This example uses a quantized version of the t5 model.
|
||||
|
||||
```bash
|
||||
$ cargo run --example quantized-t5 --release -- --prompt "translate to German: A beautiful candle."
|
||||
...
|
||||
Eine schöne Kerze.
|
||||
```
|
||||
|
||||
## Generating Quantized weight files
|
||||
|
||||
The weight file is automatically retrieved from the hub. It is also possible to
|
||||
generate quantized weight files from the original safetensors file by using the
|
||||
`tensor-tools` command line utility via:
|
||||
|
||||
```bash
|
||||
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
|
||||
```
|
||||
|
||||
## Using custom models
|
||||
|
||||
To use a different model, specify the `model-id`.
|
||||
|
||||
For example, for text editing, you can use quantized [CoEdit models](https://huggingface.co/jbochi/candle-coedit-quantized).
|
||||
|
||||
```bash
|
||||
$ cargo run --example quantized-t5 --release -- \
|
||||
--model-id "jbochi/candle-coedit-quantized" \
|
||||
--prompt "Make this text coherent: Their flight is weak. They run quickly through the tree canopy." \
|
||||
--temperature 0
|
||||
...
|
||||
Although their flight is weak, they run quickly through the tree canopy.
|
||||
```
|
||||
|
||||
By default, it will look for `model.gguf` and `config.json`, but you can specify
|
||||
custom local or remote `weight-file` and `config-file`s:
|
||||
|
||||
```bash
|
||||
cargo run --example quantized-t5 --release -- \
|
||||
--model-id "jbochi/candle-coedit-quantized" \
|
||||
--weight-file "model-xl.gguf" \
|
||||
--config-file "config-xl.json" \
|
||||
--prompt "Rewrite to make this easier to understand: Note that a storm surge is what forecasters consider a hurricane's most treacherous aspect." \
|
||||
--temperature 0
|
||||
...
|
||||
Note that a storm surge is what forecasters consider a hurricane's most dangerous part.
|
||||
```
|
||||
|
||||
### [MADLAD-400](https://arxiv.org/abs/2309.04662)
|
||||
|
||||
MADLAD-400 is a series of multilingual machine translation T5 models trained on 250 billion tokens covering over 450 languages using publicly available data. These models are competitive with significantly larger models.
|
||||
|
||||
```bash
|
||||
cargo run --example quantized-t5 --release -- \
|
||||
--model-id "jbochi/madlad400-3b-mt" --weight-file "model-q4k.gguf" \
|
||||
--prompt "<2de> How are you, my friend?" \
|
||||
--temperature 0
|
||||
...
|
||||
Wie geht es dir, mein Freund?
|
||||
```
|
@ -1,232 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use candle_transformers::models::quantized_t5 as t5;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{Device, Tensor};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::{api::sync::Api, api::sync::ApiRepo, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||
enum Which {
|
||||
T5Small,
|
||||
FlanT5Small,
|
||||
FlanT5Base,
|
||||
FlanT5Large,
|
||||
FlanT5Xl,
|
||||
FlanT5Xxl,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// The model repository to use on the HuggingFace hub.
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
// Enable/disable decoding.
|
||||
#[arg(long, default_value = "false")]
|
||||
disable_cache: bool,
|
||||
|
||||
/// Use this prompt, otherwise compute sentence similarities.
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "t5-small")]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
struct T5ModelBuilder {
|
||||
device: Device,
|
||||
config: t5::Config,
|
||||
weights_filename: PathBuf,
|
||||
}
|
||||
|
||||
impl T5ModelBuilder {
|
||||
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
|
||||
let device = Device::Cpu;
|
||||
let default_model = "lmz/candle-quantized-t5".to_string();
|
||||
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||
(None, Some(revision)) => (default_model, revision),
|
||||
(None, None) => (default_model, "main".to_string()),
|
||||
};
|
||||
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
let config_filename = match &args.config_file {
|
||||
Some(filename) => Self::get_local_or_remote_file(filename, &api)?,
|
||||
None => match args.which {
|
||||
Which::T5Small => api.get("config.json")?,
|
||||
Which::FlanT5Small => api.get("config-flan-t5-small.json")?,
|
||||
Which::FlanT5Base => api.get("config-flan-t5-base.json")?,
|
||||
Which::FlanT5Large => api.get("config-flan-t5-large.json")?,
|
||||
Which::FlanT5Xl => api.get("config-flan-t5-xl.json")?,
|
||||
Which::FlanT5Xxl => api.get("config-flan-t5-xxl.json")?,
|
||||
},
|
||||
};
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let weights_filename = match &args.weight_file {
|
||||
Some(filename) => Self::get_local_or_remote_file(filename, &api)?,
|
||||
None => match args.which {
|
||||
Which::T5Small => api.get("model.gguf")?,
|
||||
Which::FlanT5Small => api.get("model-flan-t5-small.gguf")?,
|
||||
Which::FlanT5Base => api.get("model-flan-t5-base.gguf")?,
|
||||
Which::FlanT5Large => api.get("model-flan-t5-large.gguf")?,
|
||||
Which::FlanT5Xl => api.get("model-flan-t5-xl.gguf")?,
|
||||
Which::FlanT5Xxl => api.get("model-flan-t5-xxl.gguf")?,
|
||||
},
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let mut config: t5::Config = serde_json::from_str(&config)?;
|
||||
config.use_cache = !args.disable_cache;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
Ok((
|
||||
Self {
|
||||
device,
|
||||
config,
|
||||
weights_filename,
|
||||
},
|
||||
tokenizer,
|
||||
))
|
||||
}
|
||||
|
||||
pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
|
||||
let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
|
||||
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
||||
}
|
||||
|
||||
fn get_local_or_remote_file(filename: &str, api: &ApiRepo) -> Result<PathBuf> {
|
||||
let local_filename = std::path::PathBuf::from(filename);
|
||||
if local_filename.exists() {
|
||||
Ok(local_filename)
|
||||
} else {
|
||||
Ok(api.get(filename)?)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
|
||||
let device = &builder.device;
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
let tokens = tokenizer
|
||||
.encode(args.prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
let mut model = builder.build_model()?;
|
||||
let mut output_token_ids = [builder
|
||||
.config
|
||||
.decoder_start_token_id
|
||||
.unwrap_or(builder.config.pad_token_id) as u32]
|
||||
.to_vec();
|
||||
let temperature = if args.temperature <= 0. {
|
||||
None
|
||||
} else {
|
||||
Some(args.temperature)
|
||||
};
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, temperature, args.top_p);
|
||||
let encoder_output = model.encode(&input_token_ids)?;
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
for index in 0.. {
|
||||
if output_token_ids.len() > 512 {
|
||||
break;
|
||||
}
|
||||
let decoder_token_ids = if index == 0 || !builder.config.use_cache {
|
||||
Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
|
||||
} else {
|
||||
let last_token = *output_token_ids.last().unwrap();
|
||||
Tensor::new(&[last_token], device)?.unsqueeze(0)?
|
||||
};
|
||||
let logits = model
|
||||
.decode(&decoder_token_ids, &encoder_output)?
|
||||
.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&output_token_ids[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token_id = logits_processor.sample(&logits)?;
|
||||
if next_token_id as usize == builder.config.eos_token_id {
|
||||
break;
|
||||
}
|
||||
output_token_ids.push(next_token_id);
|
||||
if let Some(text) = tokenizer.id_to_token(next_token_id) {
|
||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
print!("{text}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start.elapsed();
|
||||
println!(
|
||||
"\n{} tokens generated ({:.2} token/s)\n",
|
||||
output_token_ids.len(),
|
||||
output_token_ids.len() as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
@ -12,7 +12,6 @@ use candle::quantized::{ggml_file, gguf_file};
|
||||
use candle::{Device, Tensor};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_transformers::models::quantized_llama as model;
|
||||
use model::ModelWeights;
|
||||
|
||||
@ -25,7 +24,7 @@ enum Prompt {
|
||||
One(String),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "7b")]
|
||||
L7b,
|
||||
@ -45,52 +44,6 @@ enum Which {
|
||||
L13bCode,
|
||||
#[value(name = "32b-code")]
|
||||
L34bCode,
|
||||
#[value(name = "7b-mistral")]
|
||||
Mistral7b,
|
||||
#[value(name = "7b-mistral-instruct")]
|
||||
Mistral7bInstruct,
|
||||
#[value(name = "7b-zephyr-a")]
|
||||
Zephyr7bAlpha,
|
||||
#[value(name = "7b-zephyr-b")]
|
||||
Zephyr7bBeta,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn is_mistral(&self) -> bool {
|
||||
match self {
|
||||
Self::L7b
|
||||
| Self::L13b
|
||||
| Self::L70b
|
||||
| Self::L7bChat
|
||||
| Self::L13bChat
|
||||
| Self::L70bChat
|
||||
| Self::L7bCode
|
||||
| Self::L13bCode
|
||||
| Self::L34bCode => false,
|
||||
// Zephyr is a fine tuned version of mistral and should be treated in the same way.
|
||||
Self::Zephyr7bAlpha
|
||||
| Self::Zephyr7bBeta
|
||||
| Self::Mistral7b
|
||||
| Self::Mistral7bInstruct => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_zephyr(&self) -> bool {
|
||||
match self {
|
||||
Self::L7b
|
||||
| Self::L13b
|
||||
| Self::L70b
|
||||
| Self::L7bChat
|
||||
| Self::L13bChat
|
||||
| Self::L70bChat
|
||||
| Self::L7bCode
|
||||
| Self::L13bCode
|
||||
| Self::L34bCode
|
||||
| Self::Mistral7b
|
||||
| Self::Mistral7bInstruct => false,
|
||||
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -107,7 +60,7 @@ struct Args {
|
||||
prompt: Option<String>,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(short = 'n', long, default_value_t = 1000)]
|
||||
#[arg(short = 'n', long, default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
/// The tokenizer config in json format.
|
||||
@ -157,12 +110,7 @@ impl Args {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let repo = if self.which.is_mistral() {
|
||||
"mistralai/Mistral-7B-v0.1"
|
||||
} else {
|
||||
"hf-internal-testing/llama-tokenizer"
|
||||
};
|
||||
let api = api.model(repo.to_string());
|
||||
let api = api.model("hf-internal-testing/llama-tokenizer".to_string());
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
};
|
||||
@ -192,21 +140,6 @@ impl Args {
|
||||
Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"),
|
||||
Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"),
|
||||
Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"),
|
||||
Which::Mistral7b => (
|
||||
"TheBloke/Mistral-7B-v0.1-GGUF",
|
||||
"mistral-7b-v0.1.Q4_K_S.gguf",
|
||||
),
|
||||
Which::Mistral7bInstruct => (
|
||||
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
|
||||
"mistral-7b-instruct-v0.1.Q4_K_S.gguf",
|
||||
),
|
||||
Which::Zephyr7bAlpha => (
|
||||
"TheBloke/zephyr-7B-alpha-GGUF",
|
||||
"zephyr-7b-alpha.Q4_K_M.gguf",
|
||||
),
|
||||
Which::Zephyr7bBeta => {
|
||||
("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf")
|
||||
}
|
||||
};
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(repo.to_string());
|
||||
@ -217,6 +150,31 @@ impl Args {
|
||||
}
|
||||
}
|
||||
|
||||
fn print_token(next_token: u32, tokenizer: &Tokenizer) {
|
||||
// Extracting the last token as a string is complicated, here we just apply some simple
|
||||
// heuristics as it seems to work well enough for this example. See the following for more
|
||||
// details:
|
||||
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
||||
if let Some(text) = tokenizer.id_to_token(next_token) {
|
||||
let text = text.replace('▁', " ");
|
||||
let ascii = text
|
||||
.strip_prefix("<0x")
|
||||
.and_then(|t| t.strip_suffix('>'))
|
||||
.and_then(|t| u8::from_str_radix(t, 16).ok());
|
||||
match ascii {
|
||||
None => print!("{text}"),
|
||||
Some(ascii) => {
|
||||
if let Some(chr) = char::from_u32(ascii as u32) {
|
||||
if chr.is_ascii() {
|
||||
print!("{chr}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let _ = std::io::stdout().flush();
|
||||
}
|
||||
}
|
||||
|
||||
fn format_size(size_in_bytes: usize) -> String {
|
||||
if size_in_bytes < 1_000 {
|
||||
format!("{}B", size_in_bytes)
|
||||
@ -303,12 +261,7 @@ fn main() -> anyhow::Result<()> {
|
||||
| Which::L7bCode
|
||||
| Which::L13bCode
|
||||
| Which::L34bCode => 1,
|
||||
Which::Mistral7b
|
||||
| Which::Mistral7bInstruct
|
||||
| Which::Zephyr7bAlpha
|
||||
| Which::Zephyr7bBeta
|
||||
| Which::L70b
|
||||
| Which::L70bChat => 8,
|
||||
Which::L70b | Which::L70bChat => 8,
|
||||
};
|
||||
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
|
||||
}
|
||||
@ -316,7 +269,6 @@ fn main() -> anyhow::Result<()> {
|
||||
println!("model built");
|
||||
|
||||
let tokenizer = args.tokenizer()?;
|
||||
let mut tos = TokenOutputStream::new(tokenizer);
|
||||
let prompt = match args.prompt.as_deref() {
|
||||
Some("chat") => Prompt::Chat,
|
||||
Some("interactive") => Prompt::Interactive,
|
||||
@ -325,11 +277,10 @@ fn main() -> anyhow::Result<()> {
|
||||
};
|
||||
|
||||
let mut pre_prompt_tokens = vec![];
|
||||
for prompt_index in 0.. {
|
||||
loop {
|
||||
let prompt_str = match &prompt {
|
||||
Prompt::One(prompt) => prompt.clone(),
|
||||
Prompt::Interactive | Prompt::Chat => {
|
||||
let is_interactive = matches!(prompt, Prompt::Interactive);
|
||||
print!("> ");
|
||||
std::io::stdout().flush()?;
|
||||
let mut prompt = String::new();
|
||||
@ -340,22 +291,11 @@ fn main() -> anyhow::Result<()> {
|
||||
prompt.pop();
|
||||
}
|
||||
}
|
||||
if args.which.is_zephyr() {
|
||||
if prompt_index == 0 || is_interactive {
|
||||
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>",)
|
||||
} else {
|
||||
format!("<|user|>\n{prompt}</s>\n<|assistant|>")
|
||||
}
|
||||
} else if args.which.is_mistral() {
|
||||
format!("[INST] {prompt} [/INST]")
|
||||
} else {
|
||||
prompt
|
||||
}
|
||||
prompt
|
||||
}
|
||||
};
|
||||
print!("{}", &prompt_str);
|
||||
let tokens = tos
|
||||
.tokenizer()
|
||||
let tokens = tokenizer
|
||||
.encode(prompt_str, true)
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
if args.verbose_prompt {
|
||||
@ -385,15 +325,9 @@ fn main() -> anyhow::Result<()> {
|
||||
};
|
||||
let prompt_dt = start_prompt_processing.elapsed();
|
||||
all_tokens.push(next_token);
|
||||
if let Some(t) = tos.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
||||
let eos_token = *tos.tokenizer().get_vocab(true).get("</s>").unwrap();
|
||||
print_token(next_token, &tokenizer);
|
||||
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
let mut sampled = 0;
|
||||
for index in 0..to_sample {
|
||||
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||
@ -410,19 +344,8 @@ fn main() -> anyhow::Result<()> {
|
||||
};
|
||||
next_token = logits_processor.sample(&logits)?;
|
||||
all_tokens.push(next_token);
|
||||
if let Some(t) = tos.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
sampled += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
};
|
||||
print_token(next_token, &tokenizer);
|
||||
}
|
||||
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
let dt = start_post_prompt.elapsed();
|
||||
println!(
|
||||
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
||||
@ -430,8 +353,9 @@ fn main() -> anyhow::Result<()> {
|
||||
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
||||
);
|
||||
println!(
|
||||
"{sampled:4} tokens generated: {:.2} token/s",
|
||||
sampled as f64 / dt.as_secs_f64(),
|
||||
"{:4} tokens generated: {:.2} token/s",
|
||||
to_sample,
|
||||
to_sample as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
|
||||
match prompt {
|
||||
|
@ -1,16 +0,0 @@
|
||||
# candle-reinforcement-learning
|
||||
|
||||
Reinforcement Learning examples for candle.
|
||||
|
||||
This has been tested with `gymnasium` version `0.29.1`. You can install the
|
||||
Python package with:
|
||||
```bash
|
||||
pip install "gymnasium[accept-rom-license]"
|
||||
```
|
||||
|
||||
In order to run the example, use the following command. Note the additional
|
||||
`--package` flag to ensure that there is no conflict with the `candle-pyo3`
|
||||
crate.
|
||||
```bash
|
||||
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples
|
||||
```
|
@ -1,308 +0,0 @@
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
from PIL import Image
|
||||
from multiprocessing import Process, Pipe
|
||||
|
||||
# atari_wrappers.py
|
||||
class NoopResetEnv(gym.Wrapper):
|
||||
def __init__(self, env, noop_max=30):
|
||||
"""Sample initial states by taking random number of no-ops on reset.
|
||||
No-op is assumed to be action 0.
|
||||
"""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.noop_max = noop_max
|
||||
self.override_num_noops = None
|
||||
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
|
||||
|
||||
def reset(self):
|
||||
""" Do no-op action for a number of steps in [1, noop_max]."""
|
||||
self.env.reset()
|
||||
if self.override_num_noops is not None:
|
||||
noops = self.override_num_noops
|
||||
else:
|
||||
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) #pylint: disable=E1101
|
||||
assert noops > 0
|
||||
obs = None
|
||||
for _ in range(noops):
|
||||
obs, _, done, _ = self.env.step(0)
|
||||
if done:
|
||||
obs = self.env.reset()
|
||||
return obs
|
||||
|
||||
class FireResetEnv(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
"""Take action on reset for environments that are fixed until firing."""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
|
||||
assert len(env.unwrapped.get_action_meanings()) >= 3
|
||||
|
||||
def reset(self):
|
||||
self.env.reset()
|
||||
obs, _, done, _ = self.env.step(1)
|
||||
if done:
|
||||
self.env.reset()
|
||||
obs, _, done, _ = self.env.step(2)
|
||||
if done:
|
||||
self.env.reset()
|
||||
return obs
|
||||
|
||||
class ImageSaver(gym.Wrapper):
|
||||
def __init__(self, env, img_path, rank):
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self._cnt = 0
|
||||
self._img_path = img_path
|
||||
self._rank = rank
|
||||
|
||||
def step(self, action):
|
||||
step_result = self.env.step(action)
|
||||
obs, _, _, _ = step_result
|
||||
img = Image.fromarray(obs, 'RGB')
|
||||
img.save('%s/out%d-%05d.png' % (self._img_path, self._rank, self._cnt))
|
||||
self._cnt += 1
|
||||
return step_result
|
||||
|
||||
class EpisodicLifeEnv(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
"""Make end-of-life == end-of-episode, but only reset on true game over.
|
||||
Done by DeepMind for the DQN and co. since it helps value estimation.
|
||||
"""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.lives = 0
|
||||
self.was_real_done = True
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self.was_real_done = done
|
||||
# check current lives, make loss of life terminal,
|
||||
# then update lives to handle bonus lives
|
||||
lives = self.env.unwrapped.ale.lives()
|
||||
if lives < self.lives and lives > 0:
|
||||
# for Qbert somtimes we stay in lives == 0 condtion for a few frames
|
||||
# so its important to keep lives > 0, so that we only reset once
|
||||
# the environment advertises done.
|
||||
done = True
|
||||
self.lives = lives
|
||||
return obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
"""Reset only when lives are exhausted.
|
||||
This way all states are still reachable even though lives are episodic,
|
||||
and the learner need not know about any of this behind-the-scenes.
|
||||
"""
|
||||
if self.was_real_done:
|
||||
obs = self.env.reset()
|
||||
else:
|
||||
# no-op step to advance from terminal/lost life state
|
||||
obs, _, _, _ = self.env.step(0)
|
||||
self.lives = self.env.unwrapped.ale.lives()
|
||||
return obs
|
||||
|
||||
class MaxAndSkipEnv(gym.Wrapper):
|
||||
def __init__(self, env, skip=4):
|
||||
"""Return only every `skip`-th frame"""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
# most recent raw observations (for max pooling across time steps)
|
||||
self._obs_buffer = deque(maxlen=2)
|
||||
self._skip = skip
|
||||
|
||||
def step(self, action):
|
||||
"""Repeat action, sum reward, and max over last observations."""
|
||||
total_reward = 0.0
|
||||
done = None
|
||||
for _ in range(self._skip):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self._obs_buffer.append(obs)
|
||||
total_reward += reward
|
||||
if done:
|
||||
break
|
||||
max_frame = np.max(np.stack(self._obs_buffer), axis=0)
|
||||
|
||||
return max_frame, total_reward, done, info
|
||||
|
||||
def reset(self):
|
||||
"""Clear past frame buffer and init. to first obs. from inner env."""
|
||||
self._obs_buffer.clear()
|
||||
obs = self.env.reset()
|
||||
self._obs_buffer.append(obs)
|
||||
return obs
|
||||
|
||||
class ClipRewardEnv(gym.RewardWrapper):
|
||||
def reward(self, reward):
|
||||
"""Bin reward to {+1, 0, -1} by its sign."""
|
||||
return np.sign(reward)
|
||||
|
||||
class WarpFrame(gym.ObservationWrapper):
|
||||
def __init__(self, env):
|
||||
"""Warp frames to 84x84 as done in the Nature paper and later work."""
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
self.res = 84
|
||||
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(self.res, self.res, 1), dtype='uint8')
|
||||
|
||||
def observation(self, obs):
|
||||
frame = np.dot(obs.astype('float32'), np.array([0.299, 0.587, 0.114], 'float32'))
|
||||
frame = np.array(Image.fromarray(frame).resize((self.res, self.res),
|
||||
resample=Image.BILINEAR), dtype=np.uint8)
|
||||
return frame.reshape((self.res, self.res, 1))
|
||||
|
||||
class FrameStack(gym.Wrapper):
|
||||
def __init__(self, env, k):
|
||||
"""Buffer observations and stack across channels (last axis)."""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.k = k
|
||||
self.frames = deque([], maxlen=k)
|
||||
shp = env.observation_space.shape
|
||||
assert shp[2] == 1 # can only stack 1-channel frames
|
||||
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(shp[0], shp[1], k), dtype='uint8')
|
||||
|
||||
def reset(self):
|
||||
"""Clear buffer and re-fill by duplicating the first observation."""
|
||||
ob = self.env.reset()
|
||||
for _ in range(self.k): self.frames.append(ob)
|
||||
return self.observation()
|
||||
|
||||
def step(self, action):
|
||||
ob, reward, done, info = self.env.step(action)
|
||||
self.frames.append(ob)
|
||||
return self.observation(), reward, done, info
|
||||
|
||||
def observation(self):
|
||||
assert len(self.frames) == self.k
|
||||
return np.concatenate(self.frames, axis=2)
|
||||
|
||||
def wrap_deepmind(env, episode_life=True, clip_rewards=True):
|
||||
"""Configure environment for DeepMind-style Atari.
|
||||
|
||||
Note: this does not include frame stacking!"""
|
||||
assert 'NoFrameskip' in env.spec.id # required for DeepMind-style skip
|
||||
if episode_life:
|
||||
env = EpisodicLifeEnv(env)
|
||||
env = NoopResetEnv(env, noop_max=30)
|
||||
env = MaxAndSkipEnv(env, skip=4)
|
||||
if 'FIRE' in env.unwrapped.get_action_meanings():
|
||||
env = FireResetEnv(env)
|
||||
env = WarpFrame(env)
|
||||
if clip_rewards:
|
||||
env = ClipRewardEnv(env)
|
||||
return env
|
||||
|
||||
# envs.py
|
||||
def make_env(env_id, img_dir, seed, rank):
|
||||
def _thunk():
|
||||
env = gym.make(env_id)
|
||||
env.reset(seed=(seed + rank))
|
||||
if img_dir is not None:
|
||||
env = ImageSaver(env, img_dir, rank)
|
||||
env = wrap_deepmind(env)
|
||||
env = WrapPyTorch(env)
|
||||
return env
|
||||
|
||||
return _thunk
|
||||
|
||||
class WrapPyTorch(gym.ObservationWrapper):
|
||||
def __init__(self, env=None):
|
||||
super(WrapPyTorch, self).__init__(env)
|
||||
self.observation_space = gym.spaces.Box(0.0, 1.0, [1, 84, 84], dtype='float32')
|
||||
|
||||
def observation(self, observation):
|
||||
return observation.transpose(2, 0, 1)
|
||||
|
||||
# vecenv.py
|
||||
class VecEnv(object):
|
||||
"""
|
||||
Vectorized environment base class
|
||||
"""
|
||||
def step(self, vac):
|
||||
"""
|
||||
Apply sequence of actions to sequence of environments
|
||||
actions -> (observations, rewards, news)
|
||||
|
||||
where 'news' is a boolean vector indicating whether each element is new.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
def reset(self):
|
||||
"""
|
||||
Reset all environments
|
||||
"""
|
||||
raise NotImplementedError
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
# subproc_vec_env.py
|
||||
def worker(remote, env_fn_wrapper):
|
||||
env = env_fn_wrapper.x()
|
||||
while True:
|
||||
cmd, data = remote.recv()
|
||||
if cmd == 'step':
|
||||
ob, reward, done, info = env.step(data)
|
||||
if done:
|
||||
ob = env.reset()
|
||||
remote.send((ob, reward, done, info))
|
||||
elif cmd == 'reset':
|
||||
ob = env.reset()
|
||||
remote.send(ob)
|
||||
elif cmd == 'close':
|
||||
remote.close()
|
||||
break
|
||||
elif cmd == 'get_spaces':
|
||||
remote.send((env.action_space, env.observation_space))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
class CloudpickleWrapper(object):
|
||||
"""
|
||||
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
|
||||
"""
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
def __getstate__(self):
|
||||
import cloudpickle
|
||||
return cloudpickle.dumps(self.x)
|
||||
def __setstate__(self, ob):
|
||||
import pickle
|
||||
self.x = pickle.loads(ob)
|
||||
|
||||
class SubprocVecEnv(VecEnv):
|
||||
def __init__(self, env_fns):
|
||||
"""
|
||||
envs: list of gym environments to run in subprocesses
|
||||
"""
|
||||
nenvs = len(env_fns)
|
||||
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
|
||||
self.ps = [Process(target=worker, args=(work_remote, CloudpickleWrapper(env_fn)))
|
||||
for (work_remote, env_fn) in zip(self.work_remotes, env_fns)]
|
||||
for p in self.ps:
|
||||
p.start()
|
||||
|
||||
self.remotes[0].send(('get_spaces', None))
|
||||
self.action_space, self.observation_space = self.remotes[0].recv()
|
||||
|
||||
|
||||
def step(self, actions):
|
||||
for remote, action in zip(self.remotes, actions):
|
||||
remote.send(('step', action))
|
||||
results = [remote.recv() for remote in self.remotes]
|
||||
obs, rews, dones, infos = zip(*results)
|
||||
return np.stack(obs), np.stack(rews), np.stack(dones), infos
|
||||
|
||||
def reset(self):
|
||||
for remote in self.remotes:
|
||||
remote.send(('reset', None))
|
||||
return np.stack([remote.recv() for remote in self.remotes])
|
||||
|
||||
def close(self):
|
||||
for remote in self.remotes:
|
||||
remote.send(('close', None))
|
||||
for p in self.ps:
|
||||
p.join()
|
||||
|
||||
@property
|
||||
def num_envs(self):
|
||||
return len(self.remotes)
|
||||
|
||||
# Create the environment.
|
||||
def make(env_name, img_dir, num_processes):
|
||||
envs = SubprocVecEnv([
|
||||
make_env(env_name, img_dir, 1337, i) for i in range(num_processes)
|
||||
])
|
||||
return envs
|
@ -1,451 +0,0 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::fmt::Display;
|
||||
|
||||
use candle::{DType, Device, Error, Module, Result, Tensor, Var};
|
||||
use candle_nn::{
|
||||
func, linear, sequential::seq, Activation, AdamW, Optimizer, ParamsAdamW, Sequential,
|
||||
VarBuilder, VarMap,
|
||||
};
|
||||
use rand::{distributions::Uniform, thread_rng, Rng};
|
||||
|
||||
pub struct OuNoise {
|
||||
mu: f64,
|
||||
theta: f64,
|
||||
sigma: f64,
|
||||
state: Tensor,
|
||||
}
|
||||
impl OuNoise {
|
||||
pub fn new(mu: f64, theta: f64, sigma: f64, size_action: usize) -> Result<Self> {
|
||||
Ok(Self {
|
||||
mu,
|
||||
theta,
|
||||
sigma,
|
||||
state: Tensor::ones(size_action, DType::F32, &Device::Cpu)?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn sample(&mut self) -> Result<Tensor> {
|
||||
let rand = Tensor::randn_like(&self.state, 0.0, 1.0)?;
|
||||
let dx = ((self.theta * (self.mu - &self.state)?)? + (self.sigma * rand)?)?;
|
||||
self.state = (&self.state + dx)?;
|
||||
Ok(self.state.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Transition {
|
||||
state: Tensor,
|
||||
action: Tensor,
|
||||
reward: Tensor,
|
||||
next_state: Tensor,
|
||||
terminated: bool,
|
||||
truncated: bool,
|
||||
}
|
||||
impl Transition {
|
||||
fn new(
|
||||
state: &Tensor,
|
||||
action: &Tensor,
|
||||
reward: &Tensor,
|
||||
next_state: &Tensor,
|
||||
terminated: bool,
|
||||
truncated: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
state: state.clone(),
|
||||
action: action.clone(),
|
||||
reward: reward.clone(),
|
||||
next_state: next_state.clone(),
|
||||
terminated,
|
||||
truncated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ReplayBuffer {
|
||||
buffer: VecDeque<Transition>,
|
||||
capacity: usize,
|
||||
size: usize,
|
||||
}
|
||||
impl ReplayBuffer {
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
buffer: VecDeque::with_capacity(capacity),
|
||||
capacity,
|
||||
size: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push(
|
||||
&mut self,
|
||||
state: &Tensor,
|
||||
action: &Tensor,
|
||||
reward: &Tensor,
|
||||
next_state: &Tensor,
|
||||
terminated: bool,
|
||||
truncated: bool,
|
||||
) {
|
||||
if self.size == self.capacity {
|
||||
self.buffer.pop_front();
|
||||
} else {
|
||||
self.size += 1;
|
||||
}
|
||||
self.buffer.push_back(Transition::new(
|
||||
state, action, reward, next_state, terminated, truncated,
|
||||
));
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub fn random_batch(
|
||||
&self,
|
||||
batch_size: usize,
|
||||
) -> Result<Option<(Tensor, Tensor, Tensor, Tensor, Vec<bool>, Vec<bool>)>> {
|
||||
if self.size < batch_size {
|
||||
Ok(None)
|
||||
} else {
|
||||
let transitions: Vec<&Transition> = thread_rng()
|
||||
.sample_iter(Uniform::from(0..self.size))
|
||||
.take(batch_size)
|
||||
.map(|i| self.buffer.get(i).unwrap())
|
||||
.collect();
|
||||
|
||||
let states: Vec<Tensor> = transitions
|
||||
.iter()
|
||||
.map(|t| t.state.unsqueeze(0))
|
||||
.collect::<Result<_>>()?;
|
||||
let actions: Vec<Tensor> = transitions
|
||||
.iter()
|
||||
.map(|t| t.action.unsqueeze(0))
|
||||
.collect::<Result<_>>()?;
|
||||
let rewards: Vec<Tensor> = transitions
|
||||
.iter()
|
||||
.map(|t| t.reward.unsqueeze(0))
|
||||
.collect::<Result<_>>()?;
|
||||
let next_states: Vec<Tensor> = transitions
|
||||
.iter()
|
||||
.map(|t| t.next_state.unsqueeze(0))
|
||||
.collect::<Result<_>>()?;
|
||||
let terminateds: Vec<bool> = transitions.iter().map(|t| t.terminated).collect();
|
||||
let truncateds: Vec<bool> = transitions.iter().map(|t| t.truncated).collect();
|
||||
|
||||
Ok(Some((
|
||||
Tensor::cat(&states, 0)?,
|
||||
Tensor::cat(&actions, 0)?,
|
||||
Tensor::cat(&rewards, 0)?,
|
||||
Tensor::cat(&next_states, 0)?,
|
||||
terminateds,
|
||||
truncateds,
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn track(
|
||||
varmap: &mut VarMap,
|
||||
vb: &VarBuilder,
|
||||
target_prefix: &str,
|
||||
network_prefix: &str,
|
||||
dims: &[(usize, usize)],
|
||||
tau: f64,
|
||||
) -> Result<()> {
|
||||
for (i, &(in_dim, out_dim)) in dims.iter().enumerate() {
|
||||
let target_w = vb.get((out_dim, in_dim), &format!("{target_prefix}-fc{i}.weight"))?;
|
||||
let network_w = vb.get((out_dim, in_dim), &format!("{network_prefix}-fc{i}.weight"))?;
|
||||
varmap.set_one(
|
||||
format!("{target_prefix}-fc{i}.weight"),
|
||||
((tau * network_w)? + ((1.0 - tau) * target_w)?)?,
|
||||
)?;
|
||||
|
||||
let target_b = vb.get(out_dim, &format!("{target_prefix}-fc{i}.bias"))?;
|
||||
let network_b = vb.get(out_dim, &format!("{network_prefix}-fc{i}.bias"))?;
|
||||
varmap.set_one(
|
||||
format!("{target_prefix}-fc{i}.bias"),
|
||||
((tau * network_b)? + ((1.0 - tau) * target_b)?)?,
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct Actor<'a> {
|
||||
varmap: VarMap,
|
||||
vb: VarBuilder<'a>,
|
||||
network: Sequential,
|
||||
target_network: Sequential,
|
||||
size_state: usize,
|
||||
size_action: usize,
|
||||
dims: Vec<(usize, usize)>,
|
||||
}
|
||||
|
||||
impl Actor<'_> {
|
||||
fn new(device: &Device, dtype: DType, size_state: usize, size_action: usize) -> Result<Self> {
|
||||
let mut varmap = VarMap::new();
|
||||
let vb = VarBuilder::from_varmap(&varmap, dtype, device);
|
||||
|
||||
let dims = vec![(size_state, 400), (400, 300), (300, size_action)];
|
||||
|
||||
let make_network = |prefix: &str| {
|
||||
let seq = seq()
|
||||
.add(linear(
|
||||
dims[0].0,
|
||||
dims[0].1,
|
||||
vb.pp(format!("{prefix}-fc0")),
|
||||
)?)
|
||||
.add(Activation::Relu)
|
||||
.add(linear(
|
||||
dims[1].0,
|
||||
dims[1].1,
|
||||
vb.pp(format!("{prefix}-fc1")),
|
||||
)?)
|
||||
.add(Activation::Relu)
|
||||
.add(linear(
|
||||
dims[2].0,
|
||||
dims[2].1,
|
||||
vb.pp(format!("{prefix}-fc2")),
|
||||
)?)
|
||||
.add(func(|xs| xs.tanh()));
|
||||
Ok::<Sequential, Error>(seq)
|
||||
};
|
||||
|
||||
let network = make_network("actor")?;
|
||||
let target_network = make_network("target-actor")?;
|
||||
|
||||
// this sets the two networks to be equal to each other using tau = 1.0
|
||||
track(&mut varmap, &vb, "target-actor", "actor", &dims, 1.0);
|
||||
|
||||
Ok(Self {
|
||||
varmap,
|
||||
vb,
|
||||
network,
|
||||
target_network,
|
||||
size_state,
|
||||
size_action,
|
||||
dims,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, state: &Tensor) -> Result<Tensor> {
|
||||
self.network.forward(state)
|
||||
}
|
||||
|
||||
fn target_forward(&self, state: &Tensor) -> Result<Tensor> {
|
||||
self.target_network.forward(state)
|
||||
}
|
||||
|
||||
fn track(&mut self, tau: f64) -> Result<()> {
|
||||
track(
|
||||
&mut self.varmap,
|
||||
&self.vb,
|
||||
"target-actor",
|
||||
"actor",
|
||||
&self.dims,
|
||||
tau,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
struct Critic<'a> {
|
||||
varmap: VarMap,
|
||||
vb: VarBuilder<'a>,
|
||||
network: Sequential,
|
||||
target_network: Sequential,
|
||||
size_state: usize,
|
||||
size_action: usize,
|
||||
dims: Vec<(usize, usize)>,
|
||||
}
|
||||
|
||||
impl Critic<'_> {
|
||||
fn new(device: &Device, dtype: DType, size_state: usize, size_action: usize) -> Result<Self> {
|
||||
let mut varmap = VarMap::new();
|
||||
let vb = VarBuilder::from_varmap(&varmap, dtype, device);
|
||||
|
||||
let dims: Vec<(usize, usize)> = vec![(size_state + size_action, 400), (400, 300), (300, 1)];
|
||||
|
||||
let make_network = |prefix: &str| {
|
||||
let seq = seq()
|
||||
.add(linear(
|
||||
dims[0].0,
|
||||
dims[0].1,
|
||||
vb.pp(format!("{prefix}-fc0")),
|
||||
)?)
|
||||
.add(Activation::Relu)
|
||||
.add(linear(
|
||||
dims[1].0,
|
||||
dims[1].1,
|
||||
vb.pp(format!("{prefix}-fc1")),
|
||||
)?)
|
||||
.add(Activation::Relu)
|
||||
.add(linear(
|
||||
dims[2].0,
|
||||
dims[2].1,
|
||||
vb.pp(format!("{prefix}-fc2")),
|
||||
)?);
|
||||
Ok::<Sequential, Error>(seq)
|
||||
};
|
||||
|
||||
let network = make_network("critic")?;
|
||||
let target_network = make_network("target-critic")?;
|
||||
|
||||
// this sets the two networks to be equal to each other using tau = 1.0
|
||||
track(&mut varmap, &vb, "target-critic", "critic", &dims, 1.0);
|
||||
|
||||
Ok(Self {
|
||||
varmap,
|
||||
vb,
|
||||
network,
|
||||
target_network,
|
||||
size_state,
|
||||
size_action,
|
||||
dims,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, state: &Tensor, action: &Tensor) -> Result<Tensor> {
|
||||
let xs = Tensor::cat(&[action, state], 1)?;
|
||||
self.network.forward(&xs)
|
||||
}
|
||||
|
||||
fn target_forward(&self, state: &Tensor, action: &Tensor) -> Result<Tensor> {
|
||||
let xs = Tensor::cat(&[action, state], 1)?;
|
||||
self.target_network.forward(&xs)
|
||||
}
|
||||
|
||||
fn track(&mut self, tau: f64) -> Result<()> {
|
||||
track(
|
||||
&mut self.varmap,
|
||||
&self.vb,
|
||||
"target-critic",
|
||||
"critic",
|
||||
&self.dims,
|
||||
tau,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
pub struct DDPG<'a> {
|
||||
actor: Actor<'a>,
|
||||
actor_optim: AdamW,
|
||||
critic: Critic<'a>,
|
||||
critic_optim: AdamW,
|
||||
gamma: f64,
|
||||
tau: f64,
|
||||
replay_buffer: ReplayBuffer,
|
||||
ou_noise: OuNoise,
|
||||
|
||||
size_state: usize,
|
||||
size_action: usize,
|
||||
pub train: bool,
|
||||
}
|
||||
|
||||
impl DDPG<'_> {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
device: &Device,
|
||||
size_state: usize,
|
||||
size_action: usize,
|
||||
train: bool,
|
||||
actor_lr: f64,
|
||||
critic_lr: f64,
|
||||
gamma: f64,
|
||||
tau: f64,
|
||||
buffer_capacity: usize,
|
||||
ou_noise: OuNoise,
|
||||
) -> Result<Self> {
|
||||
let filter_by_prefix = |varmap: &VarMap, prefix: &str| {
|
||||
varmap
|
||||
.data()
|
||||
.lock()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.filter_map(|(name, var)| name.starts_with(prefix).then_some(var.clone()))
|
||||
.collect::<Vec<Var>>()
|
||||
};
|
||||
|
||||
let actor = Actor::new(device, DType::F32, size_state, size_action)?;
|
||||
let actor_optim = AdamW::new(
|
||||
filter_by_prefix(&actor.varmap, "actor"),
|
||||
ParamsAdamW {
|
||||
lr: actor_lr,
|
||||
..Default::default()
|
||||
},
|
||||
)?;
|
||||
|
||||
let critic = Critic::new(device, DType::F32, size_state, size_action)?;
|
||||
let critic_optim = AdamW::new(
|
||||
filter_by_prefix(&critic.varmap, "critic"),
|
||||
ParamsAdamW {
|
||||
lr: critic_lr,
|
||||
..Default::default()
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
actor,
|
||||
actor_optim,
|
||||
critic,
|
||||
critic_optim,
|
||||
gamma,
|
||||
tau,
|
||||
replay_buffer: ReplayBuffer::new(buffer_capacity),
|
||||
ou_noise,
|
||||
size_state,
|
||||
size_action,
|
||||
train,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn remember(
|
||||
&mut self,
|
||||
state: &Tensor,
|
||||
action: &Tensor,
|
||||
reward: &Tensor,
|
||||
next_state: &Tensor,
|
||||
terminated: bool,
|
||||
truncated: bool,
|
||||
) {
|
||||
self.replay_buffer
|
||||
.push(state, action, reward, next_state, terminated, truncated)
|
||||
}
|
||||
|
||||
pub fn actions(&mut self, state: &Tensor) -> Result<f32> {
|
||||
let actions = self
|
||||
.actor
|
||||
.forward(&state.detach()?.unsqueeze(0)?)?
|
||||
.squeeze(0)?;
|
||||
let actions = if self.train {
|
||||
(actions + self.ou_noise.sample()?)?
|
||||
} else {
|
||||
actions
|
||||
};
|
||||
actions.squeeze(0)?.to_scalar::<f32>()
|
||||
}
|
||||
|
||||
pub fn train(&mut self, batch_size: usize) -> Result<()> {
|
||||
let (states, actions, rewards, next_states, _, _) =
|
||||
match self.replay_buffer.random_batch(batch_size)? {
|
||||
Some(v) => v,
|
||||
_ => return Ok(()),
|
||||
};
|
||||
|
||||
let q_target = self
|
||||
.critic
|
||||
.target_forward(&next_states, &self.actor.target_forward(&next_states)?)?;
|
||||
let q_target = (rewards + (self.gamma * q_target)?.detach())?;
|
||||
let q = self.critic.forward(&states, &actions)?;
|
||||
let diff = (q_target - q)?;
|
||||
|
||||
let critic_loss = diff.sqr()?.mean_all()?;
|
||||
self.critic_optim.backward_step(&critic_loss)?;
|
||||
|
||||
let actor_loss = self
|
||||
.critic
|
||||
.forward(&states, &self.actor.forward(&states)?)?
|
||||
.mean_all()?
|
||||
.neg()?;
|
||||
self.actor_optim.backward_step(&actor_loss)?;
|
||||
|
||||
self.critic.track(self.tau)?;
|
||||
self.actor.track(self.tau)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -1,112 +0,0 @@
|
||||
#![allow(unused)]
|
||||
//! Wrappers around the Python API of Gymnasium (the new version of OpenAI gym)
|
||||
use candle::{Device, Result, Tensor};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyDict;
|
||||
|
||||
/// The return value for a step.
|
||||
#[derive(Debug)]
|
||||
pub struct Step<A> {
|
||||
pub state: Tensor,
|
||||
pub action: A,
|
||||
pub reward: f64,
|
||||
pub terminated: bool,
|
||||
pub truncated: bool,
|
||||
}
|
||||
|
||||
impl<A: Copy> Step<A> {
|
||||
/// Returns a copy of this step changing the observation tensor.
|
||||
pub fn copy_with_obs(&self, state: &Tensor) -> Step<A> {
|
||||
Step {
|
||||
state: state.clone(),
|
||||
action: self.action,
|
||||
reward: self.reward,
|
||||
terminated: self.terminated,
|
||||
truncated: self.truncated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An OpenAI Gym session.
|
||||
pub struct GymEnv {
|
||||
env: PyObject,
|
||||
action_space: usize,
|
||||
observation_space: Vec<usize>,
|
||||
}
|
||||
|
||||
fn w(res: PyErr) -> candle::Error {
|
||||
candle::Error::wrap(res)
|
||||
}
|
||||
|
||||
impl GymEnv {
|
||||
/// Creates a new session of the specified OpenAI Gym environment.
|
||||
pub fn new(name: &str) -> Result<GymEnv> {
|
||||
Python::with_gil(|py| {
|
||||
let gym = py.import("gymnasium")?;
|
||||
let make = gym.getattr("make")?;
|
||||
let env = make.call1((name,))?;
|
||||
let action_space = env.getattr("action_space")?;
|
||||
let action_space = if let Ok(val) = action_space.getattr("n") {
|
||||
val.extract()?
|
||||
} else {
|
||||
let action_space: Vec<usize> = action_space.getattr("shape")?.extract()?;
|
||||
action_space[0]
|
||||
};
|
||||
let observation_space = env.getattr("observation_space")?;
|
||||
let observation_space = observation_space.getattr("shape")?.extract()?;
|
||||
Ok(GymEnv {
|
||||
env: env.into(),
|
||||
action_space,
|
||||
observation_space,
|
||||
})
|
||||
})
|
||||
.map_err(w)
|
||||
}
|
||||
|
||||
/// Resets the environment, returning the observation tensor.
|
||||
pub fn reset(&self, seed: u64) -> Result<Tensor> {
|
||||
let state: Vec<f32> = Python::with_gil(|py| {
|
||||
let kwargs = PyDict::new(py);
|
||||
kwargs.set_item("seed", seed)?;
|
||||
let state = self.env.call_method(py, "reset", (), Some(kwargs))?;
|
||||
state.as_ref(py).get_item(0)?.extract()
|
||||
})
|
||||
.map_err(w)?;
|
||||
Tensor::new(state, &Device::Cpu)
|
||||
}
|
||||
|
||||
/// Applies an environment step using the specified action.
|
||||
pub fn step<A: pyo3::IntoPy<pyo3::Py<pyo3::PyAny>> + Clone>(
|
||||
&self,
|
||||
action: A,
|
||||
) -> Result<Step<A>> {
|
||||
let (state, reward, terminated, truncated) = Python::with_gil(|py| {
|
||||
let step = self.env.call_method(py, "step", (action.clone(),), None)?;
|
||||
let step = step.as_ref(py);
|
||||
let state: Vec<f32> = step.get_item(0)?.extract()?;
|
||||
let reward: f64 = step.get_item(1)?.extract()?;
|
||||
let terminated: bool = step.get_item(2)?.extract()?;
|
||||
let truncated: bool = step.get_item(3)?.extract()?;
|
||||
Ok((state, reward, terminated, truncated))
|
||||
})
|
||||
.map_err(w)?;
|
||||
let state = Tensor::new(state, &Device::Cpu)?;
|
||||
Ok(Step {
|
||||
state,
|
||||
action,
|
||||
reward,
|
||||
terminated,
|
||||
truncated,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the number of allowed actions for this environment.
|
||||
pub fn action_space(&self) -> usize {
|
||||
self.action_space
|
||||
}
|
||||
|
||||
/// Returns the shape of the observation tensors.
|
||||
pub fn observation_space(&self) -> &[usize] {
|
||||
&self.observation_space
|
||||
}
|
||||
}
|
@ -1,144 +0,0 @@
|
||||
#![allow(unused)]
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
mod gym_env;
|
||||
mod vec_gym_env;
|
||||
|
||||
mod ddpg;
|
||||
|
||||
use candle::{Device, Result, Tensor};
|
||||
use clap::Parser;
|
||||
use rand::Rng;
|
||||
|
||||
// The impact of the q value of the next state on the current state's q value.
|
||||
const GAMMA: f64 = 0.99;
|
||||
// The weight for updating the target networks.
|
||||
const TAU: f64 = 0.005;
|
||||
// The capacity of the replay buffer used for sampling training data.
|
||||
const REPLAY_BUFFER_CAPACITY: usize = 100_000;
|
||||
// The training batch size for each training iteration.
|
||||
const TRAINING_BATCH_SIZE: usize = 100;
|
||||
// The total number of episodes.
|
||||
const MAX_EPISODES: usize = 100;
|
||||
// The maximum length of an episode.
|
||||
const EPISODE_LENGTH: usize = 200;
|
||||
// The number of training iterations after one episode finishes.
|
||||
const TRAINING_ITERATIONS: usize = 200;
|
||||
|
||||
// Ornstein-Uhlenbeck process parameters.
|
||||
const MU: f64 = 0.0;
|
||||
const THETA: f64 = 0.15;
|
||||
const SIGMA: f64 = 0.1;
|
||||
|
||||
const ACTOR_LEARNING_RATE: f64 = 1e-4;
|
||||
const CRITIC_LEARNING_RATE: f64 = 1e-3;
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let env = gym_env::GymEnv::new("Pendulum-v1")?;
|
||||
println!("action space: {}", env.action_space());
|
||||
println!("observation space: {:?}", env.observation_space());
|
||||
|
||||
let size_state = env.observation_space().iter().product::<usize>();
|
||||
let size_action = env.action_space();
|
||||
|
||||
let mut agent = ddpg::DDPG::new(
|
||||
&Device::Cpu,
|
||||
size_state,
|
||||
size_action,
|
||||
true,
|
||||
ACTOR_LEARNING_RATE,
|
||||
CRITIC_LEARNING_RATE,
|
||||
GAMMA,
|
||||
TAU,
|
||||
REPLAY_BUFFER_CAPACITY,
|
||||
ddpg::OuNoise::new(MU, THETA, SIGMA, size_action)?,
|
||||
)?;
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for episode in 0..MAX_EPISODES {
|
||||
// let mut state = env.reset(episode as u64)?;
|
||||
let mut state = env.reset(rng.gen::<u64>())?;
|
||||
|
||||
let mut total_reward = 0.0;
|
||||
for _ in 0..EPISODE_LENGTH {
|
||||
let mut action = 2.0 * agent.actions(&state)?;
|
||||
action = action.clamp(-2.0, 2.0);
|
||||
|
||||
let step = env.step(vec![action])?;
|
||||
total_reward += step.reward;
|
||||
|
||||
agent.remember(
|
||||
&state,
|
||||
&Tensor::new(vec![action], &Device::Cpu)?,
|
||||
&Tensor::new(vec![step.reward as f32], &Device::Cpu)?,
|
||||
&step.state,
|
||||
step.terminated,
|
||||
step.truncated,
|
||||
);
|
||||
|
||||
if step.terminated || step.truncated {
|
||||
break;
|
||||
}
|
||||
state = step.state;
|
||||
}
|
||||
|
||||
println!("episode {episode} with total reward of {total_reward}");
|
||||
|
||||
for _ in 0..TRAINING_ITERATIONS {
|
||||
agent.train(TRAINING_BATCH_SIZE)?;
|
||||
}
|
||||
}
|
||||
|
||||
println!("Testing...");
|
||||
agent.train = false;
|
||||
for episode in 0..10 {
|
||||
// let mut state = env.reset(episode as u64)?;
|
||||
let mut state = env.reset(rng.gen::<u64>())?;
|
||||
let mut total_reward = 0.0;
|
||||
for _ in 0..EPISODE_LENGTH {
|
||||
let mut action = 2.0 * agent.actions(&state)?;
|
||||
action = action.clamp(-2.0, 2.0);
|
||||
|
||||
let step = env.step(vec![action])?;
|
||||
total_reward += step.reward;
|
||||
|
||||
if step.terminated || step.truncated {
|
||||
break;
|
||||
}
|
||||
state = step.state;
|
||||
}
|
||||
println!("episode {episode} with total reward of {total_reward}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user