mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Compare commits
233 Commits
einsum-cus
...
ddpg
Author | SHA1 | Date | |
---|---|---|---|
f524755634 | |||
29c7f2565d | |||
9309cfc47d | |||
a193bf5f60 | |||
2c110ac7d9 | |||
75989fc3b7 | |||
07af87a1d8 | |||
eefad2b95f | |||
5e6df4a3f7 | |||
7473c4ceca | |||
c096f02411 | |||
e7560443e4 | |||
89b525b5e7 | |||
37dbbff261 | |||
9fea56d28e | |||
bc3351bce4 | |||
b34d7f0248 | |||
4d04ac83c7 | |||
392fe02fba | |||
59ab6d7832 | |||
783735cf22 | |||
9abeddd750 | |||
2e5fb0b251 | |||
823fe23f9b | |||
d833527fda | |||
a4967600d0 | |||
aa53368aeb | |||
955e00b2e8 | |||
d5f7267087 | |||
904bbdae65 | |||
b0442eff8a | |||
4631c48273 | |||
716883e9b0 | |||
47c25a567b | |||
7f7d95e2c3 | |||
f47bd9bab5 | |||
8f7973958c | |||
f0c619a4af | |||
b86ac0c507 | |||
27e70a5093 | |||
c18a856e76 | |||
3349c89252 | |||
11d3687cc6 | |||
dac73edb34 | |||
b4da19d1be | |||
ff513314fc | |||
043cc25766 | |||
7b06872f90 | |||
65825e7240 | |||
7670fe7d1f | |||
cddfc3944c | |||
089fc3b584 | |||
e04c789230 | |||
263a172202 | |||
638ccf9f46 | |||
0baf5a1e19 | |||
5130a7da32 | |||
41143db1af | |||
096dee7073 | |||
f6054e9d60 | |||
328167ec04 | |||
4e55aaa51f | |||
deee7612da | |||
06207332bc | |||
4021272875 | |||
87e3a4e175 | |||
6203ced495 | |||
34842fb234 | |||
d188d6a764 | |||
0ac2db577b | |||
fc59bc31bf | |||
03348e2e6f | |||
49fa184a35 | |||
6f17ef82be | |||
01b92cd959 | |||
53510ce427 | |||
23b3576c47 | |||
716ab2ccdc | |||
ada8851a23 | |||
c05a348e36 | |||
25657804ef | |||
5e1c595e00 | |||
8a49e01b9d | |||
9cb110c44c | |||
667f01c173 | |||
e59784e353 | |||
29bd6b2979 | |||
9571b200c9 | |||
ce0a4e3a85 | |||
4abc1ea34d | |||
2dd43d6cdd | |||
1fcac4afed | |||
a084f65f9a | |||
c798184c2b | |||
c78a294323 | |||
a36d883254 | |||
7f2bbcf746 | |||
dc47224ab9 | |||
1ce7fe2543 | |||
402ddcfcb4 | |||
f5069dd354 | |||
0007ae9c11 | |||
e15862cfdb | |||
4aeb449017 | |||
bcb0ed8f1c | |||
7edd755756 | |||
e32c89d90c | |||
bb3471ea31 | |||
890d069092 | |||
5dbe46b389 | |||
ccf352f3d1 | |||
402d207f0f | |||
7582937a32 | |||
b54acfa3d0 | |||
cda1786eed | |||
912a3d63b0 | |||
3ef328c53d | |||
0c8e983514 | |||
df6f5240ba | |||
a46b1b4657 | |||
19e52e5007 | |||
8601537e31 | |||
4ac6039a42 | |||
52a60ca3ad | |||
a96878f235 | |||
aa8ec06fd2 | |||
b43ca493f6 | |||
3b557765e8 | |||
2619c4307f | |||
c89b82b2d4 | |||
7b26e513f1 | |||
ab1d40ea97 | |||
3a0d3e05df | |||
9b24d89d2d | |||
fb1c2ac535 | |||
728e167334 | |||
7b1ddcff47 | |||
f685b2231c | |||
c0b49d5a50 | |||
098dd0d1e9 | |||
05626ef492 | |||
67a486d18d | |||
7ad82b87e4 | |||
8696f64bae | |||
d7e48234d4 | |||
34f2ecbc3b | |||
4f91c8e109 | |||
06e46d7c3b | |||
9cf26c5cff | |||
aaa9d4ed6c | |||
92db8cecd3 | |||
1542e92629 | |||
82a98f6da0 | |||
5082954c52 | |||
7dd8e12472 | |||
12696b7b2d | |||
ef8cd8fea0 | |||
03e194123d | |||
c2b866172a | |||
06cc329e71 | |||
5f83c13f17 | |||
db3e9dae04 | |||
7f65af1f0d | |||
eeb54716dd | |||
1a276b5da7 | |||
8658df3485 | |||
7cafca835a | |||
04ca2b9ebd | |||
635012d770 | |||
3e49f8fce5 | |||
c2007ac88f | |||
30be5b6660 | |||
107d3d9530 | |||
2746f2c4be | |||
81a36b8713 | |||
0633c85514 | |||
39157346cb | |||
5cefbba757 | |||
130fe5a087 | |||
91ec546feb | |||
0a647875ec | |||
a0c6d5548c | |||
286f01db14 | |||
d6447ad635 | |||
49d3f7f708 | |||
9a465e1b26 | |||
31ab2ddaeb | |||
b11a2a7b9d | |||
1c09164021 | |||
3e94324012 | |||
e6f040d6e3 | |||
cbd36157ac | |||
18d3c803a8 | |||
e4553fb355 | |||
d801e1d564 | |||
9daa6dbe87 | |||
e82fcf1c59 | |||
805bf9ffa7 | |||
42da17694a | |||
25aacda28e | |||
7a62aad24a | |||
bb23b90b1d | |||
2257f4d475 | |||
871efc0307 | |||
c5a058b169 | |||
59e63d690c | |||
dbd4561416 | |||
5c35fbbb13 | |||
70f38c2069 | |||
d7b9fec849 | |||
84ee870efd | |||
df712ecf64 | |||
6fb665004c | |||
1cd74129d4 | |||
98d1242b8f | |||
18d6db2180 | |||
4f18180fc7 | |||
559944146f | |||
3dd5804299 | |||
90e077e409 | |||
584171cae1 | |||
6c58fc59fd | |||
35f72514f5 | |||
d3f05eae8c | |||
258ac32c38 | |||
31936c08fe | |||
74ad4deb42 | |||
b7cd58473b | |||
3cd7e7b51d | |||
722c50bb0c | |||
976a1086ee | |||
c88d6fd4b9 | |||
057f7909bc |
8
.gitignore
vendored
8
.gitignore
vendored
@ -23,14 +23,16 @@ flamegraph.svg
|
|||||||
*.dylib
|
*.dylib
|
||||||
*.so
|
*.so
|
||||||
*.swp
|
*.swp
|
||||||
|
*.swo
|
||||||
trace-*.json
|
trace-*.json
|
||||||
|
|
||||||
candle-wasm-examples/*/build
|
candle-wasm-examples/*/build
|
||||||
candle-wasm-examples/*/*.bin
|
candle-wasm-examples/*/*.bin
|
||||||
candle-wasm-examples/*/*.jpeg
|
candle-wasm-examples/*/*.jpeg
|
||||||
candle-wasm-examples/*/*.wav
|
candle-wasm-examples/*/audios/*.wav
|
||||||
candle-wasm-examples/*/*.safetensors
|
candle-wasm-examples/**/*.safetensors
|
||||||
|
candle-wasm-examples/**/*.gguf
|
||||||
candle-wasm-examples/*/package-lock.json
|
candle-wasm-examples/*/package-lock.json
|
||||||
|
candle-wasm-examples/**/config*.json
|
||||||
.DS_Store
|
.DS_Store
|
||||||
.idea/*
|
.idea/*
|
||||||
|
11
.vscode/settings.json
vendored
Normal file
11
.vscode/settings.json
vendored
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
{
|
||||||
|
"[python]": {
|
||||||
|
"editor.defaultFormatter": "ms-python.black-formatter"
|
||||||
|
},
|
||||||
|
"python.formatting.provider": "none",
|
||||||
|
"python.testing.pytestArgs": [
|
||||||
|
"candle-pyo3"
|
||||||
|
],
|
||||||
|
"python.testing.unittestEnabled": false,
|
||||||
|
"python.testing.pytestEnabled": true
|
||||||
|
}
|
73
CHANGELOG.md
73
CHANGELOG.md
@ -1,13 +1,84 @@
|
|||||||
# Changelog
|
# Changelog
|
||||||
This documents the main changes to the `candle` crate.
|
This documents the main changes to the `candle` crate.
|
||||||
|
|
||||||
## v0.2.1 - Unreleased
|
## v0.3.1 - Unreleased
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|
||||||
|
### Modified
|
||||||
|
|
||||||
|
## v0.3.0 - 2023-10-01
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- Added the Mistral 7b v0.1 model
|
||||||
|
[983](https://github.com/huggingface/candle/pull/983).
|
||||||
|
- Quantized version of the Mistral model
|
||||||
|
[1009](https://github.com/huggingface/candle/pull/1009).
|
||||||
|
- Add the gelu-erf op and activation function
|
||||||
|
[969](https://github.com/huggingface/candle/pull/969).
|
||||||
|
- Add the mixformer/phi-v1.5 model
|
||||||
|
[930](https://github.com/huggingface/candle/pull/930).
|
||||||
|
- Add the sclice-scatter op
|
||||||
|
[927](https://github.com/huggingface/candle/pull/927).
|
||||||
|
- Add the Wuerstchen diffusion model
|
||||||
|
[911](https://github.com/huggingface/candle/pull/911).
|
||||||
|
|
||||||
|
### Modified
|
||||||
|
|
||||||
|
- Support for simd128 intrinsics in some quantized vecdots
|
||||||
|
[982](https://github.com/huggingface/candle/pull/982).
|
||||||
|
- Optimize the index-select cuda kernel
|
||||||
|
[976](https://github.com/huggingface/candle/pull/976).
|
||||||
|
- Self-contained safetensor wrappers
|
||||||
|
[946](https://github.com/huggingface/candle/pull/946).
|
||||||
|
|
||||||
|
## v0.2.2 - 2023-09-18
|
||||||
|
|
||||||
|
### Added
|
||||||
|
- Support for `top_p` sampling
|
||||||
|
[819](https://github.com/huggingface/candle/pull/819).
|
||||||
|
- T5 model including decoding
|
||||||
|
[864](https://github.com/huggingface/candle/pull/864).
|
||||||
|
- 1-d upsampling
|
||||||
|
[839](https://github.com/huggingface/candle/pull/839).
|
||||||
|
|
||||||
|
### Modified
|
||||||
|
- Bugfix for conv2d
|
||||||
|
[820](https://github.com/huggingface/candle/pull/820).
|
||||||
|
- Support tensor based indexing using `.i`
|
||||||
|
[842](https://github.com/huggingface/candle/pull/842).
|
||||||
|
|
||||||
|
## v0.2.1 - 2023-09-11
|
||||||
|
|
||||||
|
### Added
|
||||||
|
- Add some RNNs (GRU and LSTM) in `candle-nn`
|
||||||
|
[674](https://github.com/huggingface/candle/pull/674),
|
||||||
|
[688](https://github.com/huggingface/candle/pull/688).
|
||||||
|
- gguf v2 support
|
||||||
|
[725](https://github.com/huggingface/candle/pull/725).
|
||||||
|
- Quantized llama example in Python using the pyo3 api
|
||||||
|
[716](https://github.com/huggingface/candle/pull/716).
|
||||||
|
- `candle-nn` layer for conv2d-transposed
|
||||||
|
[760](https://github.com/huggingface/candle/pull/760).
|
||||||
|
- Add the Segment-Anything Model (SAM) as an example
|
||||||
|
[773](https://github.com/huggingface/candle/pull/773).
|
||||||
|
- TinyViT backbone for the segemnt anything example
|
||||||
|
[787](https://github.com/huggingface/candle/pull/787).
|
||||||
|
- Shape with holes support
|
||||||
|
[770](https://github.com/huggingface/candle/pull/770).
|
||||||
|
|
||||||
### Modified
|
### Modified
|
||||||
- Dilations are now supported in conv-transpose2d.
|
- Dilations are now supported in conv-transpose2d.
|
||||||
[671](https://github.com/huggingface/candle/pull/671).
|
[671](https://github.com/huggingface/candle/pull/671).
|
||||||
|
- Interactive mode for the quantized model
|
||||||
|
[690](https://github.com/huggingface/candle/pull/690).
|
||||||
|
- Faster softmax operation
|
||||||
|
[747](https://github.com/huggingface/candle/pull/747).
|
||||||
|
- Faster convolution operations on CPU and CUDA via im2col
|
||||||
|
[802](https://github.com/huggingface/candle/pull/802).
|
||||||
|
- Moving some models to a more central location
|
||||||
|
[796](https://github.com/huggingface/candle/pull/796).
|
||||||
|
|
||||||
## v0.2.0 - 2023-08-30
|
## v0.2.0 - 2023-08-30
|
||||||
|
|
||||||
|
20
Cargo.toml
20
Cargo.toml
@ -8,17 +8,19 @@ members = [
|
|||||||
"candle-pyo3",
|
"candle-pyo3",
|
||||||
"candle-transformers",
|
"candle-transformers",
|
||||||
"candle-wasm-examples/llama2-c",
|
"candle-wasm-examples/llama2-c",
|
||||||
|
"candle-wasm-examples/segment-anything",
|
||||||
"candle-wasm-examples/whisper",
|
"candle-wasm-examples/whisper",
|
||||||
"candle-wasm-examples/yolo",
|
"candle-wasm-examples/yolo",
|
||||||
|
"candle-wasm-examples/bert",
|
||||||
|
"candle-wasm-examples/phi",
|
||||||
|
"candle-wasm-examples/t5",
|
||||||
|
"candle-wasm-tests",
|
||||||
]
|
]
|
||||||
exclude = [
|
exclude = ["candle-flash-attn", "candle-kernels"]
|
||||||
"candle-flash-attn",
|
|
||||||
"candle-kernels",
|
|
||||||
]
|
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.2.1"
|
version = "0.3.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "Minimalist ML framework."
|
description = "Minimalist ML framework."
|
||||||
repository = "https://github.com/huggingface/candle"
|
repository = "https://github.com/huggingface/candle"
|
||||||
@ -32,8 +34,7 @@ anyhow = { version = "1", features = ["backtrace"] }
|
|||||||
byteorder = "1.4.3"
|
byteorder = "1.4.3"
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
cudarc = { version = "0.9.14", features = ["f16"] }
|
cudarc = { version = "0.9.14", features = ["f16"] }
|
||||||
# TODO: Switch back to the official gemm implementation once it has caught up.
|
gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] }
|
||||||
gemm = { version = "0.15.6", package = "candle-gemm" }
|
|
||||||
hf-hub = "0.3.0"
|
hf-hub = "0.3.0"
|
||||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||||
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
|
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
|
||||||
@ -41,9 +42,10 @@ imageproc = { version = "0.23.0", default-features = false }
|
|||||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||||
libc = { version = "0.2.147" }
|
libc = { version = "0.2.147" }
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
memmap2 = "0.7.1"
|
memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] }
|
||||||
num_cpus = "1.15.0"
|
num_cpus = "1.15.0"
|
||||||
num-traits = "0.2.15"
|
num-traits = "0.2.15"
|
||||||
|
parquet = { version = "45.0.0" }
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
rand_distr = "0.4.3"
|
rand_distr = "0.4.3"
|
||||||
rayon = "1.7.0"
|
rayon = "1.7.0"
|
||||||
@ -57,8 +59,8 @@ tracing = "0.1.37"
|
|||||||
tracing-chrome = "0.7.1"
|
tracing-chrome = "0.7.1"
|
||||||
tracing-subscriber = "0.3.7"
|
tracing-subscriber = "0.3.7"
|
||||||
wav = "1.0.0"
|
wav = "1.0.0"
|
||||||
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "0.6.6", default-features = false }
|
zip = { version = "0.6.6", default-features = false }
|
||||||
parquet = { version = "45.0.0" }
|
|
||||||
|
|
||||||
[profile.release-with-debug]
|
[profile.release-with-debug]
|
||||||
inherits = "release"
|
inherits = "release"
|
||||||
|
127
README.md
127
README.md
@ -8,7 +8,10 @@ Candle is a minimalist ML framework for Rust with a focus on performance (includ
|
|||||||
and ease of use. Try our online demos:
|
and ease of use. Try our online demos:
|
||||||
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
||||||
[LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2),
|
[LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2),
|
||||||
[yolo](https://huggingface.co/spaces/lmz/candle-yolo).
|
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
|
||||||
|
[yolo](https://huggingface.co/spaces/lmz/candle-yolo),
|
||||||
|
[Segment
|
||||||
|
Anything](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
|
||||||
|
|
||||||
## Get started
|
## Get started
|
||||||
|
|
||||||
@ -45,40 +48,60 @@ For more advanced examples, please have a look at the following section.
|
|||||||
|
|
||||||
## Check out our examples
|
## Check out our examples
|
||||||
|
|
||||||
Check out our [examples](./candle-examples/examples/):
|
These online demos run entirely in your browser:
|
||||||
|
- [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and
|
||||||
|
object recognition.
|
||||||
|
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): text to speech.
|
||||||
|
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
|
||||||
|
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
|
||||||
|
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
|
||||||
|
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
|
||||||
|
|
||||||
|
We also provide a some command line based examples using state of the art models:
|
||||||
|
|
||||||
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
|
||||||
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
|
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
|
||||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||||
- [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
|
- [Phi-v1.5](./candle-examples/examples/phi/): a 1.3b general LLM with performance on par with LLaMA-v2 7b.
|
||||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code
|
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||||
generation.
|
pre-trained on 1T tokens of English and code datasets.
|
||||||
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
|
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
||||||
image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions.
|
performance larger than all publicly available 13b models as of 2023-09-28.
|
||||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
||||||
using self-supervision (can be used for imagenet classification, depth
|
|
||||||
evaluation, segmentation).
|
|
||||||
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
|
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
|
||||||
the LLaMA model using the same quantization techniques as
|
the LLaMA model using the same quantization techniques as
|
||||||
[llama.cpp](https://github.com/ggerganov/llama.cpp).
|
[llama.cpp](https://github.com/ggerganov/llama.cpp).
|
||||||
|
|
||||||
|
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/quantized/assets/aoc.gif" width="600">
|
||||||
|
|
||||||
|
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
|
||||||
|
image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions.
|
||||||
|
|
||||||
|
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200">
|
||||||
|
|
||||||
|
- [Wuerstchen](./candle-examples/examples/wuerstchen/): another text to
|
||||||
|
image generative model.
|
||||||
|
|
||||||
|
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/wuerstchen/assets/cat.jpg" width="200">
|
||||||
|
|
||||||
- [yolo-v3](./candle-examples/examples/yolo-v3/) and
|
- [yolo-v3](./candle-examples/examples/yolo-v3/) and
|
||||||
[yolo-v8](./candle-examples/examples/yolo-v8/): object detection and pose
|
[yolo-v8](./candle-examples/examples/yolo-v8/): object detection and pose
|
||||||
estimation models.
|
estimation models.
|
||||||
[segment-anything](./candle-examples/examples/segment-anything/): image
|
|
||||||
|
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.od.jpg" width="200"><img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.pose.jpg" width="200">
|
||||||
|
- [segment-anything](./candle-examples/examples/segment-anything/): image
|
||||||
segmentation model with prompt.
|
segmentation model with prompt.
|
||||||
Run them using the following commands:
|
|
||||||
|
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
|
||||||
|
|
||||||
|
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
||||||
|
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
|
||||||
|
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||||
|
using self-supervision (can be used for imagenet classification, depth
|
||||||
|
evaluation, segmentation).
|
||||||
|
|
||||||
|
Run them using commands like:
|
||||||
```
|
```
|
||||||
cargo run --example whisper --release
|
|
||||||
cargo run --example llama --release
|
|
||||||
cargo run --example falcon --release
|
|
||||||
cargo run --example bert --release
|
|
||||||
cargo run --example bigcode --release
|
|
||||||
cargo run --example stable-diffusion --release -- --prompt "a rusty robot holding a fire torch"
|
|
||||||
cargo run --example dinov2 --release -- --image path/to/myinput.jpg
|
|
||||||
cargo run --example quantized --release
|
cargo run --example quantized --release
|
||||||
cargo run --example yolo-v3 --release -- myimage.jpg
|
|
||||||
cargo run --example yolo-v8 --release -- myimage.jpg # for pose estimation, add --task pose
|
|
||||||
cargo run --example segment-anything --release -- --image myimage.jpg
|
|
||||||
```
|
```
|
||||||
|
|
||||||
In order to use **CUDA** add `--features cuda` to the example command line. If
|
In order to use **CUDA** add `--features cuda` to the example command line. If
|
||||||
@ -88,7 +111,10 @@ There are also some wasm examples for whisper and
|
|||||||
[llama2.c](https://github.com/karpathy/llama2.c). You can either build them with
|
[llama2.c](https://github.com/karpathy/llama2.c). You can either build them with
|
||||||
`trunk` or try them online:
|
`trunk` or try them online:
|
||||||
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
||||||
[llama2](https://huggingface.co/spaces/lmz/candle-llama2).
|
[llama2](https://huggingface.co/spaces/lmz/candle-llama2),
|
||||||
|
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
|
||||||
|
[Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm),
|
||||||
|
[Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
|
||||||
|
|
||||||
For LLaMA2, run the following command to retrieve the weight files and start a
|
For LLaMA2, run the following command to retrieve the weight files and start a
|
||||||
test server:
|
test server:
|
||||||
@ -101,6 +127,15 @@ trunk serve --release --port 8081
|
|||||||
And then head over to
|
And then head over to
|
||||||
[http://localhost:8081/](http://localhost:8081/).
|
[http://localhost:8081/](http://localhost:8081/).
|
||||||
|
|
||||||
|
<!--- ANCHOR: useful_libraries --->
|
||||||
|
|
||||||
|
## Useful Libraries
|
||||||
|
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora) provides a LoRA implementation that conforms to the official `peft` implementation.
|
||||||
|
|
||||||
|
If you have an addition to this list, please submit a pull request.
|
||||||
|
|
||||||
|
<!--- ANCHOR_END: useful_libraries --->
|
||||||
|
|
||||||
<!--- ANCHOR: features --->
|
<!--- ANCHOR: features --->
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
@ -113,10 +148,25 @@ And then head over to
|
|||||||
- CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.
|
- CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.
|
||||||
- WASM support, run your models in a browser.
|
- WASM support, run your models in a browser.
|
||||||
- Included models.
|
- Included models.
|
||||||
- LLMs: LLaMA v1 and v2, Falcon, StarCoder.
|
- Language Models.
|
||||||
|
- LLaMA v1 and v2.
|
||||||
|
- Falcon.
|
||||||
|
- StarCoder.
|
||||||
|
- Phi v1.5.
|
||||||
|
- Mistral 7b v0.1.
|
||||||
|
- StableLM-3B-4E1T.
|
||||||
|
- T5.
|
||||||
|
- Bert.
|
||||||
- Whisper (multi-lingual support).
|
- Whisper (multi-lingual support).
|
||||||
- Stable Diffusion.
|
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
||||||
- Computer Vision: DINOv2, EfficientNet, yolo-v3, yolo-v8.
|
- Wurstchen v2.
|
||||||
|
- Computer Vision Models.
|
||||||
|
- DINOv2.
|
||||||
|
- ConvMixer.
|
||||||
|
- EfficientNet.
|
||||||
|
- yolo-v3.
|
||||||
|
- yolo-v8.
|
||||||
|
- Segment-Anything Model (SAM).
|
||||||
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
||||||
- Serverless (on CPU), small and fast deployments.
|
- Serverless (on CPU), small and fast deployments.
|
||||||
- Quantization support using the llama.cpp quantized types.
|
- Quantization support using the llama.cpp quantized types.
|
||||||
@ -257,6 +307,29 @@ This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a d
|
|||||||
env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
|
env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Linking error on windows when running rustdoc or mdbook tests
|
||||||
|
|
||||||
|
```
|
||||||
|
Couldn't compile the test.
|
||||||
|
---- .\candle-book\src\inference\hub.md - Using_the_hub::Using_in_a_real_model_ (line 50) stdout ----
|
||||||
|
error: linking with `link.exe` failed: exit code: 1181
|
||||||
|
//very long chain of linking
|
||||||
|
= note: LINK : fatal error LNK1181: cannot open input file 'windows.0.48.5.lib'
|
||||||
|
```
|
||||||
|
|
||||||
|
Make sure you link all native libraries that might be located outside a project target, e.g., to run mdbook tests, you should run:
|
||||||
|
|
||||||
|
```
|
||||||
|
mdbook test candle-book -L .\target\debug\deps\ `
|
||||||
|
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.42.2\lib `
|
||||||
|
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.48.5\lib
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Extremely slow model load time with WSL
|
||||||
|
|
||||||
|
This may be caused by the models being loaded from `/mnt/c`, more details on
|
||||||
|
[stackoverflow](https://stackoverflow.com/questions/68972448/why-is-wsl-extremely-slow-when-compared-with-native-windows-npm-yarn-processing).
|
||||||
|
|
||||||
#### Tracking down errors
|
#### Tracking down errors
|
||||||
|
|
||||||
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
|
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
|
||||||
|
@ -11,11 +11,11 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-datasets = { path = "../candle-datasets", version = "0.2.1" }
|
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.2.1" }
|
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||||
candle-transformers = { path = "../candle-transformers", version = "0.2.1" }
|
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
||||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
|
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||||
safetensors = { workspace = true }
|
safetensors = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
@ -24,9 +24,10 @@ intel-mkl-src = { workspace = true, optional = true }
|
|||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
half = { workspace = true, optional = true }
|
half = { workspace = true, optional = true }
|
||||||
image = { workspace = true, optional = true }
|
image = { workspace = true, optional = true }
|
||||||
|
anyhow = { workspace = true }
|
||||||
|
tokio = "1.29.1"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
hf-hub = { workspace = true, features=["tokio"]}
|
hf-hub = { workspace = true, features=["tokio"]}
|
||||||
clap = { workspace = true }
|
clap = { workspace = true }
|
||||||
@ -38,7 +39,6 @@ tracing-chrome = { workspace = true }
|
|||||||
tracing-subscriber = { workspace = true }
|
tracing-subscriber = { workspace = true }
|
||||||
wav = { workspace = true }
|
wav = { workspace = true }
|
||||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||||
tokio = "1.29.1"
|
|
||||||
parquet = { workspace = true }
|
parquet = { workspace = true }
|
||||||
image = { workspace = true }
|
image = { workspace = true }
|
||||||
|
|
||||||
|
@ -10,10 +10,11 @@
|
|||||||
|
|
||||||
# Reference Guide
|
# Reference Guide
|
||||||
|
|
||||||
- [Running a model](inference/README.md)
|
- [Running a model](inference/inference.md)
|
||||||
- [Using the hub](inference/hub.md)
|
- [Using the hub](inference/hub.md)
|
||||||
- [Error management](error_manage.md)
|
- [Error management](error_manage.md)
|
||||||
- [Training](training/README.md)
|
- [Training](training/training.md)
|
||||||
|
- [Simplified](training/simplified.md)
|
||||||
- [MNIST](training/mnist.md)
|
- [MNIST](training/mnist.md)
|
||||||
- [Fine-tuning]()
|
- [Fine-tuning]()
|
||||||
- [Serialization]()
|
- [Serialization]()
|
||||||
|
@ -29,7 +29,7 @@ After adding `RUST_BACKTRACE=1`:
|
|||||||
Error: WithBacktrace { inner: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }, backtrace: Backtrace [{ fn: "candle::error::Error::bt", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/error.rs", line: 200 }, { fn: "candle::tensor::Tensor::matmul", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/tensor.rs", line: 816 }, { fn: "myapp::main", file: "./src/main.rs", line: 29 }, { fn: "core::ops::function::FnOnce::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 250 }, { fn: "std::sys_common::backtrace::__rust_begin_short_backtrace", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/sys_common/backtrace.rs", line: 135 }, { fn: "std::rt::lang_start::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 166 }, { fn: "core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 284 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::rt::lang_start", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 165 }, { fn: "main" }, { fn: "__libc_start_main" }, { fn: "_start" }] }
|
Error: WithBacktrace { inner: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }, backtrace: Backtrace [{ fn: "candle::error::Error::bt", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/error.rs", line: 200 }, { fn: "candle::tensor::Tensor::matmul", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/tensor.rs", line: 816 }, { fn: "myapp::main", file: "./src/main.rs", line: 29 }, { fn: "core::ops::function::FnOnce::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 250 }, { fn: "std::sys_common::backtrace::__rust_begin_short_backtrace", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/sys_common/backtrace.rs", line: 135 }, { fn: "std::rt::lang_start::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 166 }, { fn: "core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 284 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::rt::lang_start", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 165 }, { fn: "main" }, { fn: "__libc_start_main" }, { fn: "_start" }] }
|
||||||
```
|
```
|
||||||
|
|
||||||
Not super pretty at the moment, but we can see error occured on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }`
|
Not super pretty at the moment, but we can see error occurred on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }`
|
||||||
|
|
||||||
|
|
||||||
Another thing to note, is that since Rust is compiled it is not necessarily as easy to recover proper stacktraces
|
Another thing to note, is that since Rust is compiled it is not necessarily as easy to recover proper stacktraces
|
||||||
|
@ -6,7 +6,7 @@ Open `src/main.rs` and fill in this content:
|
|||||||
|
|
||||||
```rust
|
```rust
|
||||||
# extern crate candle_core;
|
# extern crate candle_core;
|
||||||
use candle_core::{DType, Device, Result, Tensor};
|
use candle_core::{Device, Result, Tensor};
|
||||||
|
|
||||||
struct Model {
|
struct Model {
|
||||||
first: Tensor,
|
first: Tensor,
|
||||||
@ -25,11 +25,11 @@ fn main() -> Result<()> {
|
|||||||
// Use Device::new_cuda(0)?; to use the GPU.
|
// Use Device::new_cuda(0)?; to use the GPU.
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
|
|
||||||
let first = Tensor::zeros((784, 100), DType::F32, &device)?;
|
let first = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
|
||||||
let second = Tensor::zeros((100, 10), DType::F32, &device)?;
|
let second = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
|
||||||
let model = Model { first, second };
|
let model = Model { first, second };
|
||||||
|
|
||||||
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
|
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
|
||||||
|
|
||||||
let digit = model.forward(&dummy_image)?;
|
let digit = model.forward(&dummy_image)?;
|
||||||
println!("Digit {digit:?} digit");
|
println!("Digit {digit:?} digit");
|
||||||
@ -50,7 +50,7 @@ the classical `Linear` layer. We can do as such
|
|||||||
|
|
||||||
```rust
|
```rust
|
||||||
# extern crate candle_core;
|
# extern crate candle_core;
|
||||||
# use candle_core::{DType, Device, Result, Tensor};
|
# use candle_core::{Device, Result, Tensor};
|
||||||
struct Linear{
|
struct Linear{
|
||||||
weight: Tensor,
|
weight: Tensor,
|
||||||
bias: Tensor,
|
bias: Tensor,
|
||||||
@ -80,7 +80,7 @@ This will change the model running code into a new function
|
|||||||
|
|
||||||
```rust
|
```rust
|
||||||
# extern crate candle_core;
|
# extern crate candle_core;
|
||||||
# use candle_core::{DType, Device, Result, Tensor};
|
# use candle_core::{Device, Result, Tensor};
|
||||||
# struct Linear{
|
# struct Linear{
|
||||||
# weight: Tensor,
|
# weight: Tensor,
|
||||||
# bias: Tensor,
|
# bias: Tensor,
|
||||||
@ -110,15 +110,15 @@ fn main() -> Result<()> {
|
|||||||
let device = Device::cuda_if_available(0)?;
|
let device = Device::cuda_if_available(0)?;
|
||||||
|
|
||||||
// Creating a dummy model
|
// Creating a dummy model
|
||||||
let weight = Tensor::zeros((784, 100), DType::F32, &device)?;
|
let weight = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
|
||||||
let bias = Tensor::zeros((100, ), DType::F32, &device)?;
|
let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
|
||||||
let first = Linear{weight, bias};
|
let first = Linear{weight, bias};
|
||||||
let weight = Tensor::zeros((100, 10), DType::F32, &device)?;
|
let weight = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
|
||||||
let bias = Tensor::zeros((10, ), DType::F32, &device)?;
|
let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
|
||||||
let second = Linear{weight, bias};
|
let second = Linear{weight, bias};
|
||||||
let model = Model { first, second };
|
let model = Model { first, second };
|
||||||
|
|
||||||
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
|
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
|
||||||
|
|
||||||
// Inference on the model
|
// Inference on the model
|
||||||
let digit = model.forward(&dummy_image)?;
|
let digit = model.forward(&dummy_image)?;
|
||||||
@ -146,7 +146,7 @@ And rewrite our examples using it
|
|||||||
```rust
|
```rust
|
||||||
# extern crate candle_core;
|
# extern crate candle_core;
|
||||||
# extern crate candle_nn;
|
# extern crate candle_nn;
|
||||||
use candle_core::{DType, Device, Result, Tensor};
|
use candle_core::{Device, Result, Tensor};
|
||||||
use candle_nn::{Linear, Module};
|
use candle_nn::{Linear, Module};
|
||||||
|
|
||||||
struct Model {
|
struct Model {
|
||||||
@ -167,15 +167,15 @@ fn main() -> Result<()> {
|
|||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
|
|
||||||
// This has changed (784, 100) -> (100, 784) !
|
// This has changed (784, 100) -> (100, 784) !
|
||||||
let weight = Tensor::zeros((100, 784), DType::F32, &device)?;
|
let weight = Tensor::randn(0f32, 1.0, (100, 784), &device)?;
|
||||||
let bias = Tensor::zeros((100, ), DType::F32, &device)?;
|
let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
|
||||||
let first = Linear::new(weight, Some(bias));
|
let first = Linear::new(weight, Some(bias));
|
||||||
let weight = Tensor::zeros((10, 100), DType::F32, &device)?;
|
let weight = Tensor::randn(0f32, 1.0, (10, 100), &device)?;
|
||||||
let bias = Tensor::zeros((10, ), DType::F32, &device)?;
|
let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
|
||||||
let second = Linear::new(weight, Some(bias));
|
let second = Linear::new(weight, Some(bias));
|
||||||
let model = Model { first, second };
|
let model = Model { first, second };
|
||||||
|
|
||||||
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
|
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
|
||||||
|
|
||||||
let digit = model.forward(&dummy_image)?;
|
let digit = model.forward(&dummy_image)?;
|
||||||
println!("Digit {digit:?} digit");
|
println!("Digit {digit:?} digit");
|
||||||
@ -188,8 +188,8 @@ Feel free to modify this example to use `Conv2d` to create a classical convnet i
|
|||||||
|
|
||||||
Now that we have the running dummy code we can get to more advanced topics:
|
Now that we have the running dummy code we can get to more advanced topics:
|
||||||
|
|
||||||
- [For PyTorch users](./guide/cheatsheet.md)
|
- [For PyTorch users](../guide/cheatsheet.md)
|
||||||
- [Running existing models](./inference/README.md)
|
- [Running existing models](../inference/inference.md)
|
||||||
- [Training models](./training/README.md)
|
- [Training models](../training/training.md)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
#[cfg(test)]
|
||||||
|
pub mod simplified;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
196
candle-book/src/simplified.rs
Normal file
196
candle-book/src/simplified.rs
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
//! #A simplified example in Rust of training a neural network and then using it based on the Candle Framework by Hugging Face.
|
||||||
|
//! Author: Evgeny Igumnov 2023 igumnovnsk@gmail.com
|
||||||
|
//! This program implements a neural network to predict the winner of the second round of elections based on the results of the first round.
|
||||||
|
//!
|
||||||
|
//! ##Basic moments:
|
||||||
|
//!
|
||||||
|
//! A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons.
|
||||||
|
//! The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage.
|
||||||
|
//! The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose.
|
||||||
|
//! For training, samples with real data on the results of the first and second stages of different elections are used.
|
||||||
|
//! The model is trained by backpropagation using gradient descent and the cross-entropy loss function.
|
||||||
|
//! Model parameters (weights of neurons) are initialized randomly, then optimized during training.
|
||||||
|
//! After training, the model is tested on a deferred sample to evaluate the accuracy.
|
||||||
|
//! If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated.
|
||||||
|
//! Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data.
|
||||||
|
|
||||||
|
#[rustfmt::skip]
|
||||||
|
mod tests {
|
||||||
|
|
||||||
|
use candle::{DType, Result, Tensor, D, Device};
|
||||||
|
use candle_nn::{loss, ops, Linear, Module, VarBuilder, VarMap, Optimizer};
|
||||||
|
|
||||||
|
// ANCHOR: book_training_simplified1
|
||||||
|
const VOTE_DIM: usize = 2;
|
||||||
|
const RESULTS: usize = 1;
|
||||||
|
const EPOCHS: usize = 10;
|
||||||
|
const LAYER1_OUT_SIZE: usize = 4;
|
||||||
|
const LAYER2_OUT_SIZE: usize = 2;
|
||||||
|
const LEARNING_RATE: f64 = 0.05;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Dataset {
|
||||||
|
pub train_votes: Tensor,
|
||||||
|
pub train_results: Tensor,
|
||||||
|
pub test_votes: Tensor,
|
||||||
|
pub test_results: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MultiLevelPerceptron {
|
||||||
|
ln1: Linear,
|
||||||
|
ln2: Linear,
|
||||||
|
ln3: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MultiLevelPerceptron {
|
||||||
|
fn new(vs: VarBuilder) -> Result<Self> {
|
||||||
|
let ln1 = candle_nn::linear(VOTE_DIM, LAYER1_OUT_SIZE, vs.pp("ln1"))?;
|
||||||
|
let ln2 = candle_nn::linear(LAYER1_OUT_SIZE, LAYER2_OUT_SIZE, vs.pp("ln2"))?;
|
||||||
|
let ln3 = candle_nn::linear(LAYER2_OUT_SIZE, RESULTS + 1, vs.pp("ln3"))?;
|
||||||
|
Ok(Self { ln1, ln2, ln3 })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let xs = self.ln1.forward(xs)?;
|
||||||
|
let xs = xs.relu()?;
|
||||||
|
let xs = self.ln2.forward(&xs)?;
|
||||||
|
let xs = xs.relu()?;
|
||||||
|
self.ln3.forward(&xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ANCHOR_END: book_training_simplified1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// ANCHOR: book_training_simplified3
|
||||||
|
#[tokio::test]
|
||||||
|
async fn simplified() -> anyhow::Result<()> {
|
||||||
|
|
||||||
|
let dev = Device::cuda_if_available(0)?;
|
||||||
|
|
||||||
|
let train_votes_vec: Vec<u32> = vec![
|
||||||
|
15, 10,
|
||||||
|
10, 15,
|
||||||
|
5, 12,
|
||||||
|
30, 20,
|
||||||
|
16, 12,
|
||||||
|
13, 25,
|
||||||
|
6, 14,
|
||||||
|
31, 21,
|
||||||
|
];
|
||||||
|
let train_votes_tensor = Tensor::from_vec(train_votes_vec.clone(), (train_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
|
||||||
|
|
||||||
|
let train_results_vec: Vec<u32> = vec![
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
];
|
||||||
|
let train_results_tensor = Tensor::from_vec(train_results_vec, train_votes_vec.len() / VOTE_DIM, &dev)?;
|
||||||
|
|
||||||
|
let test_votes_vec: Vec<u32> = vec![
|
||||||
|
13, 9,
|
||||||
|
8, 14,
|
||||||
|
3, 10,
|
||||||
|
];
|
||||||
|
let test_votes_tensor = Tensor::from_vec(test_votes_vec.clone(), (test_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
|
||||||
|
|
||||||
|
let test_results_vec: Vec<u32> = vec![
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
];
|
||||||
|
let test_results_tensor = Tensor::from_vec(test_results_vec.clone(), test_results_vec.len(), &dev)?;
|
||||||
|
|
||||||
|
let m = Dataset {
|
||||||
|
train_votes: train_votes_tensor,
|
||||||
|
train_results: train_results_tensor,
|
||||||
|
test_votes: test_votes_tensor,
|
||||||
|
test_results: test_results_tensor,
|
||||||
|
};
|
||||||
|
|
||||||
|
let trained_model: MultiLevelPerceptron;
|
||||||
|
loop {
|
||||||
|
println!("Trying to train neural network.");
|
||||||
|
match train(m.clone(), &dev) {
|
||||||
|
Ok(model) => {
|
||||||
|
trained_model = model;
|
||||||
|
break;
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
println!("Error: {}", e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
let real_world_votes: Vec<u32> = vec![
|
||||||
|
13, 22,
|
||||||
|
];
|
||||||
|
|
||||||
|
let tensor_test_votes = Tensor::from_vec(real_world_votes.clone(), (1, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
|
||||||
|
|
||||||
|
let final_result = trained_model.forward(&tensor_test_votes)?;
|
||||||
|
|
||||||
|
let result = final_result
|
||||||
|
.argmax(D::Minus1)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.get(0).map(|x| x.to_scalar::<f32>())??;
|
||||||
|
println!("real_life_votes: {:?}", real_world_votes);
|
||||||
|
println!("neural_network_prediction_result: {:?}", result);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
|
||||||
|
}
|
||||||
|
// ANCHOR_END: book_training_simplified3
|
||||||
|
|
||||||
|
// ANCHOR: book_training_simplified2
|
||||||
|
fn train(m: Dataset, dev: &Device) -> anyhow::Result<MultiLevelPerceptron> {
|
||||||
|
let train_results = m.train_results.to_device(dev)?;
|
||||||
|
let train_votes = m.train_votes.to_device(dev)?;
|
||||||
|
let varmap = VarMap::new();
|
||||||
|
let vs = VarBuilder::from_varmap(&varmap, DType::F32, dev);
|
||||||
|
let model = MultiLevelPerceptron::new(vs.clone())?;
|
||||||
|
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), LEARNING_RATE)?;
|
||||||
|
let test_votes = m.test_votes.to_device(dev)?;
|
||||||
|
let test_results = m.test_results.to_device(dev)?;
|
||||||
|
let mut final_accuracy: f32 = 0.0;
|
||||||
|
for epoch in 1..EPOCHS + 1 {
|
||||||
|
let logits = model.forward(&train_votes)?;
|
||||||
|
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
||||||
|
let loss = loss::nll(&log_sm, &train_results)?;
|
||||||
|
sgd.backward_step(&loss)?;
|
||||||
|
|
||||||
|
let test_logits = model.forward(&test_votes)?;
|
||||||
|
let sum_ok = test_logits
|
||||||
|
.argmax(D::Minus1)?
|
||||||
|
.eq(&test_results)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.sum_all()?
|
||||||
|
.to_scalar::<f32>()?;
|
||||||
|
let test_accuracy = sum_ok / test_results.dims1()? as f32;
|
||||||
|
final_accuracy = 100. * test_accuracy;
|
||||||
|
println!("Epoch: {epoch:3} Train loss: {:8.5} Test accuracy: {:5.2}%",
|
||||||
|
loss.to_scalar::<f32>()?,
|
||||||
|
final_accuracy
|
||||||
|
);
|
||||||
|
if final_accuracy == 100.0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if final_accuracy < 100.0 {
|
||||||
|
Err(anyhow::Error::msg("The model is not trained well enough."))
|
||||||
|
} else {
|
||||||
|
Ok(model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// ANCHOR_END: book_training_simplified2
|
||||||
|
|
||||||
|
|
||||||
|
}
|
45
candle-book/src/training/simplified.md
Normal file
45
candle-book/src/training/simplified.md
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# Simplified
|
||||||
|
|
||||||
|
## How its works
|
||||||
|
|
||||||
|
This program implements a neural network to predict the winner of the second round of elections based on the results of the first round.
|
||||||
|
|
||||||
|
Basic moments:
|
||||||
|
|
||||||
|
1. A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons.
|
||||||
|
2. The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage.
|
||||||
|
3. The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose.
|
||||||
|
4. For training, samples with real data on the results of the first and second stages of different elections are used.
|
||||||
|
5. The model is trained by backpropagation using gradient descent and the cross-entropy loss function.
|
||||||
|
6. Model parameters (weights of neurons) are initialized randomly, then optimized during training.
|
||||||
|
7. After training, the model is tested on a deferred sample to evaluate the accuracy.
|
||||||
|
8. If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated.
|
||||||
|
|
||||||
|
Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data.
|
||||||
|
|
||||||
|
|
||||||
|
```rust,ignore
|
||||||
|
{{#include ../simplified.rs:book_training_simplified1}}
|
||||||
|
```
|
||||||
|
|
||||||
|
```rust,ignore
|
||||||
|
{{#include ../simplified.rs:book_training_simplified2}}
|
||||||
|
```
|
||||||
|
|
||||||
|
```rust,ignore
|
||||||
|
{{#include ../simplified.rs:book_training_simplified3}}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Example output
|
||||||
|
|
||||||
|
```bash
|
||||||
|
Trying to train neural network.
|
||||||
|
Epoch: 1 Train loss: 4.42555 Test accuracy: 0.00%
|
||||||
|
Epoch: 2 Train loss: 0.84677 Test accuracy: 33.33%
|
||||||
|
Epoch: 3 Train loss: 2.54335 Test accuracy: 33.33%
|
||||||
|
Epoch: 4 Train loss: 0.37806 Test accuracy: 33.33%
|
||||||
|
Epoch: 5 Train loss: 0.36647 Test accuracy: 100.00%
|
||||||
|
real_life_votes: [13, 22]
|
||||||
|
neural_network_prediction_result: 0.0
|
||||||
|
```
|
@ -12,7 +12,7 @@ readme = "README.md"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
candle-kernels = { path = "../candle-kernels", version = "0.2.1", optional = true }
|
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
gemm = { workspace = true }
|
gemm = { workspace = true }
|
||||||
half = { workspace = true }
|
half = { workspace = true }
|
||||||
@ -26,6 +26,7 @@ rand_distr = { workspace = true }
|
|||||||
rayon = { workspace = true }
|
rayon = { workspace = true }
|
||||||
safetensors = { workspace = true }
|
safetensors = { workspace = true }
|
||||||
thiserror = { workspace = true }
|
thiserror = { workspace = true }
|
||||||
|
yoke = { workspace = true }
|
||||||
zip = { workspace = true }
|
zip = { workspace = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
@ -103,8 +103,10 @@ enum Command {
|
|||||||
|
|
||||||
Quantize {
|
Quantize {
|
||||||
/// The input file, in gguf format.
|
/// The input file, in gguf format.
|
||||||
in_file: std::path::PathBuf,
|
in_file: Vec<std::path::PathBuf>,
|
||||||
|
|
||||||
/// The output file, in gguf format.
|
/// The output file, in gguf format.
|
||||||
|
#[arg(long)]
|
||||||
out_file: std::path::PathBuf,
|
out_file: std::path::PathBuf,
|
||||||
|
|
||||||
/// The quantization schema to apply.
|
/// The quantization schema to apply.
|
||||||
@ -150,8 +152,7 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Format::Safetensors => {
|
Format::Safetensors => {
|
||||||
let tensors = unsafe { candle_core::safetensors::MmapedFile::new(file)? };
|
let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? };
|
||||||
let tensors = tensors.deserialize()?;
|
|
||||||
let mut tensors = tensors.tensors();
|
let mut tensors = tensors.tensors();
|
||||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||||
for (name, view) in tensors.iter() {
|
for (name, view) in tensors.iter() {
|
||||||
@ -218,15 +219,99 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn run_quantize_safetensors(
|
||||||
|
in_files: &[std::path::PathBuf],
|
||||||
|
out_file: std::path::PathBuf,
|
||||||
|
q: Quantization,
|
||||||
|
) -> Result<()> {
|
||||||
|
let mut out_file = std::fs::File::create(out_file)?;
|
||||||
|
let mut tensors = std::collections::HashMap::new();
|
||||||
|
for in_file in in_files.iter() {
|
||||||
|
let in_tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?;
|
||||||
|
tensors.extend(in_tensors)
|
||||||
|
}
|
||||||
|
println!("tensors: {}", tensors.len());
|
||||||
|
|
||||||
|
let quantize_fn = match q {
|
||||||
|
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
|
||||||
|
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
|
||||||
|
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
|
||||||
|
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
|
||||||
|
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
|
||||||
|
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
|
||||||
|
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
|
||||||
|
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
|
||||||
|
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
|
||||||
|
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
|
||||||
|
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
|
||||||
|
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
|
||||||
|
Quantization::F16 => QTensor::quantize::<half::f16>,
|
||||||
|
Quantization::F32 => QTensor::quantize::<f32>,
|
||||||
|
};
|
||||||
|
let block_size = match q {
|
||||||
|
Quantization::Q4_0 => k_quants::QK4_0,
|
||||||
|
Quantization::Q4_1 => k_quants::QK4_1,
|
||||||
|
Quantization::Q5_0 => k_quants::QK5_0,
|
||||||
|
Quantization::Q5_1 => k_quants::QK5_1,
|
||||||
|
Quantization::Q8_0 => k_quants::QK8_0,
|
||||||
|
Quantization::Q8_1 => k_quants::QK8_1,
|
||||||
|
Quantization::Q2k
|
||||||
|
| Quantization::Q3k
|
||||||
|
| Quantization::Q4k
|
||||||
|
| Quantization::Q5k
|
||||||
|
| Quantization::Q6k
|
||||||
|
| Quantization::Q8k => k_quants::QK_K,
|
||||||
|
Quantization::F16 | Quantization::F32 => 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
let qtensors = tensors
|
||||||
|
.into_par_iter()
|
||||||
|
.map(|(name, tensor)| {
|
||||||
|
let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0;
|
||||||
|
println!(" quantizing {name} {tensor:?} {should_quantize}");
|
||||||
|
let tensor = if should_quantize {
|
||||||
|
quantize_fn(&tensor)?
|
||||||
|
} else {
|
||||||
|
QTensor::quantize::<f32>(&tensor)?
|
||||||
|
};
|
||||||
|
Ok((name, tensor))
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
let qtensors = qtensors
|
||||||
|
.iter()
|
||||||
|
.map(|(k, v)| (k.as_str(), v))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
gguf_file::write(&mut out_file, &[], &qtensors)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn run_quantize(
|
fn run_quantize(
|
||||||
in_file: std::path::PathBuf,
|
in_files: &[std::path::PathBuf],
|
||||||
out_file: std::path::PathBuf,
|
out_file: std::path::PathBuf,
|
||||||
q: Quantization,
|
q: Quantization,
|
||||||
qmode: QuantizationMode,
|
qmode: QuantizationMode,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
|
if in_files.is_empty() {
|
||||||
|
candle_core::bail!("no specified input files")
|
||||||
|
}
|
||||||
|
if let Some(extension) = out_file.extension() {
|
||||||
|
if extension == "safetensors" {
|
||||||
|
candle_core::bail!("the generated file cannot use the safetensors extension")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(extension) = in_files[0].extension() {
|
||||||
|
if extension == "safetensors" {
|
||||||
|
return run_quantize_safetensors(in_files, out_file, q);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if in_files.len() != 1 {
|
||||||
|
candle_core::bail!("only a single in-file can be used when quantizing gguf files")
|
||||||
|
}
|
||||||
|
|
||||||
// Open the out file early so as to fail directly on missing directories etc.
|
// Open the out file early so as to fail directly on missing directories etc.
|
||||||
let mut out_file = std::fs::File::create(out_file)?;
|
let mut out_file = std::fs::File::create(out_file)?;
|
||||||
let mut in_ = std::fs::File::open(&in_file)?;
|
let mut in_ = std::fs::File::open(&in_files[0])?;
|
||||||
let content = gguf_file::Content::read(&mut in_)?;
|
let content = gguf_file::Content::read(&mut in_)?;
|
||||||
println!("tensors: {}", content.tensor_infos.len());
|
println!("tensors: {}", content.tensor_infos.len());
|
||||||
|
|
||||||
@ -252,7 +337,7 @@ fn run_quantize(
|
|||||||
.par_iter()
|
.par_iter()
|
||||||
.map(|(name, _)| {
|
.map(|(name, _)| {
|
||||||
println!(" quantizing {name}");
|
println!(" quantizing {name}");
|
||||||
let mut in_file = std::fs::File::open(&in_file)?;
|
let mut in_file = std::fs::File::open(&in_files[0])?;
|
||||||
let tensor = content.tensor(&mut in_file, name)?;
|
let tensor = content.tensor(&mut in_file, name)?;
|
||||||
let tensor = qmode.quantize(name, tensor, quantize_fn)?;
|
let tensor = qmode.quantize(name, tensor, quantize_fn)?;
|
||||||
Ok((name, tensor))
|
Ok((name, tensor))
|
||||||
@ -293,7 +378,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
out_file,
|
out_file,
|
||||||
quantization,
|
quantization,
|
||||||
mode,
|
mode,
|
||||||
} => run_quantize(in_file, out_file, quantization, mode)?,
|
} => run_quantize(&in_file, out_file, quantization, mode)?,
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -370,6 +370,38 @@ pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
|
|||||||
y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a)
|
y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vs_tanh_inplace(y: &mut [f32]) {
|
||||||
|
unsafe { ffi::vvtanhf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vd_tanh_inplace(y: &mut [f64]) {
|
||||||
|
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||||
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
|
*y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
|
||||||
|
}
|
||||||
|
vs_tanh_inplace(ys);
|
||||||
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
|
*y = 0.5 * v * (1.0 + *y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
||||||
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
|
*y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
|
||||||
|
}
|
||||||
|
vd_tanh_inplace(ys);
|
||||||
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
|
*y = 0.5 * v * (1.0 + *y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
macro_rules! binary_op {
|
macro_rules! binary_op {
|
||||||
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
|
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
|
||||||
#[inline]
|
#[inline]
|
||||||
|
@ -57,6 +57,7 @@ pub trait BackendStorage: Sized {
|
|||||||
|
|
||||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
|
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
|
||||||
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
|
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
|
||||||
|
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self>;
|
||||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
|
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
|
||||||
|
|
||||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
|
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
|
||||||
@ -110,4 +111,6 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
|||||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||||
|
|
||||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||||
|
|
||||||
|
fn set_seed(&self, _: u64) -> Result<()>;
|
||||||
}
|
}
|
||||||
|
@ -69,7 +69,8 @@ impl Tensor {
|
|||||||
| Op::Binary(lhs, rhs, _)
|
| Op::Binary(lhs, rhs, _)
|
||||||
| Op::Gather(lhs, rhs, _)
|
| Op::Gather(lhs, rhs, _)
|
||||||
| Op::IndexSelect(lhs, rhs, _)
|
| Op::IndexSelect(lhs, rhs, _)
|
||||||
| Op::Matmul(lhs, rhs) => {
|
| Op::Matmul(lhs, rhs)
|
||||||
|
| Op::SliceScatter0(lhs, rhs, _) => {
|
||||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||||
track_grad |= tg;
|
track_grad |= tg;
|
||||||
let (tg, nodes) = walk(rhs, nodes, already_seen);
|
let (tg, nodes) = walk(rhs, nodes, already_seen);
|
||||||
@ -90,14 +91,18 @@ impl Tensor {
|
|||||||
nodes
|
nodes
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Op::Unary(_node, UnaryOp::Ceil)
|
||||||
|
| Op::Unary(_node, UnaryOp::Floor)
|
||||||
|
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
||||||
Op::Reshape(node)
|
Op::Reshape(node)
|
||||||
|
| Op::UpsampleNearest1D(node)
|
||||||
| Op::UpsampleNearest2D(node)
|
| Op::UpsampleNearest2D(node)
|
||||||
| Op::AvgPool2D { arg: node, .. }
|
| Op::AvgPool2D { arg: node, .. }
|
||||||
| Op::MaxPool2D { arg: node, .. }
|
| Op::MaxPool2D { arg: node, .. }
|
||||||
| Op::Copy(node)
|
| Op::Copy(node)
|
||||||
| Op::Broadcast(node)
|
| Op::Broadcast(node)
|
||||||
| Op::Cmp(node, _)
|
| Op::Cmp(node, _)
|
||||||
| Op::Reduce(node, _, _)
|
| Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
|
||||||
| Op::ToDType(node)
|
| Op::ToDType(node)
|
||||||
| Op::ToDevice(node)
|
| Op::ToDevice(node)
|
||||||
| Op::Transpose(node, _, _)
|
| Op::Transpose(node, _, _)
|
||||||
@ -111,6 +116,7 @@ impl Tensor {
|
|||||||
track_grad |= tg;
|
track_grad |= tg;
|
||||||
nodes
|
nodes
|
||||||
}
|
}
|
||||||
|
Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
nodes
|
nodes
|
||||||
@ -262,9 +268,21 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||||
}
|
}
|
||||||
|
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
|
||||||
|
op: "upsample-nearest1d",
|
||||||
|
})?,
|
||||||
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
|
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
|
||||||
op: "upsample-nearest2d",
|
op: "upsample-nearest2d",
|
||||||
})?,
|
})?,
|
||||||
|
Op::SliceScatter0(lhs, rhs, start_rhs) => {
|
||||||
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
|
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
|
||||||
|
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||||
|
|
||||||
|
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||||
|
let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?;
|
||||||
|
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?
|
||||||
|
}
|
||||||
Op::Gather(arg, indexes, dim) => {
|
Op::Gather(arg, indexes, dim) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
|
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
|
||||||
@ -436,7 +454,18 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
|
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
|
||||||
|
Op::Unary(_, UnaryOp::Floor) => {
|
||||||
|
Err(Error::BackwardNotSupported { op: "floor" })?
|
||||||
|
}
|
||||||
|
Op::Unary(_, UnaryOp::Round) => {
|
||||||
|
Err(Error::BackwardNotSupported { op: "round" })?
|
||||||
|
}
|
||||||
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
|
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
|
||||||
|
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
|
||||||
|
Op::Unary(_, UnaryOp::GeluErf) => {
|
||||||
|
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
|
||||||
|
}
|
||||||
Op::Unary(arg, UnaryOp::Relu) => {
|
Op::Unary(arg, UnaryOp::Relu) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||||
@ -517,6 +546,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct GradStore(HashMap<TensorId, Tensor>);
|
pub struct GradStore(HashMap<TensorId, Tensor>);
|
||||||
|
|
||||||
impl GradStore {
|
impl GradStore {
|
||||||
|
@ -25,6 +25,19 @@ impl ParamsConv1D {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
|
pub enum CudnnFwdAlgo {
|
||||||
|
ImplicitGemm,
|
||||||
|
ImplicitPrecompGemm,
|
||||||
|
Gemm,
|
||||||
|
Direct,
|
||||||
|
Fft,
|
||||||
|
FftTiling,
|
||||||
|
Winograd,
|
||||||
|
WinogradNonFused,
|
||||||
|
Count,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct ParamsConv2D {
|
pub struct ParamsConv2D {
|
||||||
pub(crate) b_size: usize,
|
pub(crate) b_size: usize,
|
||||||
@ -37,6 +50,7 @@ pub struct ParamsConv2D {
|
|||||||
pub(crate) padding: usize,
|
pub(crate) padding: usize,
|
||||||
pub(crate) stride: usize,
|
pub(crate) stride: usize,
|
||||||
pub(crate) dilation: usize,
|
pub(crate) dilation: usize,
|
||||||
|
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ParamsConv2D {
|
impl ParamsConv2D {
|
||||||
@ -188,6 +202,7 @@ impl Tensor {
|
|||||||
padding,
|
padding,
|
||||||
stride,
|
stride,
|
||||||
dilation,
|
dilation,
|
||||||
|
cudnn_fwd_algo: None,
|
||||||
};
|
};
|
||||||
if groups == 1 {
|
if groups == 1 {
|
||||||
self.conv2d_single_group(kernel, ¶ms)
|
self.conv2d_single_group(kernel, ¶ms)
|
||||||
|
763
candle-core/src/cpu/erf.rs
Normal file
763
candle-core/src/cpu/erf.rs
Normal file
@ -0,0 +1,763 @@
|
|||||||
|
#![allow(clippy::excessive_precision)]
|
||||||
|
// Code taken from https://github.com/statrs-dev/statrs
|
||||||
|
//! Provides the [error](https://en.wikipedia.org/wiki/Error_function) and
|
||||||
|
//! related functions
|
||||||
|
|
||||||
|
mod evaluate {
|
||||||
|
//! Provides functions that don't have a numerical solution and must
|
||||||
|
//! be solved computationally (e.g. evaluation of a polynomial)
|
||||||
|
|
||||||
|
/// evaluates a polynomial at `z` where `coeff` are the coeffecients
|
||||||
|
/// to a polynomial of order `k` where `k` is the length of `coeff` and the
|
||||||
|
/// coeffecient
|
||||||
|
/// to the `k`th power is the `k`th element in coeff. E.g. [3,-1,2] equates to
|
||||||
|
/// `2z^2 - z + 3`
|
||||||
|
///
|
||||||
|
/// # Remarks
|
||||||
|
///
|
||||||
|
/// Returns 0 for a 0 length coefficient slice
|
||||||
|
pub fn polynomial(z: f64, coeff: &[f64]) -> f64 {
|
||||||
|
let n = coeff.len();
|
||||||
|
if n == 0 {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut sum = *coeff.last().unwrap();
|
||||||
|
for c in coeff[0..n - 1].iter().rev() {
|
||||||
|
sum = *c + z * sum;
|
||||||
|
}
|
||||||
|
sum
|
||||||
|
}
|
||||||
|
}
|
||||||
|
use std::f64;
|
||||||
|
|
||||||
|
/// `erf` calculates the error function at `x`.
|
||||||
|
pub fn erf(x: f64) -> f64 {
|
||||||
|
if x.is_nan() {
|
||||||
|
f64::NAN
|
||||||
|
} else if x >= 0.0 && x.is_infinite() {
|
||||||
|
1.0
|
||||||
|
} else if x <= 0.0 && x.is_infinite() {
|
||||||
|
-1.0
|
||||||
|
} else if x == 0. {
|
||||||
|
0.0
|
||||||
|
} else {
|
||||||
|
erf_impl(x, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `erf_inv` calculates the inverse error function
|
||||||
|
/// at `x`.
|
||||||
|
pub fn erf_inv(x: f64) -> f64 {
|
||||||
|
if x == 0.0 {
|
||||||
|
0.0
|
||||||
|
} else if x >= 1.0 {
|
||||||
|
f64::INFINITY
|
||||||
|
} else if x <= -1.0 {
|
||||||
|
f64::NEG_INFINITY
|
||||||
|
} else if x < 0.0 {
|
||||||
|
erf_inv_impl(-x, 1.0 + x, -1.0)
|
||||||
|
} else {
|
||||||
|
erf_inv_impl(x, 1.0 - x, 1.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `erfc` calculates the complementary error function
|
||||||
|
/// at `x`.
|
||||||
|
pub fn erfc(x: f64) -> f64 {
|
||||||
|
if x.is_nan() {
|
||||||
|
f64::NAN
|
||||||
|
} else if x == f64::INFINITY {
|
||||||
|
0.0
|
||||||
|
} else if x == f64::NEG_INFINITY {
|
||||||
|
2.0
|
||||||
|
} else {
|
||||||
|
erf_impl(x, true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `erfc_inv` calculates the complementary inverse
|
||||||
|
/// error function at `x`.
|
||||||
|
pub fn erfc_inv(x: f64) -> f64 {
|
||||||
|
if x <= 0.0 {
|
||||||
|
f64::INFINITY
|
||||||
|
} else if x >= 2.0 {
|
||||||
|
f64::NEG_INFINITY
|
||||||
|
} else if x > 1.0 {
|
||||||
|
erf_inv_impl(-1.0 + x, 2.0 - x, -1.0)
|
||||||
|
} else {
|
||||||
|
erf_inv_impl(1.0 - x, x, 1.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// **********************************************************
|
||||||
|
// ********** Coefficients for erf_impl polynomial **********
|
||||||
|
// **********************************************************
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator of `erf_impl`
|
||||||
|
/// in the interval [1e-10, 0.5].
|
||||||
|
const ERF_IMPL_AN: &[f64] = &[
|
||||||
|
0.00337916709551257388990745,
|
||||||
|
-0.00073695653048167948530905,
|
||||||
|
-0.374732337392919607868241,
|
||||||
|
0.0817442448733587196071743,
|
||||||
|
-0.0421089319936548595203468,
|
||||||
|
0.0070165709512095756344528,
|
||||||
|
-0.00495091255982435110337458,
|
||||||
|
0.000871646599037922480317225,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator of `erf_impl`
|
||||||
|
/// in the interval [1e-10, 0.5]
|
||||||
|
const ERF_IMPL_AD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
-0.218088218087924645390535,
|
||||||
|
0.412542972725442099083918,
|
||||||
|
-0.0841891147873106755410271,
|
||||||
|
0.0655338856400241519690695,
|
||||||
|
-0.0120019604454941768171266,
|
||||||
|
0.00408165558926174048329689,
|
||||||
|
-0.000615900721557769691924509,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||||
|
/// in the interval [0.5, 0.75].
|
||||||
|
const ERF_IMPL_BN: &[f64] = &[
|
||||||
|
-0.0361790390718262471360258,
|
||||||
|
0.292251883444882683221149,
|
||||||
|
0.281447041797604512774415,
|
||||||
|
0.125610208862766947294894,
|
||||||
|
0.0274135028268930549240776,
|
||||||
|
0.00250839672168065762786937,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||||
|
/// in the interval [0.5, 0.75].
|
||||||
|
const ERF_IMPL_BD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
1.8545005897903486499845,
|
||||||
|
1.43575803037831418074962,
|
||||||
|
0.582827658753036572454135,
|
||||||
|
0.124810476932949746447682,
|
||||||
|
0.0113724176546353285778481,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||||
|
/// in the interval [0.75, 1.25].
|
||||||
|
const ERF_IMPL_CN: &[f64] = &[
|
||||||
|
-0.0397876892611136856954425,
|
||||||
|
0.153165212467878293257683,
|
||||||
|
0.191260295600936245503129,
|
||||||
|
0.10276327061989304213645,
|
||||||
|
0.029637090615738836726027,
|
||||||
|
0.0046093486780275489468812,
|
||||||
|
0.000307607820348680180548455,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||||
|
/// in the interval [0.75, 1.25].
|
||||||
|
const ERF_IMPL_CD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
1.95520072987627704987886,
|
||||||
|
1.64762317199384860109595,
|
||||||
|
0.768238607022126250082483,
|
||||||
|
0.209793185936509782784315,
|
||||||
|
0.0319569316899913392596356,
|
||||||
|
0.00213363160895785378615014,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||||
|
/// in the interval [1.25, 2.25].
|
||||||
|
const ERF_IMPL_DN: &[f64] = &[
|
||||||
|
-0.0300838560557949717328341,
|
||||||
|
0.0538578829844454508530552,
|
||||||
|
0.0726211541651914182692959,
|
||||||
|
0.0367628469888049348429018,
|
||||||
|
0.00964629015572527529605267,
|
||||||
|
0.00133453480075291076745275,
|
||||||
|
0.778087599782504251917881e-4,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||||
|
/// in the interval [1.25, 2.25].
|
||||||
|
const ERF_IMPL_DD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
1.75967098147167528287343,
|
||||||
|
1.32883571437961120556307,
|
||||||
|
0.552528596508757581287907,
|
||||||
|
0.133793056941332861912279,
|
||||||
|
0.0179509645176280768640766,
|
||||||
|
0.00104712440019937356634038,
|
||||||
|
-0.106640381820357337177643e-7,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||||
|
/// in the interval [2.25, 3.5].
|
||||||
|
const ERF_IMPL_EN: &[f64] = &[
|
||||||
|
-0.0117907570137227847827732,
|
||||||
|
0.014262132090538809896674,
|
||||||
|
0.0202234435902960820020765,
|
||||||
|
0.00930668299990432009042239,
|
||||||
|
0.00213357802422065994322516,
|
||||||
|
0.00025022987386460102395382,
|
||||||
|
0.120534912219588189822126e-4,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||||
|
/// in the interval [2.25, 3.5].
|
||||||
|
const ERF_IMPL_ED: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
1.50376225203620482047419,
|
||||||
|
0.965397786204462896346934,
|
||||||
|
0.339265230476796681555511,
|
||||||
|
0.0689740649541569716897427,
|
||||||
|
0.00771060262491768307365526,
|
||||||
|
0.000371421101531069302990367,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||||
|
/// in the interval [3.5, 5.25].
|
||||||
|
const ERF_IMPL_FN: &[f64] = &[
|
||||||
|
-0.00546954795538729307482955,
|
||||||
|
0.00404190278731707110245394,
|
||||||
|
0.0054963369553161170521356,
|
||||||
|
0.00212616472603945399437862,
|
||||||
|
0.000394984014495083900689956,
|
||||||
|
0.365565477064442377259271e-4,
|
||||||
|
0.135485897109932323253786e-5,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||||
|
/// in the interval [3.5, 5.25].
|
||||||
|
const ERF_IMPL_FD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
1.21019697773630784832251,
|
||||||
|
0.620914668221143886601045,
|
||||||
|
0.173038430661142762569515,
|
||||||
|
0.0276550813773432047594539,
|
||||||
|
0.00240625974424309709745382,
|
||||||
|
0.891811817251336577241006e-4,
|
||||||
|
-0.465528836283382684461025e-11,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||||
|
/// in the interval [5.25, 8].
|
||||||
|
const ERF_IMPL_GN: &[f64] = &[
|
||||||
|
-0.00270722535905778347999196,
|
||||||
|
0.0013187563425029400461378,
|
||||||
|
0.00119925933261002333923989,
|
||||||
|
0.00027849619811344664248235,
|
||||||
|
0.267822988218331849989363e-4,
|
||||||
|
0.923043672315028197865066e-6,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||||
|
/// in the interval [5.25, 8].
|
||||||
|
const ERF_IMPL_GD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
0.814632808543141591118279,
|
||||||
|
0.268901665856299542168425,
|
||||||
|
0.0449877216103041118694989,
|
||||||
|
0.00381759663320248459168994,
|
||||||
|
0.000131571897888596914350697,
|
||||||
|
0.404815359675764138445257e-11,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||||
|
/// in the interval [8, 11.5].
|
||||||
|
const ERF_IMPL_HN: &[f64] = &[
|
||||||
|
-0.00109946720691742196814323,
|
||||||
|
0.000406425442750422675169153,
|
||||||
|
0.000274499489416900707787024,
|
||||||
|
0.465293770646659383436343e-4,
|
||||||
|
0.320955425395767463401993e-5,
|
||||||
|
0.778286018145020892261936e-7,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||||
|
/// in the interval [8, 11.5].
|
||||||
|
const ERF_IMPL_HD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
0.588173710611846046373373,
|
||||||
|
0.139363331289409746077541,
|
||||||
|
0.0166329340417083678763028,
|
||||||
|
0.00100023921310234908642639,
|
||||||
|
0.24254837521587225125068e-4,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||||
|
/// in the interval [11.5, 17].
|
||||||
|
const ERF_IMPL_IN: &[f64] = &[
|
||||||
|
-0.00056907993601094962855594,
|
||||||
|
0.000169498540373762264416984,
|
||||||
|
0.518472354581100890120501e-4,
|
||||||
|
0.382819312231928859704678e-5,
|
||||||
|
0.824989931281894431781794e-7,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||||
|
/// in the interval [11.5, 17].
|
||||||
|
const ERF_IMPL_ID: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
0.339637250051139347430323,
|
||||||
|
0.043472647870310663055044,
|
||||||
|
0.00248549335224637114641629,
|
||||||
|
0.535633305337152900549536e-4,
|
||||||
|
-0.117490944405459578783846e-12,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||||
|
/// in the interval [17, 24].
|
||||||
|
const ERF_IMPL_JN: &[f64] = &[
|
||||||
|
-0.000241313599483991337479091,
|
||||||
|
0.574224975202501512365975e-4,
|
||||||
|
0.115998962927383778460557e-4,
|
||||||
|
0.581762134402593739370875e-6,
|
||||||
|
0.853971555085673614607418e-8,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||||
|
/// in the interval [17, 24].
|
||||||
|
const ERF_IMPL_JD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
0.233044138299687841018015,
|
||||||
|
0.0204186940546440312625597,
|
||||||
|
0.000797185647564398289151125,
|
||||||
|
0.117019281670172327758019e-4,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||||
|
/// in the interval [24, 38].
|
||||||
|
const ERF_IMPL_KN: &[f64] = &[
|
||||||
|
-0.000146674699277760365803642,
|
||||||
|
0.162666552112280519955647e-4,
|
||||||
|
0.269116248509165239294897e-5,
|
||||||
|
0.979584479468091935086972e-7,
|
||||||
|
0.101994647625723465722285e-8,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||||
|
/// in the interval [24, 38].
|
||||||
|
const ERF_IMPL_KD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
0.165907812944847226546036,
|
||||||
|
0.0103361716191505884359634,
|
||||||
|
0.000286593026373868366935721,
|
||||||
|
0.298401570840900340874568e-5,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||||
|
/// in the interval [38, 60].
|
||||||
|
const ERF_IMPL_LN: &[f64] = &[
|
||||||
|
-0.583905797629771786720406e-4,
|
||||||
|
0.412510325105496173512992e-5,
|
||||||
|
0.431790922420250949096906e-6,
|
||||||
|
0.993365155590013193345569e-8,
|
||||||
|
0.653480510020104699270084e-10,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||||
|
/// in the interval [38, 60].
|
||||||
|
const ERF_IMPL_LD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
0.105077086072039915406159,
|
||||||
|
0.00414278428675475620830226,
|
||||||
|
0.726338754644523769144108e-4,
|
||||||
|
0.477818471047398785369849e-6,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||||
|
/// in the interval [60, 85].
|
||||||
|
const ERF_IMPL_MN: &[f64] = &[
|
||||||
|
-0.196457797609229579459841e-4,
|
||||||
|
0.157243887666800692441195e-5,
|
||||||
|
0.543902511192700878690335e-7,
|
||||||
|
0.317472492369117710852685e-9,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||||
|
/// in the interval [60, 85].
|
||||||
|
const ERF_IMPL_MD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
0.052803989240957632204885,
|
||||||
|
0.000926876069151753290378112,
|
||||||
|
0.541011723226630257077328e-5,
|
||||||
|
0.535093845803642394908747e-15,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||||
|
/// in the interval [85, 110].
|
||||||
|
const ERF_IMPL_NN: &[f64] = &[
|
||||||
|
-0.789224703978722689089794e-5,
|
||||||
|
0.622088451660986955124162e-6,
|
||||||
|
0.145728445676882396797184e-7,
|
||||||
|
0.603715505542715364529243e-10,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||||
|
/// in the interval [85, 110].
|
||||||
|
const ERF_IMPL_ND: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
0.0375328846356293715248719,
|
||||||
|
0.000467919535974625308126054,
|
||||||
|
0.193847039275845656900547e-5,
|
||||||
|
];
|
||||||
|
|
||||||
|
// **********************************************************
|
||||||
|
// ********** Coefficients for erf_inv_impl polynomial ******
|
||||||
|
// **********************************************************
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||||
|
/// in the interval [0, 0.5].
|
||||||
|
const ERF_INV_IMPL_AN: &[f64] = &[
|
||||||
|
-0.000508781949658280665617,
|
||||||
|
-0.00836874819741736770379,
|
||||||
|
0.0334806625409744615033,
|
||||||
|
-0.0126926147662974029034,
|
||||||
|
-0.0365637971411762664006,
|
||||||
|
0.0219878681111168899165,
|
||||||
|
0.00822687874676915743155,
|
||||||
|
-0.00538772965071242932965,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||||
|
/// in the interval [0, 0.5].
|
||||||
|
const ERF_INV_IMPL_AD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
-0.970005043303290640362,
|
||||||
|
-1.56574558234175846809,
|
||||||
|
1.56221558398423026363,
|
||||||
|
0.662328840472002992063,
|
||||||
|
-0.71228902341542847553,
|
||||||
|
-0.0527396382340099713954,
|
||||||
|
0.0795283687341571680018,
|
||||||
|
-0.00233393759374190016776,
|
||||||
|
0.000886216390456424707504,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||||
|
/// in the interval [0.5, 0.75].
|
||||||
|
const ERF_INV_IMPL_BN: &[f64] = &[
|
||||||
|
-0.202433508355938759655,
|
||||||
|
0.105264680699391713268,
|
||||||
|
8.37050328343119927838,
|
||||||
|
17.6447298408374015486,
|
||||||
|
-18.8510648058714251895,
|
||||||
|
-44.6382324441786960818,
|
||||||
|
17.445385985570866523,
|
||||||
|
21.1294655448340526258,
|
||||||
|
-3.67192254707729348546,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||||
|
/// in the interval [0.5, 0.75].
|
||||||
|
const ERF_INV_IMPL_BD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
6.24264124854247537712,
|
||||||
|
3.9713437953343869095,
|
||||||
|
-28.6608180499800029974,
|
||||||
|
-20.1432634680485188801,
|
||||||
|
48.5609213108739935468,
|
||||||
|
10.8268667355460159008,
|
||||||
|
-22.6436933413139721736,
|
||||||
|
1.72114765761200282724,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||||
|
/// in the interval [0.75, 1] with x less than 3.
|
||||||
|
const ERF_INV_IMPL_CN: &[f64] = &[
|
||||||
|
-0.131102781679951906451,
|
||||||
|
-0.163794047193317060787,
|
||||||
|
0.117030156341995252019,
|
||||||
|
0.387079738972604337464,
|
||||||
|
0.337785538912035898924,
|
||||||
|
0.142869534408157156766,
|
||||||
|
0.0290157910005329060432,
|
||||||
|
0.00214558995388805277169,
|
||||||
|
-0.679465575181126350155e-6,
|
||||||
|
0.285225331782217055858e-7,
|
||||||
|
-0.681149956853776992068e-9,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||||
|
/// in the interval [0.75, 1] with x less than 3.
|
||||||
|
const ERF_INV_IMPL_CD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
3.46625407242567245975,
|
||||||
|
5.38168345707006855425,
|
||||||
|
4.77846592945843778382,
|
||||||
|
2.59301921623620271374,
|
||||||
|
0.848854343457902036425,
|
||||||
|
0.152264338295331783612,
|
||||||
|
0.01105924229346489121,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||||
|
/// in the interval [0.75, 1] with x between 3 and 6.
|
||||||
|
const ERF_INV_IMPL_DN: &[f64] = &[
|
||||||
|
-0.0350353787183177984712,
|
||||||
|
-0.00222426529213447927281,
|
||||||
|
0.0185573306514231072324,
|
||||||
|
0.00950804701325919603619,
|
||||||
|
0.00187123492819559223345,
|
||||||
|
0.000157544617424960554631,
|
||||||
|
0.460469890584317994083e-5,
|
||||||
|
-0.230404776911882601748e-9,
|
||||||
|
0.266339227425782031962e-11,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||||
|
/// in the interval [0.75, 1] with x between 3 and 6.
|
||||||
|
const ERF_INV_IMPL_DD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
1.3653349817554063097,
|
||||||
|
0.762059164553623404043,
|
||||||
|
0.220091105764131249824,
|
||||||
|
0.0341589143670947727934,
|
||||||
|
0.00263861676657015992959,
|
||||||
|
0.764675292302794483503e-4,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||||
|
/// in the interval [0.75, 1] with x between 6 and 18.
|
||||||
|
const ERF_INV_IMPL_EN: &[f64] = &[
|
||||||
|
-0.0167431005076633737133,
|
||||||
|
-0.00112951438745580278863,
|
||||||
|
0.00105628862152492910091,
|
||||||
|
0.000209386317487588078668,
|
||||||
|
0.149624783758342370182e-4,
|
||||||
|
0.449696789927706453732e-6,
|
||||||
|
0.462596163522878599135e-8,
|
||||||
|
-0.281128735628831791805e-13,
|
||||||
|
0.99055709973310326855e-16,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||||
|
/// in the interval [0.75, 1] with x between 6 and 18.
|
||||||
|
const ERF_INV_IMPL_ED: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
0.591429344886417493481,
|
||||||
|
0.138151865749083321638,
|
||||||
|
0.0160746087093676504695,
|
||||||
|
0.000964011807005165528527,
|
||||||
|
0.275335474764726041141e-4,
|
||||||
|
0.282243172016108031869e-6,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||||
|
/// in the interval [0.75, 1] with x between 18 and 44.
|
||||||
|
const ERF_INV_IMPL_FN: &[f64] = &[
|
||||||
|
-0.0024978212791898131227,
|
||||||
|
-0.779190719229053954292e-5,
|
||||||
|
0.254723037413027451751e-4,
|
||||||
|
0.162397777342510920873e-5,
|
||||||
|
0.396341011304801168516e-7,
|
||||||
|
0.411632831190944208473e-9,
|
||||||
|
0.145596286718675035587e-11,
|
||||||
|
-0.116765012397184275695e-17,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||||
|
/// in the interval [0.75, 1] with x between 18 and 44.
|
||||||
|
const ERF_INV_IMPL_FD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
0.207123112214422517181,
|
||||||
|
0.0169410838120975906478,
|
||||||
|
0.000690538265622684595676,
|
||||||
|
0.145007359818232637924e-4,
|
||||||
|
0.144437756628144157666e-6,
|
||||||
|
0.509761276599778486139e-9,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||||
|
/// in the interval [0.75, 1] with x greater than 44.
|
||||||
|
const ERF_INV_IMPL_GN: &[f64] = &[
|
||||||
|
-0.000539042911019078575891,
|
||||||
|
-0.28398759004727721098e-6,
|
||||||
|
0.899465114892291446442e-6,
|
||||||
|
0.229345859265920864296e-7,
|
||||||
|
0.225561444863500149219e-9,
|
||||||
|
0.947846627503022684216e-12,
|
||||||
|
0.135880130108924861008e-14,
|
||||||
|
-0.348890393399948882918e-21,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||||
|
/// in the interval [0.75, 1] with x greater than 44.
|
||||||
|
const ERF_INV_IMPL_GD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
0.0845746234001899436914,
|
||||||
|
0.00282092984726264681981,
|
||||||
|
0.468292921940894236786e-4,
|
||||||
|
0.399968812193862100054e-6,
|
||||||
|
0.161809290887904476097e-8,
|
||||||
|
0.231558608310259605225e-11,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// `erf_impl` computes the error function at `z`.
|
||||||
|
/// If `inv` is true, `1 - erf` is calculated as opposed to `erf`
|
||||||
|
fn erf_impl(z: f64, inv: bool) -> f64 {
|
||||||
|
if z < 0.0 {
|
||||||
|
if !inv {
|
||||||
|
return -erf_impl(-z, false);
|
||||||
|
}
|
||||||
|
if z < -0.5 {
|
||||||
|
return 2.0 - erf_impl(-z, true);
|
||||||
|
}
|
||||||
|
return 1.0 + erf_impl(-z, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = if z < 0.5 {
|
||||||
|
if z < 1e-10 {
|
||||||
|
z * 1.125 + z * 0.003379167095512573896158903121545171688
|
||||||
|
} else {
|
||||||
|
z * 1.125
|
||||||
|
+ z * evaluate::polynomial(z, ERF_IMPL_AN) / evaluate::polynomial(z, ERF_IMPL_AD)
|
||||||
|
}
|
||||||
|
} else if z < 110.0 {
|
||||||
|
let (r, b) = if z < 0.75 {
|
||||||
|
(
|
||||||
|
evaluate::polynomial(z - 0.5, ERF_IMPL_BN)
|
||||||
|
/ evaluate::polynomial(z - 0.5, ERF_IMPL_BD),
|
||||||
|
0.3440242112,
|
||||||
|
)
|
||||||
|
} else if z < 1.25 {
|
||||||
|
(
|
||||||
|
evaluate::polynomial(z - 0.75, ERF_IMPL_CN)
|
||||||
|
/ evaluate::polynomial(z - 0.75, ERF_IMPL_CD),
|
||||||
|
0.419990927,
|
||||||
|
)
|
||||||
|
} else if z < 2.25 {
|
||||||
|
(
|
||||||
|
evaluate::polynomial(z - 1.25, ERF_IMPL_DN)
|
||||||
|
/ evaluate::polynomial(z - 1.25, ERF_IMPL_DD),
|
||||||
|
0.4898625016,
|
||||||
|
)
|
||||||
|
} else if z < 3.5 {
|
||||||
|
(
|
||||||
|
evaluate::polynomial(z - 2.25, ERF_IMPL_EN)
|
||||||
|
/ evaluate::polynomial(z - 2.25, ERF_IMPL_ED),
|
||||||
|
0.5317370892,
|
||||||
|
)
|
||||||
|
} else if z < 5.25 {
|
||||||
|
(
|
||||||
|
evaluate::polynomial(z - 3.5, ERF_IMPL_FN)
|
||||||
|
/ evaluate::polynomial(z - 3.5, ERF_IMPL_FD),
|
||||||
|
0.5489973426,
|
||||||
|
)
|
||||||
|
} else if z < 8.0 {
|
||||||
|
(
|
||||||
|
evaluate::polynomial(z - 5.25, ERF_IMPL_GN)
|
||||||
|
/ evaluate::polynomial(z - 5.25, ERF_IMPL_GD),
|
||||||
|
0.5571740866,
|
||||||
|
)
|
||||||
|
} else if z < 11.5 {
|
||||||
|
(
|
||||||
|
evaluate::polynomial(z - 8.0, ERF_IMPL_HN)
|
||||||
|
/ evaluate::polynomial(z - 8.0, ERF_IMPL_HD),
|
||||||
|
0.5609807968,
|
||||||
|
)
|
||||||
|
} else if z < 17.0 {
|
||||||
|
(
|
||||||
|
evaluate::polynomial(z - 11.5, ERF_IMPL_IN)
|
||||||
|
/ evaluate::polynomial(z - 11.5, ERF_IMPL_ID),
|
||||||
|
0.5626493692,
|
||||||
|
)
|
||||||
|
} else if z < 24.0 {
|
||||||
|
(
|
||||||
|
evaluate::polynomial(z - 17.0, ERF_IMPL_JN)
|
||||||
|
/ evaluate::polynomial(z - 17.0, ERF_IMPL_JD),
|
||||||
|
0.5634598136,
|
||||||
|
)
|
||||||
|
} else if z < 38.0 {
|
||||||
|
(
|
||||||
|
evaluate::polynomial(z - 24.0, ERF_IMPL_KN)
|
||||||
|
/ evaluate::polynomial(z - 24.0, ERF_IMPL_KD),
|
||||||
|
0.5638477802,
|
||||||
|
)
|
||||||
|
} else if z < 60.0 {
|
||||||
|
(
|
||||||
|
evaluate::polynomial(z - 38.0, ERF_IMPL_LN)
|
||||||
|
/ evaluate::polynomial(z - 38.0, ERF_IMPL_LD),
|
||||||
|
0.5640528202,
|
||||||
|
)
|
||||||
|
} else if z < 85.0 {
|
||||||
|
(
|
||||||
|
evaluate::polynomial(z - 60.0, ERF_IMPL_MN)
|
||||||
|
/ evaluate::polynomial(z - 60.0, ERF_IMPL_MD),
|
||||||
|
0.5641309023,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
evaluate::polynomial(z - 85.0, ERF_IMPL_NN)
|
||||||
|
/ evaluate::polynomial(z - 85.0, ERF_IMPL_ND),
|
||||||
|
0.5641584396,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
let g = (-z * z).exp() / z;
|
||||||
|
g * b + g * r
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
|
||||||
|
if inv && z >= 0.5 {
|
||||||
|
result
|
||||||
|
} else if z >= 0.5 || inv {
|
||||||
|
1.0 - result
|
||||||
|
} else {
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// `erf_inv_impl` computes the inverse error function where
|
||||||
|
// `p`,`q`, and `s` are the first, second, and third intermediate
|
||||||
|
// parameters respectively
|
||||||
|
fn erf_inv_impl(p: f64, q: f64, s: f64) -> f64 {
|
||||||
|
let result = if p <= 0.5 {
|
||||||
|
let y = 0.0891314744949340820313;
|
||||||
|
let g = p * (p + 10.0);
|
||||||
|
let r = evaluate::polynomial(p, ERF_INV_IMPL_AN) / evaluate::polynomial(p, ERF_INV_IMPL_AD);
|
||||||
|
g * y + g * r
|
||||||
|
} else if q >= 0.25 {
|
||||||
|
let y = 2.249481201171875;
|
||||||
|
let g = (-2.0 * q.ln()).sqrt();
|
||||||
|
let xs = q - 0.25;
|
||||||
|
let r =
|
||||||
|
evaluate::polynomial(xs, ERF_INV_IMPL_BN) / evaluate::polynomial(xs, ERF_INV_IMPL_BD);
|
||||||
|
g / (y + r)
|
||||||
|
} else {
|
||||||
|
let x = (-q.ln()).sqrt();
|
||||||
|
if x < 3.0 {
|
||||||
|
let y = 0.807220458984375;
|
||||||
|
let xs = x - 1.125;
|
||||||
|
let r = evaluate::polynomial(xs, ERF_INV_IMPL_CN)
|
||||||
|
/ evaluate::polynomial(xs, ERF_INV_IMPL_CD);
|
||||||
|
y * x + r * x
|
||||||
|
} else if x < 6.0 {
|
||||||
|
let y = 0.93995571136474609375;
|
||||||
|
let xs = x - 3.0;
|
||||||
|
let r = evaluate::polynomial(xs, ERF_INV_IMPL_DN)
|
||||||
|
/ evaluate::polynomial(xs, ERF_INV_IMPL_DD);
|
||||||
|
y * x + r * x
|
||||||
|
} else if x < 18.0 {
|
||||||
|
let y = 0.98362827301025390625;
|
||||||
|
let xs = x - 6.0;
|
||||||
|
let r = evaluate::polynomial(xs, ERF_INV_IMPL_EN)
|
||||||
|
/ evaluate::polynomial(xs, ERF_INV_IMPL_ED);
|
||||||
|
y * x + r * x
|
||||||
|
} else if x < 44.0 {
|
||||||
|
let y = 0.99714565277099609375;
|
||||||
|
let xs = x - 18.0;
|
||||||
|
let r = evaluate::polynomial(xs, ERF_INV_IMPL_FN)
|
||||||
|
/ evaluate::polynomial(xs, ERF_INV_IMPL_FD);
|
||||||
|
y * x + r * x
|
||||||
|
} else {
|
||||||
|
let y = 0.99941349029541015625;
|
||||||
|
let xs = x - 44.0;
|
||||||
|
let r = evaluate::polynomial(xs, ERF_INV_IMPL_GN)
|
||||||
|
/ evaluate::polynomial(xs, ERF_INV_IMPL_GD);
|
||||||
|
y * x + r * x
|
||||||
|
}
|
||||||
|
};
|
||||||
|
s * result
|
||||||
|
}
|
@ -1,3 +1,4 @@
|
|||||||
|
pub mod erf;
|
||||||
pub mod kernels;
|
pub mod kernels;
|
||||||
|
|
||||||
trait Cpu<const ARR: usize> {
|
trait Cpu<const ARR: usize> {
|
||||||
|
@ -4,6 +4,9 @@ use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
|
|||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
|
|
||||||
|
const USE_IM2COL_CONV1D: bool = true;
|
||||||
|
const USE_IM2COL_CONV2D: bool = true;
|
||||||
|
|
||||||
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
||||||
// intercept the oom errors to avoid panicking and provide a proper error.
|
// intercept the oom errors to avoid panicking and provide a proper error.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -724,6 +727,36 @@ impl Map1 for MaxPool2D {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct UpsampleNearest1D(usize);
|
||||||
|
|
||||||
|
impl Map1 for UpsampleNearest1D {
|
||||||
|
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||||
|
// TODO: Specialized implementation for the case 2*sz?
|
||||||
|
let dst_sz = self.0;
|
||||||
|
let (b_sz, c, src_sz) = layout.shape().dims3()?;
|
||||||
|
let stride = layout.stride();
|
||||||
|
let stride_sz = stride[2];
|
||||||
|
let src_index = layout.start_offset();
|
||||||
|
let scale_sz = src_sz as f64 / dst_sz as f64;
|
||||||
|
let mut dst = vec![T::zero(); b_sz * c * dst_sz];
|
||||||
|
let src_idxs = (0..dst_sz)
|
||||||
|
.map(|idx| usize::min(src_sz - 1, (idx as f64 * scale_sz) as usize))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
for b_idx in 0..b_sz {
|
||||||
|
let dst = &mut dst[b_idx * c * dst_sz..];
|
||||||
|
let src_index = src_index + b_idx * stride[0];
|
||||||
|
for c_idx in 0..c {
|
||||||
|
let dst = &mut dst[c_idx * dst_sz..];
|
||||||
|
let src_index = src_index + c_idx * stride[1];
|
||||||
|
for (idx, src_idx) in src_idxs.iter().enumerate() {
|
||||||
|
dst[idx] = src[src_index + src_idx * stride_sz]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct UpsampleNearest2D(usize, usize);
|
struct UpsampleNearest2D(usize, usize);
|
||||||
|
|
||||||
impl Map1 for UpsampleNearest2D {
|
impl Map1 for UpsampleNearest2D {
|
||||||
@ -1089,6 +1122,140 @@ impl<'a> Map2 for Conv1D<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Im2Col1D {
|
||||||
|
l_k: usize,
|
||||||
|
stride: usize,
|
||||||
|
dilation: usize,
|
||||||
|
padding: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Im2Col1D {
|
||||||
|
fn l_out(&self, l: usize) -> usize {
|
||||||
|
(l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Map1 for Im2Col1D {
|
||||||
|
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||||
|
let &Self {
|
||||||
|
l_k,
|
||||||
|
stride,
|
||||||
|
dilation,
|
||||||
|
padding,
|
||||||
|
} = self;
|
||||||
|
let (b, c, l) = layout.shape().dims3()?;
|
||||||
|
let l_out = self.l_out(l);
|
||||||
|
let src = &vs[layout.start_offset()..];
|
||||||
|
let mut dst = vec![T::zero(); b * l_out * c * l_k];
|
||||||
|
let (src_s0, src_s1, src_s2) = {
|
||||||
|
let s = layout.stride();
|
||||||
|
(s[0], s[1], s[2])
|
||||||
|
};
|
||||||
|
// TODO: provide specialized kernels for the common use cases.
|
||||||
|
// - l_k = 1
|
||||||
|
// - padding = 0
|
||||||
|
// - stride = 1
|
||||||
|
// - dilation = 1
|
||||||
|
for b_idx in 0..b {
|
||||||
|
let src_idx = b_idx * src_s0;
|
||||||
|
let dst_idx = b_idx * l_out * c * l_k;
|
||||||
|
for l_idx in 0..l_out {
|
||||||
|
let dst_idx = dst_idx + l_idx * c * l_k;
|
||||||
|
for c_idx in 0..c {
|
||||||
|
let dst_idx = dst_idx + c_idx * l_k;
|
||||||
|
let src_idx = c_idx * src_s1 + src_idx;
|
||||||
|
for l_k_idx in 0..l_k {
|
||||||
|
let src_l = l_idx * stride + l_k_idx * dilation;
|
||||||
|
if padding != 0 && (src_l < padding || src_l >= l + padding) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let src_l = src_l - padding;
|
||||||
|
let src_idx = src_idx + src_l * src_s2;
|
||||||
|
let dst_idx = dst_idx + l_k_idx;
|
||||||
|
dst[dst_idx] = src[src_idx]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Im2Col {
|
||||||
|
h_k: usize,
|
||||||
|
w_k: usize,
|
||||||
|
stride: usize,
|
||||||
|
dilation: usize,
|
||||||
|
padding: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Im2Col {
|
||||||
|
fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
|
||||||
|
let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
|
||||||
|
let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
|
||||||
|
(h_out, w_out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Map1 for Im2Col {
|
||||||
|
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||||
|
let &Self {
|
||||||
|
h_k,
|
||||||
|
w_k,
|
||||||
|
stride,
|
||||||
|
dilation,
|
||||||
|
padding,
|
||||||
|
} = self;
|
||||||
|
let (b, c, h, w) = layout.shape().dims4()?;
|
||||||
|
let (h_out, w_out) = self.hw_out(h, w);
|
||||||
|
let src = &vs[layout.start_offset()..];
|
||||||
|
let mut dst = vec![T::zero(); b * h_out * w_out * c * h_k * w_k];
|
||||||
|
let (src_s0, src_s1, src_s2, src_s3) = {
|
||||||
|
let s = layout.stride();
|
||||||
|
(s[0], s[1], s[2], s[3])
|
||||||
|
};
|
||||||
|
// TODO: provide specialized kernels for the common use cases.
|
||||||
|
// - h_k = w_k = 1
|
||||||
|
// - padding = 0
|
||||||
|
// - stride = 1
|
||||||
|
// - dilation = 1
|
||||||
|
for b_idx in 0..b {
|
||||||
|
let src_idx = b_idx * src_s0;
|
||||||
|
let dst_idx = b_idx * h_out * w_out * c * h_k * w_k;
|
||||||
|
for h_idx in 0..h_out {
|
||||||
|
let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k;
|
||||||
|
for w_idx in 0..w_out {
|
||||||
|
let dst_idx = dst_idx + w_idx * c * h_k * w_k;
|
||||||
|
for c_idx in 0..c {
|
||||||
|
let dst_idx = dst_idx + c_idx * h_k * w_k;
|
||||||
|
let src_idx = c_idx * src_s1 + src_idx;
|
||||||
|
for h_k_idx in 0..h_k {
|
||||||
|
let src_h = h_idx * stride + h_k_idx * dilation;
|
||||||
|
if padding != 0 && (src_h < padding || src_h >= h + padding) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let src_h = src_h - padding;
|
||||||
|
let src_idx = src_idx + src_h * src_s2;
|
||||||
|
let dst_idx = dst_idx + h_k_idx * w_k;
|
||||||
|
for w_k_idx in 0..w_k {
|
||||||
|
let src_w = w_idx * stride + w_k_idx * dilation;
|
||||||
|
if padding != 0 && (src_w < padding || src_w >= w + padding) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let src_w = src_w - padding;
|
||||||
|
let src_idx = src_idx + src_w * src_s3;
|
||||||
|
let dst_idx = dst_idx + w_k_idx;
|
||||||
|
dst[dst_idx] = src[src_idx]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
||||||
|
|
||||||
impl<'a> Map2 for Conv2D<'a> {
|
impl<'a> Map2 for Conv2D<'a> {
|
||||||
@ -1294,8 +1461,9 @@ impl Map2 for MatMul {
|
|||||||
) -> Result<Vec<T>> {
|
) -> Result<Vec<T>> {
|
||||||
use gemm::{gemm, Parallelism};
|
use gemm::{gemm, Parallelism};
|
||||||
|
|
||||||
if T::DTYPE == DType::BF16 {
|
match T::DTYPE {
|
||||||
return Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?;
|
DType::F16 | DType::F32 | DType::F64 => {}
|
||||||
|
_ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?,
|
||||||
}
|
}
|
||||||
|
|
||||||
let (b, m, n, k) = self.0;
|
let (b, m, n, k) = self.0;
|
||||||
@ -1999,6 +2167,10 @@ impl BackendStorage for CpuStorage {
|
|||||||
MaxPool2D(kernel_size, stride).map(self, layout)
|
MaxPool2D(kernel_size, stride).map(self, layout)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
|
||||||
|
UpsampleNearest1D(sz).map(self, layout)
|
||||||
|
}
|
||||||
|
|
||||||
fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
|
fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
|
||||||
UpsampleNearest2D(h, w).map(self, layout)
|
UpsampleNearest2D(h, w).map(self, layout)
|
||||||
}
|
}
|
||||||
@ -2227,7 +2399,40 @@ impl BackendStorage for CpuStorage {
|
|||||||
kernel_l: &Layout,
|
kernel_l: &Layout,
|
||||||
params: &crate::conv::ParamsConv1D,
|
params: &crate::conv::ParamsConv1D,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
Conv1D(params).map(self, l, kernel, kernel_l)
|
if !USE_IM2COL_CONV1D {
|
||||||
|
return Conv1D(params).map(self, l, kernel, kernel_l);
|
||||||
|
}
|
||||||
|
let op = Im2Col1D {
|
||||||
|
l_k: params.k_size,
|
||||||
|
padding: params.padding,
|
||||||
|
stride: params.stride,
|
||||||
|
dilation: params.dilation,
|
||||||
|
};
|
||||||
|
let col = op.map(self, l)?;
|
||||||
|
let b = params.b_size;
|
||||||
|
let n = params.c_out;
|
||||||
|
let l_out = params.l_out();
|
||||||
|
let k = op.l_k * params.c_in;
|
||||||
|
let m = l_out;
|
||||||
|
let col_l = Layout::contiguous((b, m, k));
|
||||||
|
let res = if kernel_l.is_contiguous() {
|
||||||
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.broadcast_as((b, k, n))?;
|
||||||
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
|
} else {
|
||||||
|
// Make the kernel contiguous if not already the case.
|
||||||
|
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||||
|
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||||
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.broadcast_as((b, k, n))?;
|
||||||
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
|
};
|
||||||
|
let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
|
||||||
|
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||||
|
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||||
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
@ -2237,7 +2442,43 @@ impl BackendStorage for CpuStorage {
|
|||||||
kernel_l: &Layout,
|
kernel_l: &Layout,
|
||||||
params: &crate::conv::ParamsConv2D,
|
params: &crate::conv::ParamsConv2D,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
Conv2D(params).map(self, l, kernel, kernel_l)
|
if !USE_IM2COL_CONV2D {
|
||||||
|
return Conv2D(params).map(self, l, kernel, kernel_l);
|
||||||
|
}
|
||||||
|
let op = Im2Col {
|
||||||
|
h_k: params.k_h,
|
||||||
|
w_k: params.k_w,
|
||||||
|
padding: params.padding,
|
||||||
|
stride: params.stride,
|
||||||
|
dilation: params.dilation,
|
||||||
|
};
|
||||||
|
let col = op.map(self, l)?;
|
||||||
|
let b = params.b_size;
|
||||||
|
let n = params.c_out;
|
||||||
|
let (h_out, w_out) = (params.out_h(), params.out_w());
|
||||||
|
let k = op.h_k * op.w_k * params.c_in;
|
||||||
|
let m = h_out * w_out;
|
||||||
|
let col_l = Layout::contiguous((b, m, k));
|
||||||
|
let res = if kernel_l.is_contiguous() {
|
||||||
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.broadcast_as((b, k, n))?;
|
||||||
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
|
} else {
|
||||||
|
// Make the kernel contiguous if not already the case.
|
||||||
|
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||||
|
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||||
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.broadcast_as((b, k, n))?;
|
||||||
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
|
};
|
||||||
|
let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.transpose(1, 3)?;
|
||||||
|
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||||
|
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||||
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv_transpose2d(
|
fn conv_transpose2d(
|
||||||
@ -2362,6 +2603,10 @@ impl BackendDevice for CpuDevice {
|
|||||||
Ok(Self)
|
Ok(Self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn set_seed(&self, _seed: u64) -> Result<()> {
|
||||||
|
crate::bail!("cannot seed the CPU rng with set_seed")
|
||||||
|
}
|
||||||
|
|
||||||
fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> {
|
fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> {
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
|
|
||||||
|
@ -223,6 +223,14 @@ impl BackendDevice for CudaDevice {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||||
|
// We do not call set_seed but instead create a new curand object. This ensures that the
|
||||||
|
// state will be identical and the same random numbers will be generated.
|
||||||
|
let mut curand = self.curand.lock().unwrap();
|
||||||
|
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn location(&self) -> crate::DeviceLocation {
|
fn location(&self) -> crate::DeviceLocation {
|
||||||
crate::DeviceLocation::Cuda {
|
crate::DeviceLocation::Cuda {
|
||||||
gpu_id: self.device.ordinal(),
|
gpu_id: self.device.ordinal(),
|
||||||
@ -312,6 +320,13 @@ impl BackendDevice for CudaDevice {
|
|||||||
// cudarc changes.
|
// cudarc changes.
|
||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
let curand = self.curand.lock().unwrap();
|
let curand = self.curand.lock().unwrap();
|
||||||
|
// curand can only generate an odd number of values.
|
||||||
|
// https://github.com/huggingface/candle/issues/734
|
||||||
|
let elem_count_round = if elem_count % 2 == 1 {
|
||||||
|
elem_count + 1
|
||||||
|
} else {
|
||||||
|
elem_count
|
||||||
|
};
|
||||||
let slice = match dtype {
|
let slice = match dtype {
|
||||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||||
Err(CudaError::UnsupportedDtype {
|
Err(CudaError::UnsupportedDtype {
|
||||||
@ -321,7 +336,7 @@ impl BackendDevice for CudaDevice {
|
|||||||
.w()?
|
.w()?
|
||||||
}
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
|
||||||
curand
|
curand
|
||||||
.0
|
.0
|
||||||
.fill_with_normal(&mut data, mean as f32, std as f32)
|
.fill_with_normal(&mut data, mean as f32, std as f32)
|
||||||
@ -329,7 +344,7 @@ impl BackendDevice for CudaDevice {
|
|||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
DType::F64 => {
|
DType::F64 => {
|
||||||
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
|
||||||
curand.0.fill_with_normal(&mut data, mean, std).w()?;
|
curand.0.fill_with_normal(&mut data, mean, std).w()?;
|
||||||
CudaStorageSlice::F64(data)
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
@ -593,6 +608,105 @@ impl Map1 for Elu {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Im2Col1D {
|
||||||
|
l_k: usize,
|
||||||
|
stride: usize,
|
||||||
|
dilation: usize,
|
||||||
|
padding: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Im2Col1D {
|
||||||
|
fn l_out(&self, l: usize) -> usize {
|
||||||
|
(l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Map1 for Im2Col1D {
|
||||||
|
fn f<T: DeviceRepr + WithDType>(
|
||||||
|
&self,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
layout: &Layout,
|
||||||
|
) -> Result<CudaSlice<T>> {
|
||||||
|
let shape = layout.shape();
|
||||||
|
let dims = shape.dims();
|
||||||
|
let l_out = self.l_out(dims[2]);
|
||||||
|
let dst_el = dims[0] * l_out * dims[1] * self.l_k;
|
||||||
|
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||||
|
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
||||||
|
let src = &src.slice(layout.start_offset()..);
|
||||||
|
let func = dev.get_or_load_func(&kernel_name::<T>("im2col1d"), kernels::CONV)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||||
|
let params = (
|
||||||
|
dst_el,
|
||||||
|
l_out,
|
||||||
|
self.l_k,
|
||||||
|
self.stride,
|
||||||
|
self.padding,
|
||||||
|
self.dilation,
|
||||||
|
&ds,
|
||||||
|
src,
|
||||||
|
&dst,
|
||||||
|
);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
Ok(dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Im2Col {
|
||||||
|
h_k: usize,
|
||||||
|
w_k: usize,
|
||||||
|
stride: usize,
|
||||||
|
dilation: usize,
|
||||||
|
padding: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Im2Col {
|
||||||
|
fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
|
||||||
|
let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
|
||||||
|
let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
|
||||||
|
(h_out, w_out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Map1 for Im2Col {
|
||||||
|
fn f<T: DeviceRepr + WithDType>(
|
||||||
|
&self,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
layout: &Layout,
|
||||||
|
) -> Result<CudaSlice<T>> {
|
||||||
|
let shape = layout.shape();
|
||||||
|
let dims = shape.dims();
|
||||||
|
let (h_out, w_out) = self.hw_out(dims[2], dims[3]);
|
||||||
|
let dst_el = dims[0] * h_out * w_out * dims[1] * self.h_k * self.w_k;
|
||||||
|
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||||
|
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
||||||
|
let src = &src.slice(layout.start_offset()..);
|
||||||
|
let func = dev.get_or_load_func(&kernel_name::<T>("im2col"), kernels::CONV)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||||
|
let params = (
|
||||||
|
dst_el,
|
||||||
|
h_out,
|
||||||
|
w_out,
|
||||||
|
self.h_k,
|
||||||
|
self.w_k,
|
||||||
|
self.stride,
|
||||||
|
self.padding,
|
||||||
|
self.dilation,
|
||||||
|
&ds,
|
||||||
|
src,
|
||||||
|
&dst,
|
||||||
|
);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
Ok(dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct Powf(f64);
|
struct Powf(f64);
|
||||||
impl Map1 for Powf {
|
impl Map1 for Powf {
|
||||||
fn f<T: DeviceRepr + WithDType>(
|
fn f<T: DeviceRepr + WithDType>(
|
||||||
@ -778,8 +892,6 @@ impl<'a> Map1 for IndexSelect<'a> {
|
|||||||
};
|
};
|
||||||
let ids_shape = ids_l.shape();
|
let ids_shape = ids_l.shape();
|
||||||
let ids_dims = ids_shape.dims();
|
let ids_dims = ids_shape.dims();
|
||||||
let ids_el = ids_shape.elem_count();
|
|
||||||
let cfg = LaunchConfig::for_num_elems(ids_el as u32);
|
|
||||||
let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?;
|
let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?;
|
||||||
let src = match src_l.contiguous_offsets() {
|
let src = match src_l.contiguous_offsets() {
|
||||||
Some((o1, o2)) => src.slice(o1..o2),
|
Some((o1, o2)) => src.slice(o1..o2),
|
||||||
@ -787,19 +899,23 @@ impl<'a> Map1 for IndexSelect<'a> {
|
|||||||
};
|
};
|
||||||
let left_size: usize = src_l.dims()[..self.2].iter().product();
|
let left_size: usize = src_l.dims()[..self.2].iter().product();
|
||||||
let right_size: usize = src_l.dims()[self.2 + 1..].iter().product();
|
let right_size: usize = src_l.dims()[self.2 + 1..].iter().product();
|
||||||
let dim_size = src_l.dims()[self.2];
|
let src_dim_size = src_l.dims()[self.2];
|
||||||
|
let ids_dim_size = ids_shape.elem_count();
|
||||||
|
let dst_el = ids_shape.elem_count() * left_size * right_size;
|
||||||
|
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.alloc::<T>(ids_el * left_size * right_size) }.w()?;
|
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||||
let params = (
|
let params = (
|
||||||
ids_el,
|
dst_el,
|
||||||
ids_dims.len(),
|
ids_dims.len(),
|
||||||
&ds,
|
&ds,
|
||||||
ids,
|
ids,
|
||||||
&src,
|
&src,
|
||||||
&out,
|
&out,
|
||||||
left_size,
|
left_size,
|
||||||
dim_size,
|
src_dim_size,
|
||||||
|
ids_dim_size,
|
||||||
right_size,
|
right_size,
|
||||||
);
|
);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
@ -1650,9 +1766,46 @@ impl BackendStorage for CudaStorage {
|
|||||||
kernel_l: &Layout,
|
kernel_l: &Layout,
|
||||||
params: &crate::conv::ParamsConv1D,
|
params: &crate::conv::ParamsConv1D,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
|
const USE_IM2COL_CONV1D: bool = true;
|
||||||
|
|
||||||
let device = self.device().clone();
|
let device = self.device().clone();
|
||||||
let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
if !USE_IM2COL_CONV1D {
|
||||||
Ok(Self { slice, device })
|
let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||||
|
return Ok(Self { slice, device });
|
||||||
|
}
|
||||||
|
|
||||||
|
let col = Im2Col1D {
|
||||||
|
l_k: params.k_size,
|
||||||
|
stride: params.stride,
|
||||||
|
dilation: params.dilation,
|
||||||
|
padding: params.padding,
|
||||||
|
}
|
||||||
|
.map(&self.slice, &device, l)?;
|
||||||
|
let col = Self { slice: col, device };
|
||||||
|
let l_out = params.l_out();
|
||||||
|
let b = params.b_size;
|
||||||
|
let n = params.c_out;
|
||||||
|
let k = params.k_size * params.c_in;
|
||||||
|
let m = l_out;
|
||||||
|
let col_l = Layout::contiguous((b, m, k));
|
||||||
|
let res = if kernel_l.is_contiguous() {
|
||||||
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.broadcast_as((b, k, n))?;
|
||||||
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
|
} else {
|
||||||
|
// Make the kernel contiguous if not already the case.
|
||||||
|
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||||
|
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||||
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.broadcast_as((b, k, n))?;
|
||||||
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
|
};
|
||||||
|
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
|
||||||
|
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||||
|
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||||
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(feature = "cudnn"))]
|
#[cfg(not(feature = "cudnn"))]
|
||||||
@ -1663,9 +1816,50 @@ impl BackendStorage for CudaStorage {
|
|||||||
kernel_l: &Layout,
|
kernel_l: &Layout,
|
||||||
params: &crate::conv::ParamsConv2D,
|
params: &crate::conv::ParamsConv2D,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
|
const USE_IM2COL_CONV2D: bool = true;
|
||||||
|
|
||||||
let device = self.device().clone();
|
let device = self.device().clone();
|
||||||
let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
if !USE_IM2COL_CONV2D {
|
||||||
Ok(Self { slice, device })
|
let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||||
|
return Ok(Self { slice, device });
|
||||||
|
}
|
||||||
|
|
||||||
|
let col = Im2Col {
|
||||||
|
h_k: params.k_h,
|
||||||
|
w_k: params.k_w,
|
||||||
|
stride: params.stride,
|
||||||
|
dilation: params.dilation,
|
||||||
|
padding: params.padding,
|
||||||
|
}
|
||||||
|
.map(&self.slice, &device, l)?;
|
||||||
|
let col = Self { slice: col, device };
|
||||||
|
let h_out = params.out_h();
|
||||||
|
let w_out = params.out_w();
|
||||||
|
let b = params.b_size;
|
||||||
|
let n = params.c_out;
|
||||||
|
let k = params.k_h * params.k_w * params.c_in;
|
||||||
|
let m = h_out * w_out;
|
||||||
|
let col_l = Layout::contiguous((b, m, k));
|
||||||
|
let res = if kernel_l.is_contiguous() {
|
||||||
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.broadcast_as((b, k, n))?;
|
||||||
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
|
} else {
|
||||||
|
// Make the kernel contiguous if not already the case.
|
||||||
|
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||||
|
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||||
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.broadcast_as((b, k, n))?;
|
||||||
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
|
};
|
||||||
|
let res_l = Layout::contiguous((b, h_out, w_out, n))
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.transpose(1, 3)?;
|
||||||
|
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||||
|
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||||
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "cudnn")]
|
#[cfg(feature = "cudnn")]
|
||||||
@ -1770,6 +1964,10 @@ impl BackendStorage for CudaStorage {
|
|||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn upsample_nearest1d(&self, _: &Layout, _out_sz: usize) -> Result<Self> {
|
||||||
|
crate::bail!("upsample-nearest1d is not supported on cuda")
|
||||||
|
}
|
||||||
|
|
||||||
fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result<Self> {
|
fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result<Self> {
|
||||||
let device = self.device().clone();
|
let device = self.device().clone();
|
||||||
let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?;
|
let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?;
|
||||||
|
@ -34,6 +34,9 @@ pub(crate) fn launch_conv2d<
|
|||||||
params: &crate::conv::ParamsConv2D,
|
params: &crate::conv::ParamsConv2D,
|
||||||
dev: &crate::cuda_backend::CudaDevice,
|
dev: &crate::cuda_backend::CudaDevice,
|
||||||
) -> crate::Result<()> {
|
) -> crate::Result<()> {
|
||||||
|
use crate::conv::CudnnFwdAlgo as CandleAlgo;
|
||||||
|
use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
|
||||||
|
|
||||||
let device_id = dev.id();
|
let device_id = dev.id();
|
||||||
let cudnn = CUDNN.with(|cudnn| {
|
let cudnn = CUDNN.with(|cudnn| {
|
||||||
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
||||||
@ -90,7 +93,20 @@ pub(crate) fn launch_conv2d<
|
|||||||
w: &w,
|
w: &w,
|
||||||
y: &y,
|
y: &y,
|
||||||
};
|
};
|
||||||
let alg = conv2d.pick_algorithm()?;
|
let alg = match params.cudnn_fwd_algo {
|
||||||
|
None => conv2d.pick_algorithm()?,
|
||||||
|
Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
||||||
|
Some(CandleAlgo::ImplicitPrecompGemm) => {
|
||||||
|
A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
|
||||||
|
}
|
||||||
|
Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
||||||
|
Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
||||||
|
Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
||||||
|
Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
||||||
|
Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
||||||
|
Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
|
||||||
|
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
|
||||||
|
};
|
||||||
let workspace_size = conv2d.get_workspace_size(alg)?;
|
let workspace_size = conv2d.get_workspace_size(alg)?;
|
||||||
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
|
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
|
||||||
unsafe {
|
unsafe {
|
||||||
|
@ -128,6 +128,13 @@ impl Device {
|
|||||||
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_seed(&self, seed: u64) -> Result<()> {
|
||||||
|
match self {
|
||||||
|
Self::Cpu => crate::cpu_backend::CpuDevice.set_seed(seed),
|
||||||
|
Self::Cuda(c) => c.set_seed(seed),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn same_device(&self, rhs: &Self) -> bool {
|
pub fn same_device(&self, rhs: &Self) -> bool {
|
||||||
match (self, rhs) {
|
match (self, rhs) {
|
||||||
(Self::Cpu, Self::Cpu) => true,
|
(Self::Cpu, Self::Cpu) => true,
|
||||||
|
@ -67,6 +67,20 @@ impl DType {
|
|||||||
Self::F64 => 8,
|
Self::F64 => 8,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn is_int(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::U8 | Self::U32 | Self::I64 => true,
|
||||||
|
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_float(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::U8 | Self::U32 | Self::I64 => false,
|
||||||
|
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait WithDType:
|
pub trait WithDType:
|
||||||
|
@ -152,6 +152,10 @@ impl crate::backend::BackendStorage for CudaStorage {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
@ -163,6 +167,10 @@ impl crate::backend::BackendDevice for CudaDevice {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn set_seed(&self, _: u64) -> Result<()> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
fn location(&self) -> crate::DeviceLocation {
|
fn location(&self) -> crate::DeviceLocation {
|
||||||
fail!()
|
fail!()
|
||||||
}
|
}
|
||||||
|
@ -46,19 +46,31 @@ impl Tensor {
|
|||||||
current_dim += 1;
|
current_dim += 1;
|
||||||
out
|
out
|
||||||
}
|
}
|
||||||
|
TensorIndexer::IndexSelect(indexes) => {
|
||||||
|
if indexes.rank() != 1 {
|
||||||
|
crate::bail!("multi-dimensional tensor indexing is not supported")
|
||||||
|
}
|
||||||
|
let out = x.index_select(&indexes.to_device(x.device())?, current_dim)?;
|
||||||
|
current_dim += 1;
|
||||||
|
out
|
||||||
|
}
|
||||||
|
TensorIndexer::Err(e) => crate::bail!("indexing error {e:?}"),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
Ok(x)
|
Ok(x)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug)]
|
||||||
/// Generic structure used to index a slice of the tensor
|
/// Generic structure used to index a slice of the tensor
|
||||||
pub enum TensorIndexer {
|
pub enum TensorIndexer {
|
||||||
/// This selects the elemnts for which an index has some specific value.
|
/// This selects the elemnts for which an index has some specific value.
|
||||||
Select(usize),
|
Select(usize),
|
||||||
/// This is a regular slice, purely indexing a chunk of the tensor
|
/// This is a regular slice, purely indexing a chunk of the tensor
|
||||||
Narrow(Bound<usize>, Bound<usize>),
|
Narrow(Bound<usize>, Bound<usize>),
|
||||||
|
/// Indexing via a 1d tensor
|
||||||
|
IndexSelect(Tensor),
|
||||||
|
Err(Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<usize> for TensorIndexer {
|
impl From<usize> for TensorIndexer {
|
||||||
@ -67,6 +79,31 @@ impl From<usize> for TensorIndexer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<&[u32]> for TensorIndexer {
|
||||||
|
fn from(index: &[u32]) -> Self {
|
||||||
|
match Tensor::new(index, &crate::Device::Cpu) {
|
||||||
|
Ok(tensor) => TensorIndexer::IndexSelect(tensor),
|
||||||
|
Err(e) => TensorIndexer::Err(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Vec<u32>> for TensorIndexer {
|
||||||
|
fn from(index: Vec<u32>) -> Self {
|
||||||
|
let len = index.len();
|
||||||
|
match Tensor::from_vec(index, len, &crate::Device::Cpu) {
|
||||||
|
Ok(tensor) => TensorIndexer::IndexSelect(tensor),
|
||||||
|
Err(e) => TensorIndexer::Err(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&Tensor> for TensorIndexer {
|
||||||
|
fn from(tensor: &Tensor) -> Self {
|
||||||
|
TensorIndexer::IndexSelect(tensor.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
macro_rules! impl_from_range {
|
macro_rules! impl_from_range {
|
||||||
($range_type:ty) => {
|
($range_type:ty) => {
|
||||||
impl From<$range_type> for TensorIndexer {
|
impl From<$range_type> for TensorIndexer {
|
||||||
|
@ -110,14 +110,8 @@ impl ToUsize2 for (usize, usize) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// A simple trait defining a module with forward method using a single argument.
|
// A simple trait defining a module with forward method using a single argument.
|
||||||
pub trait Module: std::fmt::Debug {
|
pub trait Module {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
||||||
|
|
||||||
/// Change the module to use training mode vs eval mode.
|
|
||||||
///
|
|
||||||
/// The default implementation does nothing as this is only used for a couple modules such as
|
|
||||||
/// dropout or batch-normalization.
|
|
||||||
fn set_training(&mut self, _training: bool) {}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for quantized::QMatMul {
|
impl Module for quantized::QMatMul {
|
||||||
@ -125,3 +119,9 @@ impl Module for quantized::QMatMul {
|
|||||||
self.forward(xs)
|
self.forward(xs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
self(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -250,8 +250,6 @@ impl Tensor {
|
|||||||
if header.fortran_order {
|
if header.fortran_order {
|
||||||
return Err(Error::Npy("fortran order not supported".to_string()));
|
return Err(Error::Npy("fortran order not supported".to_string()));
|
||||||
}
|
}
|
||||||
let mut data: Vec<u8> = vec![];
|
|
||||||
reader.read_to_end(&mut data)?;
|
|
||||||
Self::from_reader(header.shape(), header.descr, &mut reader)
|
Self::from_reader(header.shape(), header.descr, &mut reader)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,8 +58,13 @@ pub enum UnaryOp {
|
|||||||
Sqr,
|
Sqr,
|
||||||
Sqrt,
|
Sqrt,
|
||||||
Gelu,
|
Gelu,
|
||||||
|
GeluErf,
|
||||||
|
Erf,
|
||||||
Relu,
|
Relu,
|
||||||
Tanh,
|
Tanh,
|
||||||
|
Floor,
|
||||||
|
Ceil,
|
||||||
|
Round,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -116,6 +121,7 @@ pub enum Op {
|
|||||||
stride: (usize, usize),
|
stride: (usize, usize),
|
||||||
},
|
},
|
||||||
|
|
||||||
|
UpsampleNearest1D(Tensor),
|
||||||
UpsampleNearest2D(Tensor),
|
UpsampleNearest2D(Tensor),
|
||||||
|
|
||||||
Cat(Vec<Tensor>, usize),
|
Cat(Vec<Tensor>, usize),
|
||||||
@ -130,6 +136,7 @@ pub enum Op {
|
|||||||
Copy(Tensor),
|
Copy(Tensor),
|
||||||
Broadcast(Tensor),
|
Broadcast(Tensor),
|
||||||
Narrow(Tensor, usize, usize, usize),
|
Narrow(Tensor, usize, usize, usize),
|
||||||
|
SliceScatter0(Tensor, Tensor, usize),
|
||||||
Reshape(Tensor),
|
Reshape(Tensor),
|
||||||
ToDevice(Tensor),
|
ToDevice(Tensor),
|
||||||
Transpose(Tensor, usize, usize),
|
Transpose(Tensor, usize, usize),
|
||||||
@ -324,8 +331,13 @@ pub(crate) struct Recip;
|
|||||||
pub(crate) struct Sqr;
|
pub(crate) struct Sqr;
|
||||||
pub(crate) struct Sqrt;
|
pub(crate) struct Sqrt;
|
||||||
pub(crate) struct Gelu;
|
pub(crate) struct Gelu;
|
||||||
|
pub(crate) struct GeluErf;
|
||||||
|
pub(crate) struct Erf;
|
||||||
pub(crate) struct Relu;
|
pub(crate) struct Relu;
|
||||||
pub(crate) struct Tanh;
|
pub(crate) struct Tanh;
|
||||||
|
pub(crate) struct Floor;
|
||||||
|
pub(crate) struct Ceil;
|
||||||
|
pub(crate) struct Round;
|
||||||
|
|
||||||
macro_rules! bin_op {
|
macro_rules! bin_op {
|
||||||
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
||||||
@ -600,6 +612,194 @@ impl UnaryOpT for Gelu {
|
|||||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||||
crate::mkl::vd_gelu(xs, ys)
|
crate::mkl::vd_gelu(xs, ys)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
const F32_VEC: bool = true;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
#[inline(always)]
|
||||||
|
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
||||||
|
crate::accelerate::vs_gelu(xs, ys)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
const F64_VEC: bool = true;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
#[inline(always)]
|
||||||
|
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||||
|
crate::accelerate::vd_gelu(xs, ys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UnaryOpT for Erf {
|
||||||
|
const NAME: &'static str = "erf";
|
||||||
|
const KERNEL: &'static str = "uerf";
|
||||||
|
const V: Self = Erf;
|
||||||
|
#[inline(always)]
|
||||||
|
fn bf16(v: bf16) -> bf16 {
|
||||||
|
bf16::from_f64(Self::f64(v.to_f64()))
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f16(v: f16) -> f16 {
|
||||||
|
f16::from_f64(Self::f64(v.to_f64()))
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f32(v: f32) -> f32 {
|
||||||
|
Self::f64(v as f64) as f32
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f64(v: f64) -> f64 {
|
||||||
|
crate::cpu::erf::erf(v)
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn u8(_: u8) -> u8 {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn u32(_: u32) -> u32 {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn i64(_: i64) -> i64 {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UnaryOpT for Ceil {
|
||||||
|
const NAME: &'static str = "ceil";
|
||||||
|
const KERNEL: &'static str = "uceil";
|
||||||
|
const V: Self = Ceil;
|
||||||
|
#[inline(always)]
|
||||||
|
fn bf16(v: bf16) -> bf16 {
|
||||||
|
v.ceil()
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f16(v: f16) -> f16 {
|
||||||
|
v.ceil()
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f32(v: f32) -> f32 {
|
||||||
|
v.ceil()
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f64(v: f64) -> f64 {
|
||||||
|
v.ceil()
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn u8(v: u8) -> u8 {
|
||||||
|
v
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn u32(v: u32) -> u32 {
|
||||||
|
v
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn i64(v: i64) -> i64 {
|
||||||
|
v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UnaryOpT for Floor {
|
||||||
|
const NAME: &'static str = "floor";
|
||||||
|
const KERNEL: &'static str = "ufloor";
|
||||||
|
const V: Self = Floor;
|
||||||
|
#[inline(always)]
|
||||||
|
fn bf16(v: bf16) -> bf16 {
|
||||||
|
v.floor()
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f16(v: f16) -> f16 {
|
||||||
|
v.floor()
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f32(v: f32) -> f32 {
|
||||||
|
v.floor()
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f64(v: f64) -> f64 {
|
||||||
|
v.floor()
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn u8(v: u8) -> u8 {
|
||||||
|
v
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn u32(v: u32) -> u32 {
|
||||||
|
v
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn i64(v: i64) -> i64 {
|
||||||
|
v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UnaryOpT for Round {
|
||||||
|
const NAME: &'static str = "round";
|
||||||
|
const KERNEL: &'static str = "uround";
|
||||||
|
const V: Self = Round;
|
||||||
|
#[inline(always)]
|
||||||
|
fn bf16(v: bf16) -> bf16 {
|
||||||
|
v.round()
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f16(v: f16) -> f16 {
|
||||||
|
v.round()
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f32(v: f32) -> f32 {
|
||||||
|
v.round()
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f64(v: f64) -> f64 {
|
||||||
|
v.round()
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn u8(v: u8) -> u8 {
|
||||||
|
v
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn u32(v: u32) -> u32 {
|
||||||
|
v
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn i64(v: i64) -> i64 {
|
||||||
|
v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UnaryOpT for GeluErf {
|
||||||
|
const NAME: &'static str = "gelu_erf";
|
||||||
|
const KERNEL: &'static str = "ugelu_erf";
|
||||||
|
const V: Self = GeluErf;
|
||||||
|
#[inline(always)]
|
||||||
|
fn bf16(v: bf16) -> bf16 {
|
||||||
|
bf16::from_f64(Self::f64(v.to_f64()))
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f16(v: f16) -> f16 {
|
||||||
|
f16::from_f64(Self::f64(v.to_f64()))
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f32(v: f32) -> f32 {
|
||||||
|
Self::f64(v as f64) as f32
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f64(v: f64) -> f64 {
|
||||||
|
(crate::cpu::erf::erf(v / 2f64.sqrt()) + 1.) * 0.5 * v
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn u8(_: u8) -> u8 {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn u32(_: u32) -> u32 {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn i64(_: i64) -> i64 {
|
||||||
|
0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UnaryOpT for Relu {
|
impl UnaryOpT for Relu {
|
||||||
|
@ -638,3 +638,35 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
|
|||||||
Ok(hsum_float_8(acc) + summs)
|
Ok(hsum_float_8(acc) + summs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||||
|
let qk = QK_K;
|
||||||
|
if n % qk != 0 {
|
||||||
|
crate::bail!("vec_dot_q8k_8k: {n} is not divisible by {qk}")
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
let mut acc = _mm256_setzero_ps();
|
||||||
|
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||||
|
let mut sumi = _mm256_setzero_si256();
|
||||||
|
let x_qs = xs.qs.as_ptr();
|
||||||
|
let y_qs = ys.qs.as_ptr();
|
||||||
|
for j in (0..QK_K).step_by(32) {
|
||||||
|
let xs = _mm256_loadu_si256(x_qs.add(j) as *const __m256i);
|
||||||
|
let ys = _mm256_loadu_si256(y_qs.add(j) as *const __m256i);
|
||||||
|
|
||||||
|
let xs0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xs, 0));
|
||||||
|
let ys0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(ys, 0));
|
||||||
|
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(xs0, ys0));
|
||||||
|
|
||||||
|
let xs1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xs, 1));
|
||||||
|
let ys1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(ys, 1));
|
||||||
|
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(xs1, ys1));
|
||||||
|
}
|
||||||
|
let d = _mm256_set1_ps(xs.d * ys.d);
|
||||||
|
acc = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi), acc);
|
||||||
|
}
|
||||||
|
Ok(hsum_float_8(acc))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -135,7 +135,13 @@ pub fn qtensor_from_ggml(
|
|||||||
dims: Vec<usize>,
|
dims: Vec<usize>,
|
||||||
) -> Result<super::QTensor> {
|
) -> Result<super::QTensor> {
|
||||||
let tensor_elems = dims.iter().product::<usize>();
|
let tensor_elems = dims.iter().product::<usize>();
|
||||||
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
|
let blck_size = ggml_dtype.blck_size();
|
||||||
|
if tensor_elems % blck_size != 0 {
|
||||||
|
crate::bail!(
|
||||||
|
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
|
||||||
|
|
||||||
match ggml_dtype {
|
match ggml_dtype {
|
||||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
|
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
|
||||||
|
@ -59,8 +59,13 @@ impl TensorInfo {
|
|||||||
tensor_data_offset: u64,
|
tensor_data_offset: u64,
|
||||||
) -> Result<QTensor> {
|
) -> Result<QTensor> {
|
||||||
let tensor_elems = self.shape.elem_count();
|
let tensor_elems = self.shape.elem_count();
|
||||||
let size_in_bytes =
|
let blck_size = self.ggml_dtype.blck_size();
|
||||||
tensor_elems * self.ggml_dtype.type_size() / self.ggml_dtype.blck_size();
|
if tensor_elems % blck_size != 0 {
|
||||||
|
crate::bail!(
|
||||||
|
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
let size_in_bytes = tensor_elems / blck_size * self.ggml_dtype.type_size();
|
||||||
let mut raw_data = vec![0u8; size_in_bytes];
|
let mut raw_data = vec![0u8; size_in_bytes];
|
||||||
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
|
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
|
||||||
reader.read_exact(&mut raw_data)?;
|
reader.read_exact(&mut raw_data)?;
|
||||||
|
@ -34,6 +34,9 @@ pub trait GgmlType: Sized + Clone + Send + Sync {
|
|||||||
/// Dot product used as a building block for quantized mat-mul.
|
/// Dot product used as a building block for quantized mat-mul.
|
||||||
/// n is the number of elements to be considered.
|
/// n is the number of elements to be considered.
|
||||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>;
|
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>;
|
||||||
|
|
||||||
|
/// Generic implementation of the dot product without simd optimizations.
|
||||||
|
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
@ -225,6 +228,13 @@ impl GgmlType for BlockQ4_0 {
|
|||||||
#[cfg(target_feature = "neon")]
|
#[cfg(target_feature = "neon")]
|
||||||
return super::neon::vec_dot_q4_0_q8_0(n, xs, ys);
|
return super::neon::vec_dot_q4_0_q8_0(n, xs, ys);
|
||||||
|
|
||||||
|
#[cfg(target_feature = "simd128")]
|
||||||
|
return super::simd128::vec_dot_q4_0_q8_0(n, xs, ys);
|
||||||
|
|
||||||
|
Self::vec_dot_unopt(n, xs, ys)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||||
let qk = QK8_0;
|
let qk = QK8_0;
|
||||||
let nb = n / qk;
|
let nb = n / qk;
|
||||||
if n % QK8_0 != 0 {
|
if n % QK8_0 != 0 {
|
||||||
@ -255,6 +265,10 @@ impl GgmlType for BlockQ4_1 {
|
|||||||
type VecDotType = BlockQ8_1;
|
type VecDotType = BlockQ8_1;
|
||||||
|
|
||||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||||
|
Self::vec_dot_unopt(n, xs, ys)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||||
// ggml_vec_dot_q4_1_q8_1
|
// ggml_vec_dot_q4_1_q8_1
|
||||||
let qk = QK8_1;
|
let qk = QK8_1;
|
||||||
if n % qk != 0 {
|
if n % qk != 0 {
|
||||||
@ -354,7 +368,10 @@ impl GgmlType for BlockQ5_0 {
|
|||||||
if nb % 2 != 0 {
|
if nb % 2 != 0 {
|
||||||
crate::bail!("vec_dot_q5_0_q8_0: {n}, nb is not divisible by 2")
|
crate::bail!("vec_dot_q5_0_q8_0: {n}, nb is not divisible by 2")
|
||||||
}
|
}
|
||||||
|
Self::vec_dot_unopt(n, xs, ys)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_dot_unopt(_n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||||
// Generic implementation.
|
// Generic implementation.
|
||||||
let mut sumf = 0f32;
|
let mut sumf = 0f32;
|
||||||
|
|
||||||
@ -445,6 +462,10 @@ impl GgmlType for BlockQ5_1 {
|
|||||||
type VecDotType = BlockQ8_1;
|
type VecDotType = BlockQ8_1;
|
||||||
|
|
||||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||||
|
Self::vec_dot_unopt(n, xs, ys)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||||
let qk = Self::BLCK_SIZE;
|
let qk = Self::BLCK_SIZE;
|
||||||
if n % Self::BLCK_SIZE != 0 {
|
if n % Self::BLCK_SIZE != 0 {
|
||||||
crate::bail!("vec_dot_q5_1_q8_1: {n} is not divisible by {qk}")
|
crate::bail!("vec_dot_q5_1_q8_1: {n} is not divisible by {qk}")
|
||||||
@ -606,6 +627,13 @@ impl GgmlType for BlockQ8_0 {
|
|||||||
#[cfg(target_feature = "neon")]
|
#[cfg(target_feature = "neon")]
|
||||||
return super::neon::vec_dot_q8_0_q8_0(n, xs, ys);
|
return super::neon::vec_dot_q8_0_q8_0(n, xs, ys);
|
||||||
|
|
||||||
|
#[cfg(target_feature = "simd128")]
|
||||||
|
return super::simd128::vec_dot_q8_0_q8_0(n, xs, ys);
|
||||||
|
|
||||||
|
Self::vec_dot_unopt(n, xs, ys)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||||
let qk = QK8_0;
|
let qk = QK8_0;
|
||||||
if n % QK8_0 != 0 {
|
if n % QK8_0 != 0 {
|
||||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||||
@ -631,7 +659,11 @@ impl GgmlType for BlockQ8_1 {
|
|||||||
const BLCK_SIZE: usize = QK8_1;
|
const BLCK_SIZE: usize = QK8_1;
|
||||||
type VecDotType = BlockQ8_1;
|
type VecDotType = BlockQ8_1;
|
||||||
|
|
||||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||||
|
Self::vec_dot_unopt(n, xs, ys)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_dot_unopt(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||||
unimplemented!("no support for vec-dot on Q8_1")
|
unimplemented!("no support for vec-dot on Q8_1")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -681,6 +713,13 @@ impl GgmlType for BlockQ2K {
|
|||||||
#[cfg(target_feature = "neon")]
|
#[cfg(target_feature = "neon")]
|
||||||
return super::neon::vec_dot_q2k_q8k(n, xs, ys);
|
return super::neon::vec_dot_q2k_q8k(n, xs, ys);
|
||||||
|
|
||||||
|
#[cfg(target_feature = "simd128")]
|
||||||
|
return super::simd128::vec_dot_q2k_q8k(n, xs, ys);
|
||||||
|
|
||||||
|
Self::vec_dot_unopt(n, xs, ys)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||||
if n % QK_K != 0 {
|
if n % QK_K != 0 {
|
||||||
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
|
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
|
||||||
}
|
}
|
||||||
@ -701,18 +740,17 @@ impl GgmlType for BlockQ2K {
|
|||||||
|
|
||||||
let mut isum = 0;
|
let mut isum = 0;
|
||||||
let mut is = 0;
|
let mut is = 0;
|
||||||
let mut d;
|
|
||||||
for _ in 0..(QK_K / 128) {
|
for _ in 0..(QK_K / 128) {
|
||||||
let mut shift = 0;
|
let mut shift = 0;
|
||||||
for _ in 0..4 {
|
for _ in 0..4 {
|
||||||
d = (sc[is] & 0xF) as i32;
|
let d = (sc[is] & 0xF) as i32;
|
||||||
is += 1;
|
is += 1;
|
||||||
let mut isuml = 0;
|
let mut isuml = 0;
|
||||||
for l in 0..16 {
|
for l in 0..16 {
|
||||||
isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
|
isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
|
||||||
}
|
}
|
||||||
isum += d * isuml;
|
isum += d * isuml;
|
||||||
d = (sc[is] & 0xF) as i32;
|
let d = (sc[is] & 0xF) as i32;
|
||||||
is += 1;
|
is += 1;
|
||||||
isuml = 0;
|
isuml = 0;
|
||||||
for l in 16..32 {
|
for l in 16..32 {
|
||||||
@ -851,6 +889,10 @@ impl GgmlType for BlockQ3K {
|
|||||||
#[cfg(target_feature = "neon")]
|
#[cfg(target_feature = "neon")]
|
||||||
return super::neon::vec_dot_q3k_q8k(n, xs, ys);
|
return super::neon::vec_dot_q3k_q8k(n, xs, ys);
|
||||||
|
|
||||||
|
Self::vec_dot_unopt(n, xs, ys)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||||
if n % QK_K != 0 {
|
if n % QK_K != 0 {
|
||||||
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
|
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
|
||||||
}
|
}
|
||||||
@ -1077,7 +1119,6 @@ impl GgmlType for BlockQ3K {
|
|||||||
let d_all = block.d.to_f32();
|
let d_all = block.d.to_f32();
|
||||||
let mut m = 1;
|
let mut m = 1;
|
||||||
let mut is = 0;
|
let mut is = 0;
|
||||||
let mut dl;
|
|
||||||
|
|
||||||
// Dequantize both 128 long blocks
|
// Dequantize both 128 long blocks
|
||||||
// 32 qs values per 128 long block
|
// 32 qs values per 128 long block
|
||||||
@ -1088,7 +1129,7 @@ impl GgmlType for BlockQ3K {
|
|||||||
for (scale_index, scale_scoped_y) in
|
for (scale_index, scale_scoped_y) in
|
||||||
shift_scoped_y.chunks_exact_mut(16).enumerate()
|
shift_scoped_y.chunks_exact_mut(16).enumerate()
|
||||||
{
|
{
|
||||||
dl = d_all * (scales[is] as f32 - 32.0);
|
let dl = d_all * (scales[is] as f32 - 32.0);
|
||||||
for (i, inner_y) in scale_scoped_y.iter_mut().enumerate() {
|
for (i, inner_y) in scale_scoped_y.iter_mut().enumerate() {
|
||||||
let new_y = dl
|
let new_y = dl
|
||||||
* (((qs[i + 16 * scale_index] >> shift) & 3) as i8
|
* (((qs[i + 16 * scale_index] >> shift) & 3) as i8
|
||||||
@ -1126,6 +1167,13 @@ impl GgmlType for BlockQ4K {
|
|||||||
#[cfg(target_feature = "neon")]
|
#[cfg(target_feature = "neon")]
|
||||||
return super::neon::vec_dot_q4k_q8k(n, xs, ys);
|
return super::neon::vec_dot_q4k_q8k(n, xs, ys);
|
||||||
|
|
||||||
|
#[cfg(target_feature = "simd128")]
|
||||||
|
return super::simd128::vec_dot_q4k_q8k(n, xs, ys);
|
||||||
|
|
||||||
|
Self::vec_dot_unopt(n, xs, ys)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||||
if n % QK_K != 0 {
|
if n % QK_K != 0 {
|
||||||
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
|
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
|
||||||
}
|
}
|
||||||
@ -1312,6 +1360,10 @@ impl GgmlType for BlockQ5K {
|
|||||||
#[cfg(target_feature = "neon")]
|
#[cfg(target_feature = "neon")]
|
||||||
return super::neon::vec_dot_q5k_q8k(n, xs, ys);
|
return super::neon::vec_dot_q5k_q8k(n, xs, ys);
|
||||||
|
|
||||||
|
Self::vec_dot_unopt(n, xs, ys)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||||
if n % QK_K != 0 {
|
if n % QK_K != 0 {
|
||||||
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
|
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
|
||||||
}
|
}
|
||||||
@ -1529,6 +1581,13 @@ impl GgmlType for BlockQ6K {
|
|||||||
#[cfg(target_feature = "neon")]
|
#[cfg(target_feature = "neon")]
|
||||||
return super::neon::vec_dot_q6k_q8k(n, xs, ys);
|
return super::neon::vec_dot_q6k_q8k(n, xs, ys);
|
||||||
|
|
||||||
|
#[cfg(target_feature = "simd128")]
|
||||||
|
return super::simd128::vec_dot_q6k_q8k(n, xs, ys);
|
||||||
|
|
||||||
|
Self::vec_dot_unopt(n, xs, ys)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||||
if n % QK_K != 0 {
|
if n % QK_K != 0 {
|
||||||
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
|
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
|
||||||
}
|
}
|
||||||
@ -1697,8 +1756,38 @@ impl GgmlType for BlockQ8K {
|
|||||||
const BLCK_SIZE: usize = QK_K;
|
const BLCK_SIZE: usize = QK_K;
|
||||||
type VecDotType = BlockQ8K;
|
type VecDotType = BlockQ8K;
|
||||||
|
|
||||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
#[allow(unreachable_code)]
|
||||||
unreachable!()
|
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||||
|
#[cfg(target_feature = "avx")]
|
||||||
|
return super::avx::vec_dot_q8k_q8k(n, xs, ys);
|
||||||
|
|
||||||
|
#[cfg(target_feature = "neon")]
|
||||||
|
return super::neon::vec_dot_q8k_q8k(n, xs, ys);
|
||||||
|
|
||||||
|
#[cfg(target_feature = "simd128")]
|
||||||
|
return super::simd128::vec_dot_q8k_q8k(n, xs, ys);
|
||||||
|
|
||||||
|
Self::vec_dot_unopt(n, xs, ys)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||||
|
let qk = QK_K;
|
||||||
|
if n % QK_K != 0 {
|
||||||
|
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generic implementation.
|
||||||
|
let mut sumf = 0f32;
|
||||||
|
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||||
|
let sum_i = xs
|
||||||
|
.qs
|
||||||
|
.iter()
|
||||||
|
.zip(ys.qs.iter())
|
||||||
|
.map(|(&x, &y)| x as i32 * y as i32)
|
||||||
|
.sum::<i32>();
|
||||||
|
sumf += sum_i as f32 * xs.d * ys.d
|
||||||
|
}
|
||||||
|
Ok(sumf)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
|
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
|
||||||
@ -1804,6 +1893,10 @@ impl GgmlType for f32 {
|
|||||||
type VecDotType = f32;
|
type VecDotType = f32;
|
||||||
|
|
||||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||||
|
Self::vec_dot_unopt(n, xs, ys)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||||
if xs.len() < n {
|
if xs.len() < n {
|
||||||
crate::bail!("size mismatch {} < {n}", xs.len())
|
crate::bail!("size mismatch {} < {n}", xs.len())
|
||||||
}
|
}
|
||||||
@ -1838,6 +1931,10 @@ impl GgmlType for f16 {
|
|||||||
type VecDotType = f16;
|
type VecDotType = f16;
|
||||||
|
|
||||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||||
|
Self::vec_dot_unopt(n, xs, ys)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||||
if xs.len() < n {
|
if xs.len() < n {
|
||||||
crate::bail!("size mismatch {} < {n}", xs.len())
|
crate::bail!("size mismatch {} < {n}", xs.len())
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,8 @@ pub mod gguf_file;
|
|||||||
pub mod k_quants;
|
pub mod k_quants;
|
||||||
#[cfg(target_feature = "neon")]
|
#[cfg(target_feature = "neon")]
|
||||||
pub mod neon;
|
pub mod neon;
|
||||||
|
#[cfg(target_feature = "simd128")]
|
||||||
|
pub mod simd128;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
|
|
||||||
pub use k_quants::GgmlType;
|
pub use k_quants::GgmlType;
|
||||||
@ -229,20 +231,40 @@ impl QTensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct QMatMul(std::sync::Arc<QTensor>);
|
pub enum QMatMul {
|
||||||
|
QTensor(std::sync::Arc<QTensor>),
|
||||||
|
Tensor(Tensor),
|
||||||
|
}
|
||||||
|
|
||||||
|
thread_local! {
|
||||||
|
static DEQUANTIZE_ALL: bool = {
|
||||||
|
match std::env::var("CANDLE_DEQUANTIZE_ALL") {
|
||||||
|
Ok(s) => {
|
||||||
|
!s.is_empty() && s != "0"
|
||||||
|
},
|
||||||
|
Err(_) => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl QMatMul {
|
impl QMatMul {
|
||||||
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Self {
|
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
|
||||||
Self(qtensor)
|
let dequantize = match qtensor.dtype() {
|
||||||
|
GgmlDType::F32 | GgmlDType::F16 => true,
|
||||||
|
_ => DEQUANTIZE_ALL.with(|b| *b),
|
||||||
|
};
|
||||||
|
let t = if dequantize {
|
||||||
|
let tensor = qtensor.dequantize(&Device::Cpu)?;
|
||||||
|
Self::Tensor(tensor)
|
||||||
|
} else {
|
||||||
|
Self::QTensor(qtensor)
|
||||||
|
};
|
||||||
|
Ok(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_qtensor(qtensor: QTensor) -> Self {
|
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
|
||||||
Self(std::sync::Arc::new(qtensor))
|
Self::from_arc(std::sync::Arc::new(qtensor))
|
||||||
}
|
|
||||||
|
|
||||||
pub fn inner(&self) -> &std::sync::Arc<QTensor> {
|
|
||||||
&self.0
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -287,6 +309,16 @@ impl crate::CustomOp1 for QTensor {
|
|||||||
|
|
||||||
impl QMatMul {
|
impl QMatMul {
|
||||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
xs.apply_op1_no_bwd(self.0.as_ref())
|
match self {
|
||||||
|
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
||||||
|
Self::Tensor(w) => {
|
||||||
|
let w = match *xs.dims() {
|
||||||
|
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
|
||||||
|
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
|
||||||
|
_ => w.t()?,
|
||||||
|
};
|
||||||
|
xs.matmul(&w)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -148,6 +148,35 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||||
|
let qk = QK_K;
|
||||||
|
if n % QK_K != 0 {
|
||||||
|
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut sumf = 0f32;
|
||||||
|
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||||
|
unsafe {
|
||||||
|
let mut sum_i = vdupq_n_s32(0);
|
||||||
|
let scale = xs.d * ys.d;
|
||||||
|
let xs = xs.qs.as_ptr();
|
||||||
|
let ys = ys.qs.as_ptr();
|
||||||
|
for i in (0..QK_K).step_by(16) {
|
||||||
|
let xs = vld1q_s8(xs.add(i));
|
||||||
|
let ys = vld1q_s8(ys.add(i));
|
||||||
|
let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys));
|
||||||
|
let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys));
|
||||||
|
|
||||||
|
let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up));
|
||||||
|
sum_i = vaddq_s32(sum_i, xy)
|
||||||
|
}
|
||||||
|
sumf += vaddvq_s32(sum_i) as f32 * scale
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(sumf)
|
||||||
|
}
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
|
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||||
if n % QK_K != 0 {
|
if n % QK_K != 0 {
|
||||||
|
427
candle-core/src/quantized/simd128.rs
Normal file
427
candle-core/src/quantized/simd128.rs
Normal file
@ -0,0 +1,427 @@
|
|||||||
|
use super::k_quants::{BlockQ2K, BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K};
|
||||||
|
use crate::Result;
|
||||||
|
use byteorder::{ByteOrder, LittleEndian};
|
||||||
|
use half::f16;
|
||||||
|
|
||||||
|
use core::arch::wasm32::*;
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||||
|
let qk = QK8_0;
|
||||||
|
if n % QK8_0 != 0 {
|
||||||
|
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||||
|
}
|
||||||
|
let nb = n / QK8_0;
|
||||||
|
if nb % 2 != 0 {
|
||||||
|
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||||
|
}
|
||||||
|
unsafe {
|
||||||
|
let mut acc = f32x4_splat(0.0f32);
|
||||||
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||||
|
let x1234 = v128_load(x.qs.as_ptr() as *const v128);
|
||||||
|
let x12 = v128_and(x1234, u8x16_splat(0x0F));
|
||||||
|
let x12 = i8x16_sub(x12, i8x16_splat(8));
|
||||||
|
let x34 = u8x16_shr(x1234, 4);
|
||||||
|
let x34 = i8x16_sub(x34, i8x16_splat(8));
|
||||||
|
|
||||||
|
let x1 = i16x8_extend_low_i8x16(x12);
|
||||||
|
let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr());
|
||||||
|
let sum_xy = i32x4_dot_i16x8(x1, y1);
|
||||||
|
|
||||||
|
let x2 = i16x8_extend_high_i8x16(x12);
|
||||||
|
let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8));
|
||||||
|
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2));
|
||||||
|
|
||||||
|
let x3 = i16x8_extend_low_i8x16(x34);
|
||||||
|
let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16));
|
||||||
|
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3));
|
||||||
|
|
||||||
|
let x4 = i16x8_extend_high_i8x16(x34);
|
||||||
|
let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24));
|
||||||
|
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4));
|
||||||
|
|
||||||
|
let sum_xy = f32x4_convert_i32x4(sum_xy);
|
||||||
|
|
||||||
|
// f32x4_relaxed_madd is nightly only.
|
||||||
|
let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d));
|
||||||
|
let scaled = f32x4_mul(sum_xy, d);
|
||||||
|
acc = f32x4_add(acc, scaled)
|
||||||
|
}
|
||||||
|
let res = f32x4_extract_lane::<0>(acc)
|
||||||
|
+ f32x4_extract_lane::<1>(acc)
|
||||||
|
+ f32x4_extract_lane::<2>(acc)
|
||||||
|
+ f32x4_extract_lane::<3>(acc);
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||||
|
let qk = QK8_0;
|
||||||
|
if n % QK8_0 != 0 {
|
||||||
|
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||||
|
}
|
||||||
|
let nb = n / QK8_0;
|
||||||
|
if nb % 2 != 0 {
|
||||||
|
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
|
||||||
|
}
|
||||||
|
unsafe {
|
||||||
|
let mut acc = f32x4_splat(0.0f32);
|
||||||
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||||
|
let x1 = i16x8_load_extend_i8x8(x.qs.as_ptr());
|
||||||
|
let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr());
|
||||||
|
let sum_xy = i32x4_dot_i16x8(x1, y1);
|
||||||
|
|
||||||
|
let x2 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(8));
|
||||||
|
let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8));
|
||||||
|
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2));
|
||||||
|
|
||||||
|
let x3 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(16));
|
||||||
|
let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16));
|
||||||
|
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3));
|
||||||
|
|
||||||
|
let x4 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(24));
|
||||||
|
let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24));
|
||||||
|
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4));
|
||||||
|
|
||||||
|
let sum_xy = f32x4_convert_i32x4(sum_xy);
|
||||||
|
|
||||||
|
// f32x4_relaxed_madd is nightly only.
|
||||||
|
let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d));
|
||||||
|
let scaled = f32x4_mul(sum_xy, d);
|
||||||
|
acc = f32x4_add(acc, scaled)
|
||||||
|
}
|
||||||
|
let res = f32x4_extract_lane::<0>(acc)
|
||||||
|
+ f32x4_extract_lane::<1>(acc)
|
||||||
|
+ f32x4_extract_lane::<2>(acc)
|
||||||
|
+ f32x4_extract_lane::<3>(acc);
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||||
|
if n % QK_K != 0 {
|
||||||
|
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
|
||||||
|
}
|
||||||
|
unsafe {
|
||||||
|
let mut sumf = f32x4_splat(0f32);
|
||||||
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||||
|
let mut q2: &[_] = &x.qs;
|
||||||
|
let mut q8: &[_] = &y.qs;
|
||||||
|
let sc = &x.scales;
|
||||||
|
|
||||||
|
let mut summs = i32x4_splat(0);
|
||||||
|
for i in (0..(QK_K / 16)).step_by(4) {
|
||||||
|
let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(i));
|
||||||
|
let scales = i32x4_shr(
|
||||||
|
i32x4(
|
||||||
|
sc[i] as i32,
|
||||||
|
sc[i + 1] as i32,
|
||||||
|
sc[i + 2] as i32,
|
||||||
|
sc[i + 3] as i32,
|
||||||
|
),
|
||||||
|
4,
|
||||||
|
);
|
||||||
|
summs = i32x4_add(summs, i32x4_mul(bsums, scales))
|
||||||
|
}
|
||||||
|
let summs = f32x4_convert_i32x4(summs);
|
||||||
|
|
||||||
|
let dall = y.d * x.d.to_f32();
|
||||||
|
let dmin = y.d * x.dmin.to_f32();
|
||||||
|
|
||||||
|
let mut isum = i32x4_splat(0);
|
||||||
|
let mut is = 0;
|
||||||
|
for _ in 0..(QK_K / 128) {
|
||||||
|
let mut shift = 0;
|
||||||
|
for _ in 0..4 {
|
||||||
|
let d = (sc[is] & 0xF) as i32;
|
||||||
|
is += 1;
|
||||||
|
let mut isuml = i16x8_splat(0);
|
||||||
|
for l in (0..16).step_by(8) {
|
||||||
|
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
|
||||||
|
let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
|
||||||
|
let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
|
||||||
|
isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
|
||||||
|
}
|
||||||
|
let dd = i32x4_splat(d);
|
||||||
|
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
|
||||||
|
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
|
||||||
|
let d = (sc[is] & 0xF) as i32;
|
||||||
|
is += 1;
|
||||||
|
let mut isuml = i16x8_splat(0);
|
||||||
|
for l in (16..32).step_by(8) {
|
||||||
|
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
|
||||||
|
let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
|
||||||
|
let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
|
||||||
|
isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
|
||||||
|
}
|
||||||
|
let dd = i32x4_splat(d);
|
||||||
|
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
|
||||||
|
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
|
||||||
|
shift += 2;
|
||||||
|
// adjust the indexing
|
||||||
|
q8 = &q8[32..];
|
||||||
|
}
|
||||||
|
// adjust the indexing
|
||||||
|
q2 = &q2[32..];
|
||||||
|
}
|
||||||
|
let isum = f32x4_convert_i32x4(isum);
|
||||||
|
sumf = f32x4_add(
|
||||||
|
sumf,
|
||||||
|
f32x4_sub(
|
||||||
|
f32x4_mul(isum, f32x4_splat(dall)),
|
||||||
|
f32x4_mul(summs, f32x4_splat(dmin)),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let sumf = f32x4_extract_lane::<0>(sumf)
|
||||||
|
+ f32x4_extract_lane::<1>(sumf)
|
||||||
|
+ f32x4_extract_lane::<2>(sumf)
|
||||||
|
+ f32x4_extract_lane::<3>(sumf);
|
||||||
|
Ok(sumf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||||
|
if n % QK_K != 0 {
|
||||||
|
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
|
||||||
|
}
|
||||||
|
|
||||||
|
const KMASK1: u32 = 0x3f3f3f3f;
|
||||||
|
const KMASK2: u32 = 0x0f0f0f0f;
|
||||||
|
const KMASK3: u32 = 0x03030303;
|
||||||
|
|
||||||
|
let mut utmp: [u32; 4] = [0; 4];
|
||||||
|
let mut scales: [u8; 8] = [0; 8];
|
||||||
|
let mut mins: [u8; 8] = [0; 8];
|
||||||
|
|
||||||
|
let mut aux8: [u8; QK_K] = [0; QK_K];
|
||||||
|
let mut sums = f32x4_splat(0f32);
|
||||||
|
unsafe {
|
||||||
|
for (y, x) in ys.iter().zip(xs.iter()) {
|
||||||
|
let q4 = &x.qs;
|
||||||
|
let q8 = &y.qs;
|
||||||
|
|
||||||
|
for j in 0..QK_K / 64 {
|
||||||
|
let q4_1 = v128_load(q4.as_ptr().add(32 * j) as *const v128);
|
||||||
|
let q4_2 = v128_load(q4.as_ptr().add(32 * j + 16) as *const v128);
|
||||||
|
v128_store(
|
||||||
|
aux8.as_mut_ptr().add(64 * j) as *mut v128,
|
||||||
|
v128_and(q4_1, u8x16_splat(0x0F)),
|
||||||
|
);
|
||||||
|
v128_store(
|
||||||
|
aux8.as_mut_ptr().add(64 * j + 16) as *mut v128,
|
||||||
|
v128_and(q4_2, u8x16_splat(0x0F)),
|
||||||
|
);
|
||||||
|
v128_store(
|
||||||
|
aux8.as_mut_ptr().add(64 * j + 32) as *mut v128,
|
||||||
|
u8x16_shr(q4_1, 4),
|
||||||
|
);
|
||||||
|
v128_store(
|
||||||
|
aux8.as_mut_ptr().add(64 * j + 48) as *mut v128,
|
||||||
|
u8x16_shr(q4_2, 4),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
|
||||||
|
|
||||||
|
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
|
||||||
|
let uaux = utmp[1] & KMASK1;
|
||||||
|
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
|
||||||
|
utmp[2] = uaux;
|
||||||
|
utmp[0] &= KMASK1;
|
||||||
|
|
||||||
|
//extract scales and mins
|
||||||
|
LittleEndian::write_u32_into(&utmp[0..2], &mut scales);
|
||||||
|
LittleEndian::write_u32_into(&utmp[2..4], &mut mins);
|
||||||
|
|
||||||
|
let mut sumi = i32x4_splat(0);
|
||||||
|
for j in (0..QK_K / 16).step_by(4) {
|
||||||
|
let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(j));
|
||||||
|
let (m1, m2) = (mins[j / 2] as i32, mins[j / 2 + 1] as i32);
|
||||||
|
let mins = i32x4(m1, m1, m2, m2);
|
||||||
|
sumi = i32x4_add(sumi, i32x4_mul(bsums, mins));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut aux32 = i32x4_splat(0i32);
|
||||||
|
for (scale_i, scale) in scales.iter().enumerate() {
|
||||||
|
let scale = i32x4_splat(*scale as i32);
|
||||||
|
for j in 0..4 {
|
||||||
|
let i = 32 * scale_i + 8 * j;
|
||||||
|
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(i));
|
||||||
|
let aux8 = i16x8_load_extend_u8x8(aux8.as_ptr().add(i));
|
||||||
|
let aux16 = i16x8_mul(q8, aux8);
|
||||||
|
aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_low_i16x8(aux16)));
|
||||||
|
aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_high_i16x8(aux16)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let aux32 = f32x4_convert_i32x4(aux32);
|
||||||
|
let d = f32x4_splat(x.d.to_f32() * y.d);
|
||||||
|
sums = f32x4_add(sums, f32x4_mul(aux32, d));
|
||||||
|
let dmin = x.dmin.to_f32() * y.d;
|
||||||
|
let dmin = f32x4_splat(dmin);
|
||||||
|
let sumi = f32x4_convert_i32x4(sumi);
|
||||||
|
sums = f32x4_sub(sums, f32x4_mul(sumi, dmin));
|
||||||
|
}
|
||||||
|
let sums = f32x4_extract_lane::<0>(sums)
|
||||||
|
+ f32x4_extract_lane::<1>(sums)
|
||||||
|
+ f32x4_extract_lane::<2>(sums)
|
||||||
|
+ f32x4_extract_lane::<3>(sums);
|
||||||
|
Ok(sums)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||||
|
if n % QK_K != 0 {
|
||||||
|
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut aux8 = [0i8; QK_K];
|
||||||
|
unsafe {
|
||||||
|
let mut sums = f32x4_splat(0f32);
|
||||||
|
|
||||||
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||||
|
let q4 = &x.ql;
|
||||||
|
let qh = &x.qh;
|
||||||
|
let q8 = &y.qs;
|
||||||
|
let mut aux32 = f32x4_splat(0f32);
|
||||||
|
|
||||||
|
for j in (0..QK_K).step_by(128) {
|
||||||
|
let aux8 = aux8.as_mut_ptr().add(j);
|
||||||
|
let q4 = &q4.as_ptr().add(j / 2);
|
||||||
|
let qh = &qh.as_ptr().add(j / 4);
|
||||||
|
for l in (0..32).step_by(16) {
|
||||||
|
// aux8[l] = (((q4[l] & 0xF) | ((qh[l] & 3) << 4)) as i32 - 32) as i8;
|
||||||
|
let a8 = v128_or(
|
||||||
|
v128_and(v128_load(q4.add(l) as *const v128), u8x16_splat(0xF)),
|
||||||
|
u8x16_shl(
|
||||||
|
v128_and(v128_load(qh.add(l) as *const v128), u8x16_splat(3)),
|
||||||
|
4,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||||
|
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||||
|
v128_store(
|
||||||
|
aux8.add(l) as *mut v128,
|
||||||
|
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||||
|
);
|
||||||
|
|
||||||
|
// aux8[l + 32] =
|
||||||
|
// (((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i32 - 32) as i8;
|
||||||
|
let a8 = v128_or(
|
||||||
|
v128_and(v128_load(q4.add(l + 32) as *const v128), u8x16_splat(0xF)),
|
||||||
|
u8x16_shl(
|
||||||
|
v128_and(
|
||||||
|
u8x16_shr(v128_load(qh.add(l) as *const v128), 2),
|
||||||
|
u8x16_splat(3),
|
||||||
|
),
|
||||||
|
4,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||||
|
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||||
|
v128_store(
|
||||||
|
aux8.add(l + 32) as *mut v128,
|
||||||
|
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||||
|
);
|
||||||
|
|
||||||
|
// aux8[l + 64] = (((q4[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i32 - 32) as i8;
|
||||||
|
let a8 = v128_or(
|
||||||
|
u8x16_shr(v128_load(q4.add(l) as *const v128), 4),
|
||||||
|
u8x16_shl(
|
||||||
|
v128_and(
|
||||||
|
u8x16_shr(v128_load(qh.add(l) as *const v128), 4),
|
||||||
|
u8x16_splat(3),
|
||||||
|
),
|
||||||
|
4,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||||
|
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||||
|
v128_store(
|
||||||
|
aux8.add(l + 64) as *mut v128,
|
||||||
|
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||||
|
);
|
||||||
|
|
||||||
|
// aux8[l + 96] =
|
||||||
|
// (((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i32 - 32) as i8;
|
||||||
|
let a8 = v128_or(
|
||||||
|
u8x16_shr(v128_load(q4.add(l + 32) as *const v128), 4),
|
||||||
|
u8x16_shl(
|
||||||
|
v128_and(
|
||||||
|
u8x16_shr(v128_load(qh.add(l) as *const v128), 6),
|
||||||
|
u8x16_splat(3),
|
||||||
|
),
|
||||||
|
4,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||||
|
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||||
|
v128_store(
|
||||||
|
aux8.add(l + 96) as *mut v128,
|
||||||
|
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (j, &scale) in x.scales.iter().enumerate() {
|
||||||
|
let scale = f32x4_splat(scale as f32);
|
||||||
|
for offset in [0, 8] {
|
||||||
|
let aux16 = i16x8_mul(
|
||||||
|
i16x8_load_extend_i8x8(q8.as_ptr().add(16 * j + offset)),
|
||||||
|
i16x8_load_extend_i8x8(aux8.as_ptr().add(16 * j + offset)),
|
||||||
|
);
|
||||||
|
aux32 = f32x4_add(
|
||||||
|
aux32,
|
||||||
|
f32x4_mul(f32x4_convert_i32x4(i32x4_extend_low_i16x8(aux16)), scale),
|
||||||
|
);
|
||||||
|
aux32 = f32x4_add(
|
||||||
|
aux32,
|
||||||
|
f32x4_mul(f32x4_convert_i32x4(i32x4_extend_high_i16x8(aux16)), scale),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let d = f32x4_splat(x.d.to_f32() * y.d);
|
||||||
|
sums = f32x4_add(sums, f32x4_mul(aux32, d));
|
||||||
|
}
|
||||||
|
let sums = f32x4_extract_lane::<0>(sums)
|
||||||
|
+ f32x4_extract_lane::<1>(sums)
|
||||||
|
+ f32x4_extract_lane::<2>(sums)
|
||||||
|
+ f32x4_extract_lane::<3>(sums);
|
||||||
|
Ok(sums)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||||
|
let qk = QK_K;
|
||||||
|
if n % QK_K != 0 {
|
||||||
|
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
let mut acc = f32x4_splat(0.0f32);
|
||||||
|
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||||
|
let x_qs = xs.qs.as_ptr();
|
||||||
|
let y_qs = ys.qs.as_ptr();
|
||||||
|
let mut sumi = i32x4_splat(0);
|
||||||
|
for j in (0..QK_K).step_by(8) {
|
||||||
|
let xs = i16x8_load_extend_i8x8(x_qs.add(j));
|
||||||
|
let ys = i16x8_load_extend_i8x8(y_qs.add(j));
|
||||||
|
let sum_xy = i32x4_dot_i16x8(xs, ys);
|
||||||
|
sumi = i32x4_add(sumi, sum_xy)
|
||||||
|
}
|
||||||
|
let d = f32x4_splat(xs.d * ys.d);
|
||||||
|
acc = f32x4_add(acc, f32x4_mul(f32x4_convert_i32x4(sumi), d))
|
||||||
|
}
|
||||||
|
let res = f32x4_extract_lane::<0>(acc)
|
||||||
|
+ f32x4_extract_lane::<1>(acc)
|
||||||
|
+ f32x4_extract_lane::<2>(acc)
|
||||||
|
+ f32x4_extract_lane::<3>(acc);
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
}
|
@ -17,7 +17,7 @@ pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
|
|||||||
let expected_blocks = xs.len() / block_size;
|
let expected_blocks = xs.len() / block_size;
|
||||||
let actual_blocks = ys.len();
|
let actual_blocks = ys.len();
|
||||||
|
|
||||||
//validate that the input is the right size
|
// Validate that the input is the right size
|
||||||
if expected_blocks != actual_blocks {
|
if expected_blocks != actual_blocks {
|
||||||
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
|
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
|
||||||
}
|
}
|
||||||
@ -37,12 +37,12 @@ pub(super) fn group_for_dequantization<'a, 'b, T: super::k_quants::GgmlType>(
|
|||||||
|
|
||||||
let actual_output_len = ys.len();
|
let actual_output_len = ys.len();
|
||||||
let expected_output_len = xs.len() * block_size;
|
let expected_output_len = xs.len() * block_size;
|
||||||
//validate that the output is the right size
|
// Validate that the output is the right size
|
||||||
if expected_output_len != actual_output_len {
|
if expected_output_len != actual_output_len {
|
||||||
crate::bail!("dequantize {dtype:?}: ys (len = {actual_output_len}) does not match the expected length of {expected_output_len}!")
|
crate::bail!("dequantize {dtype:?}: ys (len = {actual_output_len}) does not match the expected length of {expected_output_len}!")
|
||||||
}
|
}
|
||||||
|
|
||||||
//zip the blocks and outputs together
|
// Zip the blocks and outputs together
|
||||||
Ok(xs.iter().zip(ys.chunks_exact_mut(block_size)).collect())
|
Ok(xs.iter().zip(ys.chunks_exact_mut(block_size)).collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,11 +78,7 @@ impl st::View for &Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
pub fn save_safetensors<P: AsRef<std::path::Path>>(
|
pub fn save_safetensors<P: AsRef<Path>>(&self, name: &str, filename: P) -> Result<()> {
|
||||||
&self,
|
|
||||||
name: &str,
|
|
||||||
filename: P,
|
|
||||||
) -> Result<()> {
|
|
||||||
let data = [(name, self.clone())];
|
let data = [(name, self.clone())];
|
||||||
Ok(st::serialize_to_file(data, &None, filename.as_ref())?)
|
Ok(st::serialize_to_file(data, &None, filename.as_ref())?)
|
||||||
}
|
}
|
||||||
@ -255,6 +251,134 @@ pub fn save<K: AsRef<str> + Ord + std::fmt::Display, P: AsRef<Path>>(
|
|||||||
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
|
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(yoke::Yokeable)]
|
||||||
|
struct SafeTensors_<'a>(SafeTensors<'a>);
|
||||||
|
|
||||||
|
pub struct MmapedSafetensors {
|
||||||
|
safetensors: Vec<yoke::Yoke<SafeTensors_<'static>, memmap2::Mmap>>,
|
||||||
|
routing: Option<HashMap<String, usize>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MmapedSafetensors {
|
||||||
|
/// Creates a wrapper around a memory mapped file and deserialize the safetensors header.
|
||||||
|
///
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// The unsafe is inherited from [`memmap2::MmapOptions`].
|
||||||
|
pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
|
||||||
|
let p = p.as_ref();
|
||||||
|
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
|
||||||
|
let file = memmap2::MmapOptions::new()
|
||||||
|
.map(&file)
|
||||||
|
.map_err(|e| Error::from(e).with_path(p))?;
|
||||||
|
let safetensors = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
|
||||||
|
file,
|
||||||
|
|data: &[u8]| {
|
||||||
|
let st = safetensors::SafeTensors::deserialize(data)
|
||||||
|
.map_err(|e| Error::from(e).with_path(p))?;
|
||||||
|
Ok::<_, Error>(SafeTensors_(st))
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
Ok(Self {
|
||||||
|
safetensors: vec![safetensors],
|
||||||
|
routing: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a wrapper around multiple memory mapped file and deserialize the safetensors headers.
|
||||||
|
///
|
||||||
|
/// If a tensor name appears in multiple files, the last entry is returned.
|
||||||
|
///
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// The unsafe is inherited from [`memmap2::MmapOptions`].
|
||||||
|
pub unsafe fn multi<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
|
||||||
|
let mut routing = HashMap::new();
|
||||||
|
let mut safetensors = vec![];
|
||||||
|
for (index, p) in paths.iter().enumerate() {
|
||||||
|
let p = p.as_ref();
|
||||||
|
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
|
||||||
|
let file = memmap2::MmapOptions::new()
|
||||||
|
.map(&file)
|
||||||
|
.map_err(|e| Error::from(e).with_path(p))?;
|
||||||
|
let data = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
|
||||||
|
file,
|
||||||
|
|data: &[u8]| {
|
||||||
|
let st = safetensors::SafeTensors::deserialize(data)
|
||||||
|
.map_err(|e| Error::from(e).with_path(p))?;
|
||||||
|
Ok::<_, Error>(SafeTensors_(st))
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
for k in data.get().0.names() {
|
||||||
|
routing.insert(k.to_string(), index);
|
||||||
|
}
|
||||||
|
safetensors.push(data)
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
safetensors,
|
||||||
|
routing: Some(routing),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
|
||||||
|
self.get(name)?.load(dev)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
|
||||||
|
let mut tensors = vec![];
|
||||||
|
for safetensors in self.safetensors.iter() {
|
||||||
|
tensors.push(safetensors.get().0.tensors())
|
||||||
|
}
|
||||||
|
tensors.into_iter().flatten().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
|
||||||
|
let index = match &self.routing {
|
||||||
|
None => 0,
|
||||||
|
Some(routing) => {
|
||||||
|
let index = routing.get(name).ok_or_else(|| {
|
||||||
|
Error::CannotFindTensor {
|
||||||
|
path: name.to_string(),
|
||||||
|
}
|
||||||
|
.bt()
|
||||||
|
})?;
|
||||||
|
*index
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(self.safetensors[index].get().0.tensor(name)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct BufferedSafetensors {
|
||||||
|
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BufferedSafetensors {
|
||||||
|
/// Creates a wrapper around a binary buffer and deserialize the safetensors header.
|
||||||
|
pub fn new(buffer: Vec<u8>) -> Result<Self> {
|
||||||
|
let safetensors = yoke::Yoke::<SafeTensors_<'static>, Vec<u8>>::try_attach_to_cart(
|
||||||
|
buffer,
|
||||||
|
|data: &[u8]| {
|
||||||
|
let st = safetensors::SafeTensors::deserialize(data)?;
|
||||||
|
Ok::<_, Error>(SafeTensors_(st))
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
Ok(Self { safetensors })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
|
||||||
|
self.get(name)?.load(dev)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
|
||||||
|
self.safetensors.get().0.tensors()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
|
||||||
|
Ok(self.safetensors.get().0.tensor(name)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct MmapedFile {
|
pub struct MmapedFile {
|
||||||
path: std::path::PathBuf,
|
path: std::path::PathBuf,
|
||||||
inner: memmap2::Mmap,
|
inner: memmap2::Mmap,
|
||||||
@ -267,7 +391,7 @@ impl MmapedFile {
|
|||||||
/// # Safety
|
/// # Safety
|
||||||
///
|
///
|
||||||
/// The unsafe is inherited from [`memmap2::MmapOptions`].
|
/// The unsafe is inherited from [`memmap2::MmapOptions`].
|
||||||
pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
|
pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
|
||||||
let p = p.as_ref();
|
let p = p.as_ref();
|
||||||
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
|
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
|
||||||
let inner = memmap2::MmapOptions::new()
|
let inner = memmap2::MmapOptions::new()
|
||||||
|
@ -444,6 +444,18 @@ impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim, D6: Dim> Dims for (D1, D2, D3, D4, D5, D6) {
|
||||||
|
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||||
|
let d0 = self.0.to_index(shape, op)?;
|
||||||
|
let d1 = self.1.to_index(shape, op)?;
|
||||||
|
let d2 = self.2.to_index(shape, op)?;
|
||||||
|
let d3 = self.3.to_index(shape, op)?;
|
||||||
|
let d4 = self.4.to_index(shape, op)?;
|
||||||
|
let d5 = self.5.to_index(shape, op)?;
|
||||||
|
Ok(vec![d0, d1, d2, d3, d4, d5])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
extract_dims!(dims0, 0, |_: &[usize]| (), ());
|
extract_dims!(dims0, 0, |_: &[usize]| (), ());
|
||||||
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
|
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
|
||||||
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
|
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
|
||||||
|
@ -369,6 +369,19 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
|
||||||
|
match self {
|
||||||
|
Storage::Cpu(storage) => {
|
||||||
|
let storage = storage.upsample_nearest1d(layout, sz)?;
|
||||||
|
Ok(Self::Cpu(storage))
|
||||||
|
}
|
||||||
|
Self::Cuda(storage) => {
|
||||||
|
let storage = storage.upsample_nearest1d(layout, sz)?;
|
||||||
|
Ok(Self::Cuda(storage))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
|
pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
|
||||||
match self {
|
match self {
|
||||||
Storage::Cpu(storage) => {
|
Storage::Cpu(storage) => {
|
||||||
|
@ -105,6 +105,28 @@ macro_rules! binary_op {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
macro_rules! binary_op_scalar {
|
||||||
|
($fn_name:ident, $op_name:ident) => {
|
||||||
|
pub fn $fn_name<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||||
|
let rhs = match rhs.to_tensor_scalar()? {
|
||||||
|
crate::scalar::TensorScalar::Tensor(rhs) => rhs,
|
||||||
|
crate::scalar::TensorScalar::Scalar(rhs) => rhs
|
||||||
|
.to_dtype(self.dtype())?
|
||||||
|
.to_device(self.device())?
|
||||||
|
.broadcast_as(self.shape())?,
|
||||||
|
};
|
||||||
|
let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
|
||||||
|
let storage = self.storage().binary_impl::<crate::op::$op_name>(
|
||||||
|
&*rhs.storage(),
|
||||||
|
self.layout(),
|
||||||
|
rhs.layout(),
|
||||||
|
)?;
|
||||||
|
let op = BackpropOp::new2(self, &rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
|
||||||
|
Ok(from_storage(storage, shape.clone(), op, false))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
macro_rules! broadcast_binary_op {
|
macro_rules! broadcast_binary_op {
|
||||||
($fn_name:ident, $inner_fn_name:ident) => {
|
($fn_name:ident, $inner_fn_name:ident) => {
|
||||||
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
||||||
@ -155,14 +177,9 @@ impl Tensor {
|
|||||||
is_variable: bool,
|
is_variable: bool,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let none = BackpropOp::none();
|
let none = BackpropOp::none();
|
||||||
if is_variable {
|
let shape = shape.into();
|
||||||
let shape = shape.into();
|
let storage = device.ones(&shape, dtype)?;
|
||||||
let storage = device.ones(&shape, dtype)?;
|
Ok(from_storage(storage, shape, none, is_variable))
|
||||||
Ok(from_storage(storage, shape, none, is_variable))
|
|
||||||
} else {
|
|
||||||
let storage = device.ones(&crate::shape::SCALAR, dtype)?;
|
|
||||||
from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a new tensor filled with ones.
|
/// Creates a new tensor filled with ones.
|
||||||
@ -200,14 +217,9 @@ impl Tensor {
|
|||||||
is_variable: bool,
|
is_variable: bool,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let none = BackpropOp::none();
|
let none = BackpropOp::none();
|
||||||
if is_variable {
|
let shape = shape.into();
|
||||||
let shape = shape.into();
|
let storage = device.zeros(&shape, dtype)?;
|
||||||
let storage = device.zeros(&shape, dtype)?;
|
Ok(from_storage(storage, shape, none, is_variable))
|
||||||
Ok(from_storage(storage, shape, none, is_variable))
|
|
||||||
} else {
|
|
||||||
let storage = device.zeros(&crate::shape::SCALAR, dtype)?;
|
|
||||||
from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a new tensor filled with zeros.
|
/// Creates a new tensor filled with zeros.
|
||||||
@ -447,8 +459,8 @@ impl Tensor {
|
|||||||
binary_op!(mul, Mul);
|
binary_op!(mul, Mul);
|
||||||
binary_op!(sub, Sub);
|
binary_op!(sub, Sub);
|
||||||
binary_op!(div, Div);
|
binary_op!(div, Div);
|
||||||
binary_op!(maximum, Maximum);
|
binary_op_scalar!(maximum, Maximum);
|
||||||
binary_op!(minimum, Minimum);
|
binary_op_scalar!(minimum, Minimum);
|
||||||
broadcast_binary_op!(broadcast_add, add);
|
broadcast_binary_op!(broadcast_add, add);
|
||||||
broadcast_binary_op!(broadcast_mul, mul);
|
broadcast_binary_op!(broadcast_mul, mul);
|
||||||
broadcast_binary_op!(broadcast_sub, sub);
|
broadcast_binary_op!(broadcast_sub, sub);
|
||||||
@ -467,7 +479,21 @@ impl Tensor {
|
|||||||
unary_op!(sqr, Sqr);
|
unary_op!(sqr, Sqr);
|
||||||
unary_op!(sqrt, Sqrt);
|
unary_op!(sqrt, Sqrt);
|
||||||
unary_op!(gelu, Gelu);
|
unary_op!(gelu, Gelu);
|
||||||
|
unary_op!(gelu_erf, GeluErf);
|
||||||
|
unary_op!(erf, Erf);
|
||||||
unary_op!(relu, Relu);
|
unary_op!(relu, Relu);
|
||||||
|
unary_op!(ceil, Ceil);
|
||||||
|
unary_op!(floor, Floor);
|
||||||
|
unary_op!(round, Round);
|
||||||
|
|
||||||
|
/// Round element of the input tensor to the nearest integer.
|
||||||
|
///
|
||||||
|
/// If the number of decimals is negative, it specifies the number of positions to the left of
|
||||||
|
/// the decimal point.
|
||||||
|
pub fn round_to(&self, decimals: i32) -> Result<Self> {
|
||||||
|
let mult = 10f64.powi(decimals);
|
||||||
|
(self * mult)?.round()? * (1f64 / mult)
|
||||||
|
}
|
||||||
|
|
||||||
/// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple
|
/// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple
|
||||||
/// dimensions, an error is returned instead.
|
/// dimensions, an error is returned instead.
|
||||||
@ -644,7 +670,12 @@ impl Tensor {
|
|||||||
let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
|
let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
|
||||||
let mut dims = self.dims().to_vec();
|
let mut dims = self.dims().to_vec();
|
||||||
dims[dim] = 1;
|
dims[dim] = 1;
|
||||||
let op = BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()));
|
let op = match op {
|
||||||
|
ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => {
|
||||||
|
BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()))
|
||||||
|
}
|
||||||
|
ReduceOp::ArgMin | ReduceOp::ArgMax => BackpropOp::none(),
|
||||||
|
};
|
||||||
let res = from_storage(storage, dims, op, false);
|
let res = from_storage(storage, dims, op, false);
|
||||||
if keepdim {
|
if keepdim {
|
||||||
Ok(res)
|
Ok(res)
|
||||||
@ -827,12 +858,35 @@ impl Tensor {
|
|||||||
self.cmp(rhs, CmpOp::Le)
|
self.cmp(rhs, CmpOp::Le)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Upsample the input tensor to the `(target_h, target_w)` size, taking the value of the
|
/// Clamp the tensor values to be between `min` and `max`.
|
||||||
|
pub fn clamp<T1: TensorOrScalar, T2: TensorOrScalar>(&self, min: T1, max: T2) -> Result<Self> {
|
||||||
|
self.maximum(min)?.minimum(max)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Interpolate the input tensor to the `target_size` size, taking the value of the nearest element.
|
||||||
|
///
|
||||||
|
/// The input tensor should have three dimensions, `(batch, channels, l)`, the returned
|
||||||
|
/// tensor also has three dimensions, `(batch, channels, target_size)`.
|
||||||
|
pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
|
||||||
|
let (n, c, _l) = self.dims3()?;
|
||||||
|
let op = BackpropOp::new1(self, Op::UpsampleNearest1D);
|
||||||
|
let storage = self
|
||||||
|
.storage()
|
||||||
|
.upsample_nearest1d(self.layout(), target_size)?;
|
||||||
|
Ok(from_storage(storage, (n, c, target_size), op, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Alias for `interpolate1d`.
|
||||||
|
pub fn upsample_nearest1d(&self, target_size: usize) -> Result<Self> {
|
||||||
|
self.interpolate1d(target_size)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Interpolate the input tensor to the `(target_h, target_w)` size, taking the value of the
|
||||||
/// nearest element.
|
/// nearest element.
|
||||||
///
|
///
|
||||||
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
|
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
|
||||||
/// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
|
/// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
|
||||||
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
|
pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
|
||||||
let (n, c, _h, _w) = self.dims4()?;
|
let (n, c, _h, _w) = self.dims4()?;
|
||||||
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
|
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
|
||||||
let storage = self
|
let storage = self
|
||||||
@ -841,6 +895,11 @@ impl Tensor {
|
|||||||
Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
|
Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Alias for `interpolate2d`.
|
||||||
|
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
|
||||||
|
self.interpolate2d(target_h, target_w)
|
||||||
|
}
|
||||||
|
|
||||||
/// 2D average pooling over an input tensor with multiple channels.
|
/// 2D average pooling over an input tensor with multiple channels.
|
||||||
///
|
///
|
||||||
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
|
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
|
||||||
@ -1075,6 +1134,74 @@ impl Tensor {
|
|||||||
Ok(from_storage(storage, self.shape(), op, false))
|
Ok(from_storage(storage, self.shape(), op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
|
||||||
|
pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
|
||||||
|
let dim = dim.to_index(self.shape(), "slice-scatter")?;
|
||||||
|
if dim == 0 {
|
||||||
|
self.slice_scatter0(src, start)
|
||||||
|
} else {
|
||||||
|
// TODO: Maybe we want to add a more efficient implementation at some point.
|
||||||
|
self.transpose(0, dim)?
|
||||||
|
.slice_scatter0(&src.transpose(0, dim)?, start)?
|
||||||
|
.transpose(0, dim)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Embeds the values of the `src` tensor into the `self` tensor on the first dimension.
|
||||||
|
pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result<Self> {
|
||||||
|
if self.dtype() != src.dtype() {
|
||||||
|
Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: self.dtype(),
|
||||||
|
rhs: src.dtype(),
|
||||||
|
op: "slice-scatter",
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
if self.device().location() != src.device.location() {
|
||||||
|
Err(Error::DeviceMismatchBinaryOp {
|
||||||
|
lhs: self.device().location(),
|
||||||
|
rhs: src.device().location(),
|
||||||
|
op: "slice-scatter",
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
if self.rank() != src.rank() {
|
||||||
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
|
expected: self.rank(),
|
||||||
|
got: src.rank(),
|
||||||
|
shape: src.shape().clone(),
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
let shape_ok =
|
||||||
|
self.dims()
|
||||||
|
.iter()
|
||||||
|
.zip(src.dims().iter())
|
||||||
|
.enumerate()
|
||||||
|
.all(|(dim_idx, (&d1, &d2))| {
|
||||||
|
if 0 == dim_idx {
|
||||||
|
d2 + start <= d1
|
||||||
|
} else {
|
||||||
|
d1 == d2
|
||||||
|
}
|
||||||
|
});
|
||||||
|
if !shape_ok {
|
||||||
|
Err(Error::ShapeMismatchBinaryOp {
|
||||||
|
op: "slice-scatter (self, src)",
|
||||||
|
lhs: self.shape().clone(),
|
||||||
|
rhs: src.shape().clone(),
|
||||||
|
})?
|
||||||
|
}
|
||||||
|
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
|
||||||
|
self.storage()
|
||||||
|
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||||
|
let offset = start * src.dims()[1..].iter().product::<usize>();
|
||||||
|
src.storage()
|
||||||
|
.copy_strided_src(&mut storage, offset, src.layout())?;
|
||||||
|
let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start));
|
||||||
|
Ok(from_storage(storage, self.shape(), op, false))
|
||||||
|
}
|
||||||
|
|
||||||
/// Accumulate element from `source` at indexes `indexes` and add them to `self`.
|
/// Accumulate element from `source` at indexes `indexes` and add them to `self`.
|
||||||
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||||
let dim = dim.to_index(self.shape(), "index-add")?;
|
let dim = dim.to_index(self.shape(), "index-add")?;
|
||||||
@ -1491,6 +1618,9 @@ impl Tensor {
|
|||||||
pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
|
pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
|
||||||
let dim1 = dim1.to_index(self.shape(), "transpose")?;
|
let dim1 = dim1.to_index(self.shape(), "transpose")?;
|
||||||
let dim2 = dim2.to_index(self.shape(), "transpose")?;
|
let dim2 = dim2.to_index(self.shape(), "transpose")?;
|
||||||
|
if dim1 == dim2 {
|
||||||
|
return Ok(self.clone());
|
||||||
|
}
|
||||||
let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
|
let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
@ -1852,6 +1982,34 @@ impl Tensor {
|
|||||||
for arg in args {
|
for arg in args {
|
||||||
arg.as_ref().check_dim(dim, "cat")?;
|
arg.as_ref().check_dim(dim, "cat")?;
|
||||||
}
|
}
|
||||||
|
for (arg_idx, arg) in args.iter().enumerate() {
|
||||||
|
let arg = arg.as_ref();
|
||||||
|
if arg0.rank() != arg.rank() {
|
||||||
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
|
expected: arg0.rank(),
|
||||||
|
got: arg.rank(),
|
||||||
|
shape: arg.shape().clone(),
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
for (dim_idx, (v1, v2)) in arg0
|
||||||
|
.shape()
|
||||||
|
.dims()
|
||||||
|
.iter()
|
||||||
|
.zip(arg.shape().dims().iter())
|
||||||
|
.enumerate()
|
||||||
|
{
|
||||||
|
if dim_idx != dim && v1 != v2 {
|
||||||
|
Err(Error::ShapeMismatchCat {
|
||||||
|
dim: dim_idx,
|
||||||
|
first_shape: arg0.shape().clone(),
|
||||||
|
n: arg_idx + 1,
|
||||||
|
nth_shape: arg.shape().clone(),
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
if dim == 0 {
|
if dim == 0 {
|
||||||
Self::cat0(args)
|
Self::cat0(args)
|
||||||
} else {
|
} else {
|
||||||
|
@ -218,6 +218,22 @@ fn binary_grad(device: &Device) -> Result<()> {
|
|||||||
let grad_x = grads.get(x).context("no grad for x")?;
|
let grad_x = grads.get(x).context("no grad for x")?;
|
||||||
assert_eq!(y.to_vec1::<f32>()?, [3., 1., -4., -1.]);
|
assert_eq!(y.to_vec1::<f32>()?, [3., 1., -4., -1.]);
|
||||||
assert_eq!(grad_x.to_vec1::<f32>()?, [1., 1., 1., 1.]);
|
assert_eq!(grad_x.to_vec1::<f32>()?, [1., 1., 1., 1.]);
|
||||||
|
|
||||||
|
let x_var = Var::new(&[3f32, 1., -4., -1., 5., 9.], device)?;
|
||||||
|
let x = x_var.as_tensor();
|
||||||
|
let y_var = Var::new(&[2f32, 7., 1.], device)?;
|
||||||
|
let y = y_var.as_tensor();
|
||||||
|
|
||||||
|
let ss = x
|
||||||
|
.reshape((2, 3))?
|
||||||
|
.slice_scatter0(&y.reshape((1, 3))?, 1)?
|
||||||
|
.sqr()?;
|
||||||
|
let grads = ss.backward()?;
|
||||||
|
let grad_x = grads.get(x).context("no grad for x")?;
|
||||||
|
let grad_y = grads.get(y).context("no grad for y")?;
|
||||||
|
assert_eq!(ss.to_vec2::<f32>()?, [[9., 1., 16.], [4., 49., 1.]]);
|
||||||
|
assert_eq!(grad_x.to_vec1::<f32>()?, [6.0, 2.0, -8.0, 0.0, 0.0, 0.0]);
|
||||||
|
assert_eq!(grad_y.to_vec1::<f32>()?, [4.0, 14.0, 2.0]);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
9
candle-core/tests/npy.py
Normal file
9
candle-core/tests/npy.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
import numpy as np
|
||||||
|
x = np.arange(10)
|
||||||
|
|
||||||
|
# Write a npy file.
|
||||||
|
np.save("test.npy", x)
|
||||||
|
|
||||||
|
# Write multiple values to a npz file.
|
||||||
|
values = { "x": x, "x_plus_one": x + 1 }
|
||||||
|
np.savez("test.npz", **values)
|
@ -43,7 +43,7 @@ fn quantized_matmul() -> Result<()> {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||||
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||||
let res = matmul.forward(&tensor_lhs)?;
|
let res = matmul.forward(&tensor_lhs)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
to_vec2_round(&res, 0)?,
|
to_vec2_round(&res, 0)?,
|
||||||
@ -91,7 +91,7 @@ fn quantized_matmul_neg() -> Result<()> {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||||
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||||
let res = matmul.forward(&tensor_lhs)?;
|
let res = matmul.forward(&tensor_lhs)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
to_vec2_round(&res, 0)?,
|
to_vec2_round(&res, 0)?,
|
||||||
@ -491,6 +491,9 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
|
|||||||
GgmlDType::Q5_0 => 0.001353,
|
GgmlDType::Q5_0 => 0.001353,
|
||||||
GgmlDType::Q5_1 => 0.001363,
|
GgmlDType::Q5_1 => 0.001363,
|
||||||
GgmlDType::Q8_0 => 0.000092,
|
GgmlDType::Q8_0 => 0.000092,
|
||||||
|
|
||||||
|
// Not from the ggml repo.
|
||||||
|
GgmlDType::Q8K => 0.00065,
|
||||||
_ => candle_core::bail!("No GGML results for quantization type {dtype:?}",),
|
_ => candle_core::bail!("No GGML results for quantization type {dtype:?}",),
|
||||||
};
|
};
|
||||||
Ok(err)
|
Ok(err)
|
||||||
@ -508,17 +511,22 @@ fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
|
|||||||
T::VecDotType::from_float(&b, &mut b_quant)?;
|
T::VecDotType::from_float(&b, &mut b_quant)?;
|
||||||
|
|
||||||
let result = T::vec_dot(length, &a_quant, &b_quant)?;
|
let result = T::vec_dot(length, &a_quant, &b_quant)?;
|
||||||
|
let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?;
|
||||||
let reference_result = vec_dot_reference(&a, &b);
|
let reference_result = vec_dot_reference(&a, &b);
|
||||||
|
|
||||||
|
if (result - result_unopt).abs() / length as f32 > 1e-6 {
|
||||||
|
candle_core::bail!(
|
||||||
|
"the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
let error = (result - reference_result).abs() / length as f32;
|
let error = (result - reference_result).abs() / length as f32;
|
||||||
|
|
||||||
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;
|
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;
|
||||||
|
|
||||||
if error > GGML_MAX_DOT_PRODUCT_ERROR {
|
if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR {
|
||||||
candle_core::bail!(
|
candle_core::bail!(
|
||||||
"Dot product error {} exceeds max error {}",
|
"Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",
|
||||||
error,
|
|
||||||
GGML_MAX_DOT_PRODUCT_ERROR
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -571,7 +579,7 @@ fn quantized_matmul_q2k() -> Result<()> {
|
|||||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||||
|
|
||||||
let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
|
let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
|
||||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
let mm = rhs.forward(&lhs)?;
|
let mm = rhs.forward(&lhs)?;
|
||||||
|
|
||||||
assert_eq!(mm.dims(), [m, n]);
|
assert_eq!(mm.dims(), [m, n]);
|
||||||
@ -597,7 +605,7 @@ fn quantized_matmul_q3k() -> Result<()> {
|
|||||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||||
|
|
||||||
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
|
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
|
||||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
let mm = rhs.forward(&lhs)?;
|
let mm = rhs.forward(&lhs)?;
|
||||||
|
|
||||||
assert_eq!(mm.dims(), [m, n]);
|
assert_eq!(mm.dims(), [m, n]);
|
||||||
@ -623,7 +631,7 @@ fn quantized_matmul_q4k() -> Result<()> {
|
|||||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||||
|
|
||||||
let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
|
let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
|
||||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
let mm = rhs.forward(&lhs)?;
|
let mm = rhs.forward(&lhs)?;
|
||||||
|
|
||||||
assert_eq!(mm.dims(), [m, n]);
|
assert_eq!(mm.dims(), [m, n]);
|
||||||
@ -649,7 +657,7 @@ fn quantized_matmul_q5k() -> Result<()> {
|
|||||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||||
|
|
||||||
let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?;
|
let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?;
|
||||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
let mm = rhs.forward(&lhs)?;
|
let mm = rhs.forward(&lhs)?;
|
||||||
|
|
||||||
assert_eq!(mm.dims(), [m, n]);
|
assert_eq!(mm.dims(), [m, n]);
|
||||||
@ -676,7 +684,7 @@ fn quantized_matmul_q6k() -> Result<()> {
|
|||||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||||
|
|
||||||
let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
|
let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
|
||||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
let mm = rhs.forward(&lhs)?;
|
let mm = rhs.forward(&lhs)?;
|
||||||
|
|
||||||
assert_eq!(mm.dims(), [m, n]);
|
assert_eq!(mm.dims(), [m, n]);
|
||||||
@ -687,3 +695,28 @@ fn quantized_matmul_q6k() -> Result<()> {
|
|||||||
ggml_matmul_error_test::<BlockQ6K>()?;
|
ggml_matmul_error_test::<BlockQ6K>()?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn quantized_matmul_q8k() -> Result<()> {
|
||||||
|
use k_quants::BlockQ8K;
|
||||||
|
|
||||||
|
let cpu = &Device::Cpu;
|
||||||
|
let (m, k, n) = (11, 512, 21);
|
||||||
|
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
|
||||||
|
assert_eq!(mm.dims(), [m, n]);
|
||||||
|
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||||
|
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||||
|
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||||
|
|
||||||
|
let rhs = quantized::QTensor::quantize::<BlockQ8K>(&rhs)?;
|
||||||
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
|
let mm = rhs.forward(&lhs)?;
|
||||||
|
|
||||||
|
assert_eq!(mm.dims(), [m, n]);
|
||||||
|
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||||
|
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||||
|
assert_eq!(dst, [1.266, 1.504, -0.204, 1.7]);
|
||||||
|
|
||||||
|
ggml_matmul_error_test::<BlockQ8K>()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
24
candle-core/tests/serialization_tests.rs
Normal file
24
candle-core/tests/serialization_tests.rs
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
use candle_core::{DType, Result, Tensor};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn npy() -> Result<()> {
|
||||||
|
let npy = Tensor::read_npy("tests/test.npy")?;
|
||||||
|
assert_eq!(
|
||||||
|
npy.to_dtype(DType::U8)?.to_vec1::<u8>()?,
|
||||||
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn npz() -> Result<()> {
|
||||||
|
let npz = Tensor::read_npz("tests/test.npz")?;
|
||||||
|
assert_eq!(npz.len(), 2);
|
||||||
|
assert_eq!(npz[0].0, "x");
|
||||||
|
assert_eq!(npz[1].0, "x_plus_one");
|
||||||
|
assert_eq!(
|
||||||
|
npz[1].1.to_dtype(DType::U8)?.to_vec1::<u8>()?,
|
||||||
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor};
|
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor};
|
||||||
|
|
||||||
fn zeros(device: &Device) -> Result<()> {
|
fn zeros(device: &Device) -> Result<()> {
|
||||||
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
|
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
|
||||||
@ -8,6 +8,31 @@ fn zeros(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn ones(device: &Device) -> Result<()> {
|
||||||
|
assert_eq!(
|
||||||
|
Tensor::ones((2, 3), DType::U8, device)?.to_vec2::<u8>()?,
|
||||||
|
[[1, 1, 1], [1, 1, 1]],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Tensor::ones((2, 3), DType::U32, device)?.to_vec2::<u32>()?,
|
||||||
|
[[1, 1, 1], [1, 1, 1]],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Tensor::ones((2, 3), DType::I64, device)?.to_vec2::<i64>()?,
|
||||||
|
[[1, 1, 1], [1, 1, 1]],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Tensor::ones((2, 3), DType::F32, device)?.to_vec2::<f32>()?,
|
||||||
|
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
||||||
|
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn add_mul(device: &Device) -> Result<()> {
|
fn add_mul(device: &Device) -> Result<()> {
|
||||||
let tensor = Tensor::new(&[3f32, 1., 4.], device)?;
|
let tensor = Tensor::new(&[3f32, 1., 4.], device)?;
|
||||||
let dim1 = tensor.dims1()?;
|
let dim1 = tensor.dims1()?;
|
||||||
@ -33,6 +58,65 @@ fn tensor_2d(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn clamp(device: &Device) -> Result<()> {
|
||||||
|
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||||
|
let tensor = Tensor::new(data, device)?;
|
||||||
|
let tensor = tensor.clamp(1.5, 6.2)?;
|
||||||
|
assert_eq!(
|
||||||
|
tensor.to_vec2::<f32>()?,
|
||||||
|
[[3.0, 1.5, 4.0, 1.5, 5.0], [2.0, 1.5, 6.2, 6.2, 2.0]],
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unary_op(device: &Device) -> Result<()> {
|
||||||
|
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
|
||||||
|
let tensor = Tensor::new(data, device)?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec2_round(&tensor.gelu()?, 4)?,
|
||||||
|
[
|
||||||
|
[-0.0036, 0.8412, 3.9999, -0.046, 0.3457],
|
||||||
|
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
|
||||||
|
[
|
||||||
|
[-0.004, 0.8413, 3.9999, -0.046, 0.3457],
|
||||||
|
[2.6906, -0.0647, -0.1091, 1.7353, 2.7928]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec2_round(&tensor.erf()?, 4)?,
|
||||||
|
[
|
||||||
|
[-1.0, 0.8427, 1.0, -0.1125, 0.5205],
|
||||||
|
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
|
||||||
|
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec2_round(&tensor.floor()?, 4)?,
|
||||||
|
[[-3.0, 1.0, 4.0, -1.0, 0.0], [2.0, -2.0, -1.0, 1.0, 2.0]]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec2_round(&tensor.round()?, 4)?,
|
||||||
|
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -2.0, -0.0, 2.0, 3.0]]
|
||||||
|
);
|
||||||
|
let tensor = Tensor::new(&[2997.9246, 314.15926f32], device)?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&tensor.round_to(2)?, 4)?,
|
||||||
|
[2997.92, 314.16]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?,
|
||||||
|
[3000.0, 300.]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn binary_op(device: &Device) -> Result<()> {
|
fn binary_op(device: &Device) -> Result<()> {
|
||||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||||
let tensor1 = Tensor::new(data, device)?;
|
let tensor1 = Tensor::new(data, device)?;
|
||||||
@ -590,6 +674,30 @@ fn index_select(device: &Device) -> Result<()> {
|
|||||||
hs.to_vec2::<f32>()?,
|
hs.to_vec2::<f32>()?,
|
||||||
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
||||||
);
|
);
|
||||||
|
// Prior to https://github.com/huggingface/candle/pull/1022
|
||||||
|
// There would be a bug where the last values in the result tensor would be set to 0.
|
||||||
|
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
|
||||||
|
let hs = t.index_select(&ids, 0)?;
|
||||||
|
assert_eq!(
|
||||||
|
hs.to_vec2::<f32>()?,
|
||||||
|
&[
|
||||||
|
[0.0, 1.0, 2.0],
|
||||||
|
[6.0, 7.0, 8.0],
|
||||||
|
[3.0, 4.0, 5.0],
|
||||||
|
[0.0, 1.0, 2.0],
|
||||||
|
[6.0, 7.0, 8.0],
|
||||||
|
[3.0, 4.0, 5.0],
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
// Test when selecting dim > 0 with ids size different from elem count of
|
||||||
|
// target dim in source/input.
|
||||||
|
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
|
||||||
|
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
|
||||||
|
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
|
||||||
|
let hs = t.index_select(&ids, 1)?;
|
||||||
|
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -636,6 +744,48 @@ fn index_add(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn slice_scatter(device: &Device) -> Result<()> {
|
||||||
|
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||||
|
assert_eq!(
|
||||||
|
t.to_vec2::<f32>()?,
|
||||||
|
&[
|
||||||
|
[0.0, 1.0, 2.0],
|
||||||
|
[3.0, 4.0, 5.0],
|
||||||
|
[6.0, 7.0, 8.0],
|
||||||
|
[9.0, 10.0, 11.0]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
let src = Tensor::arange(100f32, 106f32, device)?.reshape((2, 3))?;
|
||||||
|
assert_eq!(
|
||||||
|
t.slice_scatter0(&src, 0)?.to_vec2::<f32>()?,
|
||||||
|
&[
|
||||||
|
[100.0, 101.0, 102.0],
|
||||||
|
[103.0, 104.0, 105.0],
|
||||||
|
[6.0, 7.0, 8.0],
|
||||||
|
[9.0, 10.0, 11.0]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
t.slice_scatter0(&src, 1)?.to_vec2::<f32>()?,
|
||||||
|
&[
|
||||||
|
[0.0, 1.0, 2.0],
|
||||||
|
[100.0, 101.0, 102.0],
|
||||||
|
[103.0, 104.0, 105.0],
|
||||||
|
[9.0, 10.0, 11.0]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
t.slice_scatter0(&src, 2)?.to_vec2::<f32>()?,
|
||||||
|
&[
|
||||||
|
[0.0, 1.0, 2.0],
|
||||||
|
[3.0, 4.0, 5.0],
|
||||||
|
[100.0, 101.0, 102.0],
|
||||||
|
[103.0, 104.0, 105.0],
|
||||||
|
]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn scatter_add(device: &Device) -> Result<()> {
|
fn scatter_add(device: &Device) -> Result<()> {
|
||||||
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -877,7 +1027,16 @@ fn broadcasting(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn randn(device: &Device) -> Result<()> {
|
||||||
|
let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?;
|
||||||
|
assert_eq!(tensor.dims(), [5, 3]);
|
||||||
|
let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?;
|
||||||
|
assert_eq!(tensor.dims(), [5, 3]);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
test_device!(zeros, zeros_cpu, zeros_gpu);
|
test_device!(zeros, zeros_cpu, zeros_gpu);
|
||||||
|
test_device!(ones, ones_cpu, ones_gpu);
|
||||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
||||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
||||||
test_device!(narrow, narrow_cpu, narrow_gpu);
|
test_device!(narrow, narrow_cpu, narrow_gpu);
|
||||||
@ -889,6 +1048,7 @@ test_device!(max, max_cpu, max_gpu);
|
|||||||
test_device!(argmax, argmax_cpu, argmax_gpu);
|
test_device!(argmax, argmax_cpu, argmax_gpu);
|
||||||
test_device!(argmin, argmin_cpu, argmin_gpu);
|
test_device!(argmin, argmin_cpu, argmin_gpu);
|
||||||
test_device!(transpose, transpose_cpu, transpose_gpu);
|
test_device!(transpose, transpose_cpu, transpose_gpu);
|
||||||
|
test_device!(unary_op, unary_op_cpu, unary_op_gpu);
|
||||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
||||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
||||||
test_device!(cmp, cmp_cpu, cmp_gpu);
|
test_device!(cmp, cmp_cpu, cmp_gpu);
|
||||||
@ -899,6 +1059,9 @@ test_device!(index_select, index_select_cpu, index_select_gpu);
|
|||||||
test_device!(index_add, index_add_cpu, index_add_gpu);
|
test_device!(index_add, index_add_cpu, index_add_gpu);
|
||||||
test_device!(gather, gather_cpu, gather_gpu);
|
test_device!(gather, gather_cpu, gather_gpu);
|
||||||
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
||||||
|
test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
|
||||||
|
test_device!(randn, randn_cpu, randn_gpu);
|
||||||
|
test_device!(clamp, clamp_cpu, clamp_gpu);
|
||||||
|
|
||||||
// There was originally a bug on the CPU implementation for randn
|
// There was originally a bug on the CPU implementation for randn
|
||||||
// https://github.com/huggingface/candle/issues/381
|
// https://github.com/huggingface/candle/issues/381
|
||||||
|
BIN
candle-core/tests/test.npy
Normal file
BIN
candle-core/tests/test.npy
Normal file
Binary file not shown.
BIN
candle-core/tests/test.npz
Normal file
BIN
candle-core/tests/test.npz
Normal file
Binary file not shown.
@ -11,8 +11,8 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.2.1" }
|
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||||
hf-hub = { workspace = true}
|
hf-hub = { workspace = true}
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
memmap2 = { workspace = true }
|
memmap2 = { workspace = true }
|
||||||
|
@ -8,13 +8,9 @@ use parquet::file::reader::{FileReader, SerializedFileReader};
|
|||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::{self, BufReader, Read};
|
use std::io::{self, BufReader, Read};
|
||||||
|
|
||||||
fn read_u32<T: Read>(reader: &mut T) -> Result<u32> {
|
fn read_u32<T: Read>(reader: &mut T) -> std::io::Result<u32> {
|
||||||
let mut b = vec![0u8; 4];
|
use byteorder::ReadBytesExt;
|
||||||
reader.read_exact(&mut b)?;
|
reader.read_u32::<byteorder::BigEndian>()
|
||||||
let (result, _) = b.iter().rev().fold((0u64, 1u64), |(s, basis), &x| {
|
|
||||||
(s + basis * u64::from(x), basis * 256)
|
|
||||||
});
|
|
||||||
Ok(result as u32)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_magic_number<T: Read>(reader: &mut T, expected: u32) -> Result<()> {
|
fn check_magic_number<T: Read>(reader: &mut T, expected: u32) -> Result<()> {
|
||||||
|
@ -11,19 +11,22 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-datasets = { path = "../candle-datasets", version = "0.2.1" }
|
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.2.1" }
|
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||||
candle-transformers = { path = "../candle-transformers", version = "0.2.1" }
|
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
||||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
|
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||||
safetensors = { workspace = true }
|
|
||||||
serde = { workspace = true }
|
|
||||||
serde_json = { workspace = true }
|
|
||||||
num-traits = { workspace = true }
|
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
half = { workspace = true, optional = true }
|
half = { workspace = true, optional = true }
|
||||||
image = { workspace = true }
|
image = { workspace = true }
|
||||||
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
|
num-traits = { workspace = true }
|
||||||
|
pyo3 = { version = "0.19.0", features = ["auto-initialize"], optional = true }
|
||||||
|
rayon = { workspace = true }
|
||||||
|
safetensors = { workspace = true }
|
||||||
|
serde = { workspace = true }
|
||||||
|
serde_json = { workspace = true }
|
||||||
|
tokenizers = { workspace = true, features = ["onig"] }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
@ -34,7 +37,6 @@ imageproc = { workspace = true }
|
|||||||
memmap2 = { workspace = true }
|
memmap2 = { workspace = true }
|
||||||
rand = { workspace = true }
|
rand = { workspace = true }
|
||||||
rusttype = { workspace = true }
|
rusttype = { workspace = true }
|
||||||
tokenizers = { workspace = true, features = ["onig"] }
|
|
||||||
tracing = { workspace = true }
|
tracing = { workspace = true }
|
||||||
tracing-chrome = { workspace = true }
|
tracing-chrome = { workspace = true }
|
||||||
tracing-subscriber = { workspace = true }
|
tracing-subscriber = { workspace = true }
|
||||||
@ -50,10 +52,14 @@ default = []
|
|||||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||||
cudnn = ["candle/cudnn"]
|
cudnn = ["candle/cudnn"]
|
||||||
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
name = "llama_multiprocess"
|
name = "llama_multiprocess"
|
||||||
required-features = ["cuda", "nccl", "flash-attn"]
|
required-features = ["cuda", "nccl", "flash-attn"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "reinforcement-learning"
|
||||||
|
required-features = ["pyo3"]
|
||||||
|
44
candle-examples/examples/bert/README.md
Normal file
44
candle-examples/examples/bert/README.md
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
# candle-bert
|
||||||
|
|
||||||
|
Bert is a general large language model. In this example it can be used for two
|
||||||
|
different tasks:
|
||||||
|
- Compute sentence embeddings for a prompt.
|
||||||
|
- Compute similarities between a set of sentences.
|
||||||
|
|
||||||
|
|
||||||
|
## Sentence embeddings
|
||||||
|
|
||||||
|
Bert is used to compute the sentence embeddings for a prompt. The model weights
|
||||||
|
are downloaded from the hub on the first run.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example bert --release -- --prompt "Here is a test sentence"
|
||||||
|
|
||||||
|
> [[[ 0.0798, -0.0665, -0.0247, ..., -0.1082, -0.1000, -0.2751],
|
||||||
|
> [ 0.4218, 0.2690, 0.2740, ..., 0.3889, 1.3503, 0.9908],
|
||||||
|
> [ 0.0466, 0.3041, -0.1143, ..., 0.4427, 0.6926, -0.1515],
|
||||||
|
> ...
|
||||||
|
> [ 0.3396, 0.4320, -0.4408, ..., 0.9212, 0.2331, -0.6777],
|
||||||
|
> [ 0.2789, 0.7539, 0.4306, ..., -0.0095, 0.3375, -1.7529],
|
||||||
|
> [ 0.6737, 0.7882, 0.0548, ..., 0.1836, 0.7299, -0.6617]]]
|
||||||
|
> Tensor[[1, 7, 384], f32]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Similarities
|
||||||
|
|
||||||
|
In this example, Bert is used to compute the sentence embeddings for a set of
|
||||||
|
sentences (hardcoded in the examples). Then cosine similarities are computed for
|
||||||
|
each sentence pair and they are reported by decreasing values, hence the first
|
||||||
|
reported pair contains the two sentences that have the highest similarity score.
|
||||||
|
The sentence embeddings are computed using average pooling through all the
|
||||||
|
sentence tokens, including some potential padding.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example bert --release
|
||||||
|
|
||||||
|
> score: 0.85 'The new movie is awesome' 'The new movie is so great'
|
||||||
|
> score: 0.61 'The cat sits outside' 'The cat plays in the garden'
|
||||||
|
> score: 0.52 'I love pasta' 'Do you like pizza?'
|
||||||
|
> score: 0.23 'The new movie is awesome' 'Do you like pizza?'
|
||||||
|
> score: 0.22 'I love pasta' 'The new movie is awesome'
|
||||||
|
```
|
@ -3,14 +3,13 @@ extern crate intel_mkl_src;
|
|||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
mod model;
|
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
|
||||||
|
|
||||||
use anyhow::{anyhow, Error as E, Result};
|
use anyhow::{anyhow, Error as E, Result};
|
||||||
use candle::Tensor;
|
use candle::Tensor;
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
|
||||||
use model::{BertModel, Config, DTYPE};
|
|
||||||
use tokenizers::{PaddingParams, Tokenizer};
|
use tokenizers::{PaddingParams, Tokenizer};
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -87,9 +86,8 @@ impl Args {
|
|||||||
let config: Config = serde_json::from_str(&config)?;
|
let config: Config = serde_json::from_str(&config)?;
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
let vb =
|
||||||
let weights = weights.deserialize()?;
|
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? };
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
|
||||||
let model = BertModel::load(vb, &config)?;
|
let model = BertModel::load(vb, &config)?;
|
||||||
Ok((model, tokenizer))
|
Ok((model, tokenizer))
|
||||||
}
|
}
|
||||||
|
19
candle-examples/examples/bigcode/README.md
Normal file
19
candle-examples/examples/bigcode/README.md
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# candle-starcoder: code generation model
|
||||||
|
|
||||||
|
[StarCoder/BigCode](https://huggingface.co/bigcode/starcoderbase-1b) is a LLM
|
||||||
|
model specialized to code generation. The initial model was trained on 80
|
||||||
|
programming languages.
|
||||||
|
|
||||||
|
## Running some example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example bigcode --release -- --prompt "fn fact(n: u64) -> u64 "
|
||||||
|
|
||||||
|
> fn fact(n: u64) -> u64 {
|
||||||
|
> if n == 0 {
|
||||||
|
> 1
|
||||||
|
> } else {
|
||||||
|
> n * fact(n - 1)
|
||||||
|
> }
|
||||||
|
> }
|
||||||
|
```
|
@ -7,8 +7,7 @@ extern crate accelerate_src;
|
|||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
|
||||||
mod model;
|
use candle_transformers::models::bigcode::{Config, GPTBigCode};
|
||||||
use model::{Config, GPTBigCode};
|
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
@ -29,9 +28,10 @@ impl TextGeneration {
|
|||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
seed: u64,
|
seed: u64,
|
||||||
temp: Option<f64>,
|
temp: Option<f64>,
|
||||||
|
top_p: Option<f64>,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let logits_processor = LogitsProcessor::new(seed, temp);
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||||
Self {
|
Self {
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@ -95,6 +95,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
temperature: Option<f64>,
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
/// The seed to use when generating random samples.
|
||||||
#[arg(long, default_value_t = 299792458)]
|
#[arg(long, default_value_t = 299792458)]
|
||||||
seed: u64,
|
seed: u64,
|
||||||
@ -134,23 +138,21 @@ fn main() -> Result<()> {
|
|||||||
println!("retrieved the files in {:?}", start.elapsed());
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let weights = filenames
|
|
||||||
.iter()
|
|
||||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
let weights = weights
|
|
||||||
.iter()
|
|
||||||
.map(|f| Ok(f.deserialize()?))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let vb = VarBuilder::from_safetensors(weights, DType::F32, &device);
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||||
let config = Config::starcoder_1b();
|
let config = Config::starcoder_1b();
|
||||||
let model = GPTBigCode::load(vb, config)?;
|
let model = GPTBigCode::load(vb, config)?;
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device);
|
let mut pipeline = TextGeneration::new(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
args.seed,
|
||||||
|
args.temperature,
|
||||||
|
args.top_p,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
pipeline.run(&args.prompt, args.sample_len)?;
|
pipeline.run(&args.prompt, args.sample_len)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
59
candle-examples/examples/convmixer/main.rs
Normal file
59
candle-examples/examples/convmixer/main.rs
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle::{DType, IndexOp, D};
|
||||||
|
use candle_nn::{Module, VarBuilder};
|
||||||
|
use candle_transformers::models::convmixer;
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
image: String,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||||
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
|
let model_file = match args.model {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model("lmz/candle-convmixer".into());
|
||||||
|
api.get("convmixer_1024_20_ks9_p14.safetensors")?
|
||||||
|
}
|
||||||
|
Some(model) => model.into(),
|
||||||
|
};
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||||
|
let model = convmixer::c1024_20(1000, vb)?;
|
||||||
|
println!("model built");
|
||||||
|
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||||
|
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||||
|
.i(0)?
|
||||||
|
.to_vec1::<f32>()?;
|
||||||
|
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||||
|
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||||
|
for &(category_idx, pr) in prs.iter().take(5) {
|
||||||
|
println!(
|
||||||
|
"{:24}: {:.2}%",
|
||||||
|
candle_examples::imagenet::CLASSES[category_idx],
|
||||||
|
100. * pr
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
19
candle-examples/examples/dinov2/README.md
Normal file
19
candle-examples/examples/dinov2/README.md
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# candle-dinov2
|
||||||
|
|
||||||
|
[DINOv2](https://github.com/facebookresearch/dinov2) is a computer vision model.
|
||||||
|
In this example, it is used as an ImageNet classifier: the model returns the
|
||||||
|
probability for the image to belong to each of the 1000 ImageNet categories.
|
||||||
|
|
||||||
|
## Running some example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example dinov2 --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
|
||||||
|
> mountain bike, all-terrain bike, off-roader: 43.67%
|
||||||
|
> bicycle-built-for-two, tandem bicycle, tandem: 33.20%
|
||||||
|
> crash helmet : 13.23%
|
||||||
|
> unicycle, monocycle : 2.44%
|
||||||
|
> maillot : 2.42%
|
||||||
|
```
|
||||||
|
|
||||||
|

|
@ -9,285 +9,10 @@ extern crate accelerate_src;
|
|||||||
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
|
||||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
use candle::{DType, IndexOp, D};
|
||||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
use candle_nn::{Module, VarBuilder};
|
||||||
|
use candle_transformers::models::dinov2;
|
||||||
|
|
||||||
const IMG_SIZE: usize = 518;
|
|
||||||
const PATCH_SIZE: usize = 14;
|
|
||||||
const NUM_CLASSES: usize = 1000;
|
|
||||||
|
|
||||||
fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
|
|
||||||
if bias {
|
|
||||||
candle_nn::linear(in_dim, out_dim, vb)
|
|
||||||
} else {
|
|
||||||
candle_nn::linear_no_bias(in_dim, out_dim, vb)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct Attention {
|
|
||||||
qkv: Linear,
|
|
||||||
proj: Linear,
|
|
||||||
num_heads: usize,
|
|
||||||
scale: f64,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Attention {
|
|
||||||
fn new(
|
|
||||||
vb: VarBuilder,
|
|
||||||
dim: usize,
|
|
||||||
num_heads: usize,
|
|
||||||
qkv_bias: bool,
|
|
||||||
proj_bias: bool,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
|
|
||||||
let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
|
|
||||||
let scale = 1. / ((dim / num_heads) as f64).sqrt();
|
|
||||||
Ok(Self {
|
|
||||||
qkv,
|
|
||||||
proj,
|
|
||||||
num_heads,
|
|
||||||
scale,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for Attention {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let (b, n, c) = xs.dims3()?;
|
|
||||||
let qkv = self
|
|
||||||
.qkv
|
|
||||||
.forward(xs)?
|
|
||||||
.reshape((b, n, 3, self.num_heads, c / self.num_heads))?
|
|
||||||
.transpose(1, 2)? // 02134
|
|
||||||
.transpose(0, 1)? // 20134
|
|
||||||
.transpose(2, 3)?; // 20314
|
|
||||||
let q = (qkv.i(0)? * self.scale)?;
|
|
||||||
let k = qkv.i(1)?;
|
|
||||||
let v = qkv.i(2)?;
|
|
||||||
let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?;
|
|
||||||
let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
|
|
||||||
self.proj.forward(&attn)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct LayerScale {
|
|
||||||
gamma: Tensor,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LayerScale {
|
|
||||||
fn new(vb: VarBuilder, dim: usize) -> Result<Self> {
|
|
||||||
let gamma = vb.get(dim, "gamma")?;
|
|
||||||
Ok(Self { gamma })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for LayerScale {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
xs.broadcast_mul(&self.gamma)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct Mlp {
|
|
||||||
fc1: Linear,
|
|
||||||
fc2: Linear,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Mlp {
|
|
||||||
fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> {
|
|
||||||
let out_features = in_features;
|
|
||||||
let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?;
|
|
||||||
let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?;
|
|
||||||
Ok(Self { fc1, fc2 })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for Mlp {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let xs = self.fc1.forward(xs)?.gelu()?;
|
|
||||||
self.fc2.forward(&xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct Block {
|
|
||||||
norm1: LayerNorm,
|
|
||||||
attn: Attention,
|
|
||||||
ls1: LayerScale,
|
|
||||||
norm2: LayerNorm,
|
|
||||||
mlp: Mlp,
|
|
||||||
ls2: LayerScale,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Block {
|
|
||||||
fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
|
|
||||||
let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
|
|
||||||
let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?;
|
|
||||||
let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
|
|
||||||
let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
|
|
||||||
let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
|
|
||||||
let ls2 = LayerScale::new(vb.pp("ls2"), dim)?;
|
|
||||||
Ok(Self {
|
|
||||||
norm1,
|
|
||||||
attn,
|
|
||||||
ls1,
|
|
||||||
norm2,
|
|
||||||
mlp,
|
|
||||||
ls2,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for Block {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let residual = xs;
|
|
||||||
let xs = self
|
|
||||||
.ls1
|
|
||||||
.forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?;
|
|
||||||
let xs = (xs + residual)?;
|
|
||||||
let residual = &xs;
|
|
||||||
let xs = self
|
|
||||||
.ls2
|
|
||||||
.forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?;
|
|
||||||
xs + residual
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct PatchEmbed {
|
|
||||||
proj: candle_nn::Conv2d,
|
|
||||||
patch_size: (usize, usize),
|
|
||||||
num_patches: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PatchEmbed {
|
|
||||||
fn new(
|
|
||||||
vb: VarBuilder,
|
|
||||||
img_size: usize,
|
|
||||||
patch_size: usize,
|
|
||||||
in_chans: usize,
|
|
||||||
embed_dim: usize,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let config = candle_nn::Conv2dConfig {
|
|
||||||
stride: patch_size,
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?;
|
|
||||||
let num_patches = (img_size / patch_size) * (img_size / patch_size);
|
|
||||||
Ok(Self {
|
|
||||||
proj,
|
|
||||||
patch_size: (patch_size, patch_size),
|
|
||||||
num_patches,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for PatchEmbed {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let (_b, _c, h, w) = xs.dims4()?;
|
|
||||||
let (patch_h, patch_w) = self.patch_size;
|
|
||||||
if (h % patch_h) != 0 {
|
|
||||||
candle::bail!("image height {h} is not a multiple of patch height {patch_h}")
|
|
||||||
}
|
|
||||||
if (w % patch_w) != 0 {
|
|
||||||
candle::bail!("image width {w} is not a multiple of patch width {patch_w}")
|
|
||||||
}
|
|
||||||
let xs = self.proj.forward(xs)?;
|
|
||||||
let (b, c, h, w) = xs.dims4()?;
|
|
||||||
// flatten embeddings.
|
|
||||||
xs.reshape((b, c, h * w))?.transpose(1, 2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct DinoVisionTransformer {
|
|
||||||
patch_embed: PatchEmbed,
|
|
||||||
cls_token: Tensor,
|
|
||||||
pos_embed: Tensor,
|
|
||||||
blocks: Vec<Block>,
|
|
||||||
norm: LayerNorm,
|
|
||||||
head: Linear,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DinoVisionTransformer {
|
|
||||||
pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> {
|
|
||||||
let patch_embed =
|
|
||||||
PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?;
|
|
||||||
let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
|
|
||||||
let num_tokens = 1;
|
|
||||||
let pos_embed = vb.get(
|
|
||||||
(1, patch_embed.num_patches + num_tokens, embed_dim),
|
|
||||||
"pos_embed",
|
|
||||||
)?;
|
|
||||||
let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?;
|
|
||||||
let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?;
|
|
||||||
let vb_b = vb.pp("blocks");
|
|
||||||
let blocks = (0..depth)
|
|
||||||
.map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
Ok(Self {
|
|
||||||
patch_embed,
|
|
||||||
cls_token,
|
|
||||||
pos_embed,
|
|
||||||
blocks,
|
|
||||||
norm,
|
|
||||||
head,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
|
|
||||||
let npatch = xs.dim(1)? - 1;
|
|
||||||
let n = self.pos_embed.dim(1)? - 1;
|
|
||||||
let sqrt_n = (n as f64).sqrt();
|
|
||||||
if npatch == n && w == h {
|
|
||||||
return Ok(xs.clone());
|
|
||||||
}
|
|
||||||
let class_pos_embed = self.pos_embed.i((.., ..1))?;
|
|
||||||
let patch_pos_embed = self.pos_embed.i((.., 1..))?;
|
|
||||||
let dim = xs.dim(D::Minus1)?;
|
|
||||||
let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
|
|
||||||
let patch_pos_embed = patch_pos_embed
|
|
||||||
.reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
|
|
||||||
.transpose(2, 3)?
|
|
||||||
.transpose(1, 2)?;
|
|
||||||
// This uses bicubic interpolation in the original implementation.
|
|
||||||
let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
|
|
||||||
let el_count = patch_pos_embed.shape().elem_count();
|
|
||||||
let patch_pos_embed =
|
|
||||||
patch_pos_embed
|
|
||||||
.transpose(1, 2)?
|
|
||||||
.transpose(2, 3)?
|
|
||||||
.reshape((1, el_count / dim, dim))?;
|
|
||||||
Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let (_b, _nc, w, h) = xs.dims4()?;
|
|
||||||
let xs = self.patch_embed.forward(xs)?;
|
|
||||||
let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
|
|
||||||
&xs + &self.interpolate_pos_encoding(&xs, w, h)?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for DinoVisionTransformer {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let mut xs = self.prepare_tokens_with_mask(xs)?;
|
|
||||||
for blk in self.blocks.iter() {
|
|
||||||
xs = blk.forward(&xs)?
|
|
||||||
}
|
|
||||||
let xs = self.norm.forward(&xs)?;
|
|
||||||
let xs_norm_clstoken = xs.i((.., 0))?;
|
|
||||||
let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?;
|
|
||||||
let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?;
|
|
||||||
self.head.forward(&xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> {
|
|
||||||
DinoVisionTransformer::new(vb, 12, 384, 6)
|
|
||||||
}
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
struct Args {
|
struct Args {
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@ -317,10 +42,8 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
Some(model) => model.into(),
|
Some(model) => model.into(),
|
||||||
};
|
};
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||||
let weights = weights.deserialize()?;
|
let model = dinov2::vit_small(vb)?;
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
|
||||||
let model = vit_small(vb)?;
|
|
||||||
println!("model built");
|
println!("model built");
|
||||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||||
|
@ -8,340 +8,11 @@ extern crate intel_mkl_src;
|
|||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use candle::{DType, IndexOp, D};
|
||||||
|
use candle_nn::{Module, VarBuilder};
|
||||||
|
use candle_transformers::models::efficientnet::{EfficientNet, MBConvConfig};
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
|
||||||
use candle_nn as nn;
|
|
||||||
use nn::{Module, VarBuilder};
|
|
||||||
|
|
||||||
// Based on the Python version from torchvision.
|
|
||||||
// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub struct MBConvConfig {
|
|
||||||
expand_ratio: f64,
|
|
||||||
kernel: usize,
|
|
||||||
stride: usize,
|
|
||||||
input_channels: usize,
|
|
||||||
out_channels: usize,
|
|
||||||
num_layers: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn make_divisible(v: f64, divisor: usize) -> usize {
|
|
||||||
let min_value = divisor;
|
|
||||||
let new_v = usize::max(
|
|
||||||
min_value,
|
|
||||||
(v + divisor as f64 * 0.5) as usize / divisor * divisor,
|
|
||||||
);
|
|
||||||
if (new_v as f64) < 0.9 * v {
|
|
||||||
new_v + divisor
|
|
||||||
} else {
|
|
||||||
new_v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> {
|
|
||||||
let bneck_conf = |e, k, s, i, o, n| {
|
|
||||||
let input_channels = make_divisible(i as f64 * width_mult, 8);
|
|
||||||
let out_channels = make_divisible(o as f64 * width_mult, 8);
|
|
||||||
let num_layers = (n as f64 * depth_mult).ceil() as usize;
|
|
||||||
MBConvConfig {
|
|
||||||
expand_ratio: e,
|
|
||||||
kernel: k,
|
|
||||||
stride: s,
|
|
||||||
input_channels,
|
|
||||||
out_channels,
|
|
||||||
num_layers,
|
|
||||||
}
|
|
||||||
};
|
|
||||||
vec![
|
|
||||||
bneck_conf(1., 3, 1, 32, 16, 1),
|
|
||||||
bneck_conf(6., 3, 2, 16, 24, 2),
|
|
||||||
bneck_conf(6., 5, 2, 24, 40, 2),
|
|
||||||
bneck_conf(6., 3, 2, 40, 80, 3),
|
|
||||||
bneck_conf(6., 5, 1, 80, 112, 3),
|
|
||||||
bneck_conf(6., 5, 2, 112, 192, 4),
|
|
||||||
bneck_conf(6., 3, 1, 192, 320, 1),
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MBConvConfig {
|
|
||||||
fn b0() -> Vec<Self> {
|
|
||||||
bneck_confs(1.0, 1.0)
|
|
||||||
}
|
|
||||||
fn b1() -> Vec<Self> {
|
|
||||||
bneck_confs(1.0, 1.1)
|
|
||||||
}
|
|
||||||
fn b2() -> Vec<Self> {
|
|
||||||
bneck_confs(1.1, 1.2)
|
|
||||||
}
|
|
||||||
fn b3() -> Vec<Self> {
|
|
||||||
bneck_confs(1.2, 1.4)
|
|
||||||
}
|
|
||||||
fn b4() -> Vec<Self> {
|
|
||||||
bneck_confs(1.4, 1.8)
|
|
||||||
}
|
|
||||||
fn b5() -> Vec<Self> {
|
|
||||||
bneck_confs(1.6, 2.2)
|
|
||||||
}
|
|
||||||
fn b6() -> Vec<Self> {
|
|
||||||
bneck_confs(1.8, 2.6)
|
|
||||||
}
|
|
||||||
fn b7() -> Vec<Self> {
|
|
||||||
bneck_confs(2.0, 3.1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Conv2D with same padding.
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct Conv2DSame {
|
|
||||||
conv2d: nn::Conv2d,
|
|
||||||
s: usize,
|
|
||||||
k: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Conv2DSame {
|
|
||||||
fn new(
|
|
||||||
vb: VarBuilder,
|
|
||||||
i: usize,
|
|
||||||
o: usize,
|
|
||||||
k: usize,
|
|
||||||
stride: usize,
|
|
||||||
groups: usize,
|
|
||||||
bias: bool,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let conv_config = nn::Conv2dConfig {
|
|
||||||
stride,
|
|
||||||
groups,
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
let conv2d = if bias {
|
|
||||||
nn::conv2d(i, o, k, conv_config, vb)?
|
|
||||||
} else {
|
|
||||||
nn::conv2d_no_bias(i, o, k, conv_config, vb)?
|
|
||||||
};
|
|
||||||
Ok(Self {
|
|
||||||
conv2d,
|
|
||||||
s: stride,
|
|
||||||
k,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for Conv2DSame {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let s = self.s;
|
|
||||||
let k = self.k;
|
|
||||||
let (_, _, ih, iw) = xs.dims4()?;
|
|
||||||
let oh = (ih + s - 1) / s;
|
|
||||||
let ow = (iw + s - 1) / s;
|
|
||||||
let pad_h = usize::max((oh - 1) * s + k - ih, 0);
|
|
||||||
let pad_w = usize::max((ow - 1) * s + k - iw, 0);
|
|
||||||
if pad_h > 0 || pad_w > 0 {
|
|
||||||
let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?;
|
|
||||||
let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?;
|
|
||||||
self.conv2d.forward(&xs)
|
|
||||||
} else {
|
|
||||||
self.conv2d.forward(xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct ConvNormActivation {
|
|
||||||
conv2d: Conv2DSame,
|
|
||||||
bn2d: nn::BatchNorm,
|
|
||||||
activation: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ConvNormActivation {
|
|
||||||
fn new(
|
|
||||||
vb: VarBuilder,
|
|
||||||
i: usize,
|
|
||||||
o: usize,
|
|
||||||
k: usize,
|
|
||||||
stride: usize,
|
|
||||||
groups: usize,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?;
|
|
||||||
let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?;
|
|
||||||
Ok(Self {
|
|
||||||
conv2d,
|
|
||||||
bn2d,
|
|
||||||
activation: true,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn no_activation(self) -> Self {
|
|
||||||
Self {
|
|
||||||
activation: false,
|
|
||||||
..self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for ConvNormActivation {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let xs = self.conv2d.forward(xs)?;
|
|
||||||
let xs = self.bn2d.forward(&xs)?;
|
|
||||||
if self.activation {
|
|
||||||
swish(&xs)
|
|
||||||
} else {
|
|
||||||
Ok(xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct SqueezeExcitation {
|
|
||||||
fc1: Conv2DSame,
|
|
||||||
fc2: Conv2DSame,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SqueezeExcitation {
|
|
||||||
fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> {
|
|
||||||
let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?;
|
|
||||||
let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?;
|
|
||||||
Ok(Self { fc1, fc2 })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for SqueezeExcitation {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let residual = xs;
|
|
||||||
// equivalent to adaptive_avg_pool2d([1, 1])
|
|
||||||
let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
|
|
||||||
let xs = self.fc1.forward(&xs)?;
|
|
||||||
let xs = swish(&xs)?;
|
|
||||||
let xs = self.fc2.forward(&xs)?;
|
|
||||||
let xs = nn::ops::sigmoid(&xs)?;
|
|
||||||
residual.broadcast_mul(&xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct MBConv {
|
|
||||||
expand_cna: Option<ConvNormActivation>,
|
|
||||||
depthwise_cna: ConvNormActivation,
|
|
||||||
squeeze_excitation: SqueezeExcitation,
|
|
||||||
project_cna: ConvNormActivation,
|
|
||||||
config: MBConvConfig,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MBConv {
|
|
||||||
fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> {
|
|
||||||
let vb = vb.pp("block");
|
|
||||||
let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8);
|
|
||||||
let expand_cna = if exp != c.input_channels {
|
|
||||||
Some(ConvNormActivation::new(
|
|
||||||
vb.pp("0"),
|
|
||||||
c.input_channels,
|
|
||||||
exp,
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
)?)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let start_index = if expand_cna.is_some() { 1 } else { 0 };
|
|
||||||
let depthwise_cna =
|
|
||||||
ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?;
|
|
||||||
let squeeze_channels = usize::max(1, c.input_channels / 4);
|
|
||||||
let squeeze_excitation =
|
|
||||||
SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?;
|
|
||||||
let project_cna =
|
|
||||||
ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)?
|
|
||||||
.no_activation();
|
|
||||||
Ok(Self {
|
|
||||||
expand_cna,
|
|
||||||
depthwise_cna,
|
|
||||||
squeeze_excitation,
|
|
||||||
project_cna,
|
|
||||||
config: c,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for MBConv {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let use_res_connect =
|
|
||||||
self.config.stride == 1 && self.config.input_channels == self.config.out_channels;
|
|
||||||
let ys = match &self.expand_cna {
|
|
||||||
Some(expand_cna) => expand_cna.forward(xs)?,
|
|
||||||
None => xs.clone(),
|
|
||||||
};
|
|
||||||
let ys = self.depthwise_cna.forward(&ys)?;
|
|
||||||
let ys = self.squeeze_excitation.forward(&ys)?;
|
|
||||||
let ys = self.project_cna.forward(&ys)?;
|
|
||||||
if use_res_connect {
|
|
||||||
ys + xs
|
|
||||||
} else {
|
|
||||||
Ok(ys)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn swish(s: &Tensor) -> Result<Tensor> {
|
|
||||||
s * nn::ops::sigmoid(s)?
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct EfficientNet {
|
|
||||||
init_cna: ConvNormActivation,
|
|
||||||
blocks: Vec<MBConv>,
|
|
||||||
final_cna: ConvNormActivation,
|
|
||||||
classifier: nn::Linear,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EfficientNet {
|
|
||||||
fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> {
|
|
||||||
let f_p = p.pp("features");
|
|
||||||
let first_in_c = configs[0].input_channels;
|
|
||||||
let last_out_c = configs.last().unwrap().out_channels;
|
|
||||||
let final_out_c = 4 * last_out_c;
|
|
||||||
let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
|
|
||||||
let nconfigs = configs.len();
|
|
||||||
let mut blocks = vec![];
|
|
||||||
for (index, cnf) in configs.into_iter().enumerate() {
|
|
||||||
let f_p = f_p.pp(index + 1);
|
|
||||||
for r_index in 0..cnf.num_layers {
|
|
||||||
let cnf = if r_index == 0 {
|
|
||||||
cnf
|
|
||||||
} else {
|
|
||||||
MBConvConfig {
|
|
||||||
input_channels: cnf.out_channels,
|
|
||||||
stride: 1,
|
|
||||||
..cnf
|
|
||||||
}
|
|
||||||
};
|
|
||||||
blocks.push(MBConv::new(f_p.pp(r_index), cnf)?)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let final_cna =
|
|
||||||
ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?;
|
|
||||||
let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?;
|
|
||||||
Ok(Self {
|
|
||||||
init_cna,
|
|
||||||
blocks,
|
|
||||||
final_cna,
|
|
||||||
classifier,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for EfficientNet {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let mut xs = self.init_cna.forward(xs)?;
|
|
||||||
for block in self.blocks.iter() {
|
|
||||||
xs = block.forward(&xs)?
|
|
||||||
}
|
|
||||||
let xs = self.final_cna.forward(&xs)?;
|
|
||||||
// Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1)
|
|
||||||
let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?;
|
|
||||||
self.classifier.forward(&xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
enum Which {
|
enum Which {
|
||||||
B0,
|
B0,
|
||||||
@ -397,9 +68,7 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
Some(model) => model.into(),
|
Some(model) => model.into(),
|
||||||
};
|
};
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||||
let weights = weights.deserialize()?;
|
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
|
||||||
let cfg = match args.which {
|
let cfg = match args.which {
|
||||||
Which::B0 => MBConvConfig::b0(),
|
Which::B0 => MBConvConfig::b0(),
|
||||||
Which::B1 => MBConvConfig::b1(),
|
Which::B1 => MBConvConfig::b1(),
|
||||||
|
3
candle-examples/examples/falcon/README.md
Normal file
3
candle-examples/examples/falcon/README.md
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# candle-falcon
|
||||||
|
|
||||||
|
Falcon is a general large language model.
|
@ -14,8 +14,7 @@ use clap::Parser;
|
|||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
mod model;
|
use candle_transformers::models::falcon::{Config, Falcon};
|
||||||
use model::{Config, Falcon};
|
|
||||||
|
|
||||||
struct TextGeneration {
|
struct TextGeneration {
|
||||||
model: Falcon,
|
model: Falcon,
|
||||||
@ -26,17 +25,25 @@ struct TextGeneration {
|
|||||||
repeat_last_n: usize,
|
repeat_last_n: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct GenerationOptions {
|
||||||
|
temp: Option<f64>,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
impl TextGeneration {
|
impl TextGeneration {
|
||||||
fn new(
|
fn new(
|
||||||
model: Falcon,
|
model: Falcon,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
|
generation_options: GenerationOptions,
|
||||||
seed: u64,
|
seed: u64,
|
||||||
temp: Option<f64>,
|
|
||||||
device: &Device,
|
device: &Device,
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let logits_processor = LogitsProcessor::new(seed, temp);
|
let logits_processor =
|
||||||
|
LogitsProcessor::new(seed, generation_options.temp, generation_options.top_p);
|
||||||
|
let repeat_penalty = generation_options.repeat_penalty;
|
||||||
|
let repeat_last_n = generation_options.repeat_last_n;
|
||||||
Self {
|
Self {
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@ -119,6 +126,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
temperature: Option<f64>,
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
/// The seed to use when generating random samples.
|
||||||
#[arg(long, default_value_t = 299792458)]
|
#[arg(long, default_value_t = 299792458)]
|
||||||
seed: u64,
|
seed: u64,
|
||||||
@ -166,35 +177,25 @@ fn main() -> Result<()> {
|
|||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let weights = filenames
|
|
||||||
.iter()
|
|
||||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
let weights = weights
|
|
||||||
.iter()
|
|
||||||
.map(|f| Ok(f.deserialize()?))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
|
|
||||||
let dtype = if args.use_f32 {
|
let dtype = if args.use_f32 {
|
||||||
DType::F32
|
DType::F32
|
||||||
} else {
|
} else {
|
||||||
DType::BF16
|
DType::BF16
|
||||||
};
|
};
|
||||||
let vb = VarBuilder::from_safetensors(weights, dtype, &device);
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
let config = Config::falcon7b();
|
let config = Config::falcon7b();
|
||||||
config.validate()?;
|
config.validate()?;
|
||||||
let model = Falcon::load(vb, config)?;
|
let model = Falcon::load(vb, config)?;
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
let mut pipeline = TextGeneration::new(
|
let generation_options = GenerationOptions {
|
||||||
model,
|
temp: args.temperature,
|
||||||
tokenizer,
|
top_p: args.top_p,
|
||||||
args.seed,
|
repeat_penalty: args.repeat_penalty,
|
||||||
args.temperature,
|
repeat_last_n: args.repeat_last_n,
|
||||||
&device,
|
};
|
||||||
args.repeat_penalty,
|
let mut pipeline =
|
||||||
args.repeat_last_n,
|
TextGeneration::new(model, tokenizer, generation_options, args.seed, &device);
|
||||||
);
|
|
||||||
pipeline.run(&args.prompt, args.sample_len)?;
|
pipeline.run(&args.prompt, args.sample_len)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -21,11 +21,10 @@ use candle_transformers::generation::LogitsProcessor;
|
|||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
||||||
mod model;
|
use candle_transformers::models::llama as model;
|
||||||
use model::{Config, Llama, LlamaConfig};
|
use model::{Config, Llama, LlamaConfig};
|
||||||
|
|
||||||
const EOS_TOKEN: &str = "</s>";
|
const EOS_TOKEN: &str = "</s>";
|
||||||
const MAX_SEQ_LEN: usize = 4096;
|
|
||||||
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -43,6 +42,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
temperature: Option<f64>,
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
/// The seed to use when generating random samples.
|
||||||
#[arg(long, default_value_t = 299792458)]
|
#[arg(long, default_value_t = 299792458)]
|
||||||
seed: u64,
|
seed: u64,
|
||||||
@ -169,17 +172,9 @@ fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
println!("building the model");
|
println!("building the model");
|
||||||
let handles = filenames
|
|
||||||
.iter()
|
|
||||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? }))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
let tensors: Vec<_> = handles
|
|
||||||
.iter()
|
|
||||||
.map(|h| Ok(h.deserialize()?))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||||
|
|
||||||
let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
|
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -194,7 +189,7 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
print!("{prompt}");
|
print!("{prompt}");
|
||||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
let mut index_pos = 0;
|
let mut index_pos = 0;
|
||||||
let mut token_generated = 0;
|
let mut token_generated = 0;
|
||||||
|
@ -27,6 +27,10 @@ struct InferenceCmd {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
temperature: Option<f64>,
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
#[arg(long, default_value = "")]
|
#[arg(long, default_value = "")]
|
||||||
prompt: String,
|
prompt: String,
|
||||||
|
|
||||||
@ -133,6 +137,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
None => {
|
None => {
|
||||||
let cmd = InferenceCmd {
|
let cmd = InferenceCmd {
|
||||||
temperature: None,
|
temperature: None,
|
||||||
|
top_p: None,
|
||||||
prompt: "".to_string(),
|
prompt: "".to_string(),
|
||||||
config: None,
|
config: None,
|
||||||
model_id: "karpathy/tinyllamas".to_string(),
|
model_id: "karpathy/tinyllamas".to_string(),
|
||||||
@ -256,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
let model = Llama::load(vb, &cache, config)?;
|
let model = Llama::load(vb, &cache, config)?;
|
||||||
|
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature);
|
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);
|
||||||
let mut index_pos = 0;
|
let mut index_pos = 0;
|
||||||
|
|
||||||
print!("{}", args.prompt);
|
print!("{}", args.prompt);
|
||||||
|
@ -89,6 +89,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
temperature: Option<f64>,
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
/// The seed to use when generating random samples.
|
||||||
#[arg(long, default_value_t = 299792458)]
|
#[arg(long, default_value_t = 299792458)]
|
||||||
seed: u64,
|
seed: u64,
|
||||||
@ -201,16 +205,9 @@ fn main() -> Result<()> {
|
|||||||
let cache = model::Cache::new(dtype, &config, &device)?;
|
let cache = model::Cache::new(dtype, &config, &device)?;
|
||||||
|
|
||||||
println!("building the model");
|
println!("building the model");
|
||||||
let handles = filenames
|
let vb = unsafe {
|
||||||
.iter()
|
candle_nn::var_builder::ShardedSafeTensors::var_builder(&filenames, dtype, &device)?
|
||||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? }))
|
};
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
let tensors: Vec<_> = handles
|
|
||||||
.iter()
|
|
||||||
.map(|h| Ok(h.deserialize()?))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
|
|
||||||
let vb = candle_nn::var_builder::ShardedSafeTensors::var_builder(tensors, dtype, &device);
|
|
||||||
let llama = Llama::load(vb, &cache, &config, comm)?;
|
let llama = Llama::load(vb, &cache, &config, comm)?;
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
@ -222,7 +219,7 @@ fn main() -> Result<()> {
|
|||||||
.to_vec();
|
.to_vec();
|
||||||
|
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
|
||||||
let mut new_tokens = vec![];
|
let mut new_tokens = vec![];
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
let mut index_pos = 0;
|
let mut index_pos = 0;
|
||||||
|
90
candle-examples/examples/mistral/README.md
Normal file
90
candle-examples/examples/mistral/README.md
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
# candle-mistral: 7b LLM with Apache 2.0 licensed weights
|
||||||
|
|
||||||
|
Mistral-7B-v0.1 is a pretrained generative LLM with 7 billion parameters. It outperforms all the publicly available 13b models
|
||||||
|
as of 2023-09-28. Weights (and the original Python model code) are released under the permissive Apache 2.0 license.
|
||||||
|
|
||||||
|
- [Blog post](https://mistral.ai/news/announcing-mistral-7b/) from Mistral announcing the model release.
|
||||||
|
- [Model card](https://huggingface.co/mistralai/Mistral-7B-v0.1) on the
|
||||||
|
HuggingFace Hub.
|
||||||
|
This example supports the initial model as well as a quantized variant.
|
||||||
|
|
||||||
|
## Running the example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example mistral --release --features cuda -- --prompt 'Write helloworld code in Rust' --sample-len 150
|
||||||
|
|
||||||
|
Generated text:
|
||||||
|
Write helloworld code in Rust
|
||||||
|
=============================
|
||||||
|
|
||||||
|
This is a simple example of how to write "Hello, world!" program in Rust.
|
||||||
|
|
||||||
|
## Compile and run
|
||||||
|
|
||||||
|
``bash
|
||||||
|
$ cargo build --release
|
||||||
|
Compiling hello-world v0.1.0 (/home/user/rust/hello-world)
|
||||||
|
Finished release [optimized] target(s) in 0.26s
|
||||||
|
$ ./target/release/hello-world
|
||||||
|
Hello, world!
|
||||||
|
``
|
||||||
|
|
||||||
|
## Source code
|
||||||
|
|
||||||
|
``rust
|
||||||
|
fn main() {
|
||||||
|
println!("Hello, world!");
|
||||||
|
}
|
||||||
|
``
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This example is released under the terms
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running the quantized version of the model
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example mistral --features accelerate --release -- \
|
||||||
|
$ --prompt "Here is a sample quick sort implementation in rust " --quantized -n 400
|
||||||
|
avx: false, neon: true, simd128: false, f16c: false
|
||||||
|
temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
|
||||||
|
retrieved the files in 562.292µs
|
||||||
|
loaded the model in 1.100323667s
|
||||||
|
Here is a sample quick sort implementation in rust
|
||||||
|
|
||||||
|
``rust
|
||||||
|
fn quick_sort(arr: &mut [i32]) {
|
||||||
|
if arr.len() <= 1 {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let pivot = arr[0];
|
||||||
|
let mut left = vec![];
|
||||||
|
let mut right = vec![];
|
||||||
|
|
||||||
|
for i in 1..arr.len() {
|
||||||
|
if arr[i] < pivot {
|
||||||
|
left.push(arr[i]);
|
||||||
|
} else {
|
||||||
|
right.push(arr[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
quick_sort(&mut left);
|
||||||
|
quick_sort(&mut right);
|
||||||
|
|
||||||
|
let mut i = 0;
|
||||||
|
for _ in &left {
|
||||||
|
arr[i] = left.pop().unwrap();
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for _ in &right {
|
||||||
|
arr[i] = right.pop().unwrap();
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
``
|
||||||
|
226 tokens generated (10.91 token/s)
|
||||||
|
```
|
271
candle-examples/examples/mistral/main.rs
Normal file
271
candle-examples/examples/mistral/main.rs
Normal file
@ -0,0 +1,271 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle_transformers::models::mistral::{Config, Model as Mistral};
|
||||||
|
use candle_transformers::models::quantized_mistral::Model as QMistral;
|
||||||
|
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
enum Model {
|
||||||
|
Mistral(Mistral),
|
||||||
|
Quantized(QMistral),
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TextGeneration {
|
||||||
|
model: Model,
|
||||||
|
device: Device,
|
||||||
|
tokenizer: TokenOutputStream,
|
||||||
|
logits_processor: LogitsProcessor,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextGeneration {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn new(
|
||||||
|
model: Model,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
seed: u64,
|
||||||
|
temp: Option<f64>,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
device: &Device,
|
||||||
|
) -> Self {
|
||||||
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
tokenizer: TokenOutputStream::new(tokenizer),
|
||||||
|
logits_processor,
|
||||||
|
repeat_penalty,
|
||||||
|
repeat_last_n,
|
||||||
|
device: device.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||||
|
use std::io::Write;
|
||||||
|
self.tokenizer.clear();
|
||||||
|
let mut tokens = self
|
||||||
|
.tokenizer
|
||||||
|
.tokenizer()
|
||||||
|
.encode(prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
for &t in tokens.iter() {
|
||||||
|
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||||
|
print!("{t}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
|
||||||
|
let mut generated_tokens = 0usize;
|
||||||
|
let eos_token = match self.tokenizer.get_token("</s>") {
|
||||||
|
Some(token) => token,
|
||||||
|
None => anyhow::bail!("cannot find the </s> token"),
|
||||||
|
};
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
for index in 0..sample_len {
|
||||||
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
|
let start_pos = tokens.len().saturating_sub(context_size);
|
||||||
|
let ctxt = &tokens[start_pos..];
|
||||||
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
|
let logits = match &mut self.model {
|
||||||
|
Model::Mistral(m) => m.forward(&input, start_pos)?,
|
||||||
|
Model::Quantized(m) => m.forward(&input, start_pos)?,
|
||||||
|
};
|
||||||
|
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
let logits = if self.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
self.repeat_penalty,
|
||||||
|
&tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
generated_tokens += 1;
|
||||||
|
if next_token == eos_token {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let dt = start_gen.elapsed();
|
||||||
|
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
println!(
|
||||||
|
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||||
|
generated_tokens as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
use_flash_attn: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long)]
|
||||||
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(long, short = 'n', default_value_t = 100)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "lmz/candle-mistral")]
|
||||||
|
model_id: String,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "main")]
|
||||||
|
revision: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
weight_files: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
quantized: bool,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature.unwrap_or(0.),
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let api = Api::new()?;
|
||||||
|
let repo = api.repo(Repo::with_revision(
|
||||||
|
args.model_id,
|
||||||
|
RepoType::Model,
|
||||||
|
args.revision,
|
||||||
|
));
|
||||||
|
let tokenizer_filename = match args.tokenizer_file {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => repo.get("tokenizer.json")?,
|
||||||
|
};
|
||||||
|
let filenames = match args.weight_files {
|
||||||
|
Some(files) => files
|
||||||
|
.split(',')
|
||||||
|
.map(std::path::PathBuf::from)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
None => {
|
||||||
|
if args.quantized {
|
||||||
|
vec![repo.get("model-q4k.gguf")?]
|
||||||
|
} else {
|
||||||
|
vec![
|
||||||
|
repo.get("pytorch_model-00001-of-00002.safetensors")?,
|
||||||
|
repo.get("pytorch_model-00002-of-00002.safetensors")?,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let config = Config::config_7b_v0_1(args.use_flash_attn);
|
||||||
|
let (model, device) = if args.quantized {
|
||||||
|
let filename = &filenames[0];
|
||||||
|
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
|
||||||
|
let model = QMistral::new(&config, vb)?;
|
||||||
|
(Model::Quantized(model), Device::Cpu)
|
||||||
|
} else {
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let dtype = if device.is_cuda() {
|
||||||
|
DType::BF16
|
||||||
|
} else {
|
||||||
|
DType::F32
|
||||||
|
};
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
|
let model = Mistral::new(&config, vb)?;
|
||||||
|
(Model::Mistral(model), device)
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let mut pipeline = TextGeneration::new(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
args.seed,
|
||||||
|
args.temperature,
|
||||||
|
args.top_p,
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
pipeline.run(&args.prompt, args.sample_len)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -1,6 +1,6 @@
|
|||||||
use crate::nn::conv1d_weight_norm;
|
use crate::nn::conv1d_weight_norm;
|
||||||
use candle::{DType, IndexOp, Result, Tensor};
|
use candle::{DType, IndexOp, Module, Result, Tensor};
|
||||||
use candle_nn::{conv1d, Conv1d, Conv1dConfig, Module, VarBuilder};
|
use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
|
||||||
|
|
||||||
// Encodec Model
|
// Encodec Model
|
||||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
|
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
|
||||||
@ -199,25 +199,34 @@ impl EncodecResidualVectorQuantizer {
|
|||||||
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
|
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct EncodecLSTM {
|
struct EncodecLSTM {
|
||||||
layers: Vec<(Tensor, Tensor, Tensor, Tensor)>,
|
layers: Vec<candle_nn::LSTM>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EncodecLSTM {
|
impl EncodecLSTM {
|
||||||
fn load(dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let vb = &vb.pp("lstm");
|
let vb = &vb.pp("lstm");
|
||||||
let mut layers = vec![];
|
let mut layers = vec![];
|
||||||
for i in 0..cfg.num_lstm_layers {
|
for layer_idx in 0..cfg.num_lstm_layers {
|
||||||
let w_hh = vb.get((4 * dim, dim), &format!("weight_hh_l{i}"))?;
|
let config = candle_nn::LSTMConfig {
|
||||||
let w_ih = vb.get((4 * dim, dim), &format!("weight_ih_l{i}"))?;
|
layer_idx,
|
||||||
let b_hh = vb.get(4 * dim, &format!("bias_hh_l{i}"))?;
|
..Default::default()
|
||||||
let b_ih = vb.get(4 * dim, &format!("bias_ih_l{i}"))?;
|
};
|
||||||
layers.push((w_hh, w_ih, b_hh, b_ih))
|
let lstm = candle_nn::lstm(dim, dim, config, vb.clone())?;
|
||||||
|
layers.push(lstm)
|
||||||
}
|
}
|
||||||
Ok(Self { layers })
|
Ok(Self { layers })
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
impl Module for EncodecLSTM {
|
||||||
todo!()
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
use candle_nn::RNN;
|
||||||
|
let mut xs = xs.clone();
|
||||||
|
for layer in self.layers.iter() {
|
||||||
|
let states = layer.seq(&xs)?;
|
||||||
|
xs = layer.states_to_tensor(&states)?;
|
||||||
|
}
|
||||||
|
Ok(xs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -247,7 +256,9 @@ impl EncodecConvTranspose1d {
|
|||||||
bias,
|
bias,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for EncodecConvTranspose1d {
|
||||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
@ -299,7 +310,9 @@ impl EncodecConv1d {
|
|||||||
conv,
|
conv,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for EncodecConv1d {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
// TODO: padding, depending on causal.
|
// TODO: padding, depending on causal.
|
||||||
let xs = self.conv.forward(xs)?;
|
let xs = self.conv.forward(xs)?;
|
||||||
@ -340,7 +353,9 @@ impl EncodecResnetBlock {
|
|||||||
shortcut,
|
shortcut,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for EncodecResnetBlock {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let residual = xs.clone();
|
let residual = xs.clone();
|
||||||
let xs = xs.elu(1.)?;
|
let xs = xs.elu(1.)?;
|
||||||
@ -439,8 +454,17 @@ impl EncodecEncoder {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
todo!()
|
let mut xs = xs.apply(&self.init_conv)?;
|
||||||
|
for (resnets, conv) in self.sampling_layers.iter() {
|
||||||
|
for resnet in resnets.iter() {
|
||||||
|
xs = xs.apply(resnet)?;
|
||||||
|
}
|
||||||
|
xs = xs.elu(1.0)?.apply(conv)?;
|
||||||
|
}
|
||||||
|
xs.apply(&self.final_lstm)?
|
||||||
|
.elu(1.0)?
|
||||||
|
.apply(&self.final_conv)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -507,8 +531,15 @@ impl EncodecDecoder {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
todo!()
|
let mut xs = xs.apply(&self.init_conv)?.apply(&self.init_lstm)?;
|
||||||
|
for (conv, resnets) in self.sampling_layers.iter() {
|
||||||
|
xs = xs.elu(1.)?.apply(conv)?;
|
||||||
|
for resnet in resnets.iter() {
|
||||||
|
xs = xs.apply(resnet)?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
xs.elu(1.)?.apply(&self.final_conv)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -13,7 +13,6 @@ extern crate accelerate_src;
|
|||||||
mod encodec_model;
|
mod encodec_model;
|
||||||
mod musicgen_model;
|
mod musicgen_model;
|
||||||
mod nn;
|
mod nn;
|
||||||
mod t5_model;
|
|
||||||
|
|
||||||
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
|
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
|
||||||
|
|
||||||
@ -74,11 +73,9 @@ fn main() -> Result<()> {
|
|||||||
))
|
))
|
||||||
.get("model.safetensors")?,
|
.get("model.safetensors")?,
|
||||||
};
|
};
|
||||||
let model = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DTYPE, &device)? };
|
||||||
let model = model.deserialize()?;
|
|
||||||
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
|
|
||||||
let config = GenConfig::small();
|
let config = GenConfig::small();
|
||||||
let model = MusicgenForConditionalGeneration::load(vb, config)?;
|
let mut model = MusicgenForConditionalGeneration::load(vb, config)?;
|
||||||
|
|
||||||
let tokens = tokenizer
|
let tokens = tokenizer
|
||||||
.encode(args.prompt.as_str(), true)
|
.encode(args.prompt.as_str(), true)
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
use crate::{encodec_model, t5_model};
|
use crate::encodec_model;
|
||||||
use candle::{DType, Device, Result, Tensor, D};
|
use candle::{DType, Device, Result, Tensor, D};
|
||||||
use candle_nn::{
|
use candle_nn::{
|
||||||
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
|
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
|
||||||
VarBuilder,
|
VarBuilder,
|
||||||
};
|
};
|
||||||
|
use candle_transformers::models::t5;
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
|
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
@ -39,7 +40,7 @@ impl Default for Config {
|
|||||||
num_attention_heads: 16,
|
num_attention_heads: 16,
|
||||||
layerdrop: 0.0,
|
layerdrop: 0.0,
|
||||||
use_cache: true,
|
use_cache: true,
|
||||||
activation_function: Activation::Gelu, // TODO: Handle old style gelu.
|
activation_function: Activation::Gelu,
|
||||||
hidden_size: 1024,
|
hidden_size: 1024,
|
||||||
dropout: 0.1,
|
dropout: 0.1,
|
||||||
attention_dropout: 0.0,
|
attention_dropout: 0.0,
|
||||||
@ -65,7 +66,7 @@ impl Config {
|
|||||||
num_attention_heads: 16,
|
num_attention_heads: 16,
|
||||||
layerdrop: 0.0,
|
layerdrop: 0.0,
|
||||||
use_cache: true,
|
use_cache: true,
|
||||||
activation_function: Activation::Gelu, // TODO: Handle old style gelu.
|
activation_function: Activation::Gelu,
|
||||||
hidden_size: 1024,
|
hidden_size: 1024,
|
||||||
dropout: 0.1,
|
dropout: 0.1,
|
||||||
attention_dropout: 0.0,
|
attention_dropout: 0.0,
|
||||||
@ -370,7 +371,7 @@ impl MusicgenForCausalLM {
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct MusicgenForConditionalGeneration {
|
pub struct MusicgenForConditionalGeneration {
|
||||||
pub text_encoder: crate::t5_model::T5EncoderModel,
|
pub text_encoder: t5::T5EncoderModel,
|
||||||
pub audio_encoder: crate::encodec_model::EncodecModel,
|
pub audio_encoder: crate::encodec_model::EncodecModel,
|
||||||
pub decoder: MusicgenForCausalLM,
|
pub decoder: MusicgenForCausalLM,
|
||||||
cfg: GenConfig,
|
cfg: GenConfig,
|
||||||
@ -379,7 +380,7 @@ pub struct MusicgenForConditionalGeneration {
|
|||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub struct GenConfig {
|
pub struct GenConfig {
|
||||||
musicgen: Config,
|
musicgen: Config,
|
||||||
t5: crate::t5_model::Config,
|
t5: t5::Config,
|
||||||
encodec: crate::encodec_model::Config,
|
encodec: crate::encodec_model::Config,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -387,7 +388,7 @@ impl GenConfig {
|
|||||||
pub fn small() -> Self {
|
pub fn small() -> Self {
|
||||||
Self {
|
Self {
|
||||||
musicgen: Config::musicgen_small(),
|
musicgen: Config::musicgen_small(),
|
||||||
t5: t5_model::Config::musicgen_small(),
|
t5: t5::Config::musicgen_small(),
|
||||||
encodec: encodec_model::Config::musicgen_small(),
|
encodec: encodec_model::Config::musicgen_small(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -399,7 +400,7 @@ impl MusicgenForConditionalGeneration {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> {
|
||||||
let text_encoder = t5_model::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
|
let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
|
||||||
let audio_encoder =
|
let audio_encoder =
|
||||||
encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?;
|
encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?;
|
||||||
let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?;
|
let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?;
|
||||||
|
@ -1,434 +0,0 @@
|
|||||||
// T5 Text Encoder
|
|
||||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
|
||||||
|
|
||||||
use candle::{DType, Result, Tensor, D};
|
|
||||||
use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, Module, VarBuilder};
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
|
||||||
pub struct Config {
|
|
||||||
vocab_size: usize,
|
|
||||||
d_model: usize,
|
|
||||||
d_kv: usize,
|
|
||||||
d_ff: usize,
|
|
||||||
num_layers: usize,
|
|
||||||
num_decoder_layers: Option<usize>,
|
|
||||||
num_heads: usize,
|
|
||||||
relative_attention_num_buckets: usize,
|
|
||||||
relative_attention_max_distance: usize,
|
|
||||||
dropout_rate: f64,
|
|
||||||
layer_norm_epsilon: f64,
|
|
||||||
initializer_factor: f64,
|
|
||||||
feed_forward_proj: Activation,
|
|
||||||
is_decoder: bool,
|
|
||||||
is_encoder_decoder: bool,
|
|
||||||
use_cache: bool,
|
|
||||||
pad_token_id: usize,
|
|
||||||
eos_token_id: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for Config {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
vocab_size: 32128,
|
|
||||||
d_model: 512,
|
|
||||||
d_kv: 64,
|
|
||||||
d_ff: 2048,
|
|
||||||
num_layers: 6,
|
|
||||||
num_decoder_layers: None,
|
|
||||||
num_heads: 8,
|
|
||||||
relative_attention_num_buckets: 32,
|
|
||||||
relative_attention_max_distance: 128,
|
|
||||||
dropout_rate: 0.1,
|
|
||||||
layer_norm_epsilon: 1e-6,
|
|
||||||
initializer_factor: 1.0,
|
|
||||||
feed_forward_proj: Activation::Relu,
|
|
||||||
is_decoder: false,
|
|
||||||
is_encoder_decoder: true,
|
|
||||||
use_cache: true,
|
|
||||||
pad_token_id: 0,
|
|
||||||
eos_token_id: 1,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Config {
|
|
||||||
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L184
|
|
||||||
pub fn musicgen_small() -> Self {
|
|
||||||
Self {
|
|
||||||
d_ff: 3072,
|
|
||||||
d_kv: 64,
|
|
||||||
d_model: 768,
|
|
||||||
dropout_rate: 0.1,
|
|
||||||
eos_token_id: 1,
|
|
||||||
feed_forward_proj: Activation::Relu,
|
|
||||||
initializer_factor: 1.0,
|
|
||||||
is_decoder: false,
|
|
||||||
is_encoder_decoder: true,
|
|
||||||
layer_norm_epsilon: 1e-6,
|
|
||||||
num_decoder_layers: Some(12),
|
|
||||||
num_heads: 12,
|
|
||||||
num_layers: 12,
|
|
||||||
pad_token_id: 0,
|
|
||||||
relative_attention_max_distance: 128,
|
|
||||||
relative_attention_num_buckets: 32,
|
|
||||||
use_cache: true,
|
|
||||||
vocab_size: 32128,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct T5LayerNorm {
|
|
||||||
weight: Tensor,
|
|
||||||
variance_epsilon: f64,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl T5LayerNorm {
|
|
||||||
fn load(h: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let weight = vb.get(h, "weight")?;
|
|
||||||
Ok(Self {
|
|
||||||
weight,
|
|
||||||
variance_epsilon: eps,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let dtype = xs.dtype();
|
|
||||||
let xs_f32 = xs.to_dtype(DType::F32)?;
|
|
||||||
// variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
|
||||||
let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
|
|
||||||
let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
|
|
||||||
let xs = xs.to_dtype(dtype)?;
|
|
||||||
let xs = xs.broadcast_mul(&self.weight)?;
|
|
||||||
Ok(xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct T5DenseActDense {
|
|
||||||
wi: Linear,
|
|
||||||
wo: Linear,
|
|
||||||
act: Activation,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl T5DenseActDense {
|
|
||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
|
||||||
let wi = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
|
|
||||||
let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
|
|
||||||
Ok(Self {
|
|
||||||
wi,
|
|
||||||
wo,
|
|
||||||
act: Activation::Relu,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let xs = self.wi.forward(xs)?;
|
|
||||||
let xs = self.act.forward(&xs)?;
|
|
||||||
let xs = self.wo.forward(&xs)?;
|
|
||||||
Ok(xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct T5LayerFF {
|
|
||||||
dense_relu_dense: T5DenseActDense,
|
|
||||||
layer_norm: T5LayerNorm,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl T5LayerFF {
|
|
||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
|
||||||
// is_gated_act is not supported.
|
|
||||||
let dense_relu_dense = T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?;
|
|
||||||
let layer_norm =
|
|
||||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
|
||||||
Ok(Self {
|
|
||||||
dense_relu_dense,
|
|
||||||
layer_norm,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let ys = self.layer_norm.forward(xs)?;
|
|
||||||
let ys = self.dense_relu_dense.forward(&ys)?;
|
|
||||||
let xs = (xs + ys)?;
|
|
||||||
Ok(xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct T5Attention {
|
|
||||||
q: Linear,
|
|
||||||
k: Linear,
|
|
||||||
v: Linear,
|
|
||||||
o: Linear,
|
|
||||||
n_heads: usize,
|
|
||||||
d_kv: usize,
|
|
||||||
relative_attention_bias: Option<Embedding>,
|
|
||||||
relative_attention_num_buckets: usize,
|
|
||||||
relative_attention_max_distance: usize,
|
|
||||||
inner_dim: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl T5Attention {
|
|
||||||
fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
|
||||||
let inner_dim = cfg.num_heads * cfg.d_kv;
|
|
||||||
let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?;
|
|
||||||
let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?;
|
|
||||||
let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?;
|
|
||||||
let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?;
|
|
||||||
let relative_attention_bias = if h {
|
|
||||||
let emb = embedding(
|
|
||||||
cfg.relative_attention_num_buckets,
|
|
||||||
cfg.num_heads,
|
|
||||||
vb.pp("relative_attention_bias"),
|
|
||||||
)?;
|
|
||||||
Some(emb)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(Self {
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
o,
|
|
||||||
n_heads: cfg.num_heads,
|
|
||||||
d_kv: cfg.d_kv,
|
|
||||||
relative_attention_bias,
|
|
||||||
relative_attention_num_buckets: cfg.relative_attention_num_buckets,
|
|
||||||
relative_attention_max_distance: cfg.relative_attention_max_distance,
|
|
||||||
inner_dim,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(
|
|
||||||
&self,
|
|
||||||
xs: &Tensor,
|
|
||||||
position_bias: Option<&Tensor>,
|
|
||||||
) -> Result<(Tensor, Option<Tensor>)> {
|
|
||||||
// TODO: Apply the mask(s)?
|
|
||||||
// TODO: kv caching.
|
|
||||||
let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?);
|
|
||||||
let q = self.q.forward(xs)?;
|
|
||||||
let k = self.k.forward(xs)?;
|
|
||||||
let v = self.v.forward(xs)?;
|
|
||||||
let q = q
|
|
||||||
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
|
||||||
.transpose(1, 2)?
|
|
||||||
.contiguous()?;
|
|
||||||
let k = k
|
|
||||||
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
|
||||||
.transpose(1, 2)?
|
|
||||||
.contiguous()?;
|
|
||||||
let v = v
|
|
||||||
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
|
||||||
.transpose(1, 2)?
|
|
||||||
.contiguous()?;
|
|
||||||
let scores = q.matmul(&k.t()?)?;
|
|
||||||
|
|
||||||
let (scores, position_bias) = match position_bias {
|
|
||||||
Some(position_bias) => ((scores + position_bias)?, Some(position_bias.clone())),
|
|
||||||
None => match &self.relative_attention_bias {
|
|
||||||
None => (scores, None),
|
|
||||||
Some(relative_attention_bias) => {
|
|
||||||
let query_length = seq_len;
|
|
||||||
let key_length = seq_len;
|
|
||||||
// This only handles the bidirectional case.
|
|
||||||
let num_buckets = self.relative_attention_num_buckets as u32 / 2;
|
|
||||||
let max_exact = num_buckets / 2;
|
|
||||||
let relative_position = (0..query_length as u32)
|
|
||||||
.map(|i| {
|
|
||||||
(0..key_length as u32)
|
|
||||||
.map(|j| {
|
|
||||||
if i < j {
|
|
||||||
if j - i < max_exact {
|
|
||||||
j - i + num_buckets
|
|
||||||
} else {
|
|
||||||
let b = f32::log(
|
|
||||||
(j - i) as f32 / max_exact as f32,
|
|
||||||
self.relative_attention_max_distance as f32
|
|
||||||
/ max_exact as f32,
|
|
||||||
) * (num_buckets - max_exact) as f32;
|
|
||||||
u32::min(
|
|
||||||
max_exact + num_buckets + b as u32,
|
|
||||||
self.relative_attention_num_buckets as u32 - 1,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
} else if i - j < max_exact {
|
|
||||||
i - j
|
|
||||||
} else {
|
|
||||||
let b = f32::log(
|
|
||||||
(i - j) as f32 / max_exact as f32,
|
|
||||||
self.relative_attention_max_distance as f32
|
|
||||||
/ max_exact as f32,
|
|
||||||
) * (num_buckets - max_exact) as f32;
|
|
||||||
max_exact + b as u32
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Vec<u32>>()
|
|
||||||
})
|
|
||||||
.collect::<Vec<Vec<_>>>();
|
|
||||||
let relative_buckets = Tensor::new(relative_position, q.device())?;
|
|
||||||
let position_bias = relative_attention_bias
|
|
||||||
.forward(&relative_buckets)?
|
|
||||||
.permute((2, 0, 1))?
|
|
||||||
.unsqueeze(0)?;
|
|
||||||
((scores + &position_bias)?, Some(position_bias))
|
|
||||||
// TODO: position_bias_masked?
|
|
||||||
}
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
|
|
||||||
let attn_output = attn_weights.matmul(&v)?;
|
|
||||||
let attn_output = attn_output
|
|
||||||
.transpose(1, 2)?
|
|
||||||
.reshape((b_sz, seq_len, self.inner_dim))?;
|
|
||||||
let attn_output = self.o.forward(&attn_output)?;
|
|
||||||
Ok((attn_output, position_bias))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct T5LayerSelfAttention {
|
|
||||||
self_attention: T5Attention,
|
|
||||||
layer_norm: T5LayerNorm,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl T5LayerSelfAttention {
|
|
||||||
fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
|
||||||
let self_attention = T5Attention::load(h, vb.pp("SelfAttention"), cfg)?;
|
|
||||||
let layer_norm =
|
|
||||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
|
||||||
Ok(Self {
|
|
||||||
self_attention,
|
|
||||||
layer_norm,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(
|
|
||||||
&self,
|
|
||||||
xs: &Tensor,
|
|
||||||
position_bias: Option<&Tensor>,
|
|
||||||
) -> Result<(Tensor, Option<Tensor>)> {
|
|
||||||
let normed_xs = self.layer_norm.forward(xs)?;
|
|
||||||
let (ys, position_bias) = self.self_attention.forward(&normed_xs, position_bias)?;
|
|
||||||
let ys = (xs + ys)?;
|
|
||||||
Ok((ys, position_bias))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct T5LayerCrossAttention {}
|
|
||||||
|
|
||||||
impl T5LayerCrossAttention {
|
|
||||||
fn load(_vb: VarBuilder, _cfg: &Config) -> Result<Self> {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct T5Block {
|
|
||||||
self_attn: T5LayerSelfAttention,
|
|
||||||
cross_attn: Option<T5LayerCrossAttention>,
|
|
||||||
ff: T5LayerFF,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl T5Block {
|
|
||||||
fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
|
||||||
let vb = vb.pp("layer");
|
|
||||||
let self_attn = T5LayerSelfAttention::load(has_relative_attention_bias, vb.pp("0"), cfg)?;
|
|
||||||
let cross_attn = if cfg.is_decoder {
|
|
||||||
Some(T5LayerCrossAttention::load(vb.pp("1"), cfg)?)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let ff_i = if cross_attn.is_some() { 2 } else { 1 };
|
|
||||||
let ff = T5LayerFF::load(vb.pp(&ff_i.to_string()), cfg)?;
|
|
||||||
Ok(Self {
|
|
||||||
self_attn,
|
|
||||||
cross_attn,
|
|
||||||
ff,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(
|
|
||||||
&self,
|
|
||||||
xs: &Tensor,
|
|
||||||
position_bias: Option<&Tensor>,
|
|
||||||
) -> Result<(Tensor, Option<Tensor>)> {
|
|
||||||
let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias)?;
|
|
||||||
// TODO: clamp for f16?
|
|
||||||
if let Some(cross_attn) = &self.cross_attn {
|
|
||||||
xs = cross_attn.forward(&xs)?;
|
|
||||||
// TODO: clamp for f16?
|
|
||||||
}
|
|
||||||
let xs = self.ff.forward(&xs)?;
|
|
||||||
// TODO: clamp for f16?
|
|
||||||
Ok((xs, position_bias))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct T5Stack {
|
|
||||||
block: Vec<T5Block>,
|
|
||||||
shared: Arc<Embedding>,
|
|
||||||
final_layer_norm: T5LayerNorm,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl T5Stack {
|
|
||||||
fn load(vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> {
|
|
||||||
let block = (0..cfg.num_layers)
|
|
||||||
.map(|i| T5Block::load(i == 0, vb.pp(&format!("block.{i}")), cfg))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
let final_layer_norm = T5LayerNorm::load(
|
|
||||||
cfg.d_model,
|
|
||||||
cfg.layer_norm_epsilon,
|
|
||||||
vb.pp("final_layer_norm"),
|
|
||||||
)?;
|
|
||||||
Ok(Self {
|
|
||||||
block,
|
|
||||||
shared: shared.clone(),
|
|
||||||
final_layer_norm,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
|
||||||
let input_embeds = self.shared.as_ref().forward(input_ids)?;
|
|
||||||
let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?);
|
|
||||||
|
|
||||||
let mut hidden_states = input_embeds;
|
|
||||||
let mut position_bias = None;
|
|
||||||
for block in self.block.iter() {
|
|
||||||
(hidden_states, position_bias) =
|
|
||||||
block.forward(&hidden_states, position_bias.as_ref())?
|
|
||||||
}
|
|
||||||
let hidden_states = self.final_layer_norm.forward(&hidden_states)?;
|
|
||||||
Ok(hidden_states)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct T5EncoderModel {
|
|
||||||
shared: Arc<Embedding>,
|
|
||||||
encoder: T5Stack,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl T5EncoderModel {
|
|
||||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
|
||||||
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
|
||||||
let shared = Arc::new(shared);
|
|
||||||
let encoder = T5Stack::load(vb.pp("encoder"), &shared, cfg)?;
|
|
||||||
Ok(Self { shared, encoder })
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
|
||||||
let encoder_outputs = self.encoder.forward(input_ids)?;
|
|
||||||
Ok(encoder_outputs)
|
|
||||||
}
|
|
||||||
}
|
|
43
candle-examples/examples/phi/README.md
Normal file
43
candle-examples/examples/phi/README.md
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
# candle-phi: 1.3b LLM with state of the art performance for <10b models.
|
||||||
|
|
||||||
|
[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) is a language model using
|
||||||
|
only 1.3 billion parameters but with state of the art performance compared to
|
||||||
|
models with up to 10 billion parameters.
|
||||||
|
|
||||||
|
The candle implementation provides both the standard version as well as a
|
||||||
|
quantized variant.
|
||||||
|
|
||||||
|
## Running some example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example phi --release -- --prompt "def print_prime(n): "
|
||||||
|
|
||||||
|
def print_prime(n):
|
||||||
|
print("Printing prime numbers")
|
||||||
|
for i in range(2, n+1):
|
||||||
|
if is_prime(i):
|
||||||
|
print(i)
|
||||||
|
|
||||||
|
def is_prime(n):
|
||||||
|
if n <= 1:
|
||||||
|
return False
|
||||||
|
for i in range(2, int(math.sqrt(n))+1):
|
||||||
|
if n % i == 0:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
$ cargo run --example phi --release -- \
|
||||||
|
--prompt "Explain how to find the median of an array and write the corresponding python function.\nAnswer:" \
|
||||||
|
--quantized --sample-len 200
|
||||||
|
|
||||||
|
Explain how to find the median of an array and write the corresponding python function.
|
||||||
|
Answer: The median is the middle value in an array. If the array has an even number of elements, the median is the average of the two middle values.
|
||||||
|
|
||||||
|
def median(arr):
|
||||||
|
arr.sort()
|
||||||
|
n = len(arr)
|
||||||
|
if n % 2 == 0:
|
||||||
|
return (arr[n//2 - 1] + arr[n//2]) / 2
|
||||||
|
else:
|
||||||
|
return arr[n//2]
|
||||||
|
```
|
238
candle-examples/examples/phi/main.rs
Normal file
238
candle-examples/examples/phi/main.rs
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
||||||
|
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
||||||
|
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
enum Model {
|
||||||
|
MixFormer(MixFormer),
|
||||||
|
Quantized(QMixFormer),
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TextGeneration {
|
||||||
|
model: Model,
|
||||||
|
device: Device,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
logits_processor: LogitsProcessor,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextGeneration {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn new(
|
||||||
|
model: Model,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
seed: u64,
|
||||||
|
temp: Option<f64>,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
device: &Device,
|
||||||
|
) -> Self {
|
||||||
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
logits_processor,
|
||||||
|
repeat_penalty,
|
||||||
|
repeat_last_n,
|
||||||
|
device: device.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||||
|
use std::io::Write;
|
||||||
|
println!("starting the inference loop");
|
||||||
|
print!("{prompt}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
let mut tokens = self
|
||||||
|
.tokenizer
|
||||||
|
.encode(prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
|
||||||
|
let mut generated_tokens = 0usize;
|
||||||
|
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
||||||
|
Some(token) => *token,
|
||||||
|
None => anyhow::bail!("cannot find the endoftext token"),
|
||||||
|
};
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
for index in 0..sample_len {
|
||||||
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
|
let logits = match &mut self.model {
|
||||||
|
Model::MixFormer(m) => m.forward(&input)?,
|
||||||
|
Model::Quantized(m) => m.forward(&input)?,
|
||||||
|
};
|
||||||
|
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
let logits = if self.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
self.repeat_penalty,
|
||||||
|
&tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
generated_tokens += 1;
|
||||||
|
if next_token == eos_token {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||||
|
print!("{token}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
let dt = start_gen.elapsed();
|
||||||
|
println!(
|
||||||
|
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||||
|
generated_tokens as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long)]
|
||||||
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(long, short = 'n', default_value_t = 100)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "microsoft/phi-1_5")]
|
||||||
|
model_id: String,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "refs/pr/18")]
|
||||||
|
revision: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
weight_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
quantized: bool,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature.unwrap_or(0.),
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let api = Api::new()?;
|
||||||
|
let repo = api.repo(Repo::with_revision(
|
||||||
|
args.model_id,
|
||||||
|
RepoType::Model,
|
||||||
|
args.revision,
|
||||||
|
));
|
||||||
|
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||||
|
let filename = match args.weight_file {
|
||||||
|
Some(weight_file) => std::path::PathBuf::from(weight_file),
|
||||||
|
None => {
|
||||||
|
if args.quantized {
|
||||||
|
api.model("lmz/candle-quantized-phi".to_string())
|
||||||
|
.get("model-q4k.gguf")?
|
||||||
|
} else {
|
||||||
|
repo.get("model.safetensors")?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let config = Config::v1_5();
|
||||||
|
let (model, device) = if args.quantized {
|
||||||
|
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
|
||||||
|
let model = QMixFormer::new(&config, vb)?;
|
||||||
|
(Model::Quantized(model), Device::Cpu)
|
||||||
|
} else {
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
|
||||||
|
let model = MixFormer::new(&config, vb)?;
|
||||||
|
(Model::MixFormer(model), device)
|
||||||
|
};
|
||||||
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let mut pipeline = TextGeneration::new(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
args.seed,
|
||||||
|
args.temperature,
|
||||||
|
args.top_p,
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
pipeline.run(&args.prompt, args.sample_len)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
42
candle-examples/examples/quantized-t5/README.md
Normal file
42
candle-examples/examples/quantized-t5/README.md
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
# candle-quantized-t5
|
||||||
|
|
||||||
|
This example uses a quantized version of the t5 model.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example quantized-t5 --release -- --prompt "translate to German: A beautiful candle."
|
||||||
|
...
|
||||||
|
Eine schöne Kerze.
|
||||||
|
```
|
||||||
|
|
||||||
|
The weight file is automatically retrieved from the hub. It is also possible to
|
||||||
|
generate quantized weight files from the original safetensors file by using the
|
||||||
|
`tensor-tools` command line utility via:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
|
||||||
|
```
|
||||||
|
|
||||||
|
To use a different model, specify the `model-id`. For example, you can use
|
||||||
|
quantized [CoEdit models](https://huggingface.co/jbochi/candle-coedit-quantized).
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example quantized-t5 --release -- \
|
||||||
|
--model-id "jbochi/candle-coedit-quantized" \
|
||||||
|
--prompt "Make this text coherent: Their flight is weak. They run quickly through the tree canopy." \
|
||||||
|
--temperature 0
|
||||||
|
...
|
||||||
|
Although their flight is weak, they run quickly through the tree canopy.
|
||||||
|
|
||||||
|
By default, it will look for `model.gguf` and `config.json`, but you can specify
|
||||||
|
custom local or remote `weight-file` and `config-file`s:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example quantized-t5 --release -- \
|
||||||
|
--model-id "jbochi/candle-coedit-quantized" \
|
||||||
|
--weight-file "model-xl.gguf" \
|
||||||
|
--config-file "config-xl.json" \
|
||||||
|
--prompt "Rewrite to make this easier to understand: Note that a storm surge is what forecasters consider a hurricane's most treacherous aspect." \
|
||||||
|
--temperature 0
|
||||||
|
...
|
||||||
|
Note that a storm surge is what forecasters consider a hurricane's most dangerous part.
|
||||||
|
```
|
228
candle-examples/examples/quantized-t5/main.rs
Normal file
228
candle-examples/examples/quantized-t5/main.rs
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
use std::io::Write;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use candle_transformers::models::quantized_t5 as t5;
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use candle::{Device, Tensor};
|
||||||
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
use hf_hub::{api::sync::Api, api::sync::ApiRepo, Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
T5Small,
|
||||||
|
FlanT5Small,
|
||||||
|
FlanT5Base,
|
||||||
|
FlanT5Large,
|
||||||
|
FlanT5Xl,
|
||||||
|
FlanT5Xxl,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug, Clone)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
/// The model repository to use on the HuggingFace hub.
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
revision: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
weight_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
config_file: Option<String>,
|
||||||
|
|
||||||
|
// Enable/disable decoding.
|
||||||
|
#[arg(long, default_value = "false")]
|
||||||
|
disable_cache: bool,
|
||||||
|
|
||||||
|
/// Use this prompt, otherwise compute sentence similarities.
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long, default_value_t = 0.8)]
|
||||||
|
temperature: f64,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
|
||||||
|
/// The model size to use.
|
||||||
|
#[arg(long, default_value = "t5-small")]
|
||||||
|
which: Which,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct T5ModelBuilder {
|
||||||
|
device: Device,
|
||||||
|
config: t5::Config,
|
||||||
|
weights_filename: PathBuf,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl T5ModelBuilder {
|
||||||
|
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
|
||||||
|
let device = Device::Cpu;
|
||||||
|
let default_model = "lmz/candle-quantized-t5".to_string();
|
||||||
|
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
|
||||||
|
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||||
|
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||||
|
(None, Some(revision)) => (default_model, revision),
|
||||||
|
(None, None) => (default_model, "main".to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||||
|
let api = Api::new()?;
|
||||||
|
let api = api.repo(repo);
|
||||||
|
let config_filename = match &args.config_file {
|
||||||
|
Some(filename) => Self::get_local_or_remote_file(filename, &api)?,
|
||||||
|
None => match args.which {
|
||||||
|
Which::T5Small => api.get("config.json")?,
|
||||||
|
Which::FlanT5Small => api.get("config-flan-t5-small.json")?,
|
||||||
|
Which::FlanT5Base => api.get("config-flan-t5-base.json")?,
|
||||||
|
Which::FlanT5Large => api.get("config-flan-t5-large.json")?,
|
||||||
|
Which::FlanT5Xl => api.get("config-flan-t5-xl.json")?,
|
||||||
|
Which::FlanT5Xxl => api.get("config-flan-t5-xxl.json")?,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||||
|
let weights_filename = match &args.weight_file {
|
||||||
|
Some(filename) => Self::get_local_or_remote_file(filename, &api)?,
|
||||||
|
None => match args.which {
|
||||||
|
Which::T5Small => api.get("model.gguf")?,
|
||||||
|
Which::FlanT5Small => api.get("model-flan-t5-small.gguf")?,
|
||||||
|
Which::FlanT5Base => api.get("model-flan-t5-base.gguf")?,
|
||||||
|
Which::FlanT5Large => api.get("model-flan-t5-large.gguf")?,
|
||||||
|
Which::FlanT5Xl => api.get("model-flan-t5-xl.gguf")?,
|
||||||
|
Which::FlanT5Xxl => api.get("model-flan-t5-xxl.gguf")?,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
|
let mut config: t5::Config = serde_json::from_str(&config)?;
|
||||||
|
config.use_cache = !args.disable_cache;
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
Ok((
|
||||||
|
Self {
|
||||||
|
device,
|
||||||
|
config,
|
||||||
|
weights_filename,
|
||||||
|
},
|
||||||
|
tokenizer,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
|
||||||
|
let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
|
||||||
|
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_local_or_remote_file(filename: &str, api: &ApiRepo) -> Result<PathBuf> {
|
||||||
|
let local_filename = std::path::PathBuf::from(filename);
|
||||||
|
if local_filename.exists() {
|
||||||
|
Ok(local_filename)
|
||||||
|
} else {
|
||||||
|
Ok(api.get(filename)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
|
||||||
|
let device = &builder.device;
|
||||||
|
let tokenizer = tokenizer
|
||||||
|
.with_padding(None)
|
||||||
|
.with_truncation(None)
|
||||||
|
.map_err(E::msg)?;
|
||||||
|
let tokens = tokenizer
|
||||||
|
.encode(args.prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||||
|
let mut model = builder.build_model()?;
|
||||||
|
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
||||||
|
let temperature = if args.temperature <= 0. {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(args.temperature)
|
||||||
|
};
|
||||||
|
let mut logits_processor = LogitsProcessor::new(299792458, temperature, args.top_p);
|
||||||
|
let encoder_output = model.encode(&input_token_ids)?;
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
|
||||||
|
for index in 0.. {
|
||||||
|
if output_token_ids.len() > 512 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let decoder_token_ids = if index == 0 || !builder.config.use_cache {
|
||||||
|
Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
|
||||||
|
} else {
|
||||||
|
let last_token = *output_token_ids.last().unwrap();
|
||||||
|
Tensor::new(&[last_token], device)?.unsqueeze(0)?
|
||||||
|
};
|
||||||
|
let logits = model
|
||||||
|
.decode(&decoder_token_ids, &encoder_output)?
|
||||||
|
.squeeze(0)?;
|
||||||
|
let logits = if args.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
args.repeat_penalty,
|
||||||
|
&output_token_ids[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let next_token_id = logits_processor.sample(&logits)?;
|
||||||
|
if next_token_id as usize == builder.config.eos_token_id {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
output_token_ids.push(next_token_id);
|
||||||
|
if let Some(text) = tokenizer.id_to_token(next_token_id) {
|
||||||
|
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||||
|
print!("{text}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let dt = start.elapsed();
|
||||||
|
println!(
|
||||||
|
"\n{} tokens generated ({:.2} token/s)\n",
|
||||||
|
output_token_ids.len(),
|
||||||
|
output_token_ids.len() as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
37
candle-examples/examples/quantized/README.md
Normal file
37
candle-examples/examples/quantized/README.md
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
# candle-quantized-llama: Fast Inference of quantized LLaMA models
|
||||||
|
|
||||||
|
This example provides a quantized LLaMA model similar to
|
||||||
|
[llama.cpp](https://github.com/ggerganov/llama.cpp). This is based on candle
|
||||||
|
built-in quantization methods. Supported features include:
|
||||||
|
|
||||||
|
- 2-bit, 3-bit, 4-bit, 5-bit, 6-bit and 8-bit integer quantization support.
|
||||||
|
- SIMD optimizations on Apple Silicon and x86.
|
||||||
|
- Support using the `gguf` and `ggml` file formats.
|
||||||
|
|
||||||
|
The weights are automatically downloaded for you from the [HuggingFace
|
||||||
|
Hub](https://huggingface.co/) on the first run. There are various command line
|
||||||
|
flags to use local files instead, run with `--help` to learn about them.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
## Running some example.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example quantized --release -- --prompt "The best thing about coding in rust is "
|
||||||
|
|
||||||
|
> avx: true, neon: false, simd128: false, f16c: true
|
||||||
|
> temp: 0.80 repeat-penalty: 1.10 repeat-last-n: 64
|
||||||
|
> loaded 291 tensors (3.79GB) in 2.17s
|
||||||
|
> params: HParams { n_vocab: 32000, n_embd: 4096, n_mult: 256, n_head: 32, n_layer: 32, n_rot: 128, ftype: 2 }
|
||||||
|
> The best thing about coding in rust is 1.) that I don’t need to worry about memory leaks, 2.) speed and 3.) my program will compile even on old machines.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Command-line flags
|
||||||
|
|
||||||
|
Run with `--help` to see all options.
|
||||||
|
|
||||||
|
- `--which`: specify the model to use, e.g. `7b`, `13-chat`, `7b-code`.
|
||||||
|
- `--prompt interactive`: interactive mode where multiple prompts can be
|
||||||
|
entered.
|
||||||
|
- `--model mymodelfile.gguf`: use a local model file rather than getting one
|
||||||
|
from the hub.
|
BIN
candle-examples/examples/quantized/assets/aoc.gif
Normal file
BIN
candle-examples/examples/quantized/assets/aoc.gif
Normal file
Binary file not shown.
After Width: | Height: | Size: 119 KiB |
@ -12,7 +12,7 @@ use candle::quantized::{ggml_file, gguf_file};
|
|||||||
use candle::{Device, Tensor};
|
use candle::{Device, Tensor};
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
|
||||||
mod model;
|
use candle_transformers::models::quantized_llama as model;
|
||||||
use model::ModelWeights;
|
use model::ModelWeights;
|
||||||
|
|
||||||
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||||
@ -44,6 +44,27 @@ enum Which {
|
|||||||
L13bCode,
|
L13bCode,
|
||||||
#[value(name = "32b-code")]
|
#[value(name = "32b-code")]
|
||||||
L34bCode,
|
L34bCode,
|
||||||
|
#[value(name = "7b-mistral")]
|
||||||
|
Mistral7b,
|
||||||
|
#[value(name = "7b-mistral-instruct")]
|
||||||
|
Mistral7bInstruct,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Which {
|
||||||
|
fn is_mistral(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::L7b
|
||||||
|
| Self::L13b
|
||||||
|
| Self::L70b
|
||||||
|
| Self::L7bChat
|
||||||
|
| Self::L13bChat
|
||||||
|
| Self::L70bChat
|
||||||
|
| Self::L7bCode
|
||||||
|
| Self::L13bCode
|
||||||
|
| Self::L34bCode => false,
|
||||||
|
Self::Mistral7b | Self::Mistral7bInstruct => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -71,6 +92,10 @@ struct Args {
|
|||||||
#[arg(long, default_value_t = 0.8)]
|
#[arg(long, default_value_t = 0.8)]
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
/// The seed to use when generating random samples.
|
||||||
#[arg(long, default_value_t = 299792458)]
|
#[arg(long, default_value_t = 299792458)]
|
||||||
seed: u64,
|
seed: u64,
|
||||||
@ -106,7 +131,12 @@ impl Args {
|
|||||||
Some(config) => std::path::PathBuf::from(config),
|
Some(config) => std::path::PathBuf::from(config),
|
||||||
None => {
|
None => {
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let api = api.model("hf-internal-testing/llama-tokenizer".to_string());
|
let repo = if self.which.is_mistral() {
|
||||||
|
"mistralai/Mistral-7B-v0.1"
|
||||||
|
} else {
|
||||||
|
"hf-internal-testing/llama-tokenizer"
|
||||||
|
};
|
||||||
|
let api = api.model(repo.to_string());
|
||||||
api.get("tokenizer.json")?
|
api.get("tokenizer.json")?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -136,6 +166,14 @@ impl Args {
|
|||||||
Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"),
|
Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"),
|
||||||
Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"),
|
Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"),
|
||||||
Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"),
|
Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"),
|
||||||
|
Which::Mistral7b => (
|
||||||
|
"TheBloke/Mistral-7B-v0.1-GGUF",
|
||||||
|
"mistral-7b-v0.1.Q4_K_S.gguf",
|
||||||
|
),
|
||||||
|
Which::Mistral7bInstruct => (
|
||||||
|
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
|
||||||
|
"mistral-7b-instruct-v0.1.Q4_K_S.gguf",
|
||||||
|
),
|
||||||
};
|
};
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let api = api.model(repo.to_string());
|
let api = api.model(repo.to_string());
|
||||||
@ -257,7 +295,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
| Which::L7bCode
|
| Which::L7bCode
|
||||||
| Which::L13bCode
|
| Which::L13bCode
|
||||||
| Which::L34bCode => 1,
|
| Which::L34bCode => 1,
|
||||||
Which::L70b | Which::L70bChat => 8,
|
Which::Mistral7b | Which::Mistral7bInstruct | Which::L70b | Which::L70bChat => 8,
|
||||||
};
|
};
|
||||||
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
|
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
|
||||||
}
|
}
|
||||||
@ -287,7 +325,11 @@ fn main() -> anyhow::Result<()> {
|
|||||||
prompt.pop();
|
prompt.pop();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
prompt
|
if args.which.is_mistral() {
|
||||||
|
format!("[INST] {prompt} [/INST]")
|
||||||
|
} else {
|
||||||
|
prompt
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
print!("{}", &prompt_str);
|
print!("{}", &prompt_str);
|
||||||
@ -310,7 +352,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
prompt_tokens
|
prompt_tokens
|
||||||
};
|
};
|
||||||
let mut all_tokens = vec![];
|
let mut all_tokens = vec![];
|
||||||
let mut logits_processor = LogitsProcessor::new(args.seed, temperature);
|
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
|
||||||
|
|
||||||
let start_prompt_processing = std::time::Instant::now();
|
let start_prompt_processing = std::time::Instant::now();
|
||||||
let mut next_token = {
|
let mut next_token = {
|
||||||
@ -323,6 +365,8 @@ fn main() -> anyhow::Result<()> {
|
|||||||
all_tokens.push(next_token);
|
all_tokens.push(next_token);
|
||||||
print_token(next_token, &tokenizer);
|
print_token(next_token, &tokenizer);
|
||||||
|
|
||||||
|
let eos_token = *tokenizer.get_vocab(true).get("</s>").unwrap();
|
||||||
|
|
||||||
let start_post_prompt = std::time::Instant::now();
|
let start_post_prompt = std::time::Instant::now();
|
||||||
for index in 0..to_sample {
|
for index in 0..to_sample {
|
||||||
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
||||||
@ -341,6 +385,9 @@ fn main() -> anyhow::Result<()> {
|
|||||||
next_token = logits_processor.sample(&logits)?;
|
next_token = logits_processor.sample(&logits)?;
|
||||||
all_tokens.push(next_token);
|
all_tokens.push(next_token);
|
||||||
print_token(next_token, &tokenizer);
|
print_token(next_token, &tokenizer);
|
||||||
|
if next_token == eos_token {
|
||||||
|
break;
|
||||||
|
};
|
||||||
}
|
}
|
||||||
let dt = start_post_prompt.elapsed();
|
let dt = start_post_prompt.elapsed();
|
||||||
println!(
|
println!(
|
||||||
|
16
candle-examples/examples/reinforcement-learning/README.md
Normal file
16
candle-examples/examples/reinforcement-learning/README.md
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
# candle-reinforcement-learning
|
||||||
|
|
||||||
|
Reinforcement Learning examples for candle.
|
||||||
|
|
||||||
|
This has been tested with `gymnasium` version `0.29.1`. You can install the
|
||||||
|
Python package with:
|
||||||
|
```bash
|
||||||
|
pip install "gymnasium[accept-rom-license]"
|
||||||
|
```
|
||||||
|
|
||||||
|
In order to run the example, use the following command. Note the additional
|
||||||
|
`--package` flag to ensure that there is no conflict with the `candle-pyo3`
|
||||||
|
crate.
|
||||||
|
```bash
|
||||||
|
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples
|
||||||
|
```
|
@ -0,0 +1,308 @@
|
|||||||
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
from collections import deque
|
||||||
|
from PIL import Image
|
||||||
|
from multiprocessing import Process, Pipe
|
||||||
|
|
||||||
|
# atari_wrappers.py
|
||||||
|
class NoopResetEnv(gym.Wrapper):
|
||||||
|
def __init__(self, env, noop_max=30):
|
||||||
|
"""Sample initial states by taking random number of no-ops on reset.
|
||||||
|
No-op is assumed to be action 0.
|
||||||
|
"""
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
self.noop_max = noop_max
|
||||||
|
self.override_num_noops = None
|
||||||
|
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
""" Do no-op action for a number of steps in [1, noop_max]."""
|
||||||
|
self.env.reset()
|
||||||
|
if self.override_num_noops is not None:
|
||||||
|
noops = self.override_num_noops
|
||||||
|
else:
|
||||||
|
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) #pylint: disable=E1101
|
||||||
|
assert noops > 0
|
||||||
|
obs = None
|
||||||
|
for _ in range(noops):
|
||||||
|
obs, _, done, _ = self.env.step(0)
|
||||||
|
if done:
|
||||||
|
obs = self.env.reset()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
class FireResetEnv(gym.Wrapper):
|
||||||
|
def __init__(self, env):
|
||||||
|
"""Take action on reset for environments that are fixed until firing."""
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
|
||||||
|
assert len(env.unwrapped.get_action_meanings()) >= 3
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.env.reset()
|
||||||
|
obs, _, done, _ = self.env.step(1)
|
||||||
|
if done:
|
||||||
|
self.env.reset()
|
||||||
|
obs, _, done, _ = self.env.step(2)
|
||||||
|
if done:
|
||||||
|
self.env.reset()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
class ImageSaver(gym.Wrapper):
|
||||||
|
def __init__(self, env, img_path, rank):
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
self._cnt = 0
|
||||||
|
self._img_path = img_path
|
||||||
|
self._rank = rank
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
step_result = self.env.step(action)
|
||||||
|
obs, _, _, _ = step_result
|
||||||
|
img = Image.fromarray(obs, 'RGB')
|
||||||
|
img.save('%s/out%d-%05d.png' % (self._img_path, self._rank, self._cnt))
|
||||||
|
self._cnt += 1
|
||||||
|
return step_result
|
||||||
|
|
||||||
|
class EpisodicLifeEnv(gym.Wrapper):
|
||||||
|
def __init__(self, env):
|
||||||
|
"""Make end-of-life == end-of-episode, but only reset on true game over.
|
||||||
|
Done by DeepMind for the DQN and co. since it helps value estimation.
|
||||||
|
"""
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
self.lives = 0
|
||||||
|
self.was_real_done = True
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
obs, reward, done, info = self.env.step(action)
|
||||||
|
self.was_real_done = done
|
||||||
|
# check current lives, make loss of life terminal,
|
||||||
|
# then update lives to handle bonus lives
|
||||||
|
lives = self.env.unwrapped.ale.lives()
|
||||||
|
if lives < self.lives and lives > 0:
|
||||||
|
# for Qbert somtimes we stay in lives == 0 condtion for a few frames
|
||||||
|
# so its important to keep lives > 0, so that we only reset once
|
||||||
|
# the environment advertises done.
|
||||||
|
done = True
|
||||||
|
self.lives = lives
|
||||||
|
return obs, reward, done, info
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset only when lives are exhausted.
|
||||||
|
This way all states are still reachable even though lives are episodic,
|
||||||
|
and the learner need not know about any of this behind-the-scenes.
|
||||||
|
"""
|
||||||
|
if self.was_real_done:
|
||||||
|
obs = self.env.reset()
|
||||||
|
else:
|
||||||
|
# no-op step to advance from terminal/lost life state
|
||||||
|
obs, _, _, _ = self.env.step(0)
|
||||||
|
self.lives = self.env.unwrapped.ale.lives()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
class MaxAndSkipEnv(gym.Wrapper):
|
||||||
|
def __init__(self, env, skip=4):
|
||||||
|
"""Return only every `skip`-th frame"""
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
# most recent raw observations (for max pooling across time steps)
|
||||||
|
self._obs_buffer = deque(maxlen=2)
|
||||||
|
self._skip = skip
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
"""Repeat action, sum reward, and max over last observations."""
|
||||||
|
total_reward = 0.0
|
||||||
|
done = None
|
||||||
|
for _ in range(self._skip):
|
||||||
|
obs, reward, done, info = self.env.step(action)
|
||||||
|
self._obs_buffer.append(obs)
|
||||||
|
total_reward += reward
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
max_frame = np.max(np.stack(self._obs_buffer), axis=0)
|
||||||
|
|
||||||
|
return max_frame, total_reward, done, info
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Clear past frame buffer and init. to first obs. from inner env."""
|
||||||
|
self._obs_buffer.clear()
|
||||||
|
obs = self.env.reset()
|
||||||
|
self._obs_buffer.append(obs)
|
||||||
|
return obs
|
||||||
|
|
||||||
|
class ClipRewardEnv(gym.RewardWrapper):
|
||||||
|
def reward(self, reward):
|
||||||
|
"""Bin reward to {+1, 0, -1} by its sign."""
|
||||||
|
return np.sign(reward)
|
||||||
|
|
||||||
|
class WarpFrame(gym.ObservationWrapper):
|
||||||
|
def __init__(self, env):
|
||||||
|
"""Warp frames to 84x84 as done in the Nature paper and later work."""
|
||||||
|
gym.ObservationWrapper.__init__(self, env)
|
||||||
|
self.res = 84
|
||||||
|
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(self.res, self.res, 1), dtype='uint8')
|
||||||
|
|
||||||
|
def observation(self, obs):
|
||||||
|
frame = np.dot(obs.astype('float32'), np.array([0.299, 0.587, 0.114], 'float32'))
|
||||||
|
frame = np.array(Image.fromarray(frame).resize((self.res, self.res),
|
||||||
|
resample=Image.BILINEAR), dtype=np.uint8)
|
||||||
|
return frame.reshape((self.res, self.res, 1))
|
||||||
|
|
||||||
|
class FrameStack(gym.Wrapper):
|
||||||
|
def __init__(self, env, k):
|
||||||
|
"""Buffer observations and stack across channels (last axis)."""
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
self.k = k
|
||||||
|
self.frames = deque([], maxlen=k)
|
||||||
|
shp = env.observation_space.shape
|
||||||
|
assert shp[2] == 1 # can only stack 1-channel frames
|
||||||
|
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(shp[0], shp[1], k), dtype='uint8')
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Clear buffer and re-fill by duplicating the first observation."""
|
||||||
|
ob = self.env.reset()
|
||||||
|
for _ in range(self.k): self.frames.append(ob)
|
||||||
|
return self.observation()
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
ob, reward, done, info = self.env.step(action)
|
||||||
|
self.frames.append(ob)
|
||||||
|
return self.observation(), reward, done, info
|
||||||
|
|
||||||
|
def observation(self):
|
||||||
|
assert len(self.frames) == self.k
|
||||||
|
return np.concatenate(self.frames, axis=2)
|
||||||
|
|
||||||
|
def wrap_deepmind(env, episode_life=True, clip_rewards=True):
|
||||||
|
"""Configure environment for DeepMind-style Atari.
|
||||||
|
|
||||||
|
Note: this does not include frame stacking!"""
|
||||||
|
assert 'NoFrameskip' in env.spec.id # required for DeepMind-style skip
|
||||||
|
if episode_life:
|
||||||
|
env = EpisodicLifeEnv(env)
|
||||||
|
env = NoopResetEnv(env, noop_max=30)
|
||||||
|
env = MaxAndSkipEnv(env, skip=4)
|
||||||
|
if 'FIRE' in env.unwrapped.get_action_meanings():
|
||||||
|
env = FireResetEnv(env)
|
||||||
|
env = WarpFrame(env)
|
||||||
|
if clip_rewards:
|
||||||
|
env = ClipRewardEnv(env)
|
||||||
|
return env
|
||||||
|
|
||||||
|
# envs.py
|
||||||
|
def make_env(env_id, img_dir, seed, rank):
|
||||||
|
def _thunk():
|
||||||
|
env = gym.make(env_id)
|
||||||
|
env.reset(seed=(seed + rank))
|
||||||
|
if img_dir is not None:
|
||||||
|
env = ImageSaver(env, img_dir, rank)
|
||||||
|
env = wrap_deepmind(env)
|
||||||
|
env = WrapPyTorch(env)
|
||||||
|
return env
|
||||||
|
|
||||||
|
return _thunk
|
||||||
|
|
||||||
|
class WrapPyTorch(gym.ObservationWrapper):
|
||||||
|
def __init__(self, env=None):
|
||||||
|
super(WrapPyTorch, self).__init__(env)
|
||||||
|
self.observation_space = gym.spaces.Box(0.0, 1.0, [1, 84, 84], dtype='float32')
|
||||||
|
|
||||||
|
def observation(self, observation):
|
||||||
|
return observation.transpose(2, 0, 1)
|
||||||
|
|
||||||
|
# vecenv.py
|
||||||
|
class VecEnv(object):
|
||||||
|
"""
|
||||||
|
Vectorized environment base class
|
||||||
|
"""
|
||||||
|
def step(self, vac):
|
||||||
|
"""
|
||||||
|
Apply sequence of actions to sequence of environments
|
||||||
|
actions -> (observations, rewards, news)
|
||||||
|
|
||||||
|
where 'news' is a boolean vector indicating whether each element is new.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
def reset(self):
|
||||||
|
"""
|
||||||
|
Reset all environments
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
def close(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# subproc_vec_env.py
|
||||||
|
def worker(remote, env_fn_wrapper):
|
||||||
|
env = env_fn_wrapper.x()
|
||||||
|
while True:
|
||||||
|
cmd, data = remote.recv()
|
||||||
|
if cmd == 'step':
|
||||||
|
ob, reward, done, info = env.step(data)
|
||||||
|
if done:
|
||||||
|
ob = env.reset()
|
||||||
|
remote.send((ob, reward, done, info))
|
||||||
|
elif cmd == 'reset':
|
||||||
|
ob = env.reset()
|
||||||
|
remote.send(ob)
|
||||||
|
elif cmd == 'close':
|
||||||
|
remote.close()
|
||||||
|
break
|
||||||
|
elif cmd == 'get_spaces':
|
||||||
|
remote.send((env.action_space, env.observation_space))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
class CloudpickleWrapper(object):
|
||||||
|
"""
|
||||||
|
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
|
||||||
|
"""
|
||||||
|
def __init__(self, x):
|
||||||
|
self.x = x
|
||||||
|
def __getstate__(self):
|
||||||
|
import cloudpickle
|
||||||
|
return cloudpickle.dumps(self.x)
|
||||||
|
def __setstate__(self, ob):
|
||||||
|
import pickle
|
||||||
|
self.x = pickle.loads(ob)
|
||||||
|
|
||||||
|
class SubprocVecEnv(VecEnv):
|
||||||
|
def __init__(self, env_fns):
|
||||||
|
"""
|
||||||
|
envs: list of gym environments to run in subprocesses
|
||||||
|
"""
|
||||||
|
nenvs = len(env_fns)
|
||||||
|
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
|
||||||
|
self.ps = [Process(target=worker, args=(work_remote, CloudpickleWrapper(env_fn)))
|
||||||
|
for (work_remote, env_fn) in zip(self.work_remotes, env_fns)]
|
||||||
|
for p in self.ps:
|
||||||
|
p.start()
|
||||||
|
|
||||||
|
self.remotes[0].send(('get_spaces', None))
|
||||||
|
self.action_space, self.observation_space = self.remotes[0].recv()
|
||||||
|
|
||||||
|
|
||||||
|
def step(self, actions):
|
||||||
|
for remote, action in zip(self.remotes, actions):
|
||||||
|
remote.send(('step', action))
|
||||||
|
results = [remote.recv() for remote in self.remotes]
|
||||||
|
obs, rews, dones, infos = zip(*results)
|
||||||
|
return np.stack(obs), np.stack(rews), np.stack(dones), infos
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
for remote in self.remotes:
|
||||||
|
remote.send(('reset', None))
|
||||||
|
return np.stack([remote.recv() for remote in self.remotes])
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
for remote in self.remotes:
|
||||||
|
remote.send(('close', None))
|
||||||
|
for p in self.ps:
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_envs(self):
|
||||||
|
return len(self.remotes)
|
||||||
|
|
||||||
|
# Create the environment.
|
||||||
|
def make(env_name, img_dir, num_processes):
|
||||||
|
envs = SubprocVecEnv([
|
||||||
|
make_env(env_name, img_dir, 1337, i) for i in range(num_processes)
|
||||||
|
])
|
||||||
|
return envs
|
360
candle-examples/examples/reinforcement-learning/ddpg.rs
Normal file
360
candle-examples/examples/reinforcement-learning/ddpg.rs
Normal file
@ -0,0 +1,360 @@
|
|||||||
|
/* Deep Deterministic Policy Gradient.
|
||||||
|
|
||||||
|
Continuous control with deep reinforcement learning, Lillicrap et al. 2015
|
||||||
|
https://arxiv.org/abs/1509.02971
|
||||||
|
|
||||||
|
See https://spinningup.openai.com/en/latest/algorithms/ddpg.html for a
|
||||||
|
reference python implementation.
|
||||||
|
*/
|
||||||
|
use super::gym_env::GymEnv;
|
||||||
|
use candle::{DType, Device, Result, Tensor};
|
||||||
|
use candle_nn::VarMap;
|
||||||
|
|
||||||
|
// The impact of the q value of the next state on the current state's q value.
|
||||||
|
const GAMMA: f64 = 0.99;
|
||||||
|
// The weight for updating the target networks.
|
||||||
|
const TAU: f64 = 0.005;
|
||||||
|
// The capacity of the replay buffer used for sampling training data.
|
||||||
|
const REPLAY_BUFFER_CAPACITY: usize = 100_000;
|
||||||
|
// The training batch size for each training iteration.
|
||||||
|
const TRAINING_BATCH_SIZE: usize = 100;
|
||||||
|
// The total number of episodes.
|
||||||
|
const MAX_EPISODES: usize = 100;
|
||||||
|
// The maximum length of an episode.
|
||||||
|
const EPISODE_LENGTH: usize = 200;
|
||||||
|
// The number of training iterations after one episode finishes.
|
||||||
|
const TRAINING_ITERATIONS: usize = 200;
|
||||||
|
|
||||||
|
// Ornstein-Uhlenbeck process parameters.
|
||||||
|
const MU: f64 = 0.0;
|
||||||
|
const THETA: f64 = 0.15;
|
||||||
|
const SIGMA: f64 = 0.1;
|
||||||
|
|
||||||
|
const ACTOR_LEARNING_RATE: f64 = 1e-4;
|
||||||
|
const CRITIC_LEARNING_RATE: f64 = 1e-3;
|
||||||
|
|
||||||
|
struct OuNoise {
|
||||||
|
mu: f64,
|
||||||
|
theta: f64,
|
||||||
|
sigma: f64,
|
||||||
|
state: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OuNoise {
|
||||||
|
fn new(mu: f64, theta: f64, sigma: f64, num_actions: usize) -> Result<Self> {
|
||||||
|
let state = Tensor::ones(num_actions, DType::F32, &Device::Cpu)?;
|
||||||
|
Ok(Self {
|
||||||
|
mu,
|
||||||
|
theta,
|
||||||
|
sigma,
|
||||||
|
state,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sample(&mut self) -> Result<Tensor> {
|
||||||
|
let dx = (((self.mu - &self.state)? * self.theta)?
|
||||||
|
+ (self.state.randn_like(0., 1.)? * self.beta)?)?;
|
||||||
|
self.state = (self.state + dx)?;
|
||||||
|
Ok(self.state.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ReplayBuffer {
|
||||||
|
obs: Tensor,
|
||||||
|
next_obs: Vec<Tensor>,
|
||||||
|
rewards: Vec<Tensor>,
|
||||||
|
actions: Vec<Tensor>,
|
||||||
|
capacity: usize,
|
||||||
|
len: usize,
|
||||||
|
i: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ReplayBuffer {
|
||||||
|
fn new(capacity: usize, num_obs: usize, num_actions: usize) -> Self {
|
||||||
|
let cpu = Device::Cpu;
|
||||||
|
let obs = vec![Tensor::zeros(num_obs, DType::F32, &cpu)?; capacity];
|
||||||
|
let next_obs = vec![Tensor::zeros(num_obs, DType::F32, &cpu)?; capacity];
|
||||||
|
let rewards = vec![Tensor::zeros(1, DType::F32, &cpu)?; capacity];
|
||||||
|
let actions = vec![Tensor::zeros(num_actions, DType::F32, &cpu)?; capacity];
|
||||||
|
Ok(Self {
|
||||||
|
obs,
|
||||||
|
next_obs,
|
||||||
|
rewards,
|
||||||
|
actions,
|
||||||
|
capacity,
|
||||||
|
len: 0,
|
||||||
|
i: 0,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn push(&mut self, obs: &Tensor, actions: &Tensor, reward: &Tensor, next_obs: &Tensor) {
|
||||||
|
let i = self.i % self.capacity;
|
||||||
|
self.obs.get(i as _).copy_(obs);
|
||||||
|
self.rewards.get(i as _).copy_(reward);
|
||||||
|
self.actions.get(i as _).copy_(actions);
|
||||||
|
self.next_obs.get(i as _).copy_(next_obs);
|
||||||
|
self.i += 1;
|
||||||
|
if self.len < self.capacity {
|
||||||
|
self.len += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn random_batch(&self, batch_size: usize) -> Option<(Tensor, Tensor, Tensor, Tensor)> {
|
||||||
|
if self.len < 3 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let batch_size = batch_size.min(self.len - 1);
|
||||||
|
let batch_indexes = Tensor::randint((self.len - 2) as _, [batch_size as _], INT64_CPU);
|
||||||
|
|
||||||
|
let states = self.obs.index_select(0, &batch_indexes);
|
||||||
|
let next_states = self.next_obs.index_select(0, &batch_indexes);
|
||||||
|
let actions = self.actions.index_select(0, &batch_indexes);
|
||||||
|
let rewards = self.rewards.index_select(0, &batch_indexes);
|
||||||
|
|
||||||
|
Some((states, actions, rewards, next_states))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Actor {
|
||||||
|
varmap: VarMap,
|
||||||
|
network: candle_nn::Func,
|
||||||
|
num_obs: usize,
|
||||||
|
num_actions: usize,
|
||||||
|
opt: candle_nn::AdamW,
|
||||||
|
learning_rate: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Actor {
|
||||||
|
fn new(num_obs: usize, num_actions: usize, learning_rate: f64) -> Self {
|
||||||
|
let mut varmap = VarMap::new();
|
||||||
|
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
|
||||||
|
let al1 = candle_nn::linear(num_obs, 400, vb.pp("al1"))?;
|
||||||
|
let al2 = candle_nn::linear(400, 300, vb.pp("al2"))?;
|
||||||
|
let al3 = candle_nn::linear(300, num_actions, vb.pp("al3"))?;
|
||||||
|
let network = Func::new(|xs| {
|
||||||
|
xs.apply(al1)?
|
||||||
|
.relu()?
|
||||||
|
.apply(al2)?
|
||||||
|
.relu()?
|
||||||
|
.apply(al3)?
|
||||||
|
.tanh()
|
||||||
|
});
|
||||||
|
let opt = nn::Adam::default()
|
||||||
|
.build(&var_store, learning_rate)
|
||||||
|
.unwrap();
|
||||||
|
let p = &var_store.root();
|
||||||
|
Self {
|
||||||
|
network,
|
||||||
|
num_obs,
|
||||||
|
num_actions,
|
||||||
|
varmap,
|
||||||
|
opt,
|
||||||
|
learning_rate,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, obs: &Tensor) -> Result<Tensor> {
|
||||||
|
obs.apply(&self.network)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Critic {
|
||||||
|
varmap: VarMap,
|
||||||
|
network: candle_nn::Func,
|
||||||
|
num_obs: usize,
|
||||||
|
num_actions: usize,
|
||||||
|
opt: candle_nn::AdamW,
|
||||||
|
learning_rate: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Critic {
|
||||||
|
fn new(num_obs: usize, num_actions: usize, learning_rate: f64) -> Result<Self> {
|
||||||
|
let varmap = VarMap::new();
|
||||||
|
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::Cpu);
|
||||||
|
let cl1 = candle_nn::linear(num_obs + num_actions, 400, vb.pp("cl1"))?;
|
||||||
|
let cl2 = candle_nn::linear(400, 300, vb.pp("cl2"))?;
|
||||||
|
let cl3 = candle_nn::linear(300, 1, vb.pp("cl3"))?;
|
||||||
|
let network = Func::new(|xs| xs.apply(cl1)?.relu()?.apply(&cl2)?.relu()?.apply(cl3));
|
||||||
|
let adamw_params = candle_nn::ParamsAdamW {
|
||||||
|
lr: 1e-3,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let opt = AdamW::new(varmap.all_vars(), adamw_params);
|
||||||
|
Ok(Self {
|
||||||
|
network,
|
||||||
|
varmap,
|
||||||
|
num_obs,
|
||||||
|
num_actions,
|
||||||
|
opt,
|
||||||
|
learning_rate,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, obs: &Tensor, actions: &Tensor) -> Result<Tensor> {
|
||||||
|
let xs = Tensor::cat(&[actions, obs], 1)?;
|
||||||
|
xs.apply(&self.network)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* TODO: enable tracking
|
||||||
|
fn track(dest: &mut nn::VarStore, src: &nn::VarStore, tau: f64) {
|
||||||
|
tch::no_grad(|| {
|
||||||
|
for (dest, src) in dest
|
||||||
|
.trainable_variables()
|
||||||
|
.iter_mut()
|
||||||
|
.zip(src.trainable_variables().iter())
|
||||||
|
{
|
||||||
|
dest.copy_(&(tau * src + (1.0 - tau) * &*dest));
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
struct Agent {
|
||||||
|
actor: Actor,
|
||||||
|
actor_target: Actor,
|
||||||
|
|
||||||
|
critic: Critic,
|
||||||
|
critic_target: Critic,
|
||||||
|
|
||||||
|
replay_buffer: ReplayBuffer,
|
||||||
|
|
||||||
|
ou_noise: OuNoise,
|
||||||
|
|
||||||
|
train: bool,
|
||||||
|
|
||||||
|
gamma: f64,
|
||||||
|
tau: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Agent {
|
||||||
|
fn new(
|
||||||
|
actor: Actor,
|
||||||
|
critic: Critic,
|
||||||
|
ou_noise: OuNoise,
|
||||||
|
replay_buffer_capacity: usize,
|
||||||
|
train: bool,
|
||||||
|
gamma: f64,
|
||||||
|
tau: f64,
|
||||||
|
) -> Self {
|
||||||
|
let actor_target = actor.clone();
|
||||||
|
let critic_target = critic.clone();
|
||||||
|
let replay_buffer =
|
||||||
|
ReplayBuffer::new(replay_buffer_capacity, actor.num_obs, actor.num_actions);
|
||||||
|
Self {
|
||||||
|
actor,
|
||||||
|
actor_target,
|
||||||
|
critic,
|
||||||
|
critic_target,
|
||||||
|
replay_buffer,
|
||||||
|
ou_noise,
|
||||||
|
train,
|
||||||
|
gamma,
|
||||||
|
tau,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn actions(&mut self, obs: &Tensor) -> Result<Tensor> {
|
||||||
|
let mut actions = tch::no_grad(|| self.actor.forward(obs));
|
||||||
|
if self.train {
|
||||||
|
actions += self.ou_noise.sample();
|
||||||
|
}
|
||||||
|
actions
|
||||||
|
}
|
||||||
|
|
||||||
|
fn remember(&mut self, obs: &Tensor, actions: &Tensor, reward: &Tensor, next_obs: &Tensor) {
|
||||||
|
self.replay_buffer.push(obs, actions, reward, next_obs);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn train(&mut self, batch_size: usize) {
|
||||||
|
let (states, actions, rewards, next_states) =
|
||||||
|
match self.replay_buffer.random_batch(batch_size) {
|
||||||
|
Some(v) => v,
|
||||||
|
_ => return, // We don't have enough samples for training yet.
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut q_target = self
|
||||||
|
.critic_target
|
||||||
|
.forward(&next_states, &self.actor_target.forward(&next_states));
|
||||||
|
q_target = rewards + (self.gamma * q_target).detach();
|
||||||
|
|
||||||
|
let q = self.critic.forward(&states, &actions);
|
||||||
|
|
||||||
|
let diff = q_target - q;
|
||||||
|
let critic_loss = (&diff * &diff).mean(Float);
|
||||||
|
|
||||||
|
self.critic.opt.zero_grad();
|
||||||
|
critic_loss.backward();
|
||||||
|
self.critic.opt.step();
|
||||||
|
|
||||||
|
let actor_loss = -self
|
||||||
|
.critic
|
||||||
|
.forward(&states, &self.actor.forward(&states))
|
||||||
|
.mean(Float);
|
||||||
|
|
||||||
|
self.actor.opt.zero_grad();
|
||||||
|
actor_loss.backward();
|
||||||
|
self.actor.opt.step();
|
||||||
|
|
||||||
|
track(
|
||||||
|
&mut self.critic_target.var_store,
|
||||||
|
&self.critic.var_store,
|
||||||
|
self.tau,
|
||||||
|
);
|
||||||
|
track(
|
||||||
|
&mut self.actor_target.var_store,
|
||||||
|
&self.actor.var_store,
|
||||||
|
self.tau,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn run() -> Result<()> {
|
||||||
|
let env = GymEnv::new("Pendulum-v1")?;
|
||||||
|
println!("action space: {}", env.action_space());
|
||||||
|
println!("observation space: {:?}", env.observation_space());
|
||||||
|
|
||||||
|
let num_obs = env.observation_space().iter().product::<usize>();
|
||||||
|
let num_actions = env.action_space();
|
||||||
|
|
||||||
|
let actor = Actor::new(num_obs, num_actions, ACTOR_LEARNING_RATE);
|
||||||
|
let critic = Critic::new(num_obs, num_actions, CRITIC_LEARNING_RATE);
|
||||||
|
let ou_noise = OuNoise::new(MU, THETA, SIGMA, num_actions);
|
||||||
|
let mut agent = Agent::new(
|
||||||
|
actor,
|
||||||
|
critic,
|
||||||
|
ou_noise,
|
||||||
|
REPLAY_BUFFER_CAPACITY,
|
||||||
|
true,
|
||||||
|
GAMMA,
|
||||||
|
TAU,
|
||||||
|
);
|
||||||
|
|
||||||
|
for episode in 0..MAX_EPISODES as u64 {
|
||||||
|
let mut obs = env.reset(episode)?;
|
||||||
|
|
||||||
|
let mut total_reward = 0.0;
|
||||||
|
for _ in 0..EPISODE_LENGTH {
|
||||||
|
let actions: f32 = 2.0 * agent.actions(&obs)?.to_vec0::<f32>()?;
|
||||||
|
let actions = actions.clamp(-2.0, 2.0);
|
||||||
|
let step = env.step(vec![action_vec])?;
|
||||||
|
total_reward += step.reward;
|
||||||
|
|
||||||
|
agent.remember(&obs, &actions.into(), &step.reward.into(), &step.obs);
|
||||||
|
|
||||||
|
if step.is_done {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
obs = step.obs;
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("episode {episode} with total reward of {total_reward}");
|
||||||
|
|
||||||
|
for _ in 0..TRAINING_ITERATIONS {
|
||||||
|
agent.train(TRAINING_BATCH_SIZE);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
108
candle-examples/examples/reinforcement-learning/gym_env.rs
Normal file
108
candle-examples/examples/reinforcement-learning/gym_env.rs
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
#![allow(unused)]
|
||||||
|
//! Wrappers around the Python API of Gymnasium (the new version of OpenAI gym)
|
||||||
|
use candle::{Device, Result, Tensor};
|
||||||
|
use pyo3::prelude::*;
|
||||||
|
use pyo3::types::PyDict;
|
||||||
|
|
||||||
|
/// The return value for a step.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Step<A> {
|
||||||
|
pub obs: Tensor,
|
||||||
|
pub action: A,
|
||||||
|
pub reward: f64,
|
||||||
|
pub is_done: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A: Copy> Step<A> {
|
||||||
|
/// Returns a copy of this step changing the observation tensor.
|
||||||
|
pub fn copy_with_obs(&self, obs: &Tensor) -> Step<A> {
|
||||||
|
Step {
|
||||||
|
obs: obs.clone(),
|
||||||
|
action: self.action,
|
||||||
|
reward: self.reward,
|
||||||
|
is_done: self.is_done,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An OpenAI Gym session.
|
||||||
|
pub struct GymEnv {
|
||||||
|
env: PyObject,
|
||||||
|
action_space: usize,
|
||||||
|
observation_space: Vec<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn w(res: PyErr) -> candle::Error {
|
||||||
|
candle::Error::wrap(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GymEnv {
|
||||||
|
/// Creates a new session of the specified OpenAI Gym environment.
|
||||||
|
pub fn new(name: &str) -> Result<GymEnv> {
|
||||||
|
Python::with_gil(|py| {
|
||||||
|
let gym = py.import("gymnasium")?;
|
||||||
|
let make = gym.getattr("make")?;
|
||||||
|
let env = make.call1((name,))?;
|
||||||
|
let action_space = env.getattr("action_space")?;
|
||||||
|
let action_space = if let Ok(val) = action_space.getattr("n") {
|
||||||
|
val.extract()?
|
||||||
|
} else {
|
||||||
|
let action_space: Vec<usize> = action_space.getattr("shape")?.extract()?;
|
||||||
|
action_space[0]
|
||||||
|
};
|
||||||
|
let observation_space = env.getattr("observation_space")?;
|
||||||
|
let observation_space = observation_space.getattr("shape")?.extract()?;
|
||||||
|
Ok(GymEnv {
|
||||||
|
env: env.into(),
|
||||||
|
action_space,
|
||||||
|
observation_space,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.map_err(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resets the environment, returning the observation tensor.
|
||||||
|
pub fn reset(&self, seed: u64) -> Result<Tensor> {
|
||||||
|
let obs: Vec<f32> = Python::with_gil(|py| {
|
||||||
|
let kwargs = PyDict::new(py);
|
||||||
|
kwargs.set_item("seed", seed)?;
|
||||||
|
let obs = self.env.call_method(py, "reset", (), Some(kwargs))?;
|
||||||
|
obs.as_ref(py).get_item(0)?.extract()
|
||||||
|
})
|
||||||
|
.map_err(w)?;
|
||||||
|
Tensor::new(obs, &Device::Cpu)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies an environment step using the specified action.
|
||||||
|
pub fn step<A: pyo3::IntoPy<pyo3::Py<pyo3::PyAny>> + Clone>(
|
||||||
|
&self,
|
||||||
|
action: A,
|
||||||
|
) -> Result<Step<A>> {
|
||||||
|
let (obs, reward, is_done) = Python::with_gil(|py| {
|
||||||
|
let step = self.env.call_method(py, "step", (action.clone(),), None)?;
|
||||||
|
let step = step.as_ref(py);
|
||||||
|
let obs: Vec<f32> = step.get_item(0)?.extract()?;
|
||||||
|
let reward: f64 = step.get_item(1)?.extract()?;
|
||||||
|
let is_done: bool = step.get_item(2)?.extract()?;
|
||||||
|
Ok((obs, reward, is_done))
|
||||||
|
})
|
||||||
|
.map_err(w)?;
|
||||||
|
let obs = Tensor::new(obs, &Device::Cpu)?;
|
||||||
|
Ok(Step {
|
||||||
|
obs,
|
||||||
|
reward,
|
||||||
|
is_done,
|
||||||
|
action,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the number of allowed actions for this environment.
|
||||||
|
pub fn action_space(&self) -> usize {
|
||||||
|
self.action_space
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the shape of the observation tensors.
|
||||||
|
pub fn observation_space(&self) -> &[usize] {
|
||||||
|
&self.observation_space
|
||||||
|
}
|
||||||
|
}
|
76
candle-examples/examples/reinforcement-learning/main.rs
Normal file
76
candle-examples/examples/reinforcement-learning/main.rs
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
#![allow(unused)]
|
||||||
|
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
mod ddpg;
|
||||||
|
mod gym_env;
|
||||||
|
mod vec_gym_env;
|
||||||
|
|
||||||
|
use candle::Result;
|
||||||
|
use clap::Parser;
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
|
// The total number of episodes.
|
||||||
|
const MAX_EPISODES: usize = 100;
|
||||||
|
// The maximum length of an episode.
|
||||||
|
const EPISODE_LENGTH: usize = 200;
|
||||||
|
|
||||||
|
#[derive(Parser, Debug, Clone)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let env = gym_env::GymEnv::new("Pendulum-v1")?;
|
||||||
|
println!("action space: {}", env.action_space());
|
||||||
|
println!("observation space: {:?}", env.observation_space());
|
||||||
|
|
||||||
|
let _num_obs = env.observation_space().iter().product::<usize>();
|
||||||
|
let _num_actions = env.action_space();
|
||||||
|
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
|
||||||
|
for episode in 0..MAX_EPISODES {
|
||||||
|
let mut obs = env.reset(episode as u64)?;
|
||||||
|
|
||||||
|
let mut total_reward = 0.0;
|
||||||
|
for _ in 0..EPISODE_LENGTH {
|
||||||
|
let actions = rng.gen_range(-2.0..2.0);
|
||||||
|
|
||||||
|
let step = env.step(vec![actions])?;
|
||||||
|
total_reward += step.reward;
|
||||||
|
|
||||||
|
if step.is_done {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
obs = step.obs;
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("episode {episode} with total reward of {total_reward}");
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -0,0 +1,91 @@
|
|||||||
|
#![allow(unused)]
|
||||||
|
//! Vectorized version of the gym environment.
|
||||||
|
use candle::{DType, Device, Result, Tensor};
|
||||||
|
use pyo3::prelude::*;
|
||||||
|
use pyo3::types::PyDict;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Step {
|
||||||
|
pub obs: Tensor,
|
||||||
|
pub reward: Tensor,
|
||||||
|
pub is_done: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct VecGymEnv {
|
||||||
|
env: PyObject,
|
||||||
|
action_space: usize,
|
||||||
|
observation_space: Vec<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn w(res: PyErr) -> candle::Error {
|
||||||
|
candle::Error::wrap(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VecGymEnv {
|
||||||
|
pub fn new(name: &str, img_dir: Option<&str>, nprocesses: usize) -> Result<VecGymEnv> {
|
||||||
|
Python::with_gil(|py| {
|
||||||
|
let sys = py.import("sys")?;
|
||||||
|
let path = sys.getattr("path")?;
|
||||||
|
let _ = path.call_method1(
|
||||||
|
"append",
|
||||||
|
("candle-examples/examples/reinforcement-learning",),
|
||||||
|
)?;
|
||||||
|
let gym = py.import("atari_wrappers")?;
|
||||||
|
let make = gym.getattr("make")?;
|
||||||
|
let env = make.call1((name, img_dir, nprocesses))?;
|
||||||
|
let action_space = env.getattr("action_space")?;
|
||||||
|
let action_space = action_space.getattr("n")?.extract()?;
|
||||||
|
let observation_space = env.getattr("observation_space")?;
|
||||||
|
let observation_space: Vec<usize> = observation_space.getattr("shape")?.extract()?;
|
||||||
|
let observation_space =
|
||||||
|
[vec![nprocesses].as_slice(), observation_space.as_slice()].concat();
|
||||||
|
Ok(VecGymEnv {
|
||||||
|
env: env.into(),
|
||||||
|
action_space,
|
||||||
|
observation_space,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.map_err(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(&self) -> Result<Tensor> {
|
||||||
|
let obs = Python::with_gil(|py| {
|
||||||
|
let obs = self.env.call_method0(py, "reset")?;
|
||||||
|
let obs = obs.call_method0(py, "flatten")?;
|
||||||
|
obs.extract::<Vec<f32>>(py)
|
||||||
|
})
|
||||||
|
.map_err(w)?;
|
||||||
|
Tensor::new(obs, &Device::Cpu)?.reshape(self.observation_space.as_slice())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn step(&self, action: Vec<usize>) -> Result<Step> {
|
||||||
|
let (obs, reward, is_done) = Python::with_gil(|py| {
|
||||||
|
let step = self.env.call_method(py, "step", (action,), None)?;
|
||||||
|
let step = step.as_ref(py);
|
||||||
|
let obs = step.get_item(0)?.call_method("flatten", (), None)?;
|
||||||
|
let obs_buffer = pyo3::buffer::PyBuffer::get(obs)?;
|
||||||
|
let obs: Vec<u8> = obs_buffer.to_vec(py)?;
|
||||||
|
let reward: Vec<f32> = step.get_item(1)?.extract()?;
|
||||||
|
let is_done: Vec<f32> = step.get_item(2)?.extract()?;
|
||||||
|
Ok((obs, reward, is_done))
|
||||||
|
})
|
||||||
|
.map_err(w)?;
|
||||||
|
let obs = Tensor::from_vec(obs, self.observation_space.as_slice(), &Device::Cpu)?
|
||||||
|
.to_dtype(DType::F32)?;
|
||||||
|
let reward = Tensor::new(reward, &Device::Cpu)?;
|
||||||
|
let is_done = Tensor::new(is_done, &Device::Cpu)?;
|
||||||
|
Ok(Step {
|
||||||
|
obs,
|
||||||
|
reward,
|
||||||
|
is_done,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn action_space(&self) -> usize {
|
||||||
|
self.action_space
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn observation_space(&self) -> &[usize] {
|
||||||
|
&self.observation_space
|
||||||
|
}
|
||||||
|
}
|
44
candle-examples/examples/segment-anything/README.md
Normal file
44
candle-examples/examples/segment-anything/README.md
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
# candle-segment-anything: Segment-Anything Model
|
||||||
|
|
||||||
|
This example is based on Meta AI [Segment-Anything
|
||||||
|
Model](https://github.com/facebookresearch/segment-anything). This model
|
||||||
|
provides a robust and fast image segmentation pipeline that can be tweaked via
|
||||||
|
some prompting (requesting some points to be in the target mask, requesting some
|
||||||
|
points to be part of the background so _not_ in the target mask, specifying some
|
||||||
|
bounding box).
|
||||||
|
|
||||||
|
The default backbone can be replaced by the smaller and faster TinyViT model
|
||||||
|
based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
|
||||||
|
|
||||||
|
## Running some example.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example segment-anything --release -- \
|
||||||
|
--image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
--use-tiny
|
||||||
|
--point 0.6,0.6 --point 0.6,0.55
|
||||||
|
```
|
||||||
|
|
||||||
|
Running this command generates a `sam_merged.jpg` file containing the original
|
||||||
|
image with a blue overlay of the selected mask. The red dots represent the prompt
|
||||||
|
specified by `--point 0.6,0.6 --point 0.6,0.55`, this prompt is assumed to be part
|
||||||
|
of the target mask.
|
||||||
|
|
||||||
|
The values used for `--point` should be a comma delimited pair of float values.
|
||||||
|
They are proportional to the image dimension, i.e. use 0.5 for the image center.
|
||||||
|
|
||||||
|
Original image:
|
||||||
|

|
||||||
|
|
||||||
|
Segment results by prompting with a single point `--point 0.6,0.55`:
|
||||||
|

|
||||||
|
|
||||||
|
Segment results by prompting with multiple points `--point 0.6,0.6 --point 0.6,0.55`:
|
||||||
|

|
||||||
|
|
||||||
|
### Command-line flags
|
||||||
|
- `--use-tiny`: use the TinyViT based MobileSAM backbone rather than the default
|
||||||
|
one.
|
||||||
|
- `--point`: specifies the location of the target points.
|
||||||
|
- `--threshold`: sets the threshold value to be part of the mask, a negative
|
||||||
|
value results in a larger mask and can be specified via `--threshold=-1.2`.
|
BIN
candle-examples/examples/segment-anything/assets/sam_merged.jpg
Normal file
BIN
candle-examples/examples/segment-anything/assets/sam_merged.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 157 KiB |
Binary file not shown.
After Width: | Height: | Size: 158 KiB |
Binary file not shown.
After Width: | Height: | Size: 158 KiB |
@ -7,107 +7,11 @@ extern crate intel_mkl_src;
|
|||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
pub mod model_image_encoder;
|
use candle::DType;
|
||||||
pub mod model_mask_decoder;
|
use candle_nn::VarBuilder;
|
||||||
pub mod model_prompt_encoder;
|
use candle_transformers::models::segment_anything::sam;
|
||||||
pub mod model_sam;
|
|
||||||
pub mod model_transformer;
|
|
||||||
|
|
||||||
use candle::{DType, Result, Tensor};
|
|
||||||
use candle_nn::{Module, VarBuilder};
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
|
||||||
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
|
|
||||||
let inner = if bias {
|
|
||||||
candle_nn::linear(in_dim, out_dim, vb)?
|
|
||||||
} else {
|
|
||||||
candle_nn::linear_no_bias(in_dim, out_dim, vb)?
|
|
||||||
};
|
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
|
||||||
Ok(Linear { inner, span })
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct LayerNorm2d {
|
|
||||||
weight: Tensor,
|
|
||||||
bias: Tensor,
|
|
||||||
num_channels: usize,
|
|
||||||
eps: f64,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LayerNorm2d {
|
|
||||||
pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let weight = vb.get(num_channels, "weight")?;
|
|
||||||
let bias = vb.get(num_channels, "bias")?;
|
|
||||||
Ok(Self {
|
|
||||||
weight,
|
|
||||||
bias,
|
|
||||||
num_channels,
|
|
||||||
eps,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for LayerNorm2d {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let u = xs.mean_keepdim(1)?;
|
|
||||||
let xs = xs.broadcast_sub(&u)?;
|
|
||||||
let s = xs.sqr()?.mean_keepdim(1)?;
|
|
||||||
let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?;
|
|
||||||
xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)?
|
|
||||||
.broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct MlpBlock {
|
|
||||||
lin1: Linear,
|
|
||||||
lin2: Linear,
|
|
||||||
activation: candle_nn::Activation,
|
|
||||||
span: tracing::Span,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MlpBlock {
|
|
||||||
pub fn new(
|
|
||||||
embedding_dim: usize,
|
|
||||||
mlp_dim: usize,
|
|
||||||
activation: candle_nn::Activation,
|
|
||||||
vb: VarBuilder,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?;
|
|
||||||
let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?;
|
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "mlp-block");
|
|
||||||
Ok(Self {
|
|
||||||
lin1,
|
|
||||||
lin2,
|
|
||||||
activation,
|
|
||||||
span,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for MlpBlock {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
|
||||||
xs.apply(&self.lin1)?
|
|
||||||
.apply(&self.activation)?
|
|
||||||
.apply(&self.lin2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Linear {
|
|
||||||
inner: candle_nn::Linear,
|
|
||||||
span: tracing::Span,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for Linear {
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
|
||||||
self.inner.forward(x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
struct Args {
|
struct Args {
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@ -123,15 +27,28 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
generate_masks: bool,
|
generate_masks: bool,
|
||||||
|
|
||||||
#[arg(long, default_value_t = 0.5)]
|
/// List of x,y coordinates, between 0 and 1 (0.5 is at the middle of the image). These points
|
||||||
point_x: f64,
|
/// should be part of the generated mask.
|
||||||
|
#[arg(long)]
|
||||||
|
point: Vec<String>,
|
||||||
|
|
||||||
#[arg(long, default_value_t = 0.5)]
|
/// List of x,y coordinates, between 0 and 1 (0.5 is at the middle of the image). These points
|
||||||
point_y: f64,
|
/// should not be part of the generated mask and should be part of the background instead.
|
||||||
|
#[arg(long)]
|
||||||
|
neg_point: Vec<String>,
|
||||||
|
|
||||||
|
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
|
||||||
|
/// mask, positive makes the mask more selective.
|
||||||
|
#[arg(long, default_value_t = 0.)]
|
||||||
|
threshold: f32,
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tracing: bool,
|
tracing: bool,
|
||||||
|
|
||||||
|
/// Use the TinyViT based models from MobileSAM
|
||||||
|
#[arg(long)]
|
||||||
|
use_tiny: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
pub fn main() -> anyhow::Result<()> {
|
||||||
@ -149,28 +66,9 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
let (image, initial_h, initial_w) = if args.image.ends_with(".safetensors") {
|
let (image, initial_h, initial_w) =
|
||||||
let mut tensors = candle::safetensors::load(&args.image, &device)?;
|
candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?;
|
||||||
let image = match tensors.remove("image") {
|
let image = image.to_device(&device)?;
|
||||||
Some(image) => image,
|
|
||||||
None => {
|
|
||||||
if tensors.len() != 1 {
|
|
||||||
anyhow::bail!("multiple tensors in '{}'", args.image)
|
|
||||||
}
|
|
||||||
tensors.into_values().next().unwrap()
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let image = if image.rank() == 4 {
|
|
||||||
image.get(0)?
|
|
||||||
} else {
|
|
||||||
image
|
|
||||||
};
|
|
||||||
let (_c, h, w) = image.dims3()?;
|
|
||||||
(image, h, w)
|
|
||||||
} else {
|
|
||||||
let (image, h, w) = candle_examples::load_image(&args.image, Some(model_sam::IMAGE_SIZE))?;
|
|
||||||
(image.to_device(&device)?, h, w)
|
|
||||||
};
|
|
||||||
println!("loaded image {image:?}");
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
let model = match args.model {
|
let model = match args.model {
|
||||||
@ -178,13 +76,20 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
None => {
|
None => {
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let api = api.model("lmz/candle-sam".to_string());
|
let api = api.model("lmz/candle-sam".to_string());
|
||||||
api.get("sam_vit_b_01ec64.safetensors")?
|
let filename = if args.use_tiny {
|
||||||
|
"mobile_sam-tiny-vitt.safetensors"
|
||||||
|
} else {
|
||||||
|
"sam_vit_b_01ec64.safetensors"
|
||||||
|
};
|
||||||
|
api.get(filename)?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||||
let weights = weights.deserialize()?;
|
let sam = if args.use_tiny {
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
sam::Sam::new_tiny(vb)? // tiny vit_t
|
||||||
let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
|
} else {
|
||||||
|
sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
|
||||||
|
};
|
||||||
|
|
||||||
if args.generate_masks {
|
if args.generate_masks {
|
||||||
// Default options similar to the Python version.
|
// Default options similar to the Python version.
|
||||||
@ -196,7 +101,7 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
/* crop_n_points_downscale_factor */ 1,
|
/* crop_n_points_downscale_factor */ 1,
|
||||||
)?;
|
)?;
|
||||||
for (idx, bbox) in bboxes.iter().enumerate() {
|
for (idx, bbox) in bboxes.iter().enumerate() {
|
||||||
println!("{bbox:?}");
|
println!("{idx} {bbox:?}");
|
||||||
let mask = (&bbox.data.to_dtype(DType::U8)? * 255.)?;
|
let mask = (&bbox.data.to_dtype(DType::U8)? * 255.)?;
|
||||||
let (h, w) = mask.dims2()?;
|
let (h, w) = mask.dims2()?;
|
||||||
let mask = mask.broadcast_as((3, h, w))?;
|
let mask = mask.broadcast_as((3, h, w))?;
|
||||||
@ -208,66 +113,69 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let point = Some((args.point_x, args.point_y));
|
let iter_points = args.point.iter().map(|p| (p, true));
|
||||||
|
let iter_neg_points = args.neg_point.iter().map(|p| (p, false));
|
||||||
|
let points = iter_points
|
||||||
|
.chain(iter_neg_points)
|
||||||
|
.map(|(point, b)| {
|
||||||
|
use std::str::FromStr;
|
||||||
|
let xy = point.split(',').collect::<Vec<_>>();
|
||||||
|
if xy.len() != 2 {
|
||||||
|
anyhow::bail!("expected format for points is 0.4,0.2")
|
||||||
|
}
|
||||||
|
Ok((f64::from_str(xy[0])?, f64::from_str(xy[1])?, b))
|
||||||
|
})
|
||||||
|
.collect::<anyhow::Result<Vec<_>>>()?;
|
||||||
let start_time = std::time::Instant::now();
|
let start_time = std::time::Instant::now();
|
||||||
let (mask, iou_predictions) = sam.forward(&image, point, false)?;
|
let (mask, iou_predictions) = sam.forward(&image, &points, false)?;
|
||||||
println!(
|
println!(
|
||||||
"mask generated in {:.2}s",
|
"mask generated in {:.2}s",
|
||||||
start_time.elapsed().as_secs_f32()
|
start_time.elapsed().as_secs_f32()
|
||||||
);
|
);
|
||||||
println!("mask:\n{mask}");
|
println!("mask:\n{mask}");
|
||||||
println!("iou_predictions: {iou_predictions:?}");
|
println!("iou_predictions: {iou_predictions}");
|
||||||
|
|
||||||
// Save the mask as an image.
|
let mask = (mask.ge(args.threshold)? * 255.)?;
|
||||||
let mask = (mask.ge(0f32)? * 255.)?;
|
|
||||||
let (_one, h, w) = mask.dims3()?;
|
let (_one, h, w) = mask.dims3()?;
|
||||||
let mask = mask.expand((3, h, w))?;
|
let mask = mask.expand((3, h, w))?;
|
||||||
candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?;
|
|
||||||
|
|
||||||
if !args.image.ends_with(".safetensors") {
|
let mut img = image::io::Reader::open(&args.image)?
|
||||||
let mut img = image::io::Reader::open(&args.image)?
|
.decode()
|
||||||
.decode()
|
.map_err(candle::Error::wrap)?;
|
||||||
.map_err(candle::Error::wrap)?;
|
let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
|
||||||
let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
|
let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
||||||
let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) {
|
||||||
match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) {
|
Some(image) => image,
|
||||||
Some(image) => image,
|
None => anyhow::bail!("error saving merged image"),
|
||||||
None => anyhow::bail!("error saving merged image"),
|
};
|
||||||
};
|
let mask_img = image::DynamicImage::from(mask_img).resize_to_fill(
|
||||||
let mask_img = image::DynamicImage::from(mask_img).resize_to_fill(
|
img.width(),
|
||||||
img.width(),
|
img.height(),
|
||||||
img.height(),
|
image::imageops::FilterType::CatmullRom,
|
||||||
image::imageops::FilterType::CatmullRom,
|
);
|
||||||
);
|
for x in 0..img.width() {
|
||||||
for x in 0..img.width() {
|
for y in 0..img.height() {
|
||||||
for y in 0..img.height() {
|
let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y);
|
||||||
let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y);
|
if mask_p.0[0] > 100 {
|
||||||
if mask_p.0[0] > 100 {
|
let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y);
|
||||||
let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y);
|
img_p.0[2] = 255 - (255 - img_p.0[2]) / 2;
|
||||||
img_p.0[2] = 255 - (255 - img_p.0[2]) / 2;
|
img_p.0[1] /= 2;
|
||||||
img_p.0[1] /= 2;
|
img_p.0[0] /= 2;
|
||||||
img_p.0[0] /= 2;
|
imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p)
|
||||||
imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
match point {
|
|
||||||
Some((x, y)) => {
|
|
||||||
let (x, y) = (
|
|
||||||
(x * img.width() as f64) as i32,
|
|
||||||
(y * img.height() as f64) as i32,
|
|
||||||
);
|
|
||||||
imageproc::drawing::draw_filled_circle(
|
|
||||||
&img,
|
|
||||||
(x, y),
|
|
||||||
3,
|
|
||||||
image::Rgba([255, 0, 0, 200]),
|
|
||||||
)
|
|
||||||
.save("sam_merged.jpg")?
|
|
||||||
}
|
|
||||||
None => img.save("sam_merged.jpg")?,
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
for (x, y, b) in points {
|
||||||
|
let x = (x * img.width() as f64) as i32;
|
||||||
|
let y = (y * img.height() as f64) as i32;
|
||||||
|
let color = if b {
|
||||||
|
image::Rgba([255, 0, 0, 200])
|
||||||
|
} else {
|
||||||
|
image::Rgba([0, 255, 0, 200])
|
||||||
|
};
|
||||||
|
imageproc::drawing::draw_filled_circle_mut(&mut img, (x, y), 3, color);
|
||||||
|
}
|
||||||
|
img.save("sam_merged.jpg")?
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
63
candle-examples/examples/stable-diffusion/README.md
Normal file
63
candle-examples/examples/stable-diffusion/README.md
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
# candle-stable-diffusion: A Diffusers API in Rust/Candle
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
_A rusty robot holding a fire torch in its hand_, generated by Stable Diffusion
|
||||||
|
XL using Rust and [candle](https://github.com/huggingface/candle).
|
||||||
|
|
||||||
|
The `stable-diffusion` example is a conversion of
|
||||||
|
[diffusers-rs](https://github.com/LaurentMazare/diffusers-rs) using candle
|
||||||
|
rather than libtorch. This implementation supports Stable Diffusion v1.5, v2.1,
|
||||||
|
as well as Stable Diffusion XL 1.0.
|
||||||
|
|
||||||
|
## Getting the weights
|
||||||
|
|
||||||
|
The weights are automatically downloaded for you from the [HuggingFace
|
||||||
|
Hub](https://huggingface.co/) on the first run. There are various command line
|
||||||
|
flags to use local files instead, run with `--help` to learn about them.
|
||||||
|
|
||||||
|
## Running some example.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example stable-diffusion --release --features=cuda,cudnn \
|
||||||
|
-- --prompt "a cosmonaut on a horse (hd, realistic, high-def)"
|
||||||
|
```
|
||||||
|
|
||||||
|
The final image is named `sd_final.png` by default.
|
||||||
|
The default scheduler is the Denoising Diffusion Implicit Model scheduler (DDIM). The
|
||||||
|
original paper and some code can be found in the [associated repo](https://github.com/ermongroup/ddim).
|
||||||
|
|
||||||
|
### Command-line flags
|
||||||
|
|
||||||
|
- `--prompt`: the prompt to be used to generate the image.
|
||||||
|
- `--uncond-prompt`: the optional unconditional prompt.
|
||||||
|
- `--sd-version`: the Stable Diffusion version to use, can be `v1-5`, `v2-1`, or
|
||||||
|
`xl`.
|
||||||
|
- `--cpu`: use the cpu rather than the gpu (much slower).
|
||||||
|
- `--height`, `--width`: set the height and width for the generated image.
|
||||||
|
- `--n-steps`: the number of steps to be used in the diffusion process.
|
||||||
|
- `--num-samples`: the number of samples to generate.
|
||||||
|
- `--final-image`: the filename for the generated image(s).
|
||||||
|
|
||||||
|
### Using flash-attention
|
||||||
|
|
||||||
|
Using flash attention makes image generation a lot faster and uses less memory.
|
||||||
|
The downside is some long compilation time. You can set the
|
||||||
|
`CANDLE_FLASH_ATTN_BUILD_DIR` environment variable to something like
|
||||||
|
`/home/user/.candle` to ensures that the compilation artifacts are properly
|
||||||
|
cached.
|
||||||
|
|
||||||
|
Enabling flash-attention requires both a feature flag, `--feature flash-attn`
|
||||||
|
and using the command line flag `--use-flash-attn`.
|
||||||
|
|
||||||
|
## Image to Image Pipeline
|
||||||
|
...
|
||||||
|
|
||||||
|
## FAQ
|
||||||
|
|
||||||
|
### Memory Issues
|
||||||
|
|
||||||
|
This requires a GPU with more than 8GB of memory, as a fallback the CPU version can be used
|
||||||
|
with the `--cpu` flag but is much slower.
|
||||||
|
Alternatively, reducing the height and width with the `--height` and `--width`
|
||||||
|
flag is likely to reduce memory usage significantly.
|
Binary file not shown.
After Width: | Height: | Size: 36 KiB |
@ -4,20 +4,10 @@ extern crate accelerate_src;
|
|||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
mod attention;
|
use candle_transformers::models::stable_diffusion;
|
||||||
mod clip;
|
|
||||||
mod ddim;
|
|
||||||
mod embeddings;
|
|
||||||
mod resnet;
|
|
||||||
mod schedulers;
|
|
||||||
mod stable_diffusion;
|
|
||||||
mod unet_2d;
|
|
||||||
mod unet_2d_blocks;
|
|
||||||
mod utils;
|
|
||||||
mod vae;
|
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle::{DType, Device, IndexOp, Tensor, D};
|
use candle::{DType, Device, IndexOp, Module, Tensor, D};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
@ -107,14 +97,13 @@ struct Args {
|
|||||||
img2img_strength: f64,
|
img2img_strength: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
|
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
|
||||||
enum StableDiffusionVersion {
|
enum StableDiffusionVersion {
|
||||||
V1_5,
|
V1_5,
|
||||||
V2_1,
|
V2_1,
|
||||||
Xl,
|
Xl,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(unused)]
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
enum ModelFile {
|
enum ModelFile {
|
||||||
Tokenizer,
|
Tokenizer,
|
||||||
@ -214,7 +203,18 @@ impl ModelFile {
|
|||||||
Self::Clip => (version.repo(), version.clip_file(use_f16)),
|
Self::Clip => (version.repo(), version.clip_file(use_f16)),
|
||||||
Self::Clip2 => (version.repo(), version.clip2_file(use_f16)),
|
Self::Clip2 => (version.repo(), version.clip2_file(use_f16)),
|
||||||
Self::Unet => (version.repo(), version.unet_file(use_f16)),
|
Self::Unet => (version.repo(), version.unet_file(use_f16)),
|
||||||
Self::Vae => (version.repo(), version.vae_file(use_f16)),
|
Self::Vae => {
|
||||||
|
// Override for SDXL when using f16 weights.
|
||||||
|
// See https://github.com/huggingface/candle/issues/1060
|
||||||
|
if version == StableDiffusionVersion::Xl && use_f16 {
|
||||||
|
(
|
||||||
|
"madebyollin/sdxl-vae-fp16-fix",
|
||||||
|
"diffusion_pytorch_model.safetensors",
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(version.repo(), version.vae_file(use_f16))
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let filename = Api::new()?.model(repo.to_string()).get(path)?;
|
let filename = Api::new()?.model(repo.to_string()).get(path)?;
|
||||||
Ok(filename)
|
Ok(filename)
|
||||||
@ -494,9 +494,8 @@ fn run(args: Args) -> Result<()> {
|
|||||||
num_samples
|
num_samples
|
||||||
);
|
);
|
||||||
let image = vae.decode(&(&latents / 0.18215)?)?;
|
let image = vae.decode(&(&latents / 0.18215)?)?;
|
||||||
// TODO: Add the clamping between 0 and 1.
|
|
||||||
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
||||||
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
||||||
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
|
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
|
||||||
candle_examples::save_image(&image, image_filename)?
|
candle_examples::save_image(&image, image_filename)?
|
||||||
}
|
}
|
||||||
|
@ -1,39 +0,0 @@
|
|||||||
use candle::{Device, Result, Tensor};
|
|
||||||
use candle_nn::Module;
|
|
||||||
|
|
||||||
pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
|
|
||||||
if steps < 1 {
|
|
||||||
candle::bail!("cannot use linspace with steps {steps} <= 1")
|
|
||||||
}
|
|
||||||
let delta = (stop - start) / (steps - 1) as f64;
|
|
||||||
let vs = (0..steps)
|
|
||||||
.map(|step| start + step as f64 * delta)
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
Tensor::from_vec(vs, steps, &Device::Cpu)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wrap the conv2d op to provide some tracing.
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Conv2d {
|
|
||||||
inner: candle_nn::Conv2d,
|
|
||||||
span: tracing::Span,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Conv2d {
|
|
||||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
|
||||||
self.inner.forward(x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn conv2d(
|
|
||||||
in_channels: usize,
|
|
||||||
out_channels: usize,
|
|
||||||
kernel_size: usize,
|
|
||||||
cfg: candle_nn::Conv2dConfig,
|
|
||||||
vs: candle_nn::VarBuilder,
|
|
||||||
) -> Result<Conv2d> {
|
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "conv2d");
|
|
||||||
let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?;
|
|
||||||
Ok(Conv2d { inner, span })
|
|
||||||
}
|
|
25
candle-examples/examples/stable-lm/README.md
Normal file
25
candle-examples/examples/stable-lm/README.md
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# candle-stable-lm
|
||||||
|
|
||||||
|
StableLM-3B-4E1T is a 3 billion parameter decoder-only language model
|
||||||
|
pre-trained on 1 trillion tokens of diverse English and code datasets for 4
|
||||||
|
epochs. See the [HuggingFace Hub Model
|
||||||
|
Card](https://huggingface.co/stabilityai/stablelm-3b-4e1t).
|
||||||
|
|
||||||
|
Note that this model is gated so you will have to request access on the Hub in
|
||||||
|
order to be able to use it.
|
||||||
|
|
||||||
|
## Running some example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example stable-lm --release --features cuda -- --prompt 'What is the most efficient programming language in use?' --sample-len 150
|
||||||
|
avx: true, neon: false, simd128: false, f16c: true
|
||||||
|
temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
|
||||||
|
retrieved the files in 126.593µs
|
||||||
|
loaded the model in 3.474148965s
|
||||||
|
What is the most efficient programming language in use?
|
||||||
|
The answer to this question depends on what you mean by "efficient". If you're talking about speed, then C++ and Java are probably your best bets. But if you're talking about ease of development, then Python is probably the way to go.
|
||||||
|
Python is a high-level, interpreted language that is easy to learn and use. It has a large community of developers who are always working on new features and improvements.
|
||||||
|
C++ is a low-level, compiled language that can be used for both desktop applications and web development. It's more difficult to learn than Python but offers greater control over the code.
|
||||||
|
Java is another high-level language that is popular with programmers because it runs on many different platforms (including Android phones
|
||||||
|
150 tokens generated (37.61 token/s)
|
||||||
|
```
|
268
candle-examples/examples/stable-lm/main.rs
Normal file
268
candle-examples/examples/stable-lm/main.rs
Normal file
@ -0,0 +1,268 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle_transformers::models::quantized_stable_lm::Model as QStableLM;
|
||||||
|
use candle_transformers::models::stable_lm::{Config, Model as StableLM};
|
||||||
|
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
enum Model {
|
||||||
|
StableLM(StableLM),
|
||||||
|
Quantized(QStableLM),
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TextGeneration {
|
||||||
|
model: Model,
|
||||||
|
device: Device,
|
||||||
|
tokenizer: TokenOutputStream,
|
||||||
|
logits_processor: LogitsProcessor,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextGeneration {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn new(
|
||||||
|
model: Model,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
seed: u64,
|
||||||
|
temp: Option<f64>,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
device: &Device,
|
||||||
|
) -> Self {
|
||||||
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
tokenizer: TokenOutputStream::new(tokenizer),
|
||||||
|
logits_processor,
|
||||||
|
repeat_penalty,
|
||||||
|
repeat_last_n,
|
||||||
|
device: device.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||||
|
use std::io::Write;
|
||||||
|
self.tokenizer.clear();
|
||||||
|
let mut tokens = self
|
||||||
|
.tokenizer
|
||||||
|
.tokenizer()
|
||||||
|
.encode(prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
for &t in tokens.iter() {
|
||||||
|
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||||
|
print!("{t}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
|
||||||
|
let mut generated_tokens = 0usize;
|
||||||
|
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||||
|
Some(token) => token,
|
||||||
|
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
||||||
|
};
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
for index in 0..sample_len {
|
||||||
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
|
let start_pos = tokens.len().saturating_sub(context_size);
|
||||||
|
let ctxt = &tokens[start_pos..];
|
||||||
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
|
let logits = match &mut self.model {
|
||||||
|
Model::StableLM(m) => m.forward(&input, start_pos)?,
|
||||||
|
Model::Quantized(m) => m.forward(&input, start_pos)?,
|
||||||
|
};
|
||||||
|
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
let logits = if self.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
self.repeat_penalty,
|
||||||
|
&tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
generated_tokens += 1;
|
||||||
|
if next_token == eos_token {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let dt = start_gen.elapsed();
|
||||||
|
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
println!(
|
||||||
|
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||||
|
generated_tokens as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
use_flash_attn: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long)]
|
||||||
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(long, short = 'n', default_value_t = 100)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "lmz/candle-stablelm-3b-4e1t")]
|
||||||
|
model_id: String,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "main")]
|
||||||
|
revision: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
weight_files: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
quantized: bool,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature.unwrap_or(0.),
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let api = Api::new()?;
|
||||||
|
let repo = api.repo(Repo::with_revision(
|
||||||
|
args.model_id,
|
||||||
|
RepoType::Model,
|
||||||
|
args.revision,
|
||||||
|
));
|
||||||
|
let tokenizer_filename = match args.tokenizer_file {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => repo.get("tokenizer.json")?,
|
||||||
|
};
|
||||||
|
let filenames = match args.weight_files {
|
||||||
|
Some(files) => files
|
||||||
|
.split(',')
|
||||||
|
.map(std::path::PathBuf::from)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
None => {
|
||||||
|
if args.quantized {
|
||||||
|
vec![repo.get("model-q4k.gguf")?]
|
||||||
|
} else {
|
||||||
|
vec![repo.get("model.safetensors")?]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
|
||||||
|
let (model, device) = if args.quantized {
|
||||||
|
let filename = &filenames[0];
|
||||||
|
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
|
||||||
|
let model = QStableLM::new(&config, vb)?;
|
||||||
|
(Model::Quantized(model), Device::Cpu)
|
||||||
|
} else {
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let dtype = if device.is_cuda() {
|
||||||
|
DType::BF16
|
||||||
|
} else {
|
||||||
|
DType::F32
|
||||||
|
};
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
|
let model = StableLM::new(&config, vb)?;
|
||||||
|
(Model::StableLM(model), device)
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let mut pipeline = TextGeneration::new(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
args.seed,
|
||||||
|
args.temperature,
|
||||||
|
args.top_p,
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
pipeline.run(&args.prompt, args.sample_len)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
25
candle-examples/examples/t5/README.md
Normal file
25
candle-examples/examples/t5/README.md
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# candle-t5
|
||||||
|
|
||||||
|
## Encoder-decoder example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode
|
||||||
|
...
|
||||||
|
Running on CPU, to run on GPU, build this example with `--features cuda`
|
||||||
|
Eine schöne Kerze.
|
||||||
|
9 tokens generated (2.42 token/s)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Sentence embedding example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle."
|
||||||
|
...
|
||||||
|
[[[ 0.0515, -0.0541, -0.0761, ..., -0.0392, 0.1511, -0.0265],
|
||||||
|
[-0.0974, 0.0998, -0.1659, ..., -0.2450, 0.1738, -0.0164],
|
||||||
|
[ 0.0624, -0.1024, 0.0430, ..., -0.1388, 0.0564, -0.2962],
|
||||||
|
[-0.0389, -0.1173, 0.0026, ..., 0.1064, -0.1065, 0.0990],
|
||||||
|
[ 0.1300, 0.0027, -0.0326, ..., 0.0026, -0.0317, 0.0851]]]
|
||||||
|
Tensor[[1, 5, 512], f32]
|
||||||
|
Took 303.766583ms
|
||||||
|
```
|
299
candle-examples/examples/t5/main.rs
Normal file
299
candle-examples/examples/t5/main.rs
Normal file
@ -0,0 +1,299 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
use std::io::Write;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use candle_transformers::models::t5;
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
use clap::Parser;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
const DTYPE: DType = DType::F32;
|
||||||
|
|
||||||
|
#[derive(Parser, Debug, Clone)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
/// The model repository to use on the HuggingFace hub.
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
revision: Option<String>,
|
||||||
|
|
||||||
|
/// Enable decoding.
|
||||||
|
#[arg(long)]
|
||||||
|
decode: bool,
|
||||||
|
|
||||||
|
// Enable/disable decoding.
|
||||||
|
#[arg(long, default_value = "false")]
|
||||||
|
disable_cache: bool,
|
||||||
|
|
||||||
|
/// Use this prompt, otherwise compute sentence similarities.
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: Option<String>,
|
||||||
|
|
||||||
|
/// If set along with --decode, will use this prompt to initialize the decoder.
|
||||||
|
#[arg(long)]
|
||||||
|
decoder_prompt: Option<String>,
|
||||||
|
|
||||||
|
/// L2 normalization for embeddings.
|
||||||
|
#[arg(long, default_value = "true")]
|
||||||
|
normalize_embeddings: bool,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long, default_value_t = 0.8)]
|
||||||
|
temperature: f64,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct T5ModelBuilder {
|
||||||
|
device: Device,
|
||||||
|
config: t5::Config,
|
||||||
|
weights_filename: Vec<PathBuf>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl T5ModelBuilder {
|
||||||
|
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let default_model = "t5-small".to_string();
|
||||||
|
let default_revision = "refs/pr/15".to_string();
|
||||||
|
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
|
||||||
|
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||||
|
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||||
|
(None, Some(revision)) => (default_model, revision),
|
||||||
|
(None, None) => (default_model, default_revision),
|
||||||
|
};
|
||||||
|
|
||||||
|
let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision);
|
||||||
|
let api = Api::new()?;
|
||||||
|
let api = api.repo(repo);
|
||||||
|
let config_filename = api.get("config.json")?;
|
||||||
|
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||||
|
let weights_filename = if model_id == "google/flan-t5-xxl" {
|
||||||
|
vec![
|
||||||
|
api.get("model-00001-of-00005.safetensors")?,
|
||||||
|
api.get("model-00002-of-00005.safetensors")?,
|
||||||
|
api.get("model-00003-of-00005.safetensors")?,
|
||||||
|
api.get("model-00004-of-00005.safetensors")?,
|
||||||
|
api.get("model-00005-of-00005.safetensors")?,
|
||||||
|
]
|
||||||
|
} else {
|
||||||
|
vec![api.get("model.safetensors")?]
|
||||||
|
};
|
||||||
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
|
let mut config: t5::Config = serde_json::from_str(&config)?;
|
||||||
|
config.use_cache = !args.disable_cache;
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
Ok((
|
||||||
|
Self {
|
||||||
|
device,
|
||||||
|
config,
|
||||||
|
weights_filename,
|
||||||
|
},
|
||||||
|
tokenizer,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> {
|
||||||
|
let vb = unsafe {
|
||||||
|
VarBuilder::from_mmaped_safetensors(&self.weights_filename, DTYPE, &self.device)?
|
||||||
|
};
|
||||||
|
Ok(t5::T5EncoderModel::load(vb, &self.config)?)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> {
|
||||||
|
let vb = unsafe {
|
||||||
|
VarBuilder::from_mmaped_safetensors(&self.weights_filename, DTYPE, &self.device)?
|
||||||
|
};
|
||||||
|
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
|
||||||
|
let device = &builder.device;
|
||||||
|
let tokenizer = tokenizer
|
||||||
|
.with_padding(None)
|
||||||
|
.with_truncation(None)
|
||||||
|
.map_err(E::msg)?;
|
||||||
|
match args.prompt {
|
||||||
|
Some(prompt) => {
|
||||||
|
let tokens = tokenizer
|
||||||
|
.encode(prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||||
|
if !args.decode {
|
||||||
|
let mut model = builder.build_encoder()?;
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let ys = model.forward(&input_token_ids)?;
|
||||||
|
println!("{ys}");
|
||||||
|
println!("Took {:?}", start.elapsed());
|
||||||
|
} else {
|
||||||
|
let mut model = builder.build_conditional_generation()?;
|
||||||
|
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
||||||
|
if let Some(decoder_prompt) = &args.decoder_prompt {
|
||||||
|
print!("{decoder_prompt}");
|
||||||
|
output_token_ids.extend(
|
||||||
|
tokenizer
|
||||||
|
.encode(decoder_prompt.to_string(), false)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let temperature = if args.temperature <= 0. {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(args.temperature)
|
||||||
|
};
|
||||||
|
let mut logits_processor = LogitsProcessor::new(299792458, temperature, args.top_p);
|
||||||
|
let encoder_output = model.encode(&input_token_ids)?;
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
|
||||||
|
for index in 0.. {
|
||||||
|
if output_token_ids.len() > 512 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let decoder_token_ids = if index == 0 || !builder.config.use_cache {
|
||||||
|
Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
|
||||||
|
} else {
|
||||||
|
let last_token = *output_token_ids.last().unwrap();
|
||||||
|
Tensor::new(&[last_token], device)?.unsqueeze(0)?
|
||||||
|
};
|
||||||
|
let logits = model
|
||||||
|
.decode(&decoder_token_ids, &encoder_output)?
|
||||||
|
.squeeze(0)?;
|
||||||
|
let logits = if args.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
args.repeat_penalty,
|
||||||
|
&output_token_ids[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let next_token_id = logits_processor.sample(&logits)?;
|
||||||
|
if next_token_id as usize == builder.config.eos_token_id {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
output_token_ids.push(next_token_id);
|
||||||
|
if let Some(text) = tokenizer.id_to_token(next_token_id) {
|
||||||
|
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||||
|
print!("{text}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let dt = start.elapsed();
|
||||||
|
println!(
|
||||||
|
"\n{} tokens generated ({:.2} token/s)\n",
|
||||||
|
output_token_ids.len(),
|
||||||
|
output_token_ids.len() as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
let mut model = builder.build_encoder()?;
|
||||||
|
let sentences = [
|
||||||
|
"The cat sits outside",
|
||||||
|
"A man is playing guitar",
|
||||||
|
"I love pasta",
|
||||||
|
"The new movie is awesome",
|
||||||
|
"The cat plays in the garden",
|
||||||
|
"A woman watches TV",
|
||||||
|
"The new movie is so great",
|
||||||
|
"Do you like pizza?",
|
||||||
|
];
|
||||||
|
let n_sentences = sentences.len();
|
||||||
|
let mut all_embeddings = Vec::with_capacity(n_sentences);
|
||||||
|
for sentence in sentences {
|
||||||
|
let tokens = tokenizer
|
||||||
|
.encode(sentence, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
let token_ids = Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?;
|
||||||
|
let embeddings = model.forward(&token_ids)?;
|
||||||
|
println!("generated embeddings {:?}", embeddings.shape());
|
||||||
|
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||||
|
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||||
|
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||||
|
let embeddings = if args.normalize_embeddings {
|
||||||
|
normalize_l2(&embeddings)?
|
||||||
|
} else {
|
||||||
|
embeddings
|
||||||
|
};
|
||||||
|
println!("pooled embeddings {:?}", embeddings.shape());
|
||||||
|
all_embeddings.push(embeddings)
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut similarities = vec![];
|
||||||
|
for (i, e_i) in all_embeddings.iter().enumerate() {
|
||||||
|
for (j, e_j) in all_embeddings
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.take(n_sentences)
|
||||||
|
.skip(i + 1)
|
||||||
|
{
|
||||||
|
let sum_ij = (e_i * e_j)?.sum_all()?.to_scalar::<f32>()?;
|
||||||
|
let sum_i2 = (e_i * e_i)?.sum_all()?.to_scalar::<f32>()?;
|
||||||
|
let sum_j2 = (e_j * e_j)?.sum_all()?.to_scalar::<f32>()?;
|
||||||
|
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
|
||||||
|
similarities.push((cosine_similarity, i, j))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
|
||||||
|
for &(score, i, j) in similarities[..5].iter() {
|
||||||
|
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
|
||||||
|
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
|
||||||
|
}
|
39
candle-examples/examples/whisper/README.md
Normal file
39
candle-examples/examples/whisper/README.md
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
# candle-whisper: speech recognition
|
||||||
|
|
||||||
|
An implementation of [OpenAI Whisper](https://github.com/openai/whisper) using
|
||||||
|
candle. Whisper is a general purpose speech recognition model, it can be used to
|
||||||
|
convert audio files (in the `.wav` format) to text. Supported features include
|
||||||
|
language detection as well as multilingual speech recognition.
|
||||||
|
|
||||||
|
## Running some example
|
||||||
|
|
||||||
|
If no audio file is passed as input, a [sample
|
||||||
|
file](https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_jfk.wav) is automatically downloaded
|
||||||
|
from the hub.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example whisper --release
|
||||||
|
|
||||||
|
> No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav
|
||||||
|
> loaded wav data: Header { audio_format: 1, channel_count: 1, sampling_rate: 16000, bytes_per_second: 32000, bytes_per_sample: 2, bits_per_sample: 16 }
|
||||||
|
> pcm data loaded 176000
|
||||||
|
> loaded mel: [1, 80, 3000]
|
||||||
|
> 0.0s -- 30.0s: And so my fellow Americans ask not what your country can do for you ask what you can do for your country
|
||||||
|
```
|
||||||
|
|
||||||
|
In order to use the multilingual mode, specify a multilingual model via the
|
||||||
|
`--model` flag, see the details below.
|
||||||
|
|
||||||
|
## Command line flags
|
||||||
|
|
||||||
|
- `--input`: the audio file to be converted to text, in wav format.
|
||||||
|
- `--language`: force the language to some specific value rather than being
|
||||||
|
detected, e.g. `en`.
|
||||||
|
- `--task`: the task to be performed, can be `transcribe` (return the text data
|
||||||
|
in the original language) or `translate` (translate the text to English).
|
||||||
|
- `--timestamps`: enable the timestamp mode where some timestamps are reported
|
||||||
|
for each recognized audio extracts.
|
||||||
|
- `--model`: the model to be used. Models that do not end with `-en` are
|
||||||
|
multilingual models, other ones are English only models. The supported models
|
||||||
|
are `tiny`, `tiny.en`, `base`, `base.en`, `small`, `small.en`, `medium`,
|
||||||
|
`medium.en`, `large`, and `large-v2`.
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user