mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
203 Commits
w-uncond
...
meshgrid-f
Author | SHA1 | Date | |
---|---|---|---|
20da4f44ef | |||
e4c9adfdbe | |||
b6053b938b | |||
45dbe541bc | |||
7bd0faba75 | |||
807e3f9f52 | |||
eae94a451b | |||
86e1803191 | |||
25c3cc4149 | |||
a11af79e23 | |||
8a82d623e5 | |||
df2f89b6cf | |||
62fc965617 | |||
5b32c2a41e | |||
3115fe42e4 | |||
2531b13bf8 | |||
0d9bb4eb18 | |||
e8f760ee44 | |||
94e3373883 | |||
34d9e91748 | |||
cfb423ab76 | |||
7366aeac21 | |||
99cf13e8e2 | |||
b43ab6cd1d | |||
31ca4897bb | |||
55351ef57d | |||
6684b7127a | |||
93c25e8844 | |||
cd53c472df | |||
6f76383f38 | |||
8e773cc0c6 | |||
87eb1658e1 | |||
902d0b9166 | |||
185b54a33b | |||
620c94d12e | |||
86e7d539d2 | |||
cb034506cd | |||
63c204c79e | |||
767a6578f1 | |||
662c186fd5 | |||
2cd745a97c | |||
a72b50e2c0 | |||
872c3f14b0 | |||
f9e93f5b69 | |||
b355ab4e2e | |||
2fe24ac5b1 | |||
00948eb656 | |||
af67672207 | |||
6c588c4792 | |||
122da87580 | |||
75629981bc | |||
0106b0b04c | |||
588ad4835a | |||
b73c35cc57 | |||
8f310cc666 | |||
8921d5027c | |||
29c7f2565d | |||
9309cfc47d | |||
a193bf5f60 | |||
2c110ac7d9 | |||
75989fc3b7 | |||
07af87a1d8 | |||
eefad2b95f | |||
5e6df4a3f7 | |||
7473c4ceca | |||
c096f02411 | |||
e7560443e4 | |||
89b525b5e7 | |||
37dbbff261 | |||
9fea56d28e | |||
bc3351bce4 | |||
b34d7f0248 | |||
4d04ac83c7 | |||
392fe02fba | |||
59ab6d7832 | |||
783735cf22 | |||
9abeddd750 | |||
2e5fb0b251 | |||
823fe23f9b | |||
d833527fda | |||
a4967600d0 | |||
aa53368aeb | |||
955e00b2e8 | |||
d5f7267087 | |||
904bbdae65 | |||
b0442eff8a | |||
4631c48273 | |||
716883e9b0 | |||
47c25a567b | |||
7f7d95e2c3 | |||
f47bd9bab5 | |||
8f7973958c | |||
f0c619a4af | |||
b86ac0c507 | |||
27e70a5093 | |||
c18a856e76 | |||
3349c89252 | |||
11d3687cc6 | |||
dac73edb34 | |||
b4da19d1be | |||
ff513314fc | |||
043cc25766 | |||
7b06872f90 | |||
65825e7240 | |||
7670fe7d1f | |||
cddfc3944c | |||
089fc3b584 | |||
e04c789230 | |||
263a172202 | |||
638ccf9f46 | |||
0baf5a1e19 | |||
5130a7da32 | |||
41143db1af | |||
096dee7073 | |||
f6054e9d60 | |||
328167ec04 | |||
4e55aaa51f | |||
deee7612da | |||
06207332bc | |||
4021272875 | |||
87e3a4e175 | |||
6203ced495 | |||
34842fb234 | |||
d188d6a764 | |||
0ac2db577b | |||
fc59bc31bf | |||
03348e2e6f | |||
49fa184a35 | |||
6f17ef82be | |||
01b92cd959 | |||
53510ce427 | |||
23b3576c47 | |||
716ab2ccdc | |||
ada8851a23 | |||
c05a348e36 | |||
25657804ef | |||
5e1c595e00 | |||
8a49e01b9d | |||
9cb110c44c | |||
667f01c173 | |||
e59784e353 | |||
29bd6b2979 | |||
9571b200c9 | |||
ce0a4e3a85 | |||
4abc1ea34d | |||
2dd43d6cdd | |||
1fcac4afed | |||
a084f65f9a | |||
c798184c2b | |||
c78a294323 | |||
a36d883254 | |||
7f2bbcf746 | |||
dc47224ab9 | |||
1ce7fe2543 | |||
402ddcfcb4 | |||
f5069dd354 | |||
0007ae9c11 | |||
e15862cfdb | |||
4aeb449017 | |||
bcb0ed8f1c | |||
7edd755756 | |||
e32c89d90c | |||
bb3471ea31 | |||
890d069092 | |||
5dbe46b389 | |||
ccf352f3d1 | |||
402d207f0f | |||
7582937a32 | |||
b54acfa3d0 | |||
cda1786eed | |||
912a3d63b0 | |||
3ef328c53d | |||
0c8e983514 | |||
df6f5240ba | |||
a46b1b4657 | |||
19e52e5007 | |||
8601537e31 | |||
4ac6039a42 | |||
52a60ca3ad | |||
a96878f235 | |||
aa8ec06fd2 | |||
b43ca493f6 | |||
3b557765e8 | |||
2619c4307f | |||
c89b82b2d4 | |||
7b26e513f1 | |||
ab1d40ea97 | |||
3a0d3e05df | |||
9b24d89d2d | |||
fb1c2ac535 | |||
728e167334 | |||
7b1ddcff47 | |||
f685b2231c | |||
c0b49d5a50 | |||
098dd0d1e9 | |||
05626ef492 | |||
67a486d18d | |||
7ad82b87e4 | |||
8696f64bae | |||
d7e48234d4 | |||
34f2ecbc3b | |||
4f91c8e109 | |||
06e46d7c3b |
62
.github/workflows/python.yml
vendored
Normal file
62
.github/workflows/python.yml
vendored
Normal file
@ -0,0 +1,62 @@
|
||||
name: PyO3-CI
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- candle-pyo3/**
|
||||
pull_request:
|
||||
paths:
|
||||
- candle-pyo3/**
|
||||
|
||||
jobs:
|
||||
build_and_test:
|
||||
name: Check everything builds & tests
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest] # For now, only test on Linux
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Install Rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: stable
|
||||
|
||||
- name: Install Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.11
|
||||
architecture: "x64"
|
||||
|
||||
- name: Cache Cargo Registry
|
||||
uses: actions/cache@v1
|
||||
with:
|
||||
path: ~/.cargo/registry
|
||||
key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}
|
||||
|
||||
- name: Install
|
||||
working-directory: ./candle-pyo3
|
||||
run: |
|
||||
python -m venv .env
|
||||
source .env/bin/activate
|
||||
pip install -U pip
|
||||
pip install pytest maturin black
|
||||
python -m maturin develop -r
|
||||
|
||||
- name: Check style
|
||||
working-directory: ./candle-pyo3
|
||||
run: |
|
||||
source .env/bin/activate
|
||||
python stub.py --check
|
||||
black --check .
|
||||
|
||||
- name: Run tests
|
||||
working-directory: ./candle-pyo3
|
||||
run: |
|
||||
source .env/bin/activate
|
||||
python -m pytest -s -v tests
|
8
.gitignore
vendored
8
.gitignore
vendored
@ -23,14 +23,16 @@ flamegraph.svg
|
||||
*.dylib
|
||||
*.so
|
||||
*.swp
|
||||
*.swo
|
||||
trace-*.json
|
||||
|
||||
candle-wasm-examples/*/build
|
||||
candle-wasm-examples/*/*.bin
|
||||
candle-wasm-examples/*/*.jpeg
|
||||
candle-wasm-examples/*/*.wav
|
||||
candle-wasm-examples/*/*.safetensors
|
||||
candle-wasm-examples/*/audios/*.wav
|
||||
candle-wasm-examples/**/*.safetensors
|
||||
candle-wasm-examples/**/*.gguf
|
||||
candle-wasm-examples/*/package-lock.json
|
||||
|
||||
candle-wasm-examples/**/config*.json
|
||||
.DS_Store
|
||||
.idea/*
|
||||
|
11
.vscode/settings.json
vendored
Normal file
11
.vscode/settings.json
vendored
Normal file
@ -0,0 +1,11 @@
|
||||
{
|
||||
"[python]": {
|
||||
"editor.defaultFormatter": "ms-python.black-formatter"
|
||||
},
|
||||
"python.formatting.provider": "none",
|
||||
"python.testing.pytestArgs": [
|
||||
"candle-pyo3"
|
||||
],
|
||||
"python.testing.unittestEnabled": false,
|
||||
"python.testing.pytestEnabled": true
|
||||
}
|
28
CHANGELOG.md
28
CHANGELOG.md
@ -1,12 +1,38 @@
|
||||
# Changelog
|
||||
This documents the main changes to the `candle` crate.
|
||||
|
||||
## v0.2.3 - Unreleased
|
||||
## v0.3.1 - Unreleased
|
||||
|
||||
### Added
|
||||
|
||||
### Modified
|
||||
|
||||
## v0.3.0 - 2023-10-01
|
||||
|
||||
### Added
|
||||
|
||||
- Added the Mistral 7b v0.1 model
|
||||
[983](https://github.com/huggingface/candle/pull/983).
|
||||
- Quantized version of the Mistral model
|
||||
[1009](https://github.com/huggingface/candle/pull/1009).
|
||||
- Add the gelu-erf op and activation function
|
||||
[969](https://github.com/huggingface/candle/pull/969).
|
||||
- Add the mixformer/phi-v1.5 model
|
||||
[930](https://github.com/huggingface/candle/pull/930).
|
||||
- Add the sclice-scatter op
|
||||
[927](https://github.com/huggingface/candle/pull/927).
|
||||
- Add the Wuerstchen diffusion model
|
||||
[911](https://github.com/huggingface/candle/pull/911).
|
||||
|
||||
### Modified
|
||||
|
||||
- Support for simd128 intrinsics in some quantized vecdots
|
||||
[982](https://github.com/huggingface/candle/pull/982).
|
||||
- Optimize the index-select cuda kernel
|
||||
[976](https://github.com/huggingface/candle/pull/976).
|
||||
- Self-contained safetensor wrappers
|
||||
[946](https://github.com/huggingface/candle/pull/946).
|
||||
|
||||
## v0.2.2 - 2023-09-18
|
||||
|
||||
### Added
|
||||
|
19
Cargo.toml
19
Cargo.toml
@ -11,15 +11,16 @@ members = [
|
||||
"candle-wasm-examples/segment-anything",
|
||||
"candle-wasm-examples/whisper",
|
||||
"candle-wasm-examples/yolo",
|
||||
"candle-wasm-examples/bert",
|
||||
"candle-wasm-examples/phi",
|
||||
"candle-wasm-examples/t5",
|
||||
"candle-wasm-tests",
|
||||
]
|
||||
exclude = [
|
||||
"candle-flash-attn",
|
||||
"candle-kernels",
|
||||
]
|
||||
exclude = ["candle-flash-attn", "candle-kernels"]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.2.3"
|
||||
version = "0.3.0"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -33,8 +34,7 @@ anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
cudarc = { version = "0.9.14", features = ["f16"] }
|
||||
# TODO: Switch back to the official gemm implementation once it has caught up.
|
||||
gemm = { version = "0.16.0", package = "candle-gemm" }
|
||||
gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] }
|
||||
hf-hub = "0.3.0"
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
|
||||
@ -42,9 +42,10 @@ imageproc = { version = "0.23.0", default-features = false }
|
||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||
libc = { version = "0.2.147" }
|
||||
log = "0.4"
|
||||
memmap2 = "0.7.1"
|
||||
memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] }
|
||||
num_cpus = "1.15.0"
|
||||
num-traits = "0.2.15"
|
||||
parquet = { version = "45.0.0" }
|
||||
rand = "0.8.5"
|
||||
rand_distr = "0.4.3"
|
||||
rayon = "1.7.0"
|
||||
@ -58,8 +59,8 @@ tracing = "0.1.37"
|
||||
tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
wav = "1.0.0"
|
||||
yoke = { version = "0.7.2", features = ["derive"] }
|
||||
zip = { version = "0.6.6", default-features = false }
|
||||
parquet = { version = "45.0.0" }
|
||||
|
||||
[profile.release-with-debug]
|
||||
inherits = "release"
|
||||
|
49
README.md
49
README.md
@ -8,6 +8,7 @@ Candle is a minimalist ML framework for Rust with a focus on performance (includ
|
||||
and ease of use. Try our online demos:
|
||||
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
||||
[LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2),
|
||||
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
|
||||
[yolo](https://huggingface.co/spaces/lmz/candle-yolo),
|
||||
[Segment
|
||||
Anything](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
|
||||
@ -52,14 +53,21 @@ These online demos run entirely in your browser:
|
||||
object recognition.
|
||||
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): text to speech.
|
||||
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
|
||||
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
|
||||
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
|
||||
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
|
||||
|
||||
We also provide a some command line based examples using state of the art models:
|
||||
|
||||
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
|
||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code
|
||||
generation.
|
||||
- [Phi-v1 and Phi-v1.5](./candle-examples/examples/phi/): a 1.3b general LLM with performance on par with LLaMA-v2 7b.
|
||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||
pre-trained on 1T tokens of English and code datasets.
|
||||
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
||||
performance larger than all publicly available 13b models as of 2023-09-28.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
||||
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
||||
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
|
||||
the LLaMA model using the same quantization techniques as
|
||||
[llama.cpp](https://github.com/ggerganov/llama.cpp).
|
||||
@ -71,6 +79,11 @@ We also provide a some command line based examples using state of the art models
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200">
|
||||
|
||||
- [Wuerstchen](./candle-examples/examples/wuerstchen/): another text to
|
||||
image generative model.
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/wuerstchen/assets/cat.jpg" width="200">
|
||||
|
||||
- [yolo-v3](./candle-examples/examples/yolo-v3/) and
|
||||
[yolo-v8](./candle-examples/examples/yolo-v8/): object detection and pose
|
||||
estimation models.
|
||||
@ -86,6 +99,8 @@ We also provide a some command line based examples using state of the art models
|
||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||
using self-supervision (can be used for imagenet classification, depth
|
||||
evaluation, segmentation).
|
||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||
generate captions for an image.
|
||||
|
||||
Run them using commands like:
|
||||
```
|
||||
@ -100,6 +115,8 @@ There are also some wasm examples for whisper and
|
||||
`trunk` or try them online:
|
||||
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
||||
[llama2](https://huggingface.co/spaces/lmz/candle-llama2),
|
||||
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
|
||||
[Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm),
|
||||
[Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
|
||||
|
||||
For LLaMA2, run the following command to retrieve the weight files and start a
|
||||
@ -115,8 +132,11 @@ And then head over to
|
||||
|
||||
<!--- ANCHOR: useful_libraries --->
|
||||
|
||||
## Useful Libraries
|
||||
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora) provides a LoRA implementation that conforms to the official `peft` implementation.
|
||||
## Useful External Resources
|
||||
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): a
|
||||
very detailed tutorial showing how to convert a PyTorch model to Candle.
|
||||
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): a LoRA implementation
|
||||
that conforms to the official `peft` implementation.
|
||||
|
||||
If you have an addition to this list, please submit a pull request.
|
||||
|
||||
@ -138,15 +158,21 @@ If you have an addition to this list, please submit a pull request.
|
||||
- LLaMA v1 and v2.
|
||||
- Falcon.
|
||||
- StarCoder.
|
||||
- Phi v1.5.
|
||||
- Mistral 7b v0.1.
|
||||
- StableLM-3B-4E1T.
|
||||
- Replit-code-v1.5-3B.
|
||||
- T5.
|
||||
- Bert.
|
||||
- Whisper (multi-lingual support).
|
||||
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
||||
- Text to image.
|
||||
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
||||
- Wurstchen v2.
|
||||
- Image to text.
|
||||
- BLIP.
|
||||
- Computer Vision Models.
|
||||
- DINOv2.
|
||||
- EfficientNet.
|
||||
- yolo-v3.
|
||||
- yolo-v8.
|
||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
|
||||
- yolo-v3, yolo-v8.
|
||||
- Segment-Anything Model (SAM).
|
||||
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
||||
- Serverless (on CPU), small and fast deployments.
|
||||
@ -306,6 +332,11 @@ mdbook test candle-book -L .\target\debug\deps\ `
|
||||
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.48.5\lib
|
||||
```
|
||||
|
||||
#### Extremely slow model load time with WSL
|
||||
|
||||
This may be caused by the models being loaded from `/mnt/c`, more details on
|
||||
[stackoverflow](https://stackoverflow.com/questions/68972448/why-is-wsl-extremely-slow-when-compared-with-native-windows-npm-yarn-processing).
|
||||
|
||||
#### Tracking down errors
|
||||
|
||||
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
|
||||
|
@ -11,11 +11,11 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.2.3" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.3" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.2.3" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.3", optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
@ -24,9 +24,10 @@ intel-mkl-src = { workspace = true, optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
image = { workspace = true, optional = true }
|
||||
anyhow = { workspace = true }
|
||||
tokio = "1.29.1"
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
byteorder = { workspace = true }
|
||||
hf-hub = { workspace = true, features=["tokio"]}
|
||||
clap = { workspace = true }
|
||||
@ -38,7 +39,6 @@ tracing-chrome = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||
tokio = "1.29.1"
|
||||
parquet = { workspace = true }
|
||||
image = { workspace = true }
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
- [Using the hub](inference/hub.md)
|
||||
- [Error management](error_manage.md)
|
||||
- [Training](training/training.md)
|
||||
- [Simplified](training/simplified.md)
|
||||
- [MNIST](training/mnist.md)
|
||||
- [Fine-tuning]()
|
||||
- [Serialization]()
|
||||
|
@ -12,6 +12,9 @@ compute_cap
|
||||
8.9
|
||||
```
|
||||
|
||||
You can also compile the Cuda kernels for a specific compute cap using the
|
||||
`CUDA_COMPUTE_CAP=<compute cap>` environment variable.
|
||||
|
||||
If any of the above commands errors out, please make sure to update your Cuda version.
|
||||
|
||||
2. Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) with Cuda support.
|
||||
|
@ -1,3 +1,6 @@
|
||||
#[cfg(test)]
|
||||
pub mod simplified;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
|
196
candle-book/src/simplified.rs
Normal file
196
candle-book/src/simplified.rs
Normal file
@ -0,0 +1,196 @@
|
||||
//! #A simplified example in Rust of training a neural network and then using it based on the Candle Framework by Hugging Face.
|
||||
//! Author: Evgeny Igumnov 2023 igumnovnsk@gmail.com
|
||||
//! This program implements a neural network to predict the winner of the second round of elections based on the results of the first round.
|
||||
//!
|
||||
//! ##Basic moments:
|
||||
//!
|
||||
//! A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons.
|
||||
//! The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage.
|
||||
//! The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose.
|
||||
//! For training, samples with real data on the results of the first and second stages of different elections are used.
|
||||
//! The model is trained by backpropagation using gradient descent and the cross-entropy loss function.
|
||||
//! Model parameters (weights of neurons) are initialized randomly, then optimized during training.
|
||||
//! After training, the model is tested on a deferred sample to evaluate the accuracy.
|
||||
//! If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated.
|
||||
//! Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data.
|
||||
|
||||
#[rustfmt::skip]
|
||||
mod tests {
|
||||
|
||||
use candle::{DType, Result, Tensor, D, Device};
|
||||
use candle_nn::{loss, ops, Linear, Module, VarBuilder, VarMap, Optimizer};
|
||||
|
||||
// ANCHOR: book_training_simplified1
|
||||
const VOTE_DIM: usize = 2;
|
||||
const RESULTS: usize = 1;
|
||||
const EPOCHS: usize = 10;
|
||||
const LAYER1_OUT_SIZE: usize = 4;
|
||||
const LAYER2_OUT_SIZE: usize = 2;
|
||||
const LEARNING_RATE: f64 = 0.05;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Dataset {
|
||||
pub train_votes: Tensor,
|
||||
pub train_results: Tensor,
|
||||
pub test_votes: Tensor,
|
||||
pub test_results: Tensor,
|
||||
}
|
||||
|
||||
struct MultiLevelPerceptron {
|
||||
ln1: Linear,
|
||||
ln2: Linear,
|
||||
ln3: Linear,
|
||||
}
|
||||
|
||||
impl MultiLevelPerceptron {
|
||||
fn new(vs: VarBuilder) -> Result<Self> {
|
||||
let ln1 = candle_nn::linear(VOTE_DIM, LAYER1_OUT_SIZE, vs.pp("ln1"))?;
|
||||
let ln2 = candle_nn::linear(LAYER1_OUT_SIZE, LAYER2_OUT_SIZE, vs.pp("ln2"))?;
|
||||
let ln3 = candle_nn::linear(LAYER2_OUT_SIZE, RESULTS + 1, vs.pp("ln3"))?;
|
||||
Ok(Self { ln1, ln2, ln3 })
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.ln1.forward(xs)?;
|
||||
let xs = xs.relu()?;
|
||||
let xs = self.ln2.forward(&xs)?;
|
||||
let xs = xs.relu()?;
|
||||
self.ln3.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
// ANCHOR_END: book_training_simplified1
|
||||
|
||||
|
||||
|
||||
// ANCHOR: book_training_simplified3
|
||||
#[tokio::test]
|
||||
async fn simplified() -> anyhow::Result<()> {
|
||||
|
||||
let dev = Device::cuda_if_available(0)?;
|
||||
|
||||
let train_votes_vec: Vec<u32> = vec![
|
||||
15, 10,
|
||||
10, 15,
|
||||
5, 12,
|
||||
30, 20,
|
||||
16, 12,
|
||||
13, 25,
|
||||
6, 14,
|
||||
31, 21,
|
||||
];
|
||||
let train_votes_tensor = Tensor::from_vec(train_votes_vec.clone(), (train_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
|
||||
|
||||
let train_results_vec: Vec<u32> = vec![
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
];
|
||||
let train_results_tensor = Tensor::from_vec(train_results_vec, train_votes_vec.len() / VOTE_DIM, &dev)?;
|
||||
|
||||
let test_votes_vec: Vec<u32> = vec![
|
||||
13, 9,
|
||||
8, 14,
|
||||
3, 10,
|
||||
];
|
||||
let test_votes_tensor = Tensor::from_vec(test_votes_vec.clone(), (test_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
|
||||
|
||||
let test_results_vec: Vec<u32> = vec![
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
];
|
||||
let test_results_tensor = Tensor::from_vec(test_results_vec.clone(), test_results_vec.len(), &dev)?;
|
||||
|
||||
let m = Dataset {
|
||||
train_votes: train_votes_tensor,
|
||||
train_results: train_results_tensor,
|
||||
test_votes: test_votes_tensor,
|
||||
test_results: test_results_tensor,
|
||||
};
|
||||
|
||||
let trained_model: MultiLevelPerceptron;
|
||||
loop {
|
||||
println!("Trying to train neural network.");
|
||||
match train(m.clone(), &dev) {
|
||||
Ok(model) => {
|
||||
trained_model = model;
|
||||
break;
|
||||
},
|
||||
Err(e) => {
|
||||
println!("Error: {}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
let real_world_votes: Vec<u32> = vec![
|
||||
13, 22,
|
||||
];
|
||||
|
||||
let tensor_test_votes = Tensor::from_vec(real_world_votes.clone(), (1, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
|
||||
|
||||
let final_result = trained_model.forward(&tensor_test_votes)?;
|
||||
|
||||
let result = final_result
|
||||
.argmax(D::Minus1)?
|
||||
.to_dtype(DType::F32)?
|
||||
.get(0).map(|x| x.to_scalar::<f32>())??;
|
||||
println!("real_life_votes: {:?}", real_world_votes);
|
||||
println!("neural_network_prediction_result: {:?}", result);
|
||||
|
||||
Ok(())
|
||||
|
||||
}
|
||||
// ANCHOR_END: book_training_simplified3
|
||||
|
||||
// ANCHOR: book_training_simplified2
|
||||
fn train(m: Dataset, dev: &Device) -> anyhow::Result<MultiLevelPerceptron> {
|
||||
let train_results = m.train_results.to_device(dev)?;
|
||||
let train_votes = m.train_votes.to_device(dev)?;
|
||||
let varmap = VarMap::new();
|
||||
let vs = VarBuilder::from_varmap(&varmap, DType::F32, dev);
|
||||
let model = MultiLevelPerceptron::new(vs.clone())?;
|
||||
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), LEARNING_RATE)?;
|
||||
let test_votes = m.test_votes.to_device(dev)?;
|
||||
let test_results = m.test_results.to_device(dev)?;
|
||||
let mut final_accuracy: f32 = 0.0;
|
||||
for epoch in 1..EPOCHS + 1 {
|
||||
let logits = model.forward(&train_votes)?;
|
||||
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
||||
let loss = loss::nll(&log_sm, &train_results)?;
|
||||
sgd.backward_step(&loss)?;
|
||||
|
||||
let test_logits = model.forward(&test_votes)?;
|
||||
let sum_ok = test_logits
|
||||
.argmax(D::Minus1)?
|
||||
.eq(&test_results)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_scalar::<f32>()?;
|
||||
let test_accuracy = sum_ok / test_results.dims1()? as f32;
|
||||
final_accuracy = 100. * test_accuracy;
|
||||
println!("Epoch: {epoch:3} Train loss: {:8.5} Test accuracy: {:5.2}%",
|
||||
loss.to_scalar::<f32>()?,
|
||||
final_accuracy
|
||||
);
|
||||
if final_accuracy == 100.0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if final_accuracy < 100.0 {
|
||||
Err(anyhow::Error::msg("The model is not trained well enough."))
|
||||
} else {
|
||||
Ok(model)
|
||||
}
|
||||
}
|
||||
// ANCHOR_END: book_training_simplified2
|
||||
|
||||
|
||||
}
|
45
candle-book/src/training/simplified.md
Normal file
45
candle-book/src/training/simplified.md
Normal file
@ -0,0 +1,45 @@
|
||||
# Simplified
|
||||
|
||||
## How its works
|
||||
|
||||
This program implements a neural network to predict the winner of the second round of elections based on the results of the first round.
|
||||
|
||||
Basic moments:
|
||||
|
||||
1. A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons.
|
||||
2. The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage.
|
||||
3. The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose.
|
||||
4. For training, samples with real data on the results of the first and second stages of different elections are used.
|
||||
5. The model is trained by backpropagation using gradient descent and the cross-entropy loss function.
|
||||
6. Model parameters (weights of neurons) are initialized randomly, then optimized during training.
|
||||
7. After training, the model is tested on a deferred sample to evaluate the accuracy.
|
||||
8. If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated.
|
||||
|
||||
Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data.
|
||||
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../simplified.rs:book_training_simplified1}}
|
||||
```
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../simplified.rs:book_training_simplified2}}
|
||||
```
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../simplified.rs:book_training_simplified3}}
|
||||
```
|
||||
|
||||
|
||||
## Example output
|
||||
|
||||
```bash
|
||||
Trying to train neural network.
|
||||
Epoch: 1 Train loss: 4.42555 Test accuracy: 0.00%
|
||||
Epoch: 2 Train loss: 0.84677 Test accuracy: 33.33%
|
||||
Epoch: 3 Train loss: 2.54335 Test accuracy: 33.33%
|
||||
Epoch: 4 Train loss: 0.37806 Test accuracy: 33.33%
|
||||
Epoch: 5 Train loss: 0.36647 Test accuracy: 100.00%
|
||||
real_life_votes: [13, 22]
|
||||
neural_network_prediction_result: 0.0
|
||||
```
|
@ -12,7 +12,7 @@ readme = "README.md"
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.2.3", optional = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
gemm = { workspace = true }
|
||||
half = { workspace = true }
|
||||
@ -26,6 +26,7 @@ rand_distr = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
yoke = { workspace = true }
|
||||
zip = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
|
@ -103,8 +103,10 @@ enum Command {
|
||||
|
||||
Quantize {
|
||||
/// The input file, in gguf format.
|
||||
in_file: std::path::PathBuf,
|
||||
in_file: Vec<std::path::PathBuf>,
|
||||
|
||||
/// The output file, in gguf format.
|
||||
#[arg(long)]
|
||||
out_file: std::path::PathBuf,
|
||||
|
||||
/// The quantization schema to apply.
|
||||
@ -150,8 +152,7 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
|
||||
}
|
||||
}
|
||||
Format::Safetensors => {
|
||||
let tensors = unsafe { candle_core::safetensors::MmapedFile::new(file)? };
|
||||
let tensors = tensors.deserialize()?;
|
||||
let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? };
|
||||
let mut tensors = tensors.tensors();
|
||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
for (name, view) in tensors.iter() {
|
||||
@ -218,15 +219,99 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_quantize_safetensors(
|
||||
in_files: &[std::path::PathBuf],
|
||||
out_file: std::path::PathBuf,
|
||||
q: Quantization,
|
||||
) -> Result<()> {
|
||||
let mut out_file = std::fs::File::create(out_file)?;
|
||||
let mut tensors = std::collections::HashMap::new();
|
||||
for in_file in in_files.iter() {
|
||||
let in_tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?;
|
||||
tensors.extend(in_tensors)
|
||||
}
|
||||
println!("tensors: {}", tensors.len());
|
||||
|
||||
let quantize_fn = match q {
|
||||
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
|
||||
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
|
||||
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
|
||||
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
|
||||
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
|
||||
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
|
||||
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
|
||||
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
|
||||
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
|
||||
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
|
||||
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
|
||||
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
|
||||
Quantization::F16 => QTensor::quantize::<half::f16>,
|
||||
Quantization::F32 => QTensor::quantize::<f32>,
|
||||
};
|
||||
let block_size = match q {
|
||||
Quantization::Q4_0 => k_quants::QK4_0,
|
||||
Quantization::Q4_1 => k_quants::QK4_1,
|
||||
Quantization::Q5_0 => k_quants::QK5_0,
|
||||
Quantization::Q5_1 => k_quants::QK5_1,
|
||||
Quantization::Q8_0 => k_quants::QK8_0,
|
||||
Quantization::Q8_1 => k_quants::QK8_1,
|
||||
Quantization::Q2k
|
||||
| Quantization::Q3k
|
||||
| Quantization::Q4k
|
||||
| Quantization::Q5k
|
||||
| Quantization::Q6k
|
||||
| Quantization::Q8k => k_quants::QK_K,
|
||||
Quantization::F16 | Quantization::F32 => 1,
|
||||
};
|
||||
|
||||
let qtensors = tensors
|
||||
.into_par_iter()
|
||||
.map(|(name, tensor)| {
|
||||
let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0;
|
||||
println!(" quantizing {name} {tensor:?} {should_quantize}");
|
||||
let tensor = if should_quantize {
|
||||
quantize_fn(&tensor)?
|
||||
} else {
|
||||
QTensor::quantize::<f32>(&tensor)?
|
||||
};
|
||||
Ok((name, tensor))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let qtensors = qtensors
|
||||
.iter()
|
||||
.map(|(k, v)| (k.as_str(), v))
|
||||
.collect::<Vec<_>>();
|
||||
gguf_file::write(&mut out_file, &[], &qtensors)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_quantize(
|
||||
in_file: std::path::PathBuf,
|
||||
in_files: &[std::path::PathBuf],
|
||||
out_file: std::path::PathBuf,
|
||||
q: Quantization,
|
||||
qmode: QuantizationMode,
|
||||
) -> Result<()> {
|
||||
if in_files.is_empty() {
|
||||
candle_core::bail!("no specified input files")
|
||||
}
|
||||
if let Some(extension) = out_file.extension() {
|
||||
if extension == "safetensors" {
|
||||
candle_core::bail!("the generated file cannot use the safetensors extension")
|
||||
}
|
||||
}
|
||||
if let Some(extension) = in_files[0].extension() {
|
||||
if extension == "safetensors" {
|
||||
return run_quantize_safetensors(in_files, out_file, q);
|
||||
}
|
||||
}
|
||||
|
||||
if in_files.len() != 1 {
|
||||
candle_core::bail!("only a single in-file can be used when quantizing gguf files")
|
||||
}
|
||||
|
||||
// Open the out file early so as to fail directly on missing directories etc.
|
||||
let mut out_file = std::fs::File::create(out_file)?;
|
||||
let mut in_ = std::fs::File::open(&in_file)?;
|
||||
let mut in_ = std::fs::File::open(&in_files[0])?;
|
||||
let content = gguf_file::Content::read(&mut in_)?;
|
||||
println!("tensors: {}", content.tensor_infos.len());
|
||||
|
||||
@ -252,7 +337,7 @@ fn run_quantize(
|
||||
.par_iter()
|
||||
.map(|(name, _)| {
|
||||
println!(" quantizing {name}");
|
||||
let mut in_file = std::fs::File::open(&in_file)?;
|
||||
let mut in_file = std::fs::File::open(&in_files[0])?;
|
||||
let tensor = content.tensor(&mut in_file, name)?;
|
||||
let tensor = qmode.quantize(name, tensor, quantize_fn)?;
|
||||
Ok((name, tensor))
|
||||
@ -293,7 +378,7 @@ fn main() -> anyhow::Result<()> {
|
||||
out_file,
|
||||
quantization,
|
||||
mode,
|
||||
} => run_quantize(in_file, out_file, quantization, mode)?,
|
||||
} => run_quantize(&in_file, out_file, quantization, mode)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -111,4 +111,6 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||
|
||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||
|
||||
fn set_seed(&self, _: u64) -> Result<()>;
|
||||
}
|
||||
|
@ -36,6 +36,8 @@ impl Tensor {
|
||||
// Do not call recursively on the "leaf" nodes.
|
||||
track_grad = true;
|
||||
nodes
|
||||
} else if node.dtype().is_int() {
|
||||
nodes
|
||||
} else if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::IndexAdd(t1, t2, t3, _)
|
||||
@ -69,7 +71,8 @@ impl Tensor {
|
||||
| Op::Binary(lhs, rhs, _)
|
||||
| Op::Gather(lhs, rhs, _)
|
||||
| Op::IndexSelect(lhs, rhs, _)
|
||||
| Op::Matmul(lhs, rhs) => {
|
||||
| Op::Matmul(lhs, rhs)
|
||||
| Op::SliceScatter0(lhs, rhs, _) => {
|
||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
let (tg, nodes) = walk(rhs, nodes, already_seen);
|
||||
@ -90,6 +93,9 @@ impl Tensor {
|
||||
nodes
|
||||
}
|
||||
}
|
||||
Op::Unary(_node, UnaryOp::Ceil)
|
||||
| Op::Unary(_node, UnaryOp::Floor)
|
||||
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
||||
Op::Reshape(node)
|
||||
| Op::UpsampleNearest1D(node)
|
||||
| Op::UpsampleNearest2D(node)
|
||||
@ -99,7 +105,6 @@ impl Tensor {
|
||||
| Op::Broadcast(node)
|
||||
| Op::Cmp(node, _)
|
||||
| Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
|
||||
| Op::ToDType(node)
|
||||
| Op::ToDevice(node)
|
||||
| Op::Transpose(node, _, _)
|
||||
| Op::Permute(node, _)
|
||||
@ -112,6 +117,15 @@ impl Tensor {
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
}
|
||||
Op::ToDType(node) => {
|
||||
if node.dtype().is_float() {
|
||||
let (tg, nodes) = walk(node, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
} else {
|
||||
nodes
|
||||
}
|
||||
}
|
||||
Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
|
||||
}
|
||||
} else {
|
||||
@ -270,6 +284,15 @@ impl Tensor {
|
||||
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "upsample-nearest2d",
|
||||
})?,
|
||||
Op::SliceScatter0(lhs, rhs, start_rhs) => {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?
|
||||
}
|
||||
Op::Gather(arg, indexes, dim) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
|
||||
@ -361,7 +384,7 @@ impl Tensor {
|
||||
}
|
||||
Op::ToDType(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
|
||||
*sum_grad = sum_grad.add(&grad.to_dtype(arg.dtype())?)?
|
||||
}
|
||||
Op::Copy(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
@ -441,7 +464,26 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
|
||||
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
|
||||
Op::Unary(_, UnaryOp::Floor) => {
|
||||
Err(Error::BackwardNotSupported { op: "floor" })?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Round) => {
|
||||
Err(Error::BackwardNotSupported { op: "round" })?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Gelu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let cube = arg.powf(3.)?;
|
||||
let tanh = (0.0356774 * &cube + (0.797885 * arg)?)?.tanh()?;
|
||||
let gelu_grad = (((0.5 * &tanh)?
|
||||
+ (0.0535161 * cube + (0.398942 * arg)?)? * (1. - tanh.powf(2.)?))?
|
||||
+ 0.5)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
|
||||
Op::Unary(_, UnaryOp::GeluErf) => {
|
||||
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Relu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||
|
@ -25,6 +25,19 @@ impl ParamsConv1D {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum CudnnFwdAlgo {
|
||||
ImplicitGemm,
|
||||
ImplicitPrecompGemm,
|
||||
Gemm,
|
||||
Direct,
|
||||
Fft,
|
||||
FftTiling,
|
||||
Winograd,
|
||||
WinogradNonFused,
|
||||
Count,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ParamsConv2D {
|
||||
pub(crate) b_size: usize,
|
||||
@ -37,6 +50,7 @@ pub struct ParamsConv2D {
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
pub(crate) dilation: usize,
|
||||
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||
}
|
||||
|
||||
impl ParamsConv2D {
|
||||
@ -188,6 +202,7 @@ impl Tensor {
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv2d_single_group(kernel, ¶ms)
|
||||
|
763
candle-core/src/cpu/erf.rs
Normal file
763
candle-core/src/cpu/erf.rs
Normal file
@ -0,0 +1,763 @@
|
||||
#![allow(clippy::excessive_precision)]
|
||||
// Code taken from https://github.com/statrs-dev/statrs
|
||||
//! Provides the [error](https://en.wikipedia.org/wiki/Error_function) and
|
||||
//! related functions
|
||||
|
||||
mod evaluate {
|
||||
//! Provides functions that don't have a numerical solution and must
|
||||
//! be solved computationally (e.g. evaluation of a polynomial)
|
||||
|
||||
/// evaluates a polynomial at `z` where `coeff` are the coeffecients
|
||||
/// to a polynomial of order `k` where `k` is the length of `coeff` and the
|
||||
/// coeffecient
|
||||
/// to the `k`th power is the `k`th element in coeff. E.g. [3,-1,2] equates to
|
||||
/// `2z^2 - z + 3`
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// Returns 0 for a 0 length coefficient slice
|
||||
pub fn polynomial(z: f64, coeff: &[f64]) -> f64 {
|
||||
let n = coeff.len();
|
||||
if n == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mut sum = *coeff.last().unwrap();
|
||||
for c in coeff[0..n - 1].iter().rev() {
|
||||
sum = *c + z * sum;
|
||||
}
|
||||
sum
|
||||
}
|
||||
}
|
||||
use std::f64;
|
||||
|
||||
/// `erf` calculates the error function at `x`.
|
||||
pub fn erf(x: f64) -> f64 {
|
||||
if x.is_nan() {
|
||||
f64::NAN
|
||||
} else if x >= 0.0 && x.is_infinite() {
|
||||
1.0
|
||||
} else if x <= 0.0 && x.is_infinite() {
|
||||
-1.0
|
||||
} else if x == 0. {
|
||||
0.0
|
||||
} else {
|
||||
erf_impl(x, false)
|
||||
}
|
||||
}
|
||||
|
||||
/// `erf_inv` calculates the inverse error function
|
||||
/// at `x`.
|
||||
pub fn erf_inv(x: f64) -> f64 {
|
||||
if x == 0.0 {
|
||||
0.0
|
||||
} else if x >= 1.0 {
|
||||
f64::INFINITY
|
||||
} else if x <= -1.0 {
|
||||
f64::NEG_INFINITY
|
||||
} else if x < 0.0 {
|
||||
erf_inv_impl(-x, 1.0 + x, -1.0)
|
||||
} else {
|
||||
erf_inv_impl(x, 1.0 - x, 1.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// `erfc` calculates the complementary error function
|
||||
/// at `x`.
|
||||
pub fn erfc(x: f64) -> f64 {
|
||||
if x.is_nan() {
|
||||
f64::NAN
|
||||
} else if x == f64::INFINITY {
|
||||
0.0
|
||||
} else if x == f64::NEG_INFINITY {
|
||||
2.0
|
||||
} else {
|
||||
erf_impl(x, true)
|
||||
}
|
||||
}
|
||||
|
||||
/// `erfc_inv` calculates the complementary inverse
|
||||
/// error function at `x`.
|
||||
pub fn erfc_inv(x: f64) -> f64 {
|
||||
if x <= 0.0 {
|
||||
f64::INFINITY
|
||||
} else if x >= 2.0 {
|
||||
f64::NEG_INFINITY
|
||||
} else if x > 1.0 {
|
||||
erf_inv_impl(-1.0 + x, 2.0 - x, -1.0)
|
||||
} else {
|
||||
erf_inv_impl(1.0 - x, x, 1.0)
|
||||
}
|
||||
}
|
||||
|
||||
// **********************************************************
|
||||
// ********** Coefficients for erf_impl polynomial **********
|
||||
// **********************************************************
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_impl`
|
||||
/// in the interval [1e-10, 0.5].
|
||||
const ERF_IMPL_AN: &[f64] = &[
|
||||
0.00337916709551257388990745,
|
||||
-0.00073695653048167948530905,
|
||||
-0.374732337392919607868241,
|
||||
0.0817442448733587196071743,
|
||||
-0.0421089319936548595203468,
|
||||
0.0070165709512095756344528,
|
||||
-0.00495091255982435110337458,
|
||||
0.000871646599037922480317225,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_impl`
|
||||
/// in the interval [1e-10, 0.5]
|
||||
const ERF_IMPL_AD: &[f64] = &[
|
||||
1.0,
|
||||
-0.218088218087924645390535,
|
||||
0.412542972725442099083918,
|
||||
-0.0841891147873106755410271,
|
||||
0.0655338856400241519690695,
|
||||
-0.0120019604454941768171266,
|
||||
0.00408165558926174048329689,
|
||||
-0.000615900721557769691924509,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [0.5, 0.75].
|
||||
const ERF_IMPL_BN: &[f64] = &[
|
||||
-0.0361790390718262471360258,
|
||||
0.292251883444882683221149,
|
||||
0.281447041797604512774415,
|
||||
0.125610208862766947294894,
|
||||
0.0274135028268930549240776,
|
||||
0.00250839672168065762786937,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [0.5, 0.75].
|
||||
const ERF_IMPL_BD: &[f64] = &[
|
||||
1.0,
|
||||
1.8545005897903486499845,
|
||||
1.43575803037831418074962,
|
||||
0.582827658753036572454135,
|
||||
0.124810476932949746447682,
|
||||
0.0113724176546353285778481,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [0.75, 1.25].
|
||||
const ERF_IMPL_CN: &[f64] = &[
|
||||
-0.0397876892611136856954425,
|
||||
0.153165212467878293257683,
|
||||
0.191260295600936245503129,
|
||||
0.10276327061989304213645,
|
||||
0.029637090615738836726027,
|
||||
0.0046093486780275489468812,
|
||||
0.000307607820348680180548455,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [0.75, 1.25].
|
||||
const ERF_IMPL_CD: &[f64] = &[
|
||||
1.0,
|
||||
1.95520072987627704987886,
|
||||
1.64762317199384860109595,
|
||||
0.768238607022126250082483,
|
||||
0.209793185936509782784315,
|
||||
0.0319569316899913392596356,
|
||||
0.00213363160895785378615014,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [1.25, 2.25].
|
||||
const ERF_IMPL_DN: &[f64] = &[
|
||||
-0.0300838560557949717328341,
|
||||
0.0538578829844454508530552,
|
||||
0.0726211541651914182692959,
|
||||
0.0367628469888049348429018,
|
||||
0.00964629015572527529605267,
|
||||
0.00133453480075291076745275,
|
||||
0.778087599782504251917881e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [1.25, 2.25].
|
||||
const ERF_IMPL_DD: &[f64] = &[
|
||||
1.0,
|
||||
1.75967098147167528287343,
|
||||
1.32883571437961120556307,
|
||||
0.552528596508757581287907,
|
||||
0.133793056941332861912279,
|
||||
0.0179509645176280768640766,
|
||||
0.00104712440019937356634038,
|
||||
-0.106640381820357337177643e-7,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [2.25, 3.5].
|
||||
const ERF_IMPL_EN: &[f64] = &[
|
||||
-0.0117907570137227847827732,
|
||||
0.014262132090538809896674,
|
||||
0.0202234435902960820020765,
|
||||
0.00930668299990432009042239,
|
||||
0.00213357802422065994322516,
|
||||
0.00025022987386460102395382,
|
||||
0.120534912219588189822126e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [2.25, 3.5].
|
||||
const ERF_IMPL_ED: &[f64] = &[
|
||||
1.0,
|
||||
1.50376225203620482047419,
|
||||
0.965397786204462896346934,
|
||||
0.339265230476796681555511,
|
||||
0.0689740649541569716897427,
|
||||
0.00771060262491768307365526,
|
||||
0.000371421101531069302990367,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [3.5, 5.25].
|
||||
const ERF_IMPL_FN: &[f64] = &[
|
||||
-0.00546954795538729307482955,
|
||||
0.00404190278731707110245394,
|
||||
0.0054963369553161170521356,
|
||||
0.00212616472603945399437862,
|
||||
0.000394984014495083900689956,
|
||||
0.365565477064442377259271e-4,
|
||||
0.135485897109932323253786e-5,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [3.5, 5.25].
|
||||
const ERF_IMPL_FD: &[f64] = &[
|
||||
1.0,
|
||||
1.21019697773630784832251,
|
||||
0.620914668221143886601045,
|
||||
0.173038430661142762569515,
|
||||
0.0276550813773432047594539,
|
||||
0.00240625974424309709745382,
|
||||
0.891811817251336577241006e-4,
|
||||
-0.465528836283382684461025e-11,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [5.25, 8].
|
||||
const ERF_IMPL_GN: &[f64] = &[
|
||||
-0.00270722535905778347999196,
|
||||
0.0013187563425029400461378,
|
||||
0.00119925933261002333923989,
|
||||
0.00027849619811344664248235,
|
||||
0.267822988218331849989363e-4,
|
||||
0.923043672315028197865066e-6,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [5.25, 8].
|
||||
const ERF_IMPL_GD: &[f64] = &[
|
||||
1.0,
|
||||
0.814632808543141591118279,
|
||||
0.268901665856299542168425,
|
||||
0.0449877216103041118694989,
|
||||
0.00381759663320248459168994,
|
||||
0.000131571897888596914350697,
|
||||
0.404815359675764138445257e-11,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [8, 11.5].
|
||||
const ERF_IMPL_HN: &[f64] = &[
|
||||
-0.00109946720691742196814323,
|
||||
0.000406425442750422675169153,
|
||||
0.000274499489416900707787024,
|
||||
0.465293770646659383436343e-4,
|
||||
0.320955425395767463401993e-5,
|
||||
0.778286018145020892261936e-7,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [8, 11.5].
|
||||
const ERF_IMPL_HD: &[f64] = &[
|
||||
1.0,
|
||||
0.588173710611846046373373,
|
||||
0.139363331289409746077541,
|
||||
0.0166329340417083678763028,
|
||||
0.00100023921310234908642639,
|
||||
0.24254837521587225125068e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [11.5, 17].
|
||||
const ERF_IMPL_IN: &[f64] = &[
|
||||
-0.00056907993601094962855594,
|
||||
0.000169498540373762264416984,
|
||||
0.518472354581100890120501e-4,
|
||||
0.382819312231928859704678e-5,
|
||||
0.824989931281894431781794e-7,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [11.5, 17].
|
||||
const ERF_IMPL_ID: &[f64] = &[
|
||||
1.0,
|
||||
0.339637250051139347430323,
|
||||
0.043472647870310663055044,
|
||||
0.00248549335224637114641629,
|
||||
0.535633305337152900549536e-4,
|
||||
-0.117490944405459578783846e-12,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [17, 24].
|
||||
const ERF_IMPL_JN: &[f64] = &[
|
||||
-0.000241313599483991337479091,
|
||||
0.574224975202501512365975e-4,
|
||||
0.115998962927383778460557e-4,
|
||||
0.581762134402593739370875e-6,
|
||||
0.853971555085673614607418e-8,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [17, 24].
|
||||
const ERF_IMPL_JD: &[f64] = &[
|
||||
1.0,
|
||||
0.233044138299687841018015,
|
||||
0.0204186940546440312625597,
|
||||
0.000797185647564398289151125,
|
||||
0.117019281670172327758019e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [24, 38].
|
||||
const ERF_IMPL_KN: &[f64] = &[
|
||||
-0.000146674699277760365803642,
|
||||
0.162666552112280519955647e-4,
|
||||
0.269116248509165239294897e-5,
|
||||
0.979584479468091935086972e-7,
|
||||
0.101994647625723465722285e-8,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [24, 38].
|
||||
const ERF_IMPL_KD: &[f64] = &[
|
||||
1.0,
|
||||
0.165907812944847226546036,
|
||||
0.0103361716191505884359634,
|
||||
0.000286593026373868366935721,
|
||||
0.298401570840900340874568e-5,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [38, 60].
|
||||
const ERF_IMPL_LN: &[f64] = &[
|
||||
-0.583905797629771786720406e-4,
|
||||
0.412510325105496173512992e-5,
|
||||
0.431790922420250949096906e-6,
|
||||
0.993365155590013193345569e-8,
|
||||
0.653480510020104699270084e-10,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [38, 60].
|
||||
const ERF_IMPL_LD: &[f64] = &[
|
||||
1.0,
|
||||
0.105077086072039915406159,
|
||||
0.00414278428675475620830226,
|
||||
0.726338754644523769144108e-4,
|
||||
0.477818471047398785369849e-6,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [60, 85].
|
||||
const ERF_IMPL_MN: &[f64] = &[
|
||||
-0.196457797609229579459841e-4,
|
||||
0.157243887666800692441195e-5,
|
||||
0.543902511192700878690335e-7,
|
||||
0.317472492369117710852685e-9,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [60, 85].
|
||||
const ERF_IMPL_MD: &[f64] = &[
|
||||
1.0,
|
||||
0.052803989240957632204885,
|
||||
0.000926876069151753290378112,
|
||||
0.541011723226630257077328e-5,
|
||||
0.535093845803642394908747e-15,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [85, 110].
|
||||
const ERF_IMPL_NN: &[f64] = &[
|
||||
-0.789224703978722689089794e-5,
|
||||
0.622088451660986955124162e-6,
|
||||
0.145728445676882396797184e-7,
|
||||
0.603715505542715364529243e-10,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [85, 110].
|
||||
const ERF_IMPL_ND: &[f64] = &[
|
||||
1.0,
|
||||
0.0375328846356293715248719,
|
||||
0.000467919535974625308126054,
|
||||
0.193847039275845656900547e-5,
|
||||
];
|
||||
|
||||
// **********************************************************
|
||||
// ********** Coefficients for erf_inv_impl polynomial ******
|
||||
// **********************************************************
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0, 0.5].
|
||||
const ERF_INV_IMPL_AN: &[f64] = &[
|
||||
-0.000508781949658280665617,
|
||||
-0.00836874819741736770379,
|
||||
0.0334806625409744615033,
|
||||
-0.0126926147662974029034,
|
||||
-0.0365637971411762664006,
|
||||
0.0219878681111168899165,
|
||||
0.00822687874676915743155,
|
||||
-0.00538772965071242932965,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0, 0.5].
|
||||
const ERF_INV_IMPL_AD: &[f64] = &[
|
||||
1.0,
|
||||
-0.970005043303290640362,
|
||||
-1.56574558234175846809,
|
||||
1.56221558398423026363,
|
||||
0.662328840472002992063,
|
||||
-0.71228902341542847553,
|
||||
-0.0527396382340099713954,
|
||||
0.0795283687341571680018,
|
||||
-0.00233393759374190016776,
|
||||
0.000886216390456424707504,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.5, 0.75].
|
||||
const ERF_INV_IMPL_BN: &[f64] = &[
|
||||
-0.202433508355938759655,
|
||||
0.105264680699391713268,
|
||||
8.37050328343119927838,
|
||||
17.6447298408374015486,
|
||||
-18.8510648058714251895,
|
||||
-44.6382324441786960818,
|
||||
17.445385985570866523,
|
||||
21.1294655448340526258,
|
||||
-3.67192254707729348546,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.5, 0.75].
|
||||
const ERF_INV_IMPL_BD: &[f64] = &[
|
||||
1.0,
|
||||
6.24264124854247537712,
|
||||
3.9713437953343869095,
|
||||
-28.6608180499800029974,
|
||||
-20.1432634680485188801,
|
||||
48.5609213108739935468,
|
||||
10.8268667355460159008,
|
||||
-22.6436933413139721736,
|
||||
1.72114765761200282724,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x less than 3.
|
||||
const ERF_INV_IMPL_CN: &[f64] = &[
|
||||
-0.131102781679951906451,
|
||||
-0.163794047193317060787,
|
||||
0.117030156341995252019,
|
||||
0.387079738972604337464,
|
||||
0.337785538912035898924,
|
||||
0.142869534408157156766,
|
||||
0.0290157910005329060432,
|
||||
0.00214558995388805277169,
|
||||
-0.679465575181126350155e-6,
|
||||
0.285225331782217055858e-7,
|
||||
-0.681149956853776992068e-9,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x less than 3.
|
||||
const ERF_INV_IMPL_CD: &[f64] = &[
|
||||
1.0,
|
||||
3.46625407242567245975,
|
||||
5.38168345707006855425,
|
||||
4.77846592945843778382,
|
||||
2.59301921623620271374,
|
||||
0.848854343457902036425,
|
||||
0.152264338295331783612,
|
||||
0.01105924229346489121,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 3 and 6.
|
||||
const ERF_INV_IMPL_DN: &[f64] = &[
|
||||
-0.0350353787183177984712,
|
||||
-0.00222426529213447927281,
|
||||
0.0185573306514231072324,
|
||||
0.00950804701325919603619,
|
||||
0.00187123492819559223345,
|
||||
0.000157544617424960554631,
|
||||
0.460469890584317994083e-5,
|
||||
-0.230404776911882601748e-9,
|
||||
0.266339227425782031962e-11,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 3 and 6.
|
||||
const ERF_INV_IMPL_DD: &[f64] = &[
|
||||
1.0,
|
||||
1.3653349817554063097,
|
||||
0.762059164553623404043,
|
||||
0.220091105764131249824,
|
||||
0.0341589143670947727934,
|
||||
0.00263861676657015992959,
|
||||
0.764675292302794483503e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 6 and 18.
|
||||
const ERF_INV_IMPL_EN: &[f64] = &[
|
||||
-0.0167431005076633737133,
|
||||
-0.00112951438745580278863,
|
||||
0.00105628862152492910091,
|
||||
0.000209386317487588078668,
|
||||
0.149624783758342370182e-4,
|
||||
0.449696789927706453732e-6,
|
||||
0.462596163522878599135e-8,
|
||||
-0.281128735628831791805e-13,
|
||||
0.99055709973310326855e-16,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 6 and 18.
|
||||
const ERF_INV_IMPL_ED: &[f64] = &[
|
||||
1.0,
|
||||
0.591429344886417493481,
|
||||
0.138151865749083321638,
|
||||
0.0160746087093676504695,
|
||||
0.000964011807005165528527,
|
||||
0.275335474764726041141e-4,
|
||||
0.282243172016108031869e-6,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 18 and 44.
|
||||
const ERF_INV_IMPL_FN: &[f64] = &[
|
||||
-0.0024978212791898131227,
|
||||
-0.779190719229053954292e-5,
|
||||
0.254723037413027451751e-4,
|
||||
0.162397777342510920873e-5,
|
||||
0.396341011304801168516e-7,
|
||||
0.411632831190944208473e-9,
|
||||
0.145596286718675035587e-11,
|
||||
-0.116765012397184275695e-17,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 18 and 44.
|
||||
const ERF_INV_IMPL_FD: &[f64] = &[
|
||||
1.0,
|
||||
0.207123112214422517181,
|
||||
0.0169410838120975906478,
|
||||
0.000690538265622684595676,
|
||||
0.145007359818232637924e-4,
|
||||
0.144437756628144157666e-6,
|
||||
0.509761276599778486139e-9,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x greater than 44.
|
||||
const ERF_INV_IMPL_GN: &[f64] = &[
|
||||
-0.000539042911019078575891,
|
||||
-0.28398759004727721098e-6,
|
||||
0.899465114892291446442e-6,
|
||||
0.229345859265920864296e-7,
|
||||
0.225561444863500149219e-9,
|
||||
0.947846627503022684216e-12,
|
||||
0.135880130108924861008e-14,
|
||||
-0.348890393399948882918e-21,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x greater than 44.
|
||||
const ERF_INV_IMPL_GD: &[f64] = &[
|
||||
1.0,
|
||||
0.0845746234001899436914,
|
||||
0.00282092984726264681981,
|
||||
0.468292921940894236786e-4,
|
||||
0.399968812193862100054e-6,
|
||||
0.161809290887904476097e-8,
|
||||
0.231558608310259605225e-11,
|
||||
];
|
||||
|
||||
/// `erf_impl` computes the error function at `z`.
|
||||
/// If `inv` is true, `1 - erf` is calculated as opposed to `erf`
|
||||
fn erf_impl(z: f64, inv: bool) -> f64 {
|
||||
if z < 0.0 {
|
||||
if !inv {
|
||||
return -erf_impl(-z, false);
|
||||
}
|
||||
if z < -0.5 {
|
||||
return 2.0 - erf_impl(-z, true);
|
||||
}
|
||||
return 1.0 + erf_impl(-z, false);
|
||||
}
|
||||
|
||||
let result = if z < 0.5 {
|
||||
if z < 1e-10 {
|
||||
z * 1.125 + z * 0.003379167095512573896158903121545171688
|
||||
} else {
|
||||
z * 1.125
|
||||
+ z * evaluate::polynomial(z, ERF_IMPL_AN) / evaluate::polynomial(z, ERF_IMPL_AD)
|
||||
}
|
||||
} else if z < 110.0 {
|
||||
let (r, b) = if z < 0.75 {
|
||||
(
|
||||
evaluate::polynomial(z - 0.5, ERF_IMPL_BN)
|
||||
/ evaluate::polynomial(z - 0.5, ERF_IMPL_BD),
|
||||
0.3440242112,
|
||||
)
|
||||
} else if z < 1.25 {
|
||||
(
|
||||
evaluate::polynomial(z - 0.75, ERF_IMPL_CN)
|
||||
/ evaluate::polynomial(z - 0.75, ERF_IMPL_CD),
|
||||
0.419990927,
|
||||
)
|
||||
} else if z < 2.25 {
|
||||
(
|
||||
evaluate::polynomial(z - 1.25, ERF_IMPL_DN)
|
||||
/ evaluate::polynomial(z - 1.25, ERF_IMPL_DD),
|
||||
0.4898625016,
|
||||
)
|
||||
} else if z < 3.5 {
|
||||
(
|
||||
evaluate::polynomial(z - 2.25, ERF_IMPL_EN)
|
||||
/ evaluate::polynomial(z - 2.25, ERF_IMPL_ED),
|
||||
0.5317370892,
|
||||
)
|
||||
} else if z < 5.25 {
|
||||
(
|
||||
evaluate::polynomial(z - 3.5, ERF_IMPL_FN)
|
||||
/ evaluate::polynomial(z - 3.5, ERF_IMPL_FD),
|
||||
0.5489973426,
|
||||
)
|
||||
} else if z < 8.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 5.25, ERF_IMPL_GN)
|
||||
/ evaluate::polynomial(z - 5.25, ERF_IMPL_GD),
|
||||
0.5571740866,
|
||||
)
|
||||
} else if z < 11.5 {
|
||||
(
|
||||
evaluate::polynomial(z - 8.0, ERF_IMPL_HN)
|
||||
/ evaluate::polynomial(z - 8.0, ERF_IMPL_HD),
|
||||
0.5609807968,
|
||||
)
|
||||
} else if z < 17.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 11.5, ERF_IMPL_IN)
|
||||
/ evaluate::polynomial(z - 11.5, ERF_IMPL_ID),
|
||||
0.5626493692,
|
||||
)
|
||||
} else if z < 24.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 17.0, ERF_IMPL_JN)
|
||||
/ evaluate::polynomial(z - 17.0, ERF_IMPL_JD),
|
||||
0.5634598136,
|
||||
)
|
||||
} else if z < 38.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 24.0, ERF_IMPL_KN)
|
||||
/ evaluate::polynomial(z - 24.0, ERF_IMPL_KD),
|
||||
0.5638477802,
|
||||
)
|
||||
} else if z < 60.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 38.0, ERF_IMPL_LN)
|
||||
/ evaluate::polynomial(z - 38.0, ERF_IMPL_LD),
|
||||
0.5640528202,
|
||||
)
|
||||
} else if z < 85.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 60.0, ERF_IMPL_MN)
|
||||
/ evaluate::polynomial(z - 60.0, ERF_IMPL_MD),
|
||||
0.5641309023,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
evaluate::polynomial(z - 85.0, ERF_IMPL_NN)
|
||||
/ evaluate::polynomial(z - 85.0, ERF_IMPL_ND),
|
||||
0.5641584396,
|
||||
)
|
||||
};
|
||||
let g = (-z * z).exp() / z;
|
||||
g * b + g * r
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
if inv && z >= 0.5 {
|
||||
result
|
||||
} else if z >= 0.5 || inv {
|
||||
1.0 - result
|
||||
} else {
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
// `erf_inv_impl` computes the inverse error function where
|
||||
// `p`,`q`, and `s` are the first, second, and third intermediate
|
||||
// parameters respectively
|
||||
fn erf_inv_impl(p: f64, q: f64, s: f64) -> f64 {
|
||||
let result = if p <= 0.5 {
|
||||
let y = 0.0891314744949340820313;
|
||||
let g = p * (p + 10.0);
|
||||
let r = evaluate::polynomial(p, ERF_INV_IMPL_AN) / evaluate::polynomial(p, ERF_INV_IMPL_AD);
|
||||
g * y + g * r
|
||||
} else if q >= 0.25 {
|
||||
let y = 2.249481201171875;
|
||||
let g = (-2.0 * q.ln()).sqrt();
|
||||
let xs = q - 0.25;
|
||||
let r =
|
||||
evaluate::polynomial(xs, ERF_INV_IMPL_BN) / evaluate::polynomial(xs, ERF_INV_IMPL_BD);
|
||||
g / (y + r)
|
||||
} else {
|
||||
let x = (-q.ln()).sqrt();
|
||||
if x < 3.0 {
|
||||
let y = 0.807220458984375;
|
||||
let xs = x - 1.125;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_CN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_CD);
|
||||
y * x + r * x
|
||||
} else if x < 6.0 {
|
||||
let y = 0.93995571136474609375;
|
||||
let xs = x - 3.0;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_DN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_DD);
|
||||
y * x + r * x
|
||||
} else if x < 18.0 {
|
||||
let y = 0.98362827301025390625;
|
||||
let xs = x - 6.0;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_EN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_ED);
|
||||
y * x + r * x
|
||||
} else if x < 44.0 {
|
||||
let y = 0.99714565277099609375;
|
||||
let xs = x - 18.0;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_FN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_FD);
|
||||
y * x + r * x
|
||||
} else {
|
||||
let y = 0.99941349029541015625;
|
||||
let xs = x - 44.0;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_GN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_GD);
|
||||
y * x + r * x
|
||||
}
|
||||
};
|
||||
s * result
|
||||
}
|
@ -1,3 +1,4 @@
|
||||
pub mod erf;
|
||||
pub mod kernels;
|
||||
|
||||
trait Cpu<const ARR: usize> {
|
||||
|
@ -2603,6 +2603,10 @@ impl BackendDevice for CpuDevice {
|
||||
Ok(Self)
|
||||
}
|
||||
|
||||
fn set_seed(&self, _seed: u64) -> Result<()> {
|
||||
crate::bail!("cannot seed the CPU rng with set_seed")
|
||||
}
|
||||
|
||||
fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> {
|
||||
use rand::prelude::*;
|
||||
|
||||
|
@ -223,6 +223,14 @@ impl BackendDevice for CudaDevice {
|
||||
})
|
||||
}
|
||||
|
||||
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||
// We do not call set_seed but instead create a new curand object. This ensures that the
|
||||
// state will be identical and the same random numbers will be generated.
|
||||
let mut curand = self.curand.lock().unwrap();
|
||||
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
crate::DeviceLocation::Cuda {
|
||||
gpu_id: self.device.ordinal(),
|
||||
@ -884,8 +892,6 @@ impl<'a> Map1 for IndexSelect<'a> {
|
||||
};
|
||||
let ids_shape = ids_l.shape();
|
||||
let ids_dims = ids_shape.dims();
|
||||
let ids_el = ids_shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(ids_el as u32);
|
||||
let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?;
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
@ -893,19 +899,23 @@ impl<'a> Map1 for IndexSelect<'a> {
|
||||
};
|
||||
let left_size: usize = src_l.dims()[..self.2].iter().product();
|
||||
let right_size: usize = src_l.dims()[self.2 + 1..].iter().product();
|
||||
let dim_size = src_l.dims()[self.2];
|
||||
let src_dim_size = src_l.dims()[self.2];
|
||||
let ids_dim_size = ids_shape.elem_count();
|
||||
let dst_el = ids_shape.elem_count() * left_size * right_size;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(ids_el * left_size * right_size) }.w()?;
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let params = (
|
||||
ids_el,
|
||||
dst_el,
|
||||
ids_dims.len(),
|
||||
&ds,
|
||||
ids,
|
||||
&src,
|
||||
&out,
|
||||
left_size,
|
||||
dim_size,
|
||||
src_dim_size,
|
||||
ids_dim_size,
|
||||
right_size,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
@ -2161,7 +2171,7 @@ impl BackendStorage for CudaStorage {
|
||||
if src_l.is_contiguous() {
|
||||
dev.dtod_copy(&src, &mut dst).w()?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;
|
||||
let func = dev.get_or_load_func("ucopy_f64", kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
||||
// SAFETY: ffi.
|
||||
|
@ -34,6 +34,9 @@ pub(crate) fn launch_conv2d<
|
||||
params: &crate::conv::ParamsConv2D,
|
||||
dev: &crate::cuda_backend::CudaDevice,
|
||||
) -> crate::Result<()> {
|
||||
use crate::conv::CudnnFwdAlgo as CandleAlgo;
|
||||
use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
|
||||
|
||||
let device_id = dev.id();
|
||||
let cudnn = CUDNN.with(|cudnn| {
|
||||
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
||||
@ -90,7 +93,20 @@ pub(crate) fn launch_conv2d<
|
||||
w: &w,
|
||||
y: &y,
|
||||
};
|
||||
let alg = conv2d.pick_algorithm()?;
|
||||
let alg = match params.cudnn_fwd_algo {
|
||||
None => conv2d.pick_algorithm()?,
|
||||
Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
||||
Some(CandleAlgo::ImplicitPrecompGemm) => {
|
||||
A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
|
||||
}
|
||||
Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
||||
Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
||||
Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
||||
Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
||||
Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
||||
Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
|
||||
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
|
||||
};
|
||||
let workspace_size = conv2d.get_workspace_size(alg)?;
|
||||
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
|
||||
unsafe {
|
||||
|
@ -128,6 +128,13 @@ impl Device {
|
||||
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
||||
}
|
||||
|
||||
pub fn set_seed(&self, seed: u64) -> Result<()> {
|
||||
match self {
|
||||
Self::Cpu => crate::cpu_backend::CpuDevice.set_seed(seed),
|
||||
Self::Cuda(c) => c.set_seed(seed),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn same_device(&self, rhs: &Self) -> bool {
|
||||
match (self, rhs) {
|
||||
(Self::Cpu, Self::Cpu) => true,
|
||||
|
@ -67,6 +67,20 @@ impl DType {
|
||||
Self::F64 => 8,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_int(&self) -> bool {
|
||||
match self {
|
||||
Self::U8 | Self::U32 | Self::I64 => true,
|
||||
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_float(&self) -> bool {
|
||||
match self {
|
||||
Self::U8 | Self::U32 | Self::I64 => false,
|
||||
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait WithDType:
|
||||
|
@ -167,6 +167,10 @@ impl crate::backend::BackendDevice for CudaDevice {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn set_seed(&self, _: u64) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
fail!()
|
||||
}
|
||||
|
@ -142,6 +142,9 @@ pub enum Error {
|
||||
#[error("{op} expects at least one tensor")]
|
||||
OpRequiresAtLeastOneTensor { op: &'static str },
|
||||
|
||||
#[error("{op} expects at least two tensors")]
|
||||
OpRequiresAtLeastTwoTensors { op: &'static str },
|
||||
|
||||
#[error("backward is not supported for {op}")]
|
||||
BackwardNotSupported { op: &'static str },
|
||||
|
||||
|
@ -250,8 +250,6 @@ impl Tensor {
|
||||
if header.fortran_order {
|
||||
return Err(Error::Npy("fortran order not supported".to_string()));
|
||||
}
|
||||
let mut data: Vec<u8> = vec![];
|
||||
reader.read_to_end(&mut data)?;
|
||||
Self::from_reader(header.shape(), header.descr, &mut reader)
|
||||
}
|
||||
|
||||
|
@ -58,8 +58,13 @@ pub enum UnaryOp {
|
||||
Sqr,
|
||||
Sqrt,
|
||||
Gelu,
|
||||
GeluErf,
|
||||
Erf,
|
||||
Relu,
|
||||
Tanh,
|
||||
Floor,
|
||||
Ceil,
|
||||
Round,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -131,6 +136,7 @@ pub enum Op {
|
||||
Copy(Tensor),
|
||||
Broadcast(Tensor),
|
||||
Narrow(Tensor, usize, usize, usize),
|
||||
SliceScatter0(Tensor, Tensor, usize),
|
||||
Reshape(Tensor),
|
||||
ToDevice(Tensor),
|
||||
Transpose(Tensor, usize, usize),
|
||||
@ -325,8 +331,13 @@ pub(crate) struct Recip;
|
||||
pub(crate) struct Sqr;
|
||||
pub(crate) struct Sqrt;
|
||||
pub(crate) struct Gelu;
|
||||
pub(crate) struct GeluErf;
|
||||
pub(crate) struct Erf;
|
||||
pub(crate) struct Relu;
|
||||
pub(crate) struct Tanh;
|
||||
pub(crate) struct Floor;
|
||||
pub(crate) struct Ceil;
|
||||
pub(crate) struct Round;
|
||||
|
||||
macro_rules! bin_op {
|
||||
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
||||
@ -621,6 +632,176 @@ impl UnaryOpT for Gelu {
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Erf {
|
||||
const NAME: &'static str = "erf";
|
||||
const KERNEL: &'static str = "uerf";
|
||||
const V: Self = Erf;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
bf16::from_f64(Self::f64(v.to_f64()))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
f16::from_f64(Self::f64(v.to_f64()))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
Self::f64(v as f64) as f32
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
crate::cpu::erf::erf(v)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(_: u8) -> u8 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Ceil {
|
||||
const NAME: &'static str = "ceil";
|
||||
const KERNEL: &'static str = "uceil";
|
||||
const V: Self = Ceil;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v.ceil()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v.ceil()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v.ceil()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v.ceil()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Floor {
|
||||
const NAME: &'static str = "floor";
|
||||
const KERNEL: &'static str = "ufloor";
|
||||
const V: Self = Floor;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v.floor()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v.floor()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v.floor()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v.floor()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Round {
|
||||
const NAME: &'static str = "round";
|
||||
const KERNEL: &'static str = "uround";
|
||||
const V: Self = Round;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v.round()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v.round()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v.round()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v.round()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for GeluErf {
|
||||
const NAME: &'static str = "gelu_erf";
|
||||
const KERNEL: &'static str = "ugelu_erf";
|
||||
const V: Self = GeluErf;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
bf16::from_f64(Self::f64(v.to_f64()))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
f16::from_f64(Self::f64(v.to_f64()))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
Self::f64(v as f64) as f32
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
(crate::cpu::erf::erf(v / 2f64.sqrt()) + 1.) * 0.5 * v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(_: u8) -> u8 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Relu {
|
||||
const NAME: &'static str = "relu";
|
||||
const KERNEL: &'static str = "urelu";
|
||||
|
@ -193,6 +193,50 @@ impl Object {
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_tensor_info(
|
||||
self,
|
||||
name: Self,
|
||||
dir_name: &std::path::Path,
|
||||
) -> Result<Option<TensorInfo>> {
|
||||
let name = match name.unicode() {
|
||||
Ok(name) => name,
|
||||
Err(_) => return Ok(None),
|
||||
};
|
||||
let (callable, args) = match self.reduce() {
|
||||
Ok(callable_args) => callable_args,
|
||||
_ => return Ok(None),
|
||||
};
|
||||
let (callable, args) = match callable {
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "torch._tensor" && class_name == "_rebuild_from_type_v2" => {
|
||||
let mut args = args.tuple()?;
|
||||
let callable = args.remove(0);
|
||||
let args = args.remove(1);
|
||||
(callable, args)
|
||||
}
|
||||
_ => (callable, args),
|
||||
};
|
||||
match callable {
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
|
||||
_ => return Ok(None),
|
||||
};
|
||||
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
|
||||
let mut path = dir_name.to_path_buf();
|
||||
path.push(file_path);
|
||||
Ok(Some(TensorInfo {
|
||||
name,
|
||||
dtype,
|
||||
layout,
|
||||
path: path.to_string_lossy().into_owned(),
|
||||
storage_size,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Object> for String {
|
||||
@ -565,6 +609,7 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
|
||||
"HalfStorage" => DType::F16,
|
||||
"BFloat16Storage" => DType::BF16,
|
||||
"ByteStorage" => DType::U8,
|
||||
"LongStorage" => DType::I64,
|
||||
other => {
|
||||
crate::bail!("unsupported storage type {other}")
|
||||
}
|
||||
@ -623,50 +668,10 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
||||
};
|
||||
if let Object::Dict(key_values) = obj {
|
||||
for (name, value) in key_values.into_iter() {
|
||||
let name = match name.unicode() {
|
||||
Ok(name) => name,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let (callable, args) = match value.reduce() {
|
||||
Ok(callable_args) => callable_args,
|
||||
_ => continue,
|
||||
};
|
||||
let (callable, args) = match callable {
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "torch._tensor"
|
||||
&& class_name == "_rebuild_from_type_v2" =>
|
||||
{
|
||||
let mut args = args.tuple()?;
|
||||
let callable = args.remove(0);
|
||||
let args = args.remove(1);
|
||||
(callable, args)
|
||||
}
|
||||
_ => (callable, args),
|
||||
};
|
||||
match callable {
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
|
||||
_ => continue,
|
||||
};
|
||||
match rebuild_args(args) {
|
||||
Ok((layout, dtype, file_path, storage_size)) => {
|
||||
let mut path = dir_name.clone();
|
||||
path.push(file_path);
|
||||
tensor_infos.push(TensorInfo {
|
||||
name,
|
||||
dtype,
|
||||
layout,
|
||||
path: path.to_string_lossy().into_owned(),
|
||||
storage_size,
|
||||
})
|
||||
}
|
||||
Err(err) => {
|
||||
eprintln!("skipping {name}: {err:?}")
|
||||
}
|
||||
match value.into_tensor_info(name, &dir_name) {
|
||||
Ok(Some(tensor_info)) => tensor_infos.push(tensor_info),
|
||||
Ok(None) => {}
|
||||
Err(err) => eprintln!("skipping: {err:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -723,3 +728,16 @@ impl PthTensors {
|
||||
Ok(Some(tensor))
|
||||
}
|
||||
}
|
||||
|
||||
/// Read all the tensors from a PyTorch pth file.
|
||||
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
|
||||
let pth = PthTensors::new(path)?;
|
||||
let tensor_names = pth.tensor_infos.keys();
|
||||
let mut tensors = Vec::with_capacity(tensor_names.len());
|
||||
for name in tensor_names {
|
||||
if let Some(tensor) = pth.get(name)? {
|
||||
tensors.push((name.to_string(), tensor))
|
||||
}
|
||||
}
|
||||
Ok(tensors)
|
||||
}
|
||||
|
@ -638,3 +638,35 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
|
||||
Ok(hsum_float_8(acc) + summs)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
let qk = QK_K;
|
||||
if n % qk != 0 {
|
||||
crate::bail!("vec_dot_q8k_8k: {n} is not divisible by {qk}")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||
let mut sumi = _mm256_setzero_si256();
|
||||
let x_qs = xs.qs.as_ptr();
|
||||
let y_qs = ys.qs.as_ptr();
|
||||
for j in (0..QK_K).step_by(32) {
|
||||
let xs = _mm256_loadu_si256(x_qs.add(j) as *const __m256i);
|
||||
let ys = _mm256_loadu_si256(y_qs.add(j) as *const __m256i);
|
||||
|
||||
let xs0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xs, 0));
|
||||
let ys0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(ys, 0));
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(xs0, ys0));
|
||||
|
||||
let xs1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xs, 1));
|
||||
let ys1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(ys, 1));
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(xs1, ys1));
|
||||
}
|
||||
let d = _mm256_set1_ps(xs.d * ys.d);
|
||||
acc = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi), acc);
|
||||
}
|
||||
Ok(hsum_float_8(acc))
|
||||
}
|
||||
}
|
||||
|
@ -135,7 +135,13 @@ pub fn qtensor_from_ggml(
|
||||
dims: Vec<usize>,
|
||||
) -> Result<super::QTensor> {
|
||||
let tensor_elems = dims.iter().product::<usize>();
|
||||
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
|
||||
let blck_size = ggml_dtype.blck_size();
|
||||
if tensor_elems % blck_size != 0 {
|
||||
crate::bail!(
|
||||
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
|
||||
)
|
||||
}
|
||||
let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
|
||||
|
||||
match ggml_dtype {
|
||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
|
||||
|
@ -59,8 +59,13 @@ impl TensorInfo {
|
||||
tensor_data_offset: u64,
|
||||
) -> Result<QTensor> {
|
||||
let tensor_elems = self.shape.elem_count();
|
||||
let size_in_bytes =
|
||||
tensor_elems * self.ggml_dtype.type_size() / self.ggml_dtype.blck_size();
|
||||
let blck_size = self.ggml_dtype.blck_size();
|
||||
if tensor_elems % blck_size != 0 {
|
||||
crate::bail!(
|
||||
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
|
||||
)
|
||||
}
|
||||
let size_in_bytes = tensor_elems / blck_size * self.ggml_dtype.type_size();
|
||||
let mut raw_data = vec![0u8; size_in_bytes];
|
||||
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
|
||||
reader.read_exact(&mut raw_data)?;
|
||||
|
@ -34,6 +34,9 @@ pub trait GgmlType: Sized + Clone + Send + Sync {
|
||||
/// Dot product used as a building block for quantized mat-mul.
|
||||
/// n is the number of elements to be considered.
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>;
|
||||
|
||||
/// Generic implementation of the dot product without simd optimizations.
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@ -225,6 +228,13 @@ impl GgmlType for BlockQ4_0 {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q4_0_q8_0(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "simd128")]
|
||||
return super::simd128::vec_dot_q4_0_q8_0(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
let nb = n / qk;
|
||||
if n % QK8_0 != 0 {
|
||||
@ -255,6 +265,10 @@ impl GgmlType for BlockQ4_1 {
|
||||
type VecDotType = BlockQ8_1;
|
||||
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
// ggml_vec_dot_q4_1_q8_1
|
||||
let qk = QK8_1;
|
||||
if n % qk != 0 {
|
||||
@ -354,7 +368,10 @@ impl GgmlType for BlockQ5_0 {
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q5_0_q8_0: {n}, nb is not divisible by 2")
|
||||
}
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(_n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
// Generic implementation.
|
||||
let mut sumf = 0f32;
|
||||
|
||||
@ -445,6 +462,10 @@ impl GgmlType for BlockQ5_1 {
|
||||
type VecDotType = BlockQ8_1;
|
||||
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
let qk = Self::BLCK_SIZE;
|
||||
if n % Self::BLCK_SIZE != 0 {
|
||||
crate::bail!("vec_dot_q5_1_q8_1: {n} is not divisible by {qk}")
|
||||
@ -606,6 +627,13 @@ impl GgmlType for BlockQ8_0 {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q8_0_q8_0(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "simd128")]
|
||||
return super::simd128::vec_dot_q8_0_q8_0(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||
@ -631,7 +659,11 @@ impl GgmlType for BlockQ8_1 {
|
||||
const BLCK_SIZE: usize = QK8_1;
|
||||
type VecDotType = BlockQ8_1;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
unimplemented!("no support for vec-dot on Q8_1")
|
||||
}
|
||||
|
||||
@ -681,6 +713,13 @@ impl GgmlType for BlockQ2K {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q2k_q8k(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "simd128")]
|
||||
return super::simd128::vec_dot_q2k_q8k(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
@ -701,18 +740,17 @@ impl GgmlType for BlockQ2K {
|
||||
|
||||
let mut isum = 0;
|
||||
let mut is = 0;
|
||||
let mut d;
|
||||
for _ in 0..(QK_K / 128) {
|
||||
let mut shift = 0;
|
||||
for _ in 0..4 {
|
||||
d = (sc[is] & 0xF) as i32;
|
||||
let d = (sc[is] & 0xF) as i32;
|
||||
is += 1;
|
||||
let mut isuml = 0;
|
||||
for l in 0..16 {
|
||||
isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
|
||||
}
|
||||
isum += d * isuml;
|
||||
d = (sc[is] & 0xF) as i32;
|
||||
let d = (sc[is] & 0xF) as i32;
|
||||
is += 1;
|
||||
isuml = 0;
|
||||
for l in 16..32 {
|
||||
@ -851,6 +889,10 @@ impl GgmlType for BlockQ3K {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q3k_q8k(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
@ -1077,7 +1119,6 @@ impl GgmlType for BlockQ3K {
|
||||
let d_all = block.d.to_f32();
|
||||
let mut m = 1;
|
||||
let mut is = 0;
|
||||
let mut dl;
|
||||
|
||||
// Dequantize both 128 long blocks
|
||||
// 32 qs values per 128 long block
|
||||
@ -1088,7 +1129,7 @@ impl GgmlType for BlockQ3K {
|
||||
for (scale_index, scale_scoped_y) in
|
||||
shift_scoped_y.chunks_exact_mut(16).enumerate()
|
||||
{
|
||||
dl = d_all * (scales[is] as f32 - 32.0);
|
||||
let dl = d_all * (scales[is] as f32 - 32.0);
|
||||
for (i, inner_y) in scale_scoped_y.iter_mut().enumerate() {
|
||||
let new_y = dl
|
||||
* (((qs[i + 16 * scale_index] >> shift) & 3) as i8
|
||||
@ -1126,6 +1167,13 @@ impl GgmlType for BlockQ4K {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q4k_q8k(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "simd128")]
|
||||
return super::simd128::vec_dot_q4k_q8k(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
@ -1312,6 +1360,10 @@ impl GgmlType for BlockQ5K {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q5k_q8k(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
@ -1529,6 +1581,13 @@ impl GgmlType for BlockQ6K {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q6k_q8k(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "simd128")]
|
||||
return super::simd128::vec_dot_q6k_q8k(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
@ -1697,8 +1756,38 @@ impl GgmlType for BlockQ8K {
|
||||
const BLCK_SIZE: usize = QK_K;
|
||||
type VecDotType = BlockQ8K;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
unreachable!()
|
||||
#[allow(unreachable_code)]
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
#[cfg(target_feature = "avx")]
|
||||
return super::avx::vec_dot_q8k_q8k(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q8k_q8k(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "simd128")]
|
||||
return super::simd128::vec_dot_q8k_q8k(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
let qk = QK_K;
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
|
||||
}
|
||||
|
||||
// Generic implementation.
|
||||
let mut sumf = 0f32;
|
||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||
let sum_i = xs
|
||||
.qs
|
||||
.iter()
|
||||
.zip(ys.qs.iter())
|
||||
.map(|(&x, &y)| x as i32 * y as i32)
|
||||
.sum::<i32>();
|
||||
sumf += sum_i as f32 * xs.d * ys.d
|
||||
}
|
||||
Ok(sumf)
|
||||
}
|
||||
|
||||
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
|
||||
@ -1804,6 +1893,10 @@ impl GgmlType for f32 {
|
||||
type VecDotType = f32;
|
||||
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if xs.len() < n {
|
||||
crate::bail!("size mismatch {} < {n}", xs.len())
|
||||
}
|
||||
@ -1838,6 +1931,10 @@ impl GgmlType for f16 {
|
||||
type VecDotType = f16;
|
||||
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if xs.len() < n {
|
||||
crate::bail!("size mismatch {} < {n}", xs.len())
|
||||
}
|
||||
|
@ -7,6 +7,8 @@ pub mod gguf_file;
|
||||
pub mod k_quants;
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub mod neon;
|
||||
#[cfg(target_feature = "simd128")]
|
||||
pub mod simd128;
|
||||
pub mod utils;
|
||||
|
||||
pub use k_quants::GgmlType;
|
||||
@ -229,20 +231,40 @@ impl QTensor {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct QMatMul(std::sync::Arc<QTensor>);
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum QMatMul {
|
||||
QTensor(std::sync::Arc<QTensor>),
|
||||
Tensor(Tensor),
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static DEQUANTIZE_ALL: bool = {
|
||||
match std::env::var("CANDLE_DEQUANTIZE_ALL") {
|
||||
Ok(s) => {
|
||||
!s.is_empty() && s != "0"
|
||||
},
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl QMatMul {
|
||||
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Self {
|
||||
Self(qtensor)
|
||||
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
|
||||
let dequantize = match qtensor.dtype() {
|
||||
GgmlDType::F32 | GgmlDType::F16 => true,
|
||||
_ => DEQUANTIZE_ALL.with(|b| *b),
|
||||
};
|
||||
let t = if dequantize {
|
||||
let tensor = qtensor.dequantize(&Device::Cpu)?;
|
||||
Self::Tensor(tensor)
|
||||
} else {
|
||||
Self::QTensor(qtensor)
|
||||
};
|
||||
Ok(t)
|
||||
}
|
||||
|
||||
pub fn from_qtensor(qtensor: QTensor) -> Self {
|
||||
Self(std::sync::Arc::new(qtensor))
|
||||
}
|
||||
|
||||
pub fn inner(&self) -> &std::sync::Arc<QTensor> {
|
||||
&self.0
|
||||
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
|
||||
Self::from_arc(std::sync::Arc::new(qtensor))
|
||||
}
|
||||
}
|
||||
|
||||
@ -287,6 +309,16 @@ impl crate::CustomOp1 for QTensor {
|
||||
|
||||
impl QMatMul {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply_op1_no_bwd(self.0.as_ref())
|
||||
match self {
|
||||
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
||||
Self::Tensor(w) => {
|
||||
let w = match *xs.dims() {
|
||||
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
|
||||
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
|
||||
_ => w.t()?,
|
||||
};
|
||||
xs.matmul(&w)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -148,6 +148,35 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
let qk = QK_K;
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
|
||||
}
|
||||
|
||||
let mut sumf = 0f32;
|
||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||
unsafe {
|
||||
let mut sum_i = vdupq_n_s32(0);
|
||||
let scale = xs.d * ys.d;
|
||||
let xs = xs.qs.as_ptr();
|
||||
let ys = ys.qs.as_ptr();
|
||||
for i in (0..QK_K).step_by(16) {
|
||||
let xs = vld1q_s8(xs.add(i));
|
||||
let ys = vld1q_s8(ys.add(i));
|
||||
let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys));
|
||||
let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys));
|
||||
|
||||
let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up));
|
||||
sum_i = vaddq_s32(sum_i, xy)
|
||||
}
|
||||
sumf += vaddvq_s32(sum_i) as f32 * scale
|
||||
}
|
||||
}
|
||||
Ok(sumf)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
|
427
candle-core/src/quantized/simd128.rs
Normal file
427
candle-core/src/quantized/simd128.rs
Normal file
@ -0,0 +1,427 @@
|
||||
use super::k_quants::{BlockQ2K, BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K};
|
||||
use crate::Result;
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
use half::f16;
|
||||
|
||||
use core::arch::wasm32::*;
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
let nb = n / QK8_0;
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||
}
|
||||
unsafe {
|
||||
let mut acc = f32x4_splat(0.0f32);
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let x1234 = v128_load(x.qs.as_ptr() as *const v128);
|
||||
let x12 = v128_and(x1234, u8x16_splat(0x0F));
|
||||
let x12 = i8x16_sub(x12, i8x16_splat(8));
|
||||
let x34 = u8x16_shr(x1234, 4);
|
||||
let x34 = i8x16_sub(x34, i8x16_splat(8));
|
||||
|
||||
let x1 = i16x8_extend_low_i8x16(x12);
|
||||
let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr());
|
||||
let sum_xy = i32x4_dot_i16x8(x1, y1);
|
||||
|
||||
let x2 = i16x8_extend_high_i8x16(x12);
|
||||
let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2));
|
||||
|
||||
let x3 = i16x8_extend_low_i8x16(x34);
|
||||
let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3));
|
||||
|
||||
let x4 = i16x8_extend_high_i8x16(x34);
|
||||
let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4));
|
||||
|
||||
let sum_xy = f32x4_convert_i32x4(sum_xy);
|
||||
|
||||
// f32x4_relaxed_madd is nightly only.
|
||||
let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d));
|
||||
let scaled = f32x4_mul(sum_xy, d);
|
||||
acc = f32x4_add(acc, scaled)
|
||||
}
|
||||
let res = f32x4_extract_lane::<0>(acc)
|
||||
+ f32x4_extract_lane::<1>(acc)
|
||||
+ f32x4_extract_lane::<2>(acc)
|
||||
+ f32x4_extract_lane::<3>(acc);
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
let nb = n / QK8_0;
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
|
||||
}
|
||||
unsafe {
|
||||
let mut acc = f32x4_splat(0.0f32);
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let x1 = i16x8_load_extend_i8x8(x.qs.as_ptr());
|
||||
let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr());
|
||||
let sum_xy = i32x4_dot_i16x8(x1, y1);
|
||||
|
||||
let x2 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(8));
|
||||
let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2));
|
||||
|
||||
let x3 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(16));
|
||||
let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3));
|
||||
|
||||
let x4 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(24));
|
||||
let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4));
|
||||
|
||||
let sum_xy = f32x4_convert_i32x4(sum_xy);
|
||||
|
||||
// f32x4_relaxed_madd is nightly only.
|
||||
let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d));
|
||||
let scaled = f32x4_mul(sum_xy, d);
|
||||
acc = f32x4_add(acc, scaled)
|
||||
}
|
||||
let res = f32x4_extract_lane::<0>(acc)
|
||||
+ f32x4_extract_lane::<1>(acc)
|
||||
+ f32x4_extract_lane::<2>(acc)
|
||||
+ f32x4_extract_lane::<3>(acc);
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
unsafe {
|
||||
let mut sumf = f32x4_splat(0f32);
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let mut q2: &[_] = &x.qs;
|
||||
let mut q8: &[_] = &y.qs;
|
||||
let sc = &x.scales;
|
||||
|
||||
let mut summs = i32x4_splat(0);
|
||||
for i in (0..(QK_K / 16)).step_by(4) {
|
||||
let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(i));
|
||||
let scales = i32x4_shr(
|
||||
i32x4(
|
||||
sc[i] as i32,
|
||||
sc[i + 1] as i32,
|
||||
sc[i + 2] as i32,
|
||||
sc[i + 3] as i32,
|
||||
),
|
||||
4,
|
||||
);
|
||||
summs = i32x4_add(summs, i32x4_mul(bsums, scales))
|
||||
}
|
||||
let summs = f32x4_convert_i32x4(summs);
|
||||
|
||||
let dall = y.d * x.d.to_f32();
|
||||
let dmin = y.d * x.dmin.to_f32();
|
||||
|
||||
let mut isum = i32x4_splat(0);
|
||||
let mut is = 0;
|
||||
for _ in 0..(QK_K / 128) {
|
||||
let mut shift = 0;
|
||||
for _ in 0..4 {
|
||||
let d = (sc[is] & 0xF) as i32;
|
||||
is += 1;
|
||||
let mut isuml = i16x8_splat(0);
|
||||
for l in (0..16).step_by(8) {
|
||||
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
|
||||
let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
|
||||
let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
|
||||
isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
|
||||
}
|
||||
let dd = i32x4_splat(d);
|
||||
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
|
||||
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
|
||||
let d = (sc[is] & 0xF) as i32;
|
||||
is += 1;
|
||||
let mut isuml = i16x8_splat(0);
|
||||
for l in (16..32).step_by(8) {
|
||||
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
|
||||
let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
|
||||
let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
|
||||
isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
|
||||
}
|
||||
let dd = i32x4_splat(d);
|
||||
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
|
||||
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
|
||||
shift += 2;
|
||||
// adjust the indexing
|
||||
q8 = &q8[32..];
|
||||
}
|
||||
// adjust the indexing
|
||||
q2 = &q2[32..];
|
||||
}
|
||||
let isum = f32x4_convert_i32x4(isum);
|
||||
sumf = f32x4_add(
|
||||
sumf,
|
||||
f32x4_sub(
|
||||
f32x4_mul(isum, f32x4_splat(dall)),
|
||||
f32x4_mul(summs, f32x4_splat(dmin)),
|
||||
),
|
||||
);
|
||||
}
|
||||
let sumf = f32x4_extract_lane::<0>(sumf)
|
||||
+ f32x4_extract_lane::<1>(sumf)
|
||||
+ f32x4_extract_lane::<2>(sumf)
|
||||
+ f32x4_extract_lane::<3>(sumf);
|
||||
Ok(sumf)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
|
||||
const KMASK1: u32 = 0x3f3f3f3f;
|
||||
const KMASK2: u32 = 0x0f0f0f0f;
|
||||
const KMASK3: u32 = 0x03030303;
|
||||
|
||||
let mut utmp: [u32; 4] = [0; 4];
|
||||
let mut scales: [u8; 8] = [0; 8];
|
||||
let mut mins: [u8; 8] = [0; 8];
|
||||
|
||||
let mut aux8: [u8; QK_K] = [0; QK_K];
|
||||
let mut sums = f32x4_splat(0f32);
|
||||
unsafe {
|
||||
for (y, x) in ys.iter().zip(xs.iter()) {
|
||||
let q4 = &x.qs;
|
||||
let q8 = &y.qs;
|
||||
|
||||
for j in 0..QK_K / 64 {
|
||||
let q4_1 = v128_load(q4.as_ptr().add(32 * j) as *const v128);
|
||||
let q4_2 = v128_load(q4.as_ptr().add(32 * j + 16) as *const v128);
|
||||
v128_store(
|
||||
aux8.as_mut_ptr().add(64 * j) as *mut v128,
|
||||
v128_and(q4_1, u8x16_splat(0x0F)),
|
||||
);
|
||||
v128_store(
|
||||
aux8.as_mut_ptr().add(64 * j + 16) as *mut v128,
|
||||
v128_and(q4_2, u8x16_splat(0x0F)),
|
||||
);
|
||||
v128_store(
|
||||
aux8.as_mut_ptr().add(64 * j + 32) as *mut v128,
|
||||
u8x16_shr(q4_1, 4),
|
||||
);
|
||||
v128_store(
|
||||
aux8.as_mut_ptr().add(64 * j + 48) as *mut v128,
|
||||
u8x16_shr(q4_2, 4),
|
||||
);
|
||||
}
|
||||
|
||||
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
|
||||
|
||||
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
|
||||
let uaux = utmp[1] & KMASK1;
|
||||
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
|
||||
utmp[2] = uaux;
|
||||
utmp[0] &= KMASK1;
|
||||
|
||||
//extract scales and mins
|
||||
LittleEndian::write_u32_into(&utmp[0..2], &mut scales);
|
||||
LittleEndian::write_u32_into(&utmp[2..4], &mut mins);
|
||||
|
||||
let mut sumi = i32x4_splat(0);
|
||||
for j in (0..QK_K / 16).step_by(4) {
|
||||
let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(j));
|
||||
let (m1, m2) = (mins[j / 2] as i32, mins[j / 2 + 1] as i32);
|
||||
let mins = i32x4(m1, m1, m2, m2);
|
||||
sumi = i32x4_add(sumi, i32x4_mul(bsums, mins));
|
||||
}
|
||||
|
||||
let mut aux32 = i32x4_splat(0i32);
|
||||
for (scale_i, scale) in scales.iter().enumerate() {
|
||||
let scale = i32x4_splat(*scale as i32);
|
||||
for j in 0..4 {
|
||||
let i = 32 * scale_i + 8 * j;
|
||||
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(i));
|
||||
let aux8 = i16x8_load_extend_u8x8(aux8.as_ptr().add(i));
|
||||
let aux16 = i16x8_mul(q8, aux8);
|
||||
aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_low_i16x8(aux16)));
|
||||
aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_high_i16x8(aux16)));
|
||||
}
|
||||
}
|
||||
let aux32 = f32x4_convert_i32x4(aux32);
|
||||
let d = f32x4_splat(x.d.to_f32() * y.d);
|
||||
sums = f32x4_add(sums, f32x4_mul(aux32, d));
|
||||
let dmin = x.dmin.to_f32() * y.d;
|
||||
let dmin = f32x4_splat(dmin);
|
||||
let sumi = f32x4_convert_i32x4(sumi);
|
||||
sums = f32x4_sub(sums, f32x4_mul(sumi, dmin));
|
||||
}
|
||||
let sums = f32x4_extract_lane::<0>(sums)
|
||||
+ f32x4_extract_lane::<1>(sums)
|
||||
+ f32x4_extract_lane::<2>(sums)
|
||||
+ f32x4_extract_lane::<3>(sums);
|
||||
Ok(sums)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
|
||||
let mut aux8 = [0i8; QK_K];
|
||||
unsafe {
|
||||
let mut sums = f32x4_splat(0f32);
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let q4 = &x.ql;
|
||||
let qh = &x.qh;
|
||||
let q8 = &y.qs;
|
||||
let mut aux32 = f32x4_splat(0f32);
|
||||
|
||||
for j in (0..QK_K).step_by(128) {
|
||||
let aux8 = aux8.as_mut_ptr().add(j);
|
||||
let q4 = &q4.as_ptr().add(j / 2);
|
||||
let qh = &qh.as_ptr().add(j / 4);
|
||||
for l in (0..32).step_by(16) {
|
||||
// aux8[l] = (((q4[l] & 0xF) | ((qh[l] & 3) << 4)) as i32 - 32) as i8;
|
||||
let a8 = v128_or(
|
||||
v128_and(v128_load(q4.add(l) as *const v128), u8x16_splat(0xF)),
|
||||
u8x16_shl(
|
||||
v128_and(v128_load(qh.add(l) as *const v128), u8x16_splat(3)),
|
||||
4,
|
||||
),
|
||||
);
|
||||
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||
v128_store(
|
||||
aux8.add(l) as *mut v128,
|
||||
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||
);
|
||||
|
||||
// aux8[l + 32] =
|
||||
// (((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i32 - 32) as i8;
|
||||
let a8 = v128_or(
|
||||
v128_and(v128_load(q4.add(l + 32) as *const v128), u8x16_splat(0xF)),
|
||||
u8x16_shl(
|
||||
v128_and(
|
||||
u8x16_shr(v128_load(qh.add(l) as *const v128), 2),
|
||||
u8x16_splat(3),
|
||||
),
|
||||
4,
|
||||
),
|
||||
);
|
||||
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||
v128_store(
|
||||
aux8.add(l + 32) as *mut v128,
|
||||
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||
);
|
||||
|
||||
// aux8[l + 64] = (((q4[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i32 - 32) as i8;
|
||||
let a8 = v128_or(
|
||||
u8x16_shr(v128_load(q4.add(l) as *const v128), 4),
|
||||
u8x16_shl(
|
||||
v128_and(
|
||||
u8x16_shr(v128_load(qh.add(l) as *const v128), 4),
|
||||
u8x16_splat(3),
|
||||
),
|
||||
4,
|
||||
),
|
||||
);
|
||||
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||
v128_store(
|
||||
aux8.add(l + 64) as *mut v128,
|
||||
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||
);
|
||||
|
||||
// aux8[l + 96] =
|
||||
// (((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i32 - 32) as i8;
|
||||
let a8 = v128_or(
|
||||
u8x16_shr(v128_load(q4.add(l + 32) as *const v128), 4),
|
||||
u8x16_shl(
|
||||
v128_and(
|
||||
u8x16_shr(v128_load(qh.add(l) as *const v128), 6),
|
||||
u8x16_splat(3),
|
||||
),
|
||||
4,
|
||||
),
|
||||
);
|
||||
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||
v128_store(
|
||||
aux8.add(l + 96) as *mut v128,
|
||||
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
for (j, &scale) in x.scales.iter().enumerate() {
|
||||
let scale = f32x4_splat(scale as f32);
|
||||
for offset in [0, 8] {
|
||||
let aux16 = i16x8_mul(
|
||||
i16x8_load_extend_i8x8(q8.as_ptr().add(16 * j + offset)),
|
||||
i16x8_load_extend_i8x8(aux8.as_ptr().add(16 * j + offset)),
|
||||
);
|
||||
aux32 = f32x4_add(
|
||||
aux32,
|
||||
f32x4_mul(f32x4_convert_i32x4(i32x4_extend_low_i16x8(aux16)), scale),
|
||||
);
|
||||
aux32 = f32x4_add(
|
||||
aux32,
|
||||
f32x4_mul(f32x4_convert_i32x4(i32x4_extend_high_i16x8(aux16)), scale),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let d = f32x4_splat(x.d.to_f32() * y.d);
|
||||
sums = f32x4_add(sums, f32x4_mul(aux32, d));
|
||||
}
|
||||
let sums = f32x4_extract_lane::<0>(sums)
|
||||
+ f32x4_extract_lane::<1>(sums)
|
||||
+ f32x4_extract_lane::<2>(sums)
|
||||
+ f32x4_extract_lane::<3>(sums);
|
||||
Ok(sums)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
let qk = QK_K;
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut acc = f32x4_splat(0.0f32);
|
||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||
let x_qs = xs.qs.as_ptr();
|
||||
let y_qs = ys.qs.as_ptr();
|
||||
let mut sumi = i32x4_splat(0);
|
||||
for j in (0..QK_K).step_by(8) {
|
||||
let xs = i16x8_load_extend_i8x8(x_qs.add(j));
|
||||
let ys = i16x8_load_extend_i8x8(y_qs.add(j));
|
||||
let sum_xy = i32x4_dot_i16x8(xs, ys);
|
||||
sumi = i32x4_add(sumi, sum_xy)
|
||||
}
|
||||
let d = f32x4_splat(xs.d * ys.d);
|
||||
acc = f32x4_add(acc, f32x4_mul(f32x4_convert_i32x4(sumi), d))
|
||||
}
|
||||
let res = f32x4_extract_lane::<0>(acc)
|
||||
+ f32x4_extract_lane::<1>(acc)
|
||||
+ f32x4_extract_lane::<2>(acc)
|
||||
+ f32x4_extract_lane::<3>(acc);
|
||||
Ok(res)
|
||||
}
|
||||
}
|
@ -17,7 +17,7 @@ pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
|
||||
let expected_blocks = xs.len() / block_size;
|
||||
let actual_blocks = ys.len();
|
||||
|
||||
//validate that the input is the right size
|
||||
// Validate that the input is the right size
|
||||
if expected_blocks != actual_blocks {
|
||||
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
|
||||
}
|
||||
@ -37,12 +37,12 @@ pub(super) fn group_for_dequantization<'a, 'b, T: super::k_quants::GgmlType>(
|
||||
|
||||
let actual_output_len = ys.len();
|
||||
let expected_output_len = xs.len() * block_size;
|
||||
//validate that the output is the right size
|
||||
// Validate that the output is the right size
|
||||
if expected_output_len != actual_output_len {
|
||||
crate::bail!("dequantize {dtype:?}: ys (len = {actual_output_len}) does not match the expected length of {expected_output_len}!")
|
||||
}
|
||||
|
||||
//zip the blocks and outputs together
|
||||
// Zip the blocks and outputs together
|
||||
Ok(xs.iter().zip(ys.chunks_exact_mut(block_size)).collect())
|
||||
}
|
||||
|
||||
|
@ -251,6 +251,134 @@ pub fn save<K: AsRef<str> + Ord + std::fmt::Display, P: AsRef<Path>>(
|
||||
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
|
||||
}
|
||||
|
||||
#[derive(yoke::Yokeable)]
|
||||
struct SafeTensors_<'a>(SafeTensors<'a>);
|
||||
|
||||
pub struct MmapedSafetensors {
|
||||
safetensors: Vec<yoke::Yoke<SafeTensors_<'static>, memmap2::Mmap>>,
|
||||
routing: Option<HashMap<String, usize>>,
|
||||
}
|
||||
|
||||
impl MmapedSafetensors {
|
||||
/// Creates a wrapper around a memory mapped file and deserialize the safetensors header.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The unsafe is inherited from [`memmap2::MmapOptions`].
|
||||
pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
|
||||
let p = p.as_ref();
|
||||
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
|
||||
let file = memmap2::MmapOptions::new()
|
||||
.map(&file)
|
||||
.map_err(|e| Error::from(e).with_path(p))?;
|
||||
let safetensors = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
|
||||
file,
|
||||
|data: &[u8]| {
|
||||
let st = safetensors::SafeTensors::deserialize(data)
|
||||
.map_err(|e| Error::from(e).with_path(p))?;
|
||||
Ok::<_, Error>(SafeTensors_(st))
|
||||
},
|
||||
)?;
|
||||
Ok(Self {
|
||||
safetensors: vec![safetensors],
|
||||
routing: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates a wrapper around multiple memory mapped file and deserialize the safetensors headers.
|
||||
///
|
||||
/// If a tensor name appears in multiple files, the last entry is returned.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The unsafe is inherited from [`memmap2::MmapOptions`].
|
||||
pub unsafe fn multi<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
|
||||
let mut routing = HashMap::new();
|
||||
let mut safetensors = vec![];
|
||||
for (index, p) in paths.iter().enumerate() {
|
||||
let p = p.as_ref();
|
||||
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
|
||||
let file = memmap2::MmapOptions::new()
|
||||
.map(&file)
|
||||
.map_err(|e| Error::from(e).with_path(p))?;
|
||||
let data = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
|
||||
file,
|
||||
|data: &[u8]| {
|
||||
let st = safetensors::SafeTensors::deserialize(data)
|
||||
.map_err(|e| Error::from(e).with_path(p))?;
|
||||
Ok::<_, Error>(SafeTensors_(st))
|
||||
},
|
||||
)?;
|
||||
for k in data.get().0.names() {
|
||||
routing.insert(k.to_string(), index);
|
||||
}
|
||||
safetensors.push(data)
|
||||
}
|
||||
Ok(Self {
|
||||
safetensors,
|
||||
routing: Some(routing),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
|
||||
self.get(name)?.load(dev)
|
||||
}
|
||||
|
||||
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
|
||||
let mut tensors = vec![];
|
||||
for safetensors in self.safetensors.iter() {
|
||||
tensors.push(safetensors.get().0.tensors())
|
||||
}
|
||||
tensors.into_iter().flatten().collect()
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
|
||||
let index = match &self.routing {
|
||||
None => 0,
|
||||
Some(routing) => {
|
||||
let index = routing.get(name).ok_or_else(|| {
|
||||
Error::CannotFindTensor {
|
||||
path: name.to_string(),
|
||||
}
|
||||
.bt()
|
||||
})?;
|
||||
*index
|
||||
}
|
||||
};
|
||||
Ok(self.safetensors[index].get().0.tensor(name)?)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BufferedSafetensors {
|
||||
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
|
||||
}
|
||||
|
||||
impl BufferedSafetensors {
|
||||
/// Creates a wrapper around a binary buffer and deserialize the safetensors header.
|
||||
pub fn new(buffer: Vec<u8>) -> Result<Self> {
|
||||
let safetensors = yoke::Yoke::<SafeTensors_<'static>, Vec<u8>>::try_attach_to_cart(
|
||||
buffer,
|
||||
|data: &[u8]| {
|
||||
let st = safetensors::SafeTensors::deserialize(data)?;
|
||||
Ok::<_, Error>(SafeTensors_(st))
|
||||
},
|
||||
)?;
|
||||
Ok(Self { safetensors })
|
||||
}
|
||||
|
||||
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
|
||||
self.get(name)?.load(dev)
|
||||
}
|
||||
|
||||
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
|
||||
self.safetensors.get().0.tensors()
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
|
||||
Ok(self.safetensors.get().0.tensor(name)?)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MmapedFile {
|
||||
path: std::path::PathBuf,
|
||||
inner: memmap2::Mmap,
|
||||
|
@ -511,154 +511,119 @@ impl ShapeWithOneHole for ((),) {
|
||||
}
|
||||
}
|
||||
|
||||
fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
|
||||
if prod_d == 0 {
|
||||
crate::bail!("cannot reshape tensor of {el_count} elements to {s:?}")
|
||||
}
|
||||
if el_count % prod_d != 0 {
|
||||
crate::bail!("cannot reshape tensor with {el_count} elements to {s:?}")
|
||||
}
|
||||
Ok(el_count / prod_d)
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1) = self;
|
||||
if el_count % d1 != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
|
||||
}
|
||||
Ok((el_count / d1, d1).into())
|
||||
Ok((hole_size(el_count, d1, &self)?, d1).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, ()) = self;
|
||||
if el_count % d1 != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
|
||||
}
|
||||
Ok((d1, el_count / d1).into())
|
||||
Ok((d1, hole_size(el_count, d1, &self)?).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1, d2) = self;
|
||||
let d = d1 * d2;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((el_count / d, d1, d2).into())
|
||||
Ok((hole_size(el_count, d1 * d2, &self)?, d1, d2).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, (), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, (), d2) = self;
|
||||
let d = d1 * d2;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, el_count / d, d2).into())
|
||||
Ok((d1, hole_size(el_count, d1 * d2, &self)?, d2).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, ()) = self;
|
||||
let d = d1 * d2;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, el_count / d).into())
|
||||
Ok((d1, d2, hole_size(el_count, d1 * d2, &self)?).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize, usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1, d2, d3) = self;
|
||||
let d = d1 * d2 * d3;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((el_count / d, d1, d2, d3).into())
|
||||
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
|
||||
Ok((d, d1, d2, d3).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, (), usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, (), d2, d3) = self;
|
||||
let d = d1 * d2 * d3;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, el_count / d, d2, d3).into())
|
||||
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
|
||||
Ok((d1, d, d2, d3).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, (), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, (), d3) = self;
|
||||
let d = d1 * d2 * d3;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, el_count / d, d3).into())
|
||||
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
|
||||
Ok((d1, d2, d, d3).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, d3, ()) = self;
|
||||
let d = d1 * d2 * d3;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, d3, el_count / d).into())
|
||||
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
|
||||
Ok((d1, d2, d3, d).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1, d2, d3, d4) = self;
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((el_count / d, d1, d2, d3, d4).into())
|
||||
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
||||
Ok((d, d1, d2, d3, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, (), d2, d3, d4) = self;
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, el_count / d, d2, d3, d4).into())
|
||||
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
||||
Ok((d1, d, d2, d3, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, (), d3, d4) = self;
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, el_count / d, d3, d4).into())
|
||||
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
||||
Ok((d1, d2, d, d3, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, d3, (), d4) = self;
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, d3, el_count / d, d4).into())
|
||||
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
||||
Ok((d1, d2, d3, d, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, d3, d4, ()) = self;
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, d3, d4, el_count / d).into())
|
||||
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
||||
Ok((d1, d2, d3, d4, d).into())
|
||||
}
|
||||
}
|
||||
|
@ -177,14 +177,9 @@ impl Tensor {
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let none = BackpropOp::none();
|
||||
if is_variable {
|
||||
let shape = shape.into();
|
||||
let storage = device.ones(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
} else {
|
||||
let storage = device.ones(&crate::shape::SCALAR, dtype)?;
|
||||
from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
|
||||
}
|
||||
let shape = shape.into();
|
||||
let storage = device.ones(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
}
|
||||
|
||||
/// Creates a new tensor filled with ones.
|
||||
@ -222,14 +217,9 @@ impl Tensor {
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let none = BackpropOp::none();
|
||||
if is_variable {
|
||||
let shape = shape.into();
|
||||
let storage = device.zeros(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
} else {
|
||||
let storage = device.zeros(&crate::shape::SCALAR, dtype)?;
|
||||
from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
|
||||
}
|
||||
let shape = shape.into();
|
||||
let storage = device.zeros(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
}
|
||||
|
||||
/// Creates a new tensor filled with zeros.
|
||||
@ -459,7 +449,7 @@ impl Tensor {
|
||||
|
||||
/// Returns true if the computation graph should track this op, that is if it is
|
||||
/// a variable or if it has some variable as dependencies.
|
||||
pub(crate) fn track_op(&self) -> bool {
|
||||
pub fn track_op(&self) -> bool {
|
||||
self.is_variable || self.op.is_some()
|
||||
}
|
||||
|
||||
@ -489,7 +479,21 @@ impl Tensor {
|
||||
unary_op!(sqr, Sqr);
|
||||
unary_op!(sqrt, Sqrt);
|
||||
unary_op!(gelu, Gelu);
|
||||
unary_op!(gelu_erf, GeluErf);
|
||||
unary_op!(erf, Erf);
|
||||
unary_op!(relu, Relu);
|
||||
unary_op!(ceil, Ceil);
|
||||
unary_op!(floor, Floor);
|
||||
unary_op!(round, Round);
|
||||
|
||||
/// Round element of the input tensor to the nearest integer.
|
||||
///
|
||||
/// If the number of decimals is negative, it specifies the number of positions to the left of
|
||||
/// the decimal point.
|
||||
pub fn round_to(&self, decimals: i32) -> Result<Self> {
|
||||
let mult = 10f64.powi(decimals);
|
||||
(self * mult)?.round()? * (1f64 / mult)
|
||||
}
|
||||
|
||||
/// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple
|
||||
/// dimensions, an error is returned instead.
|
||||
@ -536,6 +540,73 @@ impl Tensor {
|
||||
Ok(inp)
|
||||
}
|
||||
|
||||
/// Creates grids of coordinates specified by the 1D inputs.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `args` - A slice of 1D tensors.
|
||||
/// * `xy_indexing` - Whether to use xy indexing or ij indexing. If xy is selected, the
|
||||
/// first dimension corresponds to the cardinality of the second input and the second
|
||||
/// dimension corresponds to the cardinality of the first input. If ij is selected, the
|
||||
/// dimensions are in the same order as the cardinality of the inputs.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device, Shape};
|
||||
/// let x = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?;
|
||||
/// let y = Tensor::new(&[4f32, 5., 6.], &Device::Cpu)?;
|
||||
///
|
||||
/// let grids_xy = Tensor::meshgrid(&[&x, &y], true)?;
|
||||
///
|
||||
/// assert_eq!(grids_xy.len(), 2);
|
||||
/// assert_eq!(grids_xy[0].dims(), &[3, 3]);
|
||||
///
|
||||
/// assert_eq!(grids_xy[0].to_vec2::<f32>()?, &[[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]]);
|
||||
/// assert_eq!(grids_xy[1].to_vec2::<f32>()?, &[[4., 4., 4.], [5., 5., 5.], [6., 6., 6.]]);
|
||||
///
|
||||
/// let grids_ij = Tensor::meshgrid(&[&x, &y], false)?;
|
||||
///
|
||||
/// assert_eq!(grids_ij[0].to_vec2::<f32>()?, &[[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]]);
|
||||
/// assert_eq!(grids_ij[1].to_vec2::<f32>()?, &[[4., 5., 6.], [4., 5., 6.], [4., 5., 6.]]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// * Will return `Err` if `args` contains less than 2 tensors.
|
||||
///
|
||||
pub fn meshgrid<A: AsRef<Tensor>>(args: &[A], xy_indexing: bool) -> Result<Vec<Self>> {
|
||||
if args.len() <= 1 {
|
||||
Err(Error::OpRequiresAtLeastTwoTensors { op: "meshgrid" }.bt())?
|
||||
}
|
||||
let args: Vec<_> = if xy_indexing {
|
||||
args.iter().rev().collect()
|
||||
} else {
|
||||
args.iter().collect()
|
||||
};
|
||||
|
||||
let mut shape = Vec::with_capacity(args.len());
|
||||
for arg in args.iter() {
|
||||
shape.push(arg.as_ref().dims1()?)
|
||||
}
|
||||
|
||||
let mut grids = Vec::with_capacity(args.len());
|
||||
for idx in 0..args.len() {
|
||||
let mut ones = vec![1usize; args.len()];
|
||||
ones[idx] = shape[idx];
|
||||
let arg = args[idx].as_ref().reshape(ones)?;
|
||||
let mut repeats = shape.clone();
|
||||
repeats[idx] = 1;
|
||||
let repeated_tensor = arg.repeat(repeats)?;
|
||||
grids.push(repeated_tensor);
|
||||
}
|
||||
if xy_indexing {
|
||||
grids.reverse();
|
||||
}
|
||||
Ok(grids)
|
||||
}
|
||||
|
||||
/// This operation multiplies the input tensor by `mul` then adds `add` and return the result.
|
||||
/// The input values `mul` and `add` are casted to the appropriate type so some rounding might
|
||||
/// be performed.
|
||||
@ -611,15 +682,23 @@ impl Tensor {
|
||||
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
|
||||
let dims = self.dims();
|
||||
let dim = dim.to_index(self.shape(), "narrow")?;
|
||||
if start + len > dims[dim] {
|
||||
Err(Error::NarrowInvalidArgs {
|
||||
shape: self.shape().clone(),
|
||||
dim,
|
||||
start,
|
||||
len,
|
||||
msg: "start + len > dim_len",
|
||||
}
|
||||
.bt())?
|
||||
let err = |msg| {
|
||||
Err::<(), _>(
|
||||
Error::NarrowInvalidArgs {
|
||||
shape: self.shape().clone(),
|
||||
dim,
|
||||
start,
|
||||
len,
|
||||
msg,
|
||||
}
|
||||
.bt(),
|
||||
)
|
||||
};
|
||||
if start > dims[dim] {
|
||||
err("start > dim_len")?
|
||||
}
|
||||
if start.saturating_add(len) > dims[dim] {
|
||||
err("start + len > dim_len")?
|
||||
}
|
||||
if start == 0 && dims[dim] == len {
|
||||
Ok(self.clone())
|
||||
@ -1130,6 +1209,74 @@ impl Tensor {
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
/// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
|
||||
pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "slice-scatter")?;
|
||||
if dim == 0 {
|
||||
self.slice_scatter0(src, start)
|
||||
} else {
|
||||
// TODO: Maybe we want to add a more efficient implementation at some point.
|
||||
self.transpose(0, dim)?
|
||||
.slice_scatter0(&src.transpose(0, dim)?, start)?
|
||||
.transpose(0, dim)
|
||||
}
|
||||
}
|
||||
|
||||
/// Embeds the values of the `src` tensor into the `self` tensor on the first dimension.
|
||||
pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result<Self> {
|
||||
if self.dtype() != src.dtype() {
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: self.dtype(),
|
||||
rhs: src.dtype(),
|
||||
op: "slice-scatter",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if self.device().location() != src.device.location() {
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: self.device().location(),
|
||||
rhs: src.device().location(),
|
||||
op: "slice-scatter",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if self.rank() != src.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: self.rank(),
|
||||
got: src.rank(),
|
||||
shape: src.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let shape_ok =
|
||||
self.dims()
|
||||
.iter()
|
||||
.zip(src.dims().iter())
|
||||
.enumerate()
|
||||
.all(|(dim_idx, (&d1, &d2))| {
|
||||
if 0 == dim_idx {
|
||||
d2 + start <= d1
|
||||
} else {
|
||||
d1 == d2
|
||||
}
|
||||
});
|
||||
if !shape_ok {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "slice-scatter (self, src)",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: src.shape().clone(),
|
||||
})?
|
||||
}
|
||||
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
let offset = start * src.dims()[1..].iter().product::<usize>();
|
||||
src.storage()
|
||||
.copy_strided_src(&mut storage, offset, src.layout())?;
|
||||
let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start));
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
/// Accumulate element from `source` at indexes `indexes` and add them to `self`.
|
||||
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "index-add")?;
|
||||
@ -1518,6 +1665,24 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the sub-tensor fixing the index at `index` on the dimension `dim`.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||
/// let t = tensor.get_on_dim(1, 0)?;
|
||||
/// assert_eq!(t.to_vec1::<f32>()?, &[0., 2., 4.]);
|
||||
/// let t = tensor.get_on_dim(1, 1)?;
|
||||
/// assert_eq!(t.to_vec1::<f32>()?, &[1., 3., 5.]);
|
||||
/// let t = tensor.get_on_dim(0, 1)?;
|
||||
/// assert_eq!(t.to_vec1::<f32>()?, &[2., 3.]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn get_on_dim<D: Dim>(&self, dim: D, index: usize) -> Result<Tensor> {
|
||||
let dim = dim.to_index(self.shape(), "get_on_dim")?;
|
||||
self.narrow(dim, index, 1)?.squeeze(dim)
|
||||
}
|
||||
|
||||
/// Returns a tensor that is a transposed version of the input, the two last dimensions of the
|
||||
/// input are swapped.
|
||||
///
|
||||
@ -1546,6 +1711,9 @@ impl Tensor {
|
||||
pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
|
||||
let dim1 = dim1.to_index(self.shape(), "transpose")?;
|
||||
let dim2 = dim2.to_index(self.shape(), "transpose")?;
|
||||
if dim1 == dim2 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
@ -1907,6 +2075,34 @@ impl Tensor {
|
||||
for arg in args {
|
||||
arg.as_ref().check_dim(dim, "cat")?;
|
||||
}
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg0.rank() != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: arg0.rank(),
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx != dim && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
}
|
||||
if dim == 0 {
|
||||
Self::cat0(args)
|
||||
} else {
|
||||
@ -2024,6 +2220,46 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Pad the input tensor using same values along dimension `dim`. This adds `left` elements before the
|
||||
/// input tensor values and `right` elements after.
|
||||
pub fn pad_with_same<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
|
||||
if left == 0 && right == 0 {
|
||||
Ok(self.clone())
|
||||
} else if self.elem_count() == 0 {
|
||||
crate::bail!("cannot use pad_with_same on an empty tensor")
|
||||
} else if left == 0 {
|
||||
let dim = dim.to_index(self.shape(), "pad_with_same")?;
|
||||
let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
|
||||
let mut v = vec![self];
|
||||
for _ in 0..right {
|
||||
v.push(&r)
|
||||
}
|
||||
Tensor::cat(&v, dim)
|
||||
} else if right == 0 {
|
||||
let dim = dim.to_index(self.shape(), "pad_with_same")?;
|
||||
let l = self.narrow(dim, 0, 1)?;
|
||||
let mut v = vec![];
|
||||
for _ in 0..left {
|
||||
v.push(&l)
|
||||
}
|
||||
v.push(self);
|
||||
Tensor::cat(&v, dim)
|
||||
} else {
|
||||
let dim = dim.to_index(self.shape(), "pad_with_same")?;
|
||||
let l = self.narrow(dim, 0, 1)?;
|
||||
let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
|
||||
let mut v = vec![];
|
||||
for _ in 0..left {
|
||||
v.push(&l)
|
||||
}
|
||||
v.push(self);
|
||||
for _ in 0..right {
|
||||
v.push(&r)
|
||||
}
|
||||
Tensor::cat(&v, dim)
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the `forward` method of `m` on `self`.
|
||||
pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
|
||||
m.forward(self)
|
||||
|
@ -192,6 +192,19 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
test_utils::to_vec1_round(grad_x, 2)?,
|
||||
[0.01, 0.42, 0.0, 0.98],
|
||||
);
|
||||
|
||||
// testing compared to pytorch nn.GELU(approximate = 'tanh')
|
||||
let y = x.gelu()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[2.9964, 0.8412, 3.9999, 0.0839]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[1.0116, 1.0830, 1.0003, 0.6188],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -218,6 +231,22 @@ fn binary_grad(device: &Device) -> Result<()> {
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [3., 1., -4., -1.]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [1., 1., 1., 1.]);
|
||||
|
||||
let x_var = Var::new(&[3f32, 1., -4., -1., 5., 9.], device)?;
|
||||
let x = x_var.as_tensor();
|
||||
let y_var = Var::new(&[2f32, 7., 1.], device)?;
|
||||
let y = y_var.as_tensor();
|
||||
|
||||
let ss = x
|
||||
.reshape((2, 3))?
|
||||
.slice_scatter0(&y.reshape((1, 3))?, 1)?
|
||||
.sqr()?;
|
||||
let grads = ss.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
let grad_y = grads.get(y).context("no grad for y")?;
|
||||
assert_eq!(ss.to_vec2::<f32>()?, [[9., 1., 16.], [4., 49., 1.]]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [6.0, 2.0, -8.0, 0.0, 0.0, 0.0]);
|
||||
assert_eq!(grad_y.to_vec1::<f32>()?, [4.0, 14.0, 2.0]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
9
candle-core/tests/npy.py
Normal file
9
candle-core/tests/npy.py
Normal file
@ -0,0 +1,9 @@
|
||||
import numpy as np
|
||||
x = np.arange(10)
|
||||
|
||||
# Write a npy file.
|
||||
np.save("test.npy", x)
|
||||
|
||||
# Write multiple values to a npz file.
|
||||
values = { "x": x, "x_plus_one": x + 1 }
|
||||
np.savez("test.npz", **values)
|
@ -43,7 +43,7 @@ fn quantized_matmul() -> Result<()> {
|
||||
);
|
||||
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
@ -91,7 +91,7 @@ fn quantized_matmul_neg() -> Result<()> {
|
||||
);
|
||||
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
@ -491,6 +491,9 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
|
||||
GgmlDType::Q5_0 => 0.001353,
|
||||
GgmlDType::Q5_1 => 0.001363,
|
||||
GgmlDType::Q8_0 => 0.000092,
|
||||
|
||||
// Not from the ggml repo.
|
||||
GgmlDType::Q8K => 0.00065,
|
||||
_ => candle_core::bail!("No GGML results for quantization type {dtype:?}",),
|
||||
};
|
||||
Ok(err)
|
||||
@ -508,17 +511,22 @@ fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
|
||||
T::VecDotType::from_float(&b, &mut b_quant)?;
|
||||
|
||||
let result = T::vec_dot(length, &a_quant, &b_quant)?;
|
||||
let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?;
|
||||
let reference_result = vec_dot_reference(&a, &b);
|
||||
|
||||
if (result - result_unopt).abs() / length as f32 > 1e-6 {
|
||||
candle_core::bail!(
|
||||
"the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}"
|
||||
)
|
||||
}
|
||||
|
||||
let error = (result - reference_result).abs() / length as f32;
|
||||
|
||||
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;
|
||||
|
||||
if error > GGML_MAX_DOT_PRODUCT_ERROR {
|
||||
if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR {
|
||||
candle_core::bail!(
|
||||
"Dot product error {} exceeds max error {}",
|
||||
error,
|
||||
GGML_MAX_DOT_PRODUCT_ERROR
|
||||
"Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",
|
||||
);
|
||||
}
|
||||
|
||||
@ -571,7 +579,7 @@ fn quantized_matmul_q2k() -> Result<()> {
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
@ -597,7 +605,7 @@ fn quantized_matmul_q3k() -> Result<()> {
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
@ -623,7 +631,7 @@ fn quantized_matmul_q4k() -> Result<()> {
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
@ -649,7 +657,7 @@ fn quantized_matmul_q5k() -> Result<()> {
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
@ -676,7 +684,7 @@ fn quantized_matmul_q6k() -> Result<()> {
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
@ -687,3 +695,28 @@ fn quantized_matmul_q6k() -> Result<()> {
|
||||
ggml_matmul_error_test::<BlockQ6K>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_q8k() -> Result<()> {
|
||||
use k_quants::BlockQ8K;
|
||||
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (11, 512, 21);
|
||||
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ8K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.266, 1.504, -0.204, 1.7]);
|
||||
|
||||
ggml_matmul_error_test::<BlockQ8K>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
24
candle-core/tests/serialization_tests.rs
Normal file
24
candle-core/tests/serialization_tests.rs
Normal file
@ -0,0 +1,24 @@
|
||||
use candle_core::{DType, Result, Tensor};
|
||||
|
||||
#[test]
|
||||
fn npy() -> Result<()> {
|
||||
let npy = Tensor::read_npy("tests/test.npy")?;
|
||||
assert_eq!(
|
||||
npy.to_dtype(DType::U8)?.to_vec1::<u8>()?,
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn npz() -> Result<()> {
|
||||
let npz = Tensor::read_npz("tests/test.npz")?;
|
||||
assert_eq!(npz.len(), 2);
|
||||
assert_eq!(npz[0].0, "x");
|
||||
assert_eq!(npz[1].0, "x_plus_one");
|
||||
assert_eq!(
|
||||
npz[1].1.to_dtype(DType::U8)?.to_vec1::<u8>()?,
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
);
|
||||
Ok(())
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor};
|
||||
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor};
|
||||
|
||||
fn zeros(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
|
||||
@ -8,6 +8,31 @@ fn zeros(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ones(device: &Device) -> Result<()> {
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::U8, device)?.to_vec2::<u8>()?,
|
||||
[[1, 1, 1], [1, 1, 1]],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::U32, device)?.to_vec2::<u32>()?,
|
||||
[[1, 1, 1], [1, 1, 1]],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::I64, device)?.to_vec2::<i64>()?,
|
||||
[[1, 1, 1], [1, 1, 1]],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::F32, device)?.to_vec2::<f32>()?,
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn add_mul(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::new(&[3f32, 1., 4.], device)?;
|
||||
let dim1 = tensor.dims1()?;
|
||||
@ -44,6 +69,54 @@ fn clamp(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn unary_op(device: &Device) -> Result<()> {
|
||||
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.gelu()?, 4)?,
|
||||
[
|
||||
[-0.0036, 0.8412, 3.9999, -0.046, 0.3457],
|
||||
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
|
||||
[
|
||||
[-0.004, 0.8413, 3.9999, -0.046, 0.3457],
|
||||
[2.6906, -0.0647, -0.1091, 1.7353, 2.7928]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.erf()?, 4)?,
|
||||
[
|
||||
[-1.0, 0.8427, 1.0, -0.1125, 0.5205],
|
||||
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
|
||||
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.floor()?, 4)?,
|
||||
[[-3.0, 1.0, 4.0, -1.0, 0.0], [2.0, -2.0, -1.0, 1.0, 2.0]]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.round()?, 4)?,
|
||||
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -2.0, -0.0, 2.0, 3.0]]
|
||||
);
|
||||
let tensor = Tensor::new(&[2997.9246, 314.15926f32], device)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&tensor.round_to(2)?, 4)?,
|
||||
[2997.92, 314.16]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?,
|
||||
[3000.0, 300.]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn binary_op(device: &Device) -> Result<()> {
|
||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||
let tensor1 = Tensor::new(data, device)?;
|
||||
@ -601,6 +674,30 @@ fn index_select(device: &Device) -> Result<()> {
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
||||
);
|
||||
// Prior to https://github.com/huggingface/candle/pull/1022
|
||||
// There would be a bug where the last values in the result tensor would be set to 0.
|
||||
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
]
|
||||
);
|
||||
|
||||
// Test when selecting dim > 0 with ids size different from elem count of
|
||||
// target dim in source/input.
|
||||
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
|
||||
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
|
||||
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
|
||||
let hs = t.index_select(&ids, 1)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -647,6 +744,48 @@ fn index_add(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn slice_scatter(device: &Device) -> Result<()> {
|
||||
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
let src = Tensor::arange(100f32, 106f32, device)?.reshape((2, 3))?;
|
||||
assert_eq!(
|
||||
t.slice_scatter0(&src, 0)?.to_vec2::<f32>()?,
|
||||
&[
|
||||
[100.0, 101.0, 102.0],
|
||||
[103.0, 104.0, 105.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
t.slice_scatter0(&src, 1)?.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[100.0, 101.0, 102.0],
|
||||
[103.0, 104.0, 105.0],
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
t.slice_scatter0(&src, 2)?.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[100.0, 101.0, 102.0],
|
||||
[103.0, 104.0, 105.0],
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn scatter_add(device: &Device) -> Result<()> {
|
||||
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||
assert_eq!(
|
||||
@ -897,6 +1036,7 @@ fn randn(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu);
|
||||
test_device!(ones, ones_cpu, ones_gpu);
|
||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
||||
test_device!(narrow, narrow_cpu, narrow_gpu);
|
||||
@ -908,6 +1048,7 @@ test_device!(max, max_cpu, max_gpu);
|
||||
test_device!(argmax, argmax_cpu, argmax_gpu);
|
||||
test_device!(argmin, argmin_cpu, argmin_gpu);
|
||||
test_device!(transpose, transpose_cpu, transpose_gpu);
|
||||
test_device!(unary_op, unary_op_cpu, unary_op_gpu);
|
||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
||||
test_device!(cmp, cmp_cpu, cmp_gpu);
|
||||
@ -918,6 +1059,7 @@ test_device!(index_select, index_select_cpu, index_select_gpu);
|
||||
test_device!(index_add, index_add_cpu, index_add_gpu);
|
||||
test_device!(gather, gather_cpu, gather_gpu);
|
||||
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
||||
test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
|
||||
test_device!(randn, randn_cpu, randn_gpu);
|
||||
test_device!(clamp, clamp_cpu, clamp_gpu);
|
||||
|
||||
@ -931,3 +1073,19 @@ fn randn_hasneg() -> Result<()> {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pad_with_same() -> Result<()> {
|
||||
let t = Tensor::arange(1f32, 5f32, &Device::Cpu)?.reshape((2, 2))?;
|
||||
let t0 = t.pad_with_same(0, 1, 2)?;
|
||||
assert_eq!(
|
||||
t0.to_vec2::<f32>()?,
|
||||
[[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0], [3.0, 4.0]]
|
||||
);
|
||||
let t1 = t.pad_with_same(1, 1, 2)?;
|
||||
assert_eq!(
|
||||
t1.to_vec2::<f32>()?,
|
||||
[[1.0, 1.0, 2.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0, 4.0]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
BIN
candle-core/tests/test.npy
Normal file
BIN
candle-core/tests/test.npy
Normal file
Binary file not shown.
BIN
candle-core/tests/test.npz
Normal file
BIN
candle-core/tests/test.npz
Normal file
Binary file not shown.
@ -11,8 +11,8 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
byteorder = { workspace = true }
|
||||
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.3" }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
hf-hub = { workspace = true}
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
memmap2 = { workspace = true }
|
||||
|
@ -11,20 +11,22 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.2.3" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.3", optional = true }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.3" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.2.3" }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
image = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
pyo3 = { version = "0.19.0", features = ["auto-initialize"], optional = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
@ -35,7 +37,6 @@ imageproc = { workspace = true }
|
||||
memmap2 = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
rusttype = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-chrome = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
@ -51,10 +52,14 @@ default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
cudnn = ["candle/cudnn"]
|
||||
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
||||
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||
|
||||
[[example]]
|
||||
name = "llama_multiprocess"
|
||||
required-features = ["cuda", "nccl", "flash-attn"]
|
||||
|
||||
[[example]]
|
||||
name = "reinforcement-learning"
|
||||
required-features = ["pyo3"]
|
||||
|
@ -5,11 +5,11 @@ extern crate intel_mkl_src;
|
||||
extern crate accelerate_src;
|
||||
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
|
||||
|
||||
use anyhow::{anyhow, Error as E, Result};
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::Tensor;
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::{PaddingParams, Tokenizer};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -19,10 +19,6 @@ struct Args {
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Run offline (you must have the files already cached)
|
||||
#[arg(long)]
|
||||
offline: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
@ -38,6 +34,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// Use the pytorch weights rather than the safetensors ones
|
||||
#[arg(long)]
|
||||
use_pth: bool,
|
||||
|
||||
/// The number of times to run the prompt.
|
||||
#[arg(long, default_value = "1")]
|
||||
n: usize,
|
||||
@ -60,35 +60,27 @@ impl Args {
|
||||
};
|
||||
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
|
||||
let cache = Cache::default().repo(repo);
|
||||
(
|
||||
cache
|
||||
.get("config.json")
|
||||
.ok_or(anyhow!("Missing config file in cache"))?,
|
||||
cache
|
||||
.get("tokenizer.json")
|
||||
.ok_or(anyhow!("Missing tokenizer file in cache"))?,
|
||||
cache
|
||||
.get("model.safetensors")
|
||||
.ok_or(anyhow!("Missing weights file in cache"))?,
|
||||
)
|
||||
} else {
|
||||
let (config_filename, tokenizer_filename, weights_filename) = {
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
(
|
||||
api.get("config.json")?,
|
||||
api.get("tokenizer.json")?,
|
||||
api.get("model.safetensors")?,
|
||||
)
|
||||
let config = api.get("config.json")?;
|
||||
let tokenizer = api.get("tokenizer.json")?;
|
||||
let weights = if self.use_pth {
|
||||
api.get("pytorch_model.bin")?
|
||||
} else {
|
||||
api.get("model.safetensors")?
|
||||
};
|
||||
(config, tokenizer, weights)
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let config: Config = serde_json::from_str(&config)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
||||
let vb = if self.use_pth {
|
||||
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
|
||||
} else {
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
||||
};
|
||||
let model = BertModel::load(vb, &config)?;
|
||||
Ok((model, tokenizer))
|
||||
}
|
||||
|
@ -138,18 +138,9 @@ fn main() -> Result<()> {
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let weights = filenames
|
||||
.iter()
|
||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let weights = weights
|
||||
.iter()
|
||||
.map(|f| Ok(f.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = VarBuilder::from_safetensors(weights, DType::F32, &device);
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let config = Config::starcoder_1b();
|
||||
let model = GPTBigCode::load(vb, config)?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
19
candle-examples/examples/blip/README.md
Normal file
19
candle-examples/examples/blip/README.md
Normal file
@ -0,0 +1,19 @@
|
||||
# candle-blip
|
||||
|
||||
The
|
||||
[blip-image-captioning](https://huggingface.co/Salesforce/blip-image-captioning-base)
|
||||
model can generate captions for an input image.
|
||||
|
||||
## Running on an example
|
||||
|
||||
```bash
|
||||
cargo run --example blip --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
```
|
||||
|
||||
```
|
||||
Running on CPU, to run on GPU, build this example with `--features cuda`
|
||||
loaded image Tensor[dims 3, 384, 384; f32]
|
||||
model built
|
||||
several cyclists are riding down a road with cars behind them%
|
||||
```
|
||||

|
154
candle-examples/examples/blip/main.rs
Normal file
154
candle-examples/examples/blip/main.rs
Normal file
@ -0,0 +1,154 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Error as E;
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::blip;
|
||||
use candle_transformers::models::quantized_blip;
|
||||
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
M(blip::BlipForConditionalGeneration),
|
||||
Q(quantized_blip::BlipForConditionalGeneration),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn text_decoder_forward(&mut self, xs: &Tensor, img_xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::M(m) => m.text_decoder().forward(xs, img_xs),
|
||||
Self::Q(m) => m.text_decoder().forward(xs, img_xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Maybe add support for the conditional prompt.
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Use the quantized version of the model.
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
}
|
||||
|
||||
const SEP_TOKEN_ID: u32 = 102;
|
||||
|
||||
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||
/// (3, 384, 384). OpenAI normalization is applied.
|
||||
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||
let img = image::io::Reader::open(p)?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?
|
||||
.resize_to_fill(384, 384, image::imageops::FilterType::Triangle);
|
||||
let img = img.to_rgb8();
|
||||
let data = img.into_raw();
|
||||
let data = Tensor::from_vec(data, (384, 384, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
||||
let mean =
|
||||
Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||
let std = Tensor::new(&[0.26862954f32, 0.261_302_6, 0.275_777_1], &Device::Cpu)?
|
||||
.reshape((3, 1, 1))?;
|
||||
(data.to_dtype(candle::DType::F32)? / 255.)?
|
||||
.broadcast_sub(&mean)?
|
||||
.broadcast_div(&std)
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
if args.quantized {
|
||||
let api = api.model("lmz/candle-blip".to_string());
|
||||
api.get("blip-image-captioning-large-q4k.gguf")?
|
||||
} else {
|
||||
let api = api.repo(hf_hub::Repo::with_revision(
|
||||
"Salesforce/blip-image-captioning-large".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/18".to_string(),
|
||||
));
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let tokenizer = match args.tokenizer {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("Salesforce/blip-image-captioning-large".to_string());
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
Some(file) => file.into(),
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
let mut tokenizer = TokenOutputStream::new(tokenizer);
|
||||
let mut logits_processor =
|
||||
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
|
||||
|
||||
let config = blip::Config::image_captioning_large();
|
||||
|
||||
let (image_embeds, device, mut model) = if args.quantized {
|
||||
let device = Device::Cpu;
|
||||
let image = load_image(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let vb = quantized_blip::VarBuilder::from_gguf(model_file)?;
|
||||
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
|
||||
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
|
||||
(image_embeds, device, Model::Q(model))
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let image = load_image(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = blip::BlipForConditionalGeneration::new(&config, vb)?;
|
||||
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
|
||||
(image_embeds, device, Model::M(model))
|
||||
};
|
||||
|
||||
let mut token_ids = vec![30522u32];
|
||||
for index in 0..1000 {
|
||||
let context_size = if index > 0 { 1 } else { token_ids.len() };
|
||||
let start_pos = token_ids.len().saturating_sub(context_size);
|
||||
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
|
||||
let logits = model.text_decoder_forward(&input_ids, &image_embeds)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||
let token = logits_processor.sample(&logits)?;
|
||||
if token == SEP_TOKEN_ID {
|
||||
break;
|
||||
}
|
||||
token_ids.push(token);
|
||||
if let Some(t) = tokenizer.next_token(token)? {
|
||||
use std::io::Write;
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
59
candle-examples/examples/convmixer/main.rs
Normal file
59
candle-examples/examples/convmixer/main.rs
Normal file
@ -0,0 +1,59 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::convmixer;
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("lmz/candle-convmixer".into());
|
||||
api.get("convmixer_1024_20_ks9_p14.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = convmixer::c1024_20(1000, vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -42,9 +42,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = dinov2::vit_small(vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
|
@ -68,9 +68,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let cfg = match args.which {
|
||||
Which::B0 => MBConvConfig::b0(),
|
||||
Which::B1 => MBConvConfig::b1(),
|
||||
|
@ -177,21 +177,12 @@ fn main() -> Result<()> {
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let weights = filenames
|
||||
.iter()
|
||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let weights = weights
|
||||
.iter()
|
||||
.map(|f| Ok(f.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let dtype = if args.use_f32 {
|
||||
DType::F32
|
||||
} else {
|
||||
DType::BF16
|
||||
};
|
||||
let vb = VarBuilder::from_safetensors(weights, dtype, &device);
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let config = Config::falcon7b();
|
||||
config.validate()?;
|
||||
let model = Falcon::load(vb, config)?;
|
||||
|
@ -172,17 +172,9 @@ fn main() -> Result<()> {
|
||||
}
|
||||
|
||||
println!("building the model");
|
||||
let handles = filenames
|
||||
.iter()
|
||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? }))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let tensors: Vec<_> = handles
|
||||
.iter()
|
||||
.map(|h| Ok(h.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
|
||||
let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
|
||||
}
|
||||
};
|
||||
|
@ -89,6 +89,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
@ -201,16 +205,9 @@ fn main() -> Result<()> {
|
||||
let cache = model::Cache::new(dtype, &config, &device)?;
|
||||
|
||||
println!("building the model");
|
||||
let handles = filenames
|
||||
.iter()
|
||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? }))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let tensors: Vec<_> = handles
|
||||
.iter()
|
||||
.map(|h| Ok(h.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let vb = candle_nn::var_builder::ShardedSafeTensors::var_builder(tensors, dtype, &device);
|
||||
let vb = unsafe {
|
||||
candle_nn::var_builder::ShardedSafeTensors::var_builder(&filenames, dtype, &device)?
|
||||
};
|
||||
let llama = Llama::load(vb, &cache, &config, comm)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
@ -222,7 +219,7 @@ fn main() -> Result<()> {
|
||||
.to_vec();
|
||||
|
||||
println!("starting the inference loop");
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
|
||||
let mut new_tokens = vec![];
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut index_pos = 0;
|
||||
|
90
candle-examples/examples/mistral/README.md
Normal file
90
candle-examples/examples/mistral/README.md
Normal file
@ -0,0 +1,90 @@
|
||||
# candle-mistral: 7b LLM with Apache 2.0 licensed weights
|
||||
|
||||
Mistral-7B-v0.1 is a pretrained generative LLM with 7 billion parameters. It outperforms all the publicly available 13b models
|
||||
as of 2023-09-28. Weights (and the original Python model code) are released under the permissive Apache 2.0 license.
|
||||
|
||||
- [Blog post](https://mistral.ai/news/announcing-mistral-7b/) from Mistral announcing the model release.
|
||||
- [Model card](https://huggingface.co/mistralai/Mistral-7B-v0.1) on the
|
||||
HuggingFace Hub.
|
||||
This example supports the initial model as well as a quantized variant.
|
||||
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
$ cargo run --example mistral --release --features cuda -- --prompt 'Write helloworld code in Rust' --sample-len 150
|
||||
|
||||
Generated text:
|
||||
Write helloworld code in Rust
|
||||
=============================
|
||||
|
||||
This is a simple example of how to write "Hello, world!" program in Rust.
|
||||
|
||||
## Compile and run
|
||||
|
||||
``bash
|
||||
$ cargo build --release
|
||||
Compiling hello-world v0.1.0 (/home/user/rust/hello-world)
|
||||
Finished release [optimized] target(s) in 0.26s
|
||||
$ ./target/release/hello-world
|
||||
Hello, world!
|
||||
``
|
||||
|
||||
## Source code
|
||||
|
||||
``rust
|
||||
fn main() {
|
||||
println!("Hello, world!");
|
||||
}
|
||||
``
|
||||
|
||||
## License
|
||||
|
||||
This example is released under the terms
|
||||
```
|
||||
|
||||
## Running the quantized version of the model
|
||||
|
||||
```bash
|
||||
$ cargo run --example mistral --features accelerate --release -- \
|
||||
$ --prompt "Here is a sample quick sort implementation in rust " --quantized -n 400
|
||||
avx: false, neon: true, simd128: false, f16c: false
|
||||
temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
|
||||
retrieved the files in 562.292µs
|
||||
loaded the model in 1.100323667s
|
||||
Here is a sample quick sort implementation in rust
|
||||
|
||||
``rust
|
||||
fn quick_sort(arr: &mut [i32]) {
|
||||
if arr.len() <= 1 {
|
||||
return;
|
||||
}
|
||||
|
||||
let pivot = arr[0];
|
||||
let mut left = vec![];
|
||||
let mut right = vec![];
|
||||
|
||||
for i in 1..arr.len() {
|
||||
if arr[i] < pivot {
|
||||
left.push(arr[i]);
|
||||
} else {
|
||||
right.push(arr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
quick_sort(&mut left);
|
||||
quick_sort(&mut right);
|
||||
|
||||
let mut i = 0;
|
||||
for _ in &left {
|
||||
arr[i] = left.pop().unwrap();
|
||||
i += 1;
|
||||
}
|
||||
|
||||
for _ in &right {
|
||||
arr[i] = right.pop().unwrap();
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
``
|
||||
226 tokens generated (10.91 token/s)
|
||||
```
|
271
candle-examples/examples/mistral/main.rs
Normal file
271
candle-examples/examples/mistral/main.rs
Normal file
@ -0,0 +1,271 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::mistral::{Config, Model as Mistral};
|
||||
use candle_transformers::models::quantized_mistral::Model as QMistral;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
Mistral(Mistral),
|
||||
Quantized(QMistral),
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("</s>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the </s> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = match &mut self.model {
|
||||
Model::Mistral(m) => m.forward(&input, start_pos)?,
|
||||
Model::Quantized(m) => m.forward(&input, start_pos)?,
|
||||
};
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long, default_value = "lmz/candle-mistral")]
|
||||
model_id: String,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => {
|
||||
if args.quantized {
|
||||
vec![repo.get("model-q4k.gguf")?]
|
||||
} else {
|
||||
vec![
|
||||
repo.get("pytorch_model-00001-of-00002.safetensors")?,
|
||||
repo.get("pytorch_model-00002-of-00002.safetensors")?,
|
||||
]
|
||||
}
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::config_7b_v0_1(args.use_flash_attn);
|
||||
let (model, device) = if args.quantized {
|
||||
let filename = &filenames[0];
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
|
||||
let model = QMistral::new(&config, vb)?;
|
||||
(Model::Quantized(model), Device::Cpu)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Mistral::new(&config, vb)?;
|
||||
(Model::Mistral(model), device)
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
use crate::nn::conv1d_weight_norm;
|
||||
use candle::{DType, IndexOp, Result, Tensor};
|
||||
use candle_nn::{conv1d, Conv1d, Conv1dConfig, Module, VarBuilder};
|
||||
use candle::{DType, IndexOp, Module, Result, Tensor};
|
||||
use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
|
||||
|
||||
// Encodec Model
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
|
||||
@ -199,25 +199,34 @@ impl EncodecResidualVectorQuantizer {
|
||||
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
|
||||
#[derive(Debug)]
|
||||
struct EncodecLSTM {
|
||||
layers: Vec<(Tensor, Tensor, Tensor, Tensor)>,
|
||||
layers: Vec<candle_nn::LSTM>,
|
||||
}
|
||||
|
||||
impl EncodecLSTM {
|
||||
fn load(dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let vb = &vb.pp("lstm");
|
||||
let mut layers = vec![];
|
||||
for i in 0..cfg.num_lstm_layers {
|
||||
let w_hh = vb.get((4 * dim, dim), &format!("weight_hh_l{i}"))?;
|
||||
let w_ih = vb.get((4 * dim, dim), &format!("weight_ih_l{i}"))?;
|
||||
let b_hh = vb.get(4 * dim, &format!("bias_hh_l{i}"))?;
|
||||
let b_ih = vb.get(4 * dim, &format!("bias_ih_l{i}"))?;
|
||||
layers.push((w_hh, w_ih, b_hh, b_ih))
|
||||
for layer_idx in 0..cfg.num_lstm_layers {
|
||||
let config = candle_nn::LSTMConfig {
|
||||
layer_idx,
|
||||
..Default::default()
|
||||
};
|
||||
let lstm = candle_nn::lstm(dim, dim, config, vb.clone())?;
|
||||
layers.push(lstm)
|
||||
}
|
||||
Ok(Self { layers })
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
impl Module for EncodecLSTM {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
use candle_nn::RNN;
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
let states = layer.seq(&xs)?;
|
||||
xs = layer.states_to_tensor(&states)?;
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
@ -247,7 +256,9 @@ impl EncodecConvTranspose1d {
|
||||
bias,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecConvTranspose1d {
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
@ -299,7 +310,9 @@ impl EncodecConv1d {
|
||||
conv,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecConv1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
// TODO: padding, depending on causal.
|
||||
let xs = self.conv.forward(xs)?;
|
||||
@ -340,7 +353,9 @@ impl EncodecResnetBlock {
|
||||
shortcut,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecResnetBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs.clone();
|
||||
let xs = xs.elu(1.)?;
|
||||
@ -439,8 +454,17 @@ impl EncodecEncoder {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.apply(&self.init_conv)?;
|
||||
for (resnets, conv) in self.sampling_layers.iter() {
|
||||
for resnet in resnets.iter() {
|
||||
xs = xs.apply(resnet)?;
|
||||
}
|
||||
xs = xs.elu(1.0)?.apply(conv)?;
|
||||
}
|
||||
xs.apply(&self.final_lstm)?
|
||||
.elu(1.0)?
|
||||
.apply(&self.final_conv)
|
||||
}
|
||||
}
|
||||
|
||||
@ -507,8 +531,15 @@ impl EncodecDecoder {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.apply(&self.init_conv)?.apply(&self.init_lstm)?;
|
||||
for (conv, resnets) in self.sampling_layers.iter() {
|
||||
xs = xs.elu(1.)?.apply(conv)?;
|
||||
for resnet in resnets.iter() {
|
||||
xs = xs.apply(resnet)?
|
||||
}
|
||||
}
|
||||
xs.elu(1.)?.apply(&self.final_conv)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -73,9 +73,7 @@ fn main() -> Result<()> {
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
};
|
||||
let model = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
||||
let model = model.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DTYPE, &device)? };
|
||||
let config = GenConfig::small();
|
||||
let mut model = MusicgenForConditionalGeneration::load(vb, config)?;
|
||||
|
||||
|
@ -40,7 +40,7 @@ impl Default for Config {
|
||||
num_attention_heads: 16,
|
||||
layerdrop: 0.0,
|
||||
use_cache: true,
|
||||
activation_function: Activation::Gelu, // TODO: Handle old style gelu.
|
||||
activation_function: Activation::Gelu,
|
||||
hidden_size: 1024,
|
||||
dropout: 0.1,
|
||||
attention_dropout: 0.0,
|
||||
@ -66,7 +66,7 @@ impl Config {
|
||||
num_attention_heads: 16,
|
||||
layerdrop: 0.0,
|
||||
use_cache: true,
|
||||
activation_function: Activation::Gelu, // TODO: Handle old style gelu.
|
||||
activation_function: Activation::Gelu,
|
||||
hidden_size: 1024,
|
||||
dropout: 0.1,
|
||||
attention_dropout: 0.0,
|
||||
|
56
candle-examples/examples/phi/README.md
Normal file
56
candle-examples/examples/phi/README.md
Normal file
@ -0,0 +1,56 @@
|
||||
# candle-phi: 1.3b LLM with state of the art performance for <10b models.
|
||||
|
||||
[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) is a language model using
|
||||
only 1.3 billion parameters but with state of the art performance compared to
|
||||
models with up to 10 billion parameters.
|
||||
|
||||
The candle implementation provides both the standard version as well as a
|
||||
quantized variant.
|
||||
|
||||
## Running some example
|
||||
|
||||
```bash
|
||||
$ cargo run --example phi --release -- --prompt "def print_prime(n): "
|
||||
|
||||
def print_prime(n):
|
||||
print("Printing prime numbers")
|
||||
for i in range(2, n+1):
|
||||
if is_prime(i):
|
||||
print(i)
|
||||
|
||||
def is_prime(n):
|
||||
if n <= 1:
|
||||
return False
|
||||
for i in range(2, int(math.sqrt(n))+1):
|
||||
if n % i == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
$ cargo run --example phi --release -- \
|
||||
--prompt "Explain how to find the median of an array and write the corresponding python function.\nAnswer:" \
|
||||
--quantized --sample-len 200
|
||||
|
||||
Explain how to find the median of an array and write the corresponding python function.
|
||||
Answer: The median is the middle value in an array. If the array has an even number of elements, the median is the average of the two middle values.
|
||||
|
||||
def median(arr):
|
||||
arr.sort()
|
||||
n = len(arr)
|
||||
if n % 2 == 0:
|
||||
return (arr[n//2 - 1] + arr[n//2]) / 2
|
||||
else:
|
||||
return arr[n//2]
|
||||
```
|
||||
|
||||
This also supports the [Puffin Phi v2
|
||||
model](https://huggingface.co/teknium/Puffin-Phi-v2) for human interaction.
|
||||
```
|
||||
$ cargo run --example phi --release -- \
|
||||
--prompt "USER: What would you do on a sunny day in Paris?\nASSISTANT:" \
|
||||
--sample-len 200 --model puffin-phi-v2 --quantized
|
||||
USER: What would you do on a sunny day in Paris?
|
||||
ASSISTANT: On a sunny day in Paris, you could visit the Musée du Louvre to admire the famous
|
||||
painting "Mona Lisa" by Leonardo da Vinci. You might also want to stroll along the Champs-Élysées
|
||||
and enjoy the beautiful architecture of the buildings around you. Don't forget to stop by a café
|
||||
for a cup of coffee and to soak up the sun!"
|
||||
```
|
305
candle-examples/examples/phi/main.rs
Normal file
305
candle-examples/examples/phi/main.rs
Normal file
@ -0,0 +1,305 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
||||
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
MixFormer(MixFormer),
|
||||
Quantized(QMixFormer),
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
verbose_prompt,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
println!("starting the inference loop");
|
||||
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
||||
if tokens.is_empty() {
|
||||
anyhow::bail!("Empty prompts are not supported in the phi model.")
|
||||
}
|
||||
if self.verbose_prompt {
|
||||
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
||||
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
println!("{id:7} -> '{token}'");
|
||||
}
|
||||
}
|
||||
let mut tokens = tokens.get_ids().to_vec();
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
||||
Some(token) => *token,
|
||||
None => anyhow::bail!("cannot find the endoftext token"),
|
||||
};
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush()?;
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = match &mut self.model {
|
||||
Model::MixFormer(m) => m.forward(&input)?,
|
||||
Model::Quantized(m) => m.forward(&input)?,
|
||||
};
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||
print!("{token}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum WhichModel {
|
||||
#[value(name = "1")]
|
||||
V1,
|
||||
#[value(name = "1.5")]
|
||||
V1_5,
|
||||
PuffinPhiV2,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Display the token for the specified prompt.
|
||||
#[arg(long)]
|
||||
verbose_prompt: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "1.5")]
|
||||
model: WhichModel,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id.to_string(),
|
||||
None => {
|
||||
if args.quantized {
|
||||
"lmz/candle-quantized-phi".to_string()
|
||||
} else {
|
||||
match args.model {
|
||||
WhichModel::V1 => "microsoft/phi-1".to_string(),
|
||||
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
||||
WhichModel::PuffinPhiV2 => "lmz/candle-quantized-phi".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
let revision = match args.revision {
|
||||
Some(rev) => rev.to_string(),
|
||||
None => {
|
||||
if args.quantized {
|
||||
"main".to_string()
|
||||
} else {
|
||||
match args.model {
|
||||
WhichModel::V1 => "refs/pr/2".to_string(),
|
||||
WhichModel::V1_5 => "refs/pr/18".to_string(),
|
||||
WhichModel::PuffinPhiV2 => "main".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
let tokenizer_filename = match args.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 => repo.get("tokenizer.json")?,
|
||||
WhichModel::PuffinPhiV2 => repo.get("tokenizer-puffin-phi-v2.json")?,
|
||||
},
|
||||
};
|
||||
let filename = match args.weight_file {
|
||||
Some(weight_file) => std::path::PathBuf::from(weight_file),
|
||||
None => {
|
||||
if args.quantized {
|
||||
match args.model {
|
||||
WhichModel::V1 => repo.get("model-v1-q4k.gguf")?,
|
||||
WhichModel::V1_5 => repo.get("model-q4k.gguf")?,
|
||||
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2-q4k.gguf")?,
|
||||
}
|
||||
} else {
|
||||
match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 => repo.get("model.safetensors")?,
|
||||
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2.safetensors")?,
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = match args.model {
|
||||
WhichModel::V1 => Config::v1(),
|
||||
WhichModel::V1_5 => Config::v1_5(),
|
||||
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
||||
};
|
||||
let (model, device) = if args.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
|
||||
let model = QMixFormer::new(&config, vb)?;
|
||||
(Model::Quantized(model), Device::Cpu)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
|
||||
let model = MixFormer::new(&config, vb)?;
|
||||
(Model::MixFormer(model), device)
|
||||
};
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
args.verbose_prompt,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
42
candle-examples/examples/quantized-t5/README.md
Normal file
42
candle-examples/examples/quantized-t5/README.md
Normal file
@ -0,0 +1,42 @@
|
||||
# candle-quantized-t5
|
||||
|
||||
This example uses a quantized version of the t5 model.
|
||||
|
||||
```bash
|
||||
$ cargo run --example quantized-t5 --release -- --prompt "translate to German: A beautiful candle."
|
||||
...
|
||||
Eine schöne Kerze.
|
||||
```
|
||||
|
||||
The weight file is automatically retrieved from the hub. It is also possible to
|
||||
generate quantized weight files from the original safetensors file by using the
|
||||
`tensor-tools` command line utility via:
|
||||
|
||||
```bash
|
||||
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
|
||||
```
|
||||
|
||||
To use a different model, specify the `model-id`. For example, you can use
|
||||
quantized [CoEdit models](https://huggingface.co/jbochi/candle-coedit-quantized).
|
||||
|
||||
```bash
|
||||
$ cargo run --example quantized-t5 --release -- \
|
||||
--model-id "jbochi/candle-coedit-quantized" \
|
||||
--prompt "Make this text coherent: Their flight is weak. They run quickly through the tree canopy." \
|
||||
--temperature 0
|
||||
...
|
||||
Although their flight is weak, they run quickly through the tree canopy.
|
||||
|
||||
By default, it will look for `model.gguf` and `config.json`, but you can specify
|
||||
custom local or remote `weight-file` and `config-file`s:
|
||||
|
||||
```bash
|
||||
cargo run --example quantized-t5 --release -- \
|
||||
--model-id "jbochi/candle-coedit-quantized" \
|
||||
--weight-file "model-xl.gguf" \
|
||||
--config-file "config-xl.json" \
|
||||
--prompt "Rewrite to make this easier to understand: Note that a storm surge is what forecasters consider a hurricane's most treacherous aspect." \
|
||||
--temperature 0
|
||||
...
|
||||
Note that a storm surge is what forecasters consider a hurricane's most dangerous part.
|
||||
```
|
228
candle-examples/examples/quantized-t5/main.rs
Normal file
228
candle-examples/examples/quantized-t5/main.rs
Normal file
@ -0,0 +1,228 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use candle_transformers::models::quantized_t5 as t5;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{Device, Tensor};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::{api::sync::Api, api::sync::ApiRepo, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||
enum Which {
|
||||
T5Small,
|
||||
FlanT5Small,
|
||||
FlanT5Base,
|
||||
FlanT5Large,
|
||||
FlanT5Xl,
|
||||
FlanT5Xxl,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// The model repository to use on the HuggingFace hub.
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
// Enable/disable decoding.
|
||||
#[arg(long, default_value = "false")]
|
||||
disable_cache: bool,
|
||||
|
||||
/// Use this prompt, otherwise compute sentence similarities.
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "t5-small")]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
struct T5ModelBuilder {
|
||||
device: Device,
|
||||
config: t5::Config,
|
||||
weights_filename: PathBuf,
|
||||
}
|
||||
|
||||
impl T5ModelBuilder {
|
||||
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
|
||||
let device = Device::Cpu;
|
||||
let default_model = "lmz/candle-quantized-t5".to_string();
|
||||
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||
(None, Some(revision)) => (default_model, revision),
|
||||
(None, None) => (default_model, "main".to_string()),
|
||||
};
|
||||
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
let config_filename = match &args.config_file {
|
||||
Some(filename) => Self::get_local_or_remote_file(filename, &api)?,
|
||||
None => match args.which {
|
||||
Which::T5Small => api.get("config.json")?,
|
||||
Which::FlanT5Small => api.get("config-flan-t5-small.json")?,
|
||||
Which::FlanT5Base => api.get("config-flan-t5-base.json")?,
|
||||
Which::FlanT5Large => api.get("config-flan-t5-large.json")?,
|
||||
Which::FlanT5Xl => api.get("config-flan-t5-xl.json")?,
|
||||
Which::FlanT5Xxl => api.get("config-flan-t5-xxl.json")?,
|
||||
},
|
||||
};
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let weights_filename = match &args.weight_file {
|
||||
Some(filename) => Self::get_local_or_remote_file(filename, &api)?,
|
||||
None => match args.which {
|
||||
Which::T5Small => api.get("model.gguf")?,
|
||||
Which::FlanT5Small => api.get("model-flan-t5-small.gguf")?,
|
||||
Which::FlanT5Base => api.get("model-flan-t5-base.gguf")?,
|
||||
Which::FlanT5Large => api.get("model-flan-t5-large.gguf")?,
|
||||
Which::FlanT5Xl => api.get("model-flan-t5-xl.gguf")?,
|
||||
Which::FlanT5Xxl => api.get("model-flan-t5-xxl.gguf")?,
|
||||
},
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let mut config: t5::Config = serde_json::from_str(&config)?;
|
||||
config.use_cache = !args.disable_cache;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
Ok((
|
||||
Self {
|
||||
device,
|
||||
config,
|
||||
weights_filename,
|
||||
},
|
||||
tokenizer,
|
||||
))
|
||||
}
|
||||
|
||||
pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
|
||||
let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
|
||||
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
||||
}
|
||||
|
||||
fn get_local_or_remote_file(filename: &str, api: &ApiRepo) -> Result<PathBuf> {
|
||||
let local_filename = std::path::PathBuf::from(filename);
|
||||
if local_filename.exists() {
|
||||
Ok(local_filename)
|
||||
} else {
|
||||
Ok(api.get(filename)?)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
|
||||
let device = &builder.device;
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
let tokens = tokenizer
|
||||
.encode(args.prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
let mut model = builder.build_model()?;
|
||||
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
||||
let temperature = if args.temperature <= 0. {
|
||||
None
|
||||
} else {
|
||||
Some(args.temperature)
|
||||
};
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, temperature, args.top_p);
|
||||
let encoder_output = model.encode(&input_token_ids)?;
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
for index in 0.. {
|
||||
if output_token_ids.len() > 512 {
|
||||
break;
|
||||
}
|
||||
let decoder_token_ids = if index == 0 || !builder.config.use_cache {
|
||||
Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
|
||||
} else {
|
||||
let last_token = *output_token_ids.last().unwrap();
|
||||
Tensor::new(&[last_token], device)?.unsqueeze(0)?
|
||||
};
|
||||
let logits = model
|
||||
.decode(&decoder_token_ids, &encoder_output)?
|
||||
.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&output_token_ids[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token_id = logits_processor.sample(&logits)?;
|
||||
if next_token_id as usize == builder.config.eos_token_id {
|
||||
break;
|
||||
}
|
||||
output_token_ids.push(next_token_id);
|
||||
if let Some(text) = tokenizer.id_to_token(next_token_id) {
|
||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
print!("{text}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start.elapsed();
|
||||
println!(
|
||||
"\n{} tokens generated ({:.2} token/s)\n",
|
||||
output_token_ids.len(),
|
||||
output_token_ids.len() as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
@ -44,6 +44,29 @@ enum Which {
|
||||
L13bCode,
|
||||
#[value(name = "32b-code")]
|
||||
L34bCode,
|
||||
#[value(name = "7b-mistral")]
|
||||
Mistral7b,
|
||||
#[value(name = "7b-mistral-instruct")]
|
||||
Mistral7bInstruct,
|
||||
#[value(name = "7b-zephyr")]
|
||||
Zephyr7b,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn is_mistral(&self) -> bool {
|
||||
match self {
|
||||
Self::L7b
|
||||
| Self::L13b
|
||||
| Self::L70b
|
||||
| Self::L7bChat
|
||||
| Self::L13bChat
|
||||
| Self::L70bChat
|
||||
| Self::L7bCode
|
||||
| Self::L13bCode
|
||||
| Self::L34bCode => false,
|
||||
Self::Mistral7b | Self::Mistral7bInstruct | Self::Zephyr7b => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -110,7 +133,12 @@ impl Args {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("hf-internal-testing/llama-tokenizer".to_string());
|
||||
let repo = if self.which.is_mistral() {
|
||||
"mistralai/Mistral-7B-v0.1"
|
||||
} else {
|
||||
"hf-internal-testing/llama-tokenizer"
|
||||
};
|
||||
let api = api.model(repo.to_string());
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
};
|
||||
@ -140,6 +168,18 @@ impl Args {
|
||||
Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"),
|
||||
Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"),
|
||||
Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"),
|
||||
Which::Mistral7b => (
|
||||
"TheBloke/Mistral-7B-v0.1-GGUF",
|
||||
"mistral-7b-v0.1.Q4_K_S.gguf",
|
||||
),
|
||||
Which::Mistral7bInstruct => (
|
||||
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
|
||||
"mistral-7b-instruct-v0.1.Q4_K_S.gguf",
|
||||
),
|
||||
Which::Zephyr7b => (
|
||||
"TheBloke/zephyr-7B-alpha-GGUF",
|
||||
"zephyr-7b-alpha.Q4_K_M.gguf",
|
||||
),
|
||||
};
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(repo.to_string());
|
||||
@ -261,7 +301,11 @@ fn main() -> anyhow::Result<()> {
|
||||
| Which::L7bCode
|
||||
| Which::L13bCode
|
||||
| Which::L34bCode => 1,
|
||||
Which::L70b | Which::L70bChat => 8,
|
||||
Which::Mistral7b
|
||||
| Which::Mistral7bInstruct
|
||||
| Which::Zephyr7b
|
||||
| Which::L70b
|
||||
| Which::L70bChat => 8,
|
||||
};
|
||||
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
|
||||
}
|
||||
@ -291,7 +335,11 @@ fn main() -> anyhow::Result<()> {
|
||||
prompt.pop();
|
||||
}
|
||||
}
|
||||
prompt
|
||||
if args.which.is_mistral() {
|
||||
format!("[INST] {prompt} [/INST]")
|
||||
} else {
|
||||
prompt
|
||||
}
|
||||
}
|
||||
};
|
||||
print!("{}", &prompt_str);
|
||||
@ -327,6 +375,8 @@ fn main() -> anyhow::Result<()> {
|
||||
all_tokens.push(next_token);
|
||||
print_token(next_token, &tokenizer);
|
||||
|
||||
let eos_token = *tokenizer.get_vocab(true).get("</s>").unwrap();
|
||||
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
for index in 0..to_sample {
|
||||
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
||||
@ -345,6 +395,9 @@ fn main() -> anyhow::Result<()> {
|
||||
next_token = logits_processor.sample(&logits)?;
|
||||
all_tokens.push(next_token);
|
||||
print_token(next_token, &tokenizer);
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
};
|
||||
}
|
||||
let dt = start_post_prompt.elapsed();
|
||||
println!(
|
||||
|
16
candle-examples/examples/reinforcement-learning/README.md
Normal file
16
candle-examples/examples/reinforcement-learning/README.md
Normal file
@ -0,0 +1,16 @@
|
||||
# candle-reinforcement-learning
|
||||
|
||||
Reinforcement Learning examples for candle.
|
||||
|
||||
This has been tested with `gymnasium` version `0.29.1`. You can install the
|
||||
Python package with:
|
||||
```bash
|
||||
pip install "gymnasium[accept-rom-license]"
|
||||
```
|
||||
|
||||
In order to run the example, use the following command. Note the additional
|
||||
`--package` flag to ensure that there is no conflict with the `candle-pyo3`
|
||||
crate.
|
||||
```bash
|
||||
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples
|
||||
```
|
@ -0,0 +1,308 @@
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
from PIL import Image
|
||||
from multiprocessing import Process, Pipe
|
||||
|
||||
# atari_wrappers.py
|
||||
class NoopResetEnv(gym.Wrapper):
|
||||
def __init__(self, env, noop_max=30):
|
||||
"""Sample initial states by taking random number of no-ops on reset.
|
||||
No-op is assumed to be action 0.
|
||||
"""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.noop_max = noop_max
|
||||
self.override_num_noops = None
|
||||
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
|
||||
|
||||
def reset(self):
|
||||
""" Do no-op action for a number of steps in [1, noop_max]."""
|
||||
self.env.reset()
|
||||
if self.override_num_noops is not None:
|
||||
noops = self.override_num_noops
|
||||
else:
|
||||
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) #pylint: disable=E1101
|
||||
assert noops > 0
|
||||
obs = None
|
||||
for _ in range(noops):
|
||||
obs, _, done, _ = self.env.step(0)
|
||||
if done:
|
||||
obs = self.env.reset()
|
||||
return obs
|
||||
|
||||
class FireResetEnv(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
"""Take action on reset for environments that are fixed until firing."""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
|
||||
assert len(env.unwrapped.get_action_meanings()) >= 3
|
||||
|
||||
def reset(self):
|
||||
self.env.reset()
|
||||
obs, _, done, _ = self.env.step(1)
|
||||
if done:
|
||||
self.env.reset()
|
||||
obs, _, done, _ = self.env.step(2)
|
||||
if done:
|
||||
self.env.reset()
|
||||
return obs
|
||||
|
||||
class ImageSaver(gym.Wrapper):
|
||||
def __init__(self, env, img_path, rank):
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self._cnt = 0
|
||||
self._img_path = img_path
|
||||
self._rank = rank
|
||||
|
||||
def step(self, action):
|
||||
step_result = self.env.step(action)
|
||||
obs, _, _, _ = step_result
|
||||
img = Image.fromarray(obs, 'RGB')
|
||||
img.save('%s/out%d-%05d.png' % (self._img_path, self._rank, self._cnt))
|
||||
self._cnt += 1
|
||||
return step_result
|
||||
|
||||
class EpisodicLifeEnv(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
"""Make end-of-life == end-of-episode, but only reset on true game over.
|
||||
Done by DeepMind for the DQN and co. since it helps value estimation.
|
||||
"""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.lives = 0
|
||||
self.was_real_done = True
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self.was_real_done = done
|
||||
# check current lives, make loss of life terminal,
|
||||
# then update lives to handle bonus lives
|
||||
lives = self.env.unwrapped.ale.lives()
|
||||
if lives < self.lives and lives > 0:
|
||||
# for Qbert somtimes we stay in lives == 0 condtion for a few frames
|
||||
# so its important to keep lives > 0, so that we only reset once
|
||||
# the environment advertises done.
|
||||
done = True
|
||||
self.lives = lives
|
||||
return obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
"""Reset only when lives are exhausted.
|
||||
This way all states are still reachable even though lives are episodic,
|
||||
and the learner need not know about any of this behind-the-scenes.
|
||||
"""
|
||||
if self.was_real_done:
|
||||
obs = self.env.reset()
|
||||
else:
|
||||
# no-op step to advance from terminal/lost life state
|
||||
obs, _, _, _ = self.env.step(0)
|
||||
self.lives = self.env.unwrapped.ale.lives()
|
||||
return obs
|
||||
|
||||
class MaxAndSkipEnv(gym.Wrapper):
|
||||
def __init__(self, env, skip=4):
|
||||
"""Return only every `skip`-th frame"""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
# most recent raw observations (for max pooling across time steps)
|
||||
self._obs_buffer = deque(maxlen=2)
|
||||
self._skip = skip
|
||||
|
||||
def step(self, action):
|
||||
"""Repeat action, sum reward, and max over last observations."""
|
||||
total_reward = 0.0
|
||||
done = None
|
||||
for _ in range(self._skip):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self._obs_buffer.append(obs)
|
||||
total_reward += reward
|
||||
if done:
|
||||
break
|
||||
max_frame = np.max(np.stack(self._obs_buffer), axis=0)
|
||||
|
||||
return max_frame, total_reward, done, info
|
||||
|
||||
def reset(self):
|
||||
"""Clear past frame buffer and init. to first obs. from inner env."""
|
||||
self._obs_buffer.clear()
|
||||
obs = self.env.reset()
|
||||
self._obs_buffer.append(obs)
|
||||
return obs
|
||||
|
||||
class ClipRewardEnv(gym.RewardWrapper):
|
||||
def reward(self, reward):
|
||||
"""Bin reward to {+1, 0, -1} by its sign."""
|
||||
return np.sign(reward)
|
||||
|
||||
class WarpFrame(gym.ObservationWrapper):
|
||||
def __init__(self, env):
|
||||
"""Warp frames to 84x84 as done in the Nature paper and later work."""
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
self.res = 84
|
||||
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(self.res, self.res, 1), dtype='uint8')
|
||||
|
||||
def observation(self, obs):
|
||||
frame = np.dot(obs.astype('float32'), np.array([0.299, 0.587, 0.114], 'float32'))
|
||||
frame = np.array(Image.fromarray(frame).resize((self.res, self.res),
|
||||
resample=Image.BILINEAR), dtype=np.uint8)
|
||||
return frame.reshape((self.res, self.res, 1))
|
||||
|
||||
class FrameStack(gym.Wrapper):
|
||||
def __init__(self, env, k):
|
||||
"""Buffer observations and stack across channels (last axis)."""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.k = k
|
||||
self.frames = deque([], maxlen=k)
|
||||
shp = env.observation_space.shape
|
||||
assert shp[2] == 1 # can only stack 1-channel frames
|
||||
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(shp[0], shp[1], k), dtype='uint8')
|
||||
|
||||
def reset(self):
|
||||
"""Clear buffer and re-fill by duplicating the first observation."""
|
||||
ob = self.env.reset()
|
||||
for _ in range(self.k): self.frames.append(ob)
|
||||
return self.observation()
|
||||
|
||||
def step(self, action):
|
||||
ob, reward, done, info = self.env.step(action)
|
||||
self.frames.append(ob)
|
||||
return self.observation(), reward, done, info
|
||||
|
||||
def observation(self):
|
||||
assert len(self.frames) == self.k
|
||||
return np.concatenate(self.frames, axis=2)
|
||||
|
||||
def wrap_deepmind(env, episode_life=True, clip_rewards=True):
|
||||
"""Configure environment for DeepMind-style Atari.
|
||||
|
||||
Note: this does not include frame stacking!"""
|
||||
assert 'NoFrameskip' in env.spec.id # required for DeepMind-style skip
|
||||
if episode_life:
|
||||
env = EpisodicLifeEnv(env)
|
||||
env = NoopResetEnv(env, noop_max=30)
|
||||
env = MaxAndSkipEnv(env, skip=4)
|
||||
if 'FIRE' in env.unwrapped.get_action_meanings():
|
||||
env = FireResetEnv(env)
|
||||
env = WarpFrame(env)
|
||||
if clip_rewards:
|
||||
env = ClipRewardEnv(env)
|
||||
return env
|
||||
|
||||
# envs.py
|
||||
def make_env(env_id, img_dir, seed, rank):
|
||||
def _thunk():
|
||||
env = gym.make(env_id)
|
||||
env.reset(seed=(seed + rank))
|
||||
if img_dir is not None:
|
||||
env = ImageSaver(env, img_dir, rank)
|
||||
env = wrap_deepmind(env)
|
||||
env = WrapPyTorch(env)
|
||||
return env
|
||||
|
||||
return _thunk
|
||||
|
||||
class WrapPyTorch(gym.ObservationWrapper):
|
||||
def __init__(self, env=None):
|
||||
super(WrapPyTorch, self).__init__(env)
|
||||
self.observation_space = gym.spaces.Box(0.0, 1.0, [1, 84, 84], dtype='float32')
|
||||
|
||||
def observation(self, observation):
|
||||
return observation.transpose(2, 0, 1)
|
||||
|
||||
# vecenv.py
|
||||
class VecEnv(object):
|
||||
"""
|
||||
Vectorized environment base class
|
||||
"""
|
||||
def step(self, vac):
|
||||
"""
|
||||
Apply sequence of actions to sequence of environments
|
||||
actions -> (observations, rewards, news)
|
||||
|
||||
where 'news' is a boolean vector indicating whether each element is new.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
def reset(self):
|
||||
"""
|
||||
Reset all environments
|
||||
"""
|
||||
raise NotImplementedError
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
# subproc_vec_env.py
|
||||
def worker(remote, env_fn_wrapper):
|
||||
env = env_fn_wrapper.x()
|
||||
while True:
|
||||
cmd, data = remote.recv()
|
||||
if cmd == 'step':
|
||||
ob, reward, done, info = env.step(data)
|
||||
if done:
|
||||
ob = env.reset()
|
||||
remote.send((ob, reward, done, info))
|
||||
elif cmd == 'reset':
|
||||
ob = env.reset()
|
||||
remote.send(ob)
|
||||
elif cmd == 'close':
|
||||
remote.close()
|
||||
break
|
||||
elif cmd == 'get_spaces':
|
||||
remote.send((env.action_space, env.observation_space))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
class CloudpickleWrapper(object):
|
||||
"""
|
||||
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
|
||||
"""
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
def __getstate__(self):
|
||||
import cloudpickle
|
||||
return cloudpickle.dumps(self.x)
|
||||
def __setstate__(self, ob):
|
||||
import pickle
|
||||
self.x = pickle.loads(ob)
|
||||
|
||||
class SubprocVecEnv(VecEnv):
|
||||
def __init__(self, env_fns):
|
||||
"""
|
||||
envs: list of gym environments to run in subprocesses
|
||||
"""
|
||||
nenvs = len(env_fns)
|
||||
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
|
||||
self.ps = [Process(target=worker, args=(work_remote, CloudpickleWrapper(env_fn)))
|
||||
for (work_remote, env_fn) in zip(self.work_remotes, env_fns)]
|
||||
for p in self.ps:
|
||||
p.start()
|
||||
|
||||
self.remotes[0].send(('get_spaces', None))
|
||||
self.action_space, self.observation_space = self.remotes[0].recv()
|
||||
|
||||
|
||||
def step(self, actions):
|
||||
for remote, action in zip(self.remotes, actions):
|
||||
remote.send(('step', action))
|
||||
results = [remote.recv() for remote in self.remotes]
|
||||
obs, rews, dones, infos = zip(*results)
|
||||
return np.stack(obs), np.stack(rews), np.stack(dones), infos
|
||||
|
||||
def reset(self):
|
||||
for remote in self.remotes:
|
||||
remote.send(('reset', None))
|
||||
return np.stack([remote.recv() for remote in self.remotes])
|
||||
|
||||
def close(self):
|
||||
for remote in self.remotes:
|
||||
remote.send(('close', None))
|
||||
for p in self.ps:
|
||||
p.join()
|
||||
|
||||
@property
|
||||
def num_envs(self):
|
||||
return len(self.remotes)
|
||||
|
||||
# Create the environment.
|
||||
def make(env_name, img_dir, num_processes):
|
||||
envs = SubprocVecEnv([
|
||||
make_env(env_name, img_dir, 1337, i) for i in range(num_processes)
|
||||
])
|
||||
return envs
|
108
candle-examples/examples/reinforcement-learning/gym_env.rs
Normal file
108
candle-examples/examples/reinforcement-learning/gym_env.rs
Normal file
@ -0,0 +1,108 @@
|
||||
#![allow(unused)]
|
||||
//! Wrappers around the Python API of Gymnasium (the new version of OpenAI gym)
|
||||
use candle::{Device, Result, Tensor};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyDict;
|
||||
|
||||
/// The return value for a step.
|
||||
#[derive(Debug)]
|
||||
pub struct Step<A> {
|
||||
pub obs: Tensor,
|
||||
pub action: A,
|
||||
pub reward: f64,
|
||||
pub is_done: bool,
|
||||
}
|
||||
|
||||
impl<A: Copy> Step<A> {
|
||||
/// Returns a copy of this step changing the observation tensor.
|
||||
pub fn copy_with_obs(&self, obs: &Tensor) -> Step<A> {
|
||||
Step {
|
||||
obs: obs.clone(),
|
||||
action: self.action,
|
||||
reward: self.reward,
|
||||
is_done: self.is_done,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An OpenAI Gym session.
|
||||
pub struct GymEnv {
|
||||
env: PyObject,
|
||||
action_space: usize,
|
||||
observation_space: Vec<usize>,
|
||||
}
|
||||
|
||||
fn w(res: PyErr) -> candle::Error {
|
||||
candle::Error::wrap(res)
|
||||
}
|
||||
|
||||
impl GymEnv {
|
||||
/// Creates a new session of the specified OpenAI Gym environment.
|
||||
pub fn new(name: &str) -> Result<GymEnv> {
|
||||
Python::with_gil(|py| {
|
||||
let gym = py.import("gymnasium")?;
|
||||
let make = gym.getattr("make")?;
|
||||
let env = make.call1((name,))?;
|
||||
let action_space = env.getattr("action_space")?;
|
||||
let action_space = if let Ok(val) = action_space.getattr("n") {
|
||||
val.extract()?
|
||||
} else {
|
||||
let action_space: Vec<usize> = action_space.getattr("shape")?.extract()?;
|
||||
action_space[0]
|
||||
};
|
||||
let observation_space = env.getattr("observation_space")?;
|
||||
let observation_space = observation_space.getattr("shape")?.extract()?;
|
||||
Ok(GymEnv {
|
||||
env: env.into(),
|
||||
action_space,
|
||||
observation_space,
|
||||
})
|
||||
})
|
||||
.map_err(w)
|
||||
}
|
||||
|
||||
/// Resets the environment, returning the observation tensor.
|
||||
pub fn reset(&self, seed: u64) -> Result<Tensor> {
|
||||
let obs: Vec<f32> = Python::with_gil(|py| {
|
||||
let kwargs = PyDict::new(py);
|
||||
kwargs.set_item("seed", seed)?;
|
||||
let obs = self.env.call_method(py, "reset", (), Some(kwargs))?;
|
||||
obs.as_ref(py).get_item(0)?.extract()
|
||||
})
|
||||
.map_err(w)?;
|
||||
Tensor::new(obs, &Device::Cpu)
|
||||
}
|
||||
|
||||
/// Applies an environment step using the specified action.
|
||||
pub fn step<A: pyo3::IntoPy<pyo3::Py<pyo3::PyAny>> + Clone>(
|
||||
&self,
|
||||
action: A,
|
||||
) -> Result<Step<A>> {
|
||||
let (obs, reward, is_done) = Python::with_gil(|py| {
|
||||
let step = self.env.call_method(py, "step", (action.clone(),), None)?;
|
||||
let step = step.as_ref(py);
|
||||
let obs: Vec<f32> = step.get_item(0)?.extract()?;
|
||||
let reward: f64 = step.get_item(1)?.extract()?;
|
||||
let is_done: bool = step.get_item(2)?.extract()?;
|
||||
Ok((obs, reward, is_done))
|
||||
})
|
||||
.map_err(w)?;
|
||||
let obs = Tensor::new(obs, &Device::Cpu)?;
|
||||
Ok(Step {
|
||||
obs,
|
||||
reward,
|
||||
is_done,
|
||||
action,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the number of allowed actions for this environment.
|
||||
pub fn action_space(&self) -> usize {
|
||||
self.action_space
|
||||
}
|
||||
|
||||
/// Returns the shape of the observation tensors.
|
||||
pub fn observation_space(&self) -> &[usize] {
|
||||
&self.observation_space
|
||||
}
|
||||
}
|
75
candle-examples/examples/reinforcement-learning/main.rs
Normal file
75
candle-examples/examples/reinforcement-learning/main.rs
Normal file
@ -0,0 +1,75 @@
|
||||
#![allow(unused)]
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
mod gym_env;
|
||||
mod vec_gym_env;
|
||||
|
||||
use candle::Result;
|
||||
use clap::Parser;
|
||||
use rand::Rng;
|
||||
|
||||
// The total number of episodes.
|
||||
const MAX_EPISODES: usize = 100;
|
||||
// The maximum length of an episode.
|
||||
const EPISODE_LENGTH: usize = 200;
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let env = gym_env::GymEnv::new("Pendulum-v1")?;
|
||||
println!("action space: {}", env.action_space());
|
||||
println!("observation space: {:?}", env.observation_space());
|
||||
|
||||
let _num_obs = env.observation_space().iter().product::<usize>();
|
||||
let _num_actions = env.action_space();
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for episode in 0..MAX_EPISODES {
|
||||
let mut obs = env.reset(episode as u64)?;
|
||||
|
||||
let mut total_reward = 0.0;
|
||||
for _ in 0..EPISODE_LENGTH {
|
||||
let actions = rng.gen_range(-2.0..2.0);
|
||||
|
||||
let step = env.step(vec![actions])?;
|
||||
total_reward += step.reward;
|
||||
|
||||
if step.is_done {
|
||||
break;
|
||||
}
|
||||
obs = step.obs;
|
||||
}
|
||||
|
||||
println!("episode {episode} with total reward of {total_reward}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -0,0 +1,91 @@
|
||||
#![allow(unused)]
|
||||
//! Vectorized version of the gym environment.
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyDict;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Step {
|
||||
pub obs: Tensor,
|
||||
pub reward: Tensor,
|
||||
pub is_done: Tensor,
|
||||
}
|
||||
|
||||
pub struct VecGymEnv {
|
||||
env: PyObject,
|
||||
action_space: usize,
|
||||
observation_space: Vec<usize>,
|
||||
}
|
||||
|
||||
fn w(res: PyErr) -> candle::Error {
|
||||
candle::Error::wrap(res)
|
||||
}
|
||||
|
||||
impl VecGymEnv {
|
||||
pub fn new(name: &str, img_dir: Option<&str>, nprocesses: usize) -> Result<VecGymEnv> {
|
||||
Python::with_gil(|py| {
|
||||
let sys = py.import("sys")?;
|
||||
let path = sys.getattr("path")?;
|
||||
let _ = path.call_method1(
|
||||
"append",
|
||||
("candle-examples/examples/reinforcement-learning",),
|
||||
)?;
|
||||
let gym = py.import("atari_wrappers")?;
|
||||
let make = gym.getattr("make")?;
|
||||
let env = make.call1((name, img_dir, nprocesses))?;
|
||||
let action_space = env.getattr("action_space")?;
|
||||
let action_space = action_space.getattr("n")?.extract()?;
|
||||
let observation_space = env.getattr("observation_space")?;
|
||||
let observation_space: Vec<usize> = observation_space.getattr("shape")?.extract()?;
|
||||
let observation_space =
|
||||
[vec![nprocesses].as_slice(), observation_space.as_slice()].concat();
|
||||
Ok(VecGymEnv {
|
||||
env: env.into(),
|
||||
action_space,
|
||||
observation_space,
|
||||
})
|
||||
})
|
||||
.map_err(w)
|
||||
}
|
||||
|
||||
pub fn reset(&self) -> Result<Tensor> {
|
||||
let obs = Python::with_gil(|py| {
|
||||
let obs = self.env.call_method0(py, "reset")?;
|
||||
let obs = obs.call_method0(py, "flatten")?;
|
||||
obs.extract::<Vec<f32>>(py)
|
||||
})
|
||||
.map_err(w)?;
|
||||
Tensor::new(obs, &Device::Cpu)?.reshape(self.observation_space.as_slice())
|
||||
}
|
||||
|
||||
pub fn step(&self, action: Vec<usize>) -> Result<Step> {
|
||||
let (obs, reward, is_done) = Python::with_gil(|py| {
|
||||
let step = self.env.call_method(py, "step", (action,), None)?;
|
||||
let step = step.as_ref(py);
|
||||
let obs = step.get_item(0)?.call_method("flatten", (), None)?;
|
||||
let obs_buffer = pyo3::buffer::PyBuffer::get(obs)?;
|
||||
let obs: Vec<u8> = obs_buffer.to_vec(py)?;
|
||||
let reward: Vec<f32> = step.get_item(1)?.extract()?;
|
||||
let is_done: Vec<f32> = step.get_item(2)?.extract()?;
|
||||
Ok((obs, reward, is_done))
|
||||
})
|
||||
.map_err(w)?;
|
||||
let obs = Tensor::from_vec(obs, self.observation_space.as_slice(), &Device::Cpu)?
|
||||
.to_dtype(DType::F32)?;
|
||||
let reward = Tensor::new(reward, &Device::Cpu)?;
|
||||
let is_done = Tensor::new(is_done, &Device::Cpu)?;
|
||||
Ok(Step {
|
||||
obs,
|
||||
reward,
|
||||
is_done,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn action_space(&self) -> usize {
|
||||
self.action_space
|
||||
}
|
||||
|
||||
pub fn observation_space(&self) -> &[usize] {
|
||||
&self.observation_space
|
||||
}
|
||||
}
|
40
candle-examples/examples/replit-code/README.md
Normal file
40
candle-examples/examples/replit-code/README.md
Normal file
@ -0,0 +1,40 @@
|
||||
# candle-replit-code: code completion specialized model.
|
||||
|
||||
[replit-code-v1_5-3b](https://huggingface.co/replit/replit-code-v1_5-3b) is a
|
||||
language model specialized for code completion. This model uses 3.3B parameters
|
||||
in `bfloat16` (so the GPU version will only work on recent nvidia cards).
|
||||
|
||||
## Running some example
|
||||
|
||||
```bash
|
||||
cargo run --example replit-code --release -- --prompt 'def fibonacci(n): '
|
||||
```
|
||||
This produces the following output.
|
||||
|
||||
```
|
||||
def fibonacci(n): # write Fibonacci series up to n
|
||||
"""Print a Fibonacci series up to n."""
|
||||
a, b = 0, 1
|
||||
while a < n:
|
||||
print(a, end=' ')
|
||||
a, b = b, a+b
|
||||
print()
|
||||
|
||||
|
||||
def fibonacci_loop(n): # write Fibonacci series up to n
|
||||
"""Print a Fibonacci series up to n."""
|
||||
result = []
|
||||
a, b = 0, 1
|
||||
while a < n:
|
||||
result.append(a)
|
||||
a, b = b, a+b
|
||||
return result
|
||||
|
||||
|
||||
def fibonacci_generator(n): # write Fibonacci series up to n
|
||||
"""Print a Fibonacci series up to n."""
|
||||
a, b = 0, 1
|
||||
while a < n:
|
||||
yield a
|
||||
a, b = b, a+b
|
||||
```
|
265
candle-examples/examples/replit-code/main.rs
Normal file
265
candle-examples/examples/replit-code/main.rs
Normal file
@ -0,0 +1,265 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::mpt::{Config, Model as M};
|
||||
use candle_transformers::models::quantized_mpt::Model as Q;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
M(M),
|
||||
Q(Q),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&mut self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::M(model) => model.forward(xs),
|
||||
Self::Q(model) => model.forward(xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
verbose_prompt,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
println!("starting the inference loop");
|
||||
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
||||
if tokens.is_empty() {
|
||||
anyhow::bail!("Empty prompts are not supported in the phi model.")
|
||||
}
|
||||
if self.verbose_prompt {
|
||||
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
||||
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
println!("{id:7} -> '{token}'");
|
||||
}
|
||||
}
|
||||
let mut tokens = tokens.get_ids().to_vec();
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
||||
Some(token) => *token,
|
||||
None => anyhow::bail!("cannot find the endoftext token"),
|
||||
};
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush()?;
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input)?;
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||
print!("{token}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Display the token for the specified prompt.
|
||||
#[arg(long)]
|
||||
verbose_prompt: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 1000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
#[arg(long)]
|
||||
weight_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id.to_string(),
|
||||
None => "lmz/candle-replit-code".to_string(),
|
||||
};
|
||||
let revision = match args.revision {
|
||||
Some(rev) => rev.to_string(),
|
||||
None => "main".to_string(),
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
let tokenizer_filename = match args.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let filename = match args.weight_file {
|
||||
Some(weight_file) => std::path::PathBuf::from(weight_file),
|
||||
None => {
|
||||
if args.quantized {
|
||||
repo.get("model-replit-code-v1_5-q4k.gguf")?
|
||||
} else {
|
||||
repo.get("model.safetensors")?
|
||||
}
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::replit_code_v1_5_3b();
|
||||
let (model, device) = if args.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
|
||||
let model = Model::Q(Q::new(&config, vb.pp("transformer"))?);
|
||||
(model, Device::Cpu)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
|
||||
let model = Model::M(M::new(&config, vb.pp("transformer"))?);
|
||||
(model, device)
|
||||
};
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
args.verbose_prompt,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
19
candle-examples/examples/resnet/README.md
Normal file
19
candle-examples/examples/resnet/README.md
Normal file
@ -0,0 +1,19 @@
|
||||
# candle-resnet
|
||||
|
||||
A candle implementation of inference using a pre-trained [ResNet](https://arxiv.org/abs/1512.03385).
|
||||
This uses a classification head trained on the ImageNet dataset and returns the
|
||||
probabilities for the top-5 classes.
|
||||
|
||||
## Running an example
|
||||
|
||||
```
|
||||
$ cargo run --example resnet --release -- --image tiger.jpg
|
||||
|
||||
loaded image Tensor[dims 3, 224, 224; f32]
|
||||
model built
|
||||
tiger, Panthera tigris : 90.21%
|
||||
tiger cat : 8.93%
|
||||
lion, king of beasts, Panthera leo: 0.35%
|
||||
leopard, Panthera pardus: 0.16%
|
||||
jaguar, panther, Panthera onca, Felis onca: 0.09%
|
||||
```
|
12
candle-examples/examples/resnet/export_models.py
Normal file
12
candle-examples/examples/resnet/export_models.py
Normal file
@ -0,0 +1,12 @@
|
||||
# This script exports pre-trained model weights in the safetensors format.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
from safetensors import torch as stt
|
||||
|
||||
m = torchvision.models.resnet50(pretrained=True)
|
||||
stt.save_file(m.state_dict(), 'resnet50.safetensors')
|
||||
m = torchvision.models.resnet101(pretrained=True)
|
||||
stt.save_file(m.state_dict(), 'resnet101.safetensors')
|
||||
m = torchvision.models.resnet152(pretrained=True)
|
||||
stt.save_file(m.state_dict(), 'resnet152.safetensors')
|
90
candle-examples/examples/resnet/main.rs
Normal file
90
candle-examples/examples/resnet/main.rs
Normal file
@ -0,0 +1,90 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::resnet;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "18")]
|
||||
Resnet18,
|
||||
#[value(name = "34")]
|
||||
Resnet34,
|
||||
#[value(name = "50")]
|
||||
Resnet50,
|
||||
#[value(name = "101")]
|
||||
Resnet101,
|
||||
#[value(name = "152")]
|
||||
Resnet152,
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Variant of the model to use.
|
||||
#[arg(value_enum, long, default_value_t = Which::Resnet18)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("lmz/candle-resnet".into());
|
||||
let filename = match args.which {
|
||||
Which::Resnet18 => "resnet18.safetensors",
|
||||
Which::Resnet34 => "resnet34.safetensors",
|
||||
Which::Resnet50 => "resnet50.safetensors",
|
||||
Which::Resnet101 => "resnet101.safetensors",
|
||||
Which::Resnet152 => "resnet152.safetensors",
|
||||
};
|
||||
api.get(filename)?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let class_count = candle_examples::imagenet::CLASS_COUNT as usize;
|
||||
let model = match args.which {
|
||||
Which::Resnet18 => resnet::resnet18(class_count, vb)?,
|
||||
Which::Resnet34 => resnet::resnet34(class_count, vb)?,
|
||||
Which::Resnet50 => resnet::resnet50(class_count, vb)?,
|
||||
Which::Resnet101 => resnet::resnet101(class_count, vb)?,
|
||||
Which::Resnet152 => resnet::resnet152(class_count, vb)?,
|
||||
};
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -16,25 +16,29 @@ based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
|
||||
cargo run --example segment-anything --release -- \
|
||||
--image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
--use-tiny
|
||||
--point-x 0.4
|
||||
--point-y 0.3
|
||||
--point 0.6,0.6 --point 0.6,0.55
|
||||
```
|
||||
|
||||
Running this command generates a `sam_merged.jpg` file containing the original
|
||||
image with a blue overlay of the selected mask. The red dot represents the prompt
|
||||
specified by `--point-x 0.4 --point-y 0.3`, this prompt is assumed to be part
|
||||
image with a blue overlay of the selected mask. The red dots represent the prompt
|
||||
specified by `--point 0.6,0.6 --point 0.6,0.55`, this prompt is assumed to be part
|
||||
of the target mask.
|
||||
|
||||
The values used for `--point-x` and `--point-y` should be between 0 and 1 and
|
||||
are proportional to the image dimension, i.e. use 0.5 for the image center.
|
||||
The values used for `--point` should be a comma delimited pair of float values.
|
||||
They are proportional to the image dimension, i.e. use 0.5 for the image center.
|
||||
|
||||
Original image:
|
||||

|
||||
|
||||

|
||||
Segment results by prompting with a single point `--point 0.6,0.55`:
|
||||

|
||||
|
||||
Segment results by prompting with multiple points `--point 0.6,0.6 --point 0.6,0.55`:
|
||||

|
||||
|
||||
### Command-line flags
|
||||
- `--use-tiny`: use the TinyViT based MobileSAM backbone rather than the default
|
||||
one.
|
||||
- `--point-x`, `--point-y`: specifies the location of the target point.
|
||||
- `--point`: specifies the location of the target points.
|
||||
- `--threshold`: sets the threshold value to be part of the mask, a negative
|
||||
value results in a larger mask and can be specified via `--threshold=-1.2`.
|
||||
|
Binary file not shown.
After Width: | Height: | Size: 158 KiB |
Binary file not shown.
After Width: | Height: | Size: 158 KiB |
@ -27,13 +27,15 @@ struct Args {
|
||||
#[arg(long)]
|
||||
generate_masks: bool,
|
||||
|
||||
/// The target point x coordinate, between 0 and 1 (0.5 is at the middle of the image).
|
||||
#[arg(long, default_value_t = 0.5)]
|
||||
point_x: f64,
|
||||
/// List of x,y coordinates, between 0 and 1 (0.5 is at the middle of the image). These points
|
||||
/// should be part of the generated mask.
|
||||
#[arg(long)]
|
||||
point: Vec<String>,
|
||||
|
||||
/// The target point y coordinate, between 0 and 1 (0.5 is at the middle of the image).
|
||||
#[arg(long, default_value_t = 0.5)]
|
||||
point_y: f64,
|
||||
/// List of x,y coordinates, between 0 and 1 (0.5 is at the middle of the image). These points
|
||||
/// should not be part of the generated mask and should be part of the background instead.
|
||||
#[arg(long)]
|
||||
neg_point: Vec<String>,
|
||||
|
||||
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
|
||||
/// mask, positive makes the mask more selective.
|
||||
@ -82,9 +84,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
api.get(filename)?
|
||||
}
|
||||
};
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||
let sam = if args.use_tiny {
|
||||
sam::Sam::new_tiny(vb)? // tiny vit_t
|
||||
} else {
|
||||
@ -113,15 +113,27 @@ pub fn main() -> anyhow::Result<()> {
|
||||
)?;
|
||||
}
|
||||
} else {
|
||||
let point = Some((args.point_x, args.point_y));
|
||||
let iter_points = args.point.iter().map(|p| (p, true));
|
||||
let iter_neg_points = args.neg_point.iter().map(|p| (p, false));
|
||||
let points = iter_points
|
||||
.chain(iter_neg_points)
|
||||
.map(|(point, b)| {
|
||||
use std::str::FromStr;
|
||||
let xy = point.split(',').collect::<Vec<_>>();
|
||||
if xy.len() != 2 {
|
||||
anyhow::bail!("expected format for points is 0.4,0.2")
|
||||
}
|
||||
Ok((f64::from_str(xy[0])?, f64::from_str(xy[1])?, b))
|
||||
})
|
||||
.collect::<anyhow::Result<Vec<_>>>()?;
|
||||
let start_time = std::time::Instant::now();
|
||||
let (mask, iou_predictions) = sam.forward(&image, point, false)?;
|
||||
let (mask, iou_predictions) = sam.forward(&image, &points, false)?;
|
||||
println!(
|
||||
"mask generated in {:.2}s",
|
||||
start_time.elapsed().as_secs_f32()
|
||||
);
|
||||
println!("mask:\n{mask}");
|
||||
println!("iou_predictions: {iou_predictions:?}");
|
||||
println!("iou_predictions: {iou_predictions}");
|
||||
|
||||
let mask = (mask.ge(args.threshold)? * 255.)?;
|
||||
let (_one, h, w) = mask.dims3()?;
|
||||
@ -153,12 +165,17 @@ pub fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
}
|
||||
let (x, y) = (
|
||||
(args.point_x * img.width() as f64) as i32,
|
||||
(args.point_y * img.height() as f64) as i32,
|
||||
);
|
||||
imageproc::drawing::draw_filled_circle(&img, (x, y), 3, image::Rgba([255, 0, 0, 200]))
|
||||
.save("sam_merged.jpg")?
|
||||
for (x, y, b) in points {
|
||||
let x = (x * img.width() as f64) as i32;
|
||||
let y = (y * img.height() as f64) as i32;
|
||||
let color = if b {
|
||||
image::Rgba([255, 0, 0, 200])
|
||||
} else {
|
||||
image::Rgba([0, 255, 0, 200])
|
||||
};
|
||||
imageproc::drawing::draw_filled_circle_mut(&mut img, (x, y), 3, color);
|
||||
}
|
||||
img.save("sam_merged.jpg")?
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -50,6 +50,9 @@ cached.
|
||||
Enabling flash-attention requires both a feature flag, `--feature flash-attn`
|
||||
and using the command line flag `--use-flash-attn`.
|
||||
|
||||
Note that flash-attention-v2 is only compatible with Ampere, Ada, or Hopper GPUs
|
||||
(e.g., A100/H100, RTX 3090/4090).
|
||||
|
||||
## Image to Image Pipeline
|
||||
...
|
||||
|
||||
|
@ -97,14 +97,13 @@ struct Args {
|
||||
img2img_strength: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
|
||||
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
|
||||
enum StableDiffusionVersion {
|
||||
V1_5,
|
||||
V2_1,
|
||||
Xl,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum ModelFile {
|
||||
Tokenizer,
|
||||
@ -204,7 +203,18 @@ impl ModelFile {
|
||||
Self::Clip => (version.repo(), version.clip_file(use_f16)),
|
||||
Self::Clip2 => (version.repo(), version.clip2_file(use_f16)),
|
||||
Self::Unet => (version.repo(), version.unet_file(use_f16)),
|
||||
Self::Vae => (version.repo(), version.vae_file(use_f16)),
|
||||
Self::Vae => {
|
||||
// Override for SDXL when using f16 weights.
|
||||
// See https://github.com/huggingface/candle/issues/1060
|
||||
if version == StableDiffusionVersion::Xl && use_f16 {
|
||||
(
|
||||
"madebyollin/sdxl-vae-fp16-fix",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
)
|
||||
} else {
|
||||
(version.repo(), version.vae_file(use_f16))
|
||||
}
|
||||
}
|
||||
};
|
||||
let filename = Api::new()?.model(repo.to_string()).get(path)?;
|
||||
Ok(filename)
|
||||
@ -484,9 +494,8 @@ fn run(args: Args) -> Result<()> {
|
||||
num_samples
|
||||
);
|
||||
let image = vae.decode(&(&latents / 0.18215)?)?;
|
||||
// TODO: Add the clamping between 0 and 1.
|
||||
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
||||
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
||||
let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
||||
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
|
||||
candle_examples::save_image(&image, image_filename)?
|
||||
}
|
||||
|
25
candle-examples/examples/stable-lm/README.md
Normal file
25
candle-examples/examples/stable-lm/README.md
Normal file
@ -0,0 +1,25 @@
|
||||
# candle-stable-lm
|
||||
|
||||
StableLM-3B-4E1T is a 3 billion parameter decoder-only language model
|
||||
pre-trained on 1 trillion tokens of diverse English and code datasets for 4
|
||||
epochs. See the [HuggingFace Hub Model
|
||||
Card](https://huggingface.co/stabilityai/stablelm-3b-4e1t).
|
||||
|
||||
Note that this model is gated so you will have to request access on the Hub in
|
||||
order to be able to use it.
|
||||
|
||||
## Running some example
|
||||
|
||||
```bash
|
||||
$ cargo run --example stable-lm --release --features cuda -- --prompt 'What is the most efficient programming language in use?' --sample-len 150
|
||||
avx: true, neon: false, simd128: false, f16c: true
|
||||
temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
|
||||
retrieved the files in 126.593µs
|
||||
loaded the model in 3.474148965s
|
||||
What is the most efficient programming language in use?
|
||||
The answer to this question depends on what you mean by "efficient". If you're talking about speed, then C++ and Java are probably your best bets. But if you're talking about ease of development, then Python is probably the way to go.
|
||||
Python is a high-level, interpreted language that is easy to learn and use. It has a large community of developers who are always working on new features and improvements.
|
||||
C++ is a low-level, compiled language that can be used for both desktop applications and web development. It's more difficult to learn than Python but offers greater control over the code.
|
||||
Java is another high-level language that is popular with programmers because it runs on many different platforms (including Android phones
|
||||
150 tokens generated (37.61 token/s)
|
||||
```
|
268
candle-examples/examples/stable-lm/main.rs
Normal file
268
candle-examples/examples/stable-lm/main.rs
Normal file
@ -0,0 +1,268 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::quantized_stable_lm::Model as QStableLM;
|
||||
use candle_transformers::models::stable_lm::{Config, Model as StableLM};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
StableLM(StableLM),
|
||||
Quantized(QStableLM),
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = match &mut self.model {
|
||||
Model::StableLM(m) => m.forward(&input, start_pos)?,
|
||||
Model::Quantized(m) => m.forward(&input, start_pos)?,
|
||||
};
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long, default_value = "lmz/candle-stablelm-3b-4e1t")]
|
||||
model_id: String,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => {
|
||||
if args.quantized {
|
||||
vec![repo.get("model-q4k.gguf")?]
|
||||
} else {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
|
||||
let (model, device) = if args.quantized {
|
||||
let filename = &filenames[0];
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
|
||||
let model = QStableLM::new(&config, vb)?;
|
||||
(Model::Quantized(model), Device::Cpu)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = StableLM::new(&config, vb)?;
|
||||
(Model::StableLM(model), device)
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -8,12 +8,12 @@ use std::path::PathBuf;
|
||||
|
||||
use candle_transformers::models::t5;
|
||||
|
||||
use anyhow::{anyhow, Error as E, Result};
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const DTYPE: DType = DType::F32;
|
||||
@ -25,10 +25,6 @@ struct Args {
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Run offline (you must have the files already cached)
|
||||
#[arg(long)]
|
||||
offline: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
@ -46,12 +42,16 @@ struct Args {
|
||||
|
||||
// Enable/disable decoding.
|
||||
#[arg(long, default_value = "false")]
|
||||
use_cache: bool,
|
||||
disable_cache: bool,
|
||||
|
||||
/// Use this prompt, otherwise compute sentence similarities.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// If set along with --decode, will use this prompt to initialize the decoder.
|
||||
#[arg(long)]
|
||||
decoder_prompt: Option<String>,
|
||||
|
||||
/// L2 normalization for embeddings.
|
||||
#[arg(long, default_value = "true")]
|
||||
normalize_embeddings: bool,
|
||||
@ -76,7 +76,7 @@ struct Args {
|
||||
struct T5ModelBuilder {
|
||||
device: Device,
|
||||
config: t5::Config,
|
||||
weights_filename: PathBuf,
|
||||
weights_filename: Vec<PathBuf>,
|
||||
}
|
||||
|
||||
impl T5ModelBuilder {
|
||||
@ -91,32 +91,25 @@ impl T5ModelBuilder {
|
||||
(None, None) => (default_model, default_revision),
|
||||
};
|
||||
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let (config_filename, tokenizer_filename, weights_filename) = if args.offline {
|
||||
let cache = Cache::default().repo(repo);
|
||||
(
|
||||
cache
|
||||
.get("config.json")
|
||||
.ok_or(anyhow!("Missing config file in cache"))?,
|
||||
cache
|
||||
.get("tokenizer.json")
|
||||
.ok_or(anyhow!("Missing tokenizer file in cache"))?,
|
||||
cache
|
||||
.get("model.safetensors")
|
||||
.ok_or(anyhow!("Missing weights file in cache"))?,
|
||||
)
|
||||
let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision);
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
let config_filename = api.get("config.json")?;
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let weights_filename = if model_id == "google/flan-t5-xxl" {
|
||||
vec![
|
||||
api.get("model-00001-of-00005.safetensors")?,
|
||||
api.get("model-00002-of-00005.safetensors")?,
|
||||
api.get("model-00003-of-00005.safetensors")?,
|
||||
api.get("model-00004-of-00005.safetensors")?,
|
||||
api.get("model-00005-of-00005.safetensors")?,
|
||||
]
|
||||
} else {
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
(
|
||||
api.get("config.json")?,
|
||||
api.get("tokenizer.json")?,
|
||||
api.get("model.safetensors")?,
|
||||
)
|
||||
vec![api.get("model.safetensors")?]
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let mut config: t5::Config = serde_json::from_str(&config)?;
|
||||
config.use_cache = args.use_cache;
|
||||
config.use_cache = !args.disable_cache;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
Ok((
|
||||
Self {
|
||||
@ -129,24 +122,34 @@ impl T5ModelBuilder {
|
||||
}
|
||||
|
||||
pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> {
|
||||
let weights =
|
||||
unsafe { candle::safetensors::MmapedFile::new(self.weights_filename.clone())? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &self.device);
|
||||
let vb = unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&self.weights_filename, DTYPE, &self.device)?
|
||||
};
|
||||
Ok(t5::T5EncoderModel::load(vb, &self.config)?)
|
||||
}
|
||||
|
||||
pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> {
|
||||
let weights =
|
||||
unsafe { candle::safetensors::MmapedFile::new(self.weights_filename.clone())? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &self.device);
|
||||
let vb = unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&self.weights_filename, DTYPE, &self.device)?
|
||||
};
|
||||
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
|
||||
let device = &builder.device;
|
||||
let tokenizer = tokenizer
|
||||
@ -170,6 +173,16 @@ fn main() -> Result<()> {
|
||||
} else {
|
||||
let mut model = builder.build_conditional_generation()?;
|
||||
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
||||
if let Some(decoder_prompt) = &args.decoder_prompt {
|
||||
print!("{decoder_prompt}");
|
||||
output_token_ids.extend(
|
||||
tokenizer
|
||||
.encode(decoder_prompt.to_string(), false)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec(),
|
||||
);
|
||||
}
|
||||
let temperature = if args.temperature <= 0. {
|
||||
None
|
||||
} else {
|
||||
@ -195,11 +208,11 @@ fn main() -> Result<()> {
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
|
||||
let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
&output_token_ids[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
@ -217,8 +230,8 @@ fn main() -> Result<()> {
|
||||
let dt = start.elapsed();
|
||||
println!(
|
||||
"\n{} tokens generated ({:.2} token/s)\n",
|
||||
tokens.len(),
|
||||
tokens.len() as f64 / dt.as_secs_f64(),
|
||||
output_token_ids.len(),
|
||||
output_token_ids.len() as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
20
candle-examples/examples/vit/README.md
Normal file
20
candle-examples/examples/vit/README.md
Normal file
@ -0,0 +1,20 @@
|
||||
# candle-vit
|
||||
|
||||
Vision Transformer (ViT) model implementation following the lines of
|
||||
[vit-base-patch16-224](https://huggingface.co/google/vit-base-patch16-224)
|
||||
This uses a classification head trained on the ImageNet dataset and returns the
|
||||
probabilities for the top-5 classes.
|
||||
|
||||
## Running an example
|
||||
|
||||
```
|
||||
$ cargo run --example vit --release -- --image tiger.jpg
|
||||
|
||||
loaded image Tensor[dims 3, 224, 224; f32]
|
||||
model built
|
||||
tiger, Panthera tigris : 100.00%
|
||||
tiger cat : 0.00%
|
||||
jaguar, panther, Panthera onca, Felis onca: 0.00%
|
||||
leopard, Panthera pardus: 0.00%
|
||||
lion, king of beasts, Panthera leo: 0.00%
|
||||
```
|
59
candle-examples/examples/vit/main.rs
Normal file
59
candle-examples/examples/vit/main.rs
Normal file
@ -0,0 +1,59 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::vit;
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("google/vit-base-patch16-224".into());
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = vit::Model::new(&vit::Config::vit_base_patch16_224(), 1000, vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -18,8 +18,48 @@ use rand::{distributions::Distribution, SeedableRng};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
mod multilingual;
|
||||
use candle_transformers::models::whisper::{self as m, audio, model};
|
||||
use model::{Config, Whisper};
|
||||
use candle_transformers::models::whisper::{self as m, audio, Config};
|
||||
|
||||
pub enum Model {
|
||||
Normal(m::model::Whisper),
|
||||
Quantized(m::quantized_model::Whisper),
|
||||
}
|
||||
|
||||
// Maybe we should use some traits rather than doing the dispatch for all these.
|
||||
impl Model {
|
||||
pub fn config(&self) -> &Config {
|
||||
match self {
|
||||
Self::Normal(m) => &m.config,
|
||||
Self::Quantized(m) => &m.config,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encoder_forward(&mut self, x: &Tensor, flush: bool) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Normal(m) => m.encoder.forward(x, flush),
|
||||
Self::Quantized(m) => m.encoder.forward(x, flush),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decoder_forward(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
xa: &Tensor,
|
||||
flush: bool,
|
||||
) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Normal(m) => m.decoder.forward(x, xa, flush),
|
||||
Self::Quantized(m) => m.decoder.forward(x, xa, flush),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decoder_final_linear(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Normal(m) => m.decoder.final_linear(x),
|
||||
Self::Quantized(m) => m.decoder.final_linear(x),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone)]
|
||||
@ -41,7 +81,7 @@ struct Segment {
|
||||
}
|
||||
|
||||
struct Decoder {
|
||||
model: Whisper,
|
||||
model: Model,
|
||||
rng: rand::rngs::StdRng,
|
||||
task: Option<Task>,
|
||||
timestamps: bool,
|
||||
@ -60,7 +100,7 @@ struct Decoder {
|
||||
impl Decoder {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Whisper,
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
device: &Device,
|
||||
@ -72,9 +112,9 @@ impl Decoder {
|
||||
let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
|
||||
// Suppress the notimestamps token when in timestamps mode.
|
||||
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452
|
||||
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
|
||||
let suppress_tokens: Vec<f32> = (0..model.config().vocab_size as u32)
|
||||
.map(|i| {
|
||||
if model.config.suppress_tokens.contains(&i)
|
||||
if model.config().suppress_tokens.contains(&i)
|
||||
|| timestamps && i == no_timestamps_token
|
||||
{
|
||||
f32::NEG_INFINITY
|
||||
@ -109,11 +149,11 @@ impl Decoder {
|
||||
|
||||
fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
|
||||
let model = &mut self.model;
|
||||
let audio_features = model.encoder.forward(mel, true)?;
|
||||
let audio_features = model.encoder_forward(mel, true)?;
|
||||
if self.verbose {
|
||||
println!("audio features: {:?}", audio_features.dims());
|
||||
}
|
||||
let sample_len = model.config.max_target_positions / 2;
|
||||
let sample_len = model.config().max_target_positions / 2;
|
||||
let mut sum_logprob = 0f64;
|
||||
let mut no_speech_prob = f64::NAN;
|
||||
let mut tokens = vec![self.sot_token];
|
||||
@ -133,12 +173,12 @@ impl Decoder {
|
||||
// The model expects a batch dim but this inference loop does not handle
|
||||
// it so we add it at this point.
|
||||
let tokens_t = tokens_t.unsqueeze(0)?;
|
||||
let ys = model.decoder.forward(&tokens_t, &audio_features, i == 0)?;
|
||||
let ys = model.decoder_forward(&tokens_t, &audio_features, i == 0)?;
|
||||
|
||||
// Extract the no speech probability on the first iteration by looking at the first
|
||||
// token logits and the probability for the according token.
|
||||
if i == 0 {
|
||||
let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
||||
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
||||
no_speech_prob = softmax(&logits, 0)?
|
||||
.i(self.no_speech_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
@ -146,8 +186,7 @@ impl Decoder {
|
||||
|
||||
let (_, seq_len, _) = ys.dims3()?;
|
||||
let logits = model
|
||||
.decoder
|
||||
.final_linear(&ys.i((..1, seq_len - 1..))?)?
|
||||
.decoder_final_linear(&ys.i((..1, seq_len - 1..))?)?
|
||||
.i(0)?
|
||||
.i(0)?;
|
||||
// TODO: Besides suppress tokens, we should apply the heuristics from
|
||||
@ -176,7 +215,7 @@ impl Decoder {
|
||||
let prob = softmax(&logits, candle::D::Minus1)?
|
||||
.i(next_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
if next_token == self.eot_token || tokens.len() > model.config.max_target_positions {
|
||||
if next_token == self.eot_token || tokens.len() > model.config().max_target_positions {
|
||||
break;
|
||||
}
|
||||
sum_logprob += prob.ln();
|
||||
@ -333,6 +372,7 @@ impl WhichModel {
|
||||
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn model_and_revision(&self) -> (&'static str, &'static str) {
|
||||
match self {
|
||||
Self::Tiny => ("openai/whisper-tiny", "main"),
|
||||
@ -382,6 +422,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Language.
|
||||
#[arg(long)]
|
||||
language: Option<String>,
|
||||
@ -413,10 +456,13 @@ fn main() -> Result<()> {
|
||||
None
|
||||
};
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (default_model, default_revision) = args.model.model_and_revision();
|
||||
let (default_model, default_revision) = if args.quantized {
|
||||
("lmz/candle-whisper", "main")
|
||||
} else {
|
||||
args.model.model_and_revision()
|
||||
};
|
||||
let default_model = default_model.to_string();
|
||||
let default_revision = default_revision.to_string();
|
||||
let path = std::path::PathBuf::from(default_model.clone());
|
||||
let (model_id, revision) = match (args.model_id, args.revision) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||
@ -424,20 +470,7 @@ fn main() -> Result<()> {
|
||||
(None, None) => (default_model, default_revision),
|
||||
};
|
||||
|
||||
let (config_filename, tokenizer_filename, weights_filename, input) = if path.exists() {
|
||||
let mut config_filename = path.clone();
|
||||
config_filename.push("config.json");
|
||||
let mut tokenizer_filename = path.clone();
|
||||
tokenizer_filename.push("tokenizer.json");
|
||||
let mut model_filename = path;
|
||||
model_filename.push("model.safetensors");
|
||||
(
|
||||
config_filename,
|
||||
tokenizer_filename,
|
||||
model_filename,
|
||||
std::path::PathBuf::from(args.input.expect("You didn't specify a file to read from yet, are using a local model, please add `--input example.wav` to read some audio file")),
|
||||
)
|
||||
} else {
|
||||
let (config_filename, tokenizer_filename, weights_filename, input) = {
|
||||
let api = Api::new()?;
|
||||
let dataset = api.dataset("Narsil/candle-examples".to_string());
|
||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
@ -451,12 +484,25 @@ fn main() -> Result<()> {
|
||||
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
|
||||
dataset.get("samples_jfk.wav")?
|
||||
};
|
||||
(
|
||||
repo.get("config.json")?,
|
||||
repo.get("tokenizer.json")?,
|
||||
repo.get("model.safetensors")?,
|
||||
sample,
|
||||
)
|
||||
let (config, tokenizer, model) = if args.quantized {
|
||||
let ext = match args.model {
|
||||
WhichModel::TinyEn => "tiny-en",
|
||||
WhichModel::Tiny => "tiny",
|
||||
_ => unimplemented!("no quantized support for {:?}", args.model),
|
||||
};
|
||||
(
|
||||
repo.get(&format!("config-{ext}.json"))?,
|
||||
repo.get(&format!("tokenizer-{ext}.json"))?,
|
||||
repo.get(&format!("model-{ext}-q80.gguf"))?,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
repo.get("config.json")?,
|
||||
repo.get("tokenizer.json")?,
|
||||
repo.get("model.safetensors")?,
|
||||
)
|
||||
};
|
||||
(config, tokenizer, model, sample)
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
@ -481,11 +527,16 @@ fn main() -> Result<()> {
|
||||
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
|
||||
println!("loaded mel: {:?}", mel.dims());
|
||||
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], m::DTYPE, &device);
|
||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||
let mut model = Whisper::load(&vb, config)?;
|
||||
let mut model = if args.quantized {
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;
|
||||
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
|
||||
} else {
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
|
||||
Model::Normal(m::model::Whisper::load(&vb, config)?)
|
||||
};
|
||||
|
||||
let language_token = match (args.model.is_multilingual(), args.language) {
|
||||
(true, None) => Some(multilingual::detect_language(&mut model, &tokenizer, &mel)?),
|
||||
|
@ -1,4 +1,3 @@
|
||||
use crate::Whisper;
|
||||
use candle::{IndexOp, Result, Tensor, D};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
@ -105,20 +104,28 @@ const LANGUAGES: [(&str, &str); 99] = [
|
||||
];
|
||||
|
||||
/// Returns the token id for the selected language.
|
||||
pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32> {
|
||||
pub fn detect_language(
|
||||
model: &mut super::Model,
|
||||
tokenizer: &Tokenizer,
|
||||
mel: &Tensor,
|
||||
) -> Result<u32> {
|
||||
let (_bsize, _, seq_len) = mel.dims3()?;
|
||||
let mel = mel.narrow(2, 0, usize::min(seq_len, model.config.max_source_positions))?;
|
||||
let mel = mel.narrow(
|
||||
2,
|
||||
0,
|
||||
usize::min(seq_len, model.config().max_source_positions),
|
||||
)?;
|
||||
let device = mel.device();
|
||||
let language_token_ids = LANGUAGES
|
||||
.iter()
|
||||
.map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>")))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let sot_token = crate::token_id(tokenizer, crate::m::SOT_TOKEN)?;
|
||||
let audio_features = model.encoder.forward(&mel, true)?;
|
||||
let audio_features = model.encoder_forward(&mel, true)?;
|
||||
let tokens = Tensor::new(&[[sot_token]], device)?;
|
||||
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
|
||||
let ys = model.decoder.forward(&tokens, &audio_features, true)?;
|
||||
let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
||||
let ys = model.decoder_forward(&tokens, &audio_features, true)?;
|
||||
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
||||
let logits = logits.index_select(&language_token_ids, 0)?;
|
||||
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
|
||||
let probs = probs.to_vec1::<f32>()?;
|
||||
|
27
candle-examples/examples/wuerstchen/README.md
Normal file
27
candle-examples/examples/wuerstchen/README.md
Normal file
@ -0,0 +1,27 @@
|
||||
# candle-wuerstchen: Efficient Pretraining of Text-to-Image Models
|
||||
|
||||

|
||||
|
||||
The `wuerstchen` example is a port of the [diffusers
|
||||
implementation](https://github.com/huggingface/diffusers/tree/19edca82f1ff194c07317369a92b470dbae97f34/src/diffusers/pipelines/wuerstchen) for Würstchen v2.
|
||||
The candle implementation reproduces the same structure/files for models and
|
||||
pipelines. Useful resources:
|
||||
|
||||
- [Official implementation](https://github.com/dome272/Wuerstchen).
|
||||
- [Arxiv paper](https://arxiv.org/abs/2306.00637).
|
||||
- Blog post: [Introducing Würstchen: Fast Diffusion for Image Generation](https://huggingface.co/blog/wuerstchen).
|
||||
|
||||
## Getting the weights
|
||||
|
||||
The weights are automatically downloaded for you from the [HuggingFace
|
||||
Hub](https://huggingface.co/) on the first run. There are various command line
|
||||
flags to use local files instead, run with `--help` to learn about them.
|
||||
|
||||
## Running some example.
|
||||
|
||||
```bash
|
||||
cargo run --example wuerstchen --release --features cuda,cudnn -- \
|
||||
--prompt "Anthropomorphic cat dressed as a fire fighter"
|
||||
```
|
||||
|
||||
The final image is named `sd_final.png` by default.
|
BIN
candle-examples/examples/wuerstchen/assets/cat.jpg
Normal file
BIN
candle-examples/examples/wuerstchen/assets/cat.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 38 KiB |
@ -1,5 +1,3 @@
|
||||
#![allow(unused)]
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
@ -10,11 +8,11 @@ use candle_transformers::models::stable_diffusion;
|
||||
use candle_transformers::models::wuerstchen;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, IndexOp, Module, Tensor, D};
|
||||
use candle::{DType, Device, IndexOp, Tensor};
|
||||
use clap::Parser;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const PRIOR_GUIDANCE_SCALE: f64 = 8.0;
|
||||
const PRIOR_GUIDANCE_SCALE: f64 = 4.0;
|
||||
const RESOLUTION_MULTIPLE: f64 = 42.67;
|
||||
const LATENT_DIM_SCALE: f64 = 10.67;
|
||||
const PRIOR_CIN: usize = 16;
|
||||
@ -41,6 +39,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
/// The height in pixels of the generated image.
|
||||
#[arg(long)]
|
||||
height: Option<usize>,
|
||||
@ -77,14 +78,6 @@ struct Args {
|
||||
/// The file specifying the tokenizer to used for prior tokenization.
|
||||
prior_tokenizer: Option<String>,
|
||||
|
||||
/// The size of the sliced attention or 0 for automatic slicing (disabled by default)
|
||||
#[arg(long)]
|
||||
sliced_attention_size: Option<usize>,
|
||||
|
||||
/// The number of steps to run the diffusion for.
|
||||
#[arg(long, default_value_t = 30)]
|
||||
n_steps: usize,
|
||||
|
||||
/// The number of samples to generate.
|
||||
#[arg(long, default_value_t = 1)]
|
||||
num_samples: i64,
|
||||
@ -217,10 +210,8 @@ fn run(args: Args) -> Result<()> {
|
||||
cpu,
|
||||
height,
|
||||
width,
|
||||
n_steps,
|
||||
tokenizer,
|
||||
final_image,
|
||||
sliced_attention_size,
|
||||
num_samples,
|
||||
clip_weights,
|
||||
prior_weights,
|
||||
@ -284,23 +275,27 @@ fn run(args: Args) -> Result<()> {
|
||||
)?;
|
||||
|
||||
let prior = {
|
||||
let prior_weights = ModelFile::Prior.get(prior_weights)?;
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let file = ModelFile::Prior.get(prior_weights)?;
|
||||
let vb = unsafe {
|
||||
candle_nn::VarBuilder::from_mmaped_safetensors(&[file], DType::F32, &device)?
|
||||
};
|
||||
wuerstchen::prior::WPrior::new(
|
||||
/* c_in */ PRIOR_CIN, /* c */ 1536, /* c_cond */ 1280,
|
||||
/* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb,
|
||||
/* c_in */ PRIOR_CIN,
|
||||
/* c */ 1536,
|
||||
/* c_cond */ 1280,
|
||||
/* c_r */ 64,
|
||||
/* depth */ 32,
|
||||
/* nhead */ 24,
|
||||
args.use_flash_attn,
|
||||
vb,
|
||||
)?
|
||||
};
|
||||
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
||||
let timesteps = prior_scheduler.timesteps();
|
||||
let timesteps = ×teps[..timesteps.len() - 1];
|
||||
println!("prior denoising");
|
||||
for (index, &t) in timesteps.iter().enumerate() {
|
||||
let start_time = std::time::Instant::now();
|
||||
if index == timesteps.len() - 1 {
|
||||
continue;
|
||||
}
|
||||
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
|
||||
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
|
||||
let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?;
|
||||
@ -317,10 +312,10 @@ fn run(args: Args) -> Result<()> {
|
||||
|
||||
println!("Building the vqgan.");
|
||||
let vqgan = {
|
||||
let vqgan_weights = ModelFile::VqGan.get(vqgan_weights)?;
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(vqgan_weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let file = ModelFile::VqGan.get(vqgan_weights)?;
|
||||
let vb = unsafe {
|
||||
candle_nn::VarBuilder::from_mmaped_safetensors(&[file], DType::F32, &device)?
|
||||
};
|
||||
wuerstchen::paella_vq::PaellaVQ::new(vb)?
|
||||
};
|
||||
|
||||
@ -328,10 +323,10 @@ fn run(args: Args) -> Result<()> {
|
||||
|
||||
// https://huggingface.co/warp-ai/wuerstchen/blob/main/decoder/config.json
|
||||
let decoder = {
|
||||
let decoder_weights = ModelFile::Decoder.get(decoder_weights)?;
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(decoder_weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let file = ModelFile::Decoder.get(decoder_weights)?;
|
||||
let vb = unsafe {
|
||||
candle_nn::VarBuilder::from_mmaped_safetensors(&[file], DType::F32, &device)?
|
||||
};
|
||||
wuerstchen::diffnext::WDiffNeXt::new(
|
||||
/* c_in */ DECODER_CIN,
|
||||
/* c_out */ DECODER_CIN,
|
||||
@ -339,6 +334,7 @@ fn run(args: Args) -> Result<()> {
|
||||
/* c_cond */ 1024,
|
||||
/* clip_embd */ 1024,
|
||||
/* patch_size */ 2,
|
||||
args.use_flash_attn,
|
||||
vb,
|
||||
)?
|
||||
};
|
||||
@ -356,13 +352,11 @@ fn run(args: Args) -> Result<()> {
|
||||
)?;
|
||||
|
||||
println!("diffusion process with prior {image_embeddings:?}");
|
||||
let scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
||||
let scheduler = wuerstchen::ddpm::DDPMWScheduler::new(12, Default::default())?;
|
||||
let timesteps = scheduler.timesteps();
|
||||
let timesteps = ×teps[..timesteps.len() - 1];
|
||||
for (index, &t) in timesteps.iter().enumerate() {
|
||||
let start_time = std::time::Instant::now();
|
||||
if index == timesteps.len() - 1 {
|
||||
continue;
|
||||
}
|
||||
let ratio = (Tensor::ones(1, DType::F32, &device)? * t)?;
|
||||
let noise_pred =
|
||||
decoder.forward(&latents, &ratio, &image_embeddings, Some(&text_embeddings))?;
|
||||
@ -376,9 +370,9 @@ fn run(args: Args) -> Result<()> {
|
||||
num_samples
|
||||
);
|
||||
let image = vqgan.decode(&(&latents * 0.3764)?)?;
|
||||
// TODO: Add the clamping between 0 and 1.
|
||||
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
||||
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
||||
let image = (image.clamp(0f32, 1f32)? * 255.)?
|
||||
.to_dtype(DType::U8)?
|
||||
.i(0)?;
|
||||
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
|
||||
candle_examples::save_image(&image, image_filename)?
|
||||
}
|
||||
|
@ -108,7 +108,7 @@ pub fn parse_config<T: AsRef<Path>>(path: T) -> Result<Darknet> {
|
||||
}
|
||||
|
||||
enum Bl {
|
||||
Layer(Box<dyn candle_nn::Module + Send>),
|
||||
Layer(Box<dyn candle_nn::Module + Send + Sync>),
|
||||
Route(Vec<usize>),
|
||||
Shortcut(usize),
|
||||
Yolo(usize, Vec<(usize, usize)>),
|
||||
|
@ -146,9 +146,7 @@ pub fn main() -> Result<()> {
|
||||
|
||||
// Create the model and load the weights from the file.
|
||||
let model = args.model()?;
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &Device::Cpu)? };
|
||||
let config = args.config()?;
|
||||
let darknet = darknet::parse_config(config)?;
|
||||
let model = darknet.build_model(vb)?;
|
||||
|
@ -7,7 +7,7 @@ extern crate accelerate_src;
|
||||
mod model;
|
||||
use model::{Multiples, YoloV8, YoloV8Pose};
|
||||
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||
use candle::{DType, IndexOp, Result, Tensor};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
|
||||
use clap::{Parser, ValueEnum};
|
||||
@ -253,6 +253,14 @@ enum YoloTask {
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
pub struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Model weights, in safetensors format.
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
@ -363,6 +371,7 @@ impl Task for YoloV8Pose {
|
||||
}
|
||||
|
||||
pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
// Create the model and load the weights from the file.
|
||||
let multiples = match args.which {
|
||||
Which::N => Multiples::n(),
|
||||
@ -372,9 +381,7 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
|
||||
Which::X => Multiples::x(),
|
||||
};
|
||||
let model = args.model()?;
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||
let model = T::load(vb, multiples)?;
|
||||
println!("model loaded");
|
||||
for image_name in args.images.iter() {
|
||||
@ -405,7 +412,7 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
|
||||
Tensor::from_vec(
|
||||
data,
|
||||
(img.height() as usize, img.width() as usize, 3),
|
||||
&Device::Cpu,
|
||||
&device,
|
||||
)?
|
||||
.permute((2, 0, 1))?
|
||||
};
|
||||
@ -430,7 +437,19 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
match args.task {
|
||||
YoloTask::Detect => run::<YoloV8>(args)?,
|
||||
YoloTask::Pose => run::<YoloV8Pose>(args)?,
|
||||
|
@ -77,6 +77,7 @@ impl Module for Upsample {
|
||||
struct ConvBlock {
|
||||
conv: Conv2d,
|
||||
bn: BatchNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl ConvBlock {
|
||||
@ -97,12 +98,17 @@ impl ConvBlock {
|
||||
};
|
||||
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
|
||||
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
||||
Ok(Self { conv, bn })
|
||||
Ok(Self {
|
||||
conv,
|
||||
bn,
|
||||
span: tracing::span!(tracing::Level::TRACE, "conv-block"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ConvBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = self.conv.forward(xs)?;
|
||||
let xs = self.bn.forward(&xs)?;
|
||||
candle_nn::ops::silu(&xs)
|
||||
@ -114,6 +120,7 @@ struct Bottleneck {
|
||||
cv1: ConvBlock,
|
||||
cv2: ConvBlock,
|
||||
residual: bool,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Bottleneck {
|
||||
@ -123,12 +130,18 @@ impl Bottleneck {
|
||||
let cv1 = ConvBlock::load(vb.pp("cv1"), c1, c_, 3, 1, None)?;
|
||||
let cv2 = ConvBlock::load(vb.pp("cv2"), c_, c2, 3, 1, None)?;
|
||||
let residual = c1 == c2 && shortcut;
|
||||
Ok(Self { cv1, cv2, residual })
|
||||
Ok(Self {
|
||||
cv1,
|
||||
cv2,
|
||||
residual,
|
||||
span: tracing::span!(tracing::Level::TRACE, "bottleneck"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Bottleneck {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let ys = self.cv2.forward(&self.cv1.forward(xs)?)?;
|
||||
if self.residual {
|
||||
xs + ys
|
||||
@ -143,6 +156,7 @@ struct C2f {
|
||||
cv1: ConvBlock,
|
||||
cv2: ConvBlock,
|
||||
bottleneck: Vec<Bottleneck>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl C2f {
|
||||
@ -159,12 +173,14 @@ impl C2f {
|
||||
cv1,
|
||||
cv2,
|
||||
bottleneck,
|
||||
span: tracing::span!(tracing::Level::TRACE, "c2f"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for C2f {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let ys = self.cv1.forward(xs)?;
|
||||
let mut ys = ys.chunk(2, 1)?;
|
||||
for m in self.bottleneck.iter() {
|
||||
@ -180,6 +196,7 @@ struct Sppf {
|
||||
cv1: ConvBlock,
|
||||
cv2: ConvBlock,
|
||||
k: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Sppf {
|
||||
@ -187,12 +204,18 @@ impl Sppf {
|
||||
let c_ = c1 / 2;
|
||||
let cv1 = ConvBlock::load(vb.pp("cv1"), c1, c_, 1, 1, None)?;
|
||||
let cv2 = ConvBlock::load(vb.pp("cv2"), c_ * 4, c2, 1, 1, None)?;
|
||||
Ok(Self { cv1, cv2, k })
|
||||
Ok(Self {
|
||||
cv1,
|
||||
cv2,
|
||||
k,
|
||||
span: tracing::span!(tracing::Level::TRACE, "sppf"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Sppf {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (_, _, _, _) = xs.dims4()?;
|
||||
let xs = self.cv1.forward(xs)?;
|
||||
let xs2 = xs
|
||||
@ -215,17 +238,23 @@ impl Module for Sppf {
|
||||
struct Dfl {
|
||||
conv: Conv2d,
|
||||
num_classes: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Dfl {
|
||||
fn load(vb: VarBuilder, num_classes: usize) -> Result<Self> {
|
||||
let conv = conv2d_no_bias(num_classes, 1, 1, Default::default(), vb.pp("conv"))?;
|
||||
Ok(Self { conv, num_classes })
|
||||
Ok(Self {
|
||||
conv,
|
||||
num_classes,
|
||||
span: tracing::span!(tracing::Level::TRACE, "dfl"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Dfl {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b_sz, _channels, anchors) = xs.dims3()?;
|
||||
let xs = xs
|
||||
.reshape((b_sz, 4, self.num_classes, anchors))?
|
||||
@ -247,6 +276,7 @@ struct DarkNet {
|
||||
b4_0: ConvBlock,
|
||||
b4_1: C2f,
|
||||
b5: Sppf,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl DarkNet {
|
||||
@ -330,10 +360,12 @@ impl DarkNet {
|
||||
b4_0,
|
||||
b4_1,
|
||||
b5,
|
||||
span: tracing::span!(tracing::Level::TRACE, "darknet"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
|
||||
let _enter = self.span.enter();
|
||||
let x1 = self.b1_1.forward(&self.b1_0.forward(xs)?)?;
|
||||
let x2 = self
|
||||
.b2_2
|
||||
@ -354,6 +386,7 @@ struct YoloV8Neck {
|
||||
n4: C2f,
|
||||
n5: ConvBlock,
|
||||
n6: C2f,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl YoloV8Neck {
|
||||
@ -413,10 +446,12 @@ impl YoloV8Neck {
|
||||
n4,
|
||||
n5,
|
||||
n6,
|
||||
span: tracing::span!(tracing::Level::TRACE, "neck"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, p3: &Tensor, p4: &Tensor, p5: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
|
||||
let _enter = self.span.enter();
|
||||
let x = self
|
||||
.n1
|
||||
.forward(&Tensor::cat(&[&self.up.forward(p5)?, p4], 1)?)?;
|
||||
@ -440,6 +475,7 @@ struct DetectionHead {
|
||||
cv3: [(ConvBlock, ConvBlock, Conv2d); 3],
|
||||
ch: usize,
|
||||
no: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -447,6 +483,7 @@ struct PoseHead {
|
||||
detect: DetectionHead,
|
||||
cv4: [(ConvBlock, ConvBlock, Conv2d); 3],
|
||||
kpt: (usize, usize),
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
fn make_anchors(
|
||||
@ -519,6 +556,7 @@ impl DetectionHead {
|
||||
cv3,
|
||||
ch,
|
||||
no,
|
||||
span: tracing::span!(tracing::Level::TRACE, "detection-head"),
|
||||
})
|
||||
}
|
||||
|
||||
@ -547,6 +585,7 @@ impl DetectionHead {
|
||||
}
|
||||
|
||||
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<DetectionHeadOut> {
|
||||
let _enter = self.span.enter();
|
||||
let forward_cv = |xs, i: usize| {
|
||||
let xs_2 = self.cv2[i].0.forward(xs)?;
|
||||
let xs_2 = self.cv2[i].1.forward(&xs_2)?;
|
||||
@ -606,7 +645,12 @@ impl PoseHead {
|
||||
Self::load_cv4(vb.pp("cv4.1"), c4, nk, filters.1)?,
|
||||
Self::load_cv4(vb.pp("cv4.2"), c4, nk, filters.2)?,
|
||||
];
|
||||
Ok(Self { detect, cv4, kpt })
|
||||
Ok(Self {
|
||||
detect,
|
||||
cv4,
|
||||
kpt,
|
||||
span: tracing::span!(tracing::Level::TRACE, "pose-head"),
|
||||
})
|
||||
}
|
||||
|
||||
fn load_cv4(
|
||||
@ -622,6 +666,7 @@ impl PoseHead {
|
||||
}
|
||||
|
||||
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let d = self.detect.forward(xs0, xs1, xs2)?;
|
||||
let forward_cv = |xs: &Tensor, i: usize| {
|
||||
let (b_sz, _, h, w) = xs.dims4()?;
|
||||
@ -650,6 +695,7 @@ pub struct YoloV8 {
|
||||
net: DarkNet,
|
||||
fpn: YoloV8Neck,
|
||||
head: DetectionHead,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl YoloV8 {
|
||||
@ -657,12 +703,18 @@ impl YoloV8 {
|
||||
let net = DarkNet::load(vb.pp("net"), m)?;
|
||||
let fpn = YoloV8Neck::load(vb.pp("fpn"), m)?;
|
||||
let head = DetectionHead::load(vb.pp("head"), num_classes, m.filters())?;
|
||||
Ok(Self { net, fpn, head })
|
||||
Ok(Self {
|
||||
net,
|
||||
fpn,
|
||||
head,
|
||||
span: tracing::span!(tracing::Level::TRACE, "yolo-v8"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for YoloV8 {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (xs1, xs2, xs3) = self.net.forward(xs)?;
|
||||
let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?;
|
||||
Ok(self.head.forward(&xs1, &xs2, &xs3)?.pred)
|
||||
@ -674,6 +726,7 @@ pub struct YoloV8Pose {
|
||||
net: DarkNet,
|
||||
fpn: YoloV8Neck,
|
||||
head: PoseHead,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl YoloV8Pose {
|
||||
@ -686,12 +739,18 @@ impl YoloV8Pose {
|
||||
let net = DarkNet::load(vb.pp("net"), m)?;
|
||||
let fpn = YoloV8Neck::load(vb.pp("fpn"), m)?;
|
||||
let head = PoseHead::load(vb.pp("head"), num_classes, kpt, m.filters())?;
|
||||
Ok(Self { net, fpn, head })
|
||||
Ok(Self {
|
||||
net,
|
||||
fpn,
|
||||
head,
|
||||
span: tracing::span!(tracing::Level::TRACE, "yolo-v8-pose"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for YoloV8Pose {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (xs1, xs2, xs3) = self.net.forward(xs)?;
|
||||
let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?;
|
||||
self.head.forward(&xs1, &xs2, &xs3)
|
||||
|
@ -1,5 +1,6 @@
|
||||
pub mod coco_classes;
|
||||
pub mod imagenet;
|
||||
pub mod token_output_stream;
|
||||
|
||||
use candle::{Device, Result, Tensor};
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user